Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

training.py 1.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
  1. import torch
  2. import sys
  3. from fastai2.vision.all import *
  4. from torchvision.utils import save_image
  5. class ImageImageDataLoaders(DataLoaders):
  6. "Basic wrapper around several `DataLoader`s with factory methods for Image to Image problems"
  7. @classmethod
  8. @delegates(DataLoaders.from_dblock)
  9. def from_label_func(cls, path, fnames, label_func, valid_pct=0.2, seed=None, item_tfms=None, batch_tfms=None, **kwargs):
  10. "Create from list of `fnames` in `path`s with `label_func`."
  11. dblock = DataBlock(blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImageBW)),
  12. splitter=RandomSplitter(valid_pct, seed=seed),
  13. get_y=label_func,
  14. item_tfms=item_tfms,
  15. batch_tfms=batch_tfms)
  16. res = cls.from_dblock(dblock, fnames, path=path, **kwargs)
  17. return res
  18. def get_y_fn(x):
  19. y = str(x.absolute()).replace('.jpg', '_depth.png')
  20. y = Path(y)
  21. return y
  22. def create_data(data_path):
  23. fnames = get_files(data_path/'train', extensions='.jpg')
  24. data = ImageImageDataLoaders.from_label_func(data_path/'train', seed=42, bs=4, num_workers=0, fnames=fnames, label_func=get_y_fn)
  25. return data
  26. if __name__ == "__main__":
  27. if len(sys.argv) < 2:
  28. print("usage: %s <data_path>" % sys.argv[0], file=sys.stderr)
  29. sys.exit(0)
  30. data = create_data(Path(sys.argv[1]))
  31. learner = unet_learner(data, resnet34, metrics=rmse, wd=1e-2, n_out=3, loss_func=MSELossFlat(), path='src/')
  32. learner.fine_tune(1)
  33. learner.save('model')
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...