Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_utils.py 1.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
  1. import tempfile
  2. from pathlib import Path
  3. import numpy as np
  4. import pytest
  5. import torch
  6. from madewithml import utils
  7. def test_set_seed():
  8. utils.set_seeds()
  9. a = np.random.randn(2, 3)
  10. b = np.random.randn(2, 3)
  11. utils.set_seeds()
  12. x = np.random.randn(2, 3)
  13. y = np.random.randn(2, 3)
  14. assert np.array_equal(a, x)
  15. assert np.array_equal(b, y)
  16. def test_save_and_load_dict():
  17. with tempfile.TemporaryDirectory() as dp:
  18. d = {"hello": "world"}
  19. fp = Path(dp, "d.json")
  20. utils.save_dict(d=d, path=fp)
  21. d = utils.load_dict(path=fp)
  22. assert d["hello"] == "world"
  23. def test_pad_array():
  24. arr = np.array([[1, 2], [1, 2, 3]], dtype="object")
  25. padded_arr = np.array([[1, 2, 0], [1, 2, 3]])
  26. assert np.array_equal(utils.pad_array(arr), padded_arr)
  27. def test_collate_fn():
  28. batch = {
  29. "ids": np.array([[1, 2], [1, 2, 3]], dtype="object"),
  30. "masks": np.array([[1, 1], [1, 1, 1]], dtype="object"),
  31. "targets": np.array([3, 1]),
  32. }
  33. processed_batch = utils.collate_fn(batch)
  34. expected_batch = {
  35. "ids": torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.int32),
  36. "masks": torch.tensor([[1, 1, 0], [1, 1, 1]], dtype=torch.int32),
  37. "targets": torch.tensor([3, 1], dtype=torch.int64),
  38. }
  39. for k in batch:
  40. assert torch.allclose(processed_batch[k], expected_batch[k])
  41. @pytest.mark.parametrize(
  42. "d, keys, list",
  43. [
  44. ({"a": [1, 2], "b": [1, 2]}, ["a", "b"], [{"a": 1, "b": 1}, {"a": 2, "b": 2}]),
  45. ({"a": [1, 2], "b": [1, 2]}, ["a"], [{"a": 1}, {"a": 2}]),
  46. ],
  47. )
  48. def test_dict_to_list(d, keys, list):
  49. assert utils.dict_to_list(d, keys=keys) == list
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...