Skip to content

Commit

Permalink
Change test to test only db object type
Browse files Browse the repository at this point in the history
  • Loading branch information
augusto-herrmann committed May 8, 2024
1 parent f8c21d1 commit a1457ad
Showing 1 changed file with 20 additions and 21 deletions.
41 changes: 20 additions & 21 deletions tests/test_db_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import pytest
import psycopg2
import pyodbc
from sqlalchemy.engine import Engine

from airflow.hooks.base import BaseHook
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
from fastetl.custom_functions.utils.db_connection import (
Expand Down Expand Up @@ -101,31 +103,30 @@ def test_get_hook_and_engine_by_provider(conn_id: str):
("postgres-source-conn", "engine"),
],
)
def test_db_connection(conn_id: str, use: Literal["hook", "connection", "engine"]):
"""Test that the DbConnection is as expected for the given connections.
def test_db_connection_object_type(
conn_id: str,
use: Literal["hook", "connection", "engine"],
):
"""Test that the DbConnection returns the appropriate type of object.
Args:
conn_id (str): The connection id.
use (Literal["hook", "connection", "engine"]): What object to test.
Either one of the following: "hook", "connection", or "engine".
"""
CONN_CLASS = {
"postgres": psycopg2.extensions.connection,
"mssql": pyodbc.Connection,
}
with DbConnection(conn_id=conn_id, use=use) as db_object:
conn_type = get_conn_type(conn_id)
if conn_type == "postgres":
assert (
isinstance(db_object, (PostgresHook, psycopg2.extensions.connection))
or str(db_object) == "Engine(postgresql://root:***@postgres-source/db)"
)

elif conn_type == "mssql":
assert isinstance(db_object, (MsSqlHook, pyodbc.Connection)) or str(
db_object
) == (
"Engine(mssql+pyodbc://?odbc_connect="
"Driver%3D%7BODBC+Driver+17+for+SQL+Server%7D%3B"
"Server%3Dmssql-source%2C+1433%3B+++++++++++++++++++++"
"Database%3Dmaster%3BUid%3Dsa%3BPwd%3DozoBaroF2021%3B)"
if use == "hook":
assert isinstance(db_object, BaseHook)
elif use == "connection":
conn_type = get_conn_type(conn_id)
assert isinstance(
db_object, CONN_CLASS[conn_type]
)
else:
assert isinstance(db_object, Engine)


@pytest.mark.parametrize(
Expand All @@ -135,9 +136,7 @@ def test_db_connection(conn_id: str, use: Literal["hook", "connection", "engine"
"postgres-source-fake-conn",
],
)
def test_db_fail_connection_wrong_credentials(
conn_id: str
):
def test_db_fail_connection_wrong_credentials(conn_id: str):
"""Test that the DbConnection will fail for a connection that is using
wrong credentials.
Expand Down

0 comments on commit a1457ad

Please sign in to comment.