Skip to content

Commit

Permalink
feat: add document Q&A chat API server example
Browse files Browse the repository at this point in the history
Signed-off-by: Lanture1064 <[email protected]>
  • Loading branch information
Lanture1064 committed Sep 14, 2023
1 parent 3c14126 commit de7f400
Show file tree
Hide file tree
Showing 6 changed files with 370 additions and 6 deletions.
106 changes: 106 additions & 0 deletions examples/chat_with_document/handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
Copyright 2023 The KubeAGI Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package main

import (
"context"
"errors"
"io"

"github.com/gofiber/fiber/v2"
zhipuaiembeddings "github.com/kubeagi/arcadia/pkg/embeddings/zhipuai"
"github.com/tmc/langchaingo/documentloaders"
"github.com/tmc/langchaingo/embeddings"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/textsplitter"

"github.com/kubeagi/arcadia/pkg/llms/zhipuai"
"github.com/kubeagi/arcadia/pkg/vectorstores/chromadb"
)

type Workload struct {
Document string `json:"document"`
Namespace string `json:"namespace"`
ChunkSize int `json:"chunk-size"`
ChunkOverlap int `json:"chunk-overlap"`
}

func HomePageGetHandler(c *fiber.Ctx) error {
return c.SendString("This is the home page of chat server sample. Send POST request to /chat with your question to chat with me!")
}

func LoadHandler(c *fiber.Ctx) error {
// Convert body to json workload
var workload Workload
err := c.BodyParser(&workload)
if err != nil {
return errors.New("error parsing body to workload type" + err.Error())
}

if workload.Document == "" {
return errors.New("document cannot be empty")
}

var docReader io.Reader

_, err = docReader.Read([]byte(workload.Document))

if err != nil {
return errors.New("Error reading document:" + err.Error())
}

loader := documentloaders.NewText(docReader)

split := textsplitter.NewRecursiveCharacter()
split.ChunkSize = workload.ChunkSize
split.ChunkOverlap = workload.ChunkOverlap

documents, err := loader.LoadAndSplit(context.Background(), split)
if err != nil {
return errors.New("Error loading documents:" + err.Error())
}

err = workload.EmbedDocument(context.Background(), documents)
if err != nil {
return errors.New("Error embedding documents:" + err.Error())
}

return c.SendString("OK")
}

func (w Workload) EmbedDocument(ctx context.Context, documents []schema.Document) error {
var embedder embeddings.Embedder
var err error

embedder, err = zhipuaiembeddings.NewZhiPuAI(
zhipuaiembeddings.WithClient(*zhipuai.NewZhiPuAI(apiKey)),
)
if err != nil {
return err
}

chroma, err := chromadb.New(
chromadb.WithURL(url),
chromadb.WithEmbedder(embedder),
chromadb.WithNameSpace(w.Namespace),
)
if err != nil {
return err
}

return chroma.AddDocuments(ctx, documents)
}
45 changes: 45 additions & 0 deletions examples/chat_with_document/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
Copyright 2023 The KubeAGI Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package main

import (
"flag"
"fmt"
"github.com/spf13/cobra"
)

func NewCLI() *cobra.Command {
cli := &cobra.Command{
Use: "chat [usage]",
Short: "CLI for chat server example",
RunE: func(cmd *cobra.Command, args []string) error {
return nil
},
}

cli.AddCommand(NewStartCmd())

return cli
}

func main() {
flag.Parse()

if err := NewCLI().Execute(); err != nil {
fmt.Printf("Run failed, error:\n %v", err)
}
}
168 changes: 168 additions & 0 deletions examples/chat_with_document/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
Copyright 2023 The KubeAGI Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package main

import (
"context"
"fmt"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
zhipuaiembeddings "github.com/kubeagi/arcadia/pkg/embeddings/zhipuai"
"github.com/kubeagi/arcadia/pkg/llms/zhipuai"
"github.com/kubeagi/arcadia/pkg/vectorstores/chromadb"
"github.com/spf13/cobra"
"github.com/tmc/langchaingo/schema"
)

var (
apiKey string
addr string
url string
)

func NewStartCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "start [usage]",
Short: "Start the server",
RunE: func(cmd *cobra.Command, args []string) error {
return run()
},
}

cmd.Flags().StringVar(&apiKey, "api-key", "", "used to connect to ZhiPuAI platform")
cmd.Flags().StringVar(&url, "vector-store", "", "the chromaDB vector database url")
cmd.Flags().StringVar(&addr, "addr", ":8800", "used to listen and serve GET request, default :8800")

if err := cmd.MarkFlagRequired("api-key"); err != nil {
panic(err)
}
if err := cmd.MarkFlagRequired("url"); err != nil {
panic(err)
}

return cmd
}

func run() error {
fmt.Println("Starting chat server example...")

// check ZhiPuAI api key, build embedder
if apiKey == "" {
return fmt.Errorf("ZhiPuAI api key is empty")
}

fmt.Println("Connecting platform...")
z := zhipuai.NewZhiPuAI(apiKey)
_, err := z.Validate()
if err != nil {
return fmt.Errorf("error validating ZhiPuAI api key: %s", err.Error())
}

embedder, err := zhipuaiembeddings.NewZhiPuAI(zhipuaiembeddings.WithClient(*z))
if err != nil {
return fmt.Errorf("error creating embedder: %s", err.Error())
}

// check scheme & host, connect chromaDB server
if url == "" {
return fmt.Errorf("chromaDB scheme is empty")
}

fmt.Println("Connecting vector database...")
db, err := chromadb.New(
chromadb.WithURL(url),
chromadb.WithEmbedder(embedder),
chromadb.WithNameSpace("arcadia"),
)
if err != nil {
return fmt.Errorf("error creating chroma db: %s", err.Error())
}

fmt.Println("Creating HTTP server...")
app := fiber.New(fiber.Config{
AppName: "chat-server",
CaseSensitive: true,
StrictRouting: true,
Immutable: true,
})

app.Use(cors.New(cors.ConfigDefault))

app.Get("/", HomePageGetHandler)
app.Post("/load", LoadHandler)
app.Post("/chat", func(c *fiber.Ctx) error {
fmt.Printf("Question:%s \n", c.Body())
fmt.Println("Querying similar content...")

question := string(c.Body())

res, sErr := db.SimilaritySearch(context.Background(), question, 5)
if sErr != nil {
return fmt.Errorf("error performing similarity search: %s", sErr.Error())
}

prompt := buildPrompt(question, res)

params := zhipuai.ModelParams{
Method: zhipuai.ZhiPuAIInvoke,
Model: zhipuai.ZhiPuAIPro,
Temperature: 0.5,
TopP: 0.7,
Prompt: prompt,
}

resp, iErr := z.Invoke(params)
if iErr != nil {
return fmt.Errorf("error invoking ZhiPuAI: %s", iErr.Error())
}

return c.SendString(resp.String())
})

return app.Listen(addr)
}

func buildPrompt(question string, document []schema.Document) []zhipuai.Prompt {
premise := zhipuai.Prompt{
Role: zhipuai.User,
Content: `
我将要询问一些问题,希望你仅使用我提供的上下文信息回答。
请不要在回答中添加其他信息。
若我提供的上下文不足以回答问题,
请回复"我不确定",再做出适当的猜测。
请将回答内容分割为适于阅读的段落。
`,
}
reply := zhipuai.Prompt{
Role: zhipuai.Assistant,
Content: `
好的,我将尝试仅使用你提供的上下文信息回答,并在信息不足时提供一些合理推测。
`,
}

var info string
for _, doc := range document {
info += doc.PageContent
}

requirement := zhipuai.Prompt{
Role: zhipuai.User,
Content: "问题内容如下:" + question + "\n 以下是我提供的上下文信息:\n" + info,
}

return []zhipuai.Prompt{premise, reply, requirement}
}
18 changes: 15 additions & 3 deletions examples/embedding/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ func main() {
}
apiKey := os.Args[1]

fmt.Printf("Connecting with apikey %s\n", apiKey)

// init embedder
embedder, err := embedding.NewZhiPuAI(
embedding.WithClient(*zhipuai.NewZhiPuAI(apiKey)),
Expand All @@ -48,14 +50,24 @@ func main() {

// add documents
err = chroma.AddDocuments(context.TODO(), []schema.Document{
{PageContent: "This is a document about cats. Cats are great."},
{PageContent: "this is a document about dogs. Dogs are great."},
{PageContent: "This is a document about cats. Cats are great.",
Metadata: map[string]interface{}{"about": "cat"}},
})

if err != nil {
panic(fmt.Errorf("error add documents to chroma db: %s", err.Error()))
}

err = chroma.AddDocuments(context.TODO(), []schema.Document{
{PageContent: "This is a document about dogs. Dogs are great.",
Metadata: map[string]interface{}{"about": "dog"}},
})

if err != nil {
panic(fmt.Errorf("error add documents to chroma db: %s", err.Error()))
}

docs, err := chroma.SimilaritySearch(context.TODO(), "cats", 5)
docs, err := chroma.SimilaritySearch(context.TODO(), "This is a photo of a cat. Cats are cute.", 5)
if err != nil {
panic(fmt.Errorf("error similarity search: %s", err.Error()))
}
Expand Down
Loading

0 comments on commit de7f400

Please sign in to comment.