Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

training.py 1.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
  1. """Trains or fine-tunes a model for the task of monocular depth estimation
  2. Receives 1 arguments from argparse:
  3. <data_path> - Path to the dataset which is split into 2 folders - train and test.
  4. """
  5. import sys
  6. from fastai.vision.all import unet_learner, Path, resnet34, rmse, MSELossFlat
  7. from src.code.custom_data_loading import create_data
  8. from dagshub.fastai import DAGsHubLogger
  9. if __name__ == "__main__":
  10. # Check if got all needed input for argparse
  11. if len(sys.argv) != 2:
  12. print("usage: %s <data_path>" % sys.argv[0], file=sys.stderr)
  13. sys.exit(0)
  14. data = create_data(Path(sys.argv[1]))
  15. wd, lr, ep = 1e-2, 1e-3, 1
  16. learner = unet_learner(data,
  17. resnet34,
  18. metrics=rmse,
  19. wd=wd,
  20. n_out=3,
  21. loss_func=MSELossFlat(),
  22. path='src/',
  23. model_dir='models',
  24. cbs=DAGsHubLogger(
  25. metrics_path="train_metrics.csv",
  26. hparams_path="train_params.yml"
  27. ))
  28. print("Training model...")
  29. learner.fine_tune(epochs=ep, base_lr=lr)
  30. print("Saving model...")
  31. learner.save('model')
  32. print("Done!")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...