Skip to content

Commit

Permalink
Use SQL query params for user variables
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-thom committed Aug 22, 2024
1 parent 7f1d821 commit 07181a0
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 58 deletions.
137 changes: 79 additions & 58 deletions src/infrasys/time_series_metadata_store.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Stores time series metadata in a SQLite database."""

import hashlib
import itertools
import json
import os
import sqlite3
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional
from typing import Any, Optional, Sequence
from uuid import UUID

from loguru import logger
Expand All @@ -20,6 +21,7 @@
TYPE_METADATA,
)
from infrasys.time_series_models import TimeSeriesMetadata
from infrasys.utils.sqlite import execute


class TimeSeriesMetadataStore:
Expand Down Expand Up @@ -59,7 +61,7 @@ def _create_metadata_table(self):
]
schema_text = ",".join(schema)
cur = self._con.cursor()
self._execute(cur, f"CREATE TABLE {self.TABLE_NAME}({schema_text})")
execute(cur, f"CREATE TABLE {self.TABLE_NAME}({schema_text})")
self._create_indexes(cur)
self._con.commit()
logger.debug("Created in-memory time series metadata table")
Expand All @@ -72,12 +74,12 @@ def _create_indexes(self, cur) -> None:
# 1c. time series for one component with all user attributes
# 2. Optimize for checks at system.add_time_series. Use all fields and attribute hash.
# 3. Optimize for returning all metadata for a time series UUID.
self._execute(
execute(
cur,
f"CREATE INDEX by_c_vn_tst_hash ON {self.TABLE_NAME} "
f"(component_uuid, variable_name, time_series_type, user_attributes_hash)",
)
self._execute(cur, f"CREATE INDEX by_ts_uuid ON {self.TABLE_NAME} (time_series_uuid)")
execute(cur, f"CREATE INDEX by_ts_uuid ON {self.TABLE_NAME} (time_series_uuid)")

def add(
self,
Expand All @@ -92,7 +94,7 @@ def add(
Raised if the time series metadata already stored.
"""
attribute_hash = _compute_user_attribute_hash(metadata.user_attributes)
where_clause = self._make_where_clause(
where_clause, params = self._make_where_clause(
components,
metadata.variable_name,
metadata.type,
Expand All @@ -102,7 +104,7 @@ def add(
cur = self._con.cursor()

query = f"SELECT COUNT(*) FROM {self.TABLE_NAME} WHERE {where_clause}"
res = self._execute(cur, query).fetchone()
res = execute(cur, query, params=params).fetchone()
if res[0] > 0:
msg = f"Time series with {metadata=} is already stored."
raise ISAlreadyAttached(msg)
Expand Down Expand Up @@ -164,10 +166,10 @@ def get_time_series_counts(self) -> "TimeSeriesCounts":
,resolution
"""
cur = self._con.cursor()
rows = self._execute(cur, query).fetchall()
rows = execute(cur, query).fetchall()
time_series_type_count = {(x[0], x[1], x[2], x[3]): x[4] for x in rows}

time_series_count = self._execute(
time_series_count = execute(
cur, f"SELECT COUNT(DISTINCT time_series_uuid) from {self.TABLE_NAME}"
).fetchall()[0][0]

Expand Down Expand Up @@ -216,10 +218,8 @@ def get_metadata(
def has_time_series(self, time_series_uuid: UUID) -> bool:
"""Return True if there is time series matching the UUID."""
cur = self._con.cursor()
query = (
f"SELECT COUNT(*) FROM {self.TABLE_NAME} WHERE time_series_uuid = '{time_series_uuid}'"
)
row = self._execute(cur, query).fetchone()
query = f"SELECT COUNT(*) FROM {self.TABLE_NAME} WHERE time_series_uuid = ?"
row = execute(cur, query, params=(str(time_series_uuid),)).fetchone()
return row[0] > 0

def has_time_series_metadata(
Expand Down Expand Up @@ -249,22 +249,23 @@ def has_time_series_metadata(
)
)

where_clause = self._make_where_clause(
where_clause, params = self._make_where_clause(
(component,), variable_name, time_series_type, **user_attributes
)
query = f"SELECT COUNT(*) FROM {self.TABLE_NAME} WHERE {where_clause}"
cur = self._con.cursor()
res = self._execute(cur, query).fetchone()
res = execute(cur, query, params=params).fetchone()
return res[0] > 0

def list_existing_time_series(self, time_series_uuids: list[UUID]) -> set[UUID]:
"""Return the UUIDs that are present."""
cur = self._con.cursor()
uuids = ",".join([f"'{x}'" for x in time_series_uuids])
params = tuple(str(x) for x in time_series_uuids)
uuids = ",".join(itertools.repeat("?", len(time_series_uuids)))
query = (
f"SELECT time_series_uuid FROM {self.TABLE_NAME} WHERE time_series_uuid IN ({uuids})"
)
rows = self._execute(cur, query).fetchall()
rows = execute(cur, query, params=params).fetchall()
return {UUID(x[0]) for x in rows}

def list_missing_time_series(self, time_series_uuids: list[UUID]) -> set[UUID]:
Expand All @@ -291,12 +292,12 @@ def list_metadata(
)
]

where_clause = self._make_where_clause(
where_clause, params = self._make_where_clause(
components, variable_name, time_series_type, **user_attributes
)
query = f"SELECT metadata FROM {self.TABLE_NAME} WHERE {where_clause}"
cur = self._con.cursor()
rows = self._execute(cur, query).fetchall()
rows = execute(cur, query, params=params).fetchall()
return [_deserialize_time_series_metadata(x[0]) for x in rows]

def _list_metadata_no_sql_json(
Expand All @@ -314,10 +315,10 @@ def _list_metadata_no_sql_json(
The first element of each tuple is the database id field that uniquely identifies the
row.
"""
where_clause = self._make_where_clause(components, variable_name, time_series_type)
where_clause, params = self._make_where_clause(components, variable_name, time_series_type)
query = f"SELECT id, metadata FROM {self.TABLE_NAME} WHERE {where_clause}"
cur = self._con.cursor()
rows = self._execute(cur, query).fetchall()
rows = execute(cur, query, params).fetchall()

metadata_list = []
for row in rows:
Expand All @@ -342,13 +343,13 @@ def list_rows(
)
raise ISOperationNotAllowed(msg)

where_clause = self._make_where_clause(
where_clause, params = self._make_where_clause(
components, variable_name, time_series_type, **user_attributes
)
cols = "*" if columns is None else ",".join(columns)
query = f"SELECT {cols} FROM {self.TABLE_NAME} WHERE {where_clause}"
cur = self._con.cursor()
rows = self._execute(cur, query).fetchall()
rows = execute(cur, query, params=params).fetchall()
return rows

def remove(
Expand All @@ -371,39 +372,36 @@ def remove(
):
ts_uuids.add(metadata.time_series_uuid)
ids.append(id_)
id_str = ",".join([str(x) for x in ids])
params = [str(x) for x in ids]
id_str = ",".join(itertools.repeat("?", len(ids)))
query = f"DELETE FROM {self.TABLE_NAME} WHERE id IN ({id_str})"
self._execute(cur, query)
count_deleted = self._execute(cur, "SELECT changes()").fetchall()[0][0]
execute(cur, query, params=params)
count_deleted = execute(cur, "SELECT changes()").fetchall()[0][0]
if count_deleted != len(ids):
msg = f"Bug: Unexpected length mismatch {len(ts_uuids)=} {count_deleted=}"
raise Exception(msg)
self._con.commit()
return list(ts_uuids)

where_clause = self._make_where_clause(
where_clause, params = self._make_where_clause(
components, variable_name, time_series_type, **user_attributes
)
query = f"SELECT time_series_uuid FROM {self.TABLE_NAME} WHERE {where_clause}"
uuids = [UUID(x[0]) for x in self._execute(cur, query).fetchall()]
uuids = [UUID(x[0]) for x in execute(cur, query, params=params).fetchall()]

query = f"DELETE FROM {self.TABLE_NAME} WHERE ({where_clause})"
self._execute(cur, query)
execute(cur, query, params=params)
self._con.commit()
count_deleted = self._execute(cur, "SELECT changes()").fetchall()[0][0]
count_deleted = execute(cur, "SELECT changes()").fetchall()[0][0]
if len(uuids) != count_deleted:
msg = f"Bug: Unexpected length mismatch: {len(uuids)=} {count_deleted=}"
raise Exception(msg)
return uuids

def sql(self, query: str) -> list[tuple]:
def sql(self, query: str, params: Sequence[str] = ()) -> list[tuple]:
"""Run a SQL query on the time series metadata table."""
cur = self._con.cursor()
return self._execute(cur, query).fetchall()

def _execute(self, cursor: sqlite3.Cursor, query: str) -> Any:
logger.trace("SQL query: {}", query)
return cursor.execute(query)
return execute(cur, query, params=params).fetchall()

def _insert_rows(self, rows: list[tuple]) -> None:
cur = self._con.cursor()
Expand All @@ -414,11 +412,16 @@ def _insert_rows(self, rows: list[tuple]) -> None:
finally:
self._con.commit()

def _make_components_str(self, *components: Component) -> str:
def _make_components_str(self, params: list[str], *components: Component) -> str:
if not components:
msg = "At least one component must be passed."
raise ISOperationNotAllowed(msg)
or_clause = "OR ".join([f"component_uuid = '{x.uuid}'" for x in components])

or_clause = "OR ".join((itertools.repeat("component_uuid = ? ", len(components))))

for component in components:
params.append(str(component.uuid))

return f"({or_clause})"

def _make_where_clause(
Expand All @@ -428,22 +431,36 @@ def _make_where_clause(
time_series_type: Optional[str],
attribute_hash: Optional[str] = None,
**user_attributes: str,
) -> str:
component_str = self._make_components_str(*components)
var_str = "" if variable_name is None else f"AND variable_name = '{variable_name}'"
ts_str = "" if time_series_type is None else f"AND time_series_type = '{time_series_type}'"
ua_str = (
f"AND {_make_user_attribute_filter(user_attributes)}"
if attribute_hash is None and user_attributes
else ""
)
if ua_str:
) -> tuple[str, list[str]]:
params: list[str] = []
component_str = self._make_components_str(params, *components)

if variable_name is None:
var_str = ""
else:
var_str = "AND variable_name = ?"
params.append(variable_name)

if time_series_type is None:
ts_str = ""
else:
ts_str = "AND time_series_type = ?"
params.append(time_series_type)

if attribute_hash is None and user_attributes:
_raise_if_unsupported_sql_operation()
ua_hash_filter = _make_user_attribute_filter(user_attributes, params)
ua_str = f"AND {ua_hash_filter}"
else:
ua_str = ""

ua_hash = (
f"AND {_make_user_attribute_hash_filter(attribute_hash)}" if attribute_hash else ""
)
return f"({component_str} {var_str} {ts_str}) {ua_str} {ua_hash}"
if attribute_hash:
ua_hash_filter = _make_user_attribute_hash_filter(attribute_hash, params)
ua_hash = f"AND {ua_hash_filter}"
else:
ua_hash = ""

return f"({component_str} {var_str} {ts_str}) {ua_str} {ua_hash}", params

def _try_time_series_metadata_by_full_params(
self,
Expand All @@ -455,7 +472,7 @@ def _try_time_series_metadata_by_full_params(
) -> list[tuple] | None:
assert variable_name is not None
assert time_series_type is not None
where_clause = self._make_where_clause(
where_clause, params = self._make_where_clause(
(component,),
variable_name,
time_series_type,
Expand All @@ -464,7 +481,7 @@ def _try_time_series_metadata_by_full_params(
)
query = f"SELECT {column} FROM {self.TABLE_NAME} WHERE {where_clause}"
cur = self._con.cursor()
rows = self._execute(cur, query).fetchall()
rows = execute(cur, query, params=params).fetchall()
if not rows:
return None

Expand Down Expand Up @@ -529,14 +546,18 @@ class TimeSeriesCounts:
time_series_type_count: dict[tuple[str, str, str, str], int]


def _make_user_attribute_filter(user_attributes: dict[str, Any]) -> str:
def _make_user_attribute_filter(user_attributes: dict[str, Any], params: list[str]) -> str:
attrs = _make_user_attribute_dict(user_attributes)
text = "AND ".join([f"metadata->>'$.user_attributes.{k}' = '{v}'" for k, v in attrs.items()])
return f"({text})"
items = []
for key, val in attrs.items():
items.append(f"metadata->>'$.user_attributes.{key}' = ? ")
params.append(val)
return "AND ".join(items)


def _make_user_attribute_hash_filter(attribute_hash) -> str:
return f"user_attributes_hash = '{attribute_hash}'"
def _make_user_attribute_hash_filter(attribute_hash: str, params: list[str]) -> str:
params.append(attribute_hash)
return "user_attributes_hash = ?"


def _make_user_attribute_dict(user_attributes: dict[str, Any]) -> dict[str, Any]:
Expand Down
Empty file added src/infrasys/utils/__init__.py
Empty file.
12 changes: 12 additions & 0 deletions src/infrasys/utils/sqlite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Utility functions for SQLite"""

import sqlite3
from typing import Any, Sequence

from loguru import logger


def execute(cursor: sqlite3.Cursor, query: str, params: Sequence[str] = ()) -> Any:
"""Execute a SQL query."""
logger.trace("SQL query: {query} {params=}", query)
return cursor.execute(query, params)

0 comments on commit 07181a0

Please sign in to comment.