包含以下内容:
- warp_reduce_fp32/fp16/bf16_kernel
- block_reduce_fp32_kernel
- block_all_reduce_sum_f32_f32_kernel
- block_all_reduce_sum_f32x4_f32_kernel(float4向量化版本)
- block_all_reduce_sum_f16_f16_kernel(fp16版本,使用fp16 acc)
- block_all_reduce_sum_f16_f32_kernel(fp16版本,使用fp32 acc)
- block_all_reduce_sum_f16x2_f16_kernel(fp16向量化版本,使用fp16 acc)
- block_all_reduce_sum_f16x2_f32_kernel(fp16向量化版本,使用fp32 acc)
- block_all_reduce_sum_f16x8_pack_f16_kernel(fp16向量化版本,使用fp16 acc, pack)
- block_all_reduce_sum_f16x8_pack_f32_kernel(fp16向量化版本,使用fp32 acc, pack)
- block_all_reduce_sum_bf16_bf16_kernel(bf16版本,使用bf16 acc)
- block_all_reduce_sum_bf16_f32_kernel(bf16版本,使用fp32 acc)
- block_all_reduce_sum_bf16x8_pack_bf16_kernel(bf16版本,使用bf16 acc, pack)
- block_all_reduce_sum_bf16x8_pack_f32_kernel(bf16版本,使用fp32 acc, pack)
- block_all_reduce_sum_bf16x2_bf16_kernel(bf16向量化版本,使用bf16 acc)
- block_all_reduce_sum_bf16x2_f32_kernel(bf16向量化版本,使用fp32 acc)
- block_all_reduce_sum_fp8_e4m3_f16_kernel(fp8_e4m3版本,使用fp16 acc)
- block_all_reduce_sum_fp8_e5m2_f16_kernel(fp8_e5m2版本,使用fp16 acc)
- block_all_reduce_sum_fp8_e4m3x16_pack_f16_kernel(fp8_e4m3版本,使用fp16 acc, pack)
- block_all_reduce_sum_fp8_e5m2x16_pack_f16_kernel(fp8_e5m2版本,使用fp16 acc, pack)
- block_all_reduce_sum_i8_i32_kernel(i8版本,使用i32 acc)
- block_all_reduce_sum_i8x16_pack_i32_kernel(i8版本,使用i32 acc, pack)
- PyTorch bindings for block reduce fp32/fp16/bf16/fp8/i8 kernels
所有支持的block all reduce kernel:
// packed_type, acc_type, th_type, element_type, n_elements_per_pack, out_type
TORCH_BINDING_REDUCE(f32, f32, torch::kFloat32, float, 1, float)
TORCH_BINDING_REDUCE(f32x4, f32, torch::kFloat32, float, 4, float)
TORCH_BINDING_REDUCE(f16, f16, torch::kHalf, half, 1, float)
TORCH_BINDING_REDUCE(f16, f32, torch::kHalf, half, 1, float)
TORCH_BINDING_REDUCE(f16x2, f16, torch::kHalf, half, 2, float)
TORCH_BINDING_REDUCE(f16x2, f32, torch::kHalf, half, 2, float)
TORCH_BINDING_REDUCE(f16x8_pack, f16, torch::kHalf, half, 8, float)
TORCH_BINDING_REDUCE(f16x8_pack, f32, torch::kHalf, half, 8, float)
TORCH_BINDING_REDUCE(bf16, bf16, torch::kBFloat16, __nv_bfloat16, 1, float)
TORCH_BINDING_REDUCE(bf16, f32, torch::kBFloat16, __nv_bfloat16, 1, float)
TORCH_BINDING_REDUCE(bf16x2, bf16, torch::kBFloat16, __nv_bfloat16, 2, float)
TORCH_BINDING_REDUCE(bf16x2, f32, torch::kBFloat16, __nv_bfloat16, 2, float)
TORCH_BINDING_REDUCE(bf16x8_pack, bf16, torch::kBFloat16, __nv_bfloat16, 8, float)
TORCH_BINDING_REDUCE(bf16x8_pack, f32, torch::kBFloat16, __nv_bfloat16, 8, float)
TORCH_BINDING_REDUCE(fp8_e4m3, f16, torch::kFloat8_e4m3fn, __nv_fp8_storage_t, 1, float)
TORCH_BINDING_REDUCE(fp8_e4m3x16_pack, f16, torch::kFloat8_e4m3fn, __nv_fp8_storage_t, 16, float)
TORCH_BINDING_REDUCE(fp8_e5m2, f16, torch::kFloat8_e5m2, __nv_fp8_storage_t, 1, float)
TORCH_BINDING_REDUCE(fp8_e5m2x16_pack, f16, torch::kFloat8_e5m2, __nv_fp8_storage_t, 16, float)
TORCH_BINDING_REDUCE(i8, i32, torch::kInt8, int8_t, 1, int32_t)
TORCH_BINDING_REDUCE(i8x16_pack, i32, torch::kInt8, int8_t, 16, int32_t)
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada
python3 block_all_reduce.py
输出:
--------------------------------------------------------------------------------
S=1024, K=1024
out_f32f32: 2340.52685547 , time:0.01104856ms
out_f32x4f32: 2340.52539062 , time:0.01093769ms
out_f32f32_th: 2340.52563477 , time:0.01271653ms
--------------------------------------------------------------------------------
out_f16f16: 2341.14599609 , time:0.01084900ms
out_f16f32: 2340.54711914 , time:0.01082754ms
out_f16x2f32: 2340.57177734 , time:0.01104403ms
out_f16x2f16: 2340.59106445 , time:0.01104188ms
out_f16x8packf16: 2340.04199219 , time:0.01084495ms
out_f16x8packf32: 2340.54785156 , time:0.01081753ms
out_f16f16_th: 2340.00000000 , time:0.01274228ms
--------------------------------------------------------------------------------
out_bf16bf16: 2341.23437500 , time:0.01074696ms
out_bf16f32: 2341.17846680 , time:0.01436758ms
out_bf16x2f32: 2343.69921875 , time:0.01096773ms
out_bf16x2bf16: 2342.84375000 , time:0.01102543ms
out_bf16x8packf32: 2338.20483398 , time:0.01086307ms
out_bf16x8packbf16: 2337.92187500 , time:0.01081562ms
out_bf16bf16_th: 2336.00000000 , time:0.01267934ms
--------------------------------------------------------------------------------
out_f8e4m3f16: 2309.35156250 , time:0.01085925ms
out_f8e4m3x16packf16: 2309.84179688 , time:0.01080227ms
out_f8e4m3f16_th: 2310.00000000 , time:0.01269531ms
--------------------------------------------------------------------------------
out_f8e5m2f16: 2301.28515625 , time:0.01083851ms
out_f8e5m2x16packf16: 2301.70507812 , time:0.01085329ms
out_f8e5m2f16_th: 2302.00000000 , time:0.01272225ms
--------------------------------------------------------------------------------
out_i8i32: 701 , time:0.01100421ms
out_i8x16packi32: 701 , time:0.01082873ms
out_i8i32_th: 701 , time:0.02053738ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
S=1024, K=2048
out_f32f32: 1462.08410645 , time:0.01540542ms
out_f32x4f32: 1462.08703613 , time:0.01092076ms
out_f32f32_th: 1462.08715820 , time:0.01292896ms
--------------------------------------------------------------------------------
out_f16f16: 1462.19628906 , time:0.01528525ms
out_f16f32: 1462.07800293 , time:0.01524353ms
out_f16x2f32: 1461.92590332 , time:0.01080871ms
out_f16x2f16: 1462.66113281 , time:0.01084566ms
out_f16x8packf16: 1462.65771484 , time:0.01090646ms
out_f16x8packf32: 1462.07800293 , time:0.01088381ms
out_f16f16_th: 1462.00000000 , time:0.01281977ms
--------------------------------------------------------------------------------
out_bf16bf16: 1448.33593750 , time:0.01571774ms
out_bf16f32: 1455.72143555 , time:0.01525044ms
out_bf16x2f32: 1455.90173340 , time:0.01076150ms
out_bf16x2bf16: 1468.06250000 , time:0.01074433ms
out_bf16x8packf32: 1448.51782227 , time:0.01088881ms
out_bf16x8packbf16: 1445.67187500 , time:0.01087308ms
out_bf16bf16_th: 1456.00000000 , time:0.01274490ms
--------------------------------------------------------------------------------
out_f8e4m3f16: 1379.93945312 , time:0.01552987ms
out_f8e4m3x16packf16: 1380.08593750 , time:0.01084304ms
out_f8e4m3f16_th: 1380.00000000 , time:0.01291776ms
--------------------------------------------------------------------------------
out_f8e5m2f16: 1393.96679688 , time:0.01553226ms
out_f8e5m2x16packf16: 1394.45703125 , time:0.01087117ms
out_f8e5m2f16_th: 1394.00000000 , time:0.01287508ms
--------------------------------------------------------------------------------
out_i8i32: 837 , time:0.01561451ms
out_i8x16packi32: 837 , time:0.01086545ms
out_i8i32_th: 837 , time:0.02074814ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
S=1024, K=4096
out_f32f32: 739.76281738 , time:0.02775788ms
out_f32x4f32: 739.76159668 , time:0.01103473ms
out_f32f32_th: 739.76123047 , time:0.01302052ms
--------------------------------------------------------------------------------
out_f16f16: 741.25573730 , time:0.02747726ms
out_f16f32: 740.10363770 , time:0.02743196ms
out_f16x2f32: 740.54528809 , time:0.01134133ms
out_f16x2f16: 741.26928711 , time:0.01139235ms
out_f16x8packf16: 739.66699219 , time:0.01096582ms
out_f16x8packf32: 740.10540771 , time:0.01099253ms
out_f16f16_th: 740.00000000 , time:0.01312995ms
--------------------------------------------------------------------------------
out_bf16bf16: 698.61718750 , time:0.02839518ms
out_bf16f32: 733.34777832 , time:0.02742386ms
out_bf16x2f32: 732.55615234 , time:0.01130605ms
out_bf16x2bf16: 733.70312500 , time:0.01158929ms
out_bf16x8packf32: 740.38952637 , time:0.01100612ms
out_bf16x8packbf16: 721.00000000 , time:0.01092744ms
out_bf16bf16_th: 732.00000000 , time:0.01304460ms
--------------------------------------------------------------------------------
out_f8e4m3f16: 694.41406250 , time:0.02810597ms
out_f8e4m3x16packf16: 695.38085938 , time:0.01093888ms
out_f8e4m3f16_th: 695.00000000 , time:0.01317239ms
--------------------------------------------------------------------------------
out_f8e5m2f16: 745.15625000 , time:0.02799416ms
out_f8e5m2x16packf16: 743.39062500 , time:0.01095843ms
out_f8e5m2f16_th: 743.50000000 , time:0.01320124ms
--------------------------------------------------------------------------------
out_i8i32: 822 , time:0.02806473ms
out_i8x16packi32: 822 , time:0.01087379ms
out_i8i32_th: 822 , time:0.03051281ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
S=2048, K=1024
out_f32f32: -2130.84204102 , time:0.01537204ms
out_f32x4f32: -2130.84301758 , time:0.01095867ms
out_f32f32_th: -2130.84130859 , time:0.01302028ms
--------------------------------------------------------------------------------
out_f16f16: -2130.79760742 , time:0.01521087ms
out_f16f32: -2131.24804688 , time:0.01520824ms
out_f16x2f32: -2131.21093750 , time:0.01098013ms
out_f16x2f16: -2131.38061523 , time:0.01096201ms
out_f16x8packf16: -2131.14306641 , time:0.01099825ms
out_f16x8packf32: -2131.24682617 , time:0.01098967ms
out_f16f16_th: -2132.00000000 , time:0.01285267ms
--------------------------------------------------------------------------------
out_bf16bf16: -2125.26562500 , time:0.01568747ms
out_bf16f32: -2136.03564453 , time:0.01521087ms
out_bf16x2f32: -2137.35693359 , time:0.01096845ms
out_bf16x2bf16: -2140.95312500 , time:0.01101232ms
out_bf16x8packf32: -2138.99365234 , time:0.01102495ms
out_bf16x8packbf16: -2131.59375000 , time:0.01095581ms
out_bf16bf16_th: -2144.00000000 , time:0.01276922ms
--------------------------------------------------------------------------------
out_f8e4m3f16: -2148.03710938 , time:0.01552343ms
out_f8e4m3x16packf16: -2149.79687500 , time:0.01097393ms
out_f8e4m3f16_th: -2148.00000000 , time:0.01288629ms
--------------------------------------------------------------------------------
out_f8e5m2f16: -1940.82031250 , time:0.01552033ms
out_f8e5m2x16packf16: -1941.06640625 , time:0.01102161ms
out_f8e5m2f16_th: -1941.00000000 , time:0.01291919ms
--------------------------------------------------------------------------------
out_i8i32: -733 , time:0.01552916ms
out_i8x16packi32: -733 , time:0.01085472ms
out_i8i32_th: -733 , time:0.02085137ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
S=2048, K=2048
out_f32f32: 3670.03320312 , time:0.02772927ms
out_f32x4f32: 3670.03369141 , time:0.01098394ms
out_f32f32_th: 3670.03320312 , time:0.01311922ms
--------------------------------------------------------------------------------
out_f16f16: 3671.47851562 , time:0.02740979ms
out_f16f32: 3670.10327148 , time:0.02740788ms
out_f16x2f32: 3669.92382812 , time:0.01568246ms
out_f16x2f16: 3669.72900391 , time:0.01553798ms
out_f16x8packf16: 3669.43310547 , time:0.01100826ms
out_f16x8packf32: 3670.10571289 , time:0.01098967ms
out_f16f16_th: 3670.00000000 , time:0.01318645ms
--------------------------------------------------------------------------------
out_bf16bf16: 3645.65625000 , time:0.02837849ms
out_bf16f32: 3665.68139648 , time:0.02740622ms
out_bf16x2f32: 3663.63159180 , time:0.01556945ms
out_bf16x2bf16: 3660.78125000 , time:0.01621652ms
out_bf16x8packf32: 3667.81567383 , time:0.01097989ms
out_bf16x8packbf16: 3667.12500000 , time:0.01092386ms
out_bf16bf16_th: 3664.00000000 , time:0.01311040ms
--------------------------------------------------------------------------------
out_f8e4m3f16: 3523.93554688 , time:0.02799201ms
out_f8e4m3x16packf16: 3525.20117188 , time:0.01099467ms
out_f8e4m3f16_th: 3522.00000000 , time:0.01320076ms
--------------------------------------------------------------------------------
out_f8e5m2f16: 3771.88085938 , time:0.02798295ms
out_f8e5m2x16packf16: 3769.36523438 , time:0.01100039ms
out_f8e5m2f16_th: 3772.00000000 , time:0.01321268ms
--------------------------------------------------------------------------------
out_i8i32: 3211 , time:0.02805424ms
out_i8x16packi32: 3211 , time:0.01092052ms
out_i8i32_th: 3211 , time:0.03049159ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
S=2048, K=4096
out_f32f32: -224.59820557 , time:0.05256987ms
out_f32x4f32: -224.59832764 , time:0.01848841ms
out_f32f32_th: -224.59960938 , time:0.01432800ms
--------------------------------------------------------------------------------
out_f16f16: -226.07080078 , time:0.05197215ms
out_f16f32: -225.04293823 , time:0.05196261ms
out_f16x2f32: -224.60795593 , time:0.01924849ms
out_f16x2f16: -222.50659180 , time:0.01940060ms
out_f16x8packf16: -226.12500000 , time:0.01106429ms
out_f16x8packf32: -225.04100037 , time:0.01104879ms
out_f16f16_th: -225.00000000 , time:0.01317596ms
--------------------------------------------------------------------------------
out_bf16bf16: -232.43750000 , time:0.05375409ms
out_bf16f32: -227.55787659 , time:0.05195498ms
out_bf16x2f32: -226.89102173 , time:0.01925349ms
out_bf16x2bf16: -238.76562500 , time:0.01973200ms
out_bf16x8packf32: -209.35949707 , time:0.01107860ms
out_bf16x8packbf16: -180.07812500 , time:0.01107240ms
out_bf16bf16_th: -228.00000000 , time:0.01309705ms
--------------------------------------------------------------------------------
out_f8e4m3f16: -264.29296875 , time:0.05309629ms
out_f8e4m3x16packf16: -264.98437500 , time:0.01105237ms
out_f8e4m3f16_th: -264.75000000 , time:0.01320648ms
--------------------------------------------------------------------------------
out_f8e5m2f16: -548.05859375 , time:0.05309629ms
out_f8e5m2x16packf16: -550.15625000 , time:0.01108718ms
out_f8e5m2f16_th: -551.00000000 , time:0.01320672ms
--------------------------------------------------------------------------------
out_i8i32: 1496 , time:0.05338764ms
out_i8x16packi32: 1496 , time:0.01122260ms
out_i8i32_th: 1496 , time:0.05263305ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
S=4096, K=1024
out_f32f32: -3509.85327148 , time:0.02772665ms
out_f32x4f32: -3509.85058594 , time:0.01091909ms
out_f32f32_th: -3509.84814453 , time:0.01311827ms
--------------------------------------------------------------------------------
out_f16f16: -3510.24316406 , time:0.02741575ms
out_f16f32: -3509.40112305 , time:0.02740836ms
out_f16x2f32: -3510.73193359 , time:0.01125741ms
out_f16x2f16: -3511.23486328 , time:0.01131368ms
out_f16x8packf16: -3508.63037109 , time:0.01093841ms
out_f16x8packf32: -3509.39965820 , time:0.01094270ms
out_f16f16_th: -3510.00000000 , time:0.01312375ms
--------------------------------------------------------------------------------
out_bf16bf16: -3506.85937500 , time:0.02836037ms
out_bf16f32: -3507.63208008 , time:0.02740240ms
out_bf16x2f32: -3511.86767578 , time:0.01122832ms
out_bf16x2bf16: -3522.76562500 , time:0.01150727ms
out_bf16x8packf32: -3497.67138672 , time:0.01103187ms
out_bf16x8packbf16: -3492.68750000 , time:0.01093888ms
out_bf16bf16_th: -3504.00000000 , time:0.01303720ms
--------------------------------------------------------------------------------
out_f8e4m3f16: -3450.27148438 , time:0.02798390ms
out_f8e4m3x16packf16: -3450.74218750 , time:0.01089668ms
out_f8e4m3f16_th: -3450.00000000 , time:0.01320052ms
--------------------------------------------------------------------------------
out_f8e5m2f16: -3496.27734375 , time:0.02799821ms
out_f8e5m2x16packf16: -3496.75195312 , time:0.01097918ms
out_f8e5m2f16_th: -3496.00000000 , time:0.01321816ms
--------------------------------------------------------------------------------
out_i8i32: -3024 , time:0.02805233ms
out_i8x16packi32: -3024 , time:0.01106644ms
out_i8i32_th: -3024 , time:0.03049588ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
S=4096, K=2048
out_f32f32: 2020.18762207 , time:0.05256009ms
out_f32x4f32: 2020.18798828 , time:0.01283932ms
out_f32f32_th: 2020.18493652 , time:0.01435518ms
--------------------------------------------------------------------------------
out_f16f16: 2021.24536133 , time:0.05197549ms
out_f16f32: 2020.26733398 , time:0.05196571ms
out_f16x2f32: 2020.64233398 , time:0.02826214ms
out_f16x2f16: 2022.50488281 , time:0.02805209ms
out_f16x8packf16: 2021.27587891 , time:0.01094484ms
out_f16x8packf32: 2020.26684570 , time:0.01087451ms
out_f16f16_th: 2020.00000000 , time:0.01319718ms
--------------------------------------------------------------------------------
out_bf16bf16: 2039.31250000 , time:0.05375671ms
out_bf16f32: 2022.53759766 , time:0.05195117ms
out_bf16x2f32: 2012.04919434 , time:0.02807426ms
out_bf16x2bf16: 2007.21875000 , time:0.02934122ms
out_bf16x8packf32: 2044.72705078 , time:0.01095700ms
out_bf16x8packbf16: 2029.09375000 , time:0.01089287ms
out_bf16bf16_th: 2024.00000000 , time:0.01314187ms
--------------------------------------------------------------------------------
out_f8e4m3f16: 1925.97851562 , time:0.05310130ms
out_f8e4m3x16packf16: 1924.67187500 , time:0.01106429ms
out_f8e4m3f16_th: 1927.00000000 , time:0.01334405ms
--------------------------------------------------------------------------------
out_f8e5m2f16: 2322.31445312 , time:0.05310130ms
out_f8e5m2x16packf16: 2324.93359375 , time:0.01109791ms
out_f8e5m2f16_th: 2324.00000000 , time:0.01333547ms
--------------------------------------------------------------------------------
out_i8i32: 1701 , time:0.05337167ms
out_i8x16packi32: 1701 , time:0.01122093ms
out_i8i32_th: 1701 , time:0.05257368ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
S=4096, K=4096
out_f32f32: -2295.19458008 , time:0.10227132ms
out_f32x4f32: -2295.19702148 , time:0.03361320ms
out_f32f32_th: -2295.19946289 , time:0.02290916ms
--------------------------------------------------------------------------------
out_f16f16: -2293.83764648 , time:0.10097337ms
out_f16f32: -2296.36425781 , time:0.10095334ms
out_f16x2f32: -2297.93896484 , time:0.03533483ms
out_f16x2f16: -2297.96386719 , time:0.03572583ms
out_f16x8packf16: -2299.68701172 , time:0.01311255ms
out_f16x8packf32: -2296.36645508 , time:0.01308966ms
out_f16f16_th: -2296.00000000 , time:0.01445580ms
--------------------------------------------------------------------------------
out_bf16bf16: -2264.30468750 , time:0.10450244ms
out_bf16f32: -2293.59399414 , time:0.10095382ms
out_bf16x2f32: -2299.56005859 , time:0.03533602ms
out_bf16x2bf16: -2284.02343750 , time:0.03620267ms
out_bf16x8packf32: -2290.28173828 , time:0.01310396ms
out_bf16x8packbf16: -2282.46875000 , time:0.01368093ms
out_bf16bf16_th: -2288.00000000 , time:0.01442218ms
--------------------------------------------------------------------------------
out_f8e4m3f16: -2332.72070312 , time:0.10321760ms
out_f8e4m3x16packf16: -2329.65625000 , time:0.01123261ms
out_f8e4m3f16_th: -2330.00000000 , time:0.01445007ms
--------------------------------------------------------------------------------
out_f8e5m2f16: -2035.82812500 , time:0.10325360ms
out_f8e5m2x16packf16: -2034.17187500 , time:0.01119351ms
out_f8e5m2f16_th: -2036.00000000 , time:0.01442766ms
--------------------------------------------------------------------------------
out_i8i32: -2746 , time:0.10370731ms
out_i8x16packi32: -2746 , time:0.01133108ms
out_i8i32_th: -2746 , time:0.36144137ms
--------------------------------------------------------------------------------