diff --git a/rasa/core/brokers/sql.py b/rasa/core/brokers/sql.py index 0517682c8bf7..0d85ae13f539 100644 --- a/rasa/core/brokers/sql.py +++ b/rasa/core/brokers/sql.py @@ -5,7 +5,6 @@ from typing import Any, Dict, Optional, Text, Generator from sqlalchemy.orm import Session -from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta from sqlalchemy import Column, Integer, String from sqlalchemy import Text as SqlAlchemyText # to avoid name clash with typing.Text @@ -22,7 +21,10 @@ class SQLEventBroker(EventBroker): """ - Base: DeclarativeMeta = declarative_base() + from sqlalchemy.orm import DeclarativeBase + + class Base(DeclarativeBase): + pass class SQLBrokerEvent(Base): """ORM which represents a row in the `events` table.""" diff --git a/rasa/core/tracker_store.py b/rasa/core/tracker_store.py index 7f21dd372140..2f3dea3e2e8a 100644 --- a/rasa/core/tracker_store.py +++ b/rasa/core/tracker_store.py @@ -51,14 +51,13 @@ from rasa.shared.nlu.constants import INTENT_NAME_KEY from rasa.utils.endpoints import EndpointConfig import sqlalchemy as sa -from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta if TYPE_CHECKING: import boto3.resources.factory.dynamodb.Table from sqlalchemy.engine.url import URL from sqlalchemy.engine.base import Engine - from sqlalchemy.orm import Session, Query - from sqlalchemy import Sequence + from sqlalchemy.orm import Session, Query, DeclarativeBase + from sqlalchemy import Sequence, Executable logger = logging.getLogger(__name__) @@ -1013,6 +1012,9 @@ def ensure_schema_exists(session: "Session") -> None: engine = session.get_bind() + if not isinstance(engine, Engine): + return + if is_postgresql_url(engine.url): query = sa.exists( sa.select(sa.text("schema_name")) @@ -1041,7 +1043,10 @@ def validate_port(port: Any) -> Optional[int]: class SQLTrackerStore(TrackerStore, SerializedTrackerAsText): """Store which can save and retrieve trackers from an SQL database.""" - Base: DeclarativeMeta = declarative_base() + from sqlalchemy.orm import DeclarativeBase + + class Base(DeclarativeBase): + pass class SQLEvent(Base): """Represents an event in the SQL Tracker Store.""" @@ -1202,30 +1207,26 @@ def _create_database(engine: "Engine", database_name: Text) -> None: """Create database `db` on `engine` if it does not exist.""" import sqlalchemy.exc - conn = engine.connect() - - matching_rows = ( - conn.execution_options(isolation_level="AUTOCOMMIT") - .execute( - sa.text( - "SELECT 1 FROM pg_catalog.pg_database " - "WHERE datname = :database_name" - ), - database_name=database_name, + with engine.connect() as connection: + matching_rows = ( + connection.execution_options(isolation_level="AUTOCOMMIT") + .execute( + sa.text( + f"SELECT 1 FROM pg_catalog.pg_database " + f"WHERE datname = {database_name}" + ) + ) + .rowcount ) - .rowcount - ) - if not matching_rows: - try: - conn.execute(f"CREATE DATABASE {database_name}") - except ( - sqlalchemy.exc.ProgrammingError, - sqlalchemy.exc.IntegrityError, - ) as e: - logger.error(f"Could not create database '{database_name}': {e}") - - conn.close() + if not matching_rows: + try: + connection.execute(sa.text(f"CREATE DATABASE {database_name}")) + except ( + sqlalchemy.exc.ProgrammingError, + sqlalchemy.exc.IntegrityError, + ) as e: + logger.error(f"Could not create database '{database_name}': {e}") @contextlib.contextmanager def session_scope(self) -> Generator["Session", None, None]: diff --git a/rasa/engine/caching.py b/rasa/engine/caching.py index d9e73122e43f..0ecdf39bc5b8 100644 --- a/rasa/engine/caching.py +++ b/rasa/engine/caching.py @@ -11,6 +11,7 @@ from sqlalchemy.engine import URL from sqlalchemy.exc import OperationalError +from sqlalchemy.orm import DeclarativeBase from typing_extensions import Protocol, runtime_checkable import rasa @@ -20,7 +21,6 @@ from rasa.constants import MINIMUM_COMPATIBLE_VERSION import sqlalchemy as sa import sqlalchemy.orm -from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta from rasa.engine.storage.storage import ModelStorage from rasa.shared.engine.caching import ( @@ -145,7 +145,10 @@ def from_cache( class LocalTrainingCache(TrainingCache): """Caches training results on local disk (see parent class for full docstring).""" - Base: DeclarativeMeta = declarative_base() + from sqlalchemy.orm import DeclarativeBase + + class Base(DeclarativeBase): + pass class CacheEntry(Base): """Stores metadata about a single cache entry."""