From acf90d3d499a1f249a0459427cbbad78ce40f43d Mon Sep 17 00:00:00 2001 From: Becky Smith Date: Fri, 10 Jan 2025 12:06:29 +0000 Subject: [PATCH] test and not-quite-right fix for assertion error in specialize() --- ehrql/dummy_data_nextgen/query_info.py | 14 +++++---- .../dummy_data_nextgen/test_query_info.py | 31 ++++++++++++++++++- 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/ehrql/dummy_data_nextgen/query_info.py b/ehrql/dummy_data_nextgen/query_info.py index ebc48315f..232090205 100644 --- a/ehrql/dummy_data_nextgen/query_info.py +++ b/ehrql/dummy_data_nextgen/query_info.py @@ -328,12 +328,14 @@ def specialize(query, column) -> Node | None: if lhs is None or rhs is None: return None return type(comp)(lhs=lhs, rhs=rhs) - case SelectColumn() as q: - if column == q: - assert len(columns_for_query(q)) == 1 - return q - else: - return None + case SelectColumn() as q1: + # a SelectColumn() query can be a simple select from a SelectTable source, + # but if it is from an EventTable, it can will be a PickOneRowPerPatient + # with a Sort on a source SelectTable. + if set(columns_for_query(q1)) == {column}: + assert len(columns_for_query(q1)) == 1 + return q1 + return None case _: fields = query.__dataclass_fields__ specialized = {} diff --git a/tests/unit/dummy_data_nextgen/test_query_info.py b/tests/unit/dummy_data_nextgen/test_query_info.py index 5a92e49c6..a1589c69a 100644 --- a/tests/unit/dummy_data_nextgen/test_query_info.py +++ b/tests/unit/dummy_data_nextgen/test_query_info.py @@ -1,6 +1,6 @@ import datetime -from ehrql import Dataset, days +from ehrql import Dataset, case, days, when from ehrql.codes import CTV3Code from ehrql.dummy_data_nextgen.query_info import ColumnInfo, QueryInfo, TableInfo from ehrql.tables import ( @@ -147,3 +147,32 @@ def test_query_info_ignores_complex_comparisons(): column_info = query_info.tables["patients"].columns["date_of_birth"] assert column_info.values_used == [datetime.date(2022, 10, 5)] + + +def test_query_info_with_nested_case_statements(): + # This test reproduces an error encountered in real-world ehrQL which was over-using + # case statements for boolean series (i.e. using case to return True/False on an + # already bool series, and then also filtering by == True/False on that case statement). + # QueryInfo.specialize turns those sorts of queries into e.g. EQ(lhs=Value(True), rhs=Value(False)) + # (with no column references). + # This exposed a bug where a SelectColumn on an EventTable rather than a PatientTable + # was reduced to None, and we ended up with a resulting query with no column references + # in it. + dataset = Dataset() + + has_dob = case( + when(patients.date_of_birth.is_not_null()).then(True), otherwise=False + ) + first_date = events.sort_by(events.date).first_for_patient().date + + query = case( + when( + (has_dob == False) | ((has_dob == True) & (first_date == "2020-01-01")) + ).then(True), + otherwise=False, + ) + dataset.define_population(patients.exists_for_patient() & query) + + query_info = QueryInfo.from_dataset(dataset._compile()) + column_info = query_info.tables["events"].columns["date"] + assert column_info.values_used == [datetime.date(2020, 1, 1)]