Skip to content

Commit

Permalink
extended typed dict
Browse files Browse the repository at this point in the history
  • Loading branch information
gordonkoehn committed Oct 14, 2024
1 parent 5a41b23 commit 5437e0e
Showing 1 changed file with 105 additions and 63 deletions.
168 changes: 105 additions & 63 deletions lollipop/cli/deconvolute.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""Implements the deconvolute command for the lollipop CLI."""

import sys
from typing import List, Tuple, Union, TypedDict
from typing import List, Tuple, Union, TypedDict, Callable
import logging
import time
import multiprocessing
Expand All @@ -22,6 +22,7 @@
level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s"
)


kernels = {
"gaussian": ll.GaussianKernel,
"box": ll.BoxKernel,
Expand Down Expand Up @@ -59,47 +60,10 @@ def _get_location_data(
return loc_df


def _deconvolute_bootstrap_wrapper(args):
"""
Wrapper for the deconvolute bootstrap function to allow for parallel processing,
handling the random number generator seeding.
"""

# Get the seed
child_seed = args.pop("child_seed")

# Initialize the default random number generator with the child seed
np.random.default_rng(child_seed)

return _deconvolute_bootstrap(**args)


def _deconvolute_bootstrap(
n_cores: int,
location: str,
loc_df: pd.DataFrame,
bootstrap: int,
locations_list: List,
no_loc: bool,
no_date: bool,
date_intervals: List[Tuple],
var_dates: dict,
kernel: Union[ll.GaussianKernel, ll.BoxKernel],
kernel_params: dict,
regressor: Union[ll.NnlsReg, ll.RobustReg],
regressor_params: dict,
confint: Union[ll.confints.NullConfint, ll.confints.WaldConfint],
confint_params: dict,
deconv_params: dict,
have_confint: bool,
confint_name: str,
namefield: str,
) -> List[pd.DataFrame]:
"""
Deconvolute the data for a given location and bootstrap iteration.
Parameters
----------
class DeconvBootstrapsArgsNoSeed(TypedDict):
"""Arguments for the deconvolute bootstrap function.
_deconvolute_bootstrap_
n_cores : int
The number of cores to use for parallel processing,
(Only used for all locastion progress bar parr/sequc)
Expand Down Expand Up @@ -140,13 +104,91 @@ def _deconvolute_bootstrap(
The name of the confidence interval.
namefield: str
The column to use as 'names' for the entries in the tally table.
"""
n_cores: int
location: str
loc_df: pd.DataFrame
bootstrap: int
locations_list: List
no_loc: bool
no_date: bool
date_intervals: List[Tuple]
var_dates: dict
kernel: Union[ll.GaussianKernel, ll.BoxKernel]
kernel_params: dict
regressor: Union[ll.NnlsReg, ll.RobustReg]
regressor_params: dict
confint: Union[ll.confints.NullConfint, ll.confints.WaldConfint]
confint_params: dict
deconv_params: dict
have_confint: bool
confint_name: str
namefield: str

class DeconvBootstrapsArgs(DeconvBootstrapsArgsNoSeed):
"""
Arguments for the deconvolute bootstrap function.
_deconvolute_bootstrap_wrapper_
child_seed: np.random.SeedSequence
seed for the given location
"""
child_seed: np.random.SeedSequence


def _deconvolute_bootstrap_wrapper(args : DeconvBootstrapsArgs) -> Callable[[DeconvBootstrapsArgsNoSeed], List[pd.DataFrame]]:
"""
Wrapper for the deconvolute bootstrap function to allow for parallel processing,
handling the random number generator seeding.
"""

# Get the seed
child_seed = args.pop("child_seed")

# Initialize the default random number generator with the child seed
np.random.default_rng(child_seed)

return _deconvolute_bootstrap(args)


def _deconvolute_bootstrap(
args : DeconvBootstrapsArgsNoSeed
) -> List[pd.DataFrame]:
"""
Deconvolute the data for a given location and bootstrap iteration.
Parameters
----------
args : DeconvBootstrapsArgs
The arguments for the deconvolute bootstrap function.
Returns
-------
List[pd.DataFrame]
The deconvolution results for the location and bootstrap iterations.
"""

# Unpack the arguments
n_cores = args["n_cores"]
location = args["location"]
loc_df = args["loc_df"]
bootstrap = args["bootstrap"]
locations_list = args["locations_list"]
no_loc = args["no_loc"]
no_date = args["no_date"]
date_intervals = args["date_intervals"]
var_dates = args["var_dates"]
kernel = args["kernel"]
kernel_params = args["kernel_params"]
regressor = args["regressor"]
regressor_params = args["regressor_params"]
confint = args["confint"]
confint_params = args["confint_params"]
deconv_params = args["deconv_params"]
have_confint = args["have_confint"]
confint_name = args["confint_name"]
namefield = args["namefield"]

# deconvolution results
deconv = []

Expand Down Expand Up @@ -662,28 +704,28 @@ def deconvolute(
# CORE DECONVOLUTION
# iterate over locations
# Create delayed objects for each location
args_list = [
args_list : List[DeconvBootstrapsArgs] = [
{
"n_cores" : n_cores,
"location" : location,
"loc_df" : loc_df,
"bootstrap" : bootstrap,
"locations_list" : locations_list,
"no_loc" : no_loc,
"no_date" : no_date,
"date_intervals" : date_intervals,
"var_dates" : var_dates,
"kernel" : kernel,
"kernel_params" : kernel_params,
"regressor" : regressor,
"regressor_params" : regressor_params,
"confint" : confint,
"confint_params" : confint_params,
"deconv_params" : deconv_params,
"have_confint" : have_confint,
"confint_name" : confint_name,
"namefield" : namefield,
"child_seed" : child_seed,
"n_cores": n_cores,
"location": location,
"loc_df": loc_df,
"bootstrap": bootstrap,
"locations_list": locations_list,
"no_loc": no_loc,
"no_date": no_date,
"date_intervals": date_intervals,
"var_dates": var_dates,
"kernel": kernel,
"kernel_params": kernel_params,
"regressor": regressor,
"regressor_params": regressor_params,
"confint": confint,
"confint_params": confint_params,
"deconv_params": deconv_params,
"have_confint": have_confint,
"confint_name": confint_name,
"namefield": namefield,
"child_seed": child_seed,
}
for location, loc_df, child_seed in zip(locations_list, loc_dfs, seeds)
]
Expand Down

0 comments on commit 5437e0e

Please sign in to comment.