Skip to content

Commit

Permalink
Added merging capability
Browse files Browse the repository at this point in the history
  • Loading branch information
whitead committed Nov 19, 2023
1 parent 21a533d commit e0a9e3f
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 16 deletions.
12 changes: 10 additions & 2 deletions paperqa/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
from .docs import Answer, Docs, PromptCollection, Doc, Text
from .docs import Answer, Docs, PromptCollection, Doc, Text, Context
from .version import __version__

__all__ = ["Docs", "Answer", "PromptCollection", "__version__", "Doc", "Text"]
__all__ = [
"Docs",
"Answer",
"PromptCollection",
"__version__",
"Doc",
"Text",
"Context",
]
83 changes: 73 additions & 10 deletions paperqa/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional, Set, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.manager import (
Expand All @@ -12,15 +12,14 @@
except ImportError:
from pydantic import BaseModel, validator


from .prompts import (
citation_prompt,
default_system_prompt,
qa_prompt,
select_paper_prompt,
summary_prompt,
)
from .utils import iter_citations
from .utils import extract_doi, iter_citations

DocKey = Any
CBManager = Union[AsyncCallbackManagerForChainRun, CallbackManagerForChainRun]
Expand Down Expand Up @@ -129,20 +128,84 @@ def __str__(self) -> str:
"""Return the answer as a string."""
return self.formatted_answer

def markdown(self) -> str:
def get_citation(self, name: str) -> str:
"""Return the formatted citation for the gien docname."""
try:
doc = next(filter(lambda x: x.text.name == name, self.contexts)).text.doc
except StopIteration:
raise ValueError(f"Could not find docname {name} in contexts")
return f"({doc.citation})"

def markdown(self) -> Tuple[str, str]:
"""Return the answer with footnote style citations."""
# example: This is an answer.[^1]
# [^1]: This the citation.
index = 1
output = self.answer
ref_list = "## References\n\n"
refs: Dict[str, int] = dict()
index = 1
for citation in iter_citations(self.answer):
refs = []
compound = ""
for c in citation.split(","):
refs.append(c.strip("() "))
c = c.strip("() ")
if c == "Extra background information":
continue
if c in refs:
compound += f"[^{refs[c]}]"
continue
refs[c] = index
compound += f"[^{index}]"
index += 1
output = output.replace(citation, compound)
ref_list += "\n".join([f"[^{i}]: {r}" for i, r in enumerate(refs, start=1)])
return output + "\n\n" + ref_list
formatted_refs = "\n".join(
[
f"[^{i}]: [{self.get_citation(r)}]({extract_doi(self.get_citation(r))})"
for r, i in refs.items()
]
)
return output, formatted_refs

def combine_with(self, other: "Answer") -> "Answer":
"""
Combine this answer object with another, merging their context/answer.
"""
combined = Answer(
question=self.question + " / " + other.question,
answer=self.answer + " " + other.answer,
context=self.context + " " + other.context,
contexts=self.contexts + other.contexts,
references=self.references + " " + other.references,
formatted_answer=self.formatted_answer + " " + other.formatted_answer,
summary_length=self.summary_length, # Assuming the same summary_length for both
answer_length=self.answer_length, # Assuming the same answer_length for both
memory=self.memory if self.memory else other.memory,
cost=self.cost if self.cost else other.cost,
token_counts=self.merge_token_counts(self.token_counts, other.token_counts),
)
# Handling dockey_filter if present in either of the Answer objects
if self.dockey_filter or other.dockey_filter:
combined.dockey_filter = (
self.dockey_filter if self.dockey_filter else set()
) | (other.dockey_filter if other.dockey_filter else set())
return combined

@staticmethod
def merge_token_counts(
counts1: Optional[Dict[str, List[int]]], counts2: Optional[Dict[str, List[int]]]
) -> Optional[Dict[str, List[int]]]:
"""
Merge two dictionaries of token counts.
"""
if counts1 is None and counts2 is None:
return None
if counts1 is None:
return counts2
if counts2 is None:
return counts1
merged_counts = counts1.copy()
for key, values in counts2.items():
if key in merged_counts:
merged_counts[key][0] += values[0]
merged_counts[key][1] += values[1]
else:
merged_counts[key] = values
return merged_counts
18 changes: 18 additions & 0 deletions paperqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,21 @@ def iter_citations(text: str) -> List[str]:
citation_regex = r"\b[\w\-]+\set\sal\.\s\([0-9]{4}\)|\((?:[^\)]*?[a-zA-Z][^\)]*?[0-9]{4}[^\)]*?)\)"
result = re.findall(citation_regex, text, flags=re.MULTILINE)
return result


def extract_doi(reference: str) -> str:
"""
Extracts DOI from the reference string using regex.
:param reference: A string containing the reference.
:return: A string containing the DOI link or a message if DOI is not found.
"""
# DOI regex pattern
doi_pattern = r"10.\d{4,9}/[-._;()/:A-Z0-9]+"
doi_match = re.search(doi_pattern, reference, re.IGNORECASE)

# If DOI is found in the reference, return the DOI link
if doi_match:
return "https://doi.org/" + doi_match.group()
else:
return ""
2 changes: 1 addition & 1 deletion paperqa/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.12.0"
__version__ = "3.13.0"
28 changes: 25 additions & 3 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
from langchain.llms.fake import FakeListLLM
from langchain.prompts import PromptTemplate

from paperqa import Answer, Docs, PromptCollection, Text
from paperqa import Answer, Context, Doc, Docs, PromptCollection, Text
from paperqa.chains import get_score
from paperqa.readers import read_doc
from paperqa.types import Doc
from paperqa.utils import (
iter_citations,
maybe_is_html,
Expand Down Expand Up @@ -132,8 +131,31 @@ def test_markdown():
question="What was Fredic's greatest accomplishment?",
answer="Frederick Bates's greatest accomplishment was his role in resolving land disputes "
"and his service as governor of Missouri (Wiki2023 chunk 1).",
contexts=[
Context(
context="",
text=Text(
text="Frederick Bates's greatest accomplishment was his role in resolving land disputes "
"and his service as governor of Missouri (Wiki2023 chunk 1).",
name="Wiki2023 chunk 1",
doc=Doc(
name="Wiki2023",
docname="Wiki2023",
citation="WikiMedia Foundation, 2023, Accessed now",
texts=[],
),
),
score=5,
)
],
)
assert "[^1]" in answer.markdown()
m, r = answer.markdown()
print(r)
assert "[^1]" in m
answer = answer.combine_with(answer)
m2, r2 = answer.markdown()
assert m2.startswith(m)
assert r2 == r


def test_ablations():
Expand Down

0 comments on commit e0a9e3f

Please sign in to comment.