Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

cifar10_trainer_test.py 680 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
  1. import unittest
  2. import super_gradients
  3. from super_gradients import SgModel
  4. from super_gradients.training.datasets.dataset_interfaces import LibraryDatasetInterface
  5. class TestCifar10Trainer(unittest.TestCase):
  6. def test_train_cifar10(self):
  7. super_gradients.init_trainer()
  8. model = SgModel("test", model_checkpoints_location='local')
  9. cifar_10_dataset_interface = LibraryDatasetInterface(name="cifar10")
  10. model.connect_dataset_interface(cifar_10_dataset_interface)
  11. model.build_model("resnet18_cifar", arch_params={'num_classes': 10})
  12. model.train(training_params={"max_epochs": 1})
  13. if __name__ == '__main__':
  14. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...