Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_public_api.py 3.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. test_wandb
  5. ----------------------------------
  6. Tests for the `wandb.apis.PublicApi` module.
  7. """
  8. import datetime
  9. import pytest
  10. import os
  11. import yaml
  12. import tempfile
  13. from .api_mocks import *
  14. from .api_mocks import _run, _query
  15. from click.testing import CliRunner
  16. import git
  17. import json
  18. from .utils import git_repo
  19. import h5py
  20. import numpy as np
  21. import wandb
  22. from wandb import Api
  23. from six import StringIO
  24. api = Api()
  25. def test_run_from_path(request_mocker, query_run_v2, query_download_h5):
  26. run_mock = query_run_v2(request_mocker)
  27. query_download_h5(request_mocker)
  28. run = api.run("test/test/test")
  29. assert run.summary_metrics == {"acc": 100, "loss": 0}
  30. def test_run_history(request_mocker, query_run_v2, query_download_h5):
  31. run_mock = query_run_v2(request_mocker)
  32. query_download_h5(request_mocker)
  33. run = api.run("test/test/test")
  34. assert run.history(pandas=False)[0] == {'acc': 10, 'loss': 90}
  35. def test_run_config(request_mocker, query_run_v2, query_download_h5):
  36. run_mock = query_run_v2(request_mocker)
  37. query_download_h5(request_mocker)
  38. run = api.run("test/test/test")
  39. assert run.config == {'epochs': 10}
  40. def test_run_history_system(request_mocker, query_run_v2, query_download_h5):
  41. run_mock = query_run_v2(request_mocker)
  42. query_download_h5(request_mocker)
  43. run = api.run("test/test/test")
  44. assert run.history(stream="system", pandas=False) == [
  45. {'cpu': 10}, {'cpu': 20}, {'cpu': 30}]
  46. def test_run_summary(request_mocker, query_run_v2, upsert_run, query_download_h5):
  47. run_mock = query_run_v2(request_mocker)
  48. query_download_h5(request_mocker)
  49. update_mock = upsert_run(request_mocker)
  50. run = api.run("test/test/test")
  51. run.summary.update({"cool": 1000})
  52. assert update_mock.called
  53. def test_runs_from_path(request_mocker, query_runs_v2, query_download_h5):
  54. runs_mock = query_runs_v2(request_mocker)
  55. query_download_h5(request_mocker)
  56. runs = api.runs("test/test")
  57. assert len(runs) == 4
  58. assert len(runs.runs) == 2
  59. assert runs[0].summary_metrics == {"acc": 100, "loss": 0}
  60. def test_runs_from_path_index(mocker, request_mocker, query_runs_v2, query_download_h5):
  61. runs_mock = query_runs_v2(request_mocker)
  62. query_download_h5(request_mocker)
  63. runs = api.runs("test/test")
  64. assert len(runs) == 4
  65. run_mock = mocker.patch.object(runs, 'more')
  66. run_mock.side_effect = [True, False]
  67. assert runs[3]
  68. assert len(runs.runs) == 4
  69. def test_read_advanced_summary(request_mocker, upsert_run, query_download_h5, query_upload_h5):
  70. run = _run()
  71. run["summaryMetrics"] = json.dumps({
  72. "special": {"_type": "numpy.ndarray", "min": 0, "max": 20},
  73. "normal": 32,
  74. "nested": {"deep": {"_type": "numpy.ndarray", "min": 0, "max": 20}}})
  75. _query('project', {'run': run})(request_mocker)
  76. file = os.path.join(tempfile.gettempdir(), "test.h5")
  77. with h5py.File(file, 'w') as h5:
  78. h5["summary/special"] = np.random.rand(100)
  79. h5["summary/nested.deep"] = np.random.rand(100)
  80. query_download_h5(request_mocker, content=open(file, "rb").read())
  81. api.flush()
  82. run = api.run("test/test/test")
  83. assert len(run.summary["special"]) == 100
  84. assert len(run.summary["nested"]["deep"]) == 100
  85. update_mock = upsert_run(request_mocker)
  86. h5_mock = query_upload_h5(request_mocker)
  87. run.summary.update({"nd_time": np.random.rand(1000)})
  88. assert len(run.summary["nd_time"]) == 1000
  89. # TODO: this passes locally, but fails consistently in CI?!?
  90. #assert h5_mock.called
  91. del run.summary["nd_time"]
  92. assert list(run.summary._h5["summary"].keys()) == [
  93. "nested.deep", "special"]
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...