Skip to content
This repository has been archived by the owner on Apr 27, 2023. It is now read-only.

Predict with encoder + svm #23

Open
wants to merge 1 commit into
base: test/simple_noisy_spikes
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion spike_cluster_polisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def filter_clusters (final_clusters, total_spikes, minimal_support=0.005):

labels = [None] * total_spikes
all_waves = None
encoded_waves = None
index = 0
thrown_spikes = 0
kept_cluster = 0
Expand All @@ -107,8 +108,10 @@ def filter_clusters (final_clusters, total_spikes, minimal_support=0.005):

if all_waves is None:
all_waves = each[3]
encoded_waves = each[0]
else:
all_waves = np.vstack((all_waves, each[3]))
encoded_waves = np.vstack((encoded_waves, each[0]))
for _ in range(len(each[0])):
labels[index] = idx
index += 1
Expand All @@ -118,4 +121,4 @@ def filter_clusters (final_clusters, total_spikes, minimal_support=0.005):
if index + thrown_spikes != total_spikes:
raise Exception ("Something is off %d %d %d" % (index, thrown_spikes, total_spikes))

return all_waves, labels, thrown_spikes
return encoded_waves, all_waves, labels, thrown_spikes
7 changes: 4 additions & 3 deletions spike_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,19 +222,20 @@ def main ():
logging.critical ("Done merging spikes. Total %d clusters!!!" % (len(final_clusters)))

# Lastly, classify the results with SVM
all_waves, labels, n_thrown_spikes = filter_clusters(final_clusters, total_spikes)
encoded_waves, all_waves, labels, n_thrown_spikes = filter_clusters(final_clusters, total_spikes)
logging.critical ("Done filtering. Ended up %d clusters and threw away %d spikes!!!" % (len(np.unique(labels)), n_thrown_spikes))

logging.critical ("Started SVM classification!!!")
svm_classifier.Fit (data=all_waves, label=labels)
svm_classifier.Fit (data=encoded_waves, label=labels)
logging.critical ("Done SVM classification!!!")

scatter(all_waves, labels)

# Consume trained SVMs
new_spikes = get_sample_spikes ()
new_encoded_spikes = encoder.predict(np.array(new_spikes))
new_labels = []
for each in new_spikes:
for each in new_encoded_spikes:
res = svm_classifier.Predict (each)
new_labels.append(res)

Expand Down