Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

shelfnet_lw_example.py 3.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
  1. # ShelfNet LW 34 training on CoCo Segmentation Dataset:
  2. # mIOU on CoCo Seg: ~0.66
  3. # Since the code is training on a Subset of COCO Seg, there is an initial creation process for the "Sub-DataSet"
  4. # this training process is optimized to enable fine-tuning on PASCAL VOC 2012 Dataset that has only 21 Classes...
  5. # IMPORTANT: The code is optimized for a fixed initial LR since the Polynomial Loss is pretty sensitive, so we keep the
  6. # same LR by dividing by the number of GPUs (since our code base multiplies it automatically)
  7. # P.S. - Use the relevant training params dict if you are running on TZAG or on V100
  8. import torch
  9. from super_gradients.training import SgModel, MultiGPUMode
  10. from super_gradients.training.datasets import CoCoSegmentationDatasetInterface
  11. from super_gradients.training.sg_model.sg_model import StrictLoad
  12. from super_gradients.training.metrics.segmentation_metrics import PixelAccuracy, IoU
  13. model_size_str = '34'
  14. coco_sub_classes_inclusion_tuples_list = [(0, 'background'), (5, 'airplane'), (2, 'bicycle'), (16, 'bird'),
  15. (9, 'boat'),
  16. (44, 'bottle'), (6, 'bus'), (3, 'car'), (17, 'cat'), (62, 'chair'),
  17. (21, 'cow'),
  18. (67, 'dining table'), (18, 'dog'), (19, 'horse'), (4, 'motorcycle'),
  19. (1, 'person'),
  20. (64, 'potted plant'), (20, 'sheep'), (63, 'couch'), (7, 'train'),
  21. (72, 'tv')]
  22. coco_seg_dataset_tzag_params = {
  23. "batch_size": 24,
  24. "test_batch_size": 24,
  25. "dataset_dir": "/data/coco/",
  26. "s3_link": None,
  27. "img_size": 608,
  28. "crop_size": 512
  29. }
  30. coco_seg_dataset_v100_params = {
  31. "batch_size": 32,
  32. "test_batch_size": 32,
  33. "dataset_dir": "/home/ubuntu/data/coco/",
  34. "s3_link": None,
  35. "img_size": 608,
  36. "crop_size": 512
  37. }
  38. shelfnet_coco_training_params = {
  39. "max_epochs": 150, "initial_lr": 5e-3, "loss": "shelfnet_ohem_loss",
  40. "optimizer": "SGD", "mixed_precision": True, "lr_mode": "poly",
  41. "optimizer_params": {"momentum": 0.9, "weight_decay": 1e-4, "nesterov": False},
  42. "load_opt_params": False, "train_metrics_list": [PixelAccuracy(), IoU(21)],
  43. "valid_metrics_list": [PixelAccuracy(), IoU(21)],
  44. "loss_logging_items_names": ["Loss1/4", "Loss1/8", "Loss1/16", "Loss"], "metric_to_watch": "IoU",
  45. "greater_metric_to_watch_is_better": True}
  46. shelfnet_lw_arch_params = {"num_classes": 21, "load_checkpoint": True, "strict_load": StrictLoad.ON,
  47. "multi_gpu_mode": "data_parallel", "load_weights_only": True,
  48. "load_backbone": True, "source_ckpt_folder_name": 'resnet' + model_size_str}
  49. data_loader_num_workers = 8 * torch.cuda.device_count()
  50. # BUILD THE LIGHT-WEIGHT SHELFNET ARCHITECTURE FOR TRAINING
  51. experiment_name_prefix = 'shelfnet_lw_'
  52. experiment_name_dataset_suffix = '_coco_seg_' + str(
  53. shelfnet_coco_training_params['max_epochs']) + '_epochs_train_example'
  54. experiment_name = experiment_name_prefix + model_size_str + experiment_name_dataset_suffix
  55. model = SgModel(experiment_name,
  56. multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL,
  57. ckpt_name='ckpt_best.pth')
  58. coco_seg_datasaet_interface = CoCoSegmentationDatasetInterface(dataset_params=coco_seg_dataset_tzag_params,
  59. cache_labels=False,
  60. dataset_classes_inclusion_tuples_list=coco_sub_classes_inclusion_tuples_list)
  61. model.connect_dataset_interface(coco_seg_datasaet_interface, data_loader_num_workers=data_loader_num_workers)
  62. model.build_model('shelfnet' + model_size_str, arch_params=shelfnet_lw_arch_params)
  63. print('Training ShelfNet-LW model: ' + experiment_name)
  64. model.train(training_params=shelfnet_coco_training_params)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...