From 10950f7f370a0a60a929ffb90e06899f80cc1fee Mon Sep 17 00:00:00 2001 From: Berry Schoenmakers Date: Mon, 11 Mar 2024 18:09:32 +0100 Subject: [PATCH] Complete mpc.np_divide(). --- mpyc/__init__.py | 6 +++--- mpyc/runtime.py | 8 ++------ mpyc/sectypes.py | 3 +++ tests/test_runtime.py | 19 +++++++++++++++---- 4 files changed, 23 insertions(+), 13 deletions(-) diff --git a/mpyc/__init__.py b/mpyc/__init__.py index 06ce6c1..1214912 100644 --- a/mpyc/__init__.py +++ b/mpyc/__init__.py @@ -29,7 +29,7 @@ and statistics (securely mimicking Python’s statistics module). """ -__version__ = '0.9.9' +__version__ = '0.9.10' __license__ = 'MIT License' import os @@ -170,10 +170,10 @@ def get_arg_parser(): if importlib.util.find_spec('winloop' if sys.platform.startswith('win32') else 'uvloop'): # uvloop (winloop) package available if options.no_uvloop or env_no_uvloop: - logging.info(f'Use of package uvloop (winloop) inside MPyC disabled.') + logging.info('Use of package uvloop (winloop) inside MPyC disabled.') elif sys.platform.startswith('win32'): from winloop import EventLoopPolicy - logging.debug(f'Load winloop') + logging.debug('Load winloop') else: from uvloop import EventLoopPolicy, _version logging.debug(f'Load uvloop version {_version.__version__}') diff --git a/mpyc/runtime.py b/mpyc/runtime.py index e517137..bbe29c2 100644 --- a/mpyc/runtime.py +++ b/mpyc/runtime.py @@ -1115,8 +1115,6 @@ def div(self, a, b): if f: if isinstance(b, (int, float)): c = 1/b - if c.is_integer(): - c = round(c) else: c = b.reciprocal() << f else: @@ -1140,17 +1138,15 @@ def np_divide(self, a, b): # isinstance(a, self.SecureArray) ensured if f: - if isinstance(b, (int, float)): + if isinstance(b, (int, float, np.ndarray)): c = 1/b - if c.is_integer(): - c = round(c) elif isinstance(b, self.SecureFixedPoint): c = self._rec(b) else: c = b.reciprocal() << f else: if not isinstance(b, field.array): - b = field.array(b) # TODO: see if this can be used for case f != 0 as well + b = field.array(b) c = b.reciprocal() return self.np_multiply(a, c) diff --git a/mpyc/sectypes.py b/mpyc/sectypes.py index ab7db3e..909deba 100644 --- a/mpyc/sectypes.py +++ b/mpyc/sectypes.py @@ -56,6 +56,9 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): if op == operator.sub: return inputs[1].__rsub__(inputs[0]) + if op == operator.truediv: + return inputs[1].__rtruediv__(inputs[0]) + return op(inputs[1], inputs[0]) if op := unary_ops.get(ufunc): diff --git a/tests/test_runtime.py b/tests/test_runtime.py index 0f06788..b23ec7d 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -238,6 +238,7 @@ def __lt__(self, other): @unittest.skipIf(not np, 'NumPy not available or inside MPyC disabled') def test_secfxp_array(self): np.assertEqual = np.testing.assert_array_equal + np.assertAlmostEqual = np.testing.assert_allclose secfxp = mpc.SecFxp(12) a = np.array([[-1.5, 2.5], [4.5, -8.5]]) @@ -252,6 +253,8 @@ def test_secfxp_array(self): np.assertEqual(mpc.run(mpc.output(c * np.array([1.5, 2.5]))), a * np.array([1.5, 2.5])) np.assertEqual(mpc.run(mpc.output(c * secfxp(2.5))), a * 2.5) np.assertEqual(mpc.run(mpc.output(c * 2.5)), a * 2.5) + np.assertEqual(mpc.run(mpc.output(c / secfxp.field(2))), a / 2) + np.assertEqual(mpc.run(mpc.output(c / secfxp.field.array([2]))), a / 2) # NB: NumPy dispatcher converts np.int8 to int np.assertEqual(mpc.run(mpc.output(c * np.int8(2))), a * 2) @@ -278,10 +281,17 @@ def test_secfxp_array(self): f = 32 secfxp = mpc.SecFxp(2*f) c = secfxp.array(a) - np.testing.assert_allclose(mpc.run(mpc.output(c / 2.45)), a / 2.45, rtol=0, atol=2**(1-f)) - np.testing.assert_allclose(mpc.run(mpc.output(c / 2.5)), a / 2.5, rtol=0, atol=2**(2-f)) - np.testing.assert_allclose(mpc.run(mpc.output(1 / c)), 1 / a, rtol=0, atol=2**(1-f)) - np.testing.assert_allclose(mpc.run(mpc.output(c / c)), 1, rtol=0, atol=2**(3-f)) + np.assertAlmostEqual(mpc.run(mpc.output(c / 0.5)), a / 0.5, rtol=0, atol=0) + np.assertAlmostEqual(mpc.run(mpc.output(c / 2.45)), a / 2.45, rtol=0, atol=2**(1-f)) + np.assertAlmostEqual(mpc.run(mpc.output(c / 2.5)), a / 2.5, rtol=0, atol=2**(2-f)) + np.assertAlmostEqual(mpc.run(mpc.output(c / c[0, 1])), a / 2.5, rtol=0, atol=2**(3-f)) + np.assertAlmostEqual(mpc.run(mpc.output(1 / c)), 1 / a, rtol=0, atol=2**(1-f)) + np.assertAlmostEqual(mpc.run(mpc.output(secfxp(1.5) / c)), 1.5 / a, rtol=0, atol=2**(1-f)) + np.assertAlmostEqual(mpc.run(mpc.output(1.5 / c)), 1.5 / a, rtol=0, atol=2**(1-f)) + np.assertAlmostEqual(mpc.run(mpc.output(a / c)), 1, rtol=0, atol=2**(3-f)) + np.assertAlmostEqual(mpc.run(mpc.output((2*a).astype(int) / c)), 2, rtol=0, atol=2**(4-f)) + np.assertAlmostEqual(mpc.run(mpc.output(c / a)), 1, rtol=0, atol=2**(0-f)) + np.assertAlmostEqual(mpc.run(mpc.output(c / c)), 1, rtol=0, atol=2**(3-f)) np.assertEqual(mpc.run(mpc.output(np.equal(c, c))), True) np.assertEqual(mpc.run(mpc.output(np.equal(c, 0))), False) np.assertEqual(mpc.run(mpc.output(np.sum(c, axis=(-2, 1)))), np.sum(a, axis=(-2, 1))) @@ -824,6 +834,7 @@ def test_secfxp(self): self.assertAlmostEqual(mpc.run(mpc.output(c / d)), t, delta=2**(3-f)) t = -s[3] / s[2] self.assertAlmostEqual(mpc.run(mpc.output(-d / c)), t, delta=2**(3-f)) + self.assertEqual(mpc.run(mpc.output(secfxp(2) / secfxp.field(2))), 1) self.assertEqual(mpc.run(mpc.output(mpc.sgn(+a))), s[0] > 0) self.assertEqual(mpc.run(mpc.output(mpc.sgn(-a))), -(s[0] > 0))