Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

SampleGeneratorFace.py 5.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
  1. import multiprocessing
  2. import time
  3. import traceback
  4. import cv2
  5. import numpy as np
  6. from core import mplib
  7. from core.interact import interact as io
  8. from core.joblib import SubprocessGenerator, ThisThreadGenerator
  9. from facelib import LandmarksProcessor
  10. from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
  11. SampleType)
  12. '''
  13. arg
  14. output_sample_types = [
  15. [SampleProcessor.TypeFlags, size, (optional) {} opts ] ,
  16. ...
  17. ]
  18. '''
  19. class SampleGeneratorFace(SampleGeneratorBase):
  20. def __init__ (self, samples_path, debug=False, batch_size=1,
  21. random_ct_samples_path=None,
  22. sample_process_options=SampleProcessor.Options(),
  23. output_sample_types=[],
  24. uniform_yaw_distribution=False,
  25. generators_count=4,
  26. raise_on_no_data=True,
  27. **kwargs):
  28. super().__init__(debug, batch_size)
  29. self.initialized = False
  30. self.sample_process_options = sample_process_options
  31. self.output_sample_types = output_sample_types
  32. if self.debug:
  33. self.generators_count = 1
  34. else:
  35. self.generators_count = max(1, generators_count)
  36. samples = SampleLoader.load (SampleType.FACE, samples_path)
  37. self.samples_len = len(samples)
  38. if self.samples_len == 0:
  39. if raise_on_no_data:
  40. raise ValueError('No training data provided.')
  41. else:
  42. return
  43. if uniform_yaw_distribution:
  44. samples_pyr = [ ( idx, sample.get_pitch_yaw_roll() ) for idx, sample in enumerate(samples) ]
  45. grads = 128
  46. #instead of math.pi / 2, using -1.2,+1.2 because actually maximum yaw for 2DFAN landmarks are -1.2+1.2
  47. grads_space = np.linspace (-1.2, 1.2,grads)
  48. yaws_sample_list = [None]*grads
  49. for g in io.progress_bar_generator ( range(grads), "Sort by yaw"):
  50. yaw = grads_space[g]
  51. next_yaw = grads_space[g+1] if g < grads-1 else yaw
  52. yaw_samples = []
  53. for idx, pyr in samples_pyr:
  54. s_yaw = -pyr[1]
  55. if (g == 0 and s_yaw < next_yaw) or \
  56. (g < grads-1 and s_yaw >= yaw and s_yaw < next_yaw) or \
  57. (g == grads-1 and s_yaw >= yaw):
  58. yaw_samples += [ idx ]
  59. if len(yaw_samples) > 0:
  60. yaws_sample_list[g] = yaw_samples
  61. yaws_sample_list = [ y for y in yaws_sample_list if y is not None ]
  62. index_host = mplib.Index2DHost( yaws_sample_list )
  63. else:
  64. index_host = mplib.IndexHost(self.samples_len)
  65. if random_ct_samples_path is not None:
  66. ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path)
  67. ct_index_host = mplib.IndexHost( len(ct_samples) )
  68. else:
  69. ct_samples = None
  70. ct_index_host = None
  71. if self.debug:
  72. self.generators = [ThisThreadGenerator ( self.batch_func, (samples, index_host.create_cli(), ct_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )]
  73. else:
  74. self.generators = [SubprocessGenerator ( self.batch_func, (samples, index_host.create_cli(), ct_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=False ) \
  75. for i in range(self.generators_count) ]
  76. SubprocessGenerator.start_in_parallel( self.generators )
  77. self.generator_counter = -1
  78. self.initialized = True
  79. #overridable
  80. def is_initialized(self):
  81. return self.initialized
  82. def __iter__(self):
  83. return self
  84. def __next__(self):
  85. if not self.initialized:
  86. return []
  87. self.generator_counter += 1
  88. generator = self.generators[self.generator_counter % len(self.generators) ]
  89. return next(generator)
  90. def batch_func(self, param ):
  91. samples, index_host, ct_samples, ct_index_host = param
  92. bs = self.batch_size
  93. while True:
  94. batches = None
  95. indexes = index_host.multi_get(bs)
  96. ct_indexes = ct_index_host.multi_get(bs) if ct_samples is not None else None
  97. t = time.time()
  98. for n_batch in range(bs):
  99. sample_idx = indexes[n_batch]
  100. sample = samples[sample_idx]
  101. ct_sample = None
  102. if ct_samples is not None:
  103. ct_sample = ct_samples[ct_indexes[n_batch]]
  104. try:
  105. x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample)
  106. except:
  107. raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
  108. if batches is None:
  109. batches = [ [] for _ in range(len(x)) ]
  110. for i in range(len(x)):
  111. batches[i].append ( x[i] )
  112. yield [ np.array(batch) for batch in batches]
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...