Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic implementation of VectorSearch Step and KnowledgeBases #1006

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,695 changes: 1,695 additions & 0 deletions pdm.lock

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ text-clustering = [

# minhash
minhash = ["datasketch >= 1.6.5", "nltk>3.8.1"]
lancedb = [
"lancedb>=0.13.0",
]

[project.urls]
Documentation = "https://distilabel.argilla.io/"
Expand All @@ -123,4 +126,4 @@ select = ["E", "W", "F", "I", "C", "B"]
ignore = ["E501", "B905", "B008"]

[tool.pytest.ini_options]
testpaths = ["tests"]
testpaths = ["tests"]
3 changes: 3 additions & 0 deletions src/distilabel/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
# Docs page for the custom errors
DISTILABEL_DOCS_URL: Final[str] = "https://distilabel.argilla.io/latest/"

# Argilla related constants
ARGILLA_API_URL_ENV_VAR_NAME: Final[str] = "ARGILLA_API_URL"
ARGILLA_API_KEY_ENV_VAR_NAME: Final[str] = "ARGILLA_API_KEY"

__all__ = [
"DISTILABEL_METADATA_KEY",
Expand Down
14 changes: 14 additions & 0 deletions src/distilabel/knowledge_bases/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

56 changes: 56 additions & 0 deletions src/distilabel/knowledge_bases/argilla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List

import pkg_resources
from argilla import Query, Similar
from pydantic import Field

from distilabel.knowledge_bases.base import KnowledgeBase
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.utils.argilla import ArgillaBase


class ArgillaKnowledgeBase(KnowledgeBase, ArgillaBase):
vector_field: RuntimeParameter[str] = Field(
None, description="The name of the field containing the vector."
)

def load(self) -> None:
ArgillaBase.load(self)

if pkg_resources.get_distribution("argilla").version < "2.3.0":
raise ValueError(
"Argilla version must be 2.3.0 or higher to use ArgillaKnowledgeBase."
)

self._dataset = self._client.datasets(
name=self.dataset_name, workspace=self.dataset_workspace
)

def unload(self) -> None:
self._client = None
self._dataset = None

def vector_search(
self, query_vector: List[float], n_retrieved_documents: int
) -> List[Dict[str, Any]]:
return self._dataset.records(
query=Query(similar=Similar(name=self.vector_field, value=query_vector)),
limit=n_retrieved_documents,
).to_list(flatten=True)

@property
def columns(self) -> List[str]:
return list(self._dataset.records(limit=1).to_list(flatten=True)[0].keys())
67 changes: 67 additions & 0 deletions src/distilabel/knowledge_bases/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional

from pydantic import BaseModel

from distilabel.mixins.runtime_parameters import RuntimeParametersMixin
from distilabel.utils.serialization import _Serializable


class KnowledgeBase(RuntimeParametersMixin, BaseModel, _Serializable, ABC):
@abstractmethod
def load(self) -> None:
pass

@abstractmethod
def unload(self) -> None:
pass

def keyword_search(
self, keywords: List[str], n_retrieved_documents: int
) -> List[str]:
raise NotImplementedError

def vector_search(
self, vector: List[float], n_retrieved_documents: int
) -> List[Dict[str, Any]]:
raise NotImplementedError

def hybrid_search(
self, vector: List[float], keywords: List[str], n_retrieved_documents: int
) -> List[Dict[str, Any]]:
raise NotImplementedError

def search(
self,
vector: Optional[List[float]] = None,
keywords: Optional[List[str]] = None,
n_retrieved_documents: int = 5,
) -> List[Dict[str, Any]]:
if vector is not None and keywords is not None:
return self.hybrid_search(vector, keywords, n_retrieved_documents)
elif vector is not None:
return self.vector_search(vector, n_retrieved_documents)
elif keywords is not None:
return self.keyword_search(keywords, n_retrieved_documents)
else:
raise ValueError("Either vector or keywords must be provided.")

@property
@abstractmethod
def columns(self) -> List[str]:
pass
63 changes: 63 additions & 0 deletions src/distilabel/knowledge_bases/lancedb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from datetime import timedelta
from typing import Any, Dict, List, Optional

import lancedb
from lancedb import DBConnection
from lancedb.table import Table
from pydantic import Field, PrivateAttr

from distilabel.knowledge_bases.base import KnowledgeBase


class LanceDBKnowledgeBase(KnowledgeBase):
uri: str = Field(..., description="The URI of the LanceDB database.")
table_name: str = Field(..., description="The name of the table to use.")
api_key: Optional[str] = Field(
None, description="The API key to use to connect to the LanceDB database."
)
region: Optional[str] = Field(
None, description="The region of the LanceDB database."
)
read_consistency_interval: Optional[timedelta] = Field(
None, description="The read consistency interval of the LanceDB database."
)
request_thread_pool_size: Optional[int] = Field(
None, description="The request thread pool size of the LanceDB database."
)
index_cache_size: Optional[int] = Field(
None, description="The index cache size of the LanceDB database."
)

_db: DBConnection = PrivateAttr(None)
_tbl: Table = PrivateAttr(None)

def load(self) -> None:
self._db = lancedb.connect(self.uri)
self._tbl = self._db.open_table(name=self.table_name)

def unload(self) -> None:
self._db.close()

def vector_search(
self, query_vector: List[float], n_retrieved_documents: int
) -> List[Dict[str, Any]]:
return self._tbl.search(query_vector).limit(n_retrieved_documents).to_list()

@property
def columns(self) -> List[str]:
return self._tbl.schema.names
102 changes: 5 additions & 97 deletions src/distilabel/steps/argilla/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib.util
import os
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING

from pydantic import Field, PrivateAttr, SecretStr

try:
import argilla as rg
except ImportError:
pass

from distilabel.errors import DistilabelUserError
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.base import Step, StepInput
from distilabel.utils.argilla import ArgillaBase

if TYPE_CHECKING:
from argilla import Argilla, Dataset

from distilabel.steps.typing import StepColumns, StepOutput


_ARGILLA_API_URL_ENV_VAR_NAME = "ARGILLA_API_URL"
_ARGILLA_API_KEY_ENV_VAR_NAME = "ARGILLA_API_KEY"


class ArgillaBase(Step, ABC):
class ArgillaStepBase(Step, ArgillaBase, ABC):
"""Abstract step that provides a class to subclass from, that contains the boilerplate code
required to interact with Argilla, as well as some extra validations on top of it. It also defines
the abstract methods that need to be implemented in order to add a new dataset type as a step.
Expand Down Expand Up @@ -67,69 +51,9 @@ class ArgillaBase(Step, ABC):
- dynamic, based on the `inputs` value provided
"""

dataset_name: RuntimeParameter[str] = Field(
default=None, description="The name of the dataset in Argilla."
)
dataset_workspace: Optional[RuntimeParameter[str]] = Field(
default=None,
description="The workspace where the dataset will be created in Argilla. Defaults "
"to `None` which means it will be created in the default workspace.",
)

api_url: Optional[RuntimeParameter[str]] = Field(
default_factory=lambda: os.getenv(_ARGILLA_API_URL_ENV_VAR_NAME),
description="The base URL to use for the Argilla API requests.",
)
api_key: Optional[RuntimeParameter[SecretStr]] = Field(
default_factory=lambda: os.getenv(_ARGILLA_API_KEY_ENV_VAR_NAME),
description="The API key to authenticate the requests to the Argilla API.",
)

_client: Optional["Argilla"] = PrivateAttr(...)
_dataset: Optional["Dataset"] = PrivateAttr(...)

def model_post_init(self, __context: Any) -> None:
"""Checks that the Argilla Python SDK is installed, and then filters the Argilla warnings."""
super().model_post_init(__context)

if importlib.util.find_spec("argilla") is None:
raise ImportError(
"Argilla is not installed. Please install it using `pip install argilla"
" --upgrade`."
)

def _client_init(self) -> None:
"""Initializes the Argilla API client with the provided `api_url` and `api_key`."""
try:
self._client = rg.Argilla( # type: ignore
api_url=self.api_url,
api_key=self.api_key.get_secret_value(), # type: ignore
headers={"Authorization": f"Bearer {os.environ['HF_TOKEN']}"}
if isinstance(self.api_url, str)
and "hf.space" in self.api_url
and "HF_TOKEN" in os.environ
else {},
)
except Exception as e:
raise DistilabelUserError(
f"Failed to initialize the Argilla API: {e}",
page="sections/how_to_guides/advanced/argilla/",
) from e

@property
def _dataset_exists_in_workspace(self) -> bool:
"""Checks if the dataset already exists in Argilla in the provided workspace if any.

Returns:
`True` if the dataset exists, `False` otherwise.
"""
return (
self._client.datasets( # type: ignore
name=self.dataset_name, # type: ignore
workspace=self.dataset_workspace,
)
is not None
)
def load(self) -> None:
super().load()

@property
def outputs(self) -> "StepColumns":
Expand All @@ -138,22 +62,6 @@ def outputs(self) -> "StepColumns":
"""
return []

def load(self) -> None:
"""Method to perform any initialization logic before the `process` method is
called. For example, to load an LLM, stablish a connection to a database, etc.
"""
super().load()

if self.api_url is None or self.api_key is None:
raise DistilabelUserError(
"`Argilla` step requires the `api_url` and `api_key` to be provided. Please,"
" provide those at step instantiation, via environment variables `ARGILLA_API_URL`"
" and `ARGILLA_API_KEY`, or as `Step` runtime parameters via `pipeline.run(parameters={...})`.",
page="sections/how_to_guides/advanced/argilla/",
)

self._client_init()

@property
@abstractmethod
def inputs(self) -> "StepColumns": ...
Expand Down
4 changes: 2 additions & 2 deletions src/distilabel/steps/argilla/preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
pass

from distilabel.errors import DistilabelUserError
from distilabel.steps.argilla.base import ArgillaBase
from distilabel.steps.argilla.base import ArgillaStepBase
from distilabel.steps.base import StepInput

if TYPE_CHECKING:
Expand All @@ -33,7 +33,7 @@
from distilabel.steps.typing import StepOutput


class PreferenceToArgilla(ArgillaBase):
class PreferenceToArgilla(ArgillaStepBase):
"""Creates a preference dataset in Argilla.

Step that creates a dataset in Argilla during the load phase, and then pushes the input
Expand Down
4 changes: 2 additions & 2 deletions src/distilabel/steps/argilla/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
pass

from distilabel.errors import DistilabelUserError
from distilabel.steps.argilla.base import ArgillaBase
from distilabel.steps.argilla.base import ArgillaStepBase
from distilabel.steps.base import StepInput

if TYPE_CHECKING:
from distilabel.steps.typing import StepOutput


class TextGenerationToArgilla(ArgillaBase):
class TextGenerationToArgilla(ArgillaStepBase):
"""Creates a text generation dataset in Argilla.

`Step` that creates a dataset in Argilla during the load phase, and then pushes the input
Expand Down
Loading
Loading