Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

1ad40730-6e94-48ad-bd37-3fc6b28c0ca9 6.6 KB
Raw

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
  1. import numpy as np
  2. from ..core import WaveformExtractor
  3. from ..core.waveform_extractor import BaseWaveformExtractorExtension
  4. class TemplateSimilarityCalculator(BaseWaveformExtractorExtension):
  5. """Compute similarity between templates with several methods.
  6. Parameters
  7. ----------
  8. waveform_extractor: WaveformExtractor
  9. A waveform extractor object
  10. """
  11. extension_name = "similarity"
  12. def __init__(self, waveform_extractor):
  13. BaseWaveformExtractorExtension.__init__(self, waveform_extractor)
  14. def _set_params(self, method="cosine_similarity"):
  15. params = dict(method=method)
  16. return params
  17. def _select_extension_data(self, unit_ids):
  18. # filter metrics dataframe
  19. unit_indices = self.waveform_extractor.sorting.ids_to_indices(unit_ids)
  20. new_similarity = self._extension_data["similarity"][unit_indices][:, unit_indices]
  21. return dict(similarity=new_similarity)
  22. def _run(self):
  23. similarity = _compute_template_similarity(self.waveform_extractor, method=self._params["method"])
  24. self._extension_data["similarity"] = similarity
  25. def get_data(self):
  26. """
  27. Get the computed similarity.
  28. Returns
  29. -------
  30. similarity : 2d np.array
  31. 2d matrix with computed similarity values.
  32. """
  33. msg = "Template similarity is not computed. Use the 'run()' function."
  34. assert self._extension_data["similarity"] is not None, msg
  35. return self._extension_data["similarity"]
  36. @staticmethod
  37. def get_extension_function():
  38. return compute_template_similarity
  39. WaveformExtractor.register_extension(TemplateSimilarityCalculator)
  40. def _compute_template_similarity(
  41. waveform_extractor, load_if_exists=False, method="cosine_similarity", waveform_extractor_other=None
  42. ):
  43. import sklearn.metrics.pairwise
  44. templates = waveform_extractor.get_all_templates()
  45. s = templates.shape
  46. if method == "cosine_similarity":
  47. templates_flat = templates.reshape(s[0], -1)
  48. if waveform_extractor_other is not None:
  49. templates_other = waveform_extractor_other.get_all_templates()
  50. s_other = templates_other.shape
  51. templates_other_flat = templates_other.reshape(s_other[0], -1)
  52. assert len(templates_flat[0]) == len(templates_other_flat[0]), (
  53. "Templates from second WaveformExtractor " "don't have the correct shape!"
  54. )
  55. else:
  56. templates_other_flat = None
  57. similarity = sklearn.metrics.pairwise.cosine_similarity(templates_flat, templates_other_flat)
  58. # elif method == '':
  59. else:
  60. raise ValueError(f"compute_template_similarity(method {method}) not exists")
  61. return similarity
  62. def compute_template_similarity(
  63. waveform_extractor, load_if_exists=False, method="cosine_similarity", waveform_extractor_other=None
  64. ):
  65. """Compute similarity between templates with several methods.
  66. Parameters
  67. ----------
  68. waveform_extractor: WaveformExtractor
  69. A waveform extractor object
  70. load_if_exists : bool, default: False
  71. Whether to load precomputed similarity, if is already exists.
  72. method: str, default: "cosine_similarity"
  73. The method to compute the similarity
  74. waveform_extractor_other: WaveformExtractor, default: None
  75. A second waveform extractor object
  76. Returns
  77. -------
  78. similarity: np.array
  79. The similarity matrix
  80. """
  81. if waveform_extractor_other is None:
  82. if load_if_exists and waveform_extractor.is_extension(TemplateSimilarityCalculator.extension_name):
  83. tmc = waveform_extractor.load_extension(TemplateSimilarityCalculator.extension_name)
  84. else:
  85. tmc = TemplateSimilarityCalculator(waveform_extractor)
  86. tmc.set_params(method=method)
  87. tmc.run()
  88. similarity = tmc.get_data()
  89. return similarity
  90. else:
  91. return _compute_template_similarity(waveform_extractor, waveform_extractor_other, method)
  92. def check_equal_template_with_distribution_overlap(
  93. waveforms0, waveforms1, template0=None, template1=None, num_shift=2, quantile_limit=0.8, return_shift=False
  94. ):
  95. """
  96. Given 2 waveforms sets, check if they come from the same distribution.
  97. This is computed with a simple trick:
  98. It project all waveforms from each cluster on the normed vector going from
  99. one template to another, if the cluster are well separate enought we should
  100. have one distribution around 0 and one distribution around .
  101. If the distribution overlap too much then then come from the same distribution.
  102. Done by samuel Garcia with an idea of Crhistophe Pouzat.
  103. This is used internally by tridesclous for auto merge step.
  104. Can be also used as a distance metrics between 2 clusters.
  105. waveforms0 and waveforms1 have to be spasifyed outside this function.
  106. This is done with a combinaison of shift bewteen the 2 cluster to also check
  107. if cluster are similar with a sample shift.
  108. Parameters
  109. ----------
  110. waveforms0, waveforms1: numpy array
  111. Shape (num_spikes, num_samples, num_chans)
  112. num_spikes are not necessarly the same for custer.
  113. template0 , template1=None or numpy array
  114. The average of each cluster.
  115. If None, then computed.
  116. num_shift: int default: 2
  117. number of shift on each side to perform.
  118. quantile_limit: float in [0 1]
  119. The quantile overlap limit.
  120. Returns
  121. -------
  122. equal: bool
  123. equal or not
  124. """
  125. assert waveforms0.shape[1] == waveforms1.shape[1]
  126. assert waveforms0.shape[2] == waveforms1.shape[2]
  127. if template0 is None:
  128. template0 = np.mean(waveforms0, axis=0)
  129. if template1 is None:
  130. template1 = np.mean(waveforms1, axis=0)
  131. template0_ = template0[num_shift:-num_shift, :]
  132. width = template0_.shape[0]
  133. wfs0 = waveforms0[:, num_shift:-num_shift, :].copy()
  134. equal = False
  135. final_shift = None
  136. for shift in range(num_shift * 2 + 1):
  137. template1_ = template1[shift : width + shift, :]
  138. vector_0_1 = template1_ - template0_
  139. vector_0_1 /= np.sum(vector_0_1**2)
  140. wfs1 = waveforms1[:, shift : width + shift, :].copy()
  141. scalar_product0 = np.sum((wfs0 - template0_[np.newaxis, :, :]) * vector_0_1[np.newaxis, :, :], axis=(1, 2))
  142. scalar_product1 = np.sum((wfs1 - template0_[np.newaxis, :, :]) * vector_0_1[np.newaxis, :, :], axis=(1, 2))
  143. l0 = np.quantile(scalar_product0, quantile_limit)
  144. l1 = np.quantile(scalar_product1, 1 - quantile_limit)
  145. equal = l0 >= l1
  146. if equal:
  147. final_shift = shift - num_shift
  148. break
  149. if return_shift:
  150. return equal, final_shift
  151. else:
  152. return equal
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...