Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_train.py 835 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
  1. import json
  2. import pytest
  3. import utils
  4. from madewithml import train
  5. @pytest.mark.training
  6. def test_train_model(dataset_loc):
  7. experiment_name = utils.generate_experiment_name(prefix="test_train")
  8. train_loop_config = {"dropout_p": 0.5, "lr": 1e-4, "lr_factor": 0.8, "lr_patience": 3}
  9. result = train.train_model(
  10. experiment_name=experiment_name,
  11. dataset_loc=dataset_loc,
  12. train_loop_config=json.dumps(train_loop_config),
  13. num_workers=6,
  14. cpu_per_worker=1,
  15. gpu_per_worker=0,
  16. num_epochs=2,
  17. num_samples=512,
  18. batch_size=256,
  19. results_fp=None,
  20. )
  21. utils.delete_experiment(experiment_name=experiment_name)
  22. train_loss_list = result.metrics_dataframe.to_dict()["train_loss"]
  23. assert train_loss_list[0] > train_loss_list[1] # loss decreased
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...