Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

yolo_ls.py 4.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
  1. import dagshub
  2. import mlflow
  3. import os
  4. from label_studio_ml.model import LabelStudioMLBase
  5. from ultralytics import YOLO
  6. class YoloLS(LabelStudioMLBase):
  7. def __init__(self, **kwargs):
  8. super(YoloLS, self).__init__(**kwargs)
  9. # pre-initialize your variables here
  10. from_name, schema = list(self.parsed_label_config.items())[0]
  11. self.from_name = from_name
  12. self.to_name = schema['to_name'][0]
  13. self.labels = schema['labels']
  14. self.user = os.getenv("DAGSHUB_USER_NAME")
  15. self.token = os.getenv("DAGSHUB_TOKEN")
  16. self.repo = os.getenv("DAGSHUB_REPO_NAME")
  17. self.host = os.getenv("DAGSHUB_CLIENT_HOST")
  18. dagshub.auth.add_app_token(token=self.token, host=self.host)
  19. dagshub.init(repo_name=self.repo, repo_owner=self.user)
  20. self.model = YOLO('yolov8n-seg.pt', task='segment')
  21. # client = mlflow.MlflowClient()
  22. # name = 'YOLOv8 Custom'
  23. # version = client.get_latest_versions(name=name)[0].version
  24. # self.model_version = f'{name}:{version}'
  25. # model_uri = f'models:/{name}/{version}'
  26. # self.model = mlflow.pyfunc.load_model(model_uri)
  27. def image_uri_to_https(self, uri):
  28. if uri.startswith('http'):
  29. return uri
  30. elif uri.startswith('repo://'):
  31. link_data = uri.split("repo://")[-1].split("/")
  32. commit, tree_path = link_data[0], "/".join(link_data[1:])
  33. return f"{self.host}/api/v1/repos/{self.user}/{self.repo}/raw/{commit}/{tree_path}"
  34. raise FileNotFoundError(f'Unkown URI {uri}')
  35. def predict(self, tasks, **kwargs):
  36. """ This is where inference happens:
  37. model returns the list of predictions based on input list of tasks
  38. :param tasks: Label Studio tasks in JSON format
  39. """
  40. results = []
  41. for task in tasks:
  42. uri = task['data']['image']
  43. url = self.image_uri_to_https(uri)
  44. if self.host != "https://dagshub.com":
  45. url = url.replace("https://",f"https://{self.user}:{self.token}@")
  46. preds = self.model.predict(url)[0]
  47. lowest_conf = 2.0
  48. img_results = []
  49. boxes = preds.boxes.cpu().numpy()
  50. masks = preds.masks
  51. for i in range(len(boxes.cls)):
  52. conf = float(boxes.conf[i])
  53. if conf < lowest_conf:
  54. lowest_conf = conf
  55. img_results.append({
  56. 'type': 'polygonlabels',
  57. 'to_name': self.to_name,
  58. 'from_name': self.from_name,
  59. 'image_rotation': 0,
  60. 'original_height': preds.orig_shape[0],
  61. 'original_width': preds.orig_shape[1],
  62. 'value': {
  63. 'closed': True,
  64. 'points': (masks.xyn[i] * 100).astype(float).tolist(),
  65. 'polygonlabels': [self.labels[int(boxes.cls[i])]],
  66. },
  67. 'score': conf
  68. })
  69. # img_results = {
  70. # 'type': 'polygonlabels',
  71. # 'to_name': self.to_name,
  72. # 'from_name': self.from_name,
  73. # 'image_rotation': 0,
  74. # 'original_height': preds.orig_shape[0],
  75. # 'original_width': preds.orig_shape[1],
  76. # 'value': {
  77. # 'closed': True,
  78. # 'points': (masks.xyn[i] * 100).astype(float).tolist(),
  79. # 'polygonlabels': self.labels[int(boxes.cls[i])],
  80. # },
  81. # 'score': conf
  82. # }
  83. results.append({
  84. 'result': img_results,
  85. 'score': lowest_conf
  86. })
  87. return results
  88. def fit(self, event, data, **kwargs):
  89. """ This is where training happens: train your model given list of completions,
  90. then returns dict with created links and resources
  91. :param completions: aka annotations, the labeling results from Label Studio
  92. :param workdir: current working directory for ML backend
  93. """
  94. # save some training outputs to the job result
  95. return {'random': 1}
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...