Skip to content

Commit

Permalink
add prompt template and context modifiers to BaseRAGQA (#7699)
Browse files Browse the repository at this point in the history
Co-authored-by: Szymon Dudycz <[email protected]>
GitOrigin-RevId: 92678ea50a9aab5dfb5b8c6d371305f532289154
  • Loading branch information
2 people authored and Manul from Pathway committed Dec 17, 2024
1 parent a25d36f commit 6ade4ac
Show file tree
Hide file tree
Showing 7 changed files with 482 additions and 247 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm
## [Unreleased]

### Added
- `pw.xpacks.llm.prompts.RAGPromptTemplate`, set of prompt utilities that enable verifying templates and creating UDFs from prompt strings or callables.
- `pw.xpacks.llm.question_answering.BaseContextProcessor` streamlines development and tuning of representing retrieved context documents to the LLM.
- `pw.io.kafka.read` now supports `with_metadata` flag, which makes it possible to attach the metadata of the Kafka messages to the table entries.

### Changed
- `pw.io.sharepoint.read` now explicitly terminates with an error if it fails to read the data the specified number of times per row (the default is `8`).
- `pw.xpacks.llm.prompts.prompt_qa`, and other prompts expect 'context' and 'query' fields instead of 'docs'.
- Removed support for `short_prompt_template` and `long_prompt_template` in `pw.xpacks.llm.question_answering.BaseRAGQuestionAnswerer`. These prompt variants are no longer accepted during construction or in requests.
- `pw.xpacks.llm.question_answering.BaseRAGQuestionAnswerer` allows setting user created prompts. Templates are verified to include 'context' and 'query' placeholders.
- `pw.xpacks.llm.question_answering.BaseRAGQuestionAnswerer` can take a `BaseContextProcessor` that represents context documents to the LLM. Defaults to `pw.xpacks.llm.question_answering.SimpleContextProcessor` which filters metadata fields and joins the documents with new lines.

### Fixed
- The input of `pw.io.fs.read` and `pw.io.s3.read` is now correctly persisted in case deletions or modifications of already processed objects take place.
Expand Down
62 changes: 7 additions & 55 deletions integration_tests/webserver/test_llm_xpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
import pathway as pw
from pathway.internals.udfs.caches import InMemoryCache
from pathway.tests.utils import wait_result_with_checker
from pathway.xpacks.llm import llms
from pathway.xpacks.llm.question_answering import BaseRAGQuestionAnswerer
from pathway.xpacks.llm.tests.utils import build_vector_store, create_build_rag_app
from pathway.xpacks.llm.vector_store import VectorStoreClient, VectorStoreServer

PATHWAY_HOST = "127.0.0.1"
Expand Down Expand Up @@ -258,26 +257,6 @@ def checker() -> bool:
)


def build_vector_store(embedder) -> VectorStoreServer:
"""From a given embedder, with a single doc."""
docs = pw.debug.table_from_rows(
schema=pw.schema_from_types(data=bytes, _metadata=dict),
rows=[
(
"test".encode("utf-8"),
{"path": "test_module.py"},
)
],
)

vector_server = VectorStoreServer(
docs,
embedder=embedder,
)

return vector_server


@pytest.mark.parametrize(
"cache_strategy_cls",
[
Expand Down Expand Up @@ -321,33 +300,6 @@ def checker() -> bool:
)


def build_rag_app(port: int) -> BaseRAGQuestionAnswerer:
@pw.udf
def fake_embeddings_model(x: str) -> list[float]:
return [1.0, 1.0, 0.0]

class FakeChatModel(llms.BaseChat):
async def __wrapped__(self, *args, **kwargs) -> str:
return "Text"

def _accepts_call_arg(self, arg_name: str) -> bool:
return True

chat = FakeChatModel()

vector_server = build_vector_store(fake_embeddings_model)

rag_app = BaseRAGQuestionAnswerer(
llm=chat,
indexer=vector_server,
default_llm_name="gpt-4o-mini",
)

rag_app.build_server(host=PATHWAY_HOST, port=port)

return rag_app


@pytest.mark.parametrize("input", [1, 2, 3, 99])
@pytest.mark.parametrize(
"async_mode",
Expand All @@ -357,7 +309,7 @@ def test_serve_callable(port: int, input: int, async_mode: bool):
TEST_ENDPOINT = "test_add_1"
expected = input + 1

rag_app = build_rag_app(port)
rag_app = create_build_rag_app(port)

if async_mode:

Expand Down Expand Up @@ -406,7 +358,7 @@ def test_serve_callable_cache(port: int, input: int, async_mode: bool):
TEST_ENDPOINT = "test_add_1"
expected = input + 1

rag_app = build_rag_app(port)
rag_app = create_build_rag_app(port)
setattr(rag_app, "num_calls", 0)

if async_mode:
Expand Down Expand Up @@ -478,7 +430,7 @@ def test_serve_callable_symmetric(port: int, input: Any):
TEST_ENDPOINT = "symmetric"
expected = input

rag_app = build_rag_app(port)
rag_app = create_build_rag_app(port)

UType: TypeAlias = int | dict | str | list | None

Expand Down Expand Up @@ -526,7 +478,7 @@ def test_serve_callable_nested_async_typing(
TEST_ENDPOINT = "nested"
expected = [{"name": name, "value": dc}]

rag_app = build_rag_app(port)
rag_app = create_build_rag_app(port)

if typed:

Expand Down Expand Up @@ -569,9 +521,9 @@ def checker() -> bool:

def test_serve_callable_with_search(port: int):
TEST_ENDPOINT = "custom_search"
expected = "test" # set in the docs part of `build_rag_app`
expected = "test" # set in the docs part of `build_vector_store`

rag_app = build_rag_app(port)
rag_app = create_build_rag_app(port)

@rag_app.serve_callable(route=f"/{TEST_ENDPOINT}")
async def return_top_doc_text(query):
Expand Down
Loading

0 comments on commit 6ade4ac

Please sign in to comment.