diff --git a/src/infrasys/system.py b/src/infrasys/system.py index deaa055..f40437e 100644 --- a/src/infrasys/system.py +++ b/src/infrasys/system.py @@ -2,6 +2,7 @@ import json import shutil +import sqlite3 from operator import itemgetter from collections import defaultdict from datetime import datetime @@ -33,16 +34,20 @@ ) from infrasys.time_series_manager import TimeSeriesManager, TIME_SERIES_KWARGS from infrasys.time_series_models import SingleTimeSeries, TimeSeriesData, TimeSeriesMetadata +from infrasys.utils.sqlite import backup, create_in_memory_db, restore class System: """Implements behavior for systems""" + DB_FILENAME = "time_series_metadata.db" + def __init__( self, name: Optional[str] = None, description: Optional[str] = None, auto_add_composed_components: bool = False, + con: Optional[sqlite3.Connection] = None, time_series_manager: Optional[TimeSeriesManager] = None, uuid: Optional[UUID] = None, **kwargs: Any, @@ -60,6 +65,8 @@ def __init__( The default behavior is to raise an ISOperationNotAllowed when this condition occurs. This handles values that are components, such as generator.bus, and lists of components, such as subsystem.generators, but not any other form of nested components. + con : None | sqlite3.Connection + Users should not pass this. De-serialization (from_json) will pass a Connection. time_series_manager : None | TimeSeriesManager Users should not pass this. De-serialization (from_json) will pass a constructed manager. @@ -79,8 +86,11 @@ def __init__( self._name = name self._description = description self._component_mgr = ComponentManager(self._uuid, auto_add_composed_components) + self._con = con or create_in_memory_db() time_series_kwargs = {k: v for k, v in kwargs.items() if k in TIME_SERIES_KWARGS} - self._time_series_mgr = time_series_manager or TimeSeriesManager(**time_series_kwargs) + self._time_series_mgr = time_series_manager or TimeSeriesManager( + self._con, **time_series_kwargs + ) self._data_format_version: Optional[str] = None # Note to devs: if you add new fields, add support in to_json/from_json as appropriate. @@ -127,10 +137,9 @@ def to_json(self, filename: Path | str, overwrite=False, indent=None, data=None) msg = f"{filename=} already exists. Choose a different path or set overwrite=True." raise ISFileExists(msg) - if not filename.parent.exists(): - filename.parent.mkdir() - + filename.parent.mkdir(exist_ok=True) time_series_dir = filename.parent / (filename.stem + "_time_series") + time_series_dir.mkdir(exist_ok=True) system_data = { "name": self.name, "description": self.description, @@ -161,7 +170,8 @@ def to_json(self, filename: Path | str, overwrite=False, indent=None, data=None) json.dump(data, f_out, indent=indent) logger.info("Wrote system data to {}", filename) - self._time_series_mgr.serialize(self._make_time_series_directory(filename)) + backup(self._con, time_series_dir / self.DB_FILENAME) + self._time_series_mgr.serialize(time_series_dir) @classmethod def from_json( @@ -257,12 +267,20 @@ def from_dict( """ system_data = data if "system" not in data else data["system"] ts_kwargs = {k: v for k, v in kwargs.items() if k in TIME_SERIES_KWARGS} + ts_path = ( + time_series_parent_dir + if isinstance(time_series_parent_dir, Path) + else Path(time_series_parent_dir) + ) + con = create_in_memory_db() + restore(con, ts_path / data["time_series"]["directory"] / System.DB_FILENAME) time_series_manager = TimeSeriesManager.deserialize( - data["time_series"], time_series_parent_dir, **ts_kwargs + con, data["time_series"], ts_path, **ts_kwargs ) system = cls( name=system_data.get("name"), description=system_data.get("description"), + con=con, time_series_manager=time_series_manager, uuid=UUID(system_data["uuid"]), **kwargs, diff --git a/src/infrasys/time_series_manager.py b/src/infrasys/time_series_manager.py index 7013387..a705ccc 100644 --- a/src/infrasys/time_series_manager.py +++ b/src/infrasys/time_series_manager.py @@ -1,5 +1,6 @@ """Manages time series arrays""" +import sqlite3 from datetime import datetime from pathlib import Path from typing import Any, Optional, Type @@ -32,7 +33,13 @@ def _process_time_series_kwarg(key: str, **kwargs: Any) -> Any: class TimeSeriesManager: """Manages time series for a system.""" - def __init__(self, storage: Optional[TimeSeriesStorageBase] = None, **kwargs) -> None: + def __init__( + self, + con: sqlite3.Connection, + storage: Optional[TimeSeriesStorageBase] = None, + initialize: bool = True, + **kwargs, + ) -> None: base_directory: Path | None = _process_time_series_kwarg("time_series_directory", **kwargs) self._read_only = _process_time_series_kwarg("time_series_read_only", **kwargs) self._storage = storage or ( @@ -40,7 +47,7 @@ def __init__(self, storage: Optional[TimeSeriesStorageBase] = None, **kwargs) -> if _process_time_series_kwarg("time_series_in_memory", **kwargs) else ArrowTimeSeriesStorage.create_with_temp_directory(base_directory=base_directory) ) - self._metadata_store = TimeSeriesMetadataStore() + self._metadata_store = TimeSeriesMetadataStore(con, initialize=initialize) # TODO: create parsing mechanism? CSV, CSV + JSON @@ -245,11 +252,11 @@ def _get_by_metadata( def serialize(self, dst: Path | str, src: Optional[Path | str] = None) -> None: """Serialize the time series data to dst.""" self._storage.serialize(dst, src) - self._metadata_store.backup(dst) @classmethod def deserialize( cls, + con: sqlite3.Connection, data: dict[str, Any], parent_dir: Path | str, **kwargs: Any, @@ -269,9 +276,7 @@ def deserialize( storage = ArrowTimeSeriesStorage.create_with_temp_directory() storage.serialize(src=time_series_dir, dst=storage.get_time_series_directory()) - mgr = cls(storage=storage, **kwargs) - mgr.metadata_store.restore(time_series_dir) - return mgr + return cls(con, storage=storage, initialize=False, **kwargs) def _handle_read_only(self) -> None: if self._read_only: diff --git a/src/infrasys/time_series_metadata_store.py b/src/infrasys/time_series_metadata_store.py index e1bb195..44cffcb 100644 --- a/src/infrasys/time_series_metadata_store.py +++ b/src/infrasys/time_series_metadata_store.py @@ -1,12 +1,12 @@ """Stores time series metadata in a SQLite database.""" import hashlib +import itertools import json import os import sqlite3 from dataclasses import dataclass -from pathlib import Path -from typing import Any, Optional +from typing import Any, Optional, Sequence from uuid import UUID from loguru import logger @@ -20,17 +20,18 @@ TYPE_METADATA, ) from infrasys.time_series_models import TimeSeriesMetadata +from infrasys.utils.sqlite import execute class TimeSeriesMetadataStore: """Stores time series metadata in a SQLite database.""" TABLE_NAME = "time_series_metadata" - DB_FILENAME = "time_series_metadata.db" - def __init__(self): - self._con = sqlite3.connect(":memory:") - self._create_metadata_table() + def __init__(self, con: sqlite3.Connection, initialize: bool = True): + self._con = con + if initialize: + self._create_metadata_table() self._supports_sqlite_json = _does_sqlite_support_json() if not self._supports_sqlite_json: # This is true on Ubuntu 22.04, which is used by GitHub runners as of March 2024. @@ -59,7 +60,7 @@ def _create_metadata_table(self): ] schema_text = ",".join(schema) cur = self._con.cursor() - self._execute(cur, f"CREATE TABLE {self.TABLE_NAME}({schema_text})") + execute(cur, f"CREATE TABLE {self.TABLE_NAME}({schema_text})") self._create_indexes(cur) self._con.commit() logger.debug("Created in-memory time series metadata table") @@ -72,12 +73,12 @@ def _create_indexes(self, cur) -> None: # 1c. time series for one component with all user attributes # 2. Optimize for checks at system.add_time_series. Use all fields and attribute hash. # 3. Optimize for returning all metadata for a time series UUID. - self._execute( + execute( cur, f"CREATE INDEX by_c_vn_tst_hash ON {self.TABLE_NAME} " f"(component_uuid, variable_name, time_series_type, user_attributes_hash)", ) - self._execute(cur, f"CREATE INDEX by_ts_uuid ON {self.TABLE_NAME} (time_series_uuid)") + execute(cur, f"CREATE INDEX by_ts_uuid ON {self.TABLE_NAME} (time_series_uuid)") def add( self, @@ -92,7 +93,7 @@ def add( Raised if the time series metadata already stored. """ attribute_hash = _compute_user_attribute_hash(metadata.user_attributes) - where_clause = self._make_where_clause( + where_clause, params = self._make_where_clause( components, metadata.variable_name, metadata.type, @@ -102,7 +103,7 @@ def add( cur = self._con.cursor() query = f"SELECT COUNT(*) FROM {self.TABLE_NAME} WHERE {where_clause}" - res = self._execute(cur, query).fetchone() + res = execute(cur, query, params=params).fetchone() if res[0] > 0: msg = f"Time series with {metadata=} is already stored." raise ISAlreadyAttached(msg) @@ -124,24 +125,6 @@ def add( ] self._insert_rows(rows) - def backup(self, directory: Path | str) -> None: - """Backup the database to a file in directory.""" - path = directory if isinstance(directory, Path) else Path(directory) - filename = path / self.DB_FILENAME - with sqlite3.connect(filename) as con: - self._con.backup(con) - con.close() - logger.info("Backed up the time series metadata to {}", filename) - - def restore(self, directory: Path | str) -> None: - """Restore the database from a file to memory.""" - path = directory if isinstance(directory, Path) else Path(directory) - filename = path / self.DB_FILENAME - with sqlite3.connect(filename) as con: - con.backup(self._con) - con.close() - logger.info("Restored the time series metadata to memory") - def get_time_series_counts(self) -> "TimeSeriesCounts": """Return summary counts of components and time series.""" query = f""" @@ -164,10 +147,10 @@ def get_time_series_counts(self) -> "TimeSeriesCounts": ,resolution """ cur = self._con.cursor() - rows = self._execute(cur, query).fetchall() + rows = execute(cur, query).fetchall() time_series_type_count = {(x[0], x[1], x[2], x[3]): x[4] for x in rows} - time_series_count = self._execute( + time_series_count = execute( cur, f"SELECT COUNT(DISTINCT time_series_uuid) from {self.TABLE_NAME}" ).fetchall()[0][0] @@ -216,10 +199,8 @@ def get_metadata( def has_time_series(self, time_series_uuid: UUID) -> bool: """Return True if there is time series matching the UUID.""" cur = self._con.cursor() - query = ( - f"SELECT COUNT(*) FROM {self.TABLE_NAME} WHERE time_series_uuid = '{time_series_uuid}'" - ) - row = self._execute(cur, query).fetchone() + query = f"SELECT COUNT(*) FROM {self.TABLE_NAME} WHERE time_series_uuid = ?" + row = execute(cur, query, params=(str(time_series_uuid),)).fetchone() return row[0] > 0 def has_time_series_metadata( @@ -249,22 +230,23 @@ def has_time_series_metadata( ) ) - where_clause = self._make_where_clause( + where_clause, params = self._make_where_clause( (component,), variable_name, time_series_type, **user_attributes ) query = f"SELECT COUNT(*) FROM {self.TABLE_NAME} WHERE {where_clause}" cur = self._con.cursor() - res = self._execute(cur, query).fetchone() + res = execute(cur, query, params=params).fetchone() return res[0] > 0 def list_existing_time_series(self, time_series_uuids: list[UUID]) -> set[UUID]: """Return the UUIDs that are present.""" cur = self._con.cursor() - uuids = ",".join([f"'{x}'" for x in time_series_uuids]) + params = tuple(str(x) for x in time_series_uuids) + uuids = ",".join(itertools.repeat("?", len(time_series_uuids))) query = ( f"SELECT time_series_uuid FROM {self.TABLE_NAME} WHERE time_series_uuid IN ({uuids})" ) - rows = self._execute(cur, query).fetchall() + rows = execute(cur, query, params=params).fetchall() return {UUID(x[0]) for x in rows} def list_missing_time_series(self, time_series_uuids: list[UUID]) -> set[UUID]: @@ -291,12 +273,12 @@ def list_metadata( ) ] - where_clause = self._make_where_clause( + where_clause, params = self._make_where_clause( components, variable_name, time_series_type, **user_attributes ) query = f"SELECT metadata FROM {self.TABLE_NAME} WHERE {where_clause}" cur = self._con.cursor() - rows = self._execute(cur, query).fetchall() + rows = execute(cur, query, params=params).fetchall() return [_deserialize_time_series_metadata(x[0]) for x in rows] def _list_metadata_no_sql_json( @@ -314,10 +296,10 @@ def _list_metadata_no_sql_json( The first element of each tuple is the database id field that uniquely identifies the row. """ - where_clause = self._make_where_clause(components, variable_name, time_series_type) + where_clause, params = self._make_where_clause(components, variable_name, time_series_type) query = f"SELECT id, metadata FROM {self.TABLE_NAME} WHERE {where_clause}" cur = self._con.cursor() - rows = self._execute(cur, query).fetchall() + rows = execute(cur, query, params).fetchall() metadata_list = [] for row in rows: @@ -342,13 +324,13 @@ def list_rows( ) raise ISOperationNotAllowed(msg) - where_clause = self._make_where_clause( + where_clause, params = self._make_where_clause( components, variable_name, time_series_type, **user_attributes ) cols = "*" if columns is None else ",".join(columns) query = f"SELECT {cols} FROM {self.TABLE_NAME} WHERE {where_clause}" cur = self._con.cursor() - rows = self._execute(cur, query).fetchall() + rows = execute(cur, query, params=params).fetchall() return rows def remove( @@ -371,39 +353,36 @@ def remove( ): ts_uuids.add(metadata.time_series_uuid) ids.append(id_) - id_str = ",".join([str(x) for x in ids]) + params = [str(x) for x in ids] + id_str = ",".join(itertools.repeat("?", len(ids))) query = f"DELETE FROM {self.TABLE_NAME} WHERE id IN ({id_str})" - self._execute(cur, query) - count_deleted = self._execute(cur, "SELECT changes()").fetchall()[0][0] + execute(cur, query, params=params) + count_deleted = execute(cur, "SELECT changes()").fetchall()[0][0] if count_deleted != len(ids): msg = f"Bug: Unexpected length mismatch {len(ts_uuids)=} {count_deleted=}" raise Exception(msg) self._con.commit() return list(ts_uuids) - where_clause = self._make_where_clause( + where_clause, params = self._make_where_clause( components, variable_name, time_series_type, **user_attributes ) query = f"SELECT time_series_uuid FROM {self.TABLE_NAME} WHERE {where_clause}" - uuids = [UUID(x[0]) for x in self._execute(cur, query).fetchall()] + uuids = [UUID(x[0]) for x in execute(cur, query, params=params).fetchall()] query = f"DELETE FROM {self.TABLE_NAME} WHERE ({where_clause})" - self._execute(cur, query) + execute(cur, query, params=params) self._con.commit() - count_deleted = self._execute(cur, "SELECT changes()").fetchall()[0][0] + count_deleted = execute(cur, "SELECT changes()").fetchall()[0][0] if len(uuids) != count_deleted: msg = f"Bug: Unexpected length mismatch: {len(uuids)=} {count_deleted=}" raise Exception(msg) return uuids - def sql(self, query: str) -> list[tuple]: + def sql(self, query: str, params: Sequence[str] = ()) -> list[tuple]: """Run a SQL query on the time series metadata table.""" cur = self._con.cursor() - return self._execute(cur, query).fetchall() - - def _execute(self, cursor: sqlite3.Cursor, query: str) -> Any: - logger.trace("SQL query: {}", query) - return cursor.execute(query) + return execute(cur, query, params=params).fetchall() def _insert_rows(self, rows: list[tuple]) -> None: cur = self._con.cursor() @@ -414,11 +393,16 @@ def _insert_rows(self, rows: list[tuple]) -> None: finally: self._con.commit() - def _make_components_str(self, *components: Component) -> str: + def _make_components_str(self, params: list[str], *components: Component) -> str: if not components: msg = "At least one component must be passed." raise ISOperationNotAllowed(msg) - or_clause = "OR ".join([f"component_uuid = '{x.uuid}'" for x in components]) + + or_clause = "OR ".join((itertools.repeat("component_uuid = ? ", len(components)))) + + for component in components: + params.append(str(component.uuid)) + return f"({or_clause})" def _make_where_clause( @@ -428,22 +412,36 @@ def _make_where_clause( time_series_type: Optional[str], attribute_hash: Optional[str] = None, **user_attributes: str, - ) -> str: - component_str = self._make_components_str(*components) - var_str = "" if variable_name is None else f"AND variable_name = '{variable_name}'" - ts_str = "" if time_series_type is None else f"AND time_series_type = '{time_series_type}'" - ua_str = ( - f"AND {_make_user_attribute_filter(user_attributes)}" - if attribute_hash is None and user_attributes - else "" - ) - if ua_str: + ) -> tuple[str, list[str]]: + params: list[str] = [] + component_str = self._make_components_str(params, *components) + + if variable_name is None: + var_str = "" + else: + var_str = "AND variable_name = ?" + params.append(variable_name) + + if time_series_type is None: + ts_str = "" + else: + ts_str = "AND time_series_type = ?" + params.append(time_series_type) + + if attribute_hash is None and user_attributes: _raise_if_unsupported_sql_operation() + ua_hash_filter = _make_user_attribute_filter(user_attributes, params) + ua_str = f"AND {ua_hash_filter}" + else: + ua_str = "" - ua_hash = ( - f"AND {_make_user_attribute_hash_filter(attribute_hash)}" if attribute_hash else "" - ) - return f"({component_str} {var_str} {ts_str}) {ua_str} {ua_hash}" + if attribute_hash: + ua_hash_filter = _make_user_attribute_hash_filter(attribute_hash, params) + ua_hash = f"AND {ua_hash_filter}" + else: + ua_hash = "" + + return f"({component_str} {var_str} {ts_str}) {ua_str} {ua_hash}", params def _try_time_series_metadata_by_full_params( self, @@ -455,7 +453,7 @@ def _try_time_series_metadata_by_full_params( ) -> list[tuple] | None: assert variable_name is not None assert time_series_type is not None - where_clause = self._make_where_clause( + where_clause, params = self._make_where_clause( (component,), variable_name, time_series_type, @@ -464,7 +462,7 @@ def _try_time_series_metadata_by_full_params( ) query = f"SELECT {column} FROM {self.TABLE_NAME} WHERE {where_clause}" cur = self._con.cursor() - rows = self._execute(cur, query).fetchall() + rows = execute(cur, query, params=params).fetchall() if not rows: return None @@ -529,14 +527,18 @@ class TimeSeriesCounts: time_series_type_count: dict[tuple[str, str, str, str], int] -def _make_user_attribute_filter(user_attributes: dict[str, Any]) -> str: +def _make_user_attribute_filter(user_attributes: dict[str, Any], params: list[str]) -> str: attrs = _make_user_attribute_dict(user_attributes) - text = "AND ".join([f"metadata->>'$.user_attributes.{k}' = '{v}'" for k, v in attrs.items()]) - return f"({text})" + items = [] + for key, val in attrs.items(): + items.append(f"metadata->>'$.user_attributes.{key}' = ? ") + params.append(val) + return "AND ".join(items) -def _make_user_attribute_hash_filter(attribute_hash) -> str: - return f"user_attributes_hash = '{attribute_hash}'" +def _make_user_attribute_hash_filter(attribute_hash: str, params: list[str]) -> str: + params.append(attribute_hash) + return "user_attributes_hash = ?" def _make_user_attribute_dict(user_attributes: dict[str, Any]) -> dict[str, Any]: diff --git a/src/infrasys/utils/__init__.py b/src/infrasys/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/infrasys/utils/sqlite.py b/src/infrasys/utils/sqlite.py new file mode 100644 index 0000000..6aeec82 --- /dev/null +++ b/src/infrasys/utils/sqlite.py @@ -0,0 +1,34 @@ +"""Utility functions for SQLite""" + +import sqlite3 +from pathlib import Path +from typing import Any, Sequence + +from loguru import logger + + +def backup(src_con: sqlite3.Connection, filename: Path | str) -> None: + """Backup a database to a file.""" + with sqlite3.connect(filename) as dst_con: + src_con.backup(dst_con) + dst_con.close() + logger.info("Backed up the database to {}.", filename) + + +def restore(dst_con: sqlite3.Connection, filename: Path | str) -> None: + """Restore a database from a file.""" + with sqlite3.connect(filename) as src_con: + src_con.backup(dst_con) + src_con.close() + logger.info("Restored the database from {}.", filename) + + +def create_in_memory_db(database: str = ":memory:") -> sqlite3.Connection: + """Create an in-memory database.""" + return sqlite3.connect(database) + + +def execute(cursor: sqlite3.Cursor, query: str, params: Sequence[str] = ()) -> Any: + """Execute a SQL query.""" + logger.trace("SQL query: {query} {params=}", query) + return cursor.execute(query, params) diff --git a/tests/test_arrow_storage.py b/tests/test_arrow_storage.py index e61caf6..1529dec 100644 --- a/tests/test_arrow_storage.py +++ b/tests/test_arrow_storage.py @@ -56,7 +56,7 @@ def test_copy_files(tmp_path): system.to_json(filename) logger.info("Starting deserialization") - system2 = SimpleSystem.from_json(filename, base_directory=tmp_path) + system2 = SimpleSystem.from_json(filename) gen1b = system2.get_component(SimpleGenerator, gen1.name) time_series = system2.time_series.get(gen1b) time_series_fpath = ( diff --git a/tests/test_serialization.py b/tests/test_serialization.py index f382459..d696645 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -59,7 +59,7 @@ def test_serialization(tmp_path): system.to_json(filename, overwrite=True) system2 = SimpleSystem.from_json(filename) for key, val in system.__dict__.items(): - if key not in ("_component_mgr", "_time_series_mgr"): + if key not in ("_component_mgr", "_time_series_mgr", "_con"): assert getattr(system2, key) == val components2 = list(system2.iter_all_components()) @@ -195,16 +195,14 @@ def test_system_save(tmp_path, simple_system_with_time_series): simple_system = simple_system_with_time_series custom_folder = "my_system" fpath = tmp_path / custom_folder - fname = "test_system" + fname = "test_system.json" simple_system.save(fpath, filename=fname) assert os.path.exists(fpath), f"Folder {fpath} was not created successfully" assert os.path.exists(fpath / fname), f"Serialized system {fname} was not created successfully" - fname = "test_system" with pytest.raises(FileExistsError): simple_system.save(fpath, filename=fname) - fname = "test_system" simple_system.save(fpath, filename=fname, overwrite=True) assert os.path.exists(fpath), f"Folder {fpath} was not created successfully" assert os.path.exists(fpath / fname), f"Serialized system {fname} was not created successfully"