Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#257 allow using an external Optimizer (not initialized outside)

Merged
Ofri Masad merged 1 commits into Deci-AI:master from deci-ai:feature/SG-184_external_optimizer
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
  1. """
  2. TODO: REFACTOR AS YAML FILES RECIPE
  3. Train DDRNet23 according to the paper
  4. Usage:
  5. python -m torch.distributed.launch --nproc_per_node=4 ddrnet_segmentation_example.py [-s for slim]
  6. [-d $n for decinet_$n backbone] --pretrained_bb_path <path>
  7. Training time:
  8. DDRNet23: 19H (on 4 x 2080Ti)
  9. DDRNet23 slim: 13H (on 4 x 2080Ti)
  10. Validation mIoU:
  11. DDRNet23: 78.94 (paper: 79.1±0.3)
  12. DDRNet23 slim: 76.79 (paper: 77.3±0.4)
  13. Official git repo:
  14. https://github.com/ydhongHIT/DDRNet
  15. Paper:
  16. https://arxiv.org/pdf/2101.06085.pdf
  17. Pretained checkpoints:
  18. Backbones (trained by the original authors):
  19. s3://deci-model-safe-research/DDRNet/DDRNet23_bb_imagenet.pth
  20. s3://deci-model-safe-research/DDRNet/DDRNet23s_bb_imagenet.pth
  21. Segmentation (trained using this recipe:
  22. s3://deci-model-safe-research/DDRNet/DDRNet23_new/ckpt_best.pth
  23. s3://deci-model-safe-research/DDRNet/DDRNet23s_new/ckpt_best.pth
  24. Comments:
  25. * Pretrained backbones were used
  26. * To pretrain the backbone on imagenet - see ddrnet_classification_example
  27. """
  28. import torch
  29. from super_gradients.training.metrics.segmentation_metrics import IoU, PixelAccuracy
  30. import super_gradients
  31. from super_gradients.training import SgModel, MultiGPUMode
  32. import argparse
  33. import torchvision.transforms as transforms
  34. from super_gradients.training.transforms.transforms import RandomFlip, RandomRescale, CropImageAndMask, \
  35. PadShortToCropSize
  36. from super_gradients.training.losses.ddrnet_loss import DDRNetLoss
  37. from super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation import CITYSCAPES_IGNORE_LABEL
  38. from super_gradients.training.datasets.dataset_interfaces.dataset_interface import CityscapesDatasetInterface
  39. parser = argparse.ArgumentParser()
  40. super_gradients.init_trainer()
  41. parser.add_argument("--reload", action="store_true")
  42. parser.add_argument("--max_epochs", type=int, default=485)
  43. parser.add_argument("--batch", type=int, default=3)
  44. parser.add_argument("--img_size", type=int, default=1024)
  45. parser.add_argument("--experiment_name", type=str, default="ddrnet_23")
  46. parser.add_argument("--pretrained_bb_path", type=str)
  47. parser.add_argument("-s", "--slim", action="store_true", help='train the slim version of DDRNet23')
  48. args, _ = parser.parse_known_args()
  49. distributed = super_gradients.is_distributed()
  50. devices = torch.cuda.device_count() if not distributed else 1
  51. dataset_params = {
  52. "batch_size": args.batch,
  53. "val_batch_size": args.batch,
  54. "dataset_dir": "/home/ofri/cityscapes/",
  55. "crop_size": args.img_size,
  56. "img_size": args.img_size,
  57. "image_mask_transforms_aug": transforms.Compose([
  58. # ColorJitterSeg(brightness=0.5, contrast=0.5, saturation=0.5), # TODO - add
  59. RandomFlip(),
  60. RandomRescale(scales=(0.5, 2.0)),
  61. PadShortToCropSize(args.img_size, fill_mask=CITYSCAPES_IGNORE_LABEL,
  62. fill_image=(CITYSCAPES_IGNORE_LABEL, 0, 0)),
  63. # Legacy padding color that works best with this recipe
  64. CropImageAndMask(crop_size=args.img_size, mode="random"),
  65. ]),
  66. "image_mask_transforms": transforms.Compose([]) # no transform for evaluation
  67. }
  68. # num_classes for IoU includes the ignore label
  69. train_metrics_list = [PixelAccuracy(ignore_label=CITYSCAPES_IGNORE_LABEL),
  70. IoU(num_classes=20, ignore_index=CITYSCAPES_IGNORE_LABEL)]
  71. valid_metrics_list = [PixelAccuracy(ignore_label=CITYSCAPES_IGNORE_LABEL),
  72. IoU(num_classes=20, ignore_index=CITYSCAPES_IGNORE_LABEL)]
  73. train_params = {"max_epochs": args.max_epochs,
  74. "initial_lr": 1e-2,
  75. "loss": DDRNetLoss(ignore_label=CITYSCAPES_IGNORE_LABEL, num_pixels_exclude_ignored=False),
  76. "lr_mode": "poly",
  77. "ema": True, # unlike the paper (not specified in paper)
  78. "average_best_models": True,
  79. "optimizer": "SGD",
  80. "mixed_precision": False,
  81. "optimizer_params":
  82. {"weight_decay": 5e-4,
  83. "momentum": 0.9},
  84. "train_metrics_list": train_metrics_list,
  85. "valid_metrics_list": valid_metrics_list,
  86. "loss_logging_items_names": ["main_loss", "aux_loss", "Loss"],
  87. "metric_to_watch": "IoU",
  88. "greater_metric_to_watch_is_better": True
  89. }
  90. arch_params = {"num_classes": 19, "aux_head": True, "sync_bn": True}
  91. checkpoint_params = {"load_checkpoint": args.reload,
  92. "load_weights_only": args.pretrained_bb_path is not None,
  93. "load_backbone": args.pretrained_bb_path is not None,
  94. "external_checkpoint_path": args.pretrained_bb_path}
  95. model = SgModel(args.experiment_name,
  96. multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL if distributed else MultiGPUMode.DATA_PARALLEL,
  97. device='cuda')
  98. dataset_interface = CityscapesDatasetInterface(dataset_params=dataset_params, cache_labels=False)
  99. model.connect_dataset_interface(dataset_interface, data_loader_num_workers=8 * devices)
  100. model.build_model(architecture="ddrnet_23_slim" if args.slim else "ddrnet_23",
  101. arch_params=arch_params,
  102. checkpoint_params=checkpoint_params)
  103. model.train(training_params=train_params)
Discard
Tip!

Press p or to see the previous file or, n or to see the next file