Skip to content

Commit

Permalink
fix type issues
Browse files Browse the repository at this point in the history
  • Loading branch information
tabergma committed Oct 27, 2023
1 parent 436d107 commit 1b9473f
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 30 deletions.
6 changes: 4 additions & 2 deletions rasa/core/brokers/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Any, Dict, Optional, Text, Generator

from sqlalchemy.orm import Session
from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta
from sqlalchemy import Column, Integer, String
from sqlalchemy import Text as SqlAlchemyText # to avoid name clash with typing.Text

Expand All @@ -22,7 +21,10 @@ class SQLEventBroker(EventBroker):
"""

Base: DeclarativeMeta = declarative_base()
from sqlalchemy.orm import DeclarativeBase

class Base(DeclarativeBase):
pass

class SQLBrokerEvent(Base):
"""ORM which represents a row in the `events` table."""
Expand Down
53 changes: 27 additions & 26 deletions rasa/core/tracker_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,13 @@
from rasa.shared.nlu.constants import INTENT_NAME_KEY
from rasa.utils.endpoints import EndpointConfig
import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta

if TYPE_CHECKING:
import boto3.resources.factory.dynamodb.Table
from sqlalchemy.engine.url import URL
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import Session, Query
from sqlalchemy import Sequence
from sqlalchemy.orm import Session, Query, DeclarativeBase

Check failure on line 59 in rasa/core/tracker_store.py

View workflow job for this annotation

GitHub Actions / Code Quality

F401 [*] `sqlalchemy.orm.DeclarativeBase` imported but unused
from sqlalchemy import Sequence, Executable

Check failure on line 60 in rasa/core/tracker_store.py

View workflow job for this annotation

GitHub Actions / Code Quality

F401 [*] `sqlalchemy.Executable` imported but unused

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1013,6 +1012,9 @@ def ensure_schema_exists(session: "Session") -> None:

engine = session.get_bind()

if not isinstance(engine, Engine):
return

if is_postgresql_url(engine.url):
query = sa.exists(
sa.select(sa.text("schema_name"))
Expand Down Expand Up @@ -1041,7 +1043,10 @@ def validate_port(port: Any) -> Optional[int]:
class SQLTrackerStore(TrackerStore, SerializedTrackerAsText):
"""Store which can save and retrieve trackers from an SQL database."""

Base: DeclarativeMeta = declarative_base()
from sqlalchemy.orm import DeclarativeBase

Check failure on line 1046 in rasa/core/tracker_store.py

View workflow job for this annotation

GitHub Actions / Code Quality

F811 Redefinition of unused `DeclarativeBase` from line 59

class Base(DeclarativeBase):
pass

class SQLEvent(Base):
"""Represents an event in the SQL Tracker Store."""
Expand Down Expand Up @@ -1202,30 +1207,26 @@ def _create_database(engine: "Engine", database_name: Text) -> None:
"""Create database `db` on `engine` if it does not exist."""
import sqlalchemy.exc

conn = engine.connect()

matching_rows = (
conn.execution_options(isolation_level="AUTOCOMMIT")
.execute(
sa.text(
"SELECT 1 FROM pg_catalog.pg_database "
"WHERE datname = :database_name"
),
database_name=database_name,
with engine.connect() as connection:
matching_rows = (
connection.execution_options(isolation_level="AUTOCOMMIT")
.execute(
sa.text(
f"SELECT 1 FROM pg_catalog.pg_database "
f"WHERE datname = {database_name}"
)
)
.rowcount
)
.rowcount
)

if not matching_rows:
try:
conn.execute(f"CREATE DATABASE {database_name}")
except (
sqlalchemy.exc.ProgrammingError,
sqlalchemy.exc.IntegrityError,
) as e:
logger.error(f"Could not create database '{database_name}': {e}")

conn.close()
if not matching_rows:
try:
connection.execute(sa.text(f"CREATE DATABASE {database_name}"))
except (
sqlalchemy.exc.ProgrammingError,
sqlalchemy.exc.IntegrityError,
) as e:
logger.error(f"Could not create database '{database_name}': {e}")

@contextlib.contextmanager
def session_scope(self) -> Generator["Session", None, None]:
Expand Down
7 changes: 5 additions & 2 deletions rasa/engine/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sqlalchemy.engine import URL

from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import DeclarativeBase

Check failure on line 14 in rasa/engine/caching.py

View workflow job for this annotation

GitHub Actions / Code Quality

F401 [*] `sqlalchemy.orm.DeclarativeBase` imported but unused
from typing_extensions import Protocol, runtime_checkable

import rasa
Expand All @@ -20,7 +21,6 @@
from rasa.constants import MINIMUM_COMPATIBLE_VERSION
import sqlalchemy as sa
import sqlalchemy.orm
from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta

from rasa.engine.storage.storage import ModelStorage
from rasa.shared.engine.caching import (
Expand Down Expand Up @@ -145,7 +145,10 @@ def from_cache(
class LocalTrainingCache(TrainingCache):
"""Caches training results on local disk (see parent class for full docstring)."""

Base: DeclarativeMeta = declarative_base()
from sqlalchemy.orm import DeclarativeBase

Check failure on line 148 in rasa/engine/caching.py

View workflow job for this annotation

GitHub Actions / Code Quality

F811 Redefinition of unused `DeclarativeBase` from line 14

class Base(DeclarativeBase):
pass

class CacheEntry(Base):
"""Stores metadata about a single cache entry."""
Expand Down

0 comments on commit 1b9473f

Please sign in to comment.