Skip to content

Commit

Permalink
Merge pull request #814 from trheyi/main
Browse files Browse the repository at this point in the history
Neo API and Conversation Management Refactoring
  • Loading branch information
trheyi authored Jan 2, 2025
2 parents fc3a695 + af30bec commit 1254414
Show file tree
Hide file tree
Showing 10 changed files with 637 additions and 352 deletions.
362 changes: 63 additions & 299 deletions neo/assistant/assistant.go

Large diffs are not rendered by default.

400 changes: 400 additions & 0 deletions neo/assistant/load.go

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions neo/assistant/assistant_test.go → neo/assistant/load_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func prepare(t *testing.T) {
test.Prepare(t, config.Conf)
}

func TestAssistant_LoadPath(t *testing.T) {
func TestLoad_LoadPath(t *testing.T) {
prepare(t)
defer test.Clean()

Expand All @@ -37,7 +37,7 @@ func TestAssistant_LoadPath(t *testing.T) {
assert.Error(t, err)
}

func TestAssistant_LoadStore(t *testing.T) {
func TestLoad_LoadStore(t *testing.T) {
prepare(t)
defer test.Clean()

Expand Down Expand Up @@ -79,7 +79,7 @@ func TestAssistant_LoadStore(t *testing.T) {
assert.Error(t, err)
}

func TestAssistant_Cache(t *testing.T) {
func TestLoad_Cache(t *testing.T) {
prepare(t)
defer test.Clean()

Expand Down Expand Up @@ -127,7 +127,7 @@ func TestAssistant_Cache(t *testing.T) {
assert.NotNil(t, loaded)
}

func TestAssistant_Validate(t *testing.T) {
func TestLoad_Validate(t *testing.T) {
tests := []struct {
name string
ast *Assistant
Expand Down Expand Up @@ -178,7 +178,7 @@ func TestAssistant_Validate(t *testing.T) {
}
}

func TestAssistant_Clone(t *testing.T) {
func TestLoad_Clone(t *testing.T) {
// Create a test assistant with all fields populated
original := &Assistant{
ID: "test-id",
Expand Down Expand Up @@ -233,7 +233,7 @@ func TestAssistant_Clone(t *testing.T) {
assert.Nil(t, nilAssistant.Clone())
}

func TestAssistant_Update(t *testing.T) {
func TestLoad_Update(t *testing.T) {
// Create a test assistant
ast := &Assistant{
ID: "test-id",
Expand Down
16 changes: 16 additions & 0 deletions neo/assistant/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io"
"mime/multipart"

"github.com/yaoapp/gou/rag/driver"
v8 "github.com/yaoapp/gou/runtime/v8"
)

Expand All @@ -16,6 +17,19 @@ type API interface {
ReadBase64(ctx context.Context, fileID string) (string, error)
}

// RAG the RAG interface
type RAG struct {
Engine driver.Engine
Uploader driver.FileUpload
Vectorizer driver.Vectorizer
Setting RAGSetting
}

// RAGSetting the RAG setting
type RAGSetting struct {
IndexPrefix string `json:"index_prefix" yaml:"index_prefix"`
}

// Prompt a prompt
type Prompt struct {
Role string `json:"role"`
Expand Down Expand Up @@ -51,6 +65,8 @@ type Assistant struct {
Flows []map[string]interface{} `json:"flows,omitempty"` // Assistant Flows
Script *v8.Script `json:"-" yaml:"-"` // Assistant Script
API API `json:"-" yaml:"-"` // Assistant API
CreatedAt int64 `json:"created_at"` // Creation timestamp
UpdatedAt int64 `json:"updated_at"` // Last update timestamp
}

// File the file
Expand Down
44 changes: 44 additions & 0 deletions neo/assistant/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package assistant

import (
"fmt"
"strconv"
"time"
)

func getTimestamp(v interface{}) (int64, error) {
switch v := v.(type) {
case int64:
return v, nil
case int:
return int64(v), nil

case string:
if ts, err := time.Parse(time.RFC3339, v); err == nil {
return ts.UnixNano(), nil
}

// MySQL format
if ts, err := time.Parse("2006-01-02 15:04:05", v); err == nil {
return ts.UnixNano(), nil
}

// UnixNano format
if ts, err := strconv.ParseInt(v, 10, 64); err == nil {
return ts, nil
}

}
return 0, fmt.Errorf("invalid timestamp type")
}

func stringToTimestamp(v string) (int64, error) {
return strconv.ParseInt(v, 10, 64)
}

func timeToMySQLFormat(ts int64) string {
if ts == 0 {
return "0000-00-00 00:00:00"
}
return time.Unix(ts/1e9, ts%1e9).Format("2006-01-02 15:04:05")
}
111 changes: 93 additions & 18 deletions neo/load.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package neo

import (
"fmt"
"path/filepath"

"github.com/fatih/color"
"github.com/yaoapp/gou/application"
"github.com/yaoapp/gou/connector"
"github.com/yaoapp/kun/log"
"github.com/yaoapp/yao/config"
"github.com/yaoapp/yao/neo/assistant"
Expand All @@ -15,20 +17,6 @@ import (
// Neo the neo AI assistant
var Neo *DSL

// initRAG initialize the RAG instance
func (neo *DSL) initRAG() {
if neo.RAGSetting.Engine.Driver == "" {
return
}
instance, err := rag.New(neo.RAGSetting)
if err != nil {
color.Red("[Neo] Failed to initialize RAG: %v", err)
log.Error("[Neo] Failed to initialize RAG: %v", err)
return
}
neo.RAG = instance
}

// Load load AIGC
func Load(cfg config.Config) error {

Expand All @@ -38,7 +26,7 @@ func Load(cfg config.Config) error {
Option: map[string]interface{}{},
Allows: []string{},
StoreSetting: store.Setting{
Table: "yao_neo_conversation",
Prefix: "yao_neo_",
Connector: "default",
},
}
Expand All @@ -60,21 +48,94 @@ func Load(cfg config.Config) error {
Neo = &setting

// Store Setting
err = Neo.createStore()
err = Neo.initStore()
if err != nil {
return err
}

// Initialize RAG
Neo.initRAG()

// Load Built-in Assistants
// Initialize Assistant
err = Neo.initAssistant()
if err != nil {
return err
}

return nil
}

// initRAG initialize the RAG instance
func (neo *DSL) initRAG() {
if neo.RAGSetting.Engine.Driver == "" {
return
}
instance, err := rag.New(neo.RAGSetting)
if err != nil {
color.Red("[Neo] Failed to initialize RAG: %v", err)
log.Error("[Neo] Failed to initialize RAG: %v", err)
return
}

neo.RAG = instance
}

// initStore initialize the store
func (neo *DSL) initStore() error {

var err error
if neo.StoreSetting.Connector == "default" || neo.StoreSetting.Connector == "" {
neo.Store, err = store.NewXun(neo.StoreSetting)
return err
}

// other connector
conn, err := connector.Select(neo.StoreSetting.Connector)
if err != nil {
return err
}

if conn.Is(connector.DATABASE) {
neo.Store, err = store.NewXun(neo.StoreSetting)
return err

} else if conn.Is(connector.REDIS) {
neo.Store = store.NewRedis()
return nil

} else if conn.Is(connector.MONGO) {
neo.Store = store.NewMongo()
return nil
}

return fmt.Errorf("%s store connector %s not support", neo.ID, neo.StoreSetting.Connector)
}

// initAssistant initialize the assistant
func (neo *DSL) initAssistant() error {

// Set Storage
assistant.SetStorage(Neo.Store)
err = assistant.LoadBuiltIn()

// Assistant RAG
if Neo.RAG != nil {
assistant.SetRAG(
Neo.RAG.Engine(),
Neo.RAG.FileUpload(),
Neo.RAG.Vectorizer(),
assistant.RAGSetting{
IndexPrefix: Neo.RAGSetting.IndexPrefix,
},
)
}

// Load Built-in Assistants
err := assistant.LoadBuiltIn()
if err != nil {
return err
}

// Default Assistant
defaultAssistant, err := Neo.defaultAssistant()
if err != nil {
return err
Expand All @@ -83,3 +144,17 @@ func Load(cfg config.Config) error {
Neo.Assistant = defaultAssistant.API
return nil
}

// defaultAssistant get the default assistant
func (neo *DSL) defaultAssistant() (*assistant.Assistant, error) {
if neo.Use != "" {
return assistant.Get(neo.Use)
}

name := neo.Name
if name == "" {
name = "Neo"
}

return assistant.GetByConnector(neo.Connector, name)
}
14 changes: 0 additions & 14 deletions neo/neo.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,20 +346,6 @@ func (neo *DSL) chat(ast assistant.API, ctx Context, messages []map[string]inter
}
}

// defaultAssistant get the default assistant
func (neo *DSL) defaultAssistant() (*assistant.Assistant, error) {
if neo.Use != "" {
return assistant.Get(neo.Use)
}

name := neo.Name
if name == "" {
name = "Neo"
}

return assistant.GetByConnector(neo.Connector, name)
}

// updateAssistantList update the assistant list
func (neo *DSL) updateAssistantList(list []assistant.Assistant) {
lock.Lock()
Expand Down
2 changes: 1 addition & 1 deletion neo/store/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package store
type Setting struct {
Connector string `json:"connector,omitempty"` // Name of the connector used to specify data storage method
UserField string `json:"user_field,omitempty"` // User ID field name, defaults to "user_id"
Table string `json:"table,omitempty"` // Database table name
Prefix string `json:"prefix,omitempty"` // Database table name prefix
MaxSize int `json:"max_size,omitempty" yaml:"max_size,omitempty"` // Maximum storage size limit
TTL int `json:"ttl,omitempty" yaml:"ttl,omitempty"` // Time To Live in seconds
}
Expand Down
8 changes: 4 additions & 4 deletions neo/store/xun.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func (conv *Xun) clean() {
}

if nums > 0 {
log.Trace("Clean the conversation table: %s %d", conv.setting.Table, nums)
log.Trace("Clean the conversation table: %s %d", conv.setting.Prefix, nums)
}
}

Expand Down Expand Up @@ -283,15 +283,15 @@ func (conv *Xun) getUserID(sid string) (string, error) {
}

func (conv *Xun) getHistoryTable() string {
return conv.setting.Table + "_history"
return conv.setting.Prefix + "history"
}

func (conv *Xun) getChatTable() string {
return conv.setting.Table + "_chat"
return conv.setting.Prefix + "chat"
}

func (conv *Xun) getAssistantTable() string {
return conv.setting.Table + "_assistant"
return conv.setting.Prefix + "assistant"
}

// UpdateChatTitle update the chat title
Expand Down
Loading

0 comments on commit 1254414

Please sign in to comment.