|
@@ -71,6 +71,24 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
|
|
model.train(train_params)
|
|
model.train(train_params)
|
|
assert lr_scheduler.get_last_lr()[0] == lr * 0.1 * 0.1
|
|
assert lr_scheduler.get_last_lr()[0] == lr * 0.1 * 0.1
|
|
|
|
|
|
|
|
+ def test_train_with_external_scheduler_class(self):
|
|
|
|
+ model = SgModel("external_scheduler_test", model_checkpoints_location='local')
|
|
|
|
+ dataset_params = {"batch_size": 10}
|
|
|
|
+ dataset = ClassificationTestDatasetInterface(dataset_params=dataset_params)
|
|
|
|
+ model.connect_dataset_interface(dataset)
|
|
|
|
+
|
|
|
|
+ net = ResNet18(num_classes=5, arch_params={})
|
|
|
|
+ optimizer = SGD # a class - not an instance
|
|
|
|
+ model.build_model(net)
|
|
|
|
+
|
|
|
|
+ train_params = {"max_epochs": 2,
|
|
|
|
+ "lr_warmup_epochs": 0, "initial_lr": 0.3, "loss": "cross_entropy", "optimizer": optimizer,
|
|
|
|
+ "criterion_params": {},
|
|
|
|
+ "train_metrics_list": [Accuracy(), Top5()], "valid_metrics_list": [Accuracy(), Top5()],
|
|
|
|
+ "loss_logging_items_names": ["Loss"], "metric_to_watch": "Accuracy",
|
|
|
|
+ "greater_metric_to_watch_is_better": True}
|
|
|
|
+ model.train(train_params)
|
|
|
|
+
|
|
def test_train_with_reduce_on_plateau(self):
|
|
def test_train_with_reduce_on_plateau(self):
|
|
model = SgModel("external_reduce_on_plateau_scheduler_test", model_checkpoints_location='local')
|
|
model = SgModel("external_reduce_on_plateau_scheduler_test", model_checkpoints_location='local')
|
|
dataset_params = {"batch_size": 10}
|
|
dataset_params = {"batch_size": 10}
|