Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

yolo_ls.py 4.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
  1. import dagshub
  2. import mlflow
  3. import os
  4. from label_studio_ml.model import LabelStudioMLBase
  5. from ultralytics import YOLO
  6. class YoloLS(LabelStudioMLBase):
  7. def __init__(self, **kwargs):
  8. super(YoloLS, self).__init__(**kwargs)
  9. # pre-initialize your variables here
  10. from_name, schema = list(self.parsed_label_config.items())[0]
  11. self.from_name = from_name
  12. self.to_name = schema['to_name'][0]
  13. self.labels = schema['labels']
  14. self.user = os.getenv("DAGSHUB_USER_NAME")
  15. self.token = os.getenv("DAGSHUB_TOKEN")
  16. self.repo = os.getenv("DAGSHUB_REPO_NAME")
  17. self.host = os.getenv("DAGSHUB_CLIENT_HOST")
  18. dagshub.auth.add_app_token(token=self.token, host=self.host)
  19. dagshub.init(repo_name=self.repo, repo_owner=self.user)
  20. self.model = YOLO('yolov8n-seg.pt', task='segment')
  21. # client = mlflow.MlflowClient()
  22. # name = 'YOLOv8 Custom'
  23. # version = client.get_latest_versions(name=name)[0].version
  24. # self.model_version = f'{name}:{version}'
  25. # model_uri = f'models:/{name}/{version}'
  26. # self.model = mlflow.pyfunc.load_model(model_uri)
  27. def image_uri_to_https(self, uri):
  28. if uri.startswith('http'):
  29. return uri
  30. elif uri.startswith('repo://'):
  31. link_data = uri.split("repo://")[-1].split("/")
  32. commit, tree_path = link_data[0], "/".join(link_data[1:])
  33. return f"{self.host}/api/v1/repos/{self.user}/{self.repo}/raw/{commit}/{tree_path}"
  34. elif uri.startswith('/'):
  35. return f"{self.host}{uri}"
  36. raise FileNotFoundError(f'Unknown URI {uri}')
  37. def predict(self, tasks, **kwargs):
  38. """ This is where inference happens:
  39. model returns the list of predictions based on input list of tasks
  40. :param tasks: Label Studio tasks in JSON format
  41. """
  42. results = []
  43. for task in tasks:
  44. uri = task['data']['image']
  45. url = self.image_uri_to_https(uri)
  46. if self.host != "https://dagshub.com":
  47. url = url.replace("https://",f"https://{self.user}:{self.token}@")
  48. preds = self.model.predict(url)[0]
  49. lowest_conf = 2.0
  50. img_results = []
  51. boxes = preds.boxes.cpu().numpy()
  52. masks = preds.masks
  53. for i in range(len(boxes.cls)):
  54. conf = float(boxes.conf[i])
  55. if conf < lowest_conf:
  56. lowest_conf = conf
  57. img_results.append({
  58. 'type': 'polygonlabels',
  59. 'to_name': self.to_name,
  60. 'from_name': self.from_name,
  61. 'image_rotation': 0,
  62. 'original_height': preds.orig_shape[0],
  63. 'original_width': preds.orig_shape[1],
  64. 'value': {
  65. 'closed': True,
  66. 'points': (masks.xyn[i] * 100).astype(float).tolist(),
  67. 'polygonlabels': [self.labels[int(boxes.cls[i])]],
  68. },
  69. 'score': conf
  70. })
  71. # img_results = {
  72. # 'type': 'polygonlabels',
  73. # 'to_name': self.to_name,
  74. # 'from_name': self.from_name,
  75. # 'image_rotation': 0,
  76. # 'original_height': preds.orig_shape[0],
  77. # 'original_width': preds.orig_shape[1],
  78. # 'value': {
  79. # 'closed': True,
  80. # 'points': (masks.xyn[i] * 100).astype(float).tolist(),
  81. # 'polygonlabels': self.labels[int(boxes.cls[i])],
  82. # },
  83. # 'score': conf
  84. # }
  85. results.append({
  86. 'result': img_results,
  87. 'score': lowest_conf
  88. })
  89. return results
  90. def fit(self, event, data, **kwargs):
  91. """ This is where training happens: train your model given list of completions,
  92. then returns dict with created links and resources
  93. :param completions: aka annotations, the labeling results from Label Studio
  94. :param workdir: current working directory for ML backend
  95. """
  96. # save some training outputs to the job result
  97. return {'random': 1}
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...