Skip to content

Commit

Permalink
Interface for DPPY interoperability (#788)
Browse files Browse the repository at this point in the history
* added partition interface to DNDarray

* added 'locals' key to partition interface

* renamed locals to lcls to avoid global name

* corrected format of locals

* renamed dunder class attr of DNDarray to __partitioned__

* corrected split=0 case, corrected DNDarray property to be 'partitioned'

* DNDarray.__partitioned__ -> __partitions_dict__, DNDarray.partitioned -> __partitioned__

* added tests for partitioned attribute

* minor changes to test cases to check that things after the resplit are taken care of

* split=None tests

* changelog update

* added 'get' attributed to __partitioned__ to get a tile from a DNDarray

* reduced level of abstraction for __partitioned__['get']

* adding from_partitioned; aligning __partitioned__ with current spec

* updating from_partitioned function

* added nonzero split support to from partition dictionary, added tests, added factory function for building a dndarry from a partition dictionary

* Ensure  is None when virtually resplitting to None on 1 process

---------

Co-authored-by: Frank Schlimbach <[email protected]>
Co-authored-by: Claudia Comito <[email protected]>
Co-authored-by: Claudia Comito <[email protected]>
  • Loading branch information
4 people authored Feb 9, 2023
1 parent e5e8a96 commit bcea48a
Show file tree
Hide file tree
Showing 5 changed files with 352 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ Example on 2 processes:

### Misc.
- [#761](https://github.com/helmholtz-analytics/heat/pull/761) New feature: `result_type`
- [#788](https://github.com/helmholtz-analytics/heat/pull/788) Added the partition interface `DNDarray` for use with DPPY
- [#794](https://github.com/helmholtz-analytics/heat/pull/794) New feature: `meshgrid`
- [#821](https://github.com/helmholtz-analytics/heat/pull/821) Enhancement: it is no longer necessary to load-balance an imbalanced `DNDarray` before gathering it onto all processes. In short: `ht.resplit(array, None)` now works on imbalanced arrays as well.

Expand Down
121 changes: 121 additions & 0 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
self.__ishalo = False
self.__halo_next = None
self.__halo_prev = None
self.__partitions_dict__ = None
self.__lshape_map = None

# check for inconsistencies between torch and heat devices
Expand Down Expand Up @@ -183,6 +184,24 @@ def ndim(self) -> int:
"""
return len(self.__gshape)

@property
def __partitioned__(self) -> dict:
"""
This will return a dictionary containing information useful for working with the partitioned
data. These items include the shape of the data on each process, the starting index of the data
that a process has, the datatype of the data, the local devices, as well as the global
partitioning scheme.
An example of the output and shape is shown in :func:`ht.core.DNDarray.create_partition_interface <ht.core.DNDarray.create_partition_interface>`.
Returns
-------
dictionary with the partition interface
"""
if self.__partitions_dict__ is None:
self.__partitions_dict__ = self.create_partition_interface()
return self.__partitions_dict__

@property
def size(self) -> int:
"""
Expand Down Expand Up @@ -599,6 +618,104 @@ def create_lshape_map(self, force_check: bool = False) -> torch.Tensor:
self.__lshape_map = lshape_map
return lshape_map.clone()

def create_partition_interface(self):
"""
Create a partition interface in line with the DPPY proposal. This is subject to change.
The intention of this to facilitate the usage of a general format for the referencing of
distributed datasets.
An example of the output and shape is shown below.
__partitioned__ = {
'shape': (27, 3, 2),
'partition_tiling': (4, 1, 1),
'partitions': {
(0, 0, 0): {
'start': (0, 0, 0),
'shape': (7, 3, 2),
'data': tensor([...], dtype=torch.int32),
'location': [0],
'dtype': torch.int32,
'device': 'cpu'
},
(1, 0, 0): {
'start': (7, 0, 0),
'shape': (7, 3, 2),
'data': None,
'location': [1],
'dtype': torch.int32,
'device': 'cpu'
},
(2, 0, 0): {
'start': (14, 0, 0),
'shape': (7, 3, 2),
'data': None,
'location': [2],
'dtype': torch.int32,
'device': 'cpu'
},
(3, 0, 0): {
'start': (21, 0, 0),
'shape': (6, 3, 2),
'data': None,
'location': [3],
'dtype': torch.int32,
'device': 'cpu'
}
},
'locals': [(rank, 0, 0)],
'get': lambda x: x,
}
Returns
-------
dictionary containing the partition interface as shown above.
"""
lshape_map = self.create_lshape_map()
start_idx_map = torch.zeros_like(lshape_map)

part_tiling = [1] * self.ndim
lcls = [0] * self.ndim

z = torch.tensor([0], device=self.device.torch_device, dtype=self.dtype.torch_type())
if self.split is not None:
starts = torch.cat((z, torch.cumsum(lshape_map[:, self.split], dim=0)[:-1]), dim=0)
lcls[self.split] = self.comm.rank
part_tiling[self.split] = self.comm.size
else:
starts = torch.zeros(self.ndim, dtype=torch.int, device=self.device.torch_device)

start_idx_map[:, self.split] = starts

partitions = {}
base_key = [0] * self.ndim
for r in range(self.comm.size):
if self.split is not None:
base_key[self.split] = r
dat = None if r != self.comm.rank else self.larray
else:
dat = self.larray
partitions[tuple(base_key)] = {
"start": tuple(start_idx_map[r].tolist()),
"shape": tuple(lshape_map[r].tolist()),
"data": dat,
"location": [r],
"dtype": self.dtype.torch_type(),
"device": self.device.torch_device,
}

partition_dict = {
"shape": self.gshape,
"partition_tiling": tuple(part_tiling),
"partitions": partitions,
"locals": [tuple(lcls)],
"get": lambda x: x,
}

self.__partitions_dict__ = partition_dict

return partition_dict

def __float__(self) -> DNDarray:
"""
Float scalar casting.
Expand Down Expand Up @@ -1270,9 +1387,13 @@ def resplit_(self, axis: int = None):
# early out for unchanged content
if self.comm.size == 1:
self.__split = axis
if axis is None:
self.__partitions_dict__ = None
if axis == self.split:
return self

self.__partitions_dict__ = None

if axis is None:
gathered = torch.empty(
self.shape, dtype=self.dtype.torch_type(), device=self.device.torch_device
Expand Down
157 changes: 154 additions & 3 deletions heat/core/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
from . import devices
from . import types


__all__ = [
"arange",
"array",
"asarray",
"empty",
"empty_like",
"eye",
"from_partitioned",
"from_partition_dict",
"full",
"full_like",
"linspace",
Expand All @@ -42,7 +43,7 @@ def arange(
dtype: Optional[Type[datatype]] = None,
split: Optional[int] = None,
device: Optional[Union[str, Device]] = None,
comm: Optional[Communication] = None
comm: Optional[Communication] = None,
) -> DNDarray:
"""
Return evenly spaced values within a given interval.
Expand Down Expand Up @@ -730,7 +731,7 @@ def __factory_like(
device: Device,
comm: Communication,
order: str = "C",
**kwargs
**kwargs,
) -> DNDarray:
"""
Abstracted '...-like' factory function for HeAT :class:`~heat.core.dndarray.DNDarray` initialization
Expand Down Expand Up @@ -792,6 +793,156 @@ def __factory_like(
return factory(shape, dtype=dtype, split=split, device=device, comm=comm, order=order, **kwargs)


def from_partitioned(x, comm: Optional[Communication] = None) -> DNDarray:
"""
Return a newly created DNDarray constructed from the '__partitioned__' attributed of the input object.
Memory of local partitions will be shared (zero-copy) as long as supported by data objects.
Currently supports numpy ndarrays and torch tensors as data objects.
Current limitations:
* Partitions must be ordered in the partition-grid by rank
* Only one split-axis
* Only one partition per rank
* Only SPMD-style __partitioned__
Parameters
----------
x : object
Requires x.__partitioned__
comm: Communication, optional
Handle to the nodes holding distributed parts or copies of this array.
See also
--------
:func:`ht.core.DNDarray.create_partition_interface <ht.core.DNDarray.create_partition_interface>`.
Raises
------
AttributeError
If not hasattr(x, "__partitioned__") or if underlying data has no dtype.
TypeError
If it finds an unsupported array types
RuntimeError
If other unsupported content is found.
Examples
--------
>>> import heat as ht
>>> a = ht.ones((44,55), split=0)
>>> b = ht.from_partitioned(a)
>>> assert (a==b).all()
>>> a[40] = 4711
>>> assert (a==b).all()
"""
comm = sanitize_comm(comm)
parted = x.__partitioned__
return __from_partition_dict_helper(parted, comm)


def from_partition_dict(parted: dict, comm: Optional[Communication] = None) -> DNDarray:
"""
Return a newly created DNDarray constructed from the '__partitioned__' attributed of the input object.
Memory of local partitions will be shared (zero-copy) as long as supported by data objects.
Currently supports numpy ndarrays and torch tensors as data objects.
Current limitations:
* Partitions must be ordered in the partition-grid by rank
* Only one split-axis
* Only one partition per rank
* Only SPMD-style __partitioned__
Parameters
----------
parted : dict
A partition dictionary used to create the new DNDarray
comm: Communication, optional
Handle to the nodes holding distributed parts or copies of this array.
See also
--------
:func:`ht.core.DNDarray.create_partition_interface <ht.core.DNDarray.create_partition_interface>`.
Raises
------
AttributeError
If not hasattr(x, "__partitioned__") or if underlying data has no dtype.
TypeError
If it finds an unsupported array types
RuntimeError
If other unsupported content is found.
Examples
--------
>>> import heat as ht
>>> a = ht.ones((44,55), split=0)
>>> b = ht.from_partition_dict(a.__partitioned__)
>>> assert (a==b).all()
>>> a[40] = 4711
>>> assert (a==b).all()
"""
comm = sanitize_comm(comm)
return __from_partition_dict_helper(parted, comm)


def __from_partition_dict_helper(parted: dict, comm: Communication):
# helper to create a DNDarray from a partition table (dictionary)
# the dictionary must be in the same form as the DNDarray.__partitioned__ property creates
if "locals" not in parted:
raise RuntimeError("Non-SPMD __partitioned__ not supported")
try:
gshape = parted["shape"]
except KeyError:
raise RuntimeError(
"partition dictionary must have a 'shape' entry, see DNDarray.create_partition_interface for more details"
)
try:
lparts = parted["locals"]
except KeyError:
raise RuntimeError(
"partition dictionary must have a 'local' entry, see DNDarray.create_partition_interface for more details"
)
if len(lparts) != 1:
raise RuntimeError("Only exactly one partition per rank supported (yet)")
parts = parted["partitions"]
lpart = parted["get"](parts[lparts[0]]["data"])
if isinstance(lpart, np.ndarray):
data = torch.from_numpy(lpart)
elif isinstance(lpart, torch.Tensor):
data = lpart
else:
raise TypeError(f"Only numpy arrays and torch tensors supported (not {type(lpart)}")
htype = types.canonical_heat_type(data.dtype)

# get split axis
gshape_list = list(gshape)
lshape_list = list(data.shape)
shape_diff = torch.tensor(
[g - l for g, l in zip(gshape_list, lshape_list)]
) # dont care about device
nz = torch.nonzero(shape_diff)

if nz.numel() > 1:
raise RuntimeError("only one split axis allowed, check the ")
elif nz.numel() == 1:
split = nz[0].item()
else:
split = None

expected = {
int(x["location"][0]): (
comm.chunk(gshape, split, x["location"][0])[1:],
(x["shape"], x["start"]),
)
for x in parts.values()
}
balanced = all(x[0][0] == x[1][0] for x in expected.values())

ret = DNDarray(
data, gshape, htype, split, devices.sanitize_device(None), sanitize_comm(comm), balanced
)
ret.__partitions_dict__ = parted

return ret


def full(
shape: Union[int, Sequence[int]],
fill_value: Union[int, float],
Expand Down
14 changes: 14 additions & 0 deletions heat/core/tests/test_dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,20 @@ def test_or(self):
ht.equal(int16_tensor | int16_vector, ht.bitwise_or(int16_tensor, int16_vector))
)

def test_partitioned(self):
a = ht.zeros((120, 120), split=0)
parted = a.__partitioned__
self.assertEqual(parted["shape"], (120, 120))
self.assertEqual(parted["partition_tiling"], (a.comm.size, 1))
self.assertEqual(parted["partitions"][(0, 0)]["start"], (0, 0))

a.resplit_(None)
self.assertIsNone(a.__partitions_dict__)
parted = a.__partitioned__
self.assertEqual(parted["shape"], (120, 120))
self.assertEqual(parted["partition_tiling"], (1, 1))
self.assertEqual(parted["partitions"][(0, 0)]["start"], (0, 0))

def test_redistribute(self):
# need to test with 1, 2, 3, and 4 dims
st = ht.zeros((50,), split=0)
Expand Down
Loading

0 comments on commit bcea48a

Please sign in to comment.