Skip to content

Commit

Permalink
add waiting retries
Browse files Browse the repository at this point in the history
  • Loading branch information
azliu0 committed Oct 31, 2024
1 parent 24e031b commit 0ae93cd
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 4 deletions.
27 changes: 23 additions & 4 deletions modal/file_io.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Copyright Modal Labs 2024
import asyncio
import io
from typing import AsyncIterator, List, Optional, Union, cast

from grpclib.exceptions import GRPCError, StreamTerminatedError

from modal_proto import api_pb2

from ._utils.async_utils import synchronize_api
from ._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES
from .client import _Client
from .exception import FilesystemExecutionError

Expand Down Expand Up @@ -86,10 +90,25 @@ async def _consume_output(self, exec_id: str) -> AsyncIterator[Optional[bytes]]:

async def _wait(self, exec_id: str) -> Union[bytes, str]:
output = b""
async for data in self._consume_output(exec_id):
if data is None:
break
output += data
completed = False
retries_remaining = 10
while not completed:
try:
async for data in self._consume_output(exec_id):
if data is None:
completed = True
break
output += data
except (GRPCError, StreamTerminatedError) as exc:
if retries_remaining > 0:
retries_remaining -= 1
if isinstance(exc, GRPCError):
if exc.status in RETRYABLE_GRPC_STATUS_CODES:
await asyncio.sleep(1.0)
continue
elif isinstance(exc, StreamTerminatedError):
continue
raise
if self._binary:
return output
return output.decode("utf-8")
Expand Down
28 changes: 28 additions & 0 deletions test/file_io_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Copyright Modal Labs 2024
import pytest

from grpclib import Status
from grpclib.exceptions import GRPCError

from modal.file_io import FileIO
from modal_proto import api_pb2

Expand Down Expand Up @@ -313,3 +316,28 @@ def test_invalid_mode(servicer, client):
for mode in invalid_modes:
with pytest.raises(ValueError):
FileIO.create("/test.txt", mode, client, "task-123")


def test_client_retry(servicer, client):
"""Test client retry."""
retries = 5
content = "foo\nbar\nbaz\n"

async def container_filesystem_exec_get_output(servicer, stream):
nonlocal retries
req = await stream.recv_message()
if req.exec_id == READ_EXEC_ID:
if retries > 0:
retries -= 1
raise GRPCError(Status.UNAVAILABLE, "test")
await stream.send_message(api_pb2.FilesystemRuntimeOutputBatch(output=[content.encode()]))
await stream.send_message(api_pb2.FilesystemRuntimeOutputBatch(eof=True))

with servicer.intercept() as ctx:
ctx.set_responder("ContainerFilesystemExec", container_filesystem_exec)
ctx.set_responder("ContainerFilesystemExecGetOutput", container_filesystem_exec_get_output)

f = FileIO.create("/test.txt", "w+", client, "task-123")
f.write(content)
assert f.read() == content
f.close()

0 comments on commit 0ae93cd

Please sign in to comment.