Skip to content

Commit

Permalink
add triggers to set trials and experiments index
Browse files Browse the repository at this point in the history
  • Loading branch information
rasca committed Nov 7, 2024
1 parent 0482643 commit b7770a3
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 135 deletions.
9 changes: 7 additions & 2 deletions flou/alembic.ini
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ script_location = migrations
file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s

# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
#_utils defaults to the current working directory.
prepend_sys_path = .

# timezone to use when rendering the date within the migration file
Expand Down Expand Up @@ -83,7 +83,7 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne

# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
keys = root,sqlalchemy,alembic,alembic_utils

[handlers]
keys = console
Expand All @@ -106,6 +106,11 @@ level = INFO
handlers =
qualname = alembic

[logger_alembic_utils]
level = INFO
handlers =
qualname = alembic_utils

[handler_console]
class = StreamHandler
args = (sys.stderr,)
Expand Down
66 changes: 64 additions & 2 deletions flou/flou/experiments/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from sqlalchemy import ForeignKey, text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.types import String
from alembic_utils.pg_function import PGFunction
from alembic_utils.pg_trigger import PGTrigger

from flou.database.models import Base
from flou.database.utils import JSONType
Expand All @@ -15,7 +17,7 @@ class Experiment(Base):
id: Mapped[uuid.UUID] = mapped_column(
primary_key=True, server_default=text("gen_random_uuid()")
)
index: Mapped[str] = mapped_column(default=0, nullable=False)
index: Mapped[int] = mapped_column(default=0, nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str] = mapped_column(String(), nullable=False)
inputs: Mapped[dict] = mapped_column(JSONType(), default=dict, nullable=False)
Expand All @@ -24,13 +26,43 @@ class Experiment(Base):
trials: Mapped[List["Trial"]] = relationship(back_populates="experiment")


# Define the trigger function using alembic_utils
experiments_set_index = PGFunction(
schema="public",
signature="experiments_set_index()",
definition="""
RETURNS trigger AS $$
BEGIN
NEW.index := COALESCE(
(SELECT MAX(index) FROM experiments_experiments), -1
) + 1;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
""",
)


# Define the trigger using alembic_utils
experiments_set_index_trigger = PGTrigger(
schema="public",
signature="experiments_set_index_trigger",
on_entity="public.experiments_experiments",
is_constraint=False,
definition="""
BEFORE INSERT ON public.experiments_experiments
FOR EACH ROW EXECUTE FUNCTION public.experiments_set_index();
""",
)


class Trial(Base):
__tablename__ = "experiments_trials"

id: Mapped[uuid.UUID] = mapped_column(
primary_key=True, server_default=text("gen_random_uuid()")
)
index: Mapped[str] = mapped_column(default=0, nullable=False)
index: Mapped[int] = mapped_column(default=0, nullable=False)
experiment_id: Mapped[int] = mapped_column(ForeignKey("experiments_experiments.id"))
name: Mapped[str] = mapped_column(String(255), nullable=False)
ltm_id: Mapped[int] = mapped_column(ForeignKey("ltm_ltms.id"), nullable=False)
Expand All @@ -40,3 +72,33 @@ class Trial(Base):
outputs: Mapped[dict] = mapped_column(JSONType(), default=dict, nullable=False)

experiment: Mapped[Experiment] = relationship("Experiment", back_populates="trials")


# Define the trigger function using alembic_utils
trials_set_index = PGFunction(
schema="public",
signature="trials_set_index()",
definition="""
RETURNS trigger AS $$
BEGIN
NEW.index := COALESCE(
(SELECT MAX(index) FROM experiments_trials WHERE experiment_id = NEW.experiment_id), -1
) + 1;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
""",
)


# Define the trigger using alembic_utils
trials_set_index_trigger = PGTrigger(
schema="public",
signature="trials_set_index_trigger",
on_entity="public.experiments_trials",
is_constraint=False,
definition="""
BEFORE INSERT ON public.experiments_trials
FOR EACH ROW EXECUTE FUNCTION public.trials_set_index();
""",
)
21 changes: 19 additions & 2 deletions flou/migrations/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from sqlalchemy.exc import InvalidRequestError

from alembic import context
from alembic_utils.replaceable_entity import register_entities
from alembic_utils.pg_trigger import PGTrigger
from alembic_utils.pg_function import PGFunction

# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
Expand All @@ -21,20 +24,34 @@
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
import flou
from flou.database.models import Base
target_metadata = Base.metadata

# Function to dynamically import all `models.py` and `models/` from apps
def import_all_models():
root_package = 'flou'
alembic_utils_entities = []


package = importlib.import_module(root_package)
for importer, modname, ispkg in pkgutil.walk_packages(package.__path__, package.__name__ + '.'):
if modname.endswith('.models'):
importlib.import_module(modname)
models_module = importlib.import_module(modname)
# look for PGTriggers which need to be manually added to
# alembic_utils' `register_entitis`
for _, variable in models_module.__dict__.items():

if isinstance(variable, (PGTrigger, PGFunction, )):
# Ensure variable is not a subclass
if variable.__class__ in (PGTrigger, PGFunction, ):
alembic_utils_entities.append(variable)

register_entities(alembic_utils_entities) # register all entities

# setup models & triggers
try:
import_all_models()
import_all_models() # add every model to the DeclarativeBase
except InvalidRequestError:
pass # don't break on tests

Expand Down
61 changes: 0 additions & 61 deletions flou/migrations/versions/2024_11_04_2001-d17bb320f4d3_.py

This file was deleted.

40 changes: 0 additions & 40 deletions flou/migrations/versions/2024_11_06_1430-69c9354bb7ff_.py

This file was deleted.

28 changes: 0 additions & 28 deletions flou/migrations/versions/2024_11_06_1432-076b9aea5f59_.py

This file was deleted.

1 change: 1 addition & 0 deletions flou/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"watchdog",
"sqlalchemy[asyncio]",
"alembic",
"alembic_utils",
"psycopg",
]

Expand Down

0 comments on commit b7770a3

Please sign in to comment.