diff --git a/.gitignore b/.gitignore index f7db26d..bb29339 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ dist # IDE files .vscode .idea -__debug_bin \ No newline at end of file +__debug_bin +test diff --git a/cli/gpt.go b/cli/gpt.go new file mode 100644 index 0000000..c3e94de --- /dev/null +++ b/cli/gpt.go @@ -0,0 +1,138 @@ +package cli + +import ( + "bufio" + "context" + "fmt" + "io" + "log" + "os" + "strconv" + "strings" + + "github.com/pkg/errors" + openai "github.com/replicatedhq/sbctl/pkg/openai" + sbctlutil "github.com/replicatedhq/sbctl/pkg/util" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +func GptCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "gpt", + Short: "GPT", + Long: `GPT`, + SilenceUsage: true, + SilenceErrors: false, + PreRun: func(cmd *cobra.Command, args []string) { + viper.BindPFlags(cmd.Flags()) + }, + RunE: func(cmd *cobra.Command, args []string) error { + v := viper.GetViper() + + apiKey, err := sbctlutil.GetOpenAIKey() + if err != nil { + return errors.Wrap(err, "failed to get openai key") + } + + logger := createLogger("interaction.log") + + client := openai.New(apiKey, 5000) + + githubIssueURL := v.GetString("issue") + + if githubIssueURL == "" { + return errors.New("issue is required") + } else { + issueContent := sbctlutil.GetGithubIssue() + + // fmt.Println(string(resp.Choices[0].Message.Content)) + scanner := bufio.NewScanner(os.Stdin) + + for { + fmt.Print("You: ") + if !scanner.Scan() { + break + } + input := scanner.Text() + escapedInput := strconv.Quote(input) + logger.Printf(input) + if strings.ToLower(input) == "exit" { + break + } + + if strings.ToLower(input) == "github:" { + resp, err := client.GetKubectlCmd(issueContent) + if err != nil { + return errors.Wrap(err, "failed to get kubectl command") + } + + message := resp.Choices[0].Message + if message.Content != "" { + logAndPrintResponse(logger, message.Content) + } + } else { + resp, err := client.GetKubectlCmd(escapedInput) + if err != nil { + return errors.Wrap(err, "failed to get kubectl command") + } + + message := resp.Choices[0].Message + if message.Content != "" { + logAndPrintResponse(logger, message.Content) + } + } + } + } + + return nil + }, + } + cmd.Flags().StringP("issue", "i", "", "github issue URL") + return cmd +} + +func logAndPrintResponse(logger *log.Logger, message string) { + logger.Printf("AI: %s\n", message) + fmt.Printf("AI: \n%s\n", message) +} + +func createLogger(logFile string) *log.Logger { + file, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + log.Fatal(err) + } + return log.New(file, "", log.LstdFlags) +} + +func Chat(ctx context.Context) error { + apiKey, err := sbctlutil.GetOpenAIKey() + if err != nil { + return errors.Wrap(err, "failed to get openai key") + } + + client := openai.New(apiKey, 5000) + chat := client.Chat(ctx) + + err1 := make(chan error) + err2 := make(chan error) + + go func() { + if _, err := io.Copy(os.Stdout, chat); err != nil { + err1 <- fmt.Errorf("gpt: couldn't copy: %w", err) + } + }() + go func() { + if _, err := io.Copy(chat, os.Stdin); err != nil { + err2 <- fmt.Errorf("gpt: couldn't copy: %w", err) + } + }() + select { + case <-ctx.Done(): + return nil + case err := <-err1: + return err + case err := <-err2: + return err + } +} diff --git a/cli/root.go b/cli/root.go index 75a7a36..f1ff4e6 100644 --- a/cli/root.go +++ b/cli/root.go @@ -26,6 +26,7 @@ func RootCmd() *cobra.Command { cmd.AddCommand(ServeCmd()) cmd.AddCommand(ShellCmd()) + cmd.AddCommand(GptCmd()) viper.BindPFlags(cmd.Flags()) diff --git a/go.mod b/go.mod index b35062b..ef35a16 100644 --- a/go.mod +++ b/go.mod @@ -66,6 +66,7 @@ require ( github.com/prometheus/client_model v0.2.0 // indirect github.com/prometheus/common v0.32.1 // indirect github.com/prometheus/procfs v0.7.3 // indirect + github.com/sashabaranov/go-openai v1.10.1 github.com/spf13/afero v1.9.3 // indirect github.com/spf13/cast v1.5.0 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect diff --git a/go.sum b/go.sum index 4739913..90d76f3 100644 --- a/go.sum +++ b/go.sum @@ -997,6 +997,8 @@ github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= github.com/sagikazarmark/crypt v0.9.0/go.mod h1:RnH7sEhxfdnPm1z+XMgSLjWTEIjyK4z2dw6+4vHTMuo= +github.com/sashabaranov/go-openai v1.10.1 h1:6WyHJaNzF266VaEEuW6R4YW+Ei0wpMnqRYPGK7fhuhQ= +github.com/sashabaranov/go-openai v1.10.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/seccomp/libseccomp-golang v0.9.1/go.mod h1:GbW5+tmTXfcxTToHLXlScSlAvWlF4P2Ca7zGrPiEpWo= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= diff --git a/pkg/openai/openai.go b/pkg/openai/openai.go new file mode 100644 index 0000000..b069a4c --- /dev/null +++ b/pkg/openai/openai.go @@ -0,0 +1,133 @@ +package openai + +import ( + "context" + "fmt" + "io" + "log" + "time" + + sbctlutil "github.com/replicatedhq/sbctl/pkg/util" + "github.com/sashabaranov/go-openai" +) + +type Client struct { + openai.Client + maxTokens int +} + +type rw struct { + client *Client + ctx context.Context + cancel context.CancelFunc + pipeReader *io.PipeReader + pipeWriter *io.PipeWriter +} + +// New returns a new Client. +func New(key string, maxTokens int) *Client { + + client := openai.NewClient(key) + return &Client{ + Client: *client, + maxTokens: maxTokens, + } +} + +func (c *Client) GetKubectlCmd(issueContent string) (openai.ChatCompletionResponse, error) { + userMessageForModel := fmt.Sprintf("Github Issue, generate five kubectl command for debuging: ####%s####", issueContent) + + resp, err := c.Client.CreateChatCompletion( + context.Background(), + openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: "You are an kubenete expert that use kubectl command to debug issues. The github issue description will be delimited with #### characters.", + }, + { + Role: openai.ChatMessageRoleUser, + Content: userMessageForModel, + }, + }, + }, + ) + + if err != nil { + fmt.Printf("ChatCompletion error: %v\n", err) + return openai.ChatCompletionResponse{}, err + } + + return resp, nil +} + +// Chat creates a new chat session. +func (c *Client) Chat(ctx context.Context) io.ReadWriter { + ctx, cancel := context.WithCancel(ctx) + rd, wr := io.Pipe() + return &rw{ + client: c, + ctx: ctx, + cancel: cancel, + pipeReader: rd, + pipeWriter: wr, + } +} + +// Read reads from the chat. +func (r *rw) Read(b []byte) (n int, err error) { + if r.ctx.Err() != nil { + return 0, r.ctx.Err() + } + return r.pipeReader.Read(b) +} + +// Write writes to the chat. +func (r *rw) Write(b []byte) (n int, err error) { + if r.ctx.Err() != nil { + return 0, r.ctx.Err() + } + + request := sbctlutil.GetGithubIssue() + var completion openai.ChatCompletionResponse + for { + // Generate completion + completion, err = r.client.GetKubectlCmd(request) + if err != nil { + // Rate limit error, wait and try again + log.Println("openai: too many requests, waiting for 30 seconds...") + select { + case <-time.After(30 * time.Second): + case <-r.ctx.Done(): + return 0, r.ctx.Err() + } + continue + } + if err != nil { + return 0, fmt.Errorf("openai: couldn't generate completion: %w", err) + } + break + } + + if len(completion.Choices) == 0 { + return 0, fmt.Errorf("openai: no choices") + } + response := completion.Choices[0].Message.Content + log.Printf("openai: request tokens %d", completion.Usage.TotalTokens) + + // Write response to pipe + go func() { + response := response + "\n" + if _, err := r.pipeWriter.Write([]byte(response)); err != nil { + log.Println(fmt.Errorf("openai: failed to write to pipe: %w", err)) + } + }() + return len(b), nil +} + +// Close closes the chat. +func (r *rw) Close() error { + r.cancel() + return r.pipeReader.Close() +} diff --git a/pkg/util/support-bundle.go b/pkg/util/support-bundle.go index a310e1b..52e2068 100644 --- a/pkg/util/support-bundle.go +++ b/pkg/util/support-bundle.go @@ -1,5 +1,12 @@ package util +import ( + "fmt" + "io/ioutil" + "log" + "os" +) + var ( // sbResourceCompatibilityMap sbResourceCompatibilityMap = map[string]string{ @@ -19,3 +26,21 @@ func GetSBCompatibleResourceName(resource string) string { } return resource } + +func GetOpenAIKey() (string, error) { + apiKey := os.Getenv("OPENAI_API_KEY") + if apiKey == "" { + return "", fmt.Errorf("OPENAPI_API_KEY environment variable is not set") + } + return apiKey, nil +} + +func GetGithubIssue() string { + filePath := "github.yaml" + content, err := ioutil.ReadFile(filePath) + if err != nil { + log.Fatalf("Failed to read file: %v", err) + } + + return string(content) +}