Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

eval.py 1.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
  1. import sys
  2. import yaml
  3. import torch
  4. from torchvision import transforms
  5. from fastai.vision.all import unet_learner, Path, resnet34, MSELossFlat, get_files, L, PILImageBW
  6. from custom_data_loading import create_data
  7. from eval_metric_calculation import compute_eval_metrics
  8. from dagshub import dagshub_logger
  9. from tqdm import tqdm
  10. if __name__ == "__main__":
  11. if len(sys.argv) < 2:
  12. print("usage: %s <test_data_path>" % sys.argv[0], file=sys.stderr)
  13. sys.exit(0)
  14. with open(r"./src/code/params.yml") as f:
  15. params = yaml.safe_load(f)
  16. data_path = Path(sys.argv[1])
  17. data = create_data(data_path)
  18. arch = {'resnet34': resnet34}
  19. loss = {'MSELossFlat': MSELossFlat()}
  20. learner = unet_learner(data,
  21. arch.get(params['architecture']),
  22. n_out=int(params['num_outs']),
  23. loss_func=loss.get(params['loss_func']),
  24. path='src/',
  25. model_dir='models')
  26. learner = learner.load('model')
  27. filenames = get_files(Path(data_path), extensions='.jpg')
  28. test_files = L([Path(i) for i in filenames])
  29. for i, sample in tqdm(enumerate(test_files.items),
  30. desc="Predicting on test images",
  31. total=len(test_files.items)):
  32. pred = learner.predict(sample)[0]
  33. pred = PILImageBW.create(pred).convert('L')
  34. pred.save("src/eval/" + str(sample.stem) + "_pred.png")
  35. if i < 10:
  36. pred.save("src/eval/examples/" + str(sample.stem) + "_pred.png")
  37. print("Calculating metrics...")
  38. metrics = compute_eval_metrics(test_files)
  39. with dagshub_logger(
  40. metrics_path="logs/test_metrics.csv",
  41. should_log_hparams=False
  42. ) as logger:
  43. # Metric logging
  44. logger.log_metrics(metrics)
  45. print("Evaluation Done!")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...