Skip to content

Latest commit

 

History

History

reduce

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 

Reduce

0x00 说明

包含以下内容:

  • warp_reduce_fp32/fp16/bf16_kernel
  • block_reduce_fp32_kernel
  • block_all_reduce_sum_f32_f32_kernel
  • block_all_reduce_sum_f32x4_f32_kernel(float4向量化版本)
  • block_all_reduce_sum_f16_f16_kernel(fp16版本,使用fp16 acc)
  • block_all_reduce_sum_f16_f32_kernel(fp16版本,使用fp32 acc)
  • block_all_reduce_sum_f16x2_f16_kernel(fp16向量化版本,使用fp16 acc)
  • block_all_reduce_sum_f16x2_f32_kernel(fp16向量化版本,使用fp32 acc)
  • block_all_reduce_sum_f16x8_pack_f16_kernel(fp16向量化版本,使用fp16 acc, pack)
  • block_all_reduce_sum_f16x8_pack_f32_kernel(fp16向量化版本,使用fp32 acc, pack)
  • block_all_reduce_sum_bf16_bf16_kernel(bf16版本,使用bf16 acc)
  • block_all_reduce_sum_bf16_f32_kernel(bf16版本,使用fp32 acc)
  • block_all_reduce_sum_bf16x8_pack_bf16_kernel(bf16版本,使用bf16 acc, pack)
  • block_all_reduce_sum_bf16x8_pack_f32_kernel(bf16版本,使用fp32 acc, pack)
  • block_all_reduce_sum_bf16x2_bf16_kernel(bf16向量化版本,使用bf16 acc)
  • block_all_reduce_sum_bf16x2_f32_kernel(bf16向量化版本,使用fp32 acc)
  • block_all_reduce_sum_fp8_e4m3_f16_kernel(fp8_e4m3版本,使用fp16 acc)
  • block_all_reduce_sum_fp8_e5m2_f16_kernel(fp8_e5m2版本,使用fp16 acc)
  • block_all_reduce_sum_fp8_e4m3x16_pack_f16_kernel(fp8_e4m3版本,使用fp16 acc, pack)
  • block_all_reduce_sum_fp8_e5m2x16_pack_f16_kernel(fp8_e5m2版本,使用fp16 acc, pack)
  • block_all_reduce_sum_i8_i32_kernel(i8版本,使用i32 acc)
  • block_all_reduce_sum_i8x16_pack_i32_kernel(i8版本,使用i32 acc, pack)
  • PyTorch bindings for block reduce fp32/fp16/bf16/fp8/i8 kernels

所有支持的block all reduce kernel:

// packed_type, acc_type, th_type, element_type, n_elements_per_pack, out_type
TORCH_BINDING_REDUCE(f32,              f32,  torch::kFloat32,       float,              1,  float)
TORCH_BINDING_REDUCE(f32x4,            f32,  torch::kFloat32,       float,              4,  float)
TORCH_BINDING_REDUCE(f16,              f16,  torch::kHalf,          half,               1,  float)
TORCH_BINDING_REDUCE(f16,              f32,  torch::kHalf,          half,               1,  float)
TORCH_BINDING_REDUCE(f16x2,            f16,  torch::kHalf,          half,               2,  float)
TORCH_BINDING_REDUCE(f16x2,            f32,  torch::kHalf,          half,               2,  float)
TORCH_BINDING_REDUCE(f16x8_pack,       f16,  torch::kHalf,          half,               8,  float)
TORCH_BINDING_REDUCE(f16x8_pack,       f32,  torch::kHalf,          half,               8,  float)
TORCH_BINDING_REDUCE(bf16,             bf16, torch::kBFloat16,      __nv_bfloat16,      1,  float)
TORCH_BINDING_REDUCE(bf16,             f32,  torch::kBFloat16,      __nv_bfloat16,      1,  float)
TORCH_BINDING_REDUCE(bf16x2,           bf16, torch::kBFloat16,      __nv_bfloat16,      2,  float)
TORCH_BINDING_REDUCE(bf16x2,           f32,  torch::kBFloat16,      __nv_bfloat16,      2,  float)
TORCH_BINDING_REDUCE(bf16x8_pack,      bf16, torch::kBFloat16,      __nv_bfloat16,      8,  float)
TORCH_BINDING_REDUCE(bf16x8_pack,      f32,  torch::kBFloat16,      __nv_bfloat16,      8,  float)
TORCH_BINDING_REDUCE(fp8_e4m3,         f16,  torch::kFloat8_e4m3fn, __nv_fp8_storage_t, 1,  float)
TORCH_BINDING_REDUCE(fp8_e4m3x16_pack, f16,  torch::kFloat8_e4m3fn, __nv_fp8_storage_t, 16, float)
TORCH_BINDING_REDUCE(fp8_e5m2,         f16,  torch::kFloat8_e5m2,   __nv_fp8_storage_t, 1,  float)
TORCH_BINDING_REDUCE(fp8_e5m2x16_pack, f16,  torch::kFloat8_e5m2,   __nv_fp8_storage_t, 16, float)
TORCH_BINDING_REDUCE(i8,               i32,  torch::kInt8,          int8_t,             1,  int32_t)
TORCH_BINDING_REDUCE(i8x16_pack,       i32,  torch::kInt8,          int8_t,             16, int32_t)

测试

# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada 
python3 block_all_reduce.py

输出:

--------------------------------------------------------------------------------
                                        S=1024, K=1024
               out_f32f32: 2340.52685547  , time:0.01104856ms
             out_f32x4f32: 2340.52539062  , time:0.01093769ms
            out_f32f32_th: 2340.52563477  , time:0.01271653ms
--------------------------------------------------------------------------------
               out_f16f16: 2341.14599609  , time:0.01084900ms
               out_f16f32: 2340.54711914  , time:0.01082754ms
             out_f16x2f32: 2340.57177734  , time:0.01104403ms
             out_f16x2f16: 2340.59106445  , time:0.01104188ms
         out_f16x8packf16: 2340.04199219  , time:0.01084495ms
         out_f16x8packf32: 2340.54785156  , time:0.01081753ms
            out_f16f16_th: 2340.00000000  , time:0.01274228ms
--------------------------------------------------------------------------------
             out_bf16bf16: 2341.23437500  , time:0.01074696ms
              out_bf16f32: 2341.17846680  , time:0.01436758ms
            out_bf16x2f32: 2343.69921875  , time:0.01096773ms
           out_bf16x2bf16: 2342.84375000  , time:0.01102543ms
        out_bf16x8packf32: 2338.20483398  , time:0.01086307ms
       out_bf16x8packbf16: 2337.92187500  , time:0.01081562ms
          out_bf16bf16_th: 2336.00000000  , time:0.01267934ms
--------------------------------------------------------------------------------
            out_f8e4m3f16: 2309.35156250  , time:0.01085925ms
     out_f8e4m3x16packf16: 2309.84179688  , time:0.01080227ms
         out_f8e4m3f16_th: 2310.00000000  , time:0.01269531ms
--------------------------------------------------------------------------------
            out_f8e5m2f16: 2301.28515625  , time:0.01083851ms
     out_f8e5m2x16packf16: 2301.70507812  , time:0.01085329ms
         out_f8e5m2f16_th: 2302.00000000  , time:0.01272225ms
--------------------------------------------------------------------------------
                out_i8i32: 701            , time:0.01100421ms
         out_i8x16packi32: 701            , time:0.01082873ms
             out_i8i32_th: 701            , time:0.02053738ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
                                        S=1024, K=2048
               out_f32f32: 1462.08410645  , time:0.01540542ms
             out_f32x4f32: 1462.08703613  , time:0.01092076ms
            out_f32f32_th: 1462.08715820  , time:0.01292896ms
--------------------------------------------------------------------------------
               out_f16f16: 1462.19628906  , time:0.01528525ms
               out_f16f32: 1462.07800293  , time:0.01524353ms
             out_f16x2f32: 1461.92590332  , time:0.01080871ms
             out_f16x2f16: 1462.66113281  , time:0.01084566ms
         out_f16x8packf16: 1462.65771484  , time:0.01090646ms
         out_f16x8packf32: 1462.07800293  , time:0.01088381ms
            out_f16f16_th: 1462.00000000  , time:0.01281977ms
--------------------------------------------------------------------------------
             out_bf16bf16: 1448.33593750  , time:0.01571774ms
              out_bf16f32: 1455.72143555  , time:0.01525044ms
            out_bf16x2f32: 1455.90173340  , time:0.01076150ms
           out_bf16x2bf16: 1468.06250000  , time:0.01074433ms
        out_bf16x8packf32: 1448.51782227  , time:0.01088881ms
       out_bf16x8packbf16: 1445.67187500  , time:0.01087308ms
          out_bf16bf16_th: 1456.00000000  , time:0.01274490ms
--------------------------------------------------------------------------------
            out_f8e4m3f16: 1379.93945312  , time:0.01552987ms
     out_f8e4m3x16packf16: 1380.08593750  , time:0.01084304ms
         out_f8e4m3f16_th: 1380.00000000  , time:0.01291776ms
--------------------------------------------------------------------------------
            out_f8e5m2f16: 1393.96679688  , time:0.01553226ms
     out_f8e5m2x16packf16: 1394.45703125  , time:0.01087117ms
         out_f8e5m2f16_th: 1394.00000000  , time:0.01287508ms
--------------------------------------------------------------------------------
                out_i8i32: 837            , time:0.01561451ms
         out_i8x16packi32: 837            , time:0.01086545ms
             out_i8i32_th: 837            , time:0.02074814ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
                                        S=1024, K=4096
               out_f32f32: 739.76281738   , time:0.02775788ms
             out_f32x4f32: 739.76159668   , time:0.01103473ms
            out_f32f32_th: 739.76123047   , time:0.01302052ms
--------------------------------------------------------------------------------
               out_f16f16: 741.25573730   , time:0.02747726ms
               out_f16f32: 740.10363770   , time:0.02743196ms
             out_f16x2f32: 740.54528809   , time:0.01134133ms
             out_f16x2f16: 741.26928711   , time:0.01139235ms
         out_f16x8packf16: 739.66699219   , time:0.01096582ms
         out_f16x8packf32: 740.10540771   , time:0.01099253ms
            out_f16f16_th: 740.00000000   , time:0.01312995ms
--------------------------------------------------------------------------------
             out_bf16bf16: 698.61718750   , time:0.02839518ms
              out_bf16f32: 733.34777832   , time:0.02742386ms
            out_bf16x2f32: 732.55615234   , time:0.01130605ms
           out_bf16x2bf16: 733.70312500   , time:0.01158929ms
        out_bf16x8packf32: 740.38952637   , time:0.01100612ms
       out_bf16x8packbf16: 721.00000000   , time:0.01092744ms
          out_bf16bf16_th: 732.00000000   , time:0.01304460ms
--------------------------------------------------------------------------------
            out_f8e4m3f16: 694.41406250   , time:0.02810597ms
     out_f8e4m3x16packf16: 695.38085938   , time:0.01093888ms
         out_f8e4m3f16_th: 695.00000000   , time:0.01317239ms
--------------------------------------------------------------------------------
            out_f8e5m2f16: 745.15625000   , time:0.02799416ms
     out_f8e5m2x16packf16: 743.39062500   , time:0.01095843ms
         out_f8e5m2f16_th: 743.50000000   , time:0.01320124ms
--------------------------------------------------------------------------------
                out_i8i32: 822            , time:0.02806473ms
         out_i8x16packi32: 822            , time:0.01087379ms
             out_i8i32_th: 822            , time:0.03051281ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
                                        S=2048, K=1024
               out_f32f32: -2130.84204102 , time:0.01537204ms
             out_f32x4f32: -2130.84301758 , time:0.01095867ms
            out_f32f32_th: -2130.84130859 , time:0.01302028ms
--------------------------------------------------------------------------------
               out_f16f16: -2130.79760742 , time:0.01521087ms
               out_f16f32: -2131.24804688 , time:0.01520824ms
             out_f16x2f32: -2131.21093750 , time:0.01098013ms
             out_f16x2f16: -2131.38061523 , time:0.01096201ms
         out_f16x8packf16: -2131.14306641 , time:0.01099825ms
         out_f16x8packf32: -2131.24682617 , time:0.01098967ms
            out_f16f16_th: -2132.00000000 , time:0.01285267ms
--------------------------------------------------------------------------------
             out_bf16bf16: -2125.26562500 , time:0.01568747ms
              out_bf16f32: -2136.03564453 , time:0.01521087ms
            out_bf16x2f32: -2137.35693359 , time:0.01096845ms
           out_bf16x2bf16: -2140.95312500 , time:0.01101232ms
        out_bf16x8packf32: -2138.99365234 , time:0.01102495ms
       out_bf16x8packbf16: -2131.59375000 , time:0.01095581ms
          out_bf16bf16_th: -2144.00000000 , time:0.01276922ms
--------------------------------------------------------------------------------
            out_f8e4m3f16: -2148.03710938 , time:0.01552343ms
     out_f8e4m3x16packf16: -2149.79687500 , time:0.01097393ms
         out_f8e4m3f16_th: -2148.00000000 , time:0.01288629ms
--------------------------------------------------------------------------------
            out_f8e5m2f16: -1940.82031250 , time:0.01552033ms
     out_f8e5m2x16packf16: -1941.06640625 , time:0.01102161ms
         out_f8e5m2f16_th: -1941.00000000 , time:0.01291919ms
--------------------------------------------------------------------------------
                out_i8i32: -733           , time:0.01552916ms
         out_i8x16packi32: -733           , time:0.01085472ms
             out_i8i32_th: -733           , time:0.02085137ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
                                        S=2048, K=2048
               out_f32f32: 3670.03320312  , time:0.02772927ms
             out_f32x4f32: 3670.03369141  , time:0.01098394ms
            out_f32f32_th: 3670.03320312  , time:0.01311922ms
--------------------------------------------------------------------------------
               out_f16f16: 3671.47851562  , time:0.02740979ms
               out_f16f32: 3670.10327148  , time:0.02740788ms
             out_f16x2f32: 3669.92382812  , time:0.01568246ms
             out_f16x2f16: 3669.72900391  , time:0.01553798ms
         out_f16x8packf16: 3669.43310547  , time:0.01100826ms
         out_f16x8packf32: 3670.10571289  , time:0.01098967ms
            out_f16f16_th: 3670.00000000  , time:0.01318645ms
--------------------------------------------------------------------------------
             out_bf16bf16: 3645.65625000  , time:0.02837849ms
              out_bf16f32: 3665.68139648  , time:0.02740622ms
            out_bf16x2f32: 3663.63159180  , time:0.01556945ms
           out_bf16x2bf16: 3660.78125000  , time:0.01621652ms
        out_bf16x8packf32: 3667.81567383  , time:0.01097989ms
       out_bf16x8packbf16: 3667.12500000  , time:0.01092386ms
          out_bf16bf16_th: 3664.00000000  , time:0.01311040ms
--------------------------------------------------------------------------------
            out_f8e4m3f16: 3523.93554688  , time:0.02799201ms
     out_f8e4m3x16packf16: 3525.20117188  , time:0.01099467ms
         out_f8e4m3f16_th: 3522.00000000  , time:0.01320076ms
--------------------------------------------------------------------------------
            out_f8e5m2f16: 3771.88085938  , time:0.02798295ms
     out_f8e5m2x16packf16: 3769.36523438  , time:0.01100039ms
         out_f8e5m2f16_th: 3772.00000000  , time:0.01321268ms
--------------------------------------------------------------------------------
                out_i8i32: 3211           , time:0.02805424ms
         out_i8x16packi32: 3211           , time:0.01092052ms
             out_i8i32_th: 3211           , time:0.03049159ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
                                        S=2048, K=4096
               out_f32f32: -224.59820557  , time:0.05256987ms
             out_f32x4f32: -224.59832764  , time:0.01848841ms
            out_f32f32_th: -224.59960938  , time:0.01432800ms
--------------------------------------------------------------------------------
               out_f16f16: -226.07080078  , time:0.05197215ms
               out_f16f32: -225.04293823  , time:0.05196261ms
             out_f16x2f32: -224.60795593  , time:0.01924849ms
             out_f16x2f16: -222.50659180  , time:0.01940060ms
         out_f16x8packf16: -226.12500000  , time:0.01106429ms
         out_f16x8packf32: -225.04100037  , time:0.01104879ms
            out_f16f16_th: -225.00000000  , time:0.01317596ms
--------------------------------------------------------------------------------
             out_bf16bf16: -232.43750000  , time:0.05375409ms
              out_bf16f32: -227.55787659  , time:0.05195498ms
            out_bf16x2f32: -226.89102173  , time:0.01925349ms
           out_bf16x2bf16: -238.76562500  , time:0.01973200ms
        out_bf16x8packf32: -209.35949707  , time:0.01107860ms
       out_bf16x8packbf16: -180.07812500  , time:0.01107240ms
          out_bf16bf16_th: -228.00000000  , time:0.01309705ms
--------------------------------------------------------------------------------
            out_f8e4m3f16: -264.29296875  , time:0.05309629ms
     out_f8e4m3x16packf16: -264.98437500  , time:0.01105237ms
         out_f8e4m3f16_th: -264.75000000  , time:0.01320648ms
--------------------------------------------------------------------------------
            out_f8e5m2f16: -548.05859375  , time:0.05309629ms
     out_f8e5m2x16packf16: -550.15625000  , time:0.01108718ms
         out_f8e5m2f16_th: -551.00000000  , time:0.01320672ms
--------------------------------------------------------------------------------
                out_i8i32: 1496           , time:0.05338764ms
         out_i8x16packi32: 1496           , time:0.01122260ms
             out_i8i32_th: 1496           , time:0.05263305ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
                                        S=4096, K=1024
               out_f32f32: -3509.85327148 , time:0.02772665ms
             out_f32x4f32: -3509.85058594 , time:0.01091909ms
            out_f32f32_th: -3509.84814453 , time:0.01311827ms
--------------------------------------------------------------------------------
               out_f16f16: -3510.24316406 , time:0.02741575ms
               out_f16f32: -3509.40112305 , time:0.02740836ms
             out_f16x2f32: -3510.73193359 , time:0.01125741ms
             out_f16x2f16: -3511.23486328 , time:0.01131368ms
         out_f16x8packf16: -3508.63037109 , time:0.01093841ms
         out_f16x8packf32: -3509.39965820 , time:0.01094270ms
            out_f16f16_th: -3510.00000000 , time:0.01312375ms
--------------------------------------------------------------------------------
             out_bf16bf16: -3506.85937500 , time:0.02836037ms
              out_bf16f32: -3507.63208008 , time:0.02740240ms
            out_bf16x2f32: -3511.86767578 , time:0.01122832ms
           out_bf16x2bf16: -3522.76562500 , time:0.01150727ms
        out_bf16x8packf32: -3497.67138672 , time:0.01103187ms
       out_bf16x8packbf16: -3492.68750000 , time:0.01093888ms
          out_bf16bf16_th: -3504.00000000 , time:0.01303720ms
--------------------------------------------------------------------------------
            out_f8e4m3f16: -3450.27148438 , time:0.02798390ms
     out_f8e4m3x16packf16: -3450.74218750 , time:0.01089668ms
         out_f8e4m3f16_th: -3450.00000000 , time:0.01320052ms
--------------------------------------------------------------------------------
            out_f8e5m2f16: -3496.27734375 , time:0.02799821ms
     out_f8e5m2x16packf16: -3496.75195312 , time:0.01097918ms
         out_f8e5m2f16_th: -3496.00000000 , time:0.01321816ms
--------------------------------------------------------------------------------
                out_i8i32: -3024          , time:0.02805233ms
         out_i8x16packi32: -3024          , time:0.01106644ms
             out_i8i32_th: -3024          , time:0.03049588ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
                                        S=4096, K=2048
               out_f32f32: 2020.18762207  , time:0.05256009ms
             out_f32x4f32: 2020.18798828  , time:0.01283932ms
            out_f32f32_th: 2020.18493652  , time:0.01435518ms
--------------------------------------------------------------------------------
               out_f16f16: 2021.24536133  , time:0.05197549ms
               out_f16f32: 2020.26733398  , time:0.05196571ms
             out_f16x2f32: 2020.64233398  , time:0.02826214ms
             out_f16x2f16: 2022.50488281  , time:0.02805209ms
         out_f16x8packf16: 2021.27587891  , time:0.01094484ms
         out_f16x8packf32: 2020.26684570  , time:0.01087451ms
            out_f16f16_th: 2020.00000000  , time:0.01319718ms
--------------------------------------------------------------------------------
             out_bf16bf16: 2039.31250000  , time:0.05375671ms
              out_bf16f32: 2022.53759766  , time:0.05195117ms
            out_bf16x2f32: 2012.04919434  , time:0.02807426ms
           out_bf16x2bf16: 2007.21875000  , time:0.02934122ms
        out_bf16x8packf32: 2044.72705078  , time:0.01095700ms
       out_bf16x8packbf16: 2029.09375000  , time:0.01089287ms
          out_bf16bf16_th: 2024.00000000  , time:0.01314187ms
--------------------------------------------------------------------------------
            out_f8e4m3f16: 1925.97851562  , time:0.05310130ms
     out_f8e4m3x16packf16: 1924.67187500  , time:0.01106429ms
         out_f8e4m3f16_th: 1927.00000000  , time:0.01334405ms
--------------------------------------------------------------------------------
            out_f8e5m2f16: 2322.31445312  , time:0.05310130ms
     out_f8e5m2x16packf16: 2324.93359375  , time:0.01109791ms
         out_f8e5m2f16_th: 2324.00000000  , time:0.01333547ms
--------------------------------------------------------------------------------
                out_i8i32: 1701           , time:0.05337167ms
         out_i8x16packi32: 1701           , time:0.01122093ms
             out_i8i32_th: 1701           , time:0.05257368ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
                                        S=4096, K=4096
               out_f32f32: -2295.19458008 , time:0.10227132ms
             out_f32x4f32: -2295.19702148 , time:0.03361320ms
            out_f32f32_th: -2295.19946289 , time:0.02290916ms
--------------------------------------------------------------------------------
               out_f16f16: -2293.83764648 , time:0.10097337ms
               out_f16f32: -2296.36425781 , time:0.10095334ms
             out_f16x2f32: -2297.93896484 , time:0.03533483ms
             out_f16x2f16: -2297.96386719 , time:0.03572583ms
         out_f16x8packf16: -2299.68701172 , time:0.01311255ms
         out_f16x8packf32: -2296.36645508 , time:0.01308966ms
            out_f16f16_th: -2296.00000000 , time:0.01445580ms
--------------------------------------------------------------------------------
             out_bf16bf16: -2264.30468750 , time:0.10450244ms
              out_bf16f32: -2293.59399414 , time:0.10095382ms
            out_bf16x2f32: -2299.56005859 , time:0.03533602ms
           out_bf16x2bf16: -2284.02343750 , time:0.03620267ms
        out_bf16x8packf32: -2290.28173828 , time:0.01310396ms
       out_bf16x8packbf16: -2282.46875000 , time:0.01368093ms
          out_bf16bf16_th: -2288.00000000 , time:0.01442218ms
--------------------------------------------------------------------------------
            out_f8e4m3f16: -2332.72070312 , time:0.10321760ms
     out_f8e4m3x16packf16: -2329.65625000 , time:0.01123261ms
         out_f8e4m3f16_th: -2330.00000000 , time:0.01445007ms
--------------------------------------------------------------------------------
            out_f8e5m2f16: -2035.82812500 , time:0.10325360ms
     out_f8e5m2x16packf16: -2034.17187500 , time:0.01119351ms
         out_f8e5m2f16_th: -2036.00000000 , time:0.01442766ms
--------------------------------------------------------------------------------
                out_i8i32: -2746          , time:0.10370731ms
         out_i8x16packi32: -2746          , time:0.01133108ms
             out_i8i32_th: -2746          , time:0.36144137ms
--------------------------------------------------------------------------------