Skip to content

Commit

Permalink
Merge pull request #12 from Deleh/feature/handle_broken_ssh_sessions
Browse files Browse the repository at this point in the history
Handle Broken SSH Sessions
  • Loading branch information
fmessmer authored Aug 1, 2024
2 parents 301997d + 109ee8d commit 42eac21
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 14 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ sessions:
- `os: string` {linux, windows, online} (mandatory): Operating system of the host. Hosts of type `online` will only be checked for network availability.
- `user: string` (optional, default: robot): User on the host machine used for sending SSH commands.
- `port: int` (optional, default: none): The port that is checked to determine if a service on the host is already up.
- `ssh_port: int` (optional, default: `22`): The port that is used for SSH connections to the host.
- `hostname: string` (optional, default: `<key>` of `hosts` section): The hostname of the host PC.
- `check_nfs: bool` (optional, default: true): Whether the host should be checked for NFS status. Only supported on Linux.

Expand Down
42 changes: 28 additions & 14 deletions robmuxinator/robmuxinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,19 +115,16 @@ def format(self, record):
DEFAULT_USER = "robot"
DEFAULT_HOST = socket.gethostname()
DEFAULT_PORT = None # default port None disables port check
DEFAULT_SSH_PORT = 22


class SSHClient:
"""Handle commands over ssh tunnel"""

def __init__(self, user, hostname, port=DEFAULT_PORT):
def __init__(self, user, hostname, port=DEFAULT_SSH_PORT):
self._user = user
self._hostname = hostname

if port is not None:
self._port = port
else:
self._port = 22
self._port = port

# check if user has sudo privileges
self._sudo_user = True if os.getuid() == 0 else False
Expand Down Expand Up @@ -207,6 +204,7 @@ def send_cmd(self, cmd, wait_for_exit_status=True, get_pty=False):
return returncode, stdout, stderr
except Exception as e:
logger.error("{}".format(e))
self.ssh_cli = None
return 1, None, None

def send_keys(self, session_name, keys):
Expand Down Expand Up @@ -279,12 +277,13 @@ def __init__(self, hostname, user, port=DEFAULT_PORT):
self._hostname = hostname
self._user = user
self._port = port
self._ssh_port = DEFAULT_SSH_PORT

def get_hostname(self):
return self._hostname

def get_port(self):
return self._port
def get_ssh_port(self):
return self._ssh_port

def shutdown(self, timeout=30):
pass
Expand Down Expand Up @@ -337,9 +336,10 @@ def wait_for_host(self, timeout=60):
class LinuxHost(Host):
"""Handle linux hosts"""

def __init__(self, hostname, user, port=DEFAULT_PORT, check_nfs=True):
def __init__(self, hostname, user, port=DEFAULT_PORT, ssh_port=DEFAULT_SSH_PORT, check_nfs=True):
super().__init__(hostname, user, port)
self._ssh_client = SSHClient(user, hostname, port)
self._ssh_port = ssh_port
self._ssh_client = SSHClient(user, hostname, ssh_port)
self._check_nfs = check_nfs

def shutdown(self, timeout=60):
Expand Down Expand Up @@ -391,7 +391,16 @@ def wait_for_host(self, timeout=60):
)
return False

logger.info(" {} nfs is up".format(self._hostname))
# Send an initial 'echo' command to verify if sending commands works
logger.info(" {} sending initial command".format(self._hostname))
ret = 1
while ret != 0:
ret, _, _ = self._ssh_client.send_cmd("echo", get_pty=True)
if ret != 0:
logger.error(" {} sending initial command failed".format(self._hostname))
time.sleep(0.25)
logger.info(" {} sending initial command succeeded".format(self._hostname))

return True


Expand Down Expand Up @@ -800,6 +809,11 @@ def main():
else:
port = DEFAULT_PORT

if "ssh_port" in yaml_hosts[key]:
ssh_port = yaml_hosts[key]["ssh_port"]
else:
ssh_port = DEFAULT_SSH_PORT

if "check_nfs" in yaml_hosts[key]:
check_nfs = yaml_hosts[key]["check_nfs"]
else:
Expand All @@ -819,7 +833,7 @@ def main():

if yaml_hosts[key]["os"].lower().strip() == "linux":
hosts[key] = LinuxHost(
hostname, user, port, check_nfs
hostname, user, port, ssh_port, check_nfs
)
elif yaml_hosts[key]["os"].lower().strip() == "windows":
hosts[key] = WindowsHost(hostname, user, port)
Expand Down Expand Up @@ -886,7 +900,7 @@ def main():
if key in args.sessions:
sessions.append(
Session(
SSHClient(user=user, hostname=hosts[host].get_hostname(), port=hosts[host].get_port()),
SSHClient(user=user, hostname=hosts[host].get_hostname(), port=hosts[host].get_ssh_port()),
key,
yaml_sessions[key],
envs
Expand All @@ -895,7 +909,7 @@ def main():
else:
sessions.append(
Session(
SSHClient(user=user, hostname=hosts[host].get_hostname(), port=hosts[host].get_port()),
SSHClient(user=user, hostname=hosts[host].get_hostname(), port=hosts[host].get_ssh_port()),
key,
yaml_sessions[key],
envs
Expand Down

0 comments on commit 42eac21

Please sign in to comment.