Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DiT with decorator, triton fused_AdaLN and fineGrained #552

Open
wants to merge 27 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b93c57a
DiT FFN fineGrained
YKTian-x2b May 24, 2024
bca3484
DiT FFN fineGrained
YKTian-x2b May 24, 2024
6771b36
clear fine_grained_FFN
YKTian-x2b May 31, 2024
b4d92a3
Merge branch 'develop' into DiT_FFN_fineGrained
westfish May 31, 2024
ae34336
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleMIX i…
YKTian-x2b Jun 7, 2024
668f1ac
decorator + fineGrained_qkv_ffn + triton_adaLN_fusedAdaLN
YKTian-x2b Jun 7, 2024
fb96011
Merge branch 'DiT_FFN_fineGrained' of https://github.com/YKTian-x2b/P…
YKTian-x2b Jun 7, 2024
cb8bacb
clear up pr ing...
YKTian-x2b Jun 12, 2024
bdefd3b
Optional acceleration
YKTian-x2b Jun 12, 2024
18b5945
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleMIX i…
YKTian-x2b Jun 19, 2024
1cc8ca4
no reshape
YKTian-x2b Jul 22, 2024
3251f13
Revert "no reshape"
YKTian-x2b Jul 22, 2024
55b5042
no reshape
YKTian-x2b Jul 22, 2024
ee3df60
Merge branch 'DiT_FFN_fineGrained' of https://github.com/YKTian-x2b/P…
YKTian-x2b Jul 22, 2024
70c6dc0
no reshape
YKTian-x2b Jul 22, 2024
5d3d29f
fuse_repo triton kernel
YKTian-x2b Jul 25, 2024
8238c08
Merge remote-tracking branch 'upstream/develop' into DiT_FFN_fineGrained
YKTian-x2b Aug 5, 2024
19cbd90
with horizontal_fuse_pass opt
YKTian-x2b Aug 5, 2024
c000d4c
env
YKTian-x2b Aug 5, 2024
9f04a1c
ReNet
YKTian-x2b Aug 7, 2024
f2966f7
new net
YKTian-x2b Aug 7, 2024
437cbbb
little mod
YKTian-x2b Aug 7, 2024
fb4d478
pre-commit
YKTian-x2b Aug 7, 2024
e638313
update largedit
YKTian-x2b Aug 8, 2024
c933f80
update largedit
YKTian-x2b Aug 8, 2024
9903122
INFERENCE_OPTIMIZE
YKTian-x2b Aug 8, 2024
f54958a
new modify_weight
YKTian-x2b Aug 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@

from ppdiffusers import DDIMScheduler, DiTPipeline

dtype = paddle.float32
dtype = paddle.bfloat16

# To speed up this code, call zkk and let him run for you,
# then you will get a speed increase of almost 100%.
os.environ['callZKK']= "True"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个环境变量改成其他的,可以optimize_inference_for_ditllama?


with paddle.LazyGuard():
pipe = DiTPipeline.from_pretrained("Alpha-VLLM/Large-DiT-7B-256", paddle_dtype=dtype)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
Expand Down
204 changes: 169 additions & 35 deletions ppdiffusers/ppdiffusers/models/dit_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from .modeling_utils import ModelMixin
from .transformer_2d import Transformer2DModelOutput

from paddle.framework import LayerHelper, in_dynamic_mode
from paddle.incubate.tt import adaptive_layer_norm, fused_adaLN_scale_residual
import os

def TypePromote(x, y):
TYPE_PROMOTE_DICT = {
Expand Down Expand Up @@ -90,7 +93,7 @@ def forward(self, t):


class Attention(nn.Layer):
def __init__(self, dim, n_heads, n_kv_heads, qk_norm=True, fused_attn=True):
def __init__(self, dim, n_heads, n_kv_heads, qk_norm=True, fused_attn=True, callZKK=False):
"""
Initialize the Attention module.

Expand All @@ -108,6 +111,7 @@ def __init__(self, dim, n_heads, n_kv_heads, qk_norm=True, fused_attn=True):
wq (nn.Linear): Linear transformation for queries.
wk (nn.Linear): Linear transformation for keys.
wv (nn.Linear): Linear transformation for values.
qkv (nn.Linear): Linear transformation for queries, keys and values.
wo (nn.Linear): Linear transformation for output.
cache_k (paddle.Tensor): Cached keys for attention.
cache_v (paddle.Tensor): Cached values for attention.
Expand All @@ -120,9 +124,14 @@ def __init__(self, dim, n_heads, n_kv_heads, qk_norm=True, fused_attn=True):
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = dim // n_heads

self.wq = nn.Linear(dim, n_heads * self.head_dim, bias_attr=False)
self.wk = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False)
self.wv = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False)
self.callZKK = callZKK
if not callZKK:
self.wq = nn.Linear(dim, n_heads * self.head_dim, bias_attr=False)
self.wk = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False)
self.wv = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False)
else:
self.qkv = nn.Linear(dim, (n_heads + 2 * self.n_kv_heads) * self.head_dim, bias_attr=False)

self.wo = nn.Linear(n_heads * self.head_dim, dim, bias_attr=False)

if qk_norm:
Expand Down Expand Up @@ -184,7 +193,15 @@ def apply_rotary_emb(xq, xk, freqs_cis):
Tuple[paddle.Tensor, paddle.Tensor]: Tuple of modified query tensor
and key tensor with rotary embeddings.
"""
with paddle.amp.auto_cast(enable=False):
if not os.getenv('callZKK'):
with paddle.amp.auto_cast(enable=False):
xq_ = paddle.as_complex(xq.cast("float32").reshape([*tuple(xq.shape)[:-1], -1, 2]))
xk_ = paddle.as_complex(xk.cast("float32").reshape([*tuple(xk.shape)[:-1], -1, 2]))
freqs_cis = Attention.reshape_for_broadcast(freqs_cis, xq_)
xq_out = paddle.as_real(xq_ * freqs_cis).flatten(start_axis=3)
xk_out = paddle.as_real(xk_ * freqs_cis).flatten(start_axis=3)
return xq_out.cast(xq.dtype), xk_out.cast(xk.dtype)
else:
xq_ = paddle.as_complex(xq.cast("float32").reshape([*tuple(xq.shape)[:-1], -1, 2]))
xk_ = paddle.as_complex(xk.cast("float32").reshape([*tuple(xk.shape)[:-1], -1, 2]))
freqs_cis = Attention.reshape_for_broadcast(freqs_cis, xq_)
Expand All @@ -205,7 +222,13 @@ def forward(self, x, freqs_cis):

"""
bsz, seqlen, _ = tuple(x.shape)
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

if not self.callZKK:
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
else:
qkv_out = self.qkv(x)
xq, xk, xv = paddle.split(qkv_out, 3, axis=-1)

dtype = xq.dtype

xq = self.q_norm(xq)
Expand Down Expand Up @@ -253,7 +276,7 @@ def forward(self, x, freqs_cis):


class FeedForward(nn.Layer):
def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None):
def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None, callZKK=False):
"""
Initialize the FeedForward module.

Expand All @@ -266,28 +289,85 @@ def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None):
dimension. Defaults to None.

Attributes:
w1 (nn.Linear): Linear transformation for the first
layer.
w1 (nn.Linear): Linear transformation for the first layer.
w2 (nn.Linear): Linear transformation for the second layer.
w3 (nn.Linear): Linear transformation for the third
layer.

w3 (nn.Linear): Linear transformation for the third layer.
w13 (nn.Linear): Linear transformation for the first and the third layer.
"""
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = int(multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of))

self.w1 = nn.Linear(dim, hidden_dim, bias_attr=False)
self.w2 = nn.Linear(hidden_dim, dim, bias_attr=False)
self.w3 = nn.Linear(dim, hidden_dim, bias_attr=False)
self.callZKK = callZKK
if not callZKK:
self.w1 = nn.Linear(dim, hidden_dim, bias_attr=False)
self.w3 = nn.Linear(dim, hidden_dim, bias_attr=False)
else:
self.w13 = nn.Linear(dim, hidden_dim * 2, bias_attr=False)

def compute_activation(self,
ffn1_out,
bias=None,
dequant_scales=None,
shift=None,
smooth=None,
act_method="swiglu",
compute_dtype="default",
quant_scale=-1,
quant_round_type=0,
quant_max_bound=0,
quant_min_bound=0):
if in_dynamic_mode():
out = paddle._C_ops.fused_bias_act(
ffn1_out,
bias,
dequant_scales,
shift,
smooth,
act_method,
compute_dtype,
quant_scale,
quant_round_type,
quant_max_bound,
quant_min_bound
)
return out

helper = LayerHelper("fused_bias_act")
out = helper.create_variable_for_type_inference(dtype=ffn1_out.dtype)
inputs = {}
inputs["x"] = ffn1_out
attrs = {
"act_method": act_method,
"compute_dtype": compute_dtype,
"quant_scale": quant_scale,
"quant_round_type": quant_round_type,
"quant_max_bound": quant_max_bound,
"quant_min_bound": quant_min_bound,
}
helper.append_op(
type="fused_bias_act",
inputs=inputs,
outputs={"out": out},
attrs=attrs,
)
return out

def forward(self, x):
xw1 = F.silu(self.w1(x))
xw3 = self.w3(x)
output = self.w2(xw1 * xw3)
return output
if not self.callZKK:
xw1 = F.silu(self.w1(x))
xw3 = self.w3(x)
output = self.w2(xw1 * xw3)
return output
else:
ffn1_out = self.w13(x)
ffn1_out = self.compute_activation(ffn1_out)
ffn2_out = self.w2(ffn1_out)
return ffn2_out


class TransformerBlock(nn.Layer):
Expand All @@ -303,6 +383,7 @@ def __init__(
norm_eps: float,
qk_norm: bool,
fused_attn: bool,
callZKK=False,
) -> None:
"""
Initialize a TransformerBlock.
Expand Down Expand Up @@ -337,10 +418,11 @@ def __init__(
super().__init__()
self.dim = dim
self.head_dim = dim // n_heads
self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, fused_attn)
self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, fused_attn, callZKK=callZKK)
mlp_hidden_dim = int(dim * mlp_ratio)
self.feed_forward = FeedForward(
dim=dim, hidden_dim=mlp_hidden_dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier
dim=dim, hidden_dim=mlp_hidden_dim, multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier, callZKK=callZKK
)
self.layer_id = layer_id
self.attention_norm = nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False)
Expand All @@ -350,6 +432,8 @@ def __init__(
nn.Silu(),
nn.Linear(min(dim, 1024), 6 * dim),
)
self.norm_eps = norm_eps
self.callZKK = callZKK

def forward(self, x, freqs_cis, adaln_input=None):
"""
Expand All @@ -370,10 +454,17 @@ def forward(self, x, freqs_cis, adaln_input=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(
6, axis=1
)
h = x + gate_msa.unsqueeze(1) * self.attention(
modulate(self.attention_norm(x), shift_msa, scale_msa), freqs_cis
)
out = h + gate_mlp.unsqueeze(1) * self.feed_forward(modulate(self.ffn_norm(h), shift_mlp, scale_mlp))
if not self.callZKK:
h = x + gate_msa.unsqueeze(1) * self.attention(
modulate(self.attention_norm(x), shift_msa, scale_msa), freqs_cis
)
out = h + gate_mlp.unsqueeze(1) * self.feed_forward(modulate(self.ffn_norm(h), shift_mlp, scale_mlp))
else:
attention_out = self.attention(adaptive_layer_norm(x, scale_msa, shift_msa,
weight=self.attention_norm.weight, epsilon=self.norm_eps), freqs_cis)
residual_out, adaLN_out = fused_adaLN_scale_residual(x, attention_out, gate_msa, scale_mlp, shift_mlp,
weight=self.ffn_norm.weight, epsilon=self.norm_eps)
out = residual_out + gate_mlp.unsqueeze(1) * self.feed_forward(adaLN_out)
else:
h = x + self.attention(self.attention_norm(x), freqs_cis)
out = h + self.feed_forward(self.ffn_norm(h))
Expand Down Expand Up @@ -435,14 +526,14 @@ def __init__(
self.num_classes = num_classes
self.learn_sigma = learn_sigma
self.qk_norm = qk_norm

self.gradient_checkpointing = True
self.fused_attn = True

self.x_embedder = nn.Linear(in_channels * patch_size**2, dim)
self.t_embedder = TimestepEmbedder(min(dim, 1024))
self.y_embedder = LabelEmbedding(num_classes, min(dim, 1024), class_dropout_prob)

self.callZKK = True if os.getenv('callZKK') else False
# 2. Define transformers blocks
self.layers = nn.LayerList(
[
Expand All @@ -457,10 +548,13 @@ def __init__(
norm_eps=norm_eps,
qk_norm=qk_norm,
fused_attn=self.fused_attn,
callZKK=self.callZKK,
)
for idx in range(num_layers)
]
)

# del self.layers

# 3. Define output layers
self.final_layer = FinalLayer(dim, patch_size, self.out_channels)
Expand Down Expand Up @@ -531,6 +625,21 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
)
return freqs_cis

@paddle.jit.to_static(backend="inference", with_trt=False,
cache_static_model=False,
collect_shape=False)
def transformer_blocks(self, x, adaln_input):
for i, layer in enumerate(self.layers):
if self.gradient_checkpointing and False:
x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input)
else:
x = layer(
x,
self.freqs_cis[: x.shape[1]],
adaln_input,
)
return x

def forward(
self,
hidden_states: paddle.Tensor,
Expand All @@ -556,16 +665,19 @@ def forward(
adaln_input = t + y

# 2. Blocks
for i, layer in enumerate(self.layers):
if self.gradient_checkpointing:
x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input)
else:
x = layer(
x,
self.freqs_cis[: x.shape[1]],
adaln_input,
)

if not self.callZKK:
for i, layer in enumerate(self.layers):
if self.gradient_checkpointing:
x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input)
else:
x = layer(
x,
self.freqs_cis[: x.shape[1]],
adaln_input,
)
else:
x = self.transformer_blocks(x, adaln_input)

# 3. Output
hidden_states = self.final_layer(x, adaln_input)
output = self.unpatchify(hidden_states)
Expand All @@ -574,3 +686,25 @@ def forward(
return (output,)

return Transformer2DModelOutput(sample=output)

@classmethod
def custom_modify_weight(cls, state_dict):
# If you're not invited to zkk, you won't get any performance optimizations.
if os.getenv('callZKK'):
for key in list(state_dict.keys()):
if 'feed_forward.w1.weight' in key:
w1 = state_dict.pop(key)
w3_key = key.replace('w1', 'w3')
w3 = state_dict.pop(w3_key)
w13 = paddle.concat([w1, w3], axis=1)
state_dict[key.replace('w1', 'w13')] = w13
if 'attention.wq.weight' in key or 'attention.wk.weight' in key or 'attention.wv.weight' in key:
part = key.split('.')[-2]
layer_id = key.split('.')[1]
qkv_key = f'layers.{layer_id}.attention.qkv.weight'
if part == 'wq' and qkv_key not in state_dict:
state_dict[qkv_key] = state_dict.pop(key)
elif part in ('wk', 'wv'):
qkv = state_dict.get(qkv_key)
if qkv is not None:
state_dict[qkv_key] = paddle.concat([qkv, state_dict.pop(key)], axis=1)
5 changes: 5 additions & 0 deletions ppdiffusers/ppdiffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

return model

@classmethod
def custom_modify_weight(cls, state_dict):
pass

@classmethod
def _load_pretrained_model(
cls,
Expand Down Expand Up @@ -1130,6 +1134,7 @@ def _find_mismatched_keys(
error_msgs.append(
f"Error size mismatch, {key_name} receives a shape {loaded_shape}, but the expected shape is {model_shape}."
)
cls.custom_modify_weight(state_dict)
faster_set_state_dict(model_to_load, state_dict)

missing_keys = sorted(list(set(expected_keys) - set(loaded_keys)))
Expand Down