Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(query): Add custom lancedb query generation for lancedb search #518

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

115 changes: 44 additions & 71 deletions swiftide-core/src/search_strategies/custom_strategy.rs
Original file line number Diff line number Diff line change
@@ -1,64 +1,28 @@
//! Generic vector search strategy framework for customizable query generation.
//!
//! Provides core abstractions for vector similarity search:
//! - Generic query type parameter for retriever-specific implementations
//! - Flexible query generation through closure-based configuration
//!
//! This module implements a strategy pattern for vector similarity search,
//! allowing different retrieval backends to provide their own query generation
//! logic while maintaining a consistent interface. The framework emphasizes
//! composition over inheritance, enabling configuration through closures
//! rather than struct fields.
//! Implements a flexible vector search strategy framework using closure-based configuration.
//! Supports both synchronous and asynchronous query generation for different retrieval backends.

use crate::querying::{self, states, Query};
use anyhow::{anyhow, Result};
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;

/// A type alias for query generation functions.
///
/// The query generator takes a pending query state and produces a
/// retriever-specific query type. All configuration parameters should
/// be captured in the closure's environment.
// Function type for generating retriever-specific queries
type QueryGenerator<Q> = Arc<dyn Fn(&Query<states::Pending>) -> Result<Q> + Send + Sync>;

/// `CustomStrategy` provides a flexible way to generate retriever-specific search queries.
///
/// This struct implements a strategy pattern for vector similarity search, allowing
/// different retrieval backends to provide their own query generation logic. Configuration
/// is managed through the query generation closure, promoting a more flexible and
/// composable design.
///
/// # Type Parameters
/// * `Q` - The retriever-specific query type (e.g., `sqlx::QueryBuilder` for `PostgreSQL`)
///
/// # Examples
/// ```ignore
/// // Define search configuration
/// const MAX_SEARCH_RESULTS: i64 = 5;
///
/// // Create a custom search strategy
/// let strategy = CustomStrategy::from_query(|query_node| {
/// let mut builder = QueryBuilder::new();
///
/// // Configure search parameters within the closure
/// builder.push(" LIMIT ");
/// builder.push_bind(MAX_SEARCH_RESULTS);
///
/// Ok(builder)
/// });
/// ```
///
/// # Implementation Notes
/// - Search configuration (like result limits and vector fields) should be defined
/// in the closure's scope
/// - Implementers are responsible for validating configuration values
/// - The query generator has access to the full query state for maximum flexibility
// Function type for async query generation
type AsyncQueryGenerator<Q> = Arc<
dyn Fn(&Query<states::Pending>) -> Pin<Box<dyn Future<Output = Result<Q>> + Send>>
+ Send
+ Sync,
>;

/// Implements the strategy pattern for vector similarity search, allowing retrieval backends
/// to define custom query generation logic through closures.
pub struct CustomStrategy<Q> {
/// The query generation function now returns a `Q`
query: Option<QueryGenerator<Q>>,

/// `PhantomData` to handle the generic parameter
async_query: Option<AsyncQueryGenerator<Q>>,
_marker: PhantomData<Q>,
}

Expand All @@ -68,51 +32,60 @@ impl<Q> Default for CustomStrategy<Q> {
fn default() -> Self {
Self {
query: None,
async_query: None,
_marker: PhantomData,
}
}
}

// Manual Clone implementation instead of derive
impl<Q> Clone for CustomStrategy<Q> {
fn clone(&self) -> Self {
Self {
query: self.query.clone(), // Arc clone is fine
query: self.query.clone(),
async_query: self.async_query.clone(),
_marker: PhantomData,
}
}
}

impl<Q: Send + Sync + 'static> CustomStrategy<Q> {
/// Creates a new `CustomStrategy` with a query generation function.
///
/// The provided closure should contain all necessary configuration for
/// query generation. This design allows for more flexible configuration
/// management compared to struct-level fields.
///
/// # Parameters
/// * `query` - A closure that generates retriever-specific queries
/// Creates a new strategy with a synchronous query generator.
pub fn from_query(
query: impl Fn(&Query<states::Pending>) -> Result<Q> + Send + Sync + 'static,
) -> Self {
Self {
query: Some(Arc::new(query)),
async_query: None,
_marker: PhantomData,
}
}

/// Creates a new strategy with an asynchronous query generator.
pub fn from_async_query<F>(
query: impl Fn(&Query<states::Pending>) -> F + Send + Sync + 'static,
) -> Self
where
F: Future<Output = Result<Q>> + Send + 'static,
{
Self {
query: None,
async_query: Some(Arc::new(move |q| Box::pin(query(q)))),
_marker: PhantomData,
}
}

/// Gets the query builder, which can then be used to build the actual query
/// Generates a query using either the sync or async generator.
/// Returns error if no query generator is set.
///
/// # Errors
/// This function will return an error if:
/// - No query function has been set (use `from_query` to set a query function).
/// - The query function fails while processing the provided `query_node`.
pub fn build_query(&self, query_node: &Query<states::Pending>) -> Result<Q> {
match &self.query {
Some(query_fn) => Ok(query_fn(query_node)?),
None => Err(anyhow!(
"No query function has been set. Use from_query() to set a query function."
)),
/// Returns an error if:
/// * No query generator has been configured
/// * The configured query generator fails during query generation
pub async fn build_query(&self, query_node: &Query<states::Pending>) -> Result<Q> {
match (&self.query, &self.async_query) {
(Some(query_fn), _) => query_fn(query_node),
(_, Some(async_fn)) => async_fn(query_node).await,
_ => Err(anyhow!("No query function has been set.")),
}
}
}
130 changes: 90 additions & 40 deletions swiftide-integrations/src/lancedb/retrieve.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use anyhow::Result;
use arrow::datatypes::SchemaRef;
use arrow_array::StringArray;
use arrow_array::{RecordBatch, StringArray};
use async_trait::async_trait;
use futures_util::TryStreamExt;
use itertools::Itertools;
use lancedb::query::{ExecutableQuery, QueryBase};
use swiftide_core::{
document::Document,
indexing::Metadata,
querying::{search_strategies::SimilaritySingleEmbedding, states, Query},
querying::{
search_strategies::{CustomStrategy, SimilaritySingleEmbedding},
states, Query,
},
Retrieve,
};

Expand Down Expand Up @@ -66,43 +68,7 @@ impl Retrieve<SimilaritySingleEmbedding<String>> for LanceDB {
.try_collect::<Vec<_>>()
.await?;

let mut documents = vec![];

for batch in batches {
let schema: SchemaRef = batch.schema();

for row_idx in 0..batch.num_rows() {
let mut metadata = Metadata::default();
let mut content = String::new();

for (col_idx, field) in schema.fields().iter().enumerate() {
let column = batch.column(col_idx);

if let Some(array) = column.as_any().downcast_ref::<StringArray>() {
if field.name() == "chunk" {
// Extract the "content" field
content = array.value(row_idx).to_string();
} else {
// Assume other fields are part of the metadata
let value = array.value(row_idx).to_string();
metadata.insert(field.name().clone(), value);
}
} else {
// Handle other array types as necessary
// TODO: Can't we just downcast to serde::Value or fail?
}
}

documents.push(Document::new(
content,
if metadata.is_empty() {
None
} else {
Some(metadata)
},
));
}
}
let documents = Self::retrieve_from_record_batches(&batches);

Ok(query.retrieved_documents(documents))
}
Expand All @@ -124,6 +90,90 @@ impl Retrieve<SimilaritySingleEmbedding> for LanceDB {
}
}

#[async_trait]
impl<Q: ExecutableQuery + Send + Sync + 'static> Retrieve<CustomStrategy<Q>> for LanceDB {
/// Implements vector similarity search for LanceDB using a custom query strategy.
///
/// # Type Parameters
/// * `VectorQuery` - LanceDB's query type for vector similarity search
async fn retrieve(
&self,
search_strategy: &CustomStrategy<Q>,
query: Query<states::Pending>,
) -> Result<Query<states::Retrieved>> {
// Build the custom query using both strategy and query state
let query_builder = search_strategy.build_query(&query).await?;

// Execute the query using the builder's built-in methods
let batches = query_builder
.execute()
.await?
.try_collect::<Vec<_>>()
.await?;

let documents = Self::retrieve_from_record_batches(&batches);

Ok(query.retrieved_documents(documents))
}
}

impl LanceDB {
/// Retrieves documents from Arrow `RecordBatches` by processing each row and extracting content
/// and metadata fields.
///
/// The function expects a "chunk" field to contain the main document content, while all other
/// string fields are treated as metadata. Non-string fields are currently skipped
fn retrieve_from_record_batches(batches: &[RecordBatch]) -> Vec<Document> {
let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
let mut documents = Vec::with_capacity(total_rows);

let process_batch = |batch: &RecordBatch, documents: &mut Vec<Document>| {
for row_idx in 0..batch.num_rows() {
let schema = batch.schema();

let (content, metadata): (String, Option<Metadata>) = {
let mut metadata = Metadata::default();
let mut content = String::new();

for (col_idx, field) in schema.as_ref().fields().iter().enumerate() {
if let Some(array) =
batch.column(col_idx).as_any().downcast_ref::<StringArray>()
{
let value = array.value(row_idx).to_string();

if field.name() == "chunk" {
content = value;
} else {
metadata.insert(field.name().to_string(), value);
}
} else {
// Handle other array types as necessary
// TODO: Can't we just downcast to serde::Value or fail?
}
}

(
content,
if metadata.is_empty() {
None
} else {
Some(metadata)
},
)
};

documents.push(Document::new(content, metadata));
}
};

batches
.iter()
.for_each(|batch| process_batch(batch, &mut documents));

documents
}
}

#[cfg(test)]
mod test {
use swiftide_core::{
Expand Down
2 changes: 1 addition & 1 deletion swiftide-integrations/src/pgvector/retrieve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ impl Retrieve<CustomStrategy<sqlx::QueryBuilder<'static, sqlx::Postgres>>> for P
let pool = self.get_pool().await?;

// Build the custom query using both strategy and query state
let mut query_builder = search_strategy.build_query(&query)?;
let mut query_builder = search_strategy.build_query(&query).await?;

// Execute the query using the builder's built-in methods
let results = query_builder
Expand Down
1 change: 1 addition & 0 deletions swiftide/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ serde_json = { workspace = true }
tokio = { workspace = true }
arrow-array = { workspace = true }
sqlx = { workspace = true }
lancedb = { workspace = true }

[lints]
workspace = true
Expand Down
Loading
Loading