From e25caf733ddc774f29257766c629a474f533887c Mon Sep 17 00:00:00 2001 From: Jiarui Guo Date: Sun, 5 Jun 2022 16:05:26 -0700 Subject: [PATCH] include encoder in model output --- spike_cluster_polisher.py | 5 ++++- spike_main.py | 7 ++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/spike_cluster_polisher.py b/spike_cluster_polisher.py index dd202b3..438d17b 100644 --- a/spike_cluster_polisher.py +++ b/spike_cluster_polisher.py @@ -93,6 +93,7 @@ def filter_clusters (final_clusters, total_spikes, minimal_support=0.005): labels = [None] * total_spikes all_waves = None + encoded_waves = None index = 0 thrown_spikes = 0 kept_cluster = 0 @@ -107,8 +108,10 @@ def filter_clusters (final_clusters, total_spikes, minimal_support=0.005): if all_waves is None: all_waves = each[3] + encoded_waves = each[0] else: all_waves = np.vstack((all_waves, each[3])) + encoded_waves = np.vstack((encoded_waves, each[0])) for _ in range(len(each[0])): labels[index] = idx index += 1 @@ -118,4 +121,4 @@ def filter_clusters (final_clusters, total_spikes, minimal_support=0.005): if index + thrown_spikes != total_spikes: raise Exception ("Something is off %d %d %d" % (index, thrown_spikes, total_spikes)) - return all_waves, labels, thrown_spikes + return encoded_waves, all_waves, labels, thrown_spikes diff --git a/spike_main.py b/spike_main.py index f574017..a128bf7 100644 --- a/spike_main.py +++ b/spike_main.py @@ -222,19 +222,20 @@ def main (): logging.critical ("Done merging spikes. Total %d clusters!!!" % (len(final_clusters))) # Lastly, classify the results with SVM - all_waves, labels, n_thrown_spikes = filter_clusters(final_clusters, total_spikes) + encoded_waves, all_waves, labels, n_thrown_spikes = filter_clusters(final_clusters, total_spikes) logging.critical ("Done filtering. Ended up %d clusters and threw away %d spikes!!!" % (len(np.unique(labels)), n_thrown_spikes)) logging.critical ("Started SVM classification!!!") - svm_classifier.Fit (data=all_waves, label=labels) + svm_classifier.Fit (data=encoded_waves, label=labels) logging.critical ("Done SVM classification!!!") scatter(all_waves, labels) # Consume trained SVMs new_spikes = get_sample_spikes () + new_encoded_spikes = encoder.predict(np.array(new_spikes)) new_labels = [] - for each in new_spikes: + for each in new_encoded_spikes: res = svm_classifier.Predict (each) new_labels.append(res)