Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sglang example #92

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions models/model_upload/llms/sglang-llama-3_2-1b-instruct/1/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import json
import os
from typing import Iterator

from clarifai.runners.models.model_runner import ModelRunner
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2
from google.protobuf import json_format

import sglang as sgl
from transformers import AutoTokenizer

def get_inference_params(request) -> dict:
"""Get the inference params from the request."""
inference_params = {}
if request.model.model_version.id != "":
output_info = request.model.model_version.output_info
output_info = json_format.MessageToDict(output_info, preserving_proto_field_name=True)
if "params" in output_info:
inference_params = output_info["params"]
return inference_params

def parse_request(request: service_pb2.PostModelOutputsRequest):
prompts = [inp.data.text.raw for inp in request.inputs]
inference_params = get_inference_params(request)
temperature = inference_params.get("temperature", 0.7)
max_tokens = inference_params.get("max_tokens", 256)
top_p = inference_params.get("top_p", .9)

messages = []
for prompt in prompts:
try:
prompt = json.loads(prompt)
except:
prompt = [{"role": "user", "content": prompt}]
finally:
messages.append(prompt)

gen_config = dict(temperature=temperature,
max_new_tokens=max_tokens,
top_p=top_p)

return messages, gen_config

def set_output(texts: list):
assert isinstance(texts, list)
output_protos = []
for text in texts:
output_protos.append(
resources_pb2.Output(
data=resources_pb2.Data(text=resources_pb2.Text(raw=text)),
status=status_pb2.Status(code=status_code_pb2.SUCCESS)
)
)
return output_protos

class MyRunner(ModelRunner):
"""A custom runner that loads the model and generates text using sglang inference.
"""

def load_model(self):
"""Load the model here """
os.path.join(os.path.dirname(__file__))
# if checkpoints section is in config.yaml file then checkpoints will be downloaded at this path during model upload time.
checkpoints = os.path.join(os.path.dirname(__file__), "checkpoints")
self.pipe = sgl.Engine(model_path=checkpoints)
self.tokenizer = AutoTokenizer.from_pretrained(checkpoints)

def predict(self, request: service_pb2.PostModelOutputsRequest
) -> Iterator[service_pb2.MultiOutputResponse]:
"""This is the method that will be called when the runner is run. It takes in an input and
returns an output.
"""
messages, gen_config = parse_request(request)
messages = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
generated_text = self.pipe.generate(messages, gen_config)
if not isinstance(generated_text, list):
generated_text = [generated_text]
raw_texts = [each["text"] for each in generated_text]
output_protos = set_output(raw_texts)

return service_pb2.MultiOutputResponse(outputs=output_protos)

def generate(self, request: service_pb2.PostModelOutputsRequest
) -> Iterator[service_pb2.MultiOutputResponse]:
"""Example yielding a whole batch of streamed stuff back."""
messages, gen_config = parse_request(request)
messages = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
batch_size = len(messages)
outputs = [
resources_pb2.Output(
data=resources_pb2.Data(text=resources_pb2.Text(raw="")),
status=status_pb2.Status(code=status_code_pb2.SUCCESS)
) for _ in range(batch_size)
]
previous_text = {}
for item in self.pipe.generate(messages, gen_config, stream=True):
prompt_idx = item.get("index", 0)

prev_chunk_text = previous_text.get(prompt_idx, "")
chunk_text = item["text"].replace(prev_chunk_text, "")
previous_text.update({prompt_idx: item['text']})

outputs[prompt_idx].data.text.raw = chunk_text

yield service_pb2.MultiOutputResponse(outputs=outputs,)

def stream(self, request_iterator: Iterator[service_pb2.PostModelOutputsRequest]
) -> Iterator[service_pb2.MultiOutputResponse]:
"""Example yielding a whole batch of streamed stuff back."""
pass
22 changes: 22 additions & 0 deletions models/model_upload/llms/sglang-llama-3_2-1b-instruct/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Config file for the VLLM runner

model:
id: "sglang-llama3_2-1b-instruct"
user_id: ""
app_id: ""
model_type_id: "text-to-text"

build_info:
python_version: "3.10"

inference_compute_info:
cpu_limit: "4"
cpu_memory: "16Gi"
num_accelerators: 1
accelerator_type: ["NVIDIA-A10G"]
accelerator_memory: "24Gi"

checkpoints:
type: "huggingface"
repo_id: "unsloth/Llama-3.2-1B-Instruct"
hf_token:
Copy link
Contributor

@luv-bansal luv-bansal Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used below requirements with dependencies versions to test locally and it worked. I think it's better to include requirements with it's versions here, because before I don't know why but I was getting error when I didn't specify dependencies versions

torch==2.4.0
tokenizers==0.20.2
transformers==4.46.2
accelerate==0.34.2
scipy==1.10.1
optimum==1.23.3
xformers==0.0.27.post2
einops==0.8.0
requests==2.32.2
packaging
ninja
protobuf==3.20.0

sglang[all]==0.3.5.post2
orjson==3.10.11
python-multipart==0.0.17

--extra-index-url https://flashinfer.ai/whl/cu121/torch2.4/
flashinfer
``

Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
torch==2.5.1
tokenizers==0.21.0
transformers>=4.47
accelerate==1.2.0
optimum==1.23.3
xformers
einops==0.8.0
requests==2.32.2
packaging
ninja
protobuf==3.20.0

sglang[all]==0.4.1.post7
orjson==3.10.11
python-multipart==0.0.17

--extra-index-url https://flashinfer.ai/whl/cu124/torch2.4/
flashinfer
Loading