From 88ea641df9cc02b92a50fbb3d99dcd93803f0d90 Mon Sep 17 00:00:00 2001 From: Lanture1064 Date: Tue, 19 Sep 2023 13:27:31 +0800 Subject: [PATCH] feat: add SSE invoke support for API server sample Signed-off-by: Lanture1064 --- examples/chat_with_document/README.md | 50 +++++++++++++- examples/chat_with_document/handler.go | 91 +++++++++++++++++++++++++- examples/chat_with_document/start.go | 1 + go.mod | 2 +- 4 files changed, 140 insertions(+), 4 deletions(-) diff --git a/examples/chat_with_document/README.md b/examples/chat_with_document/README.md index 219315272..e280aac76 100644 --- a/examples/chat_with_document/README.md +++ b/examples/chat_with_document/README.md @@ -119,7 +119,7 @@ curl --request POST \ |---------|-----------|--------|--------------| | content | Yes | string | chat content | -#### Request Body +#### Request body ```json { @@ -148,4 +148,52 @@ curl --request POST \ "msg":"操作成功", "success":true } +``` + +### Stream chat with document + +Example: + +```shell +curl -X POST \ + -H "Content-Type: application/json" \ + -d '{"content": "KubeBB 有哪些核心套件?"}' \ + http://localhost:8800/sse +``` + +#### URL + +- `POST /chat` + +#### Parameter + +| Name | Must have | Type | Description | +|---------|-----------|--------|--------------| +| content | Yes | string | chat content | + +#### Request body + +```json +{ + "content": "KubeBB 有哪些核心套件?" +} +``` + +#### Response + +```shell + KubeBB 的核心套件包括: + +1. 内核 Kit:提供声明式的组件生命周期管理和组件市场,并通过 Tekton 流水线强化低代码平台组件与底座服务的集成。 +2. 底座 Kit:提供开箱即用的云原生服务门户,包括用户、OIDC 认证、权限、审计、租户管理、门户服务等基础组件以及证书管理、Nginx Ingress 等集群组件。 +3. 低码 Kit:依托 Low-Code Engine 和具有 Git 特性的关系数据库 Dolt 打造,并借助底座门户的菜单和路由资源以及内核套件的组件管理能力,实现组件开发、测试到上线的全链路能力。 + +关于 KubeBB 套件之间的关系,可以类比为: + +- Kubernetes ~ 操作系统内核 +- Core ~ 软件安装器 +- 底座 Kit ~ 操作系统的系统软件,如 GUI、用户系统、网络等 +- 低码组件开发 Kit ~ 操作系统软件开发工具 + + finish: ``` \ No newline at end of file diff --git a/examples/chat_with_document/handler.go b/examples/chat_with_document/handler.go index f7646dcba..a1c0b76a3 100644 --- a/examples/chat_with_document/handler.go +++ b/examples/chat_with_document/handler.go @@ -17,23 +17,30 @@ limitations under the License. package main import ( + "bufio" "bytes" "context" "errors" "fmt" + "time" + "github.com/gofiber/fiber/v2" - zhipuaiembeddings "github.com/kubeagi/arcadia/pkg/embeddings/zhipuai" + "github.com/r3labs/sse/v2" "github.com/tmc/langchaingo/documentloaders" "github.com/tmc/langchaingo/embeddings" "github.com/tmc/langchaingo/textsplitter" + "github.com/valyala/fasthttp" + zhipuaiembeddings "github.com/kubeagi/arcadia/pkg/embeddings/zhipuai" "github.com/kubeagi/arcadia/pkg/llms/zhipuai" "github.com/kubeagi/arcadia/pkg/vectorstores/chromadb" ) const ( - _defaultChunkSize = 2048 + _defaultChunkSize = 1024 _defaultChunkOverlap = 128 + _defaultTimeout = 300 * time.Second + APITokenTTLSeconds = 3 * 60 ) type Workload struct { @@ -169,3 +176,83 @@ func (w Workload) EmbedAndStoreDocument(ctx context.Context) error { return chroma.AddDocuments(ctx, documents) } + +func StreamQueryHandler(c *fiber.Ctx) error { + c.Set("Content-Type", "text/event-stream") + c.Set("Cache-Control", "no-cache") + c.Set("Connection", "keep-alive") + c.Set("Transfer-Encoding", "chunked") + + var chat Chat + err := c.BodyParser(&chat) + if chat.Content == "" { + return errors.New("content cannot be empty") + } + + embedder, err := zhipuaiembeddings.NewZhiPuAI( + zhipuaiembeddings.WithClient(*zhipuai.NewZhiPuAI(apiKey)), + ) + if err != nil { + return err + } + + fmt.Println("Connecting vector database...") + db, err := chromadb.New( + chromadb.WithURL(url), + chromadb.WithEmbedder(embedder), + chromadb.WithNameSpace(namespace), + ) + if err != nil { + return fmt.Errorf("error creating chroma db: %s", err.Error()) + } + + res, sErr := db.SimilaritySearch(context.Background(), chat.Content, 5) + if sErr != nil { + return fmt.Errorf("error performing similarity search: %s", sErr.Error()) + } + + prompt := buildPrompt(chat.Content, res) + + params := zhipuai.ModelParams{ + Method: zhipuai.ZhiPuAISSEInvoke, + Model: zhipuai.ZhiPuAIPro, + Temperature: 0.5, + TopP: 0.7, + Prompt: prompt, + } + + apiURL := zhipuai.BuildAPIURL(params.Model, params.Method) + token, err := zhipuai.GenerateToken(apiKey, APITokenTTLSeconds) + if err != nil { + return err + } + + c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { + iErr := zhipuai.Stream(apiURL, token, params, _defaultTimeout, func(event *sse.Event) { + switch string(event.Event) { + case "add": + fmt.Fprintf(w, string(event.Data)) + fmt.Printf(string(event.Data)) + case "error", "interrupted", "finish": + fmt.Fprintf(w, "\n\n %s: %s", event.Event, event.Data) + } + + err := w.Flush() + if err != nil { + // Refreshing page in web browser will establish a new + // SSE connection, but only (the last) one is alive, so + // dead connections must be closed here. + fmt.Printf("Error while flushing: %v. Closing http connection.\n", err) + + return + } + }) + + if iErr != nil { + fmt.Printf("Error while invoking: %v. Closing http connection.\n", iErr) + return + } + })) + + return nil +} diff --git a/examples/chat_with_document/start.go b/examples/chat_with_document/start.go index 19dab0370..87105b9b1 100644 --- a/examples/chat_with_document/start.go +++ b/examples/chat_with_document/start.go @@ -109,6 +109,7 @@ func run() error { app.Use(cors.New(cors.ConfigDefault)) app.Get("/", HomePageGetHandler) + app.Post("/sse", StreamQueryHandler) app.Post("/load", LoadHandler) app.Post("/chat", QueryHandler) diff --git a/go.mod b/go.mod index 96a9149f0..3ec0af4ac 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/spf13/cobra v1.4.0 github.com/stretchr/testify v1.8.4 github.com/tmc/langchaingo v0.0.0-20230829032728-c85d3967da08 + github.com/valyala/fasthttp v1.49.0 k8s.io/api v0.24.2 k8s.io/apimachinery v0.24.2 k8s.io/client-go v0.24.2 @@ -81,7 +82,6 @@ require ( github.com/rivo/uniseg v0.2.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - github.com/valyala/fasthttp v1.49.0 // indirect github.com/valyala/tcplisten v1.0.0 // indirect go.uber.org/atomic v1.7.0 // indirect go.uber.org/multierr v1.6.0 // indirect