diff --git a/example_logit_mixing.py b/example_cfg.py similarity index 90% rename from example_logit_mixing.py rename to example_cfg.py index 76e1df72..53fcbbb6 100644 --- a/example_logit_mixing.py +++ b/example_cfg.py @@ -2,6 +2,7 @@ from tokenizer import ExLlamaTokenizer from generator import ExLlamaGenerator import torch +import torch.nn.functional as F import os, glob import cuda_ext @@ -20,7 +21,6 @@ config = ExLlamaConfig(model_config_path) # create config from config.json config.model_path = model_path # supply path to model weights file -config.max_input_len = 16 model = ExLlama(config) # create ExLlama instance and load the weights tokenizer = ExLlamaTokenizer(tokenizer_path) # create tokenizer from tokenizer model file @@ -31,10 +31,10 @@ # Configure generator generator.settings.token_repetition_penalty_max = 1.15 -generator.settings.temperature = 0.75 +generator.settings.temperature = 0.95 generator.settings.top_k = 40 -generator.settings.top_p = 0.65 -# generator.settings.typical = 0.5 +generator.settings.top_p = 0.75 +# generator.settings.typical = 0.95 # Prompts to mix @@ -46,28 +46,30 @@ f2 = \ """[INST] <> -You are a rude and obnoxious assistant. You hate everything and everyone. <> +You are a rude and obnoxious assistant. You hate everything and everyone. {prompt}[/INST]""" + prompts = \ [ f1.replace("{prompt}", "Tell me about Homer Simpson"), f2.replace("{prompt}", "Tell me about Homer Simpson"), ] -def mixed_generation(prompts, alpha, max_new_tokens): +def generate_cfg(prompts, alpha, max_new_tokens): ids, mask = tokenizer.encode(prompts, return_mask = True) generator.gen_begin(ids, mask = mask) # Sampling loop - for i in range(max_new_tokens): + for _ in range(max_new_tokens): logits = model.forward(generator.sequence[:, -1:], cache, input_mask = mask) generator.apply_rep_penalty(logits) + logits = F.log_softmax(logits, dim = -1) logits_mixed = (1 - alpha) * logits[0] + alpha * logits[1] sampled_token, _ = generator.sample_current(logits_mixed) @@ -86,5 +88,5 @@ def mixed_generation(prompts, alpha, max_new_tokens): print(f"--------------------------------------") print(f"alpha = {alpha:.1f}") print(f"--------------------------------------") - output = mixed_generation(prompts, alpha, 200) + output = generate_cfg(prompts, alpha, 200) print(output[len(prompts[0]):].strip())