Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

environment.py 2.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
  1. """
  2. Active Learning Environment interface
  3. """
  4. import abc
  5. import numpy as np
  6. import tensorflow as tf
  7. from attr import attrs, attrib
  8. from src.al_manager import Cifar10ALManager
  9. from src.model_manager import ClassifierModelManager
  10. from tf_agents.specs import array_spec
  11. from sklearn.metrics import f1_score
  12. from typing import Callable
  13. @attrs
  14. class ClassiferALEnvironmentT(abc.ABC):
  15. al_manager: Cifar10ALManager = attrib()
  16. model_manager: ClassifierModelManager = attrib()
  17. @property
  18. def model(self):
  19. return self.model_manager.model
  20. @property
  21. def n_step(self):
  22. """
  23. current "time" step (# of datapoint labelled)
  24. """
  25. return self.al_manager.num_labelled
  26. def reset(self):
  27. self.al_manager.reset()
  28. self.model_manager.reset_model()
  29. def warm_start(self, n_to_label):
  30. """
  31. warm start of environment (initial sample that is random)
  32. """
  33. self.reset()
  34. unlabelled_idx, unlabelled_x = self.al_manager.unlabelled_train_data
  35. label_indices = np.random.choice(
  36. unlabelled_idx, n_to_label, replace=False)
  37. self.label_step(label_indices)
  38. def train_step(self, retrain=False):
  39. """
  40. a pseudostep to train the model, not an action
  41. """
  42. if retrain:
  43. self.model_manager.reset_model()
  44. _, x, y = self.al_manager.labelled_train_data
  45. self.model_manager.train_model(x, y)
  46. def label_step(self, indices_to_label):
  47. """
  48. a label step by the AL agent
  49. """
  50. if self.al_manager.data_is_labelled(indices_to_label):
  51. raise Exception(
  52. "invalid action, data is already labelled")
  53. self.al_manager.label_data(indices_to_label)
  54. # def evaluate_model(self, data_type: str):
  55. # x, y = self.al_manager.get_dataset(data_type)
  56. # model_manager = self.model_manager
  57. # return model_manager.evaluate_model(x, y)
  58. @abc.abstractmethod
  59. def get_reward(self):
  60. """
  61. gets reward bassed on current state of environment.
  62. we do not return the reward right after a step b/c we might
  63. call pseudo steps such as train_step before generating reward
  64. """
  65. ...
  66. @abc.abstractmethod
  67. def get_observation(self):
  68. """
  69. gets current state of environment
  70. """
  71. # get state of environment
  72. # can be model/data agnostistic or not
  73. ...
  74. class BaseClassiferALEnvironment(ClassiferALEnvironmentT):
  75. def get_reward(self):
  76. return None
  77. def get_observation(self):
  78. return None
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...