forked from karpathy/llm.c
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathamd_support.h
356 lines (325 loc) · 12.5 KB
/
amd_support.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
/*
Goal: unobtrusively provide support for AMD devices with minimal changes to the main CUDA code
Example (assuming ROCm 6.1.1 installed in /opt/rocm, or ROCM_PATH environment variable is set):
*/
#pragma once
#ifdef MULTI_GPU
#include <mpi.h>
#include <rccl/rccl.h>
#endif
#if defined(__gfx1100__) || defined(__gfx1103__)
#define AMD_TARGET_ARCH_RDNA3
#elif defined(__gfx90a__)
#define AMD_TARGET_ARCH_CDNA2
#elif defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define AMD_TARGET_ARCH_CDNA3
#endif
#include <hip/hip_bfloat16.h>
#ifndef DISABLE_CK
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/ck.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// cublaslt does not have kernels for gfx11, so best alternative in terms of perf/effort seems to be composite_kernels
// somewhat janky to invoke with all of the templating, but works..
static inline void matmul_forward_gfx11(hip_bfloat16* out,
const hip_bfloat16* inp, const hip_bfloat16* weight, const hip_bfloat16* bias,
int B, int T, int C, int OC, cudaStream_t stream) {
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Add;
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
auto cde_element_op = CDEElementOp{};
if (bias == NULL) {
auto device_op = ck::tensor_operation::device::DeviceGemmWmma_CShuffle <
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::bhalf_t,
AElementOp,
BElementOp,
CElementOp,
GemmSpec,
256,
128,
256,
8,
8,
16,
16,
4,
4,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1,
1,
S<1, 32, 1, 8>,
8,
1>{};
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(
reinterpret_cast<ck::bhalf_t*>(const_cast<hip_bfloat16 *>(inp)),
reinterpret_cast<ck::bhalf_t*>(const_cast<hip_bfloat16 *>(weight)),
reinterpret_cast<ck::bhalf_t*>(out),
B*T,
OC,
C,
C,
C,
OC,
a_element_op,
b_element_op,
c_element_op);
invoker.Run(argument, StreamConfig{stream});
} else {
auto device_op = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle <
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::Tuple<ck::tensor_layout::gemm::RowMajor>,
ck::tensor_layout::gemm::RowMajor,
ck::bhalf_t,
ck::bhalf_t,
ck::Tuple<ck::bhalf_t>,
ck::bhalf_t,
float,
ck::bhalf_t,
AElementOp,
BElementOp,
CDEElementOp,
GemmSpec,
256,
128,
256,
8,
8,
16,
16,
4,
4,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1,
1,
S<1, 32, 1, 8>,
8>{};
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(
reinterpret_cast<ck::bhalf_t*>(const_cast<hip_bfloat16 *>(inp)),
reinterpret_cast<ck::bhalf_t*>(const_cast<hip_bfloat16 *>(weight)),
std::array<const void*, 1>{reinterpret_cast<ck::bhalf_t*>(const_cast<hip_bfloat16 *>(bias))},
reinterpret_cast<ck::bhalf_t*>(out),
B*T,
OC,
C,
C,
C,
std::array<ck::index_t, 1>{0},
OC,
a_element_op,
b_element_op,
cde_element_op);
invoker.Run(argument, StreamConfig{stream});
}
}
#endif
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#include <rocblas/rocblas.h>
#include <hipblaslt/hipblaslt.h>
#include <hip/hip_cooperative_groups.h>
// macros below handle mostly cublaslt stuff not handled by hipify (yet)
#define cublasLtMatmulPreferenceSetAttribute hipblasLtMatmulPreferenceSetAttribute
#define CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
#define cublasLtMatmulPreferenceCreate hipblasLtMatmulPreferenceCreate
#define cublasLtMatmulDescSetAttribute hipblasLtMatmulDescSetAttribute
#define cublasLtMatmulPreferenceDestroy hipblasLtMatmulPreferenceDestroy
#define cublasLtMatmulDescDestroy hipblasLtMatmulDescDestroy
#define cublasLtMatmulAlgoGetHeuristic hipblasLtMatmulAlgoGetHeuristic
#define cublasLtMatrixLayoutDestroy hipblasLtMatrixLayoutDestroy
#define CUBLASLT_EPILOGUE_GELU_BIAS HIPBLASLT_EPILOGUE_GELU_BIAS
#define CUBLASLT_EPILOGUE_GELU HIPBLASLT_EPILOGUE_GELU
#define CUBLASLT_EPILOGUE_BIAS HIPBLASLT_EPILOGUE_BIAS
#define CUBLASLT_EPILOGUE_DEFAULT HIPBLASLT_EPILOGUE_DEFAULT
#define cublasLtEpilogue_t hipblasLtEpilogue_t
#define cublasLtMatmulHeuristicResult_t hipblasLtMatmulHeuristicResult_t
#define cublasLtMatrixLayout_t hipblasLtMatrixLayout_t
#define cublasLtMatmulPreference_t hipblasLtMatmulPreference_t
#define cublasLtMatmulDesc_t hipblasLtMatmulDesc_t
#define cublasLtHandle_t hipblasLtHandle_t
#define cublasLtMatmul hipblasLtMatmul
#define CUBLASLT_MATMUL_DESC_TRANSA HIPBLASLT_MATMUL_DESC_TRANSA
#define CUBLASLT_MATMUL_DESC_TRANSB HIPBLASLT_MATMUL_DESC_TRANSB
#define CUBLASLT_MATMUL_DESC_EPILOGUE HIPBLASLT_MATMUL_DESC_EPILOGUE
#define CUBLASLT_MATMUL_DESC_BIAS_POINTER HIPBLASLT_MATMUL_DESC_BIAS_POINTER
#define cublasLtCreate hipblasLtCreate
#define cublasLtDestroy hipblasLtDestroy
#define cublasLtMatrixLayoutCreate hipblasLtMatrixLayoutCreate
#define cublasLtMatmulDescCreate hipblasLtMatmulDescCreate
#define cublasSetMathMode(handle, mode) HIPBLAS_STATUS_SUCCESS
#define hipblasSetMathMode(handle, mode) HIPBLAS_STATUS_SUCCESS
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
#define cublasMath_t hipblasMath_t
#define CUBLAS_TF32_TENSOR_OP_MATH HIPBLAS_TF32_TENSOR_OP_MATH
#define CUBLAS_DEFAULT_MATH HIPBLAS_DEFAULT_MATH
#define hipFuncSetAttribute(x,y,z) 0
#define hipProfilerStart(x) hipSuccess
#define hipProfilerStop(x) hipSuccess
#define nvtxRangePush(x) {}
#define nvtxRangePop(x) {}
#define nvtxNameCudaStreamA(x,y) {}
#define cublasSetWorkspace(x,y,z) HIPBLAS_STATUS_SUCCESS
#define nvtxNameCudaEventA(x,y) {}
#define cudaStreamWaitEvent(x,y) hipStreamWaitEvent(x,y,0)
static __device__ __forceinline__ hip_bfloat16 __float2bfloat16_rn(float f) {
return hip_bfloat16::round_to_bfloat16(f);
}
static __device__ __forceinline__ float __bfloat162float(hip_bfloat16 f) {
return static_cast<float>(f);
}
template <typename T>
static __device__ __forceinline__ T __shfl_xor_sync(unsigned mask, T var, int laneMask, int width=warpSize) {
return __shfl_xor(var, laneMask, width);
}
template <typename T>
static __device__ __forceinline__ T __shfl_down_sync(unsigned mask, T var, int laneMask, int width=warpSize) {
return __shfl_down(var, laneMask, width);
}
// provide cache hints where possible
#define __stcs(ptr, val) patched_stcs(ptr, val)
#define __ldcs(ptr) patched_ldcs(ptr)
#define __stcg(ptr, val) {*(ptr) = val;}
static __device__ __forceinline__ void patched_stcs(float *addr, float val) {
__builtin_nontemporal_store(val, addr);
}
static __device__ __forceinline__ void patched_stcs(hip_bfloat16 *addr, hip_bfloat16 val) {
*addr = val;
}
static __device__ __forceinline__ void patched_stcs(int4 *addr, int4 val) {
int *a = (int *)addr;
__builtin_nontemporal_store(val.x, a);
__builtin_nontemporal_store(val.y, a+1);
__builtin_nontemporal_store(val.z, a+2);
__builtin_nontemporal_store(val.w, a+3);
}
static __device__ __forceinline__ float patched_ldcs(const float *addr) {
return __builtin_nontemporal_load(addr);
}
static __device__ __forceinline__ int4 patched_ldcs(const int4 *addr) {
const int *a = (const int *) addr;
return make_int4(__builtin_nontemporal_load(a),
__builtin_nontemporal_load(a+1),
__builtin_nontemporal_load(a+2),
__builtin_nontemporal_load(a+3));
}
static __device__ __forceinline__ hip_bfloat16 patched_ldcs(const hip_bfloat16 *addr) {
return *addr;
}
#if defined(AMD_TARGET_ARCH_RDNA3)
static __device__ __forceinline__ float warp_reduce_sum(float x) {
asm volatile ("ds_swizzle_b32 v1, %0 offset:swizzle(SWAP,16) \n"\
"s_waitcnt lgkmcnt(0) \n"\
"v_add_f32_e32 %0, %0, v1 \n"
"s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1) \n"\
"v_add_f32_dpp %0, %0, %0 row_ror:8 row_mask:0xf bank_mask:0xf bound_ctrl:1 \n"\
"v_add_f32_dpp %0, %0, %0 row_ror:4 row_mask:0xf bank_mask:0xf bound_ctrl:1 \n"\
"s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1) \n"\
"v_add_f32_dpp %0, %0, %0 row_ror:2 row_mask:0xf bank_mask:0xf bound_ctrl:1 \n"\
"v_add_f32_dpp %0, %0, %0 row_ror:1 row_mask:0xf bank_mask:0xf bound_ctrl:1 \n"
: "+v"(x) : : "v1");
return x;
}
static __device__ __forceinline__ float warp_reduce_max(float x) {
asm volatile ("s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1) \n"\
"v_max_f32_dpp %0, %0, %0 row_ror:8 row_mask:0xf bank_mask:0xf bound_ctrl:1 \n"\
"v_max_f32_dpp %0, %0, %0 row_ror:4 row_mask:0xf bank_mask:0xf bound_ctrl:1 \n"\
"s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1) \n"\
"v_max_f32_dpp %0, %0, %0 row_ror:2 row_mask:0xf bank_mask:0xf bound_ctrl:1 \n"\
"v_max_f32_dpp %0, %0, %0 row_ror:1 row_mask:0xf bank_mask:0xf bound_ctrl:1 \n"\
"ds_swizzle_b32 v1, %0 offset:swizzle(SWAP,16) \n"\
"s_waitcnt lgkmcnt(0) \n"\
"v_max_f32_e32 %0, %0, v1 \n"
: "+v"(x) : : "v1");
return x;
}
#else
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#ifdef WAVEFRONTSIZE64
for (int mask = 32; mask > 0; mask >>= 1) { x += __shfl_xor(x, mask, 64); }
#else
for (int mask = 16; mask > 0; mask >>= 1) { x += __shfl_xor(x, mask, 32); }
#endif
return x;
}
static __device__ __forceinline__ float warp_reduce_max(float x) {
#ifdef WAVEFRONTSIZE64
for (int mask = 32; mask > 0; mask >>= 1) { x = fmaxf(x, __shfl_xor(x, mask, 64)); }
#else
for (int mask = 16; mask > 0; mask >>= 1) { x = fmaxf(x, __shfl_xor(x, mask, 32)); }
#endif
return x;
}
#endif
namespace cooperative_groups {
template <typename T>
struct reduce_operator {
static __device__ __forceinline__ T reduce(const T a, const T b) { return a+b; };
};
template <typename T>
struct plus : public reduce_operator<T> {
static __device__ __forceinline__ T reduce(const T a, const T b) {
return a + b;
}
};
template <typename T>
struct greater : public reduce_operator<T> {
static __device__ __forceinline__ T reduce(const T a, const T b) {
return fmaxf(a, b);
}
};
template <typename T>
static __device__ __forceinline__ float reduce(const thread_block_tile<32>& warp, float x, const plus<T>& op) {
return warp_reduce_sum(x);
}
template <typename T>
static __device__ __forceinline__ float reduce(const thread_block_tile<32>& warp, float x, const greater<T>& op) {
return warp_reduce_max(x);
}
template struct plus<float>;
template struct greater<float>;
}