Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train_yolo.py 1.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
  1. from dagshub.streaming import install_hooks
  2. install_hooks()
  3. import dagshub
  4. from dagshub.data_engine import datasources
  5. import logging
  6. import mlflow
  7. import ultralytics
  8. from utils.data import DataFunctions
  9. from utils.dagshub_yolo_cb import custom_callbacks_fn
  10. logger = logging.getLogger('root')
  11. logger.setLevel(logging.INFO)
  12. # Environment Variables
  13. DAGSHUB_USER = "yonomitt"
  14. DAGSHUB_REPO_OWNER = "yonomitt"
  15. DAGSHUB_REPO="ToothFairy"
  16. DAGSHUB_FULL_REPO=DAGSHUB_REPO_OWNER + "/" + DAGSHUB_REPO
  17. DATASOURCE_NAME = "Tooth-Segmentation"
  18. DATASOURCE_PATH = "s3://tooth-dataset/data"
  19. ANNOTATION_FILE = "s3://tooth-dataset/tooth_segmentation.json"
  20. MLFLOW_PROJECT = "Default"
  21. def get_or_create_datasource(name):
  22. try:
  23. ds = datasources.get_datasource(repo=DAGSHUB_FULL_REPO, name=name)
  24. except:
  25. ds = datasources.create(repo=DAGSHUB_FULL_REPO, name=name, path=DATASOURCE_PATH)
  26. return ds
  27. def main():
  28. logger.info('Getting or creating the datasource')
  29. ds = get_or_create_datasource(DATASOURCE_NAME)
  30. dataset_func = DataFunctions(annotation_file=ANNOTATION_FILE, yolo_dir='generated/yolo_data', label_type='segmentation')
  31. logger.info('Converting the datasource to a YOLO-compatible dataset')
  32. dataset_func.create_yolo_v8_dataset_yaml(ds, download=False)
  33. logger.info('Monkey-patching ultralytics')
  34. ultralytics.utils.callbacks.add_integration_callbacks = custom_callbacks_fn
  35. dagshub.init(repo_name=DAGSHUB_REPO, repo_owner=DAGSHUB_USER)
  36. # Load a model
  37. model = ultralytics.YOLO('yolov8l-seg.pt', task='segment') # load a pretrained model (recommended for training)
  38. with mlflow.start_run():
  39. # Train the model
  40. model.train(data='custom_yolo.yaml', epochs=100, imgsz=640, device=0, project=MLFLOW_PROJECT)
  41. if __name__ == '__main__':
  42. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...