Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

register_model.py 1.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
  1. import argparse
  2. import cloudpickle
  3. import mlflow
  4. import os
  5. import torch
  6. from sys import version_info
  7. from get_or_create_mlflow_experiment import get_experiment_id
  8. import model_wrapper
  9. from model_wrapper import SquirrelDetectorWrapper
  10. mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI"))
  11. PYTHON_VERSION = "{major}.{minor}.1".format(major=version_info.major,
  12. minor=version_info.minor)
  13. conda_env = {
  14. 'channels': ['defaults'],
  15. 'dependencies': [
  16. 'python~={}'.format(PYTHON_VERSION),
  17. 'pip',
  18. {
  19. 'pip': [
  20. 'mlflow',
  21. 'pillow',
  22. 'cloudpickle=={}'.format(cloudpickle.__version__),
  23. 'torch>=1.12.0'
  24. ],
  25. },
  26. ],
  27. 'name': 'squirrel_env'
  28. }
  29. def main():
  30. parser = argparse.ArgumentParser('Creates/gets an MLflow experiment and registers a YOLOv5 model to the Model Registry')
  31. parser.add_argument('--name', help='MLflow experiment name')
  32. parser.add_argument('--model', help='Path to saved YOLOv5 PyTorch model')
  33. parser.add_argument('--model-name', help='Registered model name')
  34. args = parser.parse_args()
  35. artifacts = { 'path': args.model }
  36. model = SquirrelDetectorWrapper()
  37. exp_id = get_experiment_id(args.name)
  38. cloudpickle.register_pickle_by_value(model_wrapper)
  39. with mlflow.start_run(experiment_id=exp_id):
  40. mlflow.pyfunc.log_model(
  41. 'finetuned',
  42. python_model=model,
  43. conda_env=conda_env,
  44. artifacts=artifacts,
  45. registered_model_name=args.model_name
  46. )
  47. if __name__ == '__main__':
  48. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...