-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathpermuted_mnist.py
119 lines (90 loc) · 4.18 KB
/
permuted_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from __future__ import division
import time
import argparse
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from RMA import RMA
def plot_results(num_tasks_to_run, baseline_mlp, memoryadapted):
import matplotlib.pyplot as plt
tasks = range(1, num_tasks_to_run + 1)
plt.plot(tasks, baseline_mlp[::-1])
plt.plot(tasks, memoryadapted[::-1])
plt.legend(["Baseline-MLP", "RMA"], loc='lower right')
plt.xlabel("Number of Tasks")
plt.ylabel("Accuracy (%)")
plt.ylim([1, 100])
plt.xticks(tasks)
plt.show()
def main(_):
with tf.Session() as sess:
print("\nParameters used:", args, "\n")
# We create the baseline model
baseline_model = RMA(sess, args)
# We create the memory adapted model
rma_model = RMA(sess, args)
# Permuted MNIST
# Generate the tasks specifications as a list of random permutations of the input pixels.
mnist = input_data.read_data_sets("/tmp/", one_hot=True)
task_permutation = []
for task in range(args.num_tasks_to_run):
task_permutation.append(np.random.permutation(784))
print("\nBaseline MLP training...")
start = time.time()
last_performance_baseline = training(baseline_model, mnist, task_permutation, False)
end = time.time()
time_needed_baseline = round(end - start)
print("Training time elapsed: ", time_needed_baseline, "s")
print("\nMemory adapted (RMA) training...")
start = time.time()
last_performance_ma = training(rma_model, mnist, task_permutation, True)
end = time.time()
time_needed_ma = round(end - start)
print("Training time elapsed: ", time_needed_ma, "s")
# Stats
print("\nDifference in time between using or not memory: ", time_needed_ma - time_needed_baseline, "s")
print("Test accuracy mean gained due to the memory: ",
np.round(np.mean(last_performance_ma) - np.mean(last_performance_baseline), 2))
# Plot the results
plot_results(args.num_tasks_to_run, last_performance_baseline, last_performance_ma)
def training(model, mnist, task_permutation, use_memory=True):
# Training the model using or not memory adaptation
last_performance = []
for task in range(args.num_tasks_to_run):
print("\nTraining task: ", task + 1, "/", args.num_tasks_to_run)
for i in range(int(10000)):
# Permute batch elements
batch = mnist.train.next_batch(args.batch_size)
batch = (batch[0][:, task_permutation[task]], batch[1])
if use_memory:
model.train(batch[0], batch[1], args.batch_size)
# We just store a batch sample each args.memory_each steps
if i % args.memory_each == 0:
model.add_to_memory(batch[0], batch[1])
else:
model.train(batch[0], batch[1], 0)
# Print test set accuracy to each task encountered so far
for test_task in range(task + 1):
test_images = mnist.test.images
# Permute batch elements
test_images = test_images[:, task_permutation[test_task]]
acc = model.test(test_images, mnist.test.labels)
acc = acc * 100
if args.num_tasks_to_run == task + 1:
last_performance.append(acc)
print("Testing, Task: ", test_task + 1, " \tAccuracy: ", acc)
return last_performance
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--num_tasks_to_run', type=int, default=20,
help='Number of task to run')
parser.add_argument('--memory_size', type=int, default=15000,
help='Memory size')
parser.add_argument('--memory_each', type=int, default=1000,
help='Add to memory after these number of steps')
parser.add_argument('--batch_size', type=int, default=32,
help='Size of batch for updates')
parser.add_argument('--learning_rate', type=float, default=0.5,
help='Learning rate')
args = parser.parse_args()
tf.app.run()