Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

model.py 395 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
  1. import torch.nn as nn
  2. import torchvision.models as models
  3. class EmbeddingNet(nn.Module):
  4. def __init__(self, backbone=None):
  5. super().__init__()
  6. if backbone is None:
  7. backbone = models.resnet50(num_classes=128)
  8. self.backbone = backbone
  9. def forward(self, x):
  10. x = self.backbone(x)
  11. x = nn.functional.normalize(x, dim=1)
  12. return x
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...