Skip to content

Commit

Permalink
feat: add rerankers to xpacks
Browse files Browse the repository at this point in the history
Co-authored-by: Szymon Dudycz <[email protected]>
GitOrigin-RevId: 63e5945ed6309ec64c696f0b18f0db173bf0bb74
  • Loading branch information
2 people authored and Manul from Pathway committed May 10, 2024
1 parent daf2655 commit 633469e
Show file tree
Hide file tree
Showing 6 changed files with 400 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm
- `pathway.stdlib.indexing.vector_document_index`, with a few predefined instances of `pathway.stdlib.indexing.data_index.DataIndex`
- `pathway.stdlib.indexing.bm25`, with implementations of `pathway.stdlib.indexing.data_index.InnerIndex` based on BM25 index provided by Tantivy
- `pathway.stdlib.indexing.full_text_document_index`, with a predefined instance of `pathway.stdlib.indexing.data_index.DataIndex`
- Introduced the `reranker` module under `llm.xpacks`. Includes few re-ranking strategies and utility functions for RAG applications.

### Changed
- **BREAKING**: `windowby` generates IDs of produced rows differently than in the previous version.
Expand Down
6 changes: 6 additions & 0 deletions python/pathway/xpacks/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright © 2024 Pathway
from ._typing import Doc, DocTransformer, DocTransformerCallable # isort: skip

from . import (
embedders,
llms,
parsers,
prompts,
question_answering,
rerankers,
splitters,
vector_store,
)
Expand All @@ -17,5 +19,9 @@
"prompts",
"question_answering",
"splitters",
"rerankers",
"vector_store",
"Doc",
"DocTransformer",
"DocTransformerCallable",
]
13 changes: 13 additions & 0 deletions python/pathway/xpacks/llm/_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Callable, Iterable, TypeAlias, Union

import pathway as pw

Doc: TypeAlias = dict[str, str | dict]


DocTransformerCallable: TypeAlias = Union[
Callable[[Iterable[Doc]], Iterable[Doc]],
Callable[[Iterable[Doc], float], Iterable[Doc]],
]

DocTransformer: TypeAlias = Union[pw.UDF, DocTransformerCallable]
44 changes: 44 additions & 0 deletions python/pathway/xpacks/llm/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import functools
import threading
from collections.abc import Callable
from typing import Any

import pathway as pw
from pathway.xpacks.llm import llms


# https://stackoverflow.com/a/75094151
Expand Down Expand Up @@ -41,3 +45,43 @@ def wrapper(*args, **kwargs):
return wrapper
else:
return func


def _check_model_accepts_arg(model_name: str, provider: str, arg: str):
from litellm import get_supported_openai_params

response: list[str] = get_supported_openai_params(
model=model_name, custom_llm_provider=provider
)

return arg in response


def _check_llm_accepts_arg(llm: pw.UDF, arg: str) -> bool:
try:
model_name = llm.kwargs["model"] # type: ignore
except KeyError:
return False

if isinstance(llm, llms.OpenAIChat):
return _check_model_accepts_arg(model_name, "openai", arg)
elif isinstance(llm, llms.LiteLLMChat):
provider = model_name.split("/")[0]
model = "".join(
model_name.split("/")[1:]
) # handle case: replicate/meta/meta-llama-3-8b
return _check_model_accepts_arg(model, provider, arg)
elif isinstance(llm, llms.CohereChat):
return _check_model_accepts_arg(model_name, "cohere", arg)

return False


def _check_llm_accepts_logit_bias(llm: pw.UDF) -> bool:
return _check_llm_accepts_arg(llm, "logit_bias")


def _extract_value(data: Any | pw.Json) -> Any:
if isinstance(data, pw.Json):
return data.value
return data
27 changes: 21 additions & 6 deletions python/pathway/xpacks/llm/question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pathway as pw
from pathway.internals import ColumnReference, Table
from pathway.stdlib.indexing import DataIndex
from pathway.xpacks.llm import llms, prompts
from pathway.xpacks.llm import Doc, llms, prompts
from pathway.xpacks.llm.llms import prompt_chat_single_qa
from pathway.xpacks.llm.prompts import prompt_qa_geometric_rag
from pathway.xpacks.llm.vector_store import VectorStoreServer
Expand Down Expand Up @@ -212,11 +212,26 @@ class AIResponseType(Enum):

@pw.udf
def _filter_document_metadata(
docs: pw.Json, metadata_keys: list[str] = ["path"]
) -> list[dict]:
"""Utility function to filter context document metadata to keep the keys in the
provided `metadata_keys` list."""
doc_ls: list[dict[str, str | dict]] = docs.value # type: ignore
docs: pw.Json | list[pw.Json] | list[Doc], metadata_keys: list[str] = ["path"]
) -> list[Doc]:
"""Filter context document metadata to keep the keys in the
provided `metadata_keys` list.
Works on both ColumnReference and list of pw.Json."""
if isinstance(docs, pw.Json):
doc_ls: list[Doc] = docs.as_list()
elif isinstance(docs, list) and all([isinstance(dc, dict) for dc in docs]):
doc_ls = docs # type: ignore
elif all([isinstance(dc, pw.Json) for dc in docs]):
doc_ls = [dc.as_dict() for dc in docs] # type: ignore
else:
raise ValueError(
"""`docs` argument is not instance of (pw.Json | list[pw.Json] | list[Doc]).
Please check your pipeline. Using `pw.reducers.tuple` may help."""
)

if isinstance(doc_ls[0], list | tuple): # unpack if needed
doc_ls = doc_ls[0]

filtered_docs = []
for doc in doc_ls:
Expand Down
Loading

0 comments on commit 633469e

Please sign in to comment.