Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

dataset.py 2.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
  1. """
  2. This file is used to define the Dataset used for the Training.
  3. """
  4. import torchvision.datasets as datasets
  5. import torchvision.transforms as transforms
  6. from super_gradients.training import utils as core_utils
  7. from super_gradients.training.datasets.dataset_interfaces import DatasetInterface
  8. class UserDataset(DatasetInterface):
  9. """
  10. The user's dataset inherits from SuperGradient's DatasetInterface and must
  11. contain a trainset and testset from which the the data will be loaded using.
  12. All augmentations, resizing and parsing must be done in this class.
  13. - Augmentations are defined below and will be carried out in the order they are given.
  14. super_gradients provides additional dataset reading tools such as ListDataset given a list of files
  15. corresponding to the images and labels.
  16. """
  17. def __init__(self, name="cifar10", dataset_params={}):
  18. super(UserDataset, self).__init__(dataset_params)
  19. self.dataset_name = name
  20. self.lib_dataset_params = {'mean': (0.4914, 0.4822, 0.4465), 'std': (0.2023, 0.1994, 0.2010)}
  21. crop_size = core_utils.get_param(self.dataset_params, 'crop_size', default_val=32)
  22. transform_train = transforms.Compose([
  23. transforms.RandomCrop(crop_size, padding=4),
  24. transforms.RandomHorizontalFlip(),
  25. transforms.ToTensor(),
  26. transforms.Normalize(self.lib_dataset_params['mean'], self.lib_dataset_params['std']),
  27. ])
  28. transform_test = transforms.Compose([
  29. transforms.ToTensor(),
  30. transforms.Normalize(self.lib_dataset_params['mean'], self.lib_dataset_params['std']),
  31. ])
  32. self.trainset = datasets.CIFAR10(root=self.dataset_params.dataset_dir, train=True, download=True,
  33. transform=transform_train)
  34. self.testset = datasets.CIFAR10(root=self.dataset_params.dataset_dir, train=False, download=True,
  35. transform=transform_test)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...