Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Type Safety: Migrate from Sequence to TypedDict #3

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 109 additions & 64 deletions lollipop/cli/deconvolute.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""Implements the deconvolute command for the lollipop CLI."""

import sys
from typing import List, Tuple, Union
from typing import List, Tuple, Union, TypedDict, Callable
import logging
import time
import multiprocessing
Expand All @@ -22,6 +22,7 @@
level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s"
)


kernels = {
"gaussian": ll.GaussianKernel,
"box": ll.BoxKernel,
Expand Down Expand Up @@ -59,47 +60,10 @@ def _get_location_data(
return loc_df


def _deconvolute_bootstrap_wrapper(args):
"""
Wrapper for the deconvolute bootstrap function to allow for parallel processing,
handling the random number generator seeding.
"""

# Unpack the arguments
*other_args, child_seed = args

# Initialize the default random number generator with the child seed
np.random.default_rng(child_seed)

return _deconvolute_bootstrap(*other_args)


def _deconvolute_bootstrap(
n_cores: int,
location: str,
loc_df: pd.DataFrame,
bootstrap: int,
locations_list: List,
no_loc: bool,
no_date: bool,
date_intervals: List[Tuple],
var_dates: dict,
kernel: Union[ll.GaussianKernel, ll.BoxKernel],
kernel_params: dict,
regressor: Union[ll.NnlsReg, ll.RobustReg],
regressor_params: dict,
confint: Union[ll.confints.NullConfint, ll.confints.WaldConfint],
confint_params: dict,
deconv_params: dict,
have_confint: bool,
confint_name: str,
namefield: str,
) -> List[pd.DataFrame]:
"""
Deconvolute the data for a given location and bootstrap iteration.
class DeconvBootstrapsArgsNoSeed(TypedDict):
"""Arguments for the deconvolute bootstrap function.
_deconvolute_bootstrap_

Parameters
----------
n_cores : int
The number of cores to use for parallel processing,
(Only used for all locastion progress bar parr/sequc)
Expand Down Expand Up @@ -140,13 +104,94 @@ def _deconvolute_bootstrap(
The name of the confidence interval.
namefield: str
The column to use as 'names' for the entries in the tally table.
"""

n_cores: int
location: str
loc_df: pd.DataFrame
bootstrap: int
locations_list: List
no_loc: bool
no_date: bool
date_intervals: List[Tuple]
var_dates: dict
kernel: Union[ll.GaussianKernel, ll.BoxKernel]
kernel_params: dict
regressor: Union[ll.NnlsReg, ll.RobustReg]
regressor_params: dict
confint: Union[ll.confints.NullConfint, ll.confints.WaldConfint]
confint_params: dict
deconv_params: dict
have_confint: bool
confint_name: str
namefield: str


class DeconvBootstrapsArgs(DeconvBootstrapsArgsNoSeed):
"""
Arguments for the deconvolute bootstrap function.
_deconvolute_bootstrap_wrapper_

child_seed: np.random.SeedSequence
seed for the given location
"""

child_seed: np.random.SeedSequence


def _deconvolute_bootstrap_wrapper(
args: DeconvBootstrapsArgs,
) -> Callable[[DeconvBootstrapsArgsNoSeed], List[pd.DataFrame]]:
"""
Wrapper for the deconvolute bootstrap function to allow for parallel processing,
handling the random number generator seeding.
"""

# Get the seed
child_seed = args.pop("child_seed")

# Initialize the default random number generator with the child seed
np.random.default_rng(child_seed)

return _deconvolute_bootstrap(args)


def _deconvolute_bootstrap(args: DeconvBootstrapsArgsNoSeed) -> List[pd.DataFrame]:
"""
Deconvolute the data for a given location and bootstrap iteration.

Parameters
----------
args : DeconvBootstrapsArgs
The arguments for the deconvolute bootstrap function.

Returns
-------
List[pd.DataFrame]
The deconvolution results for the location and bootstrap iterations.
"""

# Unpack the arguments
n_cores = args["n_cores"]
location = args["location"]
loc_df = args["loc_df"]
bootstrap = args["bootstrap"]
locations_list = args["locations_list"]
no_loc = args["no_loc"]
no_date = args["no_date"]
date_intervals = args["date_intervals"]
var_dates = args["var_dates"]
kernel = args["kernel"]
kernel_params = args["kernel_params"]
regressor = args["regressor"]
regressor_params = args["regressor_params"]
confint = args["confint"]
confint_params = args["confint_params"]
deconv_params = args["deconv_params"]
have_confint = args["have_confint"]
confint_name = args["confint_name"]
namefield = args["namefield"]

# deconvolution results
deconv = []

Expand Down Expand Up @@ -662,29 +707,29 @@ def deconvolute(
# CORE DECONVOLUTION
# iterate over locations
# Create delayed objects for each location
args_list = [
(
n_cores,
location,
loc_df,
bootstrap,
locations_list,
no_loc,
no_date,
date_intervals,
var_dates,
kernel,
kernel_params,
regressor,
regressor_params,
confint,
confint_params,
deconv_params,
have_confint,
confint_name,
namefield,
child_seed,
)
args_list: List[DeconvBootstrapsArgs] = [
{
"n_cores": n_cores,
"location": location,
"loc_df": loc_df,
"bootstrap": bootstrap,
"locations_list": locations_list,
"no_loc": no_loc,
"no_date": no_date,
"date_intervals": date_intervals,
"var_dates": var_dates,
"kernel": kernel,
"kernel_params": kernel_params,
"regressor": regressor,
"regressor_params": regressor_params,
"confint": confint,
"confint_params": confint_params,
"deconv_params": deconv_params,
"have_confint": have_confint,
"confint_name": confint_name,
"namefield": namefield,
"child_seed": child_seed,
}
for location, loc_df, child_seed in zip(locations_list, loc_dfs, seeds)
]

Expand Down
Loading