Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

model_wrapper.py 330 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
  1. import mlflow
  2. import torch
  3. class SquirrelDetectorWrapper(mlflow.pyfunc.PythonModel):
  4. def load_context(self, context):
  5. self.model = torch.hub.load('ultralytics/yolov5', 'custom', path=context.artifacts['path'])
  6. def predict(self, context, img):
  7. objs = self.model(img).xywh[0]
  8. return objs.numpy()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...