Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

model.py 1.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
  1. import pprint
  2. from typing import List, Dict, Optional
  3. from label_studio_ml.model import LabelStudioMLBase
  4. from label_studio_ml.response import ModelResponse
  5. from label_studio_sdk.converter import brush
  6. from uuid import uuid4
  7. import json
  8. class DagsHubLSModel(LabelStudioMLBase):
  9. """Custom ML Backend model
  10. """
  11. def __init__(self):
  12. pass
  13. def configure(self, model, pre_hook, post_hook, ds, dp_map):
  14. self.model = model
  15. self.pre_hook = pre_hook
  16. self.post_hook = post_hook
  17. self.ds = ds
  18. self.dp_map = dp_map
  19. print(f'''\
  20. Configured with model: {model}
  21. Dataset: {ds}
  22. ''')
  23. def setup(self):
  24. self.set("model_version", "0.0.1")
  25. def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> ModelResponse:
  26. print(f'''\
  27. Run prediction on {tasks}
  28. Received context: {context}
  29. Project ID: {self.project_id}
  30. Label config: {self.label_config}
  31. Parsed JSON Label config: {self.parsed_label_config}
  32. Extra params: {self.extra_params}''')
  33. tasks = [(self.ds['path'] == self.dp_map[self.dp_map['datapoint_id'] == task['meta']['datapoint_id']].iloc[0].path).head()[0].download_file().as_posix() for task in tasks] # get local path
  34. res = self.post_hook(self.model.predict(self.pre_hook(tasks)))
  35. print(f'''\
  36. Returning: {res}
  37. ''')
  38. return res
  39. def fit(self, event, data, **kwargs):
  40. pass
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...