Skip to content

Commit

Permalink
minor beautification of imports
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbern committed Mar 12, 2024
1 parent aab83e5 commit 4bba53a
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@

N_INFERENCE_GPU = 2

with vllm_image.imports():
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid


@stub.cls(
gpu=modal.gpu.H100(count=N_INFERENCE_GPU),
Expand Down Expand Up @@ -37,9 +43,6 @@ def init(self):
model_path = f"{self.run_dir}/{run_name}/{output_dir}/merged"
print("Initializing vLLM engine on:", model_path)

from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine

engine_args = AsyncEngineArgs(
model=model_path,
gpu_memory_utilization=0.95,
Expand All @@ -51,9 +54,6 @@ async def _stream(self, input: str):
if not input:
return

from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid

sampling_params = SamplingParams(
repetition_penalty=1.1,
temperature=0.2,
Expand Down

0 comments on commit 4bba53a

Please sign in to comment.