Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ExLlamaV2Sampler.Settings.logits_processor #634

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions examples/json_schema_outlines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Install Outlines:
# pip install outlines

# Download Model:
# huggingface-cli download bartowski/Phi-3.1-mini-4k-instruct-exl2 --revision 6_5 --local-dir Phi-3.1-mini-4k-instruct-exl2-6_5

import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))



from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Tokenizer
from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2Sampler

from outlines.processors import JSONLogitsProcessor
from outlines.models.exllamav2 import patch_tokenizer as patch_exl2_tokenizer_for_outlines

from pydantic import BaseModel, Field, RootModel
from typing import Optional, Union, Literal
from datetime import time


################################################
# Create Structured JSON Generator With Outlines
################################################

# Additional Examples: https://outlines-dev.github.io/outlines/cookbook/
# JSON Generation Docs: https://outlines-dev.github.io/outlines/reference/json/
# `outlines.processors` also supports guaranteed regex patterns and lark grammars

# Example: Home Assistant extension for natural language commands -> actions
class LightAction(BaseModel):
entity: Literal["light"] = "light"
action: Literal["turn_on", "turn_off", "set_brightness"]
brightness: Optional[int] = Field(None, ge=0, le=100)
execute_at: Optional[time] = None


class OvenAction(BaseModel):
entity: Literal["oven"] = "oven"
action: Literal["turn_on", "turn_off", "set_temperature"]
temperature: Optional[float] = Field(None, ge=50, le=300)
execute_at: Optional[time] = None


class HomeAssistantAction(BaseModel):
instruction: Union[LightAction, OvenAction]


def create_generator(model_dir="/mnt/str/models/mistral-7b-exl2/4.0bpw"):
config = ExLlamaV2Config(model_dir)
config.arch_compat_overrides()
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, max_seq_len=32768, lazy=True)
model.load_autosplit(cache, progress=True)

print("Loading tokenizer...")
tokenizer = ExLlamaV2Tokenizer(config)
tokenizer.vocabulary = tokenizer.extended_piece_to_id

# Initialize the generator with all default parameters
return ExLlamaV2DynamicGenerator(
model=model,
cache=cache,
tokenizer=tokenizer,
)


generator = create_generator("./Phi-3.1-mini-4k-instruct-exl2-6_5")

gen_settings = ExLlamaV2Sampler.Settings()
gen_settings.logits_processor = JSONLogitsProcessor(
HomeAssistantAction,
patch_exl2_tokenizer_for_outlines(generator.tokenizer)
)


rules = "JSON for an instruction with an entity (light or oven) and action (turn_on, turn_off, set_brightness, set temperature). *Optionally* you may set an execute_at time-of-day if the user specifies, otherwise set to null"
prompts = [
f"<|user|> {rules} Turn the lights lower please!<|end|><|assistant|>",
f"<|user|> {rules} I need the oven set for homemade pizza when I get home from work at 6PM.<|end|><|assistant|>",
f"<|user|> {rules} Oh no the lights are off and I can't find the switch!<|end|><|assistant|>",
]

outputs = generator.generate(
prompt=prompts,
gen_settings=gen_settings,
max_new_tokens=2048,
completion_only=True,
encode_special_tokens=False,
stop_conditions=[generator.tokenizer.eos_token_id],
)

# raw json format
for idx, output in enumerate(outputs):
print(output)
# Output:
# {"instruction": {"entity": "light", "action": "set_brightness", "execute_at": null}}
# {"instruction": {"entity": "oven", "action": "set_temperature", "execute_at": "18:00:00"} }
# {"instruction": {"entity": "light", "action": "turn_on"}}

# pydantic model format
for idx, output in enumerate(outputs):
print(repr(HomeAssistantAction.parse_raw(output)))
# Output:
# HomeAssistantAction(instruction=LightAction(entity='light', action='set_brightness', brightness=None, execute_at=None))
# HomeAssistantAction(instruction=OvenAction(entity='oven', action='set_temperature', temperature=None, execute_at=datetime.time(18, 0)))
# HomeAssistantAction(instruction=LightAction(entity='light', action='turn_on', brightness=None, execute_at=None))
3 changes: 3 additions & 0 deletions exllamav2/generator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ def generate_simple(
return_offsets = True,
add_bos = add_bos)

pre_ids = torch.empty(*ids.shape[:-1], 0)

if prompts_identical:
position_offsets = None

Expand Down Expand Up @@ -268,6 +270,7 @@ def generate_simple(
ExLlamaV2Sampler.sample(
logits,
gen_settings,
pre_ids,
self.sequence_ids,
random.random(),
self.tokenizer,
Expand Down
5 changes: 3 additions & 2 deletions exllamav2/generator/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class CachePage:
kv_position: int
kv_position_revert: int
# Specific tokens for which KV is valid assuming prev_hash
sequence: torch.Tensor
sequence: torch.Tensors
can_revert: bool
# Used by defragmenter
new_page_index: int
Expand Down Expand Up @@ -525,7 +525,7 @@ def set_loras(self, loras: list[ExLlamaV2Lora] | None):
self.current_loras = loras
else:
self.current_loras = [loras]


def generate(
self,
Expand Down Expand Up @@ -1766,6 +1766,7 @@ def receive_logits(
ExLlamaV2Sampler.sample(
logits,
self.gen_settings,
self.sequences[0].input_ids.torch(),
self.sequences[0].sequence_ids.torch(),
self.rng.random(),
self.generator.tokenizer,
Expand Down
17 changes: 17 additions & 0 deletions exllamav2/generator/sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional, Callable
import torch
import torch.nn.functional as F
from exllamav2 import ExLlamaV2Tokenizer
Expand Down Expand Up @@ -71,6 +72,8 @@ class Settings:
typical: float = 0
skew: float = 0

logits_processor: Optional[Callable] = None

temperature_last: bool = False

mirostat: bool = False
Expand Down Expand Up @@ -269,6 +272,7 @@ def apply_dry(
def sample(
logits: torch.tensor,
settings: Settings,
input_ids: torch.tensor,
sequence_ids: torch.tensor,
random: float,
tokenizer: ExLlamaV2Tokenizer,
Expand All @@ -289,6 +293,9 @@ def sample(
:param settings:
ExLlamaV2Sampler.Settings

:param input_ids:
The prompt portion of sequence_ids, shape (batch_size, seq_len)

:param sequence_ids:
Past token IDs to consider for repetition penalty etc., shape (batch_size, seq_len)

Expand Down Expand Up @@ -354,6 +361,16 @@ def sample(
logits = logits.unsqueeze(0)
batch_size = 1

# Apply logits processor

if settings.logits_processor:
generated_ids = sequence_ids[:, input_ids.shape[-1]:]
# normalize to 2d
logits_2d = logits.view(-1, logits.shape[-1])
generated_ids_2d = generated_ids.view(logits_2d.shape[0], generated_ids.shape[-1])
# process logits and convert back to original logits shape
logits = settings.logits_processor(generated_ids_2d, logits_2d).view(logits.shape)

# Prepare filter

logit_filter = None
Expand Down
22 changes: 16 additions & 6 deletions exllamav2/generator/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def begin_stream_ex(
assert input_ids.shape[0] <= 2, "Streaming generator does not support batch size > 1"
if input_ids.shape[0] == 2:
assert gen_settings.cfg_scale is not None, "No CFG scale set"
self.input_ids = input_ids

self.position_offsets = position_offsets
self.input_mask = input_mask
Expand Down Expand Up @@ -500,7 +501,7 @@ def stream(self, **kwargs) -> Union[Tuple[str, bool, torch.Tensor],

if self.return_logits:
ret.append(logits)

return tuple(ret)


Expand Down Expand Up @@ -819,6 +820,7 @@ def _gen_single_token(self, gen_settings, prefix_token = None):
token, ptokens, pprobs, prob, eos = ExLlamaV2Sampler.sample(
logits,
gen_settings,
self.input_ids,
self.sequence_ids[:1, :],
random.random(),
self.tokenizer,
Expand Down Expand Up @@ -854,12 +856,12 @@ def _gen_single_token(self, gen_settings, prefix_token = None):
for f in self.filters: f.feed(token)

# Accept token

if self.sequence_ids.shape[0] > 1 and token.shape[0] == 1:
self.sequence_ids = torch.cat([self.sequence_ids, token.repeat(self.sequence_ids.shape[0], 1)], dim = 1)
else:
self.sequence_ids = torch.cat([self.sequence_ids, token], dim = 1)

return token, ptokens, pprobs, prob, eos, logits.flatten(1), dev_logits


Expand All @@ -881,7 +883,15 @@ def _gen_single_token_speculative(self, gen_settings, prefix_token = None):
self.draft_cache,
input_mask = self.input_mask,
position_offsets = self.position_offsets).float().cpu()
token, _, _, prob, _ = ExLlamaV2Sampler.sample(logits, draft_gen_settings, draft_sequence_ids, random.random(), self.tokenizer, prefix_token if k == 0 else None)
token, _, _, prob, _ = ExLlamaV2Sampler.sample(
logits,
draft_gen_settings,
self.input_ids,
draft_sequence_ids,
random.random(),
self.tokenizer,
prefix_token if k == 0 else None
)

if prob < self.speculative_prob_threshold:
self.draft_cache.current_seq_len -= 1
Expand Down Expand Up @@ -918,6 +928,7 @@ def _gen_single_token_speculative(self, gen_settings, prefix_token = None):
token, ptokens, pprobs, prob, eos = ExLlamaV2Sampler.sample(
logits,
gen_settings,
self.input_ids,
self.sequence_ids[:1, :], random.random(),
self.tokenizer,
prefix_token,
Expand Down Expand Up @@ -980,6 +991,7 @@ def _gen_single_token_ngram(self, gen_settings, prefix_token = None):
token, ptokens, pprobs, prob, eos = ExLlamaV2Sampler.sample(
logits,
gen_settings,
self.input_ids,
self.sequence_ids[:1, :],
random.random(),
self.tokenizer,
Expand Down Expand Up @@ -1038,5 +1050,3 @@ def ngram_preload(self,

self.ngram_preloaded = NgramCache(self.speculative_ngram_min, self.speculative_ngram_max, None)
self.ngram_preloaded.update(input_ids)


Loading