Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

commit #133

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions paddlemix/examples/blip2/run_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import random
import sys
from dataclasses import dataclass, field
Expand Down Expand Up @@ -153,14 +154,53 @@ def main():
mode="test",
)
model = create_model(model_args)

decorated = paddle.amp.decorate(
models=[model.language_model], optimizers=None, level="O2"
)
[model.language_model] = decorated

model.eval()
if training_args.model_path is not None:
checkpoint = training_args.model_path
load_model(training_args, model, ckpt_dir=checkpoint, load_language_model=False)
load_model(training_args, model.language_model, ckpt_dir=LLM_LIST[model_args.text_model_name_or_path])
generated_ids, scores = model.generate(**inputs)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
logger.info("Generate text: {}".format(generated_text))


warm_up = 5
repeate = 10
print("**inputs")
print(inputs)

# input_spec=[
# paddle.static.InputSpec(shape=[None, 3, None, None], dtype="float32"), # pixel_values
# paddle.static.InputSpec(shape=[None, None], dtype="int64"), # input_ids
# paddle.static.InputSpec(shape=[None, None], dtype="int64"), # attention_mask
# ]
# model = paddle.jit.to_static(model.generate, input_spec=input_spec)
# paddle.jit.save(model, "./blip2/inference")
# exit(0)


for i in range(warm_up):
generated_ids, scores = model.generate(**inputs)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
logger.info("Generate text: {}".format(generated_text))


import datetime
import time
starttime = datetime.datetime.now()
for i in range(repeate):
generated_ids, scores = model.generate(**inputs)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
logger.info("Generate text: {}".format(generated_text))

endtime = datetime.datetime.now()
duringtime = endtime - starttime
ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0
print (ms / repeate)# 单位是毫秒

return model


Expand Down
3 changes: 3 additions & 0 deletions paddlemix/models/blip2/eva_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,9 @@ def forward_features(self, x):
B = paddle.shape(x)[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand((B, -1, -1))
print(cls_tokens.dtype)
print(x.dtype)
print("cls_tokens.dtype")
x = paddle.concat((cls_tokens, x), axis=1)

if self.pos_embed is not None:
Expand Down
16 changes: 10 additions & 6 deletions paddlemix/models/blip2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,23 +767,27 @@ def generate(

language_model_inputs = self.Qformer.language_projection(query_output)
language_attention_mask = paddle.ones(language_model_inputs.shape[:-1], dtype="int64")
if input_ids is None:
input_ids = paddle.to_tensor([[self.config.text_config.bos_token_id]]).tile([batch_size, 1])
# if input_ids is None:
# input_ids = paddle.to_tensor([[self.config.text_config.bos_token_id]]).tile([batch_size, 1])
if attention_mask is None:
attention_mask = paddle.ones_like(input_ids)
print("attention_mask")
print(attention_mask)
attention_mask = paddle.concat([language_attention_mask, attention_mask], axis=1)
# concatenate query embeddings with prompt embeddings
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
language_model_inputs = paddle.cast(language_model_inputs, dtype="float16")
inputs_embeds = paddle.concat([language_model_inputs, inputs_embeds], axis=1)

outputs = self.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
do_sample=False,
top_p=0.9,
decode_strategy="greedy_search",
do_sample=True,
top_p=1.0,
top_k=1,
decode_strategy="sampling",
temperature=1,
num_beams=5,
num_beams=1,
max_length=30,
min_length=8,
eos_token_id=50118,
Expand Down
173 changes: 170 additions & 3 deletions paddlemix/models/blip2/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import collections
from functools import partial
from typing import Any, Dict, List
from paddle.incubate.nn import FusedMultiTransformer

import numpy as np
import paddle
Expand Down Expand Up @@ -462,6 +463,131 @@ def __init__(self, config: OPTConfig, decoder_layers: List[Layer]):

self.checkpoints = []


self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_size = self.hidden_size // self.num_heads

weight_file = "/root/.paddlenlp/models/facebook/opt-2.7b/model_state.pdparams"
self.state_dict = paddle.load(weight_file, return_numpy=True)

for k in self.state_dict.keys():
pass

paddle.set_default_dtype("float16")
ln_scale_attrs = [paddle.ParamAttr(name="opt.decoder.layers.{}.norm1.weight".format(i)) for i in range(config.num_hidden_layers)]
ln_bias_attrs = [paddle.ParamAttr(name="opt.decoder.layers.{}.norm1.bias".format(i)) for i in range(config.num_hidden_layers)]

qkv_weight_attrs = [paddle.ParamAttr(name="opt.decoder.layers.{}.qkv_weight".format(i)) for i in range(config.num_hidden_layers)]

out_proj_weight_attrs = [paddle.ParamAttr(name="opt.decoder.layers.{}.self_attn.out_proj.weight".format(i)) for i in range(config.num_hidden_layers)]
out_proj_bias_attrs = [paddle.ParamAttr(name="opt.decoder.layers.{}.self_attn.out_proj.bias".format(i)) for i in range(config.num_hidden_layers)]

ffn_ln_scale_attrs = [paddle.ParamAttr(name="opt.decoder.layers.{}.norm2.weight".format(i)) for i in range(config.num_hidden_layers)]
ffn_ln_bias_attrs = [paddle.ParamAttr(name="opt.decoder.layers.{}.norm2.bias".format(i)) for i in range(config.num_hidden_layers)]

ffn1_weight_attrs = [paddle.ParamAttr(name="opt.decoder.layers.{}.linear1.weight".format(i)) for i in range(config.num_hidden_layers)]
ffn1_bias_attrs = [paddle.ParamAttr(name="opt.decoder.layers.{}.linear1.bias".format(i)) for i in range(config.num_hidden_layers)]
ffn2_weight_attrs = [paddle.ParamAttr(name="opt.decoder.layers.{}.linear2.weight".format(i)) for i in range(config.num_hidden_layers)]
ffn2_bias_attrs = [paddle.ParamAttr(name="opt.decoder.layers.{}.linear2.bias".format(i)) for i in range(config.num_hidden_layers)]

self.transformer_block = FusedMultiTransformer(config.hidden_size,
config.num_attention_heads,
config.intermediate_size,
dropout_rate=0.0,
activation="relu",
normalize_before=True,
num_layers=config.num_hidden_layers,
nranks=1,
ring_id=-1,
ln_scale_attrs=ln_scale_attrs,
ln_bias_attrs = ln_bias_attrs,
qkv_weight_attrs=qkv_weight_attrs,
linear_weight_attrs=out_proj_weight_attrs,
linear_bias_attrs=out_proj_bias_attrs,
ffn_ln_scale_attrs=ffn_ln_scale_attrs,
ffn_ln_bias_attrs=ffn_ln_bias_attrs,
ffn1_weight_attrs=ffn1_weight_attrs,
ffn1_bias_attrs=ffn1_bias_attrs,
ffn2_weight_attrs=ffn2_weight_attrs,
ffn2_bias_attrs=ffn2_bias_attrs,
epsilon=1e-5)
self.cache_kvs = []

for i in range(self.num_layers):
ln_scale = paddle.to_tensor(self.state_dict["opt.decoder.layers.{}.norm1.weight".format(i)])
ln_scale = paddle.cast(ln_scale, "float32")
ln_bias = paddle.to_tensor(self.state_dict["opt.decoder.layers.{}.norm1.bias".format(i)])
ln_bias = paddle.cast(ln_bias, "float32")

q_weight = self.state_dict["opt.decoder.layers.{}.self_attn.q_proj.weight".format(i)]
k_weight = self.state_dict["opt.decoder.layers.{}.self_attn.k_proj.weight".format(i)]
v_weight = self.state_dict["opt.decoder.layers.{}.self_attn.v_proj.weight".format(i)]
q_bias = self.state_dict["opt.decoder.layers.{}.self_attn.q_proj.bias".format(i)]
k_bias = self.state_dict["opt.decoder.layers.{}.self_attn.k_proj.bias".format(i)]
v_bias = self.state_dict["opt.decoder.layers.{}.self_attn.v_proj.bias".format(i)]

concated_qkv_weight = np.concatenate([q_weight, k_weight, v_weight], axis=-1)
concated_qkv_weight = concated_qkv_weight.transpose(1, 0)
concated_qkv_weight = concated_qkv_weight.reshape(3, self.num_heads, self.head_size, self.hidden_size)
concated_qkv_weight = paddle.to_tensor(concated_qkv_weight)
concated_qkv_weight = paddle.cast(concated_qkv_weight, "float16")

concated_qkv_bias = np.concatenate([q_bias, k_bias, v_bias], axis=-1)
concated_qkv_bias = concated_qkv_bias.reshape(3, self.num_heads, self.head_size)
concated_qkv_bias = paddle.to_tensor(concated_qkv_bias)
concated_qkv_bias = paddle.cast(concated_qkv_bias, "float16")

out_proj_weight = paddle.to_tensor(self.state_dict["opt.decoder.layers.{}.self_attn.out_proj.weight".format(i)])
out_proj_bias = paddle.to_tensor(self.state_dict["opt.decoder.layers.{}.self_attn.out_proj.bias".format(i)])
out_proj_weight = paddle.cast(out_proj_weight, "float16")
out_proj_bias = paddle.cast(out_proj_bias, "float16")

ffn_ln_scale = paddle.to_tensor(self.state_dict["opt.decoder.layers.{}.norm2.weight".format(i)])
ffn_ln_bias = paddle.to_tensor(self.state_dict["opt.decoder.layers.{}.norm2.bias".format(i)])
ffn_ln_scale = paddle.cast(ffn_ln_scale, "float32")
ffn_ln_bias = paddle.cast(ffn_ln_bias, "float32")

ffn1_weight = paddle.to_tensor(self.state_dict["opt.decoder.layers.{}.linear1.weight".format(i)])
ffn1_bias = paddle.to_tensor(self.state_dict["opt.decoder.layers.{}.linear1.bias".format(i)])
ffn2_weight = paddle.to_tensor(self.state_dict["opt.decoder.layers.{}.linear2.weight".format(i)])
ffn2_bias = paddle.to_tensor(self.state_dict["opt.decoder.layers.{}.linear2.bias".format(i)])
ffn1_weight = paddle.cast(ffn1_weight, "float16")
ffn1_bias = paddle.cast(ffn1_bias, "float16")
ffn2_weight = paddle.cast(ffn2_weight, "float16")
ffn2_bias = paddle.cast(ffn2_bias, "float16")


# qkv_weight = paddle.concat(q_weight, k_weight, v_weight)
list_weight = [
ln_scale, ln_bias,
concated_qkv_weight, concated_qkv_bias,
out_proj_weight, out_proj_bias,
ffn_ln_scale, ffn_ln_bias,
ffn1_weight, ffn1_bias,
ffn2_weight, ffn2_bias,
]
self.transformer_block.ln_scales[i].set_value(list_weight[0])
self.transformer_block.ln_biases[i].set_value(list_weight[1])

self.transformer_block.qkv_weights[i].set_value(list_weight[2])
self.transformer_block.qkv_biases[i].set_value(list_weight[3])

self.transformer_block.linear_weights[i].set_value(list_weight[4])
self.transformer_block.linear_biases[i].set_value(list_weight[5])

self.transformer_block.ffn_ln_scales[i].set_value(list_weight[6])
self.transformer_block.ffn_ln_biases[i].set_value(list_weight[7])

self.transformer_block.ffn1_weights[i].set_value(list_weight[8])
self.transformer_block.ffn1_biases[i].set_value(list_weight[9])

self.transformer_block.ffn2_weights[i].set_value(list_weight[10])
self.transformer_block.ffn2_biases[i].set_value(list_weight[11])

paddle.set_default_dtype("float32")


def forward(
self,
tgt,
Expand All @@ -485,7 +611,37 @@ def forward(
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None

if (len(self.cache_kvs) == 0):
max_seq_length = 1024
# self.cache_kvs = [
# paddle.fluid.layers.fill_constant_batch_size_like(
# paddle.zeros(paddle.shape(tgt)[0:2], dtype='float32'),
# shape=[2, -1, self.num_heads, max_seq_length, self.head_size],
# input_dim_idx=0,
# output_dim_idx=1,
# value=0.,
# dtype="float16") for _ in range(self.num_layers)]
self.cache_kvs = [
paddle.tensor.fill_constant(
shape=[2, paddle.shape(tgt)[0], self.num_heads, max_seq_length, self.head_size],
value=0.,
dtype="float16") for _ in range(self.num_layers)]

is_decoder = cache is not None
output = paddle.cast(output, "float16")
tgt_mask = paddle.cast(tgt_mask, "float16")
# print("output", output.shape)
# presents 就是那个原地的大小!永远是最大的shape!
hidden_states, presents = self.transformer_block(output,
attn_mask = tgt_mask ,
caches=self.cache_kvs,
time_step=paddle.increment(paddle.shape(tgt_mask)[-1], -1) if is_decoder else None)
# 将output强行令为我的大Op的输出的hidden_states
output = hidden_states


for i, mod in enumerate(self.layers):
break
outputs = mod(
output,
memory,
Expand Down Expand Up @@ -837,12 +993,16 @@ def __init__(self, config: OPTConfig):
decoder_layers.append(TransformerDecoderLayer(config))
self.decoder = TransformerDecoder(config, decoder_layers)
self.checkpoints = []

self.past_key_values_length = 0

def _prepare_decoder_attention_mask(self, attention_mask, input_shape, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
# 只有第一次的时候才会进来哦!
if input_shape[-1] > 1:
print("只有第一次的时候才会进来哦!")
combined_attention_mask = _make_causal_mask(
input_shape, past_key_values_length=past_key_values_length, dtype=attention_mask.dtype
)
Expand Down Expand Up @@ -950,8 +1110,15 @@ def forward(
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

self.checkpoints = []
past_key_values_length = paddle.shape(cache[0].k)[2] if cache is not None else 0
# self.checkpoints = []
# past_key_values_length = paddle.shape(cache[0].k)[2] if cache is not None else 0


if (input_shape[1] > 1):
self.past_key_values_length = paddle.to_tensor([0])
self.past_key_values_length = self.past_key_values_length.reshape([1], name = "self.past_key_values_length")
past_key_values_length = paddle.to_tensor([self.past_key_values_length])
self.past_key_values_length += input_shape[1]

seq_length_with_past = input_shape[-1] + past_key_values_length

Expand Down Expand Up @@ -1194,7 +1361,7 @@ def prepare_inputs_for_generation(
self, input_ids, use_cache=False, cache=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if cache is not None:
input_ids = input_ids[:, -1:]
input_ids = input_ids[:, -1:2147483647]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache is None:
Expand Down