Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REF: Reorganize how multi-device parallelism is implemented #322

Merged
merged 33 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
6fe3330
DOC: Add missing type hints
carterbox Jun 20, 2024
fb87432
REF: Separate batch computation from batch separation
carterbox Jun 20, 2024
56ef8ee
REF: Add copy_to_host and copy_to_device methods
carterbox Jun 20, 2024
60e9de5
REF: Separate communications from DM algorithm
carterbox Jun 20, 2024
01bc190
DEV: Multi-gpu DM
carterbox Jun 24, 2024
a8ecdea
DEV: Implement object joining function
carterbox Jun 24, 2024
1986196
REF: Synchronize position updates
carterbox Jun 25, 2024
9a481c9
DEV: Refactor RPIE for new parallelism
carterbox Jun 26, 2024
4385019
DEV: Start transitioning lstsq method and fix bugs
carterbox Jun 27, 2024
f6c9c18
DEV: Transition probe variable updates
carterbox Jun 27, 2024
281dd7b
DEV: Strip comm from inside rpie solver
carterbox Jun 27, 2024
80aabce
DEV: Remove comm pool from lstsq implementation
carterbox Jun 27, 2024
dc47d7c
DEV: Revert changes to solvers
carterbox Jun 27, 2024
d60abeb
Merge branch 'multi-gpu-new2' into multi-gpu-new
carterbox Jun 27, 2024
f95c260
DEV: Merge two branches implementations
carterbox Jun 28, 2024
4ef26d4
DOC: Improve documentation
carterbox Jun 28, 2024
0ab637e
DOC: Add missing docstring
carterbox Jun 28, 2024
def7770
TST: Fixup some tests and docs
carterbox Jun 28, 2024
d7ed4c9
TST: Fix broken test
carterbox Jun 28, 2024
aa58d4a
DOC: Improve docs
carterbox Jun 28, 2024
a4900f1
API: Mark ptycho reconstruction get function as not implemented
carterbox Jul 1, 2024
20e7b6a
REF: Avoid using copy.copy to prevent unused copies of arrays
carterbox Jul 1, 2024
97fb026
BUG: Prevent accidental type promotion
carterbox Jul 2, 2024
6ce09ce
BUG: Soft transition when for join_psi
carterbox Jul 2, 2024
20a0688
BUG: Use get_result() instead of directly accessing parameters
carterbox Jul 2, 2024
178cd29
DEV: Reimplement some getter functions
carterbox Jul 3, 2024
6b47861
BUG: Fix broken conditional that skips probe orthogonalization
carterbox Jul 3, 2024
16698c4
Merge functions back together
carterbox Jul 3, 2024
d9f03b9
API: Move work of default stream when using pool.map
carterbox Jul 8, 2024
f6611c8
REF: Reorder object probe constraints
carterbox Jul 9, 2024
5d59edb
REF: Avoid device host transfer until end of epoch
carterbox Jul 9, 2024
14338f2
Merge branch 'main' into multi-gpu-new
carterbox Jul 9, 2024
a7a40ec
BUG: Fix bugs from merge
carterbox Jul 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 36 additions & 36 deletions src/tike/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def _split_gpu(
x: npt.ArrayLike,
dtype: npt.DTypeLike,
) -> npt.ArrayLike:
"""Return x[m] as a CuPy array on the current device."""
return cp.asarray(x[m], dtype=dtype)


Expand All @@ -25,6 +26,7 @@ def _split_host(
x: npt.ArrayLike,
dtype: npt.DTypeLike,
) -> npt.ArrayLike:
"""Return x[m] as a NumPy array."""
return np.asarray(x[m], dtype=dtype)


Expand All @@ -33,6 +35,7 @@ def _split_pinned(
x: npt.ArrayLike,
dtype: npt.DTypeLike,
) -> npt.ArrayLike:
"""Return x[m] as a CuPy pinned host memory array."""
pinned = cupyx.empty_pinned(shape=(len(m), *x.shape[1:]), dtype=dtype)
pinned[...] = x[m]
return pinned
Expand Down Expand Up @@ -171,22 +174,24 @@ def by_scan_stripes(


def by_scan_stripes_contiguous(
*args,
pool: tike.communicators.ThreadPool,
shape: typing.Tuple[int],
dtype: typing.List[npt.DTypeLike],
destination: typing.List[str],
scan: npt.NDArray[np.float32],
fly: int = 1,
batch_method,
batch_method: typing.Literal[
"compact", "wobbly_center", "wobbly_center_random_bootstrap"
],
num_batch: int,
) -> typing.Tuple[typing.List[npt.NDArray],
typing.List[typing.List[npt.NDArray]]]:
"""Split data by into stripes and create contiguously ordered batches.

Divide the field of view into one stripe per devices; within each stripe,
create batches according to the batch_method loading the batches into
contiguous blocks in device memory.
) -> typing.Tuple[
typing.List[npt.NDArray],
typing.List[typing.List[npt.NDArray]],
typing.List[int],
]:
"""Return the indices that will split `scan` into 2D stripes of equal count
and create contiguously ordered batches within those stripes.

Divide the field of view into one stripe per worker in `pool`; within each
stripe, create batches according to the batch_method loading the batches
into contiguous blocks in device memory.

Parameters
----------
Expand All @@ -206,14 +211,14 @@ def by_scan_stripes_contiguous(
Returns
-------
order : List[array[int]]
The locations of the inputs in the original arrays.
For each worker in pool, the indices of the data
batches : List[List[array[int]]]
The locations of the elements of each batch
scan : List[array[float32]]
The divided 2D coordinates of the scan positions.
args : List[array[float32]] or None
Each input divided into regions or None if arg was None.

For each worker in pool, for each batch, the indices of the elements of
each batch
stripe_start : List[int]
The coorinates of the leading edge of each stripe along the 0th
dimension in the scan coordinates. e.g the minimum coordinate of the
scan positions in each stripe.
"""
if len(shape) != 2:
raise ValueError('The grid shape must have two dimensions.')
Expand All @@ -229,6 +234,7 @@ def by_scan_stripes_contiguous(
x=scan,
dtype=scan.dtype,
)
stripe_start = [int(np.floor(np.min(x[:, 0]))) for x in split_scan]
batches_noncontiguous: typing.List[typing.List[npt.NDArray]] = pool.map(
getattr(tike.cluster, batch_method),
split_scan,
Expand All @@ -247,26 +253,13 @@ def by_scan_stripes_contiguous(
batch_breaks,
))

split_args = []
for arg, t, dest in zip([scan, *args], dtype, destination):
if arg is None:
split_args.append(None)
else:
split_args.append(
pool.map(
_split_gpu if dest == 'gpu' else _split_pinned,
map_to_gpu_contiguous,
x=arg,
dtype=t,
))

if __debug__:
for device in batches_contiguous:
assert len(device) == num_batch, (
f"There should be {num_batch} batches, found {len(device)}"
)

return (map_to_gpu_contiguous, batches_contiguous, *split_args)
return (map_to_gpu_contiguous, batches_contiguous, stripe_start)


def stripes_equal_count(
Expand Down Expand Up @@ -306,7 +299,10 @@ def stripes_equal_count(
)


def wobbly_center(population, num_cluster):
def wobbly_center(
population: npt.ArrayLike,
num_cluster: int,
) -> typing.List[npt.NDArray]:
"""Return the indices that divide population into heterogenous clusters.

Uses a contrarian approach to clustering by maximizing the heterogeneity
Expand Down Expand Up @@ -382,7 +378,7 @@ def wobbly_center(population, num_cluster):


def wobbly_center_random_bootstrap(
population,
population: npt.ArrayLike,
num_cluster: int,
boot_fraction: float = 0.95,
) -> typing.List[npt.NDArray]:
Expand Down Expand Up @@ -466,7 +462,11 @@ def wobbly_center_random_bootstrap(
return [cp.asnumpy(xp.flatnonzero(labels == c)) for c in range(num_cluster)]


def compact(population, num_cluster, max_iter=500):
def compact(
population: npt.ArrayLike,
num_cluster: int,
max_iter: int = 500,
) -> typing.List[npt.NDArray]:
"""Return the indices that divide population into compact clusters.

Uses an approach that is inspired by the naive k-means algorithm, but it
Expand Down
24 changes: 19 additions & 5 deletions src/tike/communicators/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@
from .pool import ThreadPool


def _init_streams():
return [cp.cuda.Stream() for _ in range(2)]


class Comm:
"""A Ptychography communicator.

Expand Down Expand Up @@ -47,7 +43,6 @@ def __init__(
self.use_mpi = True
self.mpi = mpi()
self.pool = pool(gpu_count)
self.streams = self.pool.map(_init_streams)

def __enter__(self):
self.mpi.__enter__()
Expand Down Expand Up @@ -139,3 +134,22 @@ def Allreduce(
buf.append(
self.mpi.Allreduce(src[self.pool.workers.index(worker)]))
return buf

def swap_edges(
self,
x: typing.List[cp.ndarray],
overlap: int,
edges: typing.List[int],
) -> typing.List[cp.ndarray]:
"""Swap the region of each x with its neighbor around the given edges.

Given iterable x, a list of ND arrays; edges, the coordinates in x
along dimension -2; and overlap, the width of the region to swap around
the edge; trade [..., edge:(edge + overlap), :] between neighbors.
"""
# FIXME: Swap edges between MPI nodes
return self.pool.swap_edges(
x=x,
overlap=overlap,
edges=edges,
)
76 changes: 73 additions & 3 deletions src/tike/communicators/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ def __init__(
self.num_workers) if self.num_workers > 1 else NoPoolExecutor(
self.num_workers)

def f(worker):
with self.Device(worker):
return [cp.cuda.Stream() for _ in range(2)]

self.streams = list(self.executor.map(f, self.workers))

def __enter__(self):
if self.workers[0] != cp.cuda.Device().id:
raise ValueError(
Expand Down Expand Up @@ -397,10 +403,74 @@ def map(
) -> list:
"""ThreadPoolExecutor.map, but wraps call in a cuda.Device context."""

def f(worker, *args):
def f(worker, streams, *args):
with self.Device(worker):
return func(*args, **kwargs)
with streams[1]:
return func(*args, **kwargs)

workers = self.workers if workers is None else workers

return list(self.executor.map(f, workers, *iterables))
return list(self.executor.map(f, workers, self.streams, *iterables))

def swap_edges(
self,
x: typing.List[cp.ndarray],
overlap: int,
edges: typing.List[int],
):
"""Swap [..., edge:(edge + overlap), :] between neighbors in-place

For example, given overlap=3 and edges=[0, 4, 8, 12], the following
swap would be returned:

```
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5

[[0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0]]
[[1 1 1 1 0 0 0 1 2 2 2 1 1 1 1 1]]
[[2 2 2 2 2 2 2 2 1 1 1 2 3 3 3 2]]
[[3 3 3 3 3 3 3 3 3 3 3 3 2 2 2 3]]
```

Note that the minimum swapped region is 1 wide.

"""
if overlap < 1:
msg = f"Overlap for swap_edges cannot be less than 1: {overlap}"
raise ValueError(msg)
for i in range(self.num_workers - 1):
lo = edges[i + 1]
hi = lo + overlap
temp0 = self._copy_to(x[i][..., lo:hi, :], self.workers[i + 1])
temp1 = self._copy_to(x[i + 1][..., lo:hi, :], self.workers[i])
with self.Device(self.workers[i]):
rampu = cp.linspace(
0.0,
1.0,
overlap + 2,
endpoint=True,
)[1:-1][..., None]
rampd = cp.linspace(
1.0,
0.0,
overlap + 2,
endpoint=True,
)[1:-1][..., None]
x[i][..., lo:hi, :] = rampd * x[i][..., lo:hi, :] + rampu * temp1
with self.Device(self.workers[i + 1]):
rampu = cp.linspace(
0.0,
1.0,
overlap + 2,
endpoint=True,
)[1:-1][..., None]
rampd = cp.linspace(
1.0,
0.0,
overlap + 2,
endpoint=True,
)[1:-1][..., None]
x[i + 1][..., lo:hi, :] = (
rampd * temp0 + rampu * x[i + 1][..., lo:hi, :]
)
return x
6 changes: 3 additions & 3 deletions src/tike/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,9 @@ def conjugate_gradient(


def fit_line_least_squares(
y: typing.List[float],
x: typing.List[float],
) -> typing.Tuple[float, float]:
y: npt.NDArray[np.floating],
x: npt.NDArray[np.floating],
) -> typing.Tuple[np.floating, np.floating]:
"""Return the `slope`, `intercept` pair that best fits `y`, `x` to a line.

y = slope * x + intercept
Expand Down
36 changes: 24 additions & 12 deletions src/tike/ptycho/exitwave.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
just free space propagation to the detector.

"""

from __future__ import annotations

import copy
Expand Down Expand Up @@ -67,7 +68,7 @@ class ExitWaveOptions:
exitwave updates in Fourier space. `1.0` for no scaling.
"""

propagation_normalization: str = 'ortho'
propagation_normalization: str = "ortho"
"""Choose the scaling of the FFT in the forward model:

"ortho" - the forward and adjoint operations are divided by sqrt(n)
Expand All @@ -78,19 +79,29 @@ class ExitWaveOptions:

"""

def copy_to_device(self, comm) -> ExitWaveOptions:
def copy_to_device(self) -> ExitWaveOptions:
"""Copy to the current GPU memory."""
options = copy.copy(self)
if self.measured_pixels is not None:
options.measured_pixels = comm.pool.bcast([self.measured_pixels])
return options
return ExitWaveOptions(
measured_pixels=cp.asarray(self.measured_pixels, dtype=bool),
noise_model=self.noise_model,
propagation_normalization=self.propagation_normalization,
step_length_start=self.step_length_start,
step_length_usemodes=self.step_length_usemodes,
step_length_weight=self.step_length_weight,
unmeasured_pixels_scaling=self.unmeasured_pixels_scaling,
)

def copy_to_host(self) -> ExitWaveOptions:
"""Copy to the host CPU memory."""
options = copy.copy(self)
if self.measured_pixels is not None:
options.measured_pixels = cp.asnumpy(self.measured_pixels[0])
return options
return ExitWaveOptions(
measured_pixels=cp.asnumpy(self.measured_pixels),
noise_model=self.noise_model,
propagation_normalization=self.propagation_normalization,
step_length_start=self.step_length_start,
step_length_usemodes=self.step_length_usemodes,
step_length_weight=self.step_length_weight,
unmeasured_pixels_scaling=self.unmeasured_pixels_scaling,
)

def resample(self, factor: float) -> ExitWaveOptions:
"""Return a new `ExitWaveOptions` with the parameters rescaled."""
Expand All @@ -103,8 +114,9 @@ def resample(self, factor: float) -> ExitWaveOptions:
self.measured_pixels,
int(self.measured_pixels.shape[-1] * factor),
),
unmeasured_pixels_scaling=self.unmeasured_pixels_scaling,
propagation_normalization=self.propagation_normalization )
unmeasured_pixels_scaling=self.unmeasured_pixels_scaling,
propagation_normalization=self.propagation_normalization,
)


def poisson_steplength_all_modes(
Expand Down
Loading
Loading