Skip to content

Commit

Permalink
N function in pedigree sims
Browse files Browse the repository at this point in the history
  • Loading branch information
hannesbecher committed Oct 3, 2024
1 parent e2932a2 commit cf0d463
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions msprime/pedigrees.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def sim_pedigree_backward(
builder,
rng,
*,
population_size,
N_function,
num_samples,
end_time,
):
Expand All @@ -412,7 +412,7 @@ def sim_pedigree_backward(
# Because we don't have overlapping generations we can add the ancestors
# to the pedigree once the previous generation has been produced
for time in range(0, end_time):
population = np.arange(population_size, dtype=np.int32)
population = np.arange(N_function(time), dtype=np.int32)
parents = rng.choice(population, (len(ancestors), 2))
unique_parents = np.unique(parents)
parent_ids = np.searchsorted(unique_parents, parents).astype(np.int32)
Expand All @@ -432,15 +432,16 @@ def sim_pedigree_forward(
builder,
rng,
*,
population_size,
N_function,
end_time,
):
population = np.array([tskit.NULL], dtype=np.int32)

# To make the semantics compatible with dtwf, the end_time means the
# *end* of generation end_time
for time in reversed(range(end_time + 1)):
N = population_size # This could be derived from the Demography
# N = p_size(time) # This could be derived from the Demography
N = N_function(time) # This could be derived from the Demography
# NB this is *with* replacement, so 1 / N chance of selfing
parents = rng.choice(population, (N, 2))
population = builder.add_individuals(parents=parents, time=time)
Expand All @@ -459,26 +460,37 @@ def sim_pedigree(
# Internal utility for generating pedigree data. This function is not
# part of the public API and subject to arbitrary changes/removal
# in the future.
num_samples = population_size if num_samples is None else num_samples

# allow for population_size to be a single value or a function
if not callable(population_size):

def N_function(_):
return population_size

else:
N_function = population_size

num_samples = N_function(0) if num_samples is None else num_samples
builder = PedigreeBuilder()
rng = np.random.RandomState(random_seed)

if direction == "forward":
if num_samples != population_size:
if num_samples != N_function(0):
raise ValueError(
"num_samples must be equal to population_size for forward simulation"
"if at all specified, num_samples must be equal to population_size "
"at generation 0 for forward simulation"
)
tables = sim_pedigree_forward(
builder,
rng,
population_size=population_size,
N_function=N_function,
end_time=end_time,
)
elif direction == "backward":
tables = sim_pedigree_backward(
builder,
rng,
population_size=population_size,
N_function=N_function,
num_samples=num_samples,
end_time=end_time,
)
Expand Down

0 comments on commit cf0d463

Please sign in to comment.