Skip to content

Commit

Permalink
Transport with as close-a-type as possible
Browse files Browse the repository at this point in the history
  • Loading branch information
cphyc committed Nov 3, 2023
1 parent 2bb07ab commit c9fc44d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 13 deletions.
1 change: 0 additions & 1 deletion yt/frontends/ramses/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,6 @@ def _initialize_oct_handler(self):
RAMSESDomainFile(self.dataset, i + 1)
for i in parallel_objects(cpu_list, method="sequential")
]
print(self.domains)

total_octs = sum(
dom.local_oct_count for dom in self.domains # + dom.ngridbound.sum()
Expand Down
40 changes: 28 additions & 12 deletions yt/utilities/parallel_tools/parallel_analysis_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import traceback
from functools import wraps
from io import StringIO
from typing import Literal, Union
from typing import Any, Literal, Union

import numpy as np
from more_itertools import always_iterable
Expand Down Expand Up @@ -166,6 +166,15 @@ def get_mpi_type(dtype):
return val


def get_transport_dtype(dtype: type) -> Union[Any, str]:
mpi_type = get_mpi_type(dtype)
if mpi_type is None:
# Fallback to char
return MPI.CHAR, "c"
else:
return mpi_type, dtype.str


class ObjectIterator:
"""
This is a generalized class that accepts a list of objects and then
Expand Down Expand Up @@ -749,7 +758,6 @@ class Communicator:
comm = None
_grids = None
_distributed = None
__tocast = "c"

def __init__(self, comm=None):
self.comm = comm
Expand Down Expand Up @@ -1127,10 +1135,12 @@ def merge_quadtree_buffers(self, qt, merge_style):

def send_array(self, arr, dest, tag=0):
if not isinstance(arr, np.ndarray):
self.comm.send((None, None), dest=dest, tag=tag)
self.comm.send((None, None, None), dest=dest, tag=tag)
self.comm.send(arr, dest=dest, tag=tag)
return
tmp = arr.view(self.__tocast) # Cast to CHAR
mpi_type, transport_dtype = get_transport_dtype(arr.dtype)
tmp = arr.view(transport_dtype)

# communicate type and shape and optionally units
if isinstance(arr, YTArray):
unit_metadata = (str(arr.units), arr.units.registry.lut)
Expand All @@ -1140,13 +1150,17 @@ def send_array(self, arr, dest, tag=0):
unit_metadata += ("YTArray",)
else:
unit_metadata = ()
self.comm.send((arr.dtype.str, arr.shape) + unit_metadata, dest=dest, tag=tag)
self.comm.Send([arr, MPI.CHAR], dest=dest, tag=tag)
self.comm.send(
(arr.dtype.str, arr.shape, transport_dtype) + unit_metadata,
dest=dest,
tag=tag,
)
self.comm.Send([arr, mpi_type], dest=dest, tag=tag)
del tmp

def recv_array(self, source, tag=0):
metadata = self.comm.recv(source=source, tag=tag)
dt, ne = metadata[:2]
dt, ne, transport_dt = metadata[:3]
if ne is None and dt is None:
return self.comm.recv(source=source, tag=tag)
arr = np.empty(ne, dtype=dt)
Expand All @@ -1156,8 +1170,9 @@ def recv_array(self, source, tag=0):
arr = ImageArray(arr, units=metadata[2], registry=registry)
else:
arr = YTArray(arr, metadata[2], registry=registry)
tmp = arr.view(self.__tocast)
self.comm.Recv([tmp, MPI.CHAR], source=source, tag=tag)
mpi_type = get_mpi_type(transport_dt)
tmp = arr.view(transport_dt)
self.comm.Recv([tmp, mpi_type], source=source, tag=tag)
return arr

def alltoallv_array(self, send, total_size, offsets, sizes):
Expand All @@ -1170,7 +1185,8 @@ def alltoallv_array(self, send, total_size, offsets, sizes):
recv = np.array(recv)
return recv
offset = offsets[self.comm.rank]
tmp_send = send.view(self.__tocast)
mpi_type, transport_dtype = get_transport_dtype(send.dtype)
tmp_send = send.view(transport_dtype)
recv = np.empty(total_size, dtype=send.dtype)
if isinstance(send, YTArray):
# We assume send.units is consistent with the units
Expand All @@ -1183,9 +1199,9 @@ def alltoallv_array(self, send, total_size, offsets, sizes):
dtr = send.dtype.itemsize / tmp_send.dtype.itemsize # > 1
roff = [off * dtr for off in offsets]
rsize = [siz * dtr for siz in sizes]
tmp_recv = recv.view(self.__tocast)
tmp_recv = recv.view(transport_dtype)
self.comm.Allgatherv(
(tmp_send, tmp_send.size, MPI.CHAR), (tmp_recv, (rsize, roff), MPI.CHAR)
(tmp_send, tmp_send.size, mpi_type), (tmp_recv, (rsize, roff), mpi_type)
)
return recv

Expand Down

0 comments on commit c9fc44d

Please sign in to comment.