Skip to content

Commit

Permalink
Merge pull request karpathy#436 from ChrisDryden/boundscheck
Browse files Browse the repository at this point in the history
Moved bounds checks to outside of the kernel
  • Loading branch information
karpathy authored May 19, 2024
2 parents 2751fa0 + 6de1137 commit 6c8bc17
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -845,9 +845,8 @@ __global__ void softmax_forward_kernel5(floatX* out, float inv_temperature, cons
}
}

__global__ void residual_forward_kernel(floatX* out, const floatX* inp1, const floatX* inp2, int N) {
__global__ void residual_forward_kernel(floatX* out, const floatX* inp1, const floatX* inp2) {
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;
if (idx >= N) { return; }

x128 packed_out;
x128 packed_inp1 = load128cs(inp1 + idx);
Expand All @@ -859,9 +858,8 @@ __global__ void residual_forward_kernel(floatX* out, const floatX* inp1, const f
}

#define GELU_SCALING_FACTOR sqrtf(2.0f / M_PI)
__global__ void gelu_forward_kernel2(floatX* out, const floatX* inp, int N) {
__global__ void gelu_forward_kernel2(floatX* out, const floatX* inp) {
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;
if (idx >= N) { return; }

x128 packed_out;
x128 packed_inp = load128cs(inp + idx); // load and do not keep in cache
Expand All @@ -875,9 +873,8 @@ __global__ void gelu_forward_kernel2(floatX* out, const floatX* inp, int N) {
store128(out + idx, packed_out);
}

__global__ void gelu_backward_kernel(floatX* dinp, const floatX* inp, const floatX* dout, const int N) {
__global__ void gelu_backward_kernel(floatX* dinp, const floatX* inp, const floatX* dout) {
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;
if (idx >= N) { return; }

x128 packed_dinp;
x128 packed_inp = load128cs(inp + idx);
Expand Down Expand Up @@ -1509,8 +1506,9 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att,
void residual_forward(floatX* out, const floatX* inp1, const floatX* inp2, int N) {
NVTX_RANGE_FN();
const int block_size = 256;
assert(N % block_size == 0);
const int grid_size = CEIL_DIV(N, block_size * x128::size);
residual_forward_kernel<<<grid_size, block_size>>>(out, inp1, inp2, N);
residual_forward_kernel<<<grid_size, block_size>>>(out, inp1, inp2);
cudaCheck(cudaGetLastError());
}

Expand Down Expand Up @@ -1542,16 +1540,18 @@ void fused_residual_forward5(floatX* residual, floatX* normed, floatX* mean, flo
void gelu_forward(floatX* out, const floatX* inp, int N) {
NVTX_RANGE_FN();
const int block_size = 512;
assert(N % block_size == 0);
const int grid_size = CEIL_DIV(N, block_size * x128::size);
gelu_forward_kernel2<<<grid_size, block_size>>>(out, inp, N);
gelu_forward_kernel2<<<grid_size, block_size>>>(out, inp);
cudaCheck(cudaGetLastError());
}

void gelu_backward(floatX* dinp, const floatX* inp, const floatX* dout, const int N) {
NVTX_RANGE_FN();
const int block_size = 128;
assert(N % block_size == 0);
const int grid_size = CEIL_DIV(N, block_size * x128::size);
gelu_backward_kernel<<<grid_size, block_size>>>(dinp, inp, dout, N);
gelu_backward_kernel<<<grid_size, block_size>>>(dinp, inp, dout);
cudaCheck(cudaGetLastError());
}

Expand Down

0 comments on commit 6c8bc17

Please sign in to comment.