Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

tests.py 3.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
  1. from fastapi.testclient import TestClient
  2. import sys
  3. from src.config import conf
  4. from src.utils.database import Database
  5. import urllib
  6. import re
  7. from main import app
  8. test_client = TestClient(app)
  9. def assert_predict(input, response):
  10. '''
  11. check status_code and response object
  12. '''
  13. json = response.json()
  14. assert response.status_code == 200
  15. assert json['content'] == input
  16. total = 0
  17. targets = ('droite', 'gauche')
  18. for target in targets:
  19. assert target in json
  20. assert type(json[target]) == float
  21. assert json[target] >= 0 and json[target] <= 1
  22. total += json[target]
  23. assert total > 0.99 and total < 1.01
  24. def test_predict():
  25. '''
  26. we test ours predictions and api
  27. '''
  28. num_of_tweets_by_side = 50
  29. min_score = 0.75
  30. with Database() as db:
  31. left_tweets = db.tweets.find({'group': {'$in': conf['parties']['gauche']}}, {'content':1}).limit(num_of_tweets_by_side)
  32. right_tweets = db.tweets.find({'group': {'$in': conf['parties']['droite']}}, {'content':1}).limit(num_of_tweets_by_side)
  33. custom_inputs = {
  34. "gauche" : [
  35. "j'aime la gauche",
  36. "vive l'égalité et la solidarité entre les peuples"
  37. ],
  38. "droite" : [
  39. "j'aime la droite",
  40. "il faut réaffirmer notre souveraineté nationale"
  41. ]
  42. }
  43. db_inputs = {
  44. "gauche": left_tweets,
  45. "droite": right_tweets
  46. }
  47. # for each custom inputs prediction must be valid or test will fails
  48. for side, inputs in custom_inputs.items():
  49. for input in inputs:
  50. response = test_client.post("/predict", data = {"content":input})
  51. json = response.json()
  52. assert_predict(input, response)
  53. assert json['position'] == side
  54. # calculate valid score of our predictions with model data, score must be > 0.75 because our model has 0.85 score
  55. db_inputs_valid = 0
  56. for side, inputs in db_inputs.items():
  57. for input in inputs:
  58. input = input['content']
  59. response = test_client.post("/predict", data = {"content":input})
  60. json = response.json()
  61. assert_predict(input, response)
  62. if json['position'] == side:
  63. db_inputs_valid += 1
  64. score = db_inputs_valid/num_of_tweets_by_side
  65. assert(score > min_score)
  66. def test_predict_account():
  67. '''
  68. same tests but for account
  69. '''
  70. accounts_gauche = ("JLMelenchon",)
  71. accounts_droite = ("MLP_officiel",)
  72. min_score = 0.75
  73. for account in accounts_gauche:
  74. response = test_client.get("/predict/account/{}".format(account))
  75. print(response.json())
  76. assert response.status_code == 200
  77. json = response.json()
  78. total = 0
  79. targets = ('droite', 'gauche')
  80. for target in targets:
  81. assert target in json
  82. assert type(json[target]) == float
  83. assert json[target] >= 0 and json[target] <= 1
  84. total += json[target]
  85. assert total > 0.99 and total < 1.01
  86. ##
  87. # note for encoding and decoding a query parameter :
  88. # remove url : input = re.sub(r'http\S+', '', input['content'])
  89. # encode : input = urllib.parse.quote_plus(input)
  90. # decode : input = urllib.parse.unquote_plus(input)
  91. # i switch from get to post requests cause cant handle all parsing issues without adding lot of regex. e.g. (2/2)
  92. ##
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...