Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

03a81898-f136-4abe-8b6d-5fe8eecbfc2e 3.9 KB
Raw

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
  1. """Sorting components: template matching."""
  2. import numpy as np
  3. from spikeinterface.core import WaveformExtractor, get_template_channel_sparsity, get_template_extremum_channel
  4. from spikeinterface.core import get_noise_levels, get_channel_distances, get_chunk_with_margin, get_random_data_chunks
  5. from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive
  6. spike_dtype = [
  7. ("sample_index", "int64"),
  8. ("channel_index", "int64"),
  9. ("cluster_index", "int64"),
  10. ("amplitude", "float64"),
  11. ("segment_index", "int64"),
  12. ]
  13. from .main import BaseTemplateMatchingEngine
  14. class NaiveMatching(BaseTemplateMatchingEngine):
  15. """
  16. This is a naive template matching that does not resolve collision
  17. and does not take in account sparsity.
  18. It just minimizes the distance to templates for detected peaks.
  19. It is implemented for benchmarking against this low quality template matching.
  20. And also as an example how to deal with methods_kwargs, margin, intit, func, ...
  21. """
  22. default_params = {
  23. "waveform_extractor": None,
  24. "peak_sign": "neg",
  25. "exclude_sweep_ms": 0.1,
  26. "detect_threshold": 5,
  27. "noise_levels": None,
  28. "radius_um": 100,
  29. "random_chunk_kwargs": {},
  30. }
  31. @classmethod
  32. def initialize_and_check_kwargs(cls, recording, kwargs):
  33. d = cls.default_params.copy()
  34. d.update(kwargs)
  35. assert d["waveform_extractor"] is not None, "'waveform_extractor' must be supplied"
  36. we = d["waveform_extractor"]
  37. if d["noise_levels"] is None:
  38. d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"])
  39. d["abs_threholds"] = d["noise_levels"] * d["detect_threshold"]
  40. channel_distance = get_channel_distances(recording)
  41. d["neighbours_mask"] = channel_distance < d["radius_um"]
  42. d["nbefore"] = we.nbefore
  43. d["nafter"] = we.nafter
  44. d["exclude_sweep_size"] = int(d["exclude_sweep_ms"] * recording.get_sampling_frequency() / 1000.0)
  45. return d
  46. @classmethod
  47. def get_margin(cls, recording, kwargs):
  48. margin = max(kwargs["nbefore"], kwargs["nafter"])
  49. return margin
  50. @classmethod
  51. def serialize_method_kwargs(cls, kwargs):
  52. kwargs = dict(kwargs)
  53. we = kwargs.pop("waveform_extractor")
  54. kwargs["templates"] = we.get_all_templates(mode="average")
  55. return kwargs
  56. @classmethod
  57. def unserialize_in_worker(cls, kwargs):
  58. return kwargs
  59. @classmethod
  60. def main_function(cls, traces, method_kwargs):
  61. peak_sign = method_kwargs["peak_sign"]
  62. abs_threholds = method_kwargs["abs_threholds"]
  63. exclude_sweep_size = method_kwargs["exclude_sweep_size"]
  64. neighbours_mask = method_kwargs["neighbours_mask"]
  65. templates = method_kwargs["templates"]
  66. nbefore = method_kwargs["nbefore"]
  67. nafter = method_kwargs["nafter"]
  68. margin = method_kwargs["margin"]
  69. if margin > 0:
  70. peak_traces = traces[margin:-margin, :]
  71. else:
  72. peak_traces = traces
  73. peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks(
  74. peak_traces, peak_sign, abs_threholds, exclude_sweep_size, neighbours_mask
  75. )
  76. peak_sample_ind += margin
  77. spikes = np.zeros(peak_sample_ind.size, dtype=spike_dtype)
  78. spikes["sample_index"] = peak_sample_ind
  79. spikes["channel_index"] = peak_chan_ind # TODO need to put the channel from template
  80. # naively take the closest template
  81. for i in range(peak_sample_ind.size):
  82. i0 = peak_sample_ind[i] - nbefore
  83. i1 = peak_sample_ind[i] + nafter
  84. waveforms = traces[i0:i1, :]
  85. dist = np.sum(np.sum((templates - waveforms[None, :, :]) ** 2, axis=1), axis=1)
  86. cluster_index = np.argmin(dist)
  87. spikes["cluster_index"][i] = cluster_index
  88. spikes["amplitude"][i] = 0.0
  89. return spikes
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...