-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add document Q&A chat API server example
Signed-off-by: Lanture1064 <[email protected]>
- Loading branch information
1 parent
3c14126
commit de7f400
Showing
6 changed files
with
370 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
/* | ||
Copyright 2023 The KubeAGI Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package main | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"io" | ||
|
||
"github.com/gofiber/fiber/v2" | ||
zhipuaiembeddings "github.com/kubeagi/arcadia/pkg/embeddings/zhipuai" | ||
"github.com/tmc/langchaingo/documentloaders" | ||
"github.com/tmc/langchaingo/embeddings" | ||
"github.com/tmc/langchaingo/schema" | ||
"github.com/tmc/langchaingo/textsplitter" | ||
|
||
"github.com/kubeagi/arcadia/pkg/llms/zhipuai" | ||
"github.com/kubeagi/arcadia/pkg/vectorstores/chromadb" | ||
) | ||
|
||
type Workload struct { | ||
Document string `json:"document"` | ||
Namespace string `json:"namespace"` | ||
ChunkSize int `json:"chunk-size"` | ||
ChunkOverlap int `json:"chunk-overlap"` | ||
} | ||
|
||
func HomePageGetHandler(c *fiber.Ctx) error { | ||
return c.SendString("This is the home page of chat server sample. Send POST request to /chat with your question to chat with me!") | ||
} | ||
|
||
func LoadHandler(c *fiber.Ctx) error { | ||
// Convert body to json workload | ||
var workload Workload | ||
err := c.BodyParser(&workload) | ||
if err != nil { | ||
return errors.New("error parsing body to workload type" + err.Error()) | ||
} | ||
|
||
if workload.Document == "" { | ||
return errors.New("document cannot be empty") | ||
} | ||
|
||
var docReader io.Reader | ||
|
||
_, err = docReader.Read([]byte(workload.Document)) | ||
|
||
if err != nil { | ||
return errors.New("Error reading document:" + err.Error()) | ||
} | ||
|
||
loader := documentloaders.NewText(docReader) | ||
|
||
split := textsplitter.NewRecursiveCharacter() | ||
split.ChunkSize = workload.ChunkSize | ||
split.ChunkOverlap = workload.ChunkOverlap | ||
|
||
documents, err := loader.LoadAndSplit(context.Background(), split) | ||
if err != nil { | ||
return errors.New("Error loading documents:" + err.Error()) | ||
} | ||
|
||
err = workload.EmbedDocument(context.Background(), documents) | ||
if err != nil { | ||
return errors.New("Error embedding documents:" + err.Error()) | ||
} | ||
|
||
return c.SendString("OK") | ||
} | ||
|
||
func (w Workload) EmbedDocument(ctx context.Context, documents []schema.Document) error { | ||
var embedder embeddings.Embedder | ||
var err error | ||
|
||
embedder, err = zhipuaiembeddings.NewZhiPuAI( | ||
zhipuaiembeddings.WithClient(*zhipuai.NewZhiPuAI(apiKey)), | ||
) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
chroma, err := chromadb.New( | ||
chromadb.WithURL(url), | ||
chromadb.WithEmbedder(embedder), | ||
chromadb.WithNameSpace(w.Namespace), | ||
) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
return chroma.AddDocuments(ctx, documents) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
/* | ||
Copyright 2023 The KubeAGI Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package main | ||
|
||
import ( | ||
"flag" | ||
"fmt" | ||
"github.com/spf13/cobra" | ||
) | ||
|
||
func NewCLI() *cobra.Command { | ||
cli := &cobra.Command{ | ||
Use: "chat [usage]", | ||
Short: "CLI for chat server example", | ||
RunE: func(cmd *cobra.Command, args []string) error { | ||
return nil | ||
}, | ||
} | ||
|
||
cli.AddCommand(NewStartCmd()) | ||
|
||
return cli | ||
} | ||
|
||
func main() { | ||
flag.Parse() | ||
|
||
if err := NewCLI().Execute(); err != nil { | ||
fmt.Printf("Run failed, error:\n %v", err) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
/* | ||
Copyright 2023 The KubeAGI Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package main | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"github.com/gofiber/fiber/v2" | ||
"github.com/gofiber/fiber/v2/middleware/cors" | ||
zhipuaiembeddings "github.com/kubeagi/arcadia/pkg/embeddings/zhipuai" | ||
"github.com/kubeagi/arcadia/pkg/llms/zhipuai" | ||
"github.com/kubeagi/arcadia/pkg/vectorstores/chromadb" | ||
"github.com/spf13/cobra" | ||
"github.com/tmc/langchaingo/schema" | ||
) | ||
|
||
var ( | ||
apiKey string | ||
addr string | ||
url string | ||
) | ||
|
||
func NewStartCmd() *cobra.Command { | ||
cmd := &cobra.Command{ | ||
Use: "start [usage]", | ||
Short: "Start the server", | ||
RunE: func(cmd *cobra.Command, args []string) error { | ||
return run() | ||
}, | ||
} | ||
|
||
cmd.Flags().StringVar(&apiKey, "api-key", "", "used to connect to ZhiPuAI platform") | ||
cmd.Flags().StringVar(&url, "vector-store", "", "the chromaDB vector database url") | ||
cmd.Flags().StringVar(&addr, "addr", ":8800", "used to listen and serve GET request, default :8800") | ||
|
||
if err := cmd.MarkFlagRequired("api-key"); err != nil { | ||
panic(err) | ||
} | ||
if err := cmd.MarkFlagRequired("url"); err != nil { | ||
panic(err) | ||
} | ||
|
||
return cmd | ||
} | ||
|
||
func run() error { | ||
fmt.Println("Starting chat server example...") | ||
|
||
// check ZhiPuAI api key, build embedder | ||
if apiKey == "" { | ||
return fmt.Errorf("ZhiPuAI api key is empty") | ||
} | ||
|
||
fmt.Println("Connecting platform...") | ||
z := zhipuai.NewZhiPuAI(apiKey) | ||
_, err := z.Validate() | ||
if err != nil { | ||
return fmt.Errorf("error validating ZhiPuAI api key: %s", err.Error()) | ||
} | ||
|
||
embedder, err := zhipuaiembeddings.NewZhiPuAI(zhipuaiembeddings.WithClient(*z)) | ||
if err != nil { | ||
return fmt.Errorf("error creating embedder: %s", err.Error()) | ||
} | ||
|
||
// check scheme & host, connect chromaDB server | ||
if url == "" { | ||
return fmt.Errorf("chromaDB scheme is empty") | ||
} | ||
|
||
fmt.Println("Connecting vector database...") | ||
db, err := chromadb.New( | ||
chromadb.WithURL(url), | ||
chromadb.WithEmbedder(embedder), | ||
chromadb.WithNameSpace("arcadia"), | ||
) | ||
if err != nil { | ||
return fmt.Errorf("error creating chroma db: %s", err.Error()) | ||
} | ||
|
||
fmt.Println("Creating HTTP server...") | ||
app := fiber.New(fiber.Config{ | ||
AppName: "chat-server", | ||
CaseSensitive: true, | ||
StrictRouting: true, | ||
Immutable: true, | ||
}) | ||
|
||
app.Use(cors.New(cors.ConfigDefault)) | ||
|
||
app.Get("/", HomePageGetHandler) | ||
app.Post("/load", LoadHandler) | ||
app.Post("/chat", func(c *fiber.Ctx) error { | ||
fmt.Printf("Question:%s \n", c.Body()) | ||
fmt.Println("Querying similar content...") | ||
|
||
question := string(c.Body()) | ||
|
||
res, sErr := db.SimilaritySearch(context.Background(), question, 5) | ||
if sErr != nil { | ||
return fmt.Errorf("error performing similarity search: %s", sErr.Error()) | ||
} | ||
|
||
prompt := buildPrompt(question, res) | ||
|
||
params := zhipuai.ModelParams{ | ||
Method: zhipuai.ZhiPuAIInvoke, | ||
Model: zhipuai.ZhiPuAIPro, | ||
Temperature: 0.5, | ||
TopP: 0.7, | ||
Prompt: prompt, | ||
} | ||
|
||
resp, iErr := z.Invoke(params) | ||
if iErr != nil { | ||
return fmt.Errorf("error invoking ZhiPuAI: %s", iErr.Error()) | ||
} | ||
|
||
return c.SendString(resp.String()) | ||
}) | ||
|
||
return app.Listen(addr) | ||
} | ||
|
||
func buildPrompt(question string, document []schema.Document) []zhipuai.Prompt { | ||
premise := zhipuai.Prompt{ | ||
Role: zhipuai.User, | ||
Content: ` | ||
我将要询问一些问题,希望你仅使用我提供的上下文信息回答。 | ||
请不要在回答中添加其他信息。 | ||
若我提供的上下文不足以回答问题, | ||
请回复"我不确定",再做出适当的猜测。 | ||
请将回答内容分割为适于阅读的段落。 | ||
`, | ||
} | ||
reply := zhipuai.Prompt{ | ||
Role: zhipuai.Assistant, | ||
Content: ` | ||
好的,我将尝试仅使用你提供的上下文信息回答,并在信息不足时提供一些合理推测。 | ||
`, | ||
} | ||
|
||
var info string | ||
for _, doc := range document { | ||
info += doc.PageContent | ||
} | ||
|
||
requirement := zhipuai.Prompt{ | ||
Role: zhipuai.User, | ||
Content: "问题内容如下:" + question + "\n 以下是我提供的上下文信息:\n" + info, | ||
} | ||
|
||
return []zhipuai.Prompt{premise, reply, requirement} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.