Skip to content

Commit

Permalink
Move database ownership to System
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-thom committed Aug 22, 2024
1 parent 07181a0 commit fa2bb13
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 40 deletions.
30 changes: 24 additions & 6 deletions src/infrasys/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import shutil
import sqlite3
from operator import itemgetter
from collections import defaultdict
from datetime import datetime
Expand Down Expand Up @@ -33,16 +34,20 @@
)
from infrasys.time_series_manager import TimeSeriesManager, TIME_SERIES_KWARGS
from infrasys.time_series_models import SingleTimeSeries, TimeSeriesData, TimeSeriesMetadata
from infrasys.utils.sqlite import backup, create_in_memory_db, restore


class System:
"""Implements behavior for systems"""

DB_FILENAME = "time_series_metadata.db"

def __init__(
self,
name: Optional[str] = None,
description: Optional[str] = None,
auto_add_composed_components: bool = False,
con: Optional[sqlite3.Connection] = None,
time_series_manager: Optional[TimeSeriesManager] = None,
uuid: Optional[UUID] = None,
**kwargs: Any,
Expand All @@ -60,6 +65,8 @@ def __init__(
The default behavior is to raise an ISOperationNotAllowed when this condition occurs.
This handles values that are components, such as generator.bus, and lists of
components, such as subsystem.generators, but not any other form of nested components.
con : None | sqlite3.Connection
Users should not pass this. De-serialization (from_json) will pass a Connection.
time_series_manager : None | TimeSeriesManager
Users should not pass this. De-serialization (from_json) will pass a constructed
manager.
Expand All @@ -79,8 +86,11 @@ def __init__(
self._name = name
self._description = description
self._component_mgr = ComponentManager(self._uuid, auto_add_composed_components)
self._con = con or create_in_memory_db()
time_series_kwargs = {k: v for k, v in kwargs.items() if k in TIME_SERIES_KWARGS}
self._time_series_mgr = time_series_manager or TimeSeriesManager(**time_series_kwargs)
self._time_series_mgr = time_series_manager or TimeSeriesManager(
self._con, **time_series_kwargs
)
self._data_format_version: Optional[str] = None
# Note to devs: if you add new fields, add support in to_json/from_json as appropriate.

Expand Down Expand Up @@ -127,10 +137,9 @@ def to_json(self, filename: Path | str, overwrite=False, indent=None, data=None)
msg = f"{filename=} already exists. Choose a different path or set overwrite=True."
raise ISFileExists(msg)

if not filename.parent.exists():
filename.parent.mkdir()

filename.parent.mkdir(exist_ok=True)
time_series_dir = filename.parent / (filename.stem + "_time_series")
time_series_dir.mkdir(exist_ok=True)
system_data = {
"name": self.name,
"description": self.description,
Expand Down Expand Up @@ -161,7 +170,8 @@ def to_json(self, filename: Path | str, overwrite=False, indent=None, data=None)
json.dump(data, f_out, indent=indent)
logger.info("Wrote system data to {}", filename)

self._time_series_mgr.serialize(self._make_time_series_directory(filename))
backup(self._con, time_series_dir / self.DB_FILENAME)
self._time_series_mgr.serialize(time_series_dir)

@classmethod
def from_json(
Expand Down Expand Up @@ -257,12 +267,20 @@ def from_dict(
"""
system_data = data if "system" not in data else data["system"]
ts_kwargs = {k: v for k, v in kwargs.items() if k in TIME_SERIES_KWARGS}
ts_path = (
time_series_parent_dir
if isinstance(time_series_parent_dir, Path)
else Path(time_series_parent_dir)
)
con = create_in_memory_db()
restore(con, ts_path / data["time_series"]["directory"] / System.DB_FILENAME)
time_series_manager = TimeSeriesManager.deserialize(
data["time_series"], time_series_parent_dir, **ts_kwargs
con, data["time_series"], ts_path, **ts_kwargs
)
system = cls(
name=system_data.get("name"),
description=system_data.get("description"),
con=con,
time_series_manager=time_series_manager,
uuid=UUID(system_data["uuid"]),
**kwargs,
Expand Down
17 changes: 11 additions & 6 deletions src/infrasys/time_series_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Manages time series arrays"""

import sqlite3
from datetime import datetime
from pathlib import Path
from typing import Any, Optional, Type
Expand Down Expand Up @@ -32,15 +33,21 @@ def _process_time_series_kwarg(key: str, **kwargs: Any) -> Any:
class TimeSeriesManager:
"""Manages time series for a system."""

def __init__(self, storage: Optional[TimeSeriesStorageBase] = None, **kwargs) -> None:
def __init__(
self,
con: sqlite3.Connection,
storage: Optional[TimeSeriesStorageBase] = None,
initialize: bool = True,
**kwargs,
) -> None:
base_directory: Path | None = _process_time_series_kwarg("time_series_directory", **kwargs)
self._read_only = _process_time_series_kwarg("time_series_read_only", **kwargs)
self._storage = storage or (
InMemoryTimeSeriesStorage()
if _process_time_series_kwarg("time_series_in_memory", **kwargs)
else ArrowTimeSeriesStorage.create_with_temp_directory(base_directory=base_directory)
)
self._metadata_store = TimeSeriesMetadataStore()
self._metadata_store = TimeSeriesMetadataStore(con, initialize=initialize)

# TODO: create parsing mechanism? CSV, CSV + JSON

Expand Down Expand Up @@ -245,11 +252,11 @@ def _get_by_metadata(
def serialize(self, dst: Path | str, src: Optional[Path | str] = None) -> None:
"""Serialize the time series data to dst."""
self._storage.serialize(dst, src)
self._metadata_store.backup(dst)

@classmethod
def deserialize(
cls,
con: sqlite3.Connection,
data: dict[str, Any],
parent_dir: Path | str,
**kwargs: Any,
Expand All @@ -269,9 +276,7 @@ def deserialize(
storage = ArrowTimeSeriesStorage.create_with_temp_directory()
storage.serialize(src=time_series_dir, dst=storage.get_time_series_directory())

mgr = cls(storage=storage, **kwargs)
mgr.metadata_store.restore(time_series_dir)
return mgr
return cls(con, storage=storage, initialize=False, **kwargs)

def _handle_read_only(self) -> None:
if self._read_only:
Expand Down
27 changes: 4 additions & 23 deletions src/infrasys/time_series_metadata_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import os
import sqlite3
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional, Sequence
from uuid import UUID

Expand All @@ -28,11 +27,11 @@ class TimeSeriesMetadataStore:
"""Stores time series metadata in a SQLite database."""

TABLE_NAME = "time_series_metadata"
DB_FILENAME = "time_series_metadata.db"

def __init__(self):
self._con = sqlite3.connect(":memory:")
self._create_metadata_table()
def __init__(self, con: sqlite3.Connection, initialize: bool = True):
self._con = con
if initialize:
self._create_metadata_table()
self._supports_sqlite_json = _does_sqlite_support_json()
if not self._supports_sqlite_json:
# This is true on Ubuntu 22.04, which is used by GitHub runners as of March 2024.
Expand Down Expand Up @@ -126,24 +125,6 @@ def add(
]
self._insert_rows(rows)

def backup(self, directory: Path | str) -> None:
"""Backup the database to a file in directory."""
path = directory if isinstance(directory, Path) else Path(directory)
filename = path / self.DB_FILENAME
with sqlite3.connect(filename) as con:
self._con.backup(con)
con.close()
logger.info("Backed up the time series metadata to {}", filename)

def restore(self, directory: Path | str) -> None:
"""Restore the database from a file to memory."""
path = directory if isinstance(directory, Path) else Path(directory)
filename = path / self.DB_FILENAME
with sqlite3.connect(filename) as con:
con.backup(self._con)
con.close()
logger.info("Restored the time series metadata to memory")

def get_time_series_counts(self) -> "TimeSeriesCounts":
"""Return summary counts of components and time series."""
query = f"""
Expand Down
22 changes: 22 additions & 0 deletions src/infrasys/utils/sqlite.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,33 @@
"""Utility functions for SQLite"""

import sqlite3
from pathlib import Path
from typing import Any, Sequence

from loguru import logger


def backup(src_con: sqlite3.Connection, filename: Path | str) -> None:
"""Backup a database to a file."""
with sqlite3.connect(filename) as dst_con:
src_con.backup(dst_con)
dst_con.close()
logger.info("Backed up the database to {}.", filename)


def restore(dst_con: sqlite3.Connection, filename: Path | str) -> None:
"""Restore a database from a file."""
with sqlite3.connect(filename) as src_con:
src_con.backup(dst_con)
src_con.close()
logger.info("Restored the database from {}.", filename)


def create_in_memory_db(database: str = ":memory:") -> sqlite3.Connection:
"""Create an in-memory database."""
return sqlite3.connect(database)


def execute(cursor: sqlite3.Cursor, query: str, params: Sequence[str] = ()) -> Any:
"""Execute a SQL query."""
logger.trace("SQL query: {query} {params=}", query)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_arrow_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_copy_files(tmp_path):
system.to_json(filename)

logger.info("Starting deserialization")
system2 = SimpleSystem.from_json(filename, base_directory=tmp_path)
system2 = SimpleSystem.from_json(filename)
gen1b = system2.get_component(SimpleGenerator, gen1.name)
time_series = system2.time_series.get(gen1b)
time_series_fpath = (
Expand Down
6 changes: 2 additions & 4 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_serialization(tmp_path):
system.to_json(filename, overwrite=True)
system2 = SimpleSystem.from_json(filename)
for key, val in system.__dict__.items():
if key not in ("_component_mgr", "_time_series_mgr"):
if key not in ("_component_mgr", "_time_series_mgr", "_con"):
assert getattr(system2, key) == val

components2 = list(system2.iter_all_components())
Expand Down Expand Up @@ -195,16 +195,14 @@ def test_system_save(tmp_path, simple_system_with_time_series):
simple_system = simple_system_with_time_series
custom_folder = "my_system"
fpath = tmp_path / custom_folder
fname = "test_system"
fname = "test_system.json"
simple_system.save(fpath, filename=fname)
assert os.path.exists(fpath), f"Folder {fpath} was not created successfully"
assert os.path.exists(fpath / fname), f"Serialized system {fname} was not created successfully"

fname = "test_system"
with pytest.raises(FileExistsError):
simple_system.save(fpath, filename=fname)

fname = "test_system"
simple_system.save(fpath, filename=fname, overwrite=True)
assert os.path.exists(fpath), f"Folder {fpath} was not created successfully"
assert os.path.exists(fpath / fname), f"Serialized system {fname} was not created successfully"
Expand Down

0 comments on commit fa2bb13

Please sign in to comment.