-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathsweeper.py
170 lines (147 loc) · 6.54 KB
/
sweeper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import time
from mps import MPS
from feature_tensor import FeatureTensor
import numpy as np
import math
import tensornetwork as tn
from tensornetwork.network_components import Node, BaseNode
from typing import Any, List, Optional, Text, Type, Union, Dict, Sequence
import threading
class Sweeper:
"""
Represents the sweeping algorithm. It should be reinstantiated
per epoch.
"""
def __init__(self, mps: MPS, Xs_batch: np.ndarray, Ys_batch: np.ndarray, feature_map=None):
self.mps = mps
self.Xs = Xs_batch
self.Ys = Ys_batch
# Left-most projection node for each vector in Xs_batch
# This is so we don't have to recompute it after every translation
self.projections : List[Node] = Xs_batch.shape[1] * [None]
self.feature_map = feature_map
def __leftmost__(self):
return self.mps.output_idx == 0
def __contract_with_right(self, input_tensor: List[Node]) -> Union[BaseNode, Node]:
right = self.mps.get_right()
MPS.connect_mps_tensors(right)
offset = self.mps.output_idx + 2
for i in range(len(right)):
right[i]['in'] ^ input_tensor[offset + i]['in']
time1 = time.time()
contracted_right = tn.contractors.greedy(right + input_tensor[offset:])
time2 = time.time()
print('Time {}'.format(time2 - time1))
contracted_right.add_axis_names(['in'])
return contracted_right
def __projection__(self, input_idx: int) -> List[Union[Node, BaseNode]]:
leftmost = self.__leftmost__()
input_tensor = FeatureTensor(self.Xs[:, input_idx]).get_nodes()
if leftmost:
node2 = input_tensor[0]
node3 = input_tensor[1]
node4 = self.__contract_with_right(input_tensor)
return [node2, node3, node4]
else:
node1 = self.projections[input_idx]
node2 = input_tensor[self.mps.output_idx]
node3 = input_tensor[self.mps.output_idx + 1]
node4 = self.__contract_with_right(input_tensor)
return [node1, node2, node3, node4]
def __gradient__(self, input_idx: int) -> np.ndarray:
"""
Calculates the gradient of the inner-product <MPS|input_tensor>
with respect to the bond tensor.
:param input_tensor: tensor formed from the input vector
:return: d<MPS|input_tensor>/d{bond}
"""
proj = self.__projection__(input_idx)
bond = self.mps.get_contracted_bond()
leftmost = self.__leftmost__()
if leftmost:
proj[0]['in'] ^ bond['in1']
proj[1]['in'] ^ bond['in2']
proj[2]['in'] ^ bond['r']
else:
proj[0]['in'] ^ bond['l']
proj[1]['in'] ^ bond['in1']
proj[2]['in'] ^ bond['in2']
proj[3]['in'] ^ bond['r']
mps_prod = tn.contractors.auto(proj + [bond]) - tn.Node(self.Ys[:, input_idx])
mps_prod.add_axis_names(['out'])
proj2 = tn.replicate_nodes(proj)
if leftmost:
in1_edge = proj2[0]['in']
in2_edge = proj2[1]['in']
r_edge = proj2[2]['in']
edge_order = [r_edge, in1_edge, in2_edge]
else:
l_edge = proj2[0]['in']
in1_edge = proj2[1]['in']
in2_edge = proj2[2]['in']
r_edge = proj2[3]['in']
edge_order = [l_edge, r_edge, in1_edge, in2_edge]
edge_order.append(mps_prod['out'])
return tn.contractors.auto([mps_prod] + proj2, output_edge_order=edge_order).tensor
def __precompute_projections__(self, input_idx: int):
# We have this redefinition bc this method is called after we update the MPS
# and translate its output label to the right
feature = FeatureTensor(self.Xs[:, input_idx])
output_idx = self.mps.output_idx - 1
leftmost = output_idx == 0
input_tensor = feature.get_nodes()
if leftmost:
node2_cpy = input_tensor[0].copy()
above_node2 = self.mps.tensors[0].copy()
node2_cpy['in'] ^ above_node2['in']
precomputed_proj = tn.contractors.auto([node2_cpy, above_node2])
precomputed_proj.add_axis_names(['in'])
self.projections[input_idx] = precomputed_proj
else:
node1_cpy = self.projections[input_idx]
node2_cpy = input_tensor[output_idx]
above_node2 = self.mps.tensors[output_idx].copy()
node1_cpy['in'] ^ above_node2['l']
node2_cpy['in'] ^ above_node2['in']
precomputed_proj = tn.contractors.auto([node1_cpy, node2_cpy, above_node2])
precomputed_proj.add_axis_names(['in'])
self.projections[input_idx] = precomputed_proj
def __add_newbond_names__(self, bond_tensor):
output_idx = self.mps.output_idx
if output_idx == 0:
bond_tensor.add_axis_names(['r', 'in1', 'in2', 'out'])
elif output_idx == len(self.mps.tensors) - 2:
bond_tensor.add_axis_names(['l', 'in1', 'in2', 'out'])
else:
bond_tensor.add_axis_names(['l', 'r', 'in1', 'in2', 'out'])
def translate(self, alpha: float, max_singular_values: int, num_threads: int) -> bool:
B = self.Xs.shape[1]
num_per_thd = int(B / num_threads)
bond = self.mps.get_contracted_bond()
bond_shape = bond.shape
leftmost = self.__leftmost__()
if leftmost:
grads = np.zeros((num_threads, bond_shape[0], bond_shape[1], bond_shape[2], bond_shape[3]))
else:
grads = np.zeros((num_threads, bond_shape[0], bond_shape[1], bond_shape[2], bond_shape[3], bond_shape[4]))
iter_barrier = threading.Barrier(num_threads + 1)
def thread_main(ithread):
grad = np.zeros(bond_shape)
for i in np.arange(ithread * num_per_thd, (ithread + 1) * num_per_thd):
grad = grad + self.__gradient__(i)
grads[ithread] = grad
iter_barrier.wait()
worker_threads = [threading.Thread(target=thread_main, args=(it,)) for it in range(num_threads)]
for t in worker_threads:
t.start()
iter_barrier.wait()
for t in worker_threads:
t.join()
# Here, all gradients have been calculated
total_grad = np.sum(grads, axis=0)
new_bond = tn.Node(bond.tensor - alpha * total_grad)
self.__add_newbond_names__(new_bond)
self.mps.update_bond(new_bond, max_singular_values)
for i in range(B):
self.__precompute_projections__(i)
return self.mps.output_idx < len(self.mps.tensors) - 2