Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detect stabilization #31

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pipeline/tests/predict.tsv
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
datetime reads kbps kkmers kmers_prop pgs pgs_pass pg alt_pg bm bm_ST bm_serotype pen_sus pen_pred pen_bm cro_sus cro_pred cro_bm tmp_sus tmp_pred tmp_bm ery_sus ery_pred ery_bm tet_sus tet_pred tet_bm flags
2017-05-26 17:29:52 0 0 0 0.000 0 0 NA NA NA NA NA 0.000 R NA (NA) 0.000 R NA (NA) 0.000 R NA (NA) 0.000 R NA (NA) 0.000 R NA (NA) []
datetime reads kbps kkmers kmers_prop pgs pgs_pass pg alt_pg bm bm_ST bm_serotype pen_ssc pen_pred pen_bm cro_ssc cro_pred cro_bm tmp_ssc tmp_pred tmp_bm ery_ssc ery_pred ery_bm tet_ssc tet_pred tet_bm flags
2017-05-26 17:29:52 0 0 0 0.000 0 0 NA NA NA NA NA 0.000 R NA (NA) 0.000 R NA (NA) 0.000 R NA (NA) 0.000 R NA (NA) 0.000 S NA (NA) []
2017-05-26 17:30:52 137 319 19 0.060 0.608 1 3 16 8QTW4 63 15A 0.141 R R (0.25) 0.813 S S (0.094) 0.659 S! S (0.19) 0.133 R R (32) 0.141 R r (NA) [D:bm, D:pg, S:bm, S:pg]
2017-05-26 17:31:52 278 646 39 0.061 0.590 1 3 2 8QTW4 63 15A 0.125 R R (0.25) 0.844 S S (0.094) 0.746 S! S (0.19) 0.312 R! R (32) 0.312 R! r (NA) []
2017-05-26 17:32:53 414 973 59 0.060 0.672 1 3 2 8QTW4 63 15A 0.135 R R (0.25) 0.824 S S (0.094) 0.754 S! S (0.19) 0.276 R! R (32) 0.276 R! r (NA) []
Expand Down
65 changes: 65 additions & 0 deletions scripts/rase_plot_timeline.R
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,71 @@ DfToAnts <- function(df) {
}


#' Return list marking final values of resistance predictions.
#'
#' @param col Column of resistance predictions
#'
#' @return List of T/F.
#'
ResPredIsFinal <- function(col) {
a <- factor(col, levels = c("S", "S!", "R!", "R"))
final.states <- head(unique(rev(a)), 2)
is.final <- a == final.states[1] | a == final.states[2]
is.final
}


#' Find stabilization points
#'
#' @param col
#'
#' @return List of T/F.
#'
StabilizationPoint <- function(col) {
# detect last block
r <- rle(col)
r$values[] <- F
r$values[length(r$values)] <- T
x <- inverse.rle(r)
# detect its first value
cumsum(x)==1
}

#' Find detection points
#'
#' @param col
#'
#' @return List of T/F.
#'
DetectionPoints <- function(col){
(col-c(0,head(col,-1)))==1
}


#' Find losing points
#'
#' @param col
#'
#' @return List of T/F.
#'
LosingPoints <- function(col){
(col-c(0,head(col,-1)))==-1
}


#' Return list marking final values of category predictions.
#'
#' @param col Column of category predictions
#'
#' @return List of T/F.
#'
CatPredIsFinal <- function(col) {
final.states <- head(unique(rev(col)), 1)
is.final <- col == final.states[1]
is.final
}


#' Create a data frame with flag points
#'
#' @param df Input dataframe
Expand Down
43 changes: 35 additions & 8 deletions src/rase/rase_prediction_add_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,24 @@
import re


def load_tsv_dl(tsv_fn):
"""Load TSV as a list of dictionaries (1 dict per 1 line).
"""
with open(tsv_fn) as f:
tr = [x for x in csv.DictReader(f, delimiter="\t")]
return tr


def get_cols(tsv_dl):
"""Get columns from a list of dictionaries.
"""
d = tsv_dl[0]
return [x for x in d.keys()]


def extract_flag_cols(cols):
# regular expressions for categories to be flagged
"""Extract columns for which flags are to be computed.
"""
res_flagging = [
re.compile(r"^bm$"),
re.compile(r"^serotype$"),
Expand All @@ -41,22 +52,38 @@ def extract_flag_cols(cols):
return flag_cols


def load_tsv_dl(tsv_fn):
with open(tsv_fn) as f:
tr = [x for x in csv.DictReader(f, delimiter="\t")]
return tr


def get_stabilization_point(tsv_dl, key):
tsv_dl_ = [collections.defaultdict(lambda: "")] + tsv_dl
"""Find points of stabilization.
"""
tsv_dl_ = [collections.defaultdict(lambda: "")
] + tsv_dl # probably to support empty files?
final_value = tsv_dl_[-1][key]
for i in range(len(tsv_dl_) - 1, -1, -1):
if tsv_dl_[i][key] != final_value:
break
return i


def add_flags_line():
for k in flag_cols:
if i == stabilization_points[k]:
flags.append('S:{}'.format(k))
assert rec[k] != prev_rec[k]
if rec[k] != prev_rec[k]:
if rec[k] == last_rec[k]:
# detected
flags.append('D:{}'.format(k))
elif prev_rec[k] == last_rec[k]:
# lost successfuly detected
flags.append('L:{}'.format(k))
flags.sort()
flags_str = str(flags).replace("'", "")
print(*rec.values(), flags_str, sep="\t")


def add_flags(tsv_fn):
"""Read RASE prediction output and add flags.
"""
tsv_dl = load_tsv_dl(tsv_fn)
cols = get_cols(tsv_dl)
flag_cols = extract_flag_cols(cols)
Expand Down