Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

SampleGeneratorFaceTemporal.py 3.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
  1. import multiprocessing
  2. import pickle
  3. import time
  4. import traceback
  5. import cv2
  6. import numpy as np
  7. from core import mplib
  8. from core.joblib import SubprocessGenerator, ThisThreadGenerator
  9. from facelib import LandmarksProcessor
  10. from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
  11. SampleType)
  12. class SampleGeneratorFaceTemporal(SampleGeneratorBase):
  13. def __init__ (self, samples_path, debug, batch_size,
  14. temporal_image_count=3,
  15. sample_process_options=SampleProcessor.Options(),
  16. output_sample_types=[],
  17. generators_count=2,
  18. **kwargs):
  19. super().__init__(debug, batch_size)
  20. self.temporal_image_count = temporal_image_count
  21. self.sample_process_options = sample_process_options
  22. self.output_sample_types = output_sample_types
  23. if self.debug:
  24. self.generators_count = 1
  25. else:
  26. self.generators_count = generators_count
  27. samples = SampleLoader.load (SampleType.FACE_TEMPORAL_SORTED, samples_path)
  28. samples_len = len(samples)
  29. if samples_len == 0:
  30. raise ValueError('No training data provided.')
  31. mult_max = 1
  32. l = samples_len - ( (self.temporal_image_count)*mult_max - (mult_max-1) )
  33. index_host = mplib.IndexHost(l+1)
  34. pickled_samples = pickle.dumps(samples, 4)
  35. if self.debug:
  36. self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(),) )]
  37. else:
  38. self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(),) ) for i in range(self.generators_count) ]
  39. self.generator_counter = -1
  40. def __iter__(self):
  41. return self
  42. def __next__(self):
  43. self.generator_counter += 1
  44. generator = self.generators[self.generator_counter % len(self.generators) ]
  45. return next(generator)
  46. def batch_func(self, param):
  47. mult_max = 1
  48. bs = self.batch_size
  49. pickled_samples, index_host = param
  50. samples = pickle.loads(pickled_samples)
  51. while True:
  52. batches = None
  53. indexes = index_host.multi_get(bs)
  54. for n_batch in range(self.batch_size):
  55. idx = indexes[n_batch]
  56. temporal_samples = []
  57. mult = np.random.randint(mult_max)+1
  58. for i in range( self.temporal_image_count ):
  59. sample = samples[ idx+i*mult ]
  60. try:
  61. temporal_samples += SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug)[0]
  62. except:
  63. raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
  64. if batches is None:
  65. batches = [ [] for _ in range(len(temporal_samples)) ]
  66. for i in range(len(temporal_samples)):
  67. batches[i].append ( temporal_samples[i] )
  68. yield [ np.array(batch) for batch in batches]
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...