Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: add openai support #105

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ dist
# IDE files
.vscode
.idea
__debug_bin
__debug_bin
test
138 changes: 138 additions & 0 deletions cli/gpt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package cli

import (
"bufio"
"context"
"fmt"
"io"
"log"
"os"
"strconv"
"strings"

"github.com/pkg/errors"
openai "github.com/replicatedhq/sbctl/pkg/openai"
sbctlutil "github.com/replicatedhq/sbctl/pkg/util"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)

func GptCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "gpt",
Short: "GPT",
Long: `GPT`,
SilenceUsage: true,
SilenceErrors: false,
PreRun: func(cmd *cobra.Command, args []string) {
viper.BindPFlags(cmd.Flags())
},
RunE: func(cmd *cobra.Command, args []string) error {
v := viper.GetViper()

apiKey, err := sbctlutil.GetOpenAIKey()
if err != nil {
return errors.Wrap(err, "failed to get openai key")
}

logger := createLogger("interaction.log")

client := openai.New(apiKey, 5000)

githubIssueURL := v.GetString("issue")

if githubIssueURL == "" {
return errors.New("issue is required")
} else {
issueContent := sbctlutil.GetGithubIssue()

// fmt.Println(string(resp.Choices[0].Message.Content))
scanner := bufio.NewScanner(os.Stdin)

for {
fmt.Print("You: ")
if !scanner.Scan() {
break
}
input := scanner.Text()
escapedInput := strconv.Quote(input)
logger.Printf(input)
if strings.ToLower(input) == "exit" {
break
}

if strings.ToLower(input) == "github:" {
resp, err := client.GetKubectlCmd(issueContent)
if err != nil {
return errors.Wrap(err, "failed to get kubectl command")
}

message := resp.Choices[0].Message
if message.Content != "" {
logAndPrintResponse(logger, message.Content)
}
} else {
resp, err := client.GetKubectlCmd(escapedInput)
if err != nil {
return errors.Wrap(err, "failed to get kubectl command")
}

message := resp.Choices[0].Message
if message.Content != "" {
logAndPrintResponse(logger, message.Content)
}
}
}
}

return nil
},
}
cmd.Flags().StringP("issue", "i", "", "github issue URL")
return cmd
}

func logAndPrintResponse(logger *log.Logger, message string) {
logger.Printf("AI: %s\n", message)
fmt.Printf("AI: \n%s\n", message)
}

func createLogger(logFile string) *log.Logger {
file, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal(err)
}
return log.New(file, "", log.LstdFlags)
}

func Chat(ctx context.Context) error {
apiKey, err := sbctlutil.GetOpenAIKey()
if err != nil {
return errors.Wrap(err, "failed to get openai key")
}

client := openai.New(apiKey, 5000)
chat := client.Chat(ctx)

err1 := make(chan error)
err2 := make(chan error)

go func() {
if _, err := io.Copy(os.Stdout, chat); err != nil {
err1 <- fmt.Errorf("gpt: couldn't copy: %w", err)
}
}()
go func() {
if _, err := io.Copy(chat, os.Stdin); err != nil {
err2 <- fmt.Errorf("gpt: couldn't copy: %w", err)
}
}()
select {
case <-ctx.Done():
return nil
case err := <-err1:
return err
case err := <-err2:
return err
}
}
1 change: 1 addition & 0 deletions cli/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ func RootCmd() *cobra.Command {

cmd.AddCommand(ServeCmd())
cmd.AddCommand(ShellCmd())
cmd.AddCommand(GptCmd())

viper.BindPFlags(cmd.Flags())

Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ require (
github.com/prometheus/client_model v0.2.0 // indirect
github.com/prometheus/common v0.32.1 // indirect
github.com/prometheus/procfs v0.7.3 // indirect
github.com/sashabaranov/go-openai v1.10.1
github.com/spf13/afero v1.9.3 // indirect
github.com/spf13/cast v1.5.0 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,8 @@ github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD
github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts=
github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts=
github.com/sagikazarmark/crypt v0.9.0/go.mod h1:RnH7sEhxfdnPm1z+XMgSLjWTEIjyK4z2dw6+4vHTMuo=
github.com/sashabaranov/go-openai v1.10.1 h1:6WyHJaNzF266VaEEuW6R4YW+Ei0wpMnqRYPGK7fhuhQ=
github.com/sashabaranov/go-openai v1.10.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc=
github.com/seccomp/libseccomp-golang v0.9.1/go.mod h1:GbW5+tmTXfcxTToHLXlScSlAvWlF4P2Ca7zGrPiEpWo=
github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM=
Expand Down
133 changes: 133 additions & 0 deletions pkg/openai/openai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package openai

import (
"context"
"fmt"
"io"
"log"
"time"

sbctlutil "github.com/replicatedhq/sbctl/pkg/util"
"github.com/sashabaranov/go-openai"
)

type Client struct {
openai.Client
maxTokens int
}

type rw struct {
client *Client
ctx context.Context
cancel context.CancelFunc
pipeReader *io.PipeReader
pipeWriter *io.PipeWriter
}

// New returns a new Client.
func New(key string, maxTokens int) *Client {

client := openai.NewClient(key)
return &Client{
Client: *client,
maxTokens: maxTokens,
}
}

func (c *Client) GetKubectlCmd(issueContent string) (openai.ChatCompletionResponse, error) {
userMessageForModel := fmt.Sprintf("Github Issue, generate five kubectl command for debuging: ####%s####", issueContent)

resp, err := c.Client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: "You are an kubenete expert that use kubectl command to debug issues. The github issue description will be delimited with #### characters.",
},
{
Role: openai.ChatMessageRoleUser,
Content: userMessageForModel,
},
},
},
)

if err != nil {
fmt.Printf("ChatCompletion error: %v\n", err)
return openai.ChatCompletionResponse{}, err
}

return resp, nil
}

// Chat creates a new chat session.
func (c *Client) Chat(ctx context.Context) io.ReadWriter {
ctx, cancel := context.WithCancel(ctx)
rd, wr := io.Pipe()
return &rw{
client: c,
ctx: ctx,
cancel: cancel,
pipeReader: rd,
pipeWriter: wr,
}
}

// Read reads from the chat.
func (r *rw) Read(b []byte) (n int, err error) {
if r.ctx.Err() != nil {
return 0, r.ctx.Err()
}
return r.pipeReader.Read(b)
}

// Write writes to the chat.
func (r *rw) Write(b []byte) (n int, err error) {
if r.ctx.Err() != nil {
return 0, r.ctx.Err()
}

request := sbctlutil.GetGithubIssue()
var completion openai.ChatCompletionResponse
for {
// Generate completion
completion, err = r.client.GetKubectlCmd(request)
if err != nil {
// Rate limit error, wait and try again
log.Println("openai: too many requests, waiting for 30 seconds...")
select {
case <-time.After(30 * time.Second):
case <-r.ctx.Done():
return 0, r.ctx.Err()
}
continue
}
if err != nil {
return 0, fmt.Errorf("openai: couldn't generate completion: %w", err)
}
break
}

if len(completion.Choices) == 0 {
return 0, fmt.Errorf("openai: no choices")
}
response := completion.Choices[0].Message.Content
log.Printf("openai: request tokens %d", completion.Usage.TotalTokens)

// Write response to pipe
go func() {
response := response + "\n"
if _, err := r.pipeWriter.Write([]byte(response)); err != nil {
log.Println(fmt.Errorf("openai: failed to write to pipe: %w", err))
}
}()
return len(b), nil
}

// Close closes the chat.
func (r *rw) Close() error {
r.cancel()
return r.pipeReader.Close()
}
25 changes: 25 additions & 0 deletions pkg/util/support-bundle.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
package util

import (
"fmt"
"io/ioutil"
"log"
"os"
)

var (
// sbResourceCompatibilityMap
sbResourceCompatibilityMap = map[string]string{
Expand All @@ -19,3 +26,21 @@ func GetSBCompatibleResourceName(resource string) string {
}
return resource
}

func GetOpenAIKey() (string, error) {
apiKey := os.Getenv("OPENAI_API_KEY")
if apiKey == "" {
return "", fmt.Errorf("OPENAPI_API_KEY environment variable is not set")
}
return apiKey, nil
}

func GetGithubIssue() string {
filePath := "github.yaml"
content, err := ioutil.ReadFile(filePath)
if err != nil {
log.Fatalf("Failed to read file: %v", err)
}

return string(content)
}