From 36e304b3daa11a0c3950bde24380e38a20ce771a Mon Sep 17 00:00:00 2001 From: "sp.wack" <83104063+amanape@users.noreply.github.com> Date: Thu, 10 Oct 2024 20:10:35 +0400 Subject: [PATCH] chore(backend): Refactor `copy_from` method to be more generic (#4278) --- openhands/runtime/client/client.py | 39 ++++++++++++++++++++++++++++- openhands/runtime/client/runtime.py | 36 ++++++++++++++++---------- openhands/runtime/remote/runtime.py | 25 ++++++++++++++++++ openhands/runtime/runtime.py | 4 +-- openhands/server/listen.py | 9 ++++--- tests/runtime/test_bash.py | 21 ++++++++++++++++ 6 files changed, 114 insertions(+), 20 deletions(-) diff --git a/openhands/runtime/client/client.py b/openhands/runtime/client/client.py index ad8981dbcbc..50d005a4c9e 100644 --- a/openhands/runtime/client/client.py +++ b/openhands/runtime/client/client.py @@ -7,18 +7,21 @@ import argparse import asyncio +import io import os import re import shutil import subprocess +import tempfile import time from contextlib import asynccontextmanager from pathlib import Path +from zipfile import ZipFile import pexpect from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile from fastapi.exceptions import RequestValidationError -from fastapi.responses import JSONResponse +from fastapi.responses import JSONResponse, StreamingResponse from fastapi.security import APIKeyHeader from pydantic import BaseModel from starlette.exceptions import HTTPException as StarletteHTTPException @@ -760,6 +763,40 @@ async def upload_file( except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + @app.get('/download_files') + async def download_file(path: str): + logger.info('Downloading files') + try: + if not os.path.isabs(path): + raise HTTPException( + status_code=400, detail='Path must be an absolute path' + ) + + if not os.path.exists(path): + raise HTTPException(status_code=404, detail='File not found') + + with tempfile.TemporaryFile() as temp_zip: + with ZipFile(temp_zip, 'w') as zipf: + for root, _, files in os.walk(path): + for file in files: + file_path = os.path.join(root, file) + zipf.write( + file_path, arcname=os.path.relpath(file_path, path) + ) + temp_zip.seek(0) # Rewind the file to the beginning after writing + content = temp_zip.read() + # Good for small to medium-sized files. For very large files, streaming directly from the + # file chunks may be more memory-efficient. + zip_stream = io.BytesIO(content) + return StreamingResponse( + content=zip_stream, + media_type='application/zip', + headers={'Content-Disposition': f'attachment; filename={path}.zip'}, + ) + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + @app.get('/alive') async def alive(): return {'status': 'ok'} diff --git a/openhands/runtime/client/runtime.py b/openhands/runtime/client/runtime.py index 8c46b018035..7292d221f92 100644 --- a/openhands/runtime/client/runtime.py +++ b/openhands/runtime/client/runtime.py @@ -558,19 +558,29 @@ def list_files(self, path: str | None = None) -> list[str]: except Exception as e: raise RuntimeError(f'List files operation failed: {str(e)}') - def zip_files_in_sandbox(self) -> bytes: - """Zips the files in the sandbox and returns the bytes for streaming.""" - sandbox_dir = os.getcwd() + self.config.workspace_mount_path_in_sandbox - with tempfile.TemporaryFile() as temp_zip: - with ZipFile(temp_zip, 'w') as zipf: - for root, _, files in os.walk(sandbox_dir): - for file in files: - file_path = os.path.join(root, file) - zipf.write( - file_path, arcname=os.path.relpath(file_path, sandbox_dir) - ) - temp_zip.seek(0) # Rewind the file to the beginning after writing - return temp_zip.read() + def copy_from(self, path: str) -> bytes: + """Zip all files in the sandbox and return as a stream of bytes.""" + self._refresh_logs() + try: + params = {'path': path} + response = send_request_with_retry( + self.session, + 'GET', + f'{self.api_url}/download_files', + params=params, + stream=True, + timeout=30, + ) + if response.status_code == 200: + data = response.content + return data + else: + error_message = response.text + raise Exception(f'Copy operation failed: {error_message}') + except requests.Timeout: + raise TimeoutError('Copy operation timed out') + except Exception as e: + raise RuntimeError(f'Copy operation failed: {str(e)}') def _is_port_in_use_docker(self, port): containers = self.docker_client.containers.list() diff --git a/openhands/runtime/remote/runtime.py b/openhands/runtime/remote/runtime.py index ae12eb1acb8..7bd4b0990bd 100644 --- a/openhands/runtime/remote/runtime.py +++ b/openhands/runtime/remote/runtime.py @@ -453,6 +453,31 @@ def list_files(self, path: str | None = None) -> list[str]: except Exception as e: raise RuntimeError(f'List files operation failed: {str(e)}') + def copy_from(self, path: str) -> bytes: + """Zip all files in the sandbox and return as a stream of bytes.""" + self._wait_until_alive() + try: + params = {'path': path} + response = send_request_with_retry( + self.session, + 'GET', + f'{self.runtime_url}/download_files', + params=params, + timeout=30, + retry_exceptions=list( + filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS) + ), + ) + if response.status_code == 200: + return response.content + else: + error_message = response.text + raise Exception(f'Copy operation failed: {error_message}') + except requests.Timeout: + raise TimeoutError('Copy operation timed out') + except Exception as e: + raise RuntimeError(f'Copy operation failed: {str(e)}') + def send_status_message(self, message: str): """Sends a status message if the callback function was provided.""" if self.status_message_callback: diff --git a/openhands/runtime/runtime.py b/openhands/runtime/runtime.py index efa7373ee58..084340fd250 100644 --- a/openhands/runtime/runtime.py +++ b/openhands/runtime/runtime.py @@ -213,6 +213,6 @@ def list_files(self, path: str | None = None) -> list[str]: raise NotImplementedError('This method is not implemented in the base class.') @abstractmethod - def zip_files_in_sandbox(self) -> bytes: - """Zip all files in the sandbox and return the zip file as bytes.""" + def copy_from(self, path: str) -> bytes: + """Zip all files in the sandbox and return as a stream of bytes.""" raise NotImplementedError('This method is not implemented in the base class.') diff --git a/openhands/server/listen.py b/openhands/server/listen.py index 8848cfeed20..b2520b9b698 100644 --- a/openhands/server/listen.py +++ b/openhands/server/listen.py @@ -781,11 +781,12 @@ async def security_api(request: Request): @app.get('/api/zip-directory') async def zip_current_workspace(request: Request): - logger.info('Zipping workspace') - runtime: Runtime = request.state.session.agent_session.runtime - try: - zip_file_bytes = runtime.zip_files_in_sandbox() + logger.info('Zipping workspace') + runtime: Runtime = request.state.session.agent_session.runtime + + path = runtime.config.workspace_mount_path_in_sandbox + zip_file_bytes = runtime.copy_from(path) zip_stream = io.BytesIO(zip_file_bytes) # Wrap to behave like a file stream response = StreamingResponse( zip_stream, diff --git a/tests/runtime/test_bash.py b/tests/runtime/test_bash.py index db8919a924d..b11e8093b1a 100644 --- a/tests/runtime/test_bash.py +++ b/tests/runtime/test_bash.py @@ -13,6 +13,7 @@ from openhands.core.logger import openhands_logger as logger from openhands.events.action import CmdRunAction from openhands.events.observation import CmdOutputObservation +from openhands.runtime.runtime import Runtime # ============================================================================================================================ # Bash-specific tests @@ -572,6 +573,26 @@ def test_copy_non_existent_file(temp_dir, box_class): _close_test_runtime(runtime) +def test_copy_from_directory(temp_dir, box_class): + runtime: Runtime = _load_runtime(temp_dir, box_class) + sandbox_dir = _get_sandbox_folder(runtime) + try: + temp_dir_copy = os.path.join(temp_dir, 'test_dir') + # We need a separate directory, since temp_dir is mounted to /workspace + _create_host_test_dir_with_files(temp_dir_copy) + + # Initial state + runtime.copy_to(temp_dir_copy, sandbox_dir, recursive=True) + + path_to_copy_from = f'{sandbox_dir}/test_dir' + result = runtime.copy_from(path=path_to_copy_from) + + # Result is returned in bytes + assert isinstance(result, bytes) + finally: + _close_test_runtime(runtime) + + def test_keep_prompt(box_class, temp_dir): runtime = _load_runtime( temp_dir,