From 4bba53a41f421cebdfcaff863208bfd7e1d8b057 Mon Sep 17 00:00:00 2001 From: Erik Bernhardsson Date: Tue, 12 Mar 2024 09:46:20 -0400 Subject: [PATCH] minor beautification of imports --- src/inference.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/inference.py b/src/inference.py index c9f58dd..f575aab 100644 --- a/src/inference.py +++ b/src/inference.py @@ -9,6 +9,12 @@ N_INFERENCE_GPU = 2 +with vllm_image.imports(): + from vllm.engine.arg_utils import AsyncEngineArgs + from vllm.engine.async_llm_engine import AsyncLLMEngine + from vllm.sampling_params import SamplingParams + from vllm.utils import random_uuid + @stub.cls( gpu=modal.gpu.H100(count=N_INFERENCE_GPU), @@ -37,9 +43,6 @@ def init(self): model_path = f"{self.run_dir}/{run_name}/{output_dir}/merged" print("Initializing vLLM engine on:", model_path) - from vllm.engine.arg_utils import AsyncEngineArgs - from vllm.engine.async_llm_engine import AsyncLLMEngine - engine_args = AsyncEngineArgs( model=model_path, gpu_memory_utilization=0.95, @@ -51,9 +54,6 @@ async def _stream(self, input: str): if not input: return - from vllm.sampling_params import SamplingParams - from vllm.utils import random_uuid - sampling_params = SamplingParams( repetition_penalty=1.1, temperature=0.2,