Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

prediction.py 2.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
  1. import yaml
  2. import os
  3. import json
  4. import joblib
  5. import numpy as np
  6. params_path = "params.yaml"
  7. schema_path = os.path.join("prediction_service", "schema_in.json")
  8. class NotInRange(Exception):
  9. def __init__(self, message="Values entered are not in expected range"):
  10. self.message = message
  11. super().__init__(self.message)
  12. class NotInCols(Exception):
  13. def __init__(self, message="Not in cols"):
  14. self.message = message
  15. super().__init__(self.message)
  16. def read_params(config_path=params_path):
  17. with open(config_path) as yaml_file:
  18. config = yaml.safe_load(yaml_file)
  19. return config
  20. def predict(data):
  21. config = read_params(params_path)
  22. model_dir_path = config["webapp_model_dir"]
  23. model = joblib.load(model_dir_path)
  24. prediction = model.predict(data).tolist()[0]
  25. try:
  26. if 3 <= prediction <= 8:
  27. return prediction
  28. else:
  29. raise NotInRange
  30. except NotInRange:
  31. return "Unexpected result"
  32. def get_schema(schema_path=schema_path):
  33. with open(schema_path) as json_file:
  34. schema = json.load(json_file)
  35. return schema
  36. def validate_input(dict_request):
  37. def _validate_cols(col):
  38. schema = get_schema()
  39. actual_cols = schema.keys()
  40. if col not in actual_cols:
  41. raise NotInCols
  42. def _validate_values(col, val):
  43. schema = get_schema()
  44. if not (schema[col]["min"] <= float(dict_request[col]) <= schema[col]["max"]):
  45. raise NotInRange
  46. for col, val in dict_request.items():
  47. _validate_cols(col)
  48. _validate_values(col, val)
  49. return True
  50. def form_response(dict_request):
  51. if validate_input(dict_request):
  52. data = dict_request.values()
  53. data = [list(map(float, data))]
  54. response = predict(data)
  55. return response
  56. def api_response(dict_request):
  57. try:
  58. if validate_input(dict_request):
  59. data = np.array([list(dict_request.values())])
  60. response = predict(data)
  61. response = {"response": response}
  62. return response
  63. except NotInRange as e:
  64. response = {"the_exected_range": get_schema(), "response": str(e)}
  65. return response
  66. except NotInCols as e:
  67. response = {"the_exected_cols": get_schema().keys(), "response": str(e)}
  68. return response
  69. except Exception as e:
  70. response = {"response": str(e)}
  71. return response
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...