Skip to content

Commit

Permalink
change coercion again...
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed May 29, 2024
1 parent 2d6805a commit 53f985d
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 10 deletions.
8 changes: 1 addition & 7 deletions src/density_estimation/density_estimation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,7 @@ function density_estimation(
)
outcome, parents = read_density_variables(density_file)
dataset = TargetedEstimation.instantiate_dataset(dataset_file)
TargetedEstimation.coerce_types!(dataset, parents)
# Continuous and Counts except Binary outcomes are treated as continuous
if elscitype(dataset[!, outcome]) <: Infinite && !(TargetedEstimation.isbinary(outcome, dataset))
TargetedEstimation.coerce_types!(dataset, [outcome], rules=:discrete_to_continuous)
else
TargetedEstimation.coerce_types!(dataset, [outcome], rules=:few_to_finite)
end
coerce_parents_and_outcome!(dataset, parents, outcome=outcome)

X, y = X_y(dataset, parents, outcome)
density_estimators = get_density_estimators(mode, X, y)
Expand Down
3 changes: 0 additions & 3 deletions src/estimation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ function estimate_from_simulated_data(
origin_dataset = TargetedEstimation.instantiate_dataset(origin_dataset)
sample_size = sample_size !== nothing ? sample_size : nrow(origin_dataset)
estimands = TargetedEstimation.instantiate_estimands(estimands_config, origin_dataset)
for estimand in estimands
TargetedEstimation.coerce_types!(origin_dataset, estimand)
end
estimators_spec = TargetedEstimation.instantiate_estimators(estimators_config)
sampler = get_sampler(sampler_config, estimands)
statistics = []
Expand Down
1 change: 1 addition & 0 deletions src/samplers/density_estimate_sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ function sample_from(sampler::DensityEstimateSampler, origin_dataset; n=100)
sampled_dataset = sample_from(origin_dataset, sampler.all_parents_set; n=n)

for (outcome, (parents, file)) in sampler.density_mapping
coerce_parents_and_outcome!(sampled_dataset, parents; outcome=nothing)
conditional_density_estimate = Simulations.sieve_neural_net_density_estimator(file)
sampled_dataset[!, outcome] = sample_from(
conditional_density_estimate,
Expand Down
12 changes: 12 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@
### Misc Functions ###
########################################################################

function coerce_parents_and_outcome!(dataset, parents; outcome=nothing)
TargetedEstimation.coerce_types!(dataset, parents)
if outcome !== nothing
# Continuous and Counts except Binary outcomes are treated as continuous
if elscitype(dataset[!, outcome]) <: Infinite && !(TargetedEstimation.isbinary(outcome, dataset))
TargetedEstimation.coerce_types!(dataset, [outcome], rules=:discrete_to_continuous)
else
TargetedEstimation.coerce_types!(dataset, [outcome], rules=:few_to_finite)
end
end
end

function variables_from_dataset(dataset)
confounders = Set([])
outcome_extra_covariates = Set(["Genetic-Sex", "Age-Assessment"])
Expand Down

0 comments on commit 53f985d

Please sign in to comment.