Skip to content

Commit

Permalink
Pulling in latest fhaviary, mypy, ruff (#647)
Browse files Browse the repository at this point in the history
Co-authored-by: Andrew White <[email protected]>
  • Loading branch information
jamesbraza and whitead authored Oct 29, 2024
1 parent 34bc169 commit 8eaef2e
Show file tree
Hide file tree
Showing 16 changed files with 244 additions and 203 deletions.
13 changes: 7 additions & 6 deletions .github/renovate.json5
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@
prHourlyLimit: 4,
timezone: "America/Los_Angeles",
rangeStrategy: "widen",
lockFileMaintenance: {
enabled: true,
},
"pre-commit": {
enabled: true,
},
lockFileMaintenance: { enabled: true },
"pre-commit": { enabled: true },
packageRules: [
{
// Allow 'widen' range strategy while matching aviary_internal pyproject.toml
matchPackageNames: ["openai"],
allowedVersions: "<1.47",
},
{
// TODO: remove after fhaviary supports Python 3.13
matchPackageNames: ["python"],
allowedVersions: "<=3.12",
},
],
}
15 changes: 8 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.9
rev: v0.7.1
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down Expand Up @@ -55,36 +55,37 @@ repos:
hooks:
- id: check-mailmap
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.20.2
rev: v0.22
hooks:
- id: validate-pyproject
additional_dependencies:
- "validate-pyproject-schema-store[all]>=2024.06.24" # Pin for Ruff's FURB154
- repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.4.21
rev: 0.4.27
hooks:
- id: uv-lock
- repo: https://github.com/renovatebot/pre-commit-hooks
rev: 38.122.0
rev: 38.131.1
hooks:
- id: renovate-config-validator
args: [--strict]
- repo: https://github.com/adamchainz/blacken-docs
rev: 1.19.0
rev: 1.19.1
hooks:
- id: blacken-docs
- repo: https://github.com/jsh9/markdown-toc-creator
rev: 0.0.8
hooks:
- id: markdown-toc-creator
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.2
rev: v1.13.0
hooks:
- id: mypy
args: [--pretty, --ignore-missing-imports]
additional_dependencies:
- aiohttp
- coredis
- fhaviary[llm]>=0.6 # Match pyproject.toml
- fhaviary[llm]>=0.8.2 # Match pyproject.toml
- ldp>=0.9 # Match pyproject.toml
- html2text
- httpx
Expand Down
11 changes: 8 additions & 3 deletions paperqa/agents/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
from copy import deepcopy
from typing import Any, Self, cast

from aviary.env import Environment, Frame
from aviary.message import Message
from aviary.tools import Tool, ToolRequestMessage, ToolResponseMessage
from aviary.core import (
Environment,
Frame,
Message,
Tool,
ToolRequestMessage,
ToolResponseMessage,
)

from paperqa.docs import Docs
from paperqa.llms import EmbeddingModel, LiteLLMModel
Expand Down
10 changes: 6 additions & 4 deletions paperqa/agents/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any

from aviary.message import MalformedMessageError, Message
from aviary.tools import (
from aviary.core import (
MalformedMessageError,
Message,
Tool,
ToolCall,
ToolRequestMessage,
Expand Down Expand Up @@ -40,7 +41,7 @@ class Callback: # type: ignore[no-redef]
from .tools import EnvironmentState, GatherEvidence, GenerateAnswer, PaperSearch

if TYPE_CHECKING:
from aviary.env import Environment
from aviary.core import Environment
from ldp.agent import Agent, SimpleAgentState
from ldp.graph.ops import OpResult

Expand Down Expand Up @@ -234,7 +235,8 @@ async def run_aviary_agent(
while not done:
if max_timesteps is not None and timestep >= max_timesteps:
logger.warning(
f"Agent didn't finish within {max_timesteps} timesteps, just answering."
f"Agent didn't finish within {max_timesteps} timesteps, just"
" answering."
)
generate_answer_tool = next(
filter(
Expand Down
21 changes: 9 additions & 12 deletions paperqa/agents/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ async def maybe_get_manifest(
}
if not file_loc_to_records:
raise ValueError( # noqa: TRY301
f"No mapping of file location to details extracted from manifest"
"No mapping of file location to details extracted from manifest"
f" file {filename}."
)
logger.debug(
Expand Down Expand Up @@ -593,11 +593,9 @@ async def get_directory_index( # noqa: PLR0912
index_settings = _settings.agent.index
if index_name:
warnings.warn(
(
f"The index_name argument has been moved to"
f" {type(_settings.agent.index).__name__},"
" this deprecation will conclude in version 6."
),
"The index_name argument has been moved to"
f" {type(_settings.agent.index).__name__},"
" this deprecation will conclude in version 6.",
category=DeprecationWarning,
stacklevel=2,
)
Expand All @@ -620,11 +618,9 @@ async def get_directory_index( # noqa: PLR0912

if not sync_index_w_directory:
warnings.warn(
(
f"The sync_index_w_directory argument has been moved to"
f" {type(_settings.agent.index).__name__},"
" this deprecation will conclude in version 6."
),
"The sync_index_w_directory argument has been moved to"
f" {type(_settings.agent.index).__name__},"
" this deprecation will conclude in version 6.",
category=DeprecationWarning,
stacklevel=2,
)
Expand Down Expand Up @@ -686,7 +682,8 @@ async def get_directory_index( # noqa: PLR0912
)
else:
logger.debug(
f"File {rel_file_path} found in paper directory {paper_directory}."
f"File {rel_file_path} found in paper directory"
f" {paper_directory}."
)

if search_index.changed:
Expand Down
12 changes: 9 additions & 3 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,15 @@
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Self, assert_never

from aviary.env import ENV_REGISTRY, TASK_DATASET_REGISTRY, Frame, TaskDataset
from aviary.message import Message
from aviary.tools import ToolRequestMessage, ToolResponseMessage
from aviary.core import (
TASK_DATASET_REGISTRY,
Frame,
Message,
TaskDataset,
ToolRequestMessage,
ToolResponseMessage,
)
from aviary.env import ENV_REGISTRY

from paperqa.types import DocDetails

Expand Down
2 changes: 1 addition & 1 deletion paperqa/clients/semantic_scholar.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ async def s2_title_search(
# need to check if nested under a 'data' key or not (depends on filtering)
if (
strings_similarity(
data.get("title") if "data" not in data else data["data"][0]["title"],
data.get("title", "") if "data" not in data else data["data"][0]["title"],
title,
)
< title_similarity_threshold
Expand Down
3 changes: 2 additions & 1 deletion paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ def __init__(self, **kwargs):
from sentence_transformers import SentenceTransformer
except ImportError as exc:
raise ImportError(
"Please install paper-qa[local] to use SentenceTransformerEmbeddingModel."
"Please install paper-qa[local] to use"
" SentenceTransformerEmbeddingModel."
) from exc

self._model = SentenceTransformer(self.name)
Expand Down
6 changes: 2 additions & 4 deletions paperqa/rate_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,6 @@ async def try_acquire(
raise_impossible_limits (:obj:`bool`, optional): flag will raise a
ValueError for weights that exceed the rate.
Returns:
None if the rate limit is satisfied.
Raises:
TimeoutError: if the acquire_timeout is exceeded.
ValueError: if the weight exceeds the rate limit and raise_impossible_limits is True.
Expand All @@ -352,7 +349,8 @@ async def try_acquire(

if rate_limit.amount < weight and raise_impossible_limits:
raise ValueError(
f"Weight ({weight}) > RateLimit ({rate_limit}), cannot satisfy rate limit."
f"Weight ({weight}) > RateLimit ({rate_limit}), cannot satisfy rate"
" limit."
)
while True:
elapsed = 0.0
Expand Down
12 changes: 9 additions & 3 deletions paperqa/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from math import ceil
from pathlib import Path
from typing import Literal, overload
from typing import Literal, cast, overload

import pymupdf
import tiktoken
Expand Down Expand Up @@ -172,7 +172,9 @@ def chunk_text(
f"ParsedText.content must be a `str`, not {type(parsed_text.content)}."
)

content = parsed_text.content if not use_tiktoken else parsed_text.encode_content()
content: str | list[int] = (
parsed_text.content if not use_tiktoken else parsed_text.encode_content()
)
if not content: # Avoid div0 in token calculations
raise ImpossibleParsingError(
f"No text was parsed from the document named {doc.docname!r} with ID"
Expand All @@ -195,7 +197,11 @@ def chunk_text(
]
texts.append(
Text(
text=enc.decode(split) if use_tiktoken else split,
text=(
enc.decode(cast(list[int], split))
if use_tiktoken
else cast(str, split)
),
name=f"{doc.docname} chunk {i + 1}",
doc=doc,
)
Expand Down
35 changes: 16 additions & 19 deletions paperqa/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import TYPE_CHECKING, Any, ClassVar, Self, assert_never, cast

import anyio
from aviary.tools import ToolSelector
from aviary.core import ToolSelector
from pydantic import (
BaseModel,
ConfigDict,
Expand Down Expand Up @@ -99,11 +99,9 @@ def _deprecated_field(self) -> Self:
# default is True, so we only warn if it's False
if not self.evidence_detailed_citations:
warnings.warn(
(
"The 'evidence_detailed_citations' field is deprecated and will be"
" removed in version 6. Adjust 'PromptSettings.context_inner' to remove"
" detailed citations."
),
"The 'evidence_detailed_citations' field is deprecated and will be"
" removed in version 6. Adjust 'PromptSettings.context_inner' to remove"
" detailed citations.",
category=DeprecationWarning,
stacklevel=2,
)
Expand Down Expand Up @@ -259,8 +257,10 @@ class PromptSettings(BaseModel):
)
context_inner: str = Field(
default=CONTEXT_INNER_PROMPT,
description="Prompt for how to format a single context in generate answer. "
"This should at least contain key and name.",
description=(
"Prompt for how to format a single context in generate answer. "
"This should at least contain key and name."
),
)

@field_validator("summary")
Expand Down Expand Up @@ -380,7 +380,8 @@ class IndexSettings(BaseModel):
default=True,
description=(
"Whether to sync the index with the paper directory when loading an index."
" Setting to True will add or delete index files to match the source paper directory."
" Setting to True will add or delete index files to match the source paper"
" directory."
),
)

Expand Down Expand Up @@ -537,11 +538,9 @@ def _deprecated_field(self) -> Self:
value = getattr(self, deprecated_field_name)
if value != type(self).model_fields[deprecated_field_name].default:
warnings.warn(
(
f"The {deprecated_field_name!r} field has been moved to"
f" {AgentSettings.__name__},"
" this deprecation will conclude in version 6."
),
f"The {deprecated_field_name!r} field has been moved to"
f" {AgentSettings.__name__},"
" this deprecation will conclude in version 6.",
category=DeprecationWarning,
stacklevel=2,
)
Expand Down Expand Up @@ -667,11 +666,9 @@ def _deprecated_field(self) -> Self:
value = getattr(self, deprecated_field_name)
if value != type(self).model_fields[deprecated_field_name].default:
warnings.warn(
(
f"The {deprecated_field_name!r} field has been moved to"
f" {AgentSettings.__name__},"
" this deprecation will conclude in version 6."
),
f"The {deprecated_field_name!r} field has been moved to"
f" {AgentSettings.__name__},"
" this deprecation will conclude in version 6.",
category=DeprecationWarning,
stacklevel=2,
)
Expand Down
12 changes: 7 additions & 5 deletions paperqa/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import litellm # for cost
import tiktoken
from aviary.message import Message
from aviary.core import Message
from pybtex.database import BibliographyData, Entry, Person
from pybtex.database.input.bibtex import Parser
from pybtex.scanner import PybtexSyntaxError
Expand Down Expand Up @@ -360,8 +360,10 @@ class DocDetails(Doc):
file_location: str | os.PathLike | None = None
license: str | None = Field(
default=None,
description="string indicating license."
" Should refer specifically to pdf_url (since that could be preprint). None means unknown/unset.",
description=(
"string indicating license. Should refer specifically to pdf_url (since"
" that could be preprint). None means unknown/unset."
),
)
pdf_url: str | None = None
other: dict[str, Any] = Field(
Expand Down Expand Up @@ -612,8 +614,8 @@ def formatted_citation(self) -> str:

if self.source_quality_message:
return (
f"{self.citation} This article has {self.citation_count} citations and is"
f" from a {self.source_quality_message}."
f"{self.citation} This article has {self.citation_count} citations and"
f" is from a {self.source_quality_message}."
)
return f"{self.citation} This article has {self.citation_count} citations."

Expand Down
Loading

0 comments on commit 8eaef2e

Please sign in to comment.