Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

engine.py 4.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
  1. import math
  2. import sys
  3. import time
  4. import torch
  5. import torchvision.models.detection.mask_rcnn
  6. import utils
  7. from coco_eval import CocoEvaluator
  8. from coco_utils import get_coco_api_from_dataset
  9. def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
  10. model.train()
  11. metric_logger = utils.MetricLogger(delimiter=" ")
  12. metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
  13. header = f"Epoch: [{epoch}]"
  14. lr_scheduler = None
  15. if epoch == 0:
  16. warmup_factor = 1.0 / 1000
  17. warmup_iters = min(1000, len(data_loader) - 1)
  18. lr_scheduler = torch.optim.lr_scheduler.LinearLR(
  19. optimizer, start_factor=warmup_factor, total_iters=warmup_iters
  20. )
  21. for images, targets in metric_logger.log_every(data_loader, print_freq, header):
  22. images = list(image.to(device) for image in images)
  23. targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
  24. with torch.cuda.amp.autocast(enabled=scaler is not None):
  25. loss_dict = model(images, targets)
  26. losses = sum(loss for loss in loss_dict.values())
  27. # reduce losses over all GPUs for logging purposes
  28. loss_dict_reduced = utils.reduce_dict(loss_dict)
  29. losses_reduced = sum(loss for loss in loss_dict_reduced.values())
  30. loss_value = losses_reduced.item()
  31. if not math.isfinite(loss_value):
  32. print(f"Loss is {loss_value}, stopping training")
  33. print(loss_dict_reduced)
  34. sys.exit(1)
  35. optimizer.zero_grad()
  36. if scaler is not None:
  37. scaler.scale(losses).backward()
  38. scaler.step(optimizer)
  39. scaler.update()
  40. else:
  41. losses.backward()
  42. optimizer.step()
  43. if lr_scheduler is not None:
  44. lr_scheduler.step()
  45. metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
  46. metric_logger.update(lr=optimizer.param_groups[0]["lr"])
  47. return metric_logger
  48. def _get_iou_types(model):
  49. model_without_ddp = model
  50. if isinstance(model, torch.nn.parallel.DistributedDataParallel):
  51. model_without_ddp = model.module
  52. iou_types = ["bbox"]
  53. if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN):
  54. iou_types.append("segm")
  55. if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN):
  56. iou_types.append("keypoints")
  57. return iou_types
  58. @torch.inference_mode()
  59. def evaluate(model, data_loader, device):
  60. n_threads = torch.get_num_threads()
  61. # FIXME remove this and make paste_masks_in_image run on the GPU
  62. torch.set_num_threads(1)
  63. cpu_device = torch.device("cpu")
  64. model.eval()
  65. metric_logger = utils.MetricLogger(delimiter=" ")
  66. header = "Test:"
  67. coco = get_coco_api_from_dataset(data_loader.dataset)
  68. iou_types = _get_iou_types(model)
  69. coco_evaluator = CocoEvaluator(coco, iou_types)
  70. for images, targets in metric_logger.log_every(data_loader, 100, header):
  71. images = list(img.to(device) for img in images)
  72. if torch.cuda.is_available():
  73. torch.cuda.synchronize()
  74. model_time = time.time()
  75. outputs = model(images)
  76. outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
  77. model_time = time.time() - model_time
  78. res = {target["image_id"]: output for target, output in zip(targets, outputs)}
  79. evaluator_time = time.time()
  80. coco_evaluator.update(res)
  81. evaluator_time = time.time() - evaluator_time
  82. metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
  83. # gather the stats from all processes
  84. metric_logger.synchronize_between_processes()
  85. print("Averaged stats:", metric_logger)
  86. coco_evaluator.synchronize_between_processes()
  87. # accumulate predictions from all images
  88. coco_evaluator.accumulate()
  89. coco_evaluator.summarize()
  90. torch.set_num_threads(n_threads)
  91. return coco_evaluator
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...