Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

models.py 657 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
  1. import torch
  2. import torch.nn as nn
  3. class FinetunedLLM(nn.Module): # pragma: no cover, torch model
  4. """Model architecture for a Large Language Model (LLM) that we will fine-tune."""
  5. def __init__(self, llm, dropout_p, embedding_dim, num_classes):
  6. super(FinetunedLLM, self).__init__()
  7. self.llm = llm
  8. self.dropout = torch.nn.Dropout(dropout_p)
  9. self.fc1 = torch.nn.Linear(embedding_dim, num_classes)
  10. def forward(self, batch):
  11. ids, masks = batch["ids"], batch["masks"]
  12. seq, pool = self.llm(input_ids=ids, attention_mask=masks)
  13. z = self.dropout(pool)
  14. z = self.fc1(z)
  15. return z
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...