-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathremove_outliers.m
31 lines (29 loc) · 1.14 KB
/
remove_outliers.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
function index = remove_outliers(waveforms, idx)
%--------------------------------------------------------------------------
% remove_outliers.m - Uses PCA for each cluster to determine outliers to be
% removed. A Gaussian PCA projection is assumed, and 6 standard deviations
% are used as the cutoff.
%
% Usage: [waveforms,ind_removed] = remove_outliers(waveforms, idx);
%
% Input: waveforms * CxWxT matrix of waveforms
% idx * 1xW vector of cluster membership IDs
% Output: index * 1xW logical vector of waveforms to
% remove.
%
% Written by Marshall Crumiller
% email: [email protected]
%--------------------------------------------------------------------------
cutoff = 4; % in std devs
[C,W,T]=size(waveforms);
waveforms=reshape(permute(waveforms,[3 1 2]),C*T,[])';
index=false(W,1);
num_groups=length(unique(idx));
for i = 1:num_groups
locs=find(idx==i);
wfs=waveforms(locs,:);
[~,score]=princomp(wfs);
score=zscore(score(:,[1 2]));
bad_cells = any(score<-cutoff | score>cutoff,2);
index(locs(bad_cells))=true;
end