Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_environment.py 761 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
  1. from src.environment import (
  2. ClassiferALEnvironmentT,
  3. )
  4. from src.al_manager import Cifar10ALManager
  5. from src.model_manager import ClassifierModelManager
  6. from src.model.cifar10_model import Cifar10Model
  7. import tensorflow as tf
  8. config = tf.compat.v1.ConfigProto()
  9. config.gpu_options.allow_growth = True
  10. session = tf.compat.v1.Session(config=config)
  11. def get_model():
  12. return Cifar10Model(2)
  13. model_manager = ClassifierModelManager(1)
  14. al_manager = Cifar10ALManager([0,1], [0.5, 0.5], 0.1)
  15. class TestClassifierALEnv(ClassiferALEnvironmentT):
  16. def get_reward(self):
  17. return None
  18. def get_observation(self):
  19. return None
  20. env = TestClassifierALEnv(100, get_model, model_manager, al_manager)
  21. env.reset()
  22. env.label_step(0)
  23. env.train_step()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...