From 6e55a741a15b6e86d0db62f6a6a26b6c7e23002f Mon Sep 17 00:00:00 2001 From: mxdlzg Date: Thu, 24 Oct 2024 10:14:35 +0800 Subject: [PATCH] feat: update Gemini adaptor to support custom response format --- relay/adaptor/gemini/main.go | 17 ++++++++++++++++- relay/adaptor/gemini/model.go | 14 ++++++++------ 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/relay/adaptor/gemini/main.go b/relay/adaptor/gemini/main.go index 51fd6aa801..d6ab45d489 100644 --- a/relay/adaptor/gemini/main.go +++ b/relay/adaptor/gemini/main.go @@ -4,11 +4,12 @@ import ( "bufio" "encoding/json" "fmt" - "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" + "github.com/songquanpeng/one-api/common/render" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" @@ -28,6 +29,11 @@ const ( VisionMaxImageNum = 16 ) +var mimeTypeMap = map[string]string{ + "json_object": "application/json", + "text": "text/plain", +} + // Setting safety to the lowest possible values since Gemini is already powerless enough func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { geminiRequest := ChatRequest{ @@ -56,6 +62,15 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { MaxOutputTokens: textRequest.MaxTokens, }, } + if textRequest.ResponseFormat != nil { + if mimeType, ok := mimeTypeMap[textRequest.ResponseFormat.Type]; ok { + geminiRequest.GenerationConfig.ResponseMimeType = mimeType + } + if textRequest.ResponseFormat.JsonSchema != nil { + geminiRequest.GenerationConfig.ResponseSchema = textRequest.ResponseFormat.JsonSchema.Schema + geminiRequest.GenerationConfig.ResponseMimeType = mimeTypeMap["json_object"] + } + } if textRequest.Tools != nil { functions := make([]model.Function, 0, len(textRequest.Tools)) for _, tool := range textRequest.Tools { diff --git a/relay/adaptor/gemini/model.go b/relay/adaptor/gemini/model.go index f7179ea48e..f6a3b25042 100644 --- a/relay/adaptor/gemini/model.go +++ b/relay/adaptor/gemini/model.go @@ -65,10 +65,12 @@ type ChatTools struct { } type ChatGenerationConfig struct { - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK float64 `json:"topK,omitempty"` - MaxOutputTokens int `json:"maxOutputTokens,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - StopSequences []string `json:"stopSequences,omitempty"` + ResponseMimeType string `json:"responseMimeType,omitempty"` + ResponseSchema any `json:"responseSchema,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK float64 `json:"topK,omitempty"` + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` }