Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

al_manager.py 4.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
  1. """
  2. Active Learning Manager classes manages the AL environment
  3. """
  4. import abc
  5. from typing import List
  6. import tensorflow as tf
  7. import numpy as np
  8. from src.utils.utils import to_onehot
  9. class Cifar10ALManager:
  10. def __init__(
  11. self,
  12. classes: List[int],
  13. class_ratio: List[float]=None, # defaults to uniform split
  14. validation_split: float=None,
  15. ):
  16. assert len(classes) == len(class_ratio)
  17. self.classes = classes # mapping from class to actual class
  18. self.num_classes = len(self.classes)
  19. self.class_ratio = class_ratio
  20. self.validation_split = validation_split
  21. self._init_dataset()
  22. self.pool_size = self.train_data[0].shape[0]
  23. self.is_labelled = np.repeat(False, self.pool_size)
  24. def _init_dataset(self):
  25. """
  26. constructs the train (pool) dataset
  27. , validation (used to compute reward if reward if relevent
  28. , and test environment
  29. """
  30. (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
  31. # Normalize pixel values to be between 0 and 1
  32. x_train, x_test = x_train / 255.0, x_test / 255.0
  33. y_train = y_train.flatten()
  34. y_test = y_test.flatten()
  35. def augment_dataset(x, y, classes, class_ratio):
  36. n_classes = len(classes)
  37. if class_ratio:
  38. # assumes equal class distribution from beginning
  39. class_ratio = np.array(class_ratio)/np.sum(class_ratio)
  40. else:
  41. class_ratio = np.ones(n_classes)/n_classes
  42. final_x, final_y = None, None
  43. for i, c, c_ratio in zip(np.arange(n_classes), classes, class_ratio):
  44. class_x = x[y==c]
  45. n_to_keep = int(class_x.shape[0] * c_ratio)
  46. # TODO shuffle?
  47. # instead of reusuing the class label, we start at 0,1,2...
  48. class_x = class_x[:n_to_keep]
  49. if final_x is not None:
  50. final_x = np.concatenate((final_x, class_x))
  51. final_y = np.concatenate((final_y, np.repeat(i, n_to_keep)))
  52. else:
  53. final_x = class_x
  54. final_y = np.repeat(i, n_to_keep)
  55. # 1 hot
  56. final_y = to_onehot(final_y, n_classes)
  57. final_n_points = final_y.shape[0]
  58. shuffle = np.random.permutation(final_n_points)
  59. return final_x[shuffle], final_y[shuffle]
  60. x_train, y_train = augment_dataset(x_train, y_train, self.classes, self.class_ratio)
  61. x_test, y_test = augment_dataset(x_test, y_test, self.classes, self.class_ratio)
  62. # build validation dataset
  63. if self.validation_split:
  64. num_validation = int(len(x_train) * self.validation_split)
  65. idx = np.random.choice(len(x_train), num_validation)
  66. mask = np.ones(len(x_train), np.bool)
  67. mask[idx] = 0
  68. x_val = x_train[~mask]
  69. y_val = y_train[~mask]
  70. x_train = x_train[mask]
  71. y_train = y_train[mask]
  72. # we keep as raw numpy as it's easier to index only the labelled set
  73. self.train_data = (x_train, y_train)
  74. self.test_data = (x_test, y_test)
  75. if self.validation_split:
  76. self.validation_data = (x_val, y_val)
  77. else:
  78. self.validation_split = None
  79. def reset(self):
  80. self.is_labelled = np.repeat(False, self.pool_size)
  81. def label_data(self, data_indices):
  82. self.is_labelled[data_indices] = True
  83. def data_is_labelled(self, data_indices):
  84. """
  85. returns if any data is labelled
  86. """
  87. return np.any(self.is_labelled[data_indices])
  88. @property
  89. def labelled_train_data(self):
  90. x, y = self.train_data
  91. return np.where(self.is_labelled)[0], x[self.is_labelled], y[self.is_labelled]
  92. @property
  93. def unlabelled_train_data(self):
  94. x, _ = self.train_data
  95. return np.where(~self.is_labelled)[0], x[~self.is_labelled]
  96. @property
  97. def num_labelled(self):
  98. return self.is_labelled[self.is_labelled].shape[0]
  99. @property
  100. def num_unlabelled(self):
  101. return self.is_labelled[~self.is_labelled].shape[0]
  102. def get_dataset(self, data_type: str):
  103. if data_type == "train":
  104. return self.train_data
  105. elif data_type == "test":
  106. return self.test_data
  107. elif data_type == "validation":
  108. data = self.validation_data
  109. if data is None:
  110. raise Exception("Validation data is not available")
  111. return data
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...