diff --git a/ehrql/dummy_data_nextgen/query_info.py b/ehrql/dummy_data_nextgen/query_info.py index ebc48315f..8447730fc 100644 --- a/ehrql/dummy_data_nextgen/query_info.py +++ b/ehrql/dummy_data_nextgen/query_info.py @@ -249,7 +249,7 @@ def specialize(query, column) -> Node | None: if rhs is None: return lhs result = Function.And(lhs, rhs) - assert len(columns_for_query(result)) == 1 + assert len(columns_for_query(result)) <= 1 return result case Function.Or(lhs=lhs, rhs=rhs): lhs = specialize(lhs, column) @@ -257,7 +257,7 @@ def specialize(query, column) -> Node | None: if lhs is None or rhs is None: return None result = Function.Or(lhs=lhs, rhs=rhs) - assert len(columns_for_query(result)) == 1 + assert len(columns_for_query(result)) <= 1 return result # TODO: This could really use a nicer way of handling it. diff --git a/tests/unit/dummy_data_nextgen/test_query_info.py b/tests/unit/dummy_data_nextgen/test_query_info.py index 5a92e49c6..0b71f2dc1 100644 --- a/tests/unit/dummy_data_nextgen/test_query_info.py +++ b/tests/unit/dummy_data_nextgen/test_query_info.py @@ -1,6 +1,6 @@ import datetime -from ehrql import Dataset, days +from ehrql import Dataset, days, maximum_of from ehrql.codes import CTV3Code from ehrql.dummy_data_nextgen.query_info import ColumnInfo, QueryInfo, TableInfo from ehrql.tables import ( @@ -147,3 +147,53 @@ def test_query_info_ignores_complex_comparisons(): column_info = query_info.tables["patients"].columns["date_of_birth"] assert column_info.values_used == [datetime.date(2022, 10, 5)] + + +def test_query_info_specialize_bug(): + # This test reproduces an error encountered in real-world ehrQL which included + # queries that specialised down to comparisons between pure values and an + # unspecializable query involving a column, and resulted in a query with + # no column references that hit the assertion that len(columns_for_query(result)) == 1 + dataset = Dataset() + + # We can't create `Value(False)` directly in the query language, so this is the + # easiest way + pure_value = maximum_of(0, 0) != 0 + + # We need something which is not a pure value but which `specialize` can't currently + # handle. This is probably the simplest option but there are lots of others. I've + # listed a couple of others below. + unspecializable_thing = events.count_for_patient() == 0 + # unspecializable_thing = events.sort_by(events.i).first_for_patient().i == 0 + # unspecializable_thing = events.i.count_distinct_for_patient() == 0 + + # This is the core of the problematic construction. It also works with `&` instead + # to hit the other assertion. + query = pure_value | (pure_value & unspecializable_thing) + + dataset.define_population(query) + + # Ensure there's at least one column reference in the dataset (doesn't matter what + # it is) so that it always tries to specialise + dataset.dob = patients.date_of_birth + QueryInfo.from_dataset(dataset._compile()) + + +def test_query_info_specialize_bug_values_used(): + dataset = Dataset() + + pure_value = maximum_of(0, 0) != 0 + + # Define the unspecializable thing using a column reference, so we can confirm that + # the query info includes the expected values_used + unspecializable_thing = ( + events.sort_by(events.date).first_for_patient().date == "2020-01-01" + ) + + query = pure_value | (pure_value & unspecializable_thing) + + dataset.define_population(query) + + query_info = QueryInfo.from_dataset(dataset._compile()) + column_info = query_info.tables["events"].columns["date"] + assert column_info.values_used == [datetime.date(2020, 1, 1)]