-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclient.go
154 lines (126 loc) · 3.01 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
package sshmgr
import (
"errors"
"io"
"net"
"sync/atomic"
"time"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)
var (
errClientClosed = errors.New("client already closed")
)
// Client is a shared managed ssh client
type Client struct {
client *ssh.Client
conn net.Conn
atime int64
refs int32
}
// Close notifies the manager that this client can be removed
// if there is no more references to it
func (c *Client) Close() (err error) {
if c.refcount() == 0 {
return errClientClosed
}
c.updateAtime()
c.decr()
return nil
}
// CombinedOutput runs cmd on the remote host and returns its combined
// standard output and standard error.
func (c *Client) CombinedOutput(cmd string, envs map[string]string) (data []byte, err error) {
s, err := c.client.NewSession()
if err != nil {
return nil, err
}
defer s.Close()
for name := range envs {
if err = s.Setenv(name, envs[name]); err != nil {
return nil, err
}
}
return s.CombinedOutput(cmd)
}
type readCloser struct {
io.Reader
s *ssh.Session
}
func (r readCloser) Close() (err error) {
return r.s.Close()
}
// CombinedReader is like CombinedOutput but returns a io.Reader combining both stderr and stdout.
func (c *Client) CombinedReader(cmd string, envs map[string]string) (reader io.ReadCloser, err error) {
s, err := c.client.NewSession()
if err != nil {
return nil, err
}
for name := range envs {
if err = s.Setenv(name, envs[name]); err != nil {
return nil, err
}
}
stdout, err := s.StdoutPipe()
if err != nil {
return nil, err
}
stderr, err := s.StderrPipe()
if err != nil {
return nil, err
}
if err = s.Run(cmd); err != nil {
return nil, err
}
return readCloser{Reader: io.MultiReader(stdout, stderr), s: s}, nil
}
func (c *Client) incr() (r int32) {
return atomic.AddInt32(&c.refs, 1)
}
func (c *Client) decr() (r int32) {
return atomic.AddInt32(&c.refs, -1)
}
func (c *Client) updateAtime() {
atomic.StoreInt64(&c.atime, time.Now().Unix())
}
func (c *Client) refcount() (r int32) {
return atomic.LoadInt32(&c.refs)
}
// SFTPClient type
type SFTPClient struct {
*sftp.Client
client *Client
}
// Close the session and notify the manager
func (s *SFTPClient) Close() (err error) {
return s.client.Close()
}
// Lock overrides the original client channel lock (does nothing)
func (s *SFTPClient) Lock() {
}
// Unlock overrides the original client channel lock (does nothing)
func (s *SFTPClient) Unlock() {
}
// newClient creates a new ssh.Client from the given config
func newClient(config ClientConfig) (client *Client, err error) {
if config.Port == "" {
config.Port = "22"
}
addr := config.NetAddr + ":" + config.Port
sshConfig, err := newSSHClientConfig(config)
if err != nil {
return nil, err
}
conn, err := net.DialTimeout("tcp", addr, config.DialTimeout)
if err != nil {
return nil, err
}
c, chans, reqs, err := ssh.NewClientConn(conn, addr, sshConfig)
if err != nil {
return nil, err
}
client = &Client{}
client.conn = conn
client.client = ssh.NewClient(c, chans, reqs)
return client, nil
}