From b85ea2f874bbeb6110fade7ad8f13456bc4caf70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gordon=20J=2E=20K=C3=B6hn?= Date: Mon, 14 Oct 2024 13:00:44 +0200 Subject: [PATCH] Fortified Typing of the Deconvolution Fn and Wrapper. --- lollipop/cli/deconvolute.py | 173 +++++++++++++++++++++++------------- 1 file changed, 109 insertions(+), 64 deletions(-) diff --git a/lollipop/cli/deconvolute.py b/lollipop/cli/deconvolute.py index b827b88..8eacf86 100755 --- a/lollipop/cli/deconvolute.py +++ b/lollipop/cli/deconvolute.py @@ -3,7 +3,7 @@ """Implements the deconvolute command for the lollipop CLI.""" import sys -from typing import List, Tuple, Union +from typing import List, Tuple, Union, TypedDict, Callable import logging import time import multiprocessing @@ -22,6 +22,7 @@ level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s" ) + kernels = { "gaussian": ll.GaussianKernel, "box": ll.BoxKernel, @@ -59,47 +60,10 @@ def _get_location_data( return loc_df -def _deconvolute_bootstrap_wrapper(args): - """ - Wrapper for the deconvolute bootstrap function to allow for parallel processing, - handling the random number generator seeding. - """ - - # Unpack the arguments - *other_args, child_seed = args - - # Initialize the default random number generator with the child seed - np.random.default_rng(child_seed) - - return _deconvolute_bootstrap(*other_args) - - -def _deconvolute_bootstrap( - n_cores: int, - location: str, - loc_df: pd.DataFrame, - bootstrap: int, - locations_list: List, - no_loc: bool, - no_date: bool, - date_intervals: List[Tuple], - var_dates: dict, - kernel: Union[ll.GaussianKernel, ll.BoxKernel], - kernel_params: dict, - regressor: Union[ll.NnlsReg, ll.RobustReg], - regressor_params: dict, - confint: Union[ll.confints.NullConfint, ll.confints.WaldConfint], - confint_params: dict, - deconv_params: dict, - have_confint: bool, - confint_name: str, - namefield: str, -) -> List[pd.DataFrame]: - """ - Deconvolute the data for a given location and bootstrap iteration. +class DeconvBootstrapsArgsNoSeed(TypedDict): + """Arguments for the deconvolute bootstrap function. + _deconvolute_bootstrap_ - Parameters - ---------- n_cores : int The number of cores to use for parallel processing, (Only used for all locastion progress bar parr/sequc) @@ -140,6 +104,66 @@ def _deconvolute_bootstrap( The name of the confidence interval. namefield: str The column to use as 'names' for the entries in the tally table. + """ + + n_cores: int + location: str + loc_df: pd.DataFrame + bootstrap: int + locations_list: List + no_loc: bool + no_date: bool + date_intervals: List[Tuple] + var_dates: dict + kernel: Union[ll.GaussianKernel, ll.BoxKernel] + kernel_params: dict + regressor: Union[ll.NnlsReg, ll.RobustReg] + regressor_params: dict + confint: Union[ll.confints.NullConfint, ll.confints.WaldConfint] + confint_params: dict + deconv_params: dict + have_confint: bool + confint_name: str + namefield: str + + +class DeconvBootstrapsArgs(DeconvBootstrapsArgsNoSeed): + """ + Arguments for the deconvolute bootstrap function. + _deconvolute_bootstrap_wrapper_ + + child_seed: np.random.SeedSequence + seed for the given location + """ + + child_seed: np.random.SeedSequence + + +def _deconvolute_bootstrap_wrapper( + args: DeconvBootstrapsArgs, +) -> Callable[[DeconvBootstrapsArgsNoSeed], List[pd.DataFrame]]: + """ + Wrapper for the deconvolute bootstrap function to allow for parallel processing, + handling the random number generator seeding. + """ + + # Get the seed + child_seed = args.pop("child_seed") + + # Initialize the default random number generator with the child seed + np.random.default_rng(child_seed) + + return _deconvolute_bootstrap(args) + + +def _deconvolute_bootstrap(args: DeconvBootstrapsArgsNoSeed) -> List[pd.DataFrame]: + """ + Deconvolute the data for a given location and bootstrap iteration. + + Parameters + ---------- + args : DeconvBootstrapsArgs + The arguments for the deconvolute bootstrap function. Returns ------- @@ -147,6 +171,27 @@ def _deconvolute_bootstrap( The deconvolution results for the location and bootstrap iterations. """ + # Unpack the arguments + n_cores = args["n_cores"] + location = args["location"] + loc_df = args["loc_df"] + bootstrap = args["bootstrap"] + locations_list = args["locations_list"] + no_loc = args["no_loc"] + no_date = args["no_date"] + date_intervals = args["date_intervals"] + var_dates = args["var_dates"] + kernel = args["kernel"] + kernel_params = args["kernel_params"] + regressor = args["regressor"] + regressor_params = args["regressor_params"] + confint = args["confint"] + confint_params = args["confint_params"] + deconv_params = args["deconv_params"] + have_confint = args["have_confint"] + confint_name = args["confint_name"] + namefield = args["namefield"] + # deconvolution results deconv = [] @@ -662,29 +707,29 @@ def deconvolute( # CORE DECONVOLUTION # iterate over locations # Create delayed objects for each location - args_list = [ - ( - n_cores, - location, - loc_df, - bootstrap, - locations_list, - no_loc, - no_date, - date_intervals, - var_dates, - kernel, - kernel_params, - regressor, - regressor_params, - confint, - confint_params, - deconv_params, - have_confint, - confint_name, - namefield, - child_seed, - ) + args_list: List[DeconvBootstrapsArgs] = [ + { + "n_cores": n_cores, + "location": location, + "loc_df": loc_df, + "bootstrap": bootstrap, + "locations_list": locations_list, + "no_loc": no_loc, + "no_date": no_date, + "date_intervals": date_intervals, + "var_dates": var_dates, + "kernel": kernel, + "kernel_params": kernel_params, + "regressor": regressor, + "regressor_params": regressor_params, + "confint": confint, + "confint_params": confint_params, + "deconv_params": deconv_params, + "have_confint": have_confint, + "confint_name": confint_name, + "namefield": namefield, + "child_seed": child_seed, + } for location, loc_df, child_seed in zip(locations_list, loc_dfs, seeds) ]