1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
|
- import diffdist.functional as diff_dist
- import numpy as np
- import torch
- import torch.distributed as dist
- import torch.nn as nn
- import torch.nn.functional as F
- from einops import rearrange, repeat
- from timm.loss import SoftTargetCrossEntropy
- def dist_collect(x):
- """ collect all tensor from all GPUs
- args:
- x: shape (mini_batch, ...)
- returns:
- shape (mini_batch * num_gpu, ...)
- """
- x = x.contiguous()
- out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype).contiguous() for _ in range(dist.get_world_size())]
- out_list = diff_dist.all_gather(out_list, x)
- return torch.cat(out_list, dim=0).contiguous()
- class ProjectMLP(nn.Module):
- def __init__(self, in_dim=256, inner_dim=4096, out_dim=256, num_layers=2):
- super(ProjectMLP, self).__init__()
- # hidden layers
- linear_hidden = []
- for i in range(num_layers - 1):
- linear_hidden.append(nn.Conv1d(in_dim if i == 0 else inner_dim, inner_dim, kernel_size=1))
- linear_hidden.append(nn.BatchNorm1d(inner_dim))
- linear_hidden.append(nn.ReLU(inplace=True))
- self.linear_hidden = nn.Sequential(*linear_hidden)
- self.linear_out = nn.Conv1d(
- in_dim if num_layers == 1 else inner_dim, out_dim, kernel_size=1) if num_layers >= 1 else nn.Identity()
- def forward(self, x):
- """
- Args:
- x (torch.Tensor): output of transformers, shape [B, L, C]
- Returns:
- """
- assert x.ndim in [2, 3], x.ndim
- add_dim = False
- if x.ndim == 2:
- # [B, C] -> [B, L, C]
- x = x.unsqueeze(1)
- add_dim = True
- x = rearrange(x, 'b l c -> b c l')
- x = self.linear_hidden(x)
- x = self.linear_out(x)
- x = rearrange(x, 'b c l -> b l c')
- if add_dim:
- x = x.squeeze(1)
- return x
-
- class MultiLabelContrastive(nn.Module):
- def __init__(self,
- img_encoder_dim,
- text_encoder_dim,
- output_dim=256,
- contrast_temperature=0.07,
- proj_num_layers=2,
- multi_label=0,
- share_temperature=False,
- multi_label_loss_weight=1.0):
- super().__init__()
-
- self.contrast_temperature = contrast_temperature
- self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / contrast_temperature))
- self.cross_entropy = nn.CrossEntropyLoss()
- self.soft_cross_entropy = SoftTargetCrossEntropy()
- self.proj_num_layers = proj_num_layers
- self.multi_label = multi_label
- if proj_num_layers > 0:
- self.img_projector = ProjectMLP(
- in_dim=img_encoder_dim, num_layers=proj_num_layers, out_dim=output_dim)
- self.text_projector = ProjectMLP(
- in_dim=text_encoder_dim, num_layers=proj_num_layers, out_dim=output_dim)
- self.img_projector = nn.SyncBatchNorm.convert_sync_batchnorm(self.img_projector)
- self.text_projector = nn.SyncBatchNorm.convert_sync_batchnorm(self.text_projector)
- else:
- self.img_projector = nn.Identity()
- self.text_projector = nn.Identity()
- self.share_temperature = share_temperature
- if self.with_multi_label and not self.share_temperature:
- self.multi_label_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / contrast_temperature))
- self.multi_label_loss_weight = multi_label_loss_weight
- @property
- def with_multi_label(self):
- return self.multi_label > 0
- def forward(self, image_x, text_x):
- batch_size = image_x.shape[0]
- # get label globally
- labels = torch.arange(batch_size, dtype=torch.long, device=image_x.device)
-
- # [B, C]
- image_x = F.normalize(image_x, dim=-1)
- text_x = F.normalize(text_x, dim=-1)
- print(image_x.shape)
- print(text_x.shape)
- logits_per_img = image_x
- logits_per_text = text_x
- logit_scale = torch.clamp(self.logit_scale.exp(), max=100)
- loss_img = self.cross_entropy(logits_per_img * logit_scale, labels)
- loss_text = self.cross_entropy(logits_per_text * logit_scale, labels)
- loss = 0.5 * (loss_img + loss_text)
- return loss
- a = MultiLabelContrastive(196,196)
- print(a(torch.rand(64,196)-1, torch.zeros(64,196)))
|