diff --git a/relay/adaptor/doubao/main.go b/relay/adaptor/doubao/main.go index ea26e6ba48..dd43d06c53 100644 --- a/relay/adaptor/doubao/main.go +++ b/relay/adaptor/doubao/main.go @@ -7,8 +7,12 @@ import ( ) func GetRequestURL(meta *meta.Meta) (string, error) { - if meta.Mode == relaymode.ChatCompletions { + switch meta.Mode { + case relaymode.ChatCompletions: return fmt.Sprintf("%s/api/v3/chat/completions", meta.BaseURL), nil + case relaymode.Embeddings: + return fmt.Sprintf("%s/api/v3/embeddings", meta.BaseURL), nil + default: } return "", fmt.Errorf("unsupported relay mode %d for doubao", meta.Mode) } diff --git a/relay/controller/text.go b/relay/controller/text.go index 0d3c56b07d..52ee9949ae 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -10,6 +10,7 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay" + "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/apitype" "github.com/songquanpeng/one-api/relay/billing" @@ -31,9 +32,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { meta.IsStream = textRequest.Stream // map model name - var isModelMapped bool meta.OriginModelName = textRequest.Model - textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping) + textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping) meta.ActualModelName = textRequest.Model // get model ratio & group ratio modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType) @@ -55,30 +55,9 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { adaptor.Init(meta) // get request body - var requestBody io.Reader - if meta.APIType == apitype.OpenAI { - // no need to convert request for openai - shouldResetRequestBody := isModelMapped || meta.ChannelType == channeltype.Baichuan // frequency_penalty 0 is not acceptable for baichuan - if shouldResetRequestBody { - jsonStr, err := json.Marshal(textRequest) - if err != nil { - return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - } else { - requestBody = c.Request.Body - } - } else { - convertedRequest, err := adaptor.ConvertRequest(c, meta.Mode, textRequest) - if err != nil { - return openai.ErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError) - } - jsonData, err := json.Marshal(convertedRequest) - if err != nil { - return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) - } - logger.Debugf(ctx, "converted request: \n%s", string(jsonData)) - requestBody = bytes.NewBuffer(jsonData) + requestBody, err := getRequestBody(c, meta, textRequest, adaptor) + if err != nil { + return openai.ErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError) } // do request @@ -103,3 +82,26 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio) return nil } + +func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) { + if meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan { + // no need to convert request for openai + return c.Request.Body, nil + } + + // get request body + var requestBody io.Reader + convertedRequest, err := adaptor.ConvertRequest(c, meta.Mode, textRequest) + if err != nil { + logger.Debugf(c.Request.Context(), "converted request failed: %s\n", err.Error()) + return nil, err + } + jsonData, err := json.Marshal(convertedRequest) + if err != nil { + logger.Debugf(c.Request.Context(), "converted request json_marshal_failed: %s\n", err.Error()) + return nil, err + } + logger.Debugf(c.Request.Context(), "converted request: \n%s", string(jsonData)) + requestBody = bytes.NewBuffer(jsonData) + return requestBody, nil +}