Skip to content

Commit

Permalink
Add tests for overrides
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Dec 11, 2024
1 parent da26220 commit 05e1920
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 12 deletions.
11 changes: 0 additions & 11 deletions tests/data/testrun-conf.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,3 @@ matches_dir = "testrun"
[extend_parameters]
min_group_size = 2
num_threads = 1

[[override]]
start = "2020"
stop = "2020-03-01"
parameters.max_missing_sites = 1000

[[override]]
start = "2020-01-02"
stop = "2020-01-03"
parameters.max_missing_sites = 750
parameters.min_group_size = 200
72 changes: 71 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def make_config(
log_dir="logs",
matches_dir="matches",
exclude_sites=list(),
override=list(),
**kwargs,
):
config = {
Expand All @@ -171,11 +172,12 @@ def make_config(
"matches_dir": str(tmp_path / matches_dir),
"exclude_sites": exclude_sites,
"extend_parameters": {**kwargs},
"override": override,
}
filename = tmp_path / "config.toml"
with open(filename, "w") as f:
toml = tomli_w.dumps(config)
# print("Generated", toml)
print("Generated", toml)
f.write(toml)
return filename

Expand Down Expand Up @@ -253,6 +255,74 @@ def test_include_samples(self, tmp_path, fx_ts_map, fx_dataset):
assert np.sum(ts.nodes_time[ts.samples()] == 0) == 1
assert ts.num_samples == 1

def test_override(self, tmp_path, fx_ts_map, fx_dataset):
hmm_cost_threshold = 47
config_file = self.make_config(
tmp_path,
fx_dataset,
exclude_sites=[56, 57, 58, 59, 60],
override=[
{
"start": "2020-01-01",
"stop": "2020-01-02",
"parameters": {"hmm_cost_threshold": hmm_cost_threshold},
}
],
)
runner = ct.CliRunner(mix_stderr=False)
result = runner.invoke(
cli.cli,
f"infer {config_file} --stop 2020-01-02",
catch_exceptions=False,
)
assert result.exit_code == 0
date = "2020-01-01"
ts_path = tmp_path / "results" / "test" / f"test_{date}.ts"
ts = tskit.load(ts_path)
params = json.loads(ts.provenance(-1).record)["parameters"]
assert params["hmm_cost_threshold"] == hmm_cost_threshold

assert "SRR14631544" in ts.metadata["sc2ts"]["samples_strain"]
assert np.sum(ts.nodes_time[ts.samples()] == 0) == 1
assert ts.num_samples == 1

def test_multiple_override(self, tmp_path, fx_ts_map, fx_dataset):
hmm_cost_threshold = 3
config_file = self.make_config(
tmp_path,
fx_dataset,
exclude_sites=[56, 57, 58, 59, 60],
# Overrides get applied sequentially, and last overlapping value wins.
override=[
{
"start": "2020-01-01",
"stop": "2020-01-02",
"parameters": {"hmm_cost_threshold": 123},
},
{
"start": "2020",
"stop": "2020-07-01",
"parameters": {"hmm_cost_threshold": hmm_cost_threshold},
},
],
hmm_cost_threshold=4000,
)
runner = ct.CliRunner(mix_stderr=False)
result = runner.invoke(
cli.cli,
f"infer {config_file} --stop 2020-01-02",
catch_exceptions=False,
)
assert result.exit_code == 0
date = "2020-01-01"
ts_path = tmp_path / "results" / "test" / f"test_{date}.ts"
ts = tskit.load(ts_path)
params = json.loads(ts.provenance(-1).record)["parameters"]
assert params["hmm_cost_threshold"] == hmm_cost_threshold

assert "SRR14631544" not in ts.metadata["sc2ts"]["samples_strain"]
assert ts.num_samples == 0


@pytest.mark.skip("Broken by dataset")
class TestRunRematchRecombinants:
Expand Down

0 comments on commit 05e1920

Please sign in to comment.