Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

Groupvit.py 4.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
  1. import diffdist.functional as diff_dist
  2. import numpy as np
  3. import torch
  4. import torch.distributed as dist
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from einops import rearrange, repeat
  8. from timm.loss import SoftTargetCrossEntropy
  9. def dist_collect(x):
  10. """ collect all tensor from all GPUs
  11. args:
  12. x: shape (mini_batch, ...)
  13. returns:
  14. shape (mini_batch * num_gpu, ...)
  15. """
  16. x = x.contiguous()
  17. out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype).contiguous() for _ in range(dist.get_world_size())]
  18. out_list = diff_dist.all_gather(out_list, x)
  19. return torch.cat(out_list, dim=0).contiguous()
  20. class ProjectMLP(nn.Module):
  21. def __init__(self, in_dim=256, inner_dim=4096, out_dim=256, num_layers=2):
  22. super(ProjectMLP, self).__init__()
  23. # hidden layers
  24. linear_hidden = []
  25. for i in range(num_layers - 1):
  26. linear_hidden.append(nn.Conv1d(in_dim if i == 0 else inner_dim, inner_dim, kernel_size=1))
  27. linear_hidden.append(nn.BatchNorm1d(inner_dim))
  28. linear_hidden.append(nn.ReLU(inplace=True))
  29. self.linear_hidden = nn.Sequential(*linear_hidden)
  30. self.linear_out = nn.Conv1d(
  31. in_dim if num_layers == 1 else inner_dim, out_dim, kernel_size=1) if num_layers >= 1 else nn.Identity()
  32. def forward(self, x):
  33. """
  34. Args:
  35. x (torch.Tensor): output of transformers, shape [B, L, C]
  36. Returns:
  37. """
  38. assert x.ndim in [2, 3], x.ndim
  39. add_dim = False
  40. if x.ndim == 2:
  41. # [B, C] -> [B, L, C]
  42. x = x.unsqueeze(1)
  43. add_dim = True
  44. x = rearrange(x, 'b l c -> b c l')
  45. x = self.linear_hidden(x)
  46. x = self.linear_out(x)
  47. x = rearrange(x, 'b c l -> b l c')
  48. if add_dim:
  49. x = x.squeeze(1)
  50. return x
  51. class MultiLabelContrastive(nn.Module):
  52. def __init__(self,
  53. img_encoder_dim,
  54. text_encoder_dim,
  55. output_dim=256,
  56. contrast_temperature=0.07,
  57. proj_num_layers=2,
  58. multi_label=0,
  59. share_temperature=False,
  60. multi_label_loss_weight=1.0):
  61. super().__init__()
  62. self.contrast_temperature = contrast_temperature
  63. self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / contrast_temperature))
  64. self.cross_entropy = nn.CrossEntropyLoss()
  65. self.soft_cross_entropy = SoftTargetCrossEntropy()
  66. self.proj_num_layers = proj_num_layers
  67. self.multi_label = multi_label
  68. if proj_num_layers > 0:
  69. self.img_projector = ProjectMLP(
  70. in_dim=img_encoder_dim, num_layers=proj_num_layers, out_dim=output_dim)
  71. self.text_projector = ProjectMLP(
  72. in_dim=text_encoder_dim, num_layers=proj_num_layers, out_dim=output_dim)
  73. self.img_projector = nn.SyncBatchNorm.convert_sync_batchnorm(self.img_projector)
  74. self.text_projector = nn.SyncBatchNorm.convert_sync_batchnorm(self.text_projector)
  75. else:
  76. self.img_projector = nn.Identity()
  77. self.text_projector = nn.Identity()
  78. self.share_temperature = share_temperature
  79. if self.with_multi_label and not self.share_temperature:
  80. self.multi_label_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / contrast_temperature))
  81. self.multi_label_loss_weight = multi_label_loss_weight
  82. @property
  83. def with_multi_label(self):
  84. return self.multi_label > 0
  85. def forward(self, image_x, text_x):
  86. batch_size = image_x.shape[0]
  87. # get label globally
  88. labels = torch.arange(batch_size, dtype=torch.long, device=image_x.device)
  89. # [B, C]
  90. image_x = F.normalize(image_x, dim=-1)
  91. text_x = F.normalize(text_x, dim=-1)
  92. print(image_x.shape)
  93. print(text_x.shape)
  94. logits_per_img = image_x
  95. logits_per_text = text_x
  96. logit_scale = torch.clamp(self.logit_scale.exp(), max=100)
  97. loss_img = self.cross_entropy(logits_per_img * logit_scale, labels)
  98. loss_text = self.cross_entropy(logits_per_text * logit_scale, labels)
  99. loss = 0.5 * (loss_img + loss_text)
  100. return loss
  101. a = MultiLabelContrastive(196,196)
  102. print(a(torch.rand(64,196)-1, torch.zeros(64,196)))
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...