Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

training.py 1.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
  1. """Trains or fine-tunes a model for the task of monocular depth estimation
  2. Receives 1 arguments from argparse:
  3. <data_path> - Path to the dataset which is split into 2 folders - train and test.
  4. """
  5. import sys
  6. import yaml
  7. from fastai.vision.all import unet_learner, Path, resnet34, rmse, MSELossFlat
  8. from custom_data_loading import create_data
  9. from dagshub.fastai import DAGsHubLogger
  10. if __name__ == "__main__":
  11. # Check if got all needed input for argparse
  12. if len(sys.argv) != 2:
  13. print("usage: %s <data_path>" % sys.argv[0], file=sys.stderr)
  14. sys.exit(0)
  15. with open(r"./src/code/params.yml") as f:
  16. params = yaml.safe_load(f)
  17. data = create_data(Path(sys.argv[1]))
  18. metrics = {'rmse': rmse}
  19. arch = {'resnet34': resnet34}
  20. loss = {'MSELossFlat': MSELossFlat()}
  21. learner = unet_learner(data,
  22. arch.get(params['architecture']),
  23. metrics=metrics.get(params['train_metric']),
  24. wd=float(params['weight_decay']),
  25. n_out=int(params['num_outs']),
  26. loss_func=loss.get(params['loss_func']),
  27. path=params['source_dir'],
  28. model_dir=params['model_dir'],
  29. cbs=DAGsHubLogger(
  30. metrics_path="logs/train_metrics.csv",
  31. hparams_path="logs/train_params.yml"))
  32. print("Training model...")
  33. learner.fine_tune(epochs=int(params['epochs']),
  34. base_lr=float(params['learning_rate']))
  35. print("Saving model...")
  36. learner.save('model')
  37. print("Done!")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...