Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_al_agent.py 1.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
  1. from src.environment import (
  2. ClassiferALEnvironmentT,
  3. )
  4. from src.al_manager import Cifar10ALManager
  5. from src.model_manager import ClassifierModelManager
  6. from src.model.cifar10_model import Cifar10Model
  7. from src.al_agent import (
  8. RandomALAgent,
  9. LeastConfidentALAgent,
  10. )
  11. import tensorflow as tf
  12. config = tf.compat.v1.ConfigProto()
  13. config.gpu_options.allow_growth = True
  14. session = tf.compat.v1.Session(config=config)
  15. def get_model():
  16. return Cifar10Model(2)
  17. model_manager = ClassifierModelManager(get_model, 1)
  18. al_manager = Cifar10ALManager([0,1], [0.5, 0.5], 0.1)
  19. class TestClassifierALEnv(ClassiferALEnvironmentT):
  20. def get_reward(self):
  21. return None
  22. def get_observation(self):
  23. return None
  24. env = TestClassifierALEnv(al_manager, model_manager)
  25. agent = RandomALAgent(env)
  26. # env.warm_start(100)
  27. # env.train_step()
  28. # indices_to_label = agent.select_data_to_label(5)
  29. # env.label_step(indices_to_label)
  30. agent = LeastConfidentALAgent(env)
  31. env.warm_start(100)
  32. env.train_step()
  33. indices_to_label = agent.select_data_to_label(5)
  34. env.label_step(indices_to_label)
  35. print(env.evaluate_model("test"))
  36. # env.train_step()
  37. # indices_to_label = agent.select_data_to_label(5)
  38. # env.label_step(indices_to_label)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...