diff --git a/compute_sdk/globus_compute_sdk/sdk/shell_function.py b/compute_sdk/globus_compute_sdk/sdk/shell_function.py index 2ee78e6bb..060a9ff43 100644 --- a/compute_sdk/globus_compute_sdk/sdk/shell_function.py +++ b/compute_sdk/globus_compute_sdk/sdk/shell_function.py @@ -10,6 +10,7 @@ def __init__( stderr: str, returncode: int, exception_name: t.Optional[str] = None, + run_dir: t.Optional[str] = None, ): """ @@ -34,6 +35,7 @@ def __init__( self.stderr = stderr self.returncode = returncode self.exception_name = exception_name + self.run_dir = run_dir def __str__(self): return f"Command {self.cmd} returned with exit status: {self.returncode}" @@ -53,6 +55,7 @@ def __init__( stderr: t.Optional[str] = None, walltime: t.Optional[float] = None, snippet_lines=1000, + run_dir: t.Optional[str] = None, ): """Initialize a ShellFunction @@ -84,6 +87,7 @@ def __init__( if walltime: assert walltime >= 0, f"Negative walltime={walltime} is not allowed" self.snippet_lines = snippet_lines + self.run_dir = run_dir @property def __name__(self): @@ -125,11 +129,11 @@ def execute_cmd_line( sandbox_error_message = None - run_dir = None - # run_dir takes priority over sandboxing - if os.environ.get("GC_TASK_SANDBOX_DIR"): - run_dir = os.environ["GC_TASK_SANDBOX_DIR"] - else: + sandbox_dir = os.environ.get("GC_TASK_SANDBOX_DIR") + # run_dir takes precedence over sandbox_dir + run_dir = self.run_dir or sandbox_dir + + if not run_dir: sandbox_error_message = ( "WARNING: Task sandboxing will not work due to " "endpoint misconfiguration. Please enable sandboxing " @@ -194,6 +198,7 @@ def execute_cmd_line( stderr_snippet, returncode, exception_name=exception_name, + run_dir=run_dir, ) def __call__(