From 5437e0e89f673a729db5ecdd2c2cd8a4eba4dcfd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gordon=20J=2E=20K=C3=B6hn?= Date: Mon, 14 Oct 2024 13:58:51 +0200 Subject: [PATCH] extended typed dict --- lollipop/cli/deconvolute.py | 168 ++++++++++++++++++++++-------------- 1 file changed, 105 insertions(+), 63 deletions(-) diff --git a/lollipop/cli/deconvolute.py b/lollipop/cli/deconvolute.py index f6820c1..b0cd960 100755 --- a/lollipop/cli/deconvolute.py +++ b/lollipop/cli/deconvolute.py @@ -3,7 +3,7 @@ """Implements the deconvolute command for the lollipop CLI.""" import sys -from typing import List, Tuple, Union, TypedDict +from typing import List, Tuple, Union, TypedDict, Callable import logging import time import multiprocessing @@ -22,6 +22,7 @@ level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s" ) + kernels = { "gaussian": ll.GaussianKernel, "box": ll.BoxKernel, @@ -59,47 +60,10 @@ def _get_location_data( return loc_df -def _deconvolute_bootstrap_wrapper(args): - """ - Wrapper for the deconvolute bootstrap function to allow for parallel processing, - handling the random number generator seeding. - """ - - # Get the seed - child_seed = args.pop("child_seed") - - # Initialize the default random number generator with the child seed - np.random.default_rng(child_seed) - - return _deconvolute_bootstrap(**args) - - -def _deconvolute_bootstrap( - n_cores: int, - location: str, - loc_df: pd.DataFrame, - bootstrap: int, - locations_list: List, - no_loc: bool, - no_date: bool, - date_intervals: List[Tuple], - var_dates: dict, - kernel: Union[ll.GaussianKernel, ll.BoxKernel], - kernel_params: dict, - regressor: Union[ll.NnlsReg, ll.RobustReg], - regressor_params: dict, - confint: Union[ll.confints.NullConfint, ll.confints.WaldConfint], - confint_params: dict, - deconv_params: dict, - have_confint: bool, - confint_name: str, - namefield: str, -) -> List[pd.DataFrame]: - """ - Deconvolute the data for a given location and bootstrap iteration. - - Parameters - ---------- +class DeconvBootstrapsArgsNoSeed(TypedDict): + """Arguments for the deconvolute bootstrap function. + _deconvolute_bootstrap_ + n_cores : int The number of cores to use for parallel processing, (Only used for all locastion progress bar parr/sequc) @@ -140,6 +104,63 @@ def _deconvolute_bootstrap( The name of the confidence interval. namefield: str The column to use as 'names' for the entries in the tally table. + """ + n_cores: int + location: str + loc_df: pd.DataFrame + bootstrap: int + locations_list: List + no_loc: bool + no_date: bool + date_intervals: List[Tuple] + var_dates: dict + kernel: Union[ll.GaussianKernel, ll.BoxKernel] + kernel_params: dict + regressor: Union[ll.NnlsReg, ll.RobustReg] + regressor_params: dict + confint: Union[ll.confints.NullConfint, ll.confints.WaldConfint] + confint_params: dict + deconv_params: dict + have_confint: bool + confint_name: str + namefield: str + +class DeconvBootstrapsArgs(DeconvBootstrapsArgsNoSeed): + """ + Arguments for the deconvolute bootstrap function. + _deconvolute_bootstrap_wrapper_ + + child_seed: np.random.SeedSequence + seed for the given location + """ + child_seed: np.random.SeedSequence + + +def _deconvolute_bootstrap_wrapper(args : DeconvBootstrapsArgs) -> Callable[[DeconvBootstrapsArgsNoSeed], List[pd.DataFrame]]: + """ + Wrapper for the deconvolute bootstrap function to allow for parallel processing, + handling the random number generator seeding. + """ + + # Get the seed + child_seed = args.pop("child_seed") + + # Initialize the default random number generator with the child seed + np.random.default_rng(child_seed) + + return _deconvolute_bootstrap(args) + + +def _deconvolute_bootstrap( + args : DeconvBootstrapsArgsNoSeed +) -> List[pd.DataFrame]: + """ + Deconvolute the data for a given location and bootstrap iteration. + + Parameters + ---------- + args : DeconvBootstrapsArgs + The arguments for the deconvolute bootstrap function. Returns ------- @@ -147,6 +168,27 @@ def _deconvolute_bootstrap( The deconvolution results for the location and bootstrap iterations. """ + # Unpack the arguments + n_cores = args["n_cores"] + location = args["location"] + loc_df = args["loc_df"] + bootstrap = args["bootstrap"] + locations_list = args["locations_list"] + no_loc = args["no_loc"] + no_date = args["no_date"] + date_intervals = args["date_intervals"] + var_dates = args["var_dates"] + kernel = args["kernel"] + kernel_params = args["kernel_params"] + regressor = args["regressor"] + regressor_params = args["regressor_params"] + confint = args["confint"] + confint_params = args["confint_params"] + deconv_params = args["deconv_params"] + have_confint = args["have_confint"] + confint_name = args["confint_name"] + namefield = args["namefield"] + # deconvolution results deconv = [] @@ -662,28 +704,28 @@ def deconvolute( # CORE DECONVOLUTION # iterate over locations # Create delayed objects for each location - args_list = [ + args_list : List[DeconvBootstrapsArgs] = [ { - "n_cores" : n_cores, - "location" : location, - "loc_df" : loc_df, - "bootstrap" : bootstrap, - "locations_list" : locations_list, - "no_loc" : no_loc, - "no_date" : no_date, - "date_intervals" : date_intervals, - "var_dates" : var_dates, - "kernel" : kernel, - "kernel_params" : kernel_params, - "regressor" : regressor, - "regressor_params" : regressor_params, - "confint" : confint, - "confint_params" : confint_params, - "deconv_params" : deconv_params, - "have_confint" : have_confint, - "confint_name" : confint_name, - "namefield" : namefield, - "child_seed" : child_seed, + "n_cores": n_cores, + "location": location, + "loc_df": loc_df, + "bootstrap": bootstrap, + "locations_list": locations_list, + "no_loc": no_loc, + "no_date": no_date, + "date_intervals": date_intervals, + "var_dates": var_dates, + "kernel": kernel, + "kernel_params": kernel_params, + "regressor": regressor, + "regressor_params": regressor_params, + "confint": confint, + "confint_params": confint_params, + "deconv_params": deconv_params, + "have_confint": have_confint, + "confint_name": confint_name, + "namefield": namefield, + "child_seed": child_seed, } for location, loc_df, child_seed in zip(locations_list, loc_dfs, seeds) ]