Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

utils.py 3.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
  1. import json
  2. import os
  3. import random
  4. from typing import Any, Dict, List
  5. import numpy as np
  6. import torch
  7. from ray.data import DatasetContext
  8. from ray.train.torch import get_device
  9. from madewithml.config import mlflow
  10. DatasetContext.get_current().execution_options.preserve_order = True
  11. def set_seeds(seed: int = 42):
  12. """Set seeds for reproducibility."""
  13. np.random.seed(seed)
  14. random.seed(seed)
  15. torch.manual_seed(seed)
  16. torch.cuda.manual_seed(seed)
  17. eval("setattr(torch.backends.cudnn, 'deterministic', True)")
  18. eval("setattr(torch.backends.cudnn, 'benchmark', False)")
  19. os.environ["PYTHONHASHSEED"] = str(seed)
  20. def load_dict(path: str) -> Dict:
  21. """Load a dictionary from a JSON's filepath.
  22. Args:
  23. path (str): location of file.
  24. Returns:
  25. Dict: loaded JSON data.
  26. """
  27. with open(path) as fp:
  28. d = json.load(fp)
  29. return d
  30. def save_dict(d: Dict, path: str, cls: Any = None, sortkeys: bool = False) -> None:
  31. """Save a dictionary to a specific location.
  32. Args:
  33. d (Dict): data to save.
  34. path (str): location of where to save the data.
  35. cls (optional): encoder to use on dict data. Defaults to None.
  36. sortkeys (bool, optional): whether to sort keys alphabetically. Defaults to False.
  37. """
  38. directory = os.path.dirname(path)
  39. if directory and not os.path.exists(directory): # pragma: no cover
  40. os.makedirs(directory)
  41. with open(path, "w") as fp:
  42. json.dump(d, indent=2, fp=fp, cls=cls, sort_keys=sortkeys)
  43. fp.write("\n")
  44. def pad_array(arr: np.ndarray, dtype=np.int32) -> np.ndarray:
  45. """Pad an 2D array with zeros until all rows in the
  46. 2D array are of the same length as a the longest
  47. row in the 2D array.
  48. Args:
  49. arr (np.array): input array
  50. Returns:
  51. np.array: zero padded array
  52. """
  53. max_len = max(len(row) for row in arr)
  54. padded_arr = np.zeros((arr.shape[0], max_len), dtype=dtype)
  55. for i, row in enumerate(arr):
  56. padded_arr[i][: len(row)] = row
  57. return padded_arr
  58. def collate_fn(batch: Dict[str, np.ndarray]) -> Dict[str, torch.Tensor]: # pragma: no cover, air internal
  59. """Convert a batch of numpy arrays to tensors (with appropriate padding).
  60. Args:
  61. batch (Dict[str, np.ndarray]): input batch as a dictionary of numpy arrays.
  62. Returns:
  63. Dict[str, torch.Tensor]: output batch as a dictionary of tensors.
  64. """
  65. batch["ids"] = pad_array(batch["ids"])
  66. batch["masks"] = pad_array(batch["masks"])
  67. dtypes = {"ids": torch.int32, "masks": torch.int32, "targets": torch.int64}
  68. tensor_batch = {}
  69. for key, array in batch.items():
  70. tensor_batch[key] = torch.as_tensor(array, dtype=dtypes[key], device=get_device())
  71. return tensor_batch
  72. def get_run_id(experiment_name: str, trial_id: str) -> str: # pragma: no cover, mlflow functionality
  73. """Get the MLflow run ID for a specific Ray trial ID.
  74. Args:
  75. experiment_name (str): name of the experiment.
  76. trial_id (str): id of the trial.
  77. Returns:
  78. str: run id of the trial.
  79. """
  80. trial_name = f"TorchTrainer_{trial_id}"
  81. run = mlflow.search_runs(experiment_names=[experiment_name], filter_string=f"tags.trial_name = '{trial_name}'").iloc[0]
  82. return run.run_id
  83. def dict_to_list(data: Dict, keys: List[str]) -> List[Dict[str, Any]]:
  84. """Convert a dictionary to a list of dictionaries.
  85. Args:
  86. data (Dict): input dictionary.
  87. keys (List[str]): keys to include in the output list of dictionaries.
  88. Returns:
  89. List[Dict[str, Any]]: output list of dictionaries.
  90. """
  91. list_of_dicts = []
  92. for i in range(len(data[keys[0]])):
  93. new_dict = {key: data[key][i] for key in keys}
  94. list_of_dicts.append(new_dict)
  95. return list_of_dicts
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...