Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

experimental.py 4.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Experimental modules
  4. """
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. from models.common import Conv
  9. from utils.downloads import attempt_download
  10. class CrossConv(nn.Module):
  11. # Cross Convolution Downsample
  12. def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
  13. # ch_in, ch_out, kernel, stride, groups, expansion, shortcut
  14. super().__init__()
  15. c_ = int(c2 * e) # hidden channels
  16. self.cv1 = Conv(c1, c_, (1, k), (1, s))
  17. self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
  18. self.add = shortcut and c1 == c2
  19. def forward(self, x):
  20. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  21. class Sum(nn.Module):
  22. # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
  23. def __init__(self, n, weight=False): # n: number of inputs
  24. super().__init__()
  25. self.weight = weight # apply weights boolean
  26. self.iter = range(n - 1) # iter object
  27. if weight:
  28. self.w = nn.Parameter(-torch.arange(1., n) / 2, requires_grad=True) # layer weights
  29. def forward(self, x):
  30. y = x[0] # no weight
  31. if self.weight:
  32. w = torch.sigmoid(self.w) * 2
  33. for i in self.iter:
  34. y = y + x[i + 1] * w[i]
  35. else:
  36. for i in self.iter:
  37. y = y + x[i + 1]
  38. return y
  39. class MixConv2d(nn.Module):
  40. # Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
  41. def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
  42. super().__init__()
  43. groups = len(k)
  44. if equal_ch: # equal c_ per group
  45. i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices
  46. c_ = [(i == g).sum() for g in range(groups)] # intermediate channels
  47. else: # equal weight.numel() per group
  48. b = [c2] + [0] * groups
  49. a = np.eye(groups + 1, groups, k=-1)
  50. a -= np.roll(a, 1, axis=1)
  51. a *= np.array(k) ** 2
  52. a[0] = 1
  53. c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
  54. self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)])
  55. self.bn = nn.BatchNorm2d(c2)
  56. self.act = nn.LeakyReLU(0.1, inplace=True)
  57. def forward(self, x):
  58. return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
  59. class Ensemble(nn.ModuleList):
  60. # Ensemble of models
  61. def __init__(self):
  62. super().__init__()
  63. def forward(self, x, augment=False, profile=False, visualize=False):
  64. y = []
  65. for module in self:
  66. y.append(module(x, augment, profile, visualize)[0])
  67. # y = torch.stack(y).max(0)[0] # max ensemble
  68. # y = torch.stack(y).mean(0) # mean ensemble
  69. y = torch.cat(y, 1) # nms ensemble
  70. return y, None # inference, train output
  71. def attempt_load(weights, map_location=None, inplace=True, fuse=True):
  72. from models.yolo import Detect, Model
  73. # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
  74. model = Ensemble()
  75. for w in weights if isinstance(weights, list) else [weights]:
  76. ckpt = torch.load(attempt_download(w), map_location=map_location) # load
  77. if fuse:
  78. model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
  79. else:
  80. model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().eval()) # without layer fuse
  81. # Compatibility updates
  82. for m in model.modules():
  83. if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
  84. m.inplace = inplace # pytorch 1.7.0 compatibility
  85. elif type(m) is Conv:
  86. m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
  87. if len(model) == 1:
  88. return model[-1] # return model
  89. else:
  90. print(f'Ensemble created with {weights}\n')
  91. for k in ['names']:
  92. setattr(model, k, getattr(model[-1], k))
  93. model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
  94. return model # return ensemble
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...