diff --git a/mpyc/mpctools.py b/mpyc/mpctools.py index ffe8127..e6d07f5 100644 --- a/mpyc/mpctools.py +++ b/mpyc/mpctools.py @@ -10,8 +10,12 @@ import operator +runtime = None -def reduce(f, x, iv=None): +_no_value = type('', (object,), {'__repr__': lambda self: ''}) + + +def reduce(f, x, initial=_no_value): """Apply associative function f of two arguments to the items of iterable x. The applications of f are arranged in a binary tree of logarithmic depth, @@ -22,19 +26,22 @@ def reduce(f, x, iv=None): and in this case f is not required to be associative; the arguments to f may even be of different types. - If iv is provided, it is placed before the items of x (hence effectively - serves as a default when x is empty). If iv is not given and x contains - only one item, that item is returned. + If initial is provided (possibly equal to None), it is placed before the + items of x (hence effectively serves as a default when x is empty). If + initial is not given and x contains only one item, that item is returned. """ x = list(x) - if iv is not None: - x.insert(0, iv) + if initial is not _no_value: + x.insert(0, initial) + if not x: + raise TypeError('reduce() of empty sequence with no initial value') + while len(x) > 1: x[len(x)%2:] = (f(x[i], x[i+1]) for i in range(len(x)%2, len(x), 2)) return x[0] -def accumulate(x, f=operator.add, iv=None): +def accumulate(x, f=operator.add, initial=_no_value): """For associative function f of two arguments, make an iterator that returns the accumulated results over all (nonempty) prefixes of the given iterable x. @@ -46,22 +53,36 @@ def accumulate(x, f=operator.add, iv=None): the applications of f in a linear fashion, as in general it cannot be assumed that f is associative (and that the arguments to f are even of the same type). - If iv is provided, the accumulation leads off with this initial value so that - the output has one more element than the input iterable. Otherwise, the number - of elements output matches the input iterable x. + If initial is provided (possibly equal to None), the accumulation leads off + with this initial value so that the output has one more element than the input + iterable. Otherwise, the number of elements output matches the input iterable x. """ x = list(x) - if iv is not None: - x.insert(0, iv) - - def acc(i, j): - if j == i+1: - return x[i:j] - - h = (i + j)//2 - y = acc(i, h) - a = y[-1] - y.extend(f(a, b) for b in acc(h, j)) - return y + if initial is not _no_value: + x.insert(0, initial) + n = len(x) + if runtime.options.no_prss and n >= 32: + # Minimize f-complexity of acc(0, n). + # For n=2^k, k>=0: f-complexity=2n - 2 - log2 n calls, f-depth=max(2 log2 n - 1, 0) rounds. + def acc(i, j): + h = (i + j)//2 + if i < h: + acc(i, h) + a = x[h-1] + if i: + x[h-1] = f(x[i-1], a) + acc(h, j) + x[j-1] = f(a, x[j-1]) + else: + # Minimize f-depth of acc(0, n) + # For n=2^k, k>=0: f-complexity=(n/2) log2 n calls, f-depth=log2 n rounds. + def acc(i, j): + h = (i + j)//2 + if i < h: + acc(i, h) + a = x[h-1] + acc(h, j) + x[h:j] = (f(a, b) for b in x[h:j]) - yield from acc(0, len(x)) + acc(0, n) + return iter(x) diff --git a/mpyc/runtime.py b/mpyc/runtime.py index b4af226..12568e8 100644 --- a/mpyc/runtime.py +++ b/mpyc/runtime.py @@ -4310,6 +4310,7 @@ def hop(a): rt = Runtime(pid, parties, options) sectypes.runtime = rt asyncoro.runtime = rt + mpyc.mpctools.runtime = rt mpyc.seclists.runtime = rt mpyc.secgroups.runtime = rt mpyc.random.runtime = rt diff --git a/tests/test_mpctools.py b/tests/test_mpctools.py index 19588f5..359ab95 100644 --- a/tests/test_mpctools.py +++ b/tests/test_mpctools.py @@ -28,23 +28,33 @@ def test_reduce(self): self.assertEqual(mpc.run(mpc.output(red(mpc.schur_prod, y))), [40320]*2) self.assertEqual(mpc.run(mpc.output(red(mpc.matrix_add, z)[0])), [36]*2) self.assertEqual(mpc.run(mpc.output(red(mpc.matrix_prod, z)[1])), [5160960]*2) + self.assertRaises(TypeError, red, mpc.add, []) def test_accumulate(self): secint = mpc.SecInt() - r = range(1, 9) + r = range(1, 13) + r3 = range(1, 39) x = [secint(i) for i in r] + x3 = [secint(i) for i in r3] + a = secint(10) for acc in itertools.accumulate, mpyc.mpctools.accumulate: - self.assertEqual(mpc.run(mpc.output(list(acc(x)))), - list(itertools.accumulate(r))) + self.assertEqual(mpc.run(mpc.output(list(acc(x3)))), + list(itertools.accumulate(r3))) + mpc.options.no_prss = not mpc.options.no_prss + self.assertEqual(mpc.run(mpc.output(list(acc(x3)))), + list(itertools.accumulate(r3))) + mpc.options.no_prss = not mpc.options.no_prss self.assertEqual(mpc.run(mpc.output(list(acc(x, mpc.mul)))), list(itertools.accumulate(r, operator.mul))) self.assertEqual(mpc.run(mpc.output(list(acc(x, mpc.min)))), list(itertools.accumulate(r, min))) self.assertEqual(mpc.run(mpc.output(list(acc(x, mpc.max)))), list(itertools.accumulate(r, max))) - a = secint(10) - self.assertEqual(mpc.run(mpc.output(list(acc(itertools.repeat(a, 5), mpc.mul, secint(1))))), - [1, 10, 10**2, 10**3, 10**4, 10**5]) + a5 = itertools.repeat(a, 5) + self.assertEqual(mpc.run(mpc.output(list(acc(a5, mpc.mul, initial=secint(1))))), + [1, 10, 10**2, 10**3, 10**4, 10**5]) + self.assertRaises(TypeError, acc, None) + self.assertEqual(mpc.run(mpc.output(list(acc([])))), []) if __name__ == "__main__":