Skip to content

Commit

Permalink
Update tokenizer (#3061)
Browse files Browse the repository at this point in the history
* change tokenizer path in converter

* update

* remove sentencepiece tokenizer

* fix

* fix

* remove meta_llama.py

* fix
  • Loading branch information
lvhan028 authored Jan 27, 2025
1 parent 894af4d commit 26622b8
Show file tree
Hide file tree
Showing 13 changed files with 71 additions and 446 deletions.
6 changes: 4 additions & 2 deletions benchmark/profile_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from lmdeploy.cli.utils import ArgumentHelper, DefaultsAndTypesHelpFormatter
from lmdeploy.messages import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig
from lmdeploy.tokenizer import Tokenizer
from lmdeploy.utils import get_logger

get_logger('lmdeploy').setLevel('ERROR')
Expand Down Expand Up @@ -115,12 +116,13 @@ def profile_throughput(model_path: str, concurrency: int, input_seqlen: int,
f'n_prompt_token: {input_seqlen}, '
f'n_completion_token: {output_seqlen}, '
f'test_round: {test_round}, warmup_round: {warmup_round}')
tokenizer = Tokenizer(model_path)
if isinstance(engine_config, TurbomindEngineConfig):
from lmdeploy.turbomind import TurboMind
tm_model = TurboMind.from_pretrained(model_path, engine_config=engine_config)
tm_model = TurboMind.from_pretrained(model_path, tokenizer=tokenizer, engine_config=engine_config)
elif isinstance(engine_config, PytorchEngineConfig):
from lmdeploy.pytorch.engine import Engine
tm_model = Engine(model_path, engine_config)
tm_model = Engine(model_path, tokenizer=tokenizer, engine_config=engine_config)

event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)
Expand Down
7 changes: 3 additions & 4 deletions benchmark/profile_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,15 @@ def sample_requests(
class Engine:

def __init__(self, model_path: str, engine_config: Union[PytorchEngineConfig, TurbomindEngineConfig]):
self.tokenizer = Tokenizer(model_path)
if isinstance(engine_config, TurbomindEngineConfig):
from lmdeploy.turbomind import TurboMind
tm_model = TurboMind.from_pretrained(model_path, engine_config=engine_config)
tm_model = TurboMind.from_pretrained(model_path, tokenizer=self.tokenizer, engine_config=engine_config)
elif isinstance(engine_config, PytorchEngineConfig):
from lmdeploy.pytorch.engine import Engine as PytorchEngine
tm_model = PytorchEngine(model_path, engine_config=engine_config)
tm_model = PytorchEngine(model_path, tokenizer=self.tokenizer, engine_config=engine_config)

self.tm_model = tm_model
self.tokenizer = tm_model.tokenizer

self.pbar = None

async def _inference(self, req_queue: Queue, session_id: int, temperature: float, top_p: float, top_k: int,
Expand Down
6 changes: 3 additions & 3 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def model_format(parser, default: str = None):
return parser.add_argument('--model-format',
type=str,
default=default,
choices=['hf', 'llama', 'awq', 'gptq'],
help='The format of input model. `hf` means `hf_llama`, `llama` '
'means `meta_llama`, `awq` represents the quantized model by AWQ,'
choices=['hf', 'awq', 'gptq'],
help='The format of input model. `hf` means `hf_llama`, '
'`awq` represents the quantized model by AWQ,'
' and `gptq` refers to the quantized model by GPTQ')

@staticmethod
Expand Down
9 changes: 4 additions & 5 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,10 @@ class TurbomindEngineConfig:
The `auto` option will use FP16 precision for FP32 and FP16
models, and BF16 precision for BF16 models.
model_format (str): the layout of the deployed model. It can be one
of the following values [hf, meta_llama, awq, gptq],`hf` meaning
huggingface model(.bin, .safetensors), `meta_llama` being
meta llama's format(.pth), `awq` and `gptq` meaning the quantized
model by AWQ and GPTQ, respectively. If it is not specified,
i.e. None, it will be extracted from the input model
of the following values [hf, awq, gptq],`hf` meaning
huggingface model(.bin, .safetensors), `awq` and `gptq` meaning
the quantized model by AWQ and GPTQ, respectively. If it is not
specified, i.e. None, it will be extracted from the input model
tp (int): the number of GPU cards used in tensor parallelism,
default to 1
session_len (int): the max session length of a sequence, default to
Expand Down
7 changes: 5 additions & 2 deletions lmdeploy/pytorch/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,11 @@ def run_chat(model_path: str,
trust_remote_code (bool): trust remote code.
"""
from lmdeploy.pytorch.engine import Engine
tm_model = Engine.from_pretrained(model_path, engine_config=engine_config, trust_remote_code=trust_remote_code)
tokenizer = tm_model.tokenizer
tokenizer = Tokenizer(model_path)
tm_model = Engine.from_pretrained(model_path,
tokenizer=tokenizer,
engine_config=engine_config,
trust_remote_code=trust_remote_code)
generator = tm_model.create_instance()
adapter_name = None
if engine_config.adapters is not None:
Expand Down
14 changes: 6 additions & 8 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,14 @@ class Engine:
Args:
model_path (str): The hugging face model path.
tokenizer (lmdeploy.Tokenizer): an instance of lmdeploy.Tokenizer
engine_config (PytorchEngineConfig): The config of the Engine.
trust_remote_code (bool): Trust remote code.
"""

def __init__(self,
model_path: str,
tokenizer: object,
engine_config: PytorchEngineConfig = None,
trust_remote_code: bool = True) -> None:
if engine_config is None:
Expand All @@ -124,6 +126,7 @@ def __init__(self,
logger=logger)
checker.handle()

self.tokenizer = tokenizer
adapters = engine_config.adapters
self.engine_config = engine_config
self.tp = engine_config.tp
Expand Down Expand Up @@ -172,6 +175,7 @@ def __init__(self,
@classmethod
def from_pretrained(cls,
pretrained_model_name_or_path: str,
tokenizer: object,
engine_config: PytorchEngineConfig = None,
trust_remote_code: bool = True,
**kwargs):
Expand All @@ -188,23 +192,17 @@ def from_pretrained(cls,
on huggingface.co, such as "InternLM/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
tokenizer (lmdeploy.Tokenizer): an instance of lmdeploy.Tokenizer
engine_config (PytorchEngineConfig): Pytorch engine config.
trust_remote_code (bool): Trust remote code
"""
if len(kwargs) > 0:
logger.debug(f'Get unexpected kwargs: {kwargs}')
return cls(model_path=pretrained_model_name_or_path,
tokenizer=tokenizer,
engine_config=engine_config,
trust_remote_code=trust_remote_code)

@property
def tokenizer(self):
"""create tokenizer."""
from lmdeploy.tokenizer import Tokenizer
if not hasattr(self, '_tokenizer'):
self._tokenizer = Tokenizer(self.model_path)
return self._tokenizer

def _download_adapters(self, adapters: Dict[str, str], engine_config: PytorchEngineConfig):
"""download adapters."""
download_dir = engine_config.download_dir
Expand Down
10 changes: 7 additions & 3 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import tqdm

from lmdeploy import Tokenizer
from lmdeploy.logger import RequestLogger
from lmdeploy.messages import GenerationConfig, PytorchEngineConfig, Response, ResponseType, TurbomindEngineConfig
from lmdeploy.model import MODELS, ChatTemplateConfig, best_match_model
Expand Down Expand Up @@ -267,6 +268,7 @@ def __init__(self,

logger.info(f'updated chat_template_onfig={chat_template_config}')

self.tokenizer = Tokenizer(model_path)
# build backend engine
if backend == 'turbomind':
self._build_turbomind(model_path=model_path, backend_config=backend_config, **kwargs)
Expand All @@ -284,7 +286,6 @@ def __init__(self,
self.stop_words = self.stop_words[0][0].tolist()
self.backend = backend
self.instance_num = self.backend_config.max_batch_size
self.tokenizer = self.engine.tokenizer
self.id2step = {}
self.id2inst = {}
self.free_insts: asyncio.Queue = None
Expand All @@ -311,7 +312,10 @@ def _build_turbomind(self,
**kwargs):
"""Innter build method for turbomind backend."""
from lmdeploy import turbomind as tm
self.engine = tm.TurboMind.from_pretrained(model_path, engine_config=backend_config, **kwargs)
self.engine = tm.TurboMind.from_pretrained(model_path,
tokenizer=self.tokenizer,
engine_config=backend_config,
**kwargs)
self.backend_config = self.engine.engine_config
self.hf_tm_cfg = self.engine.config

Expand All @@ -321,7 +325,7 @@ def _build_pytorch(self,
**kwargs):
"""Innter build method for pytorch backend."""
from lmdeploy.pytorch.engine import Engine
self.engine = Engine(model_path=model_path, engine_config=backend_config)
self.engine = Engine(model_path=model_path, tokenizer=self.tokenizer, engine_config=backend_config)
self.backend_config = self.engine.engine_config
self.hf_tm_cfg = getattr(self.engine.model_config, 'hf_config', None)

Expand Down
179 changes: 11 additions & 168 deletions lmdeploy/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple, Union

import torch

from lmdeploy.utils import get_logger

# this file will be copied to triton server, make sure all
Expand Down Expand Up @@ -37,149 +35,8 @@ def as_tuple(self) -> Tuple:
return (self.ids_offset, self.prev_tokens, self.prefix_offset, self.read_offset)


class SentencePieceTokenizer:
"""Tokenizer of sentencepiece.
Args:
model_file (str): the path of the tokenizer model
"""

def __init__(self, model_file: str):
from sentencepiece import SentencePieceProcessor
self.model = SentencePieceProcessor(model_file=model_file)
self._prefix_space_tokens = None
# for stop words
self._maybe_decode_bytes: bool = None
# TODO maybe lack a constant.py
self._indexes_tokens_deque = deque(maxlen=10)
self.max_indexes_num = 5
self.logger = get_logger('lmdeploy')

@property
def vocab_size(self):
"""vocabulary size."""
return self.model.vocab_size()

@property
def bos_token_id(self):
"""begine of the sentence token id."""
return self.model.bos_id()

@property
def eos_token_id(self):
"""end of the sentence token id."""
return self.model.eos_id()

@property
def prefix_space_tokens(self):
"""tokens without prefix space."""
if self._prefix_space_tokens is None:
vocab = self.model.IdToPiece(list(range(self.vocab_size)))
self._prefix_space_tokens = {i for i, tok in enumerate(vocab) if tok.startswith('▁')}
return self._prefix_space_tokens

def _maybe_add_prefix_space(self, tokens, decoded):
"""maybe add prefix space for incremental decoding."""
if len(tokens) and not decoded.startswith(' ') and\
tokens[0] in self.prefix_space_tokens:
return ' ' + decoded
else:
return decoded

def indexes_containing_token(self, token: str):
"""Return all the possible indexes, whose decoding output may contain
the input token."""
# traversing vocab is time consuming, can not be accelerated with
# multi threads (computation) or multi process (can't pickle tokenizer)
# so, we maintain latest 10 stop words and return directly if matched
for _token, _indexes in self._indexes_tokens_deque:
if token == _token:
return _indexes
if token == ' ': # ' ' is special
token = '▁'
vocab = self.model.IdToPiece(list(range(self.vocab_size)))
indexes = [i for i, voc in enumerate(vocab) if token in voc]
if len(indexes) > self.max_indexes_num:
indexes = self.encode(token, add_bos=False)[-1:]
self.logger.warning(f'There are too many(>{self.max_indexes_num}) possible '
f'indexes may decoding {token}, we will use {indexes} only')
self._indexes_tokens_deque.append((token, indexes))
return indexes

def encode(self, s: str, add_bos: bool = True, **kwargs):
"""Tokenize a prompt.
Args:
s (str): a prompt
Returns:
list[int]: token ids
"""
return self.model.Encode(s, add_bos=add_bos, **kwargs)

def decode(self, t: Sequence[int], offset: Optional[int] = None, skip_special_tokens: bool = True, **kwargs):
"""De-tokenize.
Args:
t (List[int]): a list of token ids
offset (int): for incrementally decoding. Default to None, which
means not applied.
skip_special_tokens (boo): not used in SentencePieceTokenizer.
Returns:
str: text of decoding tokens
"""
if isinstance(t, torch.Tensor):
t = t.tolist()
t = t[offset:]
out_string = self.model.Decode(t)
if offset:
out_string = self._maybe_add_prefix_space(t, out_string)
return out_string

def detokenize_incrementally(self,
all_input_ids: Sequence[int],
state: DetokenizeState,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True):
"""Incrementally detokenize the input indexes.
Args:
all_input_ids (List[int]): a list of token ids. Expected to be
different sections of a long sequence.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
spaces_between_special_tokens (bool): Whether or not to add spaces
between special tokens. Default to be True.
Returns:
str: decoding output string of the current round.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
"""
out_string = self.model.Decode(all_input_ids)
if state.prev_tokens is not None:
out_string = self._maybe_add_prefix_space(all_input_ids, out_string)
state.prev_tokens = [] # not None for the above condition
return out_string, state

def __call__(self, s: Union[str, Sequence[str]]):
"""Tokenize prompts.
Args:
s (str): prompts
Returns:
list[int]: token ids
"""
import addict
add_bos = False
add_eos = False

input_ids = self.model.Encode(s, add_bos=add_bos, add_eos=add_eos)
return addict.Addict(input_ids=input_ids)


class HuggingFaceTokenizer:
"""Tokenizer of sentencepiece.
"""A wrapper of transformers' AutoTokenizer.
Args:
model_dir (str): the directory of the tokenizer model
Expand Down Expand Up @@ -521,33 +378,19 @@ class Tokenizer:
"""Tokenize prompts or de-tokenize tokens into texts.
Args:
model_file (str): the path of the tokenizer model
model_path (str): the path of the tokenizer model
"""

def __init__(self, model_file: str):
if model_file.endswith('.model'):
model_folder = osp.split(model_file)[0]
def __init__(self, model_path: str):
from transformers.models.auto.tokenization_auto import get_tokenizer_config
tokenizer_config = get_tokenizer_config(model_path, trust_remote_code=True)
config_tokenizer_class = tokenizer_config.get('tokenizer_class')
if config_tokenizer_class == 'ChatGLM4Tokenizer':
self.model = ChatGLM4Tokenizer(model_path)
elif config_tokenizer_class == 'ChatGLMTokenizer':
self.model = ChatGLMTokenizer(model_path)
else:
model_folder = model_file
model_file = osp.join(model_folder, 'tokenizer.model')
tokenizer_config_file = osp.join(model_folder, 'tokenizer_config.json')

model_file_exists = osp.exists(model_file)
config_exists = osp.exists(tokenizer_config_file)
use_hf_model = config_exists or not model_file_exists
self.logger = get_logger('lmdeploy')
if not use_hf_model:
self.model = SentencePieceTokenizer(model_file)
else:
from transformers.models.auto.tokenization_auto import get_tokenizer_config
tokenizer_config = get_tokenizer_config(model_folder, trust_remote_code=True)
config_tokenizer_class = tokenizer_config.get('tokenizer_class')
if config_tokenizer_class == 'ChatGLM4Tokenizer':
self.model = ChatGLM4Tokenizer(model_folder)
elif config_tokenizer_class == 'ChatGLMTokenizer':
self.model = ChatGLMTokenizer(model_folder)
else:
self.model = HuggingFaceTokenizer(model_folder)
self.model = HuggingFaceTokenizer(model_path)

@property
def vocab_size(self):
Expand Down
Loading

0 comments on commit 26622b8

Please sign in to comment.