TorchPQ is a python library for Approximate Nearest Neighbor Search (ANNS) and Maximum Inner Product Search (MIPS) on GPU using Product Quantization (PQ) algorithm. TorchPQ is implemented mainly with PyTorch, with some extra CUDA kernels to accelerate clustering, indexing and searching.
- make sure you have the latest version of PyTorch installed: https://pytorch.org/get-started/locally/
- install a version of CuPy library that matches your CUDA version
pip install cupy-cuda90
pip install cupy-cuda100
pip install cupy-cuda101
...
for a full list of cupy-cuda versions, please go to Installation Guide
- install TorchPQ
pip install torchpq
InVerted File Product Quantization (IVFPQ) is a type of ANN search algorithm that is designed to do fast and efficient vector search in million, or even billion scale vector sets. check the original paper for more details.
from torchpq import IVFPQ
import torch
n_data = 1000000 # number of data points
d_vector = 128 # dimentionality / number of features
index = IVFPQ(
d_vector=d_vector,
n_subvectors=64,
n_cq_clusters=1024,
n_pq_clusters=256,
blocksize=128,
distance="euclidean",
)
x = torch.randn(d_vector, n_data, device="cuda:0")
index.train(x)
There are some important parameters that need to be explained:
- d_vector: dimentionality of input vectors. there are 2 constraints on
d_vector
: (1) it needs to be divisible byn_subvectors
; (2) it needs to be a multiple of 4.* - n_subvectors: number of subquantizers, essentially this is the byte size of each quantized vector, 64 byte per vector in the above example.**
- n_cq_clusters: number of coarse quantizer clusters
- n_pq_clusters: number of product quantizer clusters, this is assumed to be 256 throughout the entire project, and should NOT be changed.
- blocksize: initial capacity assigned to each voronoi cell of coarse quantizer.
n_cq_clusters * blocksize
is the number of vectors that can be stored initially. if any cell has reached its capacity, that cell will be automatically expanded. If you need to add vectors frequently, a larger value forblocksize
is recommended.
Remember that the shape of any tensor that contains data points has to be [d_vector, n_data]
.
* the second constraint could be removed in the future
** actual byte size would be (n_subvectors+9) bytes, 8 bytes for ID and 1 byte for is_empty
ids = torch.arange(n_data, device="cuda")
index.add(x, input_ids=ids)
Each ID in ids
needs to be a unique int64 (torch.long
) value that identifies a vector in x
.
if input_ids
is not provided, it will be set to torch.arange(n_data, device="cuda") + previous_max_id
index.remove(ids)
index.remove(ids)
will virtually remove vectors with specified ids
from storage.
It ignores ids that doesn't exist.
index.n_probe = 32
n_query = 10000
query = torch.randn(d_vector, n_query, device="cuda:0")
topk_values, topk_ids = index.topk(query, k=100)
- when
distance="inner"
,topk_values
are inner product of queries and topk closest data points. - when
distance="euclidean"
,topk_values
are negative squared L2 distance between queries and topk closest data points. - when
distance="manhattan"
,topk_values
are negative L1 distance between queries and topk closest data points. - when
distance="cosine"
,topk_values
are cosine similarity between queries and topk closest data points.
you can use IVFPQ as a vector codec for lossy compression of vectors
code = index.encode(query) # compression
reconstruction = index.decode(code) # reconstruction
Most of the TorchPQ modules are inherited from torch.nn.Module
, this means you can save and load them just like a regular pytorch model.
# Save to PATH
torch.save(index.state_dict(), PATH)
# Load from PATH
index.load_state_dict(torch.load(PATH))
from torchpq.kmeans import KMeans
import torch
n_data = 1000000 # number of data points
d_vector = 128 # dimentionality / number of features
x = torch.randn(d_vector, n_data, device="cuda")
kmeans = KMeans(n_clusters=4096, distance="euclidean")
labels = kmeans.fit(x)
Notice that the shape of the tensor that contains data points has to be [d_vector, n_data]
, this is consistant in TorchPQ.
Sometimes, we have multiple independent datasets that need to be clustered, instead of running multiple KMeans sequentianlly, we can perform multiple kmeans concurrently with MultiKMeans
from torchpq.kmeans import MultiKMeans
import torch
n_data = 1000000
n_kmeans = 16
d_vector = 64
x = torch.randn(n_kmeans, d_vector, n_data, device="cuda")
kmeans = MultiKMeans(n_clusters=256, distance="euclidean")
labels = kmeans.fit(x)
labels = kmeans.predict(x)
All experiments were performed with a Tesla T4 GPU.
Faiss is one of the most well known ANN search libraries, and it also has a GPU implementation of IVFPQ, so we did some comparison experiments with faiss.
Click to show details
How to read the plot:
- the plot format follows the style of ann-benchmarks
- X axis is recall@1, Y axis is queries/second
- the closer to the top right corner the better
- indexes with same parameters from different libraries have similar colors.
- different libraries have different line styles (TorchPQ is solid line with circle marker, faiss is dashed line with triangle marker)
- each node on the line represents a different n_probe, starting from 1 at the left most node, and multiplied by 2 at the next node. (n_probe = 1,2,4,8,16,...)
Summary:
- for all the IVF16384 variants, torchpq outperforms faiss when n_probe > 16.
- for IVF4096, torchpq has lower recall@1 compared to faiss, this could be caused by not encoding residuals. An option to encode residuals will be added soon.
coming soon...
Performing K-Means clustering on float32 data randomly sampled from normal distribution.
Click to show details
- Number of iterations is set to 15.
- Tolerance is set to 0 in order to perform full 15 iterations of K-Means
- Initial centroids are randomly chosen from training data
- All runs are performed on a Tesla T4 GPU
Contestants:
- TorchPQ.kmeans.KMeans
- faiss.Clustering
- KeOps