Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_baseline_predictions.py 1.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
  1. import unittest
  2. import pandas as pd
  3. from pyprojroot import here
  4. from poi_analysis.baseline import BaselineRecommendationAlgorithm
  5. from poi_analysis.clean_data import clean_poi_raw, clean_user_preferences
  6. # CONFIG
  7. DROP_VENUES = ['sophie café and pub tatra house',
  8. 'waaberi restaurant & takeaway',
  9. 'croft ales',
  10. 'taste of morocco',
  11. 'farro bakery']
  12. class TestBaselineRecommendationAlgorithm(unittest.TestCase):
  13. def setUp(self):
  14. self.df_poi = clean_poi_raw(pd.read_excel(
  15. here() / "data/mottli_bristol_poi_curated/Bristol Places Datasheet v2.xlsx"
  16. ), exclude_names=DROP_VENUES)
  17. self.df_user = clean_user_preferences(pd.read_excel(
  18. here() / "data/mottli_bristol_poi_curated/Mottli - User preferences.xlsx"
  19. ), df_poi=self.df_poi)
  20. def test_score_all_users(self):
  21. bra = BaselineRecommendationAlgorithm(self.df_poi)
  22. preds = bra.score_all_users(self.df_user)
  23. # assert shape as expected
  24. self.assertTrue(preds.shape == (150, 12))
  25. # assert real values exist roughly in the region we expect
  26. self.assertTrue(preds.select_dtypes('float').mean().mean() < 0.8)
  27. self.assertTrue(preds.select_dtypes('float').mean().mean() > 0.4)
  28. def tearDown(self):
  29. pass
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...