From 4aa4ebdda53809a0f159627010164fbb13fb02c6 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Mon, 23 Sep 2024 14:47:40 -0400 Subject: [PATCH 1/2] Add ExLlamaV2Sampler.Settings.logits_processor --- examples/json_schema_outlines.py | 108 ++++++++++ exllamav2/generator/base.py | 3 + exllamav2/generator/dynamic.py | 5 +- exllamav2/generator/sampler.py | 13 ++ exllamav2/generator/streaming.py | 22 ++- tests/test_logits_processor.py | 326 +++++++++++++++++++++++++++++++ 6 files changed, 469 insertions(+), 8 deletions(-) create mode 100644 examples/json_schema_outlines.py create mode 100644 tests/test_logits_processor.py diff --git a/examples/json_schema_outlines.py b/examples/json_schema_outlines.py new file mode 100644 index 00000000..18de5e8e --- /dev/null +++ b/examples/json_schema_outlines.py @@ -0,0 +1,108 @@ +# Install Outlines: +# pip install outlines + +# Download Model: +# huggingface-cli download bartowski/Phi-3.1-mini-4k-instruct-exl2 --revision 6_5 --local-dir Phi-3.1-mini-4k-instruct-exl2-6_5 + +import sys, os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + + +from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Tokenizer +from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2Sampler + +from outlines.processors import JSONLogitsProcessor +from outlines.models.exllamav2 import patch_tokenizer as patch_exl2_tokenizer_for_outlines + +from pydantic import BaseModel, Field, RootModel +from typing import Optional, Union, Literal +from datetime import time + + +################################################ +# Create Structured JSON Generator With Outlines +################################################ + +# Additional Examples: https://outlines-dev.github.io/outlines/cookbook/ +# JSON Generation Docs: https://outlines-dev.github.io/outlines/reference/json/ +# `outlines.processors` also supports guaranteed regex patterns and lark grammars + +# Example: Home Assistant extension for natural language commands -> actions +class LightAction(BaseModel): + entity: Literal["light"] = "light" + action: Literal["turn_on", "turn_off", "set_brightness"] + brightness: Optional[int] = Field(None, ge=0, le=100) + execute_at: Optional[time] = None + + +class OvenAction(BaseModel): + entity: Literal["oven"] = "oven" + action: Literal["turn_on", "turn_off", "set_temperature"] + temperature: Optional[float] = Field(None, ge=50, le=300) + execute_at: Optional[time] = None + + +class HomeAssistantAction(BaseModel): + instruction: Union[LightAction, OvenAction] + + +def create_generator(model_dir="/mnt/str/models/mistral-7b-exl2/4.0bpw"): + config = ExLlamaV2Config(model_dir) + config.arch_compat_overrides() + model = ExLlamaV2(config) + cache = ExLlamaV2Cache(model, max_seq_len=32768, lazy=True) + model.load_autosplit(cache, progress=True) + + print("Loading tokenizer...") + tokenizer = ExLlamaV2Tokenizer(config) + tokenizer.vocabulary = tokenizer.extended_piece_to_id + + # Initialize the generator with all default parameters + return ExLlamaV2DynamicGenerator( + model=model, + cache=cache, + tokenizer=tokenizer, + ) + + +generator = create_generator("./Phi-3.1-mini-4k-instruct-exl2-6_5") + +gen_settings = ExLlamaV2Sampler.Settings() +gen_settings.logits_processor = JSONLogitsProcessor( + HomeAssistantAction, + patch_exl2_tokenizer_for_outlines(generator.tokenizer) +) + + +rules = "JSON for an instruction with an entity (light or oven) and action (turn_on, turn_off, set_brightness, set temperature). *Optionally* you may set an execute_at time-of-day if the user specifies, otherwise set to null" +prompts = [ + f"<|user|> {rules} Turn the lights lower please!<|end|><|assistant|>", + f"<|user|> {rules} I need the oven set for homemade pizza when I get home from work at 6PM.<|end|><|assistant|>", + f"<|user|> {rules} Oh no the lights are off and I can't find the switch!<|end|><|assistant|>", +] + +outputs = generator.generate( + prompt=prompts, + gen_settings=gen_settings, + max_new_tokens=2048, + completion_only=True, + encode_special_tokens=False, + stop_conditions=[generator.tokenizer.eos_token_id], +) + +# raw json format +for idx, output in enumerate(outputs): + print(output) +# Output: +# {"instruction": {"entity": "light", "action": "set_brightness", "execute_at": null}} +# {"instruction": {"entity": "oven", "action": "set_temperature", "execute_at": "18:00:00"} } +# {"instruction": {"entity": "light", "action": "turn_on"}} + +# pydantic model format +for idx, output in enumerate(outputs): + print(repr(HomeAssistantAction.parse_raw(output))) +# Output: +# HomeAssistantAction(instruction=LightAction(entity='light', action='set_brightness', brightness=None, execute_at=None)) +# HomeAssistantAction(instruction=OvenAction(entity='oven', action='set_temperature', temperature=None, execute_at=datetime.time(18, 0))) +# HomeAssistantAction(instruction=LightAction(entity='light', action='turn_on', brightness=None, execute_at=None)) diff --git a/exllamav2/generator/base.py b/exllamav2/generator/base.py index 236735e0..60d5a499 100644 --- a/exllamav2/generator/base.py +++ b/exllamav2/generator/base.py @@ -193,6 +193,8 @@ def generate_simple( return_offsets = True, add_bos = add_bos) + pre_ids = torch.empty(*ids.shape[:-1], 0) + if prompts_identical: position_offsets = None @@ -268,6 +270,7 @@ def generate_simple( ExLlamaV2Sampler.sample( logits, gen_settings, + pre_ids, self.sequence_ids, random.random(), self.tokenizer, diff --git a/exllamav2/generator/dynamic.py b/exllamav2/generator/dynamic.py index 03da88d6..5717b294 100644 --- a/exllamav2/generator/dynamic.py +++ b/exllamav2/generator/dynamic.py @@ -74,7 +74,7 @@ class CachePage: kv_position: int kv_position_revert: int # Specific tokens for which KV is valid assuming prev_hash - sequence: torch.Tensor + sequence: torch.Tensors can_revert: bool # Used by defragmenter new_page_index: int @@ -525,7 +525,7 @@ def set_loras(self, loras: list[ExLlamaV2Lora] | None): self.current_loras = loras else: self.current_loras = [loras] - + def generate( self, @@ -1766,6 +1766,7 @@ def receive_logits( ExLlamaV2Sampler.sample( logits, self.gen_settings, + self.sequences[0].input_ids.torch(), self.sequences[0].sequence_ids.torch(), self.rng.random(), self.generator.tokenizer, diff --git a/exllamav2/generator/sampler.py b/exllamav2/generator/sampler.py index 9f795c8f..a6bc113e 100644 --- a/exllamav2/generator/sampler.py +++ b/exllamav2/generator/sampler.py @@ -1,5 +1,6 @@ from __future__ import annotations from dataclasses import dataclass, field +from typing import Optional, Callable import torch import torch.nn.functional as F from exllamav2 import ExLlamaV2Tokenizer @@ -71,6 +72,8 @@ class Settings: typical: float = 0 skew: float = 0 + logits_processor: Optional[Callable] = None + temperature_last: bool = False mirostat: bool = False @@ -269,6 +272,7 @@ def apply_dry( def sample( logits: torch.tensor, settings: Settings, + input_ids: torch.tensor, sequence_ids: torch.tensor, random: float, tokenizer: ExLlamaV2Tokenizer, @@ -289,6 +293,9 @@ def sample( :param settings: ExLlamaV2Sampler.Settings + :param input_ids: + The prompt portion of sequence_ids, shape (batch_size, seq_len) + :param sequence_ids: Past token IDs to consider for repetition penalty etc., shape (batch_size, seq_len) @@ -354,6 +361,12 @@ def sample( logits = logits.unsqueeze(0) batch_size = 1 + # Apply logits processor + + if settings.logits_processor: + generated_ids = sequence_ids[:, input_ids.shape[1]:] + logits = settings.logits_processor(generated_ids, logits) + # Prepare filter logit_filter = None diff --git a/exllamav2/generator/streaming.py b/exllamav2/generator/streaming.py index 8ff3b965..adde9692 100644 --- a/exllamav2/generator/streaming.py +++ b/exllamav2/generator/streaming.py @@ -327,6 +327,7 @@ def begin_stream_ex( assert input_ids.shape[0] <= 2, "Streaming generator does not support batch size > 1" if input_ids.shape[0] == 2: assert gen_settings.cfg_scale is not None, "No CFG scale set" + self.input_ids = input_ids self.position_offsets = position_offsets self.input_mask = input_mask @@ -500,7 +501,7 @@ def stream(self, **kwargs) -> Union[Tuple[str, bool, torch.Tensor], if self.return_logits: ret.append(logits) - + return tuple(ret) @@ -819,6 +820,7 @@ def _gen_single_token(self, gen_settings, prefix_token = None): token, ptokens, pprobs, prob, eos = ExLlamaV2Sampler.sample( logits, gen_settings, + self.input_ids, self.sequence_ids[:1, :], random.random(), self.tokenizer, @@ -854,12 +856,12 @@ def _gen_single_token(self, gen_settings, prefix_token = None): for f in self.filters: f.feed(token) # Accept token - + if self.sequence_ids.shape[0] > 1 and token.shape[0] == 1: self.sequence_ids = torch.cat([self.sequence_ids, token.repeat(self.sequence_ids.shape[0], 1)], dim = 1) else: self.sequence_ids = torch.cat([self.sequence_ids, token], dim = 1) - + return token, ptokens, pprobs, prob, eos, logits.flatten(1), dev_logits @@ -881,7 +883,15 @@ def _gen_single_token_speculative(self, gen_settings, prefix_token = None): self.draft_cache, input_mask = self.input_mask, position_offsets = self.position_offsets).float().cpu() - token, _, _, prob, _ = ExLlamaV2Sampler.sample(logits, draft_gen_settings, draft_sequence_ids, random.random(), self.tokenizer, prefix_token if k == 0 else None) + token, _, _, prob, _ = ExLlamaV2Sampler.sample( + logits, + draft_gen_settings, + self.input_ids, + draft_sequence_ids, + random.random(), + self.tokenizer, + prefix_token if k == 0 else None + ) if prob < self.speculative_prob_threshold: self.draft_cache.current_seq_len -= 1 @@ -918,6 +928,7 @@ def _gen_single_token_speculative(self, gen_settings, prefix_token = None): token, ptokens, pprobs, prob, eos = ExLlamaV2Sampler.sample( logits, gen_settings, + self.input_ids, self.sequence_ids[:1, :], random.random(), self.tokenizer, prefix_token, @@ -980,6 +991,7 @@ def _gen_single_token_ngram(self, gen_settings, prefix_token = None): token, ptokens, pprobs, prob, eos = ExLlamaV2Sampler.sample( logits, gen_settings, + self.input_ids, self.sequence_ids[:1, :], random.random(), self.tokenizer, @@ -1038,5 +1050,3 @@ def ngram_preload(self, self.ngram_preloaded = NgramCache(self.speculative_ngram_min, self.speculative_ngram_max, None) self.ngram_preloaded.update(input_ids) - - diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py new file mode 100644 index 00000000..15d403d8 --- /dev/null +++ b/tests/test_logits_processor.py @@ -0,0 +1,326 @@ + +import sys, os, gc, time, random +import torch + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from exllamav2 import( + ExLlamaV2, + ExLlamaV2Config, + ExLlamaV2CacheBase, + ExLlamaV2Cache, + ExLlamaV2Cache_8bit, + ExLlamaV2Tokenizer, +) + +from exllamav2.generator import ( + ExLlamaV2BaseGenerator, + ExLlamaV2StreamingGenerator, + ExLlamaV2Sampler +) + +import time + +model: ExLlamaV2 +config: ExLlamaV2Config +tokenizer: ExLlamaV2Tokenizer +cache: ExLlamaV2CacheBase + + +class SamplerLogitsProcessor: + def __init__( + self, + temperature: float = 1.0, + top_k: int = 0, + top_p: float = 0.0, + top_a: float = 0.0, + disallow_tokens: list = [], + ): + if temperature <= 0: + raise ValueError("Temperature must be > 0.") + if top_k < 0: + raise ValueError("top_k must be >= 0.") + if not (0.0 <= top_p <= 1.0): + raise ValueError("top_p must be between 0 and 1.") + if top_a < 0: + raise ValueError("top_a must be >= 0.") + self.temperature = temperature + self.top_k = top_k + self.top_p = top_p + self.top_a = top_a + self.disallow_tokens = torch.tensor(disallow_tokens) + + @torch.no_grad() + def __call__(self, input_ids, logits): + # Apply temperature scaling + if self.temperature != 1.0: + logits /= self.temperature + + # Initialize mask + mask = torch.zeros_like(logits, dtype=torch.bool) + + if self.top_p > 0.0 or self.top_a > 0.0: + probs = torch.nn.functional.softmax(logits, dim=-1) + sorted_probs, sorted_indices = probs.sort(descending=True, dim=-1) + + # Apply top-p filtering + if self.top_p > 0.0: + cumulative_probs = sorted_probs.cumsum(dim=-1) + sorted_mask = cumulative_probs > self.top_p + sorted_mask = sorted_mask.roll(shifts=1, dims=-1) + sorted_mask[..., 0] = False + mask.scatter_(-1, sorted_indices, sorted_mask) + + # Apply top-a filtering + if self.top_a > 0.0: + max_probs = sorted_probs[:, 0].unsqueeze(-1) + mask |= probs < (max_probs / self.top_a) + + # top-k: logits > kth largest value's logits + if self.top_k > 0: + threshold = logits.topk(self.top_k, dim=-1, largest=True).values[..., -1, None] + mask |= logits < threshold # Compare logits directly with the threshold + + # Filter disallowed tokens + if self.disallow_tokens is not None and self.disallow_tokens.numel() > 0: + self.disallow_tokens = self.disallow_tokens.to(device=logits.device) + mask.index_fill_(-1, self.disallow_tokens, True) + + # Apply the mask + logits.masked_fill_(mask, -float("inf")) + return logits + + + +def unload(): + global model, config, tokenizer, cache + + model.unload() + model = None + config = None + cache = None + tokenizer = None + + gc.collect() + torch.cuda.empty_cache() + + +def load_model(model_dir, split = None, cache_8bit = False): + global model, config, tokenizer, cache + + config = ExLlamaV2Config() + config.model_dir = model_dir + config.prepare() + + model = ExLlamaV2(config) + print(" -- Loading model: " + model_dir) + + model.load(split) + + tokenizer = ExLlamaV2Tokenizer(config) + + if cache_8bit: + print(" -- Creating 8-bit cache") + cache = ExLlamaV2Cache_8bit(model, batch_size = 4) + else: + print(" -- Creating 16-bit cache") + cache = ExLlamaV2Cache(model, batch_size = 4) + + +def test_gen_normal(prompt, max_new_tokens): + global model, config, tokenizer, cache + + print("--------------------------------") + print("Generating, normal") + print() + + generator = ExLlamaV2BaseGenerator(model, cache, tokenizer) + + settings = ExLlamaV2Sampler.Settings() + settings.logits_processor = SamplerLogitsProcessor( + temperature=0.85, + top_k=50, + top_p=0.8, + top_a=0.0, + disallow_tokens=[tokenizer.eos_token_id], + ) + + generator.warmup() + time_begin = time.time() + + output = generator.generate_simple(prompt, settings, max_new_tokens, seed=1234) + + time_end = time.time() + time_total = time_end - time_begin + + print(output) + print() + print(f"Response generated in {time_total:.2f} seconds, {max_new_tokens} tokens, {max_new_tokens / time_total:.2f} tokens/second") + + +def test_gen_streaming(prompt, max_new_tokens): + global model, config, tokenizer, cache + + print("--------------------------------") + print("Generating, streaming") + print() + + generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer) + + settings = ExLlamaV2Sampler.Settings() + settings.logits_processor = SamplerLogitsProcessor( + temperature=0.85, + top_k=50, + top_p=0.8, + top_a=0.0, + disallow_tokens=[tokenizer.eos_token_id], + ) + + input_ids = tokenizer.encode(prompt) + prompt_tokens = input_ids.shape[-1] + + print(prompt, end = "") + sys.stdout.flush() + + time_begin_prompt = time.time() + + generator.set_stop_conditions([]) + generator.begin_stream(input_ids, settings) + + time_begin_stream = time.time() + generated_tokens = 0 + + while True: + chunk, eos, _ = generator.stream() + generated_tokens += 1 + print(chunk, end = "") + sys.stdout.flush() + if eos or generated_tokens == max_new_tokens: break + + time_end = time.time() + + time_prompt = time_begin_stream - time_begin_prompt + time_tokens = time_end - time_begin_stream + + print() + print() + print(f"Prompt processed in {time_prompt:.2f} seconds, {prompt_tokens} tokens, {prompt_tokens / time_prompt:.2f} tokens/second") + print(f"Response generated in {time_tokens:.2f} seconds, {generated_tokens} tokens, {generated_tokens / time_tokens:.2f} tokens/second") + + +def test_gen_batch(max_new_tokens): + + print("--------------------------------") + print("Generating, batched") + print() + + generator = ExLlamaV2BaseGenerator(model, cache, tokenizer) + + settings = ExLlamaV2Sampler.Settings() + settings.logits_processor = SamplerLogitsProcessor( + temperature=0.85, + top_k=50, + top_p=0.8, + top_a=0.0, + disallow_tokens=[tokenizer.eos_token_id], + ) + + generator.warmup() + time_begin = time.time() + + prompts = ["Here's how to create a powerful love potio", + "For once,", + "The events of the American Civil W", + "A bird in the hand is worth"] + + output = generator.generate_simple(prompts, settings, max_new_tokens, seed = 1234, token_healing = True) + + time_end = time.time() + time_total = time_end - time_begin + + for o in output: + print(o) + print("---") + print() + print(f"Response generated in {time_total:.2f} seconds, {max_new_tokens} tokens, throughput {4 * max_new_tokens / time_total:.2f} tokens/second") + + +def test_multicache(max_new_tokens): + + print("--------------------------------") + print("Generating, batched multi cache") + print() + + settings = ExLlamaV2Sampler.Settings() + settings.logits_processor = SamplerLogitsProcessor( + temperature=0.85, + top_k=50, + top_p=0.8, + top_a=0.0, + disallow_tokens=[tokenizer.eos_token_id], + ) + + prompts = ["Here's how to create a powerful love potion", + "For once,", + "The events of the American Civil War", + "A bird in the hand is worth"] + + caches = [ExLlamaV2Cache(model, max_seq_len = 256) for _ in range(len(prompts))] + input_ids = [] + + for i in range(len(prompts)): + + input_ids.append(tokenizer.encode(prompts[i])) + model.forward(input_ids[i][:, :-1], caches[i], input_mask = None, preprocess_only = True) + + time_begin = time.time() + + for i in range(max_new_tokens): + + inputs = torch.cat([x[:, -1:] for x in input_ids], dim = 0) + logits = model.forward(inputs, caches, input_mask = None).float().cpu() + + r = random.random() + for j in range(len(input_ids)): + token, _, _ = ExLlamaV2Sampler.sample(logits[j:j + 1, :, :], settings, input_ids[j], r, tokenizer) + input_ids[j] = torch.cat([input_ids[j], token], dim = 1) + + output = [tokenizer.decode(ids)[0] for ids in input_ids] + + time_end = time.time() + time_total = time_end - time_begin + + for o in output: + print(o) + print("---") + print() + print(f"Response generated in {time_total:.2f} seconds, {max_new_tokens} tokens, throughput {4 * max_new_tokens / time_total:.2f} tokens/second") + + +def tests(model_dir, cache_8bit, use_split): + + if use_split: split = [1, 24] + else: split = None + print("--------------------------------") + print(f" -- Split: {split}") + load_model(model_dir, split = split, cache_8bit = cache_8bit) + + test_gen_normal("Our story begins in the Scottish town of Auchtermuchty, where once", 150) + test_gen_streaming("Our story begins in the Scottish town of Auchtermuchty, where once", 150) + test_gen_batch(40) + if model.is_quant(): test_multicache(40) + + unload() + + +q_model_directory = "/mnt/str/models/mistral-7b-instruct-exl2/4.0bpw/" +f_model_directory = "/mnt/str/models/tinyllama-1b-ckpt503/" + +tests(q_model_directory, False, False) +tests(q_model_directory, False, True) +tests(q_model_directory, True, False) +tests(q_model_directory, True, True) +tests(f_model_directory, False, False) +tests(f_model_directory, False, True) +tests(f_model_directory, True, False) +tests(f_model_directory, True, True) From ce08f16674a67dac9ea6a770650eb02248b8364a Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 4 Oct 2024 18:01:29 -0400 Subject: [PATCH 2/2] consistent shape between logits and generated_ids --- exllamav2/generator/sampler.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/exllamav2/generator/sampler.py b/exllamav2/generator/sampler.py index a6bc113e..0aee2a5b 100644 --- a/exllamav2/generator/sampler.py +++ b/exllamav2/generator/sampler.py @@ -364,8 +364,12 @@ def sample( # Apply logits processor if settings.logits_processor: - generated_ids = sequence_ids[:, input_ids.shape[1]:] - logits = settings.logits_processor(generated_ids, logits) + generated_ids = sequence_ids[:, input_ids.shape[-1]:] + # normalize to 2d + logits_2d = logits.view(-1, logits.shape[-1]) + generated_ids_2d = generated_ids.view(logits_2d.shape[0], generated_ids.shape[-1]) + # process logits and convert back to original logits shape + logits = settings.logits_processor(generated_ids_2d, logits_2d).view(logits.shape) # Prepare filter