Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made 3 important hidden parameters visible to the user #809

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions kilosort/gui/settings_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,8 @@ def set_cached_field_values(self):
# List of floats gets cached as list of strings, so
# have to convert back.
d = str([float(s) for s in v])
elif k == 'loc_range' or k == 'long_range':
d = str([int(s) for s in v])
else:
d = str(v)
else:
Expand Down
45 changes: 45 additions & 0 deletions kilosort/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,16 @@
"""
},

'max_peels': {
'gui_name': 'max peels', 'type': int, 'min': 1, 'max': 10000, 'exclude': [],
'default': 100, 'step': 'spike detection',
'description':
"""
Number of iterations to do over each batch of data in the matching
pursuit step. More iterations should detect more overlapping spikes.
"""
},

'templates_from_data': {
'gui_name': 'templates from data', 'type': bool, 'min': None, 'max': None,
'exclude': [], 'default': True, 'step': 'spike detection',
Expand All @@ -308,6 +318,28 @@
"""
},

'loc_range': {
'gui_name': 'loc range', 'type': list, 'min': None, 'max': None,
'exclude': [], 'default': [4, 5], 'step': 'spike detection',
'description':
"""
Number of channels and time steps, respectively, to use for local
maximum detection when detecting spikes to compute universal
templates from data (only used if templates_from_data is True).
"""
},

'long_range': {
'gui_name': 'loc range', 'type': list, 'min': None, 'max': None,
'exclude': [], 'default': [6, 30], 'step': 'spike detection',
'description':
"""
Number of channels and time steps, respectively, to use for peak
isolation when detecting spikes to compute universal templates from
data (only used if templates_from_data is True).
"""
},

'n_templates': {
'gui_name': 'n templates', 'type': int, 'min': 1, 'max': np.inf,
'exclude': [], 'default': 6, 'step': 'spike detection',
Expand Down Expand Up @@ -384,6 +416,19 @@
},


'drift_smoothing': {
'gui_name': 'drift smoothing', 'type': list, 'min': None, 'max': None,
'exclude': [], 'default': [0.5, 0.5, 0.5], 'step': 'preprocessing',
'description':
"""
Amount of gaussian smoothing to apply to the spatiotemporal drift
estimation, for correlation, time (units of registration blocks),
and y (units of batches) axes. The y smoothing has no effect
for `nblocks = 1`. Adjusting smoothing for the correlation axis
is not recommended.
"""
},

### POSTPROCESSING
'duplicate_spike_ms': {
'gui_name': 'duplicate spike ms', 'type': float, 'min': 0, 'max': np.inf,
Expand Down
5 changes: 4 additions & 1 deletion kilosort/spikedetect.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,16 @@ def extract_snippets(X, nt, twav_min, Th_single_ch, loc_range=[4,5],
def extract_wPCA_wTEMP(ops, bfile, nt=61, twav_min=20, Th_single_ch=6, nskip=25,
device=torch.device('cuda')):

loc_range = ops['settings']['loc_range']
long_range = ops['settings']['long_range']
clips = np.zeros((500000,nt), 'float32')
i = 0
for j in range(0, bfile.n_batches, nskip):
X = bfile.padded_batch_to_torch(j, ops)

clips_new = extract_snippets(X, nt=nt, twav_min=twav_min,
Th_single_ch=Th_single_ch, device=device)
Th_single_ch=Th_single_ch, device=device,
loc_range=loc_range,long_range=long_range)

nnew = len(clips_new)

Expand Down
7 changes: 5 additions & 2 deletions kilosort/template_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def run_matching(ops, X, U, ctc, device=torch.device('cuda')):

Xres = X.clone()
lam = 20

for t in range(100):
max_peels = ops['settings']['max_peels']
for t in range(max_peels):
# Cf = 2 * B - nm.unsqueeze(-1)
Cf = torch.relu(B)**2 /nm.unsqueeze(-1)
#a = 1 + lam
Expand All @@ -163,6 +163,9 @@ def run_matching(ops, X, U, ctc, device=torch.device('cuda')):
if len(xs)==0:
#print('iter %d'%t)
break
elif len(xs) > 0 and t == max_peels - 1:
logger.debug(f'Reached last iteration of matching pursuit with {len(xs)} spikes detected.')
logger.debug(f'Consider increasing the \'max_peels\' parameter. Current value = {max_peels}')

iX = xs[:,:1]
iY = imax[iX]
Expand Down