diff --git a/smart_open/ssh.py b/smart_open/ssh.py index 86008246..1dc73cbe 100644 --- a/smart_open/ssh.py +++ b/smart_open/ssh.py @@ -111,7 +111,10 @@ def parse_uri(uri_as_string): def open_uri(uri, mode, transport_params): - smart_open.utils.check_kwargs(open, transport_params) + # `connect_kwargs` is a legitimate param *only* for sftp, so this filters it out of validation + # (otherwise every call with this present complains it's not valid) + params_to_validate = {k: v for k, v in transport_params.items() if k != 'connect_kwargs'} + smart_open.utils.check_kwargs(open, params_to_validate) parsed_uri = parse_uri(uri) uri_path = parsed_uri.pop('uri_path') parsed_uri.pop('scheme') @@ -266,6 +269,11 @@ def open(path, mode='r', host=None, user=None, password=None, port=None, transpo for attempt in range(attempts): try: ssh = _SSH[key] + # Validate that the cached connection is still an active connection + # and if not, refresh the connection + if not ssh.get_transport().active: + ssh.close() + ssh = _SSH[key] = _connect_ssh(host, user, port, password, transport_params) except KeyError: ssh = _SSH[key] = _connect_ssh(host, user, port, password, transport_params)