Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

model_training.py 2.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
  1. """
  2. Training a model on the data from a Data Engine datasource
  3. NOTE: pulling the repository with dvc is not required for this,
  4. since the images + annotations are all received from Data Engine
  5. Furthermore, if you did any annotation work in LabelStudio, the annotations might have diverged
  6. """
  7. import logging
  8. import os
  9. from pathlib import Path
  10. import dagshub
  11. import dagshub.mlflow
  12. import mlflow
  13. from dagshub.common.determine_repo import determine_repo
  14. from dagshub.data_engine import datasources
  15. from dagshub.data_engine.model.datasource import Datasource
  16. from dagshub_annotation_converter.image.exporters import YoloExporter
  17. from dagshub_annotation_converter.image.importers import DagshubDatasourceImporter
  18. from ultralytics import YOLO
  19. logger = logging.getLogger(__name__)
  20. DATASOURCE_NAME = "COCO128" # Name of the Data Engine datasource
  21. def build_training_data(ds: Datasource, training_dir: str) -> Path:
  22. """
  23. Download the datapoints + generate the annotations from a Data Engine datasource, returns a path to the directory
  24. """
  25. target_dir = Path(training_dir)
  26. target_dir.mkdir(parents=True, exist_ok=True)
  27. # Download the datapoints
  28. query_result = ds.all()
  29. query_result.download_files(target_dir=target_dir)
  30. # Reimport the annotations
  31. importer = DagshubDatasourceImporter(query_result)
  32. exporter = YoloExporter(data_dir=target_dir, meta_file=target_dir / "yolo.yaml",
  33. annotation_type="bbox")
  34. proj = importer.parse()
  35. exporter.export(proj)
  36. return target_dir
  37. def get_datasource(ds_name: str) -> Datasource:
  38. current_repo, branch = determine_repo()
  39. ds = datasources.get_datasource(current_repo.full_name, ds_name)
  40. return ds
  41. def train_model():
  42. model = YOLO("yolov8n.pt")
  43. model.train(data="yolo.yaml", device="mps", epochs=3)
  44. # After this the model should be uploaded as an artifact to the MLflow run
  45. def main():
  46. current_repo, branch = determine_repo()
  47. # Run dagshub.init to log training metrics into MLflow
  48. dagshub.init(repo_owner=current_repo.owner, repo_name=current_repo.repo_name, mlflow=True)
  49. # Turn on autolog so YOLO training gets logged
  50. mlflow.autolog()
  51. # Turn on reliability patch so mlflow logging failing doesn't crash training
  52. dagshub.mlflow.patch_mlflow()
  53. with mlflow.start_run():
  54. ds = get_datasource(DATASOURCE_NAME)
  55. training_dir = build_training_data(ds, "training/")
  56. os.chdir(training_dir)
  57. train_model()
  58. if __name__ == "__main__":
  59. logging.basicConfig(level=logging.INFO)
  60. # Disable noisy GQL logs
  61. from gql.transport.requests import log as requests_logger
  62. requests_logger.setLevel(logging.WARNING)
  63. # Change current directory to the root of the repository
  64. os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
  65. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...