Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#381 Feature/sg 000 connect to lab

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/sg-000_connect_to_lab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
  1. import unittest
  2. from super_gradients.training.dataloaders.dataloaders import coco2017_train, coco2017_val
  3. from super_gradients.training.metrics.detection_metrics import DetectionMetrics
  4. from super_gradients.training import Trainer, models
  5. from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
  6. class TestDatasetStatisticsTensorboardLogger(unittest.TestCase):
  7. def test_dataset_statistics_tensorboard_logger(self):
  8. """
  9. ** IMPORTANT NOTE **
  10. This test is not the usual fail/pass test - it is a visual test. The success criteria is your own visual check
  11. After launching the test, follow the log the see where was the tensorboard opened. open the tensorboard in your
  12. browser and make sure the text and plots in the tensorboard are as expected.
  13. """
  14. # Create dataset
  15. trainer = Trainer('dataset_statistics_visual_test',
  16. model_checkpoints_location='local',
  17. post_prediction_callback=YoloPostPredictionCallback())
  18. model = models.get("yolox_s")
  19. training_params = {"max_epochs": 1, # we dont really need the actual training to run
  20. "lr_mode": "cosine",
  21. "initial_lr": 0.01,
  22. "loss": "yolox_loss",
  23. "criterion_params": {"strides": [8, 16, 32], "num_classes": 80},
  24. "dataset_statistics": True,
  25. "launch_tensorboard": True,
  26. "valid_metrics_list": [
  27. DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(),
  28. normalize_targets=True,
  29. num_cls=80)],
  30. "loss_logging_items_names": ["iou", "obj", "cls", "l1", "num_fg", "Loss"],
  31. "metric_to_watch": "mAP@0.50:0.95",
  32. }
  33. trainer.train(model=model, training_params=training_params, train_loader=coco2017_train(), valid_loader=coco2017_val())
  34. if __name__ == '__main__':
  35. unittest.main()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file