Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test-data-interface.py 2.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
  1. import torch
  2. from super_gradients.training.datasets.dataset_interfaces import DatasetInterface
  3. from super_gradients.training.sg_model import SgModel
  4. from torchvision.models import resnet18
  5. import numpy as np
  6. class TestDatasetInterface(DatasetInterface):
  7. def __init__(self, dataset_params={}, image_size=32, batch_size=5):
  8. super(TestDatasetInterface, self).__init__(dataset_params)
  9. self.trainset = torch.utils.data.TensorDataset(torch.Tensor(np.zeros((batch_size, 3, image_size, image_size))),
  10. torch.LongTensor(np.zeros((batch_size))))
  11. self.testset = self.trainset
  12. def get_data_loaders(self, batch_size_factor=1, num_workers=8, train_batch_size=None, test_batch_size=None,
  13. distributed_sampler=False):
  14. self.trainset.classes = [0, 1, 2, 3, 4]
  15. return super().get_data_loaders(batch_size_factor=batch_size_factor,
  16. num_workers=num_workers,
  17. train_batch_size=train_batch_size,
  18. test_batch_size=test_batch_size,
  19. distributed_sampler=distributed_sampler)
  20. # ------------------ Loading The Model From Model.py----------------
  21. arch_params = {'num_classes': 1000}
  22. model = resnet18()
  23. sg_classification_model = SgModel('Client_model_training',
  24. model_checkpoints_location='local', device='cpu')
  25. # if a torch.nn.Module is provided when building the model, the model will be integrated into deci model class
  26. sg_classification_model.build_model(model, arch_params=arch_params, load_checkpoint=False)
  27. # ------------------ Loading The Dataset From Dataset.py----------------
  28. dataset = TestDatasetInterface()
  29. sg_classification_model.connect_dataset_interface(dataset)
  30. # ------------------ Loading The Loss From Loss.py -----------------
  31. loss = 'cross_entropy'
  32. # ------------------ Training -----------------
  33. train_params = {"max_epochs": 100,
  34. "lr_mode": "step",
  35. "lr_updates": [30, 60, 90, 100],
  36. "lr_decay_factor": 0.1,
  37. "initial_lr": 0.025, "loss": loss}
  38. sg_classification_model.train(train_params)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...