1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
|
- import json
- import os
- import argparse
- import logging
- import logging.config
- logging.config.dictConfig({
- "version": 1,
- "formatters": {
- "standard": {
- "format": "[%(asctime)s] [%(levelname)s] [%(name)s::%(funcName)s::%(lineno)d] %(message)s"
- }
- },
- "handlers": {
- "console": {
- "class": "logging.StreamHandler",
- "level": "DEBUG",
- "stream": "ext://sys.stdout",
- "formatter": "standard"
- }
- },
- "root": {
- "level": "ERROR",
- "handlers": [
- "console"
- ],
- "propagate": True
- }
- })
- from label_studio_ml.api import init_app
- from ls_model_server import SquirrelDetectorLSModel
- _DEFAULT_CONFIG_PATH = os.path.join(os.path.dirname(__file__), 'config.json')
- def get_kwargs_from_config(config_path=_DEFAULT_CONFIG_PATH):
- if not os.path.exists(config_path):
- return dict()
- with open(config_path) as f:
- config = json.load(f)
- assert isinstance(config, dict)
- return config
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(description='Label studio')
- parser.add_argument(
- '-p', '--port', dest='port', type=int, default=9090,
- help='Server port')
- parser.add_argument(
- '--host', dest='host', type=str, default='0.0.0.0',
- help='Server host')
- parser.add_argument(
- '--kwargs', '--with', dest='kwargs', metavar='KEY=VAL', nargs='+', type=lambda kv: kv.split('='),
- help='Additional LabelStudioMLBase model initialization kwargs')
- parser.add_argument(
- '-d', '--debug', dest='debug', action='store_true',
- help='Switch debug mode')
- parser.add_argument(
- '--log-level', dest='log_level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default=None,
- help='Logging level')
- parser.add_argument(
- '--model-dir', dest='model_dir', default=os.path.dirname(__file__),
- help='Directory where models are stored (relative to the project directory)')
- parser.add_argument(
- '--check', dest='check', action='store_true',
- help='Validate model instance before launching server')
- args = parser.parse_args()
- # setup logging level
- if args.log_level:
- logging.root.setLevel(args.log_level)
- def isfloat(value):
- try:
- float(value)
- return True
- except ValueError:
- return False
- def parse_kwargs():
- param = dict()
- for k, v in args.kwargs:
- if v.isdigit():
- param[k] = int(v)
- elif v == 'True' or v == 'true':
- param[k] = True
- elif v == 'False' or v == 'false':
- param[k] = False
- elif isfloat(v):
- param[k] = float(v)
- else:
- param[k] = v
- return param
- kwargs = get_kwargs_from_config()
- if args.kwargs:
- kwargs.update(parse_kwargs())
- if args.check:
- print('Check "' + SquirrelDetectorLSModel.__name__ + '" instance creation..')
- model = SquirrelDetectorLSModel(**kwargs)
- app = init_app(
- model_class=SquirrelDetectorLSModel,
- model_dir=os.environ.get('MODEL_DIR', args.model_dir),
- redis_queue=os.environ.get('RQ_QUEUE_NAME', 'default'),
- redis_host=os.environ.get('REDIS_HOST', 'localhost'),
- redis_port=os.environ.get('REDIS_PORT', 6379),
- **kwargs
- )
- app.run(host=args.host, port=args.port, debug=args.debug)
- else:
- # for uWSGI use
- app = init_app(
- model_class=SquirrelDetectorLSModel,
- model_dir=os.environ.get('MODEL_DIR', os.path.dirname(__file__)),
- redis_queue=os.environ.get('RQ_QUEUE_NAME', 'default'),
- redis_host=os.environ.get('REDIS_HOST', 'localhost'),
- redis_port=os.environ.get('REDIS_PORT', 6379)
- )
|