diff --git a/ci/check_inference.py b/ci/check_inference.py index 5c06cb6..a316223 100644 --- a/ci/check_inference.py +++ b/ci/check_inference.py @@ -2,15 +2,11 @@ if __name__ == "__main__": - - with open(".last_run_name", "r") as f: - run_name = f.read().strip() - prompt = """[INST] Using the schema context below, generate a SQL query that answers the question. CREATE TABLE head (age INTEGER) How many heads of the departments are older than 56 ? [/INST] """ - p = subprocess.Popen(["modal", "run", "src.inference", "--run-name", run_name, "--prompt", prompt], stdout=subprocess.PIPE) + p = subprocess.Popen(["modal", "run", "src.inference", "--prompt", prompt], stdout=subprocess.PIPE) output = "" for line in iter(p.stdout.readline, b''): diff --git a/src/inference.py b/src/inference.py index f575aab..fe46f64 100644 --- a/src/inference.py +++ b/src/inference.py @@ -16,6 +16,11 @@ from vllm.utils import random_uuid +def get_model_path_from_run(path: Path) -> Path: + with (path / "config.yml").open() as f: + return path / yaml.safe_load(f.read())["output_dir"] / "merged" + + @stub.cls( gpu=modal.gpu.H100(count=N_INFERENCE_GPU), image=vllm_image, @@ -31,16 +36,16 @@ def __init__(self, run_name: str = "", run_dir: str = "/runs") -> None: @modal.enter() def init(self): if self.run_name: - run_name = self.run_name + path = Path(self.run_dir) / self.run_name + model_path = get_model_path_from_run(path) else: # Pick the last run automatically - run_name = VOLUME_CONFIG[self.run_dir].listdir("/")[-1].path - - # Grab the output dir (usually "lora-out") - with open(f"{self.run_dir}/{run_name}/config.yml") as f: - output_dir = yaml.safe_load(f.read())["output_dir"] + run_paths = list(Path(self.run_dir).iterdir()) + for path in sorted(run_paths, reverse=True): + model_path = get_model_path_from_run(path) + if model_path.exists(): + break - model_path = f"{self.run_dir}/{run_name}/{output_dir}/merged" print("Initializing vLLM engine on:", model_path) engine_args = AsyncEngineArgs( @@ -88,6 +93,11 @@ async def completion(self, input: str): async for text in self._stream(input): yield text + @modal.method() + async def non_streaming(self, input: str): + output = [text async for text in self._stream(input)] + return "".join(output) + @modal.web_endpoint() async def web(self, input: str): return StreamingResponse(self._stream(input), media_type="text/event-stream") diff --git a/src/train.py b/src/train.py index 1c7df7b..7a62131 100644 --- a/src/train.py +++ b/src/train.py @@ -150,4 +150,6 @@ def main( print(f"Training complete. Run tag: {run_name}") print(f"To inspect weights, run `modal volume ls example-runs-vol {run_name}`") - print(f"To run sample inference, run `modal run -q src.inference --run-name {run_name}`") + print( + f"To run sample inference, run `modal run -q src.inference --run-name {run_name}`" + )