Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

unilstm.py 3.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
  1. import torch
  2. import torch.nn as nn
  3. import pickle
  4. import matplotlib.pyplot as plt
  5. import mlflow
  6. from config import *
  7. mlflow.set_tracking_uri('https://dagshub.com/SHENSHENZYC/next-word-prediction-with-LSTM.mlflow')
  8. with open(EMBEDDED_CONTEXT_TRAIN_PATH, 'rb') as f:
  9. X_train = pickle.load(f)
  10. with open(EMBEDDED_TARGET_TRAIN_PATH, 'rb') as f:
  11. y_train = pickle.load(f)
  12. # create batches for training data
  13. train_loader = torch.utils.data.DataLoader(dataset=list(zip(X_train, y_train)), batch_size=batch_size, shuffle=True)
  14. # create a LSTM model for next-word-prediction
  15. class NWP_LSTM(nn.Module):
  16. def __init__(self, input_size, hidden_size, word_vector_size, num_layers):
  17. super(NWP_LSTM, self).__init__()
  18. self.num_layers = num_layers
  19. self.hidden_size = hidden_size
  20. # lstm layer
  21. self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
  22. self.fc = nn.Linear(hidden_size, word_vector_size)
  23. def forward(self, x):
  24. h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) # hidden state: short-term memories
  25. c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) # cell state: long-term memories
  26. out, _ = self.lstm(x, (h0, c0))
  27. out = out[:, -1, :] # only use the output from last sequence
  28. out = self.fc(out)
  29. return out
  30. #with dagshub_logger(metrics_path='logs/train_metrics.csv', hparams_path='logs/train_params.yml') as logger:
  31. model = NWP_LSTM(input_size, hidden_size, word_vector_size, num_layers)
  32. # Loss and optimizer
  33. criterion = nn.CosineEmbeddingLoss()
  34. optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  35. # log hyperparameters
  36. mlflow.start_run()
  37. mlflow.log_params({'model_class': 'unidirectional LSTM',
  38. 'model': model.parameters(),
  39. 'optimizer': 'adam',
  40. 'criterion': 'cosine embedding',
  41. 'epochs': num_epochs,
  42. 'batch size': batch_size,
  43. 'learning rate': learning_rate,
  44. 'hidden layer size': hidden_size,
  45. 'number of LSTM layers': num_layers,
  46. 'context window size': CONTEXT_WINDOW})
  47. # Train the model
  48. n_total_steps = len(train_loader)
  49. steps = []
  50. losses = []
  51. for epoch in range(num_epochs):
  52. for i, (contexts, targets) in enumerate(train_loader):
  53. # Forward pass
  54. preds = model(contexts)
  55. loss = criterion(preds, targets, torch.tensor([1] * batch_size))
  56. # Backward and optimize
  57. loss.backward()
  58. optimizer.step()
  59. optimizer.zero_grad()
  60. if (i+1) % 100 == 0:
  61. # print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')
  62. steps.append(epoch + 1 + (i + 1) / n_total_steps)
  63. losses.append(loss.item())
  64. mlflow.log_metric(key='training_loss', value=loss.item(), step=epoch)
  65. # test loss
  66. with open(EMBEDDED_CONTEXT_TEST_PATH, 'rb') as f:
  67. X_test = torch.tensor(pickle.load(f))
  68. with open(EMBEDDED_TARGET_TEST_PATH, 'rb') as f:
  69. y_test = torch.tensor(pickle.load(f))
  70. test_preds = model(X_test)
  71. test_loss = criterion(test_preds, y_test, torch.tensor([1] * X_test.size(0)))
  72. mlflow.log_metric(key='test loss', value=test_loss.item())
  73. plt.plot(steps, losses)
  74. plt.xlabel('Epoches')
  75. plt.ylabel('Loss')
  76. plt.savefig('plots/train_losses_lstm.png')
  77. torch.save(model, LSTM_MODEL_PATH)
  78. mlflow.log_artifact(LSTM_MODEL_PATH)
  79. mlflow.end_run()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...