Skip to content

Commit

Permalink
Remove extend command and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Dec 11, 2024
1 parent ba2dc21 commit aa94aa8
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 248 deletions.
201 changes: 0 additions & 201 deletions sc2ts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,6 @@ def info_ts(ts_path, recombinants, verbose):
print(ti.recombinants_summary())



def summarise_base(ts, date, progress):
ti = sc2ts.TreeInfo(ts, quick=True)
node_info = "; ".join(f"{k}:{v}" for k, v in ti.node_counts().items())
Expand All @@ -294,205 +293,6 @@ def summarise_base(ts, date, progress):
print(f"{date} Start base: {node_info}", file=sys.stderr)


def parse_include_samples(fileobj):
strains = []
for line in fileobj:
strain = line.split(maxsplit=1)[0]
strains.append(strain)
return strains


@click.command()
@click.argument("base_ts", type=click.Path(exists=True, dir_okay=False))
@click.argument("date")
@dataset
@click.argument("matches", type=click.Path(exists=True, dir_okay=False))
@click.argument("output_ts", type=click.Path(dir_okay=False))
@num_mismatches
@deletions_as_missing
@memory_limit
@click.option(
"--hmm-cost-threshold",
default=5,
type=float,
show_default=True,
help="The maximum HMM cost for samples to be included unconditionally",
)
@click.option(
"--min-group-size",
default=10,
show_default=True,
type=int,
help="Minimum size of groups of reconsidered samples for inclusion",
)
@click.option(
"--min-root-mutations",
default=2,
show_default=True,
type=int,
help="Minimum number of shared mutations for reconsidered sample groups",
)
@click.option(
"--max-mutations-per-sample",
default=10,
show_default=True,
type=int,
help=(
"Maximum average number of mutations per sample in an inferred retrospective "
"group tree"
),
)
@click.option(
"--max-recurrent-mutations",
default=10,
show_default=True,
type=int,
help=(
"Maximum number of recurrent mutations in an inferred retrospective "
"group tree"
),
)
@click.option(
"--retrospective-window",
default=30,
show_default=True,
type=int,
help="Number of days in the past to reconsider potential matches",
)
@click.option(
"--max-daily-samples",
default=None,
type=int,
help=(
"The maximum number of samples to match in a single day. If the total "
"is greater than this, randomly subsample."
),
)
@click.option(
"--max-missing-sites",
default=None,
type=int,
help=(
"The maximum number of missing sites in a sample to be accepted for inclusion"
),
)
@click.option(
"--include-samples",
default=None,
type=click.File("r"),
help=(
"File containing the list of strains to unconditionally include, "
"one per line. Strains are the first white-space delimited token "
"and the rest of the line ignored (to allow for comments etc)"
),
)
@click.option(
"--random-seed",
default=42,
type=int,
help="Random seed for subsampling",
show_default=True,
)
@click.option(
"--num-threads",
default=0,
type=int,
help="Number of match threads (default to one)",
)
@click.option("--progress/--no-progress", default=True)
@click.option("-v", "--verbose", count=True)
@click.option("-l", "--log-file", default=None, type=click.Path(dir_okay=False))
@click.option(
"-f",
"--force",
is_flag=True,
flag_value=True,
help="Force clearing newer matches from DB",
)
def extend(
base_ts,
date,
dataset,
matches,
output_ts,
num_mismatches,
hmm_cost_threshold,
min_group_size,
min_root_mutations,
max_mutations_per_sample,
max_recurrent_mutations,
retrospective_window,
deletions_as_missing,
memory_limit,
max_daily_samples,
max_missing_sites,
include_samples,
num_threads,
random_seed,
progress,
verbose,
log_file,
force,
):
"""
Extend base_ts with sequences for the specified date, using specified
alignments and metadata databases, updating the specified matches
database, and outputting the result to the specified file.
"""
setup_logging(verbose, log_file, date=date)
base = tskit.load(base_ts)
summarise_base(base, date, progress)
if include_samples is not None:
include_samples = parse_include_samples(include_samples)
logger.debug(
f"Loaded {len(include_samples)} include samples: {include_samples}"
)
with contextlib.ExitStack() as exit_stack:
ds = sc2ts.Dataset(dataset)
match_db = exit_stack.enter_context(sc2ts.MatchDb(matches))

newer_matches = match_db.count_newer(date)
if newer_matches > 0:
if not force:
click.confirm(
f"Do you want to remove {newer_matches} newer matches "
f"from MatchDB > {date}?",
abort=True,
)
match_db.delete_newer(date)
ts_out = sc2ts.extend(
dataset=ds,
base_ts=base,
date=date,
match_db=match_db,
num_mismatches=num_mismatches,
include_samples=include_samples,
hmm_cost_threshold=hmm_cost_threshold,
min_group_size=min_group_size,
min_root_mutations=min_root_mutations,
max_mutations_per_sample=max_mutations_per_sample,
max_recurrent_mutations=max_recurrent_mutations,
retrospective_window=retrospective_window,
deletions_as_missing=deletions_as_missing,
max_daily_samples=max_daily_samples,
max_missing_sites=max_missing_sites,
random_seed=random_seed,
num_threads=num_threads,
memory_limit=memory_limit * 2**30,
show_progress=progress,
)
add_provenance(ts_out, output_ts)
resource_usage = f"{summarise_usage()}"
logger.info(resource_usage)
if progress:
print(resource_usage, file=sys.stderr)
df = pd.DataFrame(
ts_out.metadata["sc2ts"]["daily_stats"][date]["samples_processed"]
).set_index("scorpio")
df = df[list(df.columns)[::-1]].sort_values("total")
print(df)


def _run_extend(out_path, verbose, log_file, **params):
date = params["date"]
setup_logging(verbose, log_file, date=date)
Expand Down Expand Up @@ -975,7 +775,6 @@ def cli():
cli.add_command(info_matches)
cli.add_command(info_ts)

cli.add_command(extend)
cli.add_command(infer)
cli.add_command(validate)
cli.add_command(_match)
Expand Down
71 changes: 24 additions & 47 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def make_config(
filename = tmp_path / "config.toml"
with open(filename, "w") as f:
toml = tomli_w.dumps(config)
print("Generated", toml)
# print("Generated", toml)
f.write(toml)
return filename

Expand Down Expand Up @@ -214,52 +214,44 @@ def test_problematic_sites(self, tmp_path, fx_dataset, problematic):
match_db = sc2ts.MatchDb(match_db_path)
assert len(match_db) == 0


@pytest.mark.skip("stuff")
class TestExtend:

def test_first_day(self, tmp_path, fx_ts_map, fx_dataset):
ts = fx_ts_map["2020-01-01"]
ts_path = tmp_path / "ts.ts"
output_ts_path = tmp_path / "out.ts"
ts.dump(ts_path)
match_db = sc2ts.MatchDb.initialise(tmp_path / "match.db")
config_file = self.make_config(
tmp_path, fx_dataset, exclude_sites=[56, 57, 58, 59, 60]
)
runner = ct.CliRunner(mix_stderr=False)
result = runner.invoke(
cli.cli,
f"extend {ts_path} 2020-01-19 {fx_dataset.path} "
f"{match_db.path} {output_ts_path}",
f"infer {config_file} --stop 2020-01-20",
catch_exceptions=False,
)
date = "2020-01-19"
assert result.exit_code == 0
out_ts = tskit.load(output_ts_path)
out_ts.tables.assert_equals(
fx_ts_map["2020-01-19"].tables, ignore_provenance=True
)
ts_path = tmp_path / "results" / "test" / f"test_{date}.ts"
out_ts = tskit.load(ts_path)
out_ts.tables.assert_equals(fx_ts_map[date].tables, ignore_provenance=True)

def test_include_samples(self, tmp_path, fx_ts_map, fx_dataset):
ts = fx_ts_map["2020-02-01"]
ts_path = tmp_path / "ts.ts"
output_ts_path = tmp_path / "out.ts"
ts.dump(ts_path)
include_samples_path = tmp_path / "include_samples.txt"
with open(include_samples_path, "w") as f:
print("SRR11597115 This is a test strain", file=f)
print("ABCD this is a strain that doesn't exist", file=f)
match_db = sc2ts.MatchDb.initialise(tmp_path / "match.db")
config_file = self.make_config(
tmp_path,
fx_dataset,
exclude_sites=[56, 57, 58, 59, 60],
include_samples=["SRR14631544", "NO_SUCH_STRAIN"],
)
runner = ct.CliRunner(mix_stderr=False)
result = runner.invoke(
cli.cli,
f"extend {ts_path} 2020-02-02 {fx_dataset.path} "
f"{match_db.path} {output_ts_path} "
f"--include-samples={include_samples_path}",
f"infer {config_file} --stop 2020-01-02",
catch_exceptions=False,
)
assert result.exit_code == 0
ts = tskit.load(output_ts_path)
assert "SRR11597115" in ts.metadata["sc2ts"]["samples_strain"]
assert np.sum(ts.nodes_time[ts.samples()] == 0) == 5
assert ts.num_samples == 22
date = "2020-01-01"
ts_path = tmp_path / "results" / "test" / f"test_{date}.ts"

assert result.exit_code == 0
ts = tskit.load(ts_path)
assert "SRR14631544" in ts.metadata["sc2ts"]["samples_strain"]
assert np.sum(ts.nodes_time[ts.samples()] == 0) == 1
assert ts.num_samples == 1


@pytest.mark.skip("Broken by dataset")
Expand Down Expand Up @@ -348,18 +340,3 @@ def test_zarr(self, fx_dataset):
assert result.exit_code == 0
# Pick arbitrary field as a basic check
assert "/sample_Genbank_N" in result.stdout


class TestParseIncludeSamples:
@pytest.mark.parametrize(
["text", "parsed"],
[
("ABCD\n1234\n56", ["ABCD", "1234", "56"]),
(" ABCD\n\t1234\n 56", ["ABCD", "1234", "56"]),
("ABCD the rest is a comment", ["ABCD"]),
("", []),
],
)
def test_examples(self, text, parsed):
result = cli.parse_include_samples(io.StringIO(text))
assert result == parsed

0 comments on commit aa94aa8

Please sign in to comment.