Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

results.py 1.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
  1. from six import BytesIO
  2. from wandb import Api
  3. import base64
  4. import csv
  5. import tempfile
  6. import os
  7. def results(project=None, run=None):
  8. r = Results(project, run)
  9. try:
  10. yield r
  11. finally:
  12. r.close()
  13. class Results(object):
  14. """Generates results to be compared in WandB
  15. with Results("project/run") as r:
  16. for truth, img_data in test_data:
  17. label, score = model.predict(img_data)
  18. img = array_to_img(img_data)
  19. r.write(input=r.encode_image(img), output=label,
  20. truth=truth, score=score)
  21. """
  22. def __init__(self, project=None, run=None):
  23. self.api = Api()
  24. self.project = project or self.api.settings("project")
  25. self.run = run or os.getenv("WANDB_RUN")
  26. self.tempfile = tempfile.NamedTemporaryFile(mode='w')
  27. self.csv = csv.writer(self.tempfile)
  28. self.csv.writerow(["input","output","probability","truth","loss"])
  29. self.rows = 0
  30. def __enter__(self):
  31. return self
  32. def __exit__(self, kind, value, extra):
  33. self.close()
  34. def encode_image(self, img, format="png"):
  35. """Accepts a PIL image and returns an encoded data uri"""
  36. buffer = BytesIO()
  37. img.save(buffer, format=format)
  38. return self.encode_data(buffer.getvalue(), format="image/%s" % format)
  39. def encode_data(self, data, format="image/png"):
  40. """Creates a data uri from raw data"""
  41. return "data:{format};base64,{img}".format(
  42. format=format,
  43. img=base64.b64encode(data).decode("UTF-8")
  44. )
  45. def write(self, **kwargs):
  46. self.rows += 1
  47. self.csv.writerow(
  48. [kwargs["input"], kwargs["output"], kwargs.get("probability"),
  49. kwargs["truth"], kwargs["loss"]])
  50. def close(self):
  51. self.tempfile.flush()
  52. self.api.push(self.project, {'results.csv': open(self.tempfile.name, "rb")},
  53. run=self.run)
  54. self.tempfile.close()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...