Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added rest_server connector #2

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions protollm/connectors/rest_server.py

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General Suggestions:
It would be easier for me to review that if there was a simple small usage example. Also tests would not hurt since they are easily generated by LLMs.
Custom errors (ChatRESTServerError or smth) instead of ValueErrors would be better, but that requires writing an Exceptions module, which is quite easy to do.
A custom logger module with logs would not hurt too.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With custom errors I agree, but it is ValueErrors that throws similar langchain classes in these situations. I thought about tests, but they are not possible, for the reason that it is an open library and I can't use any internal LLM url.
Logging tools for Runnables are built into langchain.

Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import json
from typing import Any, Dict, List, Optional, Union

import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (AIMessage, BaseMessage, HumanMessage,
SystemMessage)
from langchain_core.outputs import ChatGeneration, ChatResult


class ChatRESTServer(BaseChatModel):
model: Optional[str] = 'llama3'
base_url: str = 'http://10.32.2.2:8672'
AaLexUser marked this conversation as resolved.
Show resolved Hide resolved
timeout: Optional[int] = None
jrzkaminski marked this conversation as resolved.
Show resolved Hide resolved
AaLexUser marked this conversation as resolved.
Show resolved Hide resolved
"""Timeout for the request stream"""
jrzkaminski marked this conversation as resolved.
Show resolved Hide resolved

@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "chat-rest-server"
jrzkaminski marked this conversation as resolved.
Show resolved Hide resolved

def _convert_messages_to_rest_server_messages(
self, messages: List[BaseMessage]
) -> List[Dict[str, Union[str, List[str]]]]:
chat_messages: List = []
for message in messages:
role = ""
match message:
case HumanMessage():
role = "user"
case AIMessage():
role = "assistant"
case SystemMessage():
role = "system"
case _:
raise ValueError("Received unsupported message type.")


content = ""
if isinstance(message.content, str):
content = message.content
else:
raise ValueError(
"Unsupported message content type. "
"Must have type 'text' "
)
chat_messages.append(
{
"role": role,
"content": content
}
)
return chat_messages

def _create_chat(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
payload = {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Max_tokens, temperature and other parameters must be configurable by user, consider adding these fields to that dictionary.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What we have on our server now does not support configuration of any parameters. When it will be available then I will add and test it.

"model": self.model,
"messages": self._convert_messages_to_rest_server_messages(
messages)
}
response = requests.post(
url=f'{self.base_url}/v1/chat/completions',
Copy link

@jrzkaminski jrzkaminski Aug 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some older models use /v1/completions with manual tokens. Also, some newer models provide more control over prompt in raw /v1/completions mode, which could be extremely beneficial. I have no clue how to generalize that, but that's something to keep in mind, because come users may have a custom LLM that operates custom tokens. There are also other modes like tokenization and so on.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's for the LLM classes in langchain. ChatBaseModel classes assume chat exactly, not custom tokens. Maybe it makes sense to let the developer choose the endpoint, but then I think the whole point of this class is lost. There are too many additional settings and I would choose to write my own class then. But I don't know, it's debatable.

headers={"Content-Type": "application/json"},
json=payload,
timeout=self.timeout
)
response.encoding = "utf-8"
match response.status_code:
case 200:
continue # Status code is 200, no action needed
case 404:
raise ValueError(
"CustomWeb call failed with status code 404. "
"Maybe you need to connect to the corporate network."
)
case _:
optional_detail = response.text
raise ValueError(
f"CustomWeb call failed with status code "
f"{response.status_code}. "
f"Details: {optional_detail}"
)
return json.loads(response.text)

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
response = self._create_chat(messages, stop, **kwargs)
chat_generation = ChatGeneration(
message=AIMessage(
content=response['choices'][0]['message']['content']),
generation_info=response,
)
return ChatResult(generations=[chat_generation])

@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return a dictionary of identifying parameters.

This information is used by the LangChain callback system, which
is used for tracing purposes make it possible to monitor LLMs.
"""
return {
"model_name": self.model,
"url": self.base_url
}
5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# LangChain
langchain>=0.2.12

# Network
requests~=2.32.3