From e0a9e3f852720955dfd84587c90ef4da04ce058b Mon Sep 17 00:00:00 2001 From: Andrew White Date: Sun, 19 Nov 2023 15:44:24 -0800 Subject: [PATCH] Added merging capability --- paperqa/__init__.py | 12 +++++-- paperqa/types.py | 83 +++++++++++++++++++++++++++++++++++++------ paperqa/utils.py | 18 ++++++++++ paperqa/version.py | 2 +- tests/test_paperqa.py | 28 +++++++++++++-- 5 files changed, 127 insertions(+), 16 deletions(-) diff --git a/paperqa/__init__.py b/paperqa/__init__.py index f2595c09..fa06c892 100644 --- a/paperqa/__init__.py +++ b/paperqa/__init__.py @@ -1,4 +1,12 @@ -from .docs import Answer, Docs, PromptCollection, Doc, Text +from .docs import Answer, Docs, PromptCollection, Doc, Text, Context from .version import __version__ -__all__ = ["Docs", "Answer", "PromptCollection", "__version__", "Doc", "Text"] +__all__ = [ + "Docs", + "Answer", + "PromptCollection", + "__version__", + "Doc", + "Text", + "Context", +] diff --git a/paperqa/types.py b/paperqa/types.py index 1f282228..c3a01a42 100644 --- a/paperqa/types.py +++ b/paperqa/types.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Set, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.manager import ( @@ -12,7 +12,6 @@ except ImportError: from pydantic import BaseModel, validator - from .prompts import ( citation_prompt, default_system_prompt, @@ -20,7 +19,7 @@ select_paper_prompt, summary_prompt, ) -from .utils import iter_citations +from .utils import extract_doi, iter_citations DocKey = Any CBManager = Union[AsyncCallbackManagerForChainRun, CallbackManagerForChainRun] @@ -129,20 +128,84 @@ def __str__(self) -> str: """Return the answer as a string.""" return self.formatted_answer - def markdown(self) -> str: + def get_citation(self, name: str) -> str: + """Return the formatted citation for the gien docname.""" + try: + doc = next(filter(lambda x: x.text.name == name, self.contexts)).text.doc + except StopIteration: + raise ValueError(f"Could not find docname {name} in contexts") + return f"({doc.citation})" + + def markdown(self) -> Tuple[str, str]: """Return the answer with footnote style citations.""" # example: This is an answer.[^1] # [^1]: This the citation. - index = 1 output = self.answer - ref_list = "## References\n\n" + refs: Dict[str, int] = dict() + index = 1 for citation in iter_citations(self.answer): - refs = [] compound = "" for c in citation.split(","): - refs.append(c.strip("() ")) + c = c.strip("() ") + if c == "Extra background information": + continue + if c in refs: + compound += f"[^{refs[c]}]" + continue + refs[c] = index compound += f"[^{index}]" index += 1 output = output.replace(citation, compound) - ref_list += "\n".join([f"[^{i}]: {r}" for i, r in enumerate(refs, start=1)]) - return output + "\n\n" + ref_list + formatted_refs = "\n".join( + [ + f"[^{i}]: [{self.get_citation(r)}]({extract_doi(self.get_citation(r))})" + for r, i in refs.items() + ] + ) + return output, formatted_refs + + def combine_with(self, other: "Answer") -> "Answer": + """ + Combine this answer object with another, merging their context/answer. + """ + combined = Answer( + question=self.question + " / " + other.question, + answer=self.answer + " " + other.answer, + context=self.context + " " + other.context, + contexts=self.contexts + other.contexts, + references=self.references + " " + other.references, + formatted_answer=self.formatted_answer + " " + other.formatted_answer, + summary_length=self.summary_length, # Assuming the same summary_length for both + answer_length=self.answer_length, # Assuming the same answer_length for both + memory=self.memory if self.memory else other.memory, + cost=self.cost if self.cost else other.cost, + token_counts=self.merge_token_counts(self.token_counts, other.token_counts), + ) + # Handling dockey_filter if present in either of the Answer objects + if self.dockey_filter or other.dockey_filter: + combined.dockey_filter = ( + self.dockey_filter if self.dockey_filter else set() + ) | (other.dockey_filter if other.dockey_filter else set()) + return combined + + @staticmethod + def merge_token_counts( + counts1: Optional[Dict[str, List[int]]], counts2: Optional[Dict[str, List[int]]] + ) -> Optional[Dict[str, List[int]]]: + """ + Merge two dictionaries of token counts. + """ + if counts1 is None and counts2 is None: + return None + if counts1 is None: + return counts2 + if counts2 is None: + return counts1 + merged_counts = counts1.copy() + for key, values in counts2.items(): + if key in merged_counts: + merged_counts[key][0] += values[0] + merged_counts[key][1] += values[1] + else: + merged_counts[key] = values + return merged_counts diff --git a/paperqa/utils.py b/paperqa/utils.py index ec419ea2..6cb0d1a0 100644 --- a/paperqa/utils.py +++ b/paperqa/utils.py @@ -113,3 +113,21 @@ def iter_citations(text: str) -> List[str]: citation_regex = r"\b[\w\-]+\set\sal\.\s\([0-9]{4}\)|\((?:[^\)]*?[a-zA-Z][^\)]*?[0-9]{4}[^\)]*?)\)" result = re.findall(citation_regex, text, flags=re.MULTILINE) return result + + +def extract_doi(reference: str) -> str: + """ + Extracts DOI from the reference string using regex. + + :param reference: A string containing the reference. + :return: A string containing the DOI link or a message if DOI is not found. + """ + # DOI regex pattern + doi_pattern = r"10.\d{4,9}/[-._;()/:A-Z0-9]+" + doi_match = re.search(doi_pattern, reference, re.IGNORECASE) + + # If DOI is found in the reference, return the DOI link + if doi_match: + return "https://doi.org/" + doi_match.group() + else: + return "" diff --git a/paperqa/version.py b/paperqa/version.py index d1a7f1e0..62ee17b8 100644 --- a/paperqa/version.py +++ b/paperqa/version.py @@ -1 +1 @@ -__version__ = "3.12.0" +__version__ = "3.13.0" diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index 1ecd80ff..aa3185d1 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -11,10 +11,9 @@ from langchain.llms.fake import FakeListLLM from langchain.prompts import PromptTemplate -from paperqa import Answer, Docs, PromptCollection, Text +from paperqa import Answer, Context, Doc, Docs, PromptCollection, Text from paperqa.chains import get_score from paperqa.readers import read_doc -from paperqa.types import Doc from paperqa.utils import ( iter_citations, maybe_is_html, @@ -132,8 +131,31 @@ def test_markdown(): question="What was Fredic's greatest accomplishment?", answer="Frederick Bates's greatest accomplishment was his role in resolving land disputes " "and his service as governor of Missouri (Wiki2023 chunk 1).", + contexts=[ + Context( + context="", + text=Text( + text="Frederick Bates's greatest accomplishment was his role in resolving land disputes " + "and his service as governor of Missouri (Wiki2023 chunk 1).", + name="Wiki2023 chunk 1", + doc=Doc( + name="Wiki2023", + docname="Wiki2023", + citation="WikiMedia Foundation, 2023, Accessed now", + texts=[], + ), + ), + score=5, + ) + ], ) - assert "[^1]" in answer.markdown() + m, r = answer.markdown() + print(r) + assert "[^1]" in m + answer = answer.combine_with(answer) + m2, r2 = answer.markdown() + assert m2.startswith(m) + assert r2 == r def test_ablations():