From 6fcdd7d599d31d6401389c10b6117ddc51b6e450 Mon Sep 17 00:00:00 2001 From: Jean Carlo Machado Date: Mon, 3 Jun 2024 11:17:25 +0200 Subject: [PATCH] black --- ddataflow/data_source.py | 4 +- ddataflow/ddataflow.py | 66 ++++++++++++++++------------- ddataflow/setup/ddataflow_config.py | 16 +++---- examples/ddataflow_config.py | 10 ++--- tests/test_configuration.py | 1 + tests/test_sql.py | 10 +---- 6 files changed, 53 insertions(+), 54 deletions(-) diff --git a/ddataflow/data_source.py b/ddataflow/data_source.py index 30d932f..1535362 100644 --- a/ddataflow/data_source.py +++ b/ddataflow/data_source.py @@ -129,6 +129,4 @@ def _estimate_size(self, df: DataFrame) -> float: print(f"Amount of rows in dataframe to estimate size: {df.count()}") average_variable_size_bytes = 50 - return (df.count() * len(df.columns) * average_variable_size_bytes) / ( - 1024**3 - ) + return (df.count() * len(df.columns) * average_variable_size_bytes) / (1024**3) diff --git a/ddataflow/ddataflow.py b/ddataflow/ddataflow.py index 3bbe31c..df7e88e 100644 --- a/ddataflow/ddataflow.py +++ b/ddataflow/ddataflow.py @@ -6,9 +6,13 @@ from ddataflow.data_sources import DataSources from ddataflow.downloader import DataSourceDownloader from ddataflow.exceptions import WriterNotFoundException -from ddataflow.sampling.default import build_default_sampling_for_sources, DefaultSamplerOptions +from ddataflow.sampling.default import ( + build_default_sampling_for_sources, + DefaultSamplerOptions, +) from ddataflow.sampling.sampler import Sampler from ddataflow.utils import get_or_create_spark, using_databricks_connect +from pyspark.sql import DataFrame logger = logging.getLogger(__name__) handler = logging.StreamHandler() @@ -72,7 +76,6 @@ def __init__( self._snapshot_path = base_path + "/" + project_folder_name self._local_path = self._LOCAL_BASE_SNAPSHOT_PATH + "/" + project_folder_name - if default_sampler: # set this before creating data sources DefaultSamplerOptions.set(default_sampler) @@ -92,7 +95,6 @@ def __init__( size_limit=self._size_limit, ) - self._data_writers: dict = data_writers if data_writers else {} self._offline_enabled = os.getenv(self._ENABLE_OFFLINE_MODE_ENVVARIABLE, False) @@ -116,8 +118,10 @@ def __init__( if self._ddataflow_enabled: self.set_logger_level(logging.DEBUG) else: - logger.info("DDataflow is now DISABLED." - "PRODUCTION data will be used and it will write to production tables.") + logger.info( + "DDataflow is now DISABLED." + "PRODUCTION data will be used and it will write to production tables." + ) @staticmethod def setup_project(): @@ -125,6 +129,7 @@ def setup_project(): Sets up a new ddataflow project with empty data sources in the current directory """ from ddataflow.setup.setup_project import setup_project + setup_project() @staticmethod @@ -167,18 +172,38 @@ def current_project() -> "DDataflow": return ddataflow_config.ddataflow + def source(self, name: str, debugger=False) -> DataFrame: + """ + Gives access to the data source configured in the dataflow + + You can also use this function in the terminal with --debugger=True to inspect the dataframe. + """ + self.print_status() + + logger.debug("Loading data source") + data_source: DataSource = self._data_sources.get_data_source(name) + logger.debug("Data source loaded") + df = self._get_data_from_data_source(data_source) + + if debugger: + logger.debug(f"Debugger enabled: {debugger}") + logger.debug("In debug mode now, use query to inspect it") + breakpoint() + + return df + def _get_spark(self): return get_or_create_spark() def enable(self): """ - When enabled ddataflow will read from the filtered datasoruces + When enabled ddataflow will read from the filtered data sources instead of production tables. And write to testing tables instead of production ones. """ self._ddataflow_enabled = True - def is_enabled(self): + def is_enabled(self) -> bool: return self._ddataflow_enabled def enable_offline(self): @@ -186,34 +211,14 @@ def enable_offline(self): self._offline_enabled = True self.enable() - def is_local(self): + def is_local(self) -> bool: return self._offline_enabled def disable_offline(self): """Programatically enable offline mode""" self._offline_enabled = False - def source(self, name: str, debugger=False): - """ - Gives access to the data source configured in the dataflow - - You can also use this function in the terminal with --debugger=True to inspect the dataframe. - """ - self.print_status() - - logger.debug("Loading data source") - data_source: DataSource = self._data_sources.get_data_source(name) - logger.debug("Data source loaded") - df = self._get_df_from_source(data_source) - - if debugger: - logger.debug(f"Debugger enabled: {debugger}") - logger.debug("In debug mode now, use query to inspect it") - breakpoint() - - return df - - def source_name(self, name, disable_view_creation=False): + def source_name(self, name, disable_view_creation=False) -> str: """ Given the name of a production table, returns the name of the corresponding ddataflow table when ddataflow is enabled If ddataflow is disabled get the production one. @@ -221,6 +226,7 @@ def source_name(self, name, disable_view_creation=False): logger.debug("Source name used: ", name) source_name = name + # the gist of ddtafalow if self._ddataflow_enabled: source_name = self._get_new_table_name(name) if disable_view_creation: @@ -286,7 +292,7 @@ def disable(self): """Disable ddtaflow overriding tables, uses production state in other words""" self._ddataflow_enabled = False - def _get_df_from_source(self, data_source): + def _get_data_from_data_source(self, data_source: DataSource) -> DataFrame: if not self._ddataflow_enabled: logger.debug("DDataflow not enabled") # goes directly to production without prefilters diff --git a/ddataflow/setup/ddataflow_config.py b/ddataflow/setup/ddataflow_config.py index 4c94893..9577909 100644 --- a/ddataflow/setup/ddataflow_config.py +++ b/ddataflow/setup/ddataflow_config.py @@ -5,25 +5,25 @@ "project_folder_name": "your_project_name", # add here your tables or paths with customized sampling logic "data_sources": { - #"demo_tours": { + # "demo_tours": { # # use default_sampling=True to rely on automatic sampling otherwise it will use the whole data # "default_sampling": True, - #}, - #"demo_tours2": { + # }, + # "demo_tours2": { # "source": lambda spark: spark.table("demo_tours"), # # to customize the sampling logic # "filter": lambda df: df.limit(500), - #}, - #"/mnt/cleaned/EventName": { + # }, + # "/mnt/cleaned/EventName": { # "file-type": "parquet", # "default_sampling": True, - #}, + # }, }, - #"default_sampler": { + # "default_sampler": { # # defines the amount of rows retrieved with the default sampler, used as .limit(limit) in the dataframe # # default = 10000 # "limit": 100000, - #}, + # }, # to customize the max size of your examples uncomment the line below # "data_source_size_limit_gb": 3 # to customize the location of your test datasets in your data wharehouse diff --git a/examples/ddataflow_config.py b/examples/ddataflow_config.py index 6e5ced8..f72162e 100644 --- a/examples/ddataflow_config.py +++ b/examples/ddataflow_config.py @@ -4,16 +4,16 @@ # add here your tables or paths with customized sampling logic "data_sources": { "demo_tours": { - "source": lambda spark: spark.table('demo_tours'), - "filter": lambda df: df.limit(500) + "source": lambda spark: spark.table("demo_tours"), + "filter": lambda df: df.limit(500), }, "demo_locations": { - "source": lambda spark: spark.table('demo_locations'), + "source": lambda spark: spark.table("demo_locations"), "default_sampling": True, - } + }, }, "project_folder_name": "ddataflow_demo", } # initialize the application and validate the configuration -ddataflow = DDataflow(**config) \ No newline at end of file +ddataflow = DDataflow(**config) diff --git a/tests/test_configuration.py b/tests/test_configuration.py index 52a7609..ea28f0d 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -41,6 +41,7 @@ "data_source_size_limit_gb": 3, } + def test_initialize_successfully(): """ Tests that a correct config will not fail to be instantiated diff --git a/tests/test_sql.py b/tests/test_sql.py index faf3008..eaa8245 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -8,14 +8,8 @@ def test_sql(): config = { "project_folder_name": "unit_tests", - "data_sources": { - "location": { - 'default_sampling': True - } - }, - 'default_sampler': { - 'limit': 2 - } + "data_sources": {"location": {"default_sampling": True}}, + "default_sampler": {"limit": 2}, } ddataflow = DDataflow(**config)