Skip to content

Commit

Permalink
Implemented function calling server api.
Browse files Browse the repository at this point in the history
  • Loading branch information
xsxszab committed Feb 3, 2025
1 parent dd42c73 commit 79a53cd
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 10 deletions.
9 changes: 9 additions & 0 deletions nexa/cli/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def run_ggml_server(args):
model_type = kwargs.pop("model_type", None)
hf = kwargs.pop('huggingface', False)
ms = kwargs.pop('modelscope', False)
use_function_calling = kwargs.pop('function_calling', False)

run_type = None
if model_type:
Expand Down Expand Up @@ -194,6 +195,7 @@ def run_ggml_server(args):
huggingface=hf,
modelscope=ms,
projector_local_path_arg=projector_local_path,
use_function_calling=use_function_calling
**kwargs
)

Expand Down Expand Up @@ -579,6 +581,13 @@ def main():
server_parser.add_argument("--port", type=int, default=8000, help="Port to bind the server to")
server_parser.add_argument("--reload", action="store_true", help="Enable automatic reloading on code changes")
server_parser.add_argument("--nctx", type=int, default=2048, help="Maximum context length of the model you're using")
server_parser.add_argument(
"-fc",
"--function_calling",
action="store_true",
help="Switch NLP model to handle function calling tasks."
)


# Other commands
pull_parser = subparsers.add_parser("pull", help="Pull a model from official or hub.")
Expand Down
78 changes: 68 additions & 10 deletions nexa/gguf/server/nexa_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,14 @@
import logging
import os
from pathlib import Path
import queue
import shutil
import socket
import threading
import time
import uuid
from typing import List, Optional, Dict, Any, Union, Literal
import base64
import multiprocessing
from PIL import Image
import tempfile
import concurrent
import tqdm
import uvicorn
from fastapi import FastAPI, HTTPException, Request, File, UploadFile, Query
from fastapi.middleware.cors import CORSMiddleware
Expand Down Expand Up @@ -54,7 +49,6 @@
from nexa.gguf.llama._utils_transformers import suppress_stdout_stderr
from nexa.general import add_model_to_list, default_use_processes, download_file_with_progress, get_model_info, is_model_exists, pull_model
from nexa.gguf.llama.llama import Llama
from nexa.gguf.nexa_inference_tts import NexaTTSInference
# temporarily disabled NexaOmniVlmInference and NexaAudioLMInference
# from nexa.gguf.nexa_inference_vlm_omni import NexaOmniVlmInference
# from nexa.gguf.nexa_inference_audio_lm import NexaAudioLMInference
Expand Down Expand Up @@ -223,6 +217,42 @@ class TextToSpeechRequest(BaseModel):
sampling_rate: int = 24000
language: Optional[str] = "en" # Only for 'outetts'

class FunctionCallRequest(BaseModel):
"""
Represents the request schema for an OpenAI-style function calling API.
Attributes:
tools (List[Dict[str, Any]]):
Defines the available function calls that can be executed.
messages (List[Dict[str, Any]]):
A list of messages representing the conversation history.
model_path (str):
The path to the model used for function calling.
"""
tools: List[Dict[str, Any]] = [
{
"type": "function",
"function": {
"name": "add_integer",
"description": "Returns the addition of input integers.",
"parameters": {
"type": "object",
"properties": {
"num1": {"type": "integer", "description": "An integer to add."},
"num2": {"type": "integer", "description": "An integer to add."}
},
"required": ["num1", "num2"],
"additionalProperties": False
},
"strict": True
}
}
]
messages: List[Dict[str, Any]] = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Please calculate the sum of 42 and 100."}
]

# New request class for embeddings
class EmbeddingRequest(BaseModel):
input: Union[str, List[str]] = Field(..., description="The input text to get embeddings for. Can be a string or an array of strings.")
Expand Down Expand Up @@ -325,6 +355,14 @@ def to_json(self):
# helper functions
async def load_model():
global model, chat_format, completion_template, model_path, n_ctx, is_local_path, model_type, is_huggingface, is_modelscope, projector_path
global use_function_calling

if use_function_calling and model_type != "NLP":
raise ValueError(
"Function calling is only supported for NLP models. "
"Please ensure that you are using a compatible NLP model before enabling this feature."
)

if is_local_path:
if model_type == "Multimodal":
if not projector_path:
Expand Down Expand Up @@ -353,9 +391,14 @@ async def load_model():
downloaded_path, model_type = pull_model(model_path)

print(f"model_type: {model_type}")

if use_function_calling:
print('Function calling option is enabled')

if model_type == "NLP" or model_type == "Text Embedding":
if model_path in NEXA_RUN_MODEL_MAP_FUNCTION_CALLING:
if model_type == "NLP" and use_function_calling:
from nexa.gguf.nexa_inference_text import NexaTextInference
model = NexaTextInference(model_path=model_path, function_calling=True)
elif model_path in NEXA_RUN_MODEL_MAP_FUNCTION_CALLING:
chat_format = "chatml-function-calling"
with suppress_stdout_stderr():
try:
Expand Down Expand Up @@ -725,12 +768,13 @@ def load_audio_from_bytes(audio_bytes: bytes):
a = librosa.resample(a, orig_sr=sr, target_sr=SAMPLING_RATE)
return a

def run_nexa_ai_service(model_path_arg=None, is_local_path_arg=False, model_type_arg=None, huggingface=False, modelscope=False, projector_local_path_arg=None, **kwargs):
global model_path, n_ctx, is_local_path, model_type, is_huggingface, is_modelscope, projector_path
def run_nexa_ai_service(model_path_arg=None, is_local_path_arg=False, model_type_arg=None, huggingface=False, modelscope=False, projector_local_path_arg=None, function_calling=False, **kwargs):
global model_path, n_ctx, is_local_path, model_type, is_huggingface, is_modelscope, projector_path, use_function_calling
is_local_path = is_local_path_arg
is_huggingface = huggingface
is_modelscope = modelscope
projector_path = projector_local_path_arg
use_function_calling = function_calling
if is_local_path_arg or huggingface or modelscope:
if not model_path_arg:
raise ValueError("model_path must be provided when using --local_path or --huggingface or --modelscope")
Expand Down Expand Up @@ -1437,6 +1481,7 @@ async def txt2speech(request: TextToSpeechRequest):
or model.sampling_rate != request.sampling_rate
or model.language != request.language
):
from nexa.gguf.nexa_inference_tts import NexaTTSInference
model = NexaTTSInference(
model_path=model_path,
tts_engine= 'bark' if 'bark' in model_path.lower() else 'outetts',
Expand Down Expand Up @@ -1475,6 +1520,19 @@ async def txt2speech(request: TextToSpeechRequest):
)
raise HTTPException(status_code=500, detail=str(e))

@app.post("/v1/func_calling", tags=["Function Calling"])
async def function_calling(request: FunctionCallRequest):
try:
json_response = model.function_calling(messages=request.messages, tools=request.tools)

return {
"created": time.time(),
"response": json_response
}
except Exception as e:
logging.error(f"Error in function calling: {e}")
raise HTTPException(status_code=500, detail=str(e))

@app.post("/v1/img2img", tags=["Computer Vision"])
async def img2img(request: ImageGenerationRequest):
try:
Expand Down

0 comments on commit 79a53cd

Please sign in to comment.