Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: LanceDB integration #3739

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/ingest/ingest.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
unstructured-ingest[airtable, astradb, azure, azure-cognitive-search, bedrock, biomed, box, chroma, clarifai, confluence, couchbase, databricks-volumes, delta-table, discord, dropbox, elasticsearch, embed-huggingface, embed-octoai, embed-vertexai, embed-voyageai, gcs, github, gitlab, google-drive, hubspot, jira, kafka, kdbai, milvus, mongodb, notion, onedrive, openai, opensearch, outlook, pinecone, postgres, qdrant, reddit, remote, s3, salesforce, sftp, sharepoint, singlestore, slack, vectara, weaviate, wikipedia]
unstructured-ingest[airtable, astradb, azure, azure-cognitive-search, bedrock, biomed, box, chroma, clarifai, confluence, couchbase, databricks-volumes, delta-table, discord, dropbox, elasticsearch, embed-huggingface, embed-octoai, embed-vertexai, embed-voyageai, gcs, github, gitlab, google-drive, hubspot, jira, kafka, kdbai, milvus, mongodb, notion, onedrive, openai, opensearch, outlook, pinecone, lancedb, postgres, qdrant, reddit, remote, s3, salesforce, sftp, sharepoint, singlestore, slack, vectara, weaviate, wikipedia]
s3fs>=2024.9.0
urllib3>=1.26.20
backoff>=2.2.1
Expand Down
40 changes: 40 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,46 @@ def load_requirements(file_list: Optional[Union[str, List[str]]] = None) -> List
"rst": rst_reqs,
"tsv": tsv_reqs,
"xlsx": xlsx_reqs,
# Extra requirements for data connectors
"airtable": load_requirements("requirements/ingest/airtable.in"),
"astradb": load_requirements("requirements/ingest/astradb.in"),
"azure": load_requirements("requirements/ingest/azure.in"),
"azure-cognitive-search": load_requirements(
"requirements/ingest/azure-cognitive-search.in",
),
"biomed": load_requirements("requirements/ingest/biomed.in"),
"box": load_requirements("requirements/ingest/box.in"),
"chroma": load_requirements("requirements/ingest/chroma.in"),
"clarifai": load_requirements("requirements/ingest/clarifai.in"),
"confluence": load_requirements("requirements/ingest/confluence.in"),
"delta-table": load_requirements("requirements/ingest/delta-table.in"),
"discord": load_requirements("requirements/ingest/discord.in"),
"dropbox": load_requirements("requirements/ingest/dropbox.in"),
"elasticsearch": load_requirements("requirements/ingest/elasticsearch.in"),
"gcs": load_requirements("requirements/ingest/gcs.in"),
"github": load_requirements("requirements/ingest/github.in"),
"gitlab": load_requirements("requirements/ingest/gitlab.in"),
"google-drive": load_requirements("requirements/ingest/google-drive.in"),
"hubspot": load_requirements("requirements/ingest/hubspot.in"),
"jira": load_requirements("requirements/ingest/jira.in"),
"kafka": load_requirements("requirements/ingest/kafka.in"),
"mongodb": load_requirements("requirements/ingest/mongodb.in"),
"notion": load_requirements("requirements/ingest/notion.in"),
"onedrive": load_requirements("requirements/ingest/onedrive.in"),
"opensearch": load_requirements("requirements/ingest/opensearch.in"),
"outlook": load_requirements("requirements/ingest/outlook.in"),
"pinecone": load_requirements("requirements/ingest/pinecone.in"),
"lancedb": load_requirements("requirements/ingest/lancedb.in"),
"postgres": load_requirements("requirements/ingest/postgres.in"),
"qdrant": load_requirements("requirements/ingest/qdrant.in"),
"reddit": load_requirements("requirements/ingest/reddit.in"),
"s3": load_requirements("requirements/ingest/s3.in"),
"sharepoint": load_requirements("requirements/ingest/sharepoint.in"),
"salesforce": load_requirements("requirements/ingest/salesforce.in"),
"sftp": load_requirements("requirements/ingest/sftp.in"),
"slack": load_requirements("requirements/ingest/slack.in"),
"wikipedia": load_requirements("requirements/ingest/wikipedia.in"),
"weaviate": load_requirements("requirements/ingest/weaviate.in"),
# Legacy extra requirements
"huggingface": load_requirements("requirements/huggingface.in"),
"local-inference": all_doc_reqs,
Expand Down
119 changes: 119 additions & 0 deletions unstructured/ingest/connector/lancedb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import copy
import json
import multiprocessing as mp
import typing as t
import uuid
from dataclasses import dataclass

from unstructured.ingest.enhanced_dataclass import enhanced_field
from unstructured.ingest.enhanced_dataclass.core import _asdict
from unstructured.ingest.error import DestinationConnectionError, WriteError
from unstructured.ingest.interfaces import (
AccessConfig,
BaseConnectorConfig,
BaseDestinationConnector,
ConfigSessionHandleMixin,
IngestDocSessionHandleMixin,
WriteConfig,
)
from unstructured.ingest.logger import logger
from unstructured.ingest.utils.data_prep import batch_generator
from unstructured.staging.base import flatten_dict
from unstructured.utils import requires_dependencies

if t.TYPE_CHECKING:
import lancedb

@dataclass
class LanceDBAccessConfig(AccessConfig):
uri: str = enhanced_field(sensitive=True)

@dataclass
class SimpleLanceDBConfig(ConfigSessionHandleMixin, BaseConnectorConfig):
table_name: str
access_config: LanceDBAccessConfig

@dataclass
class LanceDBWriteConfig(WriteConfig):
batch_size: int = 50
num_processes: int = 1

@dataclass
class LanceDBDestinationConnector(IngestDocSessionHandleMixin, BaseDestinationConnector):
write_config: LanceDBWriteConfig
connector_config: SimpleLanceDBConfig
_table: t.Optional["lancedb.Table"] = None

def to_dict(self, **kwargs):
self_cp = copy.copy(self)
if hasattr(self_cp, "_table"):
setattr(self_cp, "_table", None)
return _asdict(self_cp, **kwargs)

@property
def lancedb_table(self):
if self._table is None:
self._table = self.create_table()
return self._table

def initialize(self):
pass

@requires_dependencies(["lancedb"], extras="lancedb")
def create_table(self) -> "lancedb.Table":
import lancedb

db = lancedb.connect(self.connector_config.access_config.uri)
table = db.open_table(self.connector_config.table_name)
logger.debug(f"Connected to table: {table}")
return table

@DestinationConnectionError.wrap
def check_connection(self):
_ = self.lancedb_table

@DestinationConnectionError.wrap
@requires_dependencies(["lancedb"], extras="lancedb")
def add_batch(self, batch):
table = self.lancedb_table
try:
table.add(batch)
except Exception as error:
raise WriteError(f"LanceDB error: {error}") from error
logger.debug(f"Added {len(batch)} records to the table")

def write_dict(self, *args, elements_dict: t.List[t.Dict[str, t.Any]], **kwargs) -> None:
logger.info(
f"Adding {len(elements_dict)} elements to destination "
f"table {self.connector_config.table_name}",
)

lancedb_batch_size = self.write_config.batch_size

logger.info(f"using {self.write_config.num_processes} processes to upload")
if self.write_config.num_processes == 1:
for chunk in batch_generator(elements_dict, lancedb_batch_size):
self.add_batch(chunk)

else:
with mp.Pool(
processes=self.write_config.num_processes,
) as pool:
pool.map(
self.add_batch, list(batch_generator(elements_dict, lancedb_batch_size))
)

def normalize_dict(self, element_dict: dict) -> dict:
flattened = flatten_dict(
element_dict,
separator="_",
flatten_lists=True,
remove_none=True,
)
return {
"id": str(uuid.uuid4()),
"vector": flattened.pop("embeddings", None),
"text": flattened.pop("text", None),
"metadata": json.dumps(flattened),
**flattened,
}
52 changes: 52 additions & 0 deletions unstructured/ingest/v2/examples/examples_lancedb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os
from pathlib import Path
from unstructured.ingest.v2.interfaces import ProcessorConfig
from unstructured.ingest.v2.logger import logger
from unstructured.ingest.v2.pipeline.pipeline import Pipeline
from unstructured.ingest.v2.processes.chunker import ChunkerConfig
from unstructured.ingest.v2.processes.connectors.local import (
LocalConnectionConfig,
LocalDownloaderConfig,
LocalIndexerConfig,
)
from unstructured.ingest.v2.processes.embedder import EmbedderConfig
from unstructured.ingest.v2.processes.partitioner import PartitionerConfig

# Import the LanceDB-specific classes (assuming they've been created)
from unstructured.ingest.v2.processes.connectors.lancedb import (
LanceDBConnectionConfig,
LanceDBUploaderConfig,
LanceDBUploadStagerConfig,
)

base_path = Path(__file__).parent.parent.parent.parent.parent
docs_path = base_path / "example-docs"
work_dir = base_path / "tmp_ingest"
output_path = work_dir / "output"
download_path = work_dir / "download"

if __name__ == "__main__":
logger.info(f"Writing all content in: {work_dir.resolve()}")

Pipeline.from_configs(
context=ProcessorConfig(work_dir=str(work_dir.resolve())),
indexer_config=LocalIndexerConfig(
input_path=str(docs_path.resolve()) + "/book-war-and-peace-1p.txt"
),
downloader_config=LocalDownloaderConfig(download_dir=download_path),
source_connection_config=LocalConnectionConfig(),
partitioner_config=PartitionerConfig(strategy="fast"),
chunker_config=ChunkerConfig(chunking_strategy="by_title"),
embedder_config=EmbedderConfig(embedding_provider="langchain-huggingface"),
destination_connection_config=LanceDBConnectionConfig(
# You'll need to set LANCEDB_URI environment variable to run this example
uri=os.getenv("LANCEDB_URI", "data"),
table_name=os.getenv(
"LANCEDB_TABLE",
default="your table name here. e.g. my-table,"
"or define in environment variable LANCEDB_TABLE",
),
),
stager_config=LanceDBUploadStagerConfig(),
uploader_config=LanceDBUploaderConfig(batch_size=10, num_of_processes=2),
).run()