Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor token healing initialization. #330

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 40 additions & 51 deletions exllamav2/generator/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
no_probs: torch.Tensor = None
no_logits: torch.Tensor = None

first_token = False
heal_next_token = False
heal_prefix_token = None
heal_old_tail_len = None

draft_model: ExLlamaV2 or None = None
draft_cache: ExLlamaV2Cache or None = None
Expand All @@ -53,10 +53,12 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
return_probabilities_k: int = 1 # Number of probabilities to return per token
return_logits: bool = False # Return raw logits prior to softmax, per token

active_loras = []
active_loras = None
position_offsets = None
input_mask = None

queued_logits = None


def __init__(self, model, cache, tokenizer, draft_model = None, draft_cache = None, num_speculative_tokens = 5):
super().__init__(model, cache, tokenizer)
Expand Down Expand Up @@ -119,7 +121,23 @@ def begin_stream(self, input_ids: torch.Tensor, gen_settings: ExLlamaV2Sampler.S
self.settings = gen_settings
self._gen_begin_reuse(input_ids, gen_settings)

self.heal_next_token = (token_healing and self.sequence_ids.shape[-1] >= 2)
self.queued_logits = []

# Initialize token healing
if token_healing and self.sequence_ids.shape[-1] >= max(2, self.tail_decode_tokens + 1):

# Pop the last token, remembering tail len for first stream decode

self.heal_old_tail_len = len(self.tokenizer.decode(self.sequence_ids[:, -(self.tail_decode_tokens + 1):])[0])
self.heal_prefix_token = self.sequence_ids[:, -1:]
self.sequence_ids = self.sequence_ids[:, :-1]
self.cache.current_seq_len -= 1

# Start filters

self.settings.begin_filters(self.tokenizer.get_id_to_piece_list()[self.heal_prefix_token])
else:
self.settings.begin_filters()


def stream(self) -> Union[Tuple[str, bool, torch.Tensor],
Expand All @@ -145,53 +163,16 @@ def stream(self) -> Union[Tuple[str, bool, torch.Tensor],

def _stream(self) -> (str, bool, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):

# Token healing

if self.heal_next_token:

# Pop the last token

old_tail = self.tokenizer.decode(self.sequence_ids[:, -self.tail_decode_tokens:])[0]
last_token = self.sequence_ids[:, -1:]
self.sequence_ids = self.sequence_ids[:, :-1]
self.cache.current_seq_len -= 1

# Start filters

if self.first_token:

self.settings.begin_filters(self.tokenizer.get_id_to_piece_list()[last_token])
self.first_token = False

# Regenerate the last token again, with prefix

healed_token, _, _, eos, logits = self._gen_single_token(self.settings, prefix_token = last_token)
new_tail = self.tokenizer.decode(self.sequence_ids[:, -self.tail_decode_tokens:])[0]
self.held_text += new_tail[len(old_tail):]

self.heal_next_token = False

# In case we only needed the healed token

if eos: return self.held_text, True, self.no_tokens, self.no_probs, self.no_ptokens, self.no_logits

# Start filters when not healing

if self.heal_old_tail_len is not None:
old_tail_len = self.heal_old_tail_len
self.heal_old_tail_len = None
else:

if self.first_token:

self.settings.begin_filters()
self.first_token = False


# Decode the current tail end of the sequence

old_tail = self.tokenizer.decode(self.sequence_ids[:1, -self.tail_decode_tokens:])[0]
old_tail_len = len(self.tokenizer.decode(self.sequence_ids[:1, -self.tail_decode_tokens:])[0])

# Generate a single token and append to the sequence

next_token, next_ptokens, next_prob, eos, next_logits = self._gen_single_token(self.settings)
next_token, next_ptokens, next_prob, eos, next_logits = self._gen_single_token(self.settings, prefix_token = self.heal_prefix_token)
self.heal_prefix_token = None

# End immediately if it was a stop token

Expand All @@ -201,7 +182,7 @@ def _stream(self) -> (str, bool, torch.Tensor, torch.Tensor, torch.Tensor, torch
# Decode the tail end of the sequence with the added token to get (actual) characters added

new_tail = self.tokenizer.decode(self.sequence_ids[:1, -(self.tail_decode_tokens + 1):])[0]
new_text = new_tail[len(old_tail):]
new_text = new_tail[old_tail_len:]

next_token, new_text = self._catch_utf8(next_token, new_text)

Expand Down Expand Up @@ -321,8 +302,6 @@ def _gen_begin(self, in_tokens, gen_settings):
self.future_logits = None
self.future_tokens = None

self.first_token = True


def _gen_begin_reuse(self, in_tokens, gen_settings):

Expand Down Expand Up @@ -367,11 +346,21 @@ def _gen_feed_tokens(self, in_tokens, gen_settings):
self.future_tokens = None


def append_logits(self, logits):

assert self.draft_model is None
assert logits.shape[0] == self.sequence_ids.shape[0]

self.queued_logits.append(logits)

def _gen_single_token(self, gen_settings, prefix_token = None):

if self.draft_model is None:

logits = self.model.forward(self.sequence_ids[:, -1:], self.cache, loras = self.active_loras, input_mask = self.input_mask, position_offsets = self.position_offsets).float().cpu()
if self.queued_logits:
logits = self.queued_logits.pop()
else:
logits = self.model.forward(self.sequence_ids[:, -1:], self.cache, loras = self.active_loras, input_mask = self.input_mask, position_offsets = self.position_offsets).float().cpu()
token, ptokens, prob, eos = ExLlamaV2Sampler.sample(logits, gen_settings, self.sequence_ids[:1, :], random.random(), self.tokenizer, prefix_token, self.return_probabilities_k)

else:
Expand Down