From 79a53cdaa9824d83ed05ea3578d1cfad161d5570 Mon Sep 17 00:00:00 2001 From: Yifei Wang <1277495324@qq.com> Date: Mon, 3 Feb 2025 10:55:19 -0800 Subject: [PATCH] Implemented function calling server api. --- nexa/cli/entry.py | 9 ++++ nexa/gguf/server/nexa_service.py | 78 ++++++++++++++++++++++++++++---- 2 files changed, 77 insertions(+), 10 deletions(-) diff --git a/nexa/cli/entry.py b/nexa/cli/entry.py index 621f329a..7ca998c5 100644 --- a/nexa/cli/entry.py +++ b/nexa/cli/entry.py @@ -159,6 +159,7 @@ def run_ggml_server(args): model_type = kwargs.pop("model_type", None) hf = kwargs.pop('huggingface', False) ms = kwargs.pop('modelscope', False) + use_function_calling = kwargs.pop('function_calling', False) run_type = None if model_type: @@ -194,6 +195,7 @@ def run_ggml_server(args): huggingface=hf, modelscope=ms, projector_local_path_arg=projector_local_path, + use_function_calling=use_function_calling **kwargs ) @@ -579,6 +581,13 @@ def main(): server_parser.add_argument("--port", type=int, default=8000, help="Port to bind the server to") server_parser.add_argument("--reload", action="store_true", help="Enable automatic reloading on code changes") server_parser.add_argument("--nctx", type=int, default=2048, help="Maximum context length of the model you're using") + server_parser.add_argument( + "-fc", + "--function_calling", + action="store_true", + help="Switch NLP model to handle function calling tasks." + ) + # Other commands pull_parser = subparsers.add_parser("pull", help="Pull a model from official or hub.") diff --git a/nexa/gguf/server/nexa_service.py b/nexa/gguf/server/nexa_service.py index 5128272a..19d51252 100644 --- a/nexa/gguf/server/nexa_service.py +++ b/nexa/gguf/server/nexa_service.py @@ -2,10 +2,7 @@ import logging import os from pathlib import Path -import queue -import shutil import socket -import threading import time import uuid from typing import List, Optional, Dict, Any, Union, Literal @@ -13,8 +10,6 @@ import multiprocessing from PIL import Image import tempfile -import concurrent -import tqdm import uvicorn from fastapi import FastAPI, HTTPException, Request, File, UploadFile, Query from fastapi.middleware.cors import CORSMiddleware @@ -54,7 +49,6 @@ from nexa.gguf.llama._utils_transformers import suppress_stdout_stderr from nexa.general import add_model_to_list, default_use_processes, download_file_with_progress, get_model_info, is_model_exists, pull_model from nexa.gguf.llama.llama import Llama -from nexa.gguf.nexa_inference_tts import NexaTTSInference # temporarily disabled NexaOmniVlmInference and NexaAudioLMInference # from nexa.gguf.nexa_inference_vlm_omni import NexaOmniVlmInference # from nexa.gguf.nexa_inference_audio_lm import NexaAudioLMInference @@ -223,6 +217,42 @@ class TextToSpeechRequest(BaseModel): sampling_rate: int = 24000 language: Optional[str] = "en" # Only for 'outetts' +class FunctionCallRequest(BaseModel): + """ + Represents the request schema for an OpenAI-style function calling API. + + Attributes: + tools (List[Dict[str, Any]]): + Defines the available function calls that can be executed. + messages (List[Dict[str, Any]]): + A list of messages representing the conversation history. + model_path (str): + The path to the model used for function calling. + """ + tools: List[Dict[str, Any]] = [ + { + "type": "function", + "function": { + "name": "add_integer", + "description": "Returns the addition of input integers.", + "parameters": { + "type": "object", + "properties": { + "num1": {"type": "integer", "description": "An integer to add."}, + "num2": {"type": "integer", "description": "An integer to add."} + }, + "required": ["num1", "num2"], + "additionalProperties": False + }, + "strict": True + } + } + ] + messages: List[Dict[str, Any]] = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Please calculate the sum of 42 and 100."} + ] + # New request class for embeddings class EmbeddingRequest(BaseModel): input: Union[str, List[str]] = Field(..., description="The input text to get embeddings for. Can be a string or an array of strings.") @@ -325,6 +355,14 @@ def to_json(self): # helper functions async def load_model(): global model, chat_format, completion_template, model_path, n_ctx, is_local_path, model_type, is_huggingface, is_modelscope, projector_path + global use_function_calling + + if use_function_calling and model_type != "NLP": + raise ValueError( + "Function calling is only supported for NLP models. " + "Please ensure that you are using a compatible NLP model before enabling this feature." + ) + if is_local_path: if model_type == "Multimodal": if not projector_path: @@ -353,9 +391,14 @@ async def load_model(): downloaded_path, model_type = pull_model(model_path) print(f"model_type: {model_type}") - + if use_function_calling: + print('Function calling option is enabled') + if model_type == "NLP" or model_type == "Text Embedding": - if model_path in NEXA_RUN_MODEL_MAP_FUNCTION_CALLING: + if model_type == "NLP" and use_function_calling: + from nexa.gguf.nexa_inference_text import NexaTextInference + model = NexaTextInference(model_path=model_path, function_calling=True) + elif model_path in NEXA_RUN_MODEL_MAP_FUNCTION_CALLING: chat_format = "chatml-function-calling" with suppress_stdout_stderr(): try: @@ -725,12 +768,13 @@ def load_audio_from_bytes(audio_bytes: bytes): a = librosa.resample(a, orig_sr=sr, target_sr=SAMPLING_RATE) return a -def run_nexa_ai_service(model_path_arg=None, is_local_path_arg=False, model_type_arg=None, huggingface=False, modelscope=False, projector_local_path_arg=None, **kwargs): - global model_path, n_ctx, is_local_path, model_type, is_huggingface, is_modelscope, projector_path +def run_nexa_ai_service(model_path_arg=None, is_local_path_arg=False, model_type_arg=None, huggingface=False, modelscope=False, projector_local_path_arg=None, function_calling=False, **kwargs): + global model_path, n_ctx, is_local_path, model_type, is_huggingface, is_modelscope, projector_path, use_function_calling is_local_path = is_local_path_arg is_huggingface = huggingface is_modelscope = modelscope projector_path = projector_local_path_arg + use_function_calling = function_calling if is_local_path_arg or huggingface or modelscope: if not model_path_arg: raise ValueError("model_path must be provided when using --local_path or --huggingface or --modelscope") @@ -1437,6 +1481,7 @@ async def txt2speech(request: TextToSpeechRequest): or model.sampling_rate != request.sampling_rate or model.language != request.language ): + from nexa.gguf.nexa_inference_tts import NexaTTSInference model = NexaTTSInference( model_path=model_path, tts_engine= 'bark' if 'bark' in model_path.lower() else 'outetts', @@ -1475,6 +1520,19 @@ async def txt2speech(request: TextToSpeechRequest): ) raise HTTPException(status_code=500, detail=str(e)) +@app.post("/v1/func_calling", tags=["Function Calling"]) +async def function_calling(request: FunctionCallRequest): + try: + json_response = model.function_calling(messages=request.messages, tools=request.tools) + + return { + "created": time.time(), + "response": json_response + } + except Exception as e: + logging.error(f"Error in function calling: {e}") + raise HTTPException(status_code=500, detail=str(e)) + @app.post("/v1/img2img", tags=["Computer Vision"]) async def img2img(request: ImageGenerationRequest): try: