From 80a0222ab0a88f2aa7db57125e6d2583d934ef6a Mon Sep 17 00:00:00 2001 From: Kailash Bisht Date: Tue, 26 Sep 2023 09:12:54 +0530 Subject: [PATCH 1/6] (chore): add token utils to count the total tokens utilized --- .vscode/settings.json | 3 ++ e2e/factories/factories.go | 1 + go.mod | 2 ++ go.sum | 4 +++ .../api-gateway/connectors/openai/openai.go | 29 +++++++++++++++++-- shared/go/tokenutils/tokenutils.go | 15 ++++++++++ 6 files changed, 52 insertions(+), 2 deletions(-) create mode 100644 .vscode/settings.json create mode 100644 shared/go/tokenutils/tokenutils.go diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..45ea939d --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "java.configuration.updateBuildConfiguration": "disabled" +} diff --git a/e2e/factories/factories.go b/e2e/factories/factories.go index 9c360e96..cdd6acc6 100644 --- a/e2e/factories/factories.go +++ b/e2e/factories/factories.go @@ -3,6 +3,7 @@ package factories import ( "context" "encoding/json" + openaiconnector "github.com/basemind-ai/monorepo/gen/go/openai/v1" "github.com/basemind-ai/monorepo/shared/go/datatypes" "github.com/basemind-ai/monorepo/shared/go/db" diff --git a/go.mod b/go.mod index cec6d7a3..74097e1d 100644 --- a/go.mod +++ b/go.mod @@ -42,6 +42,7 @@ require ( github.com/containerd/continuity v0.4.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dlclark/regexp2 v1.9.0 // indirect github.com/docker/cli v24.0.6+incompatible // indirect github.com/docker/docker v24.0.6+incompatible // indirect github.com/docker/go-connections v0.4.0 // indirect @@ -77,6 +78,7 @@ require ( github.com/rogpeppe/go-internal v1.11.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/stretchr/objx v0.5.1 // indirect + github.com/tiktoken-go/tokenizer v0.1.0 // indirect github.com/vmihailenco/go-tinylfu v0.2.2 // indirect github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect diff --git a/go.sum b/go.sum index da5ad885..4f7f537b 100644 --- a/go.sum +++ b/go.sum @@ -50,6 +50,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dlclark/regexp2 v1.9.0 h1:pTK/l/3qYIKaRXuHnEnIf7Y5NxfRPfpb7dis6/gdlVI= +github.com/dlclark/regexp2 v1.9.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/docker/cli v24.0.6+incompatible h1:fF+XCQCgJjjQNIMjzaSmiKJSCcfcXb3TWTcc7GAneOY= github.com/docker/cli v24.0.6+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= github.com/docker/docker v24.0.6+incompatible h1:hceabKCtUgDqPu+qm0NgsaXf28Ljf4/pWFL7xjWWDgE= @@ -249,6 +251,8 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/tiktoken-go/tokenizer v0.1.0 h1:c1fXriHSR/NmhMDTwUDLGiNhHwTV+ElABGvqhCWLRvY= +github.com/tiktoken-go/tokenizer v0.1.0/go.mod h1:7SZW3pZUKWLJRilTvWCa86TOVIiiJhYj3FQ5V3alWcg= github.com/vmihailenco/go-tinylfu v0.2.2 h1:H1eiG6HM36iniK6+21n9LLpzx1G9R3DJa2UjUjbynsI= github.com/vmihailenco/go-tinylfu v0.2.2/go.mod h1:CutYi2Q9puTxfcolkliPq4npPuofg9N9t8JVrjzwa3Q= github.com/vmihailenco/msgpack/v5 v5.3.4/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= diff --git a/services/api-gateway/connectors/openai/openai.go b/services/api-gateway/connectors/openai/openai.go index e95b4634..8ba1579a 100644 --- a/services/api-gateway/connectors/openai/openai.go +++ b/services/api-gateway/connectors/openai/openai.go @@ -3,8 +3,13 @@ package openai import ( "context" "errors" - "github.com/basemind-ai/monorepo/shared/go/datatypes" + "fmt" "io" + "strings" + + "github.com/basemind-ai/monorepo/shared/go/datatypes" + "github.com/basemind-ai/monorepo/shared/go/tokenutils" + "github.com/tiktoken-go/tokenizer" openaiconnector "github.com/basemind-ai/monorepo/gen/go/openai/v1" "github.com/rs/zerolog/log" @@ -48,7 +53,27 @@ func (c *Client) RequestPrompt( if requestErr != nil { return "", requestErr } - // TODO handle token related logic here by using the response token properties. + + // Count the total number of tokens utilized for openai prompt + var promptMessages string + for _, message := range promptRequest.Messages { + promptMessages += *message.Content + promptMessages += "\n" + } + + promptMessages = strings.TrimRight(promptMessages, "\n") + + promptReqTokenCount, tokenizationErr := tokenutils.GetPromptTokenCount(promptMessages, tokenizer.Cl100kBase) + if tokenizationErr != nil { + log.Err(tokenizationErr).Msg("failed to get prompt token count") + } + + promptResTokenCount, tokenizationErr := tokenutils.GetPromptTokenCount(response.Content, tokenizer.Cl100kBase) + if tokenizationErr != nil { + log.Err(tokenizationErr).Msg("failed to get prompt token count") + } + + log.Debug().Msg(fmt.Sprintf("Total tokens utilized: Request-%d, Response-%d", promptReqTokenCount, promptResTokenCount)) return response.Content, nil } diff --git a/shared/go/tokenutils/tokenutils.go b/shared/go/tokenutils/tokenutils.go new file mode 100644 index 00000000..8c299af0 --- /dev/null +++ b/shared/go/tokenutils/tokenutils.go @@ -0,0 +1,15 @@ +package tokenutils + +import ( + "github.com/tiktoken-go/tokenizer" +) + +func GetPromptTokenCount(prompt string, encoding tokenizer.Encoding) (int, error) { + enc, err := tokenizer.Get(encoding) + if err != nil { + return -1, err + } + + ids, _, _ := enc.Encode(prompt) + return len(ids), nil +} From ae8b056453026c82c3df11cb17f39223578d1084 Mon Sep 17 00:00:00 2001 From: Kailash Bisht Date: Thu, 28 Sep 2023 23:43:30 +0530 Subject: [PATCH 2/6] (chore): add token utils to count the total tokens for streaming response --- .../api-gateway/connectors/openai/openai.go | 22 ++++++----- shared/go/tokenutils/tokenutils.go | 19 ++++++++++ shared/go/tokenutils/tokenutils_test.go | 37 +++++++++++++++++++ 3 files changed, 68 insertions(+), 10 deletions(-) create mode 100644 shared/go/tokenutils/tokenutils_test.go diff --git a/services/api-gateway/connectors/openai/openai.go b/services/api-gateway/connectors/openai/openai.go index 8ba1579a..2aa79a3a 100644 --- a/services/api-gateway/connectors/openai/openai.go +++ b/services/api-gateway/connectors/openai/openai.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "io" - "strings" "github.com/basemind-ai/monorepo/shared/go/datatypes" "github.com/basemind-ai/monorepo/shared/go/tokenutils" @@ -55,15 +54,7 @@ func (c *Client) RequestPrompt( } // Count the total number of tokens utilized for openai prompt - var promptMessages string - for _, message := range promptRequest.Messages { - promptMessages += *message.Content - promptMessages += "\n" - } - - promptMessages = strings.TrimRight(promptMessages, "\n") - - promptReqTokenCount, tokenizationErr := tokenutils.GetPromptTokenCount(promptMessages, tokenizer.Cl100kBase) + promptReqTokenCount, tokenizationErr := tokenutils.GetRequestPromptTokenCount(promptRequest.Messages, tokenizer.Cl100kBase) if tokenizationErr != nil { log.Err(tokenizationErr).Msg("failed to get prompt token count") } @@ -103,6 +94,12 @@ func (c *Client) RequestStream( return } + promptReqTokenCount, tokenizationErr := tokenutils.GetRequestPromptTokenCount(promptRequest.Messages, tokenizer.Cl100kBase) + if tokenizationErr != nil { + log.Err(tokenizationErr).Msg("failed to get prompt token count") + } + log.Debug().Msg(fmt.Sprintf("Total tokens utilized for request prompt - %d", promptReqTokenCount)) + for { msg, receiveErr := stream.Recv() if receiveErr != nil { @@ -114,6 +111,11 @@ func (c *Client) RequestStream( } // TODO handle token related logic here + promptResTokenCount, tokenizationErr := tokenutils.GetPromptTokenCount(msg.Content, tokenizer.Cl100kBase) + if tokenizationErr != nil { + log.Err(tokenizationErr).Msg("failed to get prompt token count") + } + log.Debug().Msg(fmt.Sprintf("Tokens utilized for streaming response-%d", promptResTokenCount)) contentChannel <- msg.Content } } diff --git a/shared/go/tokenutils/tokenutils.go b/shared/go/tokenutils/tokenutils.go index 8c299af0..f211dd5f 100644 --- a/shared/go/tokenutils/tokenutils.go +++ b/shared/go/tokenutils/tokenutils.go @@ -1,6 +1,9 @@ package tokenutils import ( + "strings" + + openaiconnector "github.com/basemind-ai/monorepo/gen/go/openai/v1" "github.com/tiktoken-go/tokenizer" ) @@ -13,3 +16,19 @@ func GetPromptTokenCount(prompt string, encoding tokenizer.Encoding) (int, error ids, _, _ := enc.Encode(prompt) return len(ids), nil } + +func GetRequestPromptTokenCount(messages []*openaiconnector.OpenAIMessage, encoding tokenizer.Encoding) (int, error) { + var promptMessages string + for _, message := range messages { + promptMessages += *message.Content + promptMessages += "\n" + } + promptMessages = strings.TrimRight(promptMessages, "\n") + + promptReqTokenCount, tokenizationErr := GetPromptTokenCount(promptMessages, encoding) + if tokenizationErr != nil { + return -1, tokenizationErr + } + + return promptReqTokenCount, nil +} diff --git a/shared/go/tokenutils/tokenutils_test.go b/shared/go/tokenutils/tokenutils_test.go new file mode 100644 index 00000000..98feca2b --- /dev/null +++ b/shared/go/tokenutils/tokenutils_test.go @@ -0,0 +1,37 @@ +package tokenutils_test + +import ( + "fmt" + "testing" + + "github.com/basemind-ai/monorepo/shared/go/tokenutils" + "github.com/stretchr/testify/assert" + "github.com/tiktoken-go/tokenizer" +) + +func TestGetPromptTokenCount(t *testing.T) { + testCases := []struct { + input string + expected int + }{ + { + input: "Hello world!", + expected: 3, + }, + { + input: "", + expected: 0, + }, + { + input: "Goodbye world!", + expected: 4, + }, + } + + for _, testCase := range testCases { + t.Run(fmt.Sprintf("Test: %d", testCase.expected), func(t *testing.T) { + actual, _ := tokenutils.GetPromptTokenCount(testCase.input, tokenizer.Cl100kBase) + assert.Equal(t, testCase.expected, actual) + }) + } +} From 10d4da8cdd1af0869e38d67cc43312a88f862281 Mon Sep 17 00:00:00 2001 From: Kailash Bisht Date: Thu, 28 Sep 2023 23:48:33 +0530 Subject: [PATCH 3/6] (chore): remove vscode config --- .vscode/settings.json | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 45ea939d..00000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "java.configuration.updateBuildConfiguration": "disabled" -} From 221d554b7c932b655921fa5beba705a860b6f387 Mon Sep 17 00:00:00 2001 From: Kailash Bisht Date: Sat, 30 Sep 2023 00:09:52 +0530 Subject: [PATCH 4/6] (chore): address review comments --- go.mod | 2 +- .../api-gateway/connectors/openai/openai.go | 17 ++++++----- .../api-gateway/connectors/openai/utils.go | 13 ++++++++- shared/go/tokenutils/tokenutils.go | 29 +++++++------------ shared/go/tokenutils/tokenutils_test.go | 3 +- 5 files changed, 34 insertions(+), 30 deletions(-) diff --git a/go.mod b/go.mod index 74097e1d..f1b20035 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/rs/zerolog v1.31.0 github.com/sethvargo/go-envconfig v0.9.0 github.com/stretchr/testify v1.8.4 + github.com/tiktoken-go/tokenizer v0.1.0 golang.org/x/sync v0.3.0 google.golang.org/grpc v1.58.2 google.golang.org/protobuf v1.31.0 @@ -78,7 +79,6 @@ require ( github.com/rogpeppe/go-internal v1.11.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/stretchr/objx v0.5.1 // indirect - github.com/tiktoken-go/tokenizer v0.1.0 // indirect github.com/vmihailenco/go-tinylfu v0.2.2 // indirect github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect diff --git a/services/api-gateway/connectors/openai/openai.go b/services/api-gateway/connectors/openai/openai.go index 2aa79a3a..d445fcdf 100644 --- a/services/api-gateway/connectors/openai/openai.go +++ b/services/api-gateway/connectors/openai/openai.go @@ -8,7 +8,6 @@ import ( "github.com/basemind-ai/monorepo/shared/go/datatypes" "github.com/basemind-ai/monorepo/shared/go/tokenutils" - "github.com/tiktoken-go/tokenizer" openaiconnector "github.com/basemind-ai/monorepo/gen/go/openai/v1" "github.com/rs/zerolog/log" @@ -54,12 +53,13 @@ func (c *Client) RequestPrompt( } // Count the total number of tokens utilized for openai prompt - promptReqTokenCount, tokenizationErr := tokenutils.GetRequestPromptTokenCount(promptRequest.Messages, tokenizer.Cl100kBase) + reqPromptString := GetRequestPromptString(promptRequest.Messages, tokenutils.Cl100kBase) + promptReqTokenCount, tokenizationErr := tokenutils.GetPromptTokenCount(reqPromptString, tokenutils.Cl100kBase) if tokenizationErr != nil { log.Err(tokenizationErr).Msg("failed to get prompt token count") } - promptResTokenCount, tokenizationErr := tokenutils.GetPromptTokenCount(response.Content, tokenizer.Cl100kBase) + promptResTokenCount, tokenizationErr := tokenutils.GetPromptTokenCount(response.Content, tokenutils.Cl100kBase) if tokenizationErr != nil { log.Err(tokenizationErr).Msg("failed to get prompt token count") } @@ -94,12 +94,15 @@ func (c *Client) RequestStream( return } - promptReqTokenCount, tokenizationErr := tokenutils.GetRequestPromptTokenCount(promptRequest.Messages, tokenizer.Cl100kBase) + reqPromptString := GetRequestPromptString(promptRequest.Messages, tokenutils.Cl100kBase) + promptReqTokenCount, tokenizationErr := tokenutils.GetPromptTokenCount(reqPromptString, tokenutils.Cl100kBase) if tokenizationErr != nil { log.Err(tokenizationErr).Msg("failed to get prompt token count") } log.Debug().Msg(fmt.Sprintf("Total tokens utilized for request prompt - %d", promptReqTokenCount)) + var promptResTokenCount int + for { msg, receiveErr := stream.Recv() if receiveErr != nil { @@ -107,15 +110,15 @@ func (c *Client) RequestStream( errChannel <- receiveErr } close(contentChannel) + log.Debug().Msg(fmt.Sprintf("Tokens utilized for streaming response-%d", promptResTokenCount)) return } - // TODO handle token related logic here - promptResTokenCount, tokenizationErr := tokenutils.GetPromptTokenCount(msg.Content, tokenizer.Cl100kBase) + streamResTokenCount, tokenizationErr := tokenutils.GetPromptTokenCount(msg.Content, tokenutils.Cl100kBase) if tokenizationErr != nil { log.Err(tokenizationErr).Msg("failed to get prompt token count") } - log.Debug().Msg(fmt.Sprintf("Tokens utilized for streaming response-%d", promptResTokenCount)) + promptResTokenCount += streamResTokenCount contentChannel <- msg.Content } } diff --git a/services/api-gateway/connectors/openai/utils.go b/services/api-gateway/connectors/openai/utils.go index cd79d865..f5921dc6 100644 --- a/services/api-gateway/connectors/openai/utils.go +++ b/services/api-gateway/connectors/openai/utils.go @@ -3,9 +3,11 @@ package openai import ( "encoding/json" "fmt" + "strings" + "github.com/basemind-ai/monorepo/shared/go/datatypes" "github.com/basemind-ai/monorepo/shared/go/db" - "strings" + "github.com/basemind-ai/monorepo/shared/go/tokenutils" openaiconnector "github.com/basemind-ai/monorepo/gen/go/openai/v1" ) @@ -86,3 +88,12 @@ func CreatePromptRequest( return promptRequest, nil } + +func GetRequestPromptString(messages []*openaiconnector.OpenAIMessage, encoding tokenutils.Encoding) string { + var promptMessages string + for _, message := range messages { + promptMessages += *message.Content + promptMessages += "\n" + } + return strings.TrimRight(promptMessages, "\n") +} diff --git a/shared/go/tokenutils/tokenutils.go b/shared/go/tokenutils/tokenutils.go index f211dd5f..7e226b1b 100644 --- a/shared/go/tokenutils/tokenutils.go +++ b/shared/go/tokenutils/tokenutils.go @@ -1,12 +1,19 @@ package tokenutils import ( - "strings" - - openaiconnector "github.com/basemind-ai/monorepo/gen/go/openai/v1" "github.com/tiktoken-go/tokenizer" ) +type Encoding = tokenizer.Encoding + +const ( + GPT2Enc Encoding = tokenizer.GPT2Enc + R50kBase Encoding = tokenizer.R50kBase + P50kBase Encoding = tokenizer.P50kBase + P50kEdit Encoding = tokenizer.P50kEdit + Cl100kBase Encoding = tokenizer.Cl100kBase +) + func GetPromptTokenCount(prompt string, encoding tokenizer.Encoding) (int, error) { enc, err := tokenizer.Get(encoding) if err != nil { @@ -16,19 +23,3 @@ func GetPromptTokenCount(prompt string, encoding tokenizer.Encoding) (int, error ids, _, _ := enc.Encode(prompt) return len(ids), nil } - -func GetRequestPromptTokenCount(messages []*openaiconnector.OpenAIMessage, encoding tokenizer.Encoding) (int, error) { - var promptMessages string - for _, message := range messages { - promptMessages += *message.Content - promptMessages += "\n" - } - promptMessages = strings.TrimRight(promptMessages, "\n") - - promptReqTokenCount, tokenizationErr := GetPromptTokenCount(promptMessages, encoding) - if tokenizationErr != nil { - return -1, tokenizationErr - } - - return promptReqTokenCount, nil -} diff --git a/shared/go/tokenutils/tokenutils_test.go b/shared/go/tokenutils/tokenutils_test.go index 98feca2b..f5fa7b5c 100644 --- a/shared/go/tokenutils/tokenutils_test.go +++ b/shared/go/tokenutils/tokenutils_test.go @@ -6,7 +6,6 @@ import ( "github.com/basemind-ai/monorepo/shared/go/tokenutils" "github.com/stretchr/testify/assert" - "github.com/tiktoken-go/tokenizer" ) func TestGetPromptTokenCount(t *testing.T) { @@ -30,7 +29,7 @@ func TestGetPromptTokenCount(t *testing.T) { for _, testCase := range testCases { t.Run(fmt.Sprintf("Test: %d", testCase.expected), func(t *testing.T) { - actual, _ := tokenutils.GetPromptTokenCount(testCase.input, tokenizer.Cl100kBase) + actual, _ := tokenutils.GetPromptTokenCount(testCase.input, tokenutils.Cl100kBase) assert.Equal(t, testCase.expected, actual) }) } From 3a1fcb9d05c3d6dd64dcca8c32732757afca1cec Mon Sep 17 00:00:00 2001 From: Kailash Bisht Date: Sat, 30 Sep 2023 00:16:30 +0530 Subject: [PATCH 5/6] (chore): add tests for error case --- shared/go/tokenutils/tokenutils_test.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/shared/go/tokenutils/tokenutils_test.go b/shared/go/tokenutils/tokenutils_test.go index f5fa7b5c..76e29e8d 100644 --- a/shared/go/tokenutils/tokenutils_test.go +++ b/shared/go/tokenutils/tokenutils_test.go @@ -33,4 +33,9 @@ func TestGetPromptTokenCount(t *testing.T) { assert.Equal(t, testCase.expected, actual) }) } + + // test for the invalid encodings + tokenCnt, err := tokenutils.GetPromptTokenCount("Hello world!", tokenutils.Encoding("invalid")) + assert.NotNil(t, err) + assert.Equal(t, -1, tokenCnt) } From 5a72ebb0514503ffae04f1a784ed70cb727ea643 Mon Sep 17 00:00:00 2001 From: Kailash Bisht Date: Sat, 30 Sep 2023 20:27:18 +0530 Subject: [PATCH 6/6] (chore): address review comments --- .../api-gateway/connectors/openai/openai.go | 12 +++--- .../api-gateway/connectors/openai/utils.go | 3 +- .../connectors/openai/utils_test.go | 41 ++++++++++++++++++- shared/go/tokenutils/tokenutils.go | 19 ++++----- shared/go/tokenutils/tokenutils_test.go | 5 ++- 5 files changed, 59 insertions(+), 21 deletions(-) diff --git a/services/api-gateway/connectors/openai/openai.go b/services/api-gateway/connectors/openai/openai.go index d445fcdf..771514db 100644 --- a/services/api-gateway/connectors/openai/openai.go +++ b/services/api-gateway/connectors/openai/openai.go @@ -53,13 +53,13 @@ func (c *Client) RequestPrompt( } // Count the total number of tokens utilized for openai prompt - reqPromptString := GetRequestPromptString(promptRequest.Messages, tokenutils.Cl100kBase) - promptReqTokenCount, tokenizationErr := tokenutils.GetPromptTokenCount(reqPromptString, tokenutils.Cl100kBase) + reqPromptString := GetRequestPromptString(promptRequest.Messages) + promptReqTokenCount, tokenizationErr := tokenutils.GetPromptTokenCount(reqPromptString, applicationPromptConfig.PromptConfigData.ModelType) if tokenizationErr != nil { log.Err(tokenizationErr).Msg("failed to get prompt token count") } - promptResTokenCount, tokenizationErr := tokenutils.GetPromptTokenCount(response.Content, tokenutils.Cl100kBase) + promptResTokenCount, tokenizationErr := tokenutils.GetPromptTokenCount(response.Content, applicationPromptConfig.PromptConfigData.ModelType) if tokenizationErr != nil { log.Err(tokenizationErr).Msg("failed to get prompt token count") } @@ -94,8 +94,8 @@ func (c *Client) RequestStream( return } - reqPromptString := GetRequestPromptString(promptRequest.Messages, tokenutils.Cl100kBase) - promptReqTokenCount, tokenizationErr := tokenutils.GetPromptTokenCount(reqPromptString, tokenutils.Cl100kBase) + reqPromptString := GetRequestPromptString(promptRequest.Messages) + promptReqTokenCount, tokenizationErr := tokenutils.GetPromptTokenCount(reqPromptString, applicationPromptConfig.PromptConfigData.ModelType) if tokenizationErr != nil { log.Err(tokenizationErr).Msg("failed to get prompt token count") } @@ -114,7 +114,7 @@ func (c *Client) RequestStream( return } - streamResTokenCount, tokenizationErr := tokenutils.GetPromptTokenCount(msg.Content, tokenutils.Cl100kBase) + streamResTokenCount, tokenizationErr := tokenutils.GetPromptTokenCount(msg.Content, applicationPromptConfig.PromptConfigData.ModelType) if tokenizationErr != nil { log.Err(tokenizationErr).Msg("failed to get prompt token count") } diff --git a/services/api-gateway/connectors/openai/utils.go b/services/api-gateway/connectors/openai/utils.go index f5921dc6..873dcf5f 100644 --- a/services/api-gateway/connectors/openai/utils.go +++ b/services/api-gateway/connectors/openai/utils.go @@ -7,7 +7,6 @@ import ( "github.com/basemind-ai/monorepo/shared/go/datatypes" "github.com/basemind-ai/monorepo/shared/go/db" - "github.com/basemind-ai/monorepo/shared/go/tokenutils" openaiconnector "github.com/basemind-ai/monorepo/gen/go/openai/v1" ) @@ -89,7 +88,7 @@ func CreatePromptRequest( return promptRequest, nil } -func GetRequestPromptString(messages []*openaiconnector.OpenAIMessage, encoding tokenutils.Encoding) string { +func GetRequestPromptString(messages []*openaiconnector.OpenAIMessage) string { var promptMessages string for _, message := range messages { promptMessages += *message.Content diff --git a/services/api-gateway/connectors/openai/utils_test.go b/services/api-gateway/connectors/openai/utils_test.go index 882f6e6e..d981c1f7 100644 --- a/services/api-gateway/connectors/openai/utils_test.go +++ b/services/api-gateway/connectors/openai/utils_test.go @@ -2,11 +2,12 @@ package openai_test import ( "fmt" + "testing" + "github.com/basemind-ai/monorepo/e2e/factories" "github.com/basemind-ai/monorepo/services/api-gateway/connectors/openai" "github.com/basemind-ai/monorepo/shared/go/db" "github.com/stretchr/testify/assert" - "testing" openaiconnector "github.com/basemind-ai/monorepo/gen/go/openai/v1" ) @@ -169,4 +170,42 @@ func TestUtils(t *testing.T) { assert.Error(t, err) }) }) + + t.Run("GetRequestPromptString", func(t *testing.T) { + t.Run("returns the request prompt as string", func(t *testing.T) { + floatValue := float32(1) + uintValue := uint32(1) + + expectedModelParameters := &openaiconnector.OpenAIModelParameters{ + Temperature: &floatValue, + TopP: &floatValue, + MaxTokens: &uintValue, + PresencePenalty: &floatValue, + FrequencyPenalty: &floatValue, + } + + systemMessage := "You are a helpful chat bot." + applicationId := "12345" + userInput := "Please write an essay on Dogs." + content := fmt.Sprintf("This is what the user asked for: %s", userInput) + + promptRequest := &openaiconnector.OpenAIPromptRequest{ + Model: openaiconnector.OpenAIModel_OPEN_AI_MODEL_GPT3_5_TURBO_4K, + ApplicationId: &applicationId, + Parameters: expectedModelParameters, + Messages: []*openaiconnector.OpenAIMessage{ + { + Content: &systemMessage, + Role: openaiconnector.OpenAIMessageRole_OPEN_AI_MESSAGE_ROLE_SYSTEM, + }, + { + Content: &content, + Role: openaiconnector.OpenAIMessageRole_OPEN_AI_MESSAGE_ROLE_USER, + }, + }, + } + reqPromptString := openai.GetRequestPromptString(promptRequest.Messages) + assert.Equal(t, "You are a helpful chat bot.\nThis is what the user asked for: Please write an essay on Dogs.", reqPromptString) + }) + }) } diff --git a/shared/go/tokenutils/tokenutils.go b/shared/go/tokenutils/tokenutils.go index 7e226b1b..b102510d 100644 --- a/shared/go/tokenutils/tokenutils.go +++ b/shared/go/tokenutils/tokenutils.go @@ -1,20 +1,19 @@ package tokenutils import ( + "github.com/basemind-ai/monorepo/shared/go/db" "github.com/tiktoken-go/tokenizer" ) -type Encoding = tokenizer.Encoding - -const ( - GPT2Enc Encoding = tokenizer.GPT2Enc - R50kBase Encoding = tokenizer.R50kBase - P50kBase Encoding = tokenizer.P50kBase - P50kEdit Encoding = tokenizer.P50kEdit - Cl100kBase Encoding = tokenizer.Cl100kBase -) +var modelEncodingMap map[db.ModelType]tokenizer.Encoding = map[db.ModelType]tokenizer.Encoding{ + db.ModelTypeGpt35Turbo: tokenizer.Cl100kBase, + db.ModelTypeGpt35Turbo16k: tokenizer.Cl100kBase, + db.ModelTypeGpt4: tokenizer.Cl100kBase, + db.ModelTypeGpt432k: tokenizer.Cl100kBase, +} -func GetPromptTokenCount(prompt string, encoding tokenizer.Encoding) (int, error) { +func GetPromptTokenCount(prompt string, modelType db.ModelType) (int, error) { + encoding := modelEncodingMap[modelType] enc, err := tokenizer.Get(encoding) if err != nil { return -1, err diff --git a/shared/go/tokenutils/tokenutils_test.go b/shared/go/tokenutils/tokenutils_test.go index 76e29e8d..d00cefb8 100644 --- a/shared/go/tokenutils/tokenutils_test.go +++ b/shared/go/tokenutils/tokenutils_test.go @@ -4,6 +4,7 @@ import ( "fmt" "testing" + "github.com/basemind-ai/monorepo/shared/go/db" "github.com/basemind-ai/monorepo/shared/go/tokenutils" "github.com/stretchr/testify/assert" ) @@ -29,13 +30,13 @@ func TestGetPromptTokenCount(t *testing.T) { for _, testCase := range testCases { t.Run(fmt.Sprintf("Test: %d", testCase.expected), func(t *testing.T) { - actual, _ := tokenutils.GetPromptTokenCount(testCase.input, tokenutils.Cl100kBase) + actual, _ := tokenutils.GetPromptTokenCount(testCase.input, db.ModelTypeGpt35Turbo) assert.Equal(t, testCase.expected, actual) }) } // test for the invalid encodings - tokenCnt, err := tokenutils.GetPromptTokenCount("Hello world!", tokenutils.Encoding("invalid")) + tokenCnt, err := tokenutils.GetPromptTokenCount("Hello world!", db.ModelType("invalid")) assert.NotNil(t, err) assert.Equal(t, -1, tokenCnt) }