Skip to content

Commit

Permalink
refactor: gemini tools 功能属于beta阶段,属性时常变动且不可用,改为本地提示实现
Browse files Browse the repository at this point in the history
  • Loading branch information
bincooo committed Sep 19, 2024
1 parent 30d1c4b commit 5efe9a9
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 83 deletions.
1 change: 1 addition & 0 deletions example.config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ hf:

# gemini 自定义安全设置
google:
tc: false # 是否使用提示词实现的toolCall
# safes:
# - category: HARM_CATEGORY_HARASSMENT
# threshold: BLOCK_NONE
Expand Down
12 changes: 12 additions & 0 deletions internal/plugin/llm/gemini/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"chatgpt-adapter/internal/plugin"
"chatgpt-adapter/internal/vars"
"chatgpt-adapter/logger"
"chatgpt-adapter/pkg"
"encoding/json"
"errors"
"net/url"
Expand Down Expand Up @@ -99,6 +100,17 @@ func (API) Completion(ctx *gin.Context) {
return
}

tc := pkg.Config.GetBool("google.tc")
if tc && plugin.NeedToToolCall(ctx) {
if completeToolCalls(ctx, cookie, proxies, completion) {
return
}
}
if tc {
completion.Tools = nil
completion.ToolChoice = nil
}

ctx.Set(ginTokens, tokens)
r, err := build(common.GetGinContext(ctx), proxies, cookie, newMessages, completion)
if err != nil {
Expand Down
108 changes: 28 additions & 80 deletions internal/plugin/llm/gemini/fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,22 +59,13 @@ var (
}

emp = map[string]interface{}{
"-": map[string]string{
"type": "string",
"em": map[string]string{
"type": "string",
"description": "empty str",
},
}
)

type funcDecl struct {
Name string `json:"name"`
Description string `json:"description"`
Params *struct {
Properties map[string]interface{} `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
Type string `json:"type"`
} `json:"parameters,omitempty"`
}

func init() {
common.AddInitialized(initSafetySettings)
}
Expand All @@ -99,96 +90,53 @@ func build(ctx context.Context, proxies, token string, messages []map[string]int
completion.TopP = 0.95
}

toStrings := func(slice []interface{}) (values []string) {
values = make([]string, 0)
for _, v := range slice {
values = append(values, v.(string))
}
return
}

condition := func(str string) string {
switch str {
case "string", "boolean", "number":
return str
default:
if strings.HasPrefix(str, "array") {
return "array"
}
return "object"
}
}

// fix: type 枚举必须符合google定义,否则报错400
// https://ai.google.dev/api/rest/v1beta/Schema?hl=zh-cn#type
var fix func(keyv pkg.Keyv[interface{}]) (pkg.Keyv[interface{}], bool)
// beta功能,时常变动. 且十分不稳定,相同参数却反复出现 "500 Internal Server Error"
var fix func(pkg.Keyv[interface{}])
{
fix = func(properties pkg.Keyv[interface{}]) (pkg.Keyv[interface{}], bool) {
if properties == nil {
return nil, false
fix = func(parameters pkg.Keyv[interface{}]) {
if parameters == nil {
return
}

if properties.Has("type") {
properties.Set("type", condition(properties.GetString("type")))
// object 的 properties 不可以为空 key = {}
if !parameters.Is("type", "object") {
return
}

hasKeys := false
properties := parameters.GetKeyv("properties")
for range properties {
hasKeys = true
break
}

if !hasKeys {
// object 类型不允许空keyv
properties.Set("properties", emp)
return properties, false
parameters.Set("properties", emp)
return
}

for key := range properties {
keyv := properties.GetKeyv(key)
if keyv.Has("type") {
keyv.Set("type", condition(keyv.GetString("type")))
}
value, _ := fix(keyv.GetKeyv("properties"))
if value == nil {
if keyv.Is("type", "object") {
keyv.Set("properties", emp)
}
} else {
keyv.Set("properties", value)
if !keyv.Is("type", "object") {
continue
}
fix(keyv.GetKeyv("properties"))
}

return properties, hasKeys
return
}
}

// 参数基本与openai对齐
_funcDecls := make([]funcDecl, 0)
funcDecls := make([]pkg.Keyv[interface{}], 0)
if toolsL := len(completion.Tools); toolsL > 0 {
for _, v := range completion.Tools {
kv := v.GetKeyv("function").GetKeyv("parameters")
required := kv.GetSlice("required")
fd := funcDecl{
// 必须为 a-z、A-Z、0-9,或包含下划线和短划线,长度上限为 63 个字符
Name: strings.Replace(v.GetKeyv("function").GetString("name"), "-", "_", -1),
Description: v.GetKeyv("function").GetString("description"),
}

props, hasKeys := fix(kv.GetKeyv("properties"))
if hasKeys {
fd.Params = &struct {
Properties map[string]interface{} `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
Type string `json:"type"`
}{
Properties: props,
Required: toStrings(required),
Type: condition(kv.GetString("type")),
}
kv := v.GetKeyv("function")
{
fix(kv.GetKeyv("parameters"))
funcDecls = append(funcDecls, kv)
}

_funcDecls = append(_funcDecls, fd)
}
}

Expand Down Expand Up @@ -239,22 +187,22 @@ func build(ctx context.Context, proxies, token string, messages []map[string]int
}
}

if len(_funcDecls) > 0 && completion.Model != "gemini-1.5-pro-exp-0801" {
if len(funcDecls) > 0 && completion.Model != "gemini-1.5-pro-exp-0801" {
// 函数调用
payload["tools"] = []map[string]interface{}{
{
"functionDeclarations": _funcDecls,
"function_declarations": funcDecls,
},
}
// tool_choice
if tc, ok := completion.ToolChoice.(map[string]interface{}); ok {
var toolChoice pkg.Keyv[interface{}] = tc
if toolChoice.Is("type", "function") {
f := toolChoice.GetKeyv("function")
payload["toolConfig"] = map[string]interface{}{
"functionCallingConfig": map[string]interface{}{
payload["tool_config"] = map[string]interface{}{
"function_calling_config": map[string]interface{}{
"mode": "ANY",
"allowedFunctionNames": []string{
"allowed_function_names": []string{
f.GetString("name"),
},
},
Expand Down
74 changes: 74 additions & 0 deletions internal/plugin/llm/gemini/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,80 @@ func waitResponse(ctx *gin.Context, matchers []common.Matcher, partialResponse *
return
}

func waitMessage(partialResponse *http.Response, cancel func(str string) bool) (string, error) {
defer partialResponse.Body.Close()
reader := bufio.NewReader(partialResponse.Body)
var original []byte
var block = []byte("data: ")
content := ""

for {
line, hm, err := reader.ReadLine()
original = append(original, line...)
if hm {
continue
}

if err == io.EOF {
break
}

if err != nil {
return "", err
}

if len(original) == 0 {
continue
}

if bytes.Contains(original, []byte(`"error":`)) {
return "", fmt.Errorf("%s", original)
}

if !bytes.HasPrefix(original, block) {
continue
}

var c candidatesResponse
original = bytes.TrimPrefix(original, block)
if err = json.Unmarshal(original, &c); err != nil {
logger.Error(err)
continue
}

if len(c.Candidates) == 0 {
continue
}

cond := c.Candidates[0]
if cond.Content.Role != "model" {
original = nil
continue
}

if len(cond.Content.Parts) == 0 {
continue
}

raw, ok := cond.Content.Parts[0]["text"]
if !ok {
original = nil
continue
}

original = nil
if len(raw.(string)) == 0 {
continue
}

if cancel != nil && cancel(raw.(string)) {
return content + raw.(string), nil
}
content += raw.(string)
}
return content, nil
}

func mergeMessages(messages []pkg.Keyv[interface{}]) (newMessages []map[string]interface{}, tokens int, err error) {
// role类型转换
condition := func(expr string) string {
Expand Down
62 changes: 62 additions & 0 deletions internal/plugin/llm/gemini/toolcall.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package gemini

import (
"chatgpt-adapter/internal/common"
"chatgpt-adapter/internal/gin.handler/response"
"chatgpt-adapter/internal/plugin"
"chatgpt-adapter/internal/vars"
"chatgpt-adapter/logger"
"chatgpt-adapter/pkg"
"encoding/json"
"github.com/gin-gonic/gin"
"net/http"
"strings"
)

func completeToolCalls(ctx *gin.Context, cookie, proxies string, completion pkg.ChatCompletion) bool {
logger.Info("completeTools ...")
echo := ctx.GetBool(vars.GinEcho)

exec, err := plugin.CompleteToolCalls(ctx, completion, func(message string) (string, error) {
message = strings.TrimSpace(message)
var messages []map[string]interface{}
messages = append(messages, map[string]interface{}{
"role": "user",
"parts": []interface{}{
map[string]string{
"text": message,
},
},
})

if echo {
bytes, _ := json.MarshalIndent(messages, "", " ")
logger.Infof("toolCall message: \n%s", bytes)
return "", nil
}

completion.Tools = nil
completion.ToolChoice = nil
r, err := build(common.GetGinContext(ctx), proxies, cookie, messages, completion)
if err != nil {
return "", err
}

return waitMessage(r, plugin.ToolCallCancel)
})

if err != nil {
errMessage := err.Error()
if strings.Contains(errMessage, "Login verification is invalid") {
logger.Error(err)
response.Error(ctx, http.StatusUnauthorized, errMessage)
return true
}

logger.Error(err)
response.Error(ctx, -1, errMessage)
return true
}

return exec
}
8 changes: 5 additions & 3 deletions internal/plugin/toolcall.go
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,12 @@ func buildTemplate(ctx *gin.Context, completion pkg.ChatCompletion, template str
}

regMap := map[*regexp.Regexp]string{
regexp.MustCompile(`<\|system\|>\n *\n<\|end\|>`): "",
regexp.MustCompile(`<\|user\|>\n *\n<\|end\|>`): "",
regexp.MustCompile(`<\|assistant\|>\n *\n<\|end\|>`): "",
regexp.MustCompile(`<\|system\|>[\n|\s]+<\|end\|>`): "",
regexp.MustCompile(`<\|user\|>[\n|\s]+<\|end\|>`): "",
regexp.MustCompile(`<\|assistant\|>[\n|\s]+<\|end\|>`): "",
regexp.MustCompile(`<\|<no value>\|>\n<no value>\n<\|end\|>`): "",
regexp.MustCompile(`\n{3}`): "\n",
regexp.MustCompile(`\n{2,}<\|end\|>`): "\n<|end|>",
}
for reg, v := range regMap {
str = reg.ReplaceAllString(str, v)
Expand Down

0 comments on commit 5efe9a9

Please sign in to comment.