diff --git a/.gitignore b/.gitignore index 66fd13c..aedf692 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ *.dll *.so *.dylib +*.idea # Test binary, built with `go test -c` *.test diff --git a/executor_structures.go b/executor_structures.go new file mode 100644 index 0000000..153f1ce --- /dev/null +++ b/executor_structures.go @@ -0,0 +1,25 @@ +package sshExecutor + +// Ssh term sizes +const ( + xTermHeight = 80 + xTermWidth = 40 +) + +// Params for scp file sending +type FilePathParams struct { + RootFolder string + FolderName string + FolderRights string + FileName string + FileRights string + Content string +} + +// Params for ssh connection +type ConnectParams struct { + Host string + Port string + User string + Psw string +} diff --git a/executor_test.go b/executor_test.go new file mode 100644 index 0000000..b02136e --- /dev/null +++ b/executor_test.go @@ -0,0 +1,598 @@ +package sshExecutor + +import ( + errorLib "github.com/NGRsoftlab/error-lib" + + "fmt" + "github.com/stretchr/testify/assert" + "os" + "testing" + "time" +) + +type testCase struct { + name string + inputOpt []string // not used + inputData interface{} + outputData interface{} + failError error + mustFail bool +} + +// ok values for tests +// TODO: change to real creds for ok tests (!) +const okHost, okPort, okUser, okPsw = "127.0.0.1", "22", "test", "test" +const connTimeout, cmdTimeout = time.Second * 30, time.Second * 30 + +///////////////////////////////////////////////// +func TestLocalExecContext(t *testing.T) { + t.Parallel() + + type testInfo struct { + command string + params []string + timeout time.Duration + } + + testCases := []*testCase{ + { + name: "bad command", + inputData: testInfo{ + command: "hhhh", + params: []string{"test"}, + timeout: cmdTimeout, + }, + failError: errorLib.GlobalErrors.ErrSshCommands(), + mustFail: true, + }, + { + name: "ok case (same windows&linux)", + inputData: testInfo{ + command: "arp", + params: []string{"-a"}, + timeout: cmdTimeout, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(tt *testing.T) { + output, err := LocalExecContext((tc.inputData).(testInfo).timeout, + (tc.inputData).(testInfo).command, + (tc.inputData).(testInfo).params...) + + tt.Log("out:", string(output)) + + if tc.mustFail { + assert.Equal(tt, tc.failError, err) + assert.Error(tt, err) + return + } + if !assert.NoError(tt, err) { + return + } + }) + } +} + +///////////////////////////////////////////////// +func TestGetConnection(t *testing.T) { + t.Parallel() + + type testInfo struct { + connParams ConnectParams + timeout time.Duration + } + + testCases := []*testCase{ + { + name: "bad empty auth data", + inputData: testInfo{ + connParams: ConnectParams{ + Host: okHost, + Port: okPort, + User: "", + Psw: "", + }, + timeout: connTimeout, + }, + failError: errorLib.GlobalErrors.ErrBadAuthData(), + mustFail: true, + }, + { + name: "bad host connection", + inputData: testInfo{ + connParams: ConnectParams{ + Host: "999.999.999.999", + Port: "22", + User: okUser, + Psw: okPsw, + }, + timeout: connTimeout, + }, + failError: errorLib.GlobalErrors.ErrBadIpOrPort(), + mustFail: true, + }, + { + name: "bad username/password connection", + inputData: testInfo{ + connParams: ConnectParams{ + Host: okHost, + Port: okPort, + User: "___", + Psw: "___", + }, + timeout: connTimeout, + }, + failError: errorLib.GlobalErrors.ErrBadAuthData(), + mustFail: true, + }, + { + name: "bad port connection (not free, not ssh port)", + inputData: testInfo{ + connParams: ConnectParams{ + Host: okHost, + Port: "8123", + User: "___", + Psw: "___", + }, + timeout: connTimeout, + }, + failError: errorLib.GlobalErrors.ErrConnectionTimeout(), + mustFail: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(tt *testing.T) { + _, err := GetSshConnection((tc.inputData).(testInfo).connParams, (tc.inputData).(testInfo).timeout) + + if tc.mustFail { + assert.Equal(tt, tc.failError, err) + assert.Error(tt, err) + return + } + if !assert.NoError(tt, err) { + return + } + }) + } +} + +///////////////////////////////////////////////// +func TestGetSudoCommandsWithoutErrOut(t *testing.T) { + t.Parallel() + + type testInfo struct { + connParams ConnectParams + commands []string + } + + testCases := []*testCase{ + { + name: "bad empty auth data", + inputData: testInfo{ + connParams: ConnectParams{ + Host: okHost, + Port: okPort, + User: "", + Psw: "", + }, + commands: []string{}, + }, + failError: errorLib.GlobalErrors.ErrBadAuthData(), + mustFail: true, + }, + { + name: "bad ip", + inputData: testInfo{ + connParams: ConnectParams{ + Host: "999.999.99.99", + Port: okPort, + User: okUser, + Psw: okPsw, + }, + commands: []string{}, + }, + failError: errorLib.GlobalErrors.ErrBadIpOrPort(), + mustFail: true, + }, + { + name: "bad port (not ssh or not free)", + inputData: testInfo{ + connParams: ConnectParams{ + Host: okHost, + Port: "8123", + User: okUser, + Psw: okPsw, + }, + commands: []string{}, + }, + failError: errorLib.GlobalErrors.ErrConnectionTimeout(), + mustFail: true, + }, + { + name: "bad command", + inputData: testInfo{ + connParams: ConnectParams{ + Host: okHost, + Port: okPort, + User: okUser, + Psw: okPsw, + }, + commands: []string{"adgsgsdgdfsjn earyery34"}, + }, + failError: errorLib.GlobalErrors.ErrSshCommands(), + mustFail: true, + }, + { + name: "bad endless command", + inputData: testInfo{ + connParams: ConnectParams{ + Host: okHost, + Port: okPort, + User: okUser, + Psw: okPsw, + }, + commands: []string{"sudo journalctl -fu test_test_test"}, + }, + failError: errorLib.GlobalErrors.ErrConnectionTimeout(), + mustFail: true, + }, + { + name: "ok case", + inputData: testInfo{ + connParams: ConnectParams{ + Host: okHost, + Port: okPort, + User: okUser, + Psw: okPsw, + }, + commands: []string{"sudo arp -a"}, + }, + }, + { + name: "ok case (no commands)", + inputData: testInfo{ + connParams: ConnectParams{ + Host: okHost, + Port: okPort, + User: okUser, + Psw: okPsw, + }, + commands: []string{}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(tt *testing.T) { + output, err := GetSudoCommandsOutWithoutErr((tc.inputData).(testInfo).connParams, + connTimeout, + cmdTimeout, + (tc.inputData).(testInfo).commands...) + + tt.Log("out:", string(output)) + + if tc.mustFail { + assert.Equal(tt, tc.failError, err) + assert.Error(tt, err) + return + } + if !assert.NoError(tt, err) { + return + } + }) + } +} + +///////////////////////////////////////////////// +func TestGetCommandOutWithErr(t *testing.T) { + t.Parallel() + + type testInfo struct { + connParams ConnectParams + kill chan *os.Signal + command string + } + + testCases := []*testCase{ + { + name: "bad empty auth data", + + inputData: testInfo{ + connParams: ConnectParams{ + Host: okHost, + Port: okPort, + User: "", + Psw: "", + }, + kill: make(chan *os.Signal), + command: "", + }, + failError: errorLib.GlobalErrors.ErrBadAuthData(), + mustFail: true, + }, + { + name: "bad ip", + inputData: testInfo{ + connParams: ConnectParams{ + Host: "999.999.99.99", + Port: okPort, + User: okUser, + Psw: okPsw, + }, + kill: make(chan *os.Signal), + command: "", + }, + failError: errorLib.GlobalErrors.ErrBadIpOrPort(), + mustFail: true, + }, + { + name: "bad port (not ssh or not free)", + inputData: testInfo{ + connParams: ConnectParams{ + Host: okHost, + Port: "8123", + User: okUser, + Psw: okPsw, + }, + kill: make(chan *os.Signal), + command: "", + }, + failError: errorLib.GlobalErrors.ErrConnectionTimeout(), + mustFail: true, + }, + { + name: "bad command", + inputData: testInfo{ + connParams: ConnectParams{ + Host: okHost, + Port: okPort, + User: okUser, + Psw: okPsw, + }, + kill: make(chan *os.Signal), + command: "hhhhhhhhhh", + }, + failError: errorLib.GlobalErrors.ErrSshCommands(), + mustFail: true, + }, + { + name: "bad endless command", + inputData: testInfo{ + connParams: ConnectParams{ + Host: okHost, + Port: okPort, + User: okUser, + Psw: okPsw, + }, + kill: make(chan *os.Signal), + command: "journalctl -fu test_test", + }, + failError: errorLib.GlobalErrors.ErrConnectionTimeout(), + mustFail: true, + }, + { + name: "ok case", + inputData: testInfo{ + connParams: ConnectParams{ + Host: okHost, + Port: okPort, + User: okUser, + Psw: okPsw, + }, + kill: make(chan *os.Signal), + command: "arp -a", + }, + }, + { + name: "ok case (empty command)", + inputData: testInfo{ + connParams: ConnectParams{ + Host: okHost, + Port: okPort, + User: okUser, + Psw: okPsw, + }, + kill: make(chan *os.Signal), + command: "", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(tt *testing.T) { + output, errOut, duration, err := GetCommandOutWithErr((tc.inputData).(testInfo).connParams, + (tc.inputData).(testInfo).kill, + connTimeout, + cmdTimeout, + (tc.inputData).(testInfo).command) + + tt.Log("out:", string(output)) + tt.Log("errOut:", string(errOut)) + tt.Log("duration:", duration) + + if tc.mustFail { + assert.Equal(tt, tc.failError, err) + assert.Error(tt, err) + return + } + if !assert.NoError(tt, err) { + return + } + }) + } +} + +///////////////////////////////////////////////// +func TestSendFileWithScp(t *testing.T) { + t.Parallel() + + type testInfo struct { + connParams ConnectParams + kill chan *os.Signal + filePathParams FilePathParams + } + + testCases := []*testCase{ + { + name: "bad empty auth data", + inputData: testInfo{ + connParams: ConnectParams{ + Host: okHost, + Port: okPort, + User: "", + Psw: "", + }, + kill: make(chan *os.Signal), + filePathParams: FilePathParams{ + RootFolder: "", + FolderName: "", + FileName: "", + FolderRights: "755", + FileRights: "777", + Content: "", + }, + }, + failError: errorLib.GlobalErrors.ErrBadAuthData(), + mustFail: true, + }, + { + name: "bad root path", + inputData: testInfo{ + connParams: ConnectParams{ + Host: okHost, + Port: okPort, + User: okUser, + Psw: okPsw, + }, + kill: make(chan *os.Signal), + filePathParams: FilePathParams{ + RootFolder: "@@@", + FolderName: "", + FileName: "", + FolderRights: "755", + FileRights: "777", + Content: "", + }, + }, + failError: errorLib.GlobalErrors.ErrSshCommands(), + mustFail: true, + }, + { + name: "ok case", + inputData: testInfo{ + connParams: ConnectParams{ + Host: okHost, + Port: okPort, + User: okUser, + Psw: okPsw, + }, + kill: make(chan *os.Signal), + filePathParams: FilePathParams{ + RootFolder: "/home", + FolderName: "", + FileName: "test.txt", + FolderRights: "755", + FileRights: "777", + Content: "hi guys", + }, + }, + }, + { + name: "ok case new folder", + inputData: testInfo{ + connParams: ConnectParams{ + Host: okHost, + Port: okPort, + User: okUser, + Psw: okPsw, + }, + kill: make(chan *os.Signal), + filePathParams: FilePathParams{ + RootFolder: "/home", + FolderName: "test_scp", + FileName: "test.txt", + FolderRights: "755", + FileRights: "777", + Content: "hi guys", + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(tt *testing.T) { + duration, err := SendFileWithScp((tc.inputData).(testInfo).connParams, + (tc.inputData).(testInfo).kill, + connTimeout, + cmdTimeout, + (tc.inputData).(testInfo).filePathParams) + + tt.Log("duration:", duration) + + if tc.mustFail { + assert.Equal(tt, tc.failError, err) + assert.Error(tt, err) + return + } + if !assert.NoError(tt, err) { + return + } + }) + } +} + +///////////////////////////////////////////////// +func TestGetCommandOutWithErr2(t *testing.T) { + t.Parallel() + + type testInfo struct { + connParams ConnectParams + kill chan *os.Signal + command string + } + + testCases := []*testCase{ + { + name: "ok?", + inputData: testInfo{ + connParams: ConnectParams{ + Host: okHost, + Port: okPort, + User: okUser, + Psw: okPsw, + }, + kill: make(chan *os.Signal), + command: "python3" + " " + fmt.Sprintf("%v%v/%v", "/home", + "test", "mg8-----.py"), + }, + mustFail: true, + failError: errorLib.GlobalErrors.ErrSshCommands(), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(tt *testing.T) { + output, errOut, duration, err := GetCommandOutWithErr((tc.inputData).(testInfo).connParams, + (tc.inputData).(testInfo).kill, + connTimeout, + cmdTimeout, + (tc.inputData).(testInfo).command) + + tt.Log("out:", string(output)) + tt.Log("errOut:", string(errOut)) + tt.Log("duration:", duration) + + if tc.mustFail { + assert.Equal(tt, tc.failError, err) + assert.Error(tt, err) + return + } + if !assert.NoError(tt, err) { + return + } + }) + } +} diff --git a/executor_utils.go b/executor_utils.go new file mode 100644 index 0000000..bd5e6da --- /dev/null +++ b/executor_utils.go @@ -0,0 +1,217 @@ +package sshExecutor + +import ( + errorLib "github.com/NGRsoftlab/error-lib" + "github.com/NGRsoftlab/ngr-logging" + + "bufio" + "context" + "golang.org/x/crypto/ssh" + "io" + "net" + "os" + "os/exec" + "time" +) + +// There are some standard scenarios of using ssh or scp +///////////////////////////////////////////////////////////////////// + +// Recovering from panic +func recoverExecutor() { + if r := recover(); r != nil { + logging.Logger.Warning("executor recovered from: ", r) + } +} + +// StdPipe to text +func scanPipe(r io.Reader) string { + scanner := bufio.NewScanner(r) + var scanningText string + for scanner.Scan() { + scanningText = scanningText + "\n" + scanner.Text() + } + return scanningText +} + +// Get ssh client config (ssh.ClientConfig) +func makeSshClientConfig(user, password string, timeout time.Duration) *ssh.ClientConfig { + return &ssh.ClientConfig{ + User: user, + Auth: []ssh.AuthMethod{ + ssh.Password(password), + }, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }, + Timeout: timeout, + } +} + +// Get ssh.TerminalModes obj +func makeSshTerminalModes() ssh.TerminalModes { + return ssh.TerminalModes{ + ssh.ECHO: 0, + ssh.TTY_OP_ISPEED: 14400, + ssh.TTY_OP_OSPEED: 14400, + } +} + +///////////////////////////////////////////////////////////////////// + +// Local execution command with context +func LocalExecContext(timeout time.Duration, command string, params ...string) (output []byte, err error) { + logging.Logger.Debug(command, " ::: ", params) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + cmd := exec.CommandContext(ctx, command, params...) + output, err = cmd.Output() + if err != nil { + logging.Logger.Warning("Error local exec: ", err, " ::: ", string(output)) + return output, errorLib.GlobalErrors.ErrSshCommands() + } + return +} + +// Get result sudo (!) commands... output from ssh connection +func GetSudoCommandsOutWithoutErr(connParams ConnectParams, + timeoutConn, timeoutCmd time.Duration, commands ...string) (output []byte, err error) { + conn, err := GetSshConnection(connParams, timeoutConn) + if err != nil { + logging.Logger.Error(err) + return output, err + } + + defer func() { + err := conn.Close() + if err != nil { + logging.Logger.Warning("bad conn close: ", err) + } + }() + + ctx, cancel := context.WithTimeout( + context.Background(), + timeoutCmd) + + go func(ctx context.Context) { + defer cancel() + output, err = conn.SendSudoCommandsWithoutErrOut(commands...) + }(ctx) + + select { + case <-ctx.Done(): + switch ctx.Err() { + case context.DeadlineExceeded: + logging.Logger.Error("ssh sudo commands timeout") + return output, errorLib.GlobalErrors.ErrConnectionTimeout() + case context.Canceled: + logging.Logger.Info("ssh conn canceled by timeout") + } + } + + if err != nil { + logging.Logger.Error("ssh sudo commands error: ", err) + return output, errorLib.GlobalErrors.ErrSshCommands() + } + + return output, nil +} + +// Get result command output (with errOut) from ssh connection +func GetCommandOutWithErr(connParams ConnectParams, kill chan *os.Signal, + timeoutConn, timeoutCmd time.Duration, command string) (output []byte, errOutput []byte, duration time.Duration, err error) { + defer func() { + kill = nil + }() + + conn, err := GetSshConnection(connParams, timeoutConn) + if err != nil { + logging.Logger.Error(err) + return output, errOutput, 0, err + } + + defer func() { + err := conn.Close() + if err != nil { + logging.Logger.Warning("bad conn close: ", err) + } + }() + + ctx, cancel := context.WithTimeout( + context.Background(), + timeoutCmd) + + go func(ctx context.Context) { + defer cancel() + output, errOutput, duration, err = conn.SendOneCommandWithErrOut(kill, command) + }(ctx) + + select { + case <-ctx.Done(): + switch ctx.Err() { + case context.DeadlineExceeded: + logging.Logger.Error("ssh command timeout") + kill <- &os.Kill + return output, errOutput, duration, errorLib.GlobalErrors.ErrConnectionTimeout() + case context.Canceled: + logging.Logger.Info("ssh conn canceled by timeout") + } + } + + if err != nil { + logging.Logger.Error("ssh command error: ", err) + return output, []byte(err.Error()), duration, errorLib.GlobalErrors.ErrSshCommands() + } + + return output, errOutput, duration, nil +} + +// Get result command output (with errOut) from ssh connection +func SendFileWithScp(connParams ConnectParams, kill chan *os.Signal, + timeoutConn, timeoutCmd time.Duration, pathParams FilePathParams) (time.Duration, error) { + defer func() { + kill = nil + }() + + conn, err := GetSshConnection(connParams, timeoutConn) + if err != nil { + logging.Logger.Error(err) + return 0, err + } + + defer func() { + err := conn.Close() + if err != nil { + logging.Logger.Warning("bad conn close: ", err) + } + }() + + ctx, cancel := context.WithTimeout( + context.Background(), + timeoutCmd) + + var duration time.Duration + + go func(ctx context.Context) { + defer cancel() + duration, err = conn.SendScpFile(kill, pathParams) + }(ctx) + + select { + case <-ctx.Done(): + switch ctx.Err() { + case context.DeadlineExceeded: + logging.Logger.Error("ssh command timeout") + kill <- &os.Kill + return duration, errorLib.GlobalErrors.ErrConnectionTimeout() + case context.Canceled: + logging.Logger.Info("ssh conn canceled by timeout") + } + } + + if err != nil { + logging.Logger.Error("ssh command error: ", err) + return duration, errorLib.GlobalErrors.ErrSshCommands() + } + + return duration, nil +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..bd132d5 --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module github.com/NGRsoftlab/ssh-executor + +go 1.13 + +require ( + github.com/NGRsoftlab/error-lib v1.0.0 + github.com/NGRsoftlab/ngr-logging v1.0.0 + github.com/sirupsen/logrus v1.8.1 // indirect + github.com/stretchr/testify v1.7.0 + golang.org/x/crypto v0.0.0-20210503195802-e9a32991a82e +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..56cb76e --- /dev/null +++ b/go.sum @@ -0,0 +1,29 @@ +github.com/NGRsoftlab/error-lib v1.0.0 h1:s8e+QLbr5e4S3VDHAy8HDATXKqWpqYtQiUD9SialEWw= +github.com/NGRsoftlab/error-lib v1.0.0/go.mod h1:RbbAZ5CPZziD++EecZFG2fIbh+K1VIWpxNrg/TAziL4= +github.com/NGRsoftlab/ngr-logging v1.0.0 h1:Yp42kvw/bofZ6xXC5jPlPx1HNabZQY9cvzGnB0earJY= +github.com/NGRsoftlab/ngr-logging v1.0.0/go.mod h1:99kZ+XwSK7rKRitmhZvqOdYnPf9Qepywt3zjJJcJDME= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= +github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/crypto v0.0.0-20210503195802-e9a32991a82e h1:8foAy0aoO5GkqCvAEJ4VC4P3zksTg4X4aJCDpZzmgQI= +golang.org/x/crypto v0.0.0-20210503195802-e9a32991a82e/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/ssh_connection.go b/ssh_connection.go new file mode 100644 index 0000000..618aed2 --- /dev/null +++ b/ssh_connection.go @@ -0,0 +1,274 @@ +package sshExecutor + +import ( + errorLib "github.com/NGRsoftlab/error-lib" + "github.com/NGRsoftlab/ngr-logging" + + "bufio" + "context" + "fmt" + "golang.org/x/crypto/ssh" + "io" + "os" + "strings" + "sync" + "time" +) + +// Ssh connection struct +type Connection struct { + *ssh.Client + password string +} + +// Getting ssh connection +func GetSshConnection(connParams ConnectParams, timeout time.Duration) (connection *Connection, err error) { + sshConfig := makeSshClientConfig(connParams.User, connParams.Psw, timeout) + + ctx, cancel := context.WithTimeout( + context.Background(), + timeout) + + var connClient *ssh.Client + + go func(ctx context.Context) { + defer cancel() + connClient, err = ssh.Dial("tcp", fmt.Sprintf("%v:%v", connParams.Host, connParams.Port), sshConfig) + if err != nil { + logging.Logger.Error("bad ssh conn: ", err) + } + }(ctx) + + select { + case <-ctx.Done(): + switch ctx.Err() { + case context.DeadlineExceeded: + logging.Logger.Error("ssh conn timeout") + return nil, errorLib.GlobalErrors.ErrConnectionTimeout() + case context.Canceled: + logging.Logger.Info("ssh conn canceled by timeout") + } + } + + if err != nil || connClient == nil { + if strings.Contains(err.Error(), "unable to authenticate") { + return nil, errorLib.GlobalErrors.ErrBadAuthData() + } else { + return nil, errorLib.GlobalErrors.ErrBadIpOrPort() + } + } + + return &Connection{connClient, connParams.Psw}, nil +} + +// Connection checking for sudo password ask (with recovery, be careful) +func (conn *Connection) SendSudoPassword(in io.WriteCloser, out io.Reader, output *[]byte) { + // recovery + defer recoverExecutor() + + var ( + line string + r = bufio.NewReader(out) + ) + for { + b, err := r.ReadByte() + if err != nil { + break + } + + *output = append(*output, b) + + if b == byte('\n') { + line = "" + continue + } + + line += string(b) + + if strings.HasPrefix(line, "[sudo] password for ") && strings.HasSuffix(line, ": ") { + _, err = in.Write([]byte(conn.password + "\n")) + if err != nil { + break + } + } + } +} + +// Sending many commands (may be with SUDO, without strErr output) +func (conn *Connection) SendSudoCommandsWithoutErrOut(commands ...string) ([]byte, error) { + session, err := conn.NewSession() + if err != nil { + logging.Logger.Error(err) + return nil, err + } + + err = session.RequestPty("xterm", xTermHeight, xTermWidth, makeSshTerminalModes()) + if err != nil { + logging.Logger.Error(err) + return nil, err + } + + in, err := session.StdinPipe() + if err != nil { + logging.Logger.Error(err) + return nil, err + } + + out, err := session.StdoutPipe() + if err != nil { + logging.Logger.Error(err) + return nil, err + } + + var output []byte + + go conn.SendSudoPassword(in, out, &output) + + commandsString := strings.Join(commands, "; ") + _, err = session.Output(commandsString) + if err != nil { + return nil, err + } + + return output, nil +} + +// Sending one command (no SUDO, with strErr output, with killChan) +func (conn *Connection) SendOneCommandWithErrOut(kill chan *os.Signal, command string) ([]byte, []byte, time.Duration, error) { + start := time.Now() + + session, err := conn.NewSession() + if err != nil { + logging.Logger.Error("Error session ssh: ", err) + return nil, nil, time.Since(start), err + } + + if err := session.RequestPty("xterm", xTermHeight, xTermWidth, makeSshTerminalModes()); err != nil { + _ = session.Close() + logging.Logger.Error("Error terminal: ", err) + return nil, nil, time.Since(start), err + } + + stdout, err := session.StdoutPipe() + if err != nil { + logging.Logger.Error("Error session stdout: ", err) + return nil, nil, time.Since(start), err + } + + stderr, err := session.StderrPipe() + if err != nil { + logging.Logger.Error("Error session stderr: ", err) + return nil, nil, time.Since(start), err + } + + wg := &sync.WaitGroup{} + wg.Add(1) + + go func() { + defer wg.Done() + err = session.Run(command) + + logging.Logger.Info("SEND SIGNAL", err) + if kill != nil { + kill <- &os.Kill + } + return + }() + + if kill != nil { + select { + case <-kill: + _ = session.Close() + kill = nil + } + } + + wg.Wait() + + stdOut := []byte(scanPipe(stdout)) + stdErr := []byte(scanPipe(stderr)) + + if err != nil { + return stdOut, stdErr, time.Since(start), err + } + return stdOut, stdErr, time.Since(start), nil +} + +// Sending file (content = file string content) to rootFolder/folderName/fileName with scp +func (conn *Connection) SendScpFile(kill chan *os.Signal, pathParams FilePathParams) (time.Duration, error) { + start := time.Now() + + session, err := conn.NewSession() + if err != nil { + logging.Logger.Error("Error session ssh: ", err) + return time.Since(start), err + } + + defer func() { _ = session.Close() }() + + go func() { + w, _ := session.StdinPipe() + defer func() { + if w != nil { + _ = w.Close() + } + }() + + if pathParams.FolderName != "" { + _, err = fmt.Fprintln(w, fmt.Sprintf("D0%s", pathParams.FolderRights), + 0, pathParams.FolderName) // mkdir (d-dir) + if err != nil { + logging.Logger.Error("Error session scp: ", err) + return + } + } + + if pathParams.FileName != "" { + _, err = fmt.Fprintln(w, fmt.Sprintf("C0%s", pathParams.FileRights), + len(pathParams.Content), pathParams.FileName) // touch file (c-create) + if err != nil { + logging.Logger.Error("Error session scp: ", err) + return + } + _, err = fmt.Fprint(w, pathParams.Content) // add content to file + if err != nil { + logging.Logger.Error("Error session scp: ", err) + return + } + } + + _, err = fmt.Fprint(w, "\x00") // transfer end with \x00 + if err != nil { + logging.Logger.Error("Error session scp: ", err) + return + } + }() + + wg := &sync.WaitGroup{} + wg.Add(1) + + go func() { + defer wg.Done() + err = session.Run("/usr/bin/scp -tr " + pathParams.RootFolder) + logging.Logger.Info("SEND SIGNAL", err) + if kill != nil { + kill <- &os.Kill + } + return + }() + + if kill != nil { + select { + case <-kill: + _ = session.Close() + kill = nil + } + } + + wg.Wait() + + if err != nil { + return time.Since(start), err + } + return time.Since(start), nil +}