Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add token utils to count the total tokens utilized #49

Merged
merged 6 commits into from
Sep 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions e2e/factories/factories.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package factories
import (
"context"
"encoding/json"

openaiconnector "github.com/basemind-ai/monorepo/gen/go/openai/v1"
"github.com/basemind-ai/monorepo/shared/go/datatypes"
"github.com/basemind-ai/monorepo/shared/go/db"
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ require (
github.com/rs/zerolog v1.31.0
github.com/sethvargo/go-envconfig v0.9.0
github.com/stretchr/testify v1.8.4
github.com/tiktoken-go/tokenizer v0.1.0
golang.org/x/sync v0.3.0
google.golang.org/grpc v1.58.2
google.golang.org/protobuf v1.31.0
Expand All @@ -42,6 +43,7 @@ require (
github.com/containerd/continuity v0.4.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.9.0 // indirect
github.com/docker/cli v24.0.6+incompatible // indirect
github.com/docker/docker v24.0.6+incompatible // indirect
github.com/docker/go-connections v0.4.0 // indirect
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dlclark/regexp2 v1.9.0 h1:pTK/l/3qYIKaRXuHnEnIf7Y5NxfRPfpb7dis6/gdlVI=
github.com/dlclark/regexp2 v1.9.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/docker/cli v24.0.6+incompatible h1:fF+XCQCgJjjQNIMjzaSmiKJSCcfcXb3TWTcc7GAneOY=
github.com/docker/cli v24.0.6+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8=
github.com/docker/docker v24.0.6+incompatible h1:hceabKCtUgDqPu+qm0NgsaXf28Ljf4/pWFL7xjWWDgE=
Expand Down Expand Up @@ -249,6 +251,8 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/tiktoken-go/tokenizer v0.1.0 h1:c1fXriHSR/NmhMDTwUDLGiNhHwTV+ElABGvqhCWLRvY=
github.com/tiktoken-go/tokenizer v0.1.0/go.mod h1:7SZW3pZUKWLJRilTvWCa86TOVIiiJhYj3FQ5V3alWcg=
github.com/vmihailenco/go-tinylfu v0.2.2 h1:H1eiG6HM36iniK6+21n9LLpzx1G9R3DJa2UjUjbynsI=
github.com/vmihailenco/go-tinylfu v0.2.2/go.mod h1:CutYi2Q9puTxfcolkliPq4npPuofg9N9t8JVrjzwa3Q=
github.com/vmihailenco/msgpack/v5 v5.3.4/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
Expand Down
36 changes: 33 additions & 3 deletions services/api-gateway/connectors/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ package openai
import (
"context"
"errors"
"github.com/basemind-ai/monorepo/shared/go/datatypes"
"fmt"
"io"

"github.com/basemind-ai/monorepo/shared/go/datatypes"
"github.com/basemind-ai/monorepo/shared/go/tokenutils"

openaiconnector "github.com/basemind-ai/monorepo/gen/go/openai/v1"
"github.com/rs/zerolog/log"
"google.golang.org/grpc"
Expand Down Expand Up @@ -48,7 +51,20 @@ func (c *Client) RequestPrompt(
if requestErr != nil {
return "", requestErr
}
// TODO handle token related logic here by using the response token properties.

K-A-I-L-A-S-H marked this conversation as resolved.
Show resolved Hide resolved
// Count the total number of tokens utilized for openai prompt
reqPromptString := GetRequestPromptString(promptRequest.Messages)
promptReqTokenCount, tokenizationErr := tokenutils.GetPromptTokenCount(reqPromptString, applicationPromptConfig.PromptConfigData.ModelType)
if tokenizationErr != nil {
log.Err(tokenizationErr).Msg("failed to get prompt token count")
}

promptResTokenCount, tokenizationErr := tokenutils.GetPromptTokenCount(response.Content, applicationPromptConfig.PromptConfigData.ModelType)
if tokenizationErr != nil {
log.Err(tokenizationErr).Msg("failed to get prompt token count")
}

log.Debug().Msg(fmt.Sprintf("Total tokens utilized: Request-%d, Response-%d", promptReqTokenCount, promptResTokenCount))
return response.Content, nil
}

Expand Down Expand Up @@ -78,17 +94,31 @@ func (c *Client) RequestStream(
return
}

reqPromptString := GetRequestPromptString(promptRequest.Messages)
promptReqTokenCount, tokenizationErr := tokenutils.GetPromptTokenCount(reqPromptString, applicationPromptConfig.PromptConfigData.ModelType)
if tokenizationErr != nil {
log.Err(tokenizationErr).Msg("failed to get prompt token count")
}
log.Debug().Msg(fmt.Sprintf("Total tokens utilized for request prompt - %d", promptReqTokenCount))

var promptResTokenCount int

for {
msg, receiveErr := stream.Recv()
if receiveErr != nil {
if !errors.Is(receiveErr, io.EOF) {
errChannel <- receiveErr
}
close(contentChannel)
log.Debug().Msg(fmt.Sprintf("Tokens utilized for streaming response-%d", promptResTokenCount))
return
}

// TODO handle token related logic here
streamResTokenCount, tokenizationErr := tokenutils.GetPromptTokenCount(msg.Content, applicationPromptConfig.PromptConfigData.ModelType)
if tokenizationErr != nil {
log.Err(tokenizationErr).Msg("failed to get prompt token count")
}
promptResTokenCount += streamResTokenCount
contentChannel <- msg.Content
}
}
12 changes: 11 additions & 1 deletion services/api-gateway/connectors/openai/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ package openai
import (
"encoding/json"
"fmt"
"strings"

"github.com/basemind-ai/monorepo/shared/go/datatypes"
"github.com/basemind-ai/monorepo/shared/go/db"
"strings"

openaiconnector "github.com/basemind-ai/monorepo/gen/go/openai/v1"
)
Expand Down Expand Up @@ -86,3 +87,12 @@ func CreatePromptRequest(

return promptRequest, nil
}

func GetRequestPromptString(messages []*openaiconnector.OpenAIMessage) string {
var promptMessages string
for _, message := range messages {
promptMessages += *message.Content
promptMessages += "\n"
}
return strings.TrimRight(promptMessages, "\n")
}
41 changes: 40 additions & 1 deletion services/api-gateway/connectors/openai/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package openai_test

import (
"fmt"
"testing"

"github.com/basemind-ai/monorepo/e2e/factories"
"github.com/basemind-ai/monorepo/services/api-gateway/connectors/openai"
"github.com/basemind-ai/monorepo/shared/go/db"
"github.com/stretchr/testify/assert"
"testing"

openaiconnector "github.com/basemind-ai/monorepo/gen/go/openai/v1"
)
Expand Down Expand Up @@ -169,4 +170,42 @@ func TestUtils(t *testing.T) {
assert.Error(t, err)
})
})

t.Run("GetRequestPromptString", func(t *testing.T) {
t.Run("returns the request prompt as string", func(t *testing.T) {
floatValue := float32(1)
uintValue := uint32(1)

expectedModelParameters := &openaiconnector.OpenAIModelParameters{
Temperature: &floatValue,
TopP: &floatValue,
MaxTokens: &uintValue,
PresencePenalty: &floatValue,
FrequencyPenalty: &floatValue,
}

systemMessage := "You are a helpful chat bot."
applicationId := "12345"
userInput := "Please write an essay on Dogs."
content := fmt.Sprintf("This is what the user asked for: %s", userInput)

promptRequest := &openaiconnector.OpenAIPromptRequest{
Model: openaiconnector.OpenAIModel_OPEN_AI_MODEL_GPT3_5_TURBO_4K,
ApplicationId: &applicationId,
Parameters: expectedModelParameters,
Messages: []*openaiconnector.OpenAIMessage{
{
Content: &systemMessage,
Role: openaiconnector.OpenAIMessageRole_OPEN_AI_MESSAGE_ROLE_SYSTEM,
},
{
Content: &content,
Role: openaiconnector.OpenAIMessageRole_OPEN_AI_MESSAGE_ROLE_USER,
},
},
}
reqPromptString := openai.GetRequestPromptString(promptRequest.Messages)
assert.Equal(t, "You are a helpful chat bot.\nThis is what the user asked for: Please write an essay on Dogs.", reqPromptString)
})
})
}
24 changes: 24 additions & 0 deletions shared/go/tokenutils/tokenutils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package tokenutils

import (
"github.com/basemind-ai/monorepo/shared/go/db"
"github.com/tiktoken-go/tokenizer"
)

var modelEncodingMap map[db.ModelType]tokenizer.Encoding = map[db.ModelType]tokenizer.Encoding{
db.ModelTypeGpt35Turbo: tokenizer.Cl100kBase,
db.ModelTypeGpt35Turbo16k: tokenizer.Cl100kBase,
db.ModelTypeGpt4: tokenizer.Cl100kBase,
db.ModelTypeGpt432k: tokenizer.Cl100kBase,
}

func GetPromptTokenCount(prompt string, modelType db.ModelType) (int, error) {
encoding := modelEncodingMap[modelType]
enc, err := tokenizer.Get(encoding)
if err != nil {
return -1, err
}

ids, _, _ := enc.Encode(prompt)
return len(ids), nil
}
42 changes: 42 additions & 0 deletions shared/go/tokenutils/tokenutils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package tokenutils_test
K-A-I-L-A-S-H marked this conversation as resolved.
Show resolved Hide resolved

import (
"fmt"
"testing"

"github.com/basemind-ai/monorepo/shared/go/db"
"github.com/basemind-ai/monorepo/shared/go/tokenutils"
"github.com/stretchr/testify/assert"
)

func TestGetPromptTokenCount(t *testing.T) {
testCases := []struct {
input string
expected int
}{
{
input: "Hello world!",
expected: 3,
},
{
input: "",
expected: 0,
},
{
input: "Goodbye world!",
expected: 4,
},
}

for _, testCase := range testCases {
t.Run(fmt.Sprintf("Test: %d", testCase.expected), func(t *testing.T) {
actual, _ := tokenutils.GetPromptTokenCount(testCase.input, db.ModelTypeGpt35Turbo)
assert.Equal(t, testCase.expected, actual)
})
}

// test for the invalid encodings
tokenCnt, err := tokenutils.GetPromptTokenCount("Hello world!", db.ModelType("invalid"))
assert.NotNil(t, err)
assert.Equal(t, -1, tokenCnt)
}
Loading