Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_classifier.py 463 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
  1. from pytorch_lightning import Trainer, seed_everything
  2. from project.lit_mnist import LitClassifier
  3. from project.datasets.mnist import mnist
  4. def test_lit_classifier():
  5. seed_everything(1234)
  6. model = LitClassifier()
  7. train, val, test = mnist()
  8. trainer = Trainer(limit_train_batches=50, limit_val_batches=20, max_epochs=2)
  9. trainer.fit(model, train, val)
  10. results = trainer.test(test_dataloaders=test)
  11. assert results[0]['test_acc'] > 0.7
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...