-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added rest_server connector #2
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import json | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
import requests | ||
from langchain_core.callbacks import CallbackManagerForLLMRun | ||
from langchain_core.language_models import BaseChatModel | ||
from langchain_core.messages import (AIMessage, BaseMessage, HumanMessage, | ||
SystemMessage) | ||
from langchain_core.outputs import ChatGeneration, ChatResult | ||
|
||
|
||
class ChatRESTServer(BaseChatModel): | ||
model: Optional[str] = 'llama3' | ||
base_url: str = 'http://10.32.2.2:8672' | ||
AaLexUser marked this conversation as resolved.
Show resolved
Hide resolved
|
||
timeout: Optional[int] = None | ||
jrzkaminski marked this conversation as resolved.
Show resolved
Hide resolved
AaLexUser marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Timeout for the request stream""" | ||
jrzkaminski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@property | ||
def _llm_type(self) -> str: | ||
"""Return type of chat model.""" | ||
return "chat-rest-server" | ||
jrzkaminski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def _convert_messages_to_rest_server_messages( | ||
self, messages: List[BaseMessage] | ||
) -> List[Dict[str, Union[str, List[str]]]]: | ||
chat_messages: List = [] | ||
for message in messages: | ||
role = "" | ||
match message: | ||
case HumanMessage(): | ||
role = "user" | ||
case AIMessage(): | ||
role = "assistant" | ||
case SystemMessage(): | ||
role = "system" | ||
case _: | ||
raise ValueError("Received unsupported message type.") | ||
|
||
|
||
content = "" | ||
if isinstance(message.content, str): | ||
content = message.content | ||
else: | ||
raise ValueError( | ||
"Unsupported message content type. " | ||
"Must have type 'text' " | ||
) | ||
chat_messages.append( | ||
{ | ||
"role": role, | ||
"content": content | ||
} | ||
) | ||
return chat_messages | ||
|
||
def _create_chat( | ||
self, | ||
messages: List[BaseMessage], | ||
stop: Optional[List[str]] = None, | ||
**kwargs: Any, | ||
) -> Dict[str, Any]: | ||
payload = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Max_tokens, temperature and other parameters must be configurable by user, consider adding these fields to that dictionary. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What we have on our server now does not support configuration of any parameters. When it will be available then I will add and test it. |
||
"model": self.model, | ||
"messages": self._convert_messages_to_rest_server_messages( | ||
messages) | ||
} | ||
response = requests.post( | ||
url=f'{self.base_url}/v1/chat/completions', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some older models use /v1/completions with manual tokens. Also, some newer models provide more control over prompt in raw /v1/completions mode, which could be extremely beneficial. I have no clue how to generalize that, but that's something to keep in mind, because come users may have a custom LLM that operates custom tokens. There are also other modes like tokenization and so on. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's for the LLM classes in langchain. ChatBaseModel classes assume chat exactly, not custom tokens. Maybe it makes sense to let the developer choose the endpoint, but then I think the whole point of this class is lost. There are too many additional settings and I would choose to write my own class then. But I don't know, it's debatable. |
||
headers={"Content-Type": "application/json"}, | ||
json=payload, | ||
timeout=self.timeout | ||
) | ||
response.encoding = "utf-8" | ||
match response.status_code: | ||
case 200: | ||
continue # Status code is 200, no action needed | ||
case 404: | ||
raise ValueError( | ||
"CustomWeb call failed with status code 404. " | ||
"Maybe you need to connect to the corporate network." | ||
) | ||
case _: | ||
optional_detail = response.text | ||
raise ValueError( | ||
f"CustomWeb call failed with status code " | ||
f"{response.status_code}. " | ||
f"Details: {optional_detail}" | ||
) | ||
return json.loads(response.text) | ||
|
||
def _generate( | ||
self, | ||
messages: List[BaseMessage], | ||
stop: Optional[List[str]] = None, | ||
run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
**kwargs: Any, | ||
) -> ChatResult: | ||
response = self._create_chat(messages, stop, **kwargs) | ||
chat_generation = ChatGeneration( | ||
message=AIMessage( | ||
content=response['choices'][0]['message']['content']), | ||
generation_info=response, | ||
) | ||
return ChatResult(generations=[chat_generation]) | ||
|
||
@property | ||
def _identifying_params(self) -> Dict[str, Any]: | ||
"""Return a dictionary of identifying parameters. | ||
|
||
This information is used by the LangChain callback system, which | ||
is used for tracing purposes make it possible to monitor LLMs. | ||
""" | ||
return { | ||
"model_name": self.model, | ||
"url": self.base_url | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# LangChain | ||
langchain>=0.2.12 | ||
|
||
# Network | ||
requests~=2.32.3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General Suggestions:
It would be easier for me to review that if there was a simple small usage example. Also tests would not hurt since they are easily generated by LLMs.
Custom errors (
ChatRESTServerError
or smth) instead ofValueErrors
would be better, but that requires writing an Exceptions module, which is quite easy to do.A custom logger module with logs would not hurt too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With custom errors I agree, but it is ValueErrors that throws similar langchain classes in these situations. I thought about tests, but they are not possible, for the reason that it is an open library and I can't use any internal LLM url.
Logging tools for Runnables are built into langchain.