Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

plot_precision.py 1.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
  1. import sys
  2. from sklearn import svm, datasets
  3. from sklearn.model_selection import train_test_split
  4. from sklearn.metrics import average_precision_score
  5. from sklearn.metrics import precision_recall_curve
  6. import matplotlib.pyplot as plt
  7. import numpy as np
  8. def set_trace():
  9. """A Poor mans break point"""
  10. # without this in iPython debugger can generate strange characters.
  11. from IPython.core.debugger import Pdb
  12. Pdb().set_trace(sys._getframe().f_back)
  13. def precision_recall(X,y):
  14. random_state = np.random.RandomState(0)
  15. n_samples, n_features = X.shape
  16. X = np.c_[X, random_state.randn(n_samples, 200 * n_features)]
  17. # Limit to the two first classes, and split into training and test
  18. X_train, X_test, y_train, y_test = train_test_split(X[y < 2], y[y < 2],
  19. test_size=.5,
  20. random_state=random_state)
  21. # Create a simple classifier
  22. classifier = svm.LinearSVC(random_state=random_state)
  23. classifier.fit(X_train, y_train)
  24. y_score = classifier.decision_function(X_test)
  25. average_precision = average_precision_score(y_test, y_score)
  26. print('Average precision-recall score: {0:0.2f}'.format(average_precision))
  27. '''def plot_precision_curve(y_test, y_score):
  28. precision, recall, _ = precision_recall_curve(y_test, y_score)
  29. plt.step(recall, precision, color='b', alpha=0.2, where='post')
  30. plt.fill_between(recall, precision, step='post', alpha=0.2, color='b')
  31. plt.xlabel('Recall')
  32. plt.ylabel('Precision')
  33. plt.ylim([0.0, 1.05])
  34. plt.xlim([0.0, 1.0])
  35. plt.title( '2-class Precision-Recall curve: AP={0:0.2f}'.format(average_precision))
  36. set_trace()'''
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...