From 2d85cedf8e6fd5048b76888f0d112df05445562a Mon Sep 17 00:00:00 2001 From: Erik Bernhardsson Date: Tue, 12 Mar 2024 17:31:35 -0400 Subject: [PATCH 1/3] Better detection of last model --- src/inference.py | 24 ++++++++++++++++-------- src/train.py | 4 +++- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/inference.py b/src/inference.py index f575aab..e3b93ab 100644 --- a/src/inference.py +++ b/src/inference.py @@ -31,16 +31,19 @@ def __init__(self, run_name: str = "", run_dir: str = "/runs") -> None: @modal.enter() def init(self): if self.run_name: - run_name = self.run_name + path = Path(self.run_dir) / self.run_name + with (path / "config.yml").open() as f: + output_dir = path / yaml.safe_load(f.read())["output_dir"] else: # Pick the last run automatically - run_name = VOLUME_CONFIG[self.run_dir].listdir("/")[-1].path - - # Grab the output dir (usually "lora-out") - with open(f"{self.run_dir}/{run_name}/config.yml") as f: - output_dir = yaml.safe_load(f.read())["output_dir"] - - model_path = f"{self.run_dir}/{run_name}/{output_dir}/merged" + run_paths = list(Path(self.run_dir).iterdir()) + for path in sorted(run_paths, reverse=True): + with (path / "config.yml").open() as f: + output_dir = path / yaml.safe_load(f.read())["output_dir"] + if output_dir.exists(): + break + + model_path = output_dir / "merged" print("Initializing vLLM engine on:", model_path) engine_args = AsyncEngineArgs( @@ -88,6 +91,11 @@ async def completion(self, input: str): async for text in self._stream(input): yield text + @modal.method() + async def non_streaming(self, input: str): + output = [text async for text in self._stream(input)] + return "".join(output) + @modal.web_endpoint() async def web(self, input: str): return StreamingResponse(self._stream(input), media_type="text/event-stream") diff --git a/src/train.py b/src/train.py index 1c7df7b..7a62131 100644 --- a/src/train.py +++ b/src/train.py @@ -150,4 +150,6 @@ def main( print(f"Training complete. Run tag: {run_name}") print(f"To inspect weights, run `modal volume ls example-runs-vol {run_name}`") - print(f"To run sample inference, run `modal run -q src.inference --run-name {run_name}`") + print( + f"To run sample inference, run `modal run -q src.inference --run-name {run_name}`" + ) From da5871d71410ab2c35b92ddc957f43d9098ad6c2 Mon Sep 17 00:00:00 2001 From: Erik Bernhardsson Date: Tue, 12 Mar 2024 17:36:23 -0400 Subject: [PATCH 2/3] No need to pass --run-name --- ci/check_inference.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/ci/check_inference.py b/ci/check_inference.py index 5c06cb6..a316223 100644 --- a/ci/check_inference.py +++ b/ci/check_inference.py @@ -2,15 +2,11 @@ if __name__ == "__main__": - - with open(".last_run_name", "r") as f: - run_name = f.read().strip() - prompt = """[INST] Using the schema context below, generate a SQL query that answers the question. CREATE TABLE head (age INTEGER) How many heads of the departments are older than 56 ? [/INST] """ - p = subprocess.Popen(["modal", "run", "src.inference", "--run-name", run_name, "--prompt", prompt], stdout=subprocess.PIPE) + p = subprocess.Popen(["modal", "run", "src.inference", "--prompt", prompt], stdout=subprocess.PIPE) output = "" for line in iter(p.stdout.readline, b''): From b274a3a46222301a251ff7cd82590150d03da7a5 Mon Sep 17 00:00:00 2001 From: Erik Bernhardsson Date: Tue, 12 Mar 2024 17:39:54 -0400 Subject: [PATCH 3/3] cleanup, use util function --- src/inference.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/inference.py b/src/inference.py index e3b93ab..fe46f64 100644 --- a/src/inference.py +++ b/src/inference.py @@ -16,6 +16,11 @@ from vllm.utils import random_uuid +def get_model_path_from_run(path: Path) -> Path: + with (path / "config.yml").open() as f: + return path / yaml.safe_load(f.read())["output_dir"] / "merged" + + @stub.cls( gpu=modal.gpu.H100(count=N_INFERENCE_GPU), image=vllm_image, @@ -32,18 +37,15 @@ def __init__(self, run_name: str = "", run_dir: str = "/runs") -> None: def init(self): if self.run_name: path = Path(self.run_dir) / self.run_name - with (path / "config.yml").open() as f: - output_dir = path / yaml.safe_load(f.read())["output_dir"] + model_path = get_model_path_from_run(path) else: # Pick the last run automatically run_paths = list(Path(self.run_dir).iterdir()) for path in sorted(run_paths, reverse=True): - with (path / "config.yml").open() as f: - output_dir = path / yaml.safe_load(f.read())["output_dir"] - if output_dir.exists(): + model_path = get_model_path_from_run(path) + if model_path.exists(): break - model_path = output_dir / "merged" print("Initializing vLLM engine on:", model_path) engine_args = AsyncEngineArgs(