forked from qwopqwop200/GPTQ-for-LLaMa
-
Notifications
You must be signed in to change notification settings - Fork 0
/
llama_inference.py
142 lines (117 loc) · 4.11 KB
/
llama_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import time
import torch
import torch.nn as nn
from gptq import *
from modelutils import *
from quant import *
from transformers import AutoTokenizer
DEV = torch.device('cuda:0')
def get_llama(model):
import torch
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
from transformers import LlamaForCausalLM
model = LlamaForCausalLM.from_pretrained(model, torch_dtype='auto')
model.seqlen = 2048
return model
def load_quant(model, checkpoint, wbits, groupsize, device, warmup_autotune = True):
from transformers import LlamaConfig, LlamaForCausalLM
config = LlamaConfig.from_pretrained(model)
def noop(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop
torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = LlamaForCausalLM(config)
torch.set_default_dtype(torch.float)
model = model.eval()
layers = find_layers(model)
for name in ['lm_head']:
if name in layers:
del layers[name]
make_quant(model, layers, wbits, groupsize)
print('Loading model ...')
if checkpoint.endswith('.safetensors'):
from safetensors.torch import load_file as safe_load
if device == -1:
device = "cpu"
model.load_state_dict(safe_load(checkpoint, device))
else:
model.load_state_dict(torch.load(checkpoint))
make_quant_attn(model)
if warmup_autotune:
autotune_warmup(model)
model.seqlen = 2048
print('Done.')
return model
if __name__ == '__main__':
import argparse
from datautils import *
parser = argparse.ArgumentParser()
parser.add_argument(
'model', type=str,
help='llama model to load'
)
parser.add_argument(
'--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16],
help='#bits to use for quantization; use 16 for evaluating base model.'
)
parser.add_argument(
'--groupsize', type=int, default=-1,
help='Groupsize to use for quantization; default uses full row.'
)
parser.add_argument(
'--load', type=str, default='',
help='Load quantized model.'
)
parser.add_argument(
'--text', type=str,
help='input text'
)
parser.add_argument(
'--min_length', type=int, default=10,
help='The minimum length of the sequence to be generated.'
)
parser.add_argument(
'--max_length', type=int, default=50,
help='The maximum length of the sequence to be generated.'
)
parser.add_argument(
'--top_p', type=float , default=0.95,
help='If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.'
)
parser.add_argument(
'--temperature', type=float, default=0.8,
help='The value used to module the next token probabilities.'
)
parser.add_argument(
'--device', type=int, default=-1,
help='The device used to load the model when using safetensors. Default device is "cpu" or specify, 0,1,2,3,... for GPU device.'
)
args = parser.parse_args()
if type(args.load) is not str:
args.load = args.load.as_posix()
if args.load:
model = load_quant(args.model, args.load, args.wbits, args.groupsize, args.device)
else:
model = get_llama(args.model)
model.eval()
model.to(DEV)
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
input_ids = tokenizer.encode(args.text, return_tensors="pt").to(DEV)
with torch.no_grad():
generated_ids = model.generate(
input_ids,
do_sample=True,
min_length=args.min_length,
max_length=args.max_length,
top_p=args.top_p,
temperature=args.temperature,
)
print(tokenizer.decode([el.item() for el in generated_ids[0]]))