diff --git a/neo/assistant/api.go b/neo/assistant/api.go index 8ca92147e9..e6fa029854 100644 --- a/neo/assistant/api.go +++ b/neo/assistant/api.go @@ -51,10 +51,11 @@ func (ast *Assistant) Execute(c *gin.Context, ctx chatctx.Context, input string, return err } + contents := chatMessage.NewContents() options = ast.withOptions(options) // Run init hook - res, err := ast.HookInit(c, ctx, messages, options) + res, err := ast.HookInit(c, ctx, messages, options, contents) if err != nil { chatMessage.New(). Assistant(ast.ID, ast.Name, ast.Avatar). @@ -94,7 +95,7 @@ func (ast *Assistant) Execute(c *gin.Context, ctx chatctx.Context, input string, } // Only proceed with chat stream if no specific next action was handled - return ast.handleChatStream(c, ctx, messages, options) + return ast.handleChatStream(c, ctx, messages, options, contents) } // Execute the next action @@ -170,10 +171,9 @@ func (next *NextAction) Execute(c *gin.Context, ctx chatctx.Context) error { } // handleChatStream manages the streaming chat interaction with the AI -func (ast *Assistant) handleChatStream(c *gin.Context, ctx chatctx.Context, messages []chatMessage.Message, options map[string]interface{}) error { +func (ast *Assistant) handleChatStream(c *gin.Context, ctx chatctx.Context, messages []chatMessage.Message, options map[string]interface{}, contents *chatMessage.Contents) error { clientBreak := make(chan bool, 1) done := make(chan bool, 1) - contents := chatMessage.NewContents() // Chat with AI in background go func() { @@ -220,7 +220,7 @@ func (ast *Assistant) streamChat( // Handle error if msg.Type == "error" { value := msg.String() - res, hookErr := ast.HookFail(c, ctx, messages, contents.JSON(), fmt.Errorf("%s", value)) + res, hookErr := ast.HookFail(c, ctx, messages, fmt.Errorf("%s", value), contents) if hookErr == nil && res != nil && (res.Output != "" || res.Error != "") { value = res.Output if res.Error != "" { @@ -236,7 +236,7 @@ func (ast *Assistant) streamChat( value := msg.String() if value != "" { // Handle stream - res, err := ast.HookStream(c, ctx, messages, contents.Data) + res, err := ast.HookStream(c, ctx, messages, contents) if err == nil && res != nil { if res.Next != nil { @@ -271,7 +271,7 @@ func (ast *Assistant) streamChat( // msg.Write(c.Writer) // } - res, hookErr := ast.HookDone(c, ctx, messages, contents.Data) + res, hookErr := ast.HookDone(c, ctx, messages, contents) if hookErr == nil && res != nil { if res.Output != nil { chatMessage.New(). diff --git a/neo/assistant/hooks.go b/neo/assistant/hooks.go index ed9c32cf07..dcfa2c6714 100644 --- a/neo/assistant/hooks.go +++ b/neo/assistant/hooks.go @@ -7,17 +7,20 @@ import ( "github.com/gin-gonic/gin" jsoniter "github.com/json-iterator/go" + "github.com/yaoapp/gou/runtime/v8/bridge" chatctx "github.com/yaoapp/yao/neo/context" "github.com/yaoapp/yao/neo/message" + chatMessage "github.com/yaoapp/yao/neo/message" + "rogchap.com/v8go" ) // HookInit initialize the assistant -func (ast *Assistant) HookInit(c *gin.Context, context chatctx.Context, input []message.Message, options map[string]interface{}) (*ResHookInit, error) { +func (ast *Assistant) HookInit(c *gin.Context, context chatctx.Context, input []message.Message, options map[string]interface{}, contents *message.Contents) (*ResHookInit, error) { // Create timeout context ctx, cancel := ast.createTimeoutContext(c) defer cancel() - v, err := ast.call(ctx, "Init", context, input, c.Writer) + v, err := ast.call(ctx, "Init", c, contents, context, input) if err != nil { if err.Error() == HookErrorMethodNotFound { return nil, nil @@ -69,13 +72,13 @@ func (ast *Assistant) HookInit(c *gin.Context, context chatctx.Context, input [] } // HookStream Handle streaming response from LLM -func (ast *Assistant) HookStream(c *gin.Context, context chatctx.Context, input []message.Message, output []message.Data) (*ResHookStream, error) { +func (ast *Assistant) HookStream(c *gin.Context, context chatctx.Context, input []message.Message, contents *chatMessage.Contents) (*ResHookStream, error) { // Create timeout context ctx, cancel := ast.createTimeoutContext(c) defer cancel() - v, err := ast.call(ctx, "Stream", context, input, output, c.Writer) + v, err := ast.call(ctx, "Stream", c, contents, context, input) if err != nil { if err.Error() == HookErrorMethodNotFound { return nil, nil @@ -133,12 +136,11 @@ func (ast *Assistant) HookStream(c *gin.Context, context chatctx.Context, input } // HookDone Handle completion of assistant response -func (ast *Assistant) HookDone(c *gin.Context, context chatctx.Context, input []message.Message, output []message.Data) (*ResHookDone, error) { +func (ast *Assistant) HookDone(c *gin.Context, context chatctx.Context, input []message.Message, contents *chatMessage.Contents) (*ResHookDone, error) { // Create timeout context - ctx, cancel := ast.createTimeoutContext(c) - defer cancel() + ctx := ast.createBackgroundContext() - v, err := ast.call(ctx, "Done", context, input, output, c.Writer) + v, err := ast.call(ctx, "Done", c, contents, context, input) if err != nil { if err.Error() == HookErrorMethodNotFound { return nil, nil @@ -148,7 +150,7 @@ func (ast *Assistant) HookDone(c *gin.Context, context chatctx.Context, input [] response := &ResHookDone{ Input: input, - Output: output, + Output: contents.Data, } switch v := v.(type) { @@ -194,12 +196,12 @@ func (ast *Assistant) HookDone(c *gin.Context, context chatctx.Context, input [] } // HookFail Handle failure of assistant response -func (ast *Assistant) HookFail(c *gin.Context, context chatctx.Context, input []message.Message, output string, err error) (*ResHookFail, error) { +func (ast *Assistant) HookFail(c *gin.Context, context chatctx.Context, input []message.Message, err error, contents *chatMessage.Contents) (*ResHookFail, error) { // Create timeout context ctx, cancel := ast.createTimeoutContext(c) defer cancel() - v, callErr := ast.call(ctx, "Fail", context, input, output, err.Error(), c.Writer) + v, callErr := ast.call(ctx, "Fail", c, contents, context, input, err.Error()) if callErr != nil { if callErr.Error() == HookErrorMethodNotFound { return nil, nil @@ -209,7 +211,7 @@ func (ast *Assistant) HookFail(c *gin.Context, context chatctx.Context, input [] response := &ResHookFail{ Input: input, - Output: output, + Output: contents.Text(), Error: err.Error(), } @@ -243,8 +245,13 @@ func (ast *Assistant) createTimeoutContext(c *gin.Context) (context.Context, con return ctx, cancel } +// createBackgroundContext creates a background context +func (ast *Assistant) createBackgroundContext() context.Context { + return context.Background() +} + // Call the script method -func (ast *Assistant) call(ctx context.Context, method string, context chatctx.Context, args ...any) (interface{}, error) { +func (ast *Assistant) call(ctx context.Context, method string, c *gin.Context, contents *chatMessage.Contents, context chatctx.Context, args ...any) (interface{}, error) { if ast.Script == nil { return nil, nil } @@ -255,6 +262,44 @@ func (ast *Assistant) call(ctx context.Context, method string, context chatctx.C } defer scriptCtx.Close() + // Add sendMessage function to the script context + scriptCtx.WithFunction("SendMessage", func(info *v8go.FunctionCallbackInfo) *v8go.Value { + + // Get the message + args := info.Args() + if len(args) < 1 { + return bridge.JsException(info.Context(), "SendMessage requires at least one argument") + } + + input, err := bridge.GoValue(args[0], info.Context()) + if err != nil { + return bridge.JsException(info.Context(), err.Error()) + } + + switch v := input.(type) { + case string: + // Check if the message is json + msg, err := message.NewString(v) + if err != nil { + return bridge.JsException(info.Context(), err.Error()) + } + + // Append the message to the contents + msg.AppendTo(contents) + msg.Write(c.Writer) + return nil + + case map[string]interface{}: + msg := message.New().Map(v) + msg.AppendTo(contents) + msg.Write(c.Writer) + return nil + + default: + return bridge.JsException(info.Context(), "SendMessage requires a string or a map") + } + }) + // Check if the method exists if !scriptCtx.Global().Has(method) { return nil, fmt.Errorf(HookErrorMethodNotFound) diff --git a/neo/assistant/types.go b/neo/assistant/types.go index f18ad02871..a6c167c6bd 100644 --- a/neo/assistant/types.go +++ b/neo/assistant/types.go @@ -25,7 +25,6 @@ type API interface { Download(ctx context.Context, fileID string) (*FileResponse, error) ReadBase64(ctx context.Context, fileID string) (string, error) Execute(c *gin.Context, ctx chatctx.Context, input string, options map[string]interface{}) error - HookInit(c *gin.Context, ctx chatctx.Context, input []message.Message, options map[string]interface{}) (*ResHookInit, error) } // ResHookInit the response of the init hook diff --git a/neo/message/contents.go b/neo/message/contents.go index f7a6440aff..aa5ca064ee 100644 --- a/neo/message/contents.go +++ b/neo/message/contents.go @@ -21,11 +21,12 @@ type Contents struct { // Data the data of the content type Data struct { - Type string `json:"type"` // text, function, error, ... - ID string `json:"id"` // the id of the content - Function string `json:"function"` // the function name - Bytes []byte `json:"bytes"` // the content bytes - Arguments []byte `json:"arguments"` // the function arguments + Type string `json:"type"` // text, function, error, ... + ID string `json:"id"` // the id of the content + Function string `json:"function"` // the function name + Bytes []byte `json:"bytes"` // the content bytes + Arguments []byte `json:"arguments"` // the function arguments + Props map[string]interface{} `json:"props"` // the props } // NewContents create a new contents @@ -57,6 +58,28 @@ func (c *Contents) NewFunction(function string, arguments []byte) *Contents { return c } +// NewType create a new type data and append to the contents +func (c *Contents) NewType(typ string, props map[string]interface{}) *Contents { + c.Data = append(c.Data, Data{ + Type: typ, + Props: props, + }) + c.Current++ + return c +} + +// UpdateType update the type of the current content +func (c *Contents) UpdateType(typ string, props map[string]interface{}) *Contents { + if c.Current == -1 { + c.NewType(typ, props) + return c + } + + c.Data[c.Current].Type = typ + c.Data[c.Current].Props = props + return c +} + // SetFunctionID set the id of the current function content func (c *Contents) SetFunctionID(id string) *Contents { if c.Current == -1 { @@ -128,10 +151,14 @@ func (data *Data) Map() (map[string]interface{}, error) { v["id"] = data.ID } - if data.Bytes != nil { + if data.Bytes != nil && data.Type == "text" { v["text"] = string(data.Bytes) } + if data.Props != nil && data.Type != "text" { + v["props"] = data.Props + } + if data.Arguments != nil { var vv interface{} = nil err := jsoniter.Unmarshal(data.Arguments, &vv) @@ -157,10 +184,14 @@ func (data *Data) MarshalJSON() ([]byte, error) { v["id"] = data.ID } - if data.Bytes != nil { + if data.Bytes != nil && data.Type == "text" { v["text"] = string(data.Bytes) } + if data.Props != nil && data.Type != "text" { + v["props"] = data.Props + } + if data.Arguments != nil { var vv interface{} = nil err := jsoniter.Unmarshal(data.Arguments, &vv) diff --git a/neo/message/message.go b/neo/message/message.go index e632a4b38a..a659d7a6a8 100644 --- a/neo/message/message.go +++ b/neo/message/message.go @@ -86,6 +86,9 @@ func NewOpenAI(data []byte) *Message { msg := New() text := string(data) data = []byte(strings.TrimPrefix(text, "data: ")) + // fmt.Println("--------------------------------") + // fmt.Println(string(data)) + // fmt.Println("--------------------------------") switch { @@ -187,11 +190,22 @@ func (m *Message) SetContent(content string) *Message { // AppendTo append the contents func (m *Message) AppendTo(contents *Contents) *Message { + // Set type + if m.Type == "" { + m.Type = "text" + } + switch m.Type { case "text": if m.Text != "" { + if m.IsNew { + contents.NewText([]byte(m.Text)) + return m + } contents.AppendText([]byte(m.Text)) + return m } + return m case "tool_calls": @@ -206,8 +220,20 @@ func (m *Message) AppendTo(contents *Contents) *Message { } contents.AppendFunction([]byte(m.Text)) + return m + + case "loading": + return m + + default: + if m.IsNew { + contents.NewType(m.Type, m.Props) + return m + } + contents.UpdateType(m.Type, m.Props) + return m } - return m + } // Content get the content @@ -271,8 +297,11 @@ func (m *Message) Map(msg map[string]interface{}) *Message { if done, ok := msg["done"].(bool); ok { m.IsDone = done } + if props, ok := msg["props"].(map[string]interface{}); ok { + m.Props = props + } - if isNew, ok := msg["is_new"].(bool); ok { + if isNew, ok := msg["new"].(bool); ok { m.IsNew = isNew }