Skip to content

Commit

Permalink
feat: added with statement support to release memory and exposed help…
Browse files Browse the repository at this point in the history
… function for tokenizer (#231)

* added context management for LLM and Engine
```python
with AsyncLLMEngine(...) as engine: 
    engine.schedule(...)
```
* added help function for tokenizer
```python
def apply_chat_template(self, messages: List[Message]) -> Optional[str]: ...
def encode(self, text: str) -> List[int]: ...
def decode(self, tokens: List[int], skip_special_tokens: bool) -> str: ...
```
  • Loading branch information
guocuimi authored Jun 7, 2024
1 parent 732f02f commit 5eecbee
Show file tree
Hide file tree
Showing 9 changed files with 218 additions and 105 deletions.
88 changes: 41 additions & 47 deletions examples/async_stream_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,53 +3,47 @@

def main():
# Create an LLM engine.
engine = AsyncLLMEngine(model="google/gemma-1.1-2b-it", devices="cuda")
# start the engine loop
engine.start()

sampling_params = SamplingParams(temperature=0.7, max_tokens=1000)

messages = []
system_prompt = input("\n[System]: ")
# append the system message
if system_prompt:
messages.append(Message(role="system", content=system_prompt))

while True:
# Get the next prompt.
prompt = input("\n[User]: ")
if not prompt:
continue
if prompt == "exit" or prompt == "quit":
break

# append the user message
messages.append(Message(role="user", content=prompt))

try:
output_stream = engine.schedule_chat(
messages=messages,
sampling_params=sampling_params,
stream=True,
)
assistant_response = ""
print("\n[Assistant]: ", end="", flush=True)
for output in output_stream:
if len(output.outputs) > 0:
response = output.outputs[0].text
assistant_response += response
print(response, end="", flush=True)
print()
except KeyboardInterrupt:
# cancel the request
output_stream.cancel()
break

# append the assistant message
messages.append(Message(role="assistant", content=assistant_response))

# stop the engine
engine.stop()
with AsyncLLMEngine(model="google/gemma-1.1-2b-it", devices="cuda") as engine:
sampling_params = SamplingParams(temperature=0.7, max_tokens=1000)

messages = []
system_prompt = input("\n[System]: ")
# append the system message
if system_prompt:
messages.append(Message(role="system", content=system_prompt))

while True:
# Get the next prompt.
prompt = input("\n[User]: ")
if not prompt:
continue
if prompt == "exit" or prompt == "quit":
break

# append the user message
messages.append(Message(role="user", content=prompt))

try:
output_stream = engine.schedule_chat(
messages=messages,
sampling_params=sampling_params,
stream=True,
)
assistant_response = ""
print("\n[Assistant]: ", end="", flush=True)
for output in output_stream:
if len(output.outputs) > 0:
response = output.outputs[0].text
assistant_response += response
print(response, end="", flush=True)
print()
except KeyboardInterrupt:
# cancel the request
output_stream.cancel()
break

# append the assistant message
messages.append(Message(role="assistant", content=assistant_response))


if __name__ == "__main__":
Expand Down
49 changes: 22 additions & 27 deletions examples/async_stream_complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,32 @@

def main():
# Create an LLM engine.
engine = AsyncLLMEngine(model="google/gemma-2b", devices="cuda")
# start the engine loop
engine.start()

prompt = input("\n[Prompt]: ")
while True:
if prompt == "exit":
break
with AsyncLLMEngine(model="google/gemma-2b", devices="cuda") as engine:
sampling_params = SamplingParams(
temperature=0, top_p=1.0, max_tokens=100, echo=True
)
try:
output_stream = engine.schedule(
prompt=prompt,
sampling_params=sampling_params,
stream=True,
)
for output in output_stream:
if len(output.outputs) > 0:
print(output.outputs[0].text, end="", flush=True)
print()
except KeyboardInterrupt:
# cancel the request
output_stream.cancel()
break

# Get the next prompt.
prompt = input("\n[Prompt]: ")

# stop the engine
engine.stop()
while True:
# Get the next prompt.
prompt = input("\n[Prompt]: ")
if not prompt:
continue
if prompt == "exit":
break
try:
output_stream = engine.schedule(
prompt=prompt,
sampling_params=sampling_params,
stream=True,
)
for output in output_stream:
if len(output.outputs) > 0:
print(output.outputs[0].text, end="", flush=True)
print()
except KeyboardInterrupt:
# cancel the request
output_stream.cancel()
break


if __name__ == "__main__":
Expand Down
5 changes: 5 additions & 0 deletions scalellm/_C/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,8 @@ class LLMHandler:
def start(self) -> None: ...
def stop(self) -> None: ...
def run_until_complete(self) -> None: ...
def reset(self) -> None: ...
# helper functions
def apply_chat_template(self, messages: List[Message]) -> Optional[str]: ...
def encode(self, text: str) -> List[int]: ...
def decode(self, tokens: List[int], skip_special_tokens: bool) -> str: ...
12 changes: 12 additions & 0 deletions scalellm/csrc/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,18 @@ PYBIND11_MODULE(PY_MODULE_NAME, m) {
py::call_guard<py::gil_scoped_release>())
.def("run_until_complete",
&LLMHandler::run_until_complete,
py::call_guard<py::gil_scoped_release>())
.def("apply_chat_template",
&LLMHandler::apply_chat_template,
py::call_guard<py::gil_scoped_release>())
.def("encode",
&LLMHandler::encode,
py::call_guard<py::gil_scoped_release>())
.def("decode",
&LLMHandler::decode,
py::call_guard<py::gil_scoped_release>())
.def("reset",
&LLMHandler::reset,
py::call_guard<py::gil_scoped_release>());

// LLMHandler::Options
Expand Down
28 changes: 25 additions & 3 deletions scalellm/llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
from typing import List, Optional, Union

from scalellm._C import LLMHandler, Priority, RequestOutput, SamplingParams
from scalellm._C import (LLMHandler, Message, Priority, RequestOutput,
SamplingParams)
from scalellm.downloader import download_hf_model
from scalellm.errors import ValidationError

Expand All @@ -26,8 +27,8 @@ def __init__(
cuda_graph_max_seq_len: int = 2048,
cuda_graph_batch_sizes: Optional[List[int]] = None,
draft_cuda_graph_batch_sizes: Optional[List[int]] = None,
max_tokens_per_batch: int = 409600, # a big number to disable chunked prefill
max_seqs_per_batch: int = 2048, # a big number for better throughput
max_tokens_per_batch: int = 409600, # a big number to disable chunked prefill
max_seqs_per_batch: int = 2048, # a big number for better throughput
num_speculative_tokens: int = 0,
num_handling_threads: int = 4,
) -> None:
Expand Down Expand Up @@ -112,3 +113,24 @@ def callback(index: int, output: RequestOutput) -> bool:
# carry over the prompt to the output
output.prompt = prompts[index]
return outputs

def apply_chat_template(self, messages: List[Message]) -> Optional[str]:
return self._handler.apply_chat_template(messages)

def encode(self, text: str) -> List[int]:
return self._handler.encode(text)

def decode(
self, tokens: List[int], skip_special_tokens: bool = True
) -> Optional[str]:
return self._handler.decode(tokens, skip_special_tokens)

def __del__(self):
self._handler.reset()

def __enter__(self):
return self

def __exit__(self, *args):
self.__del__()
return False
23 changes: 23 additions & 0 deletions scalellm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,26 @@ def start(self) -> None:
# stop the engine, non-blocking
def stop(self) -> None:
return self._handler.stop()

def apply_chat_template(self, messages: List[Message]) -> Optional[str]:
return self._handler.apply_chat_template(messages)

def encode(self, text: str) -> List[int]:
return self._handler.encode(text)

def decode(
self, tokens: List[int], skip_special_tokens: bool = True
) -> Optional[str]:
return self._handler.decode(tokens, skip_special_tokens)

def __del__(self):
self._handler.reset()

def __enter__(self):
self.start()
return self

def __exit__(self, *args):
self.stop()
self.__del__()
return False
58 changes: 45 additions & 13 deletions src/handlers/llm_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,19 +219,7 @@ LLMHandler::LLMHandler(const Options& options) : options_(options) {
}
}

LLMHandler::~LLMHandler() {
stop();

// stop all handling threads
// push nullptr to the queue to signal threads to exit
for (size_t i = 0; i < handling_threads_.size(); ++i) {
queue_.push(nullptr);
}
// wait for all threads to finish
for (auto& thread : handling_threads_) {
thread.join();
}
}
LLMHandler::~LLMHandler() { reset(); }

void LLMHandler::schedule_async(std::string prompt,
SamplingParams sp,
Expand Down Expand Up @@ -567,4 +555,48 @@ std::unique_ptr<Request> LLMHandler::create_chat_request(
tid, std::move(prompt.value()), sp, priority, stream, callback);
}

std::optional<std::string> LLMHandler::apply_chat_template(
const std::vector<Message>& conversation) {
// without chat template, return nullopt
if (chat_template_ == nullptr) {
return std::nullopt;
}
return chat_template_->apply(conversation);
}

std::vector<int32_t> LLMHandler::encode(const std::string& text) {
std::vector<int> tokens;
engine_->tokenizer()->encode(text, &tokens);
return tokens;
}

std::string LLMHandler::decode(const std::vector<int32_t>& tokens,
bool skip_special_tokens) {
return engine_->tokenizer()->decode(tokens, skip_special_tokens);
}

void LLMHandler::reset() {
stop();

// stop all handling threads
// push nullptr to the queue to signal threads to exit
for (size_t i = 0; i < handling_threads_.size(); ++i) {
queue_.push(nullptr);
}
// wait for all threads to finish
for (auto& thread : handling_threads_) {
thread.join();
}
handling_threads_.clear();

// release all underlying resources
scheduler_.reset();
engine_.reset();
tokenizers_.clear();
chat_template_.reset();

// torch::cuda::empty_cache();
c10::cuda::CUDACachingAllocator::emptyCache();
}

} // namespace llm
13 changes: 13 additions & 0 deletions src/handlers/llm_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,19 @@ class LLMHandler {
// run until complete, blocking call
void run_until_complete();

// helper functions exposed for in python
// apply the chat template to the conversation and return the result
std::optional<std::string> apply_chat_template(
const std::vector<Message>& conversation);

std::vector<int32_t> encode(const std::string& text);

std::string decode(const std::vector<int32_t>& tokens,
bool skip_special_tokens);

// release underlying resources
void reset();

private:
using Task = std::function<void(size_t tid)>;
std::unique_ptr<Request> create_request(size_t tid,
Expand Down
Loading

0 comments on commit 5eecbee

Please sign in to comment.