Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

visualize.py 2.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
  1. import itertools
  2. import matplotlib.colors
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. from typing import List, Text
  6. def plot_confusion_matrix(cm: np.array,
  7. target_names: List[Text],
  8. title: Text = 'Confusion matrix',
  9. cmap: matplotlib.colors.LinearSegmentedColormap = None,
  10. normalize: bool = True):
  11. """
  12. given a sklearn confusion matrix (cm), make a nice plot
  13. Arguments
  14. ---------
  15. cm: confusion matrix from sklearn.metrics.confusion_matrix
  16. target_names: given classification classes such as [0, 1, 2]
  17. the class names, for example: ['high', 'medium', 'low']
  18. title: the text to display at the top of the matrix
  19. cmap: the gradient of the values displayed from matplotlib.pyplot.cm
  20. see http://matplotlib.org/examples/color/colormaps_reference.html
  21. plt.get_cmap('jet') or plt.cm.Blues
  22. normalize: If False, plot the raw numbers
  23. If True, plot the proportions
  24. Usage
  25. -----
  26. plot_confusion_matrix(cm = cm, # confusion matrix created by
  27. # sklearn.metrics.confusion_matrix
  28. normalize = True, # show proportions
  29. target_names = y_labels_vals, # list of names of the classes
  30. title = best_estimator_name) # title of graph
  31. Citiation
  32. ---------
  33. http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
  34. """
  35. accuracy = np.trace(cm) / float(np.sum(cm))
  36. misclass = 1 - accuracy
  37. if cmap is None:
  38. cmap = plt.get_cmap('Blues')
  39. plt.figure(figsize=(8, 6))
  40. plt.imshow(cm, interpolation='nearest', cmap=cmap)
  41. plt.title(title)
  42. plt.colorbar()
  43. if target_names is not None:
  44. tick_marks = np.arange(len(target_names))
  45. plt.xticks(tick_marks, target_names, rotation=45)
  46. plt.yticks(tick_marks, target_names)
  47. if normalize:
  48. cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
  49. thresh = cm.max() / 1.5 if normalize else cm.max() / 2
  50. for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
  51. if normalize:
  52. plt.text(j, i, "{:0.4f}".format(cm[i, j]),
  53. horizontalalignment="center",
  54. color="white" if cm[i, j] > thresh else "black")
  55. else:
  56. plt.text(j, i, "{:,}".format(cm[i, j]),
  57. horizontalalignment="center",
  58. color="white" if cm[i, j] > thresh else "black")
  59. plt.tight_layout()
  60. plt.ylabel('True label')
  61. plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))
  62. return plt.gcf()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...