Skip to content

Commit

Permalink
Refactor & unittest asyncoro.MessageExchanger.
Browse files Browse the repository at this point in the history
  • Loading branch information
lschoe authored Mar 7, 2024
1 parent 12e678a commit e77f9a8
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 23 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# -- Project information -----------------------------------------------------

project = 'MPyC'
copyright = '2018 - 2023, Berry Schoenmakers'
copyright = '2018 - 2024, Berry Schoenmakers'
author = 'Berry Schoenmakers'

from mpyc.__init__ import __version__
Expand Down
21 changes: 6 additions & 15 deletions mpyc/asyncoro.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, rt, peer_pid=None):
The connection between the two parties will be set up with one party
listening (as server) for the other party to connect (as client).
If peer_pid=None, party rt.pid starts as server and the peer start as
If peer_pid=None, party rt.pid starts as server and the peer starts as
client, and the other way around otherwise. Once the connection is made,
the client will immediately send its pid to the server.
"""
Expand All @@ -36,12 +36,6 @@ def __init__(self, rt, peer_pid=None):
self.transport = None
self.nbytes_sent = 0

def _key_transport_done(self):
rt = self.runtime
rt.parties[self.peer_pid].protocol = self
if all(p.protocol is not None for p in rt.parties):
rt.parties[rt.pid].protocol.set_result(None)

def connection_made(self, transport):
"""Called when a connection is made.
Expand All @@ -55,7 +49,7 @@ def connection_made(self, transport):
if not rt.options.no_prss:
pid_keys.extend(rt._prss_keys_to_peer(self.peer_pid)) # send PRSS keys
transport.writelines(pid_keys)
self._key_transport_done()
rt.set_protocol(self.peer_pid, self)

def send(self, pc, payload):
"""Send payload labeled with pc to the peer.
Expand Down Expand Up @@ -84,19 +78,19 @@ def data_received(self, data):
return

peer_pid = int.from_bytes(data[:2], 'little')
del data[:2]
rt = self.runtime
if not rt.options.no_prss:
len_packet = rt._prss_keys_from_peer(peer_pid)
if len(data) < len_packet:
if len(data) < len_packet + 2:
return

# record new protocol peer
self.peer_pid = peer_pid
del data[:2]
if not rt.options.no_prss:
rt._prss_keys_from_peer(peer_pid, data) # store PRSS keys from peer
del data[:len_packet]
self._key_transport_done()
rt.set_protocol(self.peer_pid, self)

while len(data) >= 12:
pc, payload_size = unpack_from('<qI', data)
Expand Down Expand Up @@ -128,10 +122,7 @@ def connection_lost(self, exc):
if exc:
raise exc

rt = self.runtime
rt.parties[self.peer_pid].protocol = None
if all(p.protocol is None for p in rt.parties if p.pid != rt.pid):
rt.parties[rt.pid].protocol.set_result(None)
self.runtime.unset_protocol(self.peer_pid)

def close_connection(self):
"""Close connection with the peer."""
Expand Down
8 changes: 2 additions & 6 deletions mpyc/fingroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,12 +1386,8 @@ def _calculate_gap(l):
Gap must be a multiple of 4.
"""
gap = l
while True:
gap1 = round(3.5 * l * math.log(gap))
if gap != gap1:
gap = gap1
else:
break
while gap != (gap := round(3.5 * l * math.log(gap))):
pass
return gap - gap%4


Expand Down
14 changes: 13 additions & 1 deletion mpyc/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ async def barrier(self, name=None):

async def throttler(self, load_percentage=1.0, name=None):
"""Throttle runtime by given percentage (default 1.0), using optional name for barrier."""
assert 0.0 <= load_percentage <= 1.0, 'percentage as decimal fraction between 0.0 and 1.0'
if not 0.0 <= load_percentage <= 1.0:
raise ValueError('percentage required as decimal fraction between 0.0 and 1.0')

self.aggregate_load += load_percentage * 10000
if self.aggregate_load < 10000:
return
Expand Down Expand Up @@ -4172,6 +4174,16 @@ async def np_unit_vector(self, a, n):
u = np.roll(u, c)
return u

def set_protocol(self, peer_pid, protocol):
self.parties[peer_pid].protocol = protocol
if all(p.protocol is not None for p in self.parties if p.pid != self.pid):
self.parties[self.pid].protocol.set_result(None)

def unset_protocol(self, peer_pid):
self.parties[peer_pid].protocol = None
if all(p.protocol is None for p in self.parties if p.pid != self.pid):
self.parties[self.pid].protocol.set_result(None)


@dataclass
class Party:
Expand Down
83 changes: 83 additions & 0 deletions tests/test_asyncoro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import unittest
from unittest.mock import Mock
from asyncio import Future, Transport
import mpyc.asyncoro
from mpyc.runtime import Party, Runtime, mpc


class Arithmetic(unittest.TestCase):

def test_message_exchanger(self):
# two parties, each with its own MessageExchanger()
rt0 = Runtime(0, [Party(0), Party(1)], mpc.options) # NB: with its own Party() instances
rt1 = Runtime(1, [Party(0), Party(1)], mpc.options) # NB: with its own Party() instances
mx0 = mpyc.asyncoro.MessageExchanger(rt0, 1) # client
mx1 = mpyc.asyncoro.MessageExchanger(rt1) # server

# test: client connects with server
rt0.parties[0].protocol = Future()
rt0.parties[1].protocol = None
rt1.parties[0].protocol = None
rt1.parties[1].protocol = Future()
transport0 = Mock(Transport)
transport1 = Mock(Transport)

def _writelines(s):
transport0.s = b''.join(s)
transport0.writelines = _writelines

def _write(s):
transport1.s = s
transport1.write = _write

mx0.connection_made(transport0)
mx1.connection_made(transport1)
data = transport0.s
mx1.data_received(data[:1])
mx1.data_received(data[1:5])
mx1.data_received(data[5:])

# test: message from server received after client expects it
pc0 = rt0._program_counter[0]
pc1 = rt1._program_counter[0]
self.assertEqual(pc0, pc1)
payload = b'123'
mx1.send(pc1, payload)
fut = mx0.receive(pc0)
data = transport1.s
mx0.data_received(data[:12])
mx0.data_received(data[12:])
self.assertEqual(fut.result(), payload)

# message from server received before client expects it
pc0 += 1
pc1 += 1
payload = b'456'
mx1.send(pc1, payload)
data = transport1.s
mx0.data_received(data[:12])
mx0.data_received(data[12:])
msg = mx0.receive(pc0)
self.assertEqual(msg, payload)

# close connections
rt0.parties[0].protocol = Future()
rt1.parties[1].protocol = Future()
mx0.close_connection()
self.assertRaises(Exception, mx0.connection_lost, Exception())
mx1.close_connection()
mx0.connection_lost(None)
mx1.connection_lost(None)

def test_gather_futures(self):
self.assertEqual(mpc.run(mpyc.asyncoro.gather_shares(mpc, None)), None)
mpc.options.no_async = False
fut = Future()
gut = mpyc.asyncoro.gather_shares(mpc, [fut, fut])
fut.set_result(42)
self.assertEqual(mpc.run(gut), [42, 42])
mpc.options.no_async = True


if __name__ == "__main__":
unittest.main()
13 changes: 13 additions & 0 deletions tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,10 +477,23 @@ def test_async(self):
a = mpc.SecInt()(7)
b = a * a
mpc.run(mpc.barrier())
mpc.run(mpc.throttler(0.5))
mpc.run(mpc.throttler(0.5))
self.assertRaises(ValueError, mpc.run, mpc.throttler(1.5))
self.assertEqual(mpc.run(mpc.output(b)), 49)
self.assertEqual(mpc.run(mpc.output(mpc.scalar_mul(a, [-a, a]))), [-49, 49])
mpc.options.no_async = True

@unittest.skipIf(mpc.options.no_prss, 'PRSS (pseudorandom secret sharing) disabled')
def test_prss_keys(self):
from mpyc.runtime import Party, Runtime
p0 = Party(0)
p1 = Party(1)
rt0 = Runtime(0, [p0, p1], mpc.options)
rt1 = Runtime(1, [p0, p1], mpc.options)
rt1._prss_keys_from_peer(0, rt0._prss_keys_to_peer(1)[0])
self.assertEqual(rt0._prss_keys, rt1._prss_keys)

def test_io(self):
x = ({4, 3}, [1 - 1j, 2.5], 0, range(7))
self.assertEqual(mpc.run(mpc.transfer(x))[0], x)
Expand Down

0 comments on commit e77f9a8

Please sign in to comment.