diff --git a/src/density_estimation/density_estimation.jl b/src/density_estimation/density_estimation.jl index ac8f624..83af1fa 100644 --- a/src/density_estimation/density_estimation.jl +++ b/src/density_estimation/density_estimation.jl @@ -45,13 +45,7 @@ function density_estimation( ) outcome, parents = read_density_variables(density_file) dataset = TargetedEstimation.instantiate_dataset(dataset_file) - TargetedEstimation.coerce_types!(dataset, parents) - # Continuous and Counts except Binary outcomes are treated as continuous - if elscitype(dataset[!, outcome]) <: Infinite && !(TargetedEstimation.isbinary(outcome, dataset)) - TargetedEstimation.coerce_types!(dataset, [outcome], rules=:discrete_to_continuous) - else - TargetedEstimation.coerce_types!(dataset, [outcome], rules=:few_to_finite) - end + coerce_parents_and_outcome!(dataset, parents, outcome=outcome) X, y = X_y(dataset, parents, outcome) density_estimators = get_density_estimators(mode, X, y) diff --git a/src/estimation.jl b/src/estimation.jl index e5bf41f..002a7f7 100644 --- a/src/estimation.jl +++ b/src/estimation.jl @@ -21,9 +21,6 @@ function estimate_from_simulated_data( origin_dataset = TargetedEstimation.instantiate_dataset(origin_dataset) sample_size = sample_size !== nothing ? sample_size : nrow(origin_dataset) estimands = TargetedEstimation.instantiate_estimands(estimands_config, origin_dataset) - for estimand in estimands - TargetedEstimation.coerce_types!(origin_dataset, estimand) - end estimators_spec = TargetedEstimation.instantiate_estimators(estimators_config) sampler = get_sampler(sampler_config, estimands) statistics = [] diff --git a/src/samplers/density_estimate_sampler.jl b/src/samplers/density_estimate_sampler.jl index 36deb18..226f43c 100644 --- a/src/samplers/density_estimate_sampler.jl +++ b/src/samplers/density_estimate_sampler.jl @@ -46,6 +46,7 @@ function sample_from(sampler::DensityEstimateSampler, origin_dataset; n=100) sampled_dataset = sample_from(origin_dataset, sampler.all_parents_set; n=n) for (outcome, (parents, file)) in sampler.density_mapping + coerce_parents_and_outcome!(sampled_dataset, parents; outcome=nothing) conditional_density_estimate = Simulations.sieve_neural_net_density_estimator(file) sampled_dataset[!, outcome] = sample_from( conditional_density_estimate, diff --git a/src/utils.jl b/src/utils.jl index 275b8ce..9dc5af6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,6 +2,18 @@ ### Misc Functions ### ######################################################################## +function coerce_parents_and_outcome!(dataset, parents; outcome=nothing) + TargetedEstimation.coerce_types!(dataset, parents) + if outcome !== nothing + # Continuous and Counts except Binary outcomes are treated as continuous + if elscitype(dataset[!, outcome]) <: Infinite && !(TargetedEstimation.isbinary(outcome, dataset)) + TargetedEstimation.coerce_types!(dataset, [outcome], rules=:discrete_to_continuous) + else + TargetedEstimation.coerce_types!(dataset, [outcome], rules=:few_to_finite) + end + end +end + function variables_from_dataset(dataset) confounders = Set([]) outcome_extra_covariates = Set(["Genetic-Sex", "Age-Assessment"])