Skip to content

Commit

Permalink
Fixed text embedding errors
Browse files Browse the repository at this point in the history
  • Loading branch information
whitead committed Jan 22, 2024
1 parent 59ed8d3 commit 8fb8bdd
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 65 deletions.
62 changes: 48 additions & 14 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,16 @@
)
from .paths import PAPERQA_DIR
from .readers import read_doc
from .types import Answer, CallbackFactory, Context, Doc, DocKey, PromptCollection, Text
from .types import (
Answer,
CallbackFactory,
Context,
Doc,
DocKey,
LLMResult,
PromptCollection,
Text,
)
from .utils import (
gather_with_concurrency,
guess_is_4xx,
Expand Down Expand Up @@ -103,7 +112,7 @@ def __init__(self, **data):
super().__init__(**data)
self._client = client
self._embedding_client = embedding_client
# run this here (instead of automateically) so it has access to privates
# run this here (instead of automatically) so it has access to privates
# If I ever figure out a better way of validating privates
# I can move this back to the decorator
Docs.make_llm_names_consistent(self)
Expand Down Expand Up @@ -171,11 +180,15 @@ def make_llm_names_consistent(cls, data: Any) -> Any:
data.llm_model.infer_llm_type(data._client)
data.llm = data.llm_model.name
if data.summary_llm_model is not None:
if (
data.summary_llm is None
and data.summary_llm_model is data.llm_model
):
data.summary_llm = data.llm
if data.summary_llm == "langchain":
# from langchain models - kind of hacky
data.summary_llm_model.infer_llm_type(data._client)
data.summary_llm = data.summary_llm_model.name

return data

def clear_docs(self):
Expand All @@ -188,6 +201,9 @@ def __getstate__(self):
# to be overriding the behavior on setstaet/getstate anyway.
# The reason is that the other serialization methods from Pydantic -
# model_dump - will not drop private attributes.
# So - this getstate/setstate removes private attributes for pickling
# and Pydantic will handle removing private attributes for other
# serialization methods (like model_dump)
state = super().__getstate__()
# remove client from private attributes
del state["__pydantic_private__"]["_client"]
Expand Down Expand Up @@ -301,9 +317,10 @@ def add(
texts = read_doc(path, fake_doc, chunk_chars=chunk_chars, overlap=100)
if len(texts) == 0:
raise ValueError(f"Could not read document {path}. Is it empty?")
citation = asyncio.run(
chain_result = asyncio.run(
cite_chain(dict(text=texts[0].text), None),
)
citation = chain_result.text
if len(citation) < 3 or "Unknown" in citation or "insufficient" in citation:
citation = f"Unknown, {os.path.basename(path)}, {datetime.now().year}"

Expand Down Expand Up @@ -405,6 +422,7 @@ async def adoc_match(
k: int = 25,
rerank: bool | None = None,
get_callbacks: CallbackFactory = lambda x: None,
answer: Answer | None = None, # used for tracking tokens
) -> set[DocKey]:
"""Return a list of dockeys that match the query."""
matches, _ = await self.docs_index.max_marginal_relevance_search(
Expand Down Expand Up @@ -440,7 +458,9 @@ async def adoc_match(
dict(question=query, papers="\n".join(papers)),
get_callbacks("filter"),
)
return set([d.dockey for d in matched_docs if d.docname in result])
if answer:
answer.add_tokens(result)
return set([d.dockey for d in matched_docs if d.docname in str(result)])
except AttributeError:
pass
return set([d.dockey for d in matched_docs])
Expand Down Expand Up @@ -528,6 +548,8 @@ async def aget_evidence(
async def process(match):
callbacks = get_callbacks("evidence:" + match.name)
citation = match.doc.citation
# empty result
llm_result = LLMResult(model="", date="")
if detailed_citations:
citation = match.name + ": " + citation

Expand All @@ -547,7 +569,7 @@ async def process(match):
# my best idea is see if there is a 4XX
# http code in the exception
try:
context = await summary_chain(
llm_result = await summary_chain(
dict(
question=answer.question,
# Add name so chunk is stated
Expand All @@ -557,15 +579,16 @@ async def process(match):
),
callbacks,
)
context = llm_result.text
except Exception as e:
if guess_is_4xx(str(e)):
return None
return None, llm_result
raise e
if (
"not applicable" in context.lower()
or "not relevant" in context.lower()
):
return None
return None, llm_result
if self.strip_citations:
# remove citations that collide with our grounded citations (for the answer LLM)
context = strip_citations(context)
Expand All @@ -580,13 +603,16 @@ async def process(match):
),
score=score,
)
return c
return c, llm_result

results = await gather_with_concurrency(
self.max_concurrent, [process(m) for m in matches]
)
# update token counts
[answer.add_tokens(r[1]) for r in results]

# filter out failures
contexts = [c for c in results if c is not None]
contexts = [c for c, r in results if c is not None]

answer.contexts = sorted(
contexts + answer.contexts, key=lambda x: x.score, reverse=True
Expand Down Expand Up @@ -646,7 +672,9 @@ async def aquery(
# comparable - one is chunks and one is docs
if key_filter or (key_filter is None and len(self.docs) > k):
keys = await self.adoc_match(
answer.question, get_callbacks=get_callbacks
answer.question,
get_callbacks=get_callbacks,
answer=answer,
)
if len(keys) > 0:
answer.dockey_filter = keys
Expand All @@ -663,7 +691,10 @@ async def aquery(
system_prompt=self.prompts.system,
)
pre = await chain(dict(question=answer.question), get_callbacks("pre"))
answer.context = answer.context + "\n\nExtra background information:" + pre
answer.add_tokens(pre)
answer.context = (
answer.context + "\n\nExtra background information:" + str(pre)
)
bib = dict()
if len(answer.context) < 10: # and not self.memory:
answer_text = (
Expand All @@ -675,14 +706,16 @@ async def aquery(
prompt=self.prompts.qa,
system_prompt=self.prompts.system,
)
answer_text = await qa_chain(
answer_result = await qa_chain(
dict(
context=answer.context,
answer_length=answer.answer_length,
question=answer.question,
),
get_callbacks("answer"),
)
answer_text = answer_result.text
answer.add_tokens(answer_result)
# it still happens
if "(Example2012Example pages 3-4)" in answer_text:
answer_text = answer_text.replace("(Example2012Example pages 3-4)", "")
Expand All @@ -709,7 +742,8 @@ async def aquery(
system_prompt=self.prompts.system,
)
post = await chain(answer.model_dump(), get_callbacks("post"))
answer.answer = post
answer.answer = post.text
answer.add_tokens(post)
answer.formatted_answer = f"Question: {answer.question}\n\n{post}\n"
if len(bib) > 0:
answer.formatted_answer += f"\nReferences\n\n{bib_str}\n"
Expand Down
94 changes: 71 additions & 23 deletions paperqa/llms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio
import datetime
import re
from abc import ABC, abstractmethod
from inspect import signature
Expand All @@ -18,7 +20,7 @@
from pydantic import BaseModel, ConfigDict, Field, model_validator

from .prompts import default_system_prompt
from .types import Embeddable
from .types import Doc, Embeddable, LLMResult, Text
from .utils import batch_iter, flatten, gather_with_concurrency


Expand Down Expand Up @@ -59,13 +61,14 @@ def process_llm_config(llm_config: dict) -> dict:
result = {k: v for k, v in llm_config.items() if k != "model_type"}
if "max_tokens" not in result or result["max_tokens"] == -1:
model = llm_config["model"]
# now we guess!
# now we guess - we could use tiktoken to count,
# but do have the initative right now
if model.startswith("gpt-4") or (
model.startswith("gpt-3.5") and "1106" in model
):
result["max_tokens"] = 4096
result["max_tokens"] = 3000
else:
result["max_tokens"] = 2048 # ?
result["max_tokens"] = 1500
return result


Expand Down Expand Up @@ -124,13 +127,18 @@ async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any:
def infer_llm_type(self, client: Any) -> str:
return "completion"

def count_tokens(self, text: str) -> int:
return len(text) // 4 # gross approximation

def make_chain(
self,
client: Any,
prompt: str,
skip_system: bool = False,
system_prompt: str = default_system_prompt,
) -> Callable[[dict, list[Callable[[str], None]] | None], Coroutine[Any, Any, str]]:
) -> Callable[
[dict, list[Callable[[str], None]] | None], Coroutine[Any, Any, LLMResult]
]:
"""Create a function to execute a batch of prompts
This replaces the previous use of langchain for combining prompts and LLMs.
Expand All @@ -143,7 +151,7 @@ def make_chain(
Returns:
A function to execute a prompt. Its signature is:
execute(data: dict, callbacks: list[Callable[[str], None]]] | None = None) -> str
execute(data: dict, callbacks: list[Callable[[str], None]]] | None = None) -> LLMResult
where data is a dict with keys for the input variables that will be formatted into prompt
and callbacks is a list of functions to call with each chunk of the completion.
"""
Expand All @@ -160,21 +168,39 @@ def make_chain(

async def execute(
data: dict, callbacks: list[Callable[[str], None]] | None = None
) -> str:
) -> LLMResult:
start_clock = asyncio.get_running_loop().time()
result = LLMResult(
model=self.name,
date=datetime.datetime.now().isoformat(),
)
messages = chat_prompt[:-1] + [
dict(role="user", content=chat_prompt[-1]["content"].format(**data))
]
result.prompt_count = sum(
[self.count_tokens(m["content"]) for m in messages]
) + sum([self.count_tokens(m["role"]) for m in messages])

if callbacks is None:
output = await self.achat(client, messages)
else:
completion = self.achat_iter(client, messages) # type: ignore
result = []
text_result = []
async for chunk in completion: # type: ignore
if chunk:
result.append(chunk)
if result.seconds_to_first_token == 0:
result.seconds_to_first_token = (
asyncio.get_running_loop().time() - start_clock
)
text_result.append(chunk)
[f(chunk) for f in callbacks]
output = "".join(result)
return output
output = "".join(text_result)
result.completion_count = self.count_tokens(output)
result.text = output
result.seconds_to_last_token = (
asyncio.get_running_loop().time() - start_clock
)
return result

return execute
elif self.llm_type == "completion":
Expand All @@ -185,23 +211,38 @@ async def execute(

async def execute(
data: dict, callbacks: list[Callable[[str], None]] | None = None
) -> str:
) -> LLMResult:
start_clock = asyncio.get_running_loop().time()
result = LLMResult(
model=self.name,
date=datetime.datetime.now().isoformat(),
)
formatted_prompt = completion_prompt.format(**data)
result.prompt_count = self.count_tokens(formatted_prompt)

if callbacks is None:
output = await self.acomplete(
client, completion_prompt.format(**data)
)
output = await self.acomplete(client, formatted_prompt)
else:
completion = self.acomplete_iter( # type: ignore
client,
completion_prompt.format(**data),
formatted_prompt,
)
result = []
text_result = []
async for chunk in completion: # type: ignore
if chunk:
result.append(chunk)
if result.seconds_to_first_token == 0:
result.seconds_to_first_token = (
asyncio.get_running_loop().time() - start_clock
)
text_result.append(chunk)
[f(chunk) for f in callbacks]
output = "".join(result)
return output
output = "".join(text_result)
result.completion_count = self.count_tokens(output)
result.text = output
result.seconds_to_last_token = (
asyncio.get_running_loop().time() - start_clock
)
return result

return execute
raise ValueError(f"Unknown llm_type: {self.llm_type}")
Expand Down Expand Up @@ -553,9 +594,16 @@ def add_texts_and_embeddings(self, texts: Sequence[Embeddable]) -> None:
if self._store_builder is None:
raise ValueError("You must set store_builder before adding texts")
self.class_type = type(texts[0])
vec_store_text_and_embeddings = list(
map(lambda x: (x.text, x.embedding), texts)
)
if self.class_type == Text:
vec_store_text_and_embeddings = list(
map(lambda x: (x.text, x.embedding), cast(list[Text], texts))
)
elif self.class_type == Doc:
vec_store_text_and_embeddings = list(
map(lambda x: (x.citation, x.embedding), cast(list[Doc], texts))
)
else:
raise ValueError("Only embeddings of type Text are supported")
if self._store is None:
self._store = self._store_builder( # type: ignore
vec_store_text_and_embeddings,
Expand Down
Loading

0 comments on commit 8fb8bdd

Please sign in to comment.