Skip to content

Commit

Permalink
Write test for independent columns in generate
Browse files Browse the repository at this point in the history
  • Loading branch information
daffidwilde committed Nov 27, 2023
1 parent 1a2681e commit e537472
Showing 1 changed file with 58 additions and 0 deletions.
58 changes: 58 additions & 0 deletions tests/mst/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,61 @@ def test_generate(nrows, params):
data.repartition.assert_called_once_with("100MB")

assert synthetic is data


@given(
st.integers(1, 100),
st.lists(st.text(), min_size=2, max_size=10, unique=True),
)
def test_generate_with_extra_independents(nrows, params):
"""Test generation executes with multiple independent columns."""

column, *order = params

prng = da.random.default_rng(0)

data = mock.MagicMock()
data.repartition.return_value = data

model = mock.MagicMock()
marginal = mock.MagicMock()
model.project.return_value.datavector.return_value = marginal

with mock.patch("centhesus.mst.MST._setup_generate") as setup, mock.patch(
"centhesus.mst.MST._synthesise_first_column"
) as first, mock.patch(
"centhesus.mst.MST._find_prerequisite_columns"
) as find, mock.patch(
"centhesus.mst.MST._synthesise_column"
) as synth:
setup.return_value = (nrows, prng, "cliques", column, order)
first.return_value = data
find.return_value = ()
synth.return_value = "independent"

synthetic = MST.generate(model, nrows)

setup.assert_called_once_with(model, nrows, None)
first.assert_called_once_with(model, column, nrows, prng)

num_subsequent_columns = len(order)
assert model.project.call_count == num_subsequent_columns
for call, col in zip(model.project.call_args_list, order):
assert call.args == ((col,),)

assert (
model.project.return_value.datavector.call_count
== num_subsequent_columns
)
for call in model.project.return_value.datavector.call_args_list:
assert call.args == ()
assert call.kwargs == {"flatten": False}

assert synth.call_count == num_subsequent_columns
for call, col in zip(synth.call_args_list, order):
assert call.args == (marginal, nrows, prng)
assert hasattr(data, col)

data.repartition.assert_called_once_with("100MB")

assert synthetic is data

0 comments on commit e537472

Please sign in to comment.