Skip to content

Commit

Permalink
Update logit mixing example, rename to CFG
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed Jul 24, 2023
1 parent b2e3982 commit e8a544f
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions example_logit_mixing.py → example_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from tokenizer import ExLlamaTokenizer
from generator import ExLlamaGenerator
import torch
import torch.nn.functional as F
import os, glob
import cuda_ext

Expand All @@ -20,7 +21,6 @@

config = ExLlamaConfig(model_config_path) # create config from config.json
config.model_path = model_path # supply path to model weights file
config.max_input_len = 16

model = ExLlama(config) # create ExLlama instance and load the weights
tokenizer = ExLlamaTokenizer(tokenizer_path) # create tokenizer from tokenizer model file
Expand All @@ -31,10 +31,10 @@
# Configure generator

generator.settings.token_repetition_penalty_max = 1.15
generator.settings.temperature = 0.75
generator.settings.temperature = 0.95
generator.settings.top_k = 40
generator.settings.top_p = 0.65
# generator.settings.typical = 0.5
generator.settings.top_p = 0.75
# generator.settings.typical = 0.95

# Prompts to mix

Expand All @@ -46,28 +46,30 @@

f2 = \
"""[INST] <<SYS>>
You are a rude and obnoxious assistant. You hate everything and everyone.
<</SYS>>
You are a rude and obnoxious assistant. You hate everything and everyone.
{prompt}[/INST]"""


prompts = \
[
f1.replace("{prompt}", "Tell me about Homer Simpson"),
f2.replace("{prompt}", "Tell me about Homer Simpson"),
]

def mixed_generation(prompts, alpha, max_new_tokens):
def generate_cfg(prompts, alpha, max_new_tokens):

ids, mask = tokenizer.encode(prompts, return_mask = True)
generator.gen_begin(ids, mask = mask)

# Sampling loop

for i in range(max_new_tokens):
for _ in range(max_new_tokens):

logits = model.forward(generator.sequence[:, -1:], cache, input_mask = mask)
generator.apply_rep_penalty(logits)

logits = F.log_softmax(logits, dim = -1)
logits_mixed = (1 - alpha) * logits[0] + alpha * logits[1]

sampled_token, _ = generator.sample_current(logits_mixed)
Expand All @@ -86,5 +88,5 @@ def mixed_generation(prompts, alpha, max_new_tokens):
print(f"--------------------------------------")
print(f"alpha = {alpha:.1f}")
print(f"--------------------------------------")
output = mixed_generation(prompts, alpha, 200)
output = generate_cfg(prompts, alpha, 200)
print(output[len(prompts[0]):].strip())

0 comments on commit e8a544f

Please sign in to comment.