Skip to content

Commit

Permalink
Fix assertions in dummy data specialize
Browse files Browse the repository at this point in the history
  • Loading branch information
rebkwok committed Jan 13, 2025
1 parent 32f477b commit 13ddf08
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 3 deletions.
4 changes: 2 additions & 2 deletions ehrql/dummy_data_nextgen/query_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,15 +249,15 @@ def specialize(query, column) -> Node | None:
if rhs is None:
return lhs
result = Function.And(lhs, rhs)
assert len(columns_for_query(result)) == 1
assert len(columns_for_query(result)) <= 1
return result
case Function.Or(lhs=lhs, rhs=rhs):
lhs = specialize(lhs, column)
rhs = specialize(rhs, column)
if lhs is None or rhs is None:
return None
result = Function.Or(lhs=lhs, rhs=rhs)
assert len(columns_for_query(result)) == 1
assert len(columns_for_query(result)) <= 1
return result

# TODO: This could really use a nicer way of handling it.
Expand Down
52 changes: 51 additions & 1 deletion tests/unit/dummy_data_nextgen/test_query_info.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime

from ehrql import Dataset, days
from ehrql import Dataset, days, maximum_of
from ehrql.codes import CTV3Code
from ehrql.dummy_data_nextgen.query_info import ColumnInfo, QueryInfo, TableInfo
from ehrql.tables import (
Expand Down Expand Up @@ -147,3 +147,53 @@ def test_query_info_ignores_complex_comparisons():
column_info = query_info.tables["patients"].columns["date_of_birth"]

assert column_info.values_used == [datetime.date(2022, 10, 5)]


def test_query_info_specialize_bug():
# This test reproduces an error encountered in real-world ehrQL which included
# queries that specialised down to comparisons between pure values and an
# unspecializable query involving a column, and resulted in a query with
# no column references that hit the assertion that len(columns_for_query(result)) == 1
dataset = Dataset()

# We can't create `Value(False)` directly in the query language, so this is the
# easiest way
pure_value = maximum_of(0, 0) != 0

# We need something which is not a pure value but which `specialize` can't currently
# handle. This is probably the simplest option but there are lots of others. I've
# listed a couple of others below.
unspecializable_thing = events.count_for_patient() == 0
# unspecializable_thing = events.sort_by(events.i).first_for_patient().i == 0
# unspecializable_thing = events.i.count_distinct_for_patient() == 0

# This is the core of the problematic construction. It also works with `&` instead
# to hit the other assertion.
query = pure_value | (pure_value & unspecializable_thing)

dataset.define_population(query)

# Ensure there's at least one column reference in the dataset (doesn't matter what
# it is) so that it always tries to specialise
dataset.dob = patients.date_of_birth
QueryInfo.from_dataset(dataset._compile())


def test_query_info_specialize_bug_values_used():
dataset = Dataset()

pure_value = maximum_of(0, 0) != 0

# Define the unspecializable thing using a column reference, so we can confirm that
# the query info includes the expected values_used
unspecializable_thing = (
events.sort_by(events.date).first_for_patient().date == "2020-01-01"
)

query = pure_value | (pure_value & unspecializable_thing)

dataset.define_population(query)

query_info = QueryInfo.from_dataset(dataset._compile())
column_info = query_info.tables["events"].columns["date"]
assert column_info.values_used == [datetime.date(2020, 1, 1)]

0 comments on commit 13ddf08

Please sign in to comment.