From def3640e5392f81d3001f3f6798671e1e07c03a5 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Mon, 20 Nov 2023 08:29:06 -0800 Subject: [PATCH] Added markdown output with citations (#209) * Added markdown output with citations * Added merging capability * Fixed extra parantheses --- paperqa/__init__.py | 12 +++++- paperqa/contrib/zotero.py | 3 +- paperqa/types.py | 88 +++++++++++++++++++++++++++++++++++++-- paperqa/utils.py | 30 ++++++++++++- paperqa/version.py | 2 +- tests/test_paperqa.py | 67 +++++++++++++++++++++++++++-- 6 files changed, 188 insertions(+), 14 deletions(-) diff --git a/paperqa/__init__.py b/paperqa/__init__.py index f2595c099..fa06c8925 100644 --- a/paperqa/__init__.py +++ b/paperqa/__init__.py @@ -1,4 +1,12 @@ -from .docs import Answer, Docs, PromptCollection, Doc, Text +from .docs import Answer, Docs, PromptCollection, Doc, Text, Context from .version import __version__ -__all__ = ["Docs", "Answer", "PromptCollection", "__version__", "Doc", "Text"] +__all__ = [ + "Docs", + "Answer", + "PromptCollection", + "__version__", + "Doc", + "Text", + "Context", +] diff --git a/paperqa/contrib/zotero.py b/paperqa/contrib/zotero.py index b0d6240cb..a390cd1c3 100644 --- a/paperqa/contrib/zotero.py +++ b/paperqa/contrib/zotero.py @@ -14,8 +14,7 @@ except ImportError: raise ImportError("Please install pyzotero: `pip install pyzotero`") from ..paths import PAPERQA_DIR -from ..types import StrPath -from ..utils import count_pdf_pages +from ..utils import StrPath, count_pdf_pages class ZoteroPaper(BaseModel): diff --git a/paperqa/types.py b/paperqa/types.py index aba3dcf9e..4f17ff88f 100644 --- a/paperqa/types.py +++ b/paperqa/types.py @@ -1,5 +1,4 @@ -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Set, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.manager import ( @@ -13,7 +12,6 @@ except ImportError: from pydantic import BaseModel, validator - from .prompts import ( citation_prompt, default_system_prompt, @@ -21,8 +19,8 @@ select_paper_prompt, summary_prompt, ) +from .utils import extract_doi, iter_citations -StrPath = Union[str, Path] DocKey = Any CBManager = Union[AsyncCallbackManagerForChainRun, CallbackManagerForChainRun] CallbackFactory = Callable[[str], Union[None, List[BaseCallbackHandler]]] @@ -129,3 +127,85 @@ class Answer(BaseModel): def __str__(self) -> str: """Return the answer as a string.""" return self.formatted_answer + + def get_citation(self, name: str) -> str: + """Return the formatted citation for the gien docname.""" + try: + doc = next(filter(lambda x: x.text.name == name, self.contexts)).text.doc + except StopIteration: + raise ValueError(f"Could not find docname {name} in contexts") + return doc.citation + + def markdown(self) -> Tuple[str, str]: + """Return the answer with footnote style citations.""" + # example: This is an answer.[^1] + # [^1]: This the citation. + output = self.answer + refs: Dict[str, int] = dict() + index = 1 + for citation in iter_citations(self.answer): + compound = "" + for c in citation.split(","): + c = c.strip("() ") + if c == "Extra background information": + continue + if c in refs: + compound += f"[^{refs[c]}]" + continue + refs[c] = index + compound += f"[^{index}]" + index += 1 + output = output.replace(citation, compound) + formatted_refs = "\n".join( + [ + f"[^{i}]: [{self.get_citation(r)}]({extract_doi(self.get_citation(r))})" + for r, i in refs.items() + ] + ) + return output, formatted_refs + + def combine_with(self, other: "Answer") -> "Answer": + """ + Combine this answer object with another, merging their context/answer. + """ + combined = Answer( + question=self.question + " / " + other.question, + answer=self.answer + " " + other.answer, + context=self.context + " " + other.context, + contexts=self.contexts + other.contexts, + references=self.references + " " + other.references, + formatted_answer=self.formatted_answer + " " + other.formatted_answer, + summary_length=self.summary_length, # Assuming the same summary_length for both + answer_length=self.answer_length, # Assuming the same answer_length for both + memory=self.memory if self.memory else other.memory, + cost=self.cost if self.cost else other.cost, + token_counts=self.merge_token_counts(self.token_counts, other.token_counts), + ) + # Handling dockey_filter if present in either of the Answer objects + if self.dockey_filter or other.dockey_filter: + combined.dockey_filter = ( + self.dockey_filter if self.dockey_filter else set() + ) | (other.dockey_filter if other.dockey_filter else set()) + return combined + + @staticmethod + def merge_token_counts( + counts1: Optional[Dict[str, List[int]]], counts2: Optional[Dict[str, List[int]]] + ) -> Optional[Dict[str, List[int]]]: + """ + Merge two dictionaries of token counts. + """ + if counts1 is None and counts2 is None: + return None + if counts1 is None: + return counts2 + if counts2 is None: + return counts1 + merged_counts = counts1.copy() + for key, values in counts2.items(): + if key in merged_counts: + merged_counts[key][0] += values[0] + merged_counts[key][1] += values[1] + else: + merged_counts[key] = values + return merged_counts diff --git a/paperqa/utils.py b/paperqa/utils.py index dd28e01b1..6cb0d1a0e 100644 --- a/paperqa/utils.py +++ b/paperqa/utils.py @@ -2,12 +2,13 @@ import math import re import string -from typing import BinaryIO, List +from pathlib import Path +from typing import BinaryIO, List, Union import pypdf from langchain.base_language import BaseLanguageModel -from .types import StrPath +StrPath = Union[str, Path] def name_in_text(name: str, text: str) -> bool: @@ -105,3 +106,28 @@ def strip_citations(text: str) -> str: # Remove the citations from the text text = re.sub(citation_regex, "", text, flags=re.MULTILINE) return text + + +def iter_citations(text: str) -> List[str]: + # Combined regex for identifying citations (see unit tests for examples) + citation_regex = r"\b[\w\-]+\set\sal\.\s\([0-9]{4}\)|\((?:[^\)]*?[a-zA-Z][^\)]*?[0-9]{4}[^\)]*?)\)" + result = re.findall(citation_regex, text, flags=re.MULTILINE) + return result + + +def extract_doi(reference: str) -> str: + """ + Extracts DOI from the reference string using regex. + + :param reference: A string containing the reference. + :return: A string containing the DOI link or a message if DOI is not found. + """ + # DOI regex pattern + doi_pattern = r"10.\d{4,9}/[-._;()/:A-Z0-9]+" + doi_match = re.search(doi_pattern, reference, re.IGNORECASE) + + # If DOI is found in the reference, return the DOI link + if doi_match: + return "https://doi.org/" + doi_match.group() + else: + return "" diff --git a/paperqa/version.py b/paperqa/version.py index d1a7f1e0d..62ee17b83 100644 --- a/paperqa/version.py +++ b/paperqa/version.py @@ -1 +1 @@ -__version__ = "3.12.0" +__version__ = "3.13.0" diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index ae7f4e073..aa3185d1b 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -11,11 +11,11 @@ from langchain.llms.fake import FakeListLLM from langchain.prompts import PromptTemplate -from paperqa import Answer, Docs, PromptCollection, Text +from paperqa import Answer, Context, Doc, Docs, PromptCollection, Text from paperqa.chains import get_score from paperqa.readers import read_doc -from paperqa.types import Doc from paperqa.utils import ( + iter_citations, maybe_is_html, maybe_is_text, name_in_text, @@ -29,7 +29,36 @@ async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: print(token) -# Assume strip_citations is imported or defined in this file. +def test_iter_citations(): + text = ( + "Yes, COVID-19 vaccines are effective. Various studies have documented the " + "effectiveness of COVID-19 vaccines in preventing severe disease, " + "hospitalization, and death. The BNT162b2 vaccine has shown effectiveness " + "ranging from 65% to -41% for the 5-11 years age group and 76% to 46% for the " + "12-17 years age group, after the emergence of the Omicron variant in New York " + "(Dorabawila2022EffectivenessOT). Against the Delta variant, the effectiveness " + "of the BNT162b2 vaccine was approximately 88% after two doses " + "(Bernal2021EffectivenessOC pg. 1-3).\n\n" + "Vaccine effectiveness was also found to be 89% against hospitalization and " + "91% against emergency department or urgent care clinic visits " + "(Thompson2021EffectivenessOC pg. 3-5, Goo2031Foo pg. 3-4). In the UK " + "vaccination program, vaccine effectiveness was approximately 56% in " + "individuals aged ≥70 years between 28-34 days post-vaccination, increasing to " + "approximately 58% from day 35 onwards (Marfé2021EffectivenessOC).\n\n" + "However, it is important to note that vaccine effectiveness can decrease over " + "time. For instance, the effectiveness of COVID-19 vaccines against severe " + "COVID-19 declined to 64% after 121 days, compared to around 90% initially " + "(Chemaitelly2022WaningEO, Foo2019Bar). Despite this, vaccines still provide " + "significant protection against severe outcomes." + ) + ref = [ + "(Dorabawila2022EffectivenessOT)", + "(Bernal2021EffectivenessOC pg. 1-3)", + "(Thompson2021EffectivenessOC pg. 3-5, Goo2031Foo pg. 3-4)", + "(Marfé2021EffectivenessOC)", + "(Chemaitelly2022WaningEO, Foo2019Bar)", + ] + assert list(iter_citations(text)) == ref def test_single_author(): @@ -97,6 +126,38 @@ def test_citations_with_nonstandard_chars(): ) +def test_markdown(): + answer = Answer( + question="What was Fredic's greatest accomplishment?", + answer="Frederick Bates's greatest accomplishment was his role in resolving land disputes " + "and his service as governor of Missouri (Wiki2023 chunk 1).", + contexts=[ + Context( + context="", + text=Text( + text="Frederick Bates's greatest accomplishment was his role in resolving land disputes " + "and his service as governor of Missouri (Wiki2023 chunk 1).", + name="Wiki2023 chunk 1", + doc=Doc( + name="Wiki2023", + docname="Wiki2023", + citation="WikiMedia Foundation, 2023, Accessed now", + texts=[], + ), + ), + score=5, + ) + ], + ) + m, r = answer.markdown() + print(r) + assert "[^1]" in m + answer = answer.combine_with(answer) + m2, r2 = answer.markdown() + assert m2.startswith(m) + assert r2 == r + + def test_ablations(): tests_dir = os.path.dirname(os.path.abspath(__file__)) doc_path = os.path.join(tests_dir, "paper.pdf")