Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#669 Hotfix/sg 645 regression tests essential fixes

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:hotfix/SG-645_limit_tests_forward_passes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
  1. import unittest
  2. from super_gradients.common.object_names import Models
  3. from super_gradients.training.dataloaders.dataloaders import coco2017_train, coco2017_val
  4. from super_gradients.training.metrics.detection_metrics import DetectionMetrics
  5. from super_gradients.training import Trainer, models
  6. from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
  7. class TestDatasetStatisticsTensorboardLogger(unittest.TestCase):
  8. def test_dataset_statistics_tensorboard_logger(self):
  9. """
  10. ** IMPORTANT NOTE **
  11. This test is not the usual fail/pass test - it is a visual test. The success criteria is your own visual check
  12. After launching the test, follow the log the see where was the tensorboard opened. open the tensorboard in your
  13. browser and make sure the text and plots in the tensorboard are as expected.
  14. """
  15. # Create dataset
  16. trainer = Trainer("dataset_statistics_visual_test")
  17. model = models.get(Models.YOLOX_S)
  18. training_params = {
  19. "max_epochs": 1, # we dont really need the actual training to run
  20. "lr_mode": "cosine",
  21. "initial_lr": 0.01,
  22. "loss": "yolox_loss",
  23. "criterion_params": {"strides": [8, 16, 32], "num_classes": 80},
  24. "dataset_statistics": True,
  25. "launch_tensorboard": True,
  26. "valid_metrics_list": [DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), normalize_targets=True, num_cls=80)],
  27. "metric_to_watch": "mAP@0.50:0.95",
  28. }
  29. trainer.train(model=model, training_params=training_params, train_loader=coco2017_train(), valid_loader=coco2017_val())
  30. if __name__ == "__main__":
  31. unittest.main()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file