From 364c3b22beb8a009de2164c65d72afe678690656 Mon Sep 17 00:00:00 2001 From: jiachengyang <101832757+jc-bytedance@users.noreply.github.com> Date: Wed, 3 Apr 2024 16:53:49 -0700 Subject: [PATCH] [DTensor & DModule & DOptim] feature updates (#20) ## In this PR, we update some features in our DTensor & DModule & DOptim implementations, Yo~ ### DTensor Updates: 1. Support more dtensor ops. 2. Sharding Strategy Updates. ### DModule Updates: 1. Decouple uneven support and run check. 2. Reduce some CPU overhead. ### DOptim Updates: 1. More fridenly API. 2. Unit test updates. 3. Reorder some communication for better results. ### Other Updates/fixes: 1. Some minor update on our nano GPT model and test results. --- python/example/nanogpt_4D_finetune/README.md | 16 +- ...etune_4d_forcebf16_train_loss_bf16_200.jpg | Bin 0 -> 30464 bytes ...inetune_4d_forcebf16_val_loss_bf16_200.jpg | Bin 0 -> 30660 bytes .../nanogpt_4D_finetune/sharding_plan.py | 5 +- .../vescale/ddp/distributed_data_parallel.py | 22 +- python/vescale/ddp/grad_buffer.py | 53 +- python/vescale/dmodule/_dmodule.py | 4 +- python/vescale/dmodule/_grad_sync.py | 31 +- python/vescale/dmodule/_hook.py | 15 +- .../vescale/dmodule/placements_interface.py | 12 +- python/vescale/dtensor/__init__.py | 4 +- python/vescale/dtensor/_collective_utils.py | 41 +- python/vescale/dtensor/_dispatch_bypass.py | 45 +- python/vescale/dtensor/_dispatch_patch.py | 28 +- python/vescale/dtensor/_utils.py | 17 +- python/vescale/dtensor/api.py | 250 +++++---- python/vescale/dtensor/dispatch.py | 8 +- python/vescale/dtensor/dtensor.py | 20 +- python/vescale/dtensor/ops/embedding_ops.py | 80 ++- python/vescale/dtensor/ops/math_ops.py | 42 +- python/vescale/dtensor/ops/pointwise_ops.py | 15 +- python/vescale/dtensor/ops/random_ops.py | 16 +- python/vescale/dtensor/ops/tensor_ops.py | 187 +++++-- python/vescale/dtensor/ops/utils.py | 18 + python/vescale/dtensor/ops/view_ops.py | 178 ++++--- python/vescale/dtensor/redistribute.py | 66 ++- python/vescale/initialize/deferred_init.py | 8 +- python/vescale/optim/base_optimizer.py | 13 +- python/vescale/optim/distributed_optimizer.py | 320 ++++++++---- test/dmodule/test_plans.py | 76 ++- test/dtensor/general/test_api.py | 18 +- test/dtensor/general/test_dtensor.py | 294 ++++------- test/dtensor/ops/test_math_ops.py | 327 ++++++++++++ test/dtensor/ops/test_pointwise_ops.py | 11 + test/dtensor/ops/test_tensor_ops.py | 483 ++++++++++++++++++ test/parallel/ddp_optim/test_clip_grads.py | 29 +- test/parallel/ddp_optim/test_ddp.py | 20 +- test/parallel/ddp_optim/test_grad_sync.py | 13 +- 38 files changed, 2085 insertions(+), 700 deletions(-) create mode 100644 python/example/nanogpt_4D_finetune/figures/nanoGPT_finetune_4d_forcebf16_train_loss_bf16_200.jpg create mode 100644 python/example/nanogpt_4D_finetune/figures/nanoGPT_finetune_4d_forcebf16_val_loss_bf16_200.jpg create mode 100644 test/dtensor/ops/test_math_ops.py create mode 100644 test/dtensor/ops/test_tensor_ops.py diff --git a/python/example/nanogpt_4D_finetune/README.md b/python/example/nanogpt_4D_finetune/README.md index 8d70e9a..4de2467 100644 --- a/python/example/nanogpt_4D_finetune/README.md +++ b/python/example/nanogpt_4D_finetune/README.md @@ -2,10 +2,10 @@ In this example, we demonstrate how to finetune a pre-trained GPT2 using veScale. The example is built upon @karpathy's [nanoGPT](https://github.com/karpathy/nanoGPT/) project. With near-zero change in the model code and minimal changes in the training code, we can finetune a pre-trained GPT2 on the Shakespeare dataset and utilize multiple GPUs via 4D parallelism: Data, Tensor, Sequence, and Optimizer Parallelism. The correctness of our implementation is verified via comparing both the training and the validation loss with the single GPU result produced by nanoGPT. The differences is negligible when the computation is conducted using fp32, ~1% using bf16. -## Prerequisites +## Prerequisite ``` -pip3 install datasets tiktoken +pip3 install tiktoken datasets ``` ## Run @@ -35,6 +35,18 @@ Here are the training Loss and validation loss curves plot for fp32 runs that la ![figure](./figures/nanoGPT_finetune_4d_train_loss_fp32_200.jpg) +For the bf16 runs, in `base_train.py`, instead of using `torch.amp.autocast`, we cast the model to bf16 directly and both the gradients and the optimizer states are casted to bf16 automatically. For a fair comparison, we modify veScale to store both the gradients and the optimizer state in bf16 instead of fp32. + +![figure](./figures/nanoGPT_finetune_4d_forcebf16_val_loss_bf16_200.jpg) + + +![figure](./figures/nanoGPT_finetune_4d_forcebf16_train_loss_bf16_200.jpg) + +## Difference from the upstream nanoGPT + +1. When training with bf16 (`--dtype='bfloat16'`), the model is casted to bf16 and we remove the usage of `amp.autocast`. +2. Sampling mini-batches is done at the 0th rank and the indices is later broadcasted to other ranks. This ensures that both `base_train.py` and `finetune_4D.py` works on the identical batch every iteration. + ## Caveats 1. `torch.compile` for veScale is still experimental. We run the single GPU baseline with the `compile` flag off. diff --git a/python/example/nanogpt_4D_finetune/figures/nanoGPT_finetune_4d_forcebf16_train_loss_bf16_200.jpg b/python/example/nanogpt_4D_finetune/figures/nanoGPT_finetune_4d_forcebf16_train_loss_bf16_200.jpg new file mode 100644 index 0000000000000000000000000000000000000000..98df6afd00ff8fb0c3ab8c53e47d208c722d3336 GIT binary patch literal 30464 zcmeFZ1ymeOyDr)=Nbo>#*WeJ`Jy;Sf!97S|AUK1D3=$v^oB+Yy-Q696TNoUI1a}B} zC*Ln;|M~Vh=iC21XRmedI>V~fbWL?nzg2HN^~zIy|LuMmc=$|CQ4T;rKmg3(f5815 zAOk!=K|w`9et?RKiiY+89TOi569WU21n&_JJ_RWiB?T!tIW;{8BQ-4>9XUCZ05jVY zE?!<kf8j7-cdynOruf*6H8X!>VF|W!a)!9Ly19G!eee$m{1_A*75zCTHttJ&LPlm* zc1~_yenCZLRdr2mU427GXIFPmZ(skw_{8MY^vt(e*vjhK_w|jnP!V3r&37Liq1y4#1^|iwzT5g{Q_|lQ- z<64_J?GDO)%eoOR|3o_CMr;0hk~J_~3zX0SVyZiYdbz z{og69zDJs+YX<1nV4MAS;^(yuNt&wku}(yx=3WKbL6Y>C7Jvn;MAvN?i{pIr+0o#33F#9hd^F^%b~N$n5poq4WG@n3 zKq*3e-~(T`_P4`@{&&ORu;laibrCosUwhx~%^f5Fvl;cCR3BR&P1>?8S51w}H?fpa zJ(ksc^PR?|QHE+i3a4884pv(}Eh@vXQQ0!1Totfttx^$r&O1lPa{%JDMOQ{~s9f|G zMrTc#XmVO<)_WDGTt0=T7e84=HvrLj!Yy5{S|quR(APflOi8}ikaWCS$a!Z$jjsL! zhzYlmO@{>G8)AWqu>vElw=fG!v;SF3ZmyouZ2yMt!qaH!YlZm`gjarjeRRd9mqz;a z(;G8U8-l9ftD+h;cjM5$ybuht6vy5iRAnYmT1FXPDe0wW=Y{XYt4U4P&=ga-NKyYs z*CXZ1?fc&d_;XjAaTF@68k;5`%DNBLMuqac;7ZC!V+eR41ZoD}j%mMi8gClYm4l*oII)sQ0oVxSPxj_j zQHBVDGAKUm$-*}4@z}2!goc;6` zq>w5<>2eQPFUXQ>mBFgp6h2|?r!gaYFEf8nI<&Xq#JvpcxTph~ci6z$~nCZR` zVKNj6dfta`LjQmWAyD6o2{f8lWQ2d|^G+ZmafETCE=qfEFO9QuDE=EuzHKKX{W&m2 zg>Fq6Tgu)!zr9F4NiF4f8?oENII>`(<3l_m3co{MMlcib-&jDQm!Iwd%J&raK=~(? zdtfXh^YFXcJtKZOjM?*dN zRX>ZYg}kummb+@%>MR6P4D3Do%Fm*BdaZ!!WtE?dEu{BAnMUbduxMBFU@ZTb=H!aEzKeXpe0SSglpa!XiJy$3!FUmMm{RwcO`nagSfwP6$YlAg46L??xz>Xxo}Iq)u6%Z)bitUfCF zfjxvVJ7gVHM59j66@lD_gxndDwA!p246NNTL;D4%TKB^5!kz_p&`Xd&@}Gfdx26k< z!z8&CBd>4<8g-dK;g`0ETf!k$Pa<~&whM$2&I6YYuy%HEF_Hb?E z(Vj{OeN#Y1E{_n|Emh9&7Ij=))NW#3nT4UT?;VJm)w)C|G`_G`Nw!1Nz zsVc-lHU!6MjmdgC)4IJjJrUH2h!XcEJ$=Rm8hTIDlDejX0Nd7m6{@9NF_~Upr)R9+ zkH>`v2q&>V3b%S61j*Vldc+Q)LeH{!ZI|7Tb=%z}OZ$c+fawa*ejQ#`D*tb;aR1q= z$lN%Ihk+UHA&`3!canyxrnH#4Xfz!9<97XK3t{q=>#yoe52!m-7=Yj}{89Z!3vUWp zfj_GM?}h*0>m2U@5+vNc=aR_yYX4Hm#S>RTQFV?4K5_6V5h=ICLj=dn>T48;T^d4w?(0Z95v4U{u zdW+t`cn8Sd16?z+vj6vy{(sX;V5+tj%urVe|5HB7oF1BcpkhW|<2ypB36dF-3C~V9 z;B*fh2Dp*k`8T7&)i2}k>c`WzbjrRZaf$qg`iW1SQpJ#IjlNe$_zLYHS+=_(Jf*rp z{JR#8fS{wbY4^Y~{x3DW{?*LJNb095Di>TGEqq$H^hTkF=zplAj3^d-DlGcJr-B;v zh;L#K-cA!GeT@R+#UuxtZFc?1UW_4>rNjLtbtiRY1_8NKe;WEIC)*!v zcxh2W57%`R`lv6E)D!>uBWt8O73%5kc)X9HcO*H;Oi+kdw?PzZtZ+(QBW@Gh0@2+) z{1-W?e(N;r^y-Lt-EQ943O$iLQ_MNn^ivTH(wBu)fRLhBvpw3g)4} zErtiq9l|fj82fFY8@vj)Bv>*Ft*&eXEjg-LzwE7A;Fd_S?Rv1}A(e}JKawO~`mO(d zVcx&Urd8bZDq~^JU^se^6eevEZ&4jfq}@he)2rHL?}6^y+w-1k(jbZ~Uu_!Xh_^a! zNwP{Fr$)X_$Z*Ck2_$Hku91+iHOrH1;$Mp? z!!A@?%bz;el|yt0+DI_dV`h-nRp_>J$Gy2)|BCSZrD->LpfA5m^`{a0asI0zti|s* z4cqj8rDM6j+GSe7O)_tx@XLRlOuWB|T$1{$UFffY05*ObugSt?36zEaBMeJ3$`9Iq z1N~|*)OA%ophBOC7*QE@nF*A7h}^{ALK3^UNFP94Eha@5#)qDlv9Ah%(g2f&;7w87 zdw`lO^%{L!;!20l*%?ri#s6DTa#hF#(?g!B>cocup!~;~RZTUm)nVJit7M7pa;C9F z*kNK)Vl%>OzJT&kT*6xVM zi>g1iP{Ot!vUk_s&|R>NwlduVe081JAwt7AB%txMdLmHgl}6IUoO~K=`zo|*4~j8j z8ywZpR^bm}0FbbHEoq!!%_+ga_ZBKCf|gG%?A6v{a$_a4>DTy5;z=0qu>&r{$r3xfgoWjzY`mXTeOM0Y0#E&{^^`I|B0ID}zQb@=4NX zQ^v^fpVqA;y+fhR1!6>d9fO@qv)=cBCA0&Bakh1*bf5oyU8(f+jX0!2qdHW195s9( z7DEwNQ$uK2QM&CUBp;<}&&IgC(svxjMgCB(T+T0}AOztOL6TBqS5dhQd368r0R0;#1|bB(r)sHjqIgPr6YtmytP%|4 z^R0A4s}5|knh5#Nr*j_WXvZkJSeWzz*!*EgQJ(C% zL$$|dq1oJ=8tAz7qLd5P1yv|U!K7J{(85oOHNlR;B$!bZ;dw7r#{40CNQ5ru74wRi z&%zxT6M3*pzF|aPGEqq~A$t7ojQQf6%}^b{XVk07htQI`djNWc4s;&M-+0?XcUkl9 zfgfg5V>3d_;Xi@-d*GWEJOnb{UznMO7|irX!G6`lo;SiC`azfPJU#oF?(>y5W488n zbGUfIH%0DmEok7eBiwEY|Emy3w?_)K35->rgYbwq&9DSD=_`O>hf>ZEjV1C&Bx=wx_M-rntfTN1s_BM-KO-b-Y^xM%S23l%dqa8M zhj7b*+Pe`0CC3yRcB0T5W``T0h7YN(>U8jq$+YqsL4@BZ4UUBREOF%<5BG>z~(Q10c+$uSm}sn*6j zv~*dEd*C&g>J+^kZqdLaG3f6+_~S4?(d*hIDCDvXBKqRw31%N)PhS6O?NU!PdGj2) znrx~(zR&xG{q)S+2Lh%Kr-{U%iZ@pkhC0s}=zsc=!ccakT%PyC(@wLCR^>bLT8-9$ zmb5Xm20B_+^{N6);`i*@H}&Dm9DIxw9^g*>yl=Gj^G@(yp5QB*NY z5S8>%Kr7Q>)>Z!are7CxOz=1wZS4F{xfeI%f)L_%2@Y=7>?6W#CmTp6c>1NgmT>6S zr>`*VzDQl22dH!muk3Cx_+GWY$Czkh&j26Y18;ux7s^^i%rDG!2q?5+^}Yu^R4LXz zV({;(ZJ^8tKT_gIen+xq2#?N>iHrMmLfe!8i_6yT+62*jlxUS0jU_A!(Xi)b_HQ)wZ2vqCUgrA7 zbFvn#wfF8(Zfo#av7P#a^G-jnP`mQV5$Rk>;U?ILDFU}wUc*cxHw zqUjAChS*~r_SwggpDMfhiqy>$!?3xA0PDpa0nLWqz-;qaG76=XX+3bHOgC3Z--k6p z!O+Q&$P=7rTRWkjq$Y1h(7-d5{ca69){TWTBM0=_Fkb!WC?^^rvFH>f7xn+?OjrLW zs-~p*-9mb=?}1TO6e?ed?;yC#_}W1J9(ew^B&>b}3EGFUdDqbpeiy3*k7oRCU^)FB z@Ay*%KrndO-Q_Ay5}m{M<#VbjdFrjzHf!9`Hu5NSBV(e-H8^mx4;~w_$ff8y7Y=61 zi0Jq`hW`W0|MD>ndU^Nhw&%_yH;?L00`BZOw0=oL+bR$cxOk#G)|+KG&(uR7w1N&j z$Va$kq1w|MxdN<=E{>P(fw1Y|&j~H59`^v3{$xhQJy2~Yv4xNiT^fYmiKu*f*$`zW zH^#ZDNloPHhq3QKBN4%S`?;y*?C!ObTrlf!lYEhtuy|pyrRu`ac`^944NZJMm*@u> zrk5T_ZE+x{Ls?XTadK_ zI~t$lXrZerBr0kceL?P#L3AS!2(pb>oU=UFB8qCRNj}f1_k!$>%Gc#NZWH&-zDn{% zCE1ar_zH{3B26wyrf;;}&x}$qg(}f8(Hb8N~~C)<)V}X;0jpQiS2Di$O_JN4LgK z)h|dbwBxG7z6vsiJ)+(8-h(HO+z0^^zWGWTz4BH@&9)cL z5>D>zVz-@m0SA(0XV>StQGHN1%`!b?SB`eKA8x*m$@RE^*G0YszyQHLx)(!Lcg(v`gqk*=v8Ac!lc) zp`sD2t%RD^3^49(oJ#G)E~Nedc{63i{?1^sxax5m5hIK8Nw2LvAf{WeO`04KZFe3j zcLf{9T>DxVW2%2;{*j<(`{iYXm8YFAwK3FD;%k4YcV##m093#>%)VfAcZ!=Z_7&dtmEw3uLV{ z;~Z)MJvS^g^{R}$X76}jyHk~>o^^t9G~7jHe>pG+2wW+3$hQlkZqU`yrcd0ct=t;! z>^9go<`~Y~u&?VG^wKRu1Yat%ruu1BR3vt_Fi)cK@7OnZRrt@3vk4h|6sdf1?2T&3 z!eqmTB_J~+s^x_L(k)hx#1O-*v$%jIR)L7Zz1>9o#1lc1V#@r-BEu-e{VVC4+MLC& zC@)-Qz;vcF=!zvg9T+7wxmk;emy|u*C<2sk>MjB{nPbdCGgUu4`OXYNI_h@1Bd(Zy zC41?bXXWbhJWINd*4p}IG<_)~1>q*NJTQJh?bkubK&=O~TdhD~oY zYPoT^BFG6bCPOYtw4_oEpZBL6i9Y9M!ycqer%Z`##54w>E!3TlP&nzg#E557bs6Ez zZ+oER66U)j|7eJnL0NnOAk8t0s~aP1J0;=idW*>o+|8&3E3%uP5DFO*{d^YVY*O?< zG)#W2og7#BAIqDT$jaAnVz*|vF&ypyvHeDYl(QgkCpep0?5z6lxETLfmav^VjgBdg zTyOnf-2)1Hg_$M4SEM1Qjist9C{KXS}y5&7r7I4(u zQm*ADN)i|_NK&TA({X@fJs!M?>OW2W22ImX3ohP2 zZi?3TVw9P%IoJr87RfhDIO$W0B19|u;Mlb*=1uN21iUq}yn2&s6nqM-b)8+4jTBAX*V)UrU<->Vc)#VIvjXy_I@s8CZWTyJgNm^}cK4tIDg ziR@cs^&BBZf%u1_`)V_72W)ZGy+kyB?L@WLg7=GxCi(|>)Y}suAsstDcaF>yIkd{? zja(QG2ZIr_9$AdNxT-xjsy@0e1xT)TQco>Vu`Kv!Kf7aWUqzJ1y7$PA`kNz=Ku{Gc ztDLVj*0bS_&0!My563P9)N4ozo|FqL&vA$>J(902S@g zGX?{rlT~vzniIP-G&MnZ8s>s#&c*|6=hWU-V0N301d?oDW(J>6(!1VOX1-bZy9)FU z-=-}-LcmpZY#CJwi^o^TA__>}F-Yt2ACVqa6U>op5=0kVMt6C*)SKr%_R0BLdooB+ zl&b;hdulNlPQ5o!Dyf&k1ECVk|&{X>}(PDnZqerhd(N4nt9;H&x=a;qv>#2&>3QccEiF2n) zB;R>YAE{oVUV#KkLB(`P7^pg0KaJEeGT3sE$u_iNR0(q{ZR#hunGVv)&c&L(rSt!m)5ODGXuD5}|ZOiD*~zTyy%N zj;@R=#bzWJ?`b&gL-jIHigfDbgS~F>qK@u1f0V9~o{W#+c*ffO9tJU350 z(*2|hthdt>ouaW59bd`JBFi-q0x`J-B?h9qt|s7IqLXbk-Qvl}CWW%97HyMb7R(wl z;)RsJY(OPvOII$m^_ufpdHK>s%b;zf?bRNL&vtQoamv7EIheo|Of2UkBSKzD4xx`y z+*fiMb4ygMoZ5B1^;8J!-P<>*sjF$ceyT+Ma*p`3be6Zm8Y+b~pQ!u0{PUe+Z2GZ# zplVL(`i0iqy5=KdnW{HW_+vqtcnAlQ6nJhHa(7*U@^uq?2W;sgu_Arb#@bQODcXFy z%7rd?!FPm*=h4m_RSOs|D(ir}SlGDWJZREflFx8QZ_;e)z2%H6P%$Sz^)Ysqw*rkLZ;%*c}ZroIT2 zUy5!;x*Ir(d8AC8M+82#~+S- zkuvbEk{UaQw--3(@q=e@ejJ>#Q)wg!Ra3C^9Ba2uHFq~C^q;a}v%mv-0DuZ$PYtHx zZ)?poe9*-Xh&#V}H)O|qV)#H{Aag^ZXL?@FUXg_-GKFO-HRO~m8D&JfZqnxT+PbI! zqN6$M98)6G_PducDWgn$4%Y4m1g>cG~gzp>rQrO`4N1-^}LL{<@94&F{-TJMb~6&JQuwWd-$?5_D_lcT4zXjnvuL5=&z%@T7x@&x#>zGa`(a^}9^^oT{DLUcQKwi!7 zd6j?wn&Xi#5#jW>GMw>NNg8|OpdFinI#G16x2adSSzVo6VpDYyqtt{~<2b;-zvg62 zNNwTGnv;gD((YPrB^(}HvaC2R#HZ0Q0_s>Jj9vR|Pwc!ewo~+IvM_H#FS`t*IK^L?#(uz-|Iu4EtTl4EWT{O%sOF%TDHS!5=ZjXI1GLMVOw`qKBC zM5sYV9Lp|wW^mI@sqJ$C@^yUI{g@}@5XHJFBQJA= zFBQ|UG`QOO9@xFrQ>&ber7g_08Un8qgwSZIaDkZHs6$c|y0xdAnMg-e%Atrg=>*wX z4th#h-=+25aqzo`4J>2g#4C>I?r){xNK8~g@hk;MISL)DuBPTaOtlT=DLT+pfBH5_BB;zOrjYY}Hy!8)wg+@v_NbWO8Xr0t4B;be zTrqyA{#bPnpurY8jQKtI_L3)kWlb^q*O7W~Bp?kIxx6jJht)%$ulQ=RGpuq91Z+8$ zbipE{@B{Ho-PBO@!29h_fnc zinb+@X&209BW}g@aJ5%N5lXutyRI*%7SN9WRtB$Ic|(KFM*lF?r+AOeW*_p&Gq%G`d*rsOixHL=VK# zD9H=@DI;7QJn$7oPHIM{YrYCAetnCK9rC2V2{XO+?t*R-5qhJ=_-qcvG3zmd;AA`CAJnp>j|=Q>KfX!-XxBSjAg`mO#;i- zmLF<8hMGGUI0_-|nB0bq9r4L9Atb4<#(UJKeT?NLi^0G$ZFCpf^4y-*(}}E9&FRt< zdo)=~8q;zrJZ@i2`WMV}8{+Z#?5Bx|t*AzqFPA#6Qlcq_lE%Cmjy>H4o^3nso5sY) z2hk^aoR;jv@5YOTxB?+lEOll#Bumsb&|E#xv(D3q=TBwF46*l`a?s%AM5MyrS|plc zb7s{(yCz-@zT?FWjG>SoJlRB^5%IG~&a&7jeFoQqWbq}#!|a}7YjB2mA=uTEeXnSA zzLzT2TqK?zTMFP3G+$^Zf0H}*jt*+Nq9-PF7H`A(v|@gK5R9)}N?$TOfRwky#O~83 zs%0b#n1CDj&f`Vn9$C9KL>%{NGx)-us31Lj$tSc%p5&x0PHNz|(JhAE$LhS+dnnv3IBd!n9Q|)vkTdRvND6 z2PcmFJ+t$FJO=qIZ0SUa<^|88I3J%Sc7sOUKZI`@h&hQ9Yt+n$s=Udi;heVXQJ5Mz zDM?7GH8GTu#eYfs&y@FHOp*L*|DUHxfT+c8On(wadZ^>VxA(@lFYdM$L)#SHk>9n4xs=U6 zk;4LX6Z#MqP0sQJE7oR5&iC*^D}%9Kd3fga`4w_N*jXla-q-^oJql(NBw zgblUH3=?pR=Rx}~}(-9jnivPGRBzTK=2?d;B<{|1J z))kZmQv#z849N6y49(pcs`PP%K*Yc>75_(au?sfU$?D(c9FRXm)5(yoAvjgK1J1}{ zPsErITSN3mBmz(29jlgSWf>jzHL~Z#M!c&hEo_ap;+=eYR|@ChsoJ*gyn~fi9O8!H znP-e9A{sq{Wb^|NLdUX0Sj^-OOvY+@I){XiwNNB)Kh1>%;q$WJZl3{kS%ZB)(hi#c zBks5xk6?6O!KulF*>MA}`JrM148QMT0f?K9t znd|CngJ&lr%h=xTJ|G@;q8Sm7F~lP1sV&^*8CGl#W_5d4p^*`Ue~4q+JpbL(Llu>A zCB8c#H%F>eAc&nWUMe(nZ{j)FmA^y#@w8>HoH_BUUgfF4bVH+sWDXndgDx%7`e}2$ z30D2))LGfC;I?=PXD(vk?IyM7jF;UC@r-I^W6Lz3!Sh{Z^P1p&g#o$iXZt^KpFbq_ z)=sI;&iFxiMfjzLu#~l{VvvbDLXTTvoYLDNR`9$VaJ&i+Pw0JFU-igM)t$S_t?r1l zn-P_Y+g%=fMk1JNEzVc_>FES+wGyjJJbYAWu@~GzCWa_TcDiO1!J(!%Q zNzps$1UFG4m+6Z~b8u8u;YQOHK*$YMCQ+cHj%+`8HRtx;i%hX>!dax1mbg@OC+IQy z@DUN10(CGX%H_OGXjopL`j!yq;noB=?$HR6V}(FGg6g)~R6Cgug*pVqTPXP%+6`Ta zC7!mG?uYXP|I6~as-_0d2pPFu3SDWPXPC8y#)u!J>g+~@t5pQ8b`=Ocj@kS+1k7=ea0=Zvl^K) z^d&kZ(9JY>G)@1ukm~81@ghh#*`Pxa)U@->p6P2yo(fkiA6e4|UghglVznNk=>9cZ z6(ojxz;AW5^o;DwLJ8KwTc9Q~hJdsP!bET@YCzU3uW&$P{&uWW$LOH7 zOU)|3Cb7`*s^*WB+bs}C<|v>PS|DTLYO~SpgcN4(_oA_eu%s}e=me_-hx*A}AGwqM zi0{FfYr~;VC}azE1!6cG&dqb5zObU9|7kTYgw==%pxfgR4~vmKKOCE&kU!rjhS-t6 z5!~B1mwT#jnkPBrGg{93X@wQ(a3dxo>*6_K>}Yk-O5Bj#r{;|hMCM)5=3RZiC_Ma!LGfRq ziu_mNxbTQ6w$W1*PKwKPvO7JmSo14)DQLosn*LW?tyQnDULeacfytmOI9J0e-|)&? zO_qTA_l)UZIVZpz>oZd{f$2yg2}r};v4WfoZ|Wro+(LjSHCsk5)mnD$+~G-Ruhhf* zt+WRM3(yg!)`QABe7gp?bEXDevbbUKU_!*YFgh9FyxzSB{K@uPH$NL)%)&VraLSdR z*UZiFr@L#36}juupX)@o=g^Ma|N9gFR!6-`3ylyFLSCpPcP3zm6Zzf#V!#7M;NtY%wZ>o1!YpM4VSmSeHZDA&KLdl6SDYWRh^K%zy8GHKc<;8aOZDuB$1!B0wTsCx(|NAZ+l z_`3qCQrf0aZyx!oH`iNHr+iV;$HktvUmD(2da$ijhdinC|59e~RS#1N;P%(I$nO=a z$;|%-CJfkl*<4#$i^Lqesu7AHhN9T*$hlm9hJ&ZKWbi{5D@4_xcy3-h-8!*G<}>y) zjd;bcZDs^{trR zd`m>Y6N&1-fE#S+|xd>n;yC;V<$b4u)bw(~Rf47qy09qi>$y`omRO=);$mj;3F^Nsj@ z56ge&Ir-*aqmQUNj$q<8MGcaJi3cmo5ajd4^80pZSu;`OfI{Xzx=GbvyDI-kbp46% z_NR(Jnr*aD%Z*QZhI1B#`0 z8+**!_a@5vLkjK(JASQbzyd&uznKvce% zWjK4es?dOcQ>i6PVj+Z3*jGn?KJmGi z)T$A!juflUi5(C_jh9BHyUp5H)y#6lSrx^*@S3mh%v5aN@hKcYjxt5u!-*vtsAw3T zz56M8#<#5-mW+~bX;VVEPDjsgQ)1?Gc@J3eV_m-S=DPWfoBsWV$eb`a@3LSVHF+&u zL!bAgY1j#D>+Z@BZnz0+82CdEOQBY^FvwJ<*X)X0_?QpXbX_aDzC35Zo*JQxC%yvb z-UC`s#VJV(sTowHwK`F!7_oKZ9bx>L`P`e6Xo_aCg4Diw8-{_? z6N+d}PS_&XldQ04`6pimaM#`@Ss=LxA|18&+Qz)E>#O#3U%9r)5UYyR_kq458uICl6eng zM{CLJ{|PYucvpdtbum>Jzwlx@G;RJzfS7EQ;%aN?33>1*%5^hQ--1@=ec_XrmzXS< z&*-%*%o-XMDL!-jME>gYjp;hA%mh^s2{TDu;_w5`6-m=`xKjYSBdAQ@n!ZAbrcAb` zAAcH7caVJ1$kQl-=Na|Y5+wKcZul!{{C_lM?IpF+Q?e+X^AC28XD`WGsMZw_QA?dM z3*f)Oa`wO)U$8uUf@EXRsFw%Y-(0el5RTKDFq!7HUf?(VxS#I!$&|qr)66rh)!)ew zH2EfFVWk0?!{9|!CMGURxuVI41h||F%Q%zYE>} zDzg8F-4}o20sg7tpIkzMHixtunYxcQ+O`X}V_Kp?2#VYuuXQlI-ViL!P^?|cU5h8@ z_n`5g+@?r1F)=fv^XxKcE1sj2kS>bd11t9cMT_yd2q{ZSj2q`^p>^6^Q(rs1s~FBs zLj519rN5N__;+$9|C4J_|3pnM`IGwpbF(S_8HLe{u~i+2nB4e)fS^`=vMHt)f(c7^ zbC-ZB{rQPJFE>^{{aX@k4_`q(d2Q_T6?Th>l~(F6r01Mml1R2cWtc@x)Aw1EFCh?X zLhtD}mdO2+9FspCaX4J@2wgrn9Gf@bG>%43dfP~?;b1NpWM{X2JPj%_^9sUrEw8Nf zK*oylDzvJ{_&A|^8EO5tAZJF(Z|~)@y{$R2De8jh;?f~sNwiwHYV2r9ypxTMQ)Cc_ zF*ur{XLJjy3-tcUmWe*pHSEp1%;!T}PT8_5Y*C=$qJw7Wt)c6DT%+oOWth`a_#v*#8%hM2w+|dOqIyfAZ{nx&z9+hLN4suGoeor! zutA{RCOtnKt?5lZ^k~cRl>F8sV1WU*7Y8t{WDby@N(-m0**6+@mPl;A4)i zlC$hE`f5U{{No()Zt*BAU(d0e=~#-GM)0Q;WsFrY5J2|a)7>JL%q=s@1&1dCFf#^_ zKW^It{kzCYvcKFJ>pPi64mfBRKz~qFJHt)ef30x-=TQDKXSAqRu?S8pAw@qN8f9os zAQ;50>p4Ovt9QGDA@@LS2&KJsSw;c(`YEXwiBHD1)lsJV$HM{jeU%210F9a3AcVeC z`mkn!9e;j{&W0WE4>V7Vboy9_f@K0BaB6lX z`&}Az4Z>Y$P~!>Xk5`Z~1H+ye0yU2BI#FbKqB&0f=Ei-!B?!x#lxZ3brhmLbheaso zxtp#IHm*Dzb`Z9$Vv2WjcDwDz!k?_DJ$U@M$3>uCUtn`D0*mzb-7NnH&V_zsW;>-k zfNwm<8oBj{P~Ea6GzkA;k^Dyw=TEltpDO-z$#07;yWp>9QiM zyKY|*gtg*9>MKGRdVp<{_h=&*Mu#Om!9tm}EM>YP+N*ERis>{cHRBkz4LIHnV))Tk zBDb3>b>&66y{6>>AfOLI*g6dmBXy`cO^Nz0TrEsR@rWS%_VlO|bF_!< z#586yu<--e*S6HA?Vs;O2#G}>xffX+xnznZm4j&(_YFv)!hFXZ#nn(7>~q zeSWr*_Y)|3b_s!ReN;%!D7zlYmlM3H+e6j%U-8|?w=DktGKYxR(j`gsg$S=w60}U2 zr>60t6yO)(IcZhp8P)X5u6+F;mPC+DdPEP2t2c3pJFtm*DiI@G@2O-pZA3J$#|p9s zum!N`=nIZ!4yfs7hOTDR6nW3_z`OA17}x(UE9bvrwP8m189k{S8J^e0rVu69Lp$ zW7(MsUF0s^ktDwR?cEWM?=k9QPt=aoPi9rj>9z|K-3N$X=<0<%e zB%md-szm9oZ3Q#q8qAX+3487;sP_$L(LdaDGJn@As3bTzufKcVa+&=`CugypREG+0 zd-qAPogGNdSYjcvO<7(hRt!s9Nuom$`b&htsxjA~L4I)w)-&PtlrP?Hzg{&@&)3C|_Rr_vI z>)3hlr>a5S4{Jxkur-P&Zt_1O9p9OC!f;k_x8WBU#^ z6t0kIe4LsCjd*odv|sG#9#dji*Ql-iG}t^jK$Br7{MDA&$hKRN$;md}&!F$hX)9(? za$(#{3m?N2cj7j`LAsTm-7QWdJ(6`;*RR)Ro206?!r7Ag{#baE90F^?<9r}rY}<6C z&SjuHSV6{X`dx?d;uZ1WA@|Lr5KM&Nyxw>x0rQ`}`mknP=w2Di@gyGs(A=&%WH~Da%*aq??Hy39-p1mXtJ2wQ7CU&G&%fCyLL6KRbakI;Puwyy-N&Tl3kOM+dZaf&H_wUd0%p2}&0rLr z0$1#-1N5I9b{Rkj%^Cp@$iKfj*AUQzaviH=hjlQUrZ7Y?YQdG)5lg&&qcK{Ac2!ef z2# zGa2%j!bo)@L2Q2Z%bmT$g z9{-u3ECH37@D0&#W3{AwqezA5L<||p3)6SojUe%M(d%eZwU2s(&vO#+3g# zDft`!TqFAGJBXRh0q>p*eZSl_#jxOj&`&g>WxI~e|EIXGj*4U3-fV(23DQUeN#hpW zgKKaIu8oA?!67&Vf`t&YfhITv2(E#~2?T<>2X|=#H13`0cc0AM`+l?LzFBvE^H;6f ztIz4GUFYmmXYcR(sxqkrUYz6-9WedmoJ zqnXVu1a)^oKg(N#0N6zV4+*L)-D9{{EOoSgIh`f-M*1N2gkrhbnbfY8VYIfVN0Z56 zgZ)CIgI8Yhge>8ZF09n(@^%TysXeoOBVEXmC&V5kk5$&Z;(kNff+PLXk5$0>g7Ue! zm0*Jl^&SO`N|Nt@vA#EoN)i8u9VK<=)>yR%RzhbhO^AXAov^{6FWjqj#0VUP+3t%>V11KzabVSqA;f=w<%W;n0>l%F>^AxbiAh9hoR9 zNq^cN;hFQ##E$tHr4+Hep7}%4g*RuQbKkS_?vA=o=L!)>3a*7ppisAT<^>tGhP}Ku zU1J?=(R)hKN+u>s$ci3|`dhsuMn7k=pL>q+7jh0udxy%dz4Tm+{ZJl>>U#|mRs-j| zw{Kc)-5e3W02RBZ;&&oh3W@ZE`Fe|Xr-?h^POxoay$w8Phx9#~n>>qzkQ;%z05|rZ z`B_oLH`(cTx`QNn*@j-L6a4W0_6wlZG!a^DG{-yFrYDovC*?!_POXh>WQ(4~uJBmY zwEyZ>&b|o@+t~}8>IifYZE3nKmTV(3LuWc3yqeSS`Y?daMFYfvxE?>*Cl^`wbKf(d zZ%<8%6_NblA?iNvIF|p7W0wd{64vGUehQ;c_Rb6kGsb4jlvy8;(;)Cjo{ zy8M_NhocN7VKIUrXBy~P`De^?}-G<>(HWQrRwlKRsas%(%N|2C5_7(BWH3jVB zx5y_6sxV}8^mcqn)#vkm`v=8hjl$@0Q;u&n&iO1(mLy!2Z2GK63uQzYxX-!^r; zd{M_c8uRmbwqARjoqNx~Cb3eI!(e-y#ar{qhw^$7=1&;Ak*d;wy)zjp@}I_MObK>( zYwARlSw5j}?C)X&=%Mgi-Rf`0Kb__L?ce{gF%%vq2<-9uzI66L{gxj7H`?v*&;PF^ zT*ZC?zPn0%SY@8D(7l$(ar%C~UOY%*nJ{B506I7`--V!n<3kzC*Vx|Q)Lg<5x_H55 zw4CU2R8u|I4VLaO(Z!qY&RLb6#-KpP471mkI9ec!kf^AlGX}jf`zLbU!xjC`yS>=3 z6@xe2V@vwTveO#8#tjWMCvHM38L)yxzSI>dBvD7+xLTFatewJQ&d%+X7dQV?-}i?2 z&y{S=F^T1s*5jJfj~mhiCN$+plKJZA>d^&rx_it1+L(4!yw#IACRIz>;W7Zs!4K^D zu{mm;*D}~%lObyyPfnfjO%$+XRO)%$9rE2%i1yKM+jrW{aw+OxUHa{y%p!a!HB^&{ z6nF!A?{g356tR6oW>zOtVDVOFtU76FlKTf^TEVxlwj^cPa>Dm#*CbRAdlWQE?^2s7 zhvIIEd6h7RC$0op>2UVm-UQzXYzulJIXh>BA|K0=NA6`L4w!Ce2XrK491C;M4mXh+ zr*EN1sKGNvm>xM92gsOcOps(jxubZT5xmRK9i%R91Mv@%d213w#8x6jU7UU<6 zIRwRnxt&M*b9vn5+su+DW1N^&U%5Nnr@ol!#q{lHcTU+Tw75O+5LJ-0{&*PGDs*tv z$J4EClFw`VvLS3GlL5v|X#$pI1Id7l#xs<` zYFtZm=VOuU3e^+S+8nz=hHDbF5A@MJJ#$b^MD)_m^yRb&THDlLfNjBBw-Z+CrIM%4 zMDMIkN?pHXO6arA>L{_Ap7ETb$g^i_q$uTgl<gO8+-$m7x${!zh=+2tl~bQYO{dK0tQYV5XgvTVY**ucyv8XXDnUHdK=^tLhV?Mv zeM}_>2#FdHxn5A@iw+&3;?))qd@3y_3aWKYW|y`>@lh|%Z{78FBas9*?VKki(+ghC z^mm^t4Mch>x}!!h+D^<(%4F3^Q?7b2HJJym6!qdV*Db?a+la%kOl#IEw+JL z0VN>a3Ah}a@qK9$pKu&!DC2IcCm485Uk)-ITX@@azjoo35WU~~s2mq!5*q(`$$JqZ zm(^4N)Pt>cUPMJlOIMak66y zpMoXsDhy!;?5zI+e2x(TEgKVVf#q|}or`zs1RvS8F0Km3$eqzY_NDgyVMO(_@c^YV z9%v^0gQM??<9x1W<0WJu>e7kwz|!~n>gsXKWUeuDcrD%su;T1etCOBtl_^!)XH1^{ zSS6zHtK{;o{VP$;nwz7&c_JciaiwiwLusn3qeQSt%B#q&=Q#{eIkz&=eU~EcOUG(o zX0b}AG^D&@cfQlyP%<}rQ94yNQs*2vPxXW;vn?;Pvqn?5eE)lOmex&%=uHe8h|Y^0)=3%3Z=k7Ylw9NXIVye_tQ7#q=eIa1b1TPFD}at-&pnAZ7S%>vJ3(F(5wqncaWTtFM~d=#~VHjF&epXwZ7U|X@s89KCya9 z8w&eD$HJVw04jLC#W41H-sfOvYia%cm(8@jWRT8mzA$@<#7;iON=~F3S(r{ z)Gzzh>||@LJ@tF%+9t1t#!MEv)Ts%>!Tw8IFzEpB1(p*%162#ixe(z=A(zuQQQ`P= zul=4Cn%ym>E)8}4uhI9LFp>7KXDV)D3YOZQsi^~Ex%gncc`)^tRZ)N$ zvnEE~lG5O>HI#!`YZ@{?JWGcwWqO_tcLW%9v_-W&bCg!`Z8?cSPUUftgmLH zaSq5s%bZ|vK-}b^U86mb+a36&ytr5T0$$5%R0>c>q517;UVQ#W%dW?C~Bz<QlPq>V9~dQLd=+NX^~@hU*8FhE~FBi&Ijx{d*m@R+O%!XTB<~)3rYT|IM*ME-LWJ?93ac#pVdbS|^;ce- zB@!Y@I(w~O0JZ8#j@IEoJAwhjc@iCbj8pP$AF;?y#*09EXFl~@Q#C94$JJl>!mk>u z(dq4Yo;@Q0@?Ag4j?`Wc_dASOtvkKNuj>QTcLgSfuGKC#H4<>ZR zP7H6JRV}=N$J*qvUBa&*P*ZCJV}rOsW)o;FHA@8#G~1l$>2D<>+k4 zwYXowqjPqyup_>Vmvw4xxmOAMP&V;V8C<81A?<|OYR~<Jb@!Pyv&lFC;hMAFYHe2J2^Ff zs(RLH)tTkWH#d>lhSm`=TCX+HB~jd4jY2$n9D6I4?O#Q+bKjyGG;zF* zCC${ZQMZLgMb;4T&9T+rdFgMHvl}sp*p~$)mw0wjGJ{;2Z z2qsw+vv#+3?Z}FzW5=j3BwX;(2ez@*6dSc>P*vob?o_8UPJr$@*f=fA2NKpQdu43g za$UTv2P0j7%IT>a`-3l_e?FWm6+pXZnQB_mZ(y`Kn zVfQyw*wl6~RU&!G$uM!ikl)p`G@8``>@X@j%P|+MEm!bzv=C|t$7OF0J_5+3Q2N=B zGXuDU%0C9l;|`*sd4m_0808oB7;kTjf79k={NI+VAm?A3_P=A3oKoqz!$PZTvmC=Zz`Pj_4p=o^#zPjct~sQUk#H~-D>uZsU)<+}g0 zAIpnmV z%XgxvG;Rg~R%kYkbAI4(IT~UUEIN!5P`F{b(!aUd&AOm(Myo#Pj-qLKTfevI1pdB> z2^d}S{e$gKbfZgxMo!7vC)XL&CWz2Ab@D?7`$`qH6q4fZ>a`z)?3l=Dyi->6xNb2} z>{3ayb&T>}dFDuud2v~AbRQ%6k3OJArtO#=@0JHH$DpHIw7vp*mS_FmruoM<|D`kk zL`OO*YqtdDrVB5pRJS7!h(?{CvR2V5QRzuBLV&HHe2)h*;3u@x`={pg}_UQ+{m%QGD8r4I}OAq1iciuFD#~b zlb9N?U*rbmgXg$$8WW%6MhYP%aocZfF}YBr0c@xMkKr5Q#lXX2$(5&LU$>@>e*xOT z=tYI1JO-Vu{i8AW_ieVf(P@s~oBr71zwD+k_Yn1tAKHE2sK1%3TG!+uGSLC26#xI{QKI3E9^Va7y;w#Ah2Xdv zUye0EpK|3YXLc~|PwDbcg|=5(wr5%p1$xB_cky0!o)Gp5TI#oMXpAKK9V~Q7`4=F| z#Qhfl%PZ^VRqTb@NyX*qv?Nl&=NF)>B=>Lohu_=K`j&ra)F^_6g8ZBHPjpqxA{Ll4 z-(?xhVXZ^hVBq&kHBg8B zx5kqt^i)E_xs(6pq!eImjY^64KF8{mb+9hzar|~&RC%5y(K7>@nA36MB>D+mFsX_f zsr2|s;kunx->j&ICmB1jOmHCUf4|GWqMu&;F97F1BRemM+BRA5a31o#bP%hUN|=7s zncM`eSsqa-4c`4s&#F*ZEW!j!WeO?S(d-$^Ppa`mS6>b2!mv;}e{P(=!CCYdW{AZ? zUN(lE^dx*6mFR9vN=$lz+ZX2p-p%_5diswQ6t-TjGV?$Jdlain;0jI6l$9c<7d0;} zxF<&F-8D7<+ZVikUI8gpYNpRoFAaRv`;uTf^$Yc4eRQ5?m^+o{Lo%|Ld~KUSg4_Q~ z|LgwtRx6v)oB|w7;_5mX5`*TBy-07c?TN}s9 zj6`GiM6K&%e8`EFx}*^W3B&&Sa>W-43R4~ir=s=y-+#JX-W6y$K@wAWIo{3{A_9}M z3l=W)m0Rm-!(Dy#e&D-SBS>@ignDcU$*(}oY5?iwrkM(VK2Vamv-z2ib*-PD`E{Om zo~2_tC^3m4z&^WpiXz~1k~u+ibhGDKW%d#^iw7v=%u%NbMmrnmq!&7H_(}=k2Pgp^Ur)=ahLe?L z$KcW%tVGMu(X{De0x)JU)c_x*s!XX|Lm>ss(2zNI)2A zLRV90;8KUC6Fb8f_{hPF@z;4@AOU(KedT@MKVZgJrk}PWyvtE&a_|K6msA9`z|YO& z<$j`vYWZF-9+QR(4Y{VD7K>Dby^emhJ{*XD#77f7$+-~?miM;QyVsF0K>|&g0VqgU#_y$?5oIHz7FlYLh}-bFYk6G z@br|xL7rte`ZyBg4Lkdm))CE{s)6W_(^fL5m4gwK5Y@7oITxFx zEjt~2Z0ooxRty^aWy=xZQ>(lk6vg3LtW8cQ4ZL-{Ty0*nij5gI>X6g|*)p#`^3wM( zxI%xKApplF?Rkyu#%p3{s7+E5koJBPxLC508%?JDdtN{IXlO~}3v3L*!$49#Eu%Ou zE>wh|Y%(}-WQ3;;9OY|8Z}g19LAg(`s9mJ937Nl9=p~{Er}%g4>C{-krZ@bmbMyvd z8$H1Iqv+f@%uH4SkhGe%@}a_+7^d3Dsh|26J7f}T=uZ#(x70{BCu2fhQlWLdF@I5H zpWSq7;18$Sv6D@7nZU_Fc&Rzlav4U>;O2K)hG&Ln=LxIR>pk0{ZBXZwtDkUUTDyhr zE1x3j`O%vx_2=sME>if(U$;CI-e<=@T-<@R%8GQM3F1!}5|XtDU3`}H{R~C)!)wYO zu0qn-d?dzMLw&X~mDX8>@ml4r@%Ga7YJb}jJW$t0b3D`NZ3^e|zkeYbq};-aNjWV^ z($*X=JB2XUH1gr52m7ZiK3v`kmmz!e8e^+bGJO3|0Zo~|`=e_rq*uj09MYp&C8Y62 zPfA{TST!R1L$VD#HZ1dmSfky&=DXR3J0A)l z9exjrMVf6JlZGs4hy)s#$KMg_U~!n~MkcatRR9SXv<9ESdawx01y zJ4>iOObM5aGNaexZCmeZP2q3Rr(AtiKFOlU8)ytYT41s_k=V5LvRQW}HLpUjm6i^= zOQDsCUpmwZZ9-?Z+~nLDMBic&&5?WExx0KQMYc!-9|4At>Q2|BiOYzK^bbZmy)eXO zM1L?1h#jGM@cpG65~1{ldYgha3|C zo@Nlgig}^khtD$xDxXbjYpQm3s)3HhD-)hQu5oVCz*Gb)EIDt+7H(!aUIxSUdPGB_ z!h;y-W_*9jklOV3$?qm1P>np6Po{t+P!$wE&$6zBulLd)Nj-P@*voj)ww{UYG5PNQ zus~!JzSti2D810{c*79)il{wzkiY*42Y9fLOq(6F*ycqq(XeZ%$0C6F$=mZVGvK~` zYFK5A0y|R-F>NH1WM_6r`|d`p@(JEBzMV)|bbkH>pCarXx3W_{DQy=aBf;1B_REF? z6|S%o!x}VGxk|38?gcW>sb`eC7r#7@+ev+Z&`0dZ7L&Q1rz~UZJ%5PBZ*1U{vED=vS+t=m&gb(&|Kg5BBlzvhguRBK{805T%;%OjAHGB_ zXDna??3TQ@)Dt><2T{>%#_CoD!#;OmQ@Y_M5sI`C8oTKM<4a{`qhk8`_E5&!+Ns@P zn&%pw+hA|(xJry^2K*xW_(2O3&Xy8yO1+Uy$7V_57!rUFZ14py!qin5fdRF}cuh5kS5RKFO%aZAT8LdIncy?}4c6~wm l{-KlJ4^|xGBK!hiJq_`;lmN&5-@&{82s8W-&oO@e_#c9B@WKE9 literal 0 HcmV?d00001 diff --git a/python/example/nanogpt_4D_finetune/figures/nanoGPT_finetune_4d_forcebf16_val_loss_bf16_200.jpg b/python/example/nanogpt_4D_finetune/figures/nanoGPT_finetune_4d_forcebf16_val_loss_bf16_200.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e5a068f4a70a9e2eef41225f47f557132d41f7ab GIT binary patch literal 30660 zcmeFZ1yo$iwl2JIhd^)-65QP(1Pj64T^e@_q>+sTmjJ=tEhIFa;10ndxVyW<>wVri z`|P{Vx%Zv3$N%m>{(szN)ab?PRn;}CW>w8Ozi;-_&!=VJ*&As&X#fTW1~7;I0Z(&) zB!G;Fh=hoMjD&=Qf`W{Sj*EeghK5dr^BfD8jF^I)jF^;^l7^Lzl8Tv{l$4&2ftj6y zo12@0PC$g8Q<#;DoAcL3U{Fv{(9zHdF)#=@Uy#1w{I_3EtpGMM%rdM$91Ilziwy&Z z4fE6qPyhf70<^ckcKEL^7+5%X1Vkic6jU_m1CVC`EDRhREIb?n0z5qQX+P-a06aDV z&I=9+#OEp|NL0?aoB?q;$keZ@+wfFJk7>9}T>?>1@d*fth-qKa(K9e|^YHTV3kXWS zmXel{m3yP6uA!-=t)pvZZeeL<4YF}{bNBG{^7aY(B{(EB>_d2b!pFp<c3qoJ*B2AEALgRtbKlQ%AaB*tsNqy9;%Cbuj-v=$;qn`-fPpnW|S3l8pD(Z`9 z8OYCXB2H$3H>|zO#uvuD=gs4oX-_lHn|0~<-Lo5g(UQO0WB;g0q%8A9t22s)+qU?e z%tEXtuQ8Cu1cm-4lRHi&#Y13LPe_hBVU8PSTU`_L>zV?bcoT19;+Yg@c3it&AElHC z8m=m?lQ)0)63+E)Psdd%Y~;M|Zaj`<2oWtTOK{N3E};)XUZl%g+1Z%w;jLdMS6WWE znGnl2iE7$hQe9`h@S=Y@i}(x=_B3fb3XksT$t0YBydJOMv0kr_o1wZPr78dRN>;i@ zv|j;mbJOA}zJ3Bah!r*_Gd7F9Y>oyF%YuKiB`I zjT41B4pZY}WFS@!E(|SRK(=kNU}fDZ6T@1t*|;OlB94Dxzp2cCd3t*FPPj+i8W~(i z9fN(T4J=crB>e{}Z*9soXXO0SyL41wA-B5fUMC0CXi!<;T-h*Ke%FP=00a)Now9Jd0t|n+wd99f@j{Vqc7Q?3O@nV4;5KmF>ipW z+|N8>MAm|2v)a3Yp~)U2%zHQ;)HPvp%Ib{i=|Qc1t#pJHEOeTmr@4PZ$aptcS!jRs z5m58Ep(CS$lWdOB2Z8Z-D4ak34(EuX6}VTD{TZ1->E zhXd+$b#>lbjgu9NMEN)Mur%s$(E>dZQ;jXpt^>ueyc`TPwcq~8!KZ(=JiruNoSm6s zNqR{-f;fcN@1>Heh0a@PeHpuo?%5880*3nqsi;}laI3!dB1!wop-J^`(Zc*nILXR|KV`4FWi zdCY)MU0iIRWCl^A4hQ+zMVN5Lmgl!A@Ms9Gv`%=!ISusNtGPS&5BcmhRDuG6qQw!v z)rJFSNg=jbB-#=VcXSEu&-3lmO(J9c!Gh}$46qFNmszn;O#%bcc0IyMdvs>bMDaJX)r!Aq#vRF#r zXxHR3Ani(-WEd$bHnH6;+0<6S;CupROJiCzT1N z{BZBO)@?melkCyE9@E@nsfblacUDG$rd`+m1YqF{h?%S0;4oF?`V05Sp2UPU^R0;y z>Fo(ZARoW%ax~K92M(-DbKQ)W-z(*A8}d(I_-P)Pex9;Fpc{95Yd1QC)2$IFi5b|4MWp0XJ{h<$YqLo?mDg3RsSh<;begr* zO%N|05_u(-DpO6ZkIYt#gP#sM%m@#3eb5_u>sAk5o(MaGD-35)GD$pYqF@mt%a9mw z64yvMNCONrqyTTa(2HJfKP}nOd)igR$@7`Y=AtpH8Y(?{Jy))qPnya_UxTKJDqB`P z?-SUt!P5!TI#Phs)Csk(r5q8A+wmBJsQ^*EK$p(-d)T56vFMD~fk%KxpW1z2Wjj=H z6y){C?!`_Dcds-Zso-)X37gDmsb7LHrzm}!s+VGt$rHnR50)FMrifmAo$ximScn(I zeTkpRP6Ap-eDMkV&2R|%ZDIr$LtKi-h$SZ8?{N7%rz&elIW|?1dnrEM?s_M zANWwi5K@9yej-PuXP%7xC0XLACL#R1ksfaFd~WEu{&hr-!L>C15F{h9-R8MotC>_( zU|6mL(jejUy5hE~x4Gc}{r`{4hOfn{o&o(Vh^=%Elq?v1bl&tK=>@!e)FIk4v{Tp^ zwajo(76#Aj$!4S?g@?S$n^%|gK8%8Myt4_MydNOF}h z0cUdIia80hCDDtq5hObMHTk)Lh+fL$_w}bb{V&qJlY-1XYQd%4+!!bk?P-EnXeYtj zZ>GhLF0!5gio6b^Q}PqfhhViQK=z<4FB{YK5ywF5@s|qM$0tDN2{^eG`;Rp$qr-u% zKitdvW1a2o|7qxd#_0c`N4ZIw((14058`3?0aT#v3HT*h;|a)N(tiS4roi{_W3J`g zjc#!u!*>r>Pe9L}tKanfX~g55*oyRB)y)RM!zH-gpZy8=kNWu!x;Ve&{sai;J^|l{ zuT{S9J$g$$0e)G>+2~^bvHk6T=s+Fc0(h9d<>>1pu4Cg9kf{P*vbv{-!}_DP|6!&p zhOUgx`q}RGpMYSJ!Joo8MTwRGzNjo*d{N(JMVF!mEbNs@Jkx+`|8!{g=#s7T@d zzQFwOvVD|XWPJiYmx6c5Aa3q|yx|6@!53RuyPlPXQ2&e<%PE=?Z>Bi@^et&6&iG4W zkJ2J`*4d7DrR}H`F8ZIfRv3|^aAS#Hupp=P$44Kl&8cwnK%HZVNVaBS6WvA_XvT@VT;S5v}Dp8Sjw6WEw zugm7mf<)tJtW!76Vi$ZwB%Mv*v%{`37mWq3QI(aaimp z08u_uGSb~DI7YbF^$aaPE81fR)hgJt)E54zCZLu52$0YK>I`I9DvCA7(s^df2W+pC z6-OX+HSsGyhnQoL2l@x2o-^XKFc}BLh+D%IG@ZQm=pMrzgD|D-g}r?OOhdTOOvS}% zulP0A*N$9&a0RJTp9t9)QXP+sZ@P~kEq?>$dbQf0;2KoKB{Fuk!L?MF>HCT_JLfMf z$Q0OK)z*rc)yHVgt8}04_TCu;5Ha-%g>VJw@sbl{POBYXjOp$qDes&Zluu2GmzxyO zYZ=dST`9`-&=mku7Heq|h1*PcwN)kaHnj?_j$x)az0WN@R}%5QJH2B=%Jg3{A_krA zXlm6?)zv*Wd~@7kL25S4DJ>f(%Z%rdC_w9fwQeQU@&4Hax99O!eFeLv8Ns(xYBzR0 zt_@*<4JwHbbM9^H&N^3MELK()`}I^Z*4x!-weL;xdaXn7)%vnSkmxauRz???Hh33I zYq&3?kTO4qS|qkD-_5nG8?t_m|HMywX&gGZJJ+XYAXl zV0EP=Xrjhx9#7OUvluSQ770e5fa;nZ6G6ko!kYMGc98##AJ9jutZ-0q z6wbW*wJvi>CC-uZPJ58`M+9k{D2&+Yu!x|TFEJ2**EyJqH*~O=RSu#eGFCR^Czf8K zDxNzNT%$u0_V?Pt9g^v(6?DI>u$y%#;RGRwE(o)iMB`ddDcmrl!05%oP$NYer>h_H zR?do+SQmD8GnKFqV~C=qSOiiq8XvZ$NT{~LK-955`fNgugxTIX+pl?z={e2r;?)lD z^*x4vzQ1=qV-U8A>Gx;ZHRSj*epCR7?ti?VYpvr`)#%rIAT5R1D~2k)1O^6j87EygS~PFSliPUyg$8ub$Zu7$z8lYQL`Pd&6?4f zvnCzM%~=Th@B58j_Pmr2mUQZK0?o_Y|2G=mJnFxo14B9+^U$(MR8sqdt zsYa?CA!KaYV^)DIN(!9t^0B@pd9lEDr;F%rthAGiWGayTo#tC@^)&o!UFZecg5a1= za~W0tNh4fhJk`E91Z0nB|HVXfvYR^@r|q^vIU9CBo|LcP*V+xC>R1`7 z@}ZRqSB=KWstcGx-FX8L&dx}OJyUOki{h0_xX+q@A;e7yY5P_*l7(`d*}jq%f`6r^ z+g?28Tkn_#WV(1~qfWnn`7Nk*R=t(o|GoVEp6+x$6R3mZn!E^!0!tlJjGt6 z*wSNQQ8DEcVBie(RnBo?LYHH1fcYojrUspvywRgz216uI&&AjwFEST>Rch8-}pPX@r)}Y39MAzV0kq? zwwF8s`5u#_(|^1%0;50JYkZw-FK}0kUAXCo!iBea*K5}P zVproqx9Uc-1V-nG-=Ws>GW7KO*btXn@Up6Cv1M>1_v;f-!`iZq;Rm&!q1Jj2!^x7q zaCk+5p)^XHhqcr9Elj0#2GIM3 zM@rD_`%uE#nhEZ^qUJBZtkRxd+QrgI30Ue!S{-L&-sWYlqG`XKelCM9?r$iK3PS^M z`pLnaR8d`na|U>-g4{8fI))3F@mDS|Uf0O7DYr{66nAAYYkXNYBZtgi=R|&gah#C= zm8xboVt{ggh!jT<>uwwIwm2IzH~f?>EOlP%>%7c6@;r1wgNy8%NbMi0AmQtXLH!r6 zV3VYX`e^HEDm)jNCt2<8k;-fujlnfCCN#(Rks)fIW8cKZ%IIUyr2`i)b6**eF*+Nq(R|*Q|D&oUTHngbi6=E?w*wR=wCM5j3SIg$E#fZuR)eTtcz)iqbEQ@ zZQ>cYvi=DG-$Fg0W0`w@2k<^q@e{CP{_C3fW3_$?=e0uD^g7!i-drjVNPHz?RqAq!xzo5m)-GUZz2x z*>uBY<(GzG3~=weuBA&nr3dQT%4mPo&1Urh%)FJ9ld~tlhMDqCSk>wVb;maKq{eH! zW^Zs=XIJy=`GtMfod}OZ_NAaMeQxZZ4v&jNBQ*~McTx9UO24gk@#j!O_%9Z=CdO4z zje(46=Hn_;wX;*{S3<|F59X-sd{q2u0lQ2-Wf6D0(VG&UF!GvAPVEw5VK8B4_cX9r zB2j>qftT8AsLgv``{mSbe;sE|#+GaDQ#ve+X4J=+=7qcQ^A-${pm4%~mCft3VjZhg ztyC?Q5hBtg>Ny^sAK&t3{IU*oA(0dyplx$L_a^(CDdFrrp1jJ5u!n9QD5=phr(~fA zXNcRwI;&BdhQZQPB?0mohAyZJumR2bRVEb9fi5|`HET!9m3!)IC*c6|c%?OEo7=ZO zHJWYt)`gIU=uIMZ^RkjhDS3uVGdbRdqM4n~M!Q-ZpxAyG3WqooM%f?Z^=qOt`Hd%- zA1v^i>wLPi9LahY;09d6Xvs_EZdy=*>f)W+R7gtA_lP0?KGHN#yqma=5@-yqA8SQZ zN0LDkC?TZ0Dle_7xzl_#qJC7`+#GMq8`b+tH-c(fDtjRzErYJ;LH#dc>Uo6$VYv|D@1^77Da#Q;z ze63j1O52L$R-zo_94zBWtRm;^({qz(UD+w^Ro@$9t#tv`yO{)IJQ3BN-u85}@b6WU z&^}OFzTL7VobhCpTV8hou{Ja08O-VG@zi|o08u2u>go?TVDsOAJqF@hm#5r^qrMOM zT!N3IZB^$bt&=@ba9X?JO_vcLjt0RV^Q?!He)k#$5g3d&wBGg*MPoO^4;wWFbNubW zYMD2!U4SSXrgocO!4NER#d)tGlb*Rg6+VD}8w`(il{Qt&1SfkUbJF?1Qxgz~>v*ENN)M|&k{U`!wq z5nXMT+RjdplaD&^$bYg-rS}LTg^{neLcRNaOVG5*NFEg1#6jYZKCb)-W-ZVD$ z))5q^SyGypHaqM!0qHncc(PuPy-IW$1varid*07FPmocudk?&b$f(q3aQ1%_(7Pi} zW{(%;w~B3BJZxfHXk^2uEJ->>TlJB*>?^{ij7nBqvSGWT9xrl22Ek6Uexyp2Y|QxZJc|mDZer)urC>(J#K;M=lVT;*2WW=ry|;yq;6rA` zL()-!O%m^I6}}dM96>UdJu|_2TzEWgTs`JTn5j|yjipGI)d|-9eYbB)>hGxVclJ%C zWwj#@Xj43Sk%hFJiSVkR^?jz29ZNuGC>0I_ZVoU+`w~d867bJ_u*C`Jb3gZl*Ib_1 zTsz}O9aRF_F+CqsAlSesF31S(ELr*dggwNXxp&>d_r)(PP}4mW=)%{uf7Z6?-O0fnfN> zU5+y0SSK##v@*52Ss&)uvR3xT9MLB%r(a?w)S`D=?U)diJ93C{Axd%|hx2~+`+RG* zzbWQ%`G0sHey0CZe^^M4yy~4p5n(jVFVNHPdmi0IG5AZB*tS>+c&Q2e$fNsFGrT;$ zHp}G+c%eP;*x&pFWIq->0q+lnRL8l-x@?gPlusr5vqFbU|vMo5-0R2JR{bvrpzY)XegApc{YE^tkYIVkjnCgTbWB zD~cL*a9D9@JW^G$C}(DJ>5c3!O}e%r#s)kAZfgZHD=c_x^~VR0v+HoC1-3_R(G*;d z6aJ_~=x9}AY`gdfu*53y?A9nFv8^GnFM2tuI9^|Uom3vm45x{f)rc8gH~NuPH9N1c zx`d*Di_?co;EBOiic=kAPZiknnI>S^)qpfEzq-bSs*w*t+nl03L!4CPGN9Nd>zTWj zf{g08-TGKXE~vsR5+x=@bOUqyCnffkll01gs`k#5vsYZXz5Ux%$IzM_5H8#%^J(Ol zv&mnai;Qlz{cLLr%dI2#!%Ba63cw*m1AQ0EFxSdR86fNdP~{H0mn8-E+|Lhh7H4f# z-mv8OQ#*yI#J1G}*kOCK`Nvs7g^go&8&hoVFYsy$oINp9sA!z?B0dli-UmvPh7y?y zAQk>PE0|cS$8bwFo?50o?dUHP^A}umYT`-L?}xT9hIXB}wA||zgO4}&6z&Mpn3KNh z$zaA5$*R4}Tf?qAjo6jzRx-4OV!HwrXog?wjtf(0=RK3&`9_c>Yr1r9n&OG0_M;$# zkyJ-FAA?$+WoNFn5uD8F4LP%JYfHh~Q-6N`f?oqZ6Gm7|hWXhKGtMs35V|EIcO*08 z$Qa2$@2e7exX?+~mnR-o*=yf`LZ!28**IUy7xSu{FiZ2fO^2jnO9!klMy= z7>daZuTUh-5%k*2FmX_G)r_V{H_4LD^<_Dnl|hCfM8tB20KXUXVmI zA}C|RMjiuj1wQc3`5fiWRJru+rICOLP^65MdMz)nR&G5u_(&z6A&L|!;abq;2rphZ zB}w%C5wSZEM`G^{E(|42yEGqe}t_2T=Yrmz5LJK|uoi7Ur$s zeI{{X+Cc=Kjq@kQwNaR@S*1E~1fd9+amgk#6^5uArMziCGktLR@a{dOCe; zpMY3})ZadUA}a9UUc{}(_QsgUBn1O9*G%LmAQofzA=sVbfjPBN@b4IMxe?AxM27Gi zv&k*?^awxJ{kguEPJ%~o*DBH8yUVv~CG%wxVtBhUkIT3UOIgtjp_Yf*xx^zHXaSiI zEfN$5?|y=}x_Z%XxY$Iz=qlVl0vZO2 zBSt$~NiqeUSOgE}`Cwg$6!$?fVdk@X0z<4ZS|kC8#0@(5&0M9~$}+>UpCIDhG17!F;pN|lZ7F|bzb`dQYRF0SfC*1LwRhbCz7V??v(34ncm4?5Ws)E{GZw$XH?CS@$& zl@s;1tQ?hBu4xK(<@?SaqZrz0?`o}SK*&@s|5Bm95srw|N$2XNz{b021^w=Z<@Kp% zg~Hv&9B_LxNX)IImHLNsL%VwrHHyRZ=Q^WkO{qtUg=UjJjH(%6XDam60_e zS}c5_?6?XgTt6U<_F4U2wwki_y|&Tb@J{OZVdxh=&E(okH1@SP{$2g}siL`U-ZR68LKepM&8~^Ax?!E>=B#h0%XEsanR$3t zlE;fBxoPr_2gO$7ECd7LFF(V2h_jcD9k%t52E(uV)0`O0Z{M<}*mK8zFDTAuqp{OS z;?_(&mZ7E+X*p)>o6eqQX;#Ud)$^)S4Ovoe%pm7*wDC>U&J=(4Np88Af`g9ye(LB0 zcf$p8&0QLM;hEC6=v`Vf|Fw_q!*<<=NZ%Rl<)t4-rl9`j0lCqo`I!91{&33omoY@c zgbSGu^l%@h-3Kb@Iz3bEkD{~QMvuP<@gnv*@YL@grZ-jZNm_Yk|Moqk0x67VV2=V_ zIu@Bnet@s#TOA`(e;mx$&AX<8crhm!ycW!mYOl&OeJ(^k{Mb3Wm}O zqQpjd!7#k*Bz?y1}tkQuXc%}etf=6CT2^9;b^6EWBWyg%3ca1 ziRQig&lvRGy2i{j*9zO%Mad$i%P^<~(QBDtM#(xi$B*u%oLkc@kW&?U?*92X%1;IE zLL)Xy-#){)W^{%3?g!#{hrUvNj_%A1D+&(oGK{ogb9PQxDvicYeyyUDw6ivXwONkN z7#EBFfMo_bJ-MSR>Pp9Wf(mv%i|&*^cAJO0?>-B5d*K={F2*O>EdfB5hZL_gk53m2d+8qTOAtn4 z2JYT6uG2KVa{$QsMPv}u)-qmQy**AGn#yvSntr9{HD)1Y8nF4A1`*!PY=StQF3fJ< zE=TKgb3LE+S2?okK>mZ<+Go8z9*Br2>(T5Oa;tuL_mV3&=RYfxrsT-!bW+*a8RYP~ zwV2WmdQ_(TIan(pS!jELoC`LmHjN$yo|kk4?9%Za)YP1uH8$CiIq=v^0+ps%gf_1> z$SH7H4hv`ZlVo0Tv?|525)1o>(->4o&3%gN5wZOm(IFVTS!GsiJKOBJWbDcU`35PP zS=B*JTx@MhqQpVsC)&{^rx=-*sLq^!`x4}E$SQyHPU*-3#CsvaXC`to>(BMJF}k7B zmiMx;t2@KO*d)NbHX7fIcbq1@HqUJ%{zH7<-LkFH;kq|*AhSdOq_Tl zAvqDmZ2uE7Hm%K*Hu_Q=FHAdF5}nc}lyNhQtVEdX8k|@Q`kq5Z zF&&KsQ)!}cA-Jq9C>bshx0s-89M>ewksz1vq5J>? za~WNZc~sy@tQlIaBS?D)arCq&wd6r73M+76NBe$?a?7}x9n`it*CB|Qf|>K>NQ(Uv zn{d{*&-PAQo-(icrM-T5cL6Hnhdf7k^ET6(sdGD;Ny%SkbUQF=T)ac5A(^k1 zGqYTvq^?@O{Q7)9F?8@GW1uG8Jdn_ej&w535q=xn8ef;^M+U+Jq;ONxo4CL0XZ#Nr z(K8&m+Yy8eG419FOD~!ccM-lTi==lgl!fomY06!H_?QI^nd!~j5?e*qEW9*uzHn~{ z(Q#OfpXI*eq`#q2^?!!uTW{)`5Y$$YhK0KABRr-LB9_iFaU1>AIikmyPVVF!8R_qQ zpkqjQ++CNw(8n@WXZD`e9OOtVXe;9^WE?!0=?W0|1Ns0v@WUE*DWr1iar^`^!I?Y3 zFYuhuJT8JSk&!t^PWZkcMdmYJ-QNtO@^5f&GW)l*TAEQ=AtW6f8UiT9#nC*&cf}}I zX1cL|COnG1p!FtAmHm^upQNJom8+%MO+N-yK7R|BzV_}yL= zxAag3*B{7Xm4CePzazD_Z7u{gljDTgwglEl*=HyeKN@v6@43qyE4-416ki#Ah?$a| z9Nx9fFT5I&?}Y;MlK(zEA4=&D_zjKE%m!aSMm%&sn)=h_mi1o+zIqXeq}a8|2&N3Z z85bUsdGxafAJev6hXZv|>SsI@Hy02nI-m3etTL#|ukMo1E~p77W_*C#`K*jkN)Pp& z$*YUEZz~}=lAe6Ckb>Zv5G$I_Q*87cx$Wn$`k07o;^YzPJYDii?Mi+~z(rnUV{Y2N8N*PVR*W~O^Hk}h;W(gk>8bPn!Oro` zC-YvT^mKtur)6TUjkr1DGU}BAlJZA?v8i5!7L#3>J(-}56O8>-%u2^)>&ouE4%XVU z2>)I=jXI9tsL9Vx`DP#_5o@?H#|tzA&XEN}`HglWH6eMN!8Ub(Q;xY$|0$&Ja&9}Bc@)zIB z5rf;UmY3qZn;VA7z@rw~7y`luE5&bNjj%!VRuVKXs?5ku@W*-BdJ;Ku67RLKmt_NL!d+j3Ss3Y5K`?U7>RJC}m>!>z2aI?hU3TSj)#T0~wMP7d)hB>%$Ue(5NS@)&&16y!1mRLZ=WVsbY zd#hgK%)_KDWq1G*1INEb8Z-w#dmR&V{mQ|0{D3#>Rt*oZ-)Y=TC=V4hEV1x?eng`8 zOS0oMhzFOPcEkaKKVEV1jn{U=G|HZAj5{DyS)Fv zq>PkjLwp#QZ*?cHQT?SA>EY!H_3}F08TK^NDGz?esvn$I1MjS5L!se30;_3lQ?&G~ zU^)XGY4}zJB!Y?JTJ~N8oD_T*997DA=&ecGT^!WJJ;aOj?G+0T8FfeMRd`V}gRc@x zWV{TzmpepCwGYz$H(&nH;bt&VoC*Q`b$84kG)$wJ9)jT zZNzcEZQyPH{*w5u-%xpo)9uU-jO*-n;GoN5e<4k_M4=ZUb(n46Fg3%?O=8ftq zxvf)E1++SbIMw;l`Wt#Q<}gXK96!X7N|tPAMU#b{Y=|LI`D0H&S@h}(wbd=G1Wz_H zsy}pUr#UB>Wq8}R%)jRR(`}?!d1eb z72S#b#F@#z=f?L7&N7!$j` zBG5bt70QYIWzy!)>OiYVuC98M+IY?eQd0MV-yR(Od$ZZ7w;`eVy9XpT zx(11RQjg{HbNRudX#rg_?^YJDRJ=0pq{k~$Y0-!>CrPu(fZk<|;V&`|JK&==a1@1k z%pFwQmb#w+Vj$#BW}HVz9nvyh!Fn)`@FEi-j)06ExwTLsd;5xNLMi`eKMTjD{pI$e zqM%*xjSps`)!!wOnW26>Z@D^Y>GmhWOevZpF++%|F|w47#$u&4V$dT=bP`o|;zZI% z4J3n;uA9(P36=XCe$&U|lGOVp-e<#qcu(_tAH@|k#@7<{$o;#qffbQgsGYPI9daSD zj-%C73z3mVFUK-PMl6hZr)wLOC1$0s@FX1pdpwqGv!uwv^F=Wu%f$78cNwK&EbJ&T+p z?GamUjf-z8&x__87_T_8vp-f9>P<3f%KCJOw7GtE z2$}^4-;3oh_?0P))C@3175L8C%C7m6=P|-`8UEIp|De3JsfbY~?zEGcaVGZ8tiF9Y zwXZ#g^uM(27S8X~jrKUZx=dBkP`f#W%oi?4N6!dnD!q3V$G5+C zHF0%!$n8;;Ad39~0?Z2B z&mD(Q;OUB=H9PN?g=BiB{IF*kW?iq|o4?c2`3vofbm6jBlydtNtHs^I8TIw7nDpt6;9DU<61&B4l%z~TUMC}yd~fUvkBXI_tVhIrS%awr@wBsFyj^Ncpy|lQIn|jL0i{NJ)8sFjM~wL{p>(FE(B9^d zx%I%o*A+}CZxS&L44h`DU1FI$+tmyeR?0kDK;-W6)*SK*_(zIgtz;0Ww> z(xCp%sv3Efz%T21JD!${a$Et^K2~H-dS|7Eu+0t4$E~&lDVC{1DXUBZFReAqh0H$2 za_-i{U+pDK?Cv(I&&K`IQMOC$D$|t4Mezeu^w`bV$GKZtL33P~!!|!n&DPTwUR8lP zr*NJ%KjBEY*|$eH(ij!kV&(oNVik&X?#xDi7XIk>>RzsNzCVQHR65w8eByp$Gg-bM$ zz4z>$!p|vX^q_@};qNWn}{eoUB<|Cocl8Uz$dGFdU(l>AOU1G`(JsA&~xH9&fwosu8H>^vC z(1S9N9`03m`_~g;pXhAk>Tw>D2nnWF?(nalIjBk+C!@DMlYSK!H&tg1mwr5bAKGjH zH9XXzd~2b0MgZ{Omi(52iGBW~|XgVl4V^4YP2#~~b8zbCN3&nUhw9Ye}ki_y^MK5PEr z3UcuFq>VA$nr-a==dLc(^KRbNy*8f-DZZ@gH~nI>8q6)lcy+Sdm66doi6O`$$VCKJ zPMr5#9&F8AMfz+q0nXRX4@{a3SRo0rkEoD9lUjLATMwjawV>P0##;qF36?TU|2<3USGF!GIPiYXPp|Wm%gJ;iKT+n6^Eu@Mp_YlhMd&S3mx!f+g>v20 zXMO4X3;xq5Kx~0{f487M_P6bu5%MrS!y+G;Y35tfyyCj+=MSG{44U$!u=yJ93+TKz z#&lXhzQ>dTl>Vgje|Dt*y-~x9Sk>cjE&A}*r+M>voMXX4nm5f4B!iuF_cq5LC0?kf zn`a#`Rjd}N)`wPMi8l~iHIM-83xI!e4;b*sVHGR3mf$zLTg zR`*o;ax9u(>P04(`YKI&_jr;B)&+Y0VKYiN zHW;i+jR`Nx?0?5Vp`kzbL%Gp(Y~DpB#e1(!#J2pmeH&9=-M{SwwSU7tc%*4;Gep)U zJB~E0gCq?8@vQJ2@#vlI?89)a&RV%AJ9iw`HO@NQ^3Thny+xAR>De(k)+!cO7M)17 znSqVorVv6aq@aKs4Z?e|RWv@M>gz97AHW+K54lQNWv~^ule3()weR17KixQ%##LLC zN_s*(dm6r0=fw)jMK@`?Q`av~-mb1^UM{KQr*?Y88R~z2rvI9V#W8FGQxmj{y~@@-NJlaT%3FuFQ0!6Yk`ZTyZxOef&*g*DmgggVVXWA)9E4Ebt|iq-5oo zq7Wl!fcP>W1>SD?e(O=En8@lL5lrecM%3rkp;F(Jgu$k0ac6O!Wm(rdj}6|)|HTsy zv>b^|g0)%(#)1$z)Nt6lHL+Iog{yb-_J*2OEsN7Fl(qyOv_K$C!_WKEOzdVsz1}mZ z>0Fpfs@dRHE@(mm|A;7!VO#s$E^1wMw0PeQEi@?y7VSj^scKY3>0`ysdz#tBE2F*J zyC0|F^yU5komR}hX1)BHLOL|*Rsx|$3Duh5-i-0&wG>My^;_OMM$S*wkyNNRkfU4l_PUAoCTzv4lLA~nUrWCz)?HY#xh>3RUgw;X^kgJl^?gQy z8m{fHi1s{kT1pD{UrBrYo#hXsr2_pWCR>zp$84R0LW9bvrIDbCXWsBG$i0=EdQCCt zMs1)fgMmxO29wHV{2rs1Ymv_%r=|DW;G-||>gz&wkpfoG1qeF*h;lk zc-X}`g=v5WboGL+t<9RY(;@GK;Tp~$N_~?18e;$QOffwUZ-NaCdjb*O{9WW&#I*8~ zv)j&4SER!rQBUu_AU^5ybN=Bq&9dlY@*M#_s_|+8MmVcIAVEb#OSRe+;%GnMgCY&i5Y?Nr=DkVCX1UT5U^$0N zxc)VZN(?d&4Z)IMI>sFR@U6ZvuG>VAbOnpVhcwtCOOmqQ1@ciyoSs_$B&pp@#0&XO z>9rYipsCv9_+s|!FHU?uCLbxrGiNID-I92`W^@S6Inx6fUS_r{h=f9Cz5D-`mGiG# zGk8kM(;h#wIjG(L8ao!12_mP0W&Y*o?-EuEp3r%_FR6{^vxfW zB;3P06}Ayck|unRz10}flNh`Bxu)=Di0(9cdp^F3Xf4rLl1*advOFQz_Ez~vRAt(y zPwW^8`|W#EnR%!RN@mMZ zh<&p<_81R?L08EHhgRNLB@5l}{Z%h*59ALYTV9Il`8pmx0U~c)ys8wsXPUKE?tQqy zYb=^EowJPnqdI%ttC~a%c{`6F$I1!kn`t9Q6Dvy&dk{R_V~RAvjYc?&oF`Za{Hej@ zeT#P|nY`{dDPCRt7AKT8+q>d&X%VL^zDFS~=JO}uIc-}!lOLR-CI$B|3f^4c=FXn1 z&eYxQ!UDX~@cSo#y+*Tl-C_S;#W=>@^&RFnOBYt|iLct)-9De9*ODj}%zC#ax`f_N zn9AauDK_b42x`t&u18fp^ttwng2SSH*ngnxF~v{c^SNW}H*ST7_xSEbG)V(pKHnqr z1DZxoGNq+m$PqF@?&cHc@r8klB%+;UGO=6|2Du-Rv4n=>n47p`Pq$35BmC@*h{e8^e zBWCe&YX%v?2DD4bNhAV_26FSr1zdcxN{8AZ|uT;Sq z6!Gn~T`fG=LloxTX>BpAL}^+4W5ID%mgFd9iUZ49is*hAsUPx?oc?qG4@c>>;+SI^ zBgYxJlQ+dMMX+0q&j$oN(leHC%?Y9&-38Giz>wKaV|6^RY39}>DgO(V?AbgkjJZe{ znaDe)?8cw-xoM;ym~CksDkC|G6eVzX&3-%n7z>L$n7NzEa>Y;cEQv09C@L(yRBsMz z;@}5E-uW`w`z@CcsbV?FDA}oguf@Bl{w9U1M)bzPd z!(Ri&#N%FNN)cU2i=-&R$q9KJZpvn1!3&TyF|R6q;8@x@E~Ed9P?&2qgQ<(02+P~V zn^?(}XJ`yG7N|4g2qGDxcEu#fCWM0rm_{!Cll(ov?D`MCEzx()+Rq}Sa-xxsCDcP+ zS`XUtPTw4xKRuKu|27PPl70UVL;CyAVY;u=UM2P4Dq({>4g@UF{#S8d8P(Rdt{oaS z(gLMukpM*s#frO>QY1hL?$8$Z6bV`g(&7yiZE@G)Zbge1*Ayo>1d4mV?6c3=-`)G% z?;HD!d(XW;GFHYr#+qx5xz?O--sgGVCk1-llUon2Fj?WbZaz~C;;xy?v+=912xa|W zb|`tzN5qJ~cwu^IUHoWWVJoQ~z2sxMEZ$j0HewMmdNyVnqi7W0km={JlSBBBb!}9O zuy~NnAl{Z8Wcl&lj*|zxZmPck<6`XD!Z|Z5-8K^jwN7-a!4Z@yroS}}c?v9p6ZTkw;5x(5D%Pmnj z3CvJIeDoJX@E<1A!`d`5lW*wHOzSO>^9?j!zAa{~AyT~#qzbDMii2RQ*=d%#%eI<* z@XPxsE?D)N6wCfx+S|O5YH#d>J}_a*c5-EfB2V!~9TSFX;pF+ni|buC@^fZvlu?A# zd{Qtj-3(5l`s&QW#myHT7V91TLP~?knZ3-w+h(nNTl)x5AA6!8<1-&dAP(WtD z0OTb{Z}m4#&H5>HHHWr*`6$oSMhEL@+9MMzCezZP$R(=1XQrr~meTE)?nw5ZfT`m{ zNyo6k5or8H8;z3{BsP9a!PMv42al=Rdi=fkmV(44^p5g_2MX!EtYyG@*02=`nNqIg z$IOKX-#9$+v02ySDx{sL=+5R>TKf22W8Xabfi0wO6)i)+3PsnuyTl>Wc1h#tKDt~n zrYu2;RBP{)Y^yVFaVBvU>!S7LW+<;e(?RM%y@K}(?6_K_SuRdUPHU$FpY1sCps+!JYd8Wxpvx`kM@9Dob zoCF@QbMA7o-`%g>Bg5JNV%TG>l9M`Vva}s2-9|$zi)ljya=^_hRy@{fyr}ml;LVh$ z93<@e7*nfuEpNFVW+k49Ux#0b+Qg|=C ztuC%1IXAB8PLY#$o>B2cRoDiK+1r<-O=je&cEp~dReG~c&v zQ1s>vO|cU3r_zVtl7OhFmqcxTLk7F9-=0dUe&r?&@PMI)Zzbd}?(}?vQjEV_2-xq94xSCtCrfST z-PS!$hed`O2qgts+iT6M@cSPC6sFd#AxoMb7v23~LxvG-LZv7plBa7&v z{l2sWC6CBT{22;6!npm^SBxSn$fla59QyM4J;mZ{?9Ehp;%W zDRTV9C$fC;@>M}nyr`wO zVlV6La&>-T#47!}X#%sQ$3&=m8Jqn(HCtInwS$o|m-w@uQL{4v7if$BA&m6%L{3VbKBYsbv9ST%iem|*cQ%dVL-(1j><6eMwl=$RbMAeQ#M@c6yOORaKo#@E{ z85hJQo<8p^$K7SYNas`J+-gYzZ7FD?(_1Bdsd({4JXi+p$@czAWoM2|Gu8-cSGu%t zb>sHby*c5_f| za>;?~2uh%*_bsi8(|cWX(8AdSa^;y=!Q-BVJPJEyjD&m3UTwg~^ZV2hUjG zRabJ$AWy#sxrU0)&RT5q(D*62u1o?sAZELKB((8S@+LP|u@~%V12YsReNXgh~k6 z6Q39L@z=-`7qEB!IA9$u>CW*8)~L~Qgfa|&F88;|@jLKG1XXs#^jDkTkQA!K*uqeLrEB2g|q~)FgPh(bnyha^Uy3s{U!X?%n zC$JjL!`K5dmu}}>-R%^G$6c?@b!G8#yN!FDMeDKfqm*mB} zkLAV$lri!{vtJ5@ggvdXoW;wxx@S>is}Zo6mIO!=#+z-`7(?YeK4bQJ7!J+Uzt3nD zZ*#xb$%&N@@?dMn2mXWy_f~xi?TqBxlX=#f=L?o}&yPGsqGw}zyHAzroLVg<-?@$V zW+0aF89q@cW@{iYrHx`6wwGF-r-`xqP0NG%%TGw(0)j+%5@%xodkV^V%XhQ8U_;^5 zpFD7x%X8<`Jtj|>=uiUG;U3L!*^+_-%GogUv6TP$v>@c{sb%vpdYNC@)_qfhm?}|} z2yiF4(cGbx-h)IDm#JVsEoiOkLv@M~bRtiK{Ej0>?ch&e|0w^F($#zmx#akg zgpFe+_El)CaeFKNWOnt;`>LUMENpIqes#wR~R3-%%VRqbUNKVQTL!jr5@MH-G0tWUJ1Iu@G%+;nV9~{Ql7~wY3b!yA#vL8TX%^v zSkeBG_Oe9y$8)qyH^|yzFCT%fAi1c_Ph{kTh$kXFLtk)F97--~4wyE9yCc^tviX@| z12TVSTySxh=LD-HpPU6vC{}**8^7lg^4;%YD~bs4#>KBVK%?NSR1HdO)1VrBn9$Os zXYo?>(Ulc~kJ3r9u4gg8he9tal{HLDp9#jk(ec24F+eEzAw40A&j0mFxe+#RcPUvlTLNA<$y^R48a9*wF#OlouwD-{kDwa4Qh9RBj6UhJa6 znaHbvVbQYD3lE!EUS^9>r0@Gq$9U@~On?}7j;gl3^4yB!zN3!ZacDy%`aqo4VMxMl z|C70h`UW3(2Y`m;ydItIT_QAK+iidUaP27l<=rE6b&^)$^3@ITKt?Lbtd94_&!<#*4u$8kXI5hag-OvR!pI`}9NPzRyD2t5@YTug1v6{a^*VaOb6Og zq1!=XsAz0i(_9y%h(B91z|7hw-#Z(I2UyP)0pUy*W$4+-lX}$BWXzSPH^k#mb;+Aq zFuxX&itU#sv_2s;b>NIAHCywf0R@e*F0iJ3OkJ(81|lY0LtG)$*0pXIlDSDX57WzQ zvTPpNp!G-m5L&J6@siHhxIj_Re3qx=`;5c|NPSBbu|&9Kgy`w@3j0lUlFi~S(clF~ z&#r}My6N!wsm4G#32&^fY2jUU`f4 zQtadZ%&zDkEB$}r$g!YrXsg4oZ}PO#P!*%vZ{QK^-~2*3j#Es3Qr#&mM8REZe+l<- zTe&tIbn1u1@%x=uj)-dU>ncS+rA zCN*%Thd^WK$kai$?K1Dh3te@s2cyv--#r}?-sm+XjTxPY zI*nrMjg-yRJAlQXXGY=b;f3zi#dg{mQxitRp*=G4gm(RpF|9vM+Q%>xF*bGz7=G*ftf-+77lp6Wwimc*ETG_KPO^50@K}20zyv%y z&pRN)O8B=;=3jU!TI#>vc_P>>jLTx4a)!Mm+(StS_^Jig)&UHhjx^X}4@%2Av*)|JApJj81wOZN?`iE8A@$+7@qrCp-J6I1DE!Um{tqGPKb`S+zx;o7n6SI_8}7WH7{c}AfE5&xa!oI( zz&f;YwILRXs}(k3}4cX`fqR>3?`b5&m^CMN6;#-x^J#kImwK!tse0*|TLwbMfJm|Wr6=-9Ji*K2lBYORVI#GG-^ z9c;Dqw7&n*`G#ViBs0AK_kCR@up7xT(d16?!;W8uJs6>QWZ0#v!H37lNvz&z{`-4j{9pb_kw?I8WEVOO+v~s!-tP~zY0>0&aq(8 zyTzAHhq=Q4EA1X&+B05s82ki;C1X0#xCB3G+)3#k2vKlmLOi%>dxB*W>?yfDxlIDa z+tA#mM^@+s!qz@_pO!G$ki9bMPI(uEb3bBVG@Y>FKU_eRo_fZ^GfVlsGB{Fmf)!G! zXXtGH;hce0M8!P%>ZLQ}R^{>ITU*p6aZ4e4R%4G#s)}?FE0?M>B}}RE&orE(yb4f< z3*;lVU+#?@r|3wkqQ%R-X{DY%hx$lhE2MmfyKRZ3t!_Lz^(7xIE#7^VY!X)^OKaco zcH>f&kE$ccH*k;i39$$F3C99K$Eu#;v`?yjvB~_j?SN=cfAO8L<9`Wp_apU9CU^UBEEK*zkHt; zVTD_;Tz14f`o_DEV!Z0;xfCJE$j;|aUzFjzIV56cK&F0e0$x8H-b6GEN?>C!*P436 z5e$Fa_X;#S!jTI`_C)}RAZlLYUcrG!AXRbmy0P}+Dru__8IHD)y6c%~c}@R=G7RLW7kdf8pKK** zizr$$W-_SfU+9pC=wX5>*PAJ*r&5szl%bAs{8;wqb8yO+QXZ?<)Lu$Z3YGUo{R9m0 zXc4yrD~;#M66owJU}GGSooo_J!eLwzVlA&~ZdxfV*+leM4nY!M>D#|LgF$XU5=2^wv9wb+_jzl4W3SKs;nL7T&-$|0Am+b4@V2%Rb}@ILQjf+2Rq~* zN)%{ic{&R?1a9=n;LZ|T14H+N;&P+|tKW;E@nV#Mbz2BYg?c6h>YcGrZ?ddUTGzt@ zWrsDjMw`^HBLQQsk|tw2$TuuX8R8&=MTM)|-xGilZwE!mtI~+B_!41jX(xs?Rn4Mx zp~p1wMDBa6pF~)Kz_h8?c2B3;k;Xa26BRxUijJE79~=c~LNVs6=8e(KUL}M87hlq` znD#FZpW_%mmP2@l10~!WEJ~vW%PTlI!@$g|Ykl*Os=f?5 z&-G}}+lZ%oA3g{Ju9~FTtduFzTHl6G{`Q^{sm~wq=SGC4nw|#Vacex|l zSsBtEy1%OvU5W~#PXTzMBbVwW)c4`-Zcf$d2O7ww_ybZ-RWkaxsC!W^V~iPf$%bdn zdmWkW-73V}@$tgl%9;dxa{SM}28llHYrAj0GTCqmy=6_|Wx6@JKmXMhQA*979kP~1a;@UP>Lc##kC?v9PFd(T$OqT^_k;Vqo3m7tUK#1 zzUK2NX8+sn!x3}#AxCe^1Z#@9sTQ0#1*uO2&79Ldo3{4~+veam8WzA{>3T_As-oF7 zqO%{2dWjZgu!XveTLU~K)Nq7DxT>%gJ11bf4~NtuRVxn73_nsl#aBNlk!dp}^Zudt z;f+}Ii{J~K2vSvqw_Q$4L)hB)0e_T*^M)C=IZh0~{B>r-#uZn#LbgKAN7nd0(P$XL zyTIanB#>MGV>FhUTMR->yI5Dg0%J)eK9P2>=E$1c+~2=T>c~FtIGq3r{6;cuOKrMY zLqW2cQ0&0zTe8twe|l`5NyDYV*GP< zQsny!q0H28d8`kg^Ytk!xHt!q^#f1KKs`exUX$8kgikh7K}}X^a@1=X_Dmo8@w474 znubRHU}M<#IyW0_zIE6^t(swGV*?`?Wyu&INh)~O;}jSw3q`&5b@U$XU-F@FpdbW& zR=0PUmUu){A)vx1%TYi;6LSbmg-9dOQ?)+Hu^_Z`QLHu{UKE~Pq)9j zRv6#DT0EinwzrIzlXCCnN@+%+!}6V7#!>H|05Ss`q)7@zt0Tjg=oNy&{Pai`9=?v` zWcQUs3HMxMJE_~_RhJyls#%pV3Q%9;{+41tu^4cCigejW3^~?z5wLlKLrq%&F~^R+ zV*!ptF16FM%#{y0A*!$zrq7Dx#!Z`|@5~jMr&_eug)Z^t;uFdGxkcUmHoSN!AGh30 zsqY!xx1SKk%l~<}iM3yn7N6X?WtSk*YZV41+BV@zFsM(1P-`mZ_yo*9S_2c5u6e$# zX@8>Pm-U4TfvbeE<9IgN^up7H3!R8$B&6QPa2eU}nPYZ(uV int: @@ -75,7 +77,9 @@ def reset(self): """ self.params_with_grad = set() self.communication_handle = None + self.partial_grad_communication_handle = None self.communication_issued = False + self.partial_grad_communication_issued = False def shard_buffer(self, buffer: torch.Tensor): """ @@ -88,6 +92,23 @@ def shard_buffer(self, buffer: torch.Tensor): ] return sharded_buffer + def all_reduce_partial_grad( + self, partial_main_grad, model_parallel_device_mesh: DeviceMesh, placements: Sequence[Placement] + ): + # wait for the last partial grad all-reduce finish + if self.partial_grad_communication_handle is not None and self.partial_grad_communication_issued: + self.partial_grad_communication_handle.wait() + + # TODO: there may be other invalid cases, we should add more checks here. + partial_mesh_idxes = [i for i, p in enumerate(placements) if p.is_partial()] + assert len(partial_mesh_idxes) == 1, "currently, we only consider a single Partial on the same mesh dim." + model_parallel_pg = model_parallel_device_mesh.get_dim_groups(partial_mesh_idxes[0]) + + self.partial_grad_communication_handle = dist.all_reduce( + partial_main_grad, group=model_parallel_pg, async_op=True + ) + self.partial_grad_communication_issued = True + def start_grad_sync(self): """ Initiates grad sync (all-reduce or reduce-scatter) communication operation @@ -97,6 +118,11 @@ def start_grad_sync(self): communication call. When overlap_grad_reduce is set to False, makes synchronous call. """ + + # We must wait until all partial grad in this bucket is all-reduced. + if self.partial_grad_communication_handle is not None and self.partial_grad_communication_issued: + self.partial_grad_communication_handle.wait() + assert ( self.communication_handle is None and not self.communication_issued ), "Should not have multiple communication calls in flight at once" @@ -164,6 +190,19 @@ def register_grad_ready(self, param: torch.nn.Parameter): if len(self.params_with_grad) == len(self.params): self.start_grad_sync() + def register_partial_grad_ready( + self, + param: torch.nn.Parameter, + model_parallel_device_mesh: DeviceMesh, + placements: Sequence[Placement], + ): + """ + Immediately trigger partial gradient all-reduce in an async way. + """ + assert param in self.params, "Param is not in the bucket" + assert any(p.is_partial() for p in placements), "Param's grad should be partial sharded" + self.all_reduce_partial_grad(param.main_grad, model_parallel_device_mesh, placements) + class GradBuffer: """ @@ -411,3 +450,15 @@ def register_grad_ready(self, param: torch.nn.Parameter): if self.is_last_microbatch: bucket = self.param_to_bucket[param] bucket.register_grad_ready(param) + + def register_partial_grad_ready( + self, + param: torch.nn.Parameter, + model_parallel_device_mesh: DeviceMesh, + placements: Sequence[Placement], + ): + """ + Immediately trigger partial gradient all-reduce in an async way. + """ + bucket = self.param_to_bucket[param] + bucket.register_partial_grad_ready(param, model_parallel_device_mesh, placements) diff --git a/python/vescale/dmodule/_dmodule.py b/python/vescale/dmodule/_dmodule.py index 846bdfc..984cca2 100644 --- a/python/vescale/dmodule/_dmodule.py +++ b/python/vescale/dmodule/_dmodule.py @@ -254,7 +254,9 @@ def _distribute_parameter( # regular intialization if is_sharded: - dt = DTensor.from_local(t, device_mesh, pi.placements, run_check=pi.run_check) + dt = DTensor.from_local( + t, device_mesh, pi.placements, run_check=pi.run_check, support_uneven=pi.support_uneven + ) else: dt = distribute_tensor(t, device_mesh, pi.placements) return nn.Parameter(dt, requires_grad=param.requires_grad) if is_param else dt diff --git a/python/vescale/dmodule/_grad_sync.py b/python/vescale/dmodule/_grad_sync.py index 2d6ae8a..679ebd8 100644 --- a/python/vescale/dmodule/_grad_sync.py +++ b/python/vescale/dmodule/_grad_sync.py @@ -18,10 +18,9 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. ################################################################################ -"""This file handles gradient allreduce for DModule +"""This file handles gradient allreduce for DModule with no DDP NOTE: -- If wrapped by DDP, it is called after DDP.finish_grad_sync() - `generate_grad_sync_list` is not recommended to be placed into a param.grad pre-hook, because: i) having multiple hooks on param.grad complicates the design and debugging ii) gradient accumlation will repeatedly fire param.grad pre-hook, degrading performance @@ -47,21 +46,12 @@ def generate_grad_sync_list(candidate: List[Tuple[str, DTensor]]) -> List[Tuple[ for fqn, param in candidate: assert param.requires_grad assert isinstance(param.data, DTensor) - if hasattr(param, "main_grad"): - if param.main_grad is None: - continue - grad_spec = getattr(param.main_grad, "_spec", None) - assert grad_spec is not None, "DDP's .main_grad must save DTensor .grad's _spec" - placements = grad_spec.placements - fqn += ".main_grad" - grad = param.main_grad - else: - assert hasattr(param, "grad") - if param.grad is None: - continue - placements = param.grad.placements - fqn += ".grad" - grad = param.grad + assert hasattr(param, "grad") + if param.grad is None: + continue + placements = param.grad.placements + fqn += ".grad" + grad = param.grad if any(p.is_partial() for p in placements): grad_sync_list.append((fqn, grad)) return grad_sync_list @@ -122,11 +112,8 @@ def sync_gradients(grad_sync_list: List[Tuple[str, Union[Tensor, DTensor]]], dev # get local tensors to allreduce + get process group to allreduce local_gradients = [] partial_mesh_idxes = set() - for fqn, grad in grad_sync_list: - if fqn.endswith("main_grad"): - local_gradients.append(grad.data) - else: - local_gradients.append(grad._local_tensor) + for _, grad in grad_sync_list: + local_gradients.append(grad._local_tensor) partial_mesh_idxes.update([i for i, p in enumerate(grad._spec.placements) if p.is_partial()]) assert len(partial_mesh_idxes) == 1, "currently, we only consider a single Partial on the same mesh dim." partial_pg = device_mesh.get_dim_groups(partial_mesh_idxes.pop()) diff --git a/python/vescale/dmodule/_hook.py b/python/vescale/dmodule/_hook.py index c06bf60..f187ed9 100644 --- a/python/vescale/dmodule/_hook.py +++ b/python/vescale/dmodule/_hook.py @@ -51,7 +51,14 @@ def _convert_by_pi( return x return x.redistribute(device_mesh, pi.placements, async_op=pi.async_op) if isinstance(x, torch.Tensor): - return DTensor.from_local(x, device_mesh, pi.placements, run_check=pi.run_check, async_input=pi.async_op) + return DTensor.from_local( + x, + device_mesh, + pi.placements, + run_check=pi.run_check, + support_uneven=pi.support_uneven, + async_input=pi.async_op, + ) if not raise_err: logging.info("binding a placement %s with a %s obj: %s. The placement is ignored.", pi.placements, type(x), x) return x @@ -198,9 +205,9 @@ def _hook( output_pis: FwdPIs, ): if isinstance(output, Sequence) and isinstance(output_pis, Sequence): - assert len(output) == len( - output_pis - ), f"Mismatched actual output size: {output} vs. plaments size: {output_pis}!" + assert ( + len(output) == len(output_pis) + ), f"Mismatched actual output size: {[x._spec if isinstance(x, DTensor) else x for x in output]} vs. plaments size: {output_pis}!" return [PostHookOutput._convert(o, pi, device_mesh) for o, pi in zip(output, output_pis)] if isinstance(output, DTensor) and output_pis[0] is not None: return PostHookOutput._convert(output, output_pis[0], device_mesh) diff --git a/python/vescale/dmodule/placements_interface.py b/python/vescale/dmodule/placements_interface.py index 009c9df..0b1863e 100644 --- a/python/vescale/dmodule/placements_interface.py +++ b/python/vescale/dmodule/placements_interface.py @@ -34,7 +34,7 @@ class PlacementsInterface: async_op: bool = True # flag for DTensor.redistribute/from_local defer_reshard: bool = False # flag for deferred resharding mode run_check: bool = True # flag for DTensor.from_local - skippable_op: bool = True # flag for DTensor.redistribute # TODO: to enable + support_uneven: bool = True # flag for DTensor.from_local grad: Optional[Sequence[Placement]] = None # the placement to enforce on this tensor.grad @classmethod @@ -43,9 +43,13 @@ def from_placements(cls, placements: Any) -> Any: return placements return cls(placements) - def normalize_placements(self, mesh_ndim: int) -> None: - self.placements = normalize_placements(self.placements, mesh_ndim) - self.grad = normalize_placements(self.grad, mesh_ndim) + def normalize_placements(self, mesh_ndim: int, *, tensor_ndim: int = 0, none_as_replicate: bool = False) -> None: + self.placements = normalize_placements( + self.placements, mesh_ndim, tensor_ndim=tensor_ndim, none_as_replicate=none_as_replicate + ) + self.grad = normalize_placements( + self.grad, mesh_ndim, tensor_ndim=tensor_ndim, none_as_replicate=none_as_replicate + ) def is_none(self) -> bool: """Is it equivalent to `None` placements; diff --git a/python/vescale/dtensor/__init__.py b/python/vescale/dtensor/__init__.py index b657626..65c04d8 100644 --- a/python/vescale/dtensor/__init__.py +++ b/python/vescale/dtensor/__init__.py @@ -54,7 +54,9 @@ def _dtensor_init_helper( device_mesh = device_mesh or mesh_resources.get_current_mesh() device = device_mesh.device_type # get placements - placements: Tuple[Placement] = normalize_placements(placements, device_mesh.ndim, none_as_replicate=True) + placements: Tuple[Placement] = normalize_placements( + placements, device_mesh.ndim, tensor_ndim=len(global_shape), none_as_replicate=True + ) # get local tensor shape local_shape = compute_local_shape(global_shape, device_mesh, placements) # initialize the local tensor diff --git a/python/vescale/dtensor/_collective_utils.py b/python/vescale/dtensor/_collective_utils.py index ba52f6b..472b675 100644 --- a/python/vescale/dtensor/_collective_utils.py +++ b/python/vescale/dtensor/_collective_utils.py @@ -63,12 +63,17 @@ def mesh_scatter( Returns: A :class:`Work` object """ + # if rank is not part of mesh, simply return output + if mesh.get_coordinate() is None: + return output + # TODO: Ideally we should use the meta tensor way # (to register a meta kernel for the collective op) # so that it would avoid the communication. Need to # remove the check below once that is done. if output.is_meta: return None + dim_group = mesh.get_dim_groups(mesh_dim) assert isinstance(dim_group, ProcessGroup) # src need to be global rank @@ -107,6 +112,10 @@ def mesh_all_to_all( mesh_dim: int = 0, async_op: bool = False, ) -> Optional[Work]: + # if rank is not part of mesh, simply return None + if mesh.get_coordinate() is None: + return None + dim_group = mesh.get_dim_groups(mesh_dim) assert isinstance(dim_group, ProcessGroup) @@ -155,12 +164,16 @@ def mesh_broadcast( Args: tensor (torch.Tensor): tensor to broadcast. mesh_dim (int, optional): indicate which mesh dimension we want - to scatter on, we by default choose the first rank on the + to broadcast on, we by default choose the first rank on the mesh dimension as source of truth. Returns: A :class:`Tensor` object """ + # if rank is not part of mesh, simply return tensor, which should be an empty tensor + if mesh.get_coordinate() is None: + return tensor + dim_group = mesh.get_dim_groups(mesh_dim) assert isinstance(dim_group, ProcessGroup) # src need to be global rank @@ -190,15 +203,13 @@ def mesh_reduce_scatter( First peform all_reduce on the tensor, then split the tensor at scatter_dim and scatter them to a device mesh dimension. """ - my_coordinate = mesh.get_coordinate() - num_chunks = mesh.size(dim=mesh_dim) - - if my_coordinate is None: - # if rank is not part of mesh, simply return local_tensor, - # which should be an empty tensor + # if rank is not part of mesh, simply return tensor, which should be an empty tensor + if mesh.get_coordinate() is None: return tensor + # for now, we only support that size at `scatter_dim`` is divisable by # the mesh size at `mesh_dim` + num_chunks = mesh.size(dim=mesh_dim) assert ( tensor.size(scatter_dim) % num_chunks == 0 ), f"tensor size at {scatter_dim} is not divisable by the mesh size at {mesh_dim}" @@ -219,20 +230,16 @@ def mesh_all_gather( all_gather all shards and return a tensor that is replicated on the previously sharded mesh dimension """ - my_coordinate = mesh.get_coordinate() - num_chunks = mesh.size(dim=mesh_dim) + # if rank is not part of mesh, simply return tensor, which should be an empty tensor + if mesh.get_coordinate() is None: + return tensor # for now, we only support that global size at `scatter_dim` is equal with # the multuple of mesh size at `mesh_dim` and local_tensor size at `scatter_dim` + num_chunks = mesh.size(dim=mesh_dim) assert ( tensor.size(scatter_dim) * num_chunks == global_size[scatter_dim] ), f"global tensor size at {scatter_dim} is not equal with the multiply of mesh size at {mesh_dim} and local_tensor size at {scatter_dim}" - - if my_coordinate is None: - # if rank is not part of mesh, we simply return local_tensor, - # which should be an empty tensor - return tensor - tensor = tensor.contiguous() output = funcol.all_gather_tensor(tensor, gather_dim=scatter_dim, group=mesh._dim_group_infos[mesh_dim][1]) return output @@ -244,6 +251,10 @@ def mesh_all_reduce( reduce_op: c10d.ReduceOp.RedOpType, mesh_dim: int, ) -> torch.Tensor: + # if rank is not part of mesh, simply return tensor, which should be an empty tensor + if mesh.get_coordinate() is None: + return tensor + return funcol.all_reduce(tensor, reduceOp=reduce_op.name, group=mesh._dim_group_infos[mesh_dim][1]) diff --git a/python/vescale/dtensor/_dispatch_bypass.py b/python/vescale/dtensor/_dispatch_bypass.py index e665fdd..133500f 100644 --- a/python/vescale/dtensor/_dispatch_bypass.py +++ b/python/vescale/dtensor/_dispatch_bypass.py @@ -15,7 +15,6 @@ from vescale.dtensor.op_schema import ( DTensorSpec, OpInfo, - OpSchema, OutputSharding, ) from vescale.dtensor.placement_types import TensorMeta @@ -99,36 +98,64 @@ def __init__(self): aten._to_copy.default: BypassOpShardingProp.copy_handler, aten._local_scalar_dense.default: BypassOpShardingProp.scalar_handler, aten.equal.default: BypassOpShardingProp.scalar_handler, + aten.nonzero.default: BypassOpShardingProp.nonzero_handler, } def apply(self, op_info: OpInfo) -> bool: is_bypass = op_info.schema.op in self.op_handlers if is_bypass: - op_info.output_sharding = self.op_handlers[op_info.schema.op](op_info.schema) + op_info.output_sharding = self.op_handlers[op_info.schema.op](op_info) return True else: return False @staticmethod - def copy_handler(op_schema: OpSchema) -> OutputSharding: + def nonzero_handler(op_info: OpInfo) -> OutputSharding: + """ + Bypass nonzero because the output shape is dynamic. + We allow only replication on the input/ouput. + """ + op_schema = op_info.schema + input_spec = op_schema.args_schema[0] + all_replicate = all(p.is_replicate() for p in input_spec.placements) + assert all_replicate, "input placement has to be replicate" + input_local = op_info.local_args[0] + output_local = torch.nonzero(input_local) + out_tensor_meta = TensorMeta( + shape=output_local.shape, + stride=output_local.stride(), + dtype=output_local.dtype, + ) + return OutputSharding( + output_spec=DTensorSpec( + mesh=op_info.schema.args_spec[0].mesh, + placements=op_info.schema.args_spec[0].placements, + tensor_meta=out_tensor_meta, + ) + ) + + @staticmethod + def copy_handler(op_info: OpInfo) -> OutputSharding: + op_schema = op_info.schema kwargs = op_schema.gen_fake_kwargs() dtype = kwargs["dtype"] + args_spec0 = op_schema.args_spec[0] out_tensor_meta = TensorMeta( - shape=op_schema.args_spec[0].tensor_meta.shape, - stride=op_schema.args_spec[0].tensor_meta.stride, + shape=args_spec0.tensor_meta.shape, + stride=args_spec0.tensor_meta.stride, dtype=dtype, ) return OutputSharding( output_spec=DTensorSpec( - mesh=op_schema.args_spec[0].mesh, - placements=op_schema.args_spec[0].placements, + mesh=args_spec0.mesh, + placements=args_spec0.placements, tensor_meta=out_tensor_meta, ) ) @staticmethod - def scalar_handler(op_schema: OpSchema) -> OutputSharding: - return OutputSharding(None, [op_schema]) + def scalar_handler(op_info: OpInfo) -> OutputSharding: + return OutputSharding(None, [op_info.schema]) _bypass_op_sharding_prop = BypassOpShardingProp() diff --git a/python/vescale/dtensor/_dispatch_patch.py b/python/vescale/dtensor/_dispatch_patch.py index 6327e66..cd0dd42 100644 --- a/python/vescale/dtensor/_dispatch_patch.py +++ b/python/vescale/dtensor/_dispatch_patch.py @@ -40,8 +40,9 @@ def hack_for_special_op( kwargs: Dict[str, object], ): new_args = list(args) + op_name = str(op_call) if ( - str(op_call) == "aten.index_put.default" + op_name == "aten.index_put.default" and not isinstance(args[2], dtensor.DTensor) and isinstance(args[2], torch.Tensor) and isinstance(args[0], dtensor.DTensor) @@ -51,7 +52,7 @@ def hack_for_special_op( new_args[2] = dtensor.DTensor.from_local(new_args[2], device_mesh, sharding) return tuple(new_args), kwargs elif ( - str(op_call) in ["aten.scatter_.value", "aten.scatter.value", "aten.scatter_.src", "aten.scatter.src"] + op_name in ["aten.scatter_.value", "aten.scatter.value", "aten.scatter_.src", "aten.scatter.src"] and not isinstance(args[0], dtensor.DTensor) and isinstance(args[0], torch.Tensor) and isinstance(args[2], dtensor.DTensor) @@ -59,6 +60,29 @@ def hack_for_special_op( device_mesh = args[2]._spec.mesh new_args[0] = dtensor.DTensor.from_local(new_args[0], device_mesh, [Replicate()]) return tuple(new_args), kwargs + elif ( + str(op_call) == "aten.eq.Tensor" + and not isinstance(args[1], dtensor.DTensor) + and isinstance(args[0], dtensor.DTensor) + and isinstance(args[1], torch.Tensor) + ): + device_mesh = args[0]._spec.mesh + new_args[1] = dtensor.DTensor.from_local(new_args[1], device_mesh, [Replicate()]) + return tuple(new_args), kwargs + # hack to DTensorialize the index of aten.index.Tensor op. + elif op_call in [aten.index.Tensor] and isinstance(args[0], dtensor.DTensor): + device_mesh = args[0]._spec.mesh + new_args = [] + new_args.append(args[0]) + new_args.append( + [ + dtensor.DTensor.from_local(x, device_mesh, [Replicate()], run_check=False) + if isinstance(x, torch.Tensor) and not isinstance(x, dtensor.DTensor) + else x + for x in args[1] + ] + ) + return tuple(new_args), kwargs else: return args, kwargs diff --git a/python/vescale/dtensor/_utils.py b/python/vescale/dtensor/_utils.py index 92d2444..22dada9 100644 --- a/python/vescale/dtensor/_utils.py +++ b/python/vescale/dtensor/_utils.py @@ -9,7 +9,7 @@ ################################################################################ import warnings -from typing import List, Sequence, Tuple, Optional, Dict, Set +from typing import List, Sequence, Tuple, Optional, Dict, Set, Union import torch import torch.distributed._functional_collectives as funcol @@ -131,14 +131,19 @@ def is_same_shape_across_ranks(tensor_shape: ShapeType, device_mesh: DeviceMesh, def gather_local_tensor_shape( - self_local_tensor: torch.Tensor, device_mesh: DeviceMesh, placements: Sequence[Placement], shard_only: bool = True + self_local_tensor: Union[torch.Tensor, torch.Size], + device_mesh: DeviceMesh, + placements: Sequence[Placement], + shard_only: bool = False, ) -> Optional[Dict[int, List[List[int]]]]: """All gather local tensor shapes per mesh dimension. - When `shard_only is True`, all gather only sharded mesh dim.""" + When `shard_only is True`, all gather only sharded mesh dim. Otherwise, all gather all mesh dims.""" if device_mesh.get_coordinate() is None: # if rank is not part of mesh return None - self_local_shape = torch.tensor([list(self_local_tensor.shape)], dtype=torch.int64, device=device_mesh.device_type) + _shape: torch.Size = self_local_tensor if isinstance(self_local_tensor, torch.Size) else self_local_tensor.shape + self_local_shape = torch.tensor([list(_shape)], dtype=torch.int64, device="cpu", pin_memory=True) + self_local_shape = self_local_shape.to(device_mesh.device_type, non_blocking=True) meshdim_localtensor_shape = {} for mesh_dim, place in enumerate(placements): if shard_only and not isinstance(place, (Shard, InterleavedShard)): @@ -153,7 +158,9 @@ def gather_local_tensor_shape( if type(stacked_local_shape) is funcol.AsyncCollectiveTensor: # synchronously wait for any pending collectives to get the result tensor stacked_local_shape = stacked_local_shape.trigger_wait() - stacked_local_shape = stacked_local_shape.elem # type: ignore[attr-defined] + if hasattr(stacked_local_shape, "elem"): + stacked_local_shape = stacked_local_shape.elem # type: ignore[attr-defined] + meshdim_localtensor_shape[mesh_dim] = stacked_local_shape.detach().cpu().tolist() return meshdim_localtensor_shape diff --git a/python/vescale/dtensor/api.py b/python/vescale/dtensor/api.py index c796f64..34bf200 100644 --- a/python/vescale/dtensor/api.py +++ b/python/vescale/dtensor/api.py @@ -43,7 +43,7 @@ def normalize_placements( - placements: Optional[Sequence[Placement]], mesh_ndim: int, none_as_replicate: bool = False + placements: Optional[Sequence[Placement]], mesh_ndim: int, *, tensor_ndim: int = 0, none_as_replicate: bool = False ) -> Optional[Tuple[Placement]]: """ normalize a placements to be valid. @@ -64,6 +64,9 @@ def normalize_placements( for p in placements: if not isinstance(p, Placement): raise ValueError(f"Unsupported placements = {placements}!") + if isinstance(p, (Shard, InterleavedShard)) and p.dim < 0: + # normalize shard dim to be positive + p.dim += tensor_ndim return tuple(placements) @@ -101,7 +104,8 @@ def forward(ctx, input: "DTensor", grad_placements: Optional[Sequence[Placement] if not async_output and type(local_tensor) is funcol.AsyncCollectiveTensor: # synchronously wait for any pending collectives to get the result tensor local_tensor = local_tensor.trigger_wait() - local_tensor = local_tensor.elem # type: ignore[attr-defined] + if hasattr(local_tensor, "elem"): + local_tensor = local_tensor.elem # type: ignore[attr-defined] # We need to return a fresh Tensor object there as autograd metadata # will be inplaced into it. So we don't want to pollute the Tensor # object stored in the _local_tensor of this DTensor. @@ -145,62 +149,72 @@ def forward( run_check: bool, shape: Optional[torch.Size] = None, stride: Optional[Tuple[int, ...]] = None, + support_uneven: bool = True, async_input: bool = True, ) -> "DTensor": ctx.previous_placement = placements ctx.previous_device_mesh = device_mesh ctx.async_input = async_input - if shape and stride: # use given global shape and stride - tensor_shape, tensor_stride = shape, stride - elif not shape and not stride: # use inferred global shape and stride - if run_check: # support uneven shard - meshdim_localtensor_shape = gather_local_tensor_shape(input, device_mesh, placements) + # infer global shape and stride + if (shape is None) != (stride is None): + raise ValueError( + f"Found shape:{shape}, stride:{stride}.", + "Please pass both shape and stride at the same time!", + ) + elif shape and stride: # use given global shape and stride + tensor_shape, tensor_stride = torch.Size(shape), tuple(stride) + elif all( + p.is_replicate() or p.is_partial() for p in placements + ): # for all replicate/partial tensor, infer from local tensor + tensor_shape, tensor_stride = input.shape, input.stride() + else: # infer sharded global shape and stride + if support_uneven: # support uneven shard + meshdim_localtensor_shape = gather_local_tensor_shape(input, device_mesh, placements, shard_only=True) + assert meshdim_localtensor_shape is not None, "Out-of-mesh is impossible to support uneven sharding!" global_shape, global_stride = compute_global_tensor_info( input, device_mesh, placements, meshdim_localtensor_shape ) else: # assume even shard global_shape, global_stride = compute_global_tensor_info(input, device_mesh, placements) tensor_shape, tensor_stride = torch.Size(global_shape), tuple(global_stride) - else: - raise ValueError( - f"Found shape:{shape}, stride:{stride}.", - "Please pass both shape and stride at the same time.", - ) + # if global rank is not participating in the device mesh, we simply: + # - set the local tensor to an empty tensor + # - set global shape/stride as the global tensor if device_mesh.get_coordinate() is None: - # if the global rank is not participating in the device mesh, we - # simply set the local tensor to an empty tensor - # TODO: set global shape/stride as 0 as well, and simplify code input = input.new_empty(0, requires_grad=input.requires_grad) + # runtime checking for in-mesh ranks elif run_check: - # Assume global tensor_shape/tensor_stride are the same across ranks - # TODO: add assertion for Inferred local shape == actual local shape - # TODO: See if we need to make this run_check logic have a corresponding backward. + # per placement check for idx, placement in enumerate(placements): if placement.is_replicate(): # broadcast rank 0 tensor to all ranks - # only broadcast if run_check is True - input = input.contiguous() - input = mesh_broadcast(input, device_mesh, mesh_dim=idx) + input = mesh_broadcast(input.contiguous(), device_mesh, mesh_dim=idx) elif placement.is_interleaved_shard(): if input.shape[placement.dim] % placement.interleaved_size != 0: raise ValueError( f"Tensor size at dim {placement.dim} is not divisible by {placement.interleaved_size}" ) + # [conservative] global tensor_shape/tensor_stride should be the same across ranks + # meshdim_localtensor_shape = gather_local_tensor_shape( + # tensor_shape, device_mesh, placements, shard_only=False + # ) + # for stacked_local_shape in meshdim_localtensor_shape.values(): + # assert stacked_local_shape.count(stacked_local_shape[0]) == len( + # stacked_local_shape + # ), "The global tensor shape must be the same across ranks!" # We want a fresh Tensor object that shares memory with the input tensor - dist_tensor = DTensor( + return DTensor( input.view_as(input), device_mesh, placements, shape=tensor_shape, dtype=input.dtype, - # requires_grad of the dist tensor depends on if input requires_grad or not requires_grad=input.requires_grad, stride=tensor_stride, ) - return dist_tensor @staticmethod # type: ignore[override] @@ -219,11 +233,12 @@ def backward(ctx, grad_output: "DTensor"): if not async_input and type(local_tensor) is funcol.AsyncCollectiveTensor: # synchronously wait for any pending collectives to get the result tensor local_tensor = local_tensor.trigger_wait() - local_tensor = local_tensor.elem # type: ignore[attr-defined] + if hasattr(local_tensor, "elem"): + local_tensor = local_tensor.elem # type: ignore[attr-defined] # TODO: backward is also differentiable now, add a test # to test higher level gradients. - return local_tensor.view_as(local_tensor), None, None, None, None, None, None + return local_tensor.view_as(local_tensor), None, None, None, None, None, None, None def from_local( @@ -234,6 +249,7 @@ def from_local( run_check: bool = True, shape: Optional[torch.Size] = None, stride: Optional[Tuple[int, ...]] = None, + support_uneven: bool = True, async_input: bool = True, ) -> "DTensor": """ @@ -252,58 +268,52 @@ def from_local( `device_mesh` from the first rank of each dimension of the `device_mesh`. Keyword args: - run_check (bool, optional): indicate whether to run check across ranks - to check meta information and data. + run_check (bool, optional): indicate whether to run check across ranks to check meta information and data. If True (default), ensure correctness at cost of extra communication across ranks: - - allgather shapes for uneven sharding - - broadcast data for `Replicate` `placements` (the data on first rank of - the device mesh dimension will be broadcasted to other ranks.) + - broadcast data for `Replicate` `placements` (the data on first rank of + the device mesh dimension will be broadcasted to other ranks.) + - gather global shapes for other `placements` to ensure the same across ranks If False, no correctness guarantee but communication free. shape (torch.Size, optional): the global shape of DTensor. If given, use this as overriding global shape to build DTensor; This is useful when local shape of `local_tensor` are different across the ranks (i.e., uneven sharding). - If not given, `shape` will be inferred either assuming the DTensor is evenly sharded - across ranks or gathering other ranks's shape to build global shape. + If not given, `shape` will be inferred at the cost of communication (see `support_uneven`). stride (tuple[int], optional): the global stride of DTensor. Usage is same as `shape`. - async_input (bool, optional): indicate whether to get async input grad when + `shape` and `stride` must be given togather or not given togather. + support_uneven (bool, optional): indicate whether to support uneven sharding at the cost + of extra communication across ranks. + If True (default), use gather communication to infer global shape (that can be unevenly sharded). + If False, use local shape to infer global shape (that must be evenly sharded). + async_input (bool, optional): indicate whether to get asynchrounous input grad when backwarding `from_local`. Returns: A :class:`DTensor` object - Example: + Example of uneven sharding: - # manual given shape and stride (support uneven sharding) - saved_shape, saved_stride = dinput.shape, dinput.stride() - out = dinput.to_local() - dout = from_local(out, mesh, placements, shape=saved_shape, stride=saved_stride) + # manually given shape and stride (support uneven sharding) + >>> saved_shape, saved_stride = dinput.shape, dinput.stride() + >>> out = dinput.to_local() + >>> dout = from_local(out, mesh, placements, shape=saved_shape, stride=saved_stride) - # auto inferred shape and stride (support uneven sharding) - out = dinput.to_local() - dout = from_local(out, mesh, placements) + # auto inferred shape and stride (support uneven sharding) with gather communication overhead + >>> out = dinput.to_local() + >>> dout = from_local(out, mesh, placements) - # manual given shape and stride (support uneven sharding), - # without run_check's extra communication - saved_shape, saved_stride = dinput.shape, dinput.stride() - out = dinput.to_local() - dout = from_local(out, mesh, placements, run_check=False, shape=saved_shape, stride=saved_stride) + # auto inferred shape and stride (only even sharding) without gather communication overhead + >>> out = dinput.to_local() + >>> dout = from_local(out, mesh, placements, support_uneven=False) - # auto inferred shape and stride (only even sharding), - # without run_check's extra communication - out = dinput.to_local() - dout = from_local(out, mesh, placements, run_check=False) - - .. note:: `from_local` is differentiable, the `requires_grad` of the created - `DTensor` object will depend on if `local_tensor` requires_grad or not. + .. note:: + - `from_local` is differentiable + - the `requires_grad` of the created `DTensor` object will depend on if `local_tensor` requires_grad or not. """ assert type(local_tensor) is not DTensor assert type(getattr(local_tensor, "data", None)) is not DTensor - if VESCALE_DISABLE_RUN_CHECK: - run_check = False - # if same shape/dtype, no need to run_check, if not, must allgather # the metadatas to check the size/dtype across ranks # There should be no data communication unless there's replication @@ -317,7 +327,9 @@ def from_local( local_tensor = local_tensor.to(device_type) # validate placements - placements: Tuple[Placement] = normalize_placements(placements, device_mesh.ndim, none_as_replicate=True) + placements: Tuple[Placement] = normalize_placements( + placements, device_mesh.ndim, tensor_ndim=local_tensor.ndim, none_as_replicate=True + ) # TODO: fix later # if any(p.is_partial() for p in placements if p is not None): @@ -328,8 +340,28 @@ def from_local( # `from_local` is differentiable, and the gradient of the dist tensor this function # created should flow back the gradients to the local_tensor, so we call an autograd # function to construct the dist tensor instead. + + if VESCALE_DISABLE_RUN_CHECK: + run_check = False + + if device_mesh.get_coordinate() is None and support_uneven: + warnings.warn( + "Out-of-mesh rank uses `DTensor.from_local` under uneven sharding support, which is impossible!" + " We set `support_uneven` as `False`!" + " If uneven sharding does happen, out-of-mesh rank can only assume even sharding, which disgrees with in-mesh ranks!", + UserWarning, + ) + support_uneven = False + return _FromTorchTensor.apply( # pyre-ignore[16]: autograd func - local_tensor, device_mesh, placements, run_check, shape, stride, async_input + local_tensor, + device_mesh, + placements, + run_check, + shape, + stride, + support_uneven, + async_input, ) @@ -363,6 +395,8 @@ def to_local( .. note:: `to_local` is differentiable, the `requires_grad` of the local tensor returned will depend on if the `DTensor` requires_grad or not. """ + if grad_placements is not None: + grad_placements = normalize_placements(grad_placements, dtensor.mesh.ndim, tensor_ndim=dtensor.ndim) return _ToTorchTensor.apply(dtensor, grad_placements, async_output) @@ -418,7 +452,9 @@ def distribute_tensor( tensor = tensor.to(device_type) # validate placements - placements: Tuple[Placement] = normalize_placements(placements, device_mesh.ndim, none_as_replicate=True) + placements: Tuple[Placement] = normalize_placements( + placements, device_mesh.ndim, tensor_ndim=tensor.ndim, none_as_replicate=True + ) # validate tensor type if isinstance(tensor, DTensor): @@ -437,19 +473,17 @@ def distribute_tensor( ) return tensor - local_tensor = tensor - - # distribute the tensor according to the placements. - for idx, placement in enumerate(placements): - if placement.is_interleaved_shard(): - interleaved_shard = cast(InterleavedShard, placement) - if interleaved_shard.dim < 0: - # normalize interleaved shard placement dim - interleaved_shard.dim += tensor.ndim - my_coordinate = device_mesh.get_coordinate() - # if rank is not part of mesh, we simply return an empty tensor - output = local_tensor.new_empty(0, requires_grad=local_tensor.requires_grad) - if my_coordinate is not None: + my_coordinate = device_mesh.get_coordinate() + # if rank is not part of mesh, we simply create an empty local tensor + if my_coordinate is None: + local_tensor = tensor.new_empty(0, requires_grad=tensor.requires_grad) + else: + local_tensor = tensor + # distribute the tensor according to the placements. + for idx, placement in enumerate(placements): + if placement.is_interleaved_shard(): + interleaved_shard = cast(InterleavedShard, placement) + assert interleaved_shard.dim >= 0 scatter_tensor_list = interleaved_shard._split_tensor( local_tensor, num_chunks=device_mesh.size(idx), contiguous=True ) @@ -457,33 +491,27 @@ def distribute_tensor( mesh_scatter( output=output, scatter_list=scatter_tensor_list, mesh=device_mesh, mesh_dim=idx, async_op=False ) - local_tensor = output - elif placement.is_shard(): - shard = cast(Shard, placement) - if shard.dim < 0: - # normalize shard placement dim - shard.dim += tensor.ndim - local_tensor = _scatter_tensor_by_shard(local_tensor, device_mesh, idx, shard) - elif placement.is_replicate(): - placement = cast(Replicate, placement) - local_tensor = _replicate_tensor(local_tensor, device_mesh, idx) - elif placement.is_partial(): - my_coordinate = device_mesh.get_coordinate() - if my_coordinate is None: - # if rank is not part of mesh, we simply return an empty tensor - local_tensor = local_tensor.new_empty(0, requires_grad=local_tensor.requires_grad) - # we zero out all other ranks of the current mesh dim - # and leave only 1 rank (by default, rank 0) have the data, to perform a "zero cost" shard. - is_req_grad = local_tensor.requires_grad - local_tensor = local_tensor.contiguous() - if my_coordinate and my_coordinate[idx] != 0: - with torch.no_grad(): - local_tensor.zero_() # inplace memset to zero - local_tensor = local_tensor.requires_grad_(is_req_grad) - else: - raise RuntimeError( - f"Trying to distribute tensor with unsupported placements {placement} on device mesh dimension {idx}!" - ) + local_tensor = output + elif placement.is_shard(): + shard = cast(Shard, placement) + assert shard.dim >= 0 + local_tensor = _scatter_tensor_by_shard(local_tensor, device_mesh, idx, shard) + elif placement.is_replicate(): + placement = cast(Replicate, placement) + local_tensor = _replicate_tensor(local_tensor, device_mesh, idx) + elif placement.is_partial(): + # we zero out all other ranks of the current mesh dim + # and leave only 1 rank (by default, rank 0) have the data, to perform a "zero cost" shard. + local_tensor = local_tensor.contiguous() + if my_coordinate[idx] != 0: + is_req_grad = local_tensor.requires_grad + with torch.no_grad(): + local_tensor.zero_() # inplace memset to zero + local_tensor = local_tensor.requires_grad_(is_req_grad) + else: + raise RuntimeError( + f"Trying to distribute tensor with unsupported placements {placement} on device mesh dimension {idx}!" + ) assert local_tensor is not None, "distributing a tensor should not be None" # detach the local tensor passed to DTensor since after the construction @@ -518,30 +546,26 @@ def redistribute_dtensor( placements (List[:class:`Placement`], optional): the new placements that describes how to place the DTensor into the DeviceMesh, must have the same number of elements as `device_mesh.ndim`. + async_op (bool, optional): whether this redistribute is asynchronous in communication (for both forward and backward). + - True: the default asynchronous behavior for performance + - False: mostly used for third-party plugin op that doesn't accept asynchronous collective tensor. Returns: A :class:`DTensor` object - .. note:: `redistribute_dtensor` is differentiable. + .. Note:: + - `redistribute_dtensor` is differentiable (i.e., redistribute happen for both forward and backward) + - This redistribute API currently only supports out of place redistribution, i.e. it always create a new DTensor object and leave the original one unchanged. """ - # NOTE: This redistribute API currently only supports out - # of place redistribution, i.e. it always create a new - # DTensor object and leave the original one unchanged. # if device_mesh is not specified, use the current device_mesh device_mesh = device_mesh or dtensor.device_mesh - # raise error if new placements not specified + + # check new placements for not specified if placements is None: raise RuntimeError("placements is needed for redistribute!") - for placement in placements: - if isinstance(placement, (Shard, InterleavedShard)) and placement.dim < 0: - # normalize shard dim to be positive - placement.dim += dtensor.ndim - - # Early return the original DTensor if the placements are the same. - if dtensor._spec.placements == placements: - return dtensor + placements: Tuple[Placement] = normalize_placements(placements, device_mesh.ndim, tensor_ndim=dtensor.ndim) return Redistribute.apply(dtensor, device_mesh, placements, async_op) diff --git a/python/vescale/dtensor/dispatch.py b/python/vescale/dtensor/dispatch.py index ce7f1e1..ba21802 100644 --- a/python/vescale/dtensor/dispatch.py +++ b/python/vescale/dtensor/dispatch.py @@ -104,9 +104,15 @@ def unwrap_to_op_info( if isinstance(a, dtensor.DTensor): new_arg.append(a._local_tensor) new_schema.append(a._spec) + if mesh is not None: + if mesh != a.device_mesh: + raise NotImplementedError(f"{op_call}: DTensor does not support cross-mesh operation yet!") + else: + mesh = a.device_mesh else: new_arg.append(a) new_schema.append(a) + args_schema.append(new_schema) local_args.append(new_arg) else: @@ -131,7 +137,7 @@ def unwrap_to_op_info( kwargs_schema[k] = v local_kwargs[k] = v - assert mesh is not None, "found no DeviceMesh from dtensor args for {op_call}!" + assert mesh is not None, f"found no DeviceMesh from dtensor args for {op_call}!" op_info = OpInfo( mesh, diff --git a/python/vescale/dtensor/dtensor.py b/python/vescale/dtensor/dtensor.py index 735ef40..054b8f9 100644 --- a/python/vescale/dtensor/dtensor.py +++ b/python/vescale/dtensor/dtensor.py @@ -9,7 +9,8 @@ ################################################################################ import warnings -from typing import Optional, Sequence, Tuple +from typing import Optional, Sequence, Tuple, List, Union +from numbers import Number import torch @@ -187,11 +188,13 @@ def from_local( run_check: bool = True, shape: Optional[torch.Size] = None, stride: Optional[Tuple[int, ...]] = None, + support_uneven: bool = True, async_input: bool = True, ) -> "DTensor": # we have to do this to avoid circle import. from vescale.dtensor.api import from_local + # TODO: moving impl code here for performance, as here is on the critial path but api function is less used return from_local( local_tensor, device_mesh, @@ -199,6 +202,7 @@ def from_local( run_check=run_check, shape=shape, stride=stride, + support_uneven=support_uneven, async_input=async_input, ) @@ -210,6 +214,7 @@ def to_local( ) -> torch.Tensor: from vescale.dtensor.api import to_local + # TODO: moving impl code here for performance, as here is on the critial path but api function is NEVER used return to_local(self, grad_placements=grad_placements, async_output=async_output) def redistribute( @@ -220,6 +225,7 @@ def redistribute( ) -> "DTensor": from vescale.dtensor.api import redistribute_dtensor + # TODO: moving impl code here for performance, as here is on the critial path but api function is rarely used return redistribute_dtensor(self, device_mesh=device_mesh, placements=placements, async_op=async_op) def requires_grad_(self, mode=True): @@ -229,3 +235,15 @@ def requires_grad_(self, mode=True): def retain_grad(self) -> None: self._local_tensor.retain_grad() return super().retain_grad() + + def tolist(self) -> Union[List, Number]: + """ + Returns the dtensor as a (nested) list. + For scalars, a standard Python number is returned, just like with item(). + Tensors are automatically moved to the CPU first if necessary. + + Note: + - This operation is not differentiable. + - This operation is not dispatched but a torch function. + """ + return self._local_tensor.tolist() diff --git a/python/vescale/dtensor/ops/embedding_ops.py b/python/vescale/dtensor/ops/embedding_ops.py index a464225..b3c6506 100644 --- a/python/vescale/dtensor/ops/embedding_ops.py +++ b/python/vescale/dtensor/ops/embedding_ops.py @@ -9,10 +9,17 @@ ################################################################################ # implement matrix related ops for distributed tensor +import copy + import torch from vescale.dtensor.op_schema import OpSchema, OutputSharding -from vescale.dtensor.ops.utils import register_prop_rule +from vescale.dtensor.ops.utils import ( + register_prop_rule, + is_tensor_all_replicate, + is_tensor_all_replicate_except_sharded_at_dim, + is_tensor_partial, +) from vescale.dtensor.placement_types import DTensorSpec, Partial, Replicate, Shard aten = torch.ops.aten @@ -65,20 +72,57 @@ def embedding_dense_backward_rules(op_schema: OpSchema) -> OutputSharding: grad_output, indices = op_schema.args_schema[:2] assert isinstance(grad_output, DTensorSpec) assert isinstance(indices, DTensorSpec) - if grad_output.placements == indices.placements: - # The embedding table is replicated, and input/oupput activations are - # sharded. In this case, gradients for the embedding table should be - # Partial. - return OutputSharding(output_spec=DTensorSpec(mesh=indices.mesh, placements=(Partial(),))) - elif grad_output.placements == [Partial()] and indices.placements == [Replicate()]: - # The embedding table is replicated and the indices is also replicated - # (local is a more precise term). This is postional embedding. In this - # case, gradients for the embmedding table should be Partial. - return OutputSharding(output_spec=DTensorSpec(mesh=indices.mesh, placements=(Partial(),))) - elif all(placement.is_replicate() for placement in indices.placements): - # BWD for colwise sharding case - return OutputSharding(output_spec=DTensorSpec(mesh=indices.mesh, placements=(Shard(1),))) - else: - raise NotImplementedError( - "Unsupported embedding dense backward schema:\n" f"grad_output - {grad_output}\n" f"indices - {indices}" - ) + + mesh = grad_output.mesh + + # Situation 1: All replicate + if is_tensor_all_replicate(grad_output) and is_tensor_all_replicate(indices): + return OutputSharding(output_spec=DTensorSpec(mesh=mesh, placements=tuple([Replicate()] * mesh.ndim))) + + # Situation 2: Colwise sharding + if is_tensor_all_replicate_except_sharded_at_dim( + spec=grad_output, tensor_dim=grad_output.ndim - 1 + ) and is_tensor_all_replicate(indices): + result_placements = [] + for p in grad_output.placements: + if p.is_shard(): + tmp_p = copy.deepcopy(p) + tmp_p.dim = 1 + result_placements.append(tmp_p) + else: + result_placements.append(p) + return OutputSharding(output_spec=DTensorSpec(mesh=mesh, placements=tuple(result_placements))) + + # Situation 3: Sharded on dims other than hidden dim + sharded_on_no_hidden_flag = False + sharded_on_no_hidden_mesh_dims = [] + for mesh_idx, idx_p in enumerate(indices.placements): + grad_out_p = grad_output.placements[mesh_idx] + if idx_p.is_partial() or grad_out_p.is_partial(): + sharded_on_no_hidden_flag = False + break + if idx_p.is_replicate() or grad_out_p.is_replicate(): + continue + if idx_p != grad_out_p: + sharded_on_no_hidden_flag = False + break + sharded_on_no_hidden_flag = True + sharded_on_no_hidden_mesh_dims.append(mesh_idx) + + if sharded_on_no_hidden_flag: + result_placements = [Replicate()] * mesh.ndim + for mesh_idx in sharded_on_no_hidden_mesh_dims: + result_placements[mesh_idx] = Partial() + return OutputSharding(output_spec=DTensorSpec(mesh=mesh, placements=tuple(result_placements))) + + # Situation 4: grad_output is partial, but indices is replicate + if ( + is_tensor_all_replicate(indices) + and is_tensor_partial(grad_output) + and not any(p.is_shard() for p in grad_output.placements) + ): + return OutputSharding(output_spec=grad_output) + + raise NotImplementedError( + "Unsupported embedding dense backward schema:\n" f"grad_output - {grad_output}\n" f"indices - {indices}" + ) diff --git a/python/vescale/dtensor/ops/math_ops.py b/python/vescale/dtensor/ops/math_ops.py index c0af1f0..78a4ff4 100644 --- a/python/vescale/dtensor/ops/math_ops.py +++ b/python/vescale/dtensor/ops/math_ops.py @@ -15,7 +15,6 @@ from vescale.dtensor import DeviceMesh from vescale.dtensor.op_schema import OpSchema, OutputSharding, RuntimeSchemaInfo, OpStrategy, PlacementStrategy -from vescale.dtensor.ops.common_rules import pointwise_rule from vescale.dtensor.ops.utils import ( as_list, generate_redistribute_costs, @@ -190,6 +189,36 @@ def linear_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrate ) +@register_op_strategy([aten.mse_loss.default]) +def mse_loss_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + input_strategy, label_strategy = op_schema.args_schema + assert isinstance(input_strategy, OpStrategy) + assert input_strategy.strategies[0].output_spec.placements == label_strategy.strategies[0].output_spec.placements + reduce_dims = list(range(input_strategy.output_ndim)) + out = common_reduction_strategy( + mesh, + input_strategy, + reduce_dims, + keep_dim=False, + reduction_linear=True, + reduction_op=c10d.ReduceOp.AVG, + ) + input_spec = out.strategies[0].input_specs[0] + label_spec = label_strategy.strategies[0].output_spec + out.strategies[0].input_specs = (input_spec, label_spec) + return out + + +@register_op_strategy([aten.mse_loss_backward.default]) +def mse_loss_backward_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + grad_strategy, lhs_strategy, rhs_strategy, _ = op_schema.args_schema + + if any(placement.is_partial() for placement in grad_strategy.strategies[0].output_spec.placements): + raise RuntimeError("MSE meet backward with partial, if that you need to call allreduce before backward") + + return OpStrategy([PlacementStrategy(output_spec=lhs_strategy.strategies[0].output_spec)]) + + @register_op_strategy([aten.argmax.default, aten.argmin.default], schema_info=RuntimeSchemaInfo(1)) def arg_max_min(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: args_schema = op_schema.args_schema @@ -320,7 +349,9 @@ def softmax_bwd_rule(op_schema: OpSchema) -> OutputSharding: out_dim_map = out_spec.dim_map if softmax_dim < len(grad_out_dim_map) and (grad_out_dim_map[softmax_dim] >= 0 or out_dim_map[softmax_dim] >= 0): raise RuntimeError("Cannot run _softmax_backward_data on sharding dimension!") - return pointwise_rule(op_schema) + # (Hongyu), support P on grad_out_spec + # return pointwise_rule(op_schema) + return OutputSharding(grad_out_spec) @register_op_strategy( @@ -468,10 +499,11 @@ def _prop_native_layer_norm_backward(op_schema: OpSchema) -> OutputSharding: sharded_input_mesh_dims[input_dim] = [] sharded_input_mesh_dims[input_dim].append(i) assert len(sharded_input_mesh_dims) == 1, "input of layernorm must be sharded along only one dim" - param_grad_placements = [Replicate()] * weight.mesh.ndim + param_grad_placements = [Replicate()] * weight.mesh.ndim if weight is not None else None sharded_input_dim = list(sharded_input_mesh_dims.keys())[0] - for mesh_dim in sharded_input_mesh_dims[sharded_input_dim]: - param_grad_placements[mesh_dim] = Partial() + if param_grad_placements is not None: + for mesh_dim in sharded_input_mesh_dims[sharded_input_dim]: + param_grad_placements[mesh_dim] = Partial() weight_grad = ( DTensorSpec( diff --git a/python/vescale/dtensor/ops/pointwise_ops.py b/python/vescale/dtensor/ops/pointwise_ops.py index a9c0484..e5de8a8 100644 --- a/python/vescale/dtensor/ops/pointwise_ops.py +++ b/python/vescale/dtensor/ops/pointwise_ops.py @@ -55,6 +55,8 @@ aten.to.dtype, aten.add.Tensor, aten.add_.Tensor, + aten.neg.default, + aten.neg_.default, ] pointwise_ops = [ @@ -298,9 +300,7 @@ aten.nan_to_num.out, aten.nan_to_num_.default, aten.ne.Scalar, - aten.neg.default, aten.neg.out, - aten.neg_.default, aten.nextafter.default, aten.nextafter.out, aten.nextafter_.default, @@ -323,6 +323,7 @@ aten.rad2deg.out, aten.rad2deg_.default, aten.relu.default, + aten.triu.default, aten.relu_.default, aten.remainder.Scalar, aten.remainder.Scalar_Tensor, @@ -406,6 +407,16 @@ def pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = False) -> StrategyType: + # (Hongyu): allow pointwise P mul/div R + if op_schema.op in [aten.mul.Tensor, aten.div.Tensor]: + placements_a = op_schema.args_schema[0].strategies[0].output_spec.placements + if isinstance(op_schema.args_schema[1], float): + linearity = True + elif isinstance(op_schema.args_schema[1], OpStrategy): + spec_b = op_schema.args_schema[1].strategies[0].output_spec + if len(placements_a) == 1 and placements_a[0].is_partial() and spec_b.is_replicated(): + linearity = True + max_shards_strategy_index = -1 max_shards = -1 # handle broadcasting diff --git a/python/vescale/dtensor/ops/random_ops.py b/python/vescale/dtensor/ops/random_ops.py index 7dc9ad5..144bf1b 100644 --- a/python/vescale/dtensor/ops/random_ops.py +++ b/python/vescale/dtensor/ops/random_ops.py @@ -17,7 +17,7 @@ aten = torch.ops.aten -@register_op_strategy([aten.normal_.default, aten.uniform_.default, aten.native_dropout.default]) +@register_op_strategy([aten.normal_.default, aten.uniform_.default]) def random_op_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: self_strategy = op_schema.args_schema[0] assert isinstance(self_strategy, OpStrategy) @@ -31,3 +31,17 @@ def random_op_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: random_strategy.strategies.append(PlacementStrategy(output_spec=arg_spec)) return random_strategy + + +# (Hongyu) allow partial placements for dropout +@register_op_strategy(aten.native_dropout.default) +def random_op_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + self_strategy = op_schema.args_schema[0] + assert isinstance(self_strategy, OpStrategy) + + random_strategy = OpStrategy([]) + for arg_strategy in self_strategy.strategies: + arg_spec = arg_strategy.output_spec + random_strategy.strategies.append(PlacementStrategy(output_spec=arg_spec)) + + return random_strategy diff --git a/python/vescale/dtensor/ops/tensor_ops.py b/python/vescale/dtensor/ops/tensor_ops.py index 3751bae..5022a79 100644 --- a/python/vescale/dtensor/ops/tensor_ops.py +++ b/python/vescale/dtensor/ops/tensor_ops.py @@ -25,7 +25,6 @@ from vescale.dtensor._diff import EnablePartialMode from vescale.dtensor.ops.common_rules import pointwise_rule from vescale.dtensor.ops.utils import ( - generate_redistribute_costs, is_tensor_dim_sharded, is_tensor_partial, normalize_dim, @@ -38,6 +37,7 @@ Partial, Placement, Replicate, + InterleavedShard, Shard, TensorMeta, ) @@ -52,6 +52,7 @@ aten.clone.default, aten.contiguous.default, aten.copy_.default, + aten.cumsum.default, aten.detach.default, aten.equal.default, aten.fill_.Scalar, @@ -323,51 +324,58 @@ def replica_only_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType return OpStrategy([PlacementStrategy(replicate_spec)]) -@register_op_strategy([aten.select.int]) -def index_select(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: - """Only allow replication on the input/ouput""" - ( - input, - dim, - _, - ) = op_schema.args_schema - dim = normalize_dim(input.output_ndim, dim) - input_spec: DTensorSpec = input.strategies[0].output_spec - has_shard_on_dim = any(placement.is_shard(dim=dim) for placement in input_spec.placements) - assert not has_shard_on_dim, "currently not support shard on select dim" - new_placements = input_spec.placements - for pm in new_placements: - if isinstance(pm, Shard) and pm.dim > dim: - pm.dim -= 1 - output_spec = DTensorSpec(mesh, tuple(new_placements)) - return OpStrategy([PlacementStrategy(output_spec)]) +@register_prop_rule(aten.select.int) +def _prop_select(op_schema: OpSchema) -> OutputSharding: + tensor, dim = op_schema.args_schema[:2] + assert isinstance(tensor, DTensorSpec) + assert isinstance(dim, int) + placements: Sequence[Placement] = tensor.placements + assert all(not p.is_shard(dim) for p in placements), "DTensor does not support select on sharded dimension." + + # select will remove one dimension, decrement dim of Shard placements by 1 + # if they are larger than dim. + new_placements: List[Placement] = [] + for p in placements: + # Using isinstance instead of is_shard so that mypy won't complain + # about accessing dim attribute. + if isinstance(p, Shard) and p.dim > dim: + new_placements.append(Shard(p.dim - 1)) + else: + new_placements.append(p) + + return OutputSharding(output_spec=DTensorSpec(mesh=tensor.mesh, placements=tuple(new_placements))) + + +@register_prop_rule(aten.gather.default, schema_info=RuntimeSchemaInfo(1)) +def prop_gather(op_schema: OpSchema) -> OutputSharding: + values_spec, dim, indices_spec = op_schema.args_schema + + assert isinstance(values_spec, DTensorSpec) + assert isinstance(dim, int) + assert isinstance(indices_spec, DTensorSpec) + + return OutputSharding( + output_spec=DTensorSpec( + mesh=values_spec.mesh, + placements=values_spec.placements, + ), + ) @register_op_strategy([aten.scatter_.value, aten.scatter.value, aten.scatter_.src, aten.scatter.src]) def scatter_value(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: value, _, index, src = op_schema.args_schema - value_target = DTensorSpec(mesh, [Replicate()], value.strategies[0].output_spec.tensor_meta) - index_target = DTensorSpec(mesh, [Replicate()], index.strategies[0].output_spec.tensor_meta) - src_target = ( - DTensorSpec(mesh, [Replicate()], src.strategies[0].output_spec.tensor_meta) - if isinstance(src, OpStrategy) - else src - ) - - redistribute_value_costs = [] - # TODO: change to vescale stype redistribution - redistribute_value_costs.append(generate_redistribute_costs(value, value_target)) - redistribute_index_costs = [] - # TODO: change to vescale stype redistribution - redistribute_index_costs.append(generate_redistribute_costs(index, index_target)) - redistribute_costs = [[x + y for x, y in zip(redistribute_value_costs[0], redistribute_index_costs[0])]] if isinstance(src, OpStrategy): - redistribute_src_costs = [] - # TODO: change to vescale stype redistribution - redistribute_src_costs.append(generate_redistribute_costs(src, src_target)) - redistribute_costs = [[x + y for x, y in zip(redistribute_costs[0], redistribute_src_costs[0])]] + src_target = src.strategies[0].output_spec + else: + src_target = src + value_target = value.strategies[0].output_spec + index_target = index.strategies[0].output_spec - output_spec = DTensorSpec(mesh=mesh, placements=[Replicate()]) + if isinstance(src, OpStrategy): + output_spec = DTensorSpec(mesh=mesh, placements=src_target.placements) + else: + output_spec = DTensorSpec(mesh=mesh, placements=[Replicate()]) input_specs = [value_target, index_target] if isinstance(src, OpStrategy): input_specs.append(src_target) @@ -376,7 +384,6 @@ def scatter_value(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: PlacementStrategy( output_spec=output_spec, input_specs=input_specs, - redistribute_cost=redistribute_costs, ) ] ) @@ -385,7 +392,7 @@ def scatter_value(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: @register_op_strategy([aten.index_put_.default, aten.index_put.default]) def index_put(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: """Set the Output with the index sharding""" - (input, index_list, value) = op_schema.args_schema + value = op_schema.args_schema[2] value_spec: DTensorSpec = value.strategies[0].output_spec output_spec = DTensorSpec(mesh, tuple(value_spec.placements)) @@ -593,7 +600,7 @@ def place(vp: Placement, ip: Placement) -> Placement: return result -@register_prop_rule(aten.cat.default, schema_info=RuntimeSchemaInfo(1, needs_pytree=True)) +@register_prop_rule([aten.cat.default, aten.stack.default], schema_info=RuntimeSchemaInfo(1, needs_pytree=True)) def cat_rule(op_schema: OpSchema) -> OutputSharding: # torch.cat requires all tensors must either have the same shape (except # in the concatenating dimension) or be "empty". "Empty" here strictly means @@ -637,7 +644,9 @@ def is_empty(spec: DTensorSpec) -> bool: need_reshard = False tensor_list_specs_after: List[DTensorSpec] = [] for spec in tensor_list_specs: - if not is_empty(spec) and (is_tensor_dim_sharded(spec, dim=dim) or is_tensor_partial(spec)): + if not is_empty(spec) and ( + is_tensor_dim_sharded(spec, dim=dim) + ): # Hongyu: allow torch.cat DTensors with Partial placements need_reshard = True tensor_list_specs_after.append( DTensorSpec( @@ -782,6 +791,98 @@ def size_split(N, i): return OutputSharding(output_spec_list) +@register_prop_rule([aten.unbind.int], schema_info=RuntimeSchemaInfo(1)) +def unbind_rule(op_schema: OpSchema) -> OutputSharding: + output_spec_list: List[DTensorSpec] = [] + input_spec = cast(DTensorSpec, op_schema.args_schema[0]) + ndim = input_spec.ndim + dim = cast(int, op_schema.args_schema[1]) + dim = normalize_dim(dim, ndim) + + # TODO: tensor to unbind cannot have Partial + # in its placements for now. Will need to + # support in future. + if input_spec.sums: + raise NotImplementedError( + f"splitting distributed tensor with " f"Partial placement is not implemented!\n" f"DTensorSpec={input_spec}" + ) + + # TODO: just like slice op, unbind replicates before + # splitting on a sharded dimension + need_reshard = False + if is_tensor_dim_sharded(input_spec, dim=dim): + need_reshard = True + input_spec = DTensorSpec( + mesh=input_spec.mesh, + placements=unshard_tensor_dim(input_spec.placements, dim=dim), + tensor_meta=input_spec.tensor_meta, + ) + + if need_reshard: + return OutputSharding( + None, + schema_suggestions=[ + OpSchema( + op=op_schema.op, + args_schema=(input_spec,) + op_schema.args_schema[1:], + kwargs_schema=op_schema.kwargs_schema, + ), + ], + ) + + # we calculate output placements here. + output_placements = [] + for p in input_spec.placements: + if p.is_shard(): + sharded_dim = normalize_dim(p.dim, ndim) + if sharded_dim < dim: + output_placements.append(p) + else: + if isinstance(p, InterleavedShard): + output_placements.append(InterleavedShard(sharded_dim - 1, p.interleaved_size)) + else: + output_placements.append(Shard(sharded_dim - 1)) + else: + output_placements.append(p) + + output_size_list = input_spec.shape[dim] + output_spec_list = [ + DTensorSpec( + mesh=input_spec.mesh, + placements=tuple(output_placements), + ) + for _ in range(output_size_list) + ] + return OutputSharding(output_spec_list) + + +@register_prop_rule([aten.index_add.default, aten.index_add_.default], schema_info=RuntimeSchemaInfo(1)) +def index_add_rule(op_schema: OpSchema) -> OutputSharding: + input_spec = cast(DTensorSpec, op_schema.args_schema[0]) + ndim = input_spec.ndim + dim = cast(int, op_schema.args_schema[1]) + dim = normalize_dim(dim, ndim) + index_spec = cast(DTensorSpec, op_schema.args_schema[2]) + src_spec = cast(DTensorSpec, op_schema.args_schema[3]) + + if not index_spec.is_replicated(): + raise RuntimeError("index must be replicate for index_add op") + + if src_spec.sums or input_spec.sums: + # TODO(wjw): maybe we should allow partial here. + raise NotImplementedError("src and input can not be partial for index_add op") + + if src_spec.ndim != input_spec.ndim: + raise RuntimeError("invalid index_add op detected") + + assert not is_tensor_dim_sharded(input_spec, dim) and not is_tensor_dim_sharded( + src_spec, dim + ), "src or input can not be sharded on the index dim for adding" + for input_p, src_p in zip(input_spec.placements, src_spec.placements): + assert input_p == src_p, "src and input should be samley sharded on dims other than the index dim" + return OutputSharding(input_spec) + + @register_prop_rule(aten.alias.default) def _prop_aten_alias(op_schema: OpSchema) -> OutputSharding: output_spec = cast(DTensorSpec, op_schema.args_schema[0]) diff --git a/python/vescale/dtensor/ops/utils.py b/python/vescale/dtensor/ops/utils.py index d16d382..9398c87 100644 --- a/python/vescale/dtensor/ops/utils.py +++ b/python/vescale/dtensor/ops/utils.py @@ -138,6 +138,24 @@ def is_tensor_all_replicate(spec: DTensorSpec) -> bool: return all(p.is_replicate() for p in spec.placements) +def is_tensor_all_replicate_except_sharded_at_dim( + spec: DTensorSpec, + tensor_dim: int, + exclude_interleaved_shard: bool = False, +) -> bool: + for p in spec.placements: + if p.is_replicate(): + continue + if p.is_partial(): + return False + if exclude_interleaved_shard and p.is_interleaved_shard(): + return False + if p.is_shard(): + if p.dim != tensor_dim: + return False + return True + + def map_placements_after_broadcast( placements: Tuple[Placement, ...], shape: torch.Size, diff --git a/python/vescale/dtensor/ops/view_ops.py b/python/vescale/dtensor/ops/view_ops.py index 8930d45..36a3475 100644 --- a/python/vescale/dtensor/ops/view_ops.py +++ b/python/vescale/dtensor/ops/view_ops.py @@ -561,39 +561,20 @@ def get_dim_size(cmd: DimSpec) -> Tuple[int, Optional[InputDim]]: return (tuple(out_shape), output_placements, shardable_dims) -def register_prop_rule_map( - aten_op_overload: torch._ops.OpOverload, - local_op_name: Callable[..., torch.Tensor], - schema_info: Optional[RuntimeSchemaInfo] = None, -) -> None: - spec: Op = ops[local_op_name] +def _reshape_prop(op_schema: OpSchema, spec: Op) -> OutputSharding: + rules = spec.dim_map(*op_schema.args_schema, **op_schema.kwargs_schema) + input_dtensor_spec = cast(DTensorSpec, op_schema.args_schema[0]) + mesh = input_dtensor_spec.mesh - @register_prop_rule(aten_op_overload, schema_info=schema_info) - def reshape_prop(op_schema: OpSchema) -> OutputSharding: - rules = spec.dim_map(*op_schema.args_schema, **op_schema.kwargs_schema) - input_dtensor_spec = cast(DTensorSpec, op_schema.args_schema[0]) - mesh = input_dtensor_spec.mesh - - assert isinstance(input_dtensor_spec, DTensorSpec), "Expected first input to be a DTensorSpec" - global_in_shape = input_dtensor_spec.shape - assert global_in_shape is not None, "Shape required." - - if TORCH_VERSION_BIGGER_THAN_2_2: - from torch._subclasses.fake_tensor import unset_fake_temporarily - from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing - - with disable_proxy_modes_tracing(), unset_fake_temporarily(): - ( - global_out_shape, - shard_out, - shardable_dims, - ) = propagate_shape_and_sharding( - input_dtensor_spec.placements, - tuple(global_in_shape), - rules, - mesh.shape, - ) - else: + assert isinstance(input_dtensor_spec, DTensorSpec), "Expected first input to be a DTensorSpec" + global_in_shape = input_dtensor_spec.shape + assert global_in_shape is not None, "Shape required." + + if TORCH_VERSION_BIGGER_THAN_2_2: + from torch._subclasses.fake_tensor import unset_fake_temporarily + from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing + + with disable_proxy_modes_tracing(), unset_fake_temporarily(): ( global_out_shape, shard_out, @@ -604,60 +585,83 @@ def reshape_prop(op_schema: OpSchema) -> OutputSharding: rules, mesh.shape, ) + else: + ( + global_out_shape, + shard_out, + shardable_dims, + ) = propagate_shape_and_sharding( + input_dtensor_spec.placements, + tuple(global_in_shape), + rules, + mesh.shape, + ) + + if shard_out is not None: + # no reshard needed + output_dtensor_spec = DTensorSpec(mesh=mesh, placements=tuple(shard_out)) + + # We only need the local shape to lower the call into the local op + args = op_schema.args_schema + shape_argnum = spec.shape_argnum + if shape_argnum is not None: + # compute the local shape from the global shape, then return + # a resharding even if we don't really reshard, the only reason + # for this type of resharding is to lower the global shape to + # local shape + local_out_shape = compute_local_shape(list(global_out_shape), mesh, shard_out) + + suggested_schema = OpSchema( + op=op_schema.op, + args_schema=args[:shape_argnum] + (tuple(local_out_shape),) + args[shape_argnum + 1 :], + kwargs_schema=op_schema.kwargs_schema, + ) + return OutputSharding( + output_spec=output_dtensor_spec, + schema_suggestions=[suggested_schema], + needs_redistribute=True, + ) - if shard_out is not None: - # no reshard needed - output_dtensor_spec = DTensorSpec(mesh=mesh, placements=tuple(shard_out)) - - # We only need the local shape to lower the call into the local op - args = op_schema.args_schema - shape_argnum = spec.shape_argnum - if shape_argnum is not None: - # compute the local shape from the global shape, then return - # a resharding even if we don't really reshard, the only reason - # for this type of resharding is to lower the global shape to - # local shape - local_out_shape = compute_local_shape(list(global_out_shape), mesh, shard_out) - - suggested_schema = OpSchema( + return OutputSharding(output_spec=output_dtensor_spec) + + else: + # TODO: optimize this. we shouldn't simply blindly replicate + # unshardable dims ... + # FIXME: this can be wrong for situations where we have + # [Shard(0), Shard(0)] + suggested_placements = [ + p if not isinstance(p, (Shard, InterleavedShard)) or shardable_dims[p.dim][mesh_dim] else Replicate() + for mesh_dim, p in enumerate(input_dtensor_spec.placements) + ] + return OutputSharding( + output_spec=None, + schema_suggestions=[ + OpSchema( op=op_schema.op, - args_schema=args[:shape_argnum] + (tuple(local_out_shape),) + args[shape_argnum + 1 :], + args_schema=( + DTensorSpec( + placements=tuple(suggested_placements), + mesh=input_dtensor_spec.mesh, + tensor_meta=input_dtensor_spec.tensor_meta, + ), + ) + + op_schema.args_schema[1:], kwargs_schema=op_schema.kwargs_schema, ) - return OutputSharding( - output_spec=output_dtensor_spec, - schema_suggestions=[suggested_schema], - needs_redistribute=True, - ) + ], + ) - return OutputSharding(output_spec=output_dtensor_spec) - else: - # TODO: optimize this. we shouldn't simply blindly replicate - # unshardable dims ... - # FIXME: this can be wrong for situations where we have - # [Shard(0), Shard(0)] - suggested_placements = [ - p if not isinstance(p, (Shard, InterleavedShard)) or shardable_dims[p.dim][mesh_dim] else Replicate() - for mesh_dim, p in enumerate(input_dtensor_spec.placements) - ] - return OutputSharding( - output_spec=None, - schema_suggestions=[ - OpSchema( - op=op_schema.op, - args_schema=( - DTensorSpec( - placements=tuple(suggested_placements), - mesh=input_dtensor_spec.mesh, - tensor_meta=input_dtensor_spec.tensor_meta, - ), - ) - + op_schema.args_schema[1:], - kwargs_schema=op_schema.kwargs_schema, - ) - ], - ) +def register_prop_rule_map( + aten_op_overload: torch._ops.OpOverload, + local_op_name: Callable[..., torch.Tensor], + schema_info: Optional[RuntimeSchemaInfo] = None, +) -> None: + spec: Op = ops[local_op_name] + + @register_prop_rule(aten_op_overload, schema_info=schema_info) + def reshape_prop(op_schema: OpSchema) -> OutputSharding: + return _reshape_prop(op_schema, spec) register_prop_rule_map(aten.squeeze.default, torch.squeeze) @@ -667,3 +671,17 @@ def reshape_prop(op_schema: OpSchema) -> OutputSharding: register_prop_rule_map(aten.permute.default, torch.permute, schema_info=RuntimeSchemaInfo(1)) register_prop_rule_map(aten.repeat.default, Tensor.repeat, schema_info=RuntimeSchemaInfo(1)) register_prop_rule_map(aten.transpose.int, torch.transpose, schema_info=RuntimeSchemaInfo(1)) + + +@register_prop_rule([aten.expand_as.default]) +def expand_as_prop(op_schema: OpSchema) -> OutputSharding: + source, dst = op_schema.args_schema + global_out_shape = dst.tensor_meta.shape + new_op_schema = OpSchema( + op=aten.expand.default, args_schema=(source, tuple(global_out_shape)), kwargs_schema=op_schema.kwargs_schema + ) + expand_sharding_out = _reshape_prop(new_op_schema, ops[Tensor.expand]) + expand_sharding_out.output_spec.placements = dst.placements + expand_sharding_out.needs_redistribute = False + expand_sharding_out.suggested_schema = None + return expand_sharding_out diff --git a/python/vescale/dtensor/redistribute.py b/python/vescale/dtensor/redistribute.py index c7266a2..d7f4711 100644 --- a/python/vescale/dtensor/redistribute.py +++ b/python/vescale/dtensor/redistribute.py @@ -91,14 +91,13 @@ def _reshard_to_replicate_with_pad_one_dim( This function all_gather all shards and return a tensor that is replicated on the previously sharded mesh dimension """ + # if rank is not part of mesh, we simply return local_tensor, which should be an empty tensor my_coordinate = mesh.get_coordinate() - num_chunks = mesh.size(dim=mesh_dim) - if my_coordinate is None: - # if rank is not part of mesh, we simply return local_tensor, - # which should be an empty tensor return local_tensor + num_chunks = mesh.size(dim=mesh_dim) + # check if it needs to pad input tensor before all_gather full_chunk_size = (size[shard_dim] + num_chunks - 1) // num_chunks chunk_sizes = [ @@ -156,14 +155,14 @@ def _scatter_tensor_by_shard(tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: i shard and scatter a tensor on a mesh dimension (use coordinate 0 on the mesh dimension as source of truth) """ + # if rank is not part of mesh, we simply return an empty tensor my_coordinate = mesh.get_coordinate() - num_chunks = mesh.size(dim=mesh_dim) - if my_coordinate is None: - # if rank is not part of mesh, we simply return an empty tensor return tensor.new_empty(0, requires_grad=tensor.requires_grad) - scatter_list, pad_sizes = shard_spec._split_tensor(tensor, num_chunks, with_padding=True, contiguous=True) + scatter_list, pad_sizes = shard_spec._split_tensor( + tensor, num_chunks=mesh.size(dim=mesh_dim), with_padding=True, contiguous=True + ) output = torch.empty_like(scatter_list[my_coordinate[mesh_dim]]) mesh_scatter(output, scatter_list, mesh, mesh_dim=mesh_dim) @@ -180,9 +179,8 @@ def _replicate_tensor(tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int) -> Replicate (broadcast) a torch.Tensor on a mesh dimension (use the first coordinate on the mesh dimension as source of truth) """ - my_coordinate = mesh.get_coordinate() - if my_coordinate is None: - # if rank is not part of mesh, we simply return an empty tensor + # if rank is not part of mesh, we simply return an empty tensor + if mesh.get_coordinate() is None: return tensor.new_empty(0, requires_grad=tensor.requires_grad) tensor = tensor.contiguous() @@ -200,14 +198,12 @@ def _reduce_scatter_to_shard_with_pad( """ reduce and scatter a tensor on a mesh dimension """ + # if rank is not part of mesh, we simply return tensor, which should be an empty tensor my_coordinate = mesh.get_coordinate() - num_chunks = mesh.size(dim=mesh_dim) - if my_coordinate is None: - # if rank is not part of mesh, we simply return local_tensor, - # which should be an empty tensor return tensor + num_chunks = mesh.size(dim=mesh_dim) is_padded = tensor.size(shard_spec.dim) % num_chunks != 0 if is_padded: scattered_list, pad_sizes = shard_spec._split_tensor(tensor, num_chunks, with_padding=True, contiguous=True) @@ -232,6 +228,12 @@ def redistribute_local_tensor( the target DTensorSpec, which involves the necessary collective calls to transform the local shard of the DTensor from its current spec to the target spec. """ + device_mesh = current_spec.mesh + + # if rank is not part of mesh, we simply return local_tensor, which should be an empty tensor + my_coordinate = device_mesh.get_coordinate() + if my_coordinate is None: + return local_tensor if current_spec.mesh != target_spec.mesh: # TODO: alltoall/permute reshuffling to change device_mesh if they are not the same @@ -245,17 +247,7 @@ def redistribute_local_tensor( sorted_placements = _decompose_reshard(sorted_placements) sorted_placements.sort(key=_replicate_then_shard) - device_mesh = current_spec.mesh - for i, (current, target) in sorted_placements: - my_coordinate = device_mesh.get_coordinate() - num_chunks = device_mesh.size(dim=i) - - if my_coordinate is None: - # if rank is not part of mesh, we simply return local_tensor, - # which should be an empty tensor - return local_tensor - if current == target: # short cut, just use the original local tensor new_local_tensor = local_tensor @@ -323,14 +315,13 @@ def redistribute_local_tensor( mesh_dim=i, ) shards = target_placement._split_tensor( - tensor=replicate_local_tensor, num_chunks=num_chunks, contiguous=False + tensor=replicate_local_tensor, num_chunks=device_mesh.size(dim=i), contiguous=False ) new_local_tensor = shards[my_coordinate[i]].clone() - pass elif current.is_replicate(): shards = target_placement._split_tensor( tensor=local_tensor, - num_chunks=num_chunks, + num_chunks=device_mesh.size(dim=i), contiguous=False, ) new_local_tensor = shards[my_coordinate[i]].clone() @@ -351,7 +342,7 @@ def redistribute_local_tensor( # split the tensor and return the corresponding cloned local shard shards, _ = target_placement._split_tensor( local_tensor, - num_chunks, + num_chunks=device_mesh.size(dim=i), with_padding=False, contiguous=False, ) @@ -366,7 +357,6 @@ def redistribute_local_tensor( if shard_spec.dim != target_placement.dim: # TODO: enable this with all_to_all raise NotImplementedError("Changing sharding dim is not supported yet!") - elif target.is_partial(): if current.is_partial(): mode = _get_current_dispatch_mode() @@ -377,7 +367,6 @@ def redistribute_local_tensor( # P -> R partial_spec = cast(Partial, current) new_local_tensor = mesh_all_reduce(local_tensor, device_mesh, partial_spec.reduce_op, i) - elif current.is_replicate(): # For replicate -> partial, we zero out all other ranks of the current mesh dim # and leave only 1 rank have the data, to perform a "zero cost" reshard. @@ -422,7 +411,6 @@ def redistribute_local_tensor( ) if my_coordinate[i] != 0: new_local_tensor = new_local_tensor.zero_() - pass elif current.is_shard(): # For sharded tensor -> partial, we reduce the tensor, # then follow a same way as the second case. @@ -458,6 +446,11 @@ def forward( # type: ignore[override] current_spec = input._spec ctx.current_spec = current_spec ctx.async_op = async_op + + # Early return the original DTensor if the placements are the same. + if input._spec.placements == placements: + return input + target_spec = DTensorSpec(device_mesh, tuple(placements), tensor_meta=input._spec.tensor_meta) local_tensor = input._local_tensor @@ -500,9 +493,14 @@ def backward(ctx, grad_output: "dtensor.DTensor"): else: target_placements.append(target) target_spec = DTensorSpec(previous_spec.mesh, tuple(target_placements), tensor_meta=previous_spec.tensor_meta) - local_tensor = grad_output._local_tensor - output = redistribute_local_tensor(local_tensor, current_spec, target_spec, async_op) + + # Short cut the local tensor if the placements are the same. + if current_spec == target_spec: + output = local_tensor + else: + output = redistribute_local_tensor(local_tensor, current_spec, target_spec, async_op) + output_dtensor = dtensor.DTensor( output, target_spec.mesh, diff --git a/python/vescale/initialize/deferred_init.py b/python/vescale/initialize/deferred_init.py index fd35164..00fa6c1 100644 --- a/python/vescale/initialize/deferred_init.py +++ b/python/vescale/initialize/deferred_init.py @@ -113,7 +113,9 @@ def materialize_dtensor( device_mesh = device_mesh or mesh_resources.get_current_mesh() device = device_mesh.device_type # get placements - placements: Tuple[Placement] = normalize_placements(placements, device_mesh.ndim, none_as_replicate=True) + placements: Tuple[Placement] = normalize_placements( + placements, device_mesh.ndim, tensor_ndim=tensor.ndim, none_as_replicate=True + ) # get local tensor shape local_shape = compute_local_shape(global_shape, device_mesh, placements) torch_device = torch.device(device) @@ -171,7 +173,9 @@ def materialize_dparameter( device_mesh = device_mesh or mesh_resources.get_current_mesh() device = device_mesh.device_type # get placements - placements: Tuple[Placement] = normalize_placements(placements, device_mesh.ndim, none_as_replicate=True) + placements: Tuple[Placement] = normalize_placements( + placements, device_mesh.ndim, tensor_ndim=param.data.ndim, none_as_replicate=True + ) # get local tensor shape local_shape = compute_local_shape(global_shape, device_mesh, placements) torch_device = torch.device(device) diff --git a/python/vescale/optim/base_optimizer.py b/python/vescale/optim/base_optimizer.py index 51d8a08..634ea35 100644 --- a/python/vescale/optim/base_optimizer.py +++ b/python/vescale/optim/base_optimizer.py @@ -156,7 +156,7 @@ def __init__( self, optimizer, models: Union[torch.nn.Module, List[torch.nn.Module]], - grad_hook: Optional[GradOptimizerHookBase] = BasicOptimizerHook(), + grad_hook: Optional[GradOptimizerHookBase] = BasicOptimizerHook, ) -> None: super().__init__(optimizer=optimizer) self.models = models @@ -175,17 +175,22 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP for m in self.models: - if not DModule.is_dmodule(m): - logging.warning("module has no `finish_grad_sync` method defined, skip allreducing grads") - continue # if module is wrapped by DDP, we needn't handle partial grad sync. DDP will do it. if isinstance(m, DDP): continue + if not DModule.is_dmodule(m): + logging.warning("module has no `finish_grad_sync` method defined, skip allreducing grads") + continue m.finish_grad_sync() return self.optimizer.step(closure) def zero_grad(self, set_to_none: bool = True) -> None: + from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP + self.optimizer.zero_grad(set_to_none=set_to_none) + for m in self.models: + if isinstance(m, DDP): + m.zero_grad_buffer(zero_buffer=True) def state_dict(self): return self.optimizer.state_dict() diff --git a/python/vescale/optim/distributed_optimizer.py b/python/vescale/optim/distributed_optimizer.py index a05778c..f9ca030 100644 --- a/python/vescale/optim/distributed_optimizer.py +++ b/python/vescale/optim/distributed_optimizer.py @@ -7,7 +7,9 @@ """Megatron distributed optimizer.""" import math -from typing import Dict, Sequence, Any +import inspect +from dataclasses import dataclass +from typing import Dict, Sequence, Tuple, Optional, Any import torch import torch.distributed as dist @@ -45,6 +47,87 @@ def __repr__(self) -> str: return "Range(%d,%d [%d])" % (self.start, self.end, self.size) +@dataclass +class OptimizerStateSpec: + """This class represents mapping between local flattened 1D tensor + and global original DTensor in DOptimzier, it is used for + loading or saving optimizer states using OmniStore (PyTorch DCP) + and load-time checkpoint resharding when changing tp size or dp size. + + For example, a linear layer in Vescale is DTensor(size=[1024, 1024]) + It first divides into two parts along dim=0 with tensor parallel size = 2 + + tensor_part_0 = DTensor(size=[512, 1024]) + tensor_part_1 = DTensor(size=[512, 1024]) + + Then each part's optimizer states are initalized in DOptimizer sepearately + + Assume dp=2 + For process with dp=0 tp=0, the flatten tensor is torch.Tensor(size=[262144]) + global_shape=(1024, 1024), local_shape=(256, 1024), global_offset=(0, 0) local=torch.Tensor(size=[262144]).view(local_shape) + + For process with dp=1 tp=0, the flatten tensor is torch.Tensor(size=[262144]) + global_shape=(1024, 1024), local_shape=(256, 1024), global_offset=(256, 0) local=torch.Tensor(size=[262144]).view(local_shape) + + For process with dp=0 tp=1, the flatten tensor is torch.Tensor(size=[262144]) + mapping to [512:768, 0:1024] in original DTensor + global_shape=(1024, 1024), local_shape=(256, 1024), global_offset=(512, 0) local=torch.Tensor(size=[262144]).view(local_shape) + + For process with dp=1 tp=1, the flatten tensor is torch.Tensor(size=[262144]) + global_shape=(1024, 1024), local_shape=(256, 1024), global_offset=(768, 0) local=torch.Tensor(size=[262144]).view(local_shape) + """ + + # The original DTensor shape + global_shape: Tuple[int] + # The local tensor shape ***before flattened into 1D tensor*** + local_shape: Tuple[int] + # The local tensor's offset with respect to origianl DTensor + global_offset: Tuple[int] + # The unflattened local tensor after create view using local_shape on the flattened 1D Tensor in DOptimizer + # NOTE: In order to support TP resharding and state cross dp ranks, we defer the reshaping from 1D to local_shape + # to generate saving plan using OmniStore (PyTorch DCP) + local_tensor: torch.Tensor + # If the current optimizer state is sharded by multiple dp ranks, + # we should record all ranks and their ranges + dp_ranks_ranges: Optional[Dict[int, Range]] + + +def convert_dict_with_sharded( + param_state: dict, + global_shape: Tuple[int], + local_shape: Tuple[int], + global_offset: Tuple[int], + dp_ranks_ranges: Optional[Dict[int, Range]], +): + new_param_state = {} + for k, v in param_state.items(): + if isinstance(v, torch.Tensor) and v.dim() >= 1: + # Don't unflatten tensor here, see the comments above + if not dp_ranks_ranges: + if math.prod(local_shape) != math.prod(v.shape): + print(f"rank={dist.get_rank()} name={k} global shape={global_shape}\ + local_shape={local_shape} global_offset={global_offset} real shape={v.shape}") + raise AssertionError() + new_param_state[k] = OptimizerStateSpec( + global_shape, local_shape, global_offset, v, dp_ranks_ranges + ) # , process_group) + else: + new_param_state[k] = v + return new_param_state + + +def convert_dict_sharded_to_tensor(param_state: dict, range_1d: Optional[Range]): + for k, v in param_state.items(): + if isinstance(v, OptimizerStateSpec): + # If the state is distributed on multiple dp ranks + # Get my parts + if range_1d: + param_state[k] = v.local_tensor.flatten()[range_1d.start : range_1d.end] + else: + param_state[k] = v.local_tensor.flatten() + return param_state + + class DistributedOptimizer(OptimizerBase): """Distributed optimizer, for all data types (bf16, and fp32). @@ -292,6 +375,22 @@ def __init__( self.optimizer.param_groups = [g["orig_group"] for g in self.opt_group_ranges] self.optimizer.load_state_dict(self.optimizer.state_dict()) + def build_param_sharding_info_for_checkpoint(self, model: DDP, dtype, gbuf_world_all_ranges): + param_world_index_map = model.grad_buffer_param_index_map[dtype] + for param, param_world_indexes in param_world_index_map.items(): + if param not in self.param_shard_info: + self.param_shard_info[param] = [] + for gbuf_world_range in gbuf_world_all_ranges: + param_world_start, param_world_end, _ = param_world_indexes + param_local_start = max(0, param_world_start - gbuf_world_range.start) + param_local_end = min(gbuf_world_range.size, param_world_end - gbuf_world_range.start) + + # Add param, if within local gbuf range. + if param_local_end > param_local_start: + self.param_shard_info[param].append(param_local_end - param_local_start) + else: + self.param_shard_info[param].append(0) + def build_model_gbuf_param_range_map(self, model: DDP, dtype, gbuf_world_range, bucket_offset): """ Build mapping from param reference to grad buffer shard ranges. @@ -382,6 +481,9 @@ def build_model_gbuf_range(self, model, dtype, bucket_index): # Local DP's ranges. gbuf_world_range = gbuf_world_all_ranges[data_parallel_rank] + # Get parameter sharding info of all ranks, for checkpointing. + self.build_param_sharding_info_for_checkpoint(model, dtype, gbuf_world_all_ranges) + # Get each param's ranges. param_range_map = self.build_model_gbuf_param_range_map(model, dtype, gbuf_world_range, bucket.offset) @@ -553,6 +655,10 @@ def build_model_and_main_param_groups(self, model_gbuf_ranges, param_gbuf_map, o shard_model_param.shared = model_param.shared shard_main_param.shared = model_param.shared + # copy sharded info from DTensor + shard_model_param._spec = None if not isinstance(model_param, DTensor) else model_param._spec + shard_main_param._spec = None if not isinstance(model_param, DTensor) else model_param._spec + # Add to group. model_float16_params_this_group.append(model_param) shard_float16_params_this_group.append(shard_model_param) @@ -615,99 +721,129 @@ def state_dict(self): optimizer state (e.g., exp_avg, exp_avg_sq) are stored in a separate checkpoint file by calling 'save_parameter_state()'. """ + # all gather ddp module + if self.overlap_param_gather: + for m in self.models: + self._param_all_gather(m) + # we disable all pre_forward hook needed for param sync, and reenable them + # at the end of subsequent forward. + self._disable_pre_hook() - state_dict = {} + optimizer_state = self.optimizer.state_dict() - # Optimizer state (do not store parameter state here). - state_dict["optimizer"] = {k: v for k, v in self.optimizer.state_dict().items() if k != "state"} - for param_group in state_dict["optimizer"]["param_groups"]: - del param_group["params"] + distributed_state = { + "param_group_meta": optimizer_state["param_groups"], + } + self.prefix_sum_param_groups = [] + param_groups = self.optimizer.state_dict()["param_groups"] + + for i, _ in enumerate(param_groups): + if i == 0: + self.prefix_sum_param_groups.append(0) + else: + self.prefix_sum_param_groups.append( + len(param_groups[i - 1]["params"]) + self.prefix_sum_param_groups[i - 1] + ) + for i, model in enumerate(self.models): + # key is name, + # value is dtype + param_dtype = {} + for dtype, param_maps in self.model_gbuf_ranges[i].items(): + if dtype not in distributed_state: + distributed_state[dtype] = {} + for param_map in param_maps: + for param in param_map["param_map"].keys(): + param_dtype[param] = dtype + + for param in model.parameters(): + if param in param_dtype.keys(): + dtype = param_dtype[param] + param_key = self.param_to_name[param] + group_id, local_id_in_group = self.model_param_group_index_map[param] + distributed_state[dtype][param_key] = convert_dict_with_sharded( + optimizer_state["state"][self.prefix_sum_param_groups[group_id] + local_id_in_group], + self.param_global_shape_info[param], + self.param_local_shape_info[param], + self.param_global_offset_info[param], + self.param_across_dp_ranks_info.get(param), + ) + # If it is mix percision training, we should save master fp32 weights + if not all(not group for group in self.shard_fp32_from_float16_groups): + for group in self.shard_fp32_from_float16_groups: + for param in group: + original_param = self.param_to_origin_param_for_shard_fp32_from_float16_groups[param] + name = self.param_to_name[original_param] + distributed_state[torch.float32][name]["shard_fp32_from_float16_groups"] = OptimizerStateSpec( + self.param_global_shape_info[original_param], + self.param_local_shape_info[original_param], + self.param_global_offset_info[original_param], + param, + self.param_across_dp_ranks_info.get(original_param), + ) - return state_dict + return distributed_state def load_state_dict(self, state_dict): - """Load the state dict. - - As detailed in state_dict(), the state dict contains all non- - parameter-related variables. This method is notably longer than - state_dict(), because the Torch optimizers state has yet to be - allocated at this point, and so we must do a cross referencing between - the optimizers state (and the ordering it expects for parameter state) - and this DP rank's shards. The optimizer at this point does not contain - any tensor dimension information, so we must get these dimensions from - the DP shards mapped during DistributedOptimizer.__init__(). - - The tensor parameter state is loaded via load_parameter_state(), and - so this method also must populate the loaded state dict with dummy - tensor data (i.e., via torch.empty() below). This will be overwritten - during load_parameter_state(). - - ** Note: Torch optimizer's state structure. ** - The Torch optimizer stores its state in two levels. The top level is a - list of groups, where each group contains a list of integer indexes - (corresponding to parameters) that index into a master parameter list - that is shared by all groups. As such, three values are necessary for - maintaining this ordering: - - - group_index : The group to which a parameter belongs. - - group_order : The index of a parameter within its group. - - state_order : The index of a parameter within the shared parameter - list. """ + Load the state dict. + """ + optimizer_state = {"param_groups": state_dict["param_group_meta"]} + original_optimizer_state = self.optimizer.state_dict() + # update params + for i, param_group in enumerate(optimizer_state["param_groups"]): + # Just assign param indices, assign param directly leading to deepcopy error + if len(param_group["params"]) != len(original_optimizer_state["param_groups"][i]["params"]): + param_group["params"] = original_optimizer_state["param_groups"][i]["params"] + # resume optimizer state: + optimizer_state["state"] = {} + param_index = 0 - # Get the Torch optimizer's state dict. - # - This 'inner' optimizer at this point is unallocated, and only - # contains an integer odering of parameters within each group, and - # the ordering of parameters within its flattened parameter state - # list. - inner_state_dict = self.optimizer.state_dict() - state_dict_param_groups = [ - { - **group, - "params": list(inner_state_dict["param_groups"][idx]["params"]), - } - for idx, group in enumerate(state_dict["optimizer"]["param_groups"]) - ] - - # Allocate 'dummy' data for optimizer state (i.e., torch.empty() below) - # - Real data is overwritten during load_parameter_state(). - state_dict_state = [] - for gbuf_range_maps in self.model_gbuf_ranges: - for gbuf_range_map_for_all_buckets in gbuf_range_maps.values(): - for gbuf_range_map in gbuf_range_map_for_all_buckets: - for model_param, param_range_map in gbuf_range_map["param_map"].items(): - # Get parameter ordering information (see method docstring - # for details). - group_index, group_order = self.model_param_group_index_map[model_param] - state_order = inner_state_dict["param_groups"][group_index]["params"][group_order] - - # Allocate dummy tensors. - numel = len(param_range_map["gbuf_world"]) - init_shard = lambda: torch.empty( - (numel,), dtype=torch.float32, device=torch.cuda.current_device() - ) - - state_dict_state.append( - ( - state_order, - { - "exp_avg": init_shard(), - "exp_avg_sq": init_shard(), - }, - ) - ) - - # Sort by state order (see method docstring for details). - state_dict_state.sort(key=lambda s: s[0]) - state_dict_state = {s[0]: s[1] for s in state_dict_state} - - # Optimizer. - self.optimizer.load_state_dict( - { - "state": state_dict_state, - "param_groups": state_dict_param_groups, - } - ) + for i, model in enumerate(self.models): + param_dtype = {} + for dtype, param_maps in self.model_gbuf_ranges[i].items(): + for param_map in param_maps: + for param in param_map["param_map"].keys(): + param_dtype[param] = dtype + param_list = [] + for param in model.parameters(): + if param in param_dtype.keys(): + dtype = param_dtype[param] + ranges = self.param_across_dp_ranks_info.get(param) + name = self.param_to_name[param] + param_list.append((param, dtype, ranges, name)) + + for param_info in param_list: + if param_info[2]: + my_range = param_info[2][self.current_global_rank] + else: + my_range = None + group_id, local_id_in_group = self.model_param_group_index_map[param_info[0]] + optimizer_state["state"][self.prefix_sum_param_groups[group_id] + local_id_in_group] = ( + convert_dict_sharded_to_tensor(state_dict[param_info[1]][param_info[3]], my_range) + ) + param_index += 1 + + self.optimizer.load_state_dict(optimizer_state) + + if not all(not group for group in self.shard_fp32_from_float16_groups): + for group in self.shard_fp32_from_float16_groups: + for param in group: + original_param = self.param_to_origin_param_for_shard_fp32_from_float16_groups[param] + name = self.param_to_name[original_param] + # The weights have been flatten into 1D and get range based on current rank (if necessary) + # in the "resume optimizer state loop + param.copy_(state_dict[torch.float32][name]["shard_fp32_from_float16_groups"]) + # state_dict['shard_fp32_from_float16_groups'] + # optimizer_state['shard_fp32_from_float16_groups'] + # TODO: Copy data for the main params. + # for current_group, saved_group in zip( + # self.shard_fp32_from_float16_groups, + # state_dict["shard_fp32_from_float16_groups"]): + # for current_param, saved_param in zip(current_group, saved_group): + # if isinstance(current_param.data, DTensor): + # current_param.data._local_tensor.copy_(saved_param.data) + # else: + # current_param.data.copy_(saved_param.data) def zero_grad(self, set_to_none=True): """ @@ -1107,3 +1243,13 @@ def get_main_grads_for_grad_norm(self): grads_for_norm.append(grad._local_tensor if isinstance(grad, DTensor) else grad) return grads_for_norm + + +def initialize_optimizer_state(optimizer: DistributedOptimizer): + optimizer._copy_model_grads_to_main_grads() + orig_optimizer = optimizer.optimizer + for group in orig_optimizer.param_groups: + param_list = inspect.signature(orig_optimizer._init_group).parameters + num_params = len(param_list) + args = [group] + [[] for i in range(num_params - 1)] + orig_optimizer._init_group(*args) diff --git a/test/dmodule/test_plans.py b/test/dmodule/test_plans.py index 6af91eb..616b339 100644 --- a/test/dmodule/test_plans.py +++ b/test/dmodule/test_plans.py @@ -20,14 +20,14 @@ from torch import nn from torch.testing._internal.common_utils import run_tests -from common_dtensor import DTensorTestBase, with_comms_device +from common_dtensor import DTensorTestBase, with_comms_device, with_comms from vescale.dmodule.api import parallelize_module from vescale.dtensor.device_mesh import DeviceMesh from vescale.dtensor.placement_types import Replicate, Shard -config = {"seq_length": 8, "head_size": 4, "hidden_size": 4 * 4, "n_head": 4, "batch_size": 4} +CONFIG = {"batch_size": 4, "seq_length": 4, "hidden_size": 4} class MLP(nn.Module): @@ -71,18 +71,14 @@ def forward(self, x): } -class DMLP(nn.Module): +class Block(nn.Module): def __init__(self, config): super().__init__() - self.fc1 = nn.Linear(config["hidden_size"], config["hidden_size"] * 4) - self.gelu = torch.nn.GELU() - self.fc2 = nn.Linear(config["hidden_size"] * 4, config["hidden_size"]) + self.ln = nn.LayerNorm(config["hidden_size"], bias=False) + self.mlp = MLP(config) def forward(self, x): - x = self.fc1(x) - x = self.gelu(x) - x = self.fc2(x) - return x + return self.mlp(self.ln(x)) class DModuleTestPlans(DTensorTestBase): @@ -94,20 +90,20 @@ def _run_plan(self, param_sharding_plan, fwd_resharding_plan, devce_type): device_mesh = DeviceMesh(devce_type, list(range(self.world_size))) # create golden model (local replicate) - mlp_golden = MLP(config) + mlp_golden = MLP(CONFIG) mlp_golden.to(devce_type) for name, param in mlp_golden.named_parameters(): dist.all_reduce(param, async_op=False) # create dmodule (by plans) - dmlp = DMLP(config) + dmlp = MLP(CONFIG) dmlp.to(devce_type) dmlp.load_state_dict(mlp_golden.state_dict()) parallelize_module(dmlp, device_mesh, {"parameter": param_sharding_plan, "forward": fwd_resharding_plan}) # create data (local replicate) input_golden = torch.randn( - config["batch_size"] * config["seq_length"], config["hidden_size"], device=devce_type, requires_grad=False + CONFIG["batch_size"] * CONFIG["seq_length"], CONFIG["hidden_size"], device=devce_type, requires_grad=False ) dist.all_reduce(input_golden, async_op=False) input_tensor = input_golden.detach().clone() @@ -146,11 +142,59 @@ def test_cuda(self): def test_wrong_plan(self): device_mesh = DeviceMesh("cuda", list(range(self.world_size))) # create dmodule (by plans) - dmlp = DMLP(config) + mlp = MLP(CONFIG) with self.assertRaises(KeyError): - parallelize_module(dmlp, device_mesh, {"parameters": param_sharding_plan1, "forward": fwd_resharding_plan1}) + parallelize_module(mlp, device_mesh, {"parameters": param_sharding_plan1, "forward": fwd_resharding_plan1}) with self.assertRaises(KeyError): - parallelize_module(dmlp, device_mesh, {"parameter": param_sharding_plan1, "forwards": fwd_resharding_plan1}) + parallelize_module(mlp, device_mesh, {"parameter": param_sharding_plan1, "forwards": fwd_resharding_plan1}) + + @with_comms + def test_tp_plan(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + + sharding_plan = { + "parameter": { + "mlp.fc1.weight": [Shard(0)], + "mlp.fc1.bias": [Shard(0)], + "mlp.fc2.weight": [Shard(1)], + "mlp.fc2.bias": [Replicate()], + }, + "forward": { + "input": [[Replicate()]], + "ln.input": [[Replicate()]], # no SP + "mlp.input": [[Replicate()]], + "mlp.fc2.output": [[Replicate()]], + }, + } + + dmodel = parallelize_module(Block(CONFIG), device_mesh, sharding_plan) + input = torch.ones((CONFIG["batch_size"], CONFIG["seq_length"], CONFIG["hidden_size"]), requires_grad=True) + output = dmodel(input).to_local() + output.sum().backward() + + @with_comms + def test_tp_sp_plan(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + + sharding_plan = { + "parameter": { + "mlp.fc1.weight": [Shard(0)], + "mlp.fc1.bias": [Shard(0)], + "mlp.fc2.weight": [Shard(1)], + "mlp.fc2.bias": [Replicate()], + }, + "forward": { + "input": [[Replicate()]], + "ln.input": [[Shard(1)]], # SP + "mlp.input": [[Replicate()]], + "mlp.fc2.output": [[Replicate()]], + }, + } + + dmodel = parallelize_module(Block(CONFIG), device_mesh, sharding_plan) + input = torch.ones((CONFIG["batch_size"], CONFIG["seq_length"], CONFIG["hidden_size"]), requires_grad=True) + output = dmodel(input).to_local() + output.sum().backward() if __name__ == "__main__": diff --git a/test/dtensor/general/test_api.py b/test/dtensor/general/test_api.py index eea7bd5..259887a 100644 --- a/test/dtensor/general/test_api.py +++ b/test/dtensor/general/test_api.py @@ -8,13 +8,16 @@ # Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. ################################################################################ +from typing import List from common_dtensor import DTensorTestBase, with_comms import torch import torch.nn as nn from torch.testing._internal.common_utils import run_tests -from vescale import DeviceMesh, Replicate, Shard, distribute_tensor +from vescale.dtensor.device_mesh import DeviceMesh +from vescale.dtensor.api import distribute_tensor +from vescale.dtensor.placement_types import Replicate, Shard class MyModel(nn.Module): @@ -113,6 +116,19 @@ def test_jit_script_func(self): out = my_jit_add(dtensor) self.assertEqual(out.to_local(), torch.tan(tensor_to_distribute + 1)) + @with_comms + def test_tolist(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + global_shape = (self.world_size, self.world_size) + for shard_spec in [[Replicate()], [Shard(0)], [Shard(1)]]: + dist_tensor = distribute_tensor(torch.ones(global_shape, dtype=torch.float32), device_mesh, shard_spec) + nested_list = dist_tensor.tolist() + self.assertTrue(isinstance(nested_list, List)) + self.assertTrue(isinstance(nested_list[0], List)) + self.assertTrue(nested_list[0][0] == 1.0) + shape = (len(nested_list), len(nested_list[0])) + self.assertEqual(shape, dist_tensor._local_tensor.shape) + if __name__ == "__main__": run_tests() diff --git a/test/dtensor/general/test_dtensor.py b/test/dtensor/general/test_dtensor.py index 20a62f0..e05e143 100644 --- a/test/dtensor/general/test_dtensor.py +++ b/test/dtensor/general/test_dtensor.py @@ -90,15 +90,16 @@ def test_dtensor_constructor(self): def test_dtensor_stride(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) shard0_spec = [Shard(0)] + shard1_spec = [Shard(1)] + local_tensor = torch.randn(4, 8) - global_shape = torch.Size([self.world_size * 4, 8]) + # global_shape = torch.Size([self.world_size * 4, 8]) dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard0_spec) # won't affect stride self.assertEqual(dist_tensor.stride(), (8, 1)) - shard1_spec = [Shard(1)] local_tensor = torch.randn(8, 4) - global_shape = torch.Size([8, self.world_size * 4]) + # global_shape = torch.Size([8, self.world_size * 4]) dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard1_spec) # will affect stride after DT initialized self.assertEqual(dist_tensor.stride(), (4 * self.world_size, 1)) @@ -106,14 +107,16 @@ def test_dtensor_stride(self): # if initialized from a transposed mat (Even Sharding) local_tensor = torch.randn(8, 4, 8) local_tensor_t = local_tensor.permute(1, 2, 0) - global_shape = torch.Size([4, self.world_size * 8, 8]) + # global_shape = torch.Size([4, self.world_size * 8, 8]) self.assertEqual(local_tensor_t.stride(), (8, 1, 32)) - dist_tensor = DTensor.from_local(local_tensor_t, device_mesh, shard1_spec, run_check=False) + dist_tensor = DTensor.from_local( + local_tensor_t, device_mesh, shard1_spec, support_uneven=False + ) # TODO: resolve global_stride = (8 * self.world_size, 1, 32 * self.world_size) self.assertEqual(dist_tensor.stride(), global_stride) @with_comms - def test_from_local(self): + def test_from_local_default(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) shard_spec = [Shard(0)] local_tensor = torch.randn(3, 3) @@ -150,187 +153,116 @@ def test_from_local(self): self.assertEqual(local_tensor_with_grad.grad, expected_grad) @with_comms - def test_from_local_with_given_shape_stride(self): - torch.manual_seed(0) - + def test_from_local__check_shape_uneven(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) # Assert given shape and stride local_tensor = torch.randn(3 * self.world_size, 3) global_tensor = local_tensor shard_spec = (Replicate(),) - with self.assertRaisesRegex(ValueError, "Please pass both shape and stride at the same time."): + with self.assertRaisesRegex(ValueError, "Please pass both shape and stride at the same time!"): DTensor.from_local(local_tensor, device_mesh, shard_spec, shape=global_tensor.size()) - with self.assertRaisesRegex(ValueError, "Please pass both shape and stride at the same time."): + with self.assertRaisesRegex(ValueError, "Please pass both shape and stride at the same time!"): DTensor.from_local(local_tensor, device_mesh, shard_spec, stride=global_tensor.stride()) - # Replicate - local_tensor = torch.randn(3 * self.world_size, 3) - global_tensor = local_tensor - dist_tensor = DTensor.from_local( - local_tensor, device_mesh, [Replicate()], shape=global_tensor.shape, stride=global_tensor.stride() - ) - self.assertEqual(dist_tensor.size(), global_tensor.shape) - self.assertEqual(dist_tensor.stride(), global_tensor.stride()) - dist_tensor = dist_tensor.redistribute(placements=[Replicate()]) - self.assertEqual(dist_tensor._local_tensor, global_tensor, atol=0, rtol=0) - - # Partial - local_tensor = torch.randn(3 * self.world_size, 3) - global_tensor = local_tensor - dist_tensor = DTensor.from_local( - local_tensor, device_mesh, [Partial()], shape=global_tensor.shape, stride=global_tensor.stride() - ) - self.assertEqual(dist_tensor.size(), global_tensor.shape) - self.assertEqual(dist_tensor.stride(), global_tensor.stride()) - dist_tensor = dist_tensor.redistribute(placements=[Replicate()]) - self.assertEqual(dist_tensor._local_tensor, global_tensor * self.world_size, atol=0, rtol=0) - - # Even Shard(0) - local_tensors = [torch.randn(3, 3) for _ in range(self.world_size)] - global_tensor = torch.concat(local_tensors, dim=0) - dist_tensor = DTensor.from_local( - local_tensors[self.rank], device_mesh, [Shard(0)], shape=global_tensor.shape, stride=global_tensor.stride() - ) - self.assertEqual(dist_tensor.size(), global_tensor.shape) - self.assertEqual(dist_tensor.stride(), global_tensor.stride()) - dist_tensor = dist_tensor.redistribute(placements=[Replicate()]) - self.assertEqual(dist_tensor._local_tensor, global_tensor, atol=0, rtol=0) - - # Uneven Shard(0) without pad - global_tensor = torch.randn(self.world_size + 1, 2) - local_tensors, _ = Shard(0)._split_tensor( - global_tensor, - device_mesh.size(dim=0), - with_padding=False, - contiguous=True, - ) - dist_tensor = DTensor.from_local( - local_tensors[self.rank], - device_mesh, - (Shard(0),), - shape=global_tensor.size(), - stride=global_tensor.stride(), - ) - self.assertEqual(dist_tensor.size(), global_tensor.size()) - self.assertEqual(dist_tensor.stride(), global_tensor.stride()) - dist_tensor = dist_tensor.redistribute(placements=[Replicate()]) - self.assertEqual(dist_tensor._local_tensor, global_tensor, atol=0, rtol=0) - - # # Uneven Shard(0) with pad # TODO - # global_tensor = torch.randn(self.world_size + 1, 2) - # local_tensors, _ = Shard(0)._split_tensor( - # global_tensor, - # device_mesh.size(dim=0), - # with_padding=True, - # contiguous=True, - # ) - # with self.assertRaisesRegex( - # ValueError, "Given global shape and stride does not match local shape and stride!" - # ): - # DTensor.from_local(local_tensors[self.rank], device_mesh, (Shard(0),), - # shape=global_tensor.size(), stride=global_tensor.stride()) - - @with_comms - def test_from_local_without_given_shape_stride(self): - torch.manual_seed(0) - - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - - # Replicate - local_tensor = torch.randn(3 * self.world_size, 3) - global_tensor = local_tensor - dist_tensor = DTensor.from_local(local_tensor, device_mesh, [Replicate()]) - self.assertEqual(dist_tensor.size(), global_tensor.shape) - self.assertEqual(dist_tensor.stride(), global_tensor.stride()) - dist_tensor = dist_tensor.redistribute(placements=[Replicate()]) - self.assertEqual(dist_tensor._local_tensor, global_tensor, atol=0, rtol=0) - - # Partial - local_tensor = torch.randn(3 * self.world_size, 3) - global_tensor = local_tensor - dist_tensor = DTensor.from_local(local_tensor, device_mesh, [Partial()]) - self.assertEqual(dist_tensor.size(), global_tensor.shape) - self.assertEqual(dist_tensor.stride(), global_tensor.stride()) - dist_tensor = dist_tensor.redistribute(placements=[Replicate()]) - self.assertEqual(dist_tensor._local_tensor, global_tensor * self.world_size, atol=0, rtol=0) - - # Even Shard(0) - local_tensors = [torch.randn(3, 3) for _ in range(self.world_size)] - global_tensor = torch.concat(local_tensors, dim=0) - dist_tensor = DTensor.from_local(local_tensors[self.rank], device_mesh, [Shard(0)]) - self.assertEqual(dist_tensor.size(), global_tensor.shape) - self.assertEqual(dist_tensor.stride(), global_tensor.stride()) - dist_tensor = dist_tensor.redistribute(placements=[Replicate()]) - self.assertEqual(dist_tensor._local_tensor, global_tensor, atol=0, rtol=0) - - # Uneven Shard(0) without pad - global_tensor = torch.randn(self.world_size + 1, 2) - local_tensors, _ = Shard(0)._split_tensor( - global_tensor, - device_mesh.size(dim=0), - with_padding=False, - contiguous=True, - ) - dist_tensor = DTensor.from_local(local_tensors[self.rank], device_mesh, (Shard(0),)) - self.assertEqual(dist_tensor.size(), global_tensor.size()) - self.assertEqual(dist_tensor.stride(), global_tensor.stride()) - dist_tensor = dist_tensor.redistribute(placements=[Replicate()]) - self.assertEqual(dist_tensor._local_tensor, global_tensor, atol=0, rtol=0) - - @with_comms - def test_from_local_without_run_check(self): - torch.manual_seed(0) + # Check Triple kwargs + def _core_test(run_check, given_shape_stride, support_uneven): + torch.manual_seed(0) - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - - # Replicate - local_tensor = torch.randn(3 * self.world_size, 3) - global_tensor = local_tensor - dist_tensor = DTensor.from_local(local_tensor, device_mesh, [Replicate()], run_check=False) - self.assertEqual(dist_tensor.size(), global_tensor.shape) - self.assertEqual(dist_tensor.stride(), global_tensor.stride()) - dist_tensor = dist_tensor.redistribute(placements=[Replicate()]) - self.assertEqual(dist_tensor._local_tensor, global_tensor, atol=0, rtol=0) - - # Partial - local_tensor = torch.randn(3 * self.world_size, 3) - global_tensor = local_tensor - dist_tensor = DTensor.from_local(local_tensor, device_mesh, [Partial()], run_check=False) - self.assertEqual(dist_tensor.size(), global_tensor.shape) - self.assertEqual(dist_tensor.stride(), global_tensor.stride()) - dist_tensor = dist_tensor.redistribute(placements=[Replicate()]) - self.assertEqual(dist_tensor._local_tensor, global_tensor * self.world_size, atol=0, rtol=0) - - # Even Shard(0) - local_tensors = [torch.randn(3, 3) for _ in range(self.world_size)] - global_tensor = torch.concat(local_tensors, dim=0) - dist_tensor = DTensor.from_local(local_tensors[self.rank], device_mesh, [Shard(0)], run_check=False) - self.assertEqual(dist_tensor.size(), global_tensor.shape) - self.assertEqual(dist_tensor.stride(), global_tensor.stride()) - dist_tensor = dist_tensor.redistribute(placements=[Replicate()]) - self.assertEqual(dist_tensor._local_tensor, global_tensor, atol=0, rtol=0) - - # Uneven Shard(0) without pad - global_tensor = torch.randn(self.world_size + 1, 2) - local_tensors, _ = Shard(0)._split_tensor( - global_tensor, - device_mesh.size(dim=0), - with_padding=False, - contiguous=True, - ) - dist_tensor = DTensor.from_local( - local_tensors[self.rank], - device_mesh, - (Shard(0),), - run_check=False, - shape=global_tensor.size(), - stride=global_tensor.stride(), - ) - self.assertEqual(dist_tensor.size(), global_tensor.size()) - self.assertEqual(dist_tensor.stride(), global_tensor.stride()) - dist_tensor = dist_tensor.redistribute(placements=[Replicate()]) - self.assertEqual(dist_tensor._local_tensor, global_tensor, atol=0, rtol=0) + # Replicate + local_tensor = torch.randn(3 * self.world_size, 3) + global_tensor = local_tensor + dist_tensor = DTensor.from_local( + local_tensor, + device_mesh, + [Replicate()], + run_check=run_check, + shape=global_tensor.shape if given_shape_stride else None, + stride=global_tensor.stride() if given_shape_stride else None, + support_uneven=support_uneven, + ) + self.assertEqual(dist_tensor.size(), global_tensor.shape) + self.assertEqual(dist_tensor.stride(), global_tensor.stride()) + dist_tensor = dist_tensor.redistribute(placements=[Replicate()]) + self.assertEqual(dist_tensor._local_tensor, global_tensor, atol=0, rtol=0) + + # Partial + local_tensor = torch.randn(3 * self.world_size, 3) + global_tensor = local_tensor + dist_tensor = DTensor.from_local( + local_tensor, + device_mesh, + [Partial()], + run_check=run_check, + shape=global_tensor.shape if given_shape_stride else None, + stride=global_tensor.stride() if given_shape_stride else None, + support_uneven=support_uneven, + ) + self.assertEqual(dist_tensor.size(), global_tensor.shape) + self.assertEqual(dist_tensor.stride(), global_tensor.stride()) + dist_tensor = dist_tensor.redistribute(placements=[Replicate()]) + self.assertEqual(dist_tensor._local_tensor, global_tensor * self.world_size, atol=0, rtol=0) + + # Even Shard(0) + local_tensors = [torch.randn(3, 3) for _ in range(self.world_size)] + global_tensor = torch.concat(local_tensors, dim=0) + dist_tensor = DTensor.from_local( + local_tensors[self.rank], + device_mesh, + [Shard(0)], + run_check=run_check, + shape=global_tensor.shape if given_shape_stride else None, + stride=global_tensor.stride() if given_shape_stride else None, + support_uneven=support_uneven, + ) + self.assertEqual(dist_tensor.size(), global_tensor.shape) + self.assertEqual(dist_tensor.stride(), global_tensor.stride()) + dist_tensor = dist_tensor.redistribute(placements=[Replicate()]) + self.assertEqual(dist_tensor._local_tensor, global_tensor, atol=0, rtol=0) + + # Uneven Shard(0) without pad + global_tensor = torch.randn(self.world_size + 1, 2) + local_tensors, _ = Shard(0)._split_tensor( + global_tensor, + device_mesh.size(dim=0), + with_padding=False, + contiguous=True, + ) + dist_tensor = DTensor.from_local( + local_tensors[self.rank], + device_mesh, + (Shard(0),), + run_check=run_check, + shape=global_tensor.shape if given_shape_stride else None, + stride=global_tensor.stride() if given_shape_stride else None, + support_uneven=True, + ) + self.assertEqual(dist_tensor.size(), global_tensor.size()) + self.assertEqual(dist_tensor.stride(), global_tensor.stride()) + dist_tensor = dist_tensor.redistribute(placements=[Replicate()]) + self.assertEqual(dist_tensor._local_tensor, global_tensor, atol=0, rtol=0) + + # # Uneven Shard(0) with pad # TODO + # global_tensor = torch.randn(self.world_size + 1, 2) + # local_tensors, _ = Shard(0)._split_tensor( + # global_tensor, + # device_mesh.size(dim=0), + # with_padding=True, + # contiguous=True, + # ) + # with self.assertRaisesRegex( + # ValueError, "Given global shape and stride does not match local shape and stride!" + # ): + # DTensor.from_local(local_tensors[self.rank], device_mesh, (Shard(0),), + # run_check=run_check, + # shape=global_tensor.shape if given_shape_stride else None, + # stride=global_tensor.stride() if given_shape_stride else None, + # support_uneven=True,) + + for run_check in [True, False]: + for given_shape_stride in [True, False]: + for support_uneven in [True, False]: + _core_test(run_check, given_shape_stride, support_uneven) @with_comms def test_to_local(self): @@ -816,7 +748,7 @@ def test_fakify_dtensor(self): def fn(x): return x - x = DTensor.from_local(torch.rand(1), mesh, [Shard(0)], run_check=False) + x = DTensor.from_local(torch.rand(1), mesh, [Shard(0)], run_check=False, support_uneven=False) ref = fn(x) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) @@ -831,7 +763,7 @@ def test_dynamo_dtensor(self): def fn(x): return x * x + 2 - x = DTensor.from_local(torch.rand(1), mesh, [Shard(0)], run_check=False) + x = DTensor.from_local(torch.rand(1), mesh, [Shard(0)], run_check=False, support_uneven=False) ref = fn(x) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) @@ -844,7 +776,7 @@ def test_dynamo_dtensor_from_local(self): # create DTensor inside fn and run some compute def fn(x): - dt = DTensor.from_local(x, mesh, [Replicate()], run_check=False) + dt = DTensor.from_local(x, mesh, [Replicate()], run_check=False, support_uneven=False) return dt.to_local() + 2 # below is the op approach for reference @@ -871,7 +803,7 @@ def test_dynamo_dtensor_from_local_redistribute(self): # pass in tensor as inputs/outputs, create DTensor and run redistribute # (allgather collective) inside the fn def fn(x): - dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False) + dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False, support_uneven=False) return dt.redistribute(mesh, [Replicate()]).to_local() + 2 x = torch.ones(1) diff --git a/test/dtensor/ops/test_math_ops.py b/test/dtensor/ops/test_math_ops.py new file mode 100644 index 0000000..e2236be --- /dev/null +++ b/test/dtensor/ops/test_math_ops.py @@ -0,0 +1,327 @@ +################################################################################ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +################################################################################ +# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. +################################################################################ + +import itertools +from unittest import skip +from common_dtensor import ( + DTensorTestBase, + skip_unless_torch_gpu, + with_comms, +) +from unittest import skip + +import torch +from torch.testing._internal.common_utils import run_tests, instantiate_parametrized_tests, parametrize +from vescale import distribute_tensor +from vescale.dtensor.placement_types import Replicate, Shard + + +class DistMathOpsTest(DTensorTestBase): + def linear_op_reductions(self, op_str): + device_mesh = self.build_device_mesh() + shard_spec = [Shard(0)] + + tensor = torch.randn(12, 8, 8) + dtensor = distribute_tensor(tensor, device_mesh, shard_spec) + + op = getattr(tensor, op_str) + op_dt = getattr(dtensor, op_str) + + keep_dim_or_not = [True, False, None] + for dim in range(tensor.ndim): + for keep_dim in keep_dim_or_not: + args = (dim, keep_dim) if keep_dim is not None else (dim,) + if op_str in ("max", "min"): + # min and max return a tuple when dim specified + dim_reduced_tensor, _ = op(*args) + dt_reduced, _ = op_dt(*args) + else: + dim_reduced_tensor = op(*args) + dt_reduced = op_dt(*args) + dt_dim_reduced_tensor = dt_reduced.full_tensor() + self.assertEqual(dt_dim_reduced_tensor, dim_reduced_tensor) + + full_reduced_tensor = op() + dt_full_reduced = op_dt().full_tensor() + self.assertEqual(dt_full_reduced, full_reduced_tensor) + + @with_comms + def test_linear_op_reductions(self): + for op_str in ("all", "sum", "prod", "max", "min"): + self.linear_op_reductions(op_str) + + @with_comms + @skip_unless_torch_gpu + def test_mean(self): + self.linear_op_reductions("mean") + + # TODO: forward test can be removed once test_softmax_with_bwd passes on CPU + @with_comms + @skip("failed") + def test_softmax_fwd(self): + device_mesh = self.build_device_mesh() + + x = torch.rand(8, 12, 16, device=self.device_type) + dims = range(3) # used to convert -1 to the actual dim + softmax_dims = [-1, 0, 1, 2] + shard_dims = [-1, 0, 1, 2] + test_list = list(itertools.product(softmax_dims, shard_dims)) + + for softmax_dim, shard_dim in test_list: + local_y = torch.nn.functional.softmax(x, dim=softmax_dim, dtype=torch.float32) + dist_x = distribute_tensor(x, device_mesh, [Shard(shard_dim)]) + if dims[shard_dim] == dims[softmax_dim]: + with self.assertRaisesRegex(Exception, "Cannot run .* on sharding dimension!$"): + dist_y = torch.nn.functional.softmax(dist_x, dim=softmax_dim, dtype=torch.float32) + else: + dist_y = torch.nn.functional.softmax(dist_x, dim=softmax_dim, dtype=torch.float32) + self.assertTrue(dist_y.placements[0].is_shard(dim=shard_dim)) + dist_y = dist_y.redistribute(device_mesh, [Replicate()]) + self.assertEqual(dist_y.to_local(), local_y) + + @with_comms + @parametrize("func", [torch.argmax, torch.argmin]) + def test_arg_max_arg_min(self, func): + device_mesh = self.build_device_mesh() + shard_spec = [Shard(0)] + tensor = torch.randn(12, 8, 8) + dtensor = distribute_tensor(tensor, device_mesh, shard_spec) + + keep_dim_or_not = [True, False, None] + for dim in range(1, tensor.ndim): + for keep_dim in keep_dim_or_not: + args = (dim, keep_dim) if keep_dim is not None else (dim,) + dt_result = func(dtensor, *args) + t_result = func(tensor, *args) + self.assertEqual(dt_result.full_tensor(), t_result) + + shard_spec = [Replicate()] + tensor = torch.randn(12, 8, 8) + dtensor = distribute_tensor(tensor, device_mesh, shard_spec) + dt_result = func(dtensor) + t_result = func(tensor) + self.assertEqual(dt_result.full_tensor(), t_result) + + # TODO: get test_softmax_with_bwd pass on CPU + # DTensor's _softmax_backward_data produces wrong result on CPU on certain dimension. + # fail_on_cpu_list = [(0, -1), (1, -1)] + @with_comms + @skip_unless_torch_gpu + @skip("failed") + def test_softmax_with_bwd(self): + device_mesh = self.build_device_mesh() + + dims = range(3) # used to convert -1 to the actual dim + softmax_dims = [-1, 0, 1, 2] + shard_dims = [-1, 0, 1, 2] + test_list = list(itertools.product(softmax_dims, shard_dims)) + + for params in test_list: + softmax_dim, shard_dim = params + x = torch.rand(8, 12, 16, device=self.device_type, requires_grad=True) + self.assertTrue(x.requires_grad) + local_y = torch.nn.functional.softmax(x, dim=softmax_dim, dtype=torch.float32).sum() + local_y.backward() + + dist_x = distribute_tensor(x, device_mesh, [Shard(shard_dim)]) + self.assertTrue(dist_x.requires_grad) + if dims[softmax_dim] == dims[shard_dim]: + with self.assertRaisesRegex(Exception, "Cannot run .* on sharding dimension!$"): + dist_softmax = dist_x.softmax(dim=softmax_dim) + else: + dist_softmax = dist_x.softmax(dim=softmax_dim) + self.assertTrue(dist_softmax.placements[0].is_shard(dim=shard_dim)) + dist_y = dist_softmax.sum() + dist_y = dist_y.redistribute(device_mesh, [Replicate()]) + self.assertEqual(dist_y.to_local(), local_y) + self.assertIsNone(dist_x.grad) + dist_y.backward() + self.assertIsNotNone(dist_x.grad) + dist_x_grad = dist_x.grad.redistribute(device_mesh, [Replicate()]) + self.assertEqual(dist_x_grad.to_local(), x.grad) + + @with_comms + def test_onehot_replicate(self): + device_mesh = self.build_device_mesh() + tensor = torch.randint(0, 8, (8, 8)) + dtensor = distribute_tensor(tensor, device_mesh, [Replicate()]) + out = torch.nn.functional.one_hot(tensor, 8) + d_out = torch.nn.functional.one_hot(dtensor, 8) + self.assertTrue(d_out.placements[0].is_replicate()) + self.assertEqual(d_out.to_local(), out) + + @with_comms + def test_onehot_sharded(self): + device_mesh = self.build_device_mesh() + tensor = torch.randint(0, 8, (8, 8)) + dtensor = distribute_tensor(tensor, device_mesh, [Shard(0)]) + out = torch.nn.functional.one_hot(tensor, 8) + d_out = torch.nn.functional.one_hot(dtensor, 8) + self.assertTrue(d_out.placements[0].is_shard(0)) + self.assertEqual(d_out.full_tensor(), out) + + @with_comms + def test_mse_loss(self): + device_mesh = self.build_device_mesh() + tensor = torch.rand((8, 8), requires_grad=True) + dtensor = distribute_tensor(tensor, device_mesh, [Replicate()]) + loss = torch.nn.MSELoss() + + label = torch.rand((8, 8)) + d_label = distribute_tensor(label, device_mesh, [Replicate()]) + + local_loss = loss(tensor, label) + d_loss = loss(dtensor, d_label) + local_loss.backward() + d_loss.backward() + + self.assertEqual(tensor.grad, dtensor.grad.to_local()) + + @with_comms + def test_topk(self): + device_mesh = self.build_device_mesh() + tensor = torch.randn(8, 8) + topk_dim = 0 + shard_dim = 1 + dtensor = distribute_tensor(tensor, device_mesh, [Shard(shard_dim)]) + local_result = torch.topk(tensor, 2, topk_dim) + d_result = torch.topk(dtensor, 2, topk_dim) + self.assertTrue(d_result.values.placements[0].is_shard(dim=shard_dim)) + self.assertEqual(d_result.values.full_tensor(), local_result.values) + + @with_comms + def test_topk_backward(self): + device_mesh = self.build_device_mesh() + tensor = torch.randn((8, 8), requires_grad=True) + topk_dim = 0 + shard_dim = 1 + dtensor = distribute_tensor(tensor, device_mesh, [Replicate()]) + local_result = torch.topk(tensor, 2, topk_dim) + d_result = torch.topk(dtensor, 2, topk_dim) + self.assertTrue(d_result.values.placements[0].is_replicate()) + self.assertEqual(d_result.values.to_local(), local_result.values) + + loss = local_result.values.sum() + d_loss = d_result.values.sum() + + loss.backward() + d_loss.backward() + + self.assertEqual(tensor.grad, dtensor.grad.to_local()) + + @with_comms + def test_topk_backward_shard(self): + device_mesh = self.build_device_mesh() + tensor = torch.randn((8, 8), requires_grad=True) + topk_dim = 0 + shard_dim = 1 + dtensor = distribute_tensor(tensor, device_mesh, [Shard(1)]) + local_result = torch.topk(tensor, 2, topk_dim) + d_result = torch.topk(dtensor, 2, topk_dim) + self.assertTrue(d_result.values.placements[0].is_shard(dim=1)) + self.assertEqual(d_result.values.full_tensor(), local_result.values) + + loss = local_result.values.sum() + d_loss = d_result.values.full_tensor().sum() + loss.backward() + d_loss.backward() + + self.assertEqual(tensor.grad, dtensor.grad.redistribute(device_mesh, [Replicate()]).to_local()) + + @with_comms + def test_onehot(self): + device_mesh = self.build_device_mesh() + tensor = torch.randint(0, 8, (8, 8)) + dtensor = distribute_tensor(tensor, device_mesh, [Replicate()]) + out = torch.nn.functional.one_hot(tensor, 8) + d_out = torch.nn.functional.one_hot(dtensor, 8) + self.assertTrue(d_out.placements[0].is_replicate()) + self.assertEqual(d_out.to_local(), out) + + @with_comms + def test_where(self): + device_mesh = self.build_device_mesh() + tensor = torch.rand((8, 8)) + y = torch.ones((8, 8)) + dtensor = distribute_tensor(tensor, device_mesh, [Replicate()]) + d_y = distribute_tensor(y, device_mesh, [Replicate()]) + out = torch.where(tensor > 0, tensor, y) + d_out = torch.where(dtensor > 0, dtensor, d_y) + self.assertTrue(d_out.placements[0].is_replicate()) + self.assertEqual(d_out.to_local(), out) + + dtensor = distribute_tensor(tensor, device_mesh, [Shard(0)]) + d_y = distribute_tensor(y, device_mesh, [Shard(0)]) + + out = torch.where(tensor > 0, tensor, y) + d_out = torch.where(dtensor > 0, dtensor, d_y) + self.assertTrue(d_out.placements[0].is_shard(dim=0)) + self.assertEqual(d_out.full_tensor(), out) + + @with_comms + def test_where_backward(self): + device_mesh = self.build_device_mesh() + tensor = torch.rand((8, 8), requires_grad=True) + y = torch.ones((8, 8), requires_grad=True) + dtensor = distribute_tensor(tensor, device_mesh, [Replicate()]) + d_y = distribute_tensor(y, device_mesh, [Replicate()]) + out = torch.where(tensor > 0, tensor, y) + d_out = torch.where(dtensor > 0, dtensor, d_y) + self.assertTrue(d_out.placements[0].is_replicate()) + self.assertEqual(d_out.to_local(), out) + loss = out.sum() + loss.backward() + d_loss = d_out.sum() + d_loss.backward() + self.assertTrue(dtensor.grad.placements[0].is_replicate()) + self.assertEqual(dtensor.grad.to_local(), tensor.grad) + + @with_comms + def test_where_backward_shard(self): + device_mesh = self.build_device_mesh() + tensor = torch.rand((8, 8), requires_grad=True) + y = torch.ones((8, 8), requires_grad=True) + dtensor = distribute_tensor(tensor, device_mesh, [Shard(0)]) + d_y = distribute_tensor(y, device_mesh, [Shard(0)]) + out = torch.where(tensor > 0, tensor, y) + d_out = torch.where(dtensor > 0, dtensor, d_y) + loss = out.sum() + loss.backward() + d_loss = d_out.redistribute(device_mesh, [Replicate()]).sum() + d_loss.backward() + self.assertTrue(dtensor.grad.placements[0].is_shard(dim=0)) + self.assertEqual(dtensor.grad.full_tensor(), tensor.grad) + + @with_comms + def test_unique(self): + # TODO: support specifying dim, and it should be implemented in aten.unique_dim + device_mesh = self.build_device_mesh() + tensor = torch.randn(8, 8) + dtensor = distribute_tensor(tensor, device_mesh, [Replicate()]) + local_result = torch.unique(tensor) + d_result = torch.unique(dtensor) + self.assertEqual(d_result.to_local(), local_result) + + local_result, local_inverse = torch.unique(tensor, return_inverse=True) + d_result, d_inverse = torch.unique(dtensor, return_inverse=True) + self.assertEqual(d_result.to_local(), local_result) + self.assertEqual(d_inverse.to_local(), local_inverse) + + local_result, local_counts = torch.unique(tensor, return_counts=True) + d_result, d_counts = torch.unique(dtensor, return_counts=True) + self.assertEqual(d_result.to_local(), local_result) + self.assertEqual(d_counts.to_local(), local_counts) + + +instantiate_parametrized_tests(DistMathOpsTest) + + +if __name__ == "__main__": + run_tests() diff --git a/test/dtensor/ops/test_pointwise_ops.py b/test/dtensor/ops/test_pointwise_ops.py index 0f42c1b..9b5207f 100644 --- a/test/dtensor/ops/test_pointwise_ops.py +++ b/test/dtensor/ops/test_pointwise_ops.py @@ -246,6 +246,7 @@ def test_dropout_backward(self): ), ) + @skip("allowing partial dropout") def test_dropout_errors(self): device_mesh = self.build_device_mesh() with self.assertRaisesRegex(RuntimeError, "supported"): @@ -274,6 +275,16 @@ def test_mul_out(self): self.assertEqual(input_tensor, dtensor.to_local()) self.assertEqual(expected, dt.to_local()) + def test_triu(self): + device_mesh = self.build_device_mesh() + input_size = (8, 4) + tensor = torch.randn(*input_size, device=self.device_type) + d_tensor = distribute_tensor(tensor, device_mesh, [Replicate()]) + + out = torch.triu(tensor) + d_out = torch.triu(d_tensor) + self.assertEqual(d_out.to_local(), out) + if __name__ == "__main__": run_tests() diff --git a/test/dtensor/ops/test_tensor_ops.py b/test/dtensor/ops/test_tensor_ops.py new file mode 100644 index 0000000..62c156d --- /dev/null +++ b/test/dtensor/ops/test_tensor_ops.py @@ -0,0 +1,483 @@ +################################################################################ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +################################################################################ +# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. +################################################################################ + +from common_dtensor import DTensorConverter, DTensorTestBase, with_comms +from unittest import skip + +import torch +from torch.testing._internal.common_utils import run_tests +from unittest import skip + +from vescale import DeviceMesh, DTensor, distribute_tensor +from vescale.dtensor._diff import EnablePartialMode +from vescale.dtensor.placement_types import Partial, Replicate, Shard + + +class DistTensorOpsTest(DTensorTestBase): + @with_comms + def test_aten_contiguous(self): + # this op not covered by dtensor_ops + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + self._test_op( + mesh, + lambda x: torch.ops.aten.contiguous(x), + torch.randn(16, 32), + ) + + @with_comms + def test_detach(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + tensor_to_detach = torch.randn(12, 8, requires_grad=True) + mat = distribute_tensor(tensor_to_detach, device_mesh, shard_spec) + detached_mat = mat.detach() + self.assertFalse(detached_mat is mat) + + @with_comms + def test_clone(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + specs = [[Replicate()], [Shard(0)]] + tensor_to_clone = torch.randn(12, 8, requires_grad=True) + for spec in specs: + mat = distribute_tensor(tensor_to_clone, device_mesh, spec) + cloned_mat = mat.clone() + self.assertFalse(cloned_mat is mat) + self.assertEqual(cloned_mat.to_local(), mat.to_local()) + + @with_comms + def test_contiguous(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + tensor = torch.rand(3, 5, 6, requires_grad=True) + sharding = [Shard(0)] + dist_tensor = DTensor.from_local(tensor, device_mesh, sharding) + self.assertTrue(dist_tensor.is_contiguous()) + # shard on dim 0 should not change stride (30, 6, 1) + self.assertEqual(dist_tensor.stride(), tensor.stride()) + + new_dt = dist_tensor.transpose(0, 2) + self.assertFalse(new_dt.is_contiguous()) + self.assertFalse(new_dt.to_local().is_contiguous()) + # check stride + self.assertEqual(new_dt.stride(), (1, 6, 30)) + + new_dt = new_dt.contiguous() + self.assertTrue(new_dt.is_contiguous()) + self.assertTrue(new_dt.to_local().is_contiguous()) + # check stride + self.assertEqual(dist_tensor.stride(), tensor.stride()) + + # check backward + new_dt.to_local().sum().backward() + self.assertEqual(tensor.grad, torch.ones(3, 5, 6)) + + @with_comms + @skip("fail") + def test_inplace_op(self): + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + input_tensor = torch.randn((12, 3), device=self.device_type) + dt_to_add = distribute_tensor(input_tensor, mesh, [Shard(0)]) + dt_to_mul = dt_to_add.clone() + expected_add_dt = dt_to_add.clone() + 3 + add_res = dt_to_add.add_(3) + expected_mul_dt = dt_to_mul.clone() * 3 + mul_res = dt_to_mul.mul_(3) + # inplace op should be the same instance before and after + self.assertTrue(add_res is dt_to_add) + self.assertEqual(add_res.to_local(), expected_add_dt.to_local()) + + self.assertTrue(mul_res is dt_to_mul) + self.assertEqual(mul_res.to_local(), expected_mul_dt.to_local()) + + # test inplace op self and other dtensor with other specs + # and make sure out spec not change + shard_spec = [Shard(0)] + partial_spec = [Partial()] + dt_to_inplace_add = distribute_tensor(input_tensor, mesh, shard_spec) + partial_grad = DTensor.from_local(torch.randn(12, 3), mesh, partial_spec) + res = dt_to_inplace_add.add_(partial_grad) + self.assertTrue(res is dt_to_inplace_add) + self.assertTrue(res.placements == shard_spec) + + @with_comms + @skip("failed") + def test_op_out_variant(self): + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + input_tensor = torch.randn((12, 3), device=self.device_type) + sharded_dt_input = distribute_tensor(input_tensor, mesh, [Shard(0)]) + expected_dt = sharded_dt_input.clone() + 3 + sharded_dt_out = sharded_dt_input.clone() + res = torch.add(sharded_dt_input, 3, out=sharded_dt_out) + # op out variant should be the same instance before and after + self.assertTrue(res is sharded_dt_out) + self.assertEqual(sharded_dt_out.to_local(), expected_dt.to_local()) + + # test op out variant with other spec and make sure out spec not change + replica_spec = [Replicate()] + replicate_out = distribute_tensor(input_tensor, mesh, replica_spec) + expected_dt = replicate_out.clone() + 3 + res = torch.add(sharded_dt_input, 3, out=replicate_out) + self.assertTrue(res is replicate_out) + self.assertTrue(res.placements == replica_spec) + self.assertEqual(replicate_out.to_local(), expected_dt.to_local()) + + @with_comms + def test_empty_like(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + empty_like_dt = torch.empty_like(dist_tensor) + # empty is not deterministic, so we only check that the shard propagation worked + self.assertEqual((4, 8), empty_like_dt.to_local().shape) + + @with_comms + def test_fill_inplace(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + full_like_dt = torch.fill_(dist_tensor, 42.0) + full_expected = torch.full((4, 8), 42.0) + self.assertEqual(full_expected, full_like_dt.to_local()) + self.assertEqual(full_expected, dist_tensor.to_local()) + + @with_comms + def test_full_like(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + full_like_dt = torch.full_like(dist_tensor, 42.0) + full_expected = torch.full((4, 8), 42.0) + self.assertEqual(full_expected, full_like_dt.to_local()) + + @with_comms + def test_ones_like(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + ones_like_dt = torch.ones_like(dist_tensor) + ones_expected = torch.ones(4, 8) + self.assertEqual(ones_expected, ones_like_dt.to_local()) + + @with_comms + @skip("failed") + def test_ones_like_partial_sum(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Partial()] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + assert dist_tensor.shape == (4, 8) + + with EnablePartialMode(): + ones_like_dt = torch.ones_like(dist_tensor) + ones_expected = torch.ones(dist_tensor.shape) + assert isinstance(ones_like_dt.placements[0], Partial) + ones_like_dt_replicate = torch.ones_like(dist_tensor) + assert isinstance(ones_like_dt_replicate.placements[0], Replicate) + + self.assertEqual( + ones_expected, + ones_like_dt.to_local(), + ) + + @with_comms + @skip("failed") + def test_fill_inplace_partial_sum(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Partial()] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + assert dist_tensor.shape == (4, 8) + + torch.fill_(dist_tensor, 42) + fill_expected = torch.full(dist_tensor.shape, 42, dtype=input_tensor.dtype) + self.assertEqual( + fill_expected, + dist_tensor.redistribute(device_mesh, [Replicate()]).to_local(), + ) + + @with_comms + @skip("failed") + def test_zeros_like_partial_sum(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Partial()] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + assert dist_tensor.shape == (4, 8) + + with EnablePartialMode(): + zeros_like_dt = torch.zeros_like(dist_tensor) + assert isinstance(zeros_like_dt.placements[0], Partial) + zeros_like_dt_replicate = torch.zeros_like(dist_tensor) + assert isinstance(zeros_like_dt_replicate.placements[0], Replicate) + zeros_expected = torch.zeros(dist_tensor.shape) + self.assertEqual(zeros_expected, zeros_like_dt.to_local()) + + @with_comms + def test_zero_inplace(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + zeros_like_dt = torch.zero_(dist_tensor) + zeros_expected = torch.zeros(4, 8) + self.assertEqual(zeros_expected, zeros_like_dt.to_local()) + self.assertEqual(zeros_expected, dist_tensor.to_local()) + + @with_comms + def test_zeros_like(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + zeros_like_dt = torch.zeros_like(dist_tensor) + zeros_expected = torch.zeros(4, 8) + self.assertEqual(zeros_expected, zeros_like_dt.to_local()) + + @with_comms + def test_equal(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + input_tensor_1 = torch.ones(4, 4) + dist_tensor_1 = DTensor.from_local(input_tensor_1, device_mesh, shard_spec) + + # tensors are equal + input_tensor_2 = torch.ones(4, 4) + dist_tensor_2 = DTensor.from_local(input_tensor_2, device_mesh, shard_spec) + + eq_result = dist_tensor_1.equal(dist_tensor_2) + self.assertTrue(eq_result) + + # tensors are different on some shards + if self.rank == 0: + input_tensor_2 = torch.ones(4, 4) + else: + input_tensor_2 = torch.randn(4, 4) + dist_tensor_2 = DTensor.from_local(input_tensor_2, device_mesh, shard_spec) + + eq_result = dist_tensor_1.equal(dist_tensor_2) + # equal op all reduces each shard's local result + self.assertFalse(eq_result) + + def _test_op(self, mesh, op_call, *args, **kwargs): + out = op_call(*args, **kwargs) + dtc = DTensorConverter(mesh, args, kwargs) + for d_args, d_kwargs in dtc: + self.assertTrue(dtc.successful()) + d_out = op_call(*d_args, **d_kwargs) + self.assertEqual( + d_out.redistribute(mesh, [Replicate()] * mesh.ndim).to_local(), + out, + ) + + @with_comms + def test_select(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + + shard_spec_1 = [Shard(1)] + input_tensor_1 = torch.rand(4, 8) + dist_tensor_1 = distribute_tensor(input_tensor_1, device_mesh, shard_spec_1) + dist_result_1 = dist_tensor_1[1] + self.assertEqual(dist_result_1.redistribute(device_mesh, [Replicate()]).to_local(), input_tensor_1[1]) + + shard_spec_2 = [Shard(0)] + input_tensor_2 = torch.rand(4, 8) + dist_tensor_2 = distribute_tensor(input_tensor_2, device_mesh, shard_spec_2) + dist_result_2 = dist_tensor_2[:, 1] + self.assertEqual(dist_result_2.redistribute(device_mesh, [Replicate()]).to_local(), input_tensor_2[:, 1]) + + @with_comms + @skip("failed") + def test_index_select(self): + meshes = [ + DeviceMesh(self.device_type, list(range(self.world_size))), # 1D mesh + # TODO(@azzolini): un-comment when DTensorConverter supports N-D mesh + # DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, -1)), # 2D mesh + ] + for mesh in meshes: + self._test_op( + mesh, + lambda x, y: x[y], + torch.randn(16, 32, 16), + torch.randint(5, (4, 8)), + ) + self._test_op( + mesh, + lambda x, y: x.index_select(1, y), + torch.randn(16, 32, 16), + torch.randint(5, (4,)), + ) + self._test_op( + mesh, + lambda x, y: x.index_select(0, y), + torch.randn(16, 32, 16), + torch.randint(5, (4,)), + ) + self._test_op( + mesh, + lambda x, y: x[y], + torch.randn(16, 32, 16), + torch.randint(5, (12,)), + ) + self._test_op( + mesh, + lambda x, y: x[:, y], + torch.randn(16, 32, 16), + torch.randint(5, (4, 8)), + ) + self._test_op( + mesh, + lambda x, y: x[..., y], + torch.randn(16, 32, 16), + torch.randint(5, (4, 12)), + ) + self._test_op( + mesh, + lambda x, y: x[..., y], + torch.randn(16, 32, 16), + torch.randint(5, (4, 8, 16)), + ) + self._test_op( + mesh, + lambda x, y, z: x[z, y], + torch.randn(16, 32, 16), + torch.randint(5, (12, 8, 12)), + torch.randint(2, (12, 8, 12)), + ) + self._test_op( + mesh, + lambda x, y, z: x[z, :, y], + torch.randn(16, 32, 16), + torch.randint(5, (12, 8, 12)), + torch.randint(2, (12, 8, 12)), + ) + self._test_op( + mesh, + lambda x, y, z: x[:, z, :, y], + torch.randn(16, 32, 16, 12), + torch.randint(5, (12, 8, 12)), + torch.randint(2, (12, 8, 12)), + ) + # broadcast in inner dimensions + self._test_op( + mesh, + lambda x, y, z: x[:, z, :, y], + torch.randn(16, 32, 16, 12), + torch.randint(5, (12, 8, 12)), + torch.randint(2, (12, 1, 12)), + ) + # implicit (left-padded) broadcast + self._test_op( + mesh, + lambda x, y, z: x[:, z, :, y], + torch.randn(16, 32, 16, 12), + torch.randint(5, (12, 8, 12)), + torch.randint(2, (8, 12)), + ) + self._test_op( + mesh, + lambda x, y, z: x[z, y, :, :], + torch.randn(16, 32, 16, 12), + torch.randint(2, (8, 12)), + torch.randint(5, (12, 8, 12)), + ) + self._test_op( + mesh, + lambda x, y, z: x[z, :, y, :], + torch.randn(16, 32, 16, 12), + torch.randint(2, (8, 12)), + torch.randint(5, (12, 8, 12)), + ) + self._test_op( + mesh, + lambda x, y, z: x[z, :, :, y], + torch.randn(16, 32, 16, 12), + torch.randint(2, (8, 1)), + torch.randint(5, (12, 8, 12)), + ) + + @with_comms + def test_index_put_(self): + device_mesh = DeviceMesh(self.device_type, [0, 1, 2, 3]) + inout_sharding = [Replicate()] + partial_sharding = [Partial()] + x = torch.rand(16, 8) + idx = torch.tensor([0, 1, 2, 3]) + src = torch.rand(4, 8) + src2 = torch.rand(4, 8) + x_residual = torch.rand(16, 8) + dx = distribute_tensor(x, device_mesh, partial_sharding) + dx_residual = distribute_tensor(x_residual, device_mesh, partial_sharding) + didx = distribute_tensor(idx, device_mesh, inout_sharding) + dsrc1 = distribute_tensor(src, device_mesh, partial_sharding) + dsrc2 = distribute_tensor(src2, device_mesh, partial_sharding) + dsrc1.requires_grad_(True) + dsrc2.requires_grad_(True) + dsrc = dsrc1 + dsrc2 + # out = torch.ops.aten.index_put_(dx, [didx], dsrc) + dx[didx] = dsrc + out = dx + dx_residual + out.redistribute(device_mesh, inout_sharding) + loss = out.mean() + loss.backward() + + @with_comms + def test_scatter(self): + device_mesh = self.build_device_mesh() + tensor = torch.zeros(3, 4) + index = torch.tensor([[1], [2]]) + dtensor = distribute_tensor(tensor, device_mesh, [Replicate()]) + dindex = distribute_tensor(index, device_mesh, [Replicate()]) + + local_result = torch.scatter(tensor, 0, index, 1) + d_result = torch.scatter(dtensor, 0, dindex, 1) + self.assertEqual(d_result.to_local(), local_result) + + @with_comms + def test_expand_with_broadcast(self): + device_mesh = self.build_device_mesh() + tensor = torch.randn((4,)) + matrix = torch.randn((2, 3, 4)) + dtensor = distribute_tensor(tensor, device_mesh, [Replicate()]) + dmatrix = distribute_tensor(matrix, device_mesh, [Shard(0)]) + dout = dtensor.expand_as(dmatrix) + assert dout._spec.placements[0] == Shard(0), f"sharding error {dout._spec}" + + @with_comms + @skip("failed") + def test_stack(self): + device_mesh = DeviceMesh( + self.device_type, + [0, 1], + ) + x = torch.rand([4, 2, 4, 8]) + # y = torch.rand([32, 2, 4, 128]) + dx = distribute_tensor(x, device_mesh, [Replicate()]).requires_grad_(True) + # dy = distribute_tensor(y, device_mesh, [Shard(1)]).requires_grad_(True) + dx = torch.chunk(dx.transpose(1, 2).float(), 2, dim=-1) + dx = torch.stack(dx) + # torch.autograd.backward(dout, torch.ones_like(dout)) + + +if __name__ == "__main__": + run_tests() diff --git a/test/parallel/ddp_optim/test_clip_grads.py b/test/parallel/ddp_optim/test_clip_grads.py index bd628ed..d3eb833 100644 --- a/test/parallel/ddp_optim/test_clip_grads.py +++ b/test/parallel/ddp_optim/test_clip_grads.py @@ -40,12 +40,13 @@ class VeScaleClipGradsTest(DTensorTestBase): def world_size(self): return 4 - def golden_run(self, params_and_inputs, max_norm): + def golden_run(self, params_and_inputs, max_norm, dtype): m = MLP(HIDDEN_DIM).cuda() m.fc1.weight = torch.nn.Parameter(params_and_inputs["fc1.weight"]) m.fc1.bias = torch.nn.Parameter(params_and_inputs["fc1.bias"]) m.fc2.weight = torch.nn.Parameter(params_and_inputs["fc2.weight"]) m.fc2.bias = torch.nn.Parameter(params_and_inputs["fc2.bias"]) + m.to(dtype) optimizer = torch.optim.Adam(m.parameters(), lr=0.01) @@ -65,12 +66,13 @@ def golden_run(self, params_and_inputs, max_norm): @with_comms @parametrize("max_norm", [2.0]) - def test_clip_grad(self, max_norm): + @parametrize("dtype", [torch.float, torch.bfloat16]) + def test_clip_grad(self, max_norm, dtype): tp_parallel_size = 2 dp_size = self.world_size // tp_parallel_size device_mesh = init_device_mesh(self.device_type, (dp_size, tp_parallel_size), mesh_dim_names=("DP", "TP")) - params_and_inputs = get_unfied_param_and_data(BSZ, HIDDEN_DIM) + params_and_inputs = get_unfied_param_and_data(BSZ, HIDDEN_DIM, dtype) new_params_and_inputs = copy.deepcopy(params_and_inputs) tp_sub_mesh = device_mesh["TP"] dp_pg = device_mesh.get_dim_groups(0) @@ -81,6 +83,7 @@ def test_clip_grad(self, max_norm): ve_model.fc1.bias = torch.nn.Parameter(params_and_inputs["fc1.bias"]) ve_model.fc2.weight = torch.nn.Parameter(params_and_inputs["fc2.weight"]) ve_model.fc2.bias = torch.nn.Parameter(params_and_inputs["fc2.bias"]) + ve_model.to(dtype) ve_model = parallelize_module( ve_model, tp_sub_mesh, {"parameter": MLP_PAIRWISE_PARAM_SHARDING_PLAN, "forward": MLP_FWD_RESAHRDING_PLAM} @@ -115,32 +118,32 @@ def test_clip_grad(self, max_norm): # do the grad norm clipping ve_optimizer.clip_grad_norm(ve_optimizer.clip_grad) - golden_mlp = self.golden_run(new_params_and_inputs, max_norm=max_norm) + golden_mlp = self.golden_run(new_params_and_inputs, max_norm=max_norm, dtype=dtype) golden_fc1_weight_grad = distribute_tensor( golden_mlp.fc1.weight.grad.data, tp_sub_mesh, MLP_PAIRWISE_PARAM_SHARDING_PLAN["fc1.weight"] - )._local_tensor + )._local_tensor.to(dtype) golden_fc1_bias_grad = distribute_tensor( golden_mlp.fc1.bias.grad.data, tp_sub_mesh, MLP_PAIRWISE_PARAM_SHARDING_PLAN["fc1.bias"] - )._local_tensor + )._local_tensor.to(dtype) golden_fc2_weight_grad = distribute_tensor( golden_mlp.fc2.weight.grad.data, tp_sub_mesh, MLP_PAIRWISE_PARAM_SHARDING_PLAN["fc2.weight"] - )._local_tensor + )._local_tensor.to(dtype) golden_fc2_bias_grad = distribute_tensor( golden_mlp.fc2.bias.grad.data, tp_sub_mesh, MLP_PAIRWISE_PARAM_SHARDING_PLAN["fc2.bias"] - )._local_tensor + )._local_tensor.to(dtype) if self.rank in [0, 1]: optimizer_params = ve_optimizer.get_parameters() - ve_fc2_bias_grad = optimizer_params[0].grad - ve_fc2_weight_grad = optimizer_params[1].grad - ve_fc1_bias_head_2_grad = optimizer_params[2].grad + ve_fc2_bias_grad = optimizer_params[0].grad.to(dtype) + ve_fc2_weight_grad = optimizer_params[1].grad.to(dtype) + ve_fc1_bias_head_2_grad = optimizer_params[2].grad.to(dtype) torch.testing.assert_close(golden_fc2_bias_grad, ve_fc2_bias_grad) torch.testing.assert_close(golden_fc2_weight_grad.flatten(), ve_fc2_weight_grad) torch.testing.assert_close(golden_fc1_bias_grad[:2,], ve_fc1_bias_head_2_grad) if self.rank in [2, 3]: optimizer_params = ve_optimizer.get_parameters() - ve_fc1_bias_tail_6_grad = optimizer_params[0].grad - ve_fc1_weight_grad = optimizer_params[1].grad + ve_fc1_bias_tail_6_grad = optimizer_params[0].grad.to(dtype) + ve_fc1_weight_grad = optimizer_params[1].grad.to(dtype) torch.testing.assert_close(golden_fc1_bias_grad[2:], ve_fc1_bias_tail_6_grad) torch.testing.assert_close(golden_fc1_weight_grad.flatten(), ve_fc1_weight_grad) diff --git a/test/parallel/ddp_optim/test_ddp.py b/test/parallel/ddp_optim/test_ddp.py index 0aafaa2..6a8046f 100644 --- a/test/parallel/ddp_optim/test_ddp.py +++ b/test/parallel/ddp_optim/test_ddp.py @@ -42,16 +42,16 @@ ) -def get_unfied_param_and_data(bsz, hidden_dim): - fc1_weight = torch.rand(hidden_dim * 4, hidden_dim).cuda() - fc1_bias = torch.rand(hidden_dim * 4).cuda() - fc2_weight = torch.rand(hidden_dim, hidden_dim * 4).cuda() - fc2_bias = torch.rand(hidden_dim).cuda() - - batch1_epoch1 = torch.rand(bsz, hidden_dim).cuda() - batch2_epoch1 = torch.rand(bsz, hidden_dim).cuda() - batch1_epoch2 = torch.rand(bsz, hidden_dim).cuda() - batch2_epoch2 = torch.rand(bsz, hidden_dim).cuda() +def get_unfied_param_and_data(bsz, hidden_dim, dtype=torch.float): + fc1_weight = torch.rand(hidden_dim * 4, hidden_dim, dtype=dtype).cuda() + fc1_bias = torch.rand(hidden_dim * 4, dtype=dtype).cuda() + fc2_weight = torch.rand(hidden_dim, hidden_dim * 4, dtype=dtype).cuda() + fc2_bias = torch.rand(hidden_dim, dtype=dtype).cuda() + + batch1_epoch1 = torch.rand(bsz, hidden_dim, dtype=dtype).cuda() + batch2_epoch1 = torch.rand(bsz, hidden_dim, dtype=dtype).cuda() + batch1_epoch2 = torch.rand(bsz, hidden_dim, dtype=dtype).cuda() + batch2_epoch2 = torch.rand(bsz, hidden_dim, dtype=dtype).cuda() # allreduce parameter and batches to make sure they are same at all ranks torch.distributed.all_reduce(fc1_weight) diff --git a/test/parallel/ddp_optim/test_grad_sync.py b/test/parallel/ddp_optim/test_grad_sync.py index ad95522..f4c8bef 100644 --- a/test/parallel/ddp_optim/test_grad_sync.py +++ b/test/parallel/ddp_optim/test_grad_sync.py @@ -154,7 +154,7 @@ def test_ddp(self, overlap_grad_reduce: bool, use_distributed_optimizer: bool): batch_1 = deepcopy(base_batch_1) batch_2 = deepcopy(base_batch_2) - # ------------- baseline start ------------- # + # ------------- baseline ------------- # base_ln_model = torch.nn.LayerNorm(HIDDEN_DIM) base_ln_model.weight = torch.nn.Parameter(base_ln_weight_param) base_ln_model.bias = torch.nn.Parameter(base_ln_bias_param) @@ -173,7 +173,7 @@ def test_ddp(self, overlap_grad_reduce: bool, use_distributed_optimizer: bool): base_ln_weight = base_ln_model.weight base_ln_bias = base_ln_model.bias - # ------------- baseline end ------------- # + # ------------- vescale ddp ------------- # m = LN(HIDDEN_DIM) m.ln.weight = torch.nn.Parameter(ln_weight_param) @@ -231,15 +231,6 @@ def test_ddp(self, overlap_grad_reduce: bool, use_distributed_optimizer: bool): # -------- check results -------- # - grad_sync_list = ddp_m.module.list_grad_sync() - fqn_sync_list = set([fqn for fqn, _ in grad_sync_list]) # noqa: C403 - self.assertTrue(len(grad_sync_list) == 2) - self.assertTrue("ln.weight.main_grad" in fqn_sync_list) - self.assertTrue("ln.bias.main_grad" in fqn_sync_list) - - self.assertTrue(ddp_ln_weight_grad._spec.placements[0].is_replicate()) - self.assertTrue(ddp_ln_bias_grad._spec.placements[0].is_replicate()) - # NOTE: we can do the following check just because of such a conincidence: # the bias and weight parameter of LayerNorm occupy the same size of memory, # after reduce_scatter grad along DP dimension, DP 0 will hold the whole