Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

kd_trainer_utils.py 771 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
  1. import torch
  2. class NormalizationAdapter(torch.nn.Module):
  3. def __init__(self, mean_original, std_original, mean_required, std_required):
  4. super(NormalizationAdapter, self).__init__()
  5. mean_original = torch.tensor(mean_original).unsqueeze(-1).unsqueeze(-1)
  6. std_original = torch.tensor(std_original).unsqueeze(-1).unsqueeze(-1)
  7. mean_required = torch.tensor(mean_required).unsqueeze(-1).unsqueeze(-1)
  8. std_required = torch.tensor(std_required).unsqueeze(-1).unsqueeze(-1)
  9. self.additive = torch.nn.Parameter((mean_original - mean_required) / std_original)
  10. self.multiplier = torch.nn.Parameter(std_original / std_required)
  11. def forward(self, x):
  12. x = (x + self.additive) * self.multiplier
  13. return x
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...