Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#396 Trainer constructor cleanup

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-266_clean_trainer_ctor
@@ -38,7 +38,7 @@ class LRWarmupTest(unittest.TestCase):
     def test_lr_warmup(self):
     def test_lr_warmup(self):
         # Define Model
         # Define Model
         net = LeNet()
         net = LeNet()
-        trainer = Trainer("lr_warmup_test", model_checkpoints_location='local')
+        trainer = Trainer("lr_warmup_test")
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
@@ -60,7 +60,7 @@ class LRWarmupTest(unittest.TestCase):
     def test_lr_warmup_with_lr_scheduling(self):
     def test_lr_warmup_with_lr_scheduling(self):
         # Define model
         # Define model
         net = LeNet()
         net = LeNet()
-        trainer = Trainer("lr_warmup_test", model_checkpoints_location='local')
+        trainer = Trainer("lr_warmup_test")
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
@@ -85,7 +85,7 @@ class LRWarmupTest(unittest.TestCase):
     def test_warmup_initial_lr(self):
     def test_warmup_initial_lr(self):
         # Define model
         # Define model
         net = LeNet()
         net = LeNet()
-        trainer = Trainer("test_warmup_initial_lr", model_checkpoints_location='local')
+        trainer = Trainer("test_warmup_initial_lr")
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
@@ -107,7 +107,7 @@ class LRWarmupTest(unittest.TestCase):
     def test_custom_lr_warmup(self):
     def test_custom_lr_warmup(self):
         # Define model
         # Define model
         net = LeNet()
         net = LeNet()
-        trainer = Trainer("custom_lr_warmup_test", model_checkpoints_location='local')
+        trainer = Trainer("custom_lr_warmup_test")
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
Discard
Tip!

Press p or to see the previous file or, n or to see the next file