Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_pipeline.py 2.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
  1. from pathlib import Path
  2. import pytest
  3. import torch
  4. import torchvision.transforms as transforms
  5. import yaml
  6. from torch.utils.data import DataLoader
  7. from torchvision.datasets import ImageFolder
  8. from src.inspect import visualize_batch
  9. from src.train import train_model, tune
  10. from src.utils import TumorClassifier
  11. # Mock visualize_batch to avoid plotting during tests
  12. @pytest.fixture(autouse=True)
  13. def patch_visualize(monkeypatch):
  14. monkeypatch.setattr("src.inspect.visualize_batch", lambda *args, **kwargs: None)
  15. @pytest.fixture(scope="module")
  16. def config():
  17. params_file = Path("params.yaml")
  18. with open(params_file, encoding="utf-8") as f:
  19. return yaml.safe_load(f)
  20. @pytest.fixture(scope="module")
  21. def dataloaders(config):
  22. data_path = Path(config["data"]["raw_data_path"])
  23. data_transforms = transforms.Compose(
  24. [
  25. transforms.Resize((224, 224)),
  26. transforms.ToTensor(),
  27. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  28. ]
  29. )
  30. train_dataset = ImageFolder(data_path / "Training", transform=data_transforms)
  31. val_dataset = ImageFolder(data_path / "Testing", transform=data_transforms)
  32. train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)
  33. val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=0)
  34. return train_loader, val_loader
  35. @pytest.fixture(scope="module")
  36. def device():
  37. return torch.device(
  38. "mps"
  39. if torch.backends.mps.is_available()
  40. else ("cuda" if torch.cuda.is_available() else "cpu")
  41. )
  42. def test_dataloaders(dataloaders):
  43. train_loader, val_loader = dataloaders
  44. images, labels = next(iter(train_loader))
  45. assert images.shape[0] > 0
  46. assert images.shape[1:] == (3, 224, 224)
  47. assert isinstance(labels[0].item(), int)
  48. def test_tune_returns_params(dataloaders, device):
  49. train_loader, val_loader = dataloaders
  50. params = tune(TumorClassifier, train_loader, val_loader, device)
  51. assert isinstance(params, dict)
  52. assert "lr" in params or "dropout" in params # Example expected keys
  53. def test_train_model(dataloaders, device):
  54. train_loader, val_loader = dataloaders
  55. model = TumorClassifier()
  56. params = {
  57. "lr": 0.001,
  58. "dropout": 0.5,
  59. "epochs": 1,
  60. "lr_decay": 0.95,
  61. } # Minimal for test
  62. result = train_model(
  63. model, train_loader, val_loader, device, params, register=False
  64. )
  65. assert isinstance(result, dict)
  66. assert "accuracy" in result or "f1_score" in result
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...