From a1457ad7112775225c2daf00ee27975a4245207f Mon Sep 17 00:00:00 2001 From: Augusto Herrmann Date: Wed, 8 May 2024 16:37:14 -0300 Subject: [PATCH] Change test to test only db object type --- tests/test_db_connection.py | 41 ++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/tests/test_db_connection.py b/tests/test_db_connection.py index 4c5ca55..01738db 100644 --- a/tests/test_db_connection.py +++ b/tests/test_db_connection.py @@ -5,7 +5,9 @@ import pytest import psycopg2 import pyodbc +from sqlalchemy.engine import Engine +from airflow.hooks.base import BaseHook from airflow.providers.postgres.hooks.postgres import PostgresHook from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook from fastetl.custom_functions.utils.db_connection import ( @@ -101,31 +103,30 @@ def test_get_hook_and_engine_by_provider(conn_id: str): ("postgres-source-conn", "engine"), ], ) -def test_db_connection(conn_id: str, use: Literal["hook", "connection", "engine"]): - """Test that the DbConnection is as expected for the given connections. +def test_db_connection_object_type( + conn_id: str, + use: Literal["hook", "connection", "engine"], +): + """Test that the DbConnection returns the appropriate type of object. Args: conn_id (str): The connection id. use (Literal["hook", "connection", "engine"]): What object to test. - Either one of the following: "hook", "connection", or "engine". """ + CONN_CLASS = { + "postgres": psycopg2.extensions.connection, + "mssql": pyodbc.Connection, + } with DbConnection(conn_id=conn_id, use=use) as db_object: - conn_type = get_conn_type(conn_id) - if conn_type == "postgres": - assert ( - isinstance(db_object, (PostgresHook, psycopg2.extensions.connection)) - or str(db_object) == "Engine(postgresql://root:***@postgres-source/db)" - ) - - elif conn_type == "mssql": - assert isinstance(db_object, (MsSqlHook, pyodbc.Connection)) or str( - db_object - ) == ( - "Engine(mssql+pyodbc://?odbc_connect=" - "Driver%3D%7BODBC+Driver+17+for+SQL+Server%7D%3B" - "Server%3Dmssql-source%2C+1433%3B+++++++++++++++++++++" - "Database%3Dmaster%3BUid%3Dsa%3BPwd%3DozoBaroF2021%3B)" + if use == "hook": + assert isinstance(db_object, BaseHook) + elif use == "connection": + conn_type = get_conn_type(conn_id) + assert isinstance( + db_object, CONN_CLASS[conn_type] ) + else: + assert isinstance(db_object, Engine) @pytest.mark.parametrize( @@ -135,9 +136,7 @@ def test_db_connection(conn_id: str, use: Literal["hook", "connection", "engine" "postgres-source-fake-conn", ], ) -def test_db_fail_connection_wrong_credentials( - conn_id: str -): +def test_db_fail_connection_wrong_credentials(conn_id: str): """Test that the DbConnection will fail for a connection that is using wrong credentials.