Skip to content

Commit

Permalink
Fix Oracle Provider different count query and feature query when a sq…
Browse files Browse the repository at this point in the history
…l_manipulator is used, Issue #1831 (#1909)

* Adjusted Query Function to include SQL Manipulation in the Count Query for numberMatched retrieval. Now there queries should be the same in any case. Added a function for less repetitions and added a Test.

* Made process_query_with_sql_manipulator_sup more concise and removed duplications
  • Loading branch information
Moritz-Langer authored Jan 22, 2025
1 parent 51c6a95 commit d9dceac
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 32 deletions.
97 changes: 65 additions & 32 deletions pygeoapi/provider/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,38 @@ def _get_srid_from_crs(self, crs):

return srid

def _process_query_with_sql_manipulator_sup(
self, db, sql_query, bind_variables, extra_params, **query_args
):
"""
Apply the SQL manipulation plugin to process the SQL query.
:param db: Database connection instance
:param sql_query: The SQL query to process
:param bind_variables: Query bind variables
:param extra_params: Additional parameters for manipulation
:param query_args: Other dynamic arguments required for processing
:return: Processed SQL query and bind variables
"""
if self.sql_manipulator:
LOGGER.debug(f"sql_manipulator: {self.sql_manipulator}")
manipulation_class = _class_factory(self.sql_manipulator)

# Pass all arguments to the process_query method
sql_query, bind_variables = manipulation_class.process_query(
db=db,
sql_query=sql_query,
bind_variables=bind_variables,
sql_manipulator_options=self.sql_manipulator_options,
**query_args,
extra_params=extra_params,
)

for placeholder in ["#HINTS#", "#JOIN#", "#WHERE#"]:
sql_query = sql_query.replace(placeholder, "")

return sql_query, bind_variables

def query(
self,
offset=0,
Expand Down Expand Up @@ -695,9 +727,36 @@ def query(
# because of getFields ...
sql_query = f"SELECT COUNT(1) AS hits \
FROM {self.table} \
{where_dict['clause']}"
{where_dict['clause']} #WHERE#"

# Assign where_dict["properties"] to bind_variables
bind_variables = {**where_dict["properties"]}

# Default values for the process_query function (sql_manipulator)
query_args = {
"offset": offset,
"limit": limit,
"resulttype": resulttype,
"bbox": bbox,
"datetime_": datetime_,
"properties": properties,
"sortby": sortby,
"skip_geometry": skip_geometry,
"select_properties": select_properties,
"crs_transform_spec": crs_transform_spec,
"q": q,
"language": language,
"filterq": filterq,
}

# Apply the SQL manipulation plugin
extra_params["geom"] = self.geom
sql_query, bind_variables = self._process_query_with_sql_manipulator_sup( # noqa: E501
db, sql_query, bind_variables, extra_params, **query_args
)

try:
cursor.execute(sql_query, where_dict["properties"])
cursor.execute(sql_query, bind_variables)
except oracledb.Error as err:
LOGGER.error(
f"Error executing sql_query: {sql_query}: {err}"
Expand Down Expand Up @@ -795,36 +854,10 @@ def query(
# Create dictionary for sql bind variables
bind_variables = {**where_dict["properties"], **paging_bind}

# SQL manipulation plugin
if self.sql_manipulator:
LOGGER.debug("sql_manipulator: " + self.sql_manipulator)
manipulation_class = _class_factory(self.sql_manipulator)
sql_query, bind_variables = manipulation_class.process_query(
db,
sql_query,
bind_variables,
self.sql_manipulator_options,
offset,
limit,
resulttype,
bbox,
datetime_,
properties,
sortby,
skip_geometry,
select_properties,
crs_transform_spec,
q,
language,
filterq,
extra_params=extra_params
)

# Clean up placeholders that aren't used by the
# manipulation class.
sql_query = sql_query.replace("#HINTS#", "")
sql_query = sql_query.replace("#JOIN#", "")
sql_query = sql_query.replace("#WHERE#", "")
# Apply the SQL manipulation plugin
sql_query, bind_variables = self._process_query_with_sql_manipulator_sup( # noqa: E501
db, sql_query, bind_variables, extra_params, **query_args
)

LOGGER.debug(f"SQL Query: {sql_query}")
LOGGER.debug(f"Bind variables: {bind_variables}")
Expand Down
9 changes: 9 additions & 0 deletions tests/test_oracle_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def process_query(

if sql_query.find(" WHERE ") == -1:
sql_query = sql_query.replace("#WHERE#", f" WHERE {sql}")

else:
sql_query = sql_query.replace("#WHERE#", f" AND {sql}")

Expand Down Expand Up @@ -644,6 +645,14 @@ def test_extra_params_are_passed_to_sql_manipulator(config_manipulator):
assert not response['features']


def test_query_count_sql_manipulator(config_manipulator):
"""Test query number of hits"""
p = OracleProvider(config_manipulator)
result = p.query(resulttype="hits")

assert result.get("numberMatched") == 1


@pytest.fixture()
def database_connection_pool(config_db_conn):
os.environ["ORACLE_POOL_MIN"] = "2" # noqa: F841
Expand Down

0 comments on commit d9dceac

Please sign in to comment.