Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

yolo_ls.py 4.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
  1. import dagshub
  2. import mlflow
  3. import os
  4. from label_studio_ml.model import LabelStudioMLBase
  5. from ultralytics import YOLO
  6. class YoloLS(LabelStudioMLBase):
  7. def __init__(self, **kwargs):
  8. super(YoloLS, self).__init__(**kwargs)
  9. # pre-initialize your variables here
  10. from_name, schema = list(self.parsed_label_config.items())[0]
  11. self.from_name = from_name
  12. self.to_name = schema['to_name'][0]
  13. self.labels = schema['labels']
  14. self.user = os.getenv("DAGSHUB_USER_NAME")
  15. self.token = os.getenv("DAGSHUB_TOKEN")
  16. self.repo = os.getenv("DAGSHUB_REPO_NAME")
  17. dagshub.auth.add_app_token(token=self.token)
  18. dagshub.init(repo_name=self.repo, repo_owner=self.user)
  19. self.model = YOLO('yolov8n-seg.pt', task='segment')
  20. # client = mlflow.MlflowClient()
  21. # name = 'YOLOv8 Custom'
  22. # version = client.get_latest_versions(name=name)[0].version
  23. # self.model_version = f'{name}:{version}'
  24. # model_uri = f'models:/{name}/{version}'
  25. # self.model = mlflow.pyfunc.load_model(model_uri)
  26. def image_uri_to_https(self, uri):
  27. if uri.startswith('http'):
  28. return uri
  29. elif uri.startswith('repo://'):
  30. link_data = uri.split("repo://")[-1].split("/")
  31. commit, tree_path = link_data[0], "/".join(link_data[1:])
  32. return f"https://dagshub.com/api/v1/repos/{self.user}/{self.repo}/raw/{commit}/{tree_path}"
  33. raise FileNotFoundError(f'Unkown URI {uri}')
  34. def predict(self, tasks, **kwargs):
  35. """ This is where inference happens:
  36. model returns the list of predictions based on input list of tasks
  37. :param tasks: Label Studio tasks in JSON format
  38. """
  39. results = []
  40. for task in tasks:
  41. uri = task['data']['image']
  42. url = self.image_uri_to_https(uri)
  43. preds = self.model.predict(url)[0]
  44. lowest_conf = 2.0
  45. img_results = []
  46. boxes = preds.boxes.cpu().numpy()
  47. masks = preds.masks
  48. for i in range(len(boxes.cls)):
  49. conf = float(boxes.conf[i])
  50. if conf < lowest_conf:
  51. lowest_conf = conf
  52. img_results.append({
  53. 'type': 'polygonlabels',
  54. 'to_name': self.to_name,
  55. 'from_name': self.from_name,
  56. 'image_rotation': 0,
  57. 'original_height': preds.orig_shape[0],
  58. 'original_width': preds.orig_shape[1],
  59. 'value': {
  60. 'closed': True,
  61. 'points': (masks.xyn[i] * 100).astype(float).tolist(),
  62. 'polygonlabels': [self.labels[int(boxes.cls[i])]],
  63. },
  64. 'score': conf
  65. })
  66. # img_results = {
  67. # 'type': 'polygonlabels',
  68. # 'to_name': self.to_name,
  69. # 'from_name': self.from_name,
  70. # 'image_rotation': 0,
  71. # 'original_height': preds.orig_shape[0],
  72. # 'original_width': preds.orig_shape[1],
  73. # 'value': {
  74. # 'closed': True,
  75. # 'points': (masks.xyn[i] * 100).astype(float).tolist(),
  76. # 'polygonlabels': self.labels[int(boxes.cls[i])],
  77. # },
  78. # 'score': conf
  79. # }
  80. results.append({
  81. 'result': img_results,
  82. 'score': lowest_conf
  83. })
  84. return results
  85. def fit(self, event, data, **kwargs):
  86. """ This is where training happens: train your model given list of completions,
  87. then returns dict with created links and resources
  88. :param completions: aka annotations, the labeling results from Label Studio
  89. :param workdir: current working directory for ML backend
  90. """
  91. # save some training outputs to the job result
  92. return {'random': 1}
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...