From 4d2fd5c491109b2e8ad9246a6a7332c4cf5e7d6d Mon Sep 17 00:00:00 2001 From: ChronoBoot <146651003+ChronoBoot@users.noreply.github.com> Date: Wed, 17 Jan 2024 18:47:00 +0100 Subject: [PATCH] Add option to not rewrite the csv files if already existing --- backend/src/data_processing/simple_load_data.py | 11 ++++++++--- backend/src/data_processing/simple_read_data.py | 3 ++- backend/src/main.py | 9 +++++++-- backend/src/models/random_forest_loan_predictor.py | 4 +++- .../tests/data_processing/test_simple_load_data.py | 7 +------ backend/utils/profiling_utils.py | 3 ++- frontend/src/main.py | 3 ++- frontend/src/ui/dash_user_interface.py | 3 +++ 8 files changed, 28 insertions(+), 15 deletions(-) diff --git a/backend/src/data_processing/simple_load_data.py b/backend/src/data_processing/simple_load_data.py index 843d2f3..54780d7 100644 --- a/backend/src/data_processing/simple_load_data.py +++ b/backend/src/data_processing/simple_load_data.py @@ -1,7 +1,6 @@ import os import logging -from dotenv import load_dotenv import requests from backend.src.data_processing.load_data_abc import LoadData import os @@ -28,7 +27,8 @@ class SimpleLoadData(LoadData): ] def __init__(self) -> None: - logging.basicConfig(level=logging.DEBUG) + log_format = "%(asctime)s - %(levelname)s - %(message)s" + logging.basicConfig(format=log_format, level=logging.DEBUG) logging.debug("SimpleLoadData initialized") @conditional_profile @@ -40,7 +40,7 @@ def download_file(self, url: str, filepath: str) -> None: f.write(chunk) @conditional_profile - def load(self, file_urls: list, download_path: str) -> None: + def load(self, file_urls: list, download_path: str, rewrite = False) -> None: """ Load data from Azure Blob Storage and save it to a local directory. @@ -58,6 +58,11 @@ def load(self, file_urls: list, download_path: str) -> None: for url in file_urls: filename = url.split('/')[-1] # Extracts the file name filepath = f"{download_path}/{filename}" + + if os.path.exists(filepath) and not rewrite: + logging.info(f"File {filename} already exists in {download_path}. Skipping download.") + continue + self.download_file(url, filepath) logging.info(f"Downloaded file {filename} from Azure Blob Storage to {filepath}") except Exception as e: diff --git a/backend/src/data_processing/simple_read_data.py b/backend/src/data_processing/simple_read_data.py index e6e3166..bf633cd 100644 --- a/backend/src/data_processing/simple_read_data.py +++ b/backend/src/data_processing/simple_read_data.py @@ -15,7 +15,8 @@ class SimpleReadData(ReadDataABC): CHUNK_SIZE = 10000 def __init__(self) -> None: - logging.basicConfig(level=logging.DEBUG) + log_format = "%(asctime)s - %(levelname)s - %(message)s" + logging.basicConfig(format=log_format, level=logging.DEBUG) logging.debug("SimpleReadData initialized") @conditional_profile diff --git a/backend/src/main.py b/backend/src/main.py index 41e2e6e..d5ead65 100644 --- a/backend/src/main.py +++ b/backend/src/main.py @@ -14,7 +14,9 @@ predictor = RandomForestLoanPredictor() loader = SimpleLoadData() reader = SimpleReadData() -logging.basicConfig(level=logging.DEBUG) + +log_format = "%(asctime)s - %(levelname)s - %(message)s" +logging.basicConfig(format=log_format, level=logging.DEBUG) app.logger = logging.getLogger(__name__) app.logger.addHandler(logging.StreamHandler()) app.logger.setLevel(logging.DEBUG) @@ -34,11 +36,14 @@ def test(): @app.route('/train', methods=['POST']) def train(): + app.logger.info('Training model...') data = request.get_json() sampling_frequency = int(data['sampling_frequency']) target_variable = data['target_variable'] + rewrite = data['rewrite'] if 'rewrite' in data else "False" + rewrite_bool = True if rewrite == "True" else False - loader.load(SimpleLoadData.CSV_URLS, FILES_FOLDER) + loader.load(SimpleLoadData.CSV_URLS, FILES_FOLDER, rewrite_bool) reader.write_data(FILES_FOLDER, DATA_FILE_MODEL, sampling_frequency) loans = reader.read_data(FILES_FOLDER, DATA_FILE_MODEL) diff --git a/backend/src/models/random_forest_loan_predictor.py b/backend/src/models/random_forest_loan_predictor.py index b616fd5..0cab0b1 100644 --- a/backend/src/models/random_forest_loan_predictor.py +++ b/backend/src/models/random_forest_loan_predictor.py @@ -26,7 +26,9 @@ def __init__(self) -> None: self.y_test = None self.random_state = 42 self.test_size = 0.2 - logging.basicConfig(level=logging.DEBUG) + + log_format = "%(asctime)s - %(levelname)s - %(message)s" + logging.basicConfig(format=log_format, level=logging.DEBUG) logging.debug("RandomForestLoanPredictor initialized") @conditional_profile diff --git a/backend/tests/data_processing/test_simple_load_data.py b/backend/tests/data_processing/test_simple_load_data.py index d5e189d..474ef2e 100644 --- a/backend/tests/data_processing/test_simple_load_data.py +++ b/backend/tests/data_processing/test_simple_load_data.py @@ -45,12 +45,7 @@ def test_load(self, mock_download_file): # Assert mock_download_file.assert_called_once_with('https://www.something.com/test.txt', 'test_download_path/test.txt') - @patch.dict('os.environ', { - 'AZURE_STORAGE_CONNECTION_STRING': 'DefaultEndpointsProtocol=https;AccountName=testaccount;AccountKey=testkey;BlobEndpoint=testendpoint', - 'AZURE_STORAGE_CONTAINER_NAME': 'test_container_name' - }) - @patch('backend.src.data_processing.simple_load_data.load_dotenv') - def test_save(self, mock_load_dotenv): + def test_save(self): # Arrange simple_load_data = SimpleLoadData() diff --git a/backend/utils/profiling_utils.py b/backend/utils/profiling_utils.py index 5f6d7c7..dc6022e 100644 --- a/backend/utils/profiling_utils.py +++ b/backend/utils/profiling_utils.py @@ -6,7 +6,8 @@ def is_profiling_enabled(): - return bool(os.getenv('ENABLE_PROFILING', False)) + enable_profiling = os.getenv('ENABLE_PROFILING', "False") + return True if enable_profiling == "True" else False def conditional_profile(func): if is_profiling_enabled(): diff --git a/frontend/src/main.py b/frontend/src/main.py index 5d33023..766dbce 100644 --- a/frontend/src/main.py +++ b/frontend/src/main.py @@ -172,7 +172,8 @@ def _main(FREQUENCY : int): load_dotenv() # Configure logging - logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') + log_format = "%(asctime)s - %(levelname)s - %(message)s" + logging.basicConfig(format=log_format, level=logging.DEBUG) # Parse command line arguments parser = argparse.ArgumentParser(description="Loan prediction application") diff --git a/frontend/src/ui/dash_user_interface.py b/frontend/src/ui/dash_user_interface.py index 4b503b8..afc7dd5 100644 --- a/frontend/src/ui/dash_user_interface.py +++ b/frontend/src/ui/dash_user_interface.py @@ -54,6 +54,9 @@ def __init__(self, categorical_values : dict, float_values: dict, loan_example: self.app.layout = self._create_layout() + log_format = "%(asctime)s - %(levelname)s - %(message)s" + logging.basicConfig(format=log_format, level=logging.DEBUG) + def get_nb_steps(self, min, max) -> int: """ Gets the number of steps for the user interface slider.