Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

dataset_statistics_test.py 2.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
  1. import unittest
  2. from super_gradients.training.dataloaders.dataloaders import coco2017_train, coco2017_val
  3. from super_gradients.training.metrics.detection_metrics import DetectionMetrics
  4. from super_gradients.training import Trainer, models
  5. from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
  6. class TestDatasetStatisticsTensorboardLogger(unittest.TestCase):
  7. def test_dataset_statistics_tensorboard_logger(self):
  8. """
  9. ** IMPORTANT NOTE **
  10. This test is not the usual fail/pass test - it is a visual test. The success criteria is your own visual check
  11. After launching the test, follow the log the see where was the tensorboard opened. open the tensorboard in your
  12. browser and make sure the text and plots in the tensorboard are as expected.
  13. """
  14. # Create dataset
  15. trainer = Trainer('dataset_statistics_visual_test',
  16. model_checkpoints_location='local',
  17. post_prediction_callback=YoloPostPredictionCallback())
  18. model = models.get("yolox_s")
  19. training_params = {"max_epochs": 1, # we dont really need the actual training to run
  20. "lr_mode": "cosine",
  21. "initial_lr": 0.01,
  22. "loss": "yolox_loss",
  23. "criterion_params": {"strides": [8, 16, 32], "num_classes": 80},
  24. "dataset_statistics": True,
  25. "launch_tensorboard": True,
  26. "valid_metrics_list": [
  27. DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(),
  28. normalize_targets=True,
  29. num_cls=80)],
  30. "metric_to_watch": "mAP@0.50:0.95",
  31. }
  32. trainer.train(model=model, training_params=training_params, train_loader=coco2017_train(), valid_loader=coco2017_val())
  33. if __name__ == '__main__':
  34. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...