diff --git a/goldenretriever/data/datasets.py b/goldenretriever/data/datasets.py index 1dd7e64..ea49e80 100644 --- a/goldenretriever/data/datasets.py +++ b/goldenretriever/data/datasets.py @@ -494,11 +494,17 @@ def load_fn( if max_positives != -1: positives = positives[:max_positives] - negatives = list(set([n["text"] for n in sample["negative_ctxs"]])) + if "negative_ctxs" in sample: + negatives = list(set([n["text"] for n in sample["negative_ctxs"]])) + else: + negatives = [] if max_negatives != -1: negatives = negatives[:max_negatives] - hard_negatives = list(set([h["text"] for h in sample["hard_negative_ctxs"]])) + if "hard_negative_ctxs" in sample: + hard_negatives = list(set([h["text"] for h in sample["hard_negative_ctxs"]])) + else: + hard_negatives = [] if max_hard_negatives != -1: hard_negatives = hard_negatives[:max_hard_negatives] diff --git a/goldenretriever/version.py b/goldenretriever/version.py index 000a79a..ac0ebe7 100644 --- a/goldenretriever/version.py +++ b/goldenretriever/version.py @@ -4,7 +4,7 @@ _MINOR = "9" # On main and in a nightly release the patch should be one ahead of the last # released build. -_PATCH = "1" +_PATCH = "2" # This is mainly for nightly builds which have the suffix ".dev$DATE". See # https://semver.org/#is-v123-a-semantic-version for the semantics. _SUFFIX = os.environ.get("GOLDENRETRIEVER_VERSION_SUFFIX", "")