Skip to content

Commit

Permalink
Enhance mpyc.mpctools.accumulate().
Browse files Browse the repository at this point in the history
  • Loading branch information
lschoe authored Feb 20, 2024
1 parent 58332e6 commit d4c4dae
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 29 deletions.
67 changes: 44 additions & 23 deletions mpyc/mpctools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@

import operator

runtime = None

def reduce(f, x, iv=None):
_no_value = type('', (object,), {'__repr__': lambda self: '<no value>'})


def reduce(f, x, initial=_no_value):
"""Apply associative function f of two arguments to the items of iterable x.
The applications of f are arranged in a binary tree of logarithmic depth,
Expand All @@ -22,19 +26,22 @@ def reduce(f, x, iv=None):
and in this case f is not required to be associative; the arguments to f
may even be of different types.
If iv is provided, it is placed before the items of x (hence effectively
serves as a default when x is empty). If iv is not given and x contains
only one item, that item is returned.
If initial is provided (possibly equal to None), it is placed before the
items of x (hence effectively serves as a default when x is empty). If
initial is not given and x contains only one item, that item is returned.
"""
x = list(x)
if iv is not None:
x.insert(0, iv)
if initial is not _no_value:
x.insert(0, initial)
if not x:
raise TypeError('reduce() of empty sequence with no initial value')

while len(x) > 1:
x[len(x)%2:] = (f(x[i], x[i+1]) for i in range(len(x)%2, len(x), 2))
return x[0]


def accumulate(x, f=operator.add, iv=None):
def accumulate(x, f=operator.add, initial=_no_value):
"""For associative function f of two arguments, make an iterator that returns
the accumulated results over all (nonempty) prefixes of the given iterable x.
Expand All @@ -46,22 +53,36 @@ def accumulate(x, f=operator.add, iv=None):
the applications of f in a linear fashion, as in general it cannot be assumed
that f is associative (and that the arguments to f are even of the same type).
If iv is provided, the accumulation leads off with this initial value so that
the output has one more element than the input iterable. Otherwise, the number
of elements output matches the input iterable x.
If initial is provided (possibly equal to None), the accumulation leads off
with this initial value so that the output has one more element than the input
iterable. Otherwise, the number of elements output matches the input iterable x.
"""
x = list(x)
if iv is not None:
x.insert(0, iv)

def acc(i, j):
if j == i+1:
return x[i:j]

h = (i + j)//2
y = acc(i, h)
a = y[-1]
y.extend(f(a, b) for b in acc(h, j))
return y
if initial is not _no_value:
x.insert(0, initial)
n = len(x)
if runtime.options.no_prss and n >= 32:
# Minimize f-complexity of acc(0, n).
# For n=2^k, k>=0: f-complexity=2n - 2 - log2 n calls, f-depth=max(2 log2 n - 1, 0) rounds.
def acc(i, j):
h = (i + j)//2
if i < h:
acc(i, h)
a = x[h-1]
if i:
x[h-1] = f(x[i-1], a)
acc(h, j)
x[j-1] = f(a, x[j-1])
else:
# Minimize f-depth of acc(0, n)
# For n=2^k, k>=0: f-complexity=(n/2) log2 n calls, f-depth=log2 n rounds.
def acc(i, j):
h = (i + j)//2
if i < h:
acc(i, h)
a = x[h-1]
acc(h, j)
x[h:j] = (f(a, b) for b in x[h:j])

yield from acc(0, len(x))
acc(0, n)
return iter(x)
1 change: 1 addition & 0 deletions mpyc/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4310,6 +4310,7 @@ def hop(a):
rt = Runtime(pid, parties, options)
sectypes.runtime = rt
asyncoro.runtime = rt
mpyc.mpctools.runtime = rt
mpyc.seclists.runtime = rt
mpyc.secgroups.runtime = rt
mpyc.random.runtime = rt
Expand Down
22 changes: 16 additions & 6 deletions tests/test_mpctools.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,33 @@ def test_reduce(self):
self.assertEqual(mpc.run(mpc.output(red(mpc.schur_prod, y))), [40320]*2)
self.assertEqual(mpc.run(mpc.output(red(mpc.matrix_add, z)[0])), [36]*2)
self.assertEqual(mpc.run(mpc.output(red(mpc.matrix_prod, z)[1])), [5160960]*2)
self.assertRaises(TypeError, red, mpc.add, [])

def test_accumulate(self):
secint = mpc.SecInt()
r = range(1, 9)
r = range(1, 13)
r3 = range(1, 39)
x = [secint(i) for i in r]
x3 = [secint(i) for i in r3]
a = secint(10)
for acc in itertools.accumulate, mpyc.mpctools.accumulate:
self.assertEqual(mpc.run(mpc.output(list(acc(x)))),
list(itertools.accumulate(r)))
self.assertEqual(mpc.run(mpc.output(list(acc(x3)))),
list(itertools.accumulate(r3)))
mpc.options.no_prss = not mpc.options.no_prss
self.assertEqual(mpc.run(mpc.output(list(acc(x3)))),
list(itertools.accumulate(r3)))
mpc.options.no_prss = not mpc.options.no_prss
self.assertEqual(mpc.run(mpc.output(list(acc(x, mpc.mul)))),
list(itertools.accumulate(r, operator.mul)))
self.assertEqual(mpc.run(mpc.output(list(acc(x, mpc.min)))),
list(itertools.accumulate(r, min)))
self.assertEqual(mpc.run(mpc.output(list(acc(x, mpc.max)))),
list(itertools.accumulate(r, max)))
a = secint(10)
self.assertEqual(mpc.run(mpc.output(list(acc(itertools.repeat(a, 5), mpc.mul, secint(1))))),
[1, 10, 10**2, 10**3, 10**4, 10**5])
a5 = itertools.repeat(a, 5)
self.assertEqual(mpc.run(mpc.output(list(acc(a5, mpc.mul, initial=secint(1))))),
[1, 10, 10**2, 10**3, 10**4, 10**5])
self.assertRaises(TypeError, acc, None)
self.assertEqual(mpc.run(mpc.output(list(acc([])))), [])


if __name__ == "__main__":
Expand Down

0 comments on commit d4c4dae

Please sign in to comment.