Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train_bert.py 2.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
  1. import os
  2. import sys
  3. import json
  4. import yaml
  5. import torch
  6. import importlib
  7. import pandas as pd
  8. from pathlib import Path
  9. from dotenv import load_dotenv
  10. import logging
  11. logging.basicConfig(
  12. level=logging.DEBUG,
  13. format="%(asctime)s [%(levelname)s] %(message)s",
  14. handlers=[
  15. logging.FileHandler("debug.log"),
  16. logging.StreamHandler()
  17. ]
  18. )
  19. load_dotenv('envs/.env')
  20. with open('params.yaml', 'r') as f:
  21. PARAMS = yaml.safe_load(f)
  22. def start_training(bert_model, pretrained_model, method='basic'):
  23. try:
  24. model_module = importlib.import_module(f'model.{bert_model}.{method}')
  25. model = model_module.Model(
  26. **PARAMS[bert_model], **PARAMS[bert_model][method],
  27. pretrained_model=pretrained_model
  28. )
  29. except Exception as e:
  30. raise e
  31. if torch.cuda.is_available():
  32. device = torch.device('cuda', PARAMS.get('gpu', 0))
  33. else:
  34. device = torch.device('cpu')
  35. model.to(device)
  36. df = pd.read_csv('data/all.csv')
  37. try:
  38. dataloader_module = importlib.import_module(f'data_loader.{bert_model}_dataloaders')
  39. except Exception as e:
  40. raise e
  41. dataloader = dataloader_module.DataFrameDataLoader(
  42. df, pretrained_model=pretrained_model,
  43. do_lower_case=PARAMS[bert_model]['do_lower_case'],
  44. batch_size=PARAMS['train']['batch_size'],
  45. shuffle=PARAMS['validate']['shuffle'], max_len=PARAMS[bert_model]['max_len']
  46. )
  47. try:
  48. trainer_module = importlib.import_module(f'training.{bert_model}')
  49. bert_model_name = f'{bert_model}-{pretrained_model}-{method}'
  50. trainer = trainer_module.Trainer(model, dataloader, method=bert_model_name, mode='train')
  51. except Exception as e:
  52. raise e
  53. results, losses = trainer.train()
  54. columns = list(losses[0].keys())
  55. losses_df = pd.DataFrame(losses, columns=columns)
  56. return results, losses_df
  57. if __name__ == '__main__':
  58. bert_model, pretrained_model, method = sys.argv[1], sys.argv[2], sys.argv[3]
  59. try:
  60. results, losses_df = start_training(bert_model, pretrained_model, method)
  61. except Exception as e:
  62. logging.error(e)
  63. raise e
  64. results_path = Path(
  65. os.getenv('OUTPUT_PATH'),
  66. f'{bert_model}-{pretrained_model}-{method}_{os.getenv("RESULTS_PATH")}'
  67. )
  68. with open(results_path, 'w') as f:
  69. json.dump(results, f)
  70. plots_path = Path(
  71. os.getenv('OUTPUT_PATH'),
  72. f'{bert_model}-{pretrained_model}-{method}_{os.getenv("PLOTS_PATH")}'
  73. )
  74. losses_df.to_csv(plots_path, index=False)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...