Skip to content

Commit

Permalink
Merge branch 'plaid-x' into publish-test
Browse files Browse the repository at this point in the history
  • Loading branch information
eugene-yang committed Feb 14, 2024
2 parents 0a09253 + 61ed531 commit 092e424
Show file tree
Hide file tree
Showing 20 changed files with 826 additions and 104 deletions.
7 changes: 5 additions & 2 deletions colbert/data/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def __iter__(self):

def __getitem__(self, item):
# TODO: Load from disk the first time this is called. Unless self.data is already not None.
return self.data[item]
if isinstance(item, list):
return [ self.data[int(i)] for i in item ]
return self.data[int(item)]

def __len__(self):
# TODO: Load here too. Basically, let's make data a property function and, on first call, either load or get __data.
Expand Down Expand Up @@ -91,7 +93,8 @@ def cast(cls, obj):
if type(obj) is list:
return cls(data=obj)

if type(obj) is cls:
# if type(obj) is cls:
if isinstance(obj, cls):
return obj

assert False, f"obj has type {type(obj)} which is not compatible with cast()"
Expand Down
23 changes: 16 additions & 7 deletions colbert/data/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import ujson

from colbert.utils.utils import print_message
from colbert.utils.utils import print_message, easy_pbar
from colbert.infra.provenance import Provenance


Expand All @@ -24,9 +24,16 @@ def _load_file(self, path):
examples = []

with open(path) as f:
for line in f:
example = ujson.loads(line)[:nway]
examples.append(example)
it = easy_pbar(f, desc=f'Loading {path}', disabled=Run().config.rank != 0)
if path.endswith('.tsv'):
for i, line in enumerate(it):
examples.append(line.strip().split("\t")[:nway])
if Run().config.debug and i > 10000:
break
else:
for line in it:
example = ujson.loads(line)[:nway]
examples.append(example)

return examples

Expand All @@ -40,7 +47,7 @@ def tolist(self, rank=None, nranks=None):

if rank or nranks:
assert rank in range(nranks), (rank, nranks)
return [self.data[idx] for idx in range(0, len(self.data), nranks)] # if line_idx % nranks == rank
return [self.data[idx] for idx in range(rank, len(self.data), nranks)] # if line_idx % nranks == rank

return list(self.data)

Expand Down Expand Up @@ -74,8 +81,10 @@ def cast(cls, obj, nway=None):
if isinstance(obj, list):
return cls(data=obj, nway=nway)

if type(obj) is cls:
assert nway is None, nway
# if type(obj) is cls:
if isinstance(obj, cls):
# assert nway is None, nway
assert obj.nway == nway, nway
return obj

assert False, f"obj has type {type(obj)} which is not compatible with cast()"
4 changes: 2 additions & 2 deletions colbert/data/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ class Queries:
def __init__(self, path=None, data=None):
self.path = path

if data:
assert isinstance(data, dict), type(data)
# if data:
# assert isinstance(data, dict), type(data)
self._load_data(data) or self._load_file(path)

def __len__(self):
Expand Down
13 changes: 10 additions & 3 deletions colbert/evaluation/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ def load_queries(queries_path):
with open(queries_path) as f:
for line in f:
qid, query, *_ = line.strip().split('\t')
qid = int(qid)
try:
qid = int(qid)
except:
qid = qid.strip()

assert (qid not in queries), ("Query QID", qid, "is repeated!")
queries[qid] = query
Expand Down Expand Up @@ -153,7 +156,7 @@ def load_topK_pids(topK_path, qrels):


def load_collection(collection_path):
print_message("#> Loading collection...")
print_message(f"#> Loading collection {collection_path}...")

collection = []

Expand All @@ -162,7 +165,11 @@ def load_collection(collection_path):
if line_idx % (1000*1000) == 0:
print(f'{line_idx // 1000 // 1000}M', end=' ', flush=True)

pid, passage, *rest = line.strip('\n\r ').split('\t')
try:
pid, passage, *rest = line.strip('\n\r ').split('\t')
except Exception as e:
print(line_idx, line)
raise e
assert pid == 'id' or int(pid) == line_idx, f"pid={pid}, line_idx={line_idx}"

if len(rest) >= 1:
Expand Down
34 changes: 29 additions & 5 deletions colbert/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from colbert.utils.utils import create_directory, print_message

from colbert.indexing.collection_indexer import encode, sample, kmeans, index
from colbert.indexing.collection_indexer import reuse_prepare, sample, kmeans, encode, finalize

class Indexer:
def __init__(self, checkpoint, config=None):
Expand Down Expand Up @@ -54,7 +54,7 @@ def erase(self):

return deleted

def prepare(self, name, collection, overwrite=False):
def prepare(self, name, collection, overwrite=False, no_kmeans=False):
assert overwrite in [True, False, 'reuse', 'resume']

self.configure(collection=collection, index_name=name, resume=overwrite=='resume')
Expand All @@ -71,18 +71,42 @@ def prepare(self, name, collection, overwrite=False):

if index_does_not_exist or overwrite != 'reuse':
self.__launch(sample, collection, nospawn=self.config.gpus == 1)
self.__launch(kmeans, collection, nospawn=True)
if not no_kmeans:
self.__launch(kmeans, collection, nospawn=True)

return self.index_path

def index(self, name, collection):
def encode(self, name, collection):
self.configure(collection=collection, index_name=name, resume=True)
self.configure(partitions=None)

self.index_path = self.config.index_path_

if self.config.reuse_centroids_from is not None:
plan_path = os.path.join(self.index_path, "plan.json")
if not os.path.exists(plan_path):
# do a light prepare step to create the plan.json file along and soft link the centriods
self.index_path = self.config.index_path_
create_directory(self.config.index_path_)
self.__launch(reuse_prepare, collection, nospawn=True)
else:
print_message(f"############# Ignoring `reuse_centroids_from` since the plan in the index directory already exists -- will only try to resume indexing only")

assert os.path.exists(self.config.index_path_), "Run first step `prepare` in advance."

self.__launch(index, collection, nospawn=self.config.gpus == 1)
self.__launch(encode, collection, nospawn=self.config.gpus == 1)

return self.index_path

def finalize(self, name, collection):
self.configure(collection=collection, index_name=name, resume=True)
self.configure(bsize=64, partitions=None)

self.index_path = self.config.index_path_



self.__launch(finalize, collection, nospawn=True)

return self.index_path

Expand Down
Loading

0 comments on commit 092e424

Please sign in to comment.