Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DiT with decorator, triton fused_AdaLN and fineGrained #552

Open
wants to merge 27 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b93c57a
DiT FFN fineGrained
YKTian-x2b May 24, 2024
bca3484
DiT FFN fineGrained
YKTian-x2b May 24, 2024
6771b36
clear fine_grained_FFN
YKTian-x2b May 31, 2024
b4d92a3
Merge branch 'develop' into DiT_FFN_fineGrained
westfish May 31, 2024
ae34336
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleMIX i…
YKTian-x2b Jun 7, 2024
668f1ac
decorator + fineGrained_qkv_ffn + triton_adaLN_fusedAdaLN
YKTian-x2b Jun 7, 2024
fb96011
Merge branch 'DiT_FFN_fineGrained' of https://github.com/YKTian-x2b/P…
YKTian-x2b Jun 7, 2024
cb8bacb
clear up pr ing...
YKTian-x2b Jun 12, 2024
bdefd3b
Optional acceleration
YKTian-x2b Jun 12, 2024
18b5945
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleMIX i…
YKTian-x2b Jun 19, 2024
1cc8ca4
no reshape
YKTian-x2b Jul 22, 2024
3251f13
Revert "no reshape"
YKTian-x2b Jul 22, 2024
55b5042
no reshape
YKTian-x2b Jul 22, 2024
ee3df60
Merge branch 'DiT_FFN_fineGrained' of https://github.com/YKTian-x2b/P…
YKTian-x2b Jul 22, 2024
70c6dc0
no reshape
YKTian-x2b Jul 22, 2024
5d3d29f
fuse_repo triton kernel
YKTian-x2b Jul 25, 2024
8238c08
Merge remote-tracking branch 'upstream/develop' into DiT_FFN_fineGrained
YKTian-x2b Aug 5, 2024
19cbd90
with horizontal_fuse_pass opt
YKTian-x2b Aug 5, 2024
c000d4c
env
YKTian-x2b Aug 5, 2024
9f04a1c
ReNet
YKTian-x2b Aug 7, 2024
f2966f7
new net
YKTian-x2b Aug 7, 2024
437cbbb
little mod
YKTian-x2b Aug 7, 2024
fb4d478
pre-commit
YKTian-x2b Aug 7, 2024
e638313
update largedit
YKTian-x2b Aug 8, 2024
c933f80
update largedit
YKTian-x2b Aug 8, 2024
9903122
INFERENCE_OPTIMIZE
YKTian-x2b Aug 8, 2024
f54958a
new modify_weight
YKTian-x2b Aug 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 73 additions & 1 deletion ppdiffusers/ppdiffusers/models/dit_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,43 @@ def forward(self, x, freqs_cis):
return self.wo(output)


class FeedForward_kai(nn.Layer):
def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = int(multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of))

self.w13 = nn.Linear(dim, hidden_dim * 2, bias_attr=False)
self.w2 = nn.Linear(hidden_dim, dim, bias_attr=False)

def compute_activation(self, ffn1_out):
origin_batch_size = ffn1_out.shape[0]
origin_seq_len = ffn1_out.shape[1]
ffn1_out = ffn1_out.reshape([origin_batch_size*origin_seq_len, ffn1_out.shape[-1]])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个reshape加的不太好,建议拓展下fused_bias_act的实现

res = paddle._C_ops.fused_bias_act(
ffn1_out,
None,
None,
None,
None,
"swiglu",
"default",
-1,
0,
0,
0
)
return res.reshape([origin_batch_size, origin_seq_len, res.shape[-1]])

def forward(self, x):
ffn1_out = self.w13(x)
ffn1_out = self.compute_activation(ffn1_out)
ffn2_out = self.w2(ffn1_out)
return ffn2_out


class FeedForward(nn.Layer):
def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None):
"""
Expand Down Expand Up @@ -339,7 +376,7 @@ def __init__(
self.head_dim = dim // n_heads
self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, fused_attn)
mlp_hidden_dim = int(dim * mlp_ratio)
self.feed_forward = FeedForward(
self.feed_forward = FeedForward_kai(
dim=dim, hidden_dim=mlp_hidden_dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier
)
self.layer_id = layer_id
Expand Down Expand Up @@ -574,3 +611,38 @@ def forward(
return (output,)

return Transformer2DModelOutput(sample=output)

@classmethod
def custom_modify_weight(cls, state_dict):
# print("kai==================================")
# print(state_dict.keys())
import re
w1_pattern = r"layers\.(\d+)\.feed_forward\.w1.weight$"
w3_pattern = r"layers\.(\d+)\.feed_forward\.w3.weight$"
keys_to_add = []
w1_keys_to_del = []
w3_keys_to_del = []
for key in state_dict.keys():
if re.match(w1_pattern, key):
w1_keys_to_del.append(key)
w3_match = re.match(w3_pattern, key)
if w3_match:
w13_key ='layers.' + w3_match.group(1) + '.feed_forward.w13.weight'
keys_to_add.append(w13_key)
w3_keys_to_del.append(key)

assert len(keys_to_add) == len(w1_keys_to_del) == len(w3_keys_to_del)

for ii in range(len(keys_to_add)):
w13_key = keys_to_add[ii]
w1_key = w1_keys_to_del[ii]
w3_key = w3_keys_to_del[ii]
state_dict[w13_key] = paddle.concat([state_dict[w1_key], state_dict[w3_key]], axis=1)
state_dict.pop(w3_key)
state_dict.pop(w1_key)

# print(state_dict.keys())
# exit()



5 changes: 5 additions & 0 deletions ppdiffusers/ppdiffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

return model

@classmethod
def custom_modify_weight(cls, state_dict):
pass

@classmethod
def _load_pretrained_model(
cls,
Expand Down Expand Up @@ -1130,6 +1134,7 @@ def _find_mismatched_keys(
error_msgs.append(
f"Error size mismatch, {key_name} receives a shape {loaded_shape}, but the expected shape is {model_shape}."
)
cls.custom_modify_weight(state_dict)
faster_set_state_dict(model_to_load, state_dict)

missing_keys = sorted(list(set(expected_keys) - set(loaded_keys)))
Expand Down