From 5c07a4516f40213aeb9a1e4290e0dd71c5cc785f Mon Sep 17 00:00:00 2001 From: Kimmo Lehto Date: Fri, 29 Sep 2023 11:50:40 +0300 Subject: [PATCH] Add ssh.AuthMethods for passing in custom authentication (#123) * Add ssh.AuthMethods for passing in custom authentication Signed-off-by: Kimmo Lehto * Add a test for passing in a private key Signed-off-by: Kimmo Lehto --------- Signed-off-by: Kimmo Lehto --- cmd/rigtest/rigtest.go | 37 +++++++++++++++++++++------- ssh.go | 55 ++++++++++++++++++++++++++++++++++++++++-- test/test.sh | 11 +++++++++ 3 files changed, 92 insertions(+), 11 deletions(-) diff --git a/cmd/rigtest/rigtest.go b/cmd/rigtest/rigtest.go index d8bcb33c..b2eee227 100644 --- a/cmd/rigtest/rigtest.go +++ b/cmd/rigtest/rigtest.go @@ -105,6 +105,7 @@ func main() { pwd := flag.String("pass", "", "winrm password") https := flag.Bool("https", false, "use https for winrm") connectOnly := flag.Bool("connect", false, "just connect and quit") + sshKey := flag.String("ssh-private-key", "", "ssh private key") fn := fmt.Sprintf("test_%s.txt", time.Now().Format("20060102150405")) @@ -159,16 +160,34 @@ func main() { var h *Host switch *proto { case "ssh": - h = &Host{ - Connection: rig.Connection{ - SSH: &rig.SSH{ - Address: address, - Port: port, - User: *usr, - KeyPath: kp, - PasswordCallback: passfunc, + if *sshKey != "" { + // test with private key in a string + authM, err := rig.ParseSSHPrivateKey([]byte(*sshKey), rig.DefaultPasswordCallback) + if err != nil { + panic(err) + } + h = &Host{ + Connection: rig.Connection{ + SSH: &rig.SSH{ + Address: address, + Port: port, + User: *usr, + AuthMethods: authM, + }, }, - }, + } + } else { + h = &Host{ + Connection: rig.Connection{ + SSH: &rig.SSH{ + Address: address, + Port: port, + User: *usr, + KeyPath: kp, + PasswordCallback: passfunc, + }, + }, + } } case "winrm": h = &Host{ diff --git a/ssh.go b/ssh.go index fc778c10..b602a700 100644 --- a/ssh.go +++ b/ssh.go @@ -35,7 +35,15 @@ type SSH struct { HostKey string `yaml:"hostKey,omitempty"` Bastion *SSH `yaml:"bastion,omitempty"` PasswordCallback PasswordCallback `yaml:"-"` - name string + + // AuthMethods can be used to pass in a list of ssh.AuthMethod objects + // for example to use a private key from memory: + // ssh.PublicKeys(privateKey) + // For convenience, you can use ParseSSHPrivateKey() to parse a private key: + // authMethods, err := rig.ParseSSHPrivateKey(key, rig.DefaultPassphraseCallback) + AuthMethods []ssh.AuthMethod `yaml:"-"` + + name string isWindows bool knowOs bool @@ -338,7 +346,7 @@ func (c *SSH) hostkeyCallback() (ssh.HostKeyCallback, error) { return knownhostsCallback(defaultPath, permissive, hash) } -func (c *SSH) clientConfig() (*ssh.ClientConfig, error) { +func (c *SSH) clientConfig() (*ssh.ClientConfig, error) { //nolint:cyclop config := &ssh.ClientConfig{ User: c.User, } @@ -360,6 +368,11 @@ func (c *SSH) clientConfig() (*ssh.ClientConfig, error) { } } + if len(c.AuthMethods) > 0 { + log.Tracef("%s: using %d passed-in auth methods", c, len(c.AuthMethods)) + config.Auth = c.AuthMethods + } + for _, keyPath := range c.keyPaths { if am, ok := authMethodCache.Load(keyPath); ok { switch authM := am.(type) { @@ -695,3 +708,41 @@ func (c *SSH) ExecInteractive(cmd string) error { return nil } + +// ParseSSHPrivateKey is a convenience utility to parses a private key and +// return []ssh.AuthMethod to be used in SSH{} AuthMethods field. This +// way you can avoid importing golang.org/x/crypto/ssh in your code +// and handle the passphrase prompt in a callback function. +func ParseSSHPrivateKey(key []byte, callback PasswordCallback) ([]ssh.AuthMethod, error) { + signer, err := ssh.ParsePrivateKey(key) + if err == nil { + return []ssh.AuthMethod{ssh.PublicKeys(signer)}, nil + } + var ppErr *ssh.PassphraseMissingError + if !errors.As(err, &ppErr) { + return nil, fmt.Errorf("failed to parse key: %w", err) + } + if callback == nil { + return nil, fmt.Errorf("key is encrypted and no callback provided: %w", err) + } + pass, err := callback() + if err != nil { + return nil, fmt.Errorf("failed to get passphrase: %w", err) + } + signer, err = ssh.ParsePrivateKeyWithPassphrase(key, []byte(pass)) + if err != nil { + return nil, fmt.Errorf("failed to parse key with passphrase: %w", err) + } + return []ssh.AuthMethod{ssh.PublicKeys(signer)}, nil +} + +// DefaultPasswordCallback is a default implementation for PasswordCallback +func DefaultPasswordCallback() (string, error) { + fmt.Print("Enter passphrase: ") + pass, err := term.ReadPassword(int(os.Stdin.Fd())) + fmt.Println() + if err != nil { + return "", fmt.Errorf("failed to read password: %w", err) + } + return string(pass), nil +} diff --git a/test/test.sh b/test/test.sh index 9e6b83c6..1307b964 100755 --- a/test/test.sh +++ b/test/test.sh @@ -170,6 +170,17 @@ rig_test_key_from_path() { RET=$exit_code } +rig_test_key_from_memory() { + color_echo "- Testing connecting using a key from string" + make create-host + mv .ssh/identity .ssh/identity2 + set +e + ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -ssh-private-key "$(cat .ssh/identity2)" -connect + local exit_code=$? + set -e + RET=$exit_code +} + rig_test_key_from_default_location() { color_echo "- Testing keypath from default location" make create-host