Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#396 Trainer constructor cleanup

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-266_clean_trainer_ctor
@@ -6,7 +6,6 @@ from super_gradients import Trainer
 import torch
 import torch
 from torch.utils.data import TensorDataset, DataLoader
 from torch.utils.data import TensorDataset, DataLoader
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
-from super_gradients.training.exceptions.sg_trainer_exceptions import IllegalDataloaderInitialization
 
 
 
 
 class InitializeWithDataloadersTest(unittest.TestCase):
 class InitializeWithDataloadersTest(unittest.TestCase):
@@ -26,27 +25,8 @@ class InitializeWithDataloadersTest(unittest.TestCase):
         label = torch.randint(0, len(self.testcase_classes), size=(test_size,))
         label = torch.randint(0, len(self.testcase_classes), size=(test_size,))
         self.testcase_testloader = DataLoader(TensorDataset(inp, label))
         self.testcase_testloader = DataLoader(TensorDataset(inp, label))
 
 
-    def test_initialization_rules(self):
-        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
-                          train_loader=self.testcase_trainloader)
-        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
-                          valid_loader=self.testcase_validloader)
-        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
-                          classes=self.testcase_classes)
-        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
-                          train_loader=self.testcase_trainloader, valid_loader=self.testcase_validloader)
-        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
-                          train_loader=self.testcase_trainloader, classes=self.testcase_classes)
-        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
-                          valid_loader=self.testcase_validloader, classes=self.testcase_classes)
-        Trainer("test_name", model_checkpoints_location='local', train_loader=self.testcase_trainloader,
-                valid_loader=self.testcase_validloader, classes=self.testcase_classes)
-        Trainer("test_name", model_checkpoints_location='local', train_loader=self.testcase_trainloader,
-                valid_loader=self.testcase_validloader, test_loader=self.testcase_testloader,
-                classes=self.testcase_classes)
-
     def test_train_with_dataloaders(self):
     def test_train_with_dataloaders(self):
-        trainer = Trainer(experiment_name="test_name", model_checkpoints_location="local")
+        trainer = Trainer(experiment_name="test_name")
         model = models.get("resnet18", num_classes=5)
         model = models.get("resnet18", num_classes=5)
         trainer.train(model=model,
         trainer.train(model=model,
                       training_params={"max_epochs": 2,
                       training_params={"max_epochs": 2,
Discard
Tip!

Press p or to see the previous file or, n or to see the next file