Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

al_agent.py 2.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
  1. import abc
  2. import numpy as np
  3. import tensorflow as tf
  4. from src.environment import ClassiferALEnvironmentT
  5. from src.utils.utils import (
  6. batch_sample_indices,
  7. )
  8. from src.utils.fixed_heap import (
  9. FixedHeap,
  10. )
  11. from typing import List
  12. class ClassifierALAgentT(abc.ABC):
  13. def __init__(self, env: ClassiferALEnvironmentT):
  14. self.al_environment = env
  15. @abc.abstractmethod
  16. def select_data_to_label(
  17. self,
  18. n_labels: int,
  19. ) -> List[int]:
  20. """
  21. choose up to n points to label
  22. """
  23. ...
  24. class RandomALAgent(ClassifierALAgentT):
  25. def select_data_to_label(
  26. self,
  27. n_labels: int,
  28. ) -> List[int]:
  29. """
  30. choose up to n points to label
  31. """
  32. al_manager = self.al_environment.al_manager
  33. n_to_label = min(n_labels, al_manager.num_unlabelled)
  34. unlabelled_idx, _ = al_manager.unlabelled_train_data
  35. return np.random.choice(
  36. unlabelled_idx, n_to_label, replace=False)
  37. class LeastConfidentALAgent(ClassifierALAgentT):
  38. def select_data_to_label(
  39. self,
  40. n_labels: int,
  41. ) -> List[int]:
  42. """
  43. choose up to n points to label
  44. """
  45. model = self.al_environment.model
  46. al_manager = self.al_environment.al_manager
  47. n_to_label = min(n_labels, al_manager.num_unlabelled)
  48. heap = FixedHeap(key=lambda x : x[0])
  49. unlabelled_indices, unlabelled_x = (
  50. al_manager.unlabelled_train_data)
  51. # we need to keep the original indices b/c that is the actual action we are taking
  52. for batch_indices in batch_sample_indices(unlabelled_indices.shape[0], shuffle=False):
  53. batch_original_indices = unlabelled_indices[batch_indices]
  54. batch_x = unlabelled_x[batch_indices]
  55. prediction = model(batch_x, training=False)
  56. # we get absolute value of prediction logit which is how confident
  57. # confidences = tf.math.abs(prediction)
  58. # multiclassifier confidence
  59. prediction = tf.nn.softmax(prediction) # normalize to softmax
  60. most_confident_prediction = tf.math.reduce_max(prediction, axis=1)
  61. confidences = tf.math.abs(most_confident_prediction - 0.5)
  62. for confidence, index in zip(confidences, batch_original_indices):
  63. if len(heap) < n_to_label:
  64. # push - confidnece since we want to pop most confident
  65. heap.push((-confidence, index))
  66. else:
  67. top_confidence, _ = heap.top()
  68. if confidence < -top_confidence:
  69. heap.pop()
  70. heap.push((-confidence, index))
  71. label_selection = []
  72. while len(heap) > 0:
  73. _, idx = heap.pop()
  74. label_selection.append(idx)
  75. del heap
  76. return np.array(label_selection)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...