prepare_dataset.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import os
  2. import io
  3. import requests
  4. import zipfile
  5. import shutil
  6. from content_based_recomendation.scripts.movie_lens_content_based_recomendation import filter_ratings
  7. from utils.features_extraction.movie_lens_features_extractor import FeaturesExtractor
  8. from settings import PATH_TO_DATA
  9. def download_file_from_google_drive(id, destination):
  10. URL = "https://docs.google.com/uc?export=download"
  11. session = requests.Session()
  12. response = session.get(URL, params={'id': id}, stream=True)
  13. token = get_confirm_token(response)
  14. if token:
  15. params = {'id': id, 'confirm': token}
  16. response = session.get(URL, params=params, stream=True)
  17. # save_response_content(response, destination)
  18. return response
  19. def get_confirm_token(response):
  20. for key, value in response.cookies.items():
  21. if key.startswith('download_warning'):
  22. return value
  23. return None
  24. def save_response_content(response, destination):
  25. CHUNK_SIZE = 32768
  26. with open(destination, "wb") as f:
  27. for chunk in response.iter_content(CHUNK_SIZE):
  28. if chunk: # filter out keep-alive new chunks
  29. f.write(chunk)
  30. def unpack_starts_with(zip_file, zip_skip, save_path):
  31. members = [x for x in zip_file.NameToInfo.keys() if x.startswith(zip_skip) and len(x) > len(zip_skip)]
  32. for mem in members:
  33. path = save_path + mem[len(zip_skip):]
  34. if not path.endswith('/'):
  35. read_file = zip_file.open(mem)
  36. with open(path, 'wb') as write_file:
  37. shutil.copyfileobj(read_file, write_file)
  38. else:
  39. os.makedirs(path, exist_ok=True)
  40. def main():
  41. eas_path = './dataset/raw/the-movies-dataset/'
  42. eas_zip_skip = ''
  43. eas_gdrive_id = '1Qx9FAqaIG9PbMRJ6coT_NNA9Bck3-jSZ'
  44. os.makedirs(eas_path, exist_ok=True)
  45. print('Downloading...')
  46. r = download_file_from_google_drive(eas_gdrive_id, None)
  47. print('Unzip...')
  48. z = zipfile.ZipFile(io.BytesIO(r.content))
  49. unpack_starts_with(z, eas_zip_skip, eas_path)
  50. print('Filtering')
  51. dataset_path = f'{PATH_TO_DATA}/raw/the-movies-dataset'
  52. features_extractor = FeaturesExtractor(dataset_path)
  53. data = features_extractor.run()
  54. filter_ratings(dataset_path, data)
  55. print('Done')
  56. if __name__ == '__main__':
  57. main()