Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dask 2024.8.1 and later is very slow #1267

Open
tomwhite opened this issue Oct 7, 2024 · 7 comments
Open

Dask 2024.8.1 and later is very slow #1267

tomwhite opened this issue Oct 7, 2024 · 7 comments
Labels
performance upstream Used when our build breaks due to upstream changes

Comments

@tomwhite
Copy link
Collaborator

tomwhite commented Oct 7, 2024

This was originally reported in #1247 and a temporary pin introduced in #1248. I've opened this to track the issue so we can remove the pin.

@tomwhite
Copy link
Collaborator Author

tomwhite commented Oct 7, 2024

I've opened dask/dask#11416

@tomwhite
Copy link
Collaborator Author

Unfortunately, it looks like Dask 2024.10.0 doesn't fix this, see https://github.com/sgkit-dev/sgkit/actions/runs/11551276595 which is taking 19 minutes to run, rather than 6 (with Dask 2024.08.0).

@tomwhite
Copy link
Collaborator Author

On further investigation what's happening is that locally defined functions that are passed to Dask map_blocks and that wrap Numba functions are being recompiled every time the (genomics) method is called. For example in pbs:

sgkit/sgkit/stats/popgen.py

Lines 598 to 600 in 9dd940e

p = da.map_blocks(
lambda t: _pbs_cohorts(t, ct), t, chunks=shape, new_axis=3, dtype=np.float64
)

The lambda function calls a Numba function that is recompiled each time.

In most cases it's fairly easy to rewrite the code to avoid the use of locally defined functions. For PBS we can just do:

-    p = da.map_blocks(
-        lambda t: _pbs_cohorts(t, ct), t, chunks=shape, new_axis=3, dtype=np.float64
-    )
+    p = da.map_blocks(_pbs_cohorts, t, ct, chunks=shape, new_axis=3, dtype=np.float64)

The distance metrics code is more dynamic though, so it's not a simple fix:

sgkit/sgkit/distance/api.py

Lines 111 to 143 in 9dd940e

try:
map_func_name = f"{metric}_map_{device}"
reduce_func_name = f"{metric}_reduce_{device}"
map_func = getattr(metrics, map_func_name)
reduce_func = getattr(metrics, reduce_func_name)
n_map_param = metrics.N_MAP_PARAM[metric]
except AttributeError:
raise NotImplementedError(
f"Given metric: '{metric}' is not implemented for '{device}'."
)
x = da.asarray(x)
if x.ndim != 2:
raise ValueError(f"2-dimensional array expected, got '{x.ndim}'")
# setting this variable outside of _pairwise to avoid it's recreation
# in every iteration, which eventually leads to increase in dask
# graph serialisation/deserialisation time significantly
metric_param = np.empty(n_map_param, dtype=x.dtype)
def _pairwise_cpu(f: ArrayLike, g: ArrayLike) -> ArrayLike:
result: ArrayLike = map_func(f[:, None, :], g, metric_param)
# Adding a new axis to help combine chunks along this axis in the
# reduction step (see the _aggregate and _combine functions below).
return result[..., np.newaxis]
def _pairwise_gpu(f: ArrayLike, g: ArrayLike) -> ArrayLike: # pragma: no cover
result = map_func(f, g)
return result[..., np.newaxis]
pairwise_func = _pairwise_cpu
if device == "gpu":
pairwise_func = _pairwise_gpu # pragma: no cover

@tomwhite
Copy link
Collaborator Author

I've fixed the non-distance functions in this commit: e83b52c

I'm not sure what to do about the distance functions at this point.

@jeromekelleher
Copy link
Collaborator

There's only two possible metrics right now ('euclidean' or 'correlation') so I vote we make the code less clever and just code in the function names directly for those two cases?

@tomwhite
Copy link
Collaborator Author

tomwhite commented Nov 4, 2024

That's what I thought too - but there is another wrinkle. In this diff

tomwhite@e1119ca

previously metric_param was initialized outside the function to prevent Dask serialization/deserialization time (see the comment).

I suppose we could have a map of (shared) empty arrays keyed by dtype - but that doesn't seem very thread safe. Or we could initialize in the function, and leave a comment about how this previously caused Dask slowdown. Another option would be to remove the code!

@jeromekelleher
Copy link
Collaborator

Ah, I see. I'm reluctant to remove the code as we put quite a lot of effort in and it's our main usage of GPUs...

Perhaps @aktech would like to comment here? Is there an easy way to avoid using lambdas?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance upstream Used when our build breaks due to upstream changes
Projects
None yet
Development

No branches or pull requests

2 participants