1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
|
- import sys
- sys.path.append('../../hierarchical-dnn-interpretations') # if pip install doesn't work
- import acd
- from acd.scores import cd_propagate
- import numpy as np
- import seaborn as sns
- import matplotlib.colors
- import matplotlib.pyplot as plt
- import torch
- import viz
- def calc_cd_score(xtrack_t, xfeats_t, start, stop, model):
- with torch.no_grad():
- rel, irrel = cd_propagate.propagate_lstm(xtrack_t.unsqueeze(-1), model.lstm, start=start, stop=stop, my_device='cpu')
- rel = rel.squeeze(1)
- irrel = irrel.squeeze(1)
- rel, irrel = cd_propagate.propagate_conv_linear(rel, irrel, model.fc)
- #return rel.item()
- return rel.data.numpy()
- def plot_segs(track_segs, cd_scores, xtrack,
- pred=None, y=None, vabs=None, cbar=True, xticks=True, yticks=True):
- '''Plot a single segmentation plot
- '''
- # cm = sns.diverging_palette(22, 220, as_cmap=True, center='light')
- # cm = LinearSegmentedColormap.from_list(
- # name='orange-blue',
- # colors=[(222/255, 85/255, 51/255),'lightgray', (50/255, 129/255, 168/255)]
- # )
- if vabs is None:
- vabs = np.max(np.abs(cd_scores))
- norm = matplotlib.colors.Normalize(vmin=-vabs, vmax=vabs)
- #vabs = 1.2
- # plt.plot(xtrack, zorder=0, lw=2, color='#111111')
- for i in range(len(track_segs)):
- (s, e) = track_segs[i]
- cd_score = cd_scores[i]
- seq_len = e - s
- xs = np.arange(s, e)
- if seq_len > 1:
- cd_score = [cd_score] * seq_len
- col = viz.cmap(norm(cd_score[0]))
- while len(col) == 1:
- col = col[0]
- plt.plot(xs, xtrack[s: e], zorder=0, lw=2, color=col, alpha=0.5)
- plt.scatter(xs, xtrack[s: e],
- c=cd_score, cmap=viz.cmap, vmin=-vabs, vmax=vabs, s=6)
- if pred is not None:
- plt.title(f"Pred: {pred: .1f}, y: {y}", fontsize=24)
- cb = None
- if cbar:
- cb = plt.colorbar() #label='CD Score')
- cb.outline.set_visible(False)
- if not xticks:
- plt.xticks([])
- if not yticks:
- plt.yticks([])
- return cb
-
-
-
- def max_abs_sum_seg(scores_list, min_length: int=1):
- """
- score_list[i][j] is the score for the segment from i to j (inclusive)
- Params
- ------
- min_length
- Minimum allowable length for a segment
- """
-
- n = len(scores_list[0])
- res = [0]*n
- paths = {}
- for s in range(n):
- for e in range(s, n):
- if e - s >= min_length - 1:
- scores_list[s][e] = abs(scores_list[s][e])
- else:
- scores_list[s][e] = -10000
- paths[-1] = []
- res[0] = scores_list[0][0]
- paths[0] = [0]
- for i in (range(1, n)):
- cand = [res[j-1] + scores_list[j][i] for j in range(i + 1)]
- seg_start = np.argmax(cand)
- res[i] = max(cand)
- paths[i] = paths[seg_start - 1] + [seg_start]
- return res, paths
|