From 3dff8feee545734717cc61d5b1e2422f0a1085ca Mon Sep 17 00:00:00 2001 From: ardfork <134447697+ardfork@users.noreply.github.com> Date: Sun, 6 Aug 2023 01:27:35 +0000 Subject: [PATCH] Fix HIP on recent PyTorch version (#224) --- exllama_ext/hip_compat.cuh | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/exllama_ext/hip_compat.cuh b/exllama_ext/hip_compat.cuh index e650cd4a..1301ca08 100644 --- a/exllama_ext/hip_compat.cuh +++ b/exllama_ext/hip_compat.cuh @@ -1,7 +1,7 @@ #ifndef _hip_compat_cuh #define _hip_compat_cuh -// Workaround for a bug in hipamd, backported from upstream. +// Workaround for a bug in hipamd, backported from upstream, this is fixed in ROCm 5.6. __device__ __forceinline__ __half __compat_hrcp(__half x) { return __half_raw{ static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))}; @@ -15,7 +15,7 @@ __device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) { #define hrcp __compat_hrcp #define h2rcp __compat_h2rcp -// Workaround for hipify_python using rocblas instead of hipblas. +// Automatic conversion of hipblasHgemm doesn't convert half to hipblasHalf. __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, hipblasOperation_t transA, hipblasOperation_t transB, @@ -37,7 +37,9 @@ __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t reinterpret_cast(beta), reinterpret_cast(CP), ldc); } +#define hipblasHgemm __compat_hipblasHgemm +// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. #define rocblas_handle hipblasHandle_t #define rocblas_operation_none HIPBLAS_OP_N #define rocblas_get_stream hipblasGetStream