1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
|
- """Sorting components: template matching."""
- import numpy as np
- from spikeinterface.core import WaveformExtractor, get_template_channel_sparsity, get_template_extremum_channel
- from spikeinterface.core import get_noise_levels, get_channel_distances, get_chunk_with_margin, get_random_data_chunks
- from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive
- spike_dtype = [
- ("sample_index", "int64"),
- ("channel_index", "int64"),
- ("cluster_index", "int64"),
- ("amplitude", "float64"),
- ("segment_index", "int64"),
- ]
- from .main import BaseTemplateMatchingEngine
- class NaiveMatching(BaseTemplateMatchingEngine):
- """
- This is a naive template matching that does not resolve collision
- and does not take in account sparsity.
- It just minimizes the distance to templates for detected peaks.
- It is implemented for benchmarking against this low quality template matching.
- And also as an example how to deal with methods_kwargs, margin, intit, func, ...
- """
- default_params = {
- "waveform_extractor": None,
- "peak_sign": "neg",
- "exclude_sweep_ms": 0.1,
- "detect_threshold": 5,
- "noise_levels": None,
- "radius_um": 100,
- "random_chunk_kwargs": {},
- }
- @classmethod
- def initialize_and_check_kwargs(cls, recording, kwargs):
- d = cls.default_params.copy()
- d.update(kwargs)
- assert d["waveform_extractor"] is not None, "'waveform_extractor' must be supplied"
- we = d["waveform_extractor"]
- if d["noise_levels"] is None:
- d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"])
- d["abs_threholds"] = d["noise_levels"] * d["detect_threshold"]
- channel_distance = get_channel_distances(recording)
- d["neighbours_mask"] = channel_distance < d["radius_um"]
- d["nbefore"] = we.nbefore
- d["nafter"] = we.nafter
- d["exclude_sweep_size"] = int(d["exclude_sweep_ms"] * recording.get_sampling_frequency() / 1000.0)
- return d
- @classmethod
- def get_margin(cls, recording, kwargs):
- margin = max(kwargs["nbefore"], kwargs["nafter"])
- return margin
- @classmethod
- def serialize_method_kwargs(cls, kwargs):
- kwargs = dict(kwargs)
- we = kwargs.pop("waveform_extractor")
- kwargs["templates"] = we.get_all_templates(mode="average")
- return kwargs
- @classmethod
- def unserialize_in_worker(cls, kwargs):
- return kwargs
- @classmethod
- def main_function(cls, traces, method_kwargs):
- peak_sign = method_kwargs["peak_sign"]
- abs_threholds = method_kwargs["abs_threholds"]
- exclude_sweep_size = method_kwargs["exclude_sweep_size"]
- neighbours_mask = method_kwargs["neighbours_mask"]
- templates = method_kwargs["templates"]
- nbefore = method_kwargs["nbefore"]
- nafter = method_kwargs["nafter"]
- margin = method_kwargs["margin"]
- if margin > 0:
- peak_traces = traces[margin:-margin, :]
- else:
- peak_traces = traces
- peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks(
- peak_traces, peak_sign, abs_threholds, exclude_sweep_size, neighbours_mask
- )
- peak_sample_ind += margin
- spikes = np.zeros(peak_sample_ind.size, dtype=spike_dtype)
- spikes["sample_index"] = peak_sample_ind
- spikes["channel_index"] = peak_chan_ind # TODO need to put the channel from template
- # naively take the closest template
- for i in range(peak_sample_ind.size):
- i0 = peak_sample_ind[i] - nbefore
- i1 = peak_sample_ind[i] + nafter
- waveforms = traces[i0:i1, :]
- dist = np.sum(np.sum((templates - waveforms[None, :, :]) ** 2, axis=1), axis=1)
- cluster_index = np.argmin(dist)
- spikes["cluster_index"][i] = cluster_index
- spikes["amplitude"][i] = 0.0
- return spikes
|