Skip to content

Commit

Permalink
QR decomposition (#577)
Browse files Browse the repository at this point in the history
* QR decomposition

* Replace one map_direct with a merge chunks

* Use multiple outputs for QR

* Use recursion when R1 is too big to fit in memory

* _merge_into_single_chunk

* Add map_blocks_multiple_outputs utility function

* Add tsqr

* Add memory utilization test for qr

* Enforce tall-and-skinny for qr

* QR recursion improvements and test
  • Loading branch information
tomwhite authored Sep 23, 2024
1 parent a4020f3 commit d5b40b3
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 2 deletions.
6 changes: 6 additions & 0 deletions cubed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,9 @@
from .array_api.utility_functions import all, any

__all__ += ["all", "any"]

# extensions

from .array_api import linalg

__all__ += ["linalg"]
172 changes: 172 additions & 0 deletions cubed/array_api/linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from typing import NamedTuple

from cubed.array_api.array_object import Array
from cubed.backend_array_api import namespace as nxp
from cubed.core.ops import general_blockwise, map_direct, merge_chunks
from cubed.utils import array_memory, get_item


class QRResult(NamedTuple):
Q: Array
R: Array


def qr(x, /, *, mode="reduced") -> QRResult:
if x.ndim != 2:
raise ValueError("qr requires x to have 2 dimensions.")

if mode != "reduced":
raise ValueError("qr only supports mode='reduced'")

if x.numblocks[1] > 1:
raise ValueError(
"qr only supports tall-and-skinny (single column chunk) arrays. "
"Consider rechunking so there is only a single column chunk."
)

return tsqr(x)


def tsqr(x) -> QRResult:
"""Direct Tall-and-Skinny QR algorithm
From:
Direct QR factorizations for tall-and-skinny matrices in MapReduce architectures
Austin R. Benson, David F. Gleich, James Demmel
Proceedings of the IEEE International Conference on Big Data, 2013
https://arxiv.org/abs/1301.1071
"""

# follows Algorithm 2 from Benson et al
Q1, R1 = _qr_first_step(x)

if _r1_is_too_big(R1):
R1 = _rechunk_r1(R1)
Q2, R2 = tsqr(R1)
else:
Q2, R2 = _qr_second_step(R1)

Q, R = _qr_third_step(Q1, Q2), R2

return QRResult(Q, R)


def _qr_first_step(A):
m, n = A.chunksize
k, _ = A.numblocks

# Q1 has same shape and chunks as A
R1_shape = (n * k, n)
R1_chunks = ((n,) * k, (n,))
# qr implementation creates internal array buffers
extra_projected_mem = A.chunkmem * 4
Q1, R1 = map_blocks_multiple_outputs(
nxp.linalg.qr,
A,
shapes=[A.shape, R1_shape],
dtypes=[nxp.float64, nxp.float64],
chunkss=[A.chunks, R1_chunks],
extra_projected_mem=extra_projected_mem,
)
return QRResult(Q1, R1)


def _r1_is_too_big(R1):
array_mem = array_memory(R1.dtype, R1.shape)
# conservative values for max_mem (4 copies, doubled to give some slack)
max_mem = (R1.spec.allowed_mem - R1.spec.reserved_mem) // (4 * 2)
return array_mem > max_mem


def _rechunk_r1(R1, split_every=4):
# expand R1's chunk size in axis 0 so that new R1 will be smaller by factor of split_every
if R1.numblocks[0] == 1:
raise ValueError(
"Can't expand R1 chunk size further. Try increasing allowed_mem"
)
chunks = (R1.chunksize[0] * split_every, R1.chunksize[1])
return merge_chunks(R1, chunks=chunks)


def _qr_second_step(R1):
R1_single = _merge_into_single_chunk(R1)

Q2_shape = R1.shape
Q2_chunks = Q2_shape # single chunk

n = R1.shape[1]
R2_shape = (n, n)
R2_chunks = R2_shape # single chunk
# qr implementation creates internal array buffers
extra_projected_mem = R1_single.chunkmem * 4
Q2, R2 = map_blocks_multiple_outputs(
nxp.linalg.qr,
R1_single,
shapes=[Q2_shape, R2_shape],
dtypes=[nxp.float64, nxp.float64],
chunkss=[Q2_chunks, R2_chunks],
extra_projected_mem=extra_projected_mem,
)
return QRResult(Q2, R2)


def _merge_into_single_chunk(x, split_every=4):
# do a tree merge along first axis
while x.numblocks[0] > 1:
chunks = (x.chunksize[0] * split_every,) + x.chunksize[1:]
x = merge_chunks(x, chunks)
return x


def _qr_third_step(Q1, Q2):
m, n = Q1.chunksize
k, _ = Q1.numblocks

Q1_shape = Q1.shape
Q1_chunks = Q1.chunks

Q2_chunks = ((n,) * k, (n,))
extra_projected_mem = 0
Q = map_direct(
_q_matmul,
Q1,
Q2,
shape=Q1_shape,
dtype=nxp.float64,
chunks=Q1_chunks,
extra_projected_mem=extra_projected_mem,
q1_chunks=Q1_chunks,
q2_chunks=Q2_chunks,
)
return Q


def _q_matmul(x, *arrays, q1_chunks=None, q2_chunks=None, block_id=None):
q1 = arrays[0].zarray[get_item(q1_chunks, block_id)]
# this array only has a single chunk, but we need to get a slice corresponding to q2_chunks
q2 = arrays[1].zarray[get_item(q2_chunks, block_id)]
return q1 @ q2


def map_blocks_multiple_outputs(
func,
*args,
shapes,
dtypes,
chunkss,
**kwargs,
):
def key_function(out_key):
return tuple((array.name,) + out_key[1:] for array in args)

return general_blockwise(
func,
key_function,
*args,
shapes=shapes,
dtypes=dtypes,
chunkss=chunkss,
target_stores=[None] * len(dtypes),
**kwargs,
)
10 changes: 9 additions & 1 deletion cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,11 @@ def _finalize(
compile_function: Optional[Decorator] = None,
array_names=None,
) -> "FinalizedPlan":
dag = self.optimize(optimize_function, array_names).dag if optimize_graph else self.dag
dag = (
self.optimize(optimize_function, array_names).dag
if optimize_graph
else self.dag
)
# create a copy since _create_lazy_zarr_arrays mutates the dag
dag = dag.copy()
if callable(compile_function):
Expand Down Expand Up @@ -501,6 +505,10 @@ def num_arrays(self) -> int:
"""Return the number of arrays in this plan."""
return sum(d.get("type") == "array" for _, d in self.dag.nodes(data=True))

def num_primitive_ops(self) -> int:
"""Return the number of primitive operations in this plan."""
return len(list(visit_nodes(self.dag)))

def num_tasks(self, resume=None):
"""Return the number of tasks needed to execute this plan."""
tasks = 0
Expand Down
59 changes: 59 additions & 0 deletions cubed/tests/test_linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import numpy as np
import pytest
from numpy.testing import assert_allclose

import cubed
import cubed.array_api as xp
from cubed.core.plan import arrays_to_plan


def test_qr():
A = np.reshape(np.arange(32, dtype=np.float64), (16, 2))
Q, R = xp.linalg.qr(xp.asarray(A, chunks=(4, 2)))

plan_unopt = arrays_to_plan(Q, R)._finalize()
assert plan_unopt.num_primitive_ops() == 4

Q, R = cubed.compute(Q, R)

assert_allclose(Q @ R, A, atol=1e-08)
assert_allclose(Q.T @ Q, np.eye(2, 2), atol=1e-08) # Q must be orthonormal
assert_allclose(R, np.triu(R), atol=1e-08) # R must be upper triangular


def test_qr_recursion():
A = np.reshape(np.arange(128, dtype=np.float64), (64, 2))

# find a memory setting where recursion happens
found = False
for factor in range(4, 16):
spec = cubed.Spec(allowed_mem=128 * factor, reserved_mem=0)

try:
Q, R = xp.linalg.qr(xp.asarray(A, chunks=(8, 2), spec=spec))

found = True
plan_unopt = arrays_to_plan(Q, R)._finalize()
assert plan_unopt.num_primitive_ops() > 4 # more than without recursion

Q, R = cubed.compute(Q, R)

assert_allclose(Q @ R, A, atol=1e-08)
assert_allclose(Q.T @ Q, np.eye(2, 2), atol=1e-08) # Q must be orthonormal
assert_allclose(R, np.triu(R), atol=1e-08) # R must be upper triangular

break

except ValueError:
pass # not enough memory

assert found


def test_qr_chunking():
A = xp.ones((32, 4), chunks=(4, 2))
with pytest.raises(
ValueError,
match=r"qr only supports tall-and-skinny \(single column chunk\) arrays.",
):
xp.linalg.qr(A)
15 changes: 14 additions & 1 deletion cubed/tests/test_mem_utilization.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,19 @@ def test_sum_partial_reduce(tmp_path, spec, executor):
run_operation(tmp_path, executor, "sum_partial_reduce", b)


# Linear algebra extension


@pytest.mark.slow
def test_qr(tmp_path, spec, executor):
a = cubed.random.random(
(40000, 1000), chunks=(5000, 1000), spec=spec
) # 40MB chunks
q, r = xp.linalg.qr(a)
# don't optimize graph so we use as much memory as possible (reading from Zarr)
run_operation(tmp_path, executor, "qr", q, r, optimize_graph=False)


# Multiple outputs


Expand Down Expand Up @@ -362,7 +375,7 @@ def run_operation(
# )
hist = HistoryCallback()
mem_warn = MemoryWarningCallback()
memray = MemrayCallback()
memray = MemrayCallback(mem_threshold=30_000_000)
# use None for each store to write to temporary zarr
cubed.store(
results,
Expand Down

0 comments on commit d5b40b3

Please sign in to comment.