Skip to content

Commit

Permalink
feat(query): Add custom lancedb query generation for lancedb search (#…
Browse files Browse the repository at this point in the history
…518)

Fixes: #511

This PR implements the `CustomStrategy<Q>` trait for LanceDB, enabling
flexible and customizable vector similarity search capabilities. The
implementation follows the successful pattern established in the
pgvector integration while leveraging LanceDB-specific features.

Signed-off-by: shamb0 <[email protected]>
Co-authored-by: Timon Vonk <[email protected]>
  • Loading branch information
shamb0 and timonv authored Jan 16, 2025
1 parent f83f3f0 commit 7f85735
Show file tree
Hide file tree
Showing 6 changed files with 276 additions and 116 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

115 changes: 44 additions & 71 deletions swiftide-core/src/search_strategies/custom_strategy.rs
Original file line number Diff line number Diff line change
@@ -1,64 +1,28 @@
//! Generic vector search strategy framework for customizable query generation.
//!
//! Provides core abstractions for vector similarity search:
//! - Generic query type parameter for retriever-specific implementations
//! - Flexible query generation through closure-based configuration
//!
//! This module implements a strategy pattern for vector similarity search,
//! allowing different retrieval backends to provide their own query generation
//! logic while maintaining a consistent interface. The framework emphasizes
//! composition over inheritance, enabling configuration through closures
//! rather than struct fields.
//! Implements a flexible vector search strategy framework using closure-based configuration.
//! Supports both synchronous and asynchronous query generation for different retrieval backends.
use crate::querying::{self, states, Query};
use anyhow::{anyhow, Result};
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;

/// A type alias for query generation functions.
///
/// The query generator takes a pending query state and produces a
/// retriever-specific query type. All configuration parameters should
/// be captured in the closure's environment.
// Function type for generating retriever-specific queries
type QueryGenerator<Q> = Arc<dyn Fn(&Query<states::Pending>) -> Result<Q> + Send + Sync>;

/// `CustomStrategy` provides a flexible way to generate retriever-specific search queries.
///
/// This struct implements a strategy pattern for vector similarity search, allowing
/// different retrieval backends to provide their own query generation logic. Configuration
/// is managed through the query generation closure, promoting a more flexible and
/// composable design.
///
/// # Type Parameters
/// * `Q` - The retriever-specific query type (e.g., `sqlx::QueryBuilder` for `PostgreSQL`)
///
/// # Examples
/// ```ignore
/// // Define search configuration
/// const MAX_SEARCH_RESULTS: i64 = 5;
///
/// // Create a custom search strategy
/// let strategy = CustomStrategy::from_query(|query_node| {
/// let mut builder = QueryBuilder::new();
///
/// // Configure search parameters within the closure
/// builder.push(" LIMIT ");
/// builder.push_bind(MAX_SEARCH_RESULTS);
///
/// Ok(builder)
/// });
/// ```
///
/// # Implementation Notes
/// - Search configuration (like result limits and vector fields) should be defined
/// in the closure's scope
/// - Implementers are responsible for validating configuration values
/// - The query generator has access to the full query state for maximum flexibility
// Function type for async query generation
type AsyncQueryGenerator<Q> = Arc<
dyn Fn(&Query<states::Pending>) -> Pin<Box<dyn Future<Output = Result<Q>> + Send>>
+ Send
+ Sync,
>;

/// Implements the strategy pattern for vector similarity search, allowing retrieval backends
/// to define custom query generation logic through closures.
pub struct CustomStrategy<Q> {
/// The query generation function now returns a `Q`
query: Option<QueryGenerator<Q>>,

/// `PhantomData` to handle the generic parameter
async_query: Option<AsyncQueryGenerator<Q>>,
_marker: PhantomData<Q>,
}

Expand All @@ -68,51 +32,60 @@ impl<Q> Default for CustomStrategy<Q> {
fn default() -> Self {
Self {
query: None,
async_query: None,
_marker: PhantomData,
}
}
}

// Manual Clone implementation instead of derive
impl<Q> Clone for CustomStrategy<Q> {
fn clone(&self) -> Self {
Self {
query: self.query.clone(), // Arc clone is fine
query: self.query.clone(),
async_query: self.async_query.clone(),
_marker: PhantomData,
}
}
}

impl<Q: Send + Sync + 'static> CustomStrategy<Q> {
/// Creates a new `CustomStrategy` with a query generation function.
///
/// The provided closure should contain all necessary configuration for
/// query generation. This design allows for more flexible configuration
/// management compared to struct-level fields.
///
/// # Parameters
/// * `query` - A closure that generates retriever-specific queries
/// Creates a new strategy with a synchronous query generator.
pub fn from_query(
query: impl Fn(&Query<states::Pending>) -> Result<Q> + Send + Sync + 'static,
) -> Self {
Self {
query: Some(Arc::new(query)),
async_query: None,
_marker: PhantomData,
}
}

/// Creates a new strategy with an asynchronous query generator.
pub fn from_async_query<F>(
query: impl Fn(&Query<states::Pending>) -> F + Send + Sync + 'static,
) -> Self
where
F: Future<Output = Result<Q>> + Send + 'static,
{
Self {
query: None,
async_query: Some(Arc::new(move |q| Box::pin(query(q)))),
_marker: PhantomData,
}
}

/// Gets the query builder, which can then be used to build the actual query
/// Generates a query using either the sync or async generator.
/// Returns error if no query generator is set.
///
/// # Errors
/// This function will return an error if:
/// - No query function has been set (use `from_query` to set a query function).
/// - The query function fails while processing the provided `query_node`.
pub fn build_query(&self, query_node: &Query<states::Pending>) -> Result<Q> {
match &self.query {
Some(query_fn) => Ok(query_fn(query_node)?),
None => Err(anyhow!(
"No query function has been set. Use from_query() to set a query function."
)),
/// Returns an error if:
/// * No query generator has been configured
/// * The configured query generator fails during query generation
pub async fn build_query(&self, query_node: &Query<states::Pending>) -> Result<Q> {
match (&self.query, &self.async_query) {
(Some(query_fn), _) => query_fn(query_node),
(_, Some(async_fn)) => async_fn(query_node).await,
_ => Err(anyhow!("No query function has been set.")),
}
}
}
130 changes: 90 additions & 40 deletions swiftide-integrations/src/lancedb/retrieve.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use anyhow::Result;
use arrow::datatypes::SchemaRef;
use arrow_array::StringArray;
use arrow_array::{RecordBatch, StringArray};
use async_trait::async_trait;
use futures_util::TryStreamExt;
use itertools::Itertools;
use lancedb::query::{ExecutableQuery, QueryBase};
use swiftide_core::{
document::Document,
indexing::Metadata,
querying::{search_strategies::SimilaritySingleEmbedding, states, Query},
querying::{
search_strategies::{CustomStrategy, SimilaritySingleEmbedding},
states, Query,
},
Retrieve,
};

Expand Down Expand Up @@ -66,43 +68,7 @@ impl Retrieve<SimilaritySingleEmbedding<String>> for LanceDB {
.try_collect::<Vec<_>>()
.await?;

let mut documents = vec![];

for batch in batches {
let schema: SchemaRef = batch.schema();

for row_idx in 0..batch.num_rows() {
let mut metadata = Metadata::default();
let mut content = String::new();

for (col_idx, field) in schema.fields().iter().enumerate() {
let column = batch.column(col_idx);

if let Some(array) = column.as_any().downcast_ref::<StringArray>() {
if field.name() == "chunk" {
// Extract the "content" field
content = array.value(row_idx).to_string();
} else {
// Assume other fields are part of the metadata
let value = array.value(row_idx).to_string();
metadata.insert(field.name().clone(), value);
}
} else {
// Handle other array types as necessary
// TODO: Can't we just downcast to serde::Value or fail?
}
}

documents.push(Document::new(
content,
if metadata.is_empty() {
None
} else {
Some(metadata)
},
));
}
}
let documents = Self::retrieve_from_record_batches(&batches);

Ok(query.retrieved_documents(documents))
}
Expand All @@ -124,6 +90,90 @@ impl Retrieve<SimilaritySingleEmbedding> for LanceDB {
}
}

#[async_trait]
impl<Q: ExecutableQuery + Send + Sync + 'static> Retrieve<CustomStrategy<Q>> for LanceDB {
/// Implements vector similarity search for LanceDB using a custom query strategy.
///
/// # Type Parameters
/// * `VectorQuery` - LanceDB's query type for vector similarity search
async fn retrieve(
&self,
search_strategy: &CustomStrategy<Q>,
query: Query<states::Pending>,
) -> Result<Query<states::Retrieved>> {
// Build the custom query using both strategy and query state
let query_builder = search_strategy.build_query(&query).await?;

// Execute the query using the builder's built-in methods
let batches = query_builder
.execute()
.await?
.try_collect::<Vec<_>>()
.await?;

let documents = Self::retrieve_from_record_batches(&batches);

Ok(query.retrieved_documents(documents))
}
}

impl LanceDB {
/// Retrieves documents from Arrow `RecordBatches` by processing each row and extracting content
/// and metadata fields.
///
/// The function expects a "chunk" field to contain the main document content, while all other
/// string fields are treated as metadata. Non-string fields are currently skipped
fn retrieve_from_record_batches(batches: &[RecordBatch]) -> Vec<Document> {
let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
let mut documents = Vec::with_capacity(total_rows);

let process_batch = |batch: &RecordBatch, documents: &mut Vec<Document>| {
for row_idx in 0..batch.num_rows() {
let schema = batch.schema();

let (content, metadata): (String, Option<Metadata>) = {
let mut metadata = Metadata::default();
let mut content = String::new();

for (col_idx, field) in schema.as_ref().fields().iter().enumerate() {
if let Some(array) =
batch.column(col_idx).as_any().downcast_ref::<StringArray>()
{
let value = array.value(row_idx).to_string();

if field.name() == "chunk" {
content = value;
} else {
metadata.insert(field.name().to_string(), value);
}
} else {
// Handle other array types as necessary
// TODO: Can't we just downcast to serde::Value or fail?
}
}

(
content,
if metadata.is_empty() {
None
} else {
Some(metadata)
},
)
};

documents.push(Document::new(content, metadata));
}
};

batches
.iter()
.for_each(|batch| process_batch(batch, &mut documents));

documents
}
}

#[cfg(test)]
mod test {
use swiftide_core::{
Expand Down
2 changes: 1 addition & 1 deletion swiftide-integrations/src/pgvector/retrieve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ impl Retrieve<CustomStrategy<sqlx::QueryBuilder<'static, sqlx::Postgres>>> for P
let pool = self.get_pool().await?;

// Build the custom query using both strategy and query state
let mut query_builder = search_strategy.build_query(&query)?;
let mut query_builder = search_strategy.build_query(&query).await?;

// Execute the query using the builder's built-in methods
let results = query_builder
Expand Down
1 change: 1 addition & 0 deletions swiftide/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ serde_json = { workspace = true }
tokio = { workspace = true }
arrow-array = { workspace = true }
sqlx = { workspace = true }
lancedb = { workspace = true }

[lints]
workspace = true
Expand Down
Loading

0 comments on commit 7f85735

Please sign in to comment.