Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

Model.py 1.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
  1. # 1. Library imports
  2. import pandas as pd
  3. from sklearn.ensemble import RandomForestClassifier
  4. from pydantic import BaseModel
  5. import joblib
  6. # 2. Class which describes a single flower measurements
  7. class IrisSpecies(BaseModel):
  8. sepal_length: float
  9. sepal_width: float
  10. petal_length: float
  11. petal_width: float
  12. # 3. Class for training the model and making predictions
  13. class IrisModel:
  14. # 6. Class constructor, loads the dataset and loads the model
  15. # if exists. If not, calls the _train_model method and
  16. # saves the model
  17. def __init__(self):
  18. self.df = pd.read_csv('iris.csv')
  19. self.model_fname_ = 'iris_model.pkl'
  20. try:
  21. self.model = joblib.load(self.model_fname_)
  22. except Exception as _:
  23. self.model = self._train_model()
  24. joblib.dump(self.model, self.model_fname_)
  25. # 4. Perform model training using the RandomForest classifier
  26. def _train_model(self):
  27. X = self.df.drop('species', axis=1)
  28. y = self.df['species']
  29. rfc = RandomForestClassifier()
  30. model = rfc.fit(X, y)
  31. return model
  32. # 5. Make a prediction based on the user-entered data
  33. # Returns the predicted species with its respective probability
  34. def predict_species(self, sepal_length, sepal_width, petal_length, petal_width):
  35. data_in = [[sepal_length, sepal_width, petal_length, petal_width]]
  36. prediction = self.model.predict(data_in)
  37. probability = self.model.predict_proba(data_in).max()
  38. return prediction[0], probability
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...