From 7f857358e46e825494ba927dffb33c3afa0d762e Mon Sep 17 00:00:00 2001 From: RK Date: Thu, 16 Jan 2025 19:13:27 +0530 Subject: [PATCH] feat(query): Add custom lancedb query generation for lancedb search (#518) Fixes: #511 This PR implements the `CustomStrategy` trait for LanceDB, enabling flexible and customizable vector similarity search capabilities. The implementation follows the successful pattern established in the pgvector integration while leveraging LanceDB-specific features. Signed-off-by: shamb0 Co-authored-by: Timon Vonk --- Cargo.lock | 1 + .../src/search_strategies/custom_strategy.rs | 115 ++++++-------- swiftide-integrations/src/lancedb/retrieve.rs | 130 +++++++++++----- .../src/pgvector/retrieve.rs | 2 +- swiftide/Cargo.toml | 1 + swiftide/tests/lancedb.rs | 143 +++++++++++++++++- 6 files changed, 276 insertions(+), 116 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9eb9b9aa..686ec29a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8521,6 +8521,7 @@ dependencies = [ "async-openai", "async-trait", "document-features", + "lancedb", "mockall", "qdrant-client", "serde", diff --git a/swiftide-core/src/search_strategies/custom_strategy.rs b/swiftide-core/src/search_strategies/custom_strategy.rs index d2e9ea6c..0d6c9615 100644 --- a/swiftide-core/src/search_strategies/custom_strategy.rs +++ b/swiftide-core/src/search_strategies/custom_strategy.rs @@ -1,64 +1,28 @@ -//! Generic vector search strategy framework for customizable query generation. -//! -//! Provides core abstractions for vector similarity search: -//! - Generic query type parameter for retriever-specific implementations -//! - Flexible query generation through closure-based configuration -//! -//! This module implements a strategy pattern for vector similarity search, -//! allowing different retrieval backends to provide their own query generation -//! logic while maintaining a consistent interface. The framework emphasizes -//! composition over inheritance, enabling configuration through closures -//! rather than struct fields. +//! Implements a flexible vector search strategy framework using closure-based configuration. +//! Supports both synchronous and asynchronous query generation for different retrieval backends. use crate::querying::{self, states, Query}; use anyhow::{anyhow, Result}; +use std::future::Future; use std::marker::PhantomData; +use std::pin::Pin; use std::sync::Arc; -/// A type alias for query generation functions. -/// -/// The query generator takes a pending query state and produces a -/// retriever-specific query type. All configuration parameters should -/// be captured in the closure's environment. +// Function type for generating retriever-specific queries type QueryGenerator = Arc) -> Result + Send + Sync>; -/// `CustomStrategy` provides a flexible way to generate retriever-specific search queries. -/// -/// This struct implements a strategy pattern for vector similarity search, allowing -/// different retrieval backends to provide their own query generation logic. Configuration -/// is managed through the query generation closure, promoting a more flexible and -/// composable design. -/// -/// # Type Parameters -/// * `Q` - The retriever-specific query type (e.g., `sqlx::QueryBuilder` for `PostgreSQL`) -/// -/// # Examples -/// ```ignore -/// // Define search configuration -/// const MAX_SEARCH_RESULTS: i64 = 5; -/// -/// // Create a custom search strategy -/// let strategy = CustomStrategy::from_query(|query_node| { -/// let mut builder = QueryBuilder::new(); -/// -/// // Configure search parameters within the closure -/// builder.push(" LIMIT "); -/// builder.push_bind(MAX_SEARCH_RESULTS); -/// -/// Ok(builder) -/// }); -/// ``` -/// -/// # Implementation Notes -/// - Search configuration (like result limits and vector fields) should be defined -/// in the closure's scope -/// - Implementers are responsible for validating configuration values -/// - The query generator has access to the full query state for maximum flexibility +// Function type for async query generation +type AsyncQueryGenerator = Arc< + dyn Fn(&Query) -> Pin> + Send>> + + Send + + Sync, +>; + +/// Implements the strategy pattern for vector similarity search, allowing retrieval backends +/// to define custom query generation logic through closures. pub struct CustomStrategy { - /// The query generation function now returns a `Q` query: Option>, - - /// `PhantomData` to handle the generic parameter + async_query: Option>, _marker: PhantomData, } @@ -68,51 +32,60 @@ impl Default for CustomStrategy { fn default() -> Self { Self { query: None, + async_query: None, _marker: PhantomData, } } } -// Manual Clone implementation instead of derive impl Clone for CustomStrategy { fn clone(&self) -> Self { Self { - query: self.query.clone(), // Arc clone is fine + query: self.query.clone(), + async_query: self.async_query.clone(), _marker: PhantomData, } } } impl CustomStrategy { - /// Creates a new `CustomStrategy` with a query generation function. - /// - /// The provided closure should contain all necessary configuration for - /// query generation. This design allows for more flexible configuration - /// management compared to struct-level fields. - /// - /// # Parameters - /// * `query` - A closure that generates retriever-specific queries + /// Creates a new strategy with a synchronous query generator. pub fn from_query( query: impl Fn(&Query) -> Result + Send + Sync + 'static, ) -> Self { Self { query: Some(Arc::new(query)), + async_query: None, + _marker: PhantomData, + } + } + + /// Creates a new strategy with an asynchronous query generator. + pub fn from_async_query( + query: impl Fn(&Query) -> F + Send + Sync + 'static, + ) -> Self + where + F: Future> + Send + 'static, + { + Self { + query: None, + async_query: Some(Arc::new(move |q| Box::pin(query(q)))), _marker: PhantomData, } } - /// Gets the query builder, which can then be used to build the actual query + /// Generates a query using either the sync or async generator. + /// Returns error if no query generator is set. /// /// # Errors - /// This function will return an error if: - /// - No query function has been set (use `from_query` to set a query function). - /// - The query function fails while processing the provided `query_node`. - pub fn build_query(&self, query_node: &Query) -> Result { - match &self.query { - Some(query_fn) => Ok(query_fn(query_node)?), - None => Err(anyhow!( - "No query function has been set. Use from_query() to set a query function." - )), + /// Returns an error if: + /// * No query generator has been configured + /// * The configured query generator fails during query generation + pub async fn build_query(&self, query_node: &Query) -> Result { + match (&self.query, &self.async_query) { + (Some(query_fn), _) => query_fn(query_node), + (_, Some(async_fn)) => async_fn(query_node).await, + _ => Err(anyhow!("No query function has been set.")), } } } diff --git a/swiftide-integrations/src/lancedb/retrieve.rs b/swiftide-integrations/src/lancedb/retrieve.rs index 0de3b9b6..9c68aafb 100644 --- a/swiftide-integrations/src/lancedb/retrieve.rs +++ b/swiftide-integrations/src/lancedb/retrieve.rs @@ -1,6 +1,5 @@ use anyhow::Result; -use arrow::datatypes::SchemaRef; -use arrow_array::StringArray; +use arrow_array::{RecordBatch, StringArray}; use async_trait::async_trait; use futures_util::TryStreamExt; use itertools::Itertools; @@ -8,7 +7,10 @@ use lancedb::query::{ExecutableQuery, QueryBase}; use swiftide_core::{ document::Document, indexing::Metadata, - querying::{search_strategies::SimilaritySingleEmbedding, states, Query}, + querying::{ + search_strategies::{CustomStrategy, SimilaritySingleEmbedding}, + states, Query, + }, Retrieve, }; @@ -66,43 +68,7 @@ impl Retrieve> for LanceDB { .try_collect::>() .await?; - let mut documents = vec![]; - - for batch in batches { - let schema: SchemaRef = batch.schema(); - - for row_idx in 0..batch.num_rows() { - let mut metadata = Metadata::default(); - let mut content = String::new(); - - for (col_idx, field) in schema.fields().iter().enumerate() { - let column = batch.column(col_idx); - - if let Some(array) = column.as_any().downcast_ref::() { - if field.name() == "chunk" { - // Extract the "content" field - content = array.value(row_idx).to_string(); - } else { - // Assume other fields are part of the metadata - let value = array.value(row_idx).to_string(); - metadata.insert(field.name().clone(), value); - } - } else { - // Handle other array types as necessary - // TODO: Can't we just downcast to serde::Value or fail? - } - } - - documents.push(Document::new( - content, - if metadata.is_empty() { - None - } else { - Some(metadata) - }, - )); - } - } + let documents = Self::retrieve_from_record_batches(&batches); Ok(query.retrieved_documents(documents)) } @@ -124,6 +90,90 @@ impl Retrieve for LanceDB { } } +#[async_trait] +impl Retrieve> for LanceDB { + /// Implements vector similarity search for LanceDB using a custom query strategy. + /// + /// # Type Parameters + /// * `VectorQuery` - LanceDB's query type for vector similarity search + async fn retrieve( + &self, + search_strategy: &CustomStrategy, + query: Query, + ) -> Result> { + // Build the custom query using both strategy and query state + let query_builder = search_strategy.build_query(&query).await?; + + // Execute the query using the builder's built-in methods + let batches = query_builder + .execute() + .await? + .try_collect::>() + .await?; + + let documents = Self::retrieve_from_record_batches(&batches); + + Ok(query.retrieved_documents(documents)) + } +} + +impl LanceDB { + /// Retrieves documents from Arrow `RecordBatches` by processing each row and extracting content + /// and metadata fields. + /// + /// The function expects a "chunk" field to contain the main document content, while all other + /// string fields are treated as metadata. Non-string fields are currently skipped + fn retrieve_from_record_batches(batches: &[RecordBatch]) -> Vec { + let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum(); + let mut documents = Vec::with_capacity(total_rows); + + let process_batch = |batch: &RecordBatch, documents: &mut Vec| { + for row_idx in 0..batch.num_rows() { + let schema = batch.schema(); + + let (content, metadata): (String, Option) = { + let mut metadata = Metadata::default(); + let mut content = String::new(); + + for (col_idx, field) in schema.as_ref().fields().iter().enumerate() { + if let Some(array) = + batch.column(col_idx).as_any().downcast_ref::() + { + let value = array.value(row_idx).to_string(); + + if field.name() == "chunk" { + content = value; + } else { + metadata.insert(field.name().to_string(), value); + } + } else { + // Handle other array types as necessary + // TODO: Can't we just downcast to serde::Value or fail? + } + } + + ( + content, + if metadata.is_empty() { + None + } else { + Some(metadata) + }, + ) + }; + + documents.push(Document::new(content, metadata)); + } + }; + + batches + .iter() + .for_each(|batch| process_batch(batch, &mut documents)); + + documents + } +} + #[cfg(test)] mod test { use swiftide_core::{ diff --git a/swiftide-integrations/src/pgvector/retrieve.rs b/swiftide-integrations/src/pgvector/retrieve.rs index a67349d7..6cdfcba4 100644 --- a/swiftide-integrations/src/pgvector/retrieve.rs +++ b/swiftide-integrations/src/pgvector/retrieve.rs @@ -166,7 +166,7 @@ impl Retrieve>> for P let pool = self.get_pool().await?; // Build the custom query using both strategy and query state - let mut query_builder = search_strategy.build_query(&query)?; + let mut query_builder = search_strategy.build_query(&query).await?; // Execute the query using the builder's built-in methods let results = query_builder diff --git a/swiftide/Cargo.toml b/swiftide/Cargo.toml index d8853fc5..edac2fed 100644 --- a/swiftide/Cargo.toml +++ b/swiftide/Cargo.toml @@ -120,6 +120,7 @@ serde_json = { workspace = true } tokio = { workspace = true } arrow-array = { workspace = true } sqlx = { workspace = true } +lancedb = { workspace = true } [lints] workspace = true diff --git a/swiftide/tests/lancedb.rs b/swiftide/tests/lancedb.rs index f2f1b583..37c335bc 100644 --- a/swiftide/tests/lancedb.rs +++ b/swiftide/tests/lancedb.rs @@ -1,11 +1,16 @@ +use anyhow::Context; +use lancedb::query::{self as lance_query_builder, QueryBase}; use swiftide::indexing; use swiftide::indexing::{ transformers::{metadata_qa_code::NAME as METADATA_QA_CODE_NAME, ChunkCode, MetadataQACode}, EmbeddedField, }; -use swiftide::query::{self, states, Query}; +use swiftide::query::{self as swift_query_pipeline, states, Query}; use swiftide_indexing::{loaders, transformers, Pipeline}; -use swiftide_integrations::{fastembed::FastEmbed, lancedb::LanceDB}; +use swiftide_integrations::{ + fastembed::FastEmbed, + lancedb::{self as lance_integration, LanceDB}, +}; use swiftide_query::{answers, query_transformers, response_transformers}; use swiftide_test_utils::{mock_chat_completions, openai_client}; use temp_dir::TempDir; @@ -55,11 +60,11 @@ async fn test_lancedb() { .await .unwrap(); - let strategy = query::search_strategies::SimilaritySingleEmbedding::from_filter( + let strategy = swift_query_pipeline::search_strategies::SimilaritySingleEmbedding::from_filter( "filter = \"true\"".to_string(), ); - let query_pipeline = query::Pipeline::from_search_strategy(strategy) + let query_pipeline = swift_query_pipeline::Pipeline::from_search_strategy(strategy) .then_transform_query(query_transformers::GenerateSubquestions::from_client( openai_client.clone(), )) @@ -87,3 +92,133 @@ async fn test_lancedb() { codefile.to_str().unwrap() ); } + +#[test_log::test(tokio::test)] +async fn test_lancedb_retrieve_dynamic_search() { + // Setup temporary directory and file for testing + let tempdir = TempDir::new().unwrap(); + let codefile = tempdir.child("main.rs"); + let code = "fn main() { println!(\"Hello, World!\"); }"; + std::fs::write(&codefile, code).unwrap(); + + // Setup mock servers to simulate API responses + let mock_server = MockServer::start().await; + mock_chat_completions(&mock_server).await; + + let openai_client = openai_client(&mock_server.uri(), "text-embedding-3-small", "gpt-4o"); + + let fastembed = FastEmbed::try_default().unwrap(); + + let lancedb = LanceDB::builder() + .uri(tempdir.child("lancedb").to_str().unwrap()) + .vector_size(384) + .with_vector(EmbeddedField::Combined) + .with_metadata(METADATA_QA_CODE_NAME) + .with_metadata("filter") + .with_metadata("path") + .table_name("swiftide_test") + .build() + .unwrap(); + + Pipeline::from_loader(loaders::FileLoader::new(tempdir.path()).with_extensions(&["rs"])) + .then_chunk(ChunkCode::try_for_language("rust").unwrap()) + .then(MetadataQACode::new(openai_client.clone())) + .then(|mut node: indexing::Node| { + // Add path to metadata, by default, storage will store all metadata fields + node.metadata + .insert("path", node.path.display().to_string()); + node.metadata + .insert("filter".to_string(), "true".to_string()); + Ok(node) + }) + .then_in_batch(transformers::Embed::new(fastembed.clone()).with_batch_size(20)) + .log_nodes() + .then_store_with(lancedb.clone()) + .run() + .await + .unwrap(); + + // Create the custom query strategy for vector similarity search + let create_vector_search_strategy = + |lancedb: &LanceDB, + table_name: String| + -> swift_query_pipeline::search_strategies::CustomStrategy< + lance_query_builder::VectorQuery, + > { + let table_name = table_name.clone(); + let lancedb = lancedb.clone(); + + swift_query_pipeline::search_strategies::CustomStrategy::from_async_query( + move |query_node| { + // Create owned copies for the async block + let table_name = table_name.clone(); + let lancedb = lancedb.clone(); + + let embedding = if let Some(embedding) = &query_node.embedding { + embedding.clone() + } else { + panic!("Query embedding not found"); + }; + + // Return a Future using async block syntax + Box::pin(async move { + // Create a new connection for each query execution + let connection = lancedb.get_connection().await?; + + // Open the table within the query execution context + let vector_table = connection + .open_table(&table_name) + .execute() + .await + .context("Failed to open vector search table")?; + + let vector_field = + lance_integration::VectorConfig::from(EmbeddedField::Combined) + .field_name(); + + // Build and return the query + let query_builder = vector_table + .query() + .nearest_to(embedding.as_slice())? + .column(&vector_field) + .limit(20); + + Ok(query_builder) + // Connection is dropped here when query_builder is executed + }) + }, + ) + }; + + let vector_search_strategy = + create_vector_search_strategy(&lancedb, "swiftide_test".to_string()); + + let query_pipeline = + swift_query_pipeline::Pipeline::from_search_strategy(vector_search_strategy) + .then_transform_query(query_transformers::GenerateSubquestions::from_client( + openai_client.clone(), + )) + .then_transform_query(query_transformers::Embed::from_client(fastembed.clone())) + .then_retrieve(lancedb.clone()) + .then_transform_response(response_transformers::Summary::from_client( + openai_client.clone(), + )) + .then_answer(answers::Simple::from_client(openai_client.clone())); + + let result: Query = query_pipeline.query("What is swiftide?").await.unwrap(); + + dbg!(&result); + + assert_eq!( + result.answer(), + "\n\nHello there, how may I assist you today?" + ); + + let retrieved_document = result.documents().first().unwrap(); + assert_eq!(retrieved_document.content(), code); + + assert_eq!( + retrieved_document.metadata().get("path").unwrap(), + codefile.to_str().unwrap() + ); +}