Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added timeout to SSHConnector #644

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ jobs:
python -m pip install -r docs/requirements.txt
- name: "Build documentation and check for consistency"
env:
CHECKSUM: "b59239241d3529a179df6158271dd00ba7a86e807a37a11ac8e078ad9c377f94"
CHECKSUM: "32fa2a0dd0bbb96a69946d22eebf3bed279697f7a1cac093e7cbad2e7e0edfec"
run: |
cd docs
HASH="$(make checksum | tail -n1)"
Expand Down
5 changes: 5 additions & 0 deletions streamflow/deployment/connector/schemas/ssh.json
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@
"description": "Perform a strict validation of the host SSH keys (and return exception if key is not recognized as valid)",
"default": true
},
"connectTimeout": {
"type": "integer",
"description": "Time (in seconds) to wait for establish the connection. When an attempt fails, the time to wait is increased.",
"default": 30
},
"dataTransferConnection": {
"oneOf": [
{
Expand Down
24 changes: 18 additions & 6 deletions streamflow/deployment/connector/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ async def _get_connection(
port=port,
tunnel=await self._get_connection(config.tunnel),
username=config.username,
connect_timeout=config.connect_timeout * (self.connection_attempts + 1),
)

def _get_param_from_file(self, file_path: str):
Expand Down Expand Up @@ -114,13 +115,14 @@ async def get_connection(self) -> asyncssh.SSHClientConnection:
self._connecting = True
try:
self._ssh_connection = await self._get_connection(self._config)
except (ConnectionError, asyncssh.Error) as e:
if logger.isEnabledFor(logging.WARNING):
logger.warning(
f"Connection to {self._config.hostname} failed: {e}."
)
except (ConnectionError, asyncssh.Error, asyncio.TimeoutError) as err:
await self.close()
raise
if isinstance(err, asyncio.TimeoutError):
raise asyncio.TimeoutError(
f"The SSH connection attempt to {self.get_hostname()} took too long."
)
else:
raise
finally:
self._connect_event.set()
else:
Expand Down Expand Up @@ -213,6 +215,7 @@ async def __aenter__(self) -> asyncssh.SSHClientProcess:
ConnectionError,
ConnectionLost,
DisconnectError,
asyncio.TimeoutError,
) as exc:
if logger.isEnabledFor(logging.WARNING):
logger.warning(
Expand Down Expand Up @@ -336,6 +339,7 @@ def __init__(
self,
check_host_key: bool,
client_keys: MutableSequence[str],
connect_timeout: int,
hostname: str,
password_file: str | None,
ssh_key_passphrase_file: str | None,
Expand All @@ -344,6 +348,7 @@ def __init__(
):
self.check_host_key: bool = check_host_key
self.client_keys: MutableSequence[str] = client_keys
self.connect_timeout: int = connect_timeout
self.hostname: str = hostname
self.password_file: str | None = password_file
self.ssh_key_passphrase_file: str | None = ssh_key_passphrase_file
Expand All @@ -359,6 +364,7 @@ def __init__(
nodes: MutableSequence[Any],
username: str | None = None,
checkHostKey: bool = True,
connectTimeout: int = 30,
dataTransferConnection: str | MutableMapping[str, Any] | None = None,
file: str | None = None,
maxConcurrentSessions: int = 10,
Expand Down Expand Up @@ -399,6 +405,7 @@ def __init__(
template_map=services_map,
)
self.checkHostKey: bool = checkHostKey
self.connect_timeout: int = connectTimeout
self.passwordFile: str | None = passwordFile
self.maxConcurrentSessions: int = maxConcurrentSessions
self.maxConnections: int = maxConnections
Expand Down Expand Up @@ -543,6 +550,11 @@ def _get_config(self, node: str | MutableMapping[str, Any]):
if "tunnel" in node
else self.tunnel if hasattr(self, "tunnel") else None
),
connect_timeout=(
node["connect_timeout"]
if "connect_timeout" in node
else self.connect_timeout
),
)

def _get_ssh_client_process(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,11 @@ def test_schema_generation():
"""Check that the `streamflow schema` command generates a correct JSON Schema."""
assert (
hashlib.sha256(SfSchema().dump("v1.0", False).encode()).hexdigest()
== "f8e3f739678510fc34afe419b215b54d1467d84ee6433fbb0c107bc30eb1f062"
== "bed6608171b77a8d7665532a6ea2405f53e9bab45c6d7719e052856eeff0f6fb"
)
assert (
hashlib.sha256(SfSchema().dump("v1.0", True).encode()).hexdigest()
== "b91f949c055e3f5de305751540725eeba7e1a6deb1082c11bca3c6e7cfa09929"
== "7ccfaf9c38100ed943ebc3b57dbb3edfe7e2512e4784d87858c4dd470970768b"
)


Expand Down
Loading