Skip to content

Commit

Permalink
Added clear_cache, do_CAR, invert_sign to GUI
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobpennington committed Aug 11, 2024
1 parent 1d11e34 commit 81fb2a5
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 14 deletions.
3 changes: 3 additions & 0 deletions kilosort/gui/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,9 @@ def set_parameters(self):

params = settings.copy()
params['save_preprocessed_copy'] = self.run_box.save_preproc_check.isChecked()
params['clear_cache'] = self.run_box.clear_cache_check.isChecked()
params['do_CAR'] = self.run_box.do_CAR_check.isChecked()
params['invert_sign'] = self.run_box.invert_sign_check.isChecked()

assert params

Expand Down
37 changes: 34 additions & 3 deletions kilosort/gui/run_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def __init__(self, parent):
self.run_all_button = QtWidgets.QPushButton("Run")
self.spike_sort_button = QtWidgets.QPushButton("Spikesort")
self.save_preproc_check = QtWidgets.QCheckBox("Save Preprocessed Copy")
self.clear_cache_check = QtWidgets.QCheckBox("Clear PyTorch Cache")
self.do_CAR_check = QtWidgets.QCheckBox("CAR")
self.invert_sign_check = QtWidgets.QCheckBox("Invert Sign")

self.buttons = [
self.run_all_button
Expand All @@ -44,7 +47,7 @@ def __init__(self, parent):
self.remote_widgets = None

self.progress_bar = QtWidgets.QProgressBar()
self.layout.addWidget(self.progress_bar, 3, 0, 2, 2)
self.layout.addWidget(self.progress_bar, 5, 0, 3, 4)

self.setup()

Expand All @@ -64,8 +67,36 @@ def setup(self):
"""
self.save_preproc_check.setToolTip(preproc_text)

self.layout.addWidget(self.run_all_button, 0, 0, 2, 2)
self.layout.addWidget(self.save_preproc_check, 2, 0, 1, 2)
self.clear_cache_check.setCheckState(QtCore.Qt.CheckState.Unchecked)
cache_text = """
If enabled, force pytorch to free up memory reserved for its cache in
between memory-intensive operations.
Note that setting `clear_cache=True` is NOT recommended unless you
encounter GPU out-of-memory errors, since this can result in slower
sorting.
"""
self.clear_cache_check.setToolTip(cache_text)

self.do_CAR_check.setCheckState(QtCore.Qt.CheckState.Checked)
car_text = """
If enabled, apply common average reference during preprocessing
(recommended).
"""
self.do_CAR_check.setToolTip(car_text)

self.invert_sign_check.setCheckState(QtCore.Qt.CheckState.Unchecked)
invert_sign_text = """
If enabled, flip positive/negative values in data to conform to
standard expected by Kilosort4. This is NOT recommended unless you
know your data is using the opposite sign.
"""
self.invert_sign_check.setToolTip(invert_sign_text)

self.layout.addWidget(self.run_all_button, 0, 0, 3, 4)
self.layout.addWidget(self.save_preproc_check, 3, 0, 1, 2)
self.layout.addWidget(self.clear_cache_check, 3, 2, 1, 2)
self.layout.addWidget(self.do_CAR_check, 4, 0, 1, 2)
self.layout.addWidget(self.invert_sign_check, 4, 2, 1, 2)

self.setLayout(self.layout)

Expand Down
22 changes: 11 additions & 11 deletions kilosort/gui/sorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,13 @@ def run(self):
try:
logger.info(f"Kilosort version {kilosort.__version__}")
logger.info(f"Sorting {self.data_path}")
clear_cache = settings['clear_cache']
if clear_cache:
logger.info('clear_cache=True')
logger.info('-'*40)

tic0 = time.time()

# TODO: make these options in GUI
do_CAR=True
invert_sign=False

if not do_CAR:
logger.info("Skipping common average reference.")

if probe['chanMap'].max() >= settings['n_chan_bin']:
raise ValueError(
f'Largest value of chanMap exceeds channel count of data, '
Expand All @@ -74,9 +70,13 @@ def run(self):
data_dtype = settings['data_dtype']
device = self.device
save_preprocessed_copy = settings['save_preprocessed_copy']
do_CAR = settings['do_CAR']
invert_sign = settings['invert_sign']
if not do_CAR:
logger.info("Skipping common average reference.")

ops = initialize_ops(settings, probe, data_dtype, do_CAR,
invert_sign, device, save_preprocessed_copy)
invert_sign, device, save_preprocessed_copy)
# Remove some stuff that doesn't need to be printed twice,
# then pretty-print format for log file.
ops_copy = ops.copy()
Expand All @@ -94,7 +94,7 @@ def run(self):
torch.random.manual_seed(1)
ops, bfile, st0 = compute_drift_correction(
ops, self.device, tic0=tic0, progress_bar=self.progress_bar,
file_object=self.file_object
file_object=self.file_object, clear_cache=clear_cache
)

# Check scale of data for log file
Expand All @@ -113,7 +113,7 @@ def run(self):
# Sort spikes and save results
st, tF, Wall0, clu0 = detect_spikes(
ops, self.device, bfile, tic0=tic0,
progress_bar=self.progress_bar
progress_bar=self.progress_bar, clear_cache=clear_cache
)

self.Wall0 = Wall0
Expand All @@ -123,7 +123,7 @@ def run(self):

clu, Wall = cluster_spikes(
st, tF, ops, self.device, bfile, tic0=tic0,
progress_bar=self.progress_bar
progress_bar=self.progress_bar, clear_cache=clear_cache
)
ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \
save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0)
Expand Down

0 comments on commit 81fb2a5

Please sign in to comment.