diff --git a/pkg/ssh/hostkey/callbacks.go b/pkg/ssh/hostkey/callbacks.go index f314794d..689cb0b6 100644 --- a/pkg/ssh/hostkey/callbacks.go +++ b/pkg/ssh/hostkey/callbacks.go @@ -50,7 +50,7 @@ var KnownHostsPathFromEnv = func() (string, bool) { } // KnownHostsFileCallback returns a HostKeyCallback that uses a known hosts file to verify host keys -func KnownHostsFileCallback(path string, permissive bool) (ssh.HostKeyCallback, error) { +func KnownHostsFileCallback(path string, permissive, hash bool) (ssh.HostKeyCallback, error) { if path == "/dev/null" { return InsecureIgnoreHostKeyCallback, nil } @@ -67,13 +67,13 @@ func KnownHostsFileCallback(path string, permissive bool) (ssh.HostKeyCallback, return nil, fmt.Errorf("%w: knownhosts callback: %w", ErrCheckHostKey, err) } - return wrapCallback(hkc, path, permissive), nil + return wrapCallback(hkc, path, permissive, hash), nil } // extends a knownhosts callback to not return an error when the key // is not found in the known_hosts file but instead adds it to the file as new // entry -func wrapCallback(hkc ssh.HostKeyCallback, path string, permissive bool) ssh.HostKeyCallback { +func wrapCallback(hkc ssh.HostKeyCallback, path string, permissive, hash bool) ssh.HostKeyCallback { return ssh.HostKeyCallback(func(hostname string, remote net.Addr, key ssh.PublicKey) error { mu.Lock() defer mu.Unlock() @@ -99,6 +99,10 @@ func wrapCallback(hkc ssh.HostKeyCallback, path string, permissive bool) ssh.Hos } knownHostsEntry := knownhosts.Normalize(remote.String()) + if hash { + knownHostsEntry = knownhosts.HashHostname(knownHostsEntry) + } + row := knownhosts.Line([]string{knownHostsEntry}, key) row = fmt.Sprintf("%s\n", strings.TrimSpace(row)) diff --git a/ssh.go b/ssh.go index 32b04150..cb76679b 100644 --- a/ssh.go +++ b/ssh.go @@ -245,14 +245,36 @@ func (c *SSH) IsWindows() bool { return c.isWindows } -func knownhostsCallback(path string, permissive bool) (ssh.HostKeyCallback, error) { - cb, err := hostkey.KnownHostsFileCallback(path, permissive) +func knownhostsCallback(path string, permissive, hash bool) (ssh.HostKeyCallback, error) { + cb, err := hostkey.KnownHostsFileCallback(path, permissive, hash) if err != nil { return nil, fmt.Errorf("%w: create host key validator: %w", ErrCantConnect, err) } return cb, nil } +func isPermissive(c *SSH) bool { + if strict := c.getConfigAll("StrictHostkeyChecking"); len(strict) > 0 && strict[0] == "no" { + log.Debugf("%s: StrictHostkeyChecking is set to 'no'", c) + return true + } + + return false +} + +func shouldHash(c *SSH) bool { + var hash bool + if hashKnownHosts := c.getConfigAll("HashKnownHosts"); len(hashKnownHosts) == 1 { + hash := hashKnownHosts[0] == "yes" + if hash { + log.Debugf("%s: HashKnownHosts is set to %q, will hash newly added keys", c, hashKnownHosts[0]) + } else { + log.Debugf("%s: HashKnownHosts is set to %q, won't hash newly added keys", c, hashKnownHosts[0]) + } + } + return hash +} + func (c *SSH) hostkeyCallback() (ssh.HostKeyCallback, error) { if c.HostKey != "" { log.Debugf("%s: using host key from config", c) @@ -262,19 +284,15 @@ func (c *SSH) hostkeyCallback() (ssh.HostKeyCallback, error) { knownHostsMU.Lock() defer knownHostsMU.Unlock() - var permissive bool - strict := c.getConfigAll("StrictHostkeyChecking") - if len(strict) > 0 && strict[0] == "no" { - log.Debugf("%s: StrictHostkeyChecking is set to 'no'", c) - permissive = true - } + permissive := isPermissive(c) + hash := shouldHash(c) if path, ok := hostkey.KnownHostsPathFromEnv(); ok { if path == "" { return hostkey.InsecureIgnoreHostKeyCallback, nil } log.Tracef("%s: using known_hosts file from SSH_KNOWN_HOSTS: %s", c, path) - return knownhostsCallback(path, permissive) + return knownhostsCallback(path, permissive, hash) } var khPath string @@ -295,7 +313,7 @@ func (c *SSH) hostkeyCallback() (ssh.HostKeyCallback, error) { if khPath != "" { log.Tracef("%s: using known_hosts file from ssh config %s", c, khPath) - return knownhostsCallback(khPath, permissive) + return knownhostsCallback(khPath, permissive, hash) } log.Tracef("%s: using default known_hosts file %s", c, hostkey.DefaultKnownHostsPath) @@ -304,7 +322,7 @@ func (c *SSH) hostkeyCallback() (ssh.HostKeyCallback, error) { return nil, err } - return knownhostsCallback(defaultPath, permissive) + return knownhostsCallback(defaultPath, permissive, hash) } func (c *SSH) clientConfig() (*ssh.ClientConfig, error) {