Skip to content

Commit

Permalink
chore(backend): Refactor copy_from method to be more generic (#4278)
Browse files Browse the repository at this point in the history
  • Loading branch information
amanape authored Oct 10, 2024
1 parent 62a58ea commit 36e304b
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 20 deletions.
39 changes: 38 additions & 1 deletion openhands/runtime/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,21 @@

import argparse
import asyncio
import io
import os
import re
import shutil
import subprocess
import tempfile
import time
from contextlib import asynccontextmanager
from pathlib import Path
from zipfile import ZipFile

import pexpect
from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.security import APIKeyHeader
from pydantic import BaseModel
from starlette.exceptions import HTTPException as StarletteHTTPException
Expand Down Expand Up @@ -760,6 +763,40 @@ async def upload_file(
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.get('/download_files')
async def download_file(path: str):
logger.info('Downloading files')
try:
if not os.path.isabs(path):
raise HTTPException(
status_code=400, detail='Path must be an absolute path'
)

if not os.path.exists(path):
raise HTTPException(status_code=404, detail='File not found')

with tempfile.TemporaryFile() as temp_zip:
with ZipFile(temp_zip, 'w') as zipf:
for root, _, files in os.walk(path):
for file in files:
file_path = os.path.join(root, file)
zipf.write(
file_path, arcname=os.path.relpath(file_path, path)
)
temp_zip.seek(0) # Rewind the file to the beginning after writing
content = temp_zip.read()
# Good for small to medium-sized files. For very large files, streaming directly from the
# file chunks may be more memory-efficient.
zip_stream = io.BytesIO(content)
return StreamingResponse(
content=zip_stream,
media_type='application/zip',
headers={'Content-Disposition': f'attachment; filename={path}.zip'},
)

except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.get('/alive')
async def alive():
return {'status': 'ok'}
Expand Down
36 changes: 23 additions & 13 deletions openhands/runtime/client/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,19 +558,29 @@ def list_files(self, path: str | None = None) -> list[str]:
except Exception as e:
raise RuntimeError(f'List files operation failed: {str(e)}')

def zip_files_in_sandbox(self) -> bytes:
"""Zips the files in the sandbox and returns the bytes for streaming."""
sandbox_dir = os.getcwd() + self.config.workspace_mount_path_in_sandbox
with tempfile.TemporaryFile() as temp_zip:
with ZipFile(temp_zip, 'w') as zipf:
for root, _, files in os.walk(sandbox_dir):
for file in files:
file_path = os.path.join(root, file)
zipf.write(
file_path, arcname=os.path.relpath(file_path, sandbox_dir)
)
temp_zip.seek(0) # Rewind the file to the beginning after writing
return temp_zip.read()
def copy_from(self, path: str) -> bytes:
"""Zip all files in the sandbox and return as a stream of bytes."""
self._refresh_logs()
try:
params = {'path': path}
response = send_request_with_retry(
self.session,
'GET',
f'{self.api_url}/download_files',
params=params,
stream=True,
timeout=30,
)
if response.status_code == 200:
data = response.content
return data
else:
error_message = response.text
raise Exception(f'Copy operation failed: {error_message}')
except requests.Timeout:
raise TimeoutError('Copy operation timed out')
except Exception as e:
raise RuntimeError(f'Copy operation failed: {str(e)}')

def _is_port_in_use_docker(self, port):
containers = self.docker_client.containers.list()
Expand Down
25 changes: 25 additions & 0 deletions openhands/runtime/remote/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,31 @@ def list_files(self, path: str | None = None) -> list[str]:
except Exception as e:
raise RuntimeError(f'List files operation failed: {str(e)}')

def copy_from(self, path: str) -> bytes:
"""Zip all files in the sandbox and return as a stream of bytes."""
self._wait_until_alive()
try:
params = {'path': path}
response = send_request_with_retry(
self.session,
'GET',
f'{self.runtime_url}/download_files',
params=params,
timeout=30,
retry_exceptions=list(
filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS)
),
)
if response.status_code == 200:
return response.content
else:
error_message = response.text
raise Exception(f'Copy operation failed: {error_message}')
except requests.Timeout:
raise TimeoutError('Copy operation timed out')
except Exception as e:
raise RuntimeError(f'Copy operation failed: {str(e)}')

def send_status_message(self, message: str):
"""Sends a status message if the callback function was provided."""
if self.status_message_callback:
Expand Down
4 changes: 2 additions & 2 deletions openhands/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,6 @@ def list_files(self, path: str | None = None) -> list[str]:
raise NotImplementedError('This method is not implemented in the base class.')

@abstractmethod
def zip_files_in_sandbox(self) -> bytes:
"""Zip all files in the sandbox and return the zip file as bytes."""
def copy_from(self, path: str) -> bytes:
"""Zip all files in the sandbox and return as a stream of bytes."""
raise NotImplementedError('This method is not implemented in the base class.')
9 changes: 5 additions & 4 deletions openhands/server/listen.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,11 +781,12 @@ async def security_api(request: Request):

@app.get('/api/zip-directory')
async def zip_current_workspace(request: Request):
logger.info('Zipping workspace')
runtime: Runtime = request.state.session.agent_session.runtime

try:
zip_file_bytes = runtime.zip_files_in_sandbox()
logger.info('Zipping workspace')
runtime: Runtime = request.state.session.agent_session.runtime

path = runtime.config.workspace_mount_path_in_sandbox
zip_file_bytes = runtime.copy_from(path)
zip_stream = io.BytesIO(zip_file_bytes) # Wrap to behave like a file stream
response = StreamingResponse(
zip_stream,
Expand Down
21 changes: 21 additions & 0 deletions tests/runtime/test_bash.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from openhands.core.logger import openhands_logger as logger
from openhands.events.action import CmdRunAction
from openhands.events.observation import CmdOutputObservation
from openhands.runtime.runtime import Runtime

# ============================================================================================================================
# Bash-specific tests
Expand Down Expand Up @@ -572,6 +573,26 @@ def test_copy_non_existent_file(temp_dir, box_class):
_close_test_runtime(runtime)


def test_copy_from_directory(temp_dir, box_class):
runtime: Runtime = _load_runtime(temp_dir, box_class)
sandbox_dir = _get_sandbox_folder(runtime)
try:
temp_dir_copy = os.path.join(temp_dir, 'test_dir')
# We need a separate directory, since temp_dir is mounted to /workspace
_create_host_test_dir_with_files(temp_dir_copy)

# Initial state
runtime.copy_to(temp_dir_copy, sandbox_dir, recursive=True)

path_to_copy_from = f'{sandbox_dir}/test_dir'
result = runtime.copy_from(path=path_to_copy_from)

# Result is returned in bytes
assert isinstance(result, bytes)
finally:
_close_test_runtime(runtime)


def test_keep_prompt(box_class, temp_dir):
runtime = _load_runtime(
temp_dir,
Expand Down

0 comments on commit 36e304b

Please sign in to comment.