Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

yolo_nas_exp.py 4.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
  1. import os
  2. import requests
  3. import torch
  4. from PIL import Image
  5. from super_gradients.training import Trainer, dataloaders, models
  6. from super_gradients.training.dataloaders.dataloaders import (
  7. coco_detection_yolo_format_train, coco_detection_yolo_format_val
  8. )
  9. from super_gradients.training.losses import PPYoloELoss
  10. from super_gradients.training.metrics import DetectionMetrics_050
  11. from super_gradients.training.models.detection_models.pp_yolo_e import (
  12. PPYoloEPostPredictionCallback
  13. )
  14. class config:
  15. #trainer params
  16. CHECKPOINT_DIR = '../models' #specify the path you want to save checkpoints to
  17. EXPERIMENT_NAME = 'airplane_det_yolonas' #specify the experiment name
  18. #dataset params
  19. DATA_DIR = '../data/' #parent directory to where data lives
  20. TRAIN_IMAGES_DIR = 'images/train' #child dir of DATA_DIR where train images are
  21. TRAIN_LABELS_DIR = 'labels/train' #child dir of DATA_DIR where train labels are
  22. VAL_IMAGES_DIR = 'images/val' #child dir of DATA_DIR where validation images are
  23. VAL_LABELS_DIR = 'labels/val' #child dir of DATA_DIR where validation labels are
  24. # if you have a test set
  25. # TEST_IMAGES_DIR = 'test/images' #child dir of DATA_DIR where test images are
  26. # TEST_LABELS_DIR = 'test/labels' #child dir of DATA_DIR where test labels are
  27. CLASSES = ['airplane'] #what class names do you have
  28. NUM_CLASSES = len(CLASSES)
  29. #dataloader params - you can add whatever PyTorch dataloader params you have
  30. #could be different across train, val, and test
  31. DATALOADER_PARAMS={
  32. 'batch_size':16,
  33. 'num_workers':2
  34. }
  35. # model params
  36. MODEL_NAME = 'yolo_nas_m' # choose from yolo_nas_s, yolo_nas_m, yolo_nas_l
  37. PRETRAINED_WEIGHTS = 'coco' #only one option here: coco
  38. train_data = coco_detection_yolo_format_train(
  39. dataset_params={
  40. 'data_dir': config.DATA_DIR,
  41. 'images_dir': config.TRAIN_IMAGES_DIR,
  42. 'labels_dir': config.TRAIN_LABELS_DIR,
  43. 'classes': config.CLASSES
  44. },
  45. dataloader_params=config.DATALOADER_PARAMS
  46. )
  47. val_data = coco_detection_yolo_format_val(
  48. dataset_params={
  49. 'data_dir': config.DATA_DIR,
  50. 'images_dir': config.VAL_IMAGES_DIR,
  51. 'labels_dir': config.VAL_LABELS_DIR,
  52. 'classes': config.CLASSES
  53. },
  54. dataloader_params=config.DATALOADER_PARAMS
  55. )
  56. model = models.get(config.MODEL_NAME,
  57. num_classes=config.NUM_CLASSES,
  58. pretrained_weights=config.PRETRAINED_WEIGHTS
  59. )
  60. train_params = {
  61. "sg_logger": "dagshub_sg_logger",
  62. "sg_logger_params": # Params that will be passes to __init__ of the logger super_gradients.common.sg_loggers.dagshub_sg_logger.DagsHubSGLogger
  63. {
  64. "dagshub_repository": "DagsHub/PlaneDetector", # Optional: Your DagsHub project name, consisting of the owner name, followed by '/', and the repo name. If this is left empty, you'll be prompted in your run to fill it in manually.
  65. "log_mlflow_only": False, # Optional: Change to true to bypass logging to DVC, and log all artifacts only to MLflow
  66. "save_checkpoints_remote": True,
  67. "save_tensorboard_remote": True,
  68. "save_logs_remote": True,
  69. },
  70. # ENABLING SILENT MODE
  71. "average_best_models":True,
  72. "warmup_mode": "linear_epoch_step",
  73. # "warmup_initial_lr": 1e-2,#1e-6,
  74. "lr_warmup_epochs": 3,
  75. "initial_lr": 1e-2,
  76. "lr_mode": "cosine",
  77. "cosine_final_lr_ratio": 0.1,
  78. "optimizer": "SGD",
  79. # "optimizer_params": {"weight_decay": 0.0001},
  80. "zero_weight_decay_on_bias_and_bn": True,
  81. "ema": True,
  82. "ema_params": {"decay": 0.9, "decay_type": "threshold"},
  83. # ONLY TRAINING FOR 10 EPOCHS FOR THIS EXAMPLE NOTEBOOK
  84. "max_epochs": 100,
  85. "mixed_precision": True,
  86. "loss": PPYoloELoss(
  87. use_static_assigner=False,
  88. # NOTE: num_classes needs to be defined here
  89. num_classes=config.NUM_CLASSES,
  90. reg_max=16
  91. ),
  92. "valid_metrics_list": [
  93. DetectionMetrics_050(
  94. score_thres=0.1,
  95. top_k_predictions=300,
  96. # NOTE: num_classes needs to be defined here
  97. num_cls=config.NUM_CLASSES,
  98. normalize_targets=True,
  99. post_prediction_callback=PPYoloEPostPredictionCallback(
  100. score_threshold=0.01,
  101. nms_top_k=1000,
  102. max_predictions=300,
  103. nms_threshold=0.7
  104. )
  105. )
  106. ],
  107. "metric_to_watch": 'mAP@0.50'
  108. }
  109. trainer = Trainer(experiment_name=config.EXPERIMENT_NAME, ckpt_root_dir=config.CHECKPOINT_DIR)
  110. trainer.train(model=model,
  111. training_params=train_params,
  112. train_loader=train_data,
  113. valid_loader=val_data)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...