From a4db19ae164562e5f98f6a5ea13b04dd6a96481c Mon Sep 17 00:00:00 2001 From: James Braza Date: Thu, 19 Dec 2024 14:20:47 -0800 Subject: [PATCH] Consolidated LDP imports into `ldp_shims` module (#772) --- paperqa/_ldp_shims.py | 50 ++++++++++++++++++++++++++++++++++++++++++ paperqa/agents/main.py | 10 +-------- paperqa/agents/task.py | 17 +++----------- paperqa/litqa.py | 6 +---- paperqa/settings.py | 34 +++++++++++----------------- 5 files changed, 68 insertions(+), 49 deletions(-) create mode 100644 paperqa/_ldp_shims.py diff --git a/paperqa/_ldp_shims.py b/paperqa/_ldp_shims.py new file mode 100644 index 00000000..39c8ef4e --- /dev/null +++ b/paperqa/_ldp_shims.py @@ -0,0 +1,50 @@ +"""Centralized place for lazy LDP imports.""" + +__all__ = [ + "HAS_LDP_INSTALLED", + "Agent", + "Callback", + "ComputeTrajectoryMetricsMixin", + "HTTPAgentClient", + "Memory", + "MemoryAgent", + "ReActAgent", + "RolloutManager", + "SimpleAgent", + "SimpleAgentState", + "UIndexMemoryModel", + "_Memories", + "discounted_returns", + "set_training_mode", +] + +from pydantic import TypeAdapter + +try: + from ldp.agent import ( + Agent, + HTTPAgentClient, + MemoryAgent, + ReActAgent, + SimpleAgent, + SimpleAgentState, + ) + from ldp.alg import Callback, ComputeTrajectoryMetricsMixin, RolloutManager + from ldp.graph.memory import Memory, UIndexMemoryModel + from ldp.graph.op_utils import set_training_mode + from ldp.utils import discounted_returns + + _Memories = TypeAdapter(dict[int, Memory] | list[Memory]) # type: ignore[var-annotated] + + HAS_LDP_INSTALLED = True +except ImportError: + HAS_LDP_INSTALLED = False + + class ComputeTrajectoryMetricsMixin: # type: ignore[no-redef] + """Placeholder parent class for when ldp isn't installed.""" + + class Callback: # type: ignore[no-redef] + """Placeholder parent class for when ldp isn't installed.""" + + RolloutManager = None # type: ignore[assignment,misc] + discounted_returns = None # type: ignore[assignment] diff --git a/paperqa/agents/main.py b/paperqa/agents/main.py index 01612df4..5f453281 100644 --- a/paperqa/agents/main.py +++ b/paperqa/agents/main.py @@ -20,15 +20,7 @@ stop_after_attempt, ) -try: - from ldp.alg import Callback, RolloutManager -except ImportError: - - class Callback: # type: ignore[no-redef] - """Placeholder parent class for when ldp isn't installed.""" - - RolloutManager = None # type: ignore[assignment,misc] - +from paperqa._ldp_shims import Callback, RolloutManager from paperqa.docs import Docs from paperqa.settings import AgentSettings from paperqa.types import PQASession diff --git a/paperqa/agents/task.py b/paperqa/agents/task.py index e71a4cf7..96220146 100644 --- a/paperqa/agents/task.py +++ b/paperqa/agents/task.py @@ -24,21 +24,9 @@ ToolResponseMessage, ) from aviary.env import ENV_REGISTRY - -from paperqa.types import DocDetails - -from .search import SearchIndex, maybe_get_manifest - -try: - from ldp.alg import ComputeTrajectoryMetricsMixin -except ImportError: - - class ComputeTrajectoryMetricsMixin: # type: ignore[no-redef] - """Placeholder for when ldp isn't installed.""" - - from llmclient import EmbeddingModel, LiteLLMModel, LLMModel +from paperqa._ldp_shims import ComputeTrajectoryMetricsMixin from paperqa.docs import Docs from paperqa.litqa import ( DEFAULT_EVAL_MODEL_NAME, @@ -47,10 +35,11 @@ class ComputeTrajectoryMetricsMixin: # type: ignore[no-redef] LitQAEvaluation, read_litqa_v2_from_hub, ) -from paperqa.types import PQASession +from paperqa.types import DocDetails, PQASession from .env import POPULATE_FROM_SETTINGS, PaperQAEnvironment from .models import QueryRequest +from .search import SearchIndex, maybe_get_manifest from .tools import Complete if TYPE_CHECKING: diff --git a/paperqa/litqa.py b/paperqa/litqa.py index 4bfa7956..2f5607af 100644 --- a/paperqa/litqa.py +++ b/paperqa/litqa.py @@ -10,14 +10,10 @@ from enum import StrEnum from typing import TYPE_CHECKING, Literal, Self -try: - from ldp.utils import discounted_returns -except ImportError: - discounted_returns = None # type: ignore[assignment] - from aviary.core import Message from llmclient import LiteLLMModel, LLMModel +from paperqa._ldp_shims import discounted_returns from paperqa.prompts import EVAL_PROMPT_TEMPLATE, QA_PROMPT_TEMPLATE from paperqa.settings import make_default_litellm_model_list_settings from paperqa.types import PQASession diff --git a/paperqa/settings.py b/paperqa/settings.py index 5e54fc39..907b7410 100644 --- a/paperqa/settings.py +++ b/paperqa/settings.py @@ -10,37 +10,29 @@ import anyio from aviary.core import ToolSelector +from llmclient import EmbeddingModel, LiteLLMModel, embedding_model_factory from pydantic import ( BaseModel, ConfigDict, Field, - TypeAdapter, computed_field, field_validator, model_validator, ) from pydantic_settings import BaseSettings, CliSettingsSource, SettingsConfigDict -try: - from ldp.agent import ( - Agent, - HTTPAgentClient, - MemoryAgent, - ReActAgent, - SimpleAgent, - SimpleAgentState, - ) - from ldp.graph.memory import Memory, UIndexMemoryModel - from ldp.graph.op_utils import set_training_mode - - _Memories = TypeAdapter(dict[int, Memory] | list[Memory]) # type: ignore[var-annotated] - - HAS_LDP_INSTALLED = True -except ImportError: - HAS_LDP_INSTALLED = False - -from llmclient import EmbeddingModel, LiteLLMModel, embedding_model_factory - +from paperqa._ldp_shims import ( + HAS_LDP_INSTALLED, + Agent, + HTTPAgentClient, + MemoryAgent, + ReActAgent, + SimpleAgent, + SimpleAgentState, + UIndexMemoryModel, + _Memories, + set_training_mode, +) from paperqa.prompts import ( CONTEXT_INNER_PROMPT, CONTEXT_OUTER_PROMPT,