Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

yolo_v5_coco.py 6.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
  1. # Yolo v5 Detection training on CoCo2017 Dataset:
  2. # Yolo v5s train on 320x320 mAP@0.5-0.95 (confidence 0.001, test on 320x320 images) ~28.4
  3. # Yolo v5s train in 640x640 mAP@0.5-0.95 (confidence 0.001, test on 320x320 images) ~29.1
  4. # Yolo v5 Detection training on CoCo2014 Dataset:
  5. # Yolo v5s train on 320x320 mAP@0.5-0.95 (confidence 0.001, test on 320x320 images) ~28.77
  6. # batch size may need to change depending on model size and GPU (2080Ti, V100)
  7. # The code is optimized for running with a Mini-Batch of 64 examples... So depending on the amount of GPUs,
  8. # you should change the "batch_accumulate" param in the training_params dict to be batch_size * gpu_num * batch_accumulate = 64.
  9. import super_gradients
  10. import argparse
  11. import torch
  12. from super_gradients.training import SgModel, MultiGPUMode
  13. from super_gradients.training.datasets import CoCoDetectionDatasetInterface, CoCo2014DetectionDatasetInterface
  14. from super_gradients.training.models.yolov5 import YoloV5PostPredictionCallback
  15. from super_gradients.training.utils.detection_utils import base_detection_collate_fn
  16. from super_gradients.training.datasets.datasets_utils import ComposedCollateFunction, MultiScaleCollateFunction
  17. from super_gradients.common.aws_connection.aws_secrets_manager_connector import AWSSecretsManagerConnector
  18. from super_gradients.training.metrics import DetectionMetrics
  19. super_gradients.init_trainer()
  20. parser = argparse.ArgumentParser()
  21. #################################
  22. # Model Options
  23. ################################
  24. parser.add_argument("--model", type=str, required=True, choices=["s", "m", "l", "x", "c"],
  25. help='on of s,m,l,x,c (small, medium, large, extra-large, custom)')
  26. parser.add_argument("--depth", type=float, help='not applicable for default models(s/m/l/x)')
  27. parser.add_argument("--width", type=float, help='not applicable for default models(s/m/l/x)')
  28. parser.add_argument("--reload", action="store_true")
  29. parser.add_argument("--max_epochs", type=int, default=300)
  30. parser.add_argument("--batch", type=int, default=64)
  31. parser.add_argument("--test-img-size", type=int, default=320)
  32. parser.add_argument("--train-img-size", type=int, default=640)
  33. parser.add_argument("--multi-scale", action="store_true")
  34. parser.add_argument("--coco2014", action="store_true")
  35. args, _ = parser.parse_known_args()
  36. models_dict = {"s": "yolo_v5s", "m": "yolo_v5m", "l": "yolo_v5l", "x": "yolo_v5x", "c": "custom_yolov5"}
  37. if args.model == "c":
  38. assert args.depth is not None and args.width is not None, "when setting model type to c (custom), depth and width flags must be set"
  39. assert 0 <= args.depth <= 1, "depth must be in the range [0,1]"
  40. assert 0 <= args.width <= 1, "width must be in the range [0,1]"
  41. else:
  42. assert args.depth is None and args.width is None, "depth and width flags have no effect when the model is not c"
  43. args.model = models_dict[args.model]
  44. distributed = super_gradients.is_distributed()
  45. if args.multi_scale:
  46. train_collate_fn = ComposedCollateFunction([base_detection_collate_fn,
  47. MultiScaleCollateFunction(target_size=args.train_img_size)])
  48. else:
  49. train_collate_fn = base_detection_collate_fn
  50. dataset_params = {"batch_size": args.batch,
  51. "test_batch_size": args.batch,
  52. "train_image_size": args.train_img_size,
  53. "test_image_size": args.test_img_size,
  54. "test_collate_fn": base_detection_collate_fn,
  55. "train_collate_fn": train_collate_fn,
  56. "test_sample_loading_method": "default", # TODO: remove when fixing distributed_data_parallel
  57. "dataset_hyper_param": {
  58. "hsv_h": 0.015, # IMAGE HSV-Hue AUGMENTATION (fraction)
  59. "hsv_s": 0.7, # IMAGE HSV-Saturation AUGMENTATION (fraction)
  60. "hsv_v": 0.4, # IMAGE HSV-Value AUGMENTATION (fraction)
  61. "degrees": 0.0, # IMAGE ROTATION (+/- deg)
  62. "translate": 0.1, # IMAGE TRANSLATION (+/- fraction)
  63. "scale": 0.5, # IMAGE SCALE (+/- gain)
  64. "shear": 0.0} # IMAGE SHEAR (+/- deg)
  65. }
  66. arch_params = {"depth_mult_factor": args.depth,
  67. "width_mult_factor": args.width
  68. }
  69. dataset_string = 'coco2017' if not args.coco2014 else 'coco2014'
  70. model_repo_bucket_name = AWSSecretsManagerConnector.get_secret_value_for_secret_key(aws_env='research',
  71. secret_name='training_secrets',
  72. secret_key='S3.MODEL_REPOSITORY_BUCKET_NAME')
  73. model = SgModel(args.model + '____' + dataset_string,
  74. model_checkpoints_location="s3://" + model_repo_bucket_name,
  75. multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL if distributed else MultiGPUMode.DATA_PARALLEL,
  76. post_prediction_callback=YoloV5PostPredictionCallback())
  77. devices = torch.cuda.device_count() if not distributed else 1
  78. dataset_interface_class = CoCoDetectionDatasetInterface if not args.coco2014 else CoCo2014DetectionDatasetInterface
  79. dataset_interface = dataset_interface_class(dataset_params=dataset_params)
  80. model.connect_dataset_interface(dataset_interface, data_loader_num_workers=20)
  81. model.build_model(args.model, arch_params=arch_params, load_checkpoint=args.reload)
  82. post_prediction_callback = YoloV5PostPredictionCallback()
  83. training_params = {"max_epochs": args.max_epochs,
  84. "lr_mode": "cosine",
  85. "initial_lr": 0.01,
  86. "cosine_final_lr_ratio": 0.2,
  87. "lr_warmup_epochs": 3,
  88. "batch_accumulate": 1,
  89. "warmup_bias_lr": 0.1,
  90. "loss": "yolo_v5_loss",
  91. "criterion_params": {"model": model},
  92. "optimizer": "SGD",
  93. "warmup_momentum": 0.8,
  94. "optimizer_params": {"momentum": 0.937,
  95. "weight_decay": 0.0005 * (args.batch / 64.0),
  96. "nesterov": True},
  97. "mixed_precision": False,
  98. "ema": True,
  99. "train_metrics_list": [],
  100. "valid_metrics_list": [DetectionMetrics(post_prediction_callback=post_prediction_callback,
  101. num_cls=len(
  102. dataset_interface.coco_classes))],
  103. "loss_logging_items_names": ["GIoU", "obj", "cls", "Loss"],
  104. "metric_to_watch": "mAP@0.50:0.95",
  105. "greater_metric_to_watch_is_better": True}
  106. print(f"Training Yolo v5 {args.model} on {dataset_string.upper()}:\n width-mult={args.width}, depth-mult={args.depth}, "
  107. f"train-img-size={args.train_img_size}, test-img-size={args.test_img_size} ")
  108. model.train(training_params=training_params)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...