Skip to content

Commit

Permalink
This commit adds the timeout in the asyncssh.connect() call. The us…
Browse files Browse the repository at this point in the history
…er chooses the time to wait using the `connectTimeout` option. If the `retry` is enabled, the time to wait is increased at each attempt.

Updated documentation checksum because the `connectTimeout` option was added in the `SSHConnector` schema
  • Loading branch information
LanderOtto committed Jan 22, 2025
1 parent eb19e9a commit 2ba9820
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ jobs:
python -m pip install -r docs/requirements.txt
- name: "Build documentation and check for consistency"
env:
CHECKSUM: "b59239241d3529a179df6158271dd00ba7a86e807a37a11ac8e078ad9c377f94"
CHECKSUM: "32fa2a0dd0bbb96a69946d22eebf3bed279697f7a1cac093e7cbad2e7e0edfec"
run: |
cd docs
HASH="$(make checksum | tail -n1)"
Expand Down
5 changes: 5 additions & 0 deletions streamflow/deployment/connector/schemas/ssh.json
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@
"description": "Perform a strict validation of the host SSH keys (and return exception if key is not recognized as valid)",
"default": true
},
"connectTimeout": {
"type": "integer",
"description": "Time (in seconds) to wait for establish the connection. When an attempt fails, the time to wait is increased.",
"default": 30
},
"dataTransferConnection": {
"oneOf": [
{
Expand Down
24 changes: 18 additions & 6 deletions streamflow/deployment/connector/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ async def _get_connection(
port=port,
tunnel=await self._get_connection(config.tunnel),
username=config.username,
connect_timeout=config.connect_timeout * (self.connection_attempts + 1),
)

def _get_param_from_file(self, file_path: str):
Expand Down Expand Up @@ -114,13 +115,14 @@ async def get_connection(self) -> asyncssh.SSHClientConnection:
self._connecting = True
try:
self._ssh_connection = await self._get_connection(self._config)
except (ConnectionError, asyncssh.Error) as e:
if logger.isEnabledFor(logging.WARNING):
logger.warning(
f"Connection to {self._config.hostname} failed: {e}."
)
except (ConnectionError, asyncssh.Error, asyncio.TimeoutError) as err:
await self.close()
raise
if isinstance(err, asyncio.TimeoutError):
raise asyncio.TimeoutError(
f"The SSH connection attempt to {self.get_hostname()} took too long."
)
else:
raise
finally:
self._connect_event.set()
else:
Expand Down Expand Up @@ -213,6 +215,7 @@ async def __aenter__(self) -> asyncssh.SSHClientProcess:
ConnectionError,
ConnectionLost,
DisconnectError,
asyncio.TimeoutError,
) as exc:
if logger.isEnabledFor(logging.WARNING):
logger.warning(
Expand Down Expand Up @@ -336,6 +339,7 @@ def __init__(
self,
check_host_key: bool,
client_keys: MutableSequence[str],
connect_timeout: int,
hostname: str,
password_file: str | None,
ssh_key_passphrase_file: str | None,
Expand All @@ -344,6 +348,7 @@ def __init__(
):
self.check_host_key: bool = check_host_key
self.client_keys: MutableSequence[str] = client_keys
self.connect_timeout: int = connect_timeout
self.hostname: str = hostname
self.password_file: str | None = password_file
self.ssh_key_passphrase_file: str | None = ssh_key_passphrase_file
Expand All @@ -359,6 +364,7 @@ def __init__(
nodes: MutableSequence[Any],
username: str | None = None,
checkHostKey: bool = True,
connectTimeout: int = 30,
dataTransferConnection: str | MutableMapping[str, Any] | None = None,
file: str | None = None,
maxConcurrentSessions: int = 10,
Expand Down Expand Up @@ -399,6 +405,7 @@ def __init__(
template_map=services_map,
)
self.checkHostKey: bool = checkHostKey
self.connect_timeout: int = connectTimeout
self.passwordFile: str | None = passwordFile
self.maxConcurrentSessions: int = maxConcurrentSessions
self.maxConnections: int = maxConnections
Expand Down Expand Up @@ -543,6 +550,11 @@ def _get_config(self, node: str | MutableMapping[str, Any]):
if "tunnel" in node
else self.tunnel if hasattr(self, "tunnel") else None
),
connect_timeout=(
node["connect_timeout"]
if "connect_timeout" in node
else self.connect_timeout
),
)

def _get_ssh_client_process(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,11 @@ def test_schema_generation():
"""Check that the `streamflow schema` command generates a correct JSON Schema."""
assert (
hashlib.sha256(SfSchema().dump("v1.0", False).encode()).hexdigest()
== "f8e3f739678510fc34afe419b215b54d1467d84ee6433fbb0c107bc30eb1f062"
== "bed6608171b77a8d7665532a6ea2405f53e9bab45c6d7719e052856eeff0f6fb"
)
assert (
hashlib.sha256(SfSchema().dump("v1.0", True).encode()).hexdigest()
== "b91f949c055e3f5de305751540725eeba7e1a6deb1082c11bca3c6e7cfa09929"
== "7ccfaf9c38100ed943ebc3b57dbb3edfe7e2512e4784d87858c4dd470970768b"
)


Expand Down

0 comments on commit 2ba9820

Please sign in to comment.