Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

repeated_mnist.py 2.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
  1. from torch.utils.data import Dataset
  2. from torchvision import datasets, transforms
  3. import torch.utils.data as data
  4. import torch
  5. class TransformedDataset(data.Dataset):
  6. """
  7. Transforms a dataset.
  8. Arguments:
  9. dataset (Dataset): The whole Dataset
  10. transformer (LambdaType): (idx, sample) -> transformed_sample
  11. """
  12. def __init__(self, dataset, transformer=None, vision_transformer=None):
  13. self.dataset = dataset
  14. assert not transformer or not vision_transformer
  15. if transformer:
  16. self.transformer = transformer
  17. else:
  18. self.transformer = lambda _, data_label: (vision_transformer(data_label[0]), data_label[1])
  19. def __getitem__(self, idx):
  20. return self.transformer(idx, self.dataset[idx])
  21. def __len__(self):
  22. return len(self.dataset)
  23. def create_repeated_MNIST_dataset(self, num_repetitions= 3,
  24. add_noise = True):
  25. # num_classes = 10, input_size = 28
  26. transform = transforms.Compose(
  27. [transforms.ToTensor(),
  28. transforms.Normalize((0.1307, ), (0.3081, ))])
  29. train_dataset = datasets.MNIST("data",
  30. train=True,
  31. download=True,
  32. transform=transform)
  33. if num_repetitions > 1:
  34. train_dataset = data.ConcatDataset([train_dataset] * num_repetitions)
  35. if add_noise:
  36. dataset_noise = torch.empty((len(train_dataset), 28, 28),
  37. dtype=torch.float32).normal_(0.0, 0.1)
  38. def apply_noise(idx, sample):
  39. data, target = sample
  40. return data + dataset_noise[idx], target
  41. train_dataset = TransformedDataset(train_dataset,
  42. transformer=apply_noise)
  43. test_dataset = datasets.MNIST("data", train=False, transform=transform)
  44. return train_dataset, test_dataset
  45. def create_MNIST_dataset():
  46. return self.create_repeated_MNIST(num_repetitions=1, add_noise=False)
  47. def get_targets(dataset):
  48. """Get the targets of a dataset without any target transforms.
  49. This supports subsets and other derivative datasets."""
  50. if isinstance(dataset, TransformedDataset):
  51. return get_targets(dataset.dataset)
  52. if isinstance(dataset, data.Subset):
  53. targets = get_targets(dataset.dataset)
  54. return torch.as_tensor(targets)[dataset.indices]
  55. if isinstance(dataset, data.ConcatDataset):
  56. return torch.cat([get_targets(sub_dataset) for sub_dataset in dataset.datasets])
  57. return torch.as_tensor(dataset.targets)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...