diff --git a/paddlemix/triton_ops/__init__.py b/paddlemix/triton_ops/__init__.py index f10d9daaf..9e87dd73a 100644 --- a/paddlemix/triton_ops/__init__.py +++ b/paddlemix/triton_ops/__init__.py @@ -14,35 +14,33 @@ __all__ = [] -try: - from .triton_ops import ( - adaptive_layer_norm, - fused_adaLN_scale_residual, - fused_rotary_emb, - paddle_use_triton, - rms_norm, - split_concat, - triton_split, - weight_only_int8, - ) - from .triton_utils import ( - get_dtype_str, - paddle_custom_op_head_part, - tune_and_invoke_part, - ) - __all__ += [ - "paddle_custom_op_head_part", - "tune_and_invoke_part", - "paddle_use_triton", - "weight_only_int8", - "adaptive_layer_norm", - "fused_adaLN_scale_residual", - "rms_norm", - "get_dtype_str", - "fused_rotary_emb", - "split_concat", - "triton_split", - ] -except: - pass +from .rotary_emb import apply_rotary_emb_triton +from .triton_ops import ( + adaptive_layer_norm, + fused_adaLN_scale_residual, + fused_rotary_emb, + paddle_use_triton, + split_concat, + triton_split, + weight_only_int8, +) +from .triton_utils import ( + get_dtype_str, + paddle_custom_op_head_part, + tune_and_invoke_part, +) + +__all__ += [ + "paddle_custom_op_head_part", + "apply_rotary_emb_triton", + "tune_and_invoke_part", + "paddle_use_triton", + "weight_only_int8", + "adaptive_layer_norm", + "fused_adaLN_scale_residual", + "get_dtype_str", + "fused_rotary_emb", + "split_concat", + "triton_split", +] diff --git a/paddlemix/triton_ops/rotary_emb.py b/paddlemix/triton_ops/rotary_emb.py new file mode 100644 index 000000000..5c994b6a4 --- /dev/null +++ b/paddlemix/triton_ops/rotary_emb.py @@ -0,0 +1,175 @@ + + + +import paddle +import triton +import triton.language as tl +from paddle import _C_ops +from paddle.base.framework import OpProtoHolder +from paddle.base.layer_helper import LayerHelper +from paddle.framework import in_dynamic_or_pir_mode + +from .triton_utils import get_dtype_str, paddle_use_triton, rendering_common_template + + +@paddle_use_triton( + key=["1"], +) +def apply_rotary_emb_kernel( + q_ptr, + k_ptr, + cos_ptr, + sin_ptr, + outq_ptr, + outk_ptr, + batch, + num_heads, + seq_len, + head_dim, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + # 计算当前线程处理的元素范围 + b_pid = tl.program_id(axis=0) # grid内哪个Block + h_pid = tl.program_id(axis=1) + s_pid = tl.program_id(axis=2) + + block_start = b_pid * num_heads * seq_len * head_dim + h_pid * seq_len * head_dim + s_pid * head_dim + read_offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = read_offsets < n_elements + even_mask = tl.arange(0, BLOCK_SIZE) % 2 == 0 + q0 = tl.load(q_ptr + read_offsets, mask=mask & even_mask) #0,2,4,6,8 + q1 = tl.load(q_ptr + read_offsets + 1, mask=mask & even_mask) #1,3,5,7,9 + + k0 = tl.load(k_ptr + read_offsets, mask=mask & even_mask) #0,2,4,6,8 + k1 = tl.load(k_ptr + read_offsets + 1, mask=mask & even_mask) #1,3,5,7,9 + + + # 加载 cos 和 sin + block_cs_start = s_pid * head_dim + read_cs_offsets = block_cs_start + tl.arange(0, BLOCK_SIZE) + cs_mask = read_cs_offsets < (seq_len * head_dim) + cos0 = tl.load(cos_ptr + read_cs_offsets, mask=cs_mask & even_mask)#0,2,4,6,8 + cos1 = tl.load(cos_ptr + read_cs_offsets + 1, mask=cs_mask & even_mask)#1,3,5,7,9 + sin0 = tl.load(sin_ptr + read_cs_offsets, mask=cs_mask & even_mask)#0,2,4,6,8 + sin1 = tl.load(sin_ptr + read_cs_offsets + 1, mask=cs_mask & even_mask)#1,3,5,7,9 + + + oq0 = tl.cast(tl.cast(q0, tl.float32) * cos0 - tl.cast(q1, tl.float32) * sin0,tl.float16) + oq1 = tl.cast(tl.cast(q1, tl.float32) * cos1 + tl.cast(q0, tl.float32) * sin1,tl.float16) + + ok0 = tl.cast(tl.cast(k0, tl.float32) * cos0 - tl.cast(k1, tl.float32) * sin0,tl.float16) + ok1 = tl.cast(tl.cast(k1, tl.float32) * cos1 + tl.cast(k0, tl.float32) * sin1,tl.float16) + + # 将结果存储到全局内存 + tl.store(outq_ptr + read_offsets, oq0, mask=mask & even_mask) + tl.store(outq_ptr + read_offsets + 1, oq1, mask=mask & even_mask) + tl.store(outk_ptr + read_offsets, ok0, mask=mask & even_mask) + tl.store(outk_ptr + read_offsets + 1, ok1, mask=mask & even_mask) + + + +def apply_rotary_emb_triton( + q, + k, + cos, + sin, +): + batch = q.shape[0] + num_heads = q.shape[1] + seq_len = q.shape[2] + head_dim = q.shape[3] + n_elements = batch * num_heads * seq_len * head_dim + + prepare_attr_for_triton_kernel = """ + // 这个名字必须保证和kernel形式参数一致! + int batch = q.dims()[0]; + int num_heads = q.dims()[1]; + int seq_len = q.dims()[2]; + int head_dim = q.dims()[3]; + int n_elements = batch * num_heads * seq_len * head_dim; + """ + + + assert head_dim == 64, "wdfdfref" + BLOCK_SIZE = head_dim + op_name = "apply_rotary_emb_triton" + op_name += get_dtype_str(q.dtype) + op_name += f"_{BLOCK_SIZE}" + # 创建输出张量 + + # apply_rotary_emb_kernel_config = [ + # {"num_warps": 2}, + # {"num_warps": 4}, + # {"num_warps": 8}, + # {"num_warps": 16}, + # {"num_warps": 32}, + # ] + if op_name not in OpProtoHolder.instance().op_proto_map.keys(): + outq = paddle.empty_like(q) + outk = paddle.empty_like(k) + + prepare_ptr_for_triton_kernel = """ + // 这个名字必须保证和kernel形式参数一致! + auto q_ptr = get_tensor_ptr(q); + auto k_ptr = get_tensor_ptr(k); + auto cos_ptr = get_tensor_ptr(cos); + auto sin_ptr = get_tensor_ptr(sin); + + auto out0_tensor = paddle::empty(q.shape(), q.dtype(), q.place()); + auto out1_tensor = paddle::empty(k.shape(), k.dtype(), k.place()); + auto outq_ptr = get_tensor_ptr(out0_tensor); + auto outk_ptr = get_tensor_ptr(out1_tensor); + """ + return_tensor_names = "out0_tensor, out1_tensor" + + template_used = rendering_common_template( + apply_rotary_emb_triton, prepare_attr_for_triton_kernel, prepare_ptr_for_triton_kernel, return_tensor_names + ) + + + grid = ("batch","num_heads","seq_len") + apply_rotary_emb_kernel[(op_name,template_used, grid)]( + q_ptr = q, + k_ptr = k, + cos_ptr = cos, + sin_ptr = sin, + outq_ptr = outq, + outk_ptr = outk, + batch = batch, + num_heads = num_heads, + seq_len = seq_len, + head_dim = head_dim, + n_elements = n_elements, + BLOCK_SIZE=BLOCK_SIZE + ) + if in_dynamic_or_pir_mode(): + #print(f"== we are in dynamic mode, op_name: {op_name}") + outs = _C_ops._run_custom_op( + op_name, + q, + k, + cos, + sin, + ) + return outs[0],outs[1] + else: + #print(f"== we are in dynamic to static mode, op_name: {op_name}") + helper = LayerHelper(op_name, **locals()) + inputs = { + "q": q, + "k": k, + "cos": cos, + "sin": sin, + } + outq = helper.create_variable_for_type_inference(dtype=q.dtype) + outk = helper.create_variable_for_type_inference(dtype=q.dtype) + + helper.append_op( + type=op_name, + inputs=inputs, + outputs={"out0_tensor": outq,"out1_tensor": outk}, + ) + return outq,outk + + diff --git a/paddlemix/triton_ops/triton_ops.py b/paddlemix/triton_ops/triton_ops.py index 1e0c2fab9..ec3526da5 100644 --- a/paddlemix/triton_ops/triton_ops.py +++ b/paddlemix/triton_ops/triton_ops.py @@ -1,3 +1,5 @@ + + # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -24,263 +26,14 @@ from paddle.framework import in_dynamic_or_pir_mode from .triton_utils import ( - SubstituteTemplate, - build_package, - compile_file, - extract_triton_kernel, - find_so_path, get_dtype_str, get_op_name_with_suffix, - get_pointer_hint, - get_value_hint, - link_file, - multi_process_do, - paddle_custom_op_head_part, - python_path, - rename_c_to_cu, + paddle_use_triton, + rendering_common_template, tune_and_invoke_part, ) -class KernelInterface: - def __init__( - self, - func, - custom_op_template, - other_config, - key_args=["1"], - ): - self.func = func - self.key_args = key_args - - import inspect - - signature = inspect.signature(func) - self.arg_names = [v.name for v in signature.parameters.values()] - for ele in self.arg_names: - assert self.arg_names.count(ele) == 1 - arg_defaults = [v.default for v in signature.parameters.values()] - - # self.annotations = { - # name: ty for name, ty in func.__annotations__.items() - # } - self.annotations = dict(func.__annotations__) - - self.constexprs = [ - self.arg_names.index(name) - for name in self.arg_names - if self.annotations.get(name) == triton.language.core.constexpr - ] - - self.arg_exclude_constexpr = [ - self.arg_names[i] for i in range(len(self.arg_names)) if i not in self.constexprs - ] - - import textwrap - - py_script = textwrap.dedent(inspect.getsource(func)) - - import re - - pat = r"def\s" + func.__name__ - func_begin = re.findall(pat, py_script) - assert len(func_begin) == 1 - func_begin = func_begin[0] - py_script = py_script[py_script.find(func_begin) :] - - def decorator(*args, **kwargs): - all_input = [] - - for i in range(len(args)): - all_input.append(args[i]) - - position_arguments_num = len(all_input) - for i in range(position_arguments_num, len(self.arg_names)): - if self.arg_names[i] in kwargs.keys(): - all_input.append(kwargs[self.arg_names[i]]) - else: - # means this input is not specified, it muse be a tl.constexpr. - assert i in self.constexprs - all_input.append(None) - - dtypes = [] - x_list = [] - const_args = [self.arg_names[i] for i in self.constexprs] - # we dont allow there are two strings in const_args, and one is a substring of the other. - for i in const_args: - for j in const_args: - if i != j and i.find(j) != -1: - raise ValueError( - f"We find {i}, {j} in tl.constexpr args, and {j} is a substring of {i}, please modify your triton kernel arguments names to avoid this." - ) - - const_hint_dict = {} - for i in range(len(all_input)): - ele = all_input[i] - if ( - type(ele) == paddle.Tensor - or type(ele) == paddle.base.framework.EagerParamBase - or type(ele) == paddle.base.framework.Parameter - or type(ele) == paddle.base.framework.Variable - or type(ele) == paddle.base.libpaddle.pir.Value - ): - dtypes.append(ele.dtype) - elif i in self.constexprs: - const_hint_dict[self.arg_names[i]] = ele - else: - x_list.append(ele) - - op_name = self.op_name - - python_package_name = f"{op_name}_package" - - generated_dir = os.getenv("TRITON_KERNEL_CACHE_DIR", None) - print("the kernel cache dir is:", generated_dir) - assert ( - generated_dir is not None - ), "TRITON_KERNEL_CACHE_DIR is None, please set it such as export TRITON_KERNEL_CACHE_DIR=/tmp/haha " - generated_dir = f"{generated_dir}/{op_name}" - os.makedirs(generated_dir, exist_ok=True) - - py_script_file = f"{generated_dir}/triton_kernels.py" - extract_triton_kernel(func, py_script_file) - - address_hint = get_pointer_hint(dtypes) - value_hint = get_value_hint(x_list) - const_args = [f"{{{ele}}}" for ele in const_args] - const_args = ",".join(const_args) - - lanuch_grid = list(self.grid) - for i in range(len(lanuch_grid)): - ele = lanuch_grid[i] - if type(ele) == str: - for key in const_hint_dict.keys(): - if key in ele: - ele = ele.replace(key, f"{{{key}}}") - else: - ele = str(ele) - - lanuch_grid[i] = ele - if len(lanuch_grid) < 3: - lanuch_grid += ["1"] * (3 - len(lanuch_grid)) - lanuch_grid = ",".join(lanuch_grid) - - op_dict = {"op_name": op_name, "reset_zero_when_tune": ""} - op_dict["triton_kernel_args"] = ",".join(self.arg_exclude_constexpr) - op_dict["key"] = ",".join(self.key_args) - # when tunning, we need to reset the out to zero. - if "reset_zero_when_tune" in other_config.keys(): - op_dict["reset_zero_when_tune"] = other_config["reset_zero_when_tune"] - - paddle_custom_op_file_path = f"{generated_dir}/{op_name}.cu" - so_path = find_so_path(generated_dir, python_package_name) - - if so_path is None: - print("== we do not find so_path, we need to compile it") - with open(paddle_custom_op_file_path, "w") as f: - f.write( - SubstituteTemplate( - custom_op_template, - op_dict, - ) - ) - f.close() - - # ahead of time compile command. - aot_template = ( - f"""{python_path} {compile_file} {py_script_file} -n {func.__name__} -o {generated_dir}/{op_name}_kernel --out-name {op_name}_kernel """ - + """ -w {num_warps} -ns {num_stages} """ - + f""" -s"{address_hint} {value_hint} {const_args}" """ - + f""" -g "{lanuch_grid}" """ - ) - all_tune_config = list(self.tune_config) - if len(all_tune_config) == 0: - # when user do not specify config, we use const_hint_dict as config. - all_tune_config = [const_hint_dict] - # reset const_hint_dict as empty. - const_hint_dict = {} - codegen_commands = [] - for config in all_tune_config: - for key in const_hint_dict.keys(): - if const_hint_dict[key] is not None: - if key not in config.keys(): - config[key] = const_hint_dict[key] - else: - raise ValueError(f"you specify {key} both in arguments and config, this is wrong.") - else: - assert key in config.keys(), f"you must specify {key} in your config." - if "num_warps" not in config.keys(): - config["num_warps"] = 4 - if "num_stages" not in config.keys(): - config["num_stages"] = 4 - - for key in config: - assert config[key] is not None, f"{key} must be specified." - codegen_command = aot_template.format( - **config, - ) - print(codegen_command) - codegen_commands.append(codegen_command) - multi_process_do(codegen_commands) - - link_command = f"{python_path} {link_file} {generated_dir}/*.h -o {generated_dir}/{op_name}_kernel" - re = os.system(link_command) - assert re == 0 - - # rename the .c file to .cu - rename_c_to_cu(generated_dir) - # build the package to so, not install - build_package(generated_dir, python_package_name) - - if op_name not in OpProtoHolder.instance().op_proto_map.keys(): - so_path = find_so_path(generated_dir, python_package_name) - print("== we find so_path: ", so_path) - assert so_path is not None - paddle.utils.cpp_extension.load_op_meta_info_and_register_op(so_path) - - self.decorator = decorator - - def __getitem__(self, op_name_and_grid): - assert len(op_name_and_grid) >= 2, "len(op_name_and_grid) must >= 2." - self.op_name = op_name_and_grid[0] - self.grid = op_name_and_grid[1] - if len(op_name_and_grid) == 2: - self.tune_config = {} - else: - self.tune_config = op_name_and_grid[2] - return self.decorator - - -def paddle_use_triton(custom_op_template, other_config={}, key=[]): - - index = custom_op_template.find("PD_BUILD_OP") - - body = custom_op_template[:index] - - if body.find("${op_name}_InferShape") == -1: - body += "std::vector> ${op_name}_InferShape(const std::vector& A_shape) {return {A_shape};}" - - if body.find("${op_name}_InferDtype") == -1: - body += ( - "std::vector ${op_name}_InferDtype(const paddle::DataType& A_dtype) {return {A_dtype};}" - ) - - tail = custom_op_template[index:] - - tail += """ - .SetKernelFn(PD_KERNEL(${op_name}_func)) - .SetInferDtypeFn(PD_INFER_DTYPE(${op_name}_InferDtype)) - .SetInferShapeFn(PD_INFER_SHAPE(${op_name}_InferShape)); - """ - - custom_op_template = paddle_custom_op_head_part + body + tail - - def decorator(func): - return KernelInterface(func, custom_op_template, other_config, key) - - return decorator - - def get_wint8_kernel_config(): configs = [] for num_stages in [2, 3, 4, 5, 6]: @@ -373,7 +126,6 @@ def get_wint8_kernel_config(): @paddle_use_triton( - custom_op_template=triton_wint8_template, other_config=wint8_kernel_other_config, key=["M", "N", "K"], ) @@ -533,7 +285,7 @@ def weight_only_int8(x, qweight, scales, bias=None, bool_trans_w=True): triton_uint_qweight = (triton_qweight.astype("int32") + 128).astype("uint8") for i in range(100): - triton_output = paddlemix.custom_ops.weight_only_int8( + triton_output = paddlemix.triton_ops.weight_only_int8( activation, triton_uint_qweight, triton_scale, @@ -543,7 +295,7 @@ def weight_only_int8(x, qweight, scales, bias=None, bool_trans_w=True): starttime = datetime.datetime.now() for i in range(100): - triton_output = paddlemix.custom_ops.weight_only_int8( + triton_output = paddlemix.triton_ops.weight_only_int8( activation, triton_uint_qweight, triton_scale, @@ -611,7 +363,7 @@ def weight_only_int8(x, qweight, scales, bias=None, bool_trans_w=True): "SPLIT_K", ) - wint8_kernel[(op_name, grid, get_wint8_kernel_config())]( + wint8_kernel[(triton_wint8_template, op_name, grid, get_wint8_kernel_config())]( x, qweight, output, @@ -650,67 +402,7 @@ def weight_only_int8(x, qweight, scales, bias=None, bool_trans_w=True): return out -########################### adaptive layer norm ############################### -fused_adaLN_scale_residual_template = ( - """ - - -std::vector ${op_name}_func( - const paddle::Tensor &x, - const paddle::Tensor &mha_out, - const paddle::Tensor &gate_msa, - const paddle::Tensor &scale_mlp, - const paddle::Tensor &shift_mlp, - paddle::optional &weight, - paddle::optional &bias, - float epsilon) { - int M = x.dims()[0] * x.dims()[1]; - int N = x.dims()[2]; - int seq_size = x.dims()[1]; - auto resi_out = paddle::empty(x.shape(), x.dtype(), x.place()); - auto adaLN_out = paddle::empty(x.shape(), x.dtype(), x.place()); - - auto x_ptr = get_tensor_ptr(x); - auto mha_out_ptr = get_tensor_ptr(mha_out); - auto resi_out_ptr = get_tensor_ptr(resi_out); - auto adaLN_out_ptr = get_tensor_ptr(adaLN_out); - auto gate_msa_ptr = get_tensor_ptr(gate_msa); - auto scale_mlp_ptr = get_tensor_ptr(scale_mlp); - auto shift_mlp_ptr = get_tensor_ptr(shift_mlp); - CUdeviceptr weight_ptr = (CUdeviceptr)(nullptr); - if (weight) { - weight_ptr = get_tensor_ptr(*weight); - } - CUdeviceptr bias_ptr = (CUdeviceptr)(nullptr); - if (bias) { - bias_ptr = get_tensor_ptr(*bias); - } - auto run_stream = adaLN_out.stream(); -""" - + tune_and_invoke_part - + """ - return {resi_out, adaLN_out}; -} - -std::vector> ${op_name}_InferShape( - const std::vector& A_shape) { - return {A_shape, A_shape}; -} - -std::vector ${op_name}_InferDtype(const paddle::DataType& A_dtype) { - return {A_dtype, A_dtype}; -} - -PD_BUILD_OP(${op_name}) - .Inputs({"x", "mha_out", "gate_msa", "scale_mlp", "shift_mlp", paddle::Optional("weight"), paddle::Optional("bias")}) - .Outputs({"resi_out", "adaLN_out"}) - .Attrs({"epsilon: float"}) -""" -) - - @paddle_use_triton( - custom_op_template=fused_adaLN_scale_residual_template, key=["M"], ) def fused_adaLN_scale_residual_kernel( @@ -813,7 +505,7 @@ def paddle_fused_adaLN(x, mha_out, gate, hidd, scale, shift, weight, bias, epsil for i in range(100): - resi_out_triton, adaLN_out_triton = paddlemix.custom_ops.fused_adaLN_scale_residual(x, mha_out, gate_msa, scale_mlp_x, shift_mlp_x, weight, bias, epsilon) + resi_out_triton, adaLN_out_triton = paddlemix.triton_ops.fused_adaLN_scale_residual(x, mha_out, gate_msa, scale_mlp_x, shift_mlp_x, weight, bias, epsilon) for i in range(100): resi_out_paddle, adaLN_out_paddle = paddle_fused_adaLN(x, mha_out, gate_msa, hidd, scale_mlp_x, shift_mlp_x, weight, bias, epsilon) @@ -853,6 +545,12 @@ def paddle_fused_adaLN(x, mha_out, gate, hidd, scale, shift, weight, bias, epsil seq_size = x.shape[1] N_npo2 = triton.next_power_of_2(N) + prepare_attr_for_triton_kernel = """ + int M = x.dims()[0] * x.dims()[1]; + int N = x.dims()[2]; + int seq_size = x.dims()[1]; + """ + # baseline. if os.getenv("INFERENCE_OPTIMIZE_TRITON") is None: resi_out_paddle = mha_out * gate_msa.unsqueeze(axis=1) + x @@ -875,28 +573,54 @@ def paddle_fused_adaLN(x, mha_out, gate, hidd, scale, shift, weight, bias, epsil if op_name not in OpProtoHolder.instance().op_proto_map.keys(): resi_out = paddle.empty_like(x) adaLN_out = paddle.empty_like(x) + prepare_ptr_for_triton_kernel = """ + auto resi_out = paddle::empty(x.shape(), x.dtype(), x.place()); + auto adaLN_out = paddle::empty(x.shape(), x.dtype(), x.place()); + + auto x_ptr = get_tensor_ptr(x); + auto mha_out_ptr = get_tensor_ptr(mha_out); + auto resi_out_ptr = get_tensor_ptr(resi_out); + auto adaLN_out_ptr = get_tensor_ptr(adaLN_out); + auto gate_msa_ptr = get_tensor_ptr(gate_msa); + auto scale_mlp_ptr = get_tensor_ptr(scale_mlp); + auto shift_mlp_ptr = get_tensor_ptr(shift_mlp); + CUdeviceptr weight_ptr = (CUdeviceptr)(nullptr); + if (weight) weight_ptr = get_tensor_ptr(*weight); + CUdeviceptr bias_ptr = (CUdeviceptr)(nullptr); + if (bias) bias_ptr = get_tensor_ptr(*bias); + """ + + return_tensor_names = "resi_out, adaLN_out" + template_used = rendering_common_template( + fused_adaLN_scale_residual, + prepare_attr_for_triton_kernel, + prepare_ptr_for_triton_kernel, + return_tensor_names, + ) + grid = ("M",) - fused_adaLN_scale_residual_kernel[(op_name, grid, fused_adaLN_scale_residual_kernel_config)]( - x, - mha_out, - gate_msa, - scale_mlp, - shift_mlp, - shift_mlp, - shift_mlp, - resi_out, - adaLN_out, - -1, - N, - -1, - epsilon, + fused_adaLN_scale_residual_kernel[(op_name, template_used, grid, fused_adaLN_scale_residual_kernel_config)]( + x_ptr=x, + mha_out_ptr=mha_out, + gate_msa_ptr=gate_msa, + scale_mlp_ptr=scale_mlp, + shift_mlp_ptr=shift_mlp, + # weight_ptr and bias_ptr may be None, so use shift_mlp. + weight_ptr=shift_mlp, + bias_ptr=shift_mlp, + resi_out_ptr=resi_out, + adaLN_out_ptr=adaLN_out, + M=-1, + N=N, + seq_size=-1, + epsilon=epsilon, N_npo2=N_npo2, weight_attr=weight_attr, bias_attr=bias_attr, ) if in_dynamic_or_pir_mode(): - print(f"== we are in dynamic mode, op_name: {op_name}") + #print(f"== we are in dynamic mode, op_name: {op_name}") outs = _C_ops._run_custom_op( op_name, x, @@ -910,7 +634,7 @@ def paddle_fused_adaLN(x, mha_out, gate, hidd, scale, shift, weight, bias, epsil ) return outs[0], outs[1] else: - print(f"== we are in dynamic to static mode, op_name: {op_name}") + #print(f"== we are in dynamic to static mode, op_name: {op_name}") helper = LayerHelper(op_name, **locals()) inputs = { "x": x, @@ -934,50 +658,7 @@ def paddle_fused_adaLN(x, mha_out, gate, hidd, scale, shift, weight, bias, epsil return resi_out, adaLN_out -triton_adaptive_layer_norm_template = ( - """ - -std::vector ${op_name}_func( - const paddle::Tensor &x, - const paddle::Tensor &scale, - const paddle::Tensor &shift, - paddle::optional &weight, - paddle::optional &bias, - float epsilon) { - int M = x.dims()[0] * x.dims()[1]; - int N = x.dims()[2]; - int seq_size = x.dims()[1]; - auto y = paddle::empty(x.shape(), x.dtype(), x.place()); - - auto x_ptr = get_tensor_ptr(x); - auto y_ptr = get_tensor_ptr(y); - auto scale_ptr = get_tensor_ptr(scale); - auto shift_ptr = get_tensor_ptr(shift); - CUdeviceptr weight_ptr = (CUdeviceptr)(nullptr); - if (weight) { - weight_ptr = get_tensor_ptr(*weight); - } - CUdeviceptr bias_ptr = (CUdeviceptr)(nullptr); - if (bias) { - bias_ptr = get_tensor_ptr(*bias); - } - auto run_stream = y.stream(); -""" - + tune_and_invoke_part - + """ - return {y}; -} - -PD_BUILD_OP(${op_name}) - .Inputs({"x", "scale", "shift", paddle::Optional("weight"), paddle::Optional("bias")}) - .Outputs({"out"}) - .Attrs({"epsilon: float"}) -""" -) - - @paddle_use_triton( - custom_op_template=triton_adaptive_layer_norm_template, key=["M"], ) def adaptive_layer_norm_kernel( @@ -1052,7 +733,7 @@ def modulate(x, shift, scale): scale_msa_x = paddle.rand([batch, hidd], dtype=dtype) for i in range(100): - mt_result = paddlemix.custom_ops.adaptive_layer_norm(x, scale_msa_x, shift_msa_x, weight, bias) + mt_result = paddlemix.triton_ops.adaptive_layer_norm(x, scale_msa_x, shift_msa_x, weight, bias) for i in range(100): baseline = modulate(paddle.nn.functional.layer_norm(x, [hidd], weight, bias, 1e-5), shift_msa_x, scale_msa_x) @@ -1083,6 +764,12 @@ def modulate(x, shift, scale): seq_size = x.shape[1] BLOCK_SIZE = triton.next_power_of_2(N) + prepare_attr_for_triton_kernel = """ + int M = x.dims()[0] * x.dims()[1]; + int N = x.dims()[2]; + int seq_size = x.dims()[1]; + """ + # baseline. if os.getenv("INFERENCE_OPTIMIZE_TRITON") is None: norm_hidden_states = paddle.nn.functional.layer_norm(x, [N], weight, bias, epsilon) @@ -1103,18 +790,34 @@ def modulate(x, shift, scale): if op_name not in OpProtoHolder.instance().op_proto_map.keys(): y = paddle.empty_like(x) + prepare_ptr_for_triton_kernel = """ + auto y = paddle::empty(x.shape(), x.dtype(), x.place()); + auto x_ptr = get_tensor_ptr(x); + auto y_ptr = get_tensor_ptr(y); + auto scale_ptr = get_tensor_ptr(scale); + auto shift_ptr = get_tensor_ptr(shift); + CUdeviceptr weight_ptr = (CUdeviceptr)(nullptr); + if (weight) weight_ptr = get_tensor_ptr(*weight); + CUdeviceptr bias_ptr = (CUdeviceptr)(nullptr); + if (bias) bias_ptr = get_tensor_ptr(*bias); + """ + return_tensor_names = "y" + template_used = rendering_common_template( + adaptive_layer_norm, prepare_attr_for_triton_kernel, prepare_ptr_for_triton_kernel, return_tensor_names + ) + grid = ("M",) - adaptive_layer_norm_kernel[(op_name, grid, adaptive_layer_norm_kernel_config)]( - x, - y, - y, - y, - y, - y, - -1, - N, - -1, - epsilon, + adaptive_layer_norm_kernel[(op_name, template_used, grid, adaptive_layer_norm_kernel_config)]( + x_ptr=x, + y_ptr=y, + weight_ptr=y, + bias_ptr=y, + scale_ptr=y, + shift_ptr=y, + M=-1, + N=N, + seq_size=-1, + epsilon=epsilon, BLOCK_SIZE=BLOCK_SIZE, weight_attr=weight_attr, bias_attr=bias_attr, @@ -1132,189 +835,16 @@ def modulate(x, shift, scale): "weight@OPTIONAL": weight, "bias@OPTIONAL": bias, } - out = helper.create_variable_for_type_inference(dtype=x.dtype) - helper.append_op( - type=op_name, - inputs=inputs, - attrs={ - "epsilon": epsilon, - }, - outputs={"out": out}, - ) - return out - - -rms_norm_template = ( - """ - -std::vector ${op_name}_func( - const paddle::Tensor &x, - paddle::optional &weight, - paddle::optional &bias, - float epsilon) { - int M = x.dims()[0] * x.dims()[1] * x.dims()[2]; - int N = x.dims()[3]; - auto y = paddle::empty(x.shape(), x.dtype(), x.place()); - - auto x_ptr = get_tensor_ptr(x); - auto y_ptr = get_tensor_ptr(y); - CUdeviceptr weight_ptr = (CUdeviceptr)(nullptr); - if (weight) { - weight_ptr = get_tensor_ptr(*weight); - } - CUdeviceptr bias_ptr = (CUdeviceptr)(nullptr); - if (bias) { - bias_ptr = get_tensor_ptr(*bias); - } - auto run_stream = y.stream(); -""" - + tune_and_invoke_part - + """ - return {y}; -} - -PD_BUILD_OP(${op_name}) - .Inputs({"x", paddle::Optional("weight"), paddle::Optional("bias")}) - .Outputs({"out"}) - .Attrs({"epsilon: float"}) -""" -) - - -@paddle_use_triton( - custom_op_template=rms_norm_template, - key=["M"], -) -def rms_norm_kernel( - x_ptr, - y_ptr, - weight_ptr, - bias_ptr, - M, - N, - epsilon, - BLOCK_SIZE_M: tl.constexpr, - N_npo2: tl.constexpr, - weight_attr: tl.constexpr, - bias_attr: tl.constexpr, -): - row = tl.program_id(axis=0) - - offs_am = tl.arange(0, BLOCK_SIZE_M) - offs_an = tl.arange(0, N_npo2) - - # compute var - all_offs = (row * BLOCK_SIZE_M + offs_am[:, None]) % M * N + offs_an[None, :] - - x_eles = tl.load(x_ptr + all_offs, mask=offs_an[None, :] < N, other=0.0).to(tl.float32) - var = tl.sum(x_eles * x_eles, axis=1) / N - - resi_hat = x_eles / tl.sqrt(var[:, None] + epsilon) - - if weight_attr: - weights = tl.load(weight_ptr + offs_an, mask=offs_an < N, other=0.0) - resi_hat = resi_hat * weights - - if bias_attr: - bias = tl.load(bias_ptr + offs_an, mask=offs_an < N, other=0.0) - resi_hat = resi_hat + bias - - tl.store(y_ptr + all_offs, resi_hat, mask=offs_an[None, :] < N) - - -def rms_norm(x, weight=None, bias=None, epsilon=1e-05): - """ - Examples: - - import os - os.environ["CUDA_VISIBLE_DEVICES"] = "2" - import paddle - - batch = 2 - seq = 3600 - num_heads = 1 - head_dim = 64*30 - dtype= "float16" - x = paddle.rand([batch, seq, num_heads, head_dim], dtype=dtype) - weight = paddle.rand([head_dim], dtype=dtype) - bias = paddle.rand([head_dim], dtype=dtype) - - for i in range(100): - baseline = paddle.incubate.nn.functional.fused_rms_norm(x, weight, bias, 1e-5, begin_norm_axis=3) - - for i in range(100): - mt_result = paddlemix.custom_ops.rms_norm(x,weight,bias,1e-5) - - - baseline = baseline[0] - print(paddle.max(paddle.abs(baseline-mt_result))) - - """ - - assert len(x.shape) == 4, "x should be 4-dim." - weight_attr = 0 - if weight is not None: - assert len(weight.shape) == 1, "weight should be 1-dim" - assert weight.shape[-1] == x.shape[-1], "x and weight should have same shape[-1]" - weight_attr = 1 - bias_attr = 0 - if bias is not None: - assert len(bias.shape) == 1, "bias should be 1-dim" - assert bias.shape[-1] == x.shape[-1], "x and bias should have same shape[-1]" - bias_attr = 1 - - M = x.shape[0] * x.shape[1] * x.shape[2] - N = x.shape[3] - N_npo2 = triton.next_power_of_2(N) - - op_name = "triton_rms_norm" - op_name += get_dtype_str(x.dtype) - op_name += f"_{N_npo2}" - - rms_norm_kernel_config = [] - if N_npo2 <= 64: - rms_norm_kernel_config.append({"BLOCK_SIZE_M": 4, "num_warps": 1}) - else: - rms_norm_kernel_config.append({"BLOCK_SIZE_M": 1, "num_warps": 4}) - - if op_name not in OpProtoHolder.instance().op_proto_map.keys(): - y = paddle.empty_like(x) - grid = ("((M+BLOCK_SIZE_M-1)/BLOCK_SIZE_M)",) - rms_norm_kernel[(op_name, grid, rms_norm_kernel_config)]( - x, - y, - weight, - x, - -1, # M, - N, - epsilon, - N_npo2=N_npo2, - weight_attr=weight_attr, - bias_attr=bias_attr, - ) - - if in_dynamic_or_pir_mode(): - print(f"== we are in dynamic mode, op_name: {op_name}") - outs = _C_ops._run_custom_op(op_name, x, weight, bias, epsilon) - return outs[0] - else: - print(f"== we are in dynamic to static mode, op_name: {op_name}") - helper = LayerHelper(op_name, **locals()) - inputs = { - "x": x, - "weight@OPTIONAL": weight, - "bias@OPTIONAL": bias, - } - out = helper.create_variable_for_type_inference(dtype=x.dtype) + y = helper.create_variable_for_type_inference(dtype=x.dtype) helper.append_op( type=op_name, inputs=inputs, attrs={ "epsilon": epsilon, }, - outputs={"out": out}, + outputs={"y": y}, ) - return out + return y fused_rotary_emb_template = ( @@ -1385,7 +915,6 @@ def rms_norm(x, weight=None, bias=None, epsilon=1e-05): @paddle_use_triton( - custom_op_template=fused_rotary_emb_template, key=["M"], ) def fused_rotary_emb_kernel( @@ -1513,7 +1042,7 @@ def fused_rotary_emb( k_out_tensor = paddle.empty([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM], dtype=empty_dtype).astype(dtype_) v_out_tensor = paddle.empty([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM], dtype=empty_dtype).astype(dtype_) grid = ("M",) - fused_rotary_emb_kernel[(op_name, grid, fused_rotary_emb_kernel_config)]( + fused_rotary_emb_kernel[(op_name, fused_rotary_emb_template, grid, fused_rotary_emb_kernel_config)]( x, q_out_tensor, k_out_tensor, @@ -1532,7 +1061,7 @@ def fused_rotary_emb( ) if in_dynamic_or_pir_mode(): - print(f"== we are in dynamic mode, op_name: {op_name}") + #print(f"== we are in dynamic mode, op_name: {op_name}") outs = _C_ops._run_custom_op( op_name, x, @@ -1545,7 +1074,7 @@ def fused_rotary_emb( ) return outs[0], outs[1], outs[2] else: - print(f"== we are in dynamic to static mode, op_name: {op_name}") + #print(f"== we are in dynamic to static mode, op_name: {op_name}") helper = LayerHelper(op_name, **locals()) inputs = { "x": x, @@ -1569,66 +1098,7 @@ def fused_rotary_emb( return q_out, k_out, v_out -########################### split concat ############################### -split_concat_template = ( - """ -std::vector ${op_name}_func( - const paddle::Tensor &x, - const paddle::Tensor &y) { - - int batch = x.dims()[0]; - - int seq_qkv = x.dims()[1]; - int seq_eqkv = y.dims()[1]; - int output_hidden = x.dims()[2] / 3; - - auto qkv = get_tensor_ptr(x); - auto eqkv = get_tensor_ptr(y); - - auto out0_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place()); - auto out1_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place()); - auto out2_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place()); - - auto out0 = get_tensor_ptr(out0_tensor); - auto out1 = get_tensor_ptr(out1_tensor); - auto out2 = get_tensor_ptr(out2_tensor); - - - auto run_stream = out0_tensor.stream(); - -""" - + tune_and_invoke_part - + """ - return {out0_tensor, out1_tensor, out2_tensor}; -} - -std::vector> ${op_name}_InferShape( - const std::vector& A_shape, const std::vector& B_shape) { - - int64_t seq1 = A_shape[1]; - int64_t seq2 = B_shape[1]; - int64_t seq = -1; - if (seq1 > 0 && seq2 > 0){ - seq = seq1 + seq2; - } - std::vector out_shape = {A_shape[0], seq, A_shape[2]/3}; - - return {out_shape, out_shape, out_shape}; -} - -std::vector ${op_name}_InferDtype(const paddle::DataType& A_dtype) { - return {A_dtype, A_dtype, A_dtype}; -} - -PD_BUILD_OP(${op_name}) - .Inputs({"x", "y"}) - .Outputs({"out0_tensor", "out1_tensor", "out2_tensor"}) -""" -) - - @paddle_use_triton( - custom_op_template=split_concat_template, key=["1"], ) def split_concat_kernel( @@ -1671,6 +1141,28 @@ def split_concat_kernel( tl.store(write_ptr, read_data, mask=mask) +########################### split concat ############################### +d2s_split_concat_infer_shape_dtype = """ +std::vector> ${op_name}_InferShape( + const std::vector& A_shape, const std::vector& B_shape) { + + int64_t seq1 = A_shape[1]; + int64_t seq2 = B_shape[1]; + int64_t seq = -1; + if (seq1 > 0 && seq2 > 0){ + seq = seq1 + seq2; + } + std::vector out_shape = {A_shape[0], seq, A_shape[2]/3}; + + return {out_shape, out_shape, out_shape}; +} + +std::vector ${op_name}_InferDtype(const paddle::DataType& A_dtype) { + return {A_dtype, A_dtype, A_dtype}; +} +""" + + def split_concat(x, y): assert len(x.shape) == 3 assert len(y.shape) == 3 @@ -1692,6 +1184,15 @@ def split_concat(x, y): hidd_x = x.shape[2] seq_eqkv = y.shape[1] ouput_hidden = hidd_x // 3 + + prepare_attr_for_triton_kernel = """ + int batch = x.dims()[0]; + int seq_qkv = x.dims()[1]; + int hidd_x = x.dims()[2]; + int seq_eqkv = y.dims()[1]; + int output_hidden = hidd_x / 3; + """ + BLOCK_SIZE = triton.next_power_of_2(ouput_hidden) op_name = "split_concat" op_name += get_dtype_str(x.dtype) @@ -1701,14 +1202,44 @@ def split_concat(x, y): out0 = paddle.empty(shape=[batch, seq_qkv + seq_eqkv, ouput_hidden], dtype=x.dtype) out1 = paddle.empty(shape=[batch, seq_qkv + seq_eqkv, ouput_hidden], dtype=x.dtype) out2 = paddle.empty(shape=[batch, seq_qkv + seq_eqkv, ouput_hidden], dtype=x.dtype) + + prepare_ptr_for_triton_kernel = """ + auto out0_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place()); + auto out1_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place()); + auto out2_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place()); + auto qkv = get_tensor_ptr(x); + auto eqkv = get_tensor_ptr(y); + auto out0 = get_tensor_ptr(out0_tensor); + auto out1 = get_tensor_ptr(out1_tensor); + auto out2 = get_tensor_ptr(out2_tensor); + """ + return_tensor_names = "out0_tensor,out1_tensor,out2_tensor" + + template_used = rendering_common_template( + split_concat, + prepare_attr_for_triton_kernel, + prepare_ptr_for_triton_kernel, + return_tensor_names, + d2s_split_concat_infer_shape_dtype, + ) + grid = ("3", "batch", "seq_qkv + seq_eqkv") # -1 means this value does not matter for triton compilation - split_concat_kernel[(op_name, grid)]( - out0, out1, out2, x, y, -1, seq_qkv, seq_eqkv, ouput_hidden, BLOCK_SIZE=BLOCK_SIZE # batch, + split_concat_kernel[(op_name, template_used, grid)]( + out0=out0, + out1=out1, + out2=out2, + qkv=x, + eqkv=y, + batch=-1, + seq_qkv=seq_qkv, + seq_eqkv=seq_eqkv, + output_hidden=ouput_hidden, + BLOCK_SIZE=BLOCK_SIZE, ) if in_dynamic_or_pir_mode(): - print(f"== we are in dynamic mode, op_name: {op_name}") + #print(f"== we are in dynamic mode, op_name: {op_name}") outs = _C_ops._run_custom_op( op_name, x, @@ -1716,7 +1247,7 @@ def split_concat(x, y): ) return outs[0], outs[1], outs[2] else: - print(f"== we are in dynamic to static mode, op_name: {op_name}") + #print(f"== we are in dynamic to static mode, op_name: {op_name}") helper = LayerHelper(op_name, **locals()) inputs = { "x": x, @@ -1785,7 +1316,6 @@ def split_concat(x, y): @paddle_use_triton( - custom_op_template=triton_split_template, key=["1"], ) def triton_split_kernel( @@ -1831,12 +1361,12 @@ def triton_split(x, num_or_sections=[-1, -1], axis=1): out1 = paddle.empty(shape=[output_batch, output_seq1, output_hidden], dtype=x.dtype) grid = ("output_batch", "output_seq0+output_seq1") - triton_split_kernel[(op_name, grid)]( + triton_split_kernel[(op_name, triton_split_template, grid)]( out0, out1, x, output_seq0, output_seq1, output_batch, output_hidden, BLOCK_SIZE=2048 ) if in_dynamic_or_pir_mode(): - print(f"== we are in dynamic mode, op_name: {op_name}") + #print(f"== we are in dynamic mode, op_name: {op_name}") outs = _C_ops._run_custom_op( op_name, x, @@ -1845,7 +1375,7 @@ def triton_split(x, num_or_sections=[-1, -1], axis=1): ) return outs[0], outs[1] else: - print(f"== we are in dynamic to static mode, op_name: {op_name}") + #print(f"== we are in dynamic to static mode, op_name: {op_name}") helper = LayerHelper(op_name, **locals()) inputs = { "x": x, @@ -1863,3 +1393,7 @@ def triton_split(x, num_or_sections=[-1, -1], axis=1): outputs={"out0_tensor": out0, "out1_tensor": out1}, ) return out0, out1 + + + + diff --git a/paddlemix/triton_ops/triton_utils.py b/paddlemix/triton_ops/triton_utils.py index 766e03a8a..bacf1686b 100644 --- a/paddlemix/triton_ops/triton_utils.py +++ b/paddlemix/triton_ops/triton_utils.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import os import re import sys import paddle import triton +from paddle.base.framework import OpProtoHolder compile_file = triton.__path__[0] + "/tools/compile.py" link_file = triton.__path__[0] + "/tools/link.py" @@ -313,8 +315,8 @@ def get_pointer_hint(dtypes): (cudaEventRecord(beg[repeat_id])); } - auto flush_l2_cache = paddle::full( - {10 * 1024 * 1024}, 0, paddle::DataType::INT32, x.place()); + // auto flush_l2_cache = paddle::full( + // {10 * 1024 * 1024}, 0, paddle::DataType::INT32, x.place()); // std::cout << &flush_l2_cache << std::endl; // this is used when out is need to be reset to zero, such as split-k gemm. ${reset_zero_when_tune}; @@ -354,3 +356,310 @@ def get_pointer_hint(dtypes): assert(status == CUDA_SUCCESS); } """ + + +common_template = ( + """ +std::vector ${op_name}_func(${input_and_attr}) { + ${prepare_attr_for_triton_kernel} + ${prepare_ptr_for_triton_kernel} + auto run_stream = ${arbitary_output_name}.stream(); + """ + + tune_and_invoke_part + + """ + return {${return_tensor_names}}; +} + + +${d2s_infer_shape_dtype_part} + +PD_BUILD_OP(${op_name}) + .Inputs({${paddle_input_sig}}) + .Outputs({${paddle_output_sig}}) + .Attrs({${paddle_attr_sig}}) + .SetKernelFn(PD_KERNEL(${op_name}_func)) + .SetInferDtypeFn(PD_INFER_DTYPE(${op_name}_InferDtype)) + .SetInferShapeFn(PD_INFER_SHAPE(${op_name}_InferShape)); +""" +) + + +def rendering_common_template( + func, + prepare_attr_for_triton_kernel, + prepare_ptr_for_triton_kernel, + return_tensor_names, + d2s_infer_shape_dtype_part=None, +): + signature = inspect.signature(func) + arg_names = [v.name for v in signature.parameters.values()] + arg_defaults = [v.default for v in signature.parameters.values()] + input_and_attr = "" + paddle_input_sig = "" + paddle_attr_sig = "" + + for i in range(len(arg_names)): + if arg_defaults[i] == None: + input_and_attr += f"paddle::optional & {arg_names[i]}," + paddle_input_sig += f"""paddle::Optional("{arg_names[i]}"),""" + elif type(arg_defaults[i]) == float: + input_and_attr += f"float {arg_names[i]}," + paddle_attr_sig += f""""{arg_names[i]}: float",""" + else: + input_and_attr += f"const paddle::Tensor & {arg_names[i]}," + paddle_input_sig += f""""{arg_names[i]}",""" + input_and_attr = input_and_attr[:-1] + paddle_input_sig = paddle_input_sig[:-1] + if len(paddle_attr_sig) > 1: + paddle_attr_sig = paddle_attr_sig[:-1] + + paddle_output_sig = "" + arbitary_output_name = "" + for name in return_tensor_names.split(","): + name = name.strip() + arbitary_output_name = name + paddle_output_sig += f""""{name}",""" + paddle_output_sig = paddle_output_sig[:-1] + + if d2s_infer_shape_dtype_part is None: + d2s_infer_shape_dtype_part = """ + std::vector> ${op_name}_InferShape(const std::vector& A_shape) {return {${tmp1}};} + std::vector ${op_name}_InferDtype(const paddle::DataType& A_dtype) {return {${tmp2}};} + """ + tmp1 = ",".join(["A_shape"] * len(return_tensor_names.split(","))) + tmp2 = ",".join(["A_dtype"] * len(return_tensor_names.split(","))) + tmp_dict = {"tmp1": tmp1, "tmp2": tmp2} + d2s_infer_shape_dtype_part = SubstituteTemplate(d2s_infer_shape_dtype_part, tmp_dict) + + result_str = SubstituteTemplate( + common_template, + { + "input_and_attr": input_and_attr, + "prepare_attr_for_triton_kernel": prepare_attr_for_triton_kernel, + "prepare_ptr_for_triton_kernel": prepare_ptr_for_triton_kernel, + "return_tensor_names": return_tensor_names, + "arbitary_output_name": arbitary_output_name, + "d2s_infer_shape_dtype_part": d2s_infer_shape_dtype_part, + "paddle_input_sig": paddle_input_sig, + "paddle_output_sig": paddle_output_sig, + "paddle_attr_sig": paddle_attr_sig, + }, + ) + + return paddle_custom_op_head_part + result_str + + +class KernelInterface: + def __init__( + self, + func, + other_config, + key_args=["1"], + ): + self.func = func + self.key_args = key_args + + signature = inspect.signature(func) + self.arg_names = [v.name for v in signature.parameters.values()] + for ele in self.arg_names: + assert self.arg_names.count(ele) == 1 + arg_defaults = [v.default for v in signature.parameters.values()] + + # self.annotations = { + # name: ty for name, ty in func.__annotations__.items() + # } + self.annotations = dict(func.__annotations__) + + self.constexprs = [ + self.arg_names.index(name) + for name in self.arg_names + if self.annotations.get(name) == triton.language.core.constexpr + ] + + self.arg_exclude_constexpr = [ + self.arg_names[i] for i in range(len(self.arg_names)) if i not in self.constexprs + ] + + import textwrap + + py_script = textwrap.dedent(inspect.getsource(func)) + + import re + + pat = r"def\s" + func.__name__ + func_begin = re.findall(pat, py_script) + assert len(func_begin) == 1 + func_begin = func_begin[0] + py_script = py_script[py_script.find(func_begin) :] + + def decorator(*args, **kwargs): + all_input = [] + + for i in range(len(args)): + all_input.append(args[i]) + + position_arguments_num = len(all_input) + for i in range(position_arguments_num, len(self.arg_names)): + if self.arg_names[i] in kwargs.keys(): + all_input.append(kwargs[self.arg_names[i]]) + else: + # means this input is not specified, it muse be a tl.constexpr. + assert i in self.constexprs + all_input.append(None) + + dtypes = [] + x_list = [] + const_args = [self.arg_names[i] for i in self.constexprs] + # we dont allow there are two strings in const_args, and one is a substring of the other. + for i in const_args: + for j in const_args: + if i != j and i.find(j) != -1: + raise ValueError( + f"We find {i}, {j} in tl.constexpr args, and {j} is a substring of {i}, please modify your triton kernel arguments names to avoid this." + ) + + const_hint_dict = {} + for i in range(len(all_input)): + ele = all_input[i] + if ( + type(ele) == paddle.Tensor + or type(ele) == paddle.base.framework.EagerParamBase + or type(ele) == paddle.base.framework.Parameter + or type(ele) == paddle.base.framework.Variable + or type(ele) == paddle.base.libpaddle.pir.Value + ): + dtypes.append(ele.dtype) + elif i in self.constexprs: + const_hint_dict[self.arg_names[i]] = ele + else: + x_list.append(ele) + + op_name = self.op_name + + python_package_name = f"{op_name}_package" + + generated_dir = os.getenv("TRITON_KERNEL_CACHE_DIR", None) + print("the kernel cache dir is:", generated_dir) + assert ( + generated_dir is not None + ), "TRITON_KERNEL_CACHE_DIR is None, please set it such as export TRITON_KERNEL_CACHE_DIR=/tmp/haha " + generated_dir = f"{generated_dir}/{op_name}" + os.makedirs(generated_dir, exist_ok=True) + + py_script_file = f"{generated_dir}/triton_kernels.py" + extract_triton_kernel(func, py_script_file) + + address_hint = get_pointer_hint(dtypes) + value_hint = get_value_hint(x_list) + const_args = [f"{{{ele}}}" for ele in const_args] + const_args = ",".join(const_args) + + lanuch_grid = list(self.grid) + for i in range(len(lanuch_grid)): + ele = lanuch_grid[i] + if type(ele) == str: + for key in const_hint_dict.keys(): + if key in ele: + ele = ele.replace(key, f"{{{key}}}") + else: + ele = str(ele) + + lanuch_grid[i] = ele + if len(lanuch_grid) < 3: + lanuch_grid += ["1"] * (3 - len(lanuch_grid)) + lanuch_grid = ",".join(lanuch_grid) + + op_dict = {"op_name": op_name, "reset_zero_when_tune": ""} + op_dict["triton_kernel_args"] = ",".join(self.arg_exclude_constexpr) + op_dict["key"] = ",".join(self.key_args) + # when tunning, we need to reset the out to zero. + if "reset_zero_when_tune" in other_config.keys(): + op_dict["reset_zero_when_tune"] = other_config["reset_zero_when_tune"] + + paddle_custom_op_file_path = f"{generated_dir}/{op_name}.cu" + so_path = find_so_path(generated_dir, python_package_name) + + if so_path is None: + print("== we do not find so_path, we need to compile it") + with open(paddle_custom_op_file_path, "w") as f: + f.write( + SubstituteTemplate( + self.custom_op_template, + op_dict, + ) + ) + f.close() + + # ahead of time compile command. + aot_template = ( + f"""{python_path} {compile_file} {py_script_file} -n {func.__name__} -o {generated_dir}/{op_name}_kernel --out-name {op_name}_kernel """ + + """ -w {num_warps} -ns {num_stages} """ + + f""" -s"{address_hint} {value_hint} {const_args}" """ + + f""" -g "{lanuch_grid}" """ + ) + all_tune_config = list(self.tune_config) + if len(all_tune_config) == 0: + # when user do not specify config, we use const_hint_dict as config. + all_tune_config = [const_hint_dict] + # reset const_hint_dict as empty. + const_hint_dict = {} + codegen_commands = [] + for config in all_tune_config: + for key in const_hint_dict.keys(): + if const_hint_dict[key] is not None: + if key not in config.keys(): + config[key] = const_hint_dict[key] + else: + raise ValueError(f"you specify {key} both in arguments and config, this is wrong.") + else: + assert key in config.keys(), f"you must specify {key} in your config." + if "num_warps" not in config.keys(): + config["num_warps"] = 4 + if "num_stages" not in config.keys(): + config["num_stages"] = 4 + + for key in config: + assert config[key] is not None, f"{key} must be specified." + codegen_command = aot_template.format( + **config, + ) + print(codegen_command) + codegen_commands.append(codegen_command) + multi_process_do(codegen_commands) + + link_command = f"{python_path} {link_file} {generated_dir}/*.h -o {generated_dir}/{op_name}_kernel" + re = os.system(link_command) + assert re == 0 + + # rename the .c file to .cu + rename_c_to_cu(generated_dir) + # build the package to so, not install + build_package(generated_dir, python_package_name) + + if op_name not in OpProtoHolder.instance().op_proto_map.keys(): + so_path = find_so_path(generated_dir, python_package_name) + print("== we find so_path: ", so_path) + assert so_path is not None + paddle.utils.cpp_extension.load_op_meta_info_and_register_op(so_path) + + self.decorator = decorator + + def __getitem__(self, op_name_and_grid): + assert len(op_name_and_grid) >= 3, "len(op_name_and_grid) must >= 3." + self.op_name = op_name_and_grid[0] + self.custom_op_template = op_name_and_grid[1] + self.grid = op_name_and_grid[2] + if len(op_name_and_grid) == 3: + self.tune_config = {} + else: + self.tune_config = op_name_and_grid[3] + + return self.decorator + + +def paddle_use_triton(other_config={}, key=[]): + def decorator(func): + return KernelInterface(func, other_config, key) + + return decorator + diff --git a/ppdiffusers/examples/vctrl/infer_cogvideox_i2v_vctrl_cli.py b/ppdiffusers/examples/vctrl/infer_cogvideox_i2v_vctrl_cli.py index cd61cca81..9fa6c0565 100644 --- a/ppdiffusers/examples/vctrl/infer_cogvideox_i2v_vctrl_cli.py +++ b/ppdiffusers/examples/vctrl/infer_cogvideox_i2v_vctrl_cli.py @@ -16,7 +16,7 @@ import gc import os import re - +import datetime import numpy as np import paddle from decord import VideoReader @@ -30,6 +30,9 @@ VCtrlModel, ) +os.environ["INFERENCE_OPTIMIZE_TRITON"] = "True" + + def write_mp4(video_path, samples, fps=8): clip = ImageSequenceClip(samples, fps=fps) @@ -198,21 +201,103 @@ def parse_args(): validation_control_images = [ref_image] + validation_control_images num_frames = len(validation_control_images) num_frames = min(num_frames, args.max_frame) - video = pipeline( - image=ref_image, - prompt=args.prompt, - num_inference_steps=args.num_inference_steps, - num_frames=num_frames, - guidance_scale=args.guidance_scale, - generator=paddle.Generator().manual_seed(42), - conditioning_frames=validation_control_images[:num_frames], - conditioning_frame_indices=list(range(num_frames)), - conditioning_scale=args.conditioning_scale, - width=args.width, - height=args.height, - task=args.task, - conditioning_masks=validation_mask_images[:num_frames] if args.task == "mask" else None, - vctrl_layout_type=args.vctrl_layout_type, - ).frames[0] - final_result.append(video) - save_vid_side_by_side(final_result, validation_control_images[:num_frames], args.output_dir, fps=args.fps) + + + # pipeline.vctrl = paddle.incubate.jit.inference( + # pipeline.vctrl, + # save_model_dir="./tmp/vctrl/vctrl_block", + # enable_new_ir=True, + # cache_static_model=True, + # exp_enable_use_cutlass=False, + # delete_pass_lists=[], + # ) + + + + if True: + print("Benchmarking...") + warm_up = 1 + repeat_times = 2 + sumtime = 0.0 + times = repeat_times + warm_up + for i in range(times): + if i > 0: + paddle.device.synchronize() + starttime = datetime.datetime.now() + with paddle.no_grad(): + video = pipeline( + image=ref_image, + prompt=args.prompt, + num_inference_steps=args.num_inference_steps, + num_frames=num_frames, + guidance_scale=args.guidance_scale, + generator=paddle.Generator().manual_seed(42), + conditioning_frames=validation_control_images[:num_frames], + conditioning_frame_indices=list(range(num_frames)), + conditioning_scale=args.conditioning_scale, + width=args.width, + height=args.height, + task=args.task, + conditioning_masks=validation_mask_images[:num_frames] if args.task == "mask" else None, + vctrl_layout_type=args.vctrl_layout_type, + ).frames[0] + if i > 0: + paddle.device.synchronize() + endtime = datetime.datetime.now() + + final_result.append(video) + save_vid_side_by_side(final_result, validation_control_images[:num_frames], args.output_dir, fps=args.fps) + + if i > 0: + duringtime = endtime - starttime + duringtime = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 + sumtime += duringtime + print("Single end to end time : ", duringtime, "ms") + inference_global_mem = paddle.device.cuda.memory_reserved() / (1024**3) + print(f"Inference used CUDA memory : {inference_global_mem:.3f} GiB") + print(f"Single ave end to end time : ", sumtime / repeat_times, "ms") + + else: + # breakpoint() + # print(pipeline.transformer) + # breakpoint() + + + + + # pipeline.transformer = paddle.incubate.jit.inference( + # pipeline.transformer, + # save_model_dir="./tmp/vctrl/transformer_block", + # enable_new_ir=False, + # cache_static_model=False, + # exp_enable_use_cutlass=False, + # delete_pass_lists=[], + # ) + # pipeline.transformer.transformer_blocks = paddle.incubate.jit.inference( + # pipeline.transformer.transformer_blocks, + # save_model_dir="./tmp/vctrl/transformer_block", + # enable_new_ir=False, + # cache_static_model=False, + # exp_enable_use_cutlass=False, + # delete_pass_lists=[], + # ) + video = pipeline( + image=ref_image, + prompt=args.prompt, + num_inference_steps=args.num_inference_steps, + num_frames=num_frames, + guidance_scale=args.guidance_scale, + generator=paddle.Generator().manual_seed(42), + conditioning_frames=validation_control_images[:num_frames], + conditioning_frame_indices=list(range(num_frames)), + conditioning_scale=args.conditioning_scale, + width=args.width, + height=args.height, + task=args.task, + conditioning_masks=validation_mask_images[:num_frames] if args.task == "mask" else None, + vctrl_layout_type=args.vctrl_layout_type, + ).frames[0] + + + final_result.append(video) + save_vid_side_by_side(final_result, validation_control_images[:num_frames], args.output_dir, fps=args.fps) diff --git a/ppdiffusers/examples/vctrl/script/infer_cogvideox_i2v_pose_vctrl.sh b/ppdiffusers/examples/vctrl/script/infer_cogvideox_i2v_pose_vctrl.sh index 71834854e..b5b523ff8 100644 --- a/ppdiffusers/examples/vctrl/script/infer_cogvideox_i2v_pose_vctrl.sh +++ b/ppdiffusers/examples/vctrl/script/infer_cogvideox_i2v_pose_vctrl.sh @@ -12,15 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +export CUDA_VISIBLE_DEVICES=1 +export FLAGS_enable_pir_api=0 + + +export LD_LIBRARY_PATH=/root/paddlejob/workspace/env_run/output/changwenbin/Research/TensorRT-10.3.0.26/lib/:$LD_LIBRARY_PATH +export LD_LIBRARY_PATH=/root/paddlejob/workspace/env_run/output/changwenbin/Paddle/paddle/phi/kernels/fusion/cutlass/conv2d/build:$LD_LIBRARY_PATH +export LD_LIBRARY_PATH=/root/paddlejob/workspace/env_run/output/changwenbin/Paddle/paddle/phi/kernels/fusion/cutlass/gemm_epilogue/build:$LD_LIBRARY_PATH +export TRITON_KERNEL_CACHE_DIR=/root/paddlejob/workspace/env_run/output/changwenbin/PaddleMIX/ppdiffusers/examples/vctrl/tmp/triton_kernel + +# nsys profile -o vctrl_pose_triton_static \ python infer_cogvideox_i2v_vctrl_cli.py \ --pretrained_model_name_or_path "paddlemix/cogvideox-5b-i2v-vctrl" \ --vctrl_path "vctrl_pose_5b_i2v.pdparams" \ --vctrl_config "vctrl_configs/cogvideox_5b_i2v_vctrl_config.json" \ - --control_video_path "guide_values_1.mp4" \ - --ref_image_path "reference_image_1.jpg" \ - --control_mask_video_path 'mask_values_1.mp4' \ + --control_video_path "pose/guide_values_0.mp4" \ + --ref_image_path "pose/reference_image_0.jpg" \ --output_dir "infer_outputs/pose2video" \ - --prompt "" \ + --prompt "An animated character with blue hair and a playful expression dances energetically on a reflective stage, wearing a black dress with white lace and ruffles, and black stockings. She is surrounded by a futuristic setting with geometric shapes and neon lights in shades of blue, red, and orange. As she dances, her outfit changes slightly, including a black top with a heart-shaped neckline and a purple tail. Her dynamic poses and expressions convey joy and confidence. The backdrop's neon lights and geometric patterns enhance the vibrant and lively atmosphere of the performance." \ --task "pose" \ --width 480 \ --height 720 \ diff --git a/ppdiffusers/ppdiffusers/models/attention_processor.py b/ppdiffusers/ppdiffusers/models/attention_processor.py index 8894b1447..e8570679f 100644 --- a/ppdiffusers/ppdiffusers/models/attention_processor.py +++ b/ppdiffusers/ppdiffusers/models/attention_processor.py @@ -23,6 +23,7 @@ from ..utils.import_utils import is_ppxformers_available from ..utils.paddle_utils import maybe_allow_in_graph from .lora import LoRACompatibleLinear, LoRALinearLayer +from .embeddings import apply_rotary_emb logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -528,8 +529,9 @@ def forward( hidden_states: paddle.Tensor, encoder_hidden_states: Optional[paddle.Tensor] = None, attention_mask: Optional[paddle.Tensor] = None, - **cross_attention_kwargs, + image_rotary_emb: Optional[tuple[paddle.Tensor, paddle.Tensor]] = None, ) -> paddle.Tensor: + cross_attention_kwargs = {"image_rotary_emb" : image_rotary_emb} r""" The forward method of the `Attention` class. @@ -2195,25 +2197,57 @@ def __call__( if attn.norm_k is not None: key = attn.norm_k(key) - # Apply RoPE if needed + # # Apply RoPE if needed if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - + # from ppdiffusers import apply_rotary_emb + + # paddle.save(image_rotary_emb,"/root/paddlejob/workspace/env_run/output/changwenbin/PaddleMIX/image_rotary_emb.pd") + + # paddle.save(query,"/root/paddlejob/workspace/env_run/output/changwenbin/PaddleMIX/allq.pd") + # breakpoint() + # breakpoint() + # import paddlemix + # q ,k= paddlemix.triton_ops.apply_rotary_emb_triton(query[:, :, text_seq_length:],key[:, :, text_seq_length:],image_rotary_emb[0],image_rotary_emb[1]) + # # k = paddlemix.triton_ops.apply_rotary_emb_triton(key[:, :, text_seq_length:],image_rotary_emb[0],image_rotary_emb[1]) + # query[:, :, text_seq_length:] = q[0] + # key[:, :, text_seq_length:] = k[0] query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) if not attn.is_cross_attention: key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) - + # query[:, :, text_seq_length:] = apply_rotary_emb_triton(query[:, :, text_seq_length:], image_rotary_emb) + # key[:, :, text_seq_length:] = apply_rotary_emb_triton(key[:, :, text_seq_length:], image_rotary_emb) + + + # cos,sin = image_rotary_emb + # cos = paddle.cast(cos,dtype="float16") + # sin = paddle.cast(sin,dtype="float16") + # breakpoint() + # cos = cos[None,None] + # sin = sin[None,None] + # from paddle.incubate.nn.functional import fused_rotary_position_embedding + # query = query.transpose([0,2,1,3]) + # key = key.transpose([0,2,1,3]) + # query[:,text_seq_length:,:], key[:,text_seq_length:, :], _ = fused_rotary_position_embedding(query[:,text_seq_length:,:], key[:, text_seq_length:,:],None, sin=sin, cos=cos,use_neox_rotary_style=False) + # breakpoint() # NOTE: There is diff between paddle's and torch's sdpa # paddle needs input: [batch_size, seq_len, num_heads, head_dim] # torch needs input: [batch_size, num_heads, seq_len, head_dim] - hidden_states = F.scaled_dot_product_attention_( - query.transpose([0, 2, 1, 3]), - key.transpose([0, 2, 1, 3]), - value.transpose([0, 2, 1, 3]), - attn_mask=attention_mask, - dropout_p=0.0, - is_causal=False, - ) + # hidden_states = F.scaled_dot_product_attention_( + # query.transpose([0, 2, 1, 3]), + # key.transpose([0, 2, 1, 3]), + # value.transpose([0, 2, 1, 3]), + # attn_mask=attention_mask, + # dropout_p=0.0, + # is_causal=False, + # ) + import paddlemix + hidden_states = paddlemix.triton_ops.sageattn_qk_int8_pv_fp16_triton( + query, + key, + value, + is_causal=False, + tensor_layout="NHD") + hidden_states = hidden_states.reshape([batch_size, -1, attn.heads * head_dim]) diff --git a/ppdiffusers/ppdiffusers/models/cogvideox_transformer_3d_vctrl.py b/ppdiffusers/ppdiffusers/models/cogvideox_transformer_3d_vctrl.py index 8f3b2beb5..563e456b2 100644 --- a/ppdiffusers/ppdiffusers/models/cogvideox_transformer_3d_vctrl.py +++ b/ppdiffusers/ppdiffusers/models/cogvideox_transformer_3d_vctrl.py @@ -33,7 +33,7 @@ from ppdiffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero from ppdiffusers.utils import logging from ppdiffusers.utils.paddle_utils import maybe_allow_in_graph - +import nvtx logger = logging.get_logger(__name__) @@ -89,10 +89,22 @@ def __init__( ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, + layer_idx = -1, ): super().__init__() self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) - + self.layer_idx = layer_idx + self.silu = paddle.nn.Silu() + self.linear1 = paddle.nn.Linear(in_features=time_embed_dim, + out_features=6 * dim, bias_attr=True) + self.norm3 = paddle.nn.LayerNorm(normalized_shape=dim, + epsilon=norm_eps, weight_attr=norm_elementwise_affine, bias_attr=norm_elementwise_affine) + + self.linear1.weight = self.norm1.linear.weight + self.linear1.bias = self.norm1.linear.bias + self.norm3.weight = self.norm1.norm.weight + self.norm3.bias = self.norm1.norm.bias + # breakpoint() self.attn1 = Attention( query_dim=dim, dim_head=attention_head_dim, @@ -105,6 +117,17 @@ def __init__( ) self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) + + self.linear2 = paddle.nn.Linear(in_features=time_embed_dim, + out_features=6 * dim, bias_attr=True) + self.norm4 = paddle.nn.LayerNorm(normalized_shape=dim, + epsilon=norm_eps, weight_attr=norm_elementwise_affine, bias_attr=norm_elementwise_affine) + + self.linear2.weight = self.norm2.linear.weight + self.linear2.bias = self.norm2.linear.bias + self.norm4.weight = self.norm2.norm.weight + self.norm4.bias = self.norm2.norm.bias + self.ff = FeedForward( dim, dropout=dropout, @@ -113,7 +136,14 @@ def __init__( inner_dim=ff_inner_dim, bias=ff_bias, ) - + # @paddle.incubate.jit.inference( + # save_model_dir="./tmp/vctrl/transformer_block", + # enable_new_ir=False, + # cache_static_model=False, + # exp_enable_use_cutlass=False, + # delete_pass_lists=[], + # switch_ir_optim = False + # ) def forward( self, hidden_states: paddle.Tensor, @@ -121,25 +151,67 @@ def forward( temb: paddle.Tensor, image_rotary_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, ) -> paddle.Tensor: + + paddle.device.synchronize() + transformer_block_nvtx = nvtx.start_range(message="block", color="red") text_seq_length = encoder_hidden_states.shape[1] - - norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( - hidden_states, encoder_hidden_states, temb - ) - + # breakpoint() + + # norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( + # hidden_states, encoder_hidden_states, temb + # ) + + shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear1(self + .silu(temb)).chunk(chunks=6, axis=1) + norm_hidden_states = self.norm3(hidden_states) * (1 + scale)[:, None, : + ] + shift[:, None, :] + norm_encoder_hidden_states = self.norm3(encoder_hidden_states) * (1 + + enc_scale)[:, None, :] + enc_shift[:, None, :] + gate_msa, enc_gate_msa = gate, enc_gate + + # # breakpoint() + # if self.layer_idx ==0: + # self.attn1 = paddle.incubate.jit.inference(self.attn1, + # save_model_dir=f"/root/.cache/paddle/inference_models/forward_{self.layer_idx}/", + # cache_static_model = False, + # enable_new_ir=False, + # switch_ir_optim = False + # ) + attn_hidden_states, attn_encoder_hidden_states = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, ) - hidden_states = hidden_states + gate_msa * attn_hidden_states - encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states + # paddle.device.synchronize() + # transformer_block_norm_nvtx = nvtx.start_range(message="norm", color="green") + + + # hidden_states = hidden_states + gate_msa * attn_hidden_states + # encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states + + # norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( + # hidden_states, encoder_hidden_states, temb + # ) - norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( - hidden_states, encoder_hidden_states, temb + shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear2(self.silu(temb)).chunk(chunks=6, axis=1) + # norm_hidden_states = self.norm4(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + # norm_encoder_hidden_states = self.norm4(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :] + gate_ff, enc_gate_ff = gate[:, None, :], enc_gate[:, None, :] + + import paddlemix + hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( + hidden_states, attn_hidden_states, gate_msa, scale, shift, epsilon=1e-05 ) + + encoder_hidden_states, norm_encoder_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( + encoder_hidden_states, attn_encoder_hidden_states, enc_gate_msa, enc_scale, enc_shift, epsilon=1e-05 + ) + + # paddle.device.synchronize() + # nvtx.end_range(transformer_block_norm_nvtx) norm_hidden_states = paddle.concat(x=[norm_encoder_hidden_states, norm_hidden_states], axis=1) @@ -148,6 +220,10 @@ def forward( hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length] + paddle.device.synchronize() + nvtx.end_range(transformer_block_nvtx) + + return hidden_states, encoder_hidden_states @@ -258,7 +334,7 @@ def __init__( use_positional_embeddings=not use_rotary_positional_embeddings, use_learned_positional_embeddings=use_learned_positional_embeddings, ) - self.embedding_dropout = paddle.nn.Dropout(p=dropout) + # self.embedding_dropout = paddle.nn.Dropout(p=dropout) self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) self.transformer_blocks = paddle.nn.LayerList( @@ -273,10 +349,20 @@ def __init__( attention_bias=attention_bias, norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, + layer_idx = ii, ) - for _ in range(num_layers) + for ii in range(num_layers) ] ) + + # self.transformer_blocks = paddle.incubate.jit.inference( + # self.transformer_blocks, + # enable_new_ir=False, + # cache_static_model=False, + # exp_enable_use_cutlass=False, + # delete_pass_lists=[], + # ) + self.norm_final = paddle.nn.LayerNorm( normalized_shape=inner_dim, epsilon=norm_eps, @@ -389,24 +475,58 @@ def forward( timestep_cond: Optional[paddle.Tensor] = None, image_rotary_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, block_vctrl_residuals: Optional[List[paddle.Tensor]] = None, - vctrl_layout_type: Optional[str] = "even", + # vctrl_layout_type: Optional[str] = "even", return_dict: bool = True, ): + + import nvtx + + # paddle.device.synchronize() + # vctrl_qian = nvtx.start_range(message="A", color="green") + + vctrl_layout_type = "even" batch_size, num_frames, channels, height, width = tuple(hidden_states.shape) + + + # paddle.device.synchronize() + # vctrl_B = nvtx.start_range(message="B", color="yellow") timesteps = timestep t_emb = self.time_proj(timesteps) - - t_emb = t_emb.to(dtype=hidden_states.dtype) + # paddle.device.synchronize() + # nvtx.end_range(vctrl_B) + + # paddle.device.synchronize() + # vctrl_C = nvtx.start_range(message="C", color="blue") + # t_emb = t_emb.to(dtype=hidden_states.dtype) + t_emb = paddle.cast(t_emb,dtype=hidden_states.dtype) emb = self.time_embedding(t_emb, timestep_cond) - + # paddle.device.synchronize() + # nvtx.end_range(vctrl_C) + + + + # paddle.device.synchronize() + # vctrl_D = nvtx.start_range(message="D", color="red") hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) - hidden_states = self.embedding_dropout(hidden_states) - + # hidden_states = self.embedding_dropout(hidden_states) + # paddle.device.synchronize() + # nvtx.end_range(vctrl_D) + + # paddle.device.synchronize() + # vctrl_E = nvtx.start_range(message="E", color="green") text_seq_length = tuple(encoder_hidden_states.shape)[1] encoder_hidden_states = hidden_states[:, :text_seq_length] hidden_states = hidden_states[:, text_seq_length:] + # paddle.device.synchronize() + # nvtx.end_range(vctrl_E) + + # paddle.device.synchronize() + # nvtx.end_range(vctrl_qian) + + # paddle.device.synchronize() + # vctrl_blocks = nvtx.start_range(message="blocks", color="green") for i, block in enumerate(self.transformer_blocks): if self.training and self.gradient_checkpointing: @@ -456,7 +576,9 @@ def custom_forward(*inputs): ) else: raise ValueError(f"vctrl_layout_type {vctrl_layout_type} is not supported.") - + # paddle.device.synchronize() + # nvtx.end_range(vctrl_blocks) + if not self.config.use_rotary_positional_embeddings: hidden_states = self.norm_final(hidden_states) else: diff --git a/ppdiffusers/ppdiffusers/models/embeddings.py b/ppdiffusers/ppdiffusers/models/embeddings.py index 152422d01..171ebbb01 100644 --- a/ppdiffusers/ppdiffusers/models/embeddings.py +++ b/ppdiffusers/ppdiffusers/models/embeddings.py @@ -22,7 +22,8 @@ from ..utils import USE_PEFT_BACKEND from .activations import FP32SiLU, get_activation from .lora import LoRACompatibleLinear - +import nvtx + def get_timestep_embedding( timesteps: paddle.Tensor, @@ -75,15 +76,15 @@ def get_2d_sincos_pos_embed( if isinstance(grid_size, int): grid_size = (grid_size, grid_size) - grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale - grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) + grid_h = paddle.arange(grid_size[0], dtype='float32') / (grid_size[0] / base_size) / interpolation_scale + grid_w = paddle.arange(grid_size[1], dtype='float32') / (grid_size[1] / base_size) / interpolation_scale + grid = paddle.meshgrid(grid_w, grid_h) # here w goes first + grid = paddle.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token and extra_tokens > 0: - pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + pos_embed = paddle.concat([paddle.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) return pos_embed @@ -95,7 +96,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + emb = paddle.concat([emb_h, emb_w], axis=1) # (H*W, D) return emb @@ -106,17 +107,17 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") - omega = np.arange(embed_dim // 2, dtype=np.float64) + omega = paddle.arange(embed_dim // 2, dtype='float64') omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) - pos = pos.reshape(-1) # (M,) - out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + pos = paddle.cast(pos.reshape([-1]),dtype="float64") # (M,) - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) + out = paddle.einsum("m,d->md", pos, omega) # (M, D/2), outer product + emb_sin = paddle.sin(out) # (M, D/2) + emb_cos = paddle.cos(out) # (M, D/2) - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + emb = paddle.concat([emb_sin, emb_cos], axis=1) # (M, D) return emb @@ -919,11 +920,11 @@ def forward(self, caption): def get_3d_sincos_pos_embed( embed_dim: int, - spatial_size: Union[int, Tuple[int, int]], + spatial_size: int or tuple, temporal_size: int, spatial_interpolation_scale: float = 1.0, temporal_interpolation_scale: float = 1.0, -) -> np.ndarray: +) -> paddle.Tensor: """ Args: embed_dim (`int`): @@ -934,23 +935,32 @@ def get_3d_sincos_pos_embed( """ if embed_dim % 4 != 0: raise ValueError("`embed_dim` must be divisible by 4") - if isinstance(spatial_size, int): - spatial_size = spatial_size, spatial_size + if isinstance(spatial_size, int): # False + spatial_size = (spatial_size, spatial_size) embed_dim_spatial = 3 * embed_dim // 4 embed_dim_temporal = embed_dim // 4 - grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale - grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale - grid = np.meshgrid(grid_w, grid_h) - grid = np.stack(grid, axis=0) + + # Generate spatial grid + grid_h = paddle.arange(spatial_size[1], dtype='float32') / spatial_interpolation_scale + grid_w = paddle.arange(spatial_size[0], dtype='float32') / spatial_interpolation_scale + + # grid = paddle.meshgrid(grid_w, grid_h) + grid = [grid_w.unsqueeze(0).tile([grid_h.shape[0], 1]),grid_h.unsqueeze(1).tile([1, grid_w.shape[0]])] + grid = paddle.stack(grid, axis=0) grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]]) pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid) - grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale + + # Generate temporal grid + grid_t = paddle.arange(temporal_size, dtype='float32') / temporal_interpolation_scale pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t) - pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] - pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) - pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] - pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) - pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) + + # Combine spatial and temporal embeddings + pos_embed_spatial = pos_embed_spatial.unsqueeze(0) + pos_embed_spatial = pos_embed_spatial.tile([temporal_size, 1, 1]) + pos_embed_temporal = pos_embed_temporal.unsqueeze(1) + pos_embed_temporal = pos_embed_temporal.tile([1, spatial_size[0] * spatial_size[1], 1]) + pos_embed = paddle.concat([pos_embed_temporal, pos_embed_spatial], axis=-1) + return pos_embed @@ -1006,6 +1016,8 @@ def _get_positional_embeddings(self, sample_height: int, sample_width: int, samp post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1 num_patches = post_patch_height * post_patch_width * post_time_compression_frames + # paddle.device.synchronize() + # emb_3d = nvtx.start_range(message="3d", color="red") pos_embedding = get_3d_sincos_pos_embed( self.embed_dim, (post_patch_width, post_patch_height), @@ -1013,9 +1025,16 @@ def _get_positional_embeddings(self, sample_height: int, sample_width: int, samp self.spatial_interpolation_scale, self.temporal_interpolation_scale, ) + # paddle.device.synchronize() + # nvtx.end_range(emb_3d) + + # paddle.device.synchronize() + # emb_to_tensor = nvtx.start_range(message="to_tensor", color="yellow") pos_embedding = paddle.to_tensor(data=pos_embedding).flatten(start_axis=0, stop_axis=1) joint_pos_embedding = paddle.zeros([1, self.max_text_seq_length + num_patches, self.embed_dim]) joint_pos_embedding[0, self.max_text_seq_length :] = pos_embedding + # paddle.device.synchronize() + # nvtx.end_range(emb_to_tensor) return joint_pos_embedding def forward(self, text_embeds: paddle.Tensor, image_embeds: paddle.Tensor): @@ -1026,14 +1045,35 @@ def forward(self, text_embeds: paddle.Tensor, image_embeds: paddle.Tensor): image_embeds (`torch.Tensor`): Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width). """ - text_embeds = self.text_proj(text_embeds) + # paddle.device.synchronize() + # emb_A = nvtx.start_range(message="emb_a", color="green") + text_embeds = self.text_proj(text_embeds) + # paddle.device.synchronize() + # nvtx.end_range(emb_A) + + + # paddle.device.synchronize() + # emb_B = nvtx.start_range(message="emb_b", color="yellow") batch, num_frames, channels, height, width = image_embeds.shape image_embeds = image_embeds.reshape([-1, channels, height, width]) image_embeds = self.proj(image_embeds) + # paddle.device.synchronize() + # nvtx.end_range(emb_B) + + # paddle.device.synchronize() + # emb_C = nvtx.start_range(message="emb_c", color="red") image_embeds = image_embeds.reshape([batch, num_frames] + image_embeds.shape[1:]) + # paddle.device.synchronize() + # nvtx.end_range(emb_C) + + # paddle.device.synchronize() + # emb_D = nvtx.start_range(message="emb_d", color="blue") image_embeds = image_embeds.flatten(3).transpose([0, 1, 3, 2]) # [batch, num_frames, height x width, channels] image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] + # paddle.device.synchronize() + # nvtx.end_range(emb_D) + embeds = paddle.concat(x=[text_embeds, image_embeds], axis=1).contiguous() if self.use_positional_embeddings or self.use_learned_positional_embeddings: @@ -1078,10 +1118,14 @@ def get_3d_rotary_pos_embed( Returns: `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. """ + # breakpoint() start, stop = crops_coords - grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) - grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32) - grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) + # grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) + # grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32) + # grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) + grid_h = paddle.linspace(start[0], stop[0], grid_size[0], dtype="float32") + grid_w = paddle.linspace(start[1], stop[1], grid_size[1], dtype="float32") + grid_t = paddle.linspace(0, temporal_size, temporal_size, dtype="float32") dim_t = embed_dim // 4 dim_h = embed_dim // 8 * 3 dim_w = embed_dim // 8 * 3 @@ -1162,7 +1206,9 @@ def apply_rotary_emb( x_rotated = paddle.concat(x=[-x_imag, x_real], axis=-1) else: raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") - out = (x.astype(dtype="float32") * cos + x_rotated.astype(dtype="float32") * sin).to(x.dtype) + # out = (x.astype(dtype="float32") * cos + x_rotated.astype(dtype="float32") * sin).to(x.dtype) + out = paddle.cast((x.astype(dtype="float32") * cos + x_rotated.astype(dtype="float32") * sin),x.dtype) + return out else: x_rotated = paddle.as_complex(x=x.astype(dtype="float32").reshape(*tuple(x.shape)[:-1], -1, 2)) diff --git a/ppdiffusers/ppdiffusers/models/vctrl.py b/ppdiffusers/ppdiffusers/models/vctrl.py index a459a292b..6e7e875ec 100644 --- a/ppdiffusers/ppdiffusers/models/vctrl.py +++ b/ppdiffusers/ppdiffusers/models/vctrl.py @@ -209,8 +209,19 @@ def __init__( attention_out_bias: bool = True, ): super().__init__() + self.norm1 = VCtrlLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) - + # breakpoint() + self.silu1 = paddle.nn.Silu() + self.linear1 = paddle.nn.Linear(in_features=time_embed_dim, out_features=3 * dim, bias_attr=True) + self.norm3 = paddle.nn.LayerNorm( + normalized_shape=dim, epsilon=1e-05, weight_attr=True, bias_attr=True + ) + self.linear1.weight = self.norm1.linear.weight + self.linear1.bias = self.norm1.linear.bias + self.norm3.weight = self.norm1.norm.weight + self.norm3.bias = self.norm1.norm.bias + self.attn1 = Attention( query_dim=dim, dim_head=attention_head_dim, @@ -222,6 +233,18 @@ def __init__( processor=VCtrlAttnProcessor2_0(), ) self.norm2 = VCtrlLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) + + self.silu2 = paddle.nn.Silu() + self.linear2 = paddle.nn.Linear(in_features=time_embed_dim, out_features=3 * dim, bias_attr=True) + self.norm4 = paddle.nn.LayerNorm( + normalized_shape=dim, epsilon=1e-05, weight_attr=True, bias_attr=True + ) + + self.linear2.weight = self.norm2.linear.weight + self.linear2.bias = self.norm2.linear.bias + self.norm4.weight = self.norm2.norm.weight + self.norm4.bias = self.norm2.norm.bias + # breakpoint() self.ff = FeedForward( dim, dropout=dropout, @@ -238,12 +261,30 @@ def forward( image_rotary_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, ) -> paddle.Tensor: - norm_hidden_states, gate_msa = self.norm1(hidden_states, temb) + # breakpoint() + + # norm_hidden_states, gate_msa = self.norm1(hidden_states, temb) + shift, scale, gate = self.linear1(self.silu1(temb)).chunk(chunks=3, axis=1) + norm_hidden_states = self.norm3(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + gate_msa = gate attn_hidden_states = self.attn1(hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb) - hidden_states = hidden_states + gate_msa * attn_hidden_states - norm_hidden_states, gate_ff = self.norm2(hidden_states, temb) + # hidden_states = hidden_states + gate[:, None, :] * attn_hidden_states + + + # norm_hidden_states, gate_ff = self.norm2(hidden_states, temb) + shift, scale, gate = self.linear2(self.silu2(temb)).chunk(chunks=3, axis=1) + # norm_hidden_states = self.norm4(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + gate_ff = gate[:, None, :] + + + import paddlemix + hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( + hidden_states, attn_hidden_states, gate_msa, scale, shift, epsilon=1e-05 + ) + + ff_output = self.ff(norm_hidden_states) hidden_states = hidden_states + gate_ff * ff_output return hidden_states @@ -336,32 +377,42 @@ def forward( sample: paddle.Tensor, timestep: Union[paddle.Tensor, float, int], v_cond: paddle.Tensor, - v_cond_scale: float = 1.0, image_rotary_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, return_dict: bool = True, - ) -> Union[VCtrlModelOutput, Tuple[Tuple[paddle.Tensor, ...], paddle.Tensor]]: + ) -> Union[VCtrlModelOutput, list[list[paddle.Tensor, ...], paddle.Tensor]]: dtype = sample.dtype timesteps = timestep t_emb = self.time_proj(timesteps) - t_emb = t_emb.to(dtype=dtype) + # t_emb = t_emb.to(dtype=dtype) + t_emb = paddle.cast(t_emb,dtype=dtype) t_emb = self.time_embedding(t_emb) sample = self.sample_patch_embed(sample) v_cond = self.cond_patch_embed(v_cond) - mean_latents, std_latents = paddle.mean(x=sample, axis=(1, 2), keepdim=True), paddle.std( - x=sample.to(dtype="float32"), axis=(1, 2), keepdim=True - ).to(dtype=dtype) - mean_control, std_control = paddle.mean(x=v_cond, axis=(1, 2), keepdim=True), paddle.std( - x=v_cond.to(dtype="float32"), axis=(1, 2), keepdim=True - ).to(dtype=dtype) + # mean_latents, std_latents = paddle.mean(x=sample, axis=(1, 2), keepdim=True), paddle.std( + # x=sample.to(dtype="float32"), axis=(1, 2), keepdim=True + # ).to(dtype=dtype) + # mean_control, std_control = paddle.mean(x=v_cond, axis=(1, 2), keepdim=True), paddle.std( + # x=v_cond.to(dtype="float32"), axis=(1, 2), keepdim=True + # ).to(dtype=dtype) + + mean_latents, std_latents = paddle.mean(x=sample, axis=(1, 2), keepdim=True), paddle.cast(paddle.std( + x=paddle.cast(sample,dtype="float32"), axis=(1, 2), keepdim=True + ),dtype=dtype) + mean_control, std_control = paddle.mean(x=v_cond, axis=(1, 2), keepdim=True), paddle.cast(paddle.std( + x=paddle.cast(v_cond,dtype="float32"), axis=(1, 2), keepdim=True + ),dtype=dtype) v_cond = (v_cond - mean_control) * (std_latents / (std_control + 1e-05)) + mean_latents hidden_states = sample + v_cond - hidden_states = hidden_states.to(dtype=dtype) + # hidden_states = hidden_states.to(dtype=dtype) + hidden_states = paddle.cast(hidden_states,dtype=dtype) + # breakpoint() + # self.modify_state_dict(self.state_dict()) features = [] for i, block in enumerate(self.transformer_blocks): if self.training and self.gradient_checkpointing: @@ -380,8 +431,32 @@ def custom_forward(*inputs): hidden_states = block(hidden_states=hidden_states, temb=t_emb, image_rotary_emb=image_rotary_emb) features.append(hidden_states) - features = [(feature * v_cond_scale) for feature in features] + features = [(feature * 1.0) for feature in features] if not return_dict: return features return VCtrlModelOutput(vctrl_block_samples=features) + + + # @classmethod + # # def custom_modify_weight(cls, model_to_load, state_dict): + # def modify_state_dict(): + # # NOTE:(changwenbin,zhoukangkang) SD3 num_layers is 24 + # sd3_num_layers = 7 + # for i in range(sd3_num_layers): + # map_sd3 = [ + # (f"{i}.linear1.weight", f"{i}.norm1.linear.weight"), + # (f"{i}.linear1.bias", f"{i}.norm1.linear.bias"), + # (f"{i}.norm3.weight", f"{i}.norm1.norm.weight"), + # (f"{i}.norm3.bias", f"{i}.norm1.norm.bias"), + # (f"{i}.linear2.weight", f"{i}.norm2.linear.weight"), + # (f"{i}.linear2.bias", f"{i}.norm2.linear.bias"), + # (f"{i}.norm4.weight", f"{i}.norm2.norm.weight"), + # (f"{i}.norm4.bias", f"{i}.norm2.norm.bias"), + + # ] + # for to_, from_ in map_sd3: + # if "transformer_blocks." + from_ in self.state_dict(): + # state_dict["transformer_blocks." + to_] = state_dict["transformer_blocks." + from_] + # else: + # print(f"Warning!!: '{from_}' not found in state_dict") diff --git a/ppdiffusers/ppdiffusers/pipelines/cogvideo/pipeline_cogvideox_image2video_vctrl.py b/ppdiffusers/ppdiffusers/pipelines/cogvideo/pipeline_cogvideox_image2video_vctrl.py index ff6fb484b..f98b1fc01 100644 --- a/ppdiffusers/ppdiffusers/pipelines/cogvideo/pipeline_cogvideox_image2video_vctrl.py +++ b/ppdiffusers/ppdiffusers/pipelines/cogvideo/pipeline_cogvideox_image2video_vctrl.py @@ -20,7 +20,7 @@ import paddle import paddlenlp import PIL - +import nvtx from ppdiffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from ppdiffusers.image_processor import PipelineImageInput, VaeImageProcessor from ppdiffusers.models import ( @@ -736,25 +736,38 @@ def map_frame_latent(indices): latent_model_input = paddle.concat(x=[latent_model_input, latent_image_input], axis=2) timestep = t.expand(shape=tuple(latent_model_input.shape)[0]) control_model_input = latent_model_input + # breakpoint() + + + # paddle.device.synchronize() + # vctrl_nvtx = nvtx.start_range(message="vctrl", color="red") + vctrl_block_samples = self.vctrl( control_model_input, timestep, v_cond=v_cond, - v_cond_scale=conditioning_scale, image_rotary_emb=v_cond_rotary_emb, return_dict=False, ) + # paddle.device.synchronize() + # nvtx.end_range(vctrl_nvtx) + + # paddle.device.synchronize() + # transformer_nvtx = nvtx.start_range(message="transformer", color="yellow") noise_pred = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, timestep=timestep, block_vctrl_residuals=vctrl_block_samples, - vctrl_layout_type=vctrl_layout_type, + # vctrl_layout_type=vctrl_layout_type, image_rotary_emb=image_rotary_emb, return_dict=False, ) + # paddle.device.synchronize() + # nvtx.end_range(transformer_nvtx) + noise_pred = noise_pred.astype(dtype="float32") if use_dynamic_cfg: