Skip to content

Commit

Permalink
feat: support loading prompts from github repos (#8)
Browse files Browse the repository at this point in the history
* docs: use monthly downloads counter
* feat: support loading prompts from github repos

---------

Co-authored-by: Dmitry Labazkin <[email protected]>
  • Loading branch information
Rai220 and labdmitriy authored Nov 11, 2024
1 parent 2132648 commit 41facb2
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 2 deletions.
67 changes: 67 additions & 0 deletions libs/gigachat/langchain_gigachat/tools/load_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Utilities for loading templates from gigachain
github-based hub or other extenal sources."""

import os
import re
import tempfile
from pathlib import Path, PurePosixPath
from typing import Any, Callable, Optional, Set, TypeVar, Union
from urllib.parse import urljoin

import requests
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.loading import _load_prompt_from_file

DEFAULT_REF = os.environ.get("GIGACHAIN_HUB_DEFAULT_REF", "master")
URL_BASE = os.environ.get(
"GIGACHAIN_HUB_DEFAULT_REF",
"https://raw.githubusercontent.com/ai-forever/gigachain/{ref}/hub/",
)
HUB_PATH_RE = re.compile(r"lc(?P<ref>@[^:]+)?://(?P<path>.*)")

T = TypeVar("T")


def _load_from_giga_hub(
path: Union[str, Path],
loader: Callable[[str], T],
valid_prefix: str,
valid_suffixes: Set[str],
**kwargs: Any,
) -> Optional[T]:
"""Load configuration from hub. Returns None if path is not a hub path."""
if not isinstance(path, str) or not (match := HUB_PATH_RE.match(path)):
return None
ref, remote_path_str = match.groups()
ref = ref[1:] if ref else DEFAULT_REF
remote_path = Path(remote_path_str)
if remote_path.parts[0] != valid_prefix:
return None
if remote_path.suffix[1:] not in valid_suffixes:
raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.")

# Using Path with URLs is not recommended, because on Windows
# the backslash is used as the path separator, which can cause issues
# when working with URLs that use forward slashes as the path separator.
# Instead, use PurePosixPath to ensure that forward slashes are used as the
# path separator, regardless of the operating system.
full_url = urljoin(URL_BASE.format(ref=ref), PurePosixPath(remote_path).__str__())

r = requests.get(full_url, timeout=5)
if r.status_code != 200:
raise ValueError(f"Could not find file at {full_url}")
with tempfile.TemporaryDirectory() as tmpdirname:
file = Path(tmpdirname) / remote_path.name
with open(file, "wb") as f:
f.write(r.content)
return loader(str(file), **kwargs)


def load_from_giga_hub(path: Union[str, Path]) -> BasePromptTemplate:
"""Unified method for loading a prompt from GigaChain repo or local fs."""
if hub_result := _load_from_giga_hub(
path, _load_prompt_from_file, "prompts", {"py", "json", "yaml"}
):
return hub_result
else:
raise ValueError("Prompt not found in GigaChain Hub.")
35 changes: 33 additions & 2 deletions libs/gigachat/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions libs/gigachat/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ license = "MIT"
python = ">=3.9,<4.0"
langchain-core = "^0.3"
gigachat = "^0.1.35"
types-requests = "^2.32"

[tool.poetry.group.dev]
optional = true
Expand Down Expand Up @@ -41,6 +42,7 @@ pytest = "^8.3.3"
pytest-cov = "^5.0.0"
pytest-asyncio = "^0.24.0"
pytest-mock = "^3.14.0"
requests_mock = "^1.12.1"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
31 changes: 31 additions & 0 deletions libs/gigachat/tests/unit_tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Generator

import pytest
import requests_mock
from langchain_core.prompts.prompt import PromptTemplate

from langchain_gigachat.tools.load_prompt import load_from_giga_hub


@pytest.fixture
def mock_requests_get() -> Generator:
with requests_mock.Mocker() as mocker:
mocker.get(
"https://raw.githubusercontent.com/ai-forever/gigachain/master/hub/prompts/entertainment/meditation.yaml",
text=(
"input_variables: [background, topic]\n"
"output_parser: null\n"
"template: 'Create mediation for {topic} with {background}'\n"
"template_format: f-string\n"
"_type: prompt"
),
)
yield mocker


def test__load_from_giga_hub(mock_requests_get: Generator) -> None:
template = load_from_giga_hub("lc://prompts/entertainment/meditation.yaml")
assert isinstance(template, PromptTemplate)
assert template.template == "Create mediation for {topic} with {background}"
assert "background" in template.input_variables
assert "topic" in template.input_variables

0 comments on commit 41facb2

Please sign in to comment.