Skip to content

Commit

Permalink
Merge pull request #825 from trheyi/main
Browse files Browse the repository at this point in the history
Enhance Neo API assistant with improved content handling and hook int…
  • Loading branch information
trheyi authored Jan 22, 2025
2 parents e54333a + 1eba627 commit 839f3bb
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 30 deletions.
14 changes: 7 additions & 7 deletions neo/assistant/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,11 @@ func (ast *Assistant) Execute(c *gin.Context, ctx chatctx.Context, input string,
return err
}

contents := chatMessage.NewContents()
options = ast.withOptions(options)

// Run init hook
res, err := ast.HookInit(c, ctx, messages, options)
res, err := ast.HookInit(c, ctx, messages, options, contents)
if err != nil {
chatMessage.New().
Assistant(ast.ID, ast.Name, ast.Avatar).
Expand Down Expand Up @@ -94,7 +95,7 @@ func (ast *Assistant) Execute(c *gin.Context, ctx chatctx.Context, input string,
}

// Only proceed with chat stream if no specific next action was handled
return ast.handleChatStream(c, ctx, messages, options)
return ast.handleChatStream(c, ctx, messages, options, contents)
}

// Execute the next action
Expand Down Expand Up @@ -170,10 +171,9 @@ func (next *NextAction) Execute(c *gin.Context, ctx chatctx.Context) error {
}

// handleChatStream manages the streaming chat interaction with the AI
func (ast *Assistant) handleChatStream(c *gin.Context, ctx chatctx.Context, messages []chatMessage.Message, options map[string]interface{}) error {
func (ast *Assistant) handleChatStream(c *gin.Context, ctx chatctx.Context, messages []chatMessage.Message, options map[string]interface{}, contents *chatMessage.Contents) error {
clientBreak := make(chan bool, 1)
done := make(chan bool, 1)
contents := chatMessage.NewContents()

// Chat with AI in background
go func() {
Expand Down Expand Up @@ -220,7 +220,7 @@ func (ast *Assistant) streamChat(
// Handle error
if msg.Type == "error" {
value := msg.String()
res, hookErr := ast.HookFail(c, ctx, messages, contents.JSON(), fmt.Errorf("%s", value))
res, hookErr := ast.HookFail(c, ctx, messages, fmt.Errorf("%s", value), contents)
if hookErr == nil && res != nil && (res.Output != "" || res.Error != "") {
value = res.Output
if res.Error != "" {
Expand All @@ -236,7 +236,7 @@ func (ast *Assistant) streamChat(
value := msg.String()
if value != "" {
// Handle stream
res, err := ast.HookStream(c, ctx, messages, contents.Data)
res, err := ast.HookStream(c, ctx, messages, contents)
if err == nil && res != nil {

if res.Next != nil {
Expand Down Expand Up @@ -271,7 +271,7 @@ func (ast *Assistant) streamChat(
// msg.Write(c.Writer)
// }

res, hookErr := ast.HookDone(c, ctx, messages, contents.Data)
res, hookErr := ast.HookDone(c, ctx, messages, contents)
if hookErr == nil && res != nil {
if res.Output != nil {
chatMessage.New().
Expand Down
71 changes: 58 additions & 13 deletions neo/assistant/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,20 @@ import (

"github.com/gin-gonic/gin"
jsoniter "github.com/json-iterator/go"
"github.com/yaoapp/gou/runtime/v8/bridge"
chatctx "github.com/yaoapp/yao/neo/context"
"github.com/yaoapp/yao/neo/message"
chatMessage "github.com/yaoapp/yao/neo/message"
"rogchap.com/v8go"
)

// HookInit initialize the assistant
func (ast *Assistant) HookInit(c *gin.Context, context chatctx.Context, input []message.Message, options map[string]interface{}) (*ResHookInit, error) {
func (ast *Assistant) HookInit(c *gin.Context, context chatctx.Context, input []message.Message, options map[string]interface{}, contents *message.Contents) (*ResHookInit, error) {
// Create timeout context
ctx, cancel := ast.createTimeoutContext(c)
defer cancel()

v, err := ast.call(ctx, "Init", context, input, c.Writer)
v, err := ast.call(ctx, "Init", c, contents, context, input)
if err != nil {
if err.Error() == HookErrorMethodNotFound {
return nil, nil
Expand Down Expand Up @@ -69,13 +72,13 @@ func (ast *Assistant) HookInit(c *gin.Context, context chatctx.Context, input []
}

// HookStream Handle streaming response from LLM
func (ast *Assistant) HookStream(c *gin.Context, context chatctx.Context, input []message.Message, output []message.Data) (*ResHookStream, error) {
func (ast *Assistant) HookStream(c *gin.Context, context chatctx.Context, input []message.Message, contents *chatMessage.Contents) (*ResHookStream, error) {

// Create timeout context
ctx, cancel := ast.createTimeoutContext(c)
defer cancel()

v, err := ast.call(ctx, "Stream", context, input, output, c.Writer)
v, err := ast.call(ctx, "Stream", c, contents, context, input)
if err != nil {
if err.Error() == HookErrorMethodNotFound {
return nil, nil
Expand Down Expand Up @@ -133,12 +136,11 @@ func (ast *Assistant) HookStream(c *gin.Context, context chatctx.Context, input
}

// HookDone Handle completion of assistant response
func (ast *Assistant) HookDone(c *gin.Context, context chatctx.Context, input []message.Message, output []message.Data) (*ResHookDone, error) {
func (ast *Assistant) HookDone(c *gin.Context, context chatctx.Context, input []message.Message, contents *chatMessage.Contents) (*ResHookDone, error) {
// Create timeout context
ctx, cancel := ast.createTimeoutContext(c)
defer cancel()
ctx := ast.createBackgroundContext()

v, err := ast.call(ctx, "Done", context, input, output, c.Writer)
v, err := ast.call(ctx, "Done", c, contents, context, input)
if err != nil {
if err.Error() == HookErrorMethodNotFound {
return nil, nil
Expand All @@ -148,7 +150,7 @@ func (ast *Assistant) HookDone(c *gin.Context, context chatctx.Context, input []

response := &ResHookDone{
Input: input,
Output: output,
Output: contents.Data,
}

switch v := v.(type) {
Expand Down Expand Up @@ -194,12 +196,12 @@ func (ast *Assistant) HookDone(c *gin.Context, context chatctx.Context, input []
}

// HookFail Handle failure of assistant response
func (ast *Assistant) HookFail(c *gin.Context, context chatctx.Context, input []message.Message, output string, err error) (*ResHookFail, error) {
func (ast *Assistant) HookFail(c *gin.Context, context chatctx.Context, input []message.Message, err error, contents *chatMessage.Contents) (*ResHookFail, error) {
// Create timeout context
ctx, cancel := ast.createTimeoutContext(c)
defer cancel()

v, callErr := ast.call(ctx, "Fail", context, input, output, err.Error(), c.Writer)
v, callErr := ast.call(ctx, "Fail", c, contents, context, input, err.Error())
if callErr != nil {
if callErr.Error() == HookErrorMethodNotFound {
return nil, nil
Expand All @@ -209,7 +211,7 @@ func (ast *Assistant) HookFail(c *gin.Context, context chatctx.Context, input []

response := &ResHookFail{
Input: input,
Output: output,
Output: contents.Text(),
Error: err.Error(),
}

Expand Down Expand Up @@ -243,8 +245,13 @@ func (ast *Assistant) createTimeoutContext(c *gin.Context) (context.Context, con
return ctx, cancel
}

// createBackgroundContext creates a background context
func (ast *Assistant) createBackgroundContext() context.Context {
return context.Background()
}

// Call the script method
func (ast *Assistant) call(ctx context.Context, method string, context chatctx.Context, args ...any) (interface{}, error) {
func (ast *Assistant) call(ctx context.Context, method string, c *gin.Context, contents *chatMessage.Contents, context chatctx.Context, args ...any) (interface{}, error) {
if ast.Script == nil {
return nil, nil
}
Expand All @@ -255,6 +262,44 @@ func (ast *Assistant) call(ctx context.Context, method string, context chatctx.C
}
defer scriptCtx.Close()

// Add sendMessage function to the script context
scriptCtx.WithFunction("SendMessage", func(info *v8go.FunctionCallbackInfo) *v8go.Value {

// Get the message
args := info.Args()
if len(args) < 1 {
return bridge.JsException(info.Context(), "SendMessage requires at least one argument")
}

input, err := bridge.GoValue(args[0], info.Context())
if err != nil {
return bridge.JsException(info.Context(), err.Error())
}

switch v := input.(type) {
case string:
// Check if the message is json
msg, err := message.NewString(v)
if err != nil {
return bridge.JsException(info.Context(), err.Error())
}

// Append the message to the contents
msg.AppendTo(contents)
msg.Write(c.Writer)
return nil

case map[string]interface{}:
msg := message.New().Map(v)
msg.AppendTo(contents)
msg.Write(c.Writer)
return nil

default:
return bridge.JsException(info.Context(), "SendMessage requires a string or a map")
}
})

// Check if the method exists
if !scriptCtx.Global().Has(method) {
return nil, fmt.Errorf(HookErrorMethodNotFound)
Expand Down
1 change: 0 additions & 1 deletion neo/assistant/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ type API interface {
Download(ctx context.Context, fileID string) (*FileResponse, error)
ReadBase64(ctx context.Context, fileID string) (string, error)
Execute(c *gin.Context, ctx chatctx.Context, input string, options map[string]interface{}) error
HookInit(c *gin.Context, ctx chatctx.Context, input []message.Message, options map[string]interface{}) (*ResHookInit, error)
}

// ResHookInit the response of the init hook
Expand Down
45 changes: 38 additions & 7 deletions neo/message/contents.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ type Contents struct {

// Data the data of the content
type Data struct {
Type string `json:"type"` // text, function, error, ...
ID string `json:"id"` // the id of the content
Function string `json:"function"` // the function name
Bytes []byte `json:"bytes"` // the content bytes
Arguments []byte `json:"arguments"` // the function arguments
Type string `json:"type"` // text, function, error, ...
ID string `json:"id"` // the id of the content
Function string `json:"function"` // the function name
Bytes []byte `json:"bytes"` // the content bytes
Arguments []byte `json:"arguments"` // the function arguments
Props map[string]interface{} `json:"props"` // the props
}

// NewContents create a new contents
Expand Down Expand Up @@ -57,6 +58,28 @@ func (c *Contents) NewFunction(function string, arguments []byte) *Contents {
return c
}

// NewType create a new type data and append to the contents
func (c *Contents) NewType(typ string, props map[string]interface{}) *Contents {
c.Data = append(c.Data, Data{
Type: typ,
Props: props,
})
c.Current++
return c
}

// UpdateType update the type of the current content
func (c *Contents) UpdateType(typ string, props map[string]interface{}) *Contents {
if c.Current == -1 {
c.NewType(typ, props)
return c
}

c.Data[c.Current].Type = typ
c.Data[c.Current].Props = props
return c
}

// SetFunctionID set the id of the current function content
func (c *Contents) SetFunctionID(id string) *Contents {
if c.Current == -1 {
Expand Down Expand Up @@ -128,10 +151,14 @@ func (data *Data) Map() (map[string]interface{}, error) {
v["id"] = data.ID
}

if data.Bytes != nil {
if data.Bytes != nil && data.Type == "text" {
v["text"] = string(data.Bytes)
}

if data.Props != nil && data.Type != "text" {
v["props"] = data.Props
}

if data.Arguments != nil {
var vv interface{} = nil
err := jsoniter.Unmarshal(data.Arguments, &vv)
Expand All @@ -157,10 +184,14 @@ func (data *Data) MarshalJSON() ([]byte, error) {
v["id"] = data.ID
}

if data.Bytes != nil {
if data.Bytes != nil && data.Type == "text" {
v["text"] = string(data.Bytes)
}

if data.Props != nil && data.Type != "text" {
v["props"] = data.Props
}

if data.Arguments != nil {
var vv interface{} = nil
err := jsoniter.Unmarshal(data.Arguments, &vv)
Expand Down
33 changes: 31 additions & 2 deletions neo/message/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ func NewOpenAI(data []byte) *Message {
msg := New()
text := string(data)
data = []byte(strings.TrimPrefix(text, "data: "))
// fmt.Println("--------------------------------")
// fmt.Println(string(data))
// fmt.Println("--------------------------------")

switch {

Expand Down Expand Up @@ -187,11 +190,22 @@ func (m *Message) SetContent(content string) *Message {
// AppendTo append the contents
func (m *Message) AppendTo(contents *Contents) *Message {

// Set type
if m.Type == "" {
m.Type = "text"
}

switch m.Type {
case "text":
if m.Text != "" {
if m.IsNew {
contents.NewText([]byte(m.Text))
return m
}
contents.AppendText([]byte(m.Text))
return m
}
return m

case "tool_calls":

Expand All @@ -206,8 +220,20 @@ func (m *Message) AppendTo(contents *Contents) *Message {
}

contents.AppendFunction([]byte(m.Text))
return m

case "loading":
return m

default:
if m.IsNew {
contents.NewType(m.Type, m.Props)
return m
}
contents.UpdateType(m.Type, m.Props)
return m
}
return m

}

// Content get the content
Expand Down Expand Up @@ -271,8 +297,11 @@ func (m *Message) Map(msg map[string]interface{}) *Message {
if done, ok := msg["done"].(bool); ok {
m.IsDone = done
}
if props, ok := msg["props"].(map[string]interface{}); ok {
m.Props = props
}

if isNew, ok := msg["is_new"].(bool); ok {
if isNew, ok := msg["new"].(bool); ok {
m.IsNew = isNew
}

Expand Down

0 comments on commit 839f3bb

Please sign in to comment.