Skip to content

Commit

Permalink
fix: 豆包支持embeddings
Browse files Browse the repository at this point in the history
Fixes #1594
  • Loading branch information
igophper authored and 江杭辉 committed Jul 22, 2024
1 parent 6209ff9 commit 2c958af
Show file tree
Hide file tree
Showing 13 changed files with 143 additions and 62 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Publish Docker image (English)
name: Publish Docker image (amd64, English)

on:
push:
Expand Down Expand Up @@ -51,7 +51,6 @@ jobs:
uses: docker/build-push-action@v3
with:
context: .
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Publish Docker image
name: Publish Docker image (amd64)

on:
push:
Expand Down Expand Up @@ -56,7 +56,6 @@ jobs:
uses: docker/build-push-action@v3
with:
context: .
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
69 changes: 69 additions & 0 deletions .github/workflows/docker-image-arm64.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
name: Publish Docker image (arm64)

on:
push:
tags:
- 'v*.*.*'
- '!*-alpha*'
workflow_dispatch:
inputs:
name:
description: 'reason'
required: false
jobs:
push_to_registries:
name: Push Docker image to multiple registries
runs-on: ubuntu-latest
permissions:
packages: write
contents: read
steps:
- name: Check out the repo
uses: actions/checkout@v3

- name: Check repository URL
run: |
REPO_URL=$(git config --get remote.origin.url)
if [[ $REPO_URL == *"pro" ]]; then
exit 1
fi
- name: Save version info
run: |
git describe --tags > VERSION
- name: Set up QEMU
uses: docker/setup-qemu-action@v2

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2

- name: Log in to Docker Hub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}

- name: Log in to the Container registry
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}

- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v4
with:
images: |
justsong/one-api
ghcr.io/${{ github.repository }}
- name: Build and push Docker images
uses: docker/build-push-action@v3
with:
context: .
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM --platform=$BUILDPLATFORM node:16 as builder
FROM node:16 as builder

WORKDIR /web
COPY ./VERSION .
Expand Down
1 change: 0 additions & 1 deletion common/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ var InitialRootAccessToken = os.Getenv("INITIAL_ROOT_ACCESS_TOKEN")

var GeminiVersion = env.String("GEMINI_VERSION", "v1")


var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false)

var RelayProxy = env.String("RELAY_PROXY", "")
Expand Down
13 changes: 11 additions & 2 deletions relay/adaptor/anthropic/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ package anthropic
import (
"errors"
"fmt"
"io"
"net/http"
"strings"

"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
)

type Adaptor struct {
Expand All @@ -31,6 +33,13 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
}
req.Header.Set("anthropic-version", anthropicVersion)
req.Header.Set("anthropic-beta", "messages-2023-12-15")

// https://x.com/alexalbert__/status/1812921642143900036
// claude-3-5-sonnet can support 8k context
if strings.HasPrefix(meta.ActualModelName, "claude-3-5-sonnet") {
req.Header.Set("anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15")
}

return nil
}

Expand Down
6 changes: 5 additions & 1 deletion relay/adaptor/doubao/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@ import (
)

func GetRequestURL(meta *meta.Meta) (string, error) {
if meta.Mode == relaymode.ChatCompletions {
switch meta.Mode {
case relaymode.ChatCompletions:
return fmt.Sprintf("%s/api/v3/chat/completions", meta.BaseURL), nil
case relaymode.Embeddings:
return fmt.Sprintf("%s/api/v3/embeddings", meta.BaseURL), nil
default:
}
return "", fmt.Errorf("unsupported relay mode %d for doubao", meta.Mode)
}
22 changes: 11 additions & 11 deletions relay/adaptor/vertexai/claude/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ import "github.com/songquanpeng/one-api/relay/adaptor/anthropic"

type Request struct {
// AnthropicVersion must be "vertex-2023-10-16"
AnthropicVersion string `json:"anthropic_version"`
AnthropicVersion string `json:"anthropic_version"`
// Model string `json:"model"`
Messages []anthropic.Message `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Tools []anthropic.Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Messages []anthropic.Message `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Tools []anthropic.Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
}
1 change: 0 additions & 1 deletion relay/adaptor/vertexai/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ func init() {
}
}


type innerAIAdapter interface {
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
Expand Down
1 change: 0 additions & 1 deletion relay/adaptor/vertexai/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ type ApplicationDefaultCredentials struct {
UniverseDomain string `json:"universe_domain"`
}


var Cache = cache.New(50*time.Minute, 55*time.Minute)

const defaultScope = "https://www.googleapis.com/auth/cloud-platform"
Expand Down
2 changes: 1 addition & 1 deletion relay/channeltype/url.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ var ChannelBaseURLs = []string{
"https://api.together.xyz", // 39
"https://ark.cn-beijing.volces.com", // 40
"https://api.novita.ai/v3/openai", // 41
"", // 42
"", // 42
}

func init() {
Expand Down
54 changes: 28 additions & 26 deletions relay/controller/text.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/apitype"
"github.com/songquanpeng/one-api/relay/billing"
Expand All @@ -31,9 +32,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
meta.IsStream = textRequest.Stream

// map model name
var isModelMapped bool
meta.OriginModelName = textRequest.Model
textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping)
textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping)
meta.ActualModelName = textRequest.Model
// get model ratio & group ratio
modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType)
Expand All @@ -55,30 +55,9 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
adaptor.Init(meta)

// get request body
var requestBody io.Reader
if meta.APIType == apitype.OpenAI {
// no need to convert request for openai
shouldResetRequestBody := isModelMapped || meta.ChannelType == channeltype.Baichuan // frequency_penalty 0 is not acceptable for baichuan
if shouldResetRequestBody {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
} else {
convertedRequest, err := adaptor.ConvertRequest(c, meta.Mode, textRequest)
if err != nil {
return openai.ErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
logger.Debugf(ctx, "converted request: \n%s", string(jsonData))
requestBody = bytes.NewBuffer(jsonData)
requestBody, err := getRequestBody(c, meta, textRequest, adaptor)
if err != nil {
return openai.ErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)
}

// do request
Expand All @@ -103,3 +82,26 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
return nil
}

func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) {
if meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan {
// no need to convert request for openai
return c.Request.Body, nil
}

// get request body
var requestBody io.Reader
convertedRequest, err := adaptor.ConvertRequest(c, meta.Mode, textRequest)
if err != nil {
logger.Debugf(c.Request.Context(), "converted request failed: %s\n", err.Error())
return nil, err
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
logger.Debugf(c.Request.Context(), "converted request json_marshal_failed: %s\n", err.Error())
return nil, err
}
logger.Debugf(c.Request.Context(), "converted request: \n%s", string(jsonData))
requestBody = bytes.NewBuffer(jsonData)
return requestBody, nil
}
28 changes: 15 additions & 13 deletions relay/meta/relay_meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,22 @@ import (
)

type Meta struct {
Mode int
ChannelType int
ChannelId int
TokenId int
TokenName string
UserId int
Group string
ModelMapping map[string]string
BaseURL string
APIKey string
APIType int
Config model.ChannelConfig
IsStream bool
Mode int
ChannelType int
ChannelId int
TokenId int
TokenName string
UserId int
Group string
ModelMapping map[string]string
BaseURL string
APIKey string
APIType int
Config model.ChannelConfig
IsStream bool
// OriginModelName is the model name from the raw user request
OriginModelName string
// ActualModelName is the model name after mapping
ActualModelName string
RequestURLPath string
PromptTokens int // only for DoResponse
Expand Down

0 comments on commit 2c958af

Please sign in to comment.