Skip to content

Commit

Permalink
Merge pull request karpathy#298 from karpathy/feature/packed128
Browse files Browse the repository at this point in the history
Feature/packed128
  • Loading branch information
karpathy authored Apr 29, 2024
2 parents 906d22f + af2bc47 commit 2490f78
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 0 deletions.
61 changes: 61 additions & 0 deletions dev/cuda/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,67 @@ void cublasCheck(cublasStatus_t status, const char *file, int line)
}
#define cublasCheck(status) { cublasCheck((status), __FILE__, __LINE__); }

// ----------------------------------------------------------------------------
// Packed128 data structure, which forces the compiler to use 128-bit loads/stores
// in GPUs that support (the LDG.128 and STS.128 instructions)
// This is a bit similar to the use of float4 in the case of 32-bit floats, but
// supports arbitrary precision.

template<class ElementType>
struct alignas(16) Packed128 {
__device__ Packed128() = default;
__device__ explicit Packed128(int4 bits) {
static_assert(sizeof(bits) == sizeof(payload), "Size mismatch.");
memcpy(&payload, &bits, sizeof(bits));
}

__device__ ElementType& operator[](int index) {
return payload[index];
}
__device__ const ElementType& operator[](int index) const {
return payload[index];
}
__device__ float fp32(int index) {
return static_cast<float>(payload[index]);
}
__device__ int4 get_bits() const {
int4 bits;
static_assert(sizeof(bits) == sizeof(payload), "Size mismatch.");
memcpy(&bits, &payload, sizeof(bits));
return bits;
}

static constexpr const size_t size = sizeof(int4) / sizeof(ElementType);
ElementType payload[size];
};

// short-form typedef
typedef Packed128<float> f128;

// load a Packed128 from an aligned memory address
template<class ElementType>
__device__ Packed128<ElementType> load128(const ElementType* address) {
return Packed128<ElementType>{*reinterpret_cast<const int4*>(address)};
}

// load a Packed128 from an aligned memory address with streaming cache hint
template<class ElementType>
__device__ Packed128<ElementType> load128cs(const ElementType* address) {
return Packed128<ElementType>{__ldcs(reinterpret_cast<const int4*>(address))};
}

// store a Packed128 to an aligned memory address
template<class ElementType>
__device__ void store128(ElementType* target, Packed128<ElementType> value) {
*reinterpret_cast<int4*>(target) = value.get_bits();
}

// store a Packed128 to an aligned memory address with streaming cache hint
template<class ElementType>
__device__ void store128cs(ElementType* target, Packed128<ElementType> value) {
__stcs(reinterpret_cast<int4*>(target), value.get_bits());
}

// ----------------------------------------------------------------------------
// random utils

Expand Down
29 changes: 29 additions & 0 deletions dev/cuda/gelu_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ If encountering "error: identifier "M_PI" is undefined", add the following lines
version 1 is naive port from CPU code to kernel
./gelu_forward 1
version 2 uses the Packed128 data structure
./gelu_forward 2
*/

#include <stdio.h>
Expand Down Expand Up @@ -44,6 +47,23 @@ __global__ void gelu_kernel(float* out, const float* inp, int N) {
}
}

// elementwise ops are nice and ez
__global__ void gelu_kernel2(float* out, const float* inp, int N) {
int i = (blockIdx.x * blockDim.x + threadIdx.x) * f128::size;
if (i < N) {
f128 packet_out;
f128 packet_in = load128cs(inp + i); // load and do not keep in cache
for(int k = 0; k < packet_in.size; ++k) {
float xi = packet_in[k];
float cube = 0.044715f * xi * xi * xi;
packet_out[k] = 0.5f * xi * (1.0f + tanhf(GELU_SCALING_FACTOR * (xi + cube)));
}
// store instead of storecs (without cache streaming) in case it is useful for the
// data to be in the cache for the next operation after this GeLU
store128(out + i, packet_out);
}
}

// ----------------------------------------------------------------------------
// kernel launcher

Expand All @@ -53,6 +73,12 @@ void gelu_forward1(float* out, const float* inp, int N, const int block_size) {
cudaCheck(cudaGetLastError());
}

void gelu_forward2(float* out, const float* inp, int N, const int block_size) {
const int grid_size = ceil_div(N, block_size) / 4;
gelu_kernel2<<<grid_size, block_size>>>(out, inp, N);
cudaCheck(cudaGetLastError());
}

// kernel version dispatch
void gelu_forward(int kernel_num,
float* out,
Expand All @@ -63,6 +89,9 @@ void gelu_forward(int kernel_num,
case 1:
gelu_forward1(out, inp, B * T * C, block_size);
break;
case 2:
gelu_forward2(out, inp, B * T * C, block_size);
break;
default:
printf("Invalid kernel number\n");
exit(1);
Expand Down
61 changes: 61 additions & 0 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,67 @@ __device__ void atomicAddX(float* addr, float val) {
atomicAdd(addr, val);
}

// ----------------------------------------------------------------------------
// Packed128 data structure, which forces the compiler to use 128-bit loads/stores
// in GPUs that support (the LDG.128 and STS.128 instructions)
// This is a bit similar to the use of float4 in the case of 32-bit floats, but
// supports arbitrary precision.

template<class ElementType>
struct alignas(16) Packed128 {
__device__ Packed128() = default;
__device__ explicit Packed128(int4 bits) {
static_assert(sizeof(bits) == sizeof(payload), "Size mismatch.");
memcpy(&payload, &bits, sizeof(bits));
}

__device__ ElementType& operator[](int index) {
return payload[index];
}
__device__ const ElementType& operator[](int index) const {
return payload[index];
}
__device__ float fp32(int index) {
return static_cast<float>(payload[index]);
}
__device__ int4 get_bits() const {
int4 bits;
static_assert(sizeof(bits) == sizeof(payload), "Size mismatch.");
memcpy(&bits, &payload, sizeof(bits));
return bits;
}

static constexpr const size_t size = sizeof(int4) / sizeof(ElementType);
ElementType payload[size];
};

// short-form typedef
typedef Packed128<float> f128;

// load a Packed128 from an aligned memory address
template<class ElementType>
__device__ Packed128<ElementType> load128(const ElementType* address) {
return Packed128<ElementType>{*reinterpret_cast<const int4*>(address)};
}

// load a Packed128 from an aligned memory address with streaming cache hint
template<class ElementType>
__device__ Packed128<ElementType> load128cs(const ElementType* address) {
return Packed128<ElementType>{__ldcs(reinterpret_cast<const int4*>(address))};
}

// store a Packed128 to an aligned memory address
template<class ElementType>
__device__ void store128(ElementType* target, Packed128<ElementType> value) {
*reinterpret_cast<int4*>(target) = value.get_bits();
}

// store a Packed128 to an aligned memory address with streaming cache hint
template<class ElementType>
__device__ void store128cs(ElementType* target, Packed128<ElementType> value) {
__stcs(reinterpret_cast<int4*>(target), value.get_bits());
}

// ----------------------------------------------------------------------------
// Random Number Generatiom

Expand Down

0 comments on commit 2490f78

Please sign in to comment.