Skip to content

Commit

Permalink
fix: 豆包支持embeddings
Browse files Browse the repository at this point in the history
Fixes #1594
  • Loading branch information
igophper authored and 江杭辉 committed Jul 22, 2024
1 parent 2a892c1 commit 34cc3e6
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 27 deletions.
6 changes: 5 additions & 1 deletion relay/adaptor/doubao/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@ import (
)

func GetRequestURL(meta *meta.Meta) (string, error) {
if meta.Mode == relaymode.ChatCompletions {
switch meta.Mode {
case relaymode.ChatCompletions:

Check warning on line 11 in relay/adaptor/doubao/main.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/doubao/main.go#L10-L11

Added lines #L10 - L11 were not covered by tests
return fmt.Sprintf("%s/api/v3/chat/completions", meta.BaseURL), nil
case relaymode.Embeddings:
return fmt.Sprintf("%s/api/v3/embeddings", meta.BaseURL), nil
default:

Check warning on line 15 in relay/adaptor/doubao/main.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/doubao/main.go#L13-L15

Added lines #L13 - L15 were not covered by tests
}
return "", fmt.Errorf("unsupported relay mode %d for doubao", meta.Mode)
}
54 changes: 28 additions & 26 deletions relay/controller/text.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/apitype"
"github.com/songquanpeng/one-api/relay/billing"
Expand All @@ -31,9 +32,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
meta.IsStream = textRequest.Stream

// map model name
var isModelMapped bool
meta.OriginModelName = textRequest.Model
textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping)
textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping)

Check warning on line 36 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L36

Added line #L36 was not covered by tests
meta.ActualModelName = textRequest.Model
// get model ratio & group ratio
modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType)
Expand All @@ -55,30 +55,9 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
adaptor.Init(meta)

// get request body
var requestBody io.Reader
if meta.APIType == apitype.OpenAI {
// no need to convert request for openai
shouldResetRequestBody := isModelMapped || meta.ChannelType == channeltype.Baichuan // frequency_penalty 0 is not acceptable for baichuan
if shouldResetRequestBody {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
} else {
convertedRequest, err := adaptor.ConvertRequest(c, meta.Mode, textRequest)
if err != nil {
return openai.ErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
logger.Debugf(ctx, "converted request: \n%s", string(jsonData))
requestBody = bytes.NewBuffer(jsonData)
requestBody, err := getRequestBody(c, meta, textRequest, adaptor)
if err != nil {
return openai.ErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)

Check warning on line 60 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L58-L60

Added lines #L58 - L60 were not covered by tests
}

// do request
Expand All @@ -103,3 +82,26 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
return nil
}

func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) {
if meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan {

Check warning on line 87 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L86-L87

Added lines #L86 - L87 were not covered by tests
// no need to convert request for openai
return c.Request.Body, nil

Check warning on line 89 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L89

Added line #L89 was not covered by tests
}

// get request body
var requestBody io.Reader
convertedRequest, err := adaptor.ConvertRequest(c, meta.Mode, textRequest)
if err != nil {
logger.Debugf(c.Request.Context(), "converted request failed: %s\n", err.Error())
return nil, err

Check warning on line 97 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L93-L97

Added lines #L93 - L97 were not covered by tests
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
logger.Debugf(c.Request.Context(), "converted request json_marshal_failed: %s\n", err.Error())
return nil, err

Check warning on line 102 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L99-L102

Added lines #L99 - L102 were not covered by tests
}
logger.Debugf(c.Request.Context(), "converted request: \n%s", string(jsonData))
requestBody = bytes.NewBuffer(jsonData)
return requestBody, nil

Check warning on line 106 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L104-L106

Added lines #L104 - L106 were not covered by tests
}

0 comments on commit 34cc3e6

Please sign in to comment.