Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

main.py 2.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
  1. from fastapi import FastAPI
  2. import sys
  3. sys.path.insert(0, "app")
  4. from joblib import load
  5. import snscrape.modules.twitter as sntwitter
  6. from src.config import conf
  7. from pymongo import MongoClient
  8. from datetime import datetime
  9. from fastapi import FastAPI, Depends, HTTPException, status, Form, Request
  10. from src.utils.database import Database
  11. import pandas as pd
  12. from pathlib import Path
  13. app = FastAPI()
  14. tfidf, model = load('{}/{}'.format(Path(__file__).with_name('models'), 'lr.joblib'))
  15. def scrape_account(account_name):
  16. '''
  17. return tweets from account_name
  18. '''
  19. max_tweets = 100
  20. tweets = []
  21. match = 'from:{}'.format(account_name)
  22. for i,tweet in enumerate(sntwitter.TwitterSearchScraper(match).get_items()):
  23. if i >= max_tweets:
  24. break
  25. tweets.append({
  26. "tweet_id": tweet.id,
  27. "account_name": account_name,
  28. "content": tweet.content
  29. })
  30. return tweets
  31. def predict_tweet(tweet):
  32. prediction = list(model.predict_proba(tfidf.transform([tweet['content']]))[0])
  33. labels = ['droite', 'gauche']
  34. prediction_dict = {
  35. 'position': labels[prediction.index(max(prediction))],
  36. labels[0]: prediction[0],
  37. labels[1]: prediction[1]
  38. }
  39. tweet = {**tweet, **prediction_dict}
  40. with Database() as db:
  41. db.api.insert_one(tweet.copy())
  42. return tweet
  43. @app.post("/predict")
  44. async def predict(r: Request):
  45. form = await r.form()
  46. tweet = dict(form)
  47. tweet = predict_tweet(tweet)
  48. return tweet
  49. @app.get("/predict/account/{account_name}")
  50. def predict_account(account_name: str):
  51. with Database() as db:
  52. tweets = list(db.api.find({"account_name": account_name}, {"_id":0, "gauche":1, "droite":1}))
  53. if len(tweets) > 0:
  54. df = pd.DataFrame(tweets)
  55. return {
  56. 'droite': df.droite.mean(),
  57. 'gauche': df.gauche.mean(),
  58. }
  59. else:
  60. tweets = scrape_account(account_name)
  61. for tweet in tweets:
  62. predict_tweet(tweet)
  63. predict_account(account_name)
  64. @app.get("/tweets/{account_name}")
  65. def get_tweets(account_name: str):
  66. return scrape_account(account_name)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...