Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

tumor_classifier.py 874 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
  1. from pathlib import Path
  2. import torch
  3. from PIL import Image
  4. from torchvision import transforms
  5. MODEL_PATH = Path(__file__).parent / "models" / "best_model.pt"
  6. # Make sure this matches your training preprocessing
  7. transform = transforms.Compose(
  8. [
  9. transforms.Resize((224, 224)),
  10. transforms.ToTensor(),
  11. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  12. ]
  13. )
  14. class_names = ["glioma", "meningioma", "notumor", "pituitary"]
  15. def load_model():
  16. model = torch.load(MODEL_PATH, map_location="cpu")
  17. model.eval()
  18. return model
  19. def predict_image(model, image: Image.Image) -> str:
  20. image_tensor = transform(image).unsqueeze(0) # shape: (1, 3, 224, 224)
  21. with torch.no_grad():
  22. outputs = model(image_tensor)
  23. predicted = torch.argmax(outputs, dim=1).item()
  24. return class_names[predicted]
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...