Skip to content

Commit

Permalink
Fix for llama2 api server
Browse files Browse the repository at this point in the history
  • Loading branch information
arjunsuresh committed Jul 16, 2024
1 parent b1ef8f1 commit 99ee8b6
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions language/llama2-70b/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,26 @@ def main():

sut_cls = sut_map[args.scenario.lower()]

sut = sut_cls(
model_path=args.model_path,
dtype=args.dtype,
batch_size=args.batch_size,
dataset_path=args.dataset_path,
total_sample_count=args.total_sample_count,
device=args.device,
api_server=args.api_server,
api_model_name=args.api_model_name,
)
if args.vllm:
sut = sut_cls(
model_path=args.model_path,
dtype=args.dtype,
batch_size=args.batch_size,
dataset_path=args.dataset_path,
total_sample_count=args.total_sample_count,
device=args.device,
api_server=args.api_server,
api_model_name=args.api_model_name,
)
else:
sut = sut_cls(
model_path=args.model_path,
dtype=args.dtype,
batch_size=args.batch_size,
dataset_path=args.dataset_path,
total_sample_count=args.total_sample_count,
device=args.device
)

# Start sut before loadgen starts
sut.start()
Expand Down

0 comments on commit 99ee8b6

Please sign in to comment.