Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

interpret.py 2.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
  1. import sys
  2. sys.path.append('../../hierarchical-dnn-interpretations') # if pip install doesn't work
  3. import acd
  4. from acd.scores import cd_propagate
  5. import numpy as np
  6. import seaborn as sns
  7. import matplotlib.colors
  8. import matplotlib.pyplot as plt
  9. import torch
  10. import viz
  11. def calc_cd_score(xtrack_t, xfeats_t, start, stop, model):
  12. with torch.no_grad():
  13. rel, irrel = cd_propagate.propagate_lstm(xtrack_t.unsqueeze(-1), model.lstm, start=start, stop=stop, my_device='cpu')
  14. rel = rel.squeeze(1)
  15. irrel = irrel.squeeze(1)
  16. rel, irrel = cd_propagate.propagate_conv_linear(rel, irrel, model.fc)
  17. #return rel.item()
  18. return rel.data.numpy()
  19. def plot_segs(track_segs, cd_scores, xtrack,
  20. pred=None, y=None, vabs=None, cbar=True, xticks=True, yticks=True):
  21. '''Plot a single segmentation plot
  22. '''
  23. # cm = sns.diverging_palette(22, 220, as_cmap=True, center='light')
  24. # cm = LinearSegmentedColormap.from_list(
  25. # name='orange-blue',
  26. # colors=[(222/255, 85/255, 51/255),'lightgray', (50/255, 129/255, 168/255)]
  27. # )
  28. if vabs is None:
  29. vabs = np.max(np.abs(cd_scores))
  30. norm = matplotlib.colors.Normalize(vmin=-vabs, vmax=vabs)
  31. #vabs = 1.2
  32. # plt.plot(xtrack, zorder=0, lw=2, color='#111111')
  33. for i in range(len(track_segs)):
  34. (s, e) = track_segs[i]
  35. cd_score = cd_scores[i]
  36. seq_len = e - s
  37. xs = np.arange(s, e)
  38. if seq_len > 1:
  39. cd_score = [cd_score] * seq_len
  40. col = viz.cmap(norm(cd_score[0]))
  41. while len(col) == 1:
  42. col = col[0]
  43. plt.plot(xs, xtrack[s: e], zorder=0, lw=2, color=col, alpha=0.5)
  44. plt.scatter(xs, xtrack[s: e],
  45. c=cd_score, cmap=viz.cmap, vmin=-vabs, vmax=vabs, s=6)
  46. if pred is not None:
  47. plt.title(f"Pred: {pred: .1f}, y: {y}", fontsize=24)
  48. cb = None
  49. if cbar:
  50. cb = plt.colorbar() #label='CD Score')
  51. cb.outline.set_visible(False)
  52. if not xticks:
  53. plt.xticks([])
  54. if not yticks:
  55. plt.yticks([])
  56. return cb
  57. def max_abs_sum_seg(scores_list, min_length: int=1):
  58. """
  59. score_list[i][j] is the score for the segment from i to j (inclusive)
  60. Params
  61. ------
  62. min_length
  63. Minimum allowable length for a segment
  64. """
  65. n = len(scores_list[0])
  66. res = [0]*n
  67. paths = {}
  68. for s in range(n):
  69. for e in range(s, n):
  70. if e - s >= min_length - 1:
  71. scores_list[s][e] = abs(scores_list[s][e])
  72. else:
  73. scores_list[s][e] = -10000
  74. paths[-1] = []
  75. res[0] = scores_list[0][0]
  76. paths[0] = [0]
  77. for i in (range(1, n)):
  78. cand = [res[j-1] + scores_list[j][i] for j in range(i + 1)]
  79. seg_start = np.argmax(cand)
  80. res[i] = max(cand)
  81. paths[i] = paths[seg_start - 1] + [seg_start]
  82. return res, paths
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...