Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

sampler.py 5.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
  1. import numpy as np
  2. import random
  3. from src.utils.utils import (
  4. batch_sample_slices,
  5. batch_sample_indices,
  6. )
  7. from src.utils.fixed_heap import (
  8. FixedHeap,
  9. )
  10. from typing import List
  11. import numpy as np
  12. import tensorflow as tf
  13. class ActiveLearningSamplerT:
  14. """
  15. ActiveLearningSampler manages a index dataset
  16. and what is labeled/unlabled
  17. """
  18. def __init__(self, n_elements):
  19. self.labelled_idx_set = set()
  20. self.unlabelled_idx_set = set([i for i in range(n_elements)])
  21. @property
  22. def n_labelled(self):
  23. return len(self.labelled_idx_set)
  24. def label_n_elements(self, n_elements: int, **kwargs) -> int:
  25. """
  26. chooses n labeled indices to labeled
  27. returns # of new elemnts labelled
  28. """
  29. # labels
  30. assert NotADirectoryError("not implemented")
  31. def get_labelled_set(self):
  32. return self.labelled_idx_set
  33. class ALRandomSampler(ActiveLearningSamplerT):
  34. def label_n_elements(self, n_elements: int) -> int:
  35. n_sampled = min(len(self.unlabelled_idx_set), n_elements)
  36. new_labels = set(random.sample(self.unlabelled_idx_set, n_sampled))
  37. self.labelled_idx_set |= new_labels
  38. self.unlabelled_idx_set -= new_labels
  39. return n_sampled
  40. class LeastConfidenceSampler(ActiveLearningSamplerT):
  41. _batch_sampler_size = 32
  42. def __init__(self, train_data):
  43. n_elements = len(train_data)
  44. super().__init__(n_elements)
  45. self.train_data = train_data
  46. def label_n_elements(
  47. self,
  48. n_elements: int,
  49. model,
  50. ) -> int:
  51. """
  52. chooses n labeled indices to labeled
  53. returns # of new elemnts labelled
  54. """
  55. n_to_sample = min(len(self.unlabelled_idx_set), n_elements)
  56. unlabelled_indices = list(self.unlabelled_idx_set)
  57. heap = FixedHeap(key=lambda x : x[0])
  58. train_x = self.train_data
  59. # we need to keep the original indices
  60. for batch_indices in batch_sample_slices(unlabelled_indices, shuffle=False):
  61. batch_x = train_x[batch_indices]
  62. prediction = model(batch_x, training=False)
  63. # we get absolute value of prediction logit which is how confident
  64. # confidences = tf.math.abs(prediction)
  65. # multiclassifier confidence
  66. prediction = tf.nn.softmax(prediction)
  67. confidences = tf.math.reduce_max(prediction, axis=0)
  68. for confidence, index in zip(confidences, batch_indices):
  69. if len(heap) < n_to_sample:
  70. # push - confidnece since we want to pop most confident
  71. heap.push((-confidence, index))
  72. else:
  73. top_confidence, _ = heap.top()
  74. if confidence < -top_confidence:
  75. heap.pop()
  76. heap.push((-confidence, index))
  77. while len(heap) > 0:
  78. _, idx = heap.pop()
  79. self.labelled_idx_set.add(idx)
  80. self.unlabelled_idx_set.remove(idx)
  81. del heap
  82. return n_to_sample
  83. class UCBBanditSampler(ActiveLearningSamplerT):
  84. def __init__(self, train_data):
  85. self.n_elements = len(train_data)
  86. super().__init__(self.n_elements)
  87. self.samplers = [
  88. ALRandomSampler(self.n_elements),
  89. LeastConfidenceSampler(train_data)
  90. ]
  91. # we make sure we share the same set
  92. for sampler in self.samplers:
  93. sampler.unlabelled_idx_set = (
  94. self.unlabelled_idx_set)
  95. sampler.labelled_idx_set = (
  96. self.labelled_idx_set)
  97. self.n_samplers = len(self.samplers)
  98. self.q_value = np.zeros(self.n_samplers)
  99. self.arm_count = np.zeros(self.n_samplers)
  100. self.total_arm_count = 0
  101. def get_action(self, arm: int) -> str:
  102. return self.samplers[arm].__class__.__name__
  103. def label_n_elements(
  104. self,
  105. n_elements: int,
  106. model) -> (int, int):
  107. # https://lilianweng.github.io/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html#ucb1
  108. # UCB 1 algorithm stolen here
  109. # if there are any actions that we have not tried, we randomly selection an action
  110. indices = np.where(self.arm_count == 0)[0]
  111. if len(indices) > 0:
  112. arm = np.random.choice(indices)
  113. else:
  114. exploration = (2*np.math.log(self.total_arm_count)/self.arm_count)**(0.5)
  115. ucb = self.q_value + exploration
  116. arm = np.argmax(ucb)
  117. sampler_selected = self.samplers[arm]
  118. # TODO add logging of which arm selected
  119. if isinstance(sampler_selected, ALRandomSampler):
  120. n_labeled = sampler_selected.label_n_elements(n_elements)
  121. if isinstance(sampler_selected, LeastConfidenceSampler):
  122. n_labeled = sampler_selected.label_n_elements(n_elements, model)
  123. return arm, n_labeled
  124. def update_q_value(self, arm: int, reward: float) -> None:
  125. self.total_arm_count += 1
  126. self.arm_count[arm] += 1
  127. # running avg
  128. # TODO we can probably can do more aggressive score decay
  129. self.q_value[arm] += (reward - self.q_value[arm])/self.arm_count[arm]
  130. class RLSampler(ActiveLearningSamplerT):
  131. pass
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...