Skip to content

Commit

Permalink
Merge pull request #496 from trheyi/main
Browse files Browse the repository at this point in the history
[add] chat.completions ( dev )
  • Loading branch information
trheyi authored Nov 11, 2023
2 parents 213111a + c84d34a commit 274f926
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 4 deletions.
9 changes: 9 additions & 0 deletions moapi/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ var dsl = []byte(`
"process": "moapi.images.Generations",
"in": ["$payload.model", "$payload.prompt", ":payload"],
"out": { "status": 200, "type": "application/json" }
},
{
"path": "/chat/completions",
"guard": "query-jwt",
"method": "GET",
"process": "moapi.chat.Completions",
"processHandler": true,
"out": { "status": 200, "type": "text/event-stream" }
}
]
}
Expand Down
104 changes: 104 additions & 0 deletions moapi/process.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
package moapi

import (
"context"
"io"
"net/http"
"strings"

"github.com/gin-gonic/gin"
jsoniter "github.com/json-iterator/go"
"github.com/yaoapp/gou/process"
"github.com/yaoapp/kun/exception"
"github.com/yaoapp/kun/utils"
Expand All @@ -10,6 +17,7 @@ import (
func init() {
process.RegisterGroup("moapi", map[string]process.Handler{
"images.generations": ImagesGenerations,
"chat.completions": ChatCompletions,
})
}

Expand Down Expand Up @@ -46,3 +54,99 @@ func ImagesGenerations(process *process.Process) interface{} {

return res
}

// ChatCompletions chat completions
func ChatCompletions(process *process.Process) interface{} {

return func(c *gin.Context) {

option := map[string]interface{}{}
query := c.Query("payload")
err := jsoniter.UnmarshalFromString(query, &option)
if err != nil {
exception.New("ChatCompletions error: %s", 400, err).Throw()
}

// option := payload
// model := "gpt-3.5-turbo"
// messages := []map[string]interface{}{
// {
// "role": "system",
// "content": "You are a helpful assistant.",
// },
// {
// "role": "user",
// "content": "Hello!",
// },
// // }

// option["messages"] = messages
// option["model"] = model

delete(option, "context")
model, ok := option["model"].(string)
if !ok || model == "" {
exception.New("ChatCompletions error: model is required", 400).Throw()
}

ai, err := openai.NewMoapi(model)
if err != nil {
exception.New("ChatCompletions error: %s", 400, err).Throw()
}

if v, ok := option["stream"].(bool); ok && v {

chanStream := make(chan []byte, 1)
chanError := make(chan error, 1)

defer func() {
close(chanStream)
close(chanError)
}()

ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()

go ai.Stream(ctx, "/v1/chat/completions", option, func(data []byte) int {

if (string(data)) == "\n" || string(data) == "" {
return 1 // HandlerReturnOk
}

chanStream <- data
if strings.HasSuffix(string(data), "[DONE]") {
return 0 // HandlerReturnBreak0
}
return 1 // HandlerReturnOk
})

c.Header("Content-Type", "text/event-stream")
c.Stream(func(w io.Writer) bool {
select {
case err := <-chanError:
if err != nil {
c.JSON(http.StatusInternalServerError, err.Error())
}
return false

case msg := <-chanStream:

if string(msg) == "\n" {
return true
}

message := strings.TrimLeft(string(msg), "data: ")
c.SSEvent("message", message)
return true

case <-ctx.Done():
return false
}
})

return
}

return
}
}
10 changes: 10 additions & 0 deletions openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,16 @@ func (openai OpenAI) GetContent(response interface{}) (string, *exception.Except
return "", exception.New("response format error, %#v", 500, response)
}

// Post post request
func (openai OpenAI) Post(path string, payload map[string]interface{}) (interface{}, *exception.Exception) {
return openai.post(path, payload)
}

// Stream post request
func (openai OpenAI) Stream(ctx context.Context, path string, payload map[string]interface{}, cb func(data []byte) int) *exception.Exception {
return openai.stream(ctx, path, payload, cb)
}

// post post request
func (openai OpenAI) post(path string, payload map[string]interface{}) (interface{}, *exception.Exception) {

Expand Down
6 changes: 3 additions & 3 deletions sui/api/process_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func TestBlockGet(t *testing.T) {
}

assert.IsType(t, []core.IBlock{}, res)
assert.Equal(t, 6, len(res.([]core.IBlock)))
assert.Equal(t, 7, len(res.([]core.IBlock)))
assert.Equal(t, "ColumnsTwo", res.([]core.IBlock)[0].(*local.Block).ID)
assert.Equal(t, "Hero", res.([]core.IBlock)[1].(*local.Block).ID)
assert.Equal(t, "Image", res.([]core.IBlock)[2].(*local.Block).ID)
Expand Down Expand Up @@ -172,7 +172,7 @@ func TestBlockExport(t *testing.T) {
t.Fatal(err)
}
assert.IsType(t, &core.BlockLayoutItems{}, res)
assert.Equal(t, 2, len(res.(*core.BlockLayoutItems).Categories))
assert.Equal(t, 3, len(res.(*core.BlockLayoutItems).Categories))
}

func TestBlockMedia(t *testing.T) {
Expand Down Expand Up @@ -210,7 +210,7 @@ func TestTemplateComponentGet(t *testing.T) {
}

assert.IsType(t, []core.IComponent{}, res)
assert.Equal(t, 5, len(res.([]core.IComponent)))
assert.Equal(t, 6, len(res.([]core.IComponent)))
assert.Equal(t, "Box", res.([]core.IComponent)[0].(*local.Component).ID)
assert.Equal(t, "Card", res.([]core.IComponent)[1].(*local.Component).ID)
assert.Equal(t, "Nav", res.([]core.IComponent)[2].(*local.Component).ID)
Expand Down
2 changes: 1 addition & 1 deletion sui/storages/local/block_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func TestBlockLayoutItems(t *testing.T) {
t.Fatalf("BlockLayoutItems error: %v", err)
}

assert.Equal(t, 2, len(items.Categories))
assert.Equal(t, 3, len(items.Categories))

tmpl, err = tests.Demo.GetTemplate("website-ai")
if err != nil {
Expand Down

0 comments on commit 274f926

Please sign in to comment.