Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

externel_dataset_interface_test.py 3.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
  1. import torch
  2. import unittest
  3. import numpy as np
  4. import tensorflow.keras as keras
  5. from super_gradients.training.datasets.dataset_interfaces.dataset_interface import ExternalDatasetInterface
  6. class DataGenerator(keras.utils.Sequence):
  7. def __init__(self, batch_size=1, dim=(320, 320), n_channels=3,
  8. n_classes=1000, shuffle=True):
  9. self.dim = dim
  10. self.batch_size = batch_size
  11. self.list_IDs = np.ones(1000)
  12. self.n_channels = n_channels
  13. self.n_classes = n_classes
  14. self.shuffle = shuffle
  15. self.on_epoch_end()
  16. def __len__(self):
  17. dataset_len = 32
  18. return dataset_len
  19. def __getitem__(self, index):
  20. indices = self.indices[index * self.batch_size:(index + 1) * self.batch_size]
  21. list_IDs_temp = [self.list_IDs[k] for k in indices]
  22. X, y = self.__data_generation(list_IDs_temp)
  23. return X.squeeze(axis=0), y.squeeze(axis=0)
  24. def on_epoch_end(self):
  25. self.indices = np.arange(len(self.list_IDs))
  26. if self.shuffle:
  27. np.random.shuffle(self.indices)
  28. def __data_generation(self, list_IDs_temp):
  29. X = np.ones((self.batch_size, self.n_channels, *self.dim), dtype=np.float32)
  30. y = np.ones((self.batch_size, 1), dtype=np.float32)
  31. return X, y
  32. class TestExternalDatasetInterface(unittest.TestCase):
  33. def setUp(self):
  34. params = {'dim': (256, 256),
  35. 'batch_size': 1,
  36. 'n_classes': 1000,
  37. 'n_channels': 3,
  38. 'shuffle': True}
  39. training_generator = DataGenerator(**params)
  40. testing_generator = DataGenerator(**params)
  41. external_num_classes = 1000
  42. external_dataset_params = {'batch_size': 16,
  43. "val_batch_size": 16}
  44. self.dim = params['dim'][0]
  45. self.n_channels = params['n_channels']
  46. self.batch_size = external_dataset_params['batch_size']
  47. self.val_batch_size = external_dataset_params['val_batch_size']
  48. self.test_external_dataset_interface = ExternalDatasetInterface(train_loader=training_generator,
  49. val_loader=testing_generator,
  50. num_classes=external_num_classes,
  51. dataset_params=external_dataset_params)
  52. def test_get_data_loaders(self):
  53. train_loader, val_loader, _, num_classes = self.test_external_dataset_interface.get_data_loaders()
  54. for batch_idx, (inputs, targets) in enumerate(train_loader):
  55. self.assertListEqual([self.batch_size, self.n_channels, self.dim, self.dim], list(inputs.shape))
  56. self.assertListEqual([self.batch_size, 1], list(targets.shape))
  57. self.assertEqual(torch.Tensor, type(inputs))
  58. self.assertEqual(torch.Tensor, type(targets))
  59. for batch_idx, (inputs, targets) in enumerate(val_loader):
  60. self.assertListEqual([self.val_batch_size, self.n_channels, self.dim, self.dim], list(inputs.shape))
  61. self.assertListEqual([self.val_batch_size, 1], list(targets.shape))
  62. self.assertEqual(torch.Tensor, type(inputs))
  63. self.assertEqual(torch.Tensor, type(targets))
  64. if __name__ == '__main__':
  65. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...