-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: support loading prompts from github repos (#8)
* docs: use monthly downloads counter * feat: support loading prompts from github repos --------- Co-authored-by: Dmitry Labazkin <[email protected]>
- Loading branch information
1 parent
2132648
commit 41facb2
Showing
4 changed files
with
133 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
"""Utilities for loading templates from gigachain | ||
github-based hub or other extenal sources.""" | ||
|
||
import os | ||
import re | ||
import tempfile | ||
from pathlib import Path, PurePosixPath | ||
from typing import Any, Callable, Optional, Set, TypeVar, Union | ||
from urllib.parse import urljoin | ||
|
||
import requests | ||
from langchain_core.prompts.base import BasePromptTemplate | ||
from langchain_core.prompts.loading import _load_prompt_from_file | ||
|
||
DEFAULT_REF = os.environ.get("GIGACHAIN_HUB_DEFAULT_REF", "master") | ||
URL_BASE = os.environ.get( | ||
"GIGACHAIN_HUB_DEFAULT_REF", | ||
"https://raw.githubusercontent.com/ai-forever/gigachain/{ref}/hub/", | ||
) | ||
HUB_PATH_RE = re.compile(r"lc(?P<ref>@[^:]+)?://(?P<path>.*)") | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
def _load_from_giga_hub( | ||
path: Union[str, Path], | ||
loader: Callable[[str], T], | ||
valid_prefix: str, | ||
valid_suffixes: Set[str], | ||
**kwargs: Any, | ||
) -> Optional[T]: | ||
"""Load configuration from hub. Returns None if path is not a hub path.""" | ||
if not isinstance(path, str) or not (match := HUB_PATH_RE.match(path)): | ||
return None | ||
ref, remote_path_str = match.groups() | ||
ref = ref[1:] if ref else DEFAULT_REF | ||
remote_path = Path(remote_path_str) | ||
if remote_path.parts[0] != valid_prefix: | ||
return None | ||
if remote_path.suffix[1:] not in valid_suffixes: | ||
raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.") | ||
|
||
# Using Path with URLs is not recommended, because on Windows | ||
# the backslash is used as the path separator, which can cause issues | ||
# when working with URLs that use forward slashes as the path separator. | ||
# Instead, use PurePosixPath to ensure that forward slashes are used as the | ||
# path separator, regardless of the operating system. | ||
full_url = urljoin(URL_BASE.format(ref=ref), PurePosixPath(remote_path).__str__()) | ||
|
||
r = requests.get(full_url, timeout=5) | ||
if r.status_code != 200: | ||
raise ValueError(f"Could not find file at {full_url}") | ||
with tempfile.TemporaryDirectory() as tmpdirname: | ||
file = Path(tmpdirname) / remote_path.name | ||
with open(file, "wb") as f: | ||
f.write(r.content) | ||
return loader(str(file), **kwargs) | ||
|
||
|
||
def load_from_giga_hub(path: Union[str, Path]) -> BasePromptTemplate: | ||
"""Unified method for loading a prompt from GigaChain repo or local fs.""" | ||
if hub_result := _load_from_giga_hub( | ||
path, _load_prompt_from_file, "prompts", {"py", "json", "yaml"} | ||
): | ||
return hub_result | ||
else: | ||
raise ValueError("Prompt not found in GigaChain Hub.") |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from typing import Generator | ||
|
||
import pytest | ||
import requests_mock | ||
from langchain_core.prompts.prompt import PromptTemplate | ||
|
||
from langchain_gigachat.tools.load_prompt import load_from_giga_hub | ||
|
||
|
||
@pytest.fixture | ||
def mock_requests_get() -> Generator: | ||
with requests_mock.Mocker() as mocker: | ||
mocker.get( | ||
"https://raw.githubusercontent.com/ai-forever/gigachain/master/hub/prompts/entertainment/meditation.yaml", | ||
text=( | ||
"input_variables: [background, topic]\n" | ||
"output_parser: null\n" | ||
"template: 'Create mediation for {topic} with {background}'\n" | ||
"template_format: f-string\n" | ||
"_type: prompt" | ||
), | ||
) | ||
yield mocker | ||
|
||
|
||
def test__load_from_giga_hub(mock_requests_get: Generator) -> None: | ||
template = load_from_giga_hub("lc://prompts/entertainment/meditation.yaml") | ||
assert isinstance(template, PromptTemplate) | ||
assert template.template == "Create mediation for {topic} with {background}" | ||
assert "background" in template.input_variables | ||
assert "topic" in template.input_variables |