Skip to content

Commit

Permalink
fixed server system prompt cannot work bug
Browse files Browse the repository at this point in the history
  • Loading branch information
zhycheng614 committed Jan 4, 2025
1 parent 282366f commit 3a8afac
Showing 1 changed file with 14 additions and 19 deletions.
33 changes: 14 additions & 19 deletions nexa/gguf/server/nexa_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@
chat_format = None
completion_template = None
hostname = socket.gethostname()
chat_completion_system_prompt = [{"role": "system", "content": "You are a helpful assistant"}]
function_call_system_prompt = [{"role": "system", "content": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"}]
default_chat_completion_system_prompt = [{"role": "system", "content": "You are a helpful assistant"}]
default_function_call_system_prompt = [{"role": "system", "content": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"}]
model_path = None
whisper_model_path = "faster-whisper-tiny" # by default, use tiny whisper model
n_ctx = None
Expand Down Expand Up @@ -495,7 +495,7 @@ async def load_whisper_model(custom_whisper_model_path=None):
raise ValueError(f"Failed to load Whisper model: {str(e)}")

def nexa_run_text_generation(
prompt, temperature, stop_words, max_new_tokens, top_k, top_p, logprobs=None, stream=False, is_chat_completion=True
prompt, temperature, stop_words, max_new_tokens, top_k, top_p, messages=[], logprobs=None, stream=False, is_chat_completion=True, **kwargs
) -> Dict[str, Any]:
global model, chat_format, completion_template
if model is None:
Expand All @@ -506,9 +506,10 @@ def nexa_run_text_generation(

if is_chat_completion:
if is_local_path or is_huggingface or is_modelscope: # do not add system prompt if local path or huggingface or modelscope
messages = [{"role": "user", "content": prompt}]
pass
else:
messages = chat_completion_system_prompt + [{"role": "user", "content": prompt}]
if messages[0]['role'] != 'system':
messages = default_chat_completion_system_prompt + messages

params = {
'messages': messages,
Expand Down Expand Up @@ -1113,24 +1114,18 @@ async def text_chat_completions(request: ChatCompletionRequest):
detail="The model that is loaded is not an NLP model. Please use an NLP model for text chat completion."
)

generation_kwargs = GenerationRequest(
prompt="" if len(request.messages) == 0 else request.messages[-1].content,
temperature=request.temperature,
max_new_tokens=request.max_tokens,
stop_words=request.stop_words,
logprobs=request.logprobs,
top_logprobs=request.top_logprobs,
stream=request.stream,
top_k=request.top_k,
top_p=request.top_p
).dict()
if not request.messages:
raise HTTPException(
status_code=400,
detail="No messages provided in the request."
)

if request.stream:
start_time = time.perf_counter()
streamer = nexa_run_text_generation(is_chat_completion=True, **generation_kwargs)
streamer = nexa_run_text_generation(None, max_new_tokens=request.max_tokens, is_chat_completion=True, **request.dict())
return StreamingResponse(_resp_async_generator(streamer, start_time), media_type="application/x-ndjson")

result = nexa_run_text_generation(is_chat_completion=True, **generation_kwargs)
result = nexa_run_text_generation(None, max_new_tokens=request.max_tokens, is_chat_completion=True, **request.dict())
return {
"id": str(uuid.uuid4()),
"object": "chat.completion",
Expand Down Expand Up @@ -1335,7 +1330,7 @@ async def function_call(request: FunctionCallRequest):
status_code=400,
detail="The model that is loaded is not an NLP model. Please use an NLP model for function calling."
)
messages = function_call_system_prompt + [
messages = default_function_call_system_prompt + [
{"role": msg.role, "content": msg.content} for msg in request.messages
]
tools = [tool.dict() for tool in request.tools]
Expand Down

0 comments on commit 3a8afac

Please sign in to comment.