-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* docs: add authorization methods to README for GigaChat usage (#1) * ci: add workflow for dev branch (#2) * docs: use monthly downloads counter (#3) * docs: Readme update * feat: support loading prompts from github repos (#8) * docs: use monthly downloads counter * feat: support loading prompts from github repos --------- Co-authored-by: Dmitry Labazkin <[email protected]> * Support few-shots from class description (#10) * docs: use monthly downloads counter * few_shot_examples можно задавать через pydantic схему * feat: add pydantic schema support for few_shot_examples --------- Co-authored-by: Dmitry Labazkin <[email protected]> Co-authored-by: NIK-TIGER-BILL <[email protected]> * chore: minor version up (#14) --------- Co-authored-by: Dmitry Labazkin <[email protected]> Co-authored-by: NIK-TIGER-BILL <[email protected]>
- Loading branch information
1 parent
1a8364a
commit 1d57eb3
Showing
9 changed files
with
260 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
"""Utilities for loading templates from gigachain | ||
github-based hub or other extenal sources.""" | ||
|
||
import os | ||
import re | ||
import tempfile | ||
from pathlib import Path, PurePosixPath | ||
from typing import Any, Callable, Optional, Set, TypeVar, Union | ||
from urllib.parse import urljoin | ||
|
||
import requests | ||
from langchain_core.prompts.base import BasePromptTemplate | ||
from langchain_core.prompts.loading import _load_prompt_from_file | ||
|
||
DEFAULT_REF = os.environ.get("GIGACHAIN_HUB_DEFAULT_REF", "master") | ||
URL_BASE = os.environ.get( | ||
"GIGACHAIN_HUB_DEFAULT_REF", | ||
"https://raw.githubusercontent.com/ai-forever/gigachain/{ref}/hub/", | ||
) | ||
HUB_PATH_RE = re.compile(r"lc(?P<ref>@[^:]+)?://(?P<path>.*)") | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
def _load_from_giga_hub( | ||
path: Union[str, Path], | ||
loader: Callable[[str], T], | ||
valid_prefix: str, | ||
valid_suffixes: Set[str], | ||
**kwargs: Any, | ||
) -> Optional[T]: | ||
"""Load configuration from hub. Returns None if path is not a hub path.""" | ||
if not isinstance(path, str) or not (match := HUB_PATH_RE.match(path)): | ||
return None | ||
ref, remote_path_str = match.groups() | ||
ref = ref[1:] if ref else DEFAULT_REF | ||
remote_path = Path(remote_path_str) | ||
if remote_path.parts[0] != valid_prefix: | ||
return None | ||
if remote_path.suffix[1:] not in valid_suffixes: | ||
raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.") | ||
|
||
# Using Path with URLs is not recommended, because on Windows | ||
# the backslash is used as the path separator, which can cause issues | ||
# when working with URLs that use forward slashes as the path separator. | ||
# Instead, use PurePosixPath to ensure that forward slashes are used as the | ||
# path separator, regardless of the operating system. | ||
full_url = urljoin(URL_BASE.format(ref=ref), PurePosixPath(remote_path).__str__()) | ||
|
||
r = requests.get(full_url, timeout=5) | ||
if r.status_code != 200: | ||
raise ValueError(f"Could not find file at {full_url}") | ||
with tempfile.TemporaryDirectory() as tmpdirname: | ||
file = Path(tmpdirname) / remote_path.name | ||
with open(file, "wb") as f: | ||
f.write(r.content) | ||
return loader(str(file), **kwargs) | ||
|
||
|
||
def load_from_giga_hub(path: Union[str, Path]) -> BasePromptTemplate: | ||
"""Unified method for loading a prompt from GigaChain repo or local fs.""" | ||
if hub_result := _load_from_giga_hub( | ||
path, _load_prompt_from_file, "prompts", {"py", "json", "yaml"} | ||
): | ||
return hub_result | ||
else: | ||
raise ValueError("Prompt not found in GigaChain Hub.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from typing import Generator | ||
|
||
import pytest | ||
import requests_mock | ||
from langchain_core.prompts.prompt import PromptTemplate | ||
|
||
from langchain_gigachat.tools.load_prompt import load_from_giga_hub | ||
|
||
|
||
@pytest.fixture | ||
def mock_requests_get() -> Generator: | ||
with requests_mock.Mocker() as mocker: | ||
mocker.get( | ||
"https://raw.githubusercontent.com/ai-forever/gigachain/master/hub/prompts/entertainment/meditation.yaml", | ||
text=( | ||
"input_variables: [background, topic]\n" | ||
"output_parser: null\n" | ||
"template: 'Create mediation for {topic} with {background}'\n" | ||
"template_format: f-string\n" | ||
"_type: prompt" | ||
), | ||
) | ||
yield mocker | ||
|
||
|
||
def test__load_from_giga_hub(mock_requests_get: Generator) -> None: | ||
template = load_from_giga_hub("lc://prompts/entertainment/meditation.yaml") | ||
assert isinstance(template, PromptTemplate) | ||
assert template.template == "Create mediation for {topic} with {background}" | ||
assert "background" in template.input_variables | ||
assert "topic" in template.input_variables |