From aa865a28cdfe823a995cd1685dfd064d0a33a1eb Mon Sep 17 00:00:00 2001 From: Ankan Banerjee Date: Tue, 22 Mar 2022 22:28:36 +0530 Subject: [PATCH 01/70] WIP attention body changes --- src/neural/cuda/common_kernels.cu | 81 ++++++ src/neural/cuda/kernels.h | 3 + src/neural/cuda/layers.cc | 381 ++++++++++++++++++----------- src/neural/cuda/layers.h | 98 +++++--- src/neural/cuda/network_cuda.cc | 393 +++++++++++++++++------------- src/neural/network_legacy.cc | 10 + src/neural/network_legacy.h | 13 + 7 files changed, 643 insertions(+), 336 deletions(-) diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index a1ecb36d97..a7d08f457d 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -146,6 +146,10 @@ void addBiasBatched(T* output, const T* input, const T* bias, int Batch, int N, addBiasBatched_kernel<<>>( output, input, bias, N, C); break; + case MISH: + addBiasBatched_kernel + <<>>(output, input, bias, N, C); + break; default: throw Exception( "unsupported activation in addBiasBatched. Add in switch-case here"); @@ -984,6 +988,75 @@ void ComputePromotionLogits(int N, int C, T* output, const T* keys, <<>>(C, output, keys, ppo, policy_attn_logits); } +__device__ constexpr float kPosEncoding[64][6] = { + {0., 0., 0., 0., 0., 0.}, {0., 0., 0., 0., 0., 1.}, + {0., 0., 0., 0., 1., 0.}, {0., 0., 0., 0., 1., 1.}, + {0., 0., 0., 1., 0., 0.}, {0., 0., 0., 1., 0., 1.}, + {0., 0., 0., 1., 1., 0.}, {0., 0., 0., 1., 1., 1.}, + {0., 0., 1., 0., 0., 0.}, {0., 0., 1., 0., 0., 1.}, + {0., 0., 1., 0., 1., 0.}, {0., 0., 1., 0., 1., 1.}, + {0., 0., 1., 1., 0., 0.}, {0., 0., 1., 1., 0., 1.}, + {0., 0., 1., 1., 1., 0.}, {0., 0., 1., 1., 1., 1.}, + {0., 1., 0., 0., 0., 0.}, {0., 1., 0., 0., 0., 1.}, + {0., 1., 0., 0., 1., 0.}, {0., 1., 0., 0., 1., 1.}, + {0., 1., 0., 1., 0., 0.}, {0., 1., 0., 1., 0., 1.}, + {0., 1., 0., 1., 1., 0.}, {0., 1., 0., 1., 1., 1.}, + {0., 1., 1., 0., 0., 0.}, {0., 1., 1., 0., 0., 1.}, + {0., 1., 1., 0., 1., 0.}, {0., 1., 1., 0., 1., 1.}, + {0., 1., 1., 1., 0., 0.}, {0., 1., 1., 1., 0., 1.}, + {0., 1., 1., 1., 1., 0.}, {0., 1., 1., 1., 1., 1.}, + {1., 0., 0., 0., 0., 0.}, {1., 0., 0., 0., 0., 1.}, + {1., 0., 0., 0., 1., 0.}, {1., 0., 0., 0., 1., 1.}, + {1., 0., 0., 1., 0., 0.}, {1., 0., 0., 1., 0., 1.}, + {1., 0., 0., 1., 1., 0.}, {1., 0., 0., 1., 1., 1.}, + {1., 0., 1., 0., 0., 0.}, {1., 0., 1., 0., 0., 1.}, + {1., 0., 1., 0., 1., 0.}, {1., 0., 1., 0., 1., 1.}, + {1., 0., 1., 1., 0., 0.}, {1., 0., 1., 1., 0., 1.}, + {1., 0., 1., 1., 1., 0.}, {1., 0., 1., 1., 1., 1.}, + {1., 1., 0., 0., 0., 0.}, {1., 1., 0., 0., 0., 1.}, + {1., 1., 0., 0., 1., 0.}, {1., 1., 0., 0., 1., 1.}, + {1., 1., 0., 1., 0., 0.}, {1., 1., 0., 1., 0., 1.}, + {1., 1., 0., 1., 1., 0.}, {1., 1., 0., 1., 1., 1.}, + {1., 1., 1., 0., 0., 0.}, {1., 1., 1., 0., 0., 1.}, + {1., 1., 1., 0., 1., 0.}, {1., 1., 1., 0., 1., 1.}, + {1., 1., 1., 1., 0., 0.}, {1., 1., 1., 1., 0., 1.}, + {1., 1., 1., 1., 1., 0.}, {1., 1., 1., 1., 1., 1.}}; + +template +__global__ void preprocess_for_attention_body_kernel(T* output, const T* input) { + int n = blockIdx.x; + int hw = blockIdx.y; + int c = threadIdx.x; + + T op; + if (c >= kInputPlanes) + { + // concatenate from fixed pos encoding array + op = (T) (kPosEncoding[hw][c - kInputPlanes]); + } else { + + op = input[n * 64 * kInputPlanes + c * 64 + hw]; // nchw + } + + constexpr int outputC = kInputPlanes + 6; + + // convert to nhwc + output[n * 64 * outputC + hw * outputC + c] = op; +} + +template +void inputPreprocessForAttentionBody(T* output, const T* input, int N, + cudaStream_t stream) { + // N * 64 blocks + // (kInputPlanes + 6) threads + // Each thread computes a single output element + dim3 gridSize = dim3(N, 64); + int blockSize = kInputPlanes + 6; + preprocess_for_attention_body_kernel + <<>>(output, input); +} + + // Template instantiation. template void copyTypeConverted(half* op, float* ip, int N, cudaStream_t stream); @@ -1205,5 +1278,13 @@ template void convertNCHWtoNHWC(half* output_tensor, const half* input_tensor, int Nin, int Cin, int Nout, int Cout, int H, int W); + +template void inputPreprocessForAttentionBody(half* output, + const half* input, + int N, cudaStream_t stream); + +template void inputPreprocessForAttentionBody(float* output, + const float* input, int N, + cudaStream_t stream); } // namespace cudnn_backend } // namespace lczero diff --git a/src/neural/cuda/kernels.h b/src/neural/cuda/kernels.h index 2763d09c1e..e1038a8aee 100644 --- a/src/neural/cuda/kernels.h +++ b/src/neural/cuda/kernels.h @@ -130,5 +130,8 @@ void ComputePromotionLogits(int N, int C, T* output, const T* keys, const T* ppo, const T* policy_attn_logits, cudaStream_t stream); +template +void inputPreprocessForAttentionBody(T* output, const T* input, int N, + cudaStream_t stream); } // namespace cudnn_backend } // namespace lczero diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 789bd04f75..8beb0c7002 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -1373,14 +1373,16 @@ AttentionPolicyHead::AttentionPolicyHead(BaseLayer* ip, allocAndUpload(&ip4_pol_w_, weights.ip4_pol_w, scratch); for (const auto& enc : weights.pol_encoder) { - EncoderWeights* pW = new EncoderWeights(enc, scratch); + EncoderBlock* pW = new EncoderBlock( + enc, scratch, encoder_heads_, embedding_op_size_); encoder_weights_.emplace_back(pW); } } template -AttentionPolicyHead::EncoderWeights::EncoderWeights( - const LegacyWeights::EncoderLayer& cpu_weights, void* scratch) { +EncoderBlock::EncoderBlock( + const LegacyWeights::EncoderLayer& cpu_weights, void* scratch, int heads, + int size) : encoder_heads_(heads), embedding_op_size_(size) { mha_q_size_ = cpu_weights.mha.q_b.size(); mha_k_size_ = cpu_weights.mha.k_b.size(); mha_v_size_ = cpu_weights.mha.v_b.size(); @@ -1478,6 +1480,145 @@ static void cublasXGemmStridedBatched( } } +// input/output tensor is scratch1, others are used as scratch. +// TODO: fix naming of scratch buffers +template +void EncoderBlock::Eval(int N, DataType* scratch1, + DataType* scratch0, + DataType* scratch2, + DataType* scratch3, + cublasHandle_t cublas, + cudaStream_t stream) const { + const int d_model = mha_q_size_; + const int depth = d_model / encoder_heads_; + + DataType* mha_q; + DataType* mha_k; + DataType* mha_v; + + { + const int num_inputs = embedding_op_size_; + const int num_outputs = d_model; + const int batch = N * 64; + + mha_q = scratch0; + mha_k = mha_q + num_outputs * batch; + mha_v = mha_k + num_outputs * batch; + + cublasXGemmStridedBatched( + cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, + mha_qkv_w, num_inputs, num_inputs * num_outputs, scratch1, + num_inputs, 0, 0.0f, mha_q, num_outputs, num_outputs * batch, 3); + addBiasBatched(mha_q, mha_q, mha_qkv_b, 3, batch, num_outputs, + NONE, stream); + } + + // Apply split_heads() to q, k and v + // which basically transposes (batch_size, 64, num_heads, depth) + // to (batch_size, num_heads, 64, depth) + // Do we really need to transpose here? + // (Maybe not, we can play with strides of the gemm and do independent gemms + // for each encoder head) + + // Apply scaled dot product attention: + /* + matmul_qk = tf.matmul(q, k, transpose_b=True) + dk = tf.cast(tf.shape(k)[-1], self.model_dtype) + scaled_attention_logits = matmul_qk / tf.math.sqrt(dk) + attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) + output = tf.matmul(attention_weights, v) + */ + + // shape(k)[-1] = depth + float factor = 1.0f / sqrt((float)depth); + + // matmul_qk = tf.matmul(q, k, transpose_b=True) + for (int i = 0; i < encoder_heads_; i++) { + int offset = i * depth; + // layout of the output: encoder_heads_ * Batch * 64 * 64 + int outOffset = i * N * 64 * 64; + cublasXGemmStridedBatched( + cublas, CUBLAS_OP_T, CUBLAS_OP_N, 64 /*M*/, 64 /*N*/, + depth /*K*/, // A/B, and M/N are swapped for row-major to col-major + // transform + factor, // to handle "/ tf.math.sqrt(dk)" + mha_k + offset /*A*/, + d_model /*LDA*/, // (d_model = depth * encoder_heads_) to skip over + // other "depth" slices / heads + 64 * d_model, /*strideA*/ + mha_q + offset /*B*/, + d_model /*LDB*/, // to skip over other other "depth" slices / heads + 64 * d_model, /*strideB*/ + 0.0f, + scratch2 + outOffset /*C*/, // output (matmul_qk) goes to scratch2 + 64 /*LDC*/, 64 * 64 /*strideC*/, N); + } + + // attention_weights = tf.nn.softmax(scaled_attention_logits, axis = -1) + // attention_weights -> scratch2 + Softmax(encoder_heads_ * N * 64, 64, scratch2, scratch2, stream); + + // output = tf.matmul(attention_weights, v) + for (int i = 0; i < encoder_heads_; i++) { + int offset = i * depth; // for output and "v" matrix + // layout: encoder_heads_ * Batch*64*64 + int weightsOffset = i * N * 64 * 64; + cublasXGemmStridedBatched( + cublas, CUBLAS_OP_N, CUBLAS_OP_N, depth /*M*/, 64 /*N*/, 64 /*K*/, 1.0f, + mha_v + offset /*A*/, // "v" matrix + d_model /*LDA*/, // to skip over other "depth" slices / heads + 64 * d_model, /*strideA*/ + scratch2 + weightsOffset /*B*/, 64 /*LDB*/, 64 * 64, /*strideB*/ + 0.0f, scratch3 + offset /*C*/, // output goes to scratch3 + d_model /*LDC*/, 64 * d_model /*strideC*/, N); + } + + // #final dense layer (mha_dense), scratch3 -> scratch2 + { + const int num_inputs = d_model; + const int num_outputs = embedding_op_size_; + const int batch = N * 64; + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, (const DataType*)mha_dense_w, num_inputs, + scratch3, num_inputs, 0.0f, scratch2, num_outputs); + } + + // LN1: skip connection and layer normalization (also bias add of prev gemm) + // scratch2/scratch1 -> scratch0 + LayerNorm(N * 64, embedding_op_size_, scratch0, scratch2, + mha_dense_b, scratch1, ln1_gammas, ln1_betas, + 1e-6, stream); + + // #FFN dense 1, scratch0 -> scratch1 + const int encoder_dff = ffn_dense1_size_; + { + const int num_inputs = embedding_op_size_; + const int num_outputs = encoder_dff; + const int batch = N * 64; + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, (const DataType*)ffn_dense1_w, num_inputs, + scratch0, num_inputs, 0.0f, scratch1, num_outputs); + addBiasBatched(scratch1, scratch1, ffn_dense1_b, 1, batch, num_outputs, + SELU, stream); + } + + // #FFN dense 2, scratch1 -> scratch2 + { + const int num_inputs = encoder_dff; + const int num_outputs = embedding_op_size_; + const int batch = N * 64; + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, (const DataType*)ffn_dense2_w, num_inputs, + scratch1, num_inputs, 0.0f, scratch2, num_outputs); + } + + // LN2: skip connection and layer normilization (also bias add of prev gemm) + // scratch2/scratch0 -> scratch1 + LayerNorm(N * 64, embedding_op_size_, scratch1, scratch2, + ffn_dense2_b, scratch0, ln2_gammas, ln2_betas, + 1e-6, stream); +} + template void AttentionPolicyHead::Eval( int N, DataType* output, const DataType* input, const DataType* input2, @@ -1508,142 +1649,7 @@ void AttentionPolicyHead::Eval( // 2. Encoder layers for (const auto pEnc : encoder_weights_) { - const auto& enc = *pEnc; - const int d_model = enc.mha_q_size_; - const int depth = d_model / encoder_heads_; - - DataType* mha_q; - DataType* mha_k; - DataType* mha_v; - - { - const int num_inputs = embedding_op_size_; - const int num_outputs = d_model; - const int batch = N * 64; - - mha_q = scratch0; - mha_k = mha_q + num_outputs * batch; - mha_v = mha_k + num_outputs * batch; - - cublasXGemmStridedBatched( - cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, - 1.0f, enc.mha_qkv_w, num_inputs, num_inputs * num_outputs, - pol_embedding, num_inputs, 0, 0.0f, mha_q, num_outputs, - num_outputs * batch, 3); - addBiasBatched(mha_q, mha_q, enc.mha_qkv_b, 3, batch, - num_outputs, NONE, stream); - } - - // Apply split_heads() to q, k and v - // which basically transposes (batch_size, 64, num_heads, depth) - // to (batch_size, num_heads, 64, depth) - // Do we really need to transpose here? - // (Maybe not, we can play with strides of the gemm and do independent gemms - // for each encoder head) - - // Apply scaled dot product attention: - /* - matmul_qk = tf.matmul(q, k, transpose_b=True) - dk = tf.cast(tf.shape(k)[-1], self.model_dtype) - scaled_attention_logits = matmul_qk / tf.math.sqrt(dk) - attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) - output = tf.matmul(attention_weights, v) - */ - - // shape(k)[-1] = depth - float factor = 1.0f / sqrt((float)depth); - - // matmul_qk = tf.matmul(q, k, transpose_b=True) - for (int i = 0; i < encoder_heads_; i++) { - int offset = i * depth; - // layout of the output: encoder_heads_ * Batch * 64 * 64 - int outOffset = i * N * 64 * 64; - cublasXGemmStridedBatched( - cublas, CUBLAS_OP_T, CUBLAS_OP_N, - 64 /*M*/, 64 /*N*/, - depth /*K*/, // A/B, and M/N are swapped for row-major to col-major - // transform - factor, // to handle "/ tf.math.sqrt(dk)" - mha_k + offset /*A*/, - d_model /*LDA*/, // (d_model = depth * encoder_heads_) to skip over - // other "depth" slices / heads - 64 * d_model, /*strideA*/ - mha_q + offset /*B*/, - d_model /*LDB*/, // to skip over other other "depth" slices / heads - 64 * d_model, /*strideB*/ - 0.0f, - scratch2 + outOffset /*C*/, // output (matmul_qk) goes to scratch2 - 64 /*LDC*/, 64 * 64 /*strideC*/, N); - } - - // attention_weights = tf.nn.softmax(scaled_attention_logits, axis = -1) - // attention_weights -> scratch2 - Softmax(encoder_heads_ * N * 64, 64, scratch2, scratch2, stream); - - // output = tf.matmul(attention_weights, v) - for (int i = 0; i < encoder_heads_; i++) { - int offset = i * depth; // for output and "v" matrix - // layout: encoder_heads_ * Batch*64*64 - int weightsOffset = i * N * 64 * 64; - cublasXGemmStridedBatched( - cublas, CUBLAS_OP_N, CUBLAS_OP_N, depth /*M*/, 64 /*N*/, 64 /*K*/, - 1.0f, mha_v + offset /*A*/, // "v" matrix - d_model /*LDA*/, // to skip over other "depth" slices / heads - 64 * d_model, /*strideA*/ - scratch2 + weightsOffset /*B*/, 64 /*LDB*/, 64 * 64, /*strideB*/ - 0.0f, scratch3 + offset /*C*/, // output goes to scratch3 - d_model /*LDC*/, 64 * d_model /*strideC*/, N); - } - - // #final dense layer (mha_dense), scratch3 -> scratch2 - { - const int num_inputs = d_model; - const int num_outputs = embedding_op_size_; - const int batch = N * 64; - cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, (const DataType*)enc.mha_dense_w, - num_inputs, scratch3, num_inputs, 0.0f, scratch2, - num_outputs); - } - - // LN1: skip connection and layer normalization (also bias add of prev gemm) - // scratch2/scratch1 -> scratch0 - LayerNorm(N * 64, embedding_op_size_, scratch0, scratch2, - enc.mha_dense_b, scratch1, enc.ln1_gammas, - enc.ln1_betas, 1e-6, stream); - - // #FFN dense 1, scratch0 -> scratch1 - const int encoder_dff = enc.ffn_dense1_size_; - { - const int num_inputs = embedding_op_size_; - const int num_outputs = encoder_dff; - const int batch = N * 64; - cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, (const DataType*)enc.ffn_dense1_w, - num_inputs, scratch0, num_inputs, 0.0f, scratch1, - num_outputs); - addBiasBatched(scratch1, scratch1, enc.ffn_dense1_b, 1, batch, - num_outputs, SELU, stream); - } - - // #FFN dense 2, scratch1 -> scratch2 - { - const int num_inputs = encoder_dff; - const int num_outputs = embedding_op_size_; - const int batch = N * 64; - cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, (const DataType*)enc.ffn_dense2_w, - num_inputs, scratch1, num_inputs, 0.0f, scratch2, - num_outputs); - } - - // LN2: skip connection and layer normilization (also bias add of prev gemm) - // scratch2/scratch0 -> scratch1 - LayerNorm(N * 64, embedding_op_size_, scratch1, scratch2, - enc.ffn_dense2_b, scratch0, enc.ln2_gammas, - enc.ln2_betas, 1e-6, stream); - - + pEnc->Eval(N, scratch1, scratch0, scratch2, scratch3, cublas, stream); } // End of encoder blocks DataType* wq; @@ -1709,7 +1715,7 @@ AttentionPolicyHead::~AttentionPolicyHead() { } template -AttentionPolicyHead::EncoderWeights::~EncoderWeights() { +EncoderBlock::~EncoderBlock() { ReportCUDAErrors(cudaFree(mha_q_w)); ReportCUDAErrors(cudaFree(mha_q_b)); ReportCUDAErrors(cudaFree(mha_k_w)); @@ -1730,6 +1736,95 @@ AttentionPolicyHead::EncoderWeights::~EncoderWeights() { ReportCUDAErrors(cudaFree(ln2_betas)); } + +template +AttentionBody::AttentionBody(BaseLayer* ip, + const LegacyWeights& weights, + void* scratch, + ActivationFunction default_act, + int num_res_blocks) + : BaseLayer(64 * weights.ip_emb_b.size(), 1, 1, ip) { + embedding_op_size_ = weights.ip_emb_b.size(); + encoder_head_count_ = weights.encoder_head_count; + num_resi_blocks_ = num_res_blocks; + default_act_ = default_act; + + allocAndUpload(&ip_emb_w_, weights.ip_emb_w, scratch); + allocAndUpload(&ip_emb_b_, weights.ip_emb_b, scratch); + + for (const auto& enc : weights.pol_encoder) { + EncoderBlock* pW = new EncoderBlock( + enc, scratch, encoder_head_count_, embedding_op_size_); + encoder_weights_.emplace_back(pW); + } +} + +template +AttentionBody::~AttentionBody() { + ReportCUDAErrors(cudaFree(ip_emb_w_)); + ReportCUDAErrors(cudaFree(ip_emb_b_)); + for (const auto pEnc : encoder_weights_) delete pEnc; +} + + +template +void AttentionBody::Eval( + int N, DataType* output, const DataType* input, const DataType* input2, + void* scratch, size_t scratch_size, cudnnHandle_t /*cudnn*/, + cublasHandle_t cublas, cudaStream_t stream) { + DataType* scratch0 = (DataType*)scratch; + DataType* scratch1 = (DataType*)input2; + DataType* scratch2 = output + scratch_size / (2 * sizeof(DataType)); + DataType* scratch3 = scratch1 + scratch_size / (2 * sizeof(DataType)); + + int inputC = this->input_->GetC(); + if (num_resi_blocks_ == 0) + { + assert(inputC == kInputPlanes); + /* + # if there are no residual blocks (pure transformer), do some input + processing + flow = tf.transpose(inputs, perm=[0, 2, 3, 1]) + flow = tf.reshape(flow, [-1, 64, tf.shape(inputs)[1]]) + # add positional encoding for each square to the input + positional_encoding = tf.broadcast_to(tf.convert_to_tensor(self.POS_ENC, + dtype=self.model_dtype), [tf.shape(flow)[0], 64, + tf.shape(self.POS_ENC)[2]]) flow = tf.concat([flow, positional_encoding], + axis=2) + */ + inputPreprocessForAttentionBody(scratch0, input, N, stream); + inputC += 6; + } else { + // #redirect flow through encoder blocks + // flow = tf.transpose(flow, perm = [ 0, 2, 3, 1 ]) + // flow = tf.reshape(flow, [ -1, 64, self.RESIDUAL_FILTERS ]) + convertNCHWtoNHWC(scratch0, input, N, inputC, N, inputC, 8, 8); + } + + // 1. square embedding (fully connected layer) + // Input data in NHWC layout N*(64)*C, output is N*(64)*embedding_op_size_ + DataType* embedding = scratch1; + { + const int num_outputs = embedding_op_size_; + const int num_inputs = inputC; + const int batch = N * 64; + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, (const DataType*)ip_emb_w_, + num_inputs, scratch0, num_inputs, 0.0f, embedding, + num_outputs); + addBiasBatched(embedding, embedding, ip_emb_b_, 1, batch, + num_outputs, default_act_, stream); + } + + // 2. Encoder layers + for (const auto pEnc : encoder_weights_) { + pEnc->Eval(N, scratch1, scratch0, scratch2, scratch3, cublas, stream); + } // End of encoder blocks + + +} + + // Template instantiation. #ifdef USE_CUDNN template class ConvLayer; @@ -1757,6 +1852,12 @@ template class ResidualBlock; template class AttentionPolicyHead; template class AttentionPolicyHead; +template class EncoderBlock; +template class EncoderBlock; + +template class AttentionBody; +template class AttentionBody; + // Misc error handling stuff. #ifdef USE_CUDNN void CudnnError(cudnnStatus_t status, const char* file, const int& line) { diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index 2bb56ce15b..097e196aba 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -332,6 +332,42 @@ class ResidualBlock : public BaseLayer { DataType* b2_; }; +template +class EncoderBlock { + public: + EncoderBlock(const LegacyWeights::EncoderLayer& cpu_weights, void* scratch, int heads, int size); + ~EncoderBlock(); + + void Eval(int N, DataType* inpop, DataType* scratch0, DataType* scratch1, + DataType* scratch2, cublasHandle_t cublas, + cudaStream_t stream) const; + + // all GPU side pointers + DataType *mha_q_w, *mha_q_b; + DataType *mha_k_w, *mha_k_b; + DataType *mha_v_w, *mha_v_b; + DataType *mha_qkv_w, *mha_qkv_b; + DataType *mha_dense_w, *mha_dense_b; + + DataType *ln1_gammas, *ln1_betas; + + DataType *ffn_dense1_w, *ffn_dense1_b; + DataType *ffn_dense2_w, *ffn_dense2_b; + + DataType *ln2_gammas, *ln2_betas; + + int mha_q_size_; + int mha_k_size_; + int mha_v_size_; + int mha_dense_size_; + + int ffn_dense1_size_; + int ffn_dense2_size_; + + int embedding_op_size_; + int encoder_heads_; +}; + // The Attention policy head implementation // Responsible for loading weights into GPU memory, and evaluating the entire // policy head @@ -354,33 +390,6 @@ class AttentionPolicyHead : public BaseLayer { cudaStream_t stream) override; private: - struct EncoderWeights { - EncoderWeights(const LegacyWeights::EncoderLayer& cpu_weights, - void* scratch); - ~EncoderWeights(); - // all GPU side pointers - DataType *mha_q_w, *mha_q_b; - DataType *mha_k_w, *mha_k_b; - DataType *mha_v_w, *mha_v_b; - DataType *mha_qkv_w, *mha_qkv_b; - DataType *mha_dense_w, *mha_dense_b; - - DataType *ln1_gammas, *ln1_betas; - - DataType *ffn_dense1_w, *ffn_dense1_b; - DataType *ffn_dense2_w, *ffn_dense2_b; - - DataType *ln2_gammas, *ln2_betas; - - int mha_q_size_; - int mha_k_size_; - int mha_v_size_; - int mha_dense_size_; - - int ffn_dense1_size_; - int ffn_dense2_size_; - }; - // GPU allocations to hold various weights used by the attention policy head DataType *ip_pol_w_, *ip_pol_b_; // "embedding" in policy attention DataType *ip2_pol_w_, *ip2_pol_b_; // "wq" in policy attention @@ -396,8 +405,41 @@ class AttentionPolicyHead : public BaseLayer { int encoder_heads_; int policy_d_model_; - std::vector encoder_weights_; + std::vector*> encoder_weights_; +}; + + +// The Attention body implementation +// Responsible for loading weights into GPU memory, and evaluating the entire +// attention network part of the body including the stack of encoder layers +template +class AttentionBody : public BaseLayer { + using BaseLayer::C; + using BaseLayer::H; + using BaseLayer::W; + using BaseLayer::GetC; + using BaseLayer::GetH; + using BaseLayer::GetW; + + public: + AttentionBody(BaseLayer* ip, const LegacyWeights& weights, + void* scratch, ActivationFunction default_act, int num_res_blocks); + ~AttentionBody(); + void Eval(int N, DataType* output, const DataType* input, + const DataType* input2, void* scratch, size_t scratch_size, + cudnnHandle_t cudnn, cublasHandle_t cublas, + cudaStream_t stream) override; + + private: + // GPU allocations to hold various weights used by the attention policy head + DataType *ip_emb_w_, *ip_emb_b_; // "embedding" layer in net body + int embedding_op_size_; + int encoder_head_count_; + std::vector*> encoder_weights_; + ActivationFunction default_act_; + int num_resi_blocks_; }; + } // namespace cudnn_backend } // namespace lczero diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 3a1ccacaca..2fb036c0aa 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -226,6 +226,19 @@ class CudaNetwork : public Network { const int kNumFilters = (int)weights.input.biases.size(); numBlocks_ = (int)weights.residual.size(); + attn_body_ = (weights.ip_emb_b.size() > 0); + if (attn_body_) { + if (numBlocks_ > 0) + throw Exception("Found residual blocks in network with attention body!"); + } + num_encoder_blocks_ = (int) weights.encoder.size(); + + // Ankan - test! + printf("\nNum filters: %d, num blocks: %d\n", kNumFilters, numBlocks_); + printf("\nNum encoder blocks: %d, num policy encoder blocks: %d\n", + num_encoder_blocks_, (int)weights.pol_encoder.size()); + + // Warn if the memory required for storing transformed weights is // going to exceed 40% of total video memory, force custom_winograd off // if it's going to exceed 50% of memory. @@ -294,63 +307,75 @@ class CudaNetwork : public Network { // 2. Build the network, and copy the weights to GPU memory. - // Input. - { - auto inputConv = std::make_unique>( - nullptr, kNumFilters, 8, 8, kNumInputPlanes, mish_net ? MISH : RELU, - true, false, false, 0, use_gemm_ex, use_res_block_winograd_fuse_opt_); - inputConv->LoadWeights(&weights.input.weights[0], - &weights.input.biases[0], scratch_mem_); - network_.emplace_back(std::move(inputConv)); - } - - // Residual block. - for (int block = 0; block < numBlocks_; block++) { - bool has_se = weights.residual[block].has_se; - int se_k = (int)weights.residual[block].se.b1.size(); - - if (use_res_block_winograd_fuse_opt_) { - auto layer = std::make_unique>( - getLastLayer(), kNumFilters, has_se, se_k, use_gemm_ex, block == 0, - block == (numBlocks_ - 1), mish_net ? MISH : RELU, deviceProp.sharedMemPerBlockOptin); - layer->LoadWeights0(&weights.residual[block].conv1.weights[0], - &weights.residual[block].conv1.biases[0], - scratch_mem_); - layer->LoadWeights1(&weights.residual[block].conv2.weights[0], - &weights.residual[block].conv2.biases[0], - scratch_mem_); - if (has_se) - layer->LoadSEWeights(&weights.residual[block].se.w1[0], - &weights.residual[block].se.b1[0], - &weights.residual[block].se.w2[0], - &weights.residual[block].se.b2[0], scratch_mem_); - network_.emplace_back(std::move(layer)); - } else { - auto conv1 = std::make_unique>( - getLastLayer(), kNumFilters, 8, 8, kNumFilters, - mish_net ? MISH : RELU, true, false, false, 0, use_gemm_ex); - conv1->LoadWeights(&weights.residual[block].conv1.weights[0], - &weights.residual[block].conv1.biases[0], - scratch_mem_); - network_.emplace_back(std::move(conv1)); + if (attn_body_) { + auto body = std::make_unique>( + getLastLayer(), weights, scratch_mem_, mish_net ? MISH : RELU, + numBlocks_); + network_.emplace_back(std::move(body)); + } else { + // Input. + { + auto inputConv = std::make_unique>( + nullptr, kNumFilters, 8, 8, kNumInputPlanes, mish_net ? MISH : RELU, + true, false, false, 0, use_gemm_ex, + use_res_block_winograd_fuse_opt_); + inputConv->LoadWeights(&weights.input.weights[0], + &weights.input.biases[0], scratch_mem_); + network_.emplace_back(std::move(inputConv)); + } - auto conv2 = std::make_unique>( - getLastLayer(), kNumFilters, 8, 8, kNumFilters, - mish_net ? MISH : RELU, true, true, has_se, se_k, use_gemm_ex); - conv2->LoadWeights(&weights.residual[block].conv2.weights[0], - &weights.residual[block].conv2.biases[0], - scratch_mem_); - if (has_se) - conv2->LoadSEWeights(&weights.residual[block].se.w1[0], - &weights.residual[block].se.b1[0], - &weights.residual[block].se.w2[0], - &weights.residual[block].se.b2[0], scratch_mem_); - network_.emplace_back(std::move(conv2)); + // Residual block. + for (int block = 0; block < numBlocks_; block++) { + bool has_se = weights.residual[block].has_se; + int se_k = (int)weights.residual[block].se.b1.size(); + + if (use_res_block_winograd_fuse_opt_) { + auto layer = std::make_unique>( + getLastLayer(), kNumFilters, has_se, se_k, use_gemm_ex, + block == 0, block == (numBlocks_ - 1), mish_net ? MISH : RELU, + deviceProp.sharedMemPerBlockOptin); + layer->LoadWeights0(&weights.residual[block].conv1.weights[0], + &weights.residual[block].conv1.biases[0], + scratch_mem_); + layer->LoadWeights1(&weights.residual[block].conv2.weights[0], + &weights.residual[block].conv2.biases[0], + scratch_mem_); + if (has_se) + layer->LoadSEWeights(&weights.residual[block].se.w1[0], + &weights.residual[block].se.b1[0], + &weights.residual[block].se.w2[0], + &weights.residual[block].se.b2[0], + scratch_mem_); + network_.emplace_back(std::move(layer)); + } else { + auto conv1 = std::make_unique>( + getLastLayer(), kNumFilters, 8, 8, kNumFilters, + mish_net ? MISH : RELU, true, false, false, 0, use_gemm_ex); + conv1->LoadWeights(&weights.residual[block].conv1.weights[0], + &weights.residual[block].conv1.biases[0], + scratch_mem_); + network_.emplace_back(std::move(conv1)); + + auto conv2 = std::make_unique>( + getLastLayer(), kNumFilters, 8, 8, kNumFilters, + mish_net ? MISH : RELU, true, true, has_se, se_k, use_gemm_ex); + conv2->LoadWeights(&weights.residual[block].conv2.weights[0], + &weights.residual[block].conv2.biases[0], + scratch_mem_); + if (has_se) + conv2->LoadSEWeights(&weights.residual[block].se.w1[0], + &weights.residual[block].se.b1[0], + &weights.residual[block].se.w2[0], + &weights.residual[block].se.b2[0], + scratch_mem_); + network_.emplace_back(std::move(conv2)); + } } } resi_last_ = getLastLayer(); + // Policy head. if (attn_policy_) { auto AttentionPolicy = std::make_unique>( @@ -399,9 +424,9 @@ class CudaNetwork : public Network { scratch_mem_); network_.emplace_back(std::move(FCPol)); } - policy_out_ = getLastLayer(); // Value head. + if (!attn_body_) { auto convVal = std::make_unique>( resi_last_, weights.value.biases.size(), 8, 8, kNumFilters, @@ -428,13 +453,12 @@ class CudaNetwork : public Network { scratch_mem_); network_.emplace_back(std::move(FCVal2)); } - value_out_ = getLastLayer(); // Moves left head moves_left_ = (file.format().network_format().moves_left() == pblczero::NetworkFormat::MOVES_LEFT_V1) && options.GetOrDefault("mlh", true); - if (moves_left_) { + if ((!attn_body_) && moves_left_) { auto convMov = std::make_unique>( resi_last_, weights.moves_left.biases.size(), 8, 8, kNumFilters, mish_net ? MISH : RELU, true, use_gemm_ex); @@ -455,7 +479,6 @@ class CudaNetwork : public Network { scratch_mem_); network_.emplace_back(std::move(FCMov2)); } - moves_left_out_ = getLastLayer(); // 3. Allocate GPU memory for running the network: // - three buffers of max size are enough (one to hold input, second to @@ -532,36 +555,21 @@ class CudaNetwork : public Network { float* opMov = io->op_moves_left_mem_gpu_; int l = 0; - // Input. - network_[l++]->Eval( - batchSize, - use_res_block_winograd_fuse_opt_ ? tensor_mem[1] : tensor_mem[2], - tensor_mem[0], nullptr, scratch_mem, scratch_size_, nullptr, cublas, - stream); // input conv - - // Residual block. - for (int block = 0; block < numBlocks_; block++) { - if (use_res_block_winograd_fuse_opt_) { - network_[l++]->Eval(batchSize, tensor_mem[2], tensor_mem[1], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // block - } else { - network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[2], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // conv1 - - network_[l++]->Eval(batchSize, tensor_mem[2], tensor_mem[0], - tensor_mem[2], scratch_mem, scratch_size_, nullptr, - cublas, stream); // conv2 - } - } - // Policy head. - if (attn_policy_) { + if (attn_body_) { + // 1. Entire Attention network body (including value and wdl heads) + network_[l++]->Eval( + batchSize, tensor_mem[2], tensor_mem[0], tensor_mem[1], scratch_mem, + scratch_size_, nullptr, cublas, + stream); + + // 2. Attention policy head network_[l++]->Eval( batchSize, tensor_mem[0], tensor_mem[2], tensor_mem[1], scratch_mem, scratch_size_, nullptr, cublas, stream); // Entire Attention policy head except for the policy map + + // 3. Policy map layer if (fp16) { network_[l++]->Eval(batchSize, tensor_mem[1], tensor_mem[0], nullptr, scratch_mem, scratch_size_, nullptr, cublas, @@ -574,110 +582,160 @@ class CudaNetwork : public Network { scratch_mem, scratch_size_, nullptr, cublas, stream); // policy map layer // POLICY output } + } else { + // Old CNN style networks + // Input. + network_[l++]->Eval( + batchSize, + use_res_block_winograd_fuse_opt_ ? tensor_mem[1] : tensor_mem[2], + tensor_mem[0], nullptr, scratch_mem, scratch_size_, nullptr, cublas, + stream); // input conv + + // Residual block. + for (int block = 0; block < numBlocks_; block++) { + if (use_res_block_winograd_fuse_opt_) { + network_[l++]->Eval(batchSize, tensor_mem[2], tensor_mem[1], nullptr, + scratch_mem, scratch_size_, nullptr, cublas, + stream); // block + } else { + network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[2], nullptr, + scratch_mem, scratch_size_, nullptr, cublas, + stream); // conv1 + + network_[l++]->Eval(batchSize, tensor_mem[2], tensor_mem[0], + tensor_mem[2], scratch_mem, scratch_size_, + nullptr, cublas, stream); // conv2 + } + } - } else if (conv_policy_) { - network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[2], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // policy conv1 - - network_[l++]->Eval(batchSize, tensor_mem[1], tensor_mem[0], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // policy conv2 + // Policy head. + if (attn_policy_) { + network_[l++]->Eval( + batchSize, tensor_mem[0], tensor_mem[2], tensor_mem[1], scratch_mem, + scratch_size_, nullptr, cublas, + stream); // Entire Attention policy head except for the policy map + if (fp16) { + network_[l++]->Eval(batchSize, tensor_mem[1], tensor_mem[0], nullptr, + scratch_mem, scratch_size_, nullptr, cublas, + stream); // policy map layer + copyTypeConverted(opPol, (half*)(tensor_mem[1]), + batchSize * kNumOutputPolicy, + stream); // POLICY output + } else { + network_[l++]->Eval(batchSize, (DataType*)opPol, tensor_mem[0], + nullptr, scratch_mem, scratch_size_, nullptr, + cublas, + stream); // policy map layer // POLICY output + } - if (fp16) { - network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[1], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // policy map layer - copyTypeConverted(opPol, (half*)(tensor_mem[0]), - batchSize * kNumOutputPolicy, - stream); // POLICY output - } else { - network_[l++]->Eval(batchSize, (DataType*)opPol, tensor_mem[1], nullptr, + } else if (conv_policy_) { + network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[2], nullptr, scratch_mem, scratch_size_, nullptr, cublas, - stream); // policy map layer // POLICY output - } - } else { - network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[2], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // pol conv + stream); // policy conv1 - if (fp16) { network_[l++]->Eval(batchSize, tensor_mem[1], tensor_mem[0], nullptr, scratch_mem, scratch_size_, nullptr, cublas, - stream); // pol FC - - copyTypeConverted(opPol, (half*)(tensor_mem[1]), - batchSize * kNumOutputPolicy, stream); // POLICY + stream); // policy conv2 + + if (fp16) { + network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[1], nullptr, + scratch_mem, scratch_size_, nullptr, cublas, + stream); // policy map layer + copyTypeConverted(opPol, (half*)(tensor_mem[0]), + batchSize * kNumOutputPolicy, + stream); // POLICY output + } else { + network_[l++]->Eval(batchSize, (DataType*)opPol, tensor_mem[1], + nullptr, scratch_mem, scratch_size_, nullptr, + cublas, + stream); // policy map layer // POLICY output + } } else { - network_[l++]->Eval(batchSize, (DataType*)opPol, tensor_mem[0], nullptr, + network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[2], nullptr, scratch_mem, scratch_size_, nullptr, cublas, - stream); // pol FC // POLICY + stream); // pol conv + + if (fp16) { + network_[l++]->Eval(batchSize, tensor_mem[1], tensor_mem[0], nullptr, + scratch_mem, scratch_size_, nullptr, cublas, + stream); // pol FC + + copyTypeConverted(opPol, (half*)(tensor_mem[1]), + batchSize * kNumOutputPolicy, stream); // POLICY + } else { + network_[l++]->Eval(batchSize, (DataType*)opPol, tensor_mem[0], + nullptr, scratch_mem, scratch_size_, nullptr, + cublas, + stream); // pol FC // POLICY + } } - } - - // Copy policy output from device memory to host memory. - ReportCUDAErrors( - cudaMemcpyAsync(io->op_policy_mem_, io->op_policy_mem_gpu_, - sizeof(float) * kNumOutputPolicy * batchSize, - cudaMemcpyDeviceToHost, stream)); - // value head - network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[2], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // value conv - - network_[l++]->Eval(batchSize, tensor_mem[1], tensor_mem[0], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // value FC1 - - if (wdl_) { - if (fp16) { - network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[1], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // value FC2 // VALUE - copyTypeConverted(opVal, (half*)(tensor_mem[0]), 3 * batchSize, - stream); // VALUE - } else { - network_[l++]->Eval(batchSize, (DataType*)opVal, tensor_mem[1], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // value FC2 // VALUE - } - } else { - if (fp16) { - // TODO: consider fusing the bias-add of FC2 with format conversion. - network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[1], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // value FC2 - copyTypeConverted(opVal, (half*)(tensor_mem[0]), batchSize, - stream); // VALUE - } else { - network_[l++]->Eval(batchSize, (DataType*)opVal, tensor_mem[1], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // value FC2 // VALUE - } - } + // Copy policy output from device memory to host memory. + ReportCUDAErrors( + cudaMemcpyAsync(io->op_policy_mem_, io->op_policy_mem_gpu_, + sizeof(float) * kNumOutputPolicy * batchSize, + cudaMemcpyDeviceToHost, stream)); - if (moves_left_) { - // Moves left head + // value head network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[2], nullptr, scratch_mem, scratch_size_, nullptr, cublas, - stream); // moves conv + stream); // value conv network_[l++]->Eval(batchSize, tensor_mem[1], tensor_mem[0], nullptr, scratch_mem, scratch_size_, nullptr, cublas, - stream); // moves FC1 + stream); // value FC1 + + if (wdl_) { + if (fp16) { + network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[1], nullptr, + scratch_mem, scratch_size_, nullptr, cublas, + stream); // value FC2 // VALUE + copyTypeConverted(opVal, (half*)(tensor_mem[0]), 3 * batchSize, + stream); // VALUE + } else { + network_[l++]->Eval(batchSize, (DataType*)opVal, tensor_mem[1], + nullptr, scratch_mem, scratch_size_, nullptr, + cublas, + stream); // value FC2 // VALUE + } + } else { + if (fp16) { + // TODO: consider fusing the bias-add of FC2 with format conversion. + network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[1], nullptr, + scratch_mem, scratch_size_, nullptr, cublas, + stream); // value FC2 + copyTypeConverted(opVal, (half*)(tensor_mem[0]), batchSize, + stream); // VALUE + } else { + network_[l++]->Eval(batchSize, (DataType*)opVal, tensor_mem[1], + nullptr, scratch_mem, scratch_size_, nullptr, + cublas, + stream); // value FC2 // VALUE + } + } - // Moves left FC2 - if (fp16) { - // TODO: consider fusing the bias-add of FC2 with format conversion. - network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[1], nullptr, + if (moves_left_) { + // Moves left head + network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[2], nullptr, scratch_mem, scratch_size_, nullptr, cublas, - stream); - copyTypeConverted(opMov, (half*)(tensor_mem[0]), batchSize, stream); - } else { - network_[l++]->Eval(batchSize, (DataType*)opMov, tensor_mem[1], nullptr, + stream); // moves conv + + network_[l++]->Eval(batchSize, tensor_mem[1], tensor_mem[0], nullptr, scratch_mem, scratch_size_, nullptr, cublas, - stream); + stream); // moves FC1 + + // Moves left FC2 + if (fp16) { + // TODO: consider fusing the bias-add of FC2 with format conversion. + network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[1], nullptr, + scratch_mem, scratch_size_, nullptr, cublas, + stream); + copyTypeConverted(opMov, (half*)(tensor_mem[0]), batchSize, stream); + } else { + network_[l++]->Eval(batchSize, (DataType*)opMov, tensor_mem[1], + nullptr, scratch_mem, scratch_size_, nullptr, + cublas, stream); + } } } @@ -774,13 +832,12 @@ class CudaNetwork : public Network { bool has_se_; bool conv_policy_; bool attn_policy_; + bool attn_body_; + int num_encoder_blocks_; std::vector>> network_; BaseLayer* getLastLayer() { return network_.back().get(); } BaseLayer* resi_last_; - BaseLayer* policy_out_; - BaseLayer* value_out_; - BaseLayer* moves_left_out_; size_t tensor_mem_size_; size_t scratch_size_; diff --git a/src/neural/network_legacy.cc b/src/neural/network_legacy.cc index f4819d4952..0872ae0ddc 100644 --- a/src/neural/network_legacy.cc +++ b/src/neural/network_legacy.cc @@ -30,6 +30,8 @@ static constexpr float kEpsilon = 1e-5f; LegacyWeights::LegacyWeights(const pblczero::Weights& weights) : input(weights.input()), + ip_emb_w(LayerAdapter(weights.ip_emb_w()).as_vector()), + ip_emb_b(LayerAdapter(weights.ip_emb_b()).as_vector()), policy1(weights.policy1()), policy(weights.policy()), ip_pol_w(LayerAdapter(weights.ip_pol_w()).as_vector()), @@ -40,11 +42,15 @@ LegacyWeights::LegacyWeights(const pblczero::Weights& weights) ip3_pol_b(LayerAdapter(weights.ip3_pol_b()).as_vector()), ip4_pol_w(LayerAdapter(weights.ip4_pol_w()).as_vector()), value(weights.value()), + ip_val_w(LayerAdapter(weights.ip_val_w()).as_vector()), + ip_val_b(LayerAdapter(weights.ip_val_b()).as_vector()), ip1_val_w(LayerAdapter(weights.ip1_val_w()).as_vector()), ip1_val_b(LayerAdapter(weights.ip1_val_b()).as_vector()), ip2_val_w(LayerAdapter(weights.ip2_val_w()).as_vector()), ip2_val_b(LayerAdapter(weights.ip2_val_b()).as_vector()), moves_left(weights.moves_left()), + ip_mov_w(LayerAdapter(weights.ip_mov_w()).as_vector()), + ip_mov_b(LayerAdapter(weights.ip_mov_b()).as_vector()), ip1_mov_w(LayerAdapter(weights.ip1_mov_w()).as_vector()), ip1_mov_b(LayerAdapter(weights.ip1_mov_b()).as_vector()), ip2_mov_w(LayerAdapter(weights.ip2_mov_w()).as_vector()), @@ -52,6 +58,10 @@ LegacyWeights::LegacyWeights(const pblczero::Weights& weights) for (const auto& res : weights.residual()) { residual.emplace_back(res); } + encoder_head_count = weights.headcount(); + for (const auto& enc : weights.encoder()) { + encoder.emplace_back(enc); + } pol_encoder_head_count = weights.pol_headcount(); for (const auto& enc : weights.pol_encoder()) { pol_encoder.emplace_back(enc); diff --git a/src/neural/network_legacy.h b/src/neural/network_legacy.h index 3ba6028d5e..19284af172 100644 --- a/src/neural/network_legacy.h +++ b/src/neural/network_legacy.h @@ -88,6 +88,15 @@ struct LegacyWeights { // Input convnet. ConvBlock input; + // Embedding layer + Vec ip_emb_w; + Vec ip_emb_b; + + // Encoder stack. + std::vector encoder; + int encoder_head_count; + + // Residual tower. std::vector residual; @@ -109,6 +118,8 @@ struct LegacyWeights { // Value head ConvBlock value; + Vec ip_val_w; + Vec ip_val_b; Vec ip1_val_w; Vec ip1_val_b; Vec ip2_val_w; @@ -116,6 +127,8 @@ struct LegacyWeights { // Moves left head ConvBlock moves_left; + Vec ip_mov_w; + Vec ip_mov_b; Vec ip1_mov_w; Vec ip1_mov_b; Vec ip2_mov_w; From d6728dc7504045cfda6ca66edcd14081662274ff Mon Sep 17 00:00:00 2001 From: Ankan Banerjee Date: Wed, 23 Mar 2022 17:06:07 +0530 Subject: [PATCH 02/70] more updates to match training code - skip connection add before layer norm now has a scaling factor (alpha) - replace conv layer of value and mlh heads with an embedding layer when attention body is used. --- src/neural/cuda/common_kernels.cu | 12 +- src/neural/cuda/kernels.h | 2 +- src/neural/cuda/layers.cc | 107 ++++++--- src/neural/cuda/layers.h | 36 ++- src/neural/cuda/network_cuda.cc | 360 +++++++++++++++--------------- 5 files changed, 297 insertions(+), 220 deletions(-) diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index a7d08f457d..60fb20c845 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -788,7 +788,7 @@ __device__ __forceinline__ float shared_sum_for_layer_norm(float x) { template __global__ void layer_norm_kernel(int N, int C, T* output, const T* input, const T* bias, const T* skip, const T* gammas, - const T* betas, float ep) { + const T* betas, float ep, float alpha) { int n = blockIdx.x * blockDim.z + threadIdx.z; if (n >= N) return; int c = (threadIdx.y * 32 + threadIdx.x) * 4; @@ -831,7 +831,7 @@ __global__ void layer_norm_kernel(int N, int C, T* output, const T* input, const float s = 0; if (!oobThread) for (int i = 0; i < 4; i++) { - val[i] += b[i] + sk[i]; + val[i] += b[i] + sk[i] * alpha; s += val[i]; } @@ -873,7 +873,7 @@ __global__ void layer_norm_kernel(int N, int C, T* output, const T* input, const // normalization is done across C dimension (i.e, sums and std deviations taken over elements in C dim) template void LayerNorm(int N, int C, T* output, const T* input, const T* bias, - const T* skip, const T* gammas, const T* betas, float ep, + const T* skip, const T* gammas, const T* betas, float ep, float alpha, cudaStream_t stream) { // process 4 elements per thread to achieve close to peak memory bandwidth if (C % 4 != 0) throw Exception("unsupported filter size"); @@ -889,7 +889,7 @@ void LayerNorm(int N, int C, T* output, const T* input, const T* bias, gridDim.z = 1; layer_norm_kernel<<>>( - N, C, output, input, bias, skip, gammas, betas, ep); + N, C, output, input, bias, skip, gammas, betas, ep, alpha); ReportCUDAErrors(cudaGetLastError()); } @@ -1251,11 +1251,11 @@ template void Softmax(int N, int C, float* output, const float* input, template void LayerNorm(int N, int C, half* output, const half* input, const half* bias, const half* skip, const half* gammas, const half* betas, float ep, - cudaStream_t stream); + float alpha, cudaStream_t stream); template void LayerNorm(int N, int C, float* output, const float* input, const float* bias, const float* skip, const float* gammas, const float* betas, - float ep, cudaStream_t stream); + float ep, float alpha, cudaStream_t stream); template void ComputePromotionLogits(int N, int C, half* output, const half* keys, const half* ppo, diff --git a/src/neural/cuda/kernels.h b/src/neural/cuda/kernels.h index e1038a8aee..a0bc2b9c06 100644 --- a/src/neural/cuda/kernels.h +++ b/src/neural/cuda/kernels.h @@ -123,7 +123,7 @@ void Softmax(int N, int C, T* output, const T* input, cudaStream_t stream); template void LayerNorm(int N, int C, T* output, const T* input, const T* bias, const T* skip, const T* gammas, const T* betas, float ep, - cudaStream_t stream); + float alpha, cudaStream_t stream); template void ComputePromotionLogits(int N, int C, T* output, const T* keys, diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 8beb0c7002..87aaad3d14 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -1331,8 +1331,12 @@ void allocAndUpload(DataType** gpu_dest, std::vector cpu_src, template AttentionPolicyHead::AttentionPolicyHead(BaseLayer* ip, const LegacyWeights& weights, - void* scratch) - : BaseLayer(64 * 64 + 24 * 8, 1, 1, ip) { + void* scratch, + bool attention_body, + ActivationFunction act) + : attention_body_(attention_body), + act_(attention_body ? act : SELU), // HACK : old networks without attention body (e.g: T79 use hardcoded SELU activations) + BaseLayer(64 * 64 + 24 * 8, 1, 1, ip) { embedding_op_size_ = weights.ip_pol_b.size(); wq_op_size_ = weights.ip2_pol_b.size(); wk_op_size_ = weights.ip3_pol_b.size(); @@ -1374,7 +1378,7 @@ AttentionPolicyHead::AttentionPolicyHead(BaseLayer* ip, for (const auto& enc : weights.pol_encoder) { EncoderBlock* pW = new EncoderBlock( - enc, scratch, encoder_heads_, embedding_op_size_); + enc, scratch, encoder_heads_, embedding_op_size_, 1.0f); // using alpha = 1 for now (TODO: may change?) encoder_weights_.emplace_back(pW); } } @@ -1382,7 +1386,8 @@ AttentionPolicyHead::AttentionPolicyHead(BaseLayer* ip, template EncoderBlock::EncoderBlock( const LegacyWeights::EncoderLayer& cpu_weights, void* scratch, int heads, - int size) : encoder_heads_(heads), embedding_op_size_(size) { + int size, float alpha) + : encoder_heads_(heads), embedding_op_size_(size), alpha_(alpha) { mha_q_size_ = cpu_weights.mha.q_b.size(); mha_k_size_ = cpu_weights.mha.k_b.size(); mha_v_size_ = cpu_weights.mha.v_b.size(); @@ -1483,12 +1488,10 @@ static void cublasXGemmStridedBatched( // input/output tensor is scratch1, others are used as scratch. // TODO: fix naming of scratch buffers template -void EncoderBlock::Eval(int N, DataType* scratch1, - DataType* scratch0, - DataType* scratch2, - DataType* scratch3, - cublasHandle_t cublas, - cudaStream_t stream) const { +void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, + DataType* scratch2, DataType* scratch3, + cublasHandle_t cublas, cudaStream_t stream, + ActivationFunction act) const { const int d_model = mha_q_size_; const int depth = d_model / encoder_heads_; @@ -1586,8 +1589,8 @@ void EncoderBlock::Eval(int N, DataType* scratch1, // LN1: skip connection and layer normalization (also bias add of prev gemm) // scratch2/scratch1 -> scratch0 LayerNorm(N * 64, embedding_op_size_, scratch0, scratch2, - mha_dense_b, scratch1, ln1_gammas, ln1_betas, - 1e-6, stream); + mha_dense_b, scratch1, ln1_gammas, ln1_betas, 1e-6, + alpha_, stream); // #FFN dense 1, scratch0 -> scratch1 const int encoder_dff = ffn_dense1_size_; @@ -1599,7 +1602,7 @@ void EncoderBlock::Eval(int N, DataType* scratch1, num_inputs, 1.0f, (const DataType*)ffn_dense1_w, num_inputs, scratch0, num_inputs, 0.0f, scratch1, num_outputs); addBiasBatched(scratch1, scratch1, ffn_dense1_b, 1, batch, num_outputs, - SELU, stream); + act, stream); } // #FFN dense 2, scratch1 -> scratch2 @@ -1615,8 +1618,8 @@ void EncoderBlock::Eval(int N, DataType* scratch1, // LN2: skip connection and layer normilization (also bias add of prev gemm) // scratch2/scratch0 -> scratch1 LayerNorm(N * 64, embedding_op_size_, scratch1, scratch2, - ffn_dense2_b, scratch0, ln2_gammas, ln2_betas, - 1e-6, stream); + ffn_dense2_b, scratch0, ln2_gammas, ln2_betas, 1e-6, + alpha_, stream); } template @@ -1629,8 +1632,10 @@ void AttentionPolicyHead::Eval( DataType* scratch2 = output + scratch_size / (2 * sizeof(DataType)); DataType* scratch3 = scratch1 + scratch_size / (2 * sizeof(DataType)); + int inputC = this->input_->GetC(); - convertNCHWtoNHWC(scratch0, input, N, inputC, N, inputC, 8, 8); + if (!attention_body_) + convertNCHWtoNHWC(scratch0, input, N, inputC, N, inputC, 8, 8); // 1. Policy embedding (fully connected layer) // Input data in NHWC layout N*(64)*C, output is N*(64)*embedding_op_size_ @@ -1641,15 +1646,15 @@ void AttentionPolicyHead::Eval( const int batch = N * 64; cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)ip_pol_w_, - num_inputs, scratch0, num_inputs, 0.0f, pol_embedding, - num_outputs); + num_inputs, attention_body_ ? input : scratch0, + num_inputs, 0.0f, pol_embedding, num_outputs); addBiasBatched(pol_embedding, pol_embedding, ip_pol_b_, 1, batch, - num_outputs, SELU, stream); + num_outputs, act_, stream); } // 2. Encoder layers for (const auto pEnc : encoder_weights_) { - pEnc->Eval(N, scratch1, scratch0, scratch2, scratch3, cublas, stream); + pEnc->Eval(N, scratch1, scratch0, scratch2, scratch3, cublas, stream, act_); } // End of encoder blocks DataType* wq; @@ -1738,23 +1743,59 @@ EncoderBlock::~EncoderBlock() { template -AttentionBody::AttentionBody(BaseLayer* ip, - const LegacyWeights& weights, +EmbeddingLayer::EmbeddingLayer(BaseLayer* ip, + const std::vector& weights, + const std::vector& biases, + void* scratch, + ActivationFunction act) + : BaseLayer(biases.size(), 8, 8, ip), act_(act) { + allocAndUpload(&weights_, weights, scratch); + allocAndUpload(&biases_, biases, scratch); +} + +template +EmbeddingLayer::~EmbeddingLayer() { + ReportCUDAErrors(cudaFree(weights_)); + ReportCUDAErrors(cudaFree(biases_)); +} + +template +void EmbeddingLayer::Eval( + int N, DataType* output, const DataType* input, const DataType* /*input2*/, + void* /*scratch*/, size_t /*scratch_size*/, cudnnHandle_t /*cudnn*/, + cublasHandle_t cublas, cudaStream_t stream) { + + const int num_outputs = this->GetC(); + const int num_inputs = this->input_->GetC(); + const int batch = N * 64; + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, weights_, + num_inputs, input, num_inputs, 0.0f, output, + num_outputs); + addBiasBatched(output, output, biases_, 1, batch, num_outputs, + act_, stream); +} + +template +AttentionBody::AttentionBody(const LegacyWeights& weights, void* scratch, ActivationFunction default_act, - int num_res_blocks) - : BaseLayer(64 * weights.ip_emb_b.size(), 1, 1, ip) { - embedding_op_size_ = weights.ip_emb_b.size(); - encoder_head_count_ = weights.encoder_head_count; - num_resi_blocks_ = num_res_blocks; - default_act_ = default_act; + int num_res_blocks, int input_c) + : embedding_op_size_(weights.ip_emb_b.size()), + encoder_head_count_(weights.encoder_head_count), + num_resi_blocks_(num_res_blocks), + default_act_(default_act), + input_c_(input_c), + BaseLayer(embedding_op_size_, 8, 8, nullptr) { allocAndUpload(&ip_emb_w_, weights.ip_emb_w, scratch); allocAndUpload(&ip_emb_b_, weights.ip_emb_b, scratch); + int num_encoders = weights.pol_encoder.size(); + float alpha = (float) pow(2.0 * num_encoders, 0.25); for (const auto& enc : weights.pol_encoder) { EncoderBlock* pW = new EncoderBlock( - enc, scratch, encoder_head_count_, embedding_op_size_); + enc, scratch, encoder_head_count_, embedding_op_size_, alpha); encoder_weights_.emplace_back(pW); } } @@ -1777,7 +1818,7 @@ void AttentionBody::Eval( DataType* scratch2 = output + scratch_size / (2 * sizeof(DataType)); DataType* scratch3 = scratch1 + scratch_size / (2 * sizeof(DataType)); - int inputC = this->input_->GetC(); + int inputC = input_c_; if (num_resi_blocks_ == 0) { assert(inputC == kInputPlanes); @@ -1818,7 +1859,8 @@ void AttentionBody::Eval( // 2. Encoder layers for (const auto pEnc : encoder_weights_) { - pEnc->Eval(N, scratch1, scratch0, scratch2, scratch3, cublas, stream); + pEnc->Eval(N, scratch1, scratch0, scratch2, scratch3, cublas, stream, + default_act_); } // End of encoder blocks @@ -1858,6 +1900,9 @@ template class EncoderBlock; template class AttentionBody; template class AttentionBody; +template class EmbeddingLayer; +template class EmbeddingLayer; + // Misc error handling stuff. #ifdef USE_CUDNN void CudnnError(cudnnStatus_t status, const char* file, const int& line) { diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index 097e196aba..af629af79e 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -335,12 +335,12 @@ class ResidualBlock : public BaseLayer { template class EncoderBlock { public: - EncoderBlock(const LegacyWeights::EncoderLayer& cpu_weights, void* scratch, int heads, int size); + EncoderBlock(const LegacyWeights::EncoderLayer& cpu_weights, void* scratch, int heads, int size, float alpha); ~EncoderBlock(); void Eval(int N, DataType* inpop, DataType* scratch0, DataType* scratch1, DataType* scratch2, cublasHandle_t cublas, - cudaStream_t stream) const; + cudaStream_t stream, ActivationFunction act) const; // all GPU side pointers DataType *mha_q_w, *mha_q_b; @@ -366,6 +366,8 @@ class EncoderBlock { int embedding_op_size_; int encoder_heads_; + + float alpha_; // scale to apply to skip connection add }; // The Attention policy head implementation @@ -382,7 +384,7 @@ class AttentionPolicyHead : public BaseLayer { public: AttentionPolicyHead(BaseLayer* ip, const LegacyWeights& weights, - void* scratch); + void* scratch, bool attention_body, ActivationFunction act); ~AttentionPolicyHead(); void Eval(int N, DataType* output, const DataType* input, const DataType* input2, void* scratch, size_t scratch_size, @@ -404,10 +406,32 @@ class AttentionPolicyHead : public BaseLayer { int encoder_heads_; int policy_d_model_; + bool attention_body_; + ActivationFunction act_; std::vector*> encoder_weights_; }; +template +class EmbeddingLayer : public BaseLayer { + using BaseLayer::C; + using BaseLayer::H; + using BaseLayer::W; + +public: + EmbeddingLayer(BaseLayer* ip, const std::vector& weights, + const std::vector& biases, void* scratch, + ActivationFunction activation); + ~EmbeddingLayer(); + + void Eval(int N, DataType* output, const DataType* input, + const DataType* input2, void* scratch, size_t scratch_size, + cudnnHandle_t cudnn, cublasHandle_t cublas, + cudaStream_t stream) override; + private: + DataType *weights_, *biases_; + ActivationFunction act_; +}; // The Attention body implementation // Responsible for loading weights into GPU memory, and evaluating the entire @@ -422,8 +446,9 @@ class AttentionBody : public BaseLayer { using BaseLayer::GetW; public: - AttentionBody(BaseLayer* ip, const LegacyWeights& weights, - void* scratch, ActivationFunction default_act, int num_res_blocks); + AttentionBody(const LegacyWeights& weights, void* scratch, + ActivationFunction default_act, int num_res_blocks, + int input_c); ~AttentionBody(); void Eval(int N, DataType* output, const DataType* input, const DataType* input2, void* scratch, size_t scratch_size, @@ -438,6 +463,7 @@ class AttentionBody : public BaseLayer { std::vector*> encoder_weights_; ActivationFunction default_act_; int num_resi_blocks_; + int input_c_; }; diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 2fb036c0aa..7b0ba27ad9 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -226,12 +226,11 @@ class CudaNetwork : public Network { const int kNumFilters = (int)weights.input.biases.size(); numBlocks_ = (int)weights.residual.size(); - attn_body_ = (weights.ip_emb_b.size() > 0); + num_encoder_blocks_ = (int) weights.encoder.size(); + attn_body_ = (num_encoder_blocks_ > 0); if (attn_body_) { - if (numBlocks_ > 0) - throw Exception("Found residual blocks in network with attention body!"); + assert(weights.ip_emb_b.size() > 0); } - num_encoder_blocks_ = (int) weights.encoder.size(); // Ankan - test! printf("\nNum filters: %d, num blocks: %d\n", kNumFilters, numBlocks_); @@ -305,18 +304,16 @@ class CudaNetwork : public Network { const bool mish_net = file.format().network_format().default_activation() == pblczero::NetworkFormat::DEFAULT_ACTIVATION_MISH; + ActivationFunction act = mish_net ? MISH : RELU; + // 2. Build the network, and copy the weights to GPU memory. - if (attn_body_) { - auto body = std::make_unique>( - getLastLayer(), weights, scratch_mem_, mish_net ? MISH : RELU, - numBlocks_); - network_.emplace_back(std::move(body)); - } else { + // Input conv only used if there are residual blocks in the network + if (numBlocks_ > 0) { // Input. { auto inputConv = std::make_unique>( - nullptr, kNumFilters, 8, 8, kNumInputPlanes, mish_net ? MISH : RELU, + nullptr, kNumFilters, 8, 8, kNumInputPlanes, act, true, false, false, 0, use_gemm_ex, use_res_block_winograd_fuse_opt_); inputConv->LoadWeights(&weights.input.weights[0], @@ -332,7 +329,7 @@ class CudaNetwork : public Network { if (use_res_block_winograd_fuse_opt_) { auto layer = std::make_unique>( getLastLayer(), kNumFilters, has_se, se_k, use_gemm_ex, - block == 0, block == (numBlocks_ - 1), mish_net ? MISH : RELU, + block == 0, block == (numBlocks_ - 1), act, deviceProp.sharedMemPerBlockOptin); layer->LoadWeights0(&weights.residual[block].conv1.weights[0], &weights.residual[block].conv1.biases[0], @@ -349,16 +346,16 @@ class CudaNetwork : public Network { network_.emplace_back(std::move(layer)); } else { auto conv1 = std::make_unique>( - getLastLayer(), kNumFilters, 8, 8, kNumFilters, - mish_net ? MISH : RELU, true, false, false, 0, use_gemm_ex); + getLastLayer(), kNumFilters, 8, 8, kNumFilters, act, true, false, + false, 0, use_gemm_ex); conv1->LoadWeights(&weights.residual[block].conv1.weights[0], &weights.residual[block].conv1.biases[0], scratch_mem_); network_.emplace_back(std::move(conv1)); auto conv2 = std::make_unique>( - getLastLayer(), kNumFilters, 8, 8, kNumFilters, - mish_net ? MISH : RELU, true, true, has_se, se_k, use_gemm_ex); + getLastLayer(), kNumFilters, 8, 8, kNumFilters, act, true, true, + has_se, se_k, use_gemm_ex); conv2->LoadWeights(&weights.residual[block].conv2.weights[0], &weights.residual[block].conv2.biases[0], scratch_mem_); @@ -371,15 +368,22 @@ class CudaNetwork : public Network { network_.emplace_back(std::move(conv2)); } } + resi_last_ = getLastLayer(); } - resi_last_ = getLastLayer(); + if (attn_body_) { + auto attention_body = std::make_unique>( + weights, scratch_mem_, act, numBlocks_, + numBlocks_ > 0 ? kNumFilters : kInputPlanes); + network_.emplace_back(std::move(attention_body)); + encoder_last_ = getLastLayer(); + } // Policy head. if (attn_policy_) { auto AttentionPolicy = std::make_unique>( - getLastLayer(), weights, scratch_mem_); + getLastLayer(), weights, scratch_mem_, attn_body_, act); network_.emplace_back(std::move(AttentionPolicy)); auto policymap = std::make_unique>( @@ -388,8 +392,9 @@ class CudaNetwork : public Network { network_.emplace_back(std::move(policymap)); } else if (conv_policy_) { + assert(!attn_body_); // not supported with attention body auto conv1 = std::make_unique>( - resi_last_, kNumFilters, 8, 8, kNumFilters, mish_net ? MISH : RELU, + resi_last_, kNumFilters, 8, 8, kNumFilters, act, true, false, false, 0, use_gemm_ex); conv1->LoadWeights(&weights.policy1.weights[0], &weights.policy1.biases[0], scratch_mem_); @@ -411,9 +416,10 @@ class CudaNetwork : public Network { network_.emplace_back(std::move(policymap)); } else { + assert(!attn_body_); // not supported with attention body auto convPol = std::make_unique>( - resi_last_, weights.policy.biases.size(), 8, 8, kNumFilters, - mish_net ? MISH : RELU, true, use_gemm_ex); + resi_last_, weights.policy.biases.size(), 8, 8, kNumFilters, act, + true, use_gemm_ex); convPol->LoadWeights(&weights.policy.weights[0], &weights.policy.biases[0], scratch_mem_); network_.emplace_back(std::move(convPol)); @@ -426,18 +432,23 @@ class CudaNetwork : public Network { } // Value head. - if (!attn_body_) { - auto convVal = std::make_unique>( - resi_last_, weights.value.biases.size(), 8, 8, kNumFilters, - mish_net ? MISH : RELU, true, use_gemm_ex); - convVal->LoadWeights(&weights.value.weights[0], &weights.value.biases[0], - scratch_mem_); - network_.emplace_back(std::move(convVal)); + if (attn_body_) { + auto embedded_val = std::make_unique>( + encoder_last_, weights.ip_val_w, weights.ip_val_b, scratch_mem_, + act); + network_.emplace_back(std::move(embedded_val)); + } else { + auto convVal = std::make_unique>( + resi_last_, weights.value.biases.size(), 8, 8, kNumFilters, act, + true, use_gemm_ex); + convVal->LoadWeights(&weights.value.weights[0], + &weights.value.biases[0], scratch_mem_); + network_.emplace_back(std::move(convVal)); + } auto FCVal1 = std::make_unique>( - getLastLayer(), weights.ip1_val_b.size(), 1, 1, true, - mish_net ? MISH : RELU); + getLastLayer(), weights.ip1_val_b.size(), 1, 1, true, act); FCVal1->LoadWeights(&weights.ip1_val_w[0], &weights.ip1_val_b[0], scratch_mem_); network_.emplace_back(std::move(FCVal1)); @@ -458,17 +469,22 @@ class CudaNetwork : public Network { moves_left_ = (file.format().network_format().moves_left() == pblczero::NetworkFormat::MOVES_LEFT_V1) && options.GetOrDefault("mlh", true); - if ((!attn_body_) && moves_left_) { - auto convMov = std::make_unique>( - resi_last_, weights.moves_left.biases.size(), 8, 8, kNumFilters, - mish_net ? MISH : RELU, true, use_gemm_ex); - convMov->LoadWeights(&weights.moves_left.weights[0], - &weights.moves_left.biases[0], scratch_mem_); - network_.emplace_back(std::move(convMov)); - + if (moves_left_) { + if (attn_body_) { + auto embedded_mov = std::make_unique>( + encoder_last_, weights.ip_mov_w, weights.ip_mov_b, scratch_mem_, + act); + network_.emplace_back(std::move(embedded_mov)); + } else { + auto convMov = std::make_unique>( + resi_last_, weights.moves_left.biases.size(), 8, 8, kNumFilters, + act, true, use_gemm_ex); + convMov->LoadWeights(&weights.moves_left.weights[0], + &weights.moves_left.biases[0], scratch_mem_); + network_.emplace_back(std::move(convMov)); + } auto FCMov1 = std::make_unique>( - getLastLayer(), weights.ip1_mov_b.size(), 1, 1, true, - mish_net ? MISH : RELU); + getLastLayer(), weights.ip1_mov_b.size(), 1, 1, true, act); FCMov1->LoadWeights(&weights.ip1_mov_w[0], &weights.ip1_mov_b[0], scratch_mem_); network_.emplace_back(std::move(FCMov1)); @@ -556,34 +572,11 @@ class CudaNetwork : public Network { int l = 0; - if (attn_body_) { - // 1. Entire Attention network body (including value and wdl heads) - network_[l++]->Eval( - batchSize, tensor_mem[2], tensor_mem[0], tensor_mem[1], scratch_mem, - scratch_size_, nullptr, cublas, - stream); - - // 2. Attention policy head - network_[l++]->Eval( - batchSize, tensor_mem[0], tensor_mem[2], tensor_mem[1], scratch_mem, - scratch_size_, nullptr, cublas, - stream); // Entire Attention policy head except for the policy map + DataType* flow = tensor_mem[0]; + DataType* spare1 = tensor_mem[1]; + DataType* spare2 = tensor_mem[2]; - // 3. Policy map layer - if (fp16) { - network_[l++]->Eval(batchSize, tensor_mem[1], tensor_mem[0], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // policy map layer - copyTypeConverted(opPol, (half*)(tensor_mem[1]), - batchSize * kNumOutputPolicy, - stream); // POLICY output - } else { - network_[l++]->Eval(batchSize, (DataType*)opPol, tensor_mem[0], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // policy map layer // POLICY output - } - } else { - // Old CNN style networks + if (numBlocks_ > 0) { // Input. network_[l++]->Eval( batchSize, @@ -608,134 +601,146 @@ class CudaNetwork : public Network { } } - // Policy head. - if (attn_policy_) { - network_[l++]->Eval( - batchSize, tensor_mem[0], tensor_mem[2], tensor_mem[1], scratch_mem, - scratch_size_, nullptr, cublas, - stream); // Entire Attention policy head except for the policy map - if (fp16) { - network_[l++]->Eval(batchSize, tensor_mem[1], tensor_mem[0], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // policy map layer - copyTypeConverted(opPol, (half*)(tensor_mem[1]), - batchSize * kNumOutputPolicy, - stream); // POLICY output - } else { - network_[l++]->Eval(batchSize, (DataType*)opPol, tensor_mem[0], - nullptr, scratch_mem, scratch_size_, nullptr, - cublas, - stream); // policy map layer // POLICY output - } + flow = tensor_mem[2]; + spare1 = tensor_mem[0]; + spare2 = tensor_mem[1]; + } - } else if (conv_policy_) { - network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[2], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // policy conv1 - network_[l++]->Eval(batchSize, tensor_mem[1], tensor_mem[0], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // policy conv2 + if (attn_body_) { + network_[l++]->Eval(batchSize, tensor_mem[1], + (numBlocks_ > 0) ? tensor_mem[2] : tensor_mem[0], + (numBlocks_ > 0) ? tensor_mem[0] : tensor_mem[2], + scratch_mem, scratch_size_, nullptr, + cublas, stream); // Entire attention body of the network + + flow = tensor_mem[1]; + spare1 = tensor_mem[0]; + spare2 = tensor_mem[2]; + } - if (fp16) { - network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[1], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // policy map layer - copyTypeConverted(opPol, (half*)(tensor_mem[0]), - batchSize * kNumOutputPolicy, - stream); // POLICY output - } else { - network_[l++]->Eval(batchSize, (DataType*)opPol, tensor_mem[1], - nullptr, scratch_mem, scratch_size_, nullptr, - cublas, - stream); // policy map layer // POLICY output - } + // Policy head. + if (attn_policy_) { + network_[l++]->Eval( + batchSize, spare1, flow, spare2, scratch_mem, + scratch_size_, nullptr, cublas, + stream); // Entire Attention policy head except for the policy map + if (fp16) { + network_[l++]->Eval(batchSize, spare2, spare1, nullptr, + scratch_mem, scratch_size_, nullptr, cublas, + stream); // policy map layer + copyTypeConverted(opPol, (half*)spare2, + batchSize * kNumOutputPolicy, + stream); // POLICY output } else { - network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[2], nullptr, + network_[l++]->Eval(batchSize, (DataType*)opPol, spare1, nullptr, scratch_mem, scratch_size_, nullptr, cublas, - stream); // pol conv - - if (fp16) { - network_[l++]->Eval(batchSize, tensor_mem[1], tensor_mem[0], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // pol FC - - copyTypeConverted(opPol, (half*)(tensor_mem[1]), - batchSize * kNumOutputPolicy, stream); // POLICY - } else { - network_[l++]->Eval(batchSize, (DataType*)opPol, tensor_mem[0], - nullptr, scratch_mem, scratch_size_, nullptr, - cublas, - stream); // pol FC // POLICY - } + stream); // policy map layer // POLICY output } - // Copy policy output from device memory to host memory. - ReportCUDAErrors( - cudaMemcpyAsync(io->op_policy_mem_, io->op_policy_mem_gpu_, - sizeof(float) * kNumOutputPolicy * batchSize, - cudaMemcpyDeviceToHost, stream)); - - // value head - network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[2], nullptr, + } else if (conv_policy_) { + network_[l++]->Eval(batchSize, spare1, flow, nullptr, scratch_mem, scratch_size_, nullptr, cublas, - stream); // value conv + stream); // policy conv1 - network_[l++]->Eval(batchSize, tensor_mem[1], tensor_mem[0], nullptr, + network_[l++]->Eval(batchSize, spare2, spare1, nullptr, scratch_mem, scratch_size_, nullptr, cublas, - stream); // value FC1 + stream); // policy conv2 - if (wdl_) { - if (fp16) { - network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[1], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // value FC2 // VALUE - copyTypeConverted(opVal, (half*)(tensor_mem[0]), 3 * batchSize, - stream); // VALUE - } else { - network_[l++]->Eval(batchSize, (DataType*)opVal, tensor_mem[1], - nullptr, scratch_mem, scratch_size_, nullptr, - cublas, - stream); // value FC2 // VALUE - } + if (fp16) { + network_[l++]->Eval(batchSize, spare1, spare2, nullptr, + scratch_mem, scratch_size_, nullptr, cublas, + stream); // policy map layer + copyTypeConverted(opPol, (half*)(spare1), + batchSize * kNumOutputPolicy, + stream); // POLICY output } else { - if (fp16) { - // TODO: consider fusing the bias-add of FC2 with format conversion. - network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[1], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // value FC2 - copyTypeConverted(opVal, (half*)(tensor_mem[0]), batchSize, - stream); // VALUE - } else { - network_[l++]->Eval(batchSize, (DataType*)opVal, tensor_mem[1], - nullptr, scratch_mem, scratch_size_, nullptr, - cublas, - stream); // value FC2 // VALUE - } + network_[l++]->Eval(batchSize, (DataType*)opPol, spare2, nullptr, + scratch_mem, scratch_size_, nullptr, cublas, + stream); // policy map layer // POLICY output } + } else { + network_[l++]->Eval(batchSize, spare1, flow, nullptr, + scratch_mem, scratch_size_, nullptr, cublas, + stream); // pol conv - if (moves_left_) { - // Moves left head - network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[2], nullptr, + if (fp16) { + network_[l++]->Eval(batchSize, spare2, spare1, nullptr, scratch_mem, scratch_size_, nullptr, cublas, - stream); // moves conv + stream); // pol FC - network_[l++]->Eval(batchSize, tensor_mem[1], tensor_mem[0], nullptr, + copyTypeConverted(opPol, (half*)(spare2), + batchSize * kNumOutputPolicy, stream); // POLICY + } else { + network_[l++]->Eval(batchSize, (DataType*)opPol, spare1, nullptr, scratch_mem, scratch_size_, nullptr, cublas, - stream); // moves FC1 + stream); // pol FC // POLICY + } + } - // Moves left FC2 - if (fp16) { - // TODO: consider fusing the bias-add of FC2 with format conversion. - network_[l++]->Eval(batchSize, tensor_mem[0], tensor_mem[1], nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); - copyTypeConverted(opMov, (half*)(tensor_mem[0]), batchSize, stream); - } else { - network_[l++]->Eval(batchSize, (DataType*)opMov, tensor_mem[1], - nullptr, scratch_mem, scratch_size_, nullptr, - cublas, stream); - } + // Copy policy output from device memory to host memory. + ReportCUDAErrors( + cudaMemcpyAsync(io->op_policy_mem_, io->op_policy_mem_gpu_, + sizeof(float) * kNumOutputPolicy * batchSize, + cudaMemcpyDeviceToHost, stream)); + + // value head + network_[l++]->Eval(batchSize, spare1, flow, nullptr, + scratch_mem, scratch_size_, nullptr, cublas, + stream); // value conv or embedding + + network_[l++]->Eval(batchSize, spare2, spare1, nullptr, + scratch_mem, scratch_size_, nullptr, cublas, + stream); // value FC1 + + if (wdl_) { + if (fp16) { + network_[l++]->Eval(batchSize, spare1, spare2, nullptr, + scratch_mem, scratch_size_, nullptr, cublas, + stream); // value FC2 // VALUE + copyTypeConverted(opVal, (half*)spare1, 3 * batchSize, + stream); // VALUE + } else { + network_[l++]->Eval(batchSize, (DataType*)opVal, spare2, nullptr, + scratch_mem, scratch_size_, nullptr, cublas, + stream); // value FC2 // VALUE + } + } else { + if (fp16) { + // TODO: consider fusing the bias-add of FC2 with format conversion. + network_[l++]->Eval(batchSize, spare1, spare2, nullptr, + scratch_mem, scratch_size_, nullptr, cublas, + stream); // value FC2 + copyTypeConverted(opVal, (half*)(spare1), batchSize, + stream); // VALUE + } else { + network_[l++]->Eval(batchSize, (DataType*)opVal, spare2, nullptr, + scratch_mem, scratch_size_, nullptr, cublas, + stream); // value FC2 // VALUE + } + } + + if (moves_left_) { + // Moves left head + network_[l++]->Eval(batchSize, spare1, flow, nullptr, + scratch_mem, scratch_size_, nullptr, cublas, + stream); // moves conv or embedding + + network_[l++]->Eval(batchSize, spare2, spare1, nullptr, + scratch_mem, scratch_size_, nullptr, cublas, + stream); // moves FC1 + + // Moves left FC2 + if (fp16) { + // TODO: consider fusing the bias-add of FC2 with format conversion. + network_[l++]->Eval(batchSize, spare1, spare2, nullptr, + scratch_mem, scratch_size_, nullptr, cublas, + stream); + copyTypeConverted(opMov, (half*)(spare1), batchSize, stream); + } else { + network_[l++]->Eval(batchSize, (DataType*)opMov, spare2, nullptr, + scratch_mem, scratch_size_, nullptr, cublas, + stream); } } @@ -838,6 +843,7 @@ class CudaNetwork : public Network { BaseLayer* getLastLayer() { return network_.back().get(); } BaseLayer* resi_last_; + BaseLayer* encoder_last_; size_t tensor_mem_size_; size_t scratch_size_; From 49a814325fa132406796459c36d2989b3e4815bb Mon Sep 17 00:00:00 2001 From: Ankan Banerjee Date: Wed, 23 Mar 2022 17:26:36 +0530 Subject: [PATCH 03/70] fix few crashes --- src/neural/cuda/layers.cc | 4 ++-- src/neural/cuda/network_cuda.cc | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 87aaad3d14..bc1bcba2a3 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -56,7 +56,7 @@ BaseLayer::BaseLayer(int c, int h, int w, BaseLayer* ip, bool nhwc, template BaseLayer::BaseLayer(int c, int h, int w, BaseLayer* ip) - : input_(ip), C(c), H(h), W(w), nhwc_(ip->nhwc_), use_gemm_ex_(false) {} + : input_(ip), C(c), H(h), W(w), nhwc_(ip ? ip->nhwc_ : false), use_gemm_ex_(false) {} #ifdef USE_CUDNN template @@ -1786,7 +1786,7 @@ AttentionBody::AttentionBody(const LegacyWeights& weights, num_resi_blocks_(num_res_blocks), default_act_(default_act), input_c_(input_c), - BaseLayer(embedding_op_size_, 8, 8, nullptr) { + BaseLayer(weights.ip_emb_b.size(), 8, 8, nullptr) { allocAndUpload(&ip_emb_w_, weights.ip_emb_w, scratch); allocAndUpload(&ip_emb_b_, weights.ip_emb_b, scratch); diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 7b0ba27ad9..5c8f21f88c 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -271,7 +271,7 @@ class CudaNetwork : public Network { // 0. Check for SE. has_se_ = false; - if (weights.residual[0].has_se) { + if (numBlocks_ && weights.residual[0].has_se) { has_se_ = true; } From 7d96c6382ff619f53632847203e65ec1b9193ebc Mon Sep 17 00:00:00 2001 From: Ankan Banerjee Date: Wed, 23 Mar 2022 17:41:52 +0530 Subject: [PATCH 04/70] use the right encoder block for body! --- src/neural/cuda/layers.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index bc1bcba2a3..c8135702fa 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -1791,9 +1791,9 @@ AttentionBody::AttentionBody(const LegacyWeights& weights, allocAndUpload(&ip_emb_w_, weights.ip_emb_w, scratch); allocAndUpload(&ip_emb_b_, weights.ip_emb_b, scratch); - int num_encoders = weights.pol_encoder.size(); + int num_encoders = weights.encoder.size(); float alpha = (float) pow(2.0 * num_encoders, 0.25); - for (const auto& enc : weights.pol_encoder) { + for (const auto& enc : weights.encoder) { EncoderBlock* pW = new EncoderBlock( enc, scratch, encoder_head_count_, embedding_op_size_, alpha); encoder_weights_.emplace_back(pW); From f5fe73779bc3289a783d120fe187735afa260841 Mon Sep 17 00:00:00 2001 From: Ankan Banerjee Date: Wed, 23 Mar 2022 21:21:31 +0530 Subject: [PATCH 05/70] fix output of AttentionBody --- src/neural/cuda/common_kernels.cu | 4 +++ src/neural/cuda/layers.cc | 42 +++++++++++++++++++++++++------ src/neural/cuda/network_cudnn.cc | 31 ----------------------- 3 files changed, 38 insertions(+), 39 deletions(-) diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index 60fb20c845..52f161fff6 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -150,6 +150,10 @@ void addBiasBatched(T* output, const T* input, const T* bias, int Batch, int N, addBiasBatched_kernel <<>>(output, input, bias, N, C); break; + case RELU: + addBiasBatched_kernel + <<>>(output, input, bias, N, C); + break; default: throw Exception( "unsupported activation in addBiasBatched. Add in switch-case here"); diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index c8135702fa..9c0958d541 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -35,8 +35,37 @@ #include "utils/fp16_utils.h" namespace lczero { -// void dumpTensor(void* memory, int elements, const char* message, bool fp16 = -// false); + +#if 0 +// debug code to dump allocation in GPU memory +void dumpTensor(void *memory, int elements, const char *message, bool fp16 = true) +{ + printf("\n%s\n", message); + int elementSize = (int) (fp16 ? sizeof(half) : sizeof(float)); + int bytes = elements * elementSize; + void *temp = malloc(bytes); + cudaMemcpy(temp, memory, bytes, cudaMemcpyDeviceToHost); + + for (int i = 0; i < elements; i++) + { + float val; + if (fp16) + { + half *arr = (half*)temp; + val = (float)arr[i]; + } + else + { + float *arr = (float *)temp; + val = arr[i]; + } + printf("%8.4f ", val); + if ((i % 8) == 7) printf("\n"); + } + free(temp); + printf("\n"); +} +#endif namespace cudnn_backend { @@ -1632,7 +1661,6 @@ void AttentionPolicyHead::Eval( DataType* scratch2 = output + scratch_size / (2 * sizeof(DataType)); DataType* scratch3 = scratch1 + scratch_size / (2 * sizeof(DataType)); - int inputC = this->input_->GetC(); if (!attention_body_) convertNCHWtoNHWC(scratch0, input, N, inputC, N, inputC, 8, 8); @@ -1814,9 +1842,9 @@ void AttentionBody::Eval( void* scratch, size_t scratch_size, cudnnHandle_t /*cudnn*/, cublasHandle_t cublas, cudaStream_t stream) { DataType* scratch0 = (DataType*)scratch; - DataType* scratch1 = (DataType*)input2; - DataType* scratch2 = output + scratch_size / (2 * sizeof(DataType)); - DataType* scratch3 = scratch1 + scratch_size / (2 * sizeof(DataType)); + DataType* scratch1 = (DataType*)output; + DataType* scratch2 = (DataType*)input2; + DataType* scratch3 = scratch2 + scratch_size / (2 * sizeof(DataType)); int inputC = input_c_; if (num_resi_blocks_ == 0) @@ -1862,8 +1890,6 @@ void AttentionBody::Eval( pEnc->Eval(N, scratch1, scratch0, scratch2, scratch3, cublas, stream, default_act_); } // End of encoder blocks - - } diff --git a/src/neural/cuda/network_cudnn.cc b/src/neural/cuda/network_cudnn.cc index 4a0e3016ca..df4bdfed4c 100644 --- a/src/neural/cuda/network_cudnn.cc +++ b/src/neural/cuda/network_cudnn.cc @@ -46,37 +46,6 @@ namespace lczero { using namespace cudnn_backend; -#if 0 -// debug code to dump allocation in GPU memory -void dumpTensor(void *memory, int elements, const char *message, bool fp16 = false) -{ - printf("\n%s\n", message); - int elementSize = (int) (fp16 ? sizeof(half) : sizeof(float)); - int bytes = elements * elementSize; - void *temp = malloc(bytes); - cudaMemcpy(temp, memory, bytes, cudaMemcpyDeviceToHost); - - for (int i = 0; i < elements; i++) - { - float val; - if (fp16) - { - half *arr = (half*)temp; - val = (float)arr[i]; - } - else - { - float *arr = (float *)temp; - val = arr[i]; - } - printf("%8.4f ", val); - if ((i % 8) == 7) printf("\n"); - } - free(temp); - printf("\n"); -} -#endif - template class CudnnNetwork; From 72cbc13dfb7141e6cdc3d1bd7d2c2a7f526bec02 Mon Sep 17 00:00:00 2001 From: Ankan Banerjee Date: Fri, 25 Mar 2022 20:02:19 +0530 Subject: [PATCH 06/70] move pos encoding table to common header file - also remove hardcoding. --- src/neural/cuda/common_kernels.cu | 40 +-- src/neural/cuda/layers.cc | 10 +- src/neural/shared/attention_policy_map.h | 384 +++++++++++++++++++++-- 3 files changed, 370 insertions(+), 64 deletions(-) diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index 52f161fff6..45829b8854 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -30,6 +30,8 @@ #include "cuda_common.h" #include "winograd_helper.inc" +#include "neural/shared/attention_policy_map.h" + namespace lczero { namespace cudnn_backend { namespace { @@ -992,40 +994,6 @@ void ComputePromotionLogits(int N, int C, T* output, const T* keys, <<>>(C, output, keys, ppo, policy_attn_logits); } -__device__ constexpr float kPosEncoding[64][6] = { - {0., 0., 0., 0., 0., 0.}, {0., 0., 0., 0., 0., 1.}, - {0., 0., 0., 0., 1., 0.}, {0., 0., 0., 0., 1., 1.}, - {0., 0., 0., 1., 0., 0.}, {0., 0., 0., 1., 0., 1.}, - {0., 0., 0., 1., 1., 0.}, {0., 0., 0., 1., 1., 1.}, - {0., 0., 1., 0., 0., 0.}, {0., 0., 1., 0., 0., 1.}, - {0., 0., 1., 0., 1., 0.}, {0., 0., 1., 0., 1., 1.}, - {0., 0., 1., 1., 0., 0.}, {0., 0., 1., 1., 0., 1.}, - {0., 0., 1., 1., 1., 0.}, {0., 0., 1., 1., 1., 1.}, - {0., 1., 0., 0., 0., 0.}, {0., 1., 0., 0., 0., 1.}, - {0., 1., 0., 0., 1., 0.}, {0., 1., 0., 0., 1., 1.}, - {0., 1., 0., 1., 0., 0.}, {0., 1., 0., 1., 0., 1.}, - {0., 1., 0., 1., 1., 0.}, {0., 1., 0., 1., 1., 1.}, - {0., 1., 1., 0., 0., 0.}, {0., 1., 1., 0., 0., 1.}, - {0., 1., 1., 0., 1., 0.}, {0., 1., 1., 0., 1., 1.}, - {0., 1., 1., 1., 0., 0.}, {0., 1., 1., 1., 0., 1.}, - {0., 1., 1., 1., 1., 0.}, {0., 1., 1., 1., 1., 1.}, - {1., 0., 0., 0., 0., 0.}, {1., 0., 0., 0., 0., 1.}, - {1., 0., 0., 0., 1., 0.}, {1., 0., 0., 0., 1., 1.}, - {1., 0., 0., 1., 0., 0.}, {1., 0., 0., 1., 0., 1.}, - {1., 0., 0., 1., 1., 0.}, {1., 0., 0., 1., 1., 1.}, - {1., 0., 1., 0., 0., 0.}, {1., 0., 1., 0., 0., 1.}, - {1., 0., 1., 0., 1., 0.}, {1., 0., 1., 0., 1., 1.}, - {1., 0., 1., 1., 0., 0.}, {1., 0., 1., 1., 0., 1.}, - {1., 0., 1., 1., 1., 0.}, {1., 0., 1., 1., 1., 1.}, - {1., 1., 0., 0., 0., 0.}, {1., 1., 0., 0., 0., 1.}, - {1., 1., 0., 0., 1., 0.}, {1., 1., 0., 0., 1., 1.}, - {1., 1., 0., 1., 0., 0.}, {1., 1., 0., 1., 0., 1.}, - {1., 1., 0., 1., 1., 0.}, {1., 1., 0., 1., 1., 1.}, - {1., 1., 1., 0., 0., 0.}, {1., 1., 1., 0., 0., 1.}, - {1., 1., 1., 0., 1., 0.}, {1., 1., 1., 0., 1., 1.}, - {1., 1., 1., 1., 0., 0.}, {1., 1., 1., 1., 0., 1.}, - {1., 1., 1., 1., 1., 0.}, {1., 1., 1., 1., 1., 1.}}; - template __global__ void preprocess_for_attention_body_kernel(T* output, const T* input) { int n = blockIdx.x; @@ -1042,7 +1010,7 @@ __global__ void preprocess_for_attention_body_kernel(T* output, const T* input) op = input[n * 64 * kInputPlanes + c * 64 + hw]; // nchw } - constexpr int outputC = kInputPlanes + 6; + constexpr int outputC = kInputPlanes + kNumPosEncodingChannels; // convert to nhwc output[n * 64 * outputC + hw * outputC + c] = op; @@ -1055,7 +1023,7 @@ void inputPreprocessForAttentionBody(T* output, const T* input, int N, // (kInputPlanes + 6) threads // Each thread computes a single output element dim3 gridSize = dim3(N, 64); - int blockSize = kInputPlanes + 6; + int blockSize = kInputPlanes + kNumPosEncodingChannels; preprocess_for_attention_body_kernel <<>>(output, input); } diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 9c0958d541..e96426e8a6 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -33,13 +33,15 @@ #include "cuda_common.h" #include "kernels.h" #include "utils/fp16_utils.h" +#include "neural/shared/attention_policy_map.h" namespace lczero { #if 0 // debug code to dump allocation in GPU memory -void dumpTensor(void *memory, int elements, const char *message, bool fp16 = true) -{ +template +void dumpTensor(T* memory, int elements, const char* message) { + const bool fp16 = std::is_same::value; printf("\n%s\n", message); int elementSize = (int) (fp16 ? sizeof(half) : sizeof(float)); int bytes = elements * elementSize; @@ -1792,7 +1794,6 @@ void EmbeddingLayer::Eval( int N, DataType* output, const DataType* input, const DataType* /*input2*/, void* /*scratch*/, size_t /*scratch_size*/, cudnnHandle_t /*cudnn*/, cublasHandle_t cublas, cudaStream_t stream) { - const int num_outputs = this->GetC(); const int num_inputs = this->input_->GetC(); const int batch = N * 64; @@ -1862,7 +1863,7 @@ void AttentionBody::Eval( axis=2) */ inputPreprocessForAttentionBody(scratch0, input, N, stream); - inputC += 6; + inputC += kNumPosEncodingChannels; } else { // #redirect flow through encoder blocks // flow = tf.transpose(flow, perm = [ 0, 2, 3, 1 ]) @@ -1890,6 +1891,7 @@ void AttentionBody::Eval( pEnc->Eval(N, scratch1, scratch0, scratch2, scratch3, cublas, stream, default_act_); } // End of encoder blocks + } diff --git a/src/neural/shared/attention_policy_map.h b/src/neural/shared/attention_policy_map.h index 5ab0966654..df39df7dc9 100644 --- a/src/neural/shared/attention_policy_map.h +++ b/src/neural/shared/attention_policy_map.h @@ -380,30 +380,366 @@ const short kAttnPolicyMap[] = { 1848, 1849, 1850, 1851, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1852, 1853, 1854, 1855, 1856, 1857}; - -} // namespace lczero - - - - - - - - - - - - - - - - - - - - - - +#if 0 +// used only by Arcturai's T02 network +constexpr int kNumPosEncodingChannels = 6; +__device__ constexpr float kPosEncoding[64][kNumPosEncodingChannels] = { + {0., 0., 0., 0., 0., 0.}, {0., 0., 0., 0., 0., 1.}, + {0., 0., 0., 0., 1., 0.}, {0., 0., 0., 0., 1., 1.}, + {0., 0., 0., 1., 0., 0.}, {0., 0., 0., 1., 0., 1.}, + {0., 0., 0., 1., 1., 0.}, {0., 0., 0., 1., 1., 1.}, + {0., 0., 1., 0., 0., 0.}, {0., 0., 1., 0., 0., 1.}, + {0., 0., 1., 0., 1., 0.}, {0., 0., 1., 0., 1., 1.}, + {0., 0., 1., 1., 0., 0.}, {0., 0., 1., 1., 0., 1.}, + {0., 0., 1., 1., 1., 0.}, {0., 0., 1., 1., 1., 1.}, + {0., 1., 0., 0., 0., 0.}, {0., 1., 0., 0., 0., 1.}, + {0., 1., 0., 0., 1., 0.}, {0., 1., 0., 0., 1., 1.}, + {0., 1., 0., 1., 0., 0.}, {0., 1., 0., 1., 0., 1.}, + {0., 1., 0., 1., 1., 0.}, {0., 1., 0., 1., 1., 1.}, + {0., 1., 1., 0., 0., 0.}, {0., 1., 1., 0., 0., 1.}, + {0., 1., 1., 0., 1., 0.}, {0., 1., 1., 0., 1., 1.}, + {0., 1., 1., 1., 0., 0.}, {0., 1., 1., 1., 0., 1.}, + {0., 1., 1., 1., 1., 0.}, {0., 1., 1., 1., 1., 1.}, + {1., 0., 0., 0., 0., 0.}, {1., 0., 0., 0., 0., 1.}, + {1., 0., 0., 0., 1., 0.}, {1., 0., 0., 0., 1., 1.}, + {1., 0., 0., 1., 0., 0.}, {1., 0., 0., 1., 0., 1.}, + {1., 0., 0., 1., 1., 0.}, {1., 0., 0., 1., 1., 1.}, + {1., 0., 1., 0., 0., 0.}, {1., 0., 1., 0., 0., 1.}, + {1., 0., 1., 0., 1., 0.}, {1., 0., 1., 0., 1., 1.}, + {1., 0., 1., 1., 0., 0.}, {1., 0., 1., 1., 0., 1.}, + {1., 0., 1., 1., 1., 0.}, {1., 0., 1., 1., 1., 1.}, + {1., 1., 0., 0., 0., 0.}, {1., 1., 0., 0., 0., 1.}, + {1., 1., 0., 0., 1., 0.}, {1., 1., 0., 0., 1., 1.}, + {1., 1., 0., 1., 0., 0.}, {1., 1., 0., 1., 0., 1.}, + {1., 1., 0., 1., 1., 0.}, {1., 1., 0., 1., 1., 1.}, + {1., 1., 1., 0., 0., 0.}, {1., 1., 1., 0., 0., 1.}, + {1., 1., 1., 0., 1., 0.}, {1., 1., 1., 0., 1., 1.}, + {1., 1., 1., 1., 0., 0.}, {1., 1., 1., 1., 0., 1.}, + {1., 1., 1., 1., 1., 0.}, {1., 1., 1., 1., 1., 1.}}; +#endif + +constexpr int kNumPosEncodingChannels = 64; +#if defined(__CUDA_ARCH__) +__device__ +#endif +const float kPosEncoding[64][kNumPosEncodingChannels] = { + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + -1.0}; +} // namespace lczero From af2db3dca1603604d9d21b63673c639e843389a2 Mon Sep 17 00:00:00 2001 From: Ankan Banerjee Date: Mon, 28 Mar 2022 19:45:07 +0530 Subject: [PATCH 07/70] add hack to match training side bug - will be removed once it's fixed. --- src/neural/cuda/common_kernels.cu | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index 45829b8854..d42d772b8c 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -1006,10 +1006,11 @@ __global__ void preprocess_for_attention_body_kernel(T* output, const T* input) // concatenate from fixed pos encoding array op = (T) (kPosEncoding[hw][c - kInputPlanes]); } else { - - op = input[n * 64 * kInputPlanes + c * 64 + hw]; // nchw + op = input[n * kInputPlanes * 64 + c * 64 + hw]; // nchw } + if (c == 109) op = (T) (float(op) / 99.0f); // Ankan - hack to match bug in training side! + constexpr int outputC = kInputPlanes + kNumPosEncodingChannels; // convert to nhwc @@ -1020,7 +1021,7 @@ template void inputPreprocessForAttentionBody(T* output, const T* input, int N, cudaStream_t stream) { // N * 64 blocks - // (kInputPlanes + 6) threads + // (kInputPlanes + kNumPosEncodingChannels) threads // Each thread computes a single output element dim3 gridSize = dim3(N, 64); int blockSize = kInputPlanes + kNumPosEncodingChannels; From 3c2639cca1098d15d88bb1de4b8ac0f7887c55f2 Mon Sep 17 00:00:00 2001 From: Ankan Banerjee Date: Mon, 28 Mar 2022 22:22:47 +0530 Subject: [PATCH 08/70] fix build error --- src/neural/cuda/layers.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index e96426e8a6..6666ed235b 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -31,6 +31,7 @@ #include #include "cuda_common.h" +#include "neural/network.h" #include "kernels.h" #include "utils/fp16_utils.h" #include "neural/shared/attention_policy_map.h" From 410c4f9378db994b2ea7dffb1ed012289ccf955b Mon Sep 17 00:00:00 2001 From: Ankan Banerjee Date: Tue, 29 Mar 2022 14:14:12 +0530 Subject: [PATCH 09/70] remove hack for plies ply plane training side bug - also fix scratch space calculation. --- src/neural/cuda/common_kernels.cu | 2 -- src/neural/cuda/network_cuda.cc | 60 ++++++++++++++++++++++++------- 2 files changed, 47 insertions(+), 15 deletions(-) diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index d42d772b8c..919f03cd64 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -1009,8 +1009,6 @@ __global__ void preprocess_for_attention_body_kernel(T* output, const T* input) op = input[n * kInputPlanes * 64 + c * 64 + hw]; // nchw } - if (c == 109) op = (T) (float(op) / 99.0f); // Ankan - hack to match bug in training side! - constexpr int outputC = kInputPlanes + kNumPosEncodingChannels; // convert to nhwc diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 5c8f21f88c..44d2585051 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -82,6 +82,39 @@ static size_t getMaxAttentionHeadSize(const LegacyWeights& weights, int N) { return size; } +static size_t getMaxAttentionBodySize(const LegacyWeights& weights, int N) { + const size_t embedding_op_size = weights.ip_emb_b.size(); + + size_t encoder_d_model = 0; + size_t encoder_dff = 0; + + if (weights.encoder.size() > 0) { + encoder_d_model = weights.encoder[0].mha.q_b.size(); + encoder_dff = weights.encoder[0].ffn.dense1_b.size(); + + assert(encoder_d_model == weights.encoder[0].mha.k_b.size()); + assert(encoder_d_model == weights.encoder[0].mha.v_b.size()); + assert(embedding_op_size == weights.encoder[0].ffn.dense2_b.size()); + } + + const size_t encoder_heads = weights.encoder_head_count; + + size_t size = + N * 64 * std::max(embedding_op_size, encoder_d_model); + + // size of matmul_qk matrix = encoder_heads_ * Batch * 64 * 64 + const size_t matmul_qk_size = encoder_heads * N * 64 * 64; + const size_t output_size = N * (64 * 64 + 8 * 24); + size = std::max(size, std::max(matmul_qk_size, output_size)); + + size_t qkv_size = N * 64 * encoder_d_model; + // We store qkv in single allocation, and other intermediate tensors are + // sometimes stored by splitting an allocation into two halves. + size = std::max(2 * size, 3 * qkv_size); + return size; +} + + template class CudaNetworkComputation : public NetworkComputation { public: @@ -232,11 +265,6 @@ class CudaNetwork : public Network { assert(weights.ip_emb_b.size() > 0); } - // Ankan - test! - printf("\nNum filters: %d, num blocks: %d\n", kNumFilters, numBlocks_); - printf("\nNum encoder blocks: %d, num policy encoder blocks: %d\n", - num_encoder_blocks_, (int)weights.pol_encoder.size()); - // Warn if the memory required for storing transformed weights is // going to exceed 40% of total video memory, force custom_winograd off @@ -289,15 +317,21 @@ class CudaNetwork : public Network { // Need additional space for transformed input/outputs which are 36/16 // times size (4x4 block transformed into 6x6). - const size_t transformed_tensor_size = - (size_t)(max_batch_size_ * kNumFilters * 64 * (36.0 / 16.0) * - sizeof(DataType)); - scratch_size_ = std::max(scratch_size_, 2 * transformed_tensor_size); + if (numBlocks_ > 0) { + const size_t transformed_tensor_size = + (size_t)(max_batch_size_ * kNumFilters * 64 * (36.0 / 16.0) * + sizeof(DataType)); + scratch_size_ = std::max(scratch_size_, 2 * transformed_tensor_size); + } - // Attention policy head may need more memory - const size_t attentionSize = + // Attention policy head or body may need more memory + const size_t attentionPolicySize = getMaxAttentionHeadSize(weights, max_batch_size_); - scratch_size_ = std::max(scratch_size_, attentionSize); + + const size_t attentionBodySize = + getMaxAttentionBodySize(weights, max_batch_size_); + scratch_size_ = std::max(scratch_size_, + std::max(attentionPolicySize, attentionBodySize)); ReportCUDAErrors(cudaMalloc(&scratch_mem_, scratch_size_)); @@ -508,7 +542,7 @@ class CudaNetwork : public Network { maxSize = std::max(maxSize, layer->GetOutputSize(max_batch_size_)); } - if ((attn_policy_ || use_res_block_winograd_fuse_opt_) && + if ((attn_policy_ || use_res_block_winograd_fuse_opt_ || attn_body_) && (scratch_size_ > maxSize)) { maxSize = scratch_size_; } From b6c8e4372724fdf20081257ef2210b2210ddad6d Mon Sep 17 00:00:00 2001 From: Ankan Banerjee Date: Mon, 12 Sep 2022 16:49:41 +0530 Subject: [PATCH 10/70] Fix attention body/head size - factor of sizeof(DataType) was missing. --- src/neural/cuda/network_cuda.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 44d2585051..34acd0fbdd 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -326,10 +326,10 @@ class CudaNetwork : public Network { // Attention policy head or body may need more memory const size_t attentionPolicySize = - getMaxAttentionHeadSize(weights, max_batch_size_); + getMaxAttentionHeadSize(weights, max_batch_size_) * sizeof(DataType); const size_t attentionBodySize = - getMaxAttentionBodySize(weights, max_batch_size_); + getMaxAttentionBodySize(weights, max_batch_size_) * sizeof(DataType); scratch_size_ = std::max(scratch_size_, std::max(attentionPolicySize, attentionBodySize)); From 844186b3adbcc67558ec655c7f28963ae10fbd42 Mon Sep 17 00:00:00 2001 From: Alma Date: Thu, 15 Dec 2022 09:04:49 +0100 Subject: [PATCH 11/70] Add input gating kernel. --- libs/lczero-common | 2 +- src/neural/cuda/common_kernels.cu | 30 ++++++++++++++++++++++++++++++ src/neural/cuda/cuda_common.h | 2 +- src/neural/cuda/kernels.h | 5 +++++ src/neural/cuda/layers.cc | 15 ++++++++++++++- src/neural/cuda/layers.h | 2 ++ src/neural/network_legacy.cc | 22 ++++++++++++++++++++-- src/neural/network_legacy.h | 24 ++++++++++++++++++++++++ src/neural/shared/activation.cc | 2 ++ src/neural/shared/activation.h | 2 +- 10 files changed, 100 insertions(+), 6 deletions(-) diff --git a/libs/lczero-common b/libs/lczero-common index 4dfa4ce833..2165d35bf6 160000 --- a/libs/lczero-common +++ b/libs/lczero-common @@ -1 +1 @@ -Subproject commit 4dfa4ce8339357819f7de01517e6297d4c768cdf +Subproject commit 2165d35bf63e95549eb4feff06a755ec88af5264 diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index 919f03cd64..fc3d6a966f 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -156,6 +156,9 @@ void addBiasBatched(T* output, const T* input, const T* bias, int Batch, int N, addBiasBatched_kernel <<>>(output, input, bias, N, C); break; + case SWISH: + addBiasBatched_kernel + <<>>(output, input, bias, N, C); default: throw Exception( "unsupported activation in addBiasBatched. Add in switch-case here"); @@ -1027,6 +1030,26 @@ void inputPreprocessForAttentionBody(T* output, const T* input, int N, <<>>(output, input); } +template +__global__ void input_gating_kernel(T* output, const T* input, const T* mult, const T* add) { + int n = blockIdx.x * blockDim.x * blockDim.y; + int idx = threadIdx.y * blockDim.x + threadIdx.x; // index in input + int idxT = threadIdx.x * blockDim.y + threadIdx.y; // index in transposed weights arrays mult and add. + + // Combine multiply gating, add gating and weights transpose. + output[n + idx] = input[n + idx] * mult[idxT] + add[idxT]; +} + +template +void applyInputGating(T* output, const T* input, const T* mult, const T* add, + int N, int HW, int C, cudaStream_t stream) { + // N blocks, + // (C * output_size) threads + // Each thread computes a single output element + dim3 gridSize = dim3(N, 1); + dim3 blockSize = dim3(C, HW); + input_gating_kernel <<>>(output, input, mult, add); +} // Template instantiation. template void copyTypeConverted(half* op, float* ip, int N, @@ -1257,5 +1280,12 @@ template void inputPreprocessForAttentionBody(half* output, template void inputPreprocessForAttentionBody(float* output, const float* input, int N, cudaStream_t stream); + +template void applyInputGating(half* output, const half* input, const half* mult, const half* add, + int N, int C, int output_size, cudaStream_t stream); + +template void applyInputGating(float* output, const float* input, const float* mult, const float* add, + int N, int C, int output_size, cudaStream_t stream); + } // namespace cudnn_backend } // namespace lczero diff --git a/src/neural/cuda/cuda_common.h b/src/neural/cuda/cuda_common.h index 759238cd4e..5a44bc4555 100644 --- a/src/neural/cuda/cuda_common.h +++ b/src/neural/cuda/cuda_common.h @@ -74,7 +74,7 @@ void CudaError(cudaError_t status, const char* file, const int& line); inline int DivUp(int a, int b) { return (a + b - 1) / b; } -enum ActivationFunction { NONE, RELU, TANH, SIGMOID, SELU, MISH }; +enum ActivationFunction { NONE, RELU, TANH, SIGMOID, SELU, MISH, SWISH }; } // namespace cudnn_backend } // namespace lczero diff --git a/src/neural/cuda/kernels.h b/src/neural/cuda/kernels.h index a0bc2b9c06..cc317b7743 100644 --- a/src/neural/cuda/kernels.h +++ b/src/neural/cuda/kernels.h @@ -133,5 +133,10 @@ void ComputePromotionLogits(int N, int C, T* output, const T* keys, template void inputPreprocessForAttentionBody(T* output, const T* input, int N, cudaStream_t stream); + +template +void applyInputGating(T* output, const T* input, const T* mult, const T* add, + int N, int HW, int C, cudaStream_t stream); + } // namespace cudnn_backend } // namespace lczero diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 0a60312a74..1eadb17201 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -38,7 +38,7 @@ namespace lczero { -#if 0 +#if 1 // debug code to dump allocation in GPU memory template void dumpTensor(T* memory, int elements, const char* message) { @@ -1815,11 +1815,17 @@ AttentionBody::AttentionBody(const LegacyWeights& weights, num_resi_blocks_(num_res_blocks), default_act_(default_act), input_c_(input_c), + has_gating_(weights.ip_mult_gate.size() > 0 && weights.ip_add_gate.size() > 0), BaseLayer(weights.ip_emb_b.size(), 8, 8, nullptr) { allocAndUpload(&ip_emb_w_, weights.ip_emb_w, scratch); allocAndUpload(&ip_emb_b_, weights.ip_emb_b, scratch); + if (has_gating_) { + allocAndUpload(&ip_mult_gate_, weights.ip_mult_gate, scratch); + allocAndUpload(&ip_add_gate_, weights.ip_add_gate, scratch); + } + int num_encoders = weights.encoder.size(); float alpha = (float) pow(2.0 * num_encoders, 0.25); for (const auto& enc : weights.encoder) { @@ -1886,6 +1892,13 @@ void AttentionBody::Eval( num_outputs, default_act_, stream); } + // Input gating + if (has_gating_) { + applyInputGating(embedding, embedding, ip_mult_gate_, ip_add_gate_, + N, 64, embedding_op_size_, stream); + dumpTensor(embedding, 64 * embedding_op_size_, "input gating outputs"); + } + // 2. Encoder layers for (const auto pEnc : encoder_weights_) { pEnc->Eval(N, scratch1, scratch0, scratch2, scratch3, cublas, stream, diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index af629af79e..f2992c68a0 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -458,12 +458,14 @@ class AttentionBody : public BaseLayer { private: // GPU allocations to hold various weights used by the attention policy head DataType *ip_emb_w_, *ip_emb_b_; // "embedding" layer in net body + DataType *ip_mult_gate_, *ip_add_gate_; // input gating int embedding_op_size_; int encoder_head_count_; std::vector*> encoder_weights_; ActivationFunction default_act_; int num_resi_blocks_; int input_c_; + const bool has_gating_; }; diff --git a/src/neural/network_legacy.cc b/src/neural/network_legacy.cc index 0872ae0ddc..387590de6b 100644 --- a/src/neural/network_legacy.cc +++ b/src/neural/network_legacy.cc @@ -32,6 +32,8 @@ LegacyWeights::LegacyWeights(const pblczero::Weights& weights) : input(weights.input()), ip_emb_w(LayerAdapter(weights.ip_emb_w()).as_vector()), ip_emb_b(LayerAdapter(weights.ip_emb_b()).as_vector()), + ip_mult_gate(LayerAdapter(weights.ip_mult_gate()).as_vector()), + ip_add_gate(LayerAdapter(weights.ip_add_gate()).as_vector()), policy1(weights.policy1()), policy(weights.policy()), ip_pol_w(LayerAdapter(weights.ip_pol_w()).as_vector()), @@ -54,7 +56,9 @@ LegacyWeights::LegacyWeights(const pblczero::Weights& weights) ip1_mov_w(LayerAdapter(weights.ip1_mov_w()).as_vector()), ip1_mov_b(LayerAdapter(weights.ip1_mov_b()).as_vector()), ip2_mov_w(LayerAdapter(weights.ip2_mov_w()).as_vector()), - ip2_mov_b(LayerAdapter(weights.ip2_mov_b()).as_vector()) { + ip2_mov_b(LayerAdapter(weights.ip2_mov_b()).as_vector()), + smolgen_w(LayerAdapter(weights.smolgen_w()).as_vector()), + has_smolgen(weights.has_smolgen_w()) { for (const auto& res : weights.residual()) { residual.emplace_back(res); } @@ -145,7 +149,9 @@ LegacyWeights::MHA::MHA(const pblczero::Weights::MHA& mha) v_w(LayerAdapter(mha.v_w()).as_vector()), v_b(LayerAdapter(mha.v_b()).as_vector()), dense_w(LayerAdapter(mha.dense_w()).as_vector()), - dense_b(LayerAdapter(mha.dense_b()).as_vector()) {} + dense_b(LayerAdapter(mha.dense_b()).as_vector()), + smolgen(Smolgen(mha.smolgen())), + has_smolgen(mha.has_smolgen()) {} LegacyWeights::FFN::FFN(const pblczero::Weights::FFN& ffn) : dense1_w(LayerAdapter(ffn.dense1_w()).as_vector()), @@ -162,4 +168,16 @@ LegacyWeights::EncoderLayer::EncoderLayer( ln2_gammas(LayerAdapter(encoder.ln2_gammas()).as_vector()), ln2_betas(LayerAdapter(encoder.ln2_betas()).as_vector()) {} +LegacyWeights::Smolgen::Smolgen( + const pblczero::Weights::Smolgen& smolgen) + : compress(LayerAdapter(smolgen.compress()).as_vector()), + dense1_w(LayerAdapter(smolgen.dense1_w()).as_vector()), + dense1_b(LayerAdapter(smolgen.dense1_b()).as_vector()), + ln1_gammas(LayerAdapter(smolgen.ln1_gammas()).as_vector()), + ln1_betas(LayerAdapter(smolgen.ln1_betas()).as_vector()), + dense2_w(LayerAdapter(smolgen.dense2_w()).as_vector()), + dense2_b(LayerAdapter(smolgen.dense2_b()).as_vector()), + ln2_gammas(LayerAdapter(smolgen.ln2_gammas()).as_vector()), + ln2_betas(LayerAdapter(smolgen.ln2_betas()).as_vector()) {} + } // namespace lczero diff --git a/src/neural/network_legacy.h b/src/neural/network_legacy.h index 19284af172..540607da42 100644 --- a/src/neural/network_legacy.h +++ b/src/neural/network_legacy.h @@ -55,6 +55,19 @@ struct LegacyWeights { bool has_se; }; + struct Smolgen { + explicit Smolgen(const pblczero::Weights::Smolgen& smolgen); + Vec compress; + Vec dense1_w; + Vec dense1_b; + Vec ln1_gammas; + Vec ln1_betas; + Vec dense2_w; + Vec dense2_b; + Vec ln2_gammas; + Vec ln2_betas; + }; + struct MHA { explicit MHA(const pblczero::Weights::MHA& mha); Vec q_w; @@ -65,6 +78,8 @@ struct LegacyWeights { Vec v_b; Vec dense_w; Vec dense_b; + Smolgen smolgen; + bool has_smolgen; }; struct FFN { @@ -92,6 +107,10 @@ struct LegacyWeights { Vec ip_emb_w; Vec ip_emb_b; + // Input gating + Vec ip_mult_gate; + Vec ip_add_gate; + // Encoder stack. std::vector encoder; int encoder_head_count; @@ -133,6 +152,11 @@ struct LegacyWeights { Vec ip1_mov_b; Vec ip2_mov_w; Vec ip2_mov_b; + + // Smolgen global weights + Vec smolgen_w; + Vec smolgen_b; + bool has_smolgen; }; } // namespace lczero diff --git a/src/neural/shared/activation.cc b/src/neural/shared/activation.cc index ecf154f2ed..9f6b7489c4 100644 --- a/src/neural/shared/activation.cc +++ b/src/neural/shared/activation.cc @@ -78,6 +78,8 @@ float Activate(const float val, const ActivationFunction activation) { return 1.0f / (1.0f + expf(-val)); case SELU: return selu(val); + case SWISH: + return val / (1.0f + expf(-val)); case NONE: // Nothing to do. break; diff --git a/src/neural/shared/activation.h b/src/neural/shared/activation.h index 5937126b32..8a55df486b 100644 --- a/src/neural/shared/activation.h +++ b/src/neural/shared/activation.h @@ -22,7 +22,7 @@ #include namespace lczero { -enum ActivationFunction { NONE, RELU, TANH, SIGMOID, SELU, MISH }; +enum ActivationFunction { NONE, RELU, TANH, SIGMOID, SELU, MISH, SWISH }; // Softmax activation void SoftmaxActivation(const size_t size, const float* input, float* output); From 780d47f96d9f5d663a35ae18378d76d830996e76 Mon Sep 17 00:00:00 2001 From: Alma Date: Thu, 15 Dec 2022 12:09:51 +0100 Subject: [PATCH 12/70] Completed input gating --- src/neural/cuda/common_kernels.cu | 35 ++++++++++++++++++++----------- src/neural/cuda/layers.cc | 6 +++--- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index fc3d6a966f..05eee1be02 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -1031,24 +1031,35 @@ void inputPreprocessForAttentionBody(T* output, const T* input, int N, } template -__global__ void input_gating_kernel(T* output, const T* input, const T* mult, const T* add) { - int n = blockIdx.x * blockDim.x * blockDim.y; - int idx = threadIdx.y * blockDim.x + threadIdx.x; // index in input - int idxT = threadIdx.x * blockDim.y + threadIdx.y; // index in transposed weights arrays mult and add. - - // Combine multiply gating, add gating and weights transpose. - output[n + idx] = input[n + idx] * mult[idxT] + add[idxT]; +__global__ void input_gating_kernel(T* output, const T* input, const T* mult, const T* add, int HW, int C) { + int n_offset = blockIdx.z * HW * C; + int idx = threadIdx.y * C + blockIdx.x * blockDim.x + threadIdx.x; // index in input + int idxT = (blockIdx.x * blockDim.x + threadIdx.x) * HW + threadIdx.y; // index in transposed weights arrays mult and add. + + if (idx < HW * C) { + // Combine multiply gating, add gating and weights transpose. + float op = (float) input[n_offset + idx] * (float) mult[idxT] + (float) add[idxT]; + output[n_offset + idx] = (T) op; + } } template void applyInputGating(T* output, const T* input, const T* mult, const T* add, int N, int HW, int C, cudaStream_t stream) { - // N blocks, - // (C * output_size) threads + // Multiple blocks to fit into each input area / volume + // Block x position indicates horizontal section of area + // Block y position indicates batch // Each thread computes a single output element - dim3 gridSize = dim3(N, 1); - dim3 blockSize = dim3(C, HW); - input_gating_kernel <<>>(output, input, mult, add); + dim3 blockSize, gridSize; + blockSize.x = DivUp(1024, HW); + blockSize.y = HW; + blockSize.z = 1; + gridSize.x = DivUp(C, blockSize.x); + gridSize.y = 1; + gridSize.z = N; + input_gating_kernel<<>>(output, input, mult, add, HW, C); + + ReportCUDAErrors(cudaGetLastError()); } // Template instantiation. diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 1eadb17201..fd9ca2efe9 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -62,8 +62,9 @@ void dumpTensor(T* memory, int elements, const char* message) { float *arr = (float *)temp; val = arr[i]; } - printf("%8.4f ", val); - if ((i % 8) == 7) printf("\n"); + // printf("%8.4f ", val); + // if ((i % 8) == 7) printf("\n"); + printf("%i;%8.4f\n", i, val); } free(temp); printf("\n"); @@ -1896,7 +1897,6 @@ void AttentionBody::Eval( if (has_gating_) { applyInputGating(embedding, embedding, ip_mult_gate_, ip_add_gate_, N, 64, embedding_op_size_, stream); - dumpTensor(embedding, 64 * embedding_op_size_, "input gating outputs"); } // 2. Encoder layers From f25f0bd168cfaa47bafe2984cb5abc967c6f6f5d Mon Sep 17 00:00:00 2001 From: Alma Date: Sat, 24 Dec 2022 01:13:44 +0100 Subject: [PATCH 13/70] Add input gating, smolgen and sqrrelu. --- src/neural/cuda/common_kernels.cu | 40 ++++- src/neural/cuda/cuda_common.h | 2 +- src/neural/cuda/kernels.h | 6 +- src/neural/cuda/layers.cc | 230 ++++++++++++++++++++++++++-- src/neural/cuda/layers.h | 24 ++- src/neural/cuda/network_cuda.cc | 1 + src/neural/cuda/winograd_helper.inc | 7 + src/neural/shared/activation.h | 2 +- 8 files changed, 286 insertions(+), 26 deletions(-) diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index 05eee1be02..01c8329f01 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -159,6 +159,11 @@ void addBiasBatched(T* output, const T* input, const T* bias, int Batch, int N, case SWISH: addBiasBatched_kernel <<>>(output, input, bias, N, C); + break; + case RELU_2: // square relu + addBiasBatched_kernel + <<>>(output, input, bias, N, C); + break; default: throw Exception( "unsupported activation in addBiasBatched. Add in switch-case here"); @@ -797,7 +802,7 @@ __device__ __forceinline__ float shared_sum_for_layer_norm(float x) { template __global__ void layer_norm_kernel(int N, int C, T* output, const T* input, const T* bias, const T* skip, const T* gammas, - const T* betas, float ep, float alpha) { + const T* betas, float ep, float alpha, ActivationFunction act) { int n = blockIdx.x * blockDim.z + threadIdx.z; if (n >= N) return; int c = (threadIdx.y * 32 + threadIdx.x) * 4; @@ -840,7 +845,7 @@ __global__ void layer_norm_kernel(int N, int C, T* output, const T* input, const float s = 0; if (!oobThread) for (int i = 0; i < 4; i++) { - val[i] += b[i] + sk[i] * alpha; + val[i] = activate(val[i] + b[i], act) + sk[i] * alpha; s += val[i]; } @@ -883,7 +888,7 @@ __global__ void layer_norm_kernel(int N, int C, T* output, const T* input, const template void LayerNorm(int N, int C, T* output, const T* input, const T* bias, const T* skip, const T* gammas, const T* betas, float ep, float alpha, - cudaStream_t stream) { + ActivationFunction act, cudaStream_t stream) { // process 4 elements per thread to achieve close to peak memory bandwidth if (C % 4 != 0) throw Exception("unsupported filter size"); if (C > 4096) throw Exception("unsupported filter size"); @@ -898,7 +903,7 @@ void LayerNorm(int N, int C, T* output, const T* input, const T* bias, gridDim.z = 1; layer_norm_kernel<<>>( - N, C, output, input, bias, skip, gammas, betas, ep, alpha); + N, C, output, input, bias, skip, gammas, betas, ep, alpha, act); ReportCUDAErrors(cudaGetLastError()); } @@ -1062,6 +1067,25 @@ void applyInputGating(T* output, const T* input, const T* mult, const T* add, ReportCUDAErrors(cudaGetLastError()); } +template +__global__ void mask_layer_kernel(T* output, const T* input, int size) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + + //printf("idx %i", idx); + if (idx >= size) return; + if (idx >= 2048) output[idx] = 0.0; + // if (idx > 2048) output[idx] = input[idx]; + // else output[idx] = 0.0; +} + +template +void maskLayer(T* output, const T* input, int size, cudaStream_t stream) { + int blockSize = min(1024, size); + dim3 gridSize = size / blockSize; + mask_layer_kernel<<>>(output, input, size); + ReportCUDAErrors(cudaGetLastError()); +} + // Template instantiation. template void copyTypeConverted(half* op, float* ip, int N, cudaStream_t stream); @@ -1256,11 +1280,13 @@ template void Softmax(int N, int C, float* output, const float* input, template void LayerNorm(int N, int C, half* output, const half* input, const half* bias, const half* skip, const half* gammas, const half* betas, float ep, - float alpha, cudaStream_t stream); + float alpha, ActivationFunction act, + cudaStream_t stream); template void LayerNorm(int N, int C, float* output, const float* input, const float* bias, const float* skip, const float* gammas, const float* betas, - float ep, float alpha, cudaStream_t stream); + float ep, float alpha, ActivationFunction act, + cudaStream_t stream); template void ComputePromotionLogits(int N, int C, half* output, const half* keys, const half* ppo, @@ -1298,5 +1324,7 @@ template void applyInputGating(half* output, const half* input, const half template void applyInputGating(float* output, const float* input, const float* mult, const float* add, int N, int C, int output_size, cudaStream_t stream); +template void maskLayer(half* output, const half* input, int size, cudaStream_t); +template void maskLayer(float* output, const float* input, int size, cudaStream_t); } // namespace cudnn_backend } // namespace lczero diff --git a/src/neural/cuda/cuda_common.h b/src/neural/cuda/cuda_common.h index 5a44bc4555..5c335053d8 100644 --- a/src/neural/cuda/cuda_common.h +++ b/src/neural/cuda/cuda_common.h @@ -74,7 +74,7 @@ void CudaError(cudaError_t status, const char* file, const int& line); inline int DivUp(int a, int b) { return (a + b - 1) / b; } -enum ActivationFunction { NONE, RELU, TANH, SIGMOID, SELU, MISH, SWISH }; +enum ActivationFunction { NONE, RELU, TANH, SIGMOID, SELU, MISH, SWISH, RELU_2 }; } // namespace cudnn_backend } // namespace lczero diff --git a/src/neural/cuda/kernels.h b/src/neural/cuda/kernels.h index cc317b7743..4c56dcad28 100644 --- a/src/neural/cuda/kernels.h +++ b/src/neural/cuda/kernels.h @@ -123,7 +123,7 @@ void Softmax(int N, int C, T* output, const T* input, cudaStream_t stream); template void LayerNorm(int N, int C, T* output, const T* input, const T* bias, const T* skip, const T* gammas, const T* betas, float ep, - float alpha, cudaStream_t stream); + float alpha, ActivationFunction act, cudaStream_t stream); template void ComputePromotionLogits(int N, int C, T* output, const T* keys, @@ -137,6 +137,8 @@ void inputPreprocessForAttentionBody(T* output, const T* input, int N, template void applyInputGating(T* output, const T* input, const T* mult, const T* add, int N, int HW, int C, cudaStream_t stream); - + +template +void maskLayer(T* output, const T* input, int size, cudaStream_t); } // namespace cudnn_backend } // namespace lczero diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index fd9ca2efe9..bfc0f322ae 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -64,7 +64,7 @@ void dumpTensor(T* memory, int elements, const char* message) { } // printf("%8.4f ", val); // if ((i % 8) == 7) printf("\n"); - printf("%i;%8.4f\n", i, val); + printf("%i;%.6f\n", i, val); } free(temp); printf("\n"); @@ -1412,7 +1412,8 @@ AttentionPolicyHead::AttentionPolicyHead(BaseLayer* ip, for (const auto& enc : weights.pol_encoder) { EncoderBlock* pW = new EncoderBlock( - enc, scratch, encoder_heads_, embedding_op_size_, 1.0f); // using alpha = 1 for now (TODO: may change?) + enc, scratch, encoder_heads_, embedding_op_size_, 1.0f, + nullptr, 0); // using alpha = 1 for now (TODO: may change?) encoder_weights_.emplace_back(pW); } } @@ -1420,8 +1421,9 @@ AttentionPolicyHead::AttentionPolicyHead(BaseLayer* ip, template EncoderBlock::EncoderBlock( const LegacyWeights::EncoderLayer& cpu_weights, void* scratch, int heads, - int size, float alpha) - : encoder_heads_(heads), embedding_op_size_(size), alpha_(alpha) { + int size, float alpha, DataType* smolgen_global_scratch, int smolgen_global_size) + : encoder_heads_(heads), embedding_op_size_(size), alpha_(alpha), + has_smolgen_(cpu_weights.mha.has_smolgen) { mha_q_size_ = cpu_weights.mha.q_b.size(); mha_k_size_ = cpu_weights.mha.k_b.size(); mha_v_size_ = cpu_weights.mha.v_b.size(); @@ -1475,6 +1477,29 @@ EncoderBlock::EncoderBlock( allocAndUpload(&ln2_gammas, cpu_weights.ln2_gammas, scratch); allocAndUpload(&ln2_betas, cpu_weights.ln2_betas, scratch); + + // Smolgen weights. + if (has_smolgen_) { + smol_compress_size_ = cpu_weights.mha.smolgen.compress.size() / mha_q_size_; + smol_dense_1_size_ = cpu_weights.mha.smolgen.dense1_b.size(); + smol_dense_2_size_ = cpu_weights.mha.smolgen.dense2_b.size(); + smol_global_size_ = smolgen_global_size; + + allocAndUpload(&smol_compress, cpu_weights.mha.smolgen.compress, scratch); + allocAndUpload(&smol_dense1_w, cpu_weights.mha.smolgen.dense1_w, scratch); + allocAndUpload(&smol_dense1_b, cpu_weights.mha.smolgen.dense1_b, scratch); + allocAndUpload(&smol_dense2_w, cpu_weights.mha.smolgen.dense2_w, scratch); + allocAndUpload(&smol_dense2_b, cpu_weights.mha.smolgen.dense2_b, scratch); + + allocAndUpload(&smol_ln1_gammas, cpu_weights.mha.smolgen.ln1_gammas, scratch); + allocAndUpload(&smol_ln1_betas, cpu_weights.mha.smolgen.ln1_betas, scratch); + allocAndUpload(&smol_ln2_gammas, cpu_weights.mha.smolgen.ln2_gammas, scratch); + allocAndUpload(&smol_ln2_betas, cpu_weights.mha.smolgen.ln2_betas, scratch); + + // GPU memory already allocated in AttentionBody. + smol_global = smolgen_global_scratch; + } + } template @@ -1525,10 +1550,135 @@ template void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, DataType* scratch2, DataType* scratch3, cublasHandle_t cublas, cudaStream_t stream, - ActivationFunction act) const { + ActivationFunction act, int layer_id) const { const int d_model = mha_q_size_; const int depth = d_model / encoder_heads_; + char desc [40]; + snprintf (desc, 40, "Encoder layer #%d input", layer_id); + dumpTensor(scratch1, 10, desc); + const int layer_to_print = 2; + + // Calculate smolgen weights. Do this first so we can make use of + // scratch2 and scratch3. + if (has_smolgen_) { + { + // Compress. + // input shape: N, 64, d_model + // output shape: N, 64, hidden_channels + const int num_inputs = d_model; + const int num_outputs = smol_compress_size_; + const int batch = N * 64; + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, (const DataType*)smol_compress, num_inputs, + scratch1, num_inputs, 0.0f, scratch0, num_outputs); + if (layer_id == layer_to_print) dumpTensor(scratch0, 10, "smol compress"); + + } + + { + // Hidden 1 dense. + // input shape: N, 64 * hidden_channels + // output shape: N, hidden_sz + const int num_inputs = 64 * smol_compress_size_; + const int num_outputs = smol_dense_1_size_; + const int batch = N; + cublasXGemmStridedBatched( + cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs /*M*/, + batch /*N*/, num_inputs /*K*/, 1.0f, + smol_dense1_w /*A*/, // "smol_weight_gen" weights + num_inputs /*LDA*/, + 0, /*strideA*/ + scratch0 /*B*/, + num_inputs /*LDB*/, + num_inputs, /*strideB*/ + 0.0f, + scratch2 /*C*/, // output goes to scratch3 + num_outputs /*LDC*/, num_outputs /*strideC*/, batch); + // dumpTensor(scratch2, 100, "Batch 1"); + // dumpTensor(scratch2 + 256, 100, "Batch 2"); + // return; + if (layer_id == layer_to_print) dumpTensor(smol_dense1_w, 1000, "smol_dense1_w"); + + if (layer_id == layer_to_print) dumpTensor(scratch2, 10, "smol hidden1"); + + LayerNorm(batch, num_outputs, scratch0, scratch2, smol_dense1_b, + scratch2, smol_ln1_gammas, smol_ln1_betas, 1e-6, + 0.0, /* alpha = 0 since we don't need skip */ SWISH, stream); + if (layer_id == layer_to_print) dumpTensor(scratch0, 10, "smol hidden1 ln"); + + // dumpTensor(scratch0, 100, "Batch 1"); + // dumpTensor(scratch0 + 256, 100, "Batch 2"); + // return; + } + + { + // Hidden 2 dense (gen_from) + // input shape: N, hidden_sz + // output shape: N, heads * gen_sz + const int num_inputs = smol_dense_1_size_; + const int num_outputs = smol_dense_2_size_; + const int batch = N; + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, (const DataType*)smol_dense2_w, num_inputs, + scratch0, num_inputs, 0.0f, scratch2, num_outputs); + + LayerNorm(batch, num_outputs, scratch0, scratch2, smol_dense2_b, + scratch2, smol_ln2_gammas, smol_ln2_betas, 1e-6, + 0.0, /* alpha = 0 since we don't need skip */ SWISH, stream); + if (layer_id == layer_to_print) dumpTensor(scratch0, 10, "smol gen_from"); + + // dumpTensor(scratch0, 100, "Batch 1"); + // dumpTensor(scratch0 + num_outputs, 100, "Batch 2"); + // return; + // Smolgen global 'smol_weight_gen' + // input shape: N, heads, gen_sz + // output shape: heads, N, 64 * 64 + // transpose: heads, N, 64, 64 to match scaled attention weights + + } + + { + // Final smolgen weights generation. + /* + gen_from = tf.reshape(gen_from, [-1, heads, gen_sz]) + out = self.smol_weight_gen_dense(gen_from) + */ + const int num_inputs = smol_dense_2_size_ / encoder_heads_; /* num_inputs == gen_sz == 256 */ + const int num_outputs = smol_global_size_; /* hwhw: 64 * 64 */ + const int batch = N; + for (int i = 0; i < encoder_heads_; i++) { + int inputOffset = i * num_inputs; + // int inputOffset = 1 * num_inputs; + int outputOffset = i * batch * num_outputs; + // int outputOffset = 1 * batch * num_outputs; + cublasXGemmStridedBatched( + cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs /*M*/, + batch /*N*/, num_inputs /*K*/, 1.0f, + smol_global /*A*/, // "smol_weight_gen" weights + num_inputs /*LDA*/, + 0, /*strideA*/ + scratch0 + inputOffset /*B*/, + num_inputs * encoder_heads_ /*LDB*/, + num_inputs * encoder_heads_, /*strideB*/ + 0.0f, + scratch3 + outputOffset /*C*/, // output goes to scratch1 + num_outputs /*LDC*/, num_outputs /*strideC*/, batch); + } + // for (int b = 0; b < N; b++) { + // for (int h = 0; h < encoder_heads_; h++) { + // // int start = (b * 12 + h) * 4096; + // int start = (h * N + b) * 4096; + // char desc [40]; + // snprintf (desc, 40, "batch %d head %d", b, h); + // dumpTensor(scratch3 + start, 50, desc); + // } + // } + // return; + } + + } + DataType* mha_q; DataType* mha_k; DataType* mha_v; @@ -1565,6 +1715,7 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) output = tf.matmul(attention_weights, v) */ + if (layer_id == layer_to_print) dumpTensor(scratch3, 10, "smolgen before"); // shape(k)[-1] = depth float factor = 1.0f / sqrt((float)depth); @@ -1590,7 +1741,29 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, scratch2 + outOffset /*C*/, // output (matmul_qk) goes to scratch2 64 /*LDC*/, 64 * 64 /*strideC*/, N); } + // dumpTensor(scratch2, 64*64*12*2, "softmax after attention weights"); + if (layer_id == layer_to_print) dumpTensor(scratch2, 10, "attn"); + + // Add smolgen weights to the scaled matmul_qk attention logits. + if (has_smolgen_) { + int size = N * encoder_heads_ * 64 * 64; + addVectors(scratch2, scratch2, scratch3, size, size, size, NONE, stream); + } + if (layer_id == layer_to_print) { + dumpTensor(scratch3, 10, "smolgen"); + dumpTensor(scratch2, 10, "attn + smolgen"); + } + // for (int b = 0; b < N; b++) { + // for (int h = 0; h < encoder_heads_; h++) { + // // int start = (b * 12 + h) * 4096; + // int start = (h * N + b) * 4096; + // char desc [40]; + // snprintf (desc, 40, "batch %d head %d", b, h); + // dumpTensor(scratch2 + start, 50, desc); + // } + // } + // return; // attention_weights = tf.nn.softmax(scaled_attention_logits, axis = -1) // attention_weights -> scratch2 Softmax(encoder_heads_ * N * 64, 64, scratch2, scratch2, stream); @@ -1609,7 +1782,7 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, 0.0f, scratch3 + offset /*C*/, // output goes to scratch3 d_model /*LDC*/, 64 * d_model /*strideC*/, N); } - + // dumpTensor(scratch2, 200, "softmax after attention weights"); // #final dense layer (mha_dense), scratch3 -> scratch2 { const int num_inputs = d_model; @@ -1624,7 +1797,7 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, // scratch2/scratch1 -> scratch0 LayerNorm(N * 64, embedding_op_size_, scratch0, scratch2, mha_dense_b, scratch1, ln1_gammas, ln1_betas, 1e-6, - alpha_, stream); + alpha_, NONE, stream); // #FFN dense 1, scratch0 -> scratch1 const int encoder_dff = ffn_dense1_size_; @@ -1636,9 +1809,9 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, num_inputs, 1.0f, (const DataType*)ffn_dense1_w, num_inputs, scratch0, num_inputs, 0.0f, scratch1, num_outputs); addBiasBatched(scratch1, scratch1, ffn_dense1_b, 1, batch, num_outputs, - act, stream); + has_smolgen_ ? RELU_2 : act, stream); // @todo sqr relu to have its own flag } - +if (layer_id == layer_to_print) dumpTensor(scratch1, 10, "ffn2"); // #FFN dense 2, scratch1 -> scratch2 { const int num_inputs = encoder_dff; @@ -1648,12 +1821,19 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, num_inputs, 1.0f, (const DataType*)ffn_dense2_w, num_inputs, scratch1, num_inputs, 0.0f, scratch2, num_outputs); } - + // LN2: skip connection and layer normilization (also bias add of prev gemm) // scratch2/scratch0 -> scratch1 LayerNorm(N * 64, embedding_op_size_, scratch1, scratch2, ffn_dense2_b, scratch0, ln2_gammas, ln2_betas, 1e-6, - alpha_, stream); + alpha_, NONE, stream); + // dumpTensor(scratch1, 40, "layernorm 2"); + // for (int b = 0; b < N; b++) { + // int start = b * embedding_op_size_ * 64; + // char desc [40]; + // snprintf (desc, 40, "batch %d", b); + // dumpTensor(scratch1 + start, 50, desc); + // } } template @@ -1770,6 +1950,15 @@ EncoderBlock::~EncoderBlock() { ReportCUDAErrors(cudaFree(ffn_dense2_b)); ReportCUDAErrors(cudaFree(ln2_gammas)); ReportCUDAErrors(cudaFree(ln2_betas)); + ReportCUDAErrors(cudaFree(smol_compress)); + ReportCUDAErrors(cudaFree(smol_dense1_w)); + ReportCUDAErrors(cudaFree(smol_dense1_b)); + ReportCUDAErrors(cudaFree(smol_dense2_w)); + ReportCUDAErrors(cudaFree(smol_dense2_b)); + ReportCUDAErrors(cudaFree(smol_ln1_gammas)); + ReportCUDAErrors(cudaFree(smol_ln1_betas)); + ReportCUDAErrors(cudaFree(smol_ln2_gammas)); + ReportCUDAErrors(cudaFree(smol_ln2_betas)); } @@ -1817,6 +2006,7 @@ AttentionBody::AttentionBody(const LegacyWeights& weights, default_act_(default_act), input_c_(input_c), has_gating_(weights.ip_mult_gate.size() > 0 && weights.ip_add_gate.size() > 0), + has_smolgen_(weights.has_smolgen), BaseLayer(weights.ip_emb_b.size(), 8, 8, nullptr) { allocAndUpload(&ip_emb_w_, weights.ip_emb_w, scratch); @@ -1827,11 +2017,17 @@ AttentionBody::AttentionBody(const LegacyWeights& weights, allocAndUpload(&ip_add_gate_, weights.ip_add_gate, scratch); } + if (has_smolgen_) { + allocAndUpload(&smolgen_global_, weights.smolgen_w, scratch); + smolgen_global_size_ = 64 * 64; + } + int num_encoders = weights.encoder.size(); float alpha = (float) pow(2.0 * num_encoders, 0.25); for (const auto& enc : weights.encoder) { EncoderBlock* pW = new EncoderBlock( - enc, scratch, encoder_head_count_, embedding_op_size_, alpha); + enc, scratch, encoder_head_count_, embedding_op_size_, alpha, + smolgen_global_, smolgen_global_size_); encoder_weights_.emplace_back(pW); } } @@ -1840,6 +2036,9 @@ template AttentionBody::~AttentionBody() { ReportCUDAErrors(cudaFree(ip_emb_w_)); ReportCUDAErrors(cudaFree(ip_emb_b_)); + ReportCUDAErrors(cudaFree(ip_mult_gate_)); + ReportCUDAErrors(cudaFree(ip_add_gate_)); + ReportCUDAErrors(cudaFree(smolgen_global_)); for (const auto pEnc : encoder_weights_) delete pEnc; } @@ -1900,11 +2099,14 @@ void AttentionBody::Eval( } // 2. Encoder layers + int i = 0; for (const auto pEnc : encoder_weights_) { pEnc->Eval(N, scratch1, scratch0, scratch2, scratch3, cublas, stream, - default_act_); - } // End of encoder blocks + default_act_, i++); + if (i == 3) break; + } // End of encoder blocks +dumpTensor(scratch1, 50, "Attention body output"); } diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index f2992c68a0..19f846e458 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -335,12 +335,14 @@ class ResidualBlock : public BaseLayer { template class EncoderBlock { public: - EncoderBlock(const LegacyWeights::EncoderLayer& cpu_weights, void* scratch, int heads, int size, float alpha); + EncoderBlock(const LegacyWeights::EncoderLayer& cpu_weights, void* scratch, + int heads, int size, float alpha, DataType* smolgen_global_scratch, + int smolgen_global_size); ~EncoderBlock(); void Eval(int N, DataType* inpop, DataType* scratch0, DataType* scratch1, DataType* scratch2, cublasHandle_t cublas, - cudaStream_t stream, ActivationFunction act) const; + cudaStream_t stream, ActivationFunction act, int layer_id = 0) const; // all GPU side pointers DataType *mha_q_w, *mha_q_b; @@ -356,6 +358,13 @@ class EncoderBlock { DataType *ln2_gammas, *ln2_betas; + DataType *smol_compress; + DataType *smol_dense1_w, *smol_dense1_b; + DataType *smol_dense2_w, *smol_dense2_b; + DataType *smol_ln1_gammas, *smol_ln1_betas; + DataType *smol_ln2_gammas, *smol_ln2_betas; + DataType *smol_global; + int mha_q_size_; int mha_k_size_; int mha_v_size_; @@ -368,6 +377,14 @@ class EncoderBlock { int encoder_heads_; float alpha_; // scale to apply to skip connection add + + const bool has_smolgen_; + + // Output sizes for smolgen layers. + int smol_compress_size_; + int smol_dense_1_size_; + int smol_dense_2_size_; + int smol_global_size_; }; // The Attention policy head implementation @@ -459,13 +476,16 @@ class AttentionBody : public BaseLayer { // GPU allocations to hold various weights used by the attention policy head DataType *ip_emb_w_, *ip_emb_b_; // "embedding" layer in net body DataType *ip_mult_gate_, *ip_add_gate_; // input gating + DataType *smolgen_global_; // global smolgen weights for all encoder layers int embedding_op_size_; int encoder_head_count_; std::vector*> encoder_weights_; ActivationFunction default_act_; int num_resi_blocks_; int input_c_; + int smolgen_global_size_; const bool has_gating_; + const bool has_smolgen_; }; diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index ef6425ef4c..27c83bee84 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -592,6 +592,7 @@ class CudaNetwork : public Network { stream = 0; // default stream cublas = cublas_; } + printf("\n multistream: %i", multi_stream_); bool fp16 = std::is_same::value; if (fp16) { diff --git a/src/neural/cuda/winograd_helper.inc b/src/neural/cuda/winograd_helper.inc index 456649ba87..6f9d00d55a 100644 --- a/src/neural/cuda/winograd_helper.inc +++ b/src/neural/cuda/winograd_helper.inc @@ -44,6 +44,10 @@ __device__ __forceinline__ float activate(float cVal, case RELU: if (cVal < 0) cVal = 0; break; + case RELU_2: + if (cVal < 0) cVal = 0; + cVal *= cVal; + break; case TANH: cVal = tanh(cVal); break; @@ -61,6 +65,9 @@ __device__ __forceinline__ float activate(float cVal, case MISH: cVal = mishActivate(cVal); break; + case SWISH: + cVal /= (1.0f + __expf(-cVal)); + break; } return cVal; } diff --git a/src/neural/shared/activation.h b/src/neural/shared/activation.h index 8a55df486b..18f1c4868d 100644 --- a/src/neural/shared/activation.h +++ b/src/neural/shared/activation.h @@ -22,7 +22,7 @@ #include namespace lczero { -enum ActivationFunction { NONE, RELU, TANH, SIGMOID, SELU, MISH, SWISH }; +enum ActivationFunction { NONE, RELU, TANH, SIGMOID, SELU, MISH, SWISH, RELU_2 }; // Softmax activation void SoftmaxActivation(const size_t size, const float* input, float* output); From e1cd35ee39505db555c6ec4a68419ca100fcb4a9 Mon Sep 17 00:00:00 2001 From: Alma Date: Sun, 25 Dec 2022 00:41:07 +0100 Subject: [PATCH 14/70] Fixed unstable softmax implementation. --- src/neural/cuda/common_kernels.cu | 29 +++++++--- src/neural/cuda/layers.cc | 83 ++++++++++++++++++++--------- src/neural/cuda/winograd_helper.inc | 18 +++++++ 3 files changed, 99 insertions(+), 31 deletions(-) diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index 01c8329f01..ae76fb7017 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -700,8 +700,12 @@ __global__ void softmax_opt_64_kernel(T* output, const T* input, int N) { copyAs(&x[0], &input[index * 2]); } - ex[0] = exp(x[0]); - ex[1] = exp(x[1]); + float threadMax = max(x[0], x[1]); + float maxval = warpMax(threadMax); + maxval = __shfl_sync(0xFFFFFFFF, maxval, 0); + + ex[0] = exp(x[0] - maxval); + ex[1] = exp(x[1] - maxval); float threadSum = ex[0] + ex[1]; float Sum = warpReduce(threadSum); @@ -734,14 +738,25 @@ __global__ void softmax_kernel(T* output, const T* input) { int C = blockDim.x; int index = n * C + c; - __shared__ float sum; - if (c == 0) sum = 0; - __syncthreads(); - // softmax = tf.exp(logits) / tf.reduce_sum(tf.exp(logits), axis) float x = (float)input[index]; - float ex = exp(x); + + __shared__ float sum, maxval; + if (c == 0) { + sum = 0; + maxval = x; + } + + __syncthreads(); + + // Get max across warp first, and then update across C dimension + float warpmax = warpMax(x); + if ((c & 0x1F) == 0) atomicMaxFloat(&maxval, warpmax); + + __syncthreads(); + + float ex = exp(x - maxval); // compute warp wide sums first float val = warpReduce(ex); diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index bfc0f322ae..8ac46f046e 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -38,16 +38,20 @@ namespace lczero { -#if 1 +#if 0 // debug code to dump allocation in GPU memory template -void dumpTensor(T* memory, int elements, const char* message) { +void dumpTensor(T* memory, int elements, const char* message, bool only_summary = false) { const bool fp16 = std::is_same::value; printf("\n%s\n", message); int elementSize = (int) (fp16 ? sizeof(half) : sizeof(float)); int bytes = elements * elementSize; void *temp = malloc(bytes); cudaMemcpy(temp, memory, bytes, cudaMemcpyDeviceToHost); + float maxval = -std::numeric_limits::max(); + float minval = std::numeric_limits::max(); + int nans = 0; + int nanss[10] {}; for (int i = 0; i < elements; i++) { @@ -62,11 +66,30 @@ void dumpTensor(T* memory, int elements, const char* message) { float *arr = (float *)temp; val = arr[i]; } - // printf("%8.4f ", val); - // if ((i % 8) == 7) printf("\n"); - printf("%i;%.6f\n", i, val); + maxval = std::max(maxval, val); + minval = std::min(minval, val); + + if (std::isnan(val)) { + if (nans < 10) nanss[nans] = i; + nans++; + } + + if (!only_summary || i < 2 || i == elements - 1) { + // printf("%8.4f ", val); + // if ((i % 8) == 7) printf("\n"); + printf("%i;%.6f\n", i, val); + } } free(temp); + if (maxval == -std::numeric_limits::max()) + maxval = std::numeric_limits::quiet_NaN(); + if (minval == std::numeric_limits::max()) + minval = std::numeric_limits::quiet_NaN(); + + printf("Max: %.6f, Min: %.6f, NaNs: %i of %i", maxval, minval, nans, elements); + printf("\nNaN indices: "); + for (int i=0; i 10) printf("......"); printf("\n"); } #endif @@ -1554,10 +1577,10 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, const int d_model = mha_q_size_; const int depth = d_model / encoder_heads_; - char desc [40]; - snprintf (desc, 40, "Encoder layer #%d input", layer_id); - dumpTensor(scratch1, 10, desc); - const int layer_to_print = 2; + // char desc [100]; + // snprintf (desc, 100, "Encoder layer #%d\n======================================\n\nInput", layer_id); + // dumpTensor(scratch1, d_model * 64 * N, desc, true); + // const int layer_to_print = 1; // Calculate smolgen weights. Do this first so we can make use of // scratch2 and scratch3. @@ -1572,7 +1595,10 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)smol_compress, num_inputs, scratch1, num_inputs, 0.0f, scratch0, num_outputs); - if (layer_id == layer_to_print) dumpTensor(scratch0, 10, "smol compress"); + + // if (layer_id == layer_to_print) dumpTensor(scratch0, 10, "smol compress"); + // dumpTensor(scratch0, num_outputs * batch, "smol compress", true); + // printf("dmodel: %i, outputs: %i, batch: %i\n", num_inputs, num_outputs, batch); } @@ -1598,14 +1624,18 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, // dumpTensor(scratch2, 100, "Batch 1"); // dumpTensor(scratch2 + 256, 100, "Batch 2"); // return; - if (layer_id == layer_to_print) dumpTensor(smol_dense1_w, 1000, "smol_dense1_w"); + // if (layer_id == layer_to_print) dumpTensor(smol_dense1_w, 1000, "smol_dense1_w"); + // dumpTensor(smol_dense1_w, num_inputs * num_outputs, "smol_dense1_w", true); + // dumpTensor(scratch2, num_outputs * batch, "smol hidden1", true); - if (layer_id == layer_to_print) dumpTensor(scratch2, 10, "smol hidden1"); + // if (layer_id == layer_to_print) dumpTensor(scratch2, 10, "smol hidden1"); LayerNorm(batch, num_outputs, scratch0, scratch2, smol_dense1_b, scratch2, smol_ln1_gammas, smol_ln1_betas, 1e-6, 0.0, /* alpha = 0 since we don't need skip */ SWISH, stream); - if (layer_id == layer_to_print) dumpTensor(scratch0, 10, "smol hidden1 ln"); + + // if (layer_id == layer_to_print) dumpTensor(scratch0, 10, "smol hidden1 ln"); + // dumpTensor(scratch0, num_outputs * batch, "smol hidden1 ln", true); // dumpTensor(scratch0, 100, "Batch 1"); // dumpTensor(scratch0 + 256, 100, "Batch 2"); @@ -1626,7 +1656,9 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, LayerNorm(batch, num_outputs, scratch0, scratch2, smol_dense2_b, scratch2, smol_ln2_gammas, smol_ln2_betas, 1e-6, 0.0, /* alpha = 0 since we don't need skip */ SWISH, stream); - if (layer_id == layer_to_print) dumpTensor(scratch0, 10, "smol gen_from"); + + // if (layer_id == layer_to_print) dumpTensor(scratch0, 10, "smol gen_from"); + // dumpTensor(scratch0, num_outputs * batch, "smol gen_from", true); // dumpTensor(scratch0, 100, "Batch 1"); // dumpTensor(scratch0 + num_outputs, 100, "Batch 2"); @@ -1660,11 +1692,12 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, 0, /*strideA*/ scratch0 + inputOffset /*B*/, num_inputs * encoder_heads_ /*LDB*/, - num_inputs * encoder_heads_, /*strideB*/ + num_inputs * encoder_heads_ /*strideB*/, 0.0f, scratch3 + outputOffset /*C*/, // output goes to scratch1 num_outputs /*LDC*/, num_outputs /*strideC*/, batch); } + // dumpTensor(scratch3, num_outputs * batch, "smol gen weights", true); // for (int b = 0; b < N; b++) { // for (int h = 0; h < encoder_heads_; h++) { // // int start = (b * 12 + h) * 4096; @@ -1715,7 +1748,7 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) output = tf.matmul(attention_weights, v) */ - if (layer_id == layer_to_print) dumpTensor(scratch3, 10, "smolgen before"); + // if (layer_id == layer_to_print) dumpTensor(scratch3, 10, "smolgen before"); // shape(k)[-1] = depth float factor = 1.0f / sqrt((float)depth); @@ -1742,17 +1775,15 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, 64 /*LDC*/, 64 * 64 /*strideC*/, N); } // dumpTensor(scratch2, 64*64*12*2, "softmax after attention weights"); - if (layer_id == layer_to_print) dumpTensor(scratch2, 10, "attn"); + // if (layer_id == layer_to_print) dumpTensor(scratch2, 10, "attn"); + // dumpTensor(scratch2, N * encoder_heads_ * 64 * 64, "attn", true); // Add smolgen weights to the scaled matmul_qk attention logits. if (has_smolgen_) { int size = N * encoder_heads_ * 64 * 64; addVectors(scratch2, scratch2, scratch3, size, size, size, NONE, stream); } - if (layer_id == layer_to_print) { - dumpTensor(scratch3, 10, "smolgen"); - dumpTensor(scratch2, 10, "attn + smolgen"); - } + // dumpTensor(scratch2, N * encoder_heads_ * 64 * 64, "attn + smolgen weights", true); // for (int b = 0; b < N; b++) { // for (int h = 0; h < encoder_heads_; h++) { @@ -1767,6 +1798,8 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, // attention_weights = tf.nn.softmax(scaled_attention_logits, axis = -1) // attention_weights -> scratch2 Softmax(encoder_heads_ * N * 64, 64, scratch2, scratch2, stream); + // if (layer_id == layer_to_print) dumpTensor(scratch2, N * encoder_heads_ * 64 * 64, "attn + smolgen weights"); + // dumpTensor(scratch2, N * encoder_heads_ * 64 * 64, "attn softmax", true); // output = tf.matmul(attention_weights, v) for (int i = 0; i < encoder_heads_; i++) { @@ -1810,8 +1843,9 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, scratch0, num_inputs, 0.0f, scratch1, num_outputs); addBiasBatched(scratch1, scratch1, ffn_dense1_b, 1, batch, num_outputs, has_smolgen_ ? RELU_2 : act, stream); // @todo sqr relu to have its own flag + // dumpTensor(scratch1, num_outputs * batch, "ffn1", true); } -if (layer_id == layer_to_print) dumpTensor(scratch1, 10, "ffn2"); + // #FFN dense 2, scratch1 -> scratch2 { const int num_inputs = encoder_dff; @@ -1820,6 +1854,7 @@ if (layer_id == layer_to_print) dumpTensor(scratch1, 10, "ffn2"); cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)ffn_dense2_w, num_inputs, scratch1, num_inputs, 0.0f, scratch2, num_outputs); + // dumpTensor(scratch2, num_outputs * batch, "ffn2", true); } // LN2: skip connection and layer normilization (also bias add of prev gemm) @@ -2103,10 +2138,10 @@ void AttentionBody::Eval( for (const auto pEnc : encoder_weights_) { pEnc->Eval(N, scratch1, scratch0, scratch2, scratch3, cublas, stream, default_act_, i++); - if (i == 3) break; + // if (i == 3) break; } // End of encoder blocks -dumpTensor(scratch1, 50, "Attention body output"); +// dumpTensor(scratch1, 50, "Attention body output"); } diff --git a/src/neural/cuda/winograd_helper.inc b/src/neural/cuda/winograd_helper.inc index 6f9d00d55a..f762d85430 100644 --- a/src/neural/cuda/winograd_helper.inc +++ b/src/neural/cuda/winograd_helper.inc @@ -426,6 +426,24 @@ __device__ __forceinline__ float warpReduce(float x) { return x; } +// fast max reduction for the warp +__device__ __forceinline__ float warpMax(float x) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + x = max(x, __shfl_xor_sync(0xFFFFFFFF, x, mask)); + + return x; +} + +// atomic max implementation for floats +__device__ __forceinline__ float atomicMaxFloat (float * addr, float val) { + float max; + max = (val >= 0) ? __int_as_float(atomicMax((int *)addr, __float_as_int(val))) : + __uint_as_float(atomicMin((unsigned int *)addr, __float_as_uint(val))); + + return max; +} + // Helper fuction to do vector loads/stores template __device__ __forceinline__ void copyAs(void* dst, const void* src) { From b5d6930e89ff68135437b765a4ea6e674526a121 Mon Sep 17 00:00:00 2001 From: Alma Date: Sun, 25 Dec 2022 00:43:23 +0100 Subject: [PATCH 15/70] Remove debug log --- src/neural/cuda/network_cuda.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 27c83bee84..ef6425ef4c 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -592,7 +592,6 @@ class CudaNetwork : public Network { stream = 0; // default stream cublas = cublas_; } - printf("\n multistream: %i", multi_stream_); bool fp16 = std::is_same::value; if (fp16) { From e9cda40300b909f2470c0d5f4f442dd6ee511926 Mon Sep 17 00:00:00 2001 From: Alma Date: Sun, 25 Dec 2022 14:19:50 +0100 Subject: [PATCH 16/70] Tilp's fix for smolgen gemms. --- src/neural/cuda/common_kernels.cu | 33 +++++++++++++++++++++++ src/neural/cuda/kernels.h | 6 +++++ src/neural/cuda/layers.cc | 45 ++++++++----------------------- 3 files changed, 50 insertions(+), 34 deletions(-) diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index ae76fb7017..36030d6c46 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -73,6 +73,34 @@ void addVectors(T* c, T* a, T* b, int size, int asize, int bsize, ReportCUDAErrors(cudaGetLastError()); } +template +__global__ void addVectorsHNC_NHC_kernel(T* a, T* b, int N, int H, int C) { + int i = threadIdx.x + blockDim.x * blockIdx.x; + if (i < N * H * C) { + int orig_i = i; + int c = i % C; + i /= C; + int n = i % N; + i /= N; + int h = i; + float aVal = (float)a[orig_i]; + float bVal = (float)b[n * H * C + h * C + c]; + + float cVal = aVal + bVal; + + a[orig_i] = (T)cVal; + } +} + +template +void addVectorsHNC_NHC(T* a, T* b, int N, int H, int C, cudaStream_t stream) { + const int kBlockSize = 256; + int blocks = DivUp(N * H * C, kBlockSize); + addVectorsHNC_NHC_kernel<<>>(a, b, N, H, C); + + ReportCUDAErrors(cudaGetLastError()); +} + template __global__ void addBiasBatched_kernel(T* output, const T* input, const T* bias, int N, int C) { @@ -1127,6 +1155,11 @@ template void addVectors(half* c, half* a, half* b, int size, int asize, int bsize, ActivationFunction act, cudaStream_t stream); +template void addVectorsHNC_NHC(float* a, float* b, int N, int H, int C, + cudaStream_t stream); +template void addVectorsHNC_NHC(half* a, half* b, int N, int H, int C, + cudaStream_t stream); + template void addBiasBatched(float* output, const float* input, const float* bias, int Batch, int N, int C, ActivationFunction activation, diff --git a/src/neural/cuda/kernels.h b/src/neural/cuda/kernels.h index 4c56dcad28..7a66efcaa4 100644 --- a/src/neural/cuda/kernels.h +++ b/src/neural/cuda/kernels.h @@ -36,6 +36,12 @@ template void addVectors(T* c, T* a, T* b, int size, int asize, int bsize, ActivationFunction activation, cudaStream_t stream); +// Adds two vectors of equal size overwriting the first with the sum. +// This specialisation performs a transposition of the first 2 indexes +// of the second while performing the addition. +template +void addVectorsHNC_NHC(T* a, T* b, int N, int H, int C, cudaStream_t stream); + // Optimized kernel to add bias to innermost dimension // and perform optional activation (to be used with GEMMs/fully connected) template diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 8ac46f046e..1f0a7566a3 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -1583,7 +1583,7 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, // const int layer_to_print = 1; // Calculate smolgen weights. Do this first so we can make use of - // scratch2 and scratch3. + // scratch0, scratch2 and scratch3. if (has_smolgen_) { { // Compress. @@ -1609,18 +1609,10 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, const int num_inputs = 64 * smol_compress_size_; const int num_outputs = smol_dense_1_size_; const int batch = N; - cublasXGemmStridedBatched( - cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs /*M*/, - batch /*N*/, num_inputs /*K*/, 1.0f, - smol_dense1_w /*A*/, // "smol_weight_gen" weights - num_inputs /*LDA*/, - 0, /*strideA*/ - scratch0 /*B*/, - num_inputs /*LDB*/, - num_inputs, /*strideB*/ - 0.0f, - scratch2 /*C*/, // output goes to scratch3 - num_outputs /*LDC*/, num_outputs /*strideC*/, batch); + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, (const DataType*)smol_dense1_w, num_inputs, + scratch0, num_inputs, 0.0f, scratch2, num_outputs); + // dumpTensor(scratch2, 100, "Batch 1"); // dumpTensor(scratch2 + 256, 100, "Batch 2"); // return; @@ -1678,25 +1670,10 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, */ const int num_inputs = smol_dense_2_size_ / encoder_heads_; /* num_inputs == gen_sz == 256 */ const int num_outputs = smol_global_size_; /* hwhw: 64 * 64 */ - const int batch = N; - for (int i = 0; i < encoder_heads_; i++) { - int inputOffset = i * num_inputs; - // int inputOffset = 1 * num_inputs; - int outputOffset = i * batch * num_outputs; - // int outputOffset = 1 * batch * num_outputs; - cublasXGemmStridedBatched( - cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs /*M*/, - batch /*N*/, num_inputs /*K*/, 1.0f, - smol_global /*A*/, // "smol_weight_gen" weights - num_inputs /*LDA*/, - 0, /*strideA*/ - scratch0 + inputOffset /*B*/, - num_inputs * encoder_heads_ /*LDB*/, - num_inputs * encoder_heads_ /*strideB*/, - 0.0f, - scratch3 + outputOffset /*C*/, // output goes to scratch1 - num_outputs /*LDC*/, num_outputs /*strideC*/, batch); - } + const int batch = N*encoder_heads_; + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, (const DataType*)smol_global, num_inputs, + scratch0, num_inputs, 0.0f, scratch3, num_outputs); // dumpTensor(scratch3, num_outputs * batch, "smol gen weights", true); // for (int b = 0; b < N; b++) { // for (int h = 0; h < encoder_heads_; h++) { @@ -1779,9 +1756,9 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, // dumpTensor(scratch2, N * encoder_heads_ * 64 * 64, "attn", true); // Add smolgen weights to the scaled matmul_qk attention logits. + // smolgen weights need to be transposed first, kernel handles that. if (has_smolgen_) { - int size = N * encoder_heads_ * 64 * 64; - addVectors(scratch2, scratch2, scratch3, size, size, size, NONE, stream); + addVectorsHNC_NHC(scratch2, scratch3, N, encoder_heads_, 64 * 64, stream); } // dumpTensor(scratch2, N * encoder_heads_ * 64 * 64, "attn + smolgen weights", true); From e98ef8dd54330e5a717a23e98498de456a3e1f5d Mon Sep 17 00:00:00 2001 From: Alma Date: Mon, 26 Dec 2022 16:02:51 +0100 Subject: [PATCH 17/70] Remove debug code --- src/neural/cuda/common_kernels.cu | 22 -------- src/neural/cuda/kernels.h | 3 -- src/neural/cuda/layers.cc | 87 ++----------------------------- src/neural/cuda/layers.h | 2 +- 4 files changed, 5 insertions(+), 109 deletions(-) diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index 36030d6c46..8a72ed3885 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -1110,25 +1110,6 @@ void applyInputGating(T* output, const T* input, const T* mult, const T* add, ReportCUDAErrors(cudaGetLastError()); } -template -__global__ void mask_layer_kernel(T* output, const T* input, int size) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - - //printf("idx %i", idx); - if (idx >= size) return; - if (idx >= 2048) output[idx] = 0.0; - // if (idx > 2048) output[idx] = input[idx]; - // else output[idx] = 0.0; -} - -template -void maskLayer(T* output, const T* input, int size, cudaStream_t stream) { - int blockSize = min(1024, size); - dim3 gridSize = size / blockSize; - mask_layer_kernel<<>>(output, input, size); - ReportCUDAErrors(cudaGetLastError()); -} - // Template instantiation. template void copyTypeConverted(half* op, float* ip, int N, cudaStream_t stream); @@ -1371,8 +1352,5 @@ template void applyInputGating(half* output, const half* input, const half template void applyInputGating(float* output, const float* input, const float* mult, const float* add, int N, int C, int output_size, cudaStream_t stream); - -template void maskLayer(half* output, const half* input, int size, cudaStream_t); -template void maskLayer(float* output, const float* input, int size, cudaStream_t); } // namespace cudnn_backend } // namespace lczero diff --git a/src/neural/cuda/kernels.h b/src/neural/cuda/kernels.h index 7a66efcaa4..cb15250768 100644 --- a/src/neural/cuda/kernels.h +++ b/src/neural/cuda/kernels.h @@ -143,8 +143,5 @@ void inputPreprocessForAttentionBody(T* output, const T* input, int N, template void applyInputGating(T* output, const T* input, const T* mult, const T* add, int N, int HW, int C, cudaStream_t stream); - -template -void maskLayer(T* output, const T* input, int size, cudaStream_t); } // namespace cudnn_backend } // namespace lczero diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index fffd60c6e2..27757792c4 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -1581,15 +1581,10 @@ template void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, DataType* scratch2, DataType* scratch3, cublasHandle_t cublas, cudaStream_t stream, - ActivationFunction act, int layer_id) const { + ActivationFunction act) const { const int d_model = mha_q_size_; const int depth = d_model / encoder_heads_; - // char desc [100]; - // snprintf (desc, 100, "Encoder layer #%d\n======================================\n\nInput", layer_id); - // dumpTensor(scratch1, d_model * 64 * N, desc, true); - // const int layer_to_print = 1; - // Calculate smolgen weights. Do this first so we can make use of // scratch0, scratch2 and scratch3. if (has_smolgen_) { @@ -1603,11 +1598,6 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)smol_compress, num_inputs, scratch1, num_inputs, 0.0f, scratch0, num_outputs); - - // if (layer_id == layer_to_print) dumpTensor(scratch0, 10, "smol compress"); - // dumpTensor(scratch0, num_outputs * batch, "smol compress", true); - // printf("dmodel: %i, outputs: %i, batch: %i\n", num_inputs, num_outputs, batch); - } { @@ -1621,25 +1611,9 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, num_inputs, 1.0f, (const DataType*)smol_dense1_w, num_inputs, scratch0, num_inputs, 0.0f, scratch2, num_outputs); - // dumpTensor(scratch2, 100, "Batch 1"); - // dumpTensor(scratch2 + 256, 100, "Batch 2"); - // return; - // if (layer_id == layer_to_print) dumpTensor(smol_dense1_w, 1000, "smol_dense1_w"); - // dumpTensor(smol_dense1_w, num_inputs * num_outputs, "smol_dense1_w", true); - // dumpTensor(scratch2, num_outputs * batch, "smol hidden1", true); - - // if (layer_id == layer_to_print) dumpTensor(scratch2, 10, "smol hidden1"); - LayerNorm(batch, num_outputs, scratch0, scratch2, smol_dense1_b, scratch2, smol_ln1_gammas, smol_ln1_betas, 1e-6, 0.0, /* alpha = 0 since we don't need skip */ SWISH, stream); - - // if (layer_id == layer_to_print) dumpTensor(scratch0, 10, "smol hidden1 ln"); - // dumpTensor(scratch0, num_outputs * batch, "smol hidden1 ln", true); - - // dumpTensor(scratch0, 100, "Batch 1"); - // dumpTensor(scratch0 + 256, 100, "Batch 2"); - // return; } { @@ -1656,18 +1630,6 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, LayerNorm(batch, num_outputs, scratch0, scratch2, smol_dense2_b, scratch2, smol_ln2_gammas, smol_ln2_betas, 1e-6, 0.0, /* alpha = 0 since we don't need skip */ SWISH, stream); - - // if (layer_id == layer_to_print) dumpTensor(scratch0, 10, "smol gen_from"); - // dumpTensor(scratch0, num_outputs * batch, "smol gen_from", true); - - // dumpTensor(scratch0, 100, "Batch 1"); - // dumpTensor(scratch0 + num_outputs, 100, "Batch 2"); - // return; - // Smolgen global 'smol_weight_gen' - // input shape: N, heads, gen_sz - // output shape: heads, N, 64 * 64 - // transpose: heads, N, 64, 64 to match scaled attention weights - } { @@ -1682,17 +1644,6 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)smol_global, num_inputs, scratch0, num_inputs, 0.0f, scratch3, num_outputs); - // dumpTensor(scratch3, num_outputs * batch, "smol gen weights", true); - // for (int b = 0; b < N; b++) { - // for (int h = 0; h < encoder_heads_; h++) { - // // int start = (b * 12 + h) * 4096; - // int start = (h * N + b) * 4096; - // char desc [40]; - // snprintf (desc, 40, "batch %d head %d", b, h); - // dumpTensor(scratch3 + start, 50, desc); - // } - // } - // return; } } @@ -1733,7 +1684,6 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) output = tf.matmul(attention_weights, v) */ - // if (layer_id == layer_to_print) dumpTensor(scratch3, 10, "smolgen before"); // shape(k)[-1] = depth float factor = 1.0f / sqrt((float)depth); @@ -1759,32 +1709,16 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, scratch2 + outOffset /*C*/, // output (matmul_qk) goes to scratch2 64 /*LDC*/, 64 * 64 /*strideC*/, N); } - // dumpTensor(scratch2, 64*64*12*2, "softmax after attention weights"); - // if (layer_id == layer_to_print) dumpTensor(scratch2, 10, "attn"); - // dumpTensor(scratch2, N * encoder_heads_ * 64 * 64, "attn", true); // Add smolgen weights to the scaled matmul_qk attention logits. // smolgen weights need to be transposed first, kernel handles that. if (has_smolgen_) { addVectorsHNC_NHC(scratch2, scratch3, N, encoder_heads_, 64 * 64, stream); } - // dumpTensor(scratch2, N * encoder_heads_ * 64 * 64, "attn + smolgen weights", true); - - // for (int b = 0; b < N; b++) { - // for (int h = 0; h < encoder_heads_; h++) { - // // int start = (b * 12 + h) * 4096; - // int start = (h * N + b) * 4096; - // char desc [40]; - // snprintf (desc, 40, "batch %d head %d", b, h); - // dumpTensor(scratch2 + start, 50, desc); - // } - // } - // return; + // attention_weights = tf.nn.softmax(scaled_attention_logits, axis = -1) // attention_weights -> scratch2 Softmax(encoder_heads_ * N * 64, 64, scratch2, scratch2, stream); - // if (layer_id == layer_to_print) dumpTensor(scratch2, N * encoder_heads_ * 64 * 64, "attn + smolgen weights"); - // dumpTensor(scratch2, N * encoder_heads_ * 64 * 64, "attn softmax", true); // output = tf.matmul(attention_weights, v) for (int i = 0; i < encoder_heads_; i++) { @@ -1800,8 +1734,7 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, 0.0f, scratch3 + offset /*C*/, // output goes to scratch3 d_model /*LDC*/, 64 * d_model /*strideC*/, N); } - // dumpTensor(scratch2, 200, "softmax after attention weights"); - // #final dense layer (mha_dense), scratch3 -> scratch2 + { const int num_inputs = d_model; const int num_outputs = embedding_op_size_; @@ -1828,7 +1761,6 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, scratch0, num_inputs, 0.0f, scratch1, num_outputs); addBiasBatched(scratch1, scratch1, ffn_dense1_b, 1, batch, num_outputs, has_smolgen_ ? RELU_2 : act, stream); // @todo sqr relu to have its own flag - // dumpTensor(scratch1, num_outputs * batch, "ffn1", true); } // #FFN dense 2, scratch1 -> scratch2 @@ -1839,7 +1771,6 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)ffn_dense2_w, num_inputs, scratch1, num_inputs, 0.0f, scratch2, num_outputs); - // dumpTensor(scratch2, num_outputs * batch, "ffn2", true); } // LN2: skip connection and layer normilization (also bias add of prev gemm) @@ -1847,13 +1778,6 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, LayerNorm(N * 64, embedding_op_size_, scratch1, scratch2, ffn_dense2_b, scratch0, ln2_gammas, ln2_betas, 1e-6, alpha_, NONE, stream); - // dumpTensor(scratch1, 40, "layernorm 2"); - // for (int b = 0; b < N; b++) { - // int start = b * embedding_op_size_ * 64; - // char desc [40]; - // snprintf (desc, 40, "batch %d", b); - // dumpTensor(scratch1 + start, 50, desc); - // } } template @@ -2122,11 +2046,8 @@ void AttentionBody::Eval( int i = 0; for (const auto pEnc : encoder_weights_) { pEnc->Eval(N, scratch1, scratch0, scratch2, scratch3, cublas, stream, - default_act_, i++); - // if (i == 3) break; - + default_act_); } // End of encoder blocks -// dumpTensor(scratch1, 50, "Attention body output"); } diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index 19f846e458..69263b0258 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -342,7 +342,7 @@ class EncoderBlock { void Eval(int N, DataType* inpop, DataType* scratch0, DataType* scratch1, DataType* scratch2, cublasHandle_t cublas, - cudaStream_t stream, ActivationFunction act, int layer_id = 0) const; + cudaStream_t stream, ActivationFunction act) const; // all GPU side pointers DataType *mha_q_w, *mha_q_b; From b5afc19a79bcb4b1b751a07eaf73683525306254 Mon Sep 17 00:00:00 2001 From: Alma Date: Tue, 27 Dec 2022 18:44:35 +0100 Subject: [PATCH 18/70] Add tilps perf improvement on existing attention qkv matmuls. --- src/neural/cuda/common_kernels.cu | 108 +++++++++++++++++++++++++ src/neural/cuda/kernels.h | 6 ++ src/neural/cuda/layers.cc | 126 +++++++++++++++++++++--------- src/neural/cuda/layers.h | 10 ++- src/neural/cuda/network_cuda.cc | 4 +- 5 files changed, 214 insertions(+), 40 deletions(-) diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index 8a72ed3885..ebaf2b83dc 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -200,6 +200,105 @@ void addBiasBatched(T* output, const T* input, const T* bias, int Batch, int N, ReportCUDAErrors(cudaGetLastError()); } +template +__global__ void addBiasBatched_kernel(T* output, const T* input, const T* bias, + int N, int C, int Nstride) { + int batch = blockIdx.y; + int n = blockIdx.x * blockDim.y + threadIdx.y; + if (n >= N) return; + int c = threadIdx.x * 4; + + int biasIndex = batch * C + c; + int tensorIndex = batch * Nstride * C + n * C + c; + + float val[4]; + float b[4]; + + // Load from memory + const bool fp16 = std::is_same::value; + if (fp16) { + half inp[4]; + copyAs(&inp[0], &input[tensorIndex]); +#pragma unroll + for (int i = 0; i < 4; i++) val[i] = (float)inp[i]; + + copyAs(&inp[0], &bias[biasIndex]); +#pragma unroll + for (int i = 0; i < 4; i++) b[i] = (float)inp[i]; + } else { + copyAs(&val[0], &input[tensorIndex]); + copyAs(&b[0], &bias[biasIndex]); + } + + // Perform bias add and activation +#pragma unroll + for (int i = 0; i < 4; i++) { + float x = val[i] + b[i]; + x = activate(x, act); + val[i] = x; + } + + // write to memory + if (fp16) { + half op[4]; +#pragma unroll + for (int i = 0; i < 4; i++) op[i] = (half)val[i]; + copyAs(&output[tensorIndex], &op[0]); + } else { + copyAs(&output[tensorIndex], &val[0]); + } +} + +// Input/output tensors are Batch * N * C +// bias tensor is N * C (i.e, different bias for each Batch dimension) +template +void addBiasBatched(T* output, const T* input, const T* bias, int Batch, int N, + int C, int Nstride, ActivationFunction activation, cudaStream_t stream) { + // process 4 elements per thread to achieve close to peak memory bandwidth + if (C % 4 != 0) throw Exception("unsupported filter size"); + if (C > 4096) throw Exception("unsupported filter size"); + + dim3 blockDim, gridDim; + blockDim.x = C / 4; + blockDim.y = std::min(std::max(512 / blockDim.x, 1u), (unsigned int) N); + blockDim.z = 1; + gridDim.x = DivUp(N, blockDim.y); + gridDim.y = Batch; + gridDim.z = 1; + + switch (activation) { + case NONE: + addBiasBatched_kernel<<>>( + output, input, bias, N, C, Nstride); + break; + case SELU: + addBiasBatched_kernel<<>>( + output, input, bias, N, C, Nstride); + break; + case MISH: + addBiasBatched_kernel + <<>>(output, input, bias, N, C, Nstride); + break; + case RELU: + addBiasBatched_kernel + <<>>(output, input, bias, N, C, Nstride); + break; + case SWISH: + addBiasBatched_kernel + <<>>(output, input, bias, N, C, Nstride); + break; + case RELU_2: // square relu + addBiasBatched_kernel + <<>>(output, input, bias, N, C, Nstride); + break; + default: + throw Exception( + "unsupported activation in addBiasBatched. Add in switch-case here"); + } + + ReportCUDAErrors(cudaGetLastError()); +} + template __global__ void addBias_NCHW_kernel(T* c, T* a, T* b, int N, int C, int H, int W, ActivationFunction activation) { @@ -1150,6 +1249,15 @@ template void addBiasBatched(half* output, const half* input, ActivationFunction activation, cudaStream_t stream); +template void addBiasBatched(float* output, const float* input, + const float* bias, int Batch, int N, int C, int Nstride, + ActivationFunction activation, + cudaStream_t stream); +template void addBiasBatched(half* output, const half* input, + const half* bias, int Batch, int N, int C, int Nstride, + ActivationFunction activation, + cudaStream_t stream); + template void addBias_NCHW(float* c, float* a, float* b, int N, int C, int H, int W, ActivationFunction activation, cudaStream_t stream); diff --git a/src/neural/cuda/kernels.h b/src/neural/cuda/kernels.h index cb15250768..7545e10241 100644 --- a/src/neural/cuda/kernels.h +++ b/src/neural/cuda/kernels.h @@ -48,6 +48,12 @@ template void addBiasBatched(T* output, const T* input, const T* bias, int Batch, int N, int C, ActivationFunction activation, cudaStream_t stream); +// Optimized kernel to add bias to innermost dimension +// and perform optional activation (to be used with GEMMs/fully connected) +template +void addBiasBatched(T* output, const T* input, const T* bias, int Batch, int N, + int C, int Nstride, ActivationFunction activation, cudaStream_t stream); + // Add bias to convolution's output. template void addBias_NCHW(T* c, T* a, T* b, int N, int C, int H, int W, diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 27757792c4..f0acb1c176 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -1398,7 +1398,7 @@ AttentionPolicyHead::AttentionPolicyHead(BaseLayer* ip, const LegacyWeights& weights, void* scratch, bool attention_body, - ActivationFunction act) + ActivationFunction act, int max_batch_size) : attention_body_(attention_body), act_(attention_body ? act : SELU), // HACK : old networks without attention body (e.g: T79 use hardcoded SELU activations) BaseLayer(64 * 64 + 24 * 8, 1, 1, ip) { @@ -1444,7 +1444,7 @@ AttentionPolicyHead::AttentionPolicyHead(BaseLayer* ip, for (const auto& enc : weights.pol_encoder) { EncoderBlock* pW = new EncoderBlock( enc, scratch, encoder_heads_, embedding_op_size_, 1.0f, - nullptr, 0); // using alpha = 1 for now (TODO: may change?) + nullptr, 0, max_batch_size); // using alpha = 1 for now (TODO: may change?) encoder_weights_.emplace_back(pW); } } @@ -1452,9 +1452,9 @@ AttentionPolicyHead::AttentionPolicyHead(BaseLayer* ip, template EncoderBlock::EncoderBlock( const LegacyWeights::EncoderLayer& cpu_weights, void* scratch, int heads, - int size, float alpha, DataType* smolgen_global_scratch, int smolgen_global_size) + int size, float alpha, DataType* smolgen_global_scratch, int smolgen_global_size, int max_batch_size) : encoder_heads_(heads), embedding_op_size_(size), alpha_(alpha), - has_smolgen_(cpu_weights.mha.has_smolgen) { + has_smolgen_(cpu_weights.mha.has_smolgen), max_batch_size_(max_batch_size) { mha_q_size_ = cpu_weights.mha.q_b.size(); mha_k_size_ = cpu_weights.mha.k_b.size(); mha_v_size_ = cpu_weights.mha.v_b.size(); @@ -1575,6 +1575,28 @@ static void cublasXGemmStridedBatched( } } +template +static void cublasXGemmBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, float alpha, DataType** A, int lda, + DataType** B, int ldb, + float beta, DataType** C, int ldc, int batchCount) { + const bool fp16 = std::is_same::value; + if (fp16) { + unsigned short alpha_h = FP32toFP16(alpha); + unsigned short beta_h = FP32toFP16(beta); + ReportCUBLASErrors(cublasHgemmBatched( + handle, transa, transb, m, n, k, (const half*)&alpha_h, (half**)A, lda, + (half**)B, ldb, (const half*)&beta_h, (half**)C, ldc, + batchCount)); + } else { + ReportCUBLASErrors(cublasSgemmBatched( + handle, transa, transb, m, n, k, &alpha, (float**)A, lda, + (float**)B, ldb, &beta, (float**)C, ldc, + batchCount)); + } +} + // input/output tensor is scratch1, others are used as scratch. // TODO: fix naming of scratch buffers template @@ -1640,7 +1662,7 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, */ const int num_inputs = smol_dense_2_size_ / encoder_heads_; /* num_inputs == gen_sz == 256 */ const int num_outputs = smol_global_size_; /* hwhw: 64 * 64 */ - const int batch = N*encoder_heads_; + const int batch = N * encoder_heads_; cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)smol_global, num_inputs, scratch0, num_inputs, 0.0f, scratch3, num_outputs); @@ -1656,16 +1678,17 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, const int num_inputs = embedding_op_size_; const int num_outputs = d_model; const int batch = N * 64; + const int max_batch = max_batch_size_ * 64; mha_q = scratch0; - mha_k = mha_q + num_outputs * batch; - mha_v = mha_k + num_outputs * batch; + mha_k = mha_q + num_outputs * max_batch; + mha_v = mha_k + num_outputs * max_batch; cublasXGemmStridedBatched( cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, mha_qkv_w, num_inputs, num_inputs * num_outputs, scratch1, - num_inputs, 0, 0.0f, mha_q, num_outputs, num_outputs * batch, 3); - addBiasBatched(mha_q, mha_q, mha_qkv_b, 3, batch, num_outputs, + num_inputs, 0, 0.0f, mha_q, num_outputs, num_outputs * max_batch, 3); + addBiasBatched(mha_q, mha_q, mha_qkv_b, 3, batch, num_outputs, max_batch, NONE, stream); } @@ -1689,52 +1712,84 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, float factor = 1.0f / sqrt((float)depth); // matmul_qk = tf.matmul(q, k, transpose_b=True) - for (int i = 0; i < encoder_heads_; i++) { - int offset = i * depth; - // layout of the output: encoder_heads_ * Batch * 64 * 64 - int outOffset = i * N * 64 * 64; - cublasXGemmStridedBatched( + { + if (scratch0 != last_known_scratch_) { + std::vector offsets(encoder_heads_ * max_batch_size_*5); + for (int i = 0; i < encoder_heads_ * max_batch_size_; i++) { + int h = i % encoder_heads_; + int n = i / encoder_heads_; + offsets[i] = mha_k + h * depth + 64 * d_model * n; + } + for (int i = 0; i < encoder_heads_ * max_batch_size_; i++) { + int h = i % encoder_heads_; + int n = i / encoder_heads_; + offsets[i + encoder_heads_ * max_batch_size_] = mha_q + h * depth + 64 * d_model * n; + } + for (int i = 0; i < encoder_heads_ * max_batch_size_; i++) { + offsets[i + 2 * encoder_heads_ * max_batch_size_] = scratch2 + i * 64 * 64; + } + for (int i = 0; i < encoder_heads_ * max_batch_size_; i++) { + int h = i % encoder_heads_; + int n = i / encoder_heads_; + offsets[i + 3 * encoder_heads_ * max_batch_size_] = mha_v + h * depth + 64 * d_model * n; + } + for (int i = 0; i < encoder_heads_ * max_batch_size_; i++) { + int h = i % encoder_heads_; + int n = i / encoder_heads_; + offsets[i + 4 * encoder_heads_ * max_batch_size_] = scratch3 + h*depth + 64*d_model*n; + } + ReportCUDAErrors(cudaMalloc((void**)&scratch_rel_ptrs_, encoder_heads_ * max_batch_size_ * 5 * sizeof(DataType*))); + ReportCUDAErrors( + cudaMemcpy(scratch_rel_ptrs_, offsets.data(), encoder_heads_ * max_batch_size_ * 5 * sizeof(DataType*), + cudaMemcpyHostToDevice)); + last_known_scratch_ = scratch0; + } + cublasXGemmBatched( cublas, CUBLAS_OP_T, CUBLAS_OP_N, 64 /*M*/, 64 /*N*/, depth /*K*/, // A/B, and M/N are swapped for row-major to col-major // transform factor, // to handle "/ tf.math.sqrt(dk)" - mha_k + offset /*A*/, + scratch_rel_ptrs_,// mha_k + offset /*A*/, d_model /*LDA*/, // (d_model = depth * encoder_heads_) to skip over // other "depth" slices / heads - 64 * d_model, /*strideA*/ - mha_q + offset /*B*/, + //64 * d_model, /*strideA*/ + scratch_rel_ptrs_ + encoder_heads_ * max_batch_size_,//mha_q + offset /*B*/, d_model /*LDB*/, // to skip over other other "depth" slices / heads - 64 * d_model, /*strideB*/ + //64 * d_model, /*strideB*/ 0.0f, - scratch2 + outOffset /*C*/, // output (matmul_qk) goes to scratch2 - 64 /*LDC*/, 64 * 64 /*strideC*/, N); - } + scratch_rel_ptrs_ + encoder_heads_ * max_batch_size_ * 2, //scratch2 + outOffset /*C*/, // output (matmul_qk) goes to scratch2 + 64 /*LDC*/, + //64 * 64 /*strideC*/, + N * encoder_heads_); + } // Add smolgen weights to the scaled matmul_qk attention logits. // smolgen weights need to be transposed first, kernel handles that. if (has_smolgen_) { - addVectorsHNC_NHC(scratch2, scratch3, N, encoder_heads_, 64 * 64, stream); + const int size = N * encoder_heads_ * 64 * 64; + addVectors(scratch2, scratch2, scratch3, size, size, size, NONE, stream); } // attention_weights = tf.nn.softmax(scaled_attention_logits, axis = -1) // attention_weights -> scratch2 Softmax(encoder_heads_ * N * 64, 64, scratch2, scratch2, stream); - // output = tf.matmul(attention_weights, v) - for (int i = 0; i < encoder_heads_; i++) { - int offset = i * depth; // for output and "v" matrix - // layout: encoder_heads_ * Batch*64*64 - int weightsOffset = i * N * 64 * 64; - cublasXGemmStridedBatched( + { + cublasXGemmBatched( cublas, CUBLAS_OP_N, CUBLAS_OP_N, depth /*M*/, 64 /*N*/, 64 /*K*/, 1.0f, - mha_v + offset /*A*/, // "v" matrix + scratch_rel_ptrs_ + encoder_heads_ * max_batch_size_ * 3, //mha_v + offset /*A*/, // "v" matrix d_model /*LDA*/, // to skip over other "depth" slices / heads - 64 * d_model, /*strideA*/ - scratch2 + weightsOffset /*B*/, 64 /*LDB*/, 64 * 64, /*strideB*/ - 0.0f, scratch3 + offset /*C*/, // output goes to scratch3 - d_model /*LDC*/, 64 * d_model /*strideC*/, N); + //64 * d_model, /*strideA*/ + scratch_rel_ptrs_ + encoder_heads_ * max_batch_size_ * 2, //scratch2 + weightsOffset /*B*/, + 64 /*LDB*/, //64 * 64, /*strideB*/ + 0.0f, + scratch_rel_ptrs_ + encoder_heads_ * max_batch_size_ * 4, //scratch3 + offset /*C*/, // output goes to scratch3 + d_model /*LDC*/, + //64 * d_model /*strideC*/, + N * encoder_heads_); } + // #final dense layer (mha_dense), scratch3 -> scratch2 { const int num_inputs = d_model; const int num_outputs = embedding_op_size_; @@ -1943,7 +1998,7 @@ template AttentionBody::AttentionBody(const LegacyWeights& weights, void* scratch, ActivationFunction default_act, - int num_res_blocks, int input_c) + int num_res_blocks, int input_c, int max_batch_size) : embedding_op_size_(weights.ip_emb_b.size()), encoder_head_count_(weights.encoder_head_count), num_resi_blocks_(num_res_blocks), @@ -1971,7 +2026,7 @@ AttentionBody::AttentionBody(const LegacyWeights& weights, for (const auto& enc : weights.encoder) { EncoderBlock* pW = new EncoderBlock( enc, scratch, encoder_head_count_, embedding_op_size_, alpha, - smolgen_global_, smolgen_global_size_); + smolgen_global_, smolgen_global_size_, max_batch_size); encoder_weights_.emplace_back(pW); } } @@ -2048,6 +2103,7 @@ void AttentionBody::Eval( pEnc->Eval(N, scratch1, scratch0, scratch2, scratch3, cublas, stream, default_act_); } // End of encoder blocks + // dumpTensor(scratch1, N * 64 * embedding_op_size_, "Outputs"); } diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index 69263b0258..1a9837db94 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -337,7 +337,7 @@ class EncoderBlock { public: EncoderBlock(const LegacyWeights::EncoderLayer& cpu_weights, void* scratch, int heads, int size, float alpha, DataType* smolgen_global_scratch, - int smolgen_global_size); + int smolgen_global_size, int max_batch_size); ~EncoderBlock(); void Eval(int N, DataType* inpop, DataType* scratch0, DataType* scratch1, @@ -385,6 +385,10 @@ class EncoderBlock { int smol_dense_1_size_; int smol_dense_2_size_; int smol_global_size_; + + const int max_batch_size_; + mutable DataType** scratch_rel_ptrs_; + mutable DataType* last_known_scratch_; }; // The Attention policy head implementation @@ -401,7 +405,7 @@ class AttentionPolicyHead : public BaseLayer { public: AttentionPolicyHead(BaseLayer* ip, const LegacyWeights& weights, - void* scratch, bool attention_body, ActivationFunction act); + void* scratch, bool attention_body, ActivationFunction act, int max_batch_size); ~AttentionPolicyHead(); void Eval(int N, DataType* output, const DataType* input, const DataType* input2, void* scratch, size_t scratch_size, @@ -465,7 +469,7 @@ class AttentionBody : public BaseLayer { public: AttentionBody(const LegacyWeights& weights, void* scratch, ActivationFunction default_act, int num_res_blocks, - int input_c); + int input_c, int max_batch_size); ~AttentionBody(); void Eval(int N, DataType* output, const DataType* input, const DataType* input2, void* scratch, size_t scratch_size, diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 7ec3c9644f..c8d3c63e32 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -415,7 +415,7 @@ class CudaNetwork : public Network { if (attn_body_) { auto attention_body = std::make_unique>( weights, scratch_mem_, act, numBlocks_, - numBlocks_ > 0 ? kNumFilters : kInputPlanes); + numBlocks_ > 0 ? kNumFilters : kInputPlanes, max_batch_size_); network_.emplace_back(std::move(attention_body)); encoder_last_ = getLastLayer(); @@ -424,7 +424,7 @@ class CudaNetwork : public Network { // Policy head. if (attn_policy_) { auto AttentionPolicy = std::make_unique>( - getLastLayer(), weights, scratch_mem_, attn_body_, act); + getLastLayer(), weights, scratch_mem_, attn_body_, act, max_batch_size_); network_.emplace_back(std::move(AttentionPolicy)); auto policymap = std::make_unique>( From 9f9304b04a590d505dd5a6f7f6ddabe1ee87f756 Mon Sep 17 00:00:00 2001 From: Alma Date: Tue, 27 Dec 2022 22:39:20 +0100 Subject: [PATCH 19/70] Fix cudnn build failures. --- src/neural/cuda/network_cudnn.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neural/cuda/network_cudnn.cc b/src/neural/cuda/network_cudnn.cc index e5fdfdd3ad..92ba9abf36 100644 --- a/src/neural/cuda/network_cudnn.cc +++ b/src/neural/cuda/network_cudnn.cc @@ -504,7 +504,7 @@ class CudnnNetwork : public Network { // Policy head. if (attn_policy_) { auto AttentionPolicy = std::make_unique>( - getLastLayer(), weights, scratch_mem_); + getLastLayer(), weights, scratch_mem_, false, SELU, max_batch_size_); network_.emplace_back(std::move(AttentionPolicy)); auto policymap = std::make_unique>( From edbd8a8284f78a5ac971c6913dbf2d19ea3c7936 Mon Sep 17 00:00:00 2001 From: Alma Date: Wed, 28 Dec 2022 12:28:11 +0100 Subject: [PATCH 20/70] Add tilps perf patch for fused smolgen weights add / softmax --- src/neural/cuda/common_kernels.cu | 29 +++++++++++++++++++++-------- src/neural/cuda/kernels.h | 2 +- src/neural/cuda/layers.cc | 14 ++++++-------- 3 files changed, 28 insertions(+), 17 deletions(-) diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index ebaf2b83dc..f7469046c0 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -808,12 +808,12 @@ void OutputInputTransform(int N, int C, int se_K, T* output, const T* input, // each thread processes two elements. Each warp computes a sum (over 64 // elements) template -__global__ void softmax_opt_64_kernel(T* output, const T* input, int N) { +__global__ void softmax_opt_64_kernel(T* output, const T* input, const T* input2, int N) { int index = blockDim.x * blockIdx.x + threadIdx.x; if (index >= N) return; - float x[2]; + float x[4]; float ex[2]; // Load from memory @@ -823,10 +823,22 @@ __global__ void softmax_opt_64_kernel(T* output, const T* input, int N) { copyAs(&inp[0], &input[index * 2]); x[0] = (float)inp[0]; x[1] = (float)inp[1]; + if (input2 != nullptr) { + copyAs(&inp[0], &input2[index * 2]); + x[2] = (float)inp[0]; + x[3] = (float)inp[1]; + } } else { copyAs(&x[0], &input[index * 2]); + if (input2 != nullptr) { + copyAs(&x[2], &input2[index * 2]); + } } + if (input2 != nullptr) { + x[0] += x[2]; + x[1] += x[3]; + } float threadMax = max(x[0], x[1]); float maxval = warpMax(threadMax); maxval = __shfl_sync(0xFFFFFFFF, maxval, 0); @@ -859,7 +871,7 @@ __global__ void softmax_opt_64_kernel(T* output, const T* input, int N) { // Sums are computed in shared memory // C threads per block, N blocks template -__global__ void softmax_kernel(T* output, const T* input) { +__global__ void softmax_kernel(T* output, const T* input, const T* input2) { int n = blockIdx.x; int c = threadIdx.x; int C = blockDim.x; @@ -868,6 +880,7 @@ __global__ void softmax_kernel(T* output, const T* input) { // softmax = tf.exp(logits) / tf.reduce_sum(tf.exp(logits), axis) float x = (float)input[index]; + if (input2 != nullptr) x += (float)input2[index]; __shared__ float sum, maxval; if (c == 0) { @@ -899,14 +912,14 @@ __global__ void softmax_kernel(T* output, const T* input) { } template -void Softmax(int N, int C, T* output, const T* input, cudaStream_t stream) { +void Softmax(int N, int C, T* output, const T* input, const T* input2, cudaStream_t stream) { if (C == 64) { int size = N * 32; // Total no of threads needed const int kBlockSize = 256; int blocks = DivUp(size, kBlockSize); - softmax_opt_64_kernel<<>>(output, input, size); + softmax_opt_64_kernel<<>>(output, input, input2, size); } else { - softmax_kernel<<>>(output, input); + softmax_kernel<<>>(output, input, input2); } ReportCUDAErrors(cudaGetLastError()); @@ -1409,9 +1422,9 @@ template void OutputInputTransform( const float* skip, const float* bias, const float* w1, const float* b1, const float* w2, const float* b2, cudaStream_t stream); -template void Softmax(int N, int C, half* output, const half* input, +template void Softmax(int N, int C, half* output, const half* input, const half* input2, cudaStream_t stream); -template void Softmax(int N, int C, float* output, const float* input, +template void Softmax(int N, int C, float* output, const float* input, const float* input2, cudaStream_t stream); template void LayerNorm(int N, int C, half* output, const half* input, diff --git a/src/neural/cuda/kernels.h b/src/neural/cuda/kernels.h index 7545e10241..d6e69dc980 100644 --- a/src/neural/cuda/kernels.h +++ b/src/neural/cuda/kernels.h @@ -130,7 +130,7 @@ void OutputInputTransform(int N, int C, int se_K, T* output, const T* input, cudaStream_t stream); template -void Softmax(int N, int C, T* output, const T* input, cudaStream_t stream); +void Softmax(int N, int C, T* output, const T* input, const T* input2, cudaStream_t stream); template void LayerNorm(int N, int C, T* output, const T* input, const T* bias, diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index f0acb1c176..04315e261d 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -1763,16 +1763,14 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, N * encoder_heads_); } - // Add smolgen weights to the scaled matmul_qk attention logits. - // smolgen weights need to be transposed first, kernel handles that. - if (has_smolgen_) { - const int size = N * encoder_heads_ * 64 * 64; - addVectors(scratch2, scratch2, scratch3, size, size, size, NONE, stream); - } - // attention_weights = tf.nn.softmax(scaled_attention_logits, axis = -1) // attention_weights -> scratch2 - Softmax(encoder_heads_ * N * 64, 64, scratch2, scratch2, stream); + if (has_smolgen_) { + // Add smolgen weights to the scaled matmul_qk attention logits before softmax. + Softmax(encoder_heads_ * N * 64, 64, scratch2, scratch2, scratch3, stream); + } else { + Softmax(encoder_heads_ * N * 64, 64, scratch2, scratch2, (const DataType*)nullptr, stream); + } { cublasXGemmBatched( From d207abefa7b840211e913f9ce8548ae5ffb7f76e Mon Sep 17 00:00:00 2001 From: Alma Date: Sun, 8 Jan 2023 16:10:43 +0100 Subject: [PATCH 21/70] Fix errors in non-attentionbody nets. --- src/neural/cuda/layers.cc | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 04315e261d..aca4a5ada4 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -1825,7 +1825,7 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, num_inputs, 1.0f, (const DataType*)ffn_dense2_w, num_inputs, scratch1, num_inputs, 0.0f, scratch2, num_outputs); } - + // LN2: skip connection and layer normilization (also bias add of prev gemm) // scratch2/scratch0 -> scratch1 LayerNorm(N * 64, embedding_op_size_, scratch1, scratch2, @@ -1947,15 +1947,17 @@ EncoderBlock::~EncoderBlock() { ReportCUDAErrors(cudaFree(ffn_dense2_b)); ReportCUDAErrors(cudaFree(ln2_gammas)); ReportCUDAErrors(cudaFree(ln2_betas)); - ReportCUDAErrors(cudaFree(smol_compress)); - ReportCUDAErrors(cudaFree(smol_dense1_w)); - ReportCUDAErrors(cudaFree(smol_dense1_b)); - ReportCUDAErrors(cudaFree(smol_dense2_w)); - ReportCUDAErrors(cudaFree(smol_dense2_b)); - ReportCUDAErrors(cudaFree(smol_ln1_gammas)); - ReportCUDAErrors(cudaFree(smol_ln1_betas)); - ReportCUDAErrors(cudaFree(smol_ln2_gammas)); - ReportCUDAErrors(cudaFree(smol_ln2_betas)); + if (has_smolgen_) { + ReportCUDAErrors(cudaFree(smol_compress)); + ReportCUDAErrors(cudaFree(smol_dense1_w)); + ReportCUDAErrors(cudaFree(smol_dense1_b)); + ReportCUDAErrors(cudaFree(smol_dense2_w)); + ReportCUDAErrors(cudaFree(smol_dense2_b)); + ReportCUDAErrors(cudaFree(smol_ln1_gammas)); + ReportCUDAErrors(cudaFree(smol_ln1_betas)); + ReportCUDAErrors(cudaFree(smol_ln2_gammas)); + ReportCUDAErrors(cudaFree(smol_ln2_betas)); + } } @@ -2033,9 +2035,13 @@ template AttentionBody::~AttentionBody() { ReportCUDAErrors(cudaFree(ip_emb_w_)); ReportCUDAErrors(cudaFree(ip_emb_b_)); - ReportCUDAErrors(cudaFree(ip_mult_gate_)); - ReportCUDAErrors(cudaFree(ip_add_gate_)); - ReportCUDAErrors(cudaFree(smolgen_global_)); + if (has_gating_) { + ReportCUDAErrors(cudaFree(ip_mult_gate_)); + ReportCUDAErrors(cudaFree(ip_add_gate_)); + } + if (has_smolgen_) { + ReportCUDAErrors(cudaFree(smolgen_global_)); + } for (const auto pEnc : encoder_weights_) delete pEnc; } From 70b05219b244d56c65d23c25da456aba8f976711 Mon Sep 17 00:00:00 2001 From: Alma Date: Mon, 9 Jan 2023 04:33:35 +0100 Subject: [PATCH 22/70] Add multistream support. Allow new attentionbody nets. --- libs/lczero-common | 2 +- src/neural/cuda/inputs_outputs.h | 2 +- src/neural/cuda/layers.cc | 18 ++---------------- src/neural/cuda/layers.h | 2 +- src/neural/cuda/network_cuda.cc | 4 +++- 5 files changed, 8 insertions(+), 20 deletions(-) diff --git a/libs/lczero-common b/libs/lczero-common index 2165d35bf6..8201edc7c2 160000 --- a/libs/lczero-common +++ b/libs/lczero-common @@ -1 +1 @@ -Subproject commit 2165d35bf63e95549eb4feff06a755ec88af5264 +Subproject commit 8201edc7c2d00b22e0858b963098e8af14c725f8 diff --git a/src/neural/cuda/inputs_outputs.h b/src/neural/cuda/inputs_outputs.h index d8677b4d91..03b9fe5725 100644 --- a/src/neural/cuda/inputs_outputs.h +++ b/src/neural/cuda/inputs_outputs.h @@ -127,9 +127,9 @@ struct InputsOutputs { // cuda stream used to run the network cudaStream_t stream_; - cublasHandle_t cublas_; // cublas handle used to run the network + cublasHandle_t cublas_; }; diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index aca4a5ada4..f839d27347 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -1713,36 +1713,22 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, // matmul_qk = tf.matmul(q, k, transpose_b=True) { - if (scratch0 != last_known_scratch_) { + if (scratch0 != offset_scratches_[stream]) { std::vector offsets(encoder_heads_ * max_batch_size_*5); for (int i = 0; i < encoder_heads_ * max_batch_size_; i++) { int h = i % encoder_heads_; int n = i / encoder_heads_; offsets[i] = mha_k + h * depth + 64 * d_model * n; - } - for (int i = 0; i < encoder_heads_ * max_batch_size_; i++) { - int h = i % encoder_heads_; - int n = i / encoder_heads_; offsets[i + encoder_heads_ * max_batch_size_] = mha_q + h * depth + 64 * d_model * n; - } - for (int i = 0; i < encoder_heads_ * max_batch_size_; i++) { offsets[i + 2 * encoder_heads_ * max_batch_size_] = scratch2 + i * 64 * 64; - } - for (int i = 0; i < encoder_heads_ * max_batch_size_; i++) { - int h = i % encoder_heads_; - int n = i / encoder_heads_; offsets[i + 3 * encoder_heads_ * max_batch_size_] = mha_v + h * depth + 64 * d_model * n; - } - for (int i = 0; i < encoder_heads_ * max_batch_size_; i++) { - int h = i % encoder_heads_; - int n = i / encoder_heads_; offsets[i + 4 * encoder_heads_ * max_batch_size_] = scratch3 + h*depth + 64*d_model*n; } ReportCUDAErrors(cudaMalloc((void**)&scratch_rel_ptrs_, encoder_heads_ * max_batch_size_ * 5 * sizeof(DataType*))); ReportCUDAErrors( cudaMemcpy(scratch_rel_ptrs_, offsets.data(), encoder_heads_ * max_batch_size_ * 5 * sizeof(DataType*), cudaMemcpyHostToDevice)); - last_known_scratch_ = scratch0; + offset_scratches_[stream] = scratch0; } cublasXGemmBatched( cublas, CUBLAS_OP_T, CUBLAS_OP_N, 64 /*M*/, 64 /*N*/, diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index 1a9837db94..f74deb467e 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -388,7 +388,7 @@ class EncoderBlock { const int max_batch_size_; mutable DataType** scratch_rel_ptrs_; - mutable DataType* last_known_scratch_; + mutable std::unordered_map offset_scratches_; }; // The Attention policy head implementation diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 456ec7861a..1b218f7006 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -1027,7 +1027,9 @@ std::unique_ptr MakeCudaNetwork(const std::optional& w, if (weights.format().network_format().network() != pblczero::NetworkFormat::NETWORK_CLASSICAL_WITH_HEADFORMAT && weights.format().network_format().network() != - pblczero::NetworkFormat::NETWORK_SE_WITH_HEADFORMAT) { + pblczero::NetworkFormat::NETWORK_SE_WITH_HEADFORMAT && + weights.format().network_format().network() != + pblczero::NetworkFormat::NETWORK_ATTENTIONBODY_WITH_HEADFORMAT) { throw Exception("Network format " + pblczero::NetworkFormat::NetworkStructure_Name( weights.format().network_format().network()) + From eb184f47eb9bb57ba944b84f55a71517a47915f9 Mon Sep 17 00:00:00 2001 From: Ankan Banerjee Date: Sat, 4 Mar 2023 10:28:20 +0530 Subject: [PATCH 23/70] add 8 elements per thread layernorm - to handle bigger/wider networks --- src/neural/cuda/common_kernels.cu | 91 ++++++++++++++++++++++++++++++- 1 file changed, 88 insertions(+), 3 deletions(-) diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index f7469046c0..f1ac91637c 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -1038,26 +1038,111 @@ __global__ void layer_norm_kernel(int N, int C, T* output, const T* input, const } } +__global__ void layer_norm_kernel_8_el_per_thread( + int N, int C, half* output, const half* input, const half* bias, + const half* skip, const half* gammas, const half* betas, float ep, + float alpha, ActivationFunction act) { + int n = blockIdx.x * blockDim.z + threadIdx.z; + if (n >= N) return; + int c = (threadIdx.y * 32 + threadIdx.x) * 8; + bool oobThread = c >= C; + + int biasIndex = c; + int tensorIndex = n * C + c; + + float val[8] = {}; + float b[8] = {}; + float sk[8] = {}; + float bts[8] = {}; + float gms[8] = {}; + + if (!oobThread) { + // Load from memory (8 elements a time) + half inp[8]; + copyAs(&inp[0], &input[tensorIndex]); + for (int i = 0; i < 8; i++) val[i] = (float)inp[i]; + copyAs(&inp[0], &skip[tensorIndex]); + for (int i = 0; i < 8; i++) sk[i] = (float)inp[i]; + copyAs(&inp[0], &bias[biasIndex]); + for (int i = 0; i < 8; i++) b[i] = (float)inp[i]; + copyAs(&inp[0], &betas[biasIndex]); + for (int i = 0; i < 8; i++) bts[i] = (float)inp[i]; + copyAs(&inp[0], &gammas[biasIndex]); + for (int i = 0; i < 8; i++) gms[i] = (float)inp[i]; + } + + // 1. Compute mean + float s = 0; + if (!oobThread) + for (int i = 0; i < 8; i++) { + val[i] = activate(val[i] + b[i], act) + sk[i] * alpha; + s += val[i]; + } + + s = shared_sum_for_layer_norm(s); + float mean = s / C; + + // 2. Compute varience + s = 0; + if (!oobThread) + for (int i = 0; i < 8; i++) { + float d = val[i] - mean; + float d_sq = d * d; + s += d_sq; + } + s = shared_sum_for_layer_norm(s); + float var = s / C; + + // 3. Normalize + for (int i = 0; i < 8; i++) { + float d = val[i] - mean; + float norm = d / sqrt(var + ep); + float op = norm * gms[i] + bts[i]; + val[i] = op; + } + + if (!oobThread) { + // Write to memory + half op[8]; + for (int i = 0; i < 8; i++) op[i] = (half)val[i]; + copyAs(&output[tensorIndex], &op[0]); + } +} + + // add (optional) skip connection to input, and then perform Layer normalization // normalization is done across C dimension (i.e, sums and std deviations taken over elements in C dim) template void LayerNorm(int N, int C, T* output, const T* input, const T* bias, const T* skip, const T* gammas, const T* betas, float ep, float alpha, ActivationFunction act, cudaStream_t stream) { + const bool fp16 = std::is_same::value; // process 4 elements per thread to achieve close to peak memory bandwidth if (C % 4 != 0) throw Exception("unsupported filter size"); - if (C > 4096) throw Exception("unsupported filter size"); + if (C > 4096) + { + if (!fp16 || (C % 8 != 0) || C > 8192) + throw Exception("unsupported filter size"); + } + + const int EL_PER_THREAD = (C > 4096) ? 8 : 4; dim3 blockDim, gridDim; blockDim.x = 32; - blockDim.y = DivUp(C / 4, 32); + blockDim.y = DivUp(C / EL_PER_THREAD, 32); blockDim.z = std::min(std::max(512 / (blockDim.x * blockDim.y), 1u), (unsigned int)N); gridDim.x = DivUp(N, blockDim.z); gridDim.y = 1; gridDim.z = 1; - layer_norm_kernel<<>>( + if (EL_PER_THREAD == 8 && fp16) + layer_norm_kernel_8_el_per_thread<<>>( + N, C, (half*)output, (const half*)input, (const half*)bias, + (const half*)skip, (const half*)gammas, (const half*)betas, ep, alpha, + act); + else + layer_norm_kernel<<>>( N, C, output, input, bias, skip, gammas, betas, ep, alpha, act); ReportCUDAErrors(cudaGetLastError()); From 6e0161a0c013d2db5c199b738fbe8f80e02e0cde Mon Sep 17 00:00:00 2001 From: Ankan Banerjee Date: Sat, 4 Mar 2023 14:53:55 +0530 Subject: [PATCH 24/70] Try fused MHA from cutlass 1.3% improvement in BT2 on RTX 4090 15.6% improvement in test BT3 network with 64 heads. --- build.cmd | 13 +- meson.build | 14 + meson_options.txt | 10 + src/neural/cuda/cutlass_kernels.cu | 124 + .../fused_multi_head_attention/debug_utils.h | 234 ++ .../epilogue/epilogue_pipelined.h | 632 +++++ .../epilogue/epilogue_rescale_output.h | 262 ++ .../epilogue_thread_apply_logsumexp.h | 175 ++ .../gemm/custom_mma.h | 124 + .../gemm/custom_mma_base.h | 183 ++ .../gemm/custom_mma_multistage.h | 767 ++++++ .../gemm/custom_mma_pipelined.h | 401 ++++ .../gemm/find_default_mma.h | 191 ++ .../gemm/mma_accum_lambda_iterator.h | 378 +++ .../gemm/mma_from_smem.h | 2055 ++++++++++++++++ .../gemm_kernel_utils.h | 248 ++ .../epilogue_predicated_tile_iterator.h | 752 ++++++ .../iterators/make_residual_last.h | 97 + ...cated_tile_access_iterator_residual_last.h | 2115 ++++++++++++++++ .../predicated_tile_iterator_residual_last.h | 2120 +++++++++++++++++ .../iterators/transpose_warp_iterator.h | 53 + .../iterators/warp_iterator_from_smem.h | 278 +++ .../kernel_forward.h | 1236 ++++++++++ .../transform/tile_smem_loader.h | 88 + src/neural/cuda/kernels.h | 6 + src/neural/cuda/layers.cc | 117 +- src/neural/cuda/layers.h | 6 +- src/neural/cuda/network_cuda.cc | 7 +- 28 files changed, 12626 insertions(+), 60 deletions(-) create mode 100644 src/neural/cuda/cutlass_kernels.cu create mode 100644 src/neural/cuda/fused_multi_head_attention/debug_utils.h create mode 100644 src/neural/cuda/fused_multi_head_attention/epilogue/epilogue_pipelined.h create mode 100644 src/neural/cuda/fused_multi_head_attention/epilogue/epilogue_rescale_output.h create mode 100644 src/neural/cuda/fused_multi_head_attention/epilogue/epilogue_thread_apply_logsumexp.h create mode 100644 src/neural/cuda/fused_multi_head_attention/gemm/custom_mma.h create mode 100644 src/neural/cuda/fused_multi_head_attention/gemm/custom_mma_base.h create mode 100644 src/neural/cuda/fused_multi_head_attention/gemm/custom_mma_multistage.h create mode 100644 src/neural/cuda/fused_multi_head_attention/gemm/custom_mma_pipelined.h create mode 100644 src/neural/cuda/fused_multi_head_attention/gemm/find_default_mma.h create mode 100644 src/neural/cuda/fused_multi_head_attention/gemm/mma_accum_lambda_iterator.h create mode 100644 src/neural/cuda/fused_multi_head_attention/gemm/mma_from_smem.h create mode 100644 src/neural/cuda/fused_multi_head_attention/gemm_kernel_utils.h create mode 100644 src/neural/cuda/fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h create mode 100644 src/neural/cuda/fused_multi_head_attention/iterators/make_residual_last.h create mode 100644 src/neural/cuda/fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h create mode 100644 src/neural/cuda/fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h create mode 100644 src/neural/cuda/fused_multi_head_attention/iterators/transpose_warp_iterator.h create mode 100644 src/neural/cuda/fused_multi_head_attention/iterators/warp_iterator_from_smem.h create mode 100644 src/neural/cuda/fused_multi_head_attention/kernel_forward.h create mode 100644 src/neural/cuda/fused_multi_head_attention/transform/tile_smem_loader.h diff --git a/build.cmd b/build.cmd index 9b607e01be..21a03a771e 100644 --- a/build.cmd +++ b/build.cmd @@ -2,7 +2,7 @@ setlocal rem 1. Set the following for the options you want to build. -set CUDNN=true +set CUDNN=false set CUDA=true set DX12=false set OPENCL=false @@ -11,10 +11,12 @@ set DNNL=false set OPENBLAS=false set EIGEN=false set TEST=false +set CUTLASS=true rem 2. Edit the paths for the build dependencies. -set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.0 +set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.0 set CUDNN_PATH=%CUDA_PATH% +set CUTLASS_INCLUDE_PATH=C:\dev\cutlass-2.11.0\include set OPENBLAS_PATH=C:\OpenBLAS set MKL_PATH=C:\Program Files (x86)\IntelSWTools\compilers_and_libraries\windows\mkl set DNNL_PATH=C:\dnnl_win_1.1.1_cpu_vcomp @@ -34,13 +36,13 @@ if exist "C:\Program Files\Microsoft Visual Studio\2022" ( where /q cl if errorlevel 1 call "C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvarsall.bat" amd64 set backend=vs2022 -) else if exist "C:\Program Files (x86)\Microsoft Visual Studio\2019" ( +) else if exist "D:\Program Files (x86)\Microsoft Visual Studio\2019" ( where /q cl - if errorlevel 1 call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Auxiliary\Build\vcvarsall.bat" amd64 + if errorlevel 1 call "D:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Auxiliary\Build\vcvarsall.bat" amd64 set backend=vs2019 ) else ( where /q cl - if errorlevel 1 call "C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Auxiliary\Build\vcvarsall.bat" amd64 + if errorlevel 1 call "D:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Auxiliary\Build\vcvarsall.bat" amd64 set backend=vs2017 ) @@ -63,6 +65,7 @@ meson build --backend %backend% --buildtype release -Ddx=%DX12% -Dcudnn=%CUDNN% -Dmkl_include="%MKL_PATH%\include" -Dmkl_libdirs="%MKL_PATH%\lib\intel64" -Ddnnl_dir="%DNNL_PATH%" ^ -Dopencl_libdirs="%OPENCL_LIB_PATH%" -Dopencl_include="%OPENCL_INCLUDE_PATH%" ^ -Dopenblas_include="%OPENBLAS_PATH%\include" -Dopenblas_libdirs="%OPENBLAS_PATH%\lib" ^ +-Dcutlass_include="%CUTLASS_INCLUDE_PATH%" -Dcutlass="%CUTLASS%" ^ -Ddefault_library=static if errorlevel 1 exit /b diff --git a/meson.build b/meson.build index 13490824e3..2a70100d0f 100644 --- a/meson.build +++ b/meson.build @@ -478,6 +478,11 @@ if get_option('build_backends') cuda_arguments += ['-ccbin=' + get_option('nvcc_ccbin')] endif cuda_cc = get_option('cc_cuda') # Unfortunately option cuda_cc is reserved. + if get_option('cutlass') + add_project_arguments('-DUSE_CUTLASS', language : 'cpp') + cuda_arguments += ['-DUSE_CUTLASS'] + cuda_arguments += ['-I', get_option('cutlass_include')] + endif nvcc_extra_args = [] if cuda_cc != '' nvcc_extra_args = ['-arch=compute_' + cuda_cc, '-code=sm_' + cuda_cc] @@ -497,6 +502,15 @@ if get_option('build_backends') depend_files: 'src/neural/cuda/winograd_helper.inc', command : [nvcc, nvcc_extra_args, cuda_arguments] ) + + if get_option('cutlass') + nvcc_cutlass_args = ['-arch=compute_80', '-code=sm_80'] + files += custom_target('cuda cutlass code', + input : 'src/neural/cuda/cutlass_kernels.cu', + output : outputname, + command : [nvcc, nvcc_cutlass_args, cuda_arguments] + ) + endif # Handling of fp16 cuda code. nvcc_sm_list = ['80', '75', '86', '70', '89', '90'] diff --git a/meson_options.txt b/meson_options.txt index cc2364be45..c55caa392a 100644 --- a/meson_options.txt +++ b/meson_options.txt @@ -43,6 +43,11 @@ option('cudnn_include', value: ['/opt/cuda/include/', '/usr/local/cuda/include/', '/usr/lib/cuda/include/'], description: 'Paths to cudnn include directory') +option('cutlass_include', + type: 'string', + value: '/usr', + description: 'Paths to cutlass include directory') + option('build_backends', type: 'boolean', value: true, @@ -73,6 +78,11 @@ option('plain_cuda', value: true, description: 'Enable CUDA backend') +option('cutlass', + type: 'boolean', + value: false, + description: 'Enable cutlass lib for cuda backend. Only supports Ampere+ right now') + option('opencl', type: 'boolean', value: true, diff --git a/src/neural/cuda/cutlass_kernels.cu b/src/neural/cuda/cutlass_kernels.cu new file mode 100644 index 0000000000..b4dc01ed87 --- /dev/null +++ b/src/neural/cuda/cutlass_kernels.cu @@ -0,0 +1,124 @@ +/* + This file is part of Leela Chess Zero. + Copyright (C) 2018 The LCZero Authors + + Leela Chess is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + Leela Chess is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Leela Chess. If not, see . + + Additional permission under GNU GPL version 3 section 7 + + If you modify this Program, or any covered work, by linking or + combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA + Toolkit and the NVIDIA CUDA Deep Neural Network library (or a + modified version of those libraries), containing parts covered by the + terms of the respective license agreement, the licensors of this + Program grant you additional permission to convey the resulting work. +*/ + +#include "cuda_common.h" + +#ifdef USE_CUTLASS + +// Fused MHA implementation from cutlass example #41 +#include "fused_multi_head_attention/kernel_forward.h" + + + +template +bool fusedMHACutlass(void* output, void* q, void* k, void* v, void* skip, + int batch_size, int num_heads, int depth, + cudaStream_t stream) { + cutlass::half_t* mha_q = (cutlass::half_t*)q; + cutlass::half_t* mha_k = (cutlass::half_t*)k; + cutlass::half_t* mha_v = (cutlass::half_t*)v; + + constexpr int kQueriesPerBlock = 64; + constexpr int kKeysPerBlock = 64; + constexpr bool kSingleValueIteration = true; + + using Attention = + AttentionKernel; + + typename Attention::Params p; + { // set parameters + p.query_ptr = mha_q; + p.key_ptr = mha_k; + p.value_ptr = mha_v; + p.logsumexp_ptr = nullptr; // Only needed for bw + p.output_accum_ptr = nullptr; + if (Attention::kNeedsOutputAccumulatorBuffer) { + // throw Exception("Unhandled case in cutlass MHA"); + return false; + } + p.output_ptr = (cutlass::half_t*)output; + p.attn_bias_ptr = (cutlass::half_t*)skip; + + p.scale = 1.0f / sqrt((float)depth); + + p.num_heads = num_heads; + p.num_batches = batch_size; + p.head_dim = depth; + p.head_dim_value = depth; + p.num_queries = 64; + p.num_keys = 64; + + // All tensors are in BMHK shapes + p.q_strideH = depth; + p.k_strideH = depth; + p.v_strideH = depth; + p.q_strideM = depth * num_heads; + p.k_strideM = depth * num_heads; + p.v_strideM = depth * num_heads; + p.q_strideB = p.q_strideM * 64; + p.k_strideB = p.k_strideM * 64; + p.v_strideB = p.v_strideM * 64; + p.o_strideM = p.head_dim_value * p.num_heads; + + p.bias_strideH = 64 * 64; + p.bias_strideM = 64; + p.bias_strideB = num_heads * p.bias_strideH; + } + + constexpr auto kernel_fn = attention_kernel_batched_impl; + int smem_bytes = sizeof(typename Attention::SharedStorage); + if (smem_bytes > 0xc000) { + cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_bytes); + } + if (!Attention::check_supported(p)) { + // throw Exception("Unhandled case in cutlass MHA"); + return false; + } + + kernel_fn<<>>(p); + + // ReportCUDAErrors(cudaGetLastError()); + return true; +} + +bool fusedMHA(void* output, void* mha_q, void* mha_k, void* mha_v, void* skip, + int batch_size, int num_heads, int depth, cudaStream_t stream) { + if (skip == nullptr) + return fusedMHACutlass(output, mha_q, mha_k, mha_v, skip, batch_size, + num_heads, depth, stream); + else + return fusedMHACutlass(output, mha_q, mha_k, mha_v, skip, batch_size, + num_heads, depth, stream); +} +#endif diff --git a/src/neural/cuda/fused_multi_head_attention/debug_utils.h b/src/neural/cuda/fused_multi_head_attention/debug_utils.h new file mode 100644 index 0000000000..aafc62d6e8 --- /dev/null +++ b/src/neural/cuda/fused_multi_head_attention/debug_utils.h @@ -0,0 +1,234 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// +// Debugging functions +//////////////////////////////////////////////////////////////////////////////// +// Nans & inf detection +#define NANCHECK(frag) \ + { \ + for (int _i = 0; _i < frag.size(); ++_i) { \ + assert(std::isfinite(float(frag[_i]))); \ + assert(!std::isnan(float(frag[_i]))); \ + } \ + } + +// Print on the first thread of the first block +#if 1 +#define PRINT_WARP_ID 0 +#define PRINT_LANE_ID 0 +#define PRINT_B0_T0(msg, ...) \ + if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && \ + threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \ + threadIdx.z == 0) { \ + printf(msg "\n", ##__VA_ARGS__); \ + } +#define PRINT_T0(msg, ...) \ + if (threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \ + threadIdx.z == 0) { \ + printf(msg "\n", ##__VA_ARGS__); \ + } +#define PRINT_TX_LX(msg, ...) \ + for (int bx = 0; bx < gridDim.x; ++bx) { \ + for (int by = 0; by < gridDim.y; ++by) { \ + for (int bz = 0; bz < gridDim.z; ++bz) { \ + for (int tx = 0; tx < blockDim.x; ++tx) { \ + for (int ty = 0; ty < blockDim.y; ++ty) { \ + for (int tz = 0; tz < blockDim.z; ++tz) { \ + __syncthreads(); \ + if (blockIdx.x == bx && blockIdx.y == by && blockIdx.z == bz && \ + threadIdx.x == tx && threadIdx.y == ty && \ + threadIdx.z == tz) { \ + printf( \ + "[%d,%d,%d][%d,%d,%d]" msg "\n", \ + bx, \ + by, \ + bz, \ + tx, \ + ty, \ + tz, \ + ##__VA_ARGS__); \ + } \ + } \ + } \ + } \ + } \ + } \ + } +#else +#define PRINT_B0_T0 +#define PRINT_TX_LX +#endif + +struct __string_view { + char const* data; + std::size_t size; +}; +#if __cplusplus >= 201402L +template +constexpr __string_view __get_type_name() { + char const* p = __PRETTY_FUNCTION__; + while (*p++ != '=') + ; + for (; *p == ' '; ++p) + ; + char const* p2 = p; + int count = 1; + for (;; ++p2) { + switch (*p2) { + case '[': + ++count; + break; + case ']': + --count; + if (!count) + return {p, std::size_t(p2 - p)}; + } + } + return {}; +} +#else +template +constexpr __string_view __get_type_name() { + return {"unsupported", 11}; +} +#endif + +// Print a given array +#define PRINT_ACCUM8_T0_L0_START(name, accum, start) \ + PRINT_B0_T0( \ + "%s[%d:%d] - {%f, %f, %f, %f, %f, %f, %f, %f}", \ + name, \ + int(start), \ + int(start + 8), \ + float(accum[start + 0]), \ + float(accum[start + 1]), \ + float(accum[start + 2]), \ + float(accum[start + 3]), \ + float(accum[start + 4]), \ + float(accum[start + 5]), \ + float(accum[start + 6]), \ + float(accum[start + 7])); +#define PRINT_ACCUM8_T0_L0(name, accum) PRINT_ACCUM8_T0_L0_START(name, accum, 0) +#define PRINT_FRAG_T0_L0(name, frag) \ + { \ + auto typeStr = __get_type_name(); \ + PRINT_B0_T0("printing %s (%s)", name, typeStr.data); \ + for (int _start = 0; _start < frag.size(); _start += 8) { \ + PRINT_ACCUM8_T0_L0_START(" ", frag, _start); \ + } \ + /*__syncthreads(); \ + NANCHECK(frag); */ \ + } +#define PRINT_ARRAY_T0_L0_INCR(name, array, length, incr) \ + { \ + PRINT_B0_T0("printing %s (len=%d)", name, int(length)); \ + for (int _start = 0; _start < length; _start += incr) { \ + PRINT_ACCUM8_T0_L0_START(" ", array, _start); \ + } \ + } +#define PRINT_ARRAY_T0_L0(name, array, length) \ + PRINT_ARRAY_T0_L0_INCR(name, array, length, 8) + +// Print a 4x4 matrix +#define PRINT_TENSOR4x4_T0_L0_START(name, ref, start_x, start_y) \ + PRINT_B0_T0( \ + "%s[%d:%d, %d:%d]:\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f", \ + name, \ + int(start_x), \ + int(start_x + 4), \ + int(start_y), \ + int(start_y + 4), \ + float(ref.at({start_x + 0, start_y + 0})), \ + float(ref.at({start_x + 0, start_y + 1})), \ + float(ref.at({start_x + 0, start_y + 2})), \ + float(ref.at({start_x + 0, start_y + 3})), \ + float(ref.at({start_x + 1, start_y + 0})), \ + float(ref.at({start_x + 1, start_y + 1})), \ + float(ref.at({start_x + 1, start_y + 2})), \ + float(ref.at({start_x + 1, start_y + 3})), \ + float(ref.at({start_x + 2, start_y + 0})), \ + float(ref.at({start_x + 2, start_y + 1})), \ + float(ref.at({start_x + 2, start_y + 2})), \ + float(ref.at({start_x + 2, start_y + 3})), \ + float(ref.at({start_x + 3, start_y + 0})), \ + float(ref.at({start_x + 3, start_y + 1})), \ + float(ref.at({start_x + 3, start_y + 2})), \ + float(ref.at({start_x + 3, start_y + 3}))); +#define PRINT_TENSOR4x4_T0_L0(name, ref) \ + PRINT_TENSOR4x4_T0_L0_START(name, ref, 0, 0) + +#define PRINT_PROBLEM_SIZE(name, ps) \ + PRINT_B0_T0( \ + "%s.problem_size: {.m=%d, .n=%d, .k=%d}", \ + name, \ + int(ps.m()), \ + int(ps.n()), \ + int(ps.k())) + +template +CUTLASS_DEVICE void print_warp_accum( + AccumT accum, + LaneOffsetT lane_offset, + int32_t num_rows, + int32_t num_cols) { + bool is_main = blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && + threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0; + for (int row = 0; row < num_rows; ++row) { + for (int col = 0; col < num_cols; ++col) { + if (col % 32 == 0) { + if (is_main) { + printf("\nmat[%3d, %3d:%3d]", row, col, col + 32); + } + __syncthreads(); + } + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (row == accum_m && col == accum_n && + (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)) { + printf(" %6.1f", float(accum[idx])); + } + }, + [&](int accum_m) {}); + __syncthreads(); + } + if (is_main) { + printf("\n"); + } + } +} diff --git a/src/neural/cuda/fused_multi_head_attention/epilogue/epilogue_pipelined.h b/src/neural/cuda/fused_multi_head_attention/epilogue/epilogue_pipelined.h new file mode 100644 index 0000000000..2a574e71f2 --- /dev/null +++ b/src/neural/cuda/fused_multi_head_attention/epilogue/epilogue_pipelined.h @@ -0,0 +1,632 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + File copied from "cutlass/epilogue/threadblock/epilogue.h" + then modified to: + (1) load 2 source fragments at the same time (pipelining) + (2) support reading from a different dtype + (3) pass the row id to the OutputOp if it takes it + (see MemoryEfficientAttentionNormalize) + Note that in general the fragment passed to the OutputOp could + span multiple rows but it does not happen with the configurations we have +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +template +struct ApplyEpilogueOp { + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum, + typename Op::FragmentOutput const& source) { + return output_op(accum, source); + } + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum) { + return output_op(accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator +template < + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: + ///< gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator writing output tensors + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting + ///< accumulators + typename WarpTileIterator_, ///< Warp-scoped tile iterator writing + ///< accumulators to SMEM + typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading + ///< from SMEM + typename OutputOp_, ///< Output operator + typename Padding_, ///< Padding added to SMEM allocation to avoid bank + ///< conflicts (concept: MatrixShape) + int FragmentsPerPartition = + 1, ///< Used to coarsten the epilogue granularity + int IterationsUnroll = ///< Used to reduce binary size when epilogue op is + ///< large + (!IsEpilogueFunctorHeavy::value), + typename OutputTileSourceIterator_ = + OutputTileIterator_ ///< Tile iterator reading tensors + > +class EpiloguePipelined : public EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition> { + public: + using Base = EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition>; + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using OutputTileSourceIterator = OutputTileSourceIterator_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp = OutputOp_; + using Padding = Padding_; + + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename Base::AccumulatorTile; + + /// Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + using ElementSource = typename OutputTileSourceIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + using SyncTensorRef = + typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Array type used to output + using OutputAccessType = Array< + typename OutputTileIterator::Element, + OutputTileIterator::kElementsPerAccess>; + using SourceAccessType = Array< + typename OutputTileSourceIterator::Element, + OutputTileSourceIterator::kElementsPerAccess>; + + /// Array type used by output functor + using AccumulatorAccessType = Array< + typename WarpTileIterator::Element, + OutputTileIterator::kElementsPerAccess>; + + /// Number of warps + using WarpCount = typename Base::WarpCount; + + static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 + ? Base::kFragmentsPerIteration + : kPartitionsK; + static int constexpr kSmemPointerOffset = + Base::SharedStorage::StorageShape::kCount / kSmemTiles; + + public: + static_assert( + OutputTileSourceIterator::Fragment::kElements == + OutputTileIterator::Fragment::kElements, + "Mismatch between input tile and output tile iterator (kElements)"); + static_assert( + OutputTileSourceIterator::kIterations == OutputTileIterator::kIterations, + "Mismatch between input tile and output tile iterator (kIterations)"); + static_assert( + SharedLoadIterator::Fragment::kElements == + OutputTileIterator::Fragment::kElements, + "Mismatch between shared load iterator and output tile iterator."); + + static_assert( + OutputTileIterator::kElementsPerAccess, + "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert( + !(OutputTileIterator::Fragment::kElements % + OutputTileIterator::kElementsPerAccess), + "Divisibility"); + + private: + /// Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator_; + + public: + /// Constructor + CUTLASS_DEVICE + EpiloguePipelined( + typename Base::SharedStorage& shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + shared_load_iterator_(shared_storage.reference(), thread_idx) {} + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators, ///< Complete warp-level accumulator tile + OutputTileSourceIterator + source_iterator) { ///< Threadblock tile coordinate in GEMM (in units + ///< of threadblock tiles) + + if (!output_op.is_source_needed()) { + compute_source_not_needed_(output_op, destination_iterator, accumulators); + } else { + compute_source_needed_( + output_op, destination_iterator, accumulators, source_iterator); + } + } + CUTLASS_DEVICE + void operator()( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators) { ///< Complete warp-level accumulator tile + compute_source_not_needed_(output_op, destination_iterator, accumulators); + } + + private: + template + struct acc2smem_source_not_needed; + + template + struct acc2smem_source_not_needed> { + template + CUTLASS_DEVICE static void helper( + AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + warp_tile_iterator.store(accum_fragment); + if (p < Base::kFragmentsPerIteration - 1) { + warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); + } + } + + if (Base::kFragmentsPerIteration > 1) { + warp_tile_iterator.add_pointer_offset( + kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); + } + } + + CUTLASS_DEVICE + static void push( + size_t pos, + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) { + int dummy[] = { + (pos == (Seq * Base::kFragmentsPerIteration)) && + (helper( + iterator_begin, warp_tile_iterator), + 0)...}; + + CUTLASS_UNUSED(dummy[0]); + } + }; + + static_assert( + kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, + "One of these must be exactly 1."); + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_not_needed_( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators ///< Complete warp-level accumulator tile + ) { + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + +#pragma unroll( \ + IterationsUnroll \ + ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration \ + : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; + iter += Base::kFragmentsPerIteration) { + // + // Convert and store fragment + // + + __syncthreads(); + + acc2smem_source_not_needed>:: + push(iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename SharedLoadIterator::Fragment + aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + if (p < Base::kFragmentsPerIteration - 1) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + } else if (kPartitionsK > 1) { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments( + aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_pointer_offset( + (1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment; + + apply_output_operator_source_not_needed_( + destination_iterator.thread_start_row(), + output_fragment, + output_op, + aligned_accum_fragment[0]); + + // + // Store the final result + // + + destination_iterator.store(output_fragment); + ++destination_iterator; + } + + if (Base::kFragmentsPerIteration > 1) { + shared_load_iterator_.add_pointer_offset( + kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); + } + } + } + + template + struct acc2smem_source_needed; + + template + struct acc2smem_source_needed> { + template + CUTLASS_DEVICE static void helper( + AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + warp_tile_iterator.store(accum_fragment); + } + + CUTLASS_DEVICE + static void push( + size_t pos, + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) { + int dummy[] = { + (pos == Seq) && + (helper(iterator_begin, warp_tile_iterator), 0)...}; + } + }; + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_needed_( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators, ///< Complete warp-level accumulator tile + OutputTileSourceIterator + source_iterator ///< Threadblock tile coordinate in GEMM (in units of + ///< threadblock tiles) + ) { + typename OutputTileSourceIterator::Fragment source_fragment[2]; + + source_fragment[0].clear(); + source_iterator.load(source_fragment[0]); + ++source_iterator; + source_fragment[1].clear(); + + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + +#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + if (iter > 0) { + __syncthreads(); + } + // + // Load the source for next iteration (pipelining) + // + + if (iter + 1 < OutputTileIterator::kIterations) { + source_iterator.load(source_fragment[(iter + 1) % 2]); + } + ++source_iterator; + acc2smem_source_needed< + cutlass::make_index_sequence>:: + push(iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment + aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + // If the number of k-slices is > 1 - perform a reduction amongst the + // k-slices + if (kPartitionsK > 1) { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments( + aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_pointer_offset( + (1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment; + + apply_output_operator_( + destination_iterator.thread_start_row(), + output_fragment, + output_op, + aligned_accum_fragment[0], + source_fragment[iter % 2]); + + // + // Store the final result + // + + destination_iterator.store(output_fragment); + ++destination_iterator; + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_( + int begin_row, + typename OutputTileIterator::Fragment& output_fragment, + OutputOp const& output_op, ///< Output operator + typename SharedLoadIterator::Fragment const& aligned_accum_fragment, + typename OutputTileSourceIterator::Fragment const& source_fragment) { + OutputAccessType* output_frag_ptr = + reinterpret_cast(&output_fragment); + + AccumulatorAccessType const* compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + SourceAccessType const* source_frag_ptr = + reinterpret_cast(&source_fragment); + + int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / + OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + // Call the output operator + output_frag_ptr[i] = ApplyEpilogueOp::apply( + output_op, + begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess), + compute_frag_ptr[i], + source_frag_ptr[i]); + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_source_not_needed_( + int begin_row, + typename OutputTileIterator::Fragment& output_fragment, + OutputOp const& output_op, ///< Output operator + typename SharedLoadIterator::Fragment const& aligned_accum_fragment) { + OutputAccessType* output_frag_ptr = + reinterpret_cast(&output_fragment); + + AccumulatorAccessType const* compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / + OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + // Call the output operator + output_frag_ptr[i] = ApplyEpilogueOp::apply( + output_op, + begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess), + compute_frag_ptr[i]); + } + } + + // This should be constexpr, but it's only supported on c++14 + static int CUTLASS_HOST_DEVICE getRowOffset(int i) { + using ThreadMap = typename OutputTileIterator::ThreadMap; + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + int frag_idx = ThreadMap::kElementsPerAccess * + (frag_row_idx * ThreadMap::Iterations::kColumn + column); + if (i < frag_idx + ThreadMap::kElementsPerAccess) { + return row_offset; + } + } + } + } + } + return -1; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/src/neural/cuda/fused_multi_head_attention/epilogue/epilogue_rescale_output.h b/src/neural/cuda/fused_multi_head_attention/epilogue/epilogue_rescale_output.h new file mode 100644 index 0000000000..a5d8f8d3f9 --- /dev/null +++ b/src/neural/cuda/fused_multi_head_attention/epilogue/epilogue_rescale_output.h @@ -0,0 +1,262 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory + to match canonical tensor layouts in global memory. Epilogues support + conversion and reduction operations. + + This is a copy of cutlass/epilogue/threadblock/epilogue.h that can + handle "row_id" as a first argument, as uses it to get the corresponding + `m_prime` / `s_prime` to rescale the output. +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "epilogue_pipelined.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +// output <- alpha * accumulator + beta * source +// with: +// alpha = 1 / s_prime (to normalize when isLast=True, 1 otherwise) +// beta = alpha / m_prime (renormalize the output when the max changes) +// source is the current output +template < + typename ElementOutput_, ///< Data type used to store tensors + typename ElementSource_, //< Data type for source (usually matches + //`ElementOutput`) + int Count, ///< Number of elements computed per operation. + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data + ///< to store + typename ElementAccumulator_, ///< Accumulator data type + typename ElementCompute_, ///< Data type used to compute linear combination + bool isFirst, + bool isLast, + typename FragmentAlphaBeta_, + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> +class MemoryEfficientAttentionNormalize { + public: + using ElementOutput = ElementOutput_; + using ElementSource = ElementSource_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + + using FragmentOutput = Array; + using FragmentSource = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + using FragmentAlphaBeta = FragmentAlphaBeta_; + + static FloatRoundStyle const kRound = Round; + + private: + // + // Data members + // + + FragmentAlphaBeta const& s_prime_; + FragmentAlphaBeta const& m_prime_; + + public: + /// Constructs the function object, possibly loading from pointers in host + /// memory + CUTLASS_HOST_DEVICE + MemoryEfficientAttentionNormalize( + FragmentAlphaBeta const& s_prime, + FragmentAlphaBeta const& m_prime) + : s_prime_(s_prime), m_prime_(m_prime) {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return !isFirst; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + int row, + FragmentAccumulator const& accumulator, + FragmentSource const& source) const { + assert(!isFirst); + + // Convert source to interal compute numeric type + NumericArrayConverter + source_converter; + NumericArrayConverter + accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + ComputeFragment converted_source = source_converter(source); + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + ComputeFragment intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + + ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1; + ElementCompute beta = alpha * m_prime_[row]; + + intermediate = mul_add_source(beta, converted_source); // X = beta * C + + intermediate = mul_add_accumulator( + alpha, converted_accumulator, intermediate); // D = alpha * Accum + X + + return destination_converter(intermediate); + } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentOutput operator()(int row, FragmentAccumulator const& accumulator) + const { + assert(isFirst); + + // Convert source to interal compute numeric type + NumericArrayConverter + accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + ComputeFragment intermediate; + multiplies mul_accumulator; + + ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1; + + intermediate = mul_accumulator( + alpha, converted_accumulator); // X = alpha * C + uniform + + return destination_converter(intermediate); + } +}; + +} // namespace thread + +namespace threadblock { +template < + typename EO, + typename ES, + int Count, + typename EA, + typename EC, + bool F, + bool L, + typename FAB, + FloatRoundStyle R> +struct ApplyEpilogueOp> { + using Op = thread:: + MemoryEfficientAttentionNormalize; + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum, + typename Op::FragmentSource const& source) { + return output_op(row_id, accum, source); + } + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum) { + return output_op(row_id, accum); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/neural/cuda/fused_multi_head_attention/epilogue/epilogue_thread_apply_logsumexp.h b/src/neural/cuda/fused_multi_head_attention/epilogue/epilogue_thread_apply_logsumexp.h new file mode 100644 index 0000000000..2e286d3f46 --- /dev/null +++ b/src/neural/cuda/fused_multi_head_attention/epilogue/epilogue_thread_apply_logsumexp.h @@ -0,0 +1,175 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination operations used by epilogues. +*/ + +#pragma once + +#include + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct ArrayExponential { + CUTLASS_HOST_DEVICE + Array operator()( + Array const& input) const { + Array result; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ElementsPerAccess; ++i) { + result[i] = expf(input[i]); + } + + return result; + } +}; + +template +struct ArrayExponential { + CUTLASS_DEVICE + Array operator()( + Array const& input) const { + Array result; + + int const kVectorCount = ElementsPerAccess / 2; + + __half2 const* input_ptr = + reinterpret_cast<__half2 const*>(input.raw_data()); + __half2* res_ptr = reinterpret_cast<__half2*>(result.raw_data()); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kVectorCount; ++i) { + res_ptr[i] = h2exp(input_ptr[i]); + } + + return result; + } +}; +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies: +/// output <- (input - lse).exp() +template < + typename ElementOutput_, // output + typename ElementLSE_, // accumulator from LSE + typename ElementAccumulator_, // accumulator from matmul + typename ElementCompute_, // intermediate compute (and exp calculation) + int ElementsPerAccess> +class ApplyLogSumExp { + public: + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using ElementLSE = ElementLSE_; + + static int const kElementsPerAccess = ElementsPerAccess; + static int const kCount = kElementsPerAccess; + static const ScaleType::Kind kScale = + cutlass::epilogue::thread::ScaleType::NoBetaScaling; + + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentLSE = Array; + using FragmentScaleBias = FragmentLSE; // Used by epilogue_smem_accumulator.h + + public: + // + // Methods + // + + CUTLASS_HOST_DEVICE + ApplyLogSumExp() {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return true; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const& AB, + FragmentLSE const& scale_unused, + // bias used as LSE + FragmentLSE const& bias) const { + FragmentCompute frag_AB = NumericArrayConverter< + ElementCompute, + ElementAccumulator, + kElementsPerAccess>()(AB); + FragmentCompute frag_lse_compute = + NumericArrayConverter()( + bias); + FragmentCompute frag_compute; + + minus minus_lse; + detail::ArrayExponential apply_exp; + frag_compute = minus_lse(frag_AB, frag_lse_compute); + frag_compute = apply_exp(frag_compute); + + return NumericArrayConverter< + ElementOutput, + ElementCompute, + kElementsPerAccess>()(frag_compute); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/neural/cuda/fused_multi_head_attention/gemm/custom_mma.h b/src/neural/cuda/fused_multi_head_attention/gemm/custom_mma.h new file mode 100644 index 0000000000..7326bad586 --- /dev/null +++ b/src/neural/cuda/fused_multi_head_attention/gemm/custom_mma.h @@ -0,0 +1,124 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "custom_mma_multistage.h" +#include "custom_mma_pipelined.h" +#include "cutlass/gemm/threadblock/mma_multistage.h" +#include "cutlass/gemm/threadblock/mma_pipelined.h" + +template +struct MakeCustomMma; + +template < + typename Shape, + typename IteratorA, + typename SmemIteratorA, + cutlass::arch::CacheOperation::Kind CacheOpA, + typename IteratorB, + typename SmemIteratorB, + cutlass::arch::CacheOperation::Kind CacheOpB, + typename ElementC, + typename LayoutC, + typename Policy, + int Stages, + cutlass::gemm::SharedMemoryClearOption SharedMemoryClear, + int kMaxK> +struct MakeCustomMma< + cutlass::gemm::threadblock::MmaMultistage< + Shape, + IteratorA, + SmemIteratorA, + CacheOpA, + IteratorB, + SmemIteratorB, + CacheOpB, + ElementC, + LayoutC, + Policy, + Stages, + SharedMemoryClear>, + kMaxK> { + // Reduce the number of stages if we don't need that many + static int constexpr kStages = + kMaxK == cutlass::platform::numeric_limits::max() + ? Stages + : cutlass::const_min( + Stages, + (kMaxK + int(Shape::kK) - 1) / int(Shape::kK)); + using Mma = cutlass::gemm::threadblock::CustomMmaMultistage< + Shape, + IteratorA, + SmemIteratorA, + CacheOpA, + IteratorB, + SmemIteratorB, + CacheOpB, + ElementC, + LayoutC, + Policy, + kStages, + SharedMemoryClear, + kMaxK>; +}; + +template < + typename Shape, + typename IteratorA, + typename SmemIteratorA, + typename IteratorB, + typename SmemIteratorB, + typename ElementC, + typename LayoutC, + typename Policy, + int kMaxK> +struct MakeCustomMma< + cutlass::gemm::threadblock::MmaPipelined< + Shape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + Policy>, + kMaxK> { + using Mma = cutlass::gemm::threadblock::CustomMmaPipelined< + Shape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + Policy>; +}; diff --git a/src/neural/cuda/fused_multi_head_attention/gemm/custom_mma_base.h b/src/neural/cuda/fused_multi_head_attention/gemm/custom_mma_base.h new file mode 100644 index 0000000000..6c6d07819b --- /dev/null +++ b/src/neural/cuda/fused_multi_head_attention/gemm/custom_mma_base.h @@ -0,0 +1,183 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class CustomMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape< + Shape::kM / WarpGemm::kM, + Shape::kN / WarpGemm::kN, + Shape::kK / WarpGemm::kK>; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + template + struct OperandSharedStorage { + AlignedBuffer buffer; + using TensorRef = TensorRef; + + CUTLASS_DEVICE + static OperandLayout Layout() { + return OperandLayout::packed({OperandShape::kRow, OperandShape::kColumn}); + } + + /// Returns a TensorRef to the operand + CUTLASS_HOST_DEVICE + TensorRef ref() { + return TensorRef{buffer.data(), Layout()}; + } + }; + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape< + Shape::kM + Policy::SmemPaddingA::kRow, + Shape::kK * kStages + Policy::SmemPaddingA::kColumn>; + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape< + Shape::kK * kStages + Policy::SmemPaddingB::kRow, + Shape::kN + Policy::SmemPaddingB::kColumn>; + + using SharedStorageA = OperandSharedStorage< + typename Operator::ElementA, + ShapeA, + typename Operator::LayoutA>; + using SharedStorageB = OperandSharedStorage< + typename Operator::ElementB, + ShapeB, + typename Operator::LayoutB>; + using TensorRefA = typename SharedStorageA::TensorRef; + using TensorRefB = typename SharedStorageB::TensorRef; + + struct SharedStorage { + /// Buffer for A operand + SharedStorageA operand_A; + + /// Buffer for B operand + SharedStorageB operand_B; + }; + + protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorageA& shared_storageA, + SharedStorageB& shared_storageB, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storageA.ref(), lane_idx), + warp_tile_iterator_B_(shared_storageB.ref(), lane_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/neural/cuda/fused_multi_head_attention/gemm/custom_mma_multistage.h b/src/neural/cuda/fused_multi_head_attention/gemm/custom_mma_multistage.h new file mode 100644 index 0000000000..e5cdc88fae --- /dev/null +++ b/src/neural/cuda/fused_multi_head_attention/gemm/custom_mma_multistage.h @@ -0,0 +1,767 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/cache_operation.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "custom_mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Upper boundon the K dimension + int kMaxK = cutlass::platform::numeric_limits::max(), + /// Used for partial specialization + typename Enable = bool> +class CustomMmaMultistage : public CustomMmaBase { + public: + ///< Base class + using Base = CustomMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert( + Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + }; + + static bool const kSmemContainsEntireMat = kMaxK <= Shape::kK * Stages; + static constexpr int kNumStagesConcurrentLoad = + kSmemContainsEntireMat ? Stages : Stages - 1; + + private: + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + bool prologue_done_; + + // Set to `True` to ensure the accumulator will be zero outside the GEMM + // footprint + bool zero_outside_bounds_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storageA.ref(), thread_idx), + smem_iterator_B_(shared_storageB.ref(), thread_idx), + prologue_done_(false), + zero_outside_bounds_(false) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + CUTLASS_DEVICE + CustomMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& st, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : CustomMmaMultistage( + st.operand_A, + st.operand_B, + thread_idx, + warp_idx, + lane_idx) {} + + CUTLASS_DEVICE + bool set_prologue_done(bool value) { + prologue_done_ = value; + } + + CUTLASS_DEVICE + bool set_zero_outside_bounds(bool value) { + zero_outside_bounds_ = value; + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorage& shared_storage, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + prologue( + shared_storage.operand_A, + shared_storage.operand_B, + iterator_A, + iterator_B, + thread_idx, + problem_size_k); + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + SmemIteratorA smem_iterator_A(shared_storageA.ref(), thread_idx); + SmemIteratorB smem_iterator_B(shared_storageB.ref(), thread_idx); + int32_t iter = (problem_size_k + Base::Shape::kK - 1) / Base::Shape::kK; + _prologue( + iterator_A, iterator_B, iter, smem_iterator_A, smem_iterator_B); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance( + IteratorA& iterator_A, + IteratorB& iterator_B, + int group_start_A = 0, + int group_start_B = 0) { + iterator_A.set_iteration_index( + group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (zero_outside_bounds_ || + SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index( + group_start_B * IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (zero_outside_bounds_ || + SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + template + CUTLASS_DEVICE static void _prologue( + IteratorA& iterator_A, + IteratorB& iterator_B, + int32_t& gemm_k_iterations, + SmemIteratorA& smem_iterator_A_, + SmemIteratorB& smem_iterator_B_) { + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; + ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + if (kLoadA) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + } + + ++iterator_A; + } + + ++smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + if (kLoadB) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + } + + ++iterator_B; + } + + ++smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + smem_iterator_A_.add_tile_offset({0, 1}); + smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< initial value of accumulator + FragmentC const& src_accum) { + // + // Prologue + // + + if (!prologue_done_) { + _prologue( + iterator_A, + iterator_B, + gemm_k_iterations, + smem_iterator_A_, + smem_iterator_B_); + } else if (!kSmemContainsEntireMat) { + _prologue( + iterator_A, + iterator_B, + gemm_k_iterations, + smem_iterator_A_, + smem_iterator_B_); + } else { + gemm_k_iterations -= kNumStagesConcurrentLoad; + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for + // some kernels so that all accumulator elements outside the GEMM footprint + // are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared + /// memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared + /// memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB warp_loaded_frag_B[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB warp_transformed_frag_B[2]; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform( + warp_transformed_frag_A[0], + warp_transformed_frag_B[0], + warp_loaded_frag_A[0], + warp_loaded_frag_B[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC tmp_accum; + + if (platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + tmp_accum.clear(); + } + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-kNumStagesConcurrentLoad);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + // In case of a non-circular buffer ("kSmemContainsEntireMat") + // make sure we don't load out of bounds data. + if (!kSmemContainsEntireMat || + gemm_k_iterations > (-kNumStagesConcurrentLoad) || + warp_mma_k < Base::kWarpGemmIterations - 1) { + this->warp_tile_iterator_A_.load( + warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load( + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + } + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) + warp_mma.transform( + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B[warp_mma_k % 2]); + + if (platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + warp_mma( + tmp_accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + tmp_accum); + + if (warp_mma_k == 0) { + accum = plus_accum(accum, tmp_accum); + tmp_accum.clear(); + } + } else { + warp_mma( + accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + accum); + } + + // Issue global->shared copies for the this stage + if (!kSmemContainsEntireMat && + warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance( + iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + if (!kSmemContainsEntireMat) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance( + iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + } + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (!kSmemContainsEntireMat && + smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations) + warp_mma.transform( + warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + } + } + + if (platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + accum = plus_accum(accum, tmp_accum); + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated cp.async pnz from the GEMM + // mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/neural/cuda/fused_multi_head_attention/gemm/custom_mma_pipelined.h b/src/neural/cuda/fused_multi_head_attention/gemm/custom_mma_pipelined.h new file mode 100644 index 0000000000..73112e9a26 --- /dev/null +++ b/src/neural/cuda/fused_multi_head_attention/gemm/custom_mma_pipelined.h @@ -0,0 +1,401 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "custom_mma_base.h" +#include "cutlass/gemm/gemm.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to A operand + typename TransformA_ = NumericArrayConverter< + typename SmemIteratorA_::Element, + typename IteratorA_::Element, + IteratorA_::Fragment::kElements>, + /// + /// Transformation applied to B operand + typename TransformB_ = NumericArrayConverter< + typename SmemIteratorB_::Element, + typename IteratorB_::Element, + IteratorB_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool> +class CustomMmaPipelined : public CustomMmaBase { + public: + ///< Base class + using Base = CustomMmaBase; + + using Shape = + Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = + IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = + IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + using TransformA = TransformA_; + using TransformB = TransformB_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert( + (Base::kStages == 2), + "MmaPipelined requires kStages set to value 2"); + + static bool const kSmemContainsEntireMat = false; + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + + protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaPipelined( + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storageA.ref(), thread_idx), + smem_iterator_B_(shared_storageB.ref(), thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + CUTLASS_DEVICE + CustomMmaPipelined( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& st, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : CustomMmaPipelined( + st.operand_A, + st.operand_B, + thread_idx, + warp_idx, + lane_idx) {} + + CUTLASS_DEVICE + bool set_prologue_done(bool value) { + // NOT IMPLEMENTED FOR PIPELINED + } + + CUTLASS_DEVICE + bool set_zero_outside_bounds(bool value) { + // NOT NEEDED FOR PIPELINED + // shared memory will always be zero-filled + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorage& shared_storage, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + prologue( + shared_storage.operand_A, + shared_storage.operand_B, + iterator_A, + iterator_B, + thread_idx, + problem_size_k); + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + // NOT IMPLEMENTED FOR PIPELINED + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const& src_accum, ///< source accumulator tile + TransformA transform_A = + TransformA(), ///< transformation applied to A fragment + TransformB transform_B = + TransformB()) { ///< transformation applied to B fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + + tb_frag_A.clear(); + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + // Write fragments to shared memory + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + + warp_mma( + accum, + warp_frag_A[warp_mma_k % 2], + warp_frag_B[warp_mma_k % 2], + accum); + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/neural/cuda/fused_multi_head_attention/gemm/find_default_mma.h b/src/neural/cuda/fused_multi_head_attention/gemm/find_default_mma.h new file mode 100644 index 0000000000..2e6b35b652 --- /dev/null +++ b/src/neural/cuda/fused_multi_head_attention/gemm/find_default_mma.h @@ -0,0 +1,191 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Cutlass provides helper template functions to figure out the right + datastructures to instanciate to run a GEMM with various parameters (see + `cutlass/gemm/threadblock/default_mma.h`). However, due to template + instantiation priority rules, it will only create an MmaMultiStage with + kStages=3 (otherwise creates an MmePipelined - which is not compatible with + FastF32). kStages=3 uses too much shared memory and we want to use kStages=2, + so we just copy-pasted some code from `default_mma.h` and + `default_mma_core.h` files and wrapped this template to allow our usecase. + + This is really only for the FastF32 case - aka using TensorCores with fp32. +*/ + +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Layout type for C and D matrix operand + typename LayoutC, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation perfomed by GEMM + typename Operator, + typename Enable_ = void> +struct FindDefaultMma { + static constexpr bool AccumulatorsInRowMajor = false; + static constexpr SharedMemoryClearOption SharedMemoryClear = + SharedMemoryClearOption::kNone; + using DefaultMma = cutlass::gemm::threadblock::DefaultMma< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + LayoutC, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + Stages, + Operator, + AccumulatorsInRowMajor, + SharedMemoryClear>; +}; + +/// Specialization for sm80 / FastF32 / multistage with kStages=2 +template < + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + int kStages, + typename Operator> +struct FindDefaultMma< + ElementA_, + LayoutA_, + kAlignmentA, + ElementB_, + LayoutB_, + kAlignmentB, + ElementAccumulator, + layout::RowMajor, + arch::OpClassTensorOp, + arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + kStages, + Operator, + typename cutlass::platform::enable_if<(kAlignmentA > 1)>::type> { + using LayoutC = layout::RowMajor; + using OperatorClass = arch::OpClassTensorOp; + using ArchTag = arch::Sm80; + + using DefaultMma_ = cutlass::gemm::threadblock::DefaultMma< + ElementA_, + LayoutA_, + kAlignmentA, + ElementB_, + LayoutB_, + kAlignmentB, + ElementAccumulator, + LayoutC, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + 3, + Operator>; + struct DefaultMma : DefaultMma_ { + using MmaCore_ = typename DefaultMma_::MmaCore; + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< + typename MmaCore_::Shape, + typename DefaultMma_::IteratorA, + typename MmaCore_::SmemIteratorA, + MmaCore_::kCacheOpA, + typename DefaultMma_::IteratorB, + typename MmaCore_::SmemIteratorB, + MmaCore_::kCacheOpB, + ElementAccumulator, + LayoutC, + typename MmaCore_::MmaPolicy, + kStages>; + }; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/src/neural/cuda/fused_multi_head_attention/gemm/mma_accum_lambda_iterator.h b/src/neural/cuda/fused_multi_head_attention/gemm/mma_accum_lambda_iterator.h new file mode 100644 index 0000000000..ad2b7e02fb --- /dev/null +++ b/src/neural/cuda/fused_multi_head_attention/gemm/mma_accum_lambda_iterator.h @@ -0,0 +1,378 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/functional.h" +#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" +#include "cutlass/matrix_shape.h" + +/* +TensorCores have different accumulator layouts. +This file provides a class to easily map the accumulator +i-th element with the corresponding matrix row/col. +*/ + +template +struct AccumLambdaIteratorSm80 { + static_assert( + cutlass::platform:: + is_same::value, + "only RowMajor is supported"); + + using Policy = typename T::Policy; + using InstructionShape = typename T::InstructionShape; + using OpDelta = typename T::OpDelta; + using Shape = typename T::Shape; + static int const kElementsPerAccess = InstructionShape::kN / 4; + static int const kRowsPerTile = 8; + static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; + + static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( + int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + return cutlass::MatrixCoord( + quad + tile_offset.row() * Shape::kRow, + lane_in_quad * kElementsPerAccess + + tile_offset.column() * Shape::kColumn); + } + + template + CUTLASS_DEVICE static void iterateRows( + cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) { + // See cutlass/gemm/warp/mma_tensor_op_tile_iterator.h + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < kAccumulatorRows; ++row) { + int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + + row * kRowsPerTile + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + int mma_accum_start = kAccumulatorRows * kElementsPerAccess * + (mma_n * Policy::MmaIterations::kRow + mma_m); + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < kElementsPerAccess; ++col) { + int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + + col + lane_offset.column(); + int idx = mma_accum_start + row * kElementsPerAccess + col; + op(accum_m, accum_n, idx); + } + } + + endRow(accum_m); + } + } + } + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { + // In each warp, 4 threads will work on the same row + // - the ones with the same `quad` + auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1); + myValue = fn(myValue, otherV); + otherV = __shfl_xor_sync(0xffffffff, myValue, 2); + myValue = fn(myValue, otherV); + int lane_in_quad = (lane_id & 3); + return lane_in_quad == 0; + } +}; + +template +struct AccumLambdaIteratorSm70 { + static_assert( + cutlass::platform:: + is_same::value, + "only RowMajor is supported"); + + using Policy = typename T::Policy; + using InstructionShape = typename T::InstructionShape; + using OpDelta = typename T::OpDelta; + using Shape = typename T::Shape; + using Element = accum_t; + + static int const kElementsPerPartial = 4; + using EleShapePerPatial = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::MatrixShape<2, 2>, + cutlass::MatrixShape<1, 4>>::type; + static int const kElementsPerMma = 8; + static int const kAccumulatorPatials = 2; + using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; + + static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( + int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + int accum_m, accum_n; + + if (cutlass::platform::is_same::value) { + // (quad[2],quad[0])+lane_in_quad[0] + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); + // (quad[1])+lane_in_quad[1] + accum_n = + ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + + (lane_in_quad & 2); + } else { + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + + lane_in_quad; // (quad[2],quad[0]) + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; + } + return cutlass::MatrixCoord( + accum_m + tile_offset.row() * Shape::kRow, + accum_n + tile_offset.column() * Shape::kColumn); + } + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { + static_assert( + cutlass::platform::is_same::value, + "update to support non-float accum"); + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 + // T0 & T2 share same line within a quad + auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 1); + myValue = fn(myValue, otherV); + // quad 0 and quad 2 are on the same lines + otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 3); + myValue = fn(myValue, otherV); + return (lane_id & ((1 << 1) | (1 << 3))) == 0; + } + + template + CUTLASS_DEVICE static void iterateRows( + cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) { + CUTLASS_PRAGMA_UNROLL + for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < EleShapePerPatial::kRow; ++m) { + int accum_m = tile_m * Policy::InterleavedTile::kRow + + mma_m * QuadShapePerPatialMma::kRow + m * 2 + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; + ++tile_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; + ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < kAccumulatorPatials; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { + int mma_accum_start = + (((tile_n * Policy::TileIterations::kRow + tile_m) * + Policy::MmaIterations::kColumn + + mma_n) * + Policy::MmaIterations::kRow + + mma_m) * + kElementsPerMma; + int accum_n = tile_n * Policy::InterleavedTile::kColumn + + mma_n * QuadShapePerPatialMma::kColumn + + p * Policy::InterleavedTile::kColumn / 2 + n + + lane_offset.column(); + int idx = mma_accum_start + p * kElementsPerPartial + + m * EleShapePerPatial::kColumn + n; + op(accum_m, accum_n, idx); + } + } + } + } + endRow(accum_m); + } + } + } + } +}; + +template +struct AccumLambdaIteratorSimt { + using Policy = typename T::Policy; + using Iterations = typename T::Iterations; + using Element = typename T::Element; + using Delta = typename T::Delta; + using Shape = typename T::Shape; + static_assert( + cutlass::platform:: + is_same::value, + "only RowMajor is supported"); + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { + CUTLASS_PRAGMA_UNROLL + for (int bit = 1; bit < Policy::WarpShape::kColumn; bit *= 2) { + auto otherV = __shfl_xor_sync(0xffffffff, myValue, bit); + myValue = fn(myValue, otherV); + } + return (lane_id & (Policy::WarpShape::kColumn - 1)) == 0; + } + + template + CUTLASS_DEVICE static void iterateRows( + cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { + int accum_m = mma_m * Delta::kRow + m + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { + int accum_n = + mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN + + lane_offset.column(); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { + int idx = n + + Policy::LaneMmaShape::kN * + (mma_n + + Iterations::kColumn * + (m + mma_m * Policy::LaneMmaShape::kM)); + op(accum_m, accum_n + n, idx); + } + } + endRow(accum_m); + } + } + } + + static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( + int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + static_assert( + cutlass::platform::is_same< + typename Policy::LaneLayout, + cutlass::layout::RowMajorInterleaved<1>>::value, + ""); + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + cutlass::MatrixCoord lane_offset = lane_layout.inverse(lane_id) * + cutlass::MatrixCoord(Policy::LaneMmaShape::kM, + Policy::LaneMmaShape::kN); + return lane_offset + + tile_offset * cutlass::MatrixCoord(Shape::kRow, Shape::kColumn); + } +}; + +template +struct DefaultMmaAccumLambdaIterator; + +// Simt +template +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp::MmaSimtTileIterator< + S, + cutlass::gemm::Operand::kC, + accum_t, + cutlass::layout::RowMajor, + P, + 1, + 1>, + accum_t, + kWarpSize> { + using WarpIterator = typename cutlass::gemm::warp::MmaSimtTileIterator< + S, + cutlass::gemm::Operand::kC, + accum_t, + cutlass::layout::RowMajor, + P, + 1, + 1>; + using Iterator = AccumLambdaIteratorSimt; +}; + +// TensorOp - Volta +template +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + cutlass::MatrixShape<1, 1>>, + accum_t, + kWarpSize> { + using WarpIterator = + typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + cutlass::MatrixShape<1, 1>>; + using Iterator = AccumLambdaIteratorSm70; +}; + +// TensorOp - Sm75+ +template < + typename S1, + typename S2, + typename S3, + typename accum_t, + int kWarpSize> +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + S3>, + accum_t, + kWarpSize> { + using WarpIterator = + typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + S3>; + using Iterator = AccumLambdaIteratorSm80; +}; diff --git a/src/neural/cuda/fused_multi_head_attention/gemm/mma_from_smem.h b/src/neural/cuda/fused_multi_head_attention/gemm/mma_from_smem.h new file mode 100644 index 0000000000..993af37a67 --- /dev/null +++ b/src/neural/cuda/fused_multi_head_attention/gemm/mma_from_smem.h @@ -0,0 +1,2055 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/functional.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/platform/platform.h" +#include "cutlass/transform/threadblock/vector_iterator.h" + +#include "../epilogue/epilogue_thread_apply_logsumexp.h" +#include "../gemm/mma_accum_lambda_iterator.h" +#include "../gemm_kernel_utils.h" +#include "../iterators/make_residual_last.h" +#include "../iterators/transpose_warp_iterator.h" +#include "../iterators/warp_iterator_from_smem.h" +#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/gemm/threadblock/mma_multistage.h" +#include "cutlass/gemm/threadblock/mma_pipelined.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +/// Shared storage object needed by accumulator +/// From 13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h +template < + typename Shape_, + typename Element_, + typename Layout_, + typename Padding_> +class AccumulatorSharedStorage { + public: + // + // Type definitions + // + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using Padding = Padding_; + + /// Tensor reference to the accumulator + using TensorRefAccum = cutlass::TensorRef; + + /// Shape of the accumulator matrix in shared memory + using ShapeAccum = cutlass:: + MatrixShape; + + public: + // + // Data members + // + + /// Buffer for accumulator + cutlass::AlignedBuffer accum; + + public: + // + // Methods + // + + /// Returns a layout object for the Accum matrix + CUTLASS_DEVICE + static Layout LayoutAccum() { + return Layout::packed({ShapeAccum::kRow, ShapeAccum::kColumn}); + } + + /// Returns a TensorRef to the Accumulator + CUTLASS_HOST_DEVICE + TensorRefAccum accum_ref() { + return TensorRefAccum{accum.data(), LayoutAccum()}; + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + // Maximum value for K + int kMaxK, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class MmaBaseFromSharedMemory { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape< + Shape::kM / WarpGemm::kM, + Shape::kN / WarpGemm::kN, + Shape::kK / WarpGemm::kK>; + using WarpCount1 = WarpCount; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + static int const kWarpGemmIterations1 = kWarpGemmIterations; + + /// Number of stages + static int const kStages = Stages; + + /// If this is true, we fill the entire shmem buffer at start + /// and don't need to iterate through it in a circular fashion + static bool const kSmemContainsEntireB = kMaxK <= Shape::kK * kStages; + + /// Tensor reference to the A operand + using TensorRefA = + TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = + TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape< + Shape::kK * kStages + Policy::SmemPaddingB::kRow, + Shape::kN + Policy::SmemPaddingB::kColumn>; + + public: + // + // Data members + // + + /// Buffer for B operand + AlignedBuffer operand_B; + + public: + // + // Methods + // + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + + protected: + // + // Data members + // + + // /// Iterator to load a warp-scoped tile of A operand from shared memory + // typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + MmaBaseFromSharedMemory( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage& shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} +}; + +namespace { + +// has necessary trait compliance with WarpIteratorFromSmem but doesn't do +// anything, can be default initialized, and uses fragment that takes up +// (almost) no space. this warp iterator is selected at compile time when +// elementwise on-the-fly scaling for operand A is disabled, in which case +// operations related to loading scale factors for operand A get wiped out by +// the compiler. +template +class NoOpWarpIteratorScale { + public: + // in pipelined+multistage MMA implementations we keep an array of fragments. + // if we aren't using scaling we don't want to waste registers on fragments + // of scale elements, so ideally this would be sized 0. + // using size 1 is kind of a hack to get around arrays of zero-sized objects + // not being allowed. the compiler is probably smart enough to wipe it out + // anyways. + using Fragment = cutlass::Array; + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale() {} + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale(TensorRef const&, int) {} + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale& add_tile_offset( + typename TensorRef::TensorCoord const&) { + return *this; + } + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale& operator++() { + return *this; + } + + CUTLASS_DEVICE + void load(Fragment&) const {} +}; + +// if scaling is enabled, performs fragment elementwise multiplication between +// fragment and its scaling factor. +template +class FragmentElementwiseScaler; + +// specialization for scaling being enabled. +template +class FragmentElementwiseScaler { + public: + // cast scale_frag to correct type then apply elementwise to fragment + CUTLASS_DEVICE + static Fragment apply(Fragment frag, FragmentScale const& scale_frag) { + Fragment converted_scale_frag = cutlass::NumericArrayConverter< + typename Fragment::Element, + typename FragmentScale::Element, + FragmentScale::kElements>()(scale_frag); + return cutlass::multiplies()(frag, converted_scale_frag); + } +}; + +// specialization for scaling being disabled. doesn't do anything and should +// just get wiped out by the compiler. +template +class FragmentElementwiseScaler { + public: + CUTLASS_DEVICE + static Fragment apply(Fragment frag, FragmentScale const&) { + return frag; + } +}; +} // namespace + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + // BEGIN smem + /// Iterates over the intermediate accumulator tile in shared memory + typename WarpIteratorA, + /// whether or not to perform elementwise multiplication of A + // by another matrix (A_scale) that is also kept in shared memory prior + // to matmul A @ B + bool ScaleOperandA_, + // Accumulator type + typename AccumulatorSharedStorage, + // END smem + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to B operand + typename TransformB_ = NumericArrayConverter< + typename SmemIteratorB_::Element, + typename IteratorB_::Element, + IteratorB_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool> +class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< + Shape_, + AccumulatorSharedStorage::Shape::kN, + Policy_, + 2> { + public: + ///< Base class + using Base = MmaBaseFromSharedMemory< + Shape_, + AccumulatorSharedStorage::Shape::kN, + Policy_, + 2>; + + using Shape = + Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + static constexpr bool ScaleOperandA = ScaleOperandA_; + + ///< loads fragments of A_scale from shared memory if operand A scaling is + ///< enabled. otherwise no-op. + using WarpIteratorAScale = typename cutlass::platform::conditional< + ScaleOperandA, + WarpIteratorA, + NoOpWarpIteratorScale>::type; + + using IteratorB = + IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorB = SmemIteratorB_; + + using TransformB = TransformB_; + + // + // Dependent types + // + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert( + (Base::kStages == 2), + "MmaPipelined requires kStages set to value 2"); + + private: + using WarpFragmentA = typename Operator::FragmentA; + + /// fragment type of OperandA elementwise scaling matrix. (almost) empty + /// if operand A scaling is disabled. + using WarpFragmentAScale = typename WarpIteratorAScale::Fragment; + + using WarpFragmentB = typename Operator::FragmentB; + + /// applies scaling factor to operand A fragment if operand A scaling is + /// enabled. otherwise no-op. + using FragmentAScaler = FragmentElementwiseScaler< + WarpFragmentA, + WarpFragmentAScale, + ScaleOperandA>; + + protected: + // /// Iterator to write threadblock-scoped tile of A operand to shared memory + // SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to load a warp-scoped tile of A operand from intermediate + /// accumulator tile + WarpIteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of A_scale from intermediate + /// accumulator tile (only used if ScaleOperandA_ is true) + WarpIteratorAScale warp_tile_iterator_A_scale_; + + public: + /// constructor for MMA with operand A scaling enabled. + CUTLASS_DEVICE + MmaPipelinedFromSharedMemory( + // shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + // warp iterator over A tile held in shared memory + WarpIteratorA warp_iter_a, + // warp iterator over A_scale tile held in shared memory + WarpIteratorAScale warp_iter_a_scale, + int thread_idx, + int warp_idx, + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A_(warp_iter_a), + warp_tile_iterator_A_scale_(warp_iter_a_scale), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_A_scale_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + /// Construct from tensor references + CUTLASS_DEVICE + MmaPipelinedFromSharedMemory( + typename Base::SharedStorage& + shared_storage, ///< Shared storage needed for internal use by + ///< threadblock-scoped GEMM + AccumulatorSharedStorage& accumulator_shared_storage, + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx, ///< ID of each thread within a warp + int problem_size_0_n) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A_(accumulator_shared_storage.accum_ref(), lane_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + // For API compatibility with MmaMultistageFromSharedMemory + // but not supported as it worsens perf: older gpus < sm80 don't + // support async tranfers and have to waste registers + CUTLASS_DEVICE + void set_prologue_done(bool value) {} + CUTLASS_DEVICE + static void prologue( + typename Base::SharedStorage& shared_storage, + IteratorB iterator_B1, + int thread_idx, + int problem_size_0_n) {} + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + // IteratorA iterator_A, ///< iterator over A + // operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const& src_accum, ///< source accumulator tile + // TransformA transform_A = TransformA(), ///< transformation + // applied to A fragment + TransformB transform_B = + TransformB()) { ///< transformation applied to B fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentB tb_frag_B; + + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_B.set_residual_tile(gemm_k_iterations == 1); + iterator_B.load(tb_frag_B); + + ++iterator_B; + + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + ++this->smem_iterator_B_; + + __syncthreads(); + + // remember that WarpFragmentAScale and WarpIteratorAScale are empty/no-op + // if scaling is disabled. + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentAScale warp_frag_A_scale[2]; + WarpFragmentB warp_frag_B[2]; + warp_frag_A[0].clear(); + warp_frag_A_scale[0].clear(); + warp_frag_B[0].clear(); + + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_A_scale_.load(warp_frag_A_scale[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_scale_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_B.set_residual_tile(gemm_k_iterations == 2); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + bool hasNext = true; + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + // Write fragments to shared memory + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory SMEM: Don't reset iterator A, as + // we are continuing our iteration at this point + if (smem_write_stage_idx == 1) { + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + } + + smem_write_stage_idx ^= 1; + hasNext = gemm_k_iterations > 1; + } + + // Only read the next if we need to + if (hasNext) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_A_scale_.load( + warp_frag_A_scale[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_scale_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + iterator_B.load(tb_frag_B); + + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_B.set_residual_tile(gemm_k_iterations == 3); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + } + + warp_mma( + accum, + FragmentAScaler::apply( + warp_frag_A[warp_mma_k % 2], warp_frag_A_scale[warp_mma_k % 2]), + warp_frag_B[warp_mma_k % 2], + accum); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile in shared memory + typename WarpIteratorA1_, + /// whether or not to perform elementwise multiplication of A + // by another matrix (A_scale) that is also kept in shared memory prior + // to matmul A @ B + bool ScaleOperandA_, + // Accumulator type + typename AccumulatorSharedStorage, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB1, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages_, + int kMaxK_, + /// Used for partial specialization + typename Enable = bool> +class MmaMultistageFromSharedMemory + : public MmaBaseFromSharedMemory { + public: + ///< Base class + using Base = MmaBaseFromSharedMemory; + + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape1 = Shape1_; + ///< Iterates over tiles of B operand in global memory + using IteratorB1 = IteratorB1_; + using IteratorB = IteratorB1; + ///< Policy describing tuning details + using Policy1 = Policy1_; + + using SmemIteratorB1 = SmemIteratorB1_; + using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate + ///< accumulator tile in shared memory + static constexpr bool ScaleOperandA = ScaleOperandA_; + + ///< warp level iterator over A_scale matrix tile kept in shared memory. + ///< if elementwise A scaling is disabled then everything this does is no-op. + using WarpIteratorAScale = typename cutlass::platform::conditional< + ScaleOperandA, + WarpIteratorA1, + NoOpWarpIteratorScale>::type; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; + static constexpr bool kSmemContainsEntireB = Base::kSmemContainsEntireB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + using FragmentC = FragmentC1; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on B operand + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert( + Base::kWarpGemmIterations1 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand B + static int const TBLoadIterationsB1 = + IteratorB1::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB1 = + (TBLoadIterationsB1 + Base::kWarpGemmIterations1 - 1) / + Base::kWarpGemmIterations1; + }; + + static constexpr int kNumStagesConcurrentLoad = + kSmemContainsEntireB ? Base::kStages : Base::kStages - 1; + + private: + using WarpLoadedFragmentA1 = typename Operator1::FragmentA; + /// fragment of OperandA scale matrix. if operand A scaling is disabled this + /// is (almost) empty. + using WarpLoadedFragmentA1Scale = typename WarpIteratorAScale::Fragment; + using WarpLoadedFragmentB1 = typename Operator1::FragmentB; + using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; + using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; + + /// applies elementwise scaling to fragment of A. if operand A scaling is + /// disabled this is a no-op. + using FragmentAScaler = FragmentElementwiseScaler< + WarpLoadedFragmentA1, + WarpLoadedFragmentA1Scale, + ScaleOperandA>; + + private: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate + /// accumulator tile + WarpIteratorA1 warp_tile_iterator_A1_; + + /// Iterator to load a warp-scoped tile of A1_scale operand from shared memory + /// if operand A scaling is disabled everything this does is a no-op. + WarpIteratorAScale warp_tile_iterator_A1_scale_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + + bool prologue_done_; + + public: + /// constructor for MMA with operand A scaling enabled. + CUTLASS_DEVICE + MmaMultistageFromSharedMemory( + // shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + // warp level iterator over operand A tile kept in shared memory + WarpIteratorA1 warp_tile_iterator_A1, + // warp level iterator over operand A elementwise scale tile kept in + // shared memory. + WarpIteratorAScale warp_tile_iterator_A1_scale, + int thread_idx, + int warp_idx, + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A1_(warp_tile_iterator_A1), + warp_tile_iterator_A1_scale_(warp_tile_iterator_A1_scale), + smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx), + prologue_done_(false) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + int warp_idx_mn_1 = + warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + // Add per-warp offsets in units of warp-level tiles + warp_tile_iterator_A1_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + warp_tile_iterator_A1_scale_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); + } + + /// Construct from tensor references + CUTLASS_DEVICE + MmaMultistageFromSharedMemory( + typename Base::SharedStorage& + shared_storage, ///< Shared storage needed for internal use by + ///< threadblock-scoped GEMM + AccumulatorSharedStorage& accumulator_shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx, + ///< GEMM0 N is used for accumulator extent + int problem_size_0_n) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A1_( + accumulator_shared_storage.accum_ref(), + lane_idx), + smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx), + prologue_done_(false) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn_1 = + warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + // Add per-warp offsets in units of warp-level tiles + warp_tile_iterator_A1_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); + } + + CUTLASS_DEVICE + void set_prologue_done(bool value) { + prologue_done_ = value; + } + + CUTLASS_DEVICE + static void prologue( + typename Base::SharedStorage& shared_storage, + IteratorB iterator_B1, + int thread_idx, + int problem_size_0_n) { + SmemIteratorB1 smem_iterator_B1(shared_storage.operand_B_ref(), thread_idx); + _prologue( + iterator_B1, + (problem_size_0_n + Base::Shape::kK - 1) / Base::Shape::kK, + smem_iterator_B1); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_1( + IteratorB1& iterator_B1, + int group_start_B1 = 0) { + iterator_B1.set_iteration_index( + group_start_B1 * IteratorB1::kAccessesPerVector); + this->smem_iterator_B1_.set_iteration_index(group_start_B1); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { + if (group_start_B1 + j < Detail::TBLoadIterationsB1) { + typename IteratorB1::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B1.get(); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B1.valid()); + + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + } + } + + CUTLASS_DEVICE + static void _prologue( + IteratorB& iterator_B1, + int32_t gemm_k_iterations_1, + SmemIteratorB1& smem_iterator_B1_) { + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; + ++stage, --gemm_k_iterations_1) { + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1); + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + + iterator_B1.set_iteration_index(0); + smem_iterator_B1_.set_iteration_index(0); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { + typename IteratorB1::AccessType* dst_ptr = + reinterpret_cast( + smem_iterator_B1_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + } + + ++smem_iterator_B1_; + } + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + smem_iterator_B1_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1); + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations_1_, + ///< destination accumulator tile + FragmentC1& accum, + ///< iterator over B1 operand in global memory + IteratorB1 iterator_B1, + ///< initial value of accumulator + FragmentC1 const& src_accum) { + // 2nd Gemm + + // + // Prologue + // + // Perform accumulation in the 'd' output operand + accum = src_accum; + + if (!prologue_done_) { + _prologue(iterator_B1, gemm_k_iterations_1_, smem_iterator_B1_); + } else if (!kSmemContainsEntireB) { + // Restore the iterators increments + + int gemm_k_iterations_1 = gemm_k_iterations_1_; + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; + ++stage, --gemm_k_iterations_1) { + iterator_B1.set_iteration_index(0); + this->smem_iterator_B1_.set_iteration_index(0); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + iterator_B1.add_tile_offset({1, 0}); + this->smem_iterator_B1_.add_tile_offset({1, 0}); + } + iterator_B1.set_residual_tile(gemm_k_iterations_1 <= 1); + iterator_B1.clear_mask(gemm_k_iterations_1 <= 0); + } + + // DEPBAR+SYNC + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // remember that WarpFragmentAScale and WarpIteratorAScale are no-op/empty + // if scaling is disabled. + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; + WarpLoadedFragmentA1Scale warp_loaded_frag_A1_scale[2]; + WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; + WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; + WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; + + Operator1 warp_mma1; + + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]); + ++warp_tile_iterator_A1_; + + warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]); + ++warp_tile_iterator_A1_scale_; + + this->warp_tile_iterator_B_.set_kgroup_index(0); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B1[0]); + ++this->warp_tile_iterator_B_; + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma1.transform( + warp_transformed_frag_A1[0], + warp_transformed_frag_B1[0], + FragmentAScaler::apply( + warp_loaded_frag_A1[0], warp_loaded_frag_A1_scale[0]), + warp_loaded_frag_B1[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC1 tmp_accum; + + if (platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + tmp_accum.clear(); + } + + // + // Mainloop + // + + CUTLASS_PRAGMA_UNROLL + for (int gemm_k_iterations_1 = gemm_k_iterations_1_ - (Base::kStages - 1); + gemm_k_iterations_1 > (-Base::kStages + 1); + gemm_k_iterations_1--) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; + ++warp_mma_k) { + // Load warp-level tile from accumulator fragment (A) + // or shared memory (operand B) + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations1); + // skip warp tile loading for the last kgroup (we are out of the buf) + if (gemm_k_iterations_1 > (-Base::kStages + 2) || + warp_mma_k < Base::kWarpGemmIterations1 - 1) { + warp_tile_iterator_A1_.load( + warp_loaded_frag_A1[(warp_mma_k + 1) % 2]); + warp_tile_iterator_A1_scale_.load( + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load( + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + ++warp_tile_iterator_A1_; + ++warp_tile_iterator_A1_scale_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) + warp_mma1.transform( + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + FragmentAScaler::apply( + warp_loaded_frag_A1[warp_mma_k % 2], + warp_loaded_frag_A1_scale[warp_mma_k % 2]), + warp_loaded_frag_B1[warp_mma_k % 2]); + + if (platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + warp_mma1( + tmp_accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + tmp_accum); + + if (warp_mma_k == 0) { + accum = plus_accum(accum, tmp_accum); + tmp_accum.clear(); + } + } else { + warp_mma1( + accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + accum); + } + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations1 - 1) { + int group_start_iteration_B1; + + group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1; + + if (!kSmemContainsEntireB) { + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + } + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { + int group_start_iteration_B1; + group_start_iteration_B1 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; + + if (!kSmemContainsEntireB) { + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + } + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (!kSmemContainsEntireB) { + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + } + + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 2); + iterator_B1.clear_mask(gemm_k_iterations_1 == 1); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) + warp_mma1.transform( + warp_transformed_frag_A1[(warp_mma_k + 1) % 2], + warp_transformed_frag_B1[(warp_mma_k + 1) % 2], + FragmentAScaler::apply( + warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]), + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + } + + if (platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + accum = plus_accum(accum, tmp_accum); + } + } +}; + +template < + typename WarpShape, + typename InstructionShape, + typename RegularWarpIterator, + typename Policy, + typename Enable = void> +struct DefaultWarpIteratorAFromSharedMemory {}; + +// TensorOp - Ampere half +template +struct DefaultWarpIteratorAFromSharedMemory< + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + RegularWarpIterator, + Policy, + typename platform::enable_if<( + sizeof_bits::value == 16 && + Policy::Operator::Policy::OpDelta::kRow == 1)>::type> { + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + using WarpShape = cutlass::MatrixShape<32, 32>; + + using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem< + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element>; +}; + +// TensorOp - Ampere f32 +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<16, 8, 8>, + RegularWarpIterator, + Policy, + typename platform::enable_if<( + sizeof_bits::value != 16 || + Policy::Operator::Policy::OpDelta::kRow != 1)>::type> { + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + + using WarpIterator = + cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + cutlass::MatrixShape, + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::layout::RowMajor, + cutlass::MatrixShape, + OpDelta::kRow, + kWarpSize>; +}; + +// TensorOp - Volta +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<16, 16, 4>, + RegularWarpIterator, + Policy> { + using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>; + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + + using WarpIterator = + cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator< + cutlass::MatrixShape<32, 32>, // MatrixShape, + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, + cutlass::MatrixShape<16, 4>, + OpDelta::kRow, + kWarpSize>; +}; + +// Simt +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<1, 1, 1>, + RegularWarpIterator, + Policy> { + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr auto kWarpSize = 32; + + // We just use the same iterator, as we reproduced the same shared-memory + // schema. Just modify it to handle non-complete tiles. + using WarpIterator = RegularWarpIterator; +}; + +// Converts a "regular" Mma into their counterpart from shared memory +template < + typename Mma_, + typename AccumulatorSharedStorage, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, + bool kTransposeA = false> +struct DefaultMmaFromSharedMemory; + +// Mma pipelined +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to A operand + typename TransformA_, + /// Transformation applied to B operand + typename TransformB_, + typename AccumulatorSharedStorage_, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, + bool kTransposeA> +struct DefaultMmaFromSharedMemory< + MmaPipelined< + Shape_, + IteratorA_, + SmemIteratorA_, + IteratorB_, + SmemIteratorB_, + ElementC_, + LayoutC_, + Policy_, + TransformA_, + TransformB_>, + AccumulatorSharedStorage_, + kScaleOperandA, + kTransposeA> { + static constexpr int kWarpSize = 32; + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + + using RegularMma = MmaPipelined< + Shape_, + IteratorA_, + SmemIteratorA_, + IteratorB_, + SmemIteratorB_, + ElementC_, + LayoutC_, + Policy_, + TransformA_, + TransformB_>; + + using WarpShape = typename Policy_::Operator::Shape; + using InstructionShape = typename Policy_::Operator::InstructionShape; + using ArchMmaOperator = typename Policy_::Operator; + + static constexpr bool kIsTransposedA = false; + using WarpIteratorA = typename DefaultWarpIteratorAFromSharedMemory< + WarpShape, + InstructionShape, + typename RegularMma::Operator::IteratorA, + Policy_>::WarpIterator; + using IteratorB = + typename cutlass::transform::threadblock::MakeIteratorResidualLast< + IteratorB_>::Iterator; + + using Mma = typename cutlass::gemm::threadblock::MmaPipelinedFromSharedMemory< + Shape_, + WarpIteratorA, + kScaleOperandA, + AccumulatorSharedStorage_, + IteratorB, + SmemIteratorB_, + ElementC_, + LayoutC_, + Policy_>; +}; + +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + typename AccumulatorSharedStorage_, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, + bool kTransposeA> +struct DefaultMmaFromSharedMemory< + MmaMultistage< + Shape_, + IteratorA_, + SmemIteratorA_, + CacheOpA, + IteratorB_, + SmemIteratorB_, + CacheOpB, + ElementC_, + LayoutC_, + Policy_, + Stages, + SharedMemoryClear>, + AccumulatorSharedStorage_, + kScaleOperandA, + kTransposeA> { + static constexpr int kWarpSize = 32; + + using RegularMma = MmaMultistage< + Shape_, + IteratorA_, + SmemIteratorA_, + CacheOpA, + IteratorB_, + SmemIteratorB_, + CacheOpB, + ElementC_, + LayoutC_, + Policy_, + Stages, + SharedMemoryClear>; + + using WarpShape = typename Policy_::Operator::Shape; + using InstructionShape = typename Policy_::Operator::InstructionShape; + using WarpIteratorA_ = typename DefaultWarpIteratorAFromSharedMemory< + WarpShape, + InstructionShape, + typename RegularMma::Operator::IteratorA, + Policy_>::WarpIterator; + using WarpIteratorTranspose = TransposeWarpIterator; + static constexpr bool kIsTransposedA = + WarpIteratorTranspose::kSupportsTranspose && kTransposeA; + using WarpIteratorA = typename platform::conditional< + kIsTransposedA, + typename WarpIteratorTranspose::Iterator, + WarpIteratorA_>::type; + + static int constexpr kMaxK = kIsTransposedA + ? AccumulatorSharedStorage_::Shape::kM + : AccumulatorSharedStorage_::Shape::kN; + // Reduce the number of stages if we don't need that many + static int constexpr kStagesMax = + (kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK); + static int constexpr kStages = cutlass::const_min(Stages, kStagesMax); + + using IteratorB = + typename cutlass::transform::threadblock::MakeIteratorResidualLast< + IteratorB_>::Iterator; + using Mma = + typename cutlass::gemm::threadblock::MmaMultistageFromSharedMemory< + Shape_, + WarpIteratorA, + kScaleOperandA, + AccumulatorSharedStorage_, + IteratorB, + SmemIteratorB_, + RegularMma::kCacheOpB, + ElementC_, + LayoutC_, + Policy_, + kStages, + kMaxK>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename IteratorC, + typename Operator, + typename scalar_t, + typename WarpShape_, + typename ThreadblockShape_> +struct B2bGemm; + +// Tensor Cores >= Sm75 specialization (Ampere ...) +template < /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Element type + typename Element_, + /// Layout of operand in memory + typename Layout_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions, concept: MatrixShape) + typename OpDelta_, + typename Operator, + typename scalar_t, + typename WarpShape_, + typename ThreadblockShape_> +struct B2bGemm< + cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + Shape_, + Element_, + Layout_, + InstructionShape_, + OpDelta_>, + Operator, + scalar_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = + typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + Shape_, + Element_, + Layout_, + InstructionShape_, + OpDelta_>; + using FragmentC = typename IteratorC::Fragment; + using InstructionShape = InstructionShape_; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using accum_t = Element_; + using lse_scalar_t = float; + + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + + // Iterator to load accumulators (results of matmul in registers) + using FragmentIteratorAccumulator = + cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape, + InstructionShape, + accum_t, + typename Operator::Policy::Operator::FragmentC, + cutlass::layout::RowMajor>; + + // Iterator to store to shared-memory + using SmemIteratorD0 = typename cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape, + InstructionShape, + scalar_t, // accum_t, + SmemAccumulatorLayout>; + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage< + ThreadblockShape, + typename SmemIteratorD0::Element, + typename SmemIteratorD0::TensorLayout, + typename SmemIteratorD0::Padding>; + // We need to provide an operation for the epilogue. Let's create an + // operation that does nothing (ScaleType::Nothing), just converts + // from accum_t (float) -> scalar_t (can be half) + using OutputOpNoOp = cutlass::epilogue::thread::LinearCombination< + typename SmemIteratorD0::Element, // ElementOutput + FragmentIteratorAccumulator::Fragment::kElements, + accum_t, // ElementAccumulator + typename SmemIteratorD0::Element, // ElementCompute + cutlass::epilogue::thread::ScaleType::Nothing>; + using Epilogue = cutlass::epilogue::threadblock::EpilogueSmemAccumulator< + SmemIteratorD0, + FragmentIteratorAccumulator, + SmemIteratorD0, // ScaleBiasIterator - not used + OutputOpNoOp>; + + // Epilogue 2: with LSE (for backwards pass) + static int const kElementsPerAccess = 2; // TODO: Why 2? + using IteratorAccumulatorLSE = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + // Shape + cutlass::MatrixShape, + // WarpShape + cutlass::MatrixShape, + lse_scalar_t, + cutlass::layout::RowMajor, + kElementsPerAccess>>; + using EpilogueOpApplyLSE = cutlass::epilogue::thread::ApplyLogSumExp< + scalar_t, // ElementOutput_ + lse_scalar_t, // ElementLSE_ + accum_t, // ElementAccumulator_ + accum_t, // ElementCompute_ + 128 / cutlass::sizeof_bits::value + // FragmentIteratorAccumulator::Fragment::kElements + // InstructionShape::kM * InstructionShape::kN / 32 + >; + using EpilogueWithLSE = + cutlass::epilogue::threadblock::EpilogueSmemAccumulator< + SmemIteratorD0, + FragmentIteratorAccumulator, + IteratorAccumulatorLSE, + EpilogueOpApplyLSE>; + + static void CUTLASS_DEVICE accumToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id); + smem_iterator_attn.add_tile_offset( + tile_coords * + cutlass::MatrixCoord{ + SmemIteratorD0::TileIterations::kRow, + SmemIteratorD0::TileIterations::kColumn}); + Epilogue epilogue; + epilogue(OutputOpNoOp({}), smem_iterator_attn, accum); + } + + static void CUTLASS_DEVICE accumApplyLSEToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC& accum, + lse_scalar_t const* lse, + int32_t lse_extents, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + constexpr int32_t kAlignLSE = 32; + IteratorAccumulatorLSE iterator_lse( + lse, + {(int32_t)0, (int32_t)ceil_div(lse_extents, kAlignLSE) * kAlignLSE}, + thread_id, + warp_id, + cutlass::MatrixCoord{0, 0} // offset + ); + + SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id); + smem_iterator_attn.add_tile_offset( + tile_coords * + cutlass::MatrixCoord{ + SmemIteratorD0::TileIterations::kRow, + SmemIteratorD0::TileIterations::kColumn}); + EpilogueWithLSE epilogue; + EpilogueOpApplyLSE minus_lse_exp({}); + epilogue( + minus_lse_exp, + smem_iterator_attn, + accum, + // scale - unused + iterator_lse, + // bias + iterator_lse); + } +}; + +// Volta Specialization +// only supported for f16 +template +struct B2bGemm< + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + cutlass::MatrixShape<32, 32>, + float, + cutlass::layout::RowMajor, + cutlass::gemm::GemmShape<16, 16, 4>, + cutlass::MatrixShape<1, 1>>, + Operator, + cutlass::half_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + cutlass::MatrixShape<32, 32>, + float, + cutlass::layout::RowMajor, + cutlass::gemm::GemmShape<16, 16, 4>, + cutlass::MatrixShape<1, 1>>; + using scalar_t = cutlass::half_t; + using accum_t = IteratorC::Element; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using FragmentC = IteratorC::Fragment; + using lse_scalar_t = float; + + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + using SmemIteratorD0 = cutlass::epilogue::warp::TileIteratorVoltaTensorOp< + WarpShape, + cutlass::gemm::GemmShape<32, 32, 4>, + scalar_t, + SmemAccumulatorLayout>; + + // // Storage in shared-memory for Q.Kt + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage< + ThreadblockShape, + scalar_t, + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise< + 16, + 32>, // typename SmemIteratorD0::TensorLayout, + cutlass::MatrixShape<0, 0> // Padding + >; + + using OutputLayout = + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>; + using TensorRef = cutlass::TensorRef; + using Policy = typename IteratorC::Policy; + using Element = accum_t; + // Those are MmaVoltaTensorOpAccumulatorTileIterator private fields + // Let's copy their values + static int const kElementsPerPartial = 4; + using EleShapePerPatial = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::MatrixShape<2, 2>, + cutlass::MatrixShape<1, 4>>::type; + static int const kElementsPerMma = 8; + static int const kAccumulatorPatials = 2; + using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; + + static void CUTLASS_DEVICE accumToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + // ctor - from MmaVoltaTensorOpAccumulatorTileIterator + TensorRef ref_(shared_storage.accum_ref()); + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + int accum_m, accum_n; + + if (cutlass::platform::is_same::value) { + // (quad[2],quad[0])+lane_in_quad[0] + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); + // (quad[1])+lane_in_quad[1] + accum_n = + ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + + (lane_in_quad & 2); + } else { + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + + lane_in_quad; // (quad[2],quad[0]) + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; + } + cutlass::MatrixCoord lane_offset(accum_m, accum_n); + + // Tile offset + ref_.add_coord_offset( + tile_coords * + cutlass::MatrixCoord( + {IteratorC::Shape::kRow, IteratorC::Shape::kColumn})); + + using AccessType = cutlass::Array; + + // store - from MmaVoltaTensorOpAccumulatorTileIterator + CUTLASS_PRAGMA_UNROLL + for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { + CUTLASS_PRAGMA_UNROLL + for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + int mma_accum_start = + (((tile_n * Policy::TileIterations::kRow + tile_m) * + Policy::MmaIterations::kColumn + + mma_n) * + Policy::MmaIterations::kRow + + mma_m) * + kElementsPerMma; + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < kAccumulatorPatials; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < EleShapePerPatial::kRow; ++m) { + int accum_m = tile_m * Policy::InterleavedTile::kRow + + mma_m * QuadShapePerPatialMma::kRow + m * 2; + int accum_n = tile_n * Policy::InterleavedTile::kColumn + + mma_n * QuadShapePerPatialMma::kColumn + + p * Policy::InterleavedTile::kColumn / 2; + int r = (accum_m + lane_offset.row()); + AccessType to_store; + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { + int idx = mma_accum_start + p * kElementsPerPartial + + m * EleShapePerPatial::kColumn + n; + int c = (accum_n + n + lane_offset.column()); + to_store[n] = scalar_t(accum[idx]); + } + int c = (accum_n + lane_offset.column()); + assert(r < 32); + assert(c < 32); + *reinterpret_cast( + ref_.data() + ref_.offset({r, c})) = to_store; + } + } + } + } + } + } + } + + static void CUTLASS_DEVICE accumApplyLSEToSmem( + AccumulatorSharedStorage& shared_storage, + typename IteratorC::Fragment& accum, + lse_scalar_t const* lse, + int lse_extent, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + // Non-optimized way to apply LSE to registers + // NOTE: accum is attn.T + // TODO: Optimize for each architecture + static constexpr int WarpSize = 32; + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator:: + Iterator; + auto lane_offset = + AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords); + + cutlass::Array lse_prefetched; + lse_prefetched.clear(); + int rowIdx = 0; + int colIdx = 0; + AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + ++rowIdx; + colIdx = 0; + }, + [&](int accum_m, int accum_n, int idx) { + if (rowIdx == 1) { + lse_prefetched[colIdx] = accum_n < lse_extent + ? lse[accum_n] + : platform::numeric_limits::infinity(); + } + accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); + ++colIdx; + }, + [&](int accum_m) {}); + accumToSmem(shared_storage, accum, lane_id, tile_coords); + } +}; + +// Simt Specialization +// for f32 on Sm70-Sm75 and f16/f32 below + +template < + typename Operator, + typename OperatorPolicy, + typename scalar_t, + typename WarpShape_, + typename ThreadblockShape_> +struct B2bGemm< + cutlass::gemm::warp::MmaSimtTileIterator< + cutlass::MatrixShape<32, 32>, + cutlass::gemm::Operand::kC, + float, + cutlass::layout::RowMajor, + OperatorPolicy, + 1, + 1>, + Operator, + scalar_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = cutlass::gemm::warp::MmaSimtTileIterator< + cutlass::MatrixShape<32, 32>, + cutlass::gemm::Operand::kC, + float, + cutlass::layout::RowMajor, + OperatorPolicy, + 1, + 1>; + using accum_t = typename IteratorC::Element; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using FragmentC = typename IteratorC::Fragment; + using lse_scalar_t = float; + + // Storage in shared-memory for Q.Kt + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage< + ThreadblockShape, + scalar_t, + cutlass::layout::ColumnMajor, + cutlass::MatrixShape<0, 0> // Padding + >; + + static void CUTLASS_DEVICE accumToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + using Policy = typename IteratorC::Policy; + using Element = typename IteratorC::Element; + using Iterations = typename IteratorC::Iterations; + using Delta = typename IteratorC::Delta; + + auto ref_ = shared_storage.accum_ref(); + // ctor - MmaSimtTileIterator + // compute offset based on thread ID and lane layout + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + MatrixCoord lane_offset = lane_layout.inverse(lane_id) * + MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN); + + ref_.add_coord_offset(lane_offset); + + // Tile offset + ref_.add_coord_offset( + tile_coords * + cutlass::MatrixCoord( + {IteratorC::Shape::kRow, IteratorC::Shape::kColumn})); + + // store - MmaSimtTileIterator + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { + int r = + Policy::LaneMmaShape::kM * (mma_m * Policy::WarpShape::kRow) + + m; + int c = mma_n * Delta::kColumn + n; + int idx = n + + Policy::LaneMmaShape::kN * + (mma_n + + Iterations::kColumn * + (m + mma_m * Policy::LaneMmaShape::kM)); + ref_.at({r, c}) = scalar_t(accum[idx]); + } + } + } + } + } + + static void CUTLASS_DEVICE accumApplyLSEToSmem( + AccumulatorSharedStorage& shared_storage, + typename IteratorC::Fragment& accum, + lse_scalar_t const* lse, + int lse_extent, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + // Non-optimized way to apply LSE to registers + // NOTE: accum is attn.T + // TODO: Optimize for each architecture + static constexpr int WarpSize = 32; + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator:: + Iterator; + auto lane_offset = + AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords); + + cutlass::Array lse_prefetched; + lse_prefetched.clear(); + int rowIdx = 0; + int colIdx = 0; + AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + ++rowIdx; + colIdx = 0; + }, + [&](int accum_m, int accum_n, int idx) { + if (rowIdx == 1) { + lse_prefetched[colIdx] = accum_n < lse_extent + ? lse[accum_n] + : platform::numeric_limits::infinity(); + } + accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); + ++colIdx; + }, + [&](int accum_m) {}); + accumToSmem(shared_storage, accum, lane_id, tile_coords); + } +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/neural/cuda/fused_multi_head_attention/gemm_kernel_utils.h b/src/neural/cuda/fused_multi_head_attention/gemm_kernel_utils.h new file mode 100644 index 0000000000..3fe57f0064 --- /dev/null +++ b/src/neural/cuda/fused_multi_head_attention/gemm_kernel_utils.h @@ -0,0 +1,248 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/arch/mma.h" + +//////////////////////////////////////////////////////////////////////////////// +// Some helper functions +//////////////////////////////////////////////////////////////////////////////// +#define DISPATCH_TYPES(tensor, func) \ + { \ + if (query.scalar_type() == at::ScalarType::Float) { \ + using scalar_t = float; \ + func(); \ + } else if (query.scalar_type() == at::ScalarType::Half) { \ + using scalar_t = cutlass::half_t; \ + func(); \ + } else if (query.scalar_type() == at::ScalarType::BFloat16) { \ + using scalar_t = cutlass::bfloat16_t; \ + func(); \ + } else { \ + XFORMERS_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \ + } \ + } + +#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \ + { \ + if (BOOL_V) { \ + constexpr bool BOOL_NAME = true; \ + F(); \ + } else { \ + constexpr bool BOOL_NAME = false; \ + F(); \ + } \ + } +#define DISPATCH_ARCHTAG(CC, func) \ + { \ + if (CC >= 80) { \ + using ArchTag = cutlass::arch::Sm80; \ + func(); \ + } else if (CC >= 75) { \ + using ArchTag = cutlass::arch::Sm75; \ + func(); \ + } else if (CC >= 70) { \ + using ArchTag = cutlass::arch::Sm70; \ + func(); \ + } else if (CC >= 50) { \ + using ArchTag = cutlass::arch::Sm50; \ + func(); \ + } else { \ + XFORMERS_CHECK( \ + false, \ + "Your device is too old. We require compute capability >= 50"); \ + } \ + } + +#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \ + XFORMERS_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + XFORMERS_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + XFORMERS_CHECK(TENSOR.is_contiguous()); + +#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \ + XFORMERS_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + XFORMERS_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + XFORMERS_CHECK( \ + TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous"); + +#ifdef TORCH_CHECK +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + XFORMERS_CHECK( \ + uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned") +#define XFORMERS_CHECK TORCH_CHECK +#elif defined(__CUDACC_RTC__) +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \ + return false; \ + } +#define XFORMERS_CHECK(COND, ERR) \ + if (!(COND)) { \ + return false; \ + } +#else +#include +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \ + std::cerr << #PTR " is not correctly aligned\n"; \ + return false; \ + } +#define XFORMERS_CHECK(COND, ERR) \ + if (!(COND)) { \ + std::cerr << #COND " failed\n"; \ + return false; \ + } +#endif + +#define ASSIGN_CHECK_OVERFLOW(A, B) \ + { \ + A = B; \ + XFORMERS_CHECK( \ + B < std::numeric_limits::max(), #B " overflows"); \ + } + +namespace gemm_kernel_utils { + +template +constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) { + return (n + m - 1) / m; +} + +template +constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m) { + return ((n + m - 1) / m) * m; +} + +//////////////////////////////////////////////////////////////////////////////// +// Determine the type of GEMM we do (TensorCores or not, Shapes ...) +// TODO: Maybe we could rely on Cutlass's DefaultGemm templates +//////////////////////////////////////////////////////////////////////////////// + +// Fallback to Simt (FMA on cuda cores) if not in a special case below +template +struct DefaultGemmType { + static constexpr int ThreadK = 8; + static constexpr int WarpK = 8; + static constexpr int kMinimumAlignment = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using OpClass = cutlass::arch::OpClassSimt; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specialization for tensorcores with f32 +template +struct DefaultGemmType< + ArchTag, + float, + typename cutlass::platform::enable_if< + ArchTag::kMinComputeCapability >= 80>::type> { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 4; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Operator = cutlass::arch::OpMultiplyAddFastF32; +}; + +// Specialization for tensorcores with f16/bf16 - Sm75+ +template +struct DefaultGemmType< + ArchTag, + scalar_t, + typename cutlass::platform::enable_if< + ArchTag::kMinComputeCapability >= 75 && + cutlass::sizeof_bits::value == 16>::type> { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 4; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specialization for tensorcores with f16 - Volta +template <> +struct DefaultGemmType { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 2; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Enables to do +// `auto x = kCondition ? fa(arg) : fb(arg)` +// when `fa` and `fb` have different types +template +struct call_conditional; + +template +struct call_conditional { + template + static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) + -> decltype(ta(arg)) { + return ta(arg); + } +}; + +template +struct call_conditional { + template + static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) + -> decltype(tb(arg)) { + return tb(arg); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Mark a variable as warp-uniform - enables some compiler optimizations +// The cheapest way to do it is just to broadcast it from lane 0 +//////////////////////////////////////////////////////////////////////////////// + +CUTLASS_DEVICE int32_t warp_uniform(int32_t value) { + return (int32_t)__shfl_sync(0xffffffff, (unsigned)value, 0); +} + +template +CUTLASS_DEVICE T* warp_uniform(T* ptr) { + struct { + union { + T* ptr; + uint32_t asInt[2]; + }; + } p; + p.ptr = ptr; + p.asInt[0] = warp_uniform(p.asInt[0]); + p.asInt[1] = warp_uniform(p.asInt[1]); + return p.ptr; +} +} // namespace gemm_kernel_utils diff --git a/src/neural/cuda/fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h b/src/neural/cuda/fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h new file mode 100644 index 0000000000..44f38dbcb8 --- /dev/null +++ b/src/neural/cuda/fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h @@ -0,0 +1,752 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue iterator that supports prefetching + + Mostly copied from "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +*/ + +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/pitch_linear_thread_map.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in +/// epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | +/// ForwardTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename Element_, ///< Element data type + bool ScatterD = false, ///< Scatter D operand or not + bool UseCUDAStore = false> +class PredicatedTileIteratorPrefetch { + public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert( + ThreadMap::Iterations::kRow > 0, + "ThreadMap::Iterations::kRow must be > 0"); + static_assert( + ThreadMap::Iterations::kGroup > 0, + "ThreadMap::Iterations::kGroup must be > 0"); + static_assert( + ThreadMap::Iterations::kCluster > 0, + "ThreadMap::Iterations::kCluster must be > 0"); + static_assert( + ThreadMap::Iterations::kColumn > 0, + "ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * + ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc()) {} + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + /// Mask object + struct Mask { + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { + enable(); + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + + private: + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorParams params_; + + /// Byte-level pointer + uint8_t* byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have + /// been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + /// Scatter indices + int const* indices_; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert( + sizeof(PredicatedTileIteratorParams::stride) == 8, + "Expected 64b strides"); + + private: + // + // Methods + // + + public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorPrefetch( + PredicatedTileIteratorParams const& params, + Element* pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord(), + int const* indices = nullptr) + : params_(params), indices_(indices) { + TensorCoord thread_offset = + ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + mask_.predicates[c] = + ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < + extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { + mask_.clear(); + } + + if (ScatterD && !indices) { + mask_.clear(); + } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.row()) * LongIndex(params_.stride) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / + kElementsPerAccess; + + if (ScatterD) { + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / + kElementsPerAccess; + } + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_DEVICE + void prefetch_all() { + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < kIterations; ++iter) { + prefetch(); + ++(*this); + } + } + + CUTLASS_DEVICE + void prefetch() { + uint8_t* byte_pointer = byte_pointer_; + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + // on windows using unsigned long here gives the error + // error: asm operand type size(4) does not match + // type/size implied by constraint 'l' + uint64_t addr = (uint64_t)((void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / + kElementsPerAccess]); + asm volatile("prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr)); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * + LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) const { + load_with_byte_offset(frag, 0); + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * + LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + if (UseCUDAStore) { + if (guard) { + memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess] = + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + + column]; + } + } else { + cutlass::arch::global_store( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) const { + store_with_byte_offset(frag, 0); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void downsample_load_with_byte_offset( + Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + + int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + + (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q; + + int64_t byte_offset = + (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void upsample_load_with_byte_offset( + Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + int row_add_P = add_P; + int row_add_Q = add_Q; + if (output_P > convolution_P - 2) + row_add_P = 0; + if (output_Q > convolution_Q - 2) + row_add_Q = 0; + + int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) + + ((output_P + row_add_P) / 2) * (convolution_Q / 2) + + (output_Q + row_add_Q) / 2; + + int64_t byte_offset = + (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + CUTLASS_DEVICE + MatrixCoord thread_start() const { + return MatrixCoord(thread_start_row_, thread_start_column_); + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { + return thread_start_row_; + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { + return thread_start_column_; + } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { + return extent_row_; + } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { + return extent_column_; + } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorPrefetch& operator++() { + ++state_[0]; + + if (!ScatterD) { + byte_pointer_ += params_.advance_row; + } + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + state_[0] = 0; + ++state_[1]; + byte_pointer_ += params_.advance_group; + + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + state_[1] = 0; + ++state_[2]; + byte_pointer_ += params_.advance_cluster; + + thread_start_row_ += ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * + ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + byte_pointer_ += params_.advance_tile; + } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { + mask_.clear(); + } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { + mask_.enable(); + } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask& mask) const { + mask = mask_; + } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const& mask) { + mask_ = mask; + } +}; + +template +struct MakePrefetchableIterator { + using Iterator = PredicatedTileIteratorPrefetch< + typename IT::ThreadMap, + typename IT::Element>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/src/neural/cuda/fused_multi_head_attention/iterators/make_residual_last.h b/src/neural/cuda/fused_multi_head_attention/iterators/make_residual_last.h new file mode 100644 index 0000000000..e6b5d58a8a --- /dev/null +++ b/src/neural/cuda/fused_multi_head_attention/iterators/make_residual_last.h @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "predicated_tile_access_iterator_residual_last.h" +#include "predicated_tile_iterator_residual_last.h" + +namespace cutlass { +namespace transform { +namespace threadblock { + +template +struct MakeIteratorResidualLast; + +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + int AccessSize, + bool Gather> +struct MakeIteratorResidualLast> { + using Iterator = PredicatedTileIteratorResidualLast< + Shape, + Element, + Layout, + AdvanceRank, + ThreadMap, + AccessSize, + Gather>; +}; + +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + typename AccessType, + bool Gather> +struct MakeIteratorResidualLast> { + using Iterator = PredicatedTileAccessIteratorResidualLast< + Shape, + Element, + Layout, + AdvanceRank, + ThreadMap, + AccessType, + Gather>; +}; +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/src/neural/cuda/fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h b/src/neural/cuda/fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h new file mode 100644 index 0000000000..b9c38cc338 --- /dev/null +++ b/src/neural/cuda/fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h @@ -0,0 +1,2115 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates calculating the address and predicates to the load of tiles + from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses. The first tile + this iterator visits maybe partial, then the remaining tiles are complete. + So, we only need to compute the predicates twice, once before the first tile + and once for the remaining full tiles which can share the same predicates. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileAccessIteratorResidualLast +/// +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + typename AccessType, + bool Gather = false> +class PredicatedTileAccessIteratorResidualLast; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for pitch-linear +/// data. +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + bool Gather> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::PitchLinear, + AdvanceRank, + ThreadMap_, + AccessType_, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< + Shape, + Element, + Layout, + AdvanceRank, + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = + ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert( + !(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + using Mask = typename UnderlyingPredicates::Mask; + + /// Uses a non-template class + struct Params : PredicatedTileAccessIteratorParams { + using Base = PredicatedTileAccessIteratorParams; + + // Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : Base( + layout.stride(0), + MakePredicatedTileAccessIteratorDesc< + Shape, + Element, + Layout, + kAdvanceRank, + ThreadMap>()()) {} + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + UnderlyingPredicates the_predicates; + Mask residual_tile_mask; + + /// Parameters object with precomputed internal state + Params const& params_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + /// Below is used when Gather is turned on. We need to record strided_offset + /// and contiguous_offset seperated to compute the offset by using + /// + /// offset = contiguous_offset + indices[strided_offset] + /// + + /// Gather indices + int const* indices_; + + Index gather_offset_strided; + + private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + the_predicates.compute_predicates_(extent, is_steady_state); + } + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + the_predicates(extent), + indices_(indices) { + the_predicates.set_predicates(thread_id, threadblock_offset); + the_predicates.get_mask(residual_tile_mask); + + // Working around a weird compiler bug happening on P100 for the backward. + // I've seen together: the_predicates.predicates_[0] = 14 (instead of 15) + // residual_tile_mask[0] = 15 (correct) + // + // Adding prints when the value is calculated (in `compute_predicates_`) + // sometimes removes the bug. The consequence is that we skip some + // element of a tensor, leading to wrong results + // Setting `compute_predicates_`'s second argument (`is_steady_state`) to + // true also seems to get rid of the bug - at the cost of twice as many + // comparisons. +#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700) + constexpr bool kWorkAroundCompilerBug = false; +#else + constexpr bool kWorkAroundCompilerBug = true; +#endif + the_predicates.compute_predicates_(extent, true && !kWorkAroundCompilerBug); + + // update internal pointers + Layout layout(params_.stride_); + + if (!Gather) { + add_pointer_offset(layout(the_predicates.thread_offset_)); + } else { + gather_offset_strided = the_predicates.thread_offset_.strided(); + add_pointer_offset( + layout(make_Coord(the_predicates.thread_offset_.contiguous(), 0))); + } + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + the_predicates.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool is_residual_tile) { + if (is_residual_tile) { + the_predicates.set_mask(residual_tile_mask); + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + if (!Gather) { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); + pointer_ += Shape::kContiguous * tile_offset.contiguous(); + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); + pointer_ += Shape::kStrided * tile_offset.strided(); + } + } else { + add_pointer_offset(Shape::kContiguous * tile_offset.contiguous()); + gather_offset_strided += Shape::kStrided * tile_offset.strided(); + } + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + if (Gather) { + assert(indices_); + + if (!valid()) { + return nullptr; + } + + LongIndex contiguous_offset = the_predicates.iteration_contiguous_ * + (ThreadMap::Delta::kContiguous * sizeof_bits::value / + 8) + + the_predicates.iteration_vector_; + int strided_index = gather_offset_strided + + the_predicates.iteration_strided_ * ThreadMap::Delta::kStrided; + + LongIndex strided_offset = indices_[strided_index] * + LongIndex(params_.stride_) * sizeof_bits::value / 8; + + return reinterpret_cast( + pointer_ + contiguous_offset + strided_offset); + } + + return reinterpret_cast( + pointer_ + + the_predicates.iteration_contiguous_ * + (ThreadMap::Delta::kContiguous * + sizeof_bits::value) / + 8) + + the_predicates.iteration_vector_; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + the_predicates.operator++(); + + ++the_predicates.iteration_vector_; + if (the_predicates.iteration_vector_ < kAccessesPerVector) { + return *this; + } + + the_predicates.iteration_vector_ = 0; + ++the_predicates.iteration_contiguous_; + + if (the_predicates.iteration_contiguous_ < + ThreadMap::Iterations::kContiguous) { + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + the_predicates.iteration_contiguous_ = 0; + ++the_predicates.iteration_strided_; + + if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { + if (!Gather) { + pointer_ += params_.inc_strided_; + } + + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + the_predicates.iteration_strided_ = 0; + + if (!Gather) { + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, + // this subtraction as well as the subsequent integer addition are both + // elided by the compiler. + pointer_ -= params_.inc_advance_; + } + + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + the_predicates.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + the_predicates.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + the_predicates.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + the_predicates.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { + return the_predicates.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for column-major +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + bool Gather> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType, + Gather>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))){}; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column()), + indices) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for row-major +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + bool Gather> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::RowMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType, + Gather>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))){}; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row()), + indices) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank 2 +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::AffineRankN<2>, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRankN<2>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< + Shape, + Element, + layout::PitchLinear, + AdvanceRank, + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = + ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert( + !(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingPredicates::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileAccessIteratorResidualLast; + + private: + /// stride of pitch-linear layout (units of Element) + Coord stride_; + /// amount (in byte) to increment pointer to move to next access along + /// contiguous dimension + LongIndex inc_contiguous_; + /// amount (in byte) to increment pointer from first access of current + /// contiguous dimension to first access of next one. + LongIndex inc_strided_; + /// amount (in byte) to increment pointer from last access of current + /// contiguous dimension to first access of next one. + LongIndex inc_next_strided_; + /// amount (in byte) to increment pointer from last access to first access + /// of next tile + LongIndex inc_next_; + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_; + + public: + // Default ctor + CUTLASS_HOST_DEVICE + Params() + : stride_(0), + inc_contiguous_(0), + inc_strided_(0), + inc_next_(0), + inc_advance_(0) {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : stride_({layout.stride(0), layout.stride(1)}) { + inc_contiguous_ = + (LongIndex(stride_[0]) * ThreadMap::Delta::kContiguous) * + sizeof_bits::value / 8; + + inc_strided_ = (LongIndex(stride_[1]) * ThreadMap::Delta::kStrided) * + sizeof_bits::value / 8; + + inc_next_strided_ = inc_strided_ - + LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_; + + if (kAdvanceRank) { + // advance along strided dimension + inc_advance_ = Shape::kStrided * LongIndex(stride_[1]) * + sizeof_bits::value / 8; + } else { + // advance along contiguous dimension + inc_advance_ = + Shape::kContiguous * stride_[0] * sizeof_bits::value / 8; + } + + inc_next_ = inc_advance_ - + LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_ - + LongIndex(ThreadMap::Iterations::kStrided - 1) * inc_strided_; + }; + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const& params_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + UnderlyingPredicates the_predicates; + Mask residual_tile_mask; + + private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + the_predicates.compute_predicates_(extent, is_steady_state); + } + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + the_predicates(extent) { + the_predicates.set_predicates(thread_id, threadblock_offset); + + // update internal pointers + Layout layout(params_.stride_); + add_pointer_offset(layout(the_predicates.thread_offset_)); + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + the_predicates.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool is_residual_tile) { + if (is_residual_tile) { + the_predicates.set_mask(residual_tile_mask); + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1]); + pointer_ += Shape::kContiguous * tile_offset[0]; + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0]); + pointer_ += Shape::kStrided * tile_offset[1]; + } + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(pointer_) + + the_predicates.iteration_vector_; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + the_predicates.operator++(); + ++the_predicates.iteration_vector_; + if (the_predicates.iteration_vector_ < kAccessesPerVector) { + return *this; + } + + the_predicates.iteration_vector_ = 0; + ++the_predicates.iteration_contiguous_; + + if (the_predicates.iteration_contiguous_ < + ThreadMap::Iterations::kContiguous) { + pointer_ += params_.inc_contiguous_; + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + the_predicates.iteration_contiguous_ = 0; + ++the_predicates.iteration_strided_; + + if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { + pointer_ += params_.inc_next_strided_; + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + the_predicates.iteration_strided_ = 0; + + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, this + // subtraction as well as the subsequent integer addition are both elided by + // the compiler. + pointer_ -= params_.inc_advance_; + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + the_predicates.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + the_predicates.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + the_predicates.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + the_predicates.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return the_predicates.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank 2 +/// column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))){}; + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column())) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset( + make_Coord(tile_offset.row(), tile_offset.column())); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank-2 +/// row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2RowMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))){}; + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset( + make_Coord(tile_offset.column(), tile_offset.row())); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for column-major +/// interleaved data. It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + int InterleavedK> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape< + Shape::kRow * kInterleavedK, + Shape::kColumn / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for row-major +/// interleaved data. +// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + int InterleavedK> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::RowMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape< + Shape::kColumn * kInterleavedK, + Shape::kRow / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/src/neural/cuda/fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h b/src/neural/cuda/fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h new file mode 100644 index 0000000000..4bb96a1395 --- /dev/null +++ b/src/neural/cuda/fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h @@ -0,0 +1,2120 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of tiles from pitch-linear rank=2 + tensors. + + This iterator uses masks to guard out-of-bounds accesses. The first tile + this iterator visits maybe partial, then the remaining tiles are complete. + So, we only need to compute the predicates twice, once before the first tile + and once for the remaining full tiles which can share the same predicates. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/arch/memory.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileIteratorResidualLast +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +/// Regular tile iterator using a precomputed control structure to minimize +/// register liveness and integer arithmetic. +/// +/// Layout is assumed to be invariant at the time the precomputed "Params" +/// object is constructed. +/// +/// Base pointer and tensor extents may be specified at the time the iterator is +/// constructed. Subsequently, they are assumed to be immutable. +/// +/// Adding a logical coordinate offset may be performed at the time the iterator +/// is constructed. Subsequent additions to logical coordinate offset may be +/// performed but are relatively expensive. +/// +/// Visitation order is intended to first visit a "residual" tile that may be +/// partially full in both the advance dimension and the steady-state dimension. +/// This is assumed to be the last tile in the iteration sequence. Advancing an +/// iterator that has just been constructed moves to the first tile that is full +/// in the advance dimension and recomputes predicates. Subsequent accesses may +/// be performed without updating internal predicates and are efficient in terms +/// of live register state and pointer arithmetic instructions. +/// +/// To be efficient, this assumes the iterator will be dereferenced and advanced +/// at least once outside any looping structure to minimize integer arithmetic. +/// +/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to +/// dereferencing the iterator. +/// +/// +/// Example: +/// +/// An efficient pipeline structure may be constructed as follows: +/// +// template +// __global__ void kernel( +// typename Iterator::Params params, +// typename Iterator::Element *ptr, +// TensorCoord extent) { +// +// typename Iterator::Fragment fragment; +// +// TensorCoord threadblock_offset(0, 0); +// +// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); +// +// +// fragment = *iter; // load "residue" tile first +// ++iter; // advance to first "steady state" tile and update +// internal masks +// +// +// #pragma unroll +// for (int i = Remaining - 1; i >= 0; --i) { +// +// f(fragment); +// +// if (!i) { +// iter.clear_mask(); // light-weight operation to clear masks - +// subsequent loads become NO-OPs. +// } +// +// fragment = *iter; // load tile during "steady state" phase +// ++iter; // advance to next tile - lightweight due to +// steady-state masks +// } +// } +// +// void host(TensorView view) { +// +// using Iterator = +// transform::threadblock::PredicatedTileIteratorResidualLast; +// +// typename Iterator::Params params(view.layout()); +// +// kernel(params, view.data()); +// } +/// +/// +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + int AccessSize = ThreadMap::kElementsPerAccess, + bool Gather = false> +class PredicatedTileIteratorResidualLast; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + bool Gather> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::PitchLinear, + AdvanceRank, + ThreadMap_, + AccessSize, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + /// Type used for internal memory accesses + using AccessType = AlignedArray< + Element, + AccessSize, + (AccessSize * sizeof_bits::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = PredicatedTileAccessIteratorResidualLast< + Shape, + Element, + Layout, + kAdvanceRank, + ThreadMap, + AccessType, + Gather>; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + using Base = typename TileAccessIterator::Params::Base; + + friend PredicatedTileIteratorResidualLast; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout) {} + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Base const& base) : params_(base) {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : address_iterator_( + params.params_, + pointer, + extent, + thread_id, + threadblock_offset, + indices) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset({0, 1}); + else + address_iterator_.add_tile_offset({1, 0}); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + address_iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + address_iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + address_iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + address_iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + address_iterator_.get_mask(mask); + } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + load_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const* byte_ptr = + reinterpret_cast(address_iterator_.get()) + + byte_offset; + + AccessType const* access_ptr = + reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_byte_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + store_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char* byte_ptr = + reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType* access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { + *access_ptr = frag_ptr[idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_byte_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + bool Gather> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize, + Gather>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column()), + indices) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + bool Gather> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::RowMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize, + Gather>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = nullptr ///< Gather indices + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row()), + indices) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank-2 data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::AffineRankN<2>, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRankN<2>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + /// Type used for internal memory accesses + using AccessType = AlignedArray< + Element, + AccessSize, + (AccessSize * sizeof_bits::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = PredicatedTileAccessIteratorResidualLast< + Shape, + Element, + Layout, + kAdvanceRank, + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileIteratorResidualLast; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout) {} + + CUTLASS_HOST_DEVICE + Params() {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : address_iterator_( + params.params_, + pointer, + extent, + thread_id, + threadblock_offset) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset(make_Coord(0, 1)); + else + address_iterator_.add_tile_offset(make_Coord(1, 0)); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + address_iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + address_iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + address_iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + address_iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + address_iterator_.get_mask(mask); + } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + load_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const* byte_ptr = + reinterpret_cast(address_iterator_.get()) + + byte_offset; + + AccessType const* access_ptr = + reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_byte_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + store_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char* byte_ptr = + reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType* access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { + *access_ptr = frag_ptr[idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_byte_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2 +/// column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) {} + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column())) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2 +/// row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2RowMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {} + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for interleaved data. +/// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + int InterleavedK> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape< + Shape::kRow * kInterleavedK, + Shape::kColumn / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for interleaved-32 +/// data. It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + int InterleavedK> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::RowMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape< + Shape::kColumn * kInterleavedK, + Shape::kRow / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/src/neural/cuda/fused_multi_head_attention/iterators/transpose_warp_iterator.h b/src/neural/cuda/fused_multi_head_attention/iterators/transpose_warp_iterator.h new file mode 100644 index 0000000000..37c42ea238 --- /dev/null +++ b/src/neural/cuda/fused_multi_head_attention/iterators/transpose_warp_iterator.h @@ -0,0 +1,53 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "warp_iterator_from_smem.h" + +template +struct TransposeWarpIterator { + using Iterator = char; + static bool constexpr kSupportsTranspose = false; +}; + +template < + /// Operand identity + cutlass::gemm::Operand Operand, + /// Data type of A elements + typename Element, + bool kTranspose> +struct TransposeWarpIterator< + cutlass::gemm::warp::WarpIteratorFromSmem> { + using Iterator = + cutlass::gemm::warp::WarpIteratorFromSmem; + static bool constexpr kSupportsTranspose = true; +}; diff --git a/src/neural/cuda/fused_multi_head_attention/iterators/warp_iterator_from_smem.h b/src/neural/cuda/fused_multi_head_attention/iterators/warp_iterator_from_smem.h new file mode 100644 index 0000000000..37f416996c --- /dev/null +++ b/src/neural/cuda/fused_multi_head_attention/iterators/warp_iterator_from_smem.h @@ -0,0 +1,278 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Inspired from + "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" Loads tiles of GEMM + operands from a RowMajor shared-memory layout into registers to use by A100 + TensorCores. + + The difference with "mma_tensor_op_tile_access_iterator.h" is that: + (1) We use "ldmatrix" to load tiles, rather than manual loads (slightly + faster) (2) We support to transpose the operand (eg read `A.transpose()` when + the shared memory holds `A`) + + This is only implemented for the specific shapes. +*/ +#pragma once + +#include + +//////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace gemm { +namespace warp { + +template < + /// Operand identity + Operand Operand_, + /// Data type of A elements + typename Element_, + bool kTranspose = false> +class WarpIteratorFromSmem { + public: + /// Shape of tile to load (concept: MatrixShape) + using Shape = cutlass::MatrixShape<32, 32>; + + /// Operand tag + static Operand const kOperand = Operand_; + + /// Basic check + static_assert( + kOperand == Operand::kA || kOperand == Operand::kB, + "WarpIteratorFromSmem may only be instantiated for A or B operands to warp-level Mma."); + + /// Element type + using Element = Element_; + static_assert(sizeof_bits::value == 16, "Only supported for half"); + + /// Layout of source tile + using Layout = cutlass::layout::RowMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = cutlass::MatrixShape<16, 8>; + + /// Delta between *MMA operations (in units of *MMA operations, concept: + /// MatrixShape) + static int const kOpDelta = 1; + + /// Number of participating threads + static int const kThreads = 32; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Number of elements accessed per Shared Memory load + static int const kElementsPerAccess = + (sizeof_bits::value >= 32 ? 1 + : 32 / sizeof_bits::value); + + using InstructionCount = MatrixShape< + Shape::kRow / InstructionShape::kRow, + Shape::kColumn / InstructionShape::kColumn>; + + static int const kIterations = (kOperand == Operand::kA) + ? InstructionCount::kColumn + : InstructionCount::kRow; + + public: + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = Array< + Element, + (kOperand == Operand::kA) + ? (Shape::kRow* InstructionShape::kColumn / kThreads) + : (Shape::kColumn* InstructionShape::kRow / kThreads)>; + + /// Memory access type + // using AccessType = AlignedArray; + using AccessType = Array; + + static int constexpr kWarpShapeDivisibleInner = + (kOperand == Operand::kA ? InstructionShape::kColumn + : InstructionShape::kRow); + static int constexpr kAccessesInner = + (kWarpShapeDivisibleInner / kElementsPerAccess) / 4; + static int const kTilesPerInstruction = InstructionShape::kRow / 8; + + private: + /// Underlying tensor reference + TensorRef ref_; + + /// Origin + MatrixCoord origin_; + + /// Iterations in a tile + int iterations_; + + public: + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem(TensorRef const& ref, int lane_id) + : WarpIteratorFromSmem(ref, {Shape::kRow, Shape::kColumn}, lane_id) {} + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id) + : ref_(ref), iterations_(0) { + int ldsm_vec_num = (lane_id >> 3); + if (kOperand == Operand::kA) { + origin_ = MatrixCoord(lane_id % 8, 0); + static_assert( + InstructionCount::kRow * kAccessesInner * kTilesPerInstruction == 4, + ""); + CUTLASS_PRAGMA_UNROLL + for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow; + ++inst_m_idx) { + CUTLASS_PRAGMA_UNROLL + for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { + CUTLASS_PRAGMA_UNROLL + for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction; + ++access_m_idx) { + int access_idx = access_m_idx + + kTilesPerInstruction * + (inner_idx + kAccessesInner * inst_m_idx); + + MatrixCoord offset( + access_m_idx * 8 + inst_m_idx * InstructionShape::kRow, + inner_idx * 4 * kElementsPerAccess); + + if (access_idx == ldsm_vec_num) { + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + origin_ += offset; + } + } + } + } + } else { + origin_ = MatrixCoord(0, lane_id % 8); + static_assert(InstructionCount::kColumn * kAccessesInner == 4, ""); + CUTLASS_PRAGMA_UNROLL + for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn; + ++inst_n_idx) { + CUTLASS_PRAGMA_UNROLL + for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { + int access_idx = inner_idx + kAccessesInner * inst_n_idx; + + MatrixCoord offset( + inner_idx * 4 * kElementsPerAccess, inst_n_idx * 8); + + if (access_idx == ldsm_vec_num) { + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + origin_ += offset; + } + } + } + } + + ref_.add_coord_offset(origin_); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem& add_tile_offset(TensorCoord const& tile_offset) { + TensorCoord coord_offset( + tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); + if (kTranspose) { + coord_offset = TensorCoord{coord_offset.column(), coord_offset.row()}; + } + origin_ += coord_offset; + + ref_.add_coord_offset(coord_offset); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_DEVICE + void advance() { + if (kOperand == Operand::kA) { + add_tile_offset({0, 1}); + } else { + add_tile_offset({1, 0}); + } + + iterations_ = 0; + } + + /// increase iterations in a tile + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem& operator++() { + iterations_++; + + if (iterations_ >= kIterations) + advance(); + + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_DEVICE + void load(Fragment& frag) const { + AccessType* access_ptr = reinterpret_cast(&frag); + using LoadLayout = typename platform:: + conditional::type; + + MatrixCoord offset; + if (kOperand == Operand::kA) { + offset = MatrixCoord(0, iterations_ * InstructionShape::kColumn); + } else { + offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0); + } + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + cutlass::arch::ldsm( + access_ptr[0], ref_.data() + ref_.offset(offset)); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass +//////////////////////////////////////////////////////////////////////////////// diff --git a/src/neural/cuda/fused_multi_head_attention/kernel_forward.h b/src/neural/cuda/fused_multi_head_attention/kernel_forward.h new file mode 100644 index 0000000000..0564bcefe5 --- /dev/null +++ b/src/neural/cuda/fused_multi_head_attention/kernel_forward.h @@ -0,0 +1,1236 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#ifdef HAS_PYTORCH +#include +#include +#endif + +#include +#include +#include + +#include "cutlass/bfloat16.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/platform/platform.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "debug_utils.h" +#include "epilogue/epilogue_pipelined.h" +#include "epilogue/epilogue_rescale_output.h" +#include "gemm/find_default_mma.h" +#include "gemm/mma_from_smem.h" +#include "gemm_kernel_utils.h" +#include "transform/tile_smem_loader.h" + +#include + +using namespace gemm_kernel_utils; + +namespace { +template +constexpr int getWarpsPerSm() { + return ( + Arch::kMinComputeCapability >= 80 && + !cutlass::platform::is_same::value + ? 16 + : 12); +} +static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) { + // source: https://stackoverflow.com/a/51549250 + return (value >= 0) + ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); +} +} // namespace + +template < + // The datatype of Q/K/V + typename scalar_t_, + // Architecture we are targeting (eg `cutlass::arch::Sm80`) + typename ArchTag, + // If Q/K/V are correctly aligned in memory and we can run a fast kernel + bool isAligned_, + int kQueriesPerBlock, + int kKeysPerBlock_, + bool kSingleValueIteration_, // = `value.shape[-1] <= kKeysPerBlock` + // This is quite slower on V100 for some reason + // Set to false if you know at compile-time you will never need dropout + bool kSupportsDropout_ = true, + bool kSupportsBias_ = true> +struct AttentionKernel { + enum CustomMaskType { + NoCustomMask = 0, + CausalFromTopLeft = 1, + CausalFromBottomRight = 2, + NumCustomMaskTypes, + }; + + using scalar_t = scalar_t_; + using accum_t = float; + using lse_scalar_t = float; + using output_t = scalar_t; + // Accumulator between 2 iterations + // Using `accum_t` improves perf on f16 at the cost of + // numerical errors + using output_accum_t = accum_t; + static constexpr bool kSupportsDropout = kSupportsDropout_; + static constexpr bool kSupportsBias = kSupportsBias_; + static constexpr int kKeysPerBlock = kKeysPerBlock_; + static constexpr bool kIsAligned = isAligned_; + static constexpr bool kSingleValueIteration = kSingleValueIteration_; + static constexpr int32_t kAlignLSE = 32; // block size of backward + static constexpr bool kPreloadV = ArchTag::kMinComputeCapability >= 80 && + cutlass::sizeof_bits::value == 16; + static constexpr bool kKeepOutputInRF = kSingleValueIteration; + static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF && + !cutlass::platform::is_same::value; + + static_assert(kQueriesPerBlock % 32 == 0, ""); + static_assert(kKeysPerBlock % 32 == 0, ""); + static constexpr int kNumWarpsPerBlock = + kQueriesPerBlock * kKeysPerBlock / (32 * 32); + static constexpr int kWarpSize = 32; + + // Launch bounds + static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock; + static constexpr int kMinBlocksPerSm = + getWarpsPerSm() / kNumWarpsPerBlock; + + struct Params { + // Input tensors + scalar_t* query_ptr; // [num_queries, num_heads, head_dim] + scalar_t* key_ptr; // [num_keys, num_heads, head_dim] + scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value] + scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys] + int32_t* seqstart_q_ptr = nullptr; + int32_t* seqstart_k_ptr = nullptr; + + int32_t* causal_diagonal_ptr = nullptr; + int32_t* seqlen_k_ptr = nullptr; + uint32_t causal_diagonal_offset = 0; + + // Output tensors + output_t* output_ptr; // [num_queries, num_heads, head_dim_value] + output_accum_t* + output_accum_ptr; // [num_queries, num_heads, head_dim_value] + lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null + + // Scale + accum_t scale; + + // Dimensions/strides + int32_t head_dim; + int32_t head_dim_value; + int32_t num_queries; + int32_t num_keys; + + uint8_t custom_mask_type = NoCustomMask; + + int32_t q_strideM; + int32_t k_strideM; + int32_t v_strideM; + int32_t bias_strideM = 0; + + int32_t o_strideM = 0; + + // Everything below is only used in `advance_to_block` + // and shouldn't use registers + int32_t q_strideH; + int32_t k_strideH; + int32_t v_strideH; + int32_t bias_strideH = 0; + + int64_t q_strideB; + int64_t k_strideB; + int64_t v_strideB; + int32_t bias_strideB = 0; + + int32_t num_batches; + int32_t num_heads; + + // dropout + bool use_dropout; + unsigned long long dropout_batch_head_rng_offset; + float dropout_prob; +#ifdef HAS_PYTORCH + at::PhiloxCudaState rng_engine_inputs; +#endif + + // Moves pointers to what we should process + // Returns "false" if there is no work to do + CUTLASS_DEVICE bool advance_to_block() { + auto batch_id = blockIdx.z; + auto head_id = blockIdx.y; + auto query_start = blockIdx.x * kQueriesPerBlock; + + auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE; + + if (kSupportsDropout) { + dropout_batch_head_rng_offset = + batch_id * num_heads * num_queries * num_keys + + head_id * num_queries * num_keys; + } + + int64_t q_start, k_start; + // Advance to current batch - in case of different sequence lengths + if (seqstart_q_ptr != nullptr) { + assert(seqstart_k_ptr != nullptr); + seqstart_q_ptr += batch_id; + + q_start = seqstart_q_ptr[0]; + int64_t q_next_start = seqstart_q_ptr[1]; + int64_t k_end; + seqstart_k_ptr += batch_id; + + if (seqlen_k_ptr) { + k_start = seqstart_k_ptr[0]; + k_end = k_start + seqlen_k_ptr[batch_id]; + } else { + k_start = seqstart_k_ptr[0]; + k_end = seqstart_k_ptr[1]; + } + + num_queries = q_next_start - q_start; + num_keys = k_end - k_start; + + if (query_start >= num_queries) { + return false; + } + } else { + query_ptr += batch_id * q_strideB; + key_ptr += batch_id * k_strideB; + value_ptr += batch_id * v_strideB; + output_ptr += int64_t(batch_id * num_queries) * o_strideM; + if (output_accum_ptr != nullptr) { + output_accum_ptr += + int64_t(batch_id * num_queries) * (head_dim_value * num_heads); + } + q_start = 0; + k_start = 0; + } + + // Advance to the current batch / head / query_start + query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH; + key_ptr += k_start * k_strideM + head_id * k_strideH; + + value_ptr += k_start * v_strideM + head_id * v_strideH; + output_ptr += + int64_t(q_start + query_start) * o_strideM + head_id * head_dim_value; + + if (kSupportsBias && attn_bias_ptr != nullptr) { + attn_bias_ptr += (batch_id * bias_strideB) + (head_id * bias_strideH); + } + if (output_accum_ptr != nullptr) { + output_accum_ptr += + int64_t(q_start + query_start) * (head_dim_value * num_heads) + + head_id * head_dim_value; + } else { + // Accumulate directly in the destination buffer (eg for f32) + output_accum_ptr = (accum_t*)output_ptr; + } + + if (logsumexp_ptr != nullptr) { + // lse[batch_id, head_id, query_start] + logsumexp_ptr += + batch_id * lse_dim * num_heads + head_id * lse_dim + query_start; + } + + // Custom masking + if (causal_diagonal_ptr) { + causal_diagonal_offset = causal_diagonal_ptr[batch_id]; + } + if (custom_mask_type == CausalFromBottomRight) { + causal_diagonal_offset += num_keys - num_queries; + } + if (custom_mask_type == CausalFromTopLeft || + custom_mask_type == CausalFromBottomRight) { + // the bottom row of the current block is query_start + kQueriesPerBlock + // the last active key is then query_start + causal_diagonal_offset + + // kQueriesPerBlock so num_keys is the min between actual num_keys and + // this to avoid extra computations + num_keys = cutlass::fast_min( + int32_t(query_start + causal_diagonal_offset + kQueriesPerBlock), + num_keys); + } + + num_queries -= query_start; + num_batches = 0; // no longer used after + + // If num_queries == 1, and there is only one key head we're wasting + // 15/16th of tensor core compute In that case : + // - we only launch kernels for head_id % kQueriesPerBlock == 0 + // - we iterate over heads instead of queries (strideM = strideH) + if (num_queries == 1 && k_strideH == 0 && v_strideH == 0) { + if (head_id % kQueriesPerBlock != 0) + return false; + q_strideM = q_strideH; + num_queries = num_heads; + num_heads = 1; // unused but here for intent + // remove causal since n_query = 1 + // otherwise, offset would change with head ! + custom_mask_type = NoCustomMask; + o_strideM = head_dim_value; + } + + // Make sure the compiler knows these variables are the same on all + // the threads of the warp. + query_ptr = warp_uniform(query_ptr); + key_ptr = warp_uniform(key_ptr); + value_ptr = warp_uniform(value_ptr); + if (kSupportsBias) { + attn_bias_ptr = warp_uniform(attn_bias_ptr); + } + output_ptr = warp_uniform(output_ptr); + output_accum_ptr = warp_uniform(output_accum_ptr); + logsumexp_ptr = warp_uniform(logsumexp_ptr); + num_queries = warp_uniform(num_queries); + num_keys = warp_uniform(num_keys); + num_heads = warp_uniform(num_heads); + head_dim = warp_uniform(head_dim); + head_dim_value = warp_uniform(head_dim_value); + o_strideM = warp_uniform(o_strideM); + custom_mask_type = warp_uniform(custom_mask_type); + return true; + } + + __host__ dim3 getBlocksGrid() const { + return dim3( + ceil_div(num_queries, (int32_t)kQueriesPerBlock), + num_heads, + num_batches); + } + + __host__ dim3 getThreadsGrid() const { + return dim3(kWarpSize, kNumWarpsPerBlock, 1); + } + }; + + struct MM0 { + /* + In this first matmul, we compute a block of `Q @ K.T`. + While the calculation result is still hot in registers, we update + `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value + into a shared-memory ("AccumulatorSharedStorage") that is used later as + operand A for the second matmul (see MM1) + */ + using GemmType = DefaultGemmType; + + using OpClass = typename GemmType::OpClass; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + scalar_t, + scalar_t, + scalar_t, // ElementC + accum_t // ElementAccumulator + >; + static constexpr int kAlignmentA = + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment; + static constexpr int kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + using ThreadblockShape = cutlass::gemm:: + GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + kAlignmentA, + scalar_t, // ElementB, + cutlass::layout::ColumnMajor, // LayoutB, + kAlignmentB, + accum_t, + cutlass::layout::RowMajor, // LayoutC, + OpClass, + ArchTag, // ArchTag + ThreadblockShape, // ThreadblockShape + WarpShape, // WarpShape + typename GemmType::InstructionShape, // InstructionShape + DefaultConfig::kStages, // Should use `DefaultConfig::kStages`, but that + // uses too much smem + typename GemmType::Operator // Operator + >::DefaultMma; + using MmaCore = typename DefaultMma::MmaCore; + using IteratorA = typename DefaultMma::IteratorA; + using IteratorB = typename DefaultMma::IteratorB; + using Mma = typename DefaultMma::ThreadblockMma; + using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< + typename Mma::Operator::IteratorC, + accum_t, + kWarpSize>::Iterator; + static_assert( + MmaCore::WarpCount::kM * MmaCore::WarpCount::kN * + MmaCore::WarpCount::kK == + kNumWarpsPerBlock, + ""); + + // used for efficient load of bias tile Bij from global to shared memory + using BiasLoader = TileSmemLoader< + scalar_t, + cutlass::MatrixShape, + MmaCore::kThreads, + // input restriction: kv_len has to be a multiple of this value + 128 / cutlass::sizeof_bits::value>; + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< + typename Mma::Operator::IteratorC, + typename Mma::Operator, + scalar_t, + WarpShape, + ThreadblockShape>; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MM1 { + /** + Second matmul: perform `attn @ V` where `attn` is the attention (not + normalized) and stored in shared memory + */ + using GemmType = DefaultGemmType; + + using OpClass = typename GemmType::OpClass; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + scalar_t, + scalar_t, + output_accum_t, // ElementC + accum_t // ElementAccumulator + >; + static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem + static constexpr int kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + using ThreadblockShape = cutlass::gemm:: + GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using LayoutB = cutlass::layout::RowMajor; + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + kAlignmentA, + scalar_t, // ElementB, + LayoutB, // LayoutB, + kAlignmentB, + output_accum_t, + cutlass::layout::RowMajor, // LayoutC, + accum_t, + OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + typename DefaultConfig::EpilogueOutputOp, + void, // ThreadblockSwizzle - not used + DefaultConfig::kStages, + false, // SplitKSerial + typename GemmType::Operator>; + + using DefaultMmaFromSmem = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + typename MM0::AccumulatorSharedStorage, + false>; // kScaleOperandA + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + static_assert( + WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock, + ""); + + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_t>; + using OutputTileIteratorAccum = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_accum_t>; + + struct SharedStorageMM1 { + typename Mma::SharedStorage mm; + }; + }; + + static constexpr int64_t kAlignmentQ = MM0::kAlignmentA; + static constexpr int64_t kAlignmentK = MM0::kAlignmentB; + static constexpr int64_t kAlignmentV = 1; + + // Shared storage - depends on kernel params + struct ScalingCoefs { + cutlass::Array m_prime; + cutlass::Array s_prime; + cutlass::Array mi; + }; + + struct SharedStorageEpilogueAtEnd : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + union { + typename MM0::BiasLoader::SmemTile bias; + typename MM0::AccumulatorSharedStorage si; + }; + typename MM1::SharedStorageMM1 mm1; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& + epilogue_shared_storage() { + return epilogue; + } + }; + + struct SharedStorageEpilogueInLoop : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + union { + typename MM0::BiasLoader::SmemTile bias; + typename MM0::AccumulatorSharedStorage si; + }; + typename MM1::SharedStorageMM1 mm1; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& + epilogue_shared_storage() { + return after_mm0.epilogue; + } + }; + + using SharedStorage = typename cutlass::platform::conditional< + kSingleValueIteration || kKeepOutputInRF, + SharedStorageEpilogueAtEnd, + SharedStorageEpilogueInLoop>::type; + + static bool __host__ check_supported(Params const& p) { + CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ); + CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK); + CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV); + if (kSupportsBias) { + CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ); + XFORMERS_CHECK( + p.bias_strideB % kAlignmentQ == 0, + "attn_bias is not correctly aligned"); + XFORMERS_CHECK( + p.bias_strideH % kAlignmentQ == 0, + "attn_bias is not correctly aligned"); + XFORMERS_CHECK( + p.bias_strideM % kAlignmentQ == 0, + "attn_bias is not correctly aligned"); + } + XFORMERS_CHECK( + p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned"); + XFORMERS_CHECK( + p.k_strideM % kAlignmentK == 0, "key is not correctly aligned"); + XFORMERS_CHECK( + p.v_strideM % kAlignmentV == 0, "value is not correctly aligned"); + XFORMERS_CHECK( + p.q_strideH % kAlignmentQ == 0, "query is not correctly aligned"); + XFORMERS_CHECK( + p.k_strideH % kAlignmentK == 0, "key is not correctly aligned"); + XFORMERS_CHECK( + p.v_strideH % kAlignmentV == 0, "value is not correctly aligned"); + XFORMERS_CHECK( + p.causal_diagonal_ptr == nullptr || p.custom_mask_type != NoCustomMask, + "`causal_diagonal_ptr` is only useful when `custom_mask_type` is causal"); + XFORMERS_CHECK( + p.custom_mask_type < NumCustomMaskTypes, + "invalid value for `custom_mask_type`"); + return true; + } + + static void CUTLASS_DEVICE attention_kernel(Params& p) { + // In this block, we will only ever: + // - read query[query_start:query_end, :] + // - write to output[query_start:query_end, :] + + extern __shared__ char smem_buffer[]; + SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); + auto& m_prime = shared_storage.m_prime; + auto& s_prime = shared_storage.s_prime; + auto& mi = shared_storage.mi; + const uint32_t query_start = blockIdx.x * kQueriesPerBlock; + + static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); + if (thread_id() < kQueriesPerBlock) { + s_prime[thread_id()] = accum_t(0); + m_prime[thread_id()] = + -cutlass::platform::numeric_limits::infinity(); + mi[thread_id()] = -cutlass::platform::numeric_limits::infinity(); + } + typename MM1::Mma::FragmentC accum_o; + accum_o.clear(); + + auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator { + using OutputTileIterator = typename MM1::OutputTileIterator; + return OutputTileIterator( + typename OutputTileIterator::Params{(int32_t)p.o_strideM}, + p.output_ptr, + typename OutputTileIterator::TensorCoord{ + p.num_queries, p.head_dim_value}, + thread_id(), + {0, col}); + }; + + auto createOutputAccumIter = [&](int col) -> + typename MM1::OutputTileIteratorAccum { + using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum; + return OutputTileIteratorAccum( + typename OutputTileIteratorAccum::Params{ + (int32_t)(p.head_dim_value * p.num_heads)}, + p.output_accum_ptr, + typename OutputTileIteratorAccum::TensorCoord{ + p.num_queries, p.head_dim_value}, + thread_id(), + {0, col}); + }; + +#ifdef HAS_PYTORCH + curandStatePhilox4_32_10_t curand_state_init; + if (kSupportsDropout && p.use_dropout) { + const auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs); + + // each element of the attention matrix P with shape + // (batch_sz, n_heads, n_queries, n_keys) is associated with a single + // offset in RNG sequence. we initialize the RNG state with offset that + // starts at the beginning of a (n_queries, n_keys) matrix for this + // block's batch_id and head_id + // initializing rng state is very expensive, so we run once per kernel, + // rather than once per iteration. each iteration takes a copy of the + // initialized RNG state and offsets it as needed. + curand_init( + std::get<0>(seeds), + 0, + std::get<1>(seeds) + p.dropout_batch_head_rng_offset, + &curand_state_init); + } +#endif + + // Iterate through keys + for (int32_t iter_key_start = 0; iter_key_start < p.num_keys; + iter_key_start += kKeysPerBlock) { + int32_t problem_size_0_m = + cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries); + int32_t problem_size_0_n = cutlass::fast_min( + int32_t(kKeysPerBlock), p.num_keys - iter_key_start); + int32_t const& problem_size_0_k = p.head_dim; + int32_t const& problem_size_1_n = p.head_dim_value; + int32_t const& problem_size_1_k = problem_size_0_n; + + auto prologueV = [&](int blockN) { + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)}, + p.value_ptr + iter_key_start * p.v_strideM, + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + MM1::Mma::prologue( + shared_storage.after_mm0.mm1.mm, + iterator_V, + thread_id(), + problem_size_1_k); + }; + + __syncthreads(); // Need to have shared memory initialized, and `m_prime` + // updated from end of prev iter + // + // MATMUL: Q.K_t + // + // Computes the block-matrix product of: + // (a) query[query_start:query_end, :] + // with + // (b) key[iter_key_start:iter_key_start + kKeysPerBlock] + // and stores that into `shared_storage.si` + // + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0}; + + cutlass::MatrixCoord tb_offset_A{ + tb_tile_offset.m() * MM0::Mma::Shape::kM, tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{ + tb_tile_offset.k(), tb_tile_offset.n() * MM0::Mma::Shape::kN}; + + // Construct iterators to A and B operands + typename MM0::IteratorA iterator_A( + typename MM0::IteratorA::Params( + typename MM0::MmaCore::LayoutA(p.q_strideM)), + p.query_ptr, + {problem_size_0_m, problem_size_0_k}, + thread_id(), + tb_offset_A); + + typename MM0::IteratorB iterator_B( + typename MM0::IteratorB::Params( + typename MM0::MmaCore::LayoutB(p.k_strideM)), + p.key_ptr + iter_key_start * p.k_strideM, + {problem_size_0_k, problem_size_0_n}, + thread_id(), + tb_offset_B); + + auto my_warp_id = warp_id(); + auto my_lane_id = lane_id(); + + // Construct thread-scoped matrix multiply + typename MM0::Mma mma( + shared_storage.mm0, thread_id(), my_warp_id, my_lane_id); + + typename MM0::Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = + (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + __syncthreads(); + + if (kPreloadV) { + prologueV(0); + } + + typename MM0::Mma::Operator::IteratorC::TensorCoord + iteratorC_tile_offset = { + (tb_tile_offset.m() * MM0::Mma::WarpCount::kM) + + (my_warp_id % MM0::Mma::WarpCount::kM), + (tb_tile_offset.n() * MM0::Mma::WarpCount::kN) + + (my_warp_id / MM0::Mma::WarpCount::kM)}; + + // multiply by scaling factor + if (kSupportsBias) { + accum = + cutlass::multiplies()(p.scale, accum); + } + + // apply attention bias if applicable + if (kSupportsBias && p.attn_bias_ptr != nullptr) { + // load bias tile Bij into shared memory + typename MM0::BiasLoader::GmemTileIterator bias_iter( + {cutlass::layout::RowMajor(p.bias_strideM)}, + // attn_bias_pointer points to matrix of size (n_queries, n_keys) + // for the relevant batch_id and head_id + p.attn_bias_ptr + query_start * p.bias_strideM + iter_key_start, + {problem_size_0_m, problem_size_0_n}, + thread_id()); + cutlass::TensorRef bias_tensor_ref( + shared_storage.after_mm0.bias.data(), + cutlass::layout::RowMajor(MM0::ThreadblockShape::kN)); + typename MM0::BiasLoader::SmemTileIterator smem_tile_iter( + bias_tensor_ref, thread_id()); + MM0::BiasLoader::load(bias_iter, smem_tile_iter); + + // Pij += Bij, Pij is in register fragment and Bij is in shared memory + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( + lane_id(), warp_id(), iteratorC_tile_offset); + MM0::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) { + accum[idx] += bias_tensor_ref.at({accum_m, accum_n}); + } + }, + [&](int accum_m) {}); + } + + // Mask out last if causal + // This is only needed if upper-right corner of current query / key block + // intersects the mask Coordinates of upper-right corner of current block + // is y=query_start x=min(iter_key_start + kKeysPerBlock, num_keys)) The + // first masked element is x = y + offset -> query_start + offset There is + // intersection (and we need to mask) if min(iter_key_start + + // kKeysPerBlock, num_keys)) >= query_start + offset + if (p.custom_mask_type && + cutlass::fast_min(iter_key_start + kKeysPerBlock, p.num_keys) >= + (query_start + p.causal_diagonal_offset)) { + auto query_start = blockIdx.x * kQueriesPerBlock; + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( + lane_id(), warp_id(), iteratorC_tile_offset); + int32_t last_col; + MM0::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + // last absolute col is (last absolute query + offset) + // last local col is (last absolute query + offset - + // iter_key_start) + last_col = query_start + accum_m + p.causal_diagonal_offset - + iter_key_start; + }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n > last_col) { + accum[idx] = + -cutlass::platform::numeric_limits::infinity(); + } + }, + [&](int accum_m) {}); + } + DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] { + DISPATCH_BOOL( + p.num_keys - iter_key_start >= kKeysPerBlock, + kFullColumns, + ([&] { + // Update `mi` from accum stored in registers + // Also does accum[i] <- exp(accum[i] - mi) + iterative_softmax< + typename MM0::Mma::Operator::IteratorC, + kFullColumns, + kIsFirst>( + accum_o, + accum, + mi, + m_prime, + s_prime, + lane_id(), + thread_id(), + warp_id(), + p.num_keys - iter_key_start, + iteratorC_tile_offset, + kSupportsBias ? 1.0f : p.scale); + })); + })); + + // Output results to shared-memory + int warp_idx_mn_0 = my_warp_id % + (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{ + warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM, + warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM}; + + MM0::B2bGemm::accumToSmem( + shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords); + + __syncthreads(); + +#ifdef HAS_PYTORCH + // apply dropout (if applicable) after we've written Pij to smem. + // dropout is applied by multiplying each element of Pij by: + // - 0 with probability dropout_p + // - 1 / (1 - dropout_p) with probability 1 - dropout_p + // + // for backward purposes we want to be able to map each element of the + // attention matrix to the same random uniform number as the one we used + // in forward, without needing to use the same iteration order or having + // to store the dropout matrix. its possible to do this in registers but + // it ends up being very slow because each thread having noncontiguous + // strips of the Pij tile means we have to skip around a lot, and also + // have to generate a single random number at a time + if (kSupportsDropout && p.use_dropout) { + auto si = shared_storage.after_mm0.si.accum_ref(); + // each thread handles a contiguous sequence of elements from Sij, all + // coming from the same row. the reason they have to come from the same + // row is that the sampling random numbers from a contiguous random + // number sequence is much more efficient than jumping around, and the + // linear offset of each element of S (the global matrix) maps to an + // offset in a random number sequence. for S, the end of a row and the + // beginning of the next have adjacent offsets, but for Sij, this is not + // necessarily the case. + const int num_threads = blockDim.x * blockDim.y * blockDim.z; + const int threads_per_row = + cutlass::fast_min(num_threads / problem_size_0_m, problem_size_0_n); + const int elts_per_thread = cutlass::round_nearest( + cutlass::ceil_div(problem_size_0_n, threads_per_row), 4); + + const int thread_i = thread_id() / threads_per_row; + const int thread_start_j = + (thread_id() % threads_per_row) * elts_per_thread; + + if (thread_i < problem_size_0_m && thread_start_j < problem_size_0_n) { + curandStatePhilox4_32_10_t curand_state = curand_state_init; + skipahead( + static_cast( + (query_start + thread_i) * p.num_keys + + (iter_key_start + thread_start_j)), + &curand_state); + const float dropout_scale = 1.0 / (1.0 - p.dropout_prob); + + // apply dropout scaling to elements this thread is responsible for, + // in chunks of 4 + for (int sij_start_col_idx = thread_start_j; sij_start_col_idx < + cutlass::fast_min(thread_start_j + elts_per_thread, + problem_size_0_n); + sij_start_col_idx += 4) { + const float4 rand_uniform_quad = curand_uniform4(&curand_state); + + CUTLASS_PRAGMA_UNROLL + for (int quad_idx = 0; quad_idx < 4; ++quad_idx) { + si.at({thread_i, sij_start_col_idx + quad_idx}) *= + static_cast( + dropout_scale * + ((&rand_uniform_quad.x)[quad_idx] > p.dropout_prob)); + } + } + } + __syncthreads(); // p.use_dropout should have same value kernel-wide + } +#endif + + // + // MATMUL: Attn . V + // Run the matmul `attn @ V` for a block of attn and V. + // `attn` is read from shared memory (in `shared_storage_si`) + // `V` is read from global memory (with iterator_B) + // + + const int64_t nBlockN = kSingleValueIteration + ? 1 + : ceil_div( + (int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN)); + for (int blockN = 0; blockN < nBlockN; ++blockN) { + int gemm_k_iterations = + (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add and store it in accum + // (in registers) + if (!kPreloadV) { + __syncthreads(); // we share shmem between mma and epilogue + } + + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)}, + p.value_ptr + iter_key_start * p.v_strideM, + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + typename MM1::Mma mma_pv( + shared_storage.after_mm0.mm1.mm, + shared_storage.after_mm0.si, + (int)thread_id(), + (int)warp_id(), + (int)lane_id(), + (int)problem_size_1_k); + mma_pv.set_prologue_done(kPreloadV); + if (!kKeepOutputInRF) { + accum_o.clear(); + } + mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o); + __syncthreads(); + + if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) { + prologueV(blockN + 1); + } + + if (!kKeepOutputInRF) { + DISPATCH_BOOL( + iter_key_start == 0, kIsFirst, ([&] { + DISPATCH_BOOL( + (iter_key_start + kKeysPerBlock) >= p.num_keys, + kIsLast, + ([&] { + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = + typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = typename cutlass::epilogue:: + thread::MemoryEfficientAttentionNormalize< + typename cutlass::platform::conditional< + kIsLast, + output_t, + output_accum_t>::type, + output_accum_t, + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, + ElementCompute, + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = typename cutlass::epilogue::threadblock:: + EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename cutlass::platform::conditional< + kIsLast, + typename MM1::OutputTileIterator, + typename MM1::OutputTileIteratorAccum>::type, + typename DefaultEpilogue:: + AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // Read + // iterator + >; + + int col = blockN * MM1::Mma::Shape::kN; + auto source_iter = createOutputAccumIter(col); + auto dest_iter = call_conditional< + kIsLast, + decltype(createOutputIter), + decltype(createOutputAccumIter)>:: + apply(createOutputIter, createOutputAccumIter, col); + EpilogueOutputOp rescale(s_prime, m_prime); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), + thread_id(), + warp_id(), + lane_id()); + epilogue(rescale, dest_iter, accum_o, source_iter); + })); + })); + if (!kSingleValueIteration) { + __syncthreads(); + } + } + } + __syncthreads(); // we modify `m_prime` after + } + + if (kKeepOutputInRF) { + constexpr bool kIsFirst = true; + constexpr bool kIsLast = true; + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = + typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize< + output_t, // output + output_accum_t, // source + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, // accum + output_accum_t, // compute + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = + typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename MM1::OutputTileIterator, // destination + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // source tile + >; + auto dest_iter = createOutputIter(0); + EpilogueOutputOp rescale(s_prime, m_prime); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), + thread_id(), + warp_id(), + lane_id()); + epilogue(rescale, dest_iter, accum_o); + } + + // 7. Calculate logsumexp + // To make the backward easier, we pad logsumexp with `inf` + // this avoids a few bound checks, and is not more expensive during fwd + static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); + if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) { + auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE; + if (thread_id() < p.num_queries) { + p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()]) + + cutlass::fast_log(accum_t(s_prime[thread_id()])); + } else if (thread_id() < lse_dim) { + p.logsumexp_ptr[thread_id()] = + cutlass::platform::numeric_limits::infinity(); + } + } + } + + template < + typename WarpIteratorC, + bool kFullColumns, + bool kIsFirst> + CUTLASS_DEVICE static void iterative_softmax( + typename WarpIteratorC::Fragment& frag_o, // output so far + typename WarpIteratorC::Fragment& frag, + cutlass::Array& mi, + cutlass::Array& m_prime, + cutlass::Array& s_prime, + int8_t lane_id, + int8_t thread_id, + int8_t warp_id, + int16_t max_col, + typename WarpIteratorC::TensorCoord const& tile_offset, + float scaling) { + /* Iterates on the accumulator and corresponding position on result matrix + + (1) Update `mi[r]` to the max value of the row `r` + (2) In a second iteration do the following: + (a) accum <- exp(accum - mi) + (b) m_prime <- exp(m_prime - mi) + (c) s_prime <- s_prime * m_prime + sum(accum) + + All of this is done on registers, before we store all of this + on shared memory for the next matmul with Value. + */ + using Fragment = typename WarpIteratorC::Fragment; + using LambdaIterator = typename DefaultMmaAccumLambdaIterator< + WarpIteratorC, + accum_t, + kWarpSize>::Iterator; + // Convert to `accum_t` (rather than double) + constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + if (!kIsFirst) { + if (thread_id < kQueriesPerBlock) { + m_prime[thread_id] = mi[thread_id]; + } + __syncthreads(); + } + + auto lane_offset = + LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset); + + // First update `mi` to the max per-row + { + accum_t max; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + max = -cutlass::platform::numeric_limits::infinity(); + }, + [&](int accum_m, int accum_n, int idx) { + if (kFullColumns || accum_n < max_col) { + max = cutlass::fast_max(max, frag[idx]); + } + }, + [&](int accum_m) { + // Having 4x atomicMax seems faster than reduce within warp + // first... + atomicMaxFloat(&mi[accum_m], max * scaling); + }); + } + frag = cutlass::multiplies()(scaling * kLog2e, frag); + + // Make sure we all share the update values for `mi` + __syncthreads(); + + if (thread_id < kQueriesPerBlock) { + auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id])); + m_prime[thread_id] = m_prime_exp; + s_prime[thread_id] *= m_prime_exp; + } + __syncthreads(); // Update output fragments + if (kKeepOutputInRF && !kIsFirst) { + accum_t mp; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { mp = m_prime[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; }, + [&](int accum_m) {}); + __syncthreads(); + } + // Update accum_m, accum_n, ... + { + accum_t mi_row, total_row; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { mi_row = kLog2e * mi[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag[idx] = (kFullColumns || accum_n < max_col) + ? exp2f(frag[idx] - mi_row) + : accum_t(0.0); + }, + [&](int accum_m) {}); + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { total_row = 0.0; }, + [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; }, + [&](int accum_m) { + if (LambdaIterator::reduceSameRow( + lane_id, total_row, [](accum_t a, accum_t b) { + return a + b; + })) { + atomicAdd(&s_prime[accum_m], total_row); + } + }); + } + } + + static CUTLASS_DEVICE int8_t lane_id() { + return threadIdx.x; + } + static CUTLASS_DEVICE int8_t warp_id() { + return threadIdx.y; + } + static CUTLASS_DEVICE int16_t thread_id() { + return threadIdx.x + threadIdx.y * blockDim.x; + } +}; + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched_impl(typename AK::Params p) { + if (!p.advance_to_block()) { + return; + } + AK::attention_kernel(p); +} + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched(typename AK::Params params); diff --git a/src/neural/cuda/fused_multi_head_attention/transform/tile_smem_loader.h b/src/neural/cuda/fused_multi_head_attention/transform/tile_smem_loader.h new file mode 100644 index 0000000000..345bc5bb68 --- /dev/null +++ b/src/neural/cuda/fused_multi_head_attention/transform/tile_smem_loader.h @@ -0,0 +1,88 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +template < + typename scalar_t, // scalar type + typename ThreadblockTileShape, // size of tile to load + int Threads, // number of participating threads + int ElementsPerAccess> // thread access width in elements +class TileSmemLoader { + public: + using SmemTile = + cutlass::AlignedBuffer; + + using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< + cutlass::layout::PitchLinearShape< + ThreadblockTileShape::kColumn, // contiguous + ThreadblockTileShape::kRow>, // strided + Threads, // Threads + ElementsPerAccess>; // ElementsPerAccess + + using GmemTileIterator = + cutlass::transform::threadblock::PredicatedTileIterator< + ThreadblockTileShape, // Shape + scalar_t, // Element + cutlass::layout::RowMajor, // Layout + 0, // AdvanceRank + ThreadMap>; // ThreadMap + + using SmemTileIterator = cutlass::transform::threadblock::RegularTileIterator< + ThreadblockTileShape, // Shape + scalar_t, // Element + cutlass::layout::RowMajor, // Layout + 0, // AdvanceRank + ThreadMap>; // ThreadMap + + using Fragment = typename GmemTileIterator::Fragment; + + /// load a tile from global memory into shared memory + CUTLASS_DEVICE + static void load( + GmemTileIterator tile_load_iter, + SmemTileIterator tile_store_iter) { + Fragment tb_frag; + tb_frag.clear(); + tile_load_iter.load(tb_frag); + tile_store_iter.store(tb_frag); + + __syncthreads(); + } +}; diff --git a/src/neural/cuda/kernels.h b/src/neural/cuda/kernels.h index d6e69dc980..5231f3ae66 100644 --- a/src/neural/cuda/kernels.h +++ b/src/neural/cuda/kernels.h @@ -149,5 +149,11 @@ void inputPreprocessForAttentionBody(T* output, const T* input, int N, template void applyInputGating(T* output, const T* input, const T* mult, const T* add, int N, int HW, int C, cudaStream_t stream); + } // namespace cudnn_backend } // namespace lczero + +// Work around to avoid "nvcc error : 'cudafe++' died with status 0xC0000409" error +// For some reason nvcc runs into this random error when trying to compile this function inside the namespaces +bool fusedMHA(void* output, void* mha_q, void* mha_k, void* mha_v, void* skip, + int batch_size, int num_heads, int depth, cudaStream_t stream); \ No newline at end of file diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index f839d27347..2987f52faf 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -38,7 +38,7 @@ namespace lczero { -#if 0 +#if 1 // debug code to dump allocation in GPU memory template void dumpTensor(T* memory, int elements, const char* message, bool only_summary = false) { @@ -75,9 +75,9 @@ void dumpTensor(T* memory, int elements, const char* message, bool only_summary } if (!only_summary || i < 2 || i == elements - 1) { - // printf("%8.4f ", val); - // if ((i % 8) == 7) printf("\n"); - printf("%i;%.6f\n", i, val); + printf("%8.4f ", val); + if ((i % 8) == 7) printf("\n"); + //printf("%i;%.6f\n", i, val); } } free(temp); @@ -1444,7 +1444,7 @@ AttentionPolicyHead::AttentionPolicyHead(BaseLayer* ip, for (const auto& enc : weights.pol_encoder) { EncoderBlock* pW = new EncoderBlock( enc, scratch, encoder_heads_, embedding_op_size_, 1.0f, - nullptr, 0, max_batch_size); // using alpha = 1 for now (TODO: may change?) + nullptr, 0, max_batch_size, false); // using alpha = 1 for now (TODO: may change?) encoder_weights_.emplace_back(pW); } } @@ -1452,8 +1452,12 @@ AttentionPolicyHead::AttentionPolicyHead(BaseLayer* ip, template EncoderBlock::EncoderBlock( const LegacyWeights::EncoderLayer& cpu_weights, void* scratch, int heads, - int size, float alpha, DataType* smolgen_global_scratch, int smolgen_global_size, int max_batch_size) - : encoder_heads_(heads), embedding_op_size_(size), alpha_(alpha), + int size, float alpha, DataType* smolgen_global_scratch, + int smolgen_global_size, int max_batch_size, bool fused_mha) + : encoder_heads_(heads), + embedding_op_size_(size), + alpha_(alpha), + use_fused_mha_(fused_mha), has_smolgen_(cpu_weights.mha.has_smolgen), max_batch_size_(max_batch_size) { mha_q_size_ = cpu_weights.mha.q_b.size(); mha_k_size_ = cpu_weights.mha.k_b.size(); @@ -1664,8 +1668,8 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, const int num_outputs = smol_global_size_; /* hwhw: 64 * 64 */ const int batch = N * encoder_heads_; cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, (const DataType*)smol_global, num_inputs, - scratch0, num_inputs, 0.0f, scratch3, num_outputs); + num_inputs, 1.0f, (const DataType*)smol_global, num_inputs, scratch0, num_inputs, + 0.0f, scratch3, num_outputs); } } @@ -1711,23 +1715,42 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, // shape(k)[-1] = depth float factor = 1.0f / sqrt((float)depth); +#ifdef USE_CUTLASS + if (use_fused_mha_) { + // TODO: check if we need skip in a different tensor than same tensor as output! + bool success = + fusedMHA(scratch3, mha_q, mha_k, mha_v, + has_smolgen_ ? scratch3 : nullptr, N, + encoder_heads_, depth, stream); + + ReportCUDAErrors(cudaGetLastError()); + if (!success) throw Exception("Some error running fused MHA"); + } else +#endif // matmul_qk = tf.matmul(q, k, transpose_b=True) { if (scratch0 != offset_scratches_[stream]) { - std::vector offsets(encoder_heads_ * max_batch_size_*5); + std::vector offsets(encoder_heads_ * max_batch_size_ * 5); for (int i = 0; i < encoder_heads_ * max_batch_size_; i++) { int h = i % encoder_heads_; int n = i / encoder_heads_; offsets[i] = mha_k + h * depth + 64 * d_model * n; - offsets[i + encoder_heads_ * max_batch_size_] = mha_q + h * depth + 64 * d_model * n; - offsets[i + 2 * encoder_heads_ * max_batch_size_] = scratch2 + i * 64 * 64; - offsets[i + 3 * encoder_heads_ * max_batch_size_] = mha_v + h * depth + 64 * d_model * n; - offsets[i + 4 * encoder_heads_ * max_batch_size_] = scratch3 + h*depth + 64*d_model*n; + offsets[i + encoder_heads_ * max_batch_size_] = + mha_q + h * depth + 64 * d_model * n; + offsets[i + 2 * encoder_heads_ * max_batch_size_] = + scratch2 + i * 64 * 64; + offsets[i + 3 * encoder_heads_ * max_batch_size_] = + mha_v + h * depth + 64 * d_model * n; + offsets[i + 4 * encoder_heads_ * max_batch_size_] = + scratch3 + h * depth + 64 * d_model * n; } - ReportCUDAErrors(cudaMalloc((void**)&scratch_rel_ptrs_, encoder_heads_ * max_batch_size_ * 5 * sizeof(DataType*))); ReportCUDAErrors( - cudaMemcpy(scratch_rel_ptrs_, offsets.data(), encoder_heads_ * max_batch_size_ * 5 * sizeof(DataType*), - cudaMemcpyHostToDevice)); + cudaMalloc((void**)&scratch_rel_ptrs_, + encoder_heads_ * max_batch_size_ * 5 * sizeof(DataType*))); + ReportCUDAErrors( + cudaMemcpy(scratch_rel_ptrs_, offsets.data(), + encoder_heads_ * max_batch_size_ * 5 * sizeof(DataType*), + cudaMemcpyHostToDevice)); offset_scratches_[stream] = scratch0; } cublasXGemmBatched( @@ -1735,42 +1758,30 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, depth /*K*/, // A/B, and M/N are swapped for row-major to col-major // transform factor, // to handle "/ tf.math.sqrt(dk)" - scratch_rel_ptrs_,// mha_k + offset /*A*/, - d_model /*LDA*/, // (d_model = depth * encoder_heads_) to skip over - // other "depth" slices / heads - //64 * d_model, /*strideA*/ - scratch_rel_ptrs_ + encoder_heads_ * max_batch_size_,//mha_q + offset /*B*/, - d_model /*LDB*/, // to skip over other other "depth" slices / heads - //64 * d_model, /*strideB*/ - 0.0f, - scratch_rel_ptrs_ + encoder_heads_ * max_batch_size_ * 2, //scratch2 + outOffset /*C*/, // output (matmul_qk) goes to scratch2 - 64 /*LDC*/, - //64 * 64 /*strideC*/, - N * encoder_heads_); - } - - // attention_weights = tf.nn.softmax(scaled_attention_logits, axis = -1) - // attention_weights -> scratch2 - if (has_smolgen_) { - // Add smolgen weights to the scaled matmul_qk attention logits before softmax. - Softmax(encoder_heads_ * N * 64, 64, scratch2, scratch2, scratch3, stream); - } else { - Softmax(encoder_heads_ * N * 64, 64, scratch2, scratch2, (const DataType*)nullptr, stream); - } + scratch_rel_ptrs_, d_model /*LDA*/, + scratch_rel_ptrs_ + encoder_heads_ * max_batch_size_, d_model /*LDB*/, + 0.0f, scratch_rel_ptrs_ + encoder_heads_ * max_batch_size_ * 2, + 64 /*LDC*/, N * encoder_heads_); + + // attention_weights = tf.nn.softmax(scaled_attention_logits, axis = -1) + // attention_weights -> scratch2 + if (has_smolgen_) { + // Add smolgen weights to the scaled matmul_qk attention logits before + // softmax. + Softmax(encoder_heads_ * N * 64, 64, scratch2, scratch2, scratch3, + stream); + } else { + Softmax(encoder_heads_ * N * 64, 64, scratch2, scratch2, + (const DataType*)nullptr, stream); + } - { cublasXGemmBatched( cublas, CUBLAS_OP_N, CUBLAS_OP_N, depth /*M*/, 64 /*N*/, 64 /*K*/, 1.0f, - scratch_rel_ptrs_ + encoder_heads_ * max_batch_size_ * 3, //mha_v + offset /*A*/, // "v" matrix - d_model /*LDA*/, // to skip over other "depth" slices / heads - //64 * d_model, /*strideA*/ - scratch_rel_ptrs_ + encoder_heads_ * max_batch_size_ * 2, //scratch2 + weightsOffset /*B*/, - 64 /*LDB*/, //64 * 64, /*strideB*/ - 0.0f, - scratch_rel_ptrs_ + encoder_heads_ * max_batch_size_ * 4, //scratch3 + offset /*C*/, // output goes to scratch3 - d_model /*LDC*/, - //64 * d_model /*strideC*/, - N * encoder_heads_); + scratch_rel_ptrs_ + encoder_heads_ * max_batch_size_ * 3, + d_model /*LDA*/, + scratch_rel_ptrs_ + encoder_heads_ * max_batch_size_ * 2, 64 /*LDB*/, + 0.0f, scratch_rel_ptrs_ + encoder_heads_ * max_batch_size_ * 4, + d_model /*LDC*/, N * encoder_heads_); } // #final dense layer (mha_dense), scratch3 -> scratch2 @@ -1984,7 +1995,8 @@ template AttentionBody::AttentionBody(const LegacyWeights& weights, void* scratch, ActivationFunction default_act, - int num_res_blocks, int input_c, int max_batch_size) + int num_res_blocks, int input_c, + int max_batch_size, bool fused_mha) : embedding_op_size_(weights.ip_emb_b.size()), encoder_head_count_(weights.encoder_head_count), num_resi_blocks_(num_res_blocks), @@ -1992,6 +2004,7 @@ AttentionBody::AttentionBody(const LegacyWeights& weights, input_c_(input_c), has_gating_(weights.ip_mult_gate.size() > 0 && weights.ip_add_gate.size() > 0), has_smolgen_(weights.has_smolgen), + use_fused_mha_(fused_mha), BaseLayer(weights.ip_emb_b.size(), 8, 8, nullptr) { allocAndUpload(&ip_emb_w_, weights.ip_emb_w, scratch); @@ -2012,7 +2025,7 @@ AttentionBody::AttentionBody(const LegacyWeights& weights, for (const auto& enc : weights.encoder) { EncoderBlock* pW = new EncoderBlock( enc, scratch, encoder_head_count_, embedding_op_size_, alpha, - smolgen_global_, smolgen_global_size_, max_batch_size); + smolgen_global_, smolgen_global_size_, max_batch_size, use_fused_mha_); encoder_weights_.emplace_back(pW); } } diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index f74deb467e..0e440b9fcf 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -337,7 +337,7 @@ class EncoderBlock { public: EncoderBlock(const LegacyWeights::EncoderLayer& cpu_weights, void* scratch, int heads, int size, float alpha, DataType* smolgen_global_scratch, - int smolgen_global_size, int max_batch_size); + int smolgen_global_size, int max_batch_size, bool fused_mha); ~EncoderBlock(); void Eval(int N, DataType* inpop, DataType* scratch0, DataType* scratch1, @@ -379,6 +379,7 @@ class EncoderBlock { float alpha_; // scale to apply to skip connection add const bool has_smolgen_; + const bool use_fused_mha_; // Output sizes for smolgen layers. int smol_compress_size_; @@ -469,7 +470,7 @@ class AttentionBody : public BaseLayer { public: AttentionBody(const LegacyWeights& weights, void* scratch, ActivationFunction default_act, int num_res_blocks, - int input_c, int max_batch_size); + int input_c, int max_batch_size, bool fused_mha); ~AttentionBody(); void Eval(int N, DataType* output, const DataType* input, const DataType* input2, void* scratch, size_t scratch_size, @@ -490,6 +491,7 @@ class AttentionBody : public BaseLayer { int smolgen_global_size_; const bool has_gating_; const bool has_smolgen_; + const bool use_fused_mha_; }; diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 1b218f7006..e72b70d3b2 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -302,6 +302,11 @@ class CudaNetwork : public Network { use_res_block_winograd_fuse_opt_ = options.Get("res_block_fusing"); } + bool use_fused_mha = false; + if (deviceProp.major >= 8 && fp16) { + use_fused_mha = options.GetOrDefault("fused_mha", true); + } + const bool use_gemm_ex = deviceProp.major >= 5; // 0. Check for SE. @@ -415,7 +420,7 @@ class CudaNetwork : public Network { if (attn_body_) { auto attention_body = std::make_unique>( weights, scratch_mem_, act, numBlocks_, - numBlocks_ > 0 ? kNumFilters : kInputPlanes, max_batch_size_); + numBlocks_ > 0 ? kNumFilters : kInputPlanes, max_batch_size_, use_fused_mha); network_.emplace_back(std::move(attention_body)); encoder_last_ = getLastLayer(); From c62cf2d024329fd9d60c01c3e948e55764d8a6be Mon Sep 17 00:00:00 2001 From: Ankan Banerjee Date: Mon, 24 Apr 2023 17:11:49 +0530 Subject: [PATCH 25/70] first (somewhat) working INT8 attempt - only tries doing the KQV dense layers in int8. - Accuracy seems reasonable. - Right now quantization isn't fused, and de-quantization is done with bias add. - Both the above can be possibly be fused with more work. - Also need to attempt INT8 for other dense layers (MHA dense, FFN1 and FFN2) --- src/neural/cuda/cutlass_kernels.cu | 850 +++++++++++++++++++++++++++++ src/neural/cuda/kernels.h | 7 +- src/neural/cuda/layers.cc | 191 ++++++- src/neural/cuda/layers.h | 16 +- src/neural/cuda/network_cuda.cc | 92 +++- 5 files changed, 1132 insertions(+), 24 deletions(-) diff --git a/src/neural/cuda/cutlass_kernels.cu b/src/neural/cuda/cutlass_kernels.cu index b4dc01ed87..94dbf22966 100644 --- a/src/neural/cuda/cutlass_kernels.cu +++ b/src/neural/cuda/cutlass_kernels.cu @@ -26,9 +26,15 @@ */ #include "cuda_common.h" +#include #ifdef USE_CUTLASS + +#include "cutlass/gemm/device/gemm_array.h" +#include "cutlass/gemm/device/gemm_batched.h" + + // Fused MHA implementation from cutlass example #41 #include "fused_multi_head_attention/kernel_forward.h" @@ -121,4 +127,848 @@ bool fusedMHA(void* output, void* mha_q, void* mha_k, void* mha_v, void* skip, return fusedMHACutlass(output, mha_q, mha_k, mha_v, skip, batch_size, num_heads, depth, stream); } + + +namespace lczero { +namespace cudnn_backend { + +// function to calculate mean +static float mean(float arr[], int n) { + float sum = 0; + for (int i = 0; i < n; i++) { + sum += arr[i]; + } + return sum / n; +} + + +#include + +int fileNumber = 0; + +// function to calculate standard deviation +static float stdDev(float arr[], int n) { + float m = mean(arr, n); // get the mean + float var = 0; // initialize variance + for (int i = 0; i < n; i++) { + var += pow(arr[i] - m, 2); // add the squared difference from mean + } + var /= n; // divide by number of elements + return sqrt(var); // return the square root of variance +} + +float computeScaleFactor(const half* memory, int elements) { + std::vector fpArr(elements); + + void* temp = malloc(elements * sizeof(half)); + cudaMemcpy(temp, memory, elements * sizeof(half), cudaMemcpyDeviceToHost); + + float absmax = 0; + for (int i = 0; i < elements; i++) { + float val; + half* arr = (half*)temp; + val = (float)arr[i]; + fpArr[i] = val; + if (val == val) + absmax = std::max(absmax, fabs(val)); + else { + printf("\nNAN found!\n"); + exit(0); + } + } + + //float avg = mean(&fpArr[0], elements); + //float stddev = stdDev(&fpArr[0], elements); + + + // Ankan - for testing + char filename[100]; + sprintf(filename, "Mat_%d", fileNumber++); + FILE* fp; + fopen_s(&fp, filename, "wb+"); + fwrite(&fpArr[0], sizeof(float), elements, fp); + fclose(fp); + + // 4x standard deviation should be enough range ? + // No, it seems including the outliers is important :-/ + // absmax = std::min(stddev * 4, absmax); + free(temp); + // printf(" absmax: %f ", absmax); + return 127.0f / absmax; +} + +// Helper fuction to do vector loads/stores +template +__device__ __forceinline__ void copyAs(void* dst, const void* src) { + *((T*)(dst)) = *((const T*)(src)); +} + + +// each thread processes 8 elements (=> 16 byte reads, 8 byte writes) +__global__ void convertFp16ToInt8(int8_t* output, const half* input, float scaleFactor, int N) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int firstEl = tid * 8; + if (firstEl >= N) return; + + half ip[8]; + int8_t op[8]; + + copyAs(&ip[0], &input[firstEl]); + + for (int i = 0; i < 8; i++) { + float val = roundf((float)ip[i] * scaleFactor); + if (val > 127) val = 127; + if (val < -128) val = -128; + op[i] = (int8_t)(val); + } + /* + if (firstEl == 0) + printf("\nfrom kernel: input: %f, output: %f\n", (float)ip[0], + (float)op[0]); + */ + copyAs(&output[firstEl], &op[0]); +} + +__global__ void convertInt8ToFp16(half* output, const int8_t* input, + float scaleFactor, int N) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int firstEl = tid * 8; + if (firstEl >= N) return; + + half op[8]; + int8_t ip[8]; + + copyAs(&ip[0], &input[firstEl]); + + for (int i = 0; i < 8; i++) op[i] = (half)(scaleFactor * (float)ip[i]); + + copyAs(&output[firstEl], &op[0]); +} + + +// debug code to dump allocation in GPU memory +template +void dumpTensor(const T* memory, int elements, const char* message, + bool only_summary = false, bool cpu_tensor = false) { + const bool fp16 = std::is_same::value; + const bool int8 = std::is_same::value; + printf("\n%s\n", message); + int elementSize = (int)(fp16 ? sizeof(half) : sizeof(float)); + if (int8) elementSize = sizeof(int8_t); + int bytes = elements * elementSize; + void* temp = (void*)memory; + if (!cpu_tensor) { + temp = malloc(bytes); + cudaMemcpy(temp, memory, bytes, cudaMemcpyDeviceToHost); + } + float maxval = -std::numeric_limits::max(); + float minval = std::numeric_limits::max(); + int nans = 0; + int nanss[10]{}; + + std::vector fpArr(elements); + for (int i = 0; i < elements; i++) { + float val; + if (int8) { + int8_t* arr = (int8_t*)temp; + val = (float)arr[i]; + } + else if (fp16) { + half* arr = (half*)temp; + val = (float)arr[i]; + } else { + float* arr = (float*)temp; + val = arr[i]; + } + fpArr[i] = val; + maxval = std::max(maxval, val); + minval = std::min(minval, val); + + if (std::isnan(val)) { + if (nans < 10) nanss[nans] = i; + nans++; + } + + if (!only_summary || i < 2 || i == elements - 1) { + printf("%8.4f ", val); + if ((i % 8) == 7) printf("\n"); + // printf("%i;%.6f\n", i, val); + } + } + if (!cpu_tensor) free(temp); + if (maxval == -std::numeric_limits::max()) + maxval = std::numeric_limits::quiet_NaN(); + if (minval == std::numeric_limits::max()) + minval = std::numeric_limits::quiet_NaN(); + + float avg = mean(&fpArr[0], elements); + float stddev = stdDev(&fpArr[0], elements); + printf("Max: %.6f, Min: %.6f, Mean: %.6f, StdDev: %.6f, NaNs: %i of %i", + maxval, minval, avg, stddev, nans, elements); + if (nans > 0) { + printf("\nNaN indices: "); + for (int i = 0; i < nans && i < 10; i++) printf("%i ", nanss[i]); + if (nans > 10) printf("......"); + } + printf("\n"); +} + + + +void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int8_t* Out, + int M, int N, int K, int batchSize, + int AStride, int BStride, int OutStride, + float alphaf, float betaf) { + //dumpTensor(A, 512, "A after scaling", false); + //dumpTensor(B, 512, "B after scaling", false); + + using ElementAccumulator = int32_t; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = int8_t; // <- data type of elements in input matrix A + using ElementInputB = int8_t; // <- data type of elements in input matrix B + using ElementOutput = + int8_t; // <- data type of elements in output matrix Out + + // TODO: figure out why row major for matrix B doesn't work?!!! + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + + // This code section describes whether you want to use tensor cores or regular + // SIMT cores on GPU SM + using MMAOp = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using SmArch = cutlass::arch::Sm80; + + using ShapeMMAThreadBlock = cutlass::gemm::GemmShape<128, 128, 128>; + using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 128>; + using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 32>; + + using SwizzleThreadBlock = + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; + + // This code section describes the epilogue part of the kernel + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementComputeEpilogue>; + + // Number of pipelines you want to use + constexpr int NumStages = 3; + + using Gemm = cutlass::gemm::device::GemmBatched< + ElementInputA, LayoutInputA, ElementInputB, LayoutInputB, ElementOutput, + LayoutOutput, ElementAccumulator, MMAOp, SmArch, ShapeMMAThreadBlock, + ShapeMMAWarp, ShapeMMAOp, EpilogueOp, SwizzleThreadBlock, NumStages>; + + // Initialize alpha and beta for dot product computation + ElementComputeEpilogue alpha = ElementComputeEpilogue(alphaf); + ElementComputeEpilogue beta = ElementComputeEpilogue(betaf); + + typename Gemm::Arguments arguments{ + {M, N, K}, {A, K}, AStride, {B, K}, BStride, {Out, N}, + OutStride, {Out, N}, OutStride, {alpha, beta}, batchSize}; + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm_op; + + // Initialize CUTLASS kernel with arguments and workspace pointer + cutlass::Status status = gemm_op.initialize(arguments, nullptr); + status = gemm_op(); +} + + +static void calibrateGemm(int8_t* weights_int8, + float* input_scaling_factors, + float* output_scaling_factors, + float* cpuA, float* cpuB, int M, + int N, int K, int batchSize) { + std::vector scaling_factors(K); + + // apply smooth-quant (basically adjust A and B matrices to make quantization + // easier) + for (int k = 0; k < K; k++) { + float absMaxA = 0; + float absMaxB = 0; + // scan a column of Matrix A to find the abs max. + for (int y = 0; y < M; y++) { + float val = cpuA[y * K + k]; + absMaxA = std::max(absMaxA, abs(val)); + } + + // scan a column of Matrix B (from each batch dimension) + for (int b = 0; b < batchSize; b++) + for (int x = 0; x < N; x++) { + float val = cpuB[b * N * K + x * K + k]; + absMaxB = std::max(absMaxB, abs(val)); + } + + // compute scaling factor: + float s = sqrt(absMaxA / (absMaxB)); + + // sanity check, don't use too small, or too big scaling factors + if (s < 1) + s = 1.0f; // don't try to squeeze activations for improving range of + // weights! + if (s > 10) s = 10.0f; + + scaling_factors[k] = s; + + // printf("\nMaxA: %f, MaxB: %f, scale: %f ", absMaxA, absMaxB, s); + + // scale A and B matrices using the scaling factor + for (int y = 0; y < M; y++) { + float val = cpuA[y * K + k]; + val /= s; + cpuA[y * K + k] = (half)val; + } + + for (int b = 0; b < batchSize; b++) + for (int x = 0; x < N; x++) { + float val = cpuB[b * N * K + x * K + k]; + val *= s; + cpuB[b * N * K + x * K + k] = (half)val; + } + } + + // figure out scaling factors for A and B matrices + float absMaxA = 0; + for (int i = 0; i < M * K; i++) { + float val = cpuA[i]; + absMaxA = std::max(absMaxA, abs(val)); + } + + float AFactor = 127.0 / absMaxA; + + // update the scaling factors based on global max for Activation matrix + for (int i = 0; i < K; i++) { + input_scaling_factors[i] = 127.0f / (scaling_factors[i] * absMaxA); + } + + std::vector BFactor(batchSize); + for (int b = 0; b < batchSize; b++) { + float absMaxB = 0; + for (int i = 0; i < K * N; i++) { + float val = cpuB[i + b * K * N]; + absMaxB = std::max(absMaxB, abs(val)); + } + + // quantize the weights + float scaleB = 127.0f / absMaxB; + BFactor[b] = scaleB; + for (int i = 0; i < K * N; i++) { + float val = cpuB[i + b * K * N]; + // quantize and clamp + val = (val * scaleB); + if (val > 127) val = 127; + if (val < -128) val = -128; + weights_int8[i + b * K * N] = (int8_t)roundf(val); + } + } + + // output scaling factors + for (int i = 0; i < batchSize; i++) + output_scaling_factors[i] = 127.0 / (AFactor * BFactor[i]); + + // Ankan - for debug/test + //printf("\nScaling factors - A: %g, B_Q: %g, B_K: %g, B_V: %g \n", + // 127.0 / absMaxA, BFactor[0], BFactor[1], BFactor[2]); + + +} + + +// Same Activation (A) matrix (M x K) is multiplied by batchSize x B matrices / +// weights (K x N transposed) The outputs are: +// 1. quantized weight matrices (weights_int8) +// 2. "per-column" scaling factors (input_scaling_factors) needed to quantize +// matrix A +// 3. Scaling factors to dequantize the output matrix (just 3 values: factorQ, factorK, factorV) +// M_Batch is the batch size component in "M" dimension +// maxValuesA contains the max values in activation matrix found so far +template +void calibrateGemmForInt8(int8_t* weights_int8, + float* input_scaling_factors, + float* output_scaling_factors, + float* maxValuesA, + const DataType* A, const DataType* B, int M, + int N, int K, int batchSize, int M_Batch) { + + auto cpuA = (DataType*)malloc(M_Batch * M * K * sizeof(DataType)); + auto cpuB = (DataType*)malloc(batchSize * K * N * sizeof(DataType)); + + ReportCUDAErrors(cudaMemcpy(cpuA, A, M_Batch * M * K * sizeof(DataType), + cudaMemcpyDeviceToHost)); + ReportCUDAErrors(cudaMemcpy(cpuB, B, batchSize * K * N * sizeof(DataType), + cudaMemcpyDeviceToHost)); + + // convert to FP32 (if not already in fp32, and pick one Activation matrix at a time) + auto fpA = (float*)malloc(M * K * sizeof(float)); + auto fpB = (float*)malloc(batchSize * K * N * sizeof(float)); + + for (int i = 0; i < K * N * batchSize; i++) + fpB[i] = (float)cpuB[i]; + + + for (int b = 0; b < M_Batch; b++) { + for (int i = 0; i < M * K; i++) { + float val = abs((float)cpuA[b * M * K + i]); + val = std::max(val, maxValuesA[i]); + fpA[i] = val; + maxValuesA[i] = val; // update the max activation matrix + } + + // calibrate a single sample + calibrateGemm(weights_int8, input_scaling_factors, + output_scaling_factors, fpA, fpB, M, N, K, + batchSize); + } + + free(fpA); + free(fpB); + free(cpuA); + free(cpuB); + +} + + +void cutlassMatrixMulBTransposed(const half* A, const half* B, half* Out, int M, + int N, int K, int batchSize, int AStride, + int BStride, int OutStride, + bool useInt8 = true); + +// Test routine that does the following +// Figures out the range of values of the inputs (A and B) +// computes scaling factors for both A and B. +// convert fp16->int8 using the scaling factors. +// run int8 GEMM +// convert the result back from int8 -> fp16 +// Decide on: +// 1. Which scaling factor to use for the GEMM +// 2. which scaling factor to use for the post-process pass ? +void cutlassMatrixMulBTransposedWithInt8(const half* A, const half* B, half* Out, int M, + int N, int K, int batchSize, int AStride, + int BStride, int OutStride) { + bool useSmooth = true; + half *smoothA, *smoothB; + + if (!useSmooth) { + smoothA = (half*)A; + smoothB = (half*)B; + } else { + // SmoothQuant test: + // pre-process both A and B matrices to get best out of int8 matmul + // In practice + // * this post processing of A need to be fused with previous layer (once + // we know per-channel scaling factors which is computed offline) + // * post processing of B can be done one time / offline. + + ReportCUDAErrors(cudaMalloc(&smoothA, M * K * sizeof(half))); + ReportCUDAErrors(cudaMalloc(&smoothB, batchSize * K * N * sizeof(half))); + + // TODO: assumption on A's stride (to be 0), and B to be packed. + half* cpuA = (half*)malloc(M * K * sizeof(half)); + half* cpuB = (half*)malloc(batchSize * K * N * sizeof(half)); + + ReportCUDAErrors(cudaMemcpy(cpuA, A, M * K * sizeof(half), cudaMemcpyDeviceToHost)); + ReportCUDAErrors(cudaMemcpy(cpuB, B, batchSize * K * N * sizeof(half), + cudaMemcpyDeviceToHost)); + + for (int k = 0; k < K; k++) { + float absMaxA = 0; + float absMaxB = 0; + // scan a column of Matrix A to find the abs max. + for (int y = 0; y < M; y++) { + float val = (float)cpuA[y * K + k]; + absMaxA = std::max(absMaxA, abs(val)); + } + + // scan a column of Matrix B (from each batch dimension) + for (int b = 0; b < batchSize; b++) + for (int x = 0; x < N; x++) { + float val = (float)cpuB[b * BStride + x * K + k]; + absMaxB = std::max(absMaxB, abs(val)); + } + + // compute scaling factor: + float s = sqrt(absMaxA / (absMaxB)); + + // sanity check, don't use too small, or too big scaling factors + if (s < 1) s = 1.0f; // don't try to squeeze activations for improving range of weights! + if (s > 10) s = 10.0f; + + // printf("\nMaxA: %f, MaxB: %f, scale: %f ", absMaxA, absMaxB, s); + + // scale A and B matrices using the scaling factor + for (int y = 0; y < M; y++) { + float val = (float)cpuA[y * K + k]; + val /= s; + cpuA[y * K + k] = (half)val; + } + + for (int b = 0; b < batchSize; b++) + for (int x = 0; x < N; x++) { + float val = (float)cpuB[b * BStride + x * K + k]; + val *= s; + cpuB[b * BStride + x * K + k] = (half)val; + } + } + + ReportCUDAErrors(cudaMemcpy(smoothA, cpuA, M * K * sizeof(half), cudaMemcpyHostToDevice)); + ReportCUDAErrors(cudaMemcpy(smoothB, cpuB, batchSize * K * N * sizeof(half), + cudaMemcpyHostToDevice)); + + free(cpuA); + free(cpuB); + ReportCUDAErrors(cudaGetLastError()); + } + + // do the actual multiplication so that we can compute OutFactor + /* + cutlassMatrixMulBTransposed(A, B, Out, M, N, K, batchSize, + AStride, + BStride, + OutStride, false); + + ReportCUDAErrors(cudaGetLastError()); + */ + + // Ankan - test! + // copy the output to CPU and check the range of values per "channel" + /* + { + half* cpuOut = (half*)malloc(batchSize * M * N * sizeof(half)); + cudaMemcpy(cpuOut, Out, M * N * sizeof(half) * batchSize, + cudaMemcpyDeviceToHost); + + for (int n = 0; n < N; n++) { + + float absMax = 0; + // scan a column of Matrix A to find the abs max. + for (int m = 0; m < M; m++) { + float val = (float)cpuOut[M*N*2 + n + m * M]; + absMax = std::max(absMax, val); + } + + printf("MaxOut: %f\n", absMax); + } + + free(cpuOut); + } + */ + + + //printf("\nRows: %d, cols: %d\n", M, K); // Ankan - test! + float AFactor = computeScaleFactor(smoothA, M * K); + + //printf("\nRows: %d, cols: %d\n", K, N * batchSize); // Ankan - test! + //float BFactor = computeScaleFactor(smoothB, N * K * batchSize); + + int offsetQ = N * K * 0; + int offsetK = N * K * 1; + int offsetV = N * K * 2; + + float BFactorQ = computeScaleFactor(smoothB + offsetQ, N * K); + float BFactorK = computeScaleFactor(smoothB + offsetK, N * K); + float BFactorV = computeScaleFactor(smoothB + offsetV, N * K); + + + // float OutFactor = computeScaleFactor(Out, M * N * batchSize); + + float opScale = 127; + //opScale = 1.5f * (AFactor * BFactor) / OutFactor; + + printf("\nScaling factors - A: %g, B_Q: %g, B_K: %g, B_V: %g \n", AFactor, + BFactorQ, BFactorK, BFactorV); + + //dumpTensor(A, /*M * K*/512, "A before scaling", false); + //dumpTensor(B, /*N * K * batchSize*/ 512, "B before scaling", false); + + int8_t* a8; + int8_t* b8; + int8_t* out8; + + ReportCUDAErrors(cudaMalloc(&a8, M * K)); + ReportCUDAErrors(cudaMalloc(&b8, N * K * batchSize)); + ReportCUDAErrors(cudaMalloc(&out8, OutStride * batchSize * sizeof(half))); + + // Convert A matrix + { + int numBlocks = lczero::cudnn_backend::DivUp(M * K, 256 * 8); + convertFp16ToInt8<<>>(a8, smoothA, AFactor, M * K); + ReportCUDAErrors(cudaGetLastError()); + } + + // Convert B matrix + { + //int numBlocks = lczero::cudnn_backend::DivUp(N * K * batchSize, 256 * 8); + //convertFp16ToInt8<<>>(b8, smoothB, BFactor, + // N * K * batchSize); + + int numBlocks = lczero::cudnn_backend::DivUp(N * K, 256 * 8); + + convertFp16ToInt8<<>>(b8 + offsetQ, smoothB + offsetQ, + BFactorQ, N * K); + convertFp16ToInt8<<>>(b8 + offsetK, smoothB + offsetK, + BFactorK, N * K); + convertFp16ToInt8<<>>(b8 + offsetV, smoothB + offsetV, + BFactorV, N * K); + + ReportCUDAErrors(cudaGetLastError()); + } + + //dumpTensor(a8, 512, "A after scaling", false); + //dumpTensor(b8, 512, "B after scaling", false); + + // run the inmt8 GEMM (1/127 scaling factor to compensate for the multiplications) + cutlassMatrixMulBTransposed(a8, b8, out8, M, N, K, batchSize, AStride, + BStride, OutStride, 1 / opScale, 0.0f); + + ReportCUDAErrors(cudaGetLastError()); + + //dumpTensor(out8, 512, "output of int8 gemm", false); + + // convert out matrix + { + int offsetOutQ = M * N * 0; + int offsetOutK = M * N * 1; + int offsetOutV = M * N * 2; + + float factorQ = opScale / (AFactor * BFactorQ); + float factorK = opScale / (AFactor * BFactorK); + float factorV = opScale / (AFactor * BFactorV); + int numBlocks = lczero::cudnn_backend::DivUp(OutStride, 256 * 8); + convertInt8ToFp16<<>>(Out + offsetOutQ, out8 + offsetOutQ, + factorQ, OutStride); + convertInt8ToFp16<<>>(Out + offsetOutK, out8 + offsetOutK, + factorK, OutStride); + convertInt8ToFp16<<>>(Out + offsetOutV, out8 + offsetOutV, + factorV, OutStride); + ReportCUDAErrors(cudaGetLastError()); + } + + //dumpTensor(Out, 512, "output after scaling back", false); + ReportCUDAErrors(cudaFree(a8)); + ReportCUDAErrors(cudaFree(b8)); + ReportCUDAErrors(cudaFree(out8)); + + if (useSmooth) { + ReportCUDAErrors(cudaFree(smoothA)); + ReportCUDAErrors(cudaFree(smoothB)); + } +} + +int gCounter = 0; + +void cutlassMatrixMulBTransposed(const half* A, const half* B, half* Out, int M, + int N, int K, int batchSize, int AStride, int BStride, int OutStride, bool useInt8) { + + gCounter++; + + // Ankan - test! + //if (useInt8 && gCounter==5) + if (useInt8) + return cutlassMatrixMulBTransposedWithInt8(A, B, Out, M, N, K, batchSize, + AStride, BStride, OutStride); + + half halfOne = (half)1.0f; + half halfZero = (half)0.0f; + + using ElementAccumulator = cutlass::half_t; // <- data type of accumulator + using ElementComputeEpilogue = + ElementAccumulator; // <- data type of epilogue operations + using ElementInputA = + cutlass::half_t; // <- data type of elements in input matrix A + using ElementInputB = + cutlass::half_t; // <- data type of elements in input matrix B + using ElementOutput = + cutlass::half_t; // <- data type of elements in output matrix D + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + using MMAOp = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using SmArch = cutlass::arch::Sm80; + + // This code section describes the tile size a thread block will compute + using ShapeMMAThreadBlock = + cutlass::gemm::GemmShape<128, 128, 32>; // <- threadblock tile M = 128, N + // = 128, K = 32 + // This code section describes tile size a warp will compute + using ShapeMMAWarp = + cutlass::gemm::GemmShape<32, 64, + 32>; // <- warp tile M = 64, N = 64, K = 32 + // This code section describes the size of MMA op + using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>; // <- MMA Op tile M = + // 8, N = 8, K = 4 + + // This code section describes how threadblocks are scheduled on GPU + using SwizzleThreadBlock = + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; // <- + // ?? + + // This code section describes ? + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits< + ElementOutput>::value, // <- this is the number of elements per + // vectorized memory access. For half + // precision, it's 8 elements. This + // becomes the vector width of math + // instructions in epilogue too + ElementAccumulator, // <- data type of accumulator + float>; // <- data type for alpha/beta in linear combination function + + constexpr int NumStages = 3; // stages == 2/4 is also good sometimes + + using Gemm = cutlass::gemm::device::GemmBatched< + ElementInputA, LayoutInputA, ElementInputB, LayoutInputB, ElementOutput, + LayoutOutput, ElementAccumulator, MMAOp, SmArch, ShapeMMAThreadBlock, + ShapeMMAWarp, ShapeMMAOp, EpilogueOp, SwizzleThreadBlock, NumStages>; + + Gemm gemm_op; + + cutlass::Status status = gemm_op({{M, N, K}, + {(cutlass::half_t const*)A, K}, + AStride, + {(cutlass::half_t const*)B, K}, + BStride, + {(cutlass::half_t const*)Out, N}, + OutStride, + {(cutlass::half_t*)Out, N}, + OutStride, + {halfOne, halfZero}, + batchSize}); +} + +// process 8 elements per thread (in x dimension) +__global__ void quantizeMatrix(int8_t* output, const half* input, int height, + int width, const float* scale) { + int x = (blockIdx.x * blockDim.x + threadIdx.x) * 8; + int y = blockIdx.y * blockDim.y + threadIdx.y; + + if (x >= width || y >= height) return; + + float factor[8]; + half ip[8]; + int8_t op[8]; + + copyAs(&ip[0], &input[y * width + x]); + copyAs(&factor[0], &scale[x]); + copyAs(&factor[4], &scale[x+4]); + + for (int i = 0; i < 8; i++) { + float val = roundf((float)ip[i] * factor[i]); + if (val > 127) val = 127; + if (val < -128) val = -128; + op[i] = (int8_t)(val); + } + + copyAs(&output[y * width + x], &op[0]); +} + + +// The scale is per column +void quantizeActivationMatrix(int8_t* output, const half* input, int height, + int width, const float* scale, cudaStream_t stream) { + + dim3 blockDim(16, 16); + dim3 gridDim(lczero::cudnn_backend::DivUp(width, 16 * 8), + lczero::cudnn_backend::DivUp(height, 16)); + quantizeMatrix<<>>(output, input, height, width, + scale); + ReportCUDAErrors(cudaGetLastError()); +} + + +#define MAX_BATCH_DEQUANT 16 + +struct ScaleParam { + float scale[MAX_BATCH_DEQUANT]; +}; + +// process 8 elements per thread (in x dimension) +__global__ void deQuantizeMatrix(half* output, const int8_t* input, const half *bias, int height, int width, int stride, ScaleParam s) { + int x = (blockIdx.x * blockDim.x + threadIdx.x) * 8; + int y = blockIdx.y * blockDim.y + threadIdx.y; + int b = blockIdx.z; + + if (x >= width || y >= height) return; + + float factor = s.scale[b]; + + int8_t ip[8] = {}; + half op[8] = {}; + half bi[8] = {}; + + copyAs(&ip[0], &input[b * stride + y * width + x]); + copyAs(&bi[0], &bias[b * width + x]); + + for (int i = 0; i < 8; i++) { + float val = (float)ip[i]; + val *= factor; + val += (float)bi[i]; + op[i] = (half) val; + } + + copyAs(&output[b * stride + y * width + x], &op[0]); +} + + + +// the scale (in CPU memory) is per "batch" +// the bias is per column, per batch +void deQuantizeOutputMatrixBiasAdd(half* output, const int8_t* input, + int height, int width, int batchSize, + float* scale, const half* bias, + cudaStream_t stream) { + dim3 blockDim(16, 16); + dim3 gridDim(lczero::cudnn_backend::DivUp(width, 16 * 8), + lczero::cudnn_backend::DivUp(height, 16), batchSize); + + assert(batchSize < MAX_BATCH_DEQUANT); // otherwise we will need to put them in GPU memory + + int stride = width * height; + + ScaleParam s = {}; + for (int i = 0; i < batchSize; i++) s.scale[i] = scale[i]; + + deQuantizeMatrix<<>>(output, input, bias, height, width, stride, s); + ReportCUDAErrors(cudaGetLastError()); + +} + + + + +template void calibrateGemmForInt8(int8_t* weights_int8, + float* input_scaling_factors, + float* output_scaling_factors, + float* maxValuesA, const float* A, + const float* B, int M, int N, int K, + int batchSize, int M_Batch); +template void calibrateGemmForInt8(int8_t* weights_int8, + float* input_scaling_factors, + float* output_scaling_factors, + float* maxValuesA, const half* A, + const half* B, int M, int N, int K, + int batchSize, int M_Batch); + + +template void dumpTensor(const float* memory, int elements, + const char* message, bool only_summary, + bool cpu_tensor); + +template void dumpTensor(const half* memory, int elements, + const char* message, bool only_summary, + bool cpu_tensor); + +template void dumpTensor(const int8_t* memory, int elements, + const char* message, bool only_summary, + bool cpu_tensor); + + +}; // namespace cudnn_backend +}; // namespace lczero #endif diff --git a/src/neural/cuda/kernels.h b/src/neural/cuda/kernels.h index 5231f3ae66..0b48ec4896 100644 --- a/src/neural/cuda/kernels.h +++ b/src/neural/cuda/kernels.h @@ -150,10 +150,15 @@ template void applyInputGating(T* output, const T* input, const T* mult, const T* add, int N, int HW, int C, cudaStream_t stream); +void cutlassMatrixMulBTransposed(const half* A, const half* B, half* Out, int M, + int N, int K, int batchSize, int AStride, + int BStride, int OutStride, bool useInt8 = true); + } // namespace cudnn_backend } // namespace lczero // Work around to avoid "nvcc error : 'cudafe++' died with status 0xC0000409" error // For some reason nvcc runs into this random error when trying to compile this function inside the namespaces bool fusedMHA(void* output, void* mha_q, void* mha_k, void* mha_v, void* skip, - int batch_size, int num_heads, int depth, cudaStream_t stream); \ No newline at end of file + int batch_size, int num_heads, int depth, cudaStream_t stream); + diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 2987f52faf..bdea75d6b2 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -39,6 +39,30 @@ namespace lczero { #if 1 + +#include +using namespace std; + +// function to calculate mean +double mean(float arr[], int n) { + float sum = 0; + for (int i = 0; i < n; i++) { + sum += arr[i]; + } + return sum / n; +} + +// function to calculate standard deviation +float stdDev(float arr[], int n) { + float m = mean(arr, n); // get the mean + float var = 0; // initialize variance + for (int i = 0; i < n; i++) { + var += pow(arr[i] - m, 2); // add the squared difference from mean + } + var /= n; // divide by number of elements + return sqrt(var); // return the square root of variance +} + // debug code to dump allocation in GPU memory template void dumpTensor(T* memory, int elements, const char* message, bool only_summary = false) { @@ -53,6 +77,7 @@ void dumpTensor(T* memory, int elements, const char* message, bool only_summary int nans = 0; int nanss[10] {}; + std::vector fpArr(elements); for (int i = 0; i < elements; i++) { float val; @@ -66,6 +91,7 @@ void dumpTensor(T* memory, int elements, const char* message, bool only_summary float *arr = (float *)temp; val = arr[i]; } + fpArr[i] = val; maxval = std::max(maxval, val); minval = std::min(minval, val); @@ -86,10 +112,15 @@ void dumpTensor(T* memory, int elements, const char* message, bool only_summary if (minval == std::numeric_limits::max()) minval = std::numeric_limits::quiet_NaN(); - printf("Max: %.6f, Min: %.6f, NaNs: %i of %i", maxval, minval, nans, elements); - printf("\nNaN indices: "); - for (int i=0; i 10) printf("......"); + + float avg = mean(&fpArr[0], elements); + float stddev = stdDev(&fpArr[0], elements); + printf("Max: %.6f, Min: %.6f, Mean: %.6f, StdDev: %.6f, NaNs: %i of %i", maxval, minval, avg, stddev, nans, elements); + if (nans > 0) { + printf("\nNaN indices: "); + for (int i = 0; i < nans && i < 10; i++) printf("%i ", nanss[i]); + if (nans > 10) printf("......"); + } printf("\n"); } #endif @@ -1444,7 +1475,7 @@ AttentionPolicyHead::AttentionPolicyHead(BaseLayer* ip, for (const auto& enc : weights.pol_encoder) { EncoderBlock* pW = new EncoderBlock( enc, scratch, encoder_heads_, embedding_op_size_, 1.0f, - nullptr, 0, max_batch_size, false); // using alpha = 1 for now (TODO: may change?) + nullptr, 0, max_batch_size, false, false, false, nullptr, 0); // using alpha = 1 for now (TODO: may change?) encoder_weights_.emplace_back(pW); } } @@ -1453,11 +1484,15 @@ template EncoderBlock::EncoderBlock( const LegacyWeights::EncoderLayer& cpu_weights, void* scratch, int heads, int size, float alpha, DataType* smolgen_global_scratch, - int smolgen_global_size, int max_batch_size, bool fused_mha) + int smolgen_global_size, int max_batch_size, bool fused_mha, + bool int8_calibrate, bool int8_inference, void* int8_weights, + int blockIndex) : encoder_heads_(heads), embedding_op_size_(size), alpha_(alpha), use_fused_mha_(fused_mha), + int8_inf_(int8_inference), + int8_cali_(int8_calibrate), has_smolgen_(cpu_weights.mha.has_smolgen), max_batch_size_(max_batch_size) { mha_q_size_ = cpu_weights.mha.q_b.size(); mha_k_size_ = cpu_weights.mha.k_b.size(); @@ -1533,6 +1568,51 @@ EncoderBlock::EncoderBlock( // GPU memory already allocated in AttentionBody. smol_global = smolgen_global_scratch; + + // int8 stuff + int per_encoder_size = embedding_op_size_ * sizeof(float) + + 3 * embedding_op_size_ * mha_q_size_ * sizeof(int8_t) + + 3 * sizeof(float); + auto w = (int8_t*)int8_weights; + // go to current encoder block + w += per_encoder_size * blockIndex; + + if (int8_inference) { + ReportCUDAErrors(cudaMalloc(&input_scaling_factors_, + sizeof(float) * embedding_op_size_)); + ReportCUDAErrors(cudaMemcpy(input_scaling_factors_, w, + sizeof(float) * embedding_op_size_, + cudaMemcpyHostToDevice)); + + //printf("\nCopied input scaling factors for index: %d\n", blockIndex); + + // go to int8 kqv weights + w += embedding_op_size_ * sizeof(float); + size_t elements = cpu_weights.mha.q_w.size(); + size_t size = elements * sizeof(int8_t) * 3; + ReportCUDAErrors(cudaMalloc(&kqv_int8_, size)); + ReportCUDAErrors(cudaMemcpy(kqv_int8_, w, size, cudaMemcpyHostToDevice)); + + //printf("\nCopied int8 weights for index: %d, size: %d\n", blockIndex, + // size); + + // go to output scaling factors + w += size; + output_scaling_factors_ = (float*)w; + } else if (int8_calibrate) { + // just save the pointers (we will over-write here during calibration) + input_scaling_factors_ = (float*)w; + w += embedding_op_size_ * sizeof(float); + kqv_int8_ = w; + w += 3 * cpu_weights.mha.q_w.size() * sizeof(int8_t); + output_scaling_factors_ = (float*)w; + + // to keep track of max values in input activation matrix + input_matrix_max_values_ = + (float*)malloc(64 * embedding_op_size_ * sizeof(float)); + memset(input_matrix_max_values_, 0, + 64 * embedding_op_size_ * sizeof(float)); + } } } @@ -1601,6 +1681,27 @@ static void cublasXGemmBatched( } } + +template +void calibrateGemmForInt8(int8_t* weights_int8, float* input_scaling_factors, + float* output_scaling_factors, float* maxValuesA, + const DataType* A, const DataType* B, int M, int N, + int K, int batchSize, int M_Batch); + +void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int8_t* Out, + int M, int N, int K, int batchSize, + int AStride, int BStride, int OutStride, + float alphaf, float betaf); + +void quantizeActivationMatrix(int8_t* output, const half* input, int height, + int width, const float* scale, + cudaStream_t stream); + +void deQuantizeOutputMatrixBiasAdd(half* output, const int8_t* input, + int height, int width, int batchSize, + float* scale, const half* bias, + cudaStream_t stream); + // input/output tensor is scratch1, others are used as scratch. // TODO: fix naming of scratch buffers template @@ -1678,22 +1779,65 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, DataType* mha_k; DataType* mha_v; + //dumpTensor(scratch1, embedding_op_size_ * 64 * N, "input to mha_kqv gemm", true); + //dumpTensor(mha_qkv_w, embedding_op_size_ * d_model * 3, "weights to mha_kqv gemm", + // true); + //exit(0); + { const int num_inputs = embedding_op_size_; const int num_outputs = d_model; const int batch = N * 64; const int max_batch = max_batch_size_ * 64; + const int batch_to_use = use_fused_mha_ ? batch : max_batch; // The array of GPU pointers assume max batch mha_q = scratch0; - mha_k = mha_q + num_outputs * max_batch; - mha_v = mha_k + num_outputs * max_batch; + mha_k = mha_q + num_outputs * batch_to_use; + mha_v = mha_k + num_outputs * batch_to_use; - cublasXGemmStridedBatched( - cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, - mha_qkv_w, num_inputs, num_inputs * num_outputs, scratch1, - num_inputs, 0, 0.0f, mha_q, num_outputs, num_outputs * max_batch, 3); - addBiasBatched(mha_q, mha_q, mha_qkv_b, 3, batch, num_outputs, max_batch, - NONE, stream); + if (int8_cali_) { + calibrateGemmForInt8(kqv_int8_, input_scaling_factors_, output_scaling_factors_, + input_matrix_max_values_, scratch1, mha_qkv_w, 64, + d_model, embedding_op_size_, 3, N); + } + + + if (int8_inf_) { + // printf("\nAttempting int8_inf\n"); + // 1. quantize the inputs (scratch1 -> scratch0) + quantizeActivationMatrix((int8_t*)scratch0, (const half*)scratch1, batch, + embedding_op_size_, input_scaling_factors_, + stream); + + // 2. perform int8 GEMM (scratch0 -> scratch2) + cutlassMatrixMulBTransposed((const int8_t*)scratch0, kqv_int8_, + (int8_t*)scratch2, batch, num_outputs, + num_inputs, 3, 0, num_inputs * num_outputs, + num_outputs * batch_to_use, 1.0 / 127.0, 0.0f); + ReportCUDAErrors(cudaGetLastError()); + + // 3. de-quantize outputs - fused with bias add (scratch2 -> scratch0) + deQuantizeOutputMatrixBiasAdd((half*)scratch0, (const int8_t*)scratch2, batch, num_outputs, 3, + output_scaling_factors_, (const half*)mha_qkv_b, stream); + } else { + cublasXGemmStridedBatched( + cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, + 1.0f, mha_qkv_w, num_inputs, num_inputs * num_outputs, scratch1, + num_inputs, 0, 0.0f, mha_q, num_outputs, num_outputs * batch_to_use, + 3); + /* + cutlassMatrixMulBTransposed((const half*)scratch1, (const half*)mha_qkv_w, + (half*) mha_q, batch, + num_outputs, num_inputs, 3, + 0, num_inputs * num_outputs, + num_outputs * batch_to_use); + */ + // dumpTensor(mha_q, num_outputs * N, "output of kqv gemm", false); + // exit(0); + + addBiasBatched(mha_q, mha_q, mha_qkv_b, 3, batch, num_outputs, + batch_to_use, NONE, stream); + } } // Apply split_heads() to q, k and v @@ -1955,6 +2099,12 @@ EncoderBlock::~EncoderBlock() { ReportCUDAErrors(cudaFree(smol_ln2_gammas)); ReportCUDAErrors(cudaFree(smol_ln2_betas)); } + if (int8_inf_) { + ReportCUDAErrors(cudaFree(kqv_int8_)); + ReportCUDAErrors(cudaFree(input_scaling_factors_)); + } else if (int8_cali_) { + free(input_matrix_max_values_); + } } @@ -1992,11 +2142,10 @@ void EmbeddingLayer::Eval( } template -AttentionBody::AttentionBody(const LegacyWeights& weights, - void* scratch, - ActivationFunction default_act, - int num_res_blocks, int input_c, - int max_batch_size, bool fused_mha) +AttentionBody::AttentionBody( + const LegacyWeights& weights, void* scratch, ActivationFunction default_act, + int num_res_blocks, int input_c, int max_batch_size, bool fused_mha, + bool int8_calibrate, bool int8_inference, void* int8_weights) : embedding_op_size_(weights.ip_emb_b.size()), encoder_head_count_(weights.encoder_head_count), num_resi_blocks_(num_res_blocks), @@ -2022,10 +2171,12 @@ AttentionBody::AttentionBody(const LegacyWeights& weights, int num_encoders = weights.encoder.size(); float alpha = (float) pow(2.0 * num_encoders, 0.25); + int index = 0; for (const auto& enc : weights.encoder) { EncoderBlock* pW = new EncoderBlock( enc, scratch, encoder_head_count_, embedding_op_size_, alpha, - smolgen_global_, smolgen_global_size_, max_batch_size, use_fused_mha_); + smolgen_global_, smolgen_global_size_, max_batch_size, use_fused_mha_, + int8_calibrate, int8_inference, int8_weights, index++); encoder_weights_.emplace_back(pW); } } diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index 0e440b9fcf..0ec58ba849 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -337,7 +337,8 @@ class EncoderBlock { public: EncoderBlock(const LegacyWeights::EncoderLayer& cpu_weights, void* scratch, int heads, int size, float alpha, DataType* smolgen_global_scratch, - int smolgen_global_size, int max_batch_size, bool fused_mha); + int smolgen_global_size, int max_batch_size, bool fused_mha, bool int8_calibrate, + bool int8_inference, void* int8_weights, int blockIndex); ~EncoderBlock(); void Eval(int N, DataType* inpop, DataType* scratch0, DataType* scratch1, @@ -365,6 +366,16 @@ class EncoderBlock { DataType *smol_ln2_gammas, *smol_ln2_betas; DataType *smol_global; + + bool int8_inf_, int8_cali_; + + // calibration factors and weights for INT8 (in GPU memory when int8_inf_ is set, otherwise in CPU memory) + int8_t* kqv_int8_; // int8 quantized weights for the KQV matrix multiplication + float* input_scaling_factors_; // scaling factors needed to quantize the inputs + float* output_scaling_factors_; // scaling factors needed to dequantize the outputs (just 3 floats: always in CPU memory) + float* input_matrix_max_values_; // max values of input matrix to KQV GEMM + + int mha_q_size_; int mha_k_size_; int mha_v_size_; @@ -470,7 +481,8 @@ class AttentionBody : public BaseLayer { public: AttentionBody(const LegacyWeights& weights, void* scratch, ActivationFunction default_act, int num_res_blocks, - int input_c, int max_batch_size, bool fused_mha); + int input_c, int max_batch_size, bool fused_mha, + bool int8_calibrate, bool int8_inference, void *int8_weights); ~AttentionBody(); void Eval(int N, DataType* output, const DataType* input, const DataType* input2, void* scratch, size_t scratch_size, diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index e72b70d3b2..9822eb6650 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -45,6 +45,12 @@ namespace lczero { using namespace cudnn_backend; +namespace cudnn_backend { +template +void dumpTensor(const T* memory, int elements, const char* message, + bool only_summary = false, bool cpu_tensor = false); +} + template class CudaNetwork; @@ -352,6 +358,74 @@ class CudaNetwork : public Network { ActivationFunction act = mish_net ? MISH : RELU; + + use_int8_ = options.GetOrDefault("int8", false); + int8_calibration_run_ = options.GetOrDefault("int8-calibrate", false); + + if (int8_calibration_run_ || use_int8_) { + if (!fp16 && use_int8_) + throw Exception("INT8 is supported only with cuda-fp16 backend."); + if (!attn_body_) + throw Exception("INT8 only supported for attention body networks"); + + // Structure of the weights file: + // For each encoder block - + // * per-channel scaling factors for Input Matrix to QKV GEMM + // (to use for quantization of the input) + // * qunatized (int8) weights for QKV GEMMs + // * float factorQ, factorK, factorV + // (basically factors needed to de-quantize the output) TODO! + // (will add more as we try int8 for more layers) + int embedding_op_size = weights.ip_emb_b.size(); + int encoder_d_model = weights.encoder[0].mha.q_b.size(); + int num_encoders = weights.encoder.size(); + int8_weights_size_ = + num_encoders * ((embedding_op_size + 3) * sizeof(float) + + 3 * embedding_op_size * encoder_d_model * sizeof(int8_t)); + int8_weights_ = malloc(int8_weights_size_); + memset(int8_weights_, 0, int8_weights_size_); + + printf("\nint8_weights_size: %d\n", int8_weights_size_); + } + + if (int8_calibration_run_) { + // we will write the file at the time of exit. + } else if (use_int8_) { + FILE* fp = fopen("weights_quant.bin", "rb"); + if (!fp) { + CERR << "ERROR: weights_quant.bin not found. Please run 'lc0 benchmark " + "-t 1 --nodes=1 -w --backend=cuda-fp16 " + "--backend-opts=int8-calibrate' first"; + throw Exception("Quantized weights not found"); + } else { + int read = fread(int8_weights_, 1, int8_weights_size_, fp); + fclose(fp); + if (read != int8_weights_size_) + throw Exception( + "Quantized weights likely corrupted or of different network"); + +#if 0 + // Ankan - test: dump some weights here + float* data = (float*)int8_weights_; + dumpTensor(data, weights.ip_emb_b.size(), + "per-channel scaling factors for input", + false, true); + + int8_t* w = (int8_t*)int8_weights_; + w += weights.ip_emb_b.size() * sizeof(float); + dumpTensor(w, 512, "quantized weights", + false, true); + + w += 3 * weights.ip_emb_b.size() * weights.encoder[0].mha.q_b.size() * + sizeof(int8_t); + dumpTensor((float*)w, 3, "scaling factors for output", false, true); + + exit(0); +#endif + + } + } + // 2. Build the network, and copy the weights to GPU memory. // Input conv only used if there are residual blocks in the network @@ -420,7 +494,8 @@ class CudaNetwork : public Network { if (attn_body_) { auto attention_body = std::make_unique>( weights, scratch_mem_, act, numBlocks_, - numBlocks_ > 0 ? kNumFilters : kInputPlanes, max_batch_size_, use_fused_mha); + numBlocks_ > 0 ? kNumFilters : kInputPlanes, max_batch_size_, + use_fused_mha, int8_calibration_run_, use_int8_, int8_weights_); network_.emplace_back(std::move(attention_body)); encoder_last_ = getLastLayer(); @@ -870,6 +945,16 @@ class CudaNetwork : public Network { } cublasDestroy(cublas_); } + + if (int8_calibration_run_) { + // write the calibration data/weights to file + FILE* fp = fopen("weights_quant.bin", "wb+"); + fwrite(int8_weights_, 1, int8_weights_size_, fp); + fclose(fp); + } + if (int8_calibration_run_ || use_int8_) + free(int8_weights_); + } const NetworkCapabilities& GetCapabilities() const override { @@ -919,6 +1004,8 @@ class CudaNetwork : public Network { // tower bool multi_stream_; // run multiple parallel network evals bool allow_cache_opt_; // try to fit residual block activations in L2 cache + bool use_int8_; // try to use INT8 (works only with cuda-fp16 backend) + bool int8_calibration_run_; // this is a calibration run to figure out quantization factors // Currently only one NN Eval can happen a time (we can fix this if needed // by allocating more memory). @@ -952,6 +1039,9 @@ class CudaNetwork : public Network { mutable std::mutex inputs_outputs_lock_; std::list> free_inputs_outputs_; + void* int8_weights_; // loaded from disk / to be stored to disk + int int8_weights_size_; + void showInfo() const { int version; int ret = cudaRuntimeGetVersion(&version); From c2e63496bb70bf743228714b331787ccb0346852 Mon Sep 17 00:00:00 2001 From: Ankan Banerjee Date: Mon, 24 Apr 2023 20:55:19 +0530 Subject: [PATCH 26/70] clean up some unused/test code --- src/neural/cuda/cutlass_kernels.cu | 521 +++++------------------------ src/neural/cuda/layers.cc | 2 + 2 files changed, 95 insertions(+), 428 deletions(-) diff --git a/src/neural/cuda/cutlass_kernels.cu b/src/neural/cuda/cutlass_kernels.cu index 94dbf22966..31e600ed83 100644 --- a/src/neural/cuda/cutlass_kernels.cu +++ b/src/neural/cuda/cutlass_kernels.cu @@ -1,6 +1,6 @@ /* This file is part of Leela Chess Zero. - Copyright (C) 2018 The LCZero Authors + Copyright (C) 2023 The LCZero Authors Leela Chess is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -142,10 +142,6 @@ static float mean(float arr[], int n) { } -#include - -int fileNumber = 0; - // function to calculate standard deviation static float stdDev(float arr[], int n) { float m = mean(arr, n); // get the mean @@ -157,95 +153,12 @@ static float stdDev(float arr[], int n) { return sqrt(var); // return the square root of variance } -float computeScaleFactor(const half* memory, int elements) { - std::vector fpArr(elements); - - void* temp = malloc(elements * sizeof(half)); - cudaMemcpy(temp, memory, elements * sizeof(half), cudaMemcpyDeviceToHost); - - float absmax = 0; - for (int i = 0; i < elements; i++) { - float val; - half* arr = (half*)temp; - val = (float)arr[i]; - fpArr[i] = val; - if (val == val) - absmax = std::max(absmax, fabs(val)); - else { - printf("\nNAN found!\n"); - exit(0); - } - } - - //float avg = mean(&fpArr[0], elements); - //float stddev = stdDev(&fpArr[0], elements); - - - // Ankan - for testing - char filename[100]; - sprintf(filename, "Mat_%d", fileNumber++); - FILE* fp; - fopen_s(&fp, filename, "wb+"); - fwrite(&fpArr[0], sizeof(float), elements, fp); - fclose(fp); - - // 4x standard deviation should be enough range ? - // No, it seems including the outliers is important :-/ - // absmax = std::min(stddev * 4, absmax); - free(temp); - // printf(" absmax: %f ", absmax); - return 127.0f / absmax; -} - // Helper fuction to do vector loads/stores template __device__ __forceinline__ void copyAs(void* dst, const void* src) { *((T*)(dst)) = *((const T*)(src)); } - -// each thread processes 8 elements (=> 16 byte reads, 8 byte writes) -__global__ void convertFp16ToInt8(int8_t* output, const half* input, float scaleFactor, int N) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int firstEl = tid * 8; - if (firstEl >= N) return; - - half ip[8]; - int8_t op[8]; - - copyAs(&ip[0], &input[firstEl]); - - for (int i = 0; i < 8; i++) { - float val = roundf((float)ip[i] * scaleFactor); - if (val > 127) val = 127; - if (val < -128) val = -128; - op[i] = (int8_t)(val); - } - /* - if (firstEl == 0) - printf("\nfrom kernel: input: %f, output: %f\n", (float)ip[0], - (float)op[0]); - */ - copyAs(&output[firstEl], &op[0]); -} - -__global__ void convertInt8ToFp16(half* output, const int8_t* input, - float scaleFactor, int N) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int firstEl = tid * 8; - if (firstEl >= N) return; - - half op[8]; - int8_t ip[8]; - - copyAs(&ip[0], &input[firstEl]); - - for (int i = 0; i < 8; i++) op[i] = (half)(scaleFactor * (float)ip[i]); - - copyAs(&output[firstEl], &op[0]); -} - - // debug code to dump allocation in GPU memory template void dumpTensor(const T* memory, int elements, const char* message, @@ -314,7 +227,7 @@ void dumpTensor(const T* memory, int elements, const char* message, } - +// int8 GEMM using CUTLASS void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int8_t* Out, int M, int N, int K, int batchSize, int AStride, int BStride, int OutStride, @@ -334,8 +247,6 @@ void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int8_t* Out, using LayoutInputB = cutlass::layout::ColumnMajor; using LayoutOutput = cutlass::layout::RowMajor; - // This code section describes whether you want to use tensor cores or regular - // SIMT cores on GPU SM using MMAOp = cutlass::arch::OpClassTensorOp; // This code section describes CUDA SM architecture number @@ -377,12 +288,85 @@ void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int8_t* Out, status = gemm_op(); } +// FP16 GEMM using cutlass +void cutlassMatrixMulBTransposed(const half* A, const half* B, half* Out, int M, + int N, int K, int batchSize, int AStride, int BStride, int OutStride, bool useInt8) { + + half halfOne = (half)1.0f; + half halfZero = (half)0.0f; -static void calibrateGemm(int8_t* weights_int8, - float* input_scaling_factors, - float* output_scaling_factors, - float* cpuA, float* cpuB, int M, - int N, int K, int batchSize) { + using ElementAccumulator = cutlass::half_t; // <- data type of accumulator + using ElementComputeEpilogue = + ElementAccumulator; // <- data type of epilogue operations + using ElementInputA = + cutlass::half_t; // <- data type of elements in input matrix A + using ElementInputB = + cutlass::half_t; // <- data type of elements in input matrix B + using ElementOutput = + cutlass::half_t; // <- data type of elements in output matrix D + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + using MMAOp = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using SmArch = cutlass::arch::Sm80; + + // This code section describes the tile size a thread block will compute + using ShapeMMAThreadBlock = + cutlass::gemm::GemmShape<128, 128, 32>; // <- threadblock tile M = 128, N + // = 128, K = 32 + // This code section describes tile size a warp will compute + using ShapeMMAWarp = + cutlass::gemm::GemmShape<32, 64, + 32>; // <- warp tile M = 64, N = 64, K = 32 + // This code section describes the size of MMA op + using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>; // <- MMA Op tile M = + // 8, N = 8, K = 4 + + // This code section describes how threadblocks are scheduled on GPU + using SwizzleThreadBlock = + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; // <- + // ?? + + // This code section describes ? + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits< + ElementOutput>::value, // <- this is the number of elements per + // vectorized memory access. For half + // precision, it's 8 elements. This + // becomes the vector width of math + // instructions in epilogue too + ElementAccumulator, // <- data type of accumulator + float>; // <- data type for alpha/beta in linear combination function + + constexpr int NumStages = 3; // stages == 2/4 is also good sometimes + + using Gemm = cutlass::gemm::device::GemmBatched< + ElementInputA, LayoutInputA, ElementInputB, LayoutInputB, ElementOutput, + LayoutOutput, ElementAccumulator, MMAOp, SmArch, ShapeMMAThreadBlock, + ShapeMMAWarp, ShapeMMAOp, EpilogueOp, SwizzleThreadBlock, NumStages>; + + Gemm gemm_op; + + cutlass::Status status = gemm_op({{M, N, K}, + {(cutlass::half_t const*)A, K}, + AStride, + {(cutlass::half_t const*)B, K}, + BStride, + {(cutlass::half_t const*)Out, N}, + OutStride, + {(cutlass::half_t*)Out, N}, + OutStride, + {halfOne, halfZero}, + batchSize}); +} + + +static void calibrateGemm(int8_t* weights_int8, float* input_scaling_factors, + float* output_scaling_factors, float* cpuA, + float* cpuB, int M, int N, int K, int batchSize) { std::vector scaling_factors(K); // apply smooth-quant (basically adjust A and B matrices to make quantization @@ -471,29 +455,24 @@ static void calibrateGemm(int8_t* weights_int8, output_scaling_factors[i] = 127.0 / (AFactor * BFactor[i]); // Ankan - for debug/test - //printf("\nScaling factors - A: %g, B_Q: %g, B_K: %g, B_V: %g \n", + // printf("\nScaling factors - A: %g, B_Q: %g, B_K: %g, B_V: %g \n", // 127.0 / absMaxA, BFactor[0], BFactor[1], BFactor[2]); - - } - // Same Activation (A) matrix (M x K) is multiplied by batchSize x B matrices / // weights (K x N transposed) The outputs are: // 1. quantized weight matrices (weights_int8) // 2. "per-column" scaling factors (input_scaling_factors) needed to quantize // matrix A -// 3. Scaling factors to dequantize the output matrix (just 3 values: factorQ, factorK, factorV) +// 3. Scaling factors to dequantize the output matrix (just 3 values: factorQ, +// factorK, factorV) // M_Batch is the batch size component in "M" dimension // maxValuesA contains the max values in activation matrix found so far template -void calibrateGemmForInt8(int8_t* weights_int8, - float* input_scaling_factors, - float* output_scaling_factors, - float* maxValuesA, - const DataType* A, const DataType* B, int M, - int N, int K, int batchSize, int M_Batch) { - +void calibrateGemmForInt8(int8_t* weights_int8, float* input_scaling_factors, + float* output_scaling_factors, float* maxValuesA, + const DataType* A, const DataType* B, int M, int N, + int K, int batchSize, int M_Batch) { auto cpuA = (DataType*)malloc(M_Batch * M * K * sizeof(DataType)); auto cpuB = (DataType*)malloc(batchSize * K * N * sizeof(DataType)); @@ -502,346 +481,33 @@ void calibrateGemmForInt8(int8_t* weights_int8, ReportCUDAErrors(cudaMemcpy(cpuB, B, batchSize * K * N * sizeof(DataType), cudaMemcpyDeviceToHost)); - // convert to FP32 (if not already in fp32, and pick one Activation matrix at a time) + // convert to FP32 (if not already in fp32, and pick one Activation matrix at + // a time) auto fpA = (float*)malloc(M * K * sizeof(float)); auto fpB = (float*)malloc(batchSize * K * N * sizeof(float)); - for (int i = 0; i < K * N * batchSize; i++) - fpB[i] = (float)cpuB[i]; - + for (int i = 0; i < K * N * batchSize; i++) fpB[i] = (float)cpuB[i]; for (int b = 0; b < M_Batch; b++) { for (int i = 0; i < M * K; i++) { float val = abs((float)cpuA[b * M * K + i]); val = std::max(val, maxValuesA[i]); fpA[i] = val; - maxValuesA[i] = val; // update the max activation matrix + maxValuesA[i] = val; // update the max activation matrix } // calibrate a single sample - calibrateGemm(weights_int8, input_scaling_factors, - output_scaling_factors, fpA, fpB, M, N, K, - batchSize); + calibrateGemm(weights_int8, input_scaling_factors, output_scaling_factors, + fpA, fpB, M, N, K, batchSize); } free(fpA); free(fpB); free(cpuA); free(cpuB); - } -void cutlassMatrixMulBTransposed(const half* A, const half* B, half* Out, int M, - int N, int K, int batchSize, int AStride, - int BStride, int OutStride, - bool useInt8 = true); - -// Test routine that does the following -// Figures out the range of values of the inputs (A and B) -// computes scaling factors for both A and B. -// convert fp16->int8 using the scaling factors. -// run int8 GEMM -// convert the result back from int8 -> fp16 -// Decide on: -// 1. Which scaling factor to use for the GEMM -// 2. which scaling factor to use for the post-process pass ? -void cutlassMatrixMulBTransposedWithInt8(const half* A, const half* B, half* Out, int M, - int N, int K, int batchSize, int AStride, - int BStride, int OutStride) { - bool useSmooth = true; - half *smoothA, *smoothB; - - if (!useSmooth) { - smoothA = (half*)A; - smoothB = (half*)B; - } else { - // SmoothQuant test: - // pre-process both A and B matrices to get best out of int8 matmul - // In practice - // * this post processing of A need to be fused with previous layer (once - // we know per-channel scaling factors which is computed offline) - // * post processing of B can be done one time / offline. - - ReportCUDAErrors(cudaMalloc(&smoothA, M * K * sizeof(half))); - ReportCUDAErrors(cudaMalloc(&smoothB, batchSize * K * N * sizeof(half))); - - // TODO: assumption on A's stride (to be 0), and B to be packed. - half* cpuA = (half*)malloc(M * K * sizeof(half)); - half* cpuB = (half*)malloc(batchSize * K * N * sizeof(half)); - - ReportCUDAErrors(cudaMemcpy(cpuA, A, M * K * sizeof(half), cudaMemcpyDeviceToHost)); - ReportCUDAErrors(cudaMemcpy(cpuB, B, batchSize * K * N * sizeof(half), - cudaMemcpyDeviceToHost)); - - for (int k = 0; k < K; k++) { - float absMaxA = 0; - float absMaxB = 0; - // scan a column of Matrix A to find the abs max. - for (int y = 0; y < M; y++) { - float val = (float)cpuA[y * K + k]; - absMaxA = std::max(absMaxA, abs(val)); - } - - // scan a column of Matrix B (from each batch dimension) - for (int b = 0; b < batchSize; b++) - for (int x = 0; x < N; x++) { - float val = (float)cpuB[b * BStride + x * K + k]; - absMaxB = std::max(absMaxB, abs(val)); - } - - // compute scaling factor: - float s = sqrt(absMaxA / (absMaxB)); - - // sanity check, don't use too small, or too big scaling factors - if (s < 1) s = 1.0f; // don't try to squeeze activations for improving range of weights! - if (s > 10) s = 10.0f; - - // printf("\nMaxA: %f, MaxB: %f, scale: %f ", absMaxA, absMaxB, s); - - // scale A and B matrices using the scaling factor - for (int y = 0; y < M; y++) { - float val = (float)cpuA[y * K + k]; - val /= s; - cpuA[y * K + k] = (half)val; - } - - for (int b = 0; b < batchSize; b++) - for (int x = 0; x < N; x++) { - float val = (float)cpuB[b * BStride + x * K + k]; - val *= s; - cpuB[b * BStride + x * K + k] = (half)val; - } - } - - ReportCUDAErrors(cudaMemcpy(smoothA, cpuA, M * K * sizeof(half), cudaMemcpyHostToDevice)); - ReportCUDAErrors(cudaMemcpy(smoothB, cpuB, batchSize * K * N * sizeof(half), - cudaMemcpyHostToDevice)); - - free(cpuA); - free(cpuB); - ReportCUDAErrors(cudaGetLastError()); - } - - // do the actual multiplication so that we can compute OutFactor - /* - cutlassMatrixMulBTransposed(A, B, Out, M, N, K, batchSize, - AStride, - BStride, - OutStride, false); - - ReportCUDAErrors(cudaGetLastError()); - */ - - // Ankan - test! - // copy the output to CPU and check the range of values per "channel" - /* - { - half* cpuOut = (half*)malloc(batchSize * M * N * sizeof(half)); - cudaMemcpy(cpuOut, Out, M * N * sizeof(half) * batchSize, - cudaMemcpyDeviceToHost); - - for (int n = 0; n < N; n++) { - - float absMax = 0; - // scan a column of Matrix A to find the abs max. - for (int m = 0; m < M; m++) { - float val = (float)cpuOut[M*N*2 + n + m * M]; - absMax = std::max(absMax, val); - } - - printf("MaxOut: %f\n", absMax); - } - - free(cpuOut); - } - */ - - - //printf("\nRows: %d, cols: %d\n", M, K); // Ankan - test! - float AFactor = computeScaleFactor(smoothA, M * K); - - //printf("\nRows: %d, cols: %d\n", K, N * batchSize); // Ankan - test! - //float BFactor = computeScaleFactor(smoothB, N * K * batchSize); - - int offsetQ = N * K * 0; - int offsetK = N * K * 1; - int offsetV = N * K * 2; - - float BFactorQ = computeScaleFactor(smoothB + offsetQ, N * K); - float BFactorK = computeScaleFactor(smoothB + offsetK, N * K); - float BFactorV = computeScaleFactor(smoothB + offsetV, N * K); - - - // float OutFactor = computeScaleFactor(Out, M * N * batchSize); - - float opScale = 127; - //opScale = 1.5f * (AFactor * BFactor) / OutFactor; - - printf("\nScaling factors - A: %g, B_Q: %g, B_K: %g, B_V: %g \n", AFactor, - BFactorQ, BFactorK, BFactorV); - - //dumpTensor(A, /*M * K*/512, "A before scaling", false); - //dumpTensor(B, /*N * K * batchSize*/ 512, "B before scaling", false); - - int8_t* a8; - int8_t* b8; - int8_t* out8; - - ReportCUDAErrors(cudaMalloc(&a8, M * K)); - ReportCUDAErrors(cudaMalloc(&b8, N * K * batchSize)); - ReportCUDAErrors(cudaMalloc(&out8, OutStride * batchSize * sizeof(half))); - - // Convert A matrix - { - int numBlocks = lczero::cudnn_backend::DivUp(M * K, 256 * 8); - convertFp16ToInt8<<>>(a8, smoothA, AFactor, M * K); - ReportCUDAErrors(cudaGetLastError()); - } - - // Convert B matrix - { - //int numBlocks = lczero::cudnn_backend::DivUp(N * K * batchSize, 256 * 8); - //convertFp16ToInt8<<>>(b8, smoothB, BFactor, - // N * K * batchSize); - - int numBlocks = lczero::cudnn_backend::DivUp(N * K, 256 * 8); - - convertFp16ToInt8<<>>(b8 + offsetQ, smoothB + offsetQ, - BFactorQ, N * K); - convertFp16ToInt8<<>>(b8 + offsetK, smoothB + offsetK, - BFactorK, N * K); - convertFp16ToInt8<<>>(b8 + offsetV, smoothB + offsetV, - BFactorV, N * K); - - ReportCUDAErrors(cudaGetLastError()); - } - - //dumpTensor(a8, 512, "A after scaling", false); - //dumpTensor(b8, 512, "B after scaling", false); - - // run the inmt8 GEMM (1/127 scaling factor to compensate for the multiplications) - cutlassMatrixMulBTransposed(a8, b8, out8, M, N, K, batchSize, AStride, - BStride, OutStride, 1 / opScale, 0.0f); - - ReportCUDAErrors(cudaGetLastError()); - - //dumpTensor(out8, 512, "output of int8 gemm", false); - - // convert out matrix - { - int offsetOutQ = M * N * 0; - int offsetOutK = M * N * 1; - int offsetOutV = M * N * 2; - - float factorQ = opScale / (AFactor * BFactorQ); - float factorK = opScale / (AFactor * BFactorK); - float factorV = opScale / (AFactor * BFactorV); - int numBlocks = lczero::cudnn_backend::DivUp(OutStride, 256 * 8); - convertInt8ToFp16<<>>(Out + offsetOutQ, out8 + offsetOutQ, - factorQ, OutStride); - convertInt8ToFp16<<>>(Out + offsetOutK, out8 + offsetOutK, - factorK, OutStride); - convertInt8ToFp16<<>>(Out + offsetOutV, out8 + offsetOutV, - factorV, OutStride); - ReportCUDAErrors(cudaGetLastError()); - } - - //dumpTensor(Out, 512, "output after scaling back", false); - ReportCUDAErrors(cudaFree(a8)); - ReportCUDAErrors(cudaFree(b8)); - ReportCUDAErrors(cudaFree(out8)); - - if (useSmooth) { - ReportCUDAErrors(cudaFree(smoothA)); - ReportCUDAErrors(cudaFree(smoothB)); - } -} - -int gCounter = 0; - -void cutlassMatrixMulBTransposed(const half* A, const half* B, half* Out, int M, - int N, int K, int batchSize, int AStride, int BStride, int OutStride, bool useInt8) { - - gCounter++; - - // Ankan - test! - //if (useInt8 && gCounter==5) - if (useInt8) - return cutlassMatrixMulBTransposedWithInt8(A, B, Out, M, N, K, batchSize, - AStride, BStride, OutStride); - - half halfOne = (half)1.0f; - half halfZero = (half)0.0f; - - using ElementAccumulator = cutlass::half_t; // <- data type of accumulator - using ElementComputeEpilogue = - ElementAccumulator; // <- data type of epilogue operations - using ElementInputA = - cutlass::half_t; // <- data type of elements in input matrix A - using ElementInputB = - cutlass::half_t; // <- data type of elements in input matrix B - using ElementOutput = - cutlass::half_t; // <- data type of elements in output matrix D - using LayoutInputA = cutlass::layout::RowMajor; - using LayoutInputB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - using MMAOp = cutlass::arch::OpClassTensorOp; - - // This code section describes CUDA SM architecture number - using SmArch = cutlass::arch::Sm80; - - // This code section describes the tile size a thread block will compute - using ShapeMMAThreadBlock = - cutlass::gemm::GemmShape<128, 128, 32>; // <- threadblock tile M = 128, N - // = 128, K = 32 - // This code section describes tile size a warp will compute - using ShapeMMAWarp = - cutlass::gemm::GemmShape<32, 64, - 32>; // <- warp tile M = 64, N = 64, K = 32 - // This code section describes the size of MMA op - using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>; // <- MMA Op tile M = - // 8, N = 8, K = 4 - - // This code section describes how threadblocks are scheduled on GPU - using SwizzleThreadBlock = - cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; // <- - // ?? - - // This code section describes ? - using EpilogueOp = cutlass::epilogue::thread::LinearCombination< - ElementOutput, // <- data type of output matrix - 128 / cutlass::sizeof_bits< - ElementOutput>::value, // <- this is the number of elements per - // vectorized memory access. For half - // precision, it's 8 elements. This - // becomes the vector width of math - // instructions in epilogue too - ElementAccumulator, // <- data type of accumulator - float>; // <- data type for alpha/beta in linear combination function - - constexpr int NumStages = 3; // stages == 2/4 is also good sometimes - - using Gemm = cutlass::gemm::device::GemmBatched< - ElementInputA, LayoutInputA, ElementInputB, LayoutInputB, ElementOutput, - LayoutOutput, ElementAccumulator, MMAOp, SmArch, ShapeMMAThreadBlock, - ShapeMMAWarp, ShapeMMAOp, EpilogueOp, SwizzleThreadBlock, NumStages>; - - Gemm gemm_op; - - cutlass::Status status = gemm_op({{M, N, K}, - {(cutlass::half_t const*)A, K}, - AStride, - {(cutlass::half_t const*)B, K}, - BStride, - {(cutlass::half_t const*)Out, N}, - OutStride, - {(cutlass::half_t*)Out, N}, - OutStride, - {halfOne, halfZero}, - batchSize}); -} - // process 8 elements per thread (in x dimension) __global__ void quantizeMatrix(int8_t* output, const half* input, int height, int width, const float* scale) { @@ -955,7 +621,6 @@ template void calibrateGemmForInt8(int8_t* weights_int8, const half* B, int M, int N, int K, int batchSize, int M_Batch); - template void dumpTensor(const float* memory, int elements, const char* message, bool only_summary, bool cpu_tensor); diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index bdea75d6b2..942e070751 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -1805,6 +1805,7 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, if (int8_inf_) { // printf("\nAttempting int8_inf\n"); // 1. quantize the inputs (scratch1 -> scratch0) + // TODO: Fuse this step with layer-norm of previous block quantizeActivationMatrix((int8_t*)scratch0, (const half*)scratch1, batch, embedding_op_size_, input_scaling_factors_, stream); @@ -1817,6 +1818,7 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, ReportCUDAErrors(cudaGetLastError()); // 3. de-quantize outputs - fused with bias add (scratch2 -> scratch0) + // TODO: fuse the entire thing with the above GEMM. deQuantizeOutputMatrixBiasAdd((half*)scratch0, (const int8_t*)scratch2, batch, num_outputs, 3, output_scaling_factors_, (const half*)mha_qkv_b, stream); } else { From a2eb7d286efa4bfd43a853b0d9ae9d1128437254 Mon Sep 17 00:00:00 2001 From: Ankan Banerjee Date: Tue, 25 Apr 2023 20:12:34 +0530 Subject: [PATCH 27/70] try int8 for more layers - all big GEMMs implemented. FFN1 seems to have issue. --- src/neural/cuda/cutlass_kernels.cu | 20 ++- src/neural/cuda/layers.cc | 249 +++++++++++++++++++++-------- src/neural/cuda/layers.h | 25 ++- src/neural/cuda/network_cuda.cc | 31 +++- 4 files changed, 239 insertions(+), 86 deletions(-) diff --git a/src/neural/cuda/cutlass_kernels.cu b/src/neural/cuda/cutlass_kernels.cu index 31e600ed83..e826c13f34 100644 --- a/src/neural/cuda/cutlass_kernels.cu +++ b/src/neural/cuda/cutlass_kernels.cu @@ -27,6 +27,7 @@ #include "cuda_common.h" #include +#include "winograd_helper.inc" #ifdef USE_CUTLASS @@ -153,11 +154,13 @@ static float stdDev(float arr[], int n) { return sqrt(var); // return the square root of variance } +/* // Helper fuction to do vector loads/stores template __device__ __forceinline__ void copyAs(void* dst, const void* src) { *((T*)(dst)) = *((const T*)(src)); } +*/ // debug code to dump allocation in GPU memory template @@ -232,6 +235,7 @@ void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int8_t* Out, int M, int N, int K, int batchSize, int AStride, int BStride, int OutStride, float alphaf, float betaf) { + // Ankan - For testing! //dumpTensor(A, 512, "A after scaling", false); //dumpTensor(B, 512, "B after scaling", false); @@ -455,7 +459,7 @@ static void calibrateGemm(int8_t* weights_int8, float* input_scaling_factors, output_scaling_factors[i] = 127.0 / (AFactor * BFactor[i]); // Ankan - for debug/test - // printf("\nScaling factors - A: %g, B_Q: %g, B_K: %g, B_V: %g \n", + //printf("\nScaling factors - A: %g, B_Q: %g, B_K: %g, B_V: %g \n", // 127.0 / absMaxA, BFactor[0], BFactor[1], BFactor[2]); } @@ -555,7 +559,7 @@ struct ScaleParam { }; // process 8 elements per thread (in x dimension) -__global__ void deQuantizeMatrix(half* output, const int8_t* input, const half *bias, int height, int width, int stride, ScaleParam s) { +__global__ void deQuantizeMatrix(half* output, const int8_t* input, const half *bias, int height, int width, int stride, ScaleParam s, ActivationFunction act) { int x = (blockIdx.x * blockDim.x + threadIdx.x) * 8; int y = blockIdx.y * blockDim.y + threadIdx.y; int b = blockIdx.z; @@ -569,13 +573,14 @@ __global__ void deQuantizeMatrix(half* output, const int8_t* input, const half * half bi[8] = {}; copyAs(&ip[0], &input[b * stride + y * width + x]); - copyAs(&bi[0], &bias[b * width + x]); + if (bias) + copyAs(&bi[0], &bias[b * width + x]); for (int i = 0; i < 8; i++) { float val = (float)ip[i]; val *= factor; - val += (float)bi[i]; - op[i] = (half) val; + if (bias) val += (float)bi[i]; + op[i] = (half) activate(val, act); } copyAs(&output[b * stride + y * width + x], &op[0]); @@ -588,7 +593,8 @@ __global__ void deQuantizeMatrix(half* output, const int8_t* input, const half * void deQuantizeOutputMatrixBiasAdd(half* output, const int8_t* input, int height, int width, int batchSize, float* scale, const half* bias, - cudaStream_t stream) { + cudaStream_t stream, + ActivationFunction act = NONE) { dim3 blockDim(16, 16); dim3 gridDim(lczero::cudnn_backend::DivUp(width, 16 * 8), lczero::cudnn_backend::DivUp(height, 16), batchSize); @@ -600,7 +606,7 @@ void deQuantizeOutputMatrixBiasAdd(half* output, const int8_t* input, ScaleParam s = {}; for (int i = 0; i < batchSize; i++) s.scale[i] = scale[i]; - deQuantizeMatrix<<>>(output, input, bias, height, width, stride, s); + deQuantizeMatrix<<>>(output, input, bias, height, width, stride, s, act); ReportCUDAErrors(cudaGetLastError()); } diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 942e070751..c1db9205d7 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -1480,6 +1480,51 @@ AttentionPolicyHead::AttentionPolicyHead(BaseLayer* ip, } } + +static int8_t* SetQuantizationData(MatMulQuantizationData& data, int8_t* w, int InputCols, int OutputCols, int outBatch, bool cali) { + size_t matrix_size = InputCols * OutputCols * sizeof(int8_t) * outBatch; + if (!cali) { + // Load weights for INT8 inference + + // (per-column) scaling factors for the input + ReportCUDAErrors(cudaMalloc(&data.input_scaling_factors, + sizeof(float) * InputCols)); + ReportCUDAErrors(cudaMemcpy(data.input_scaling_factors, w, + sizeof(float) * InputCols, + cudaMemcpyHostToDevice)); + + // go to weights + w += InputCols * sizeof(float); + ReportCUDAErrors(cudaMalloc(&data.weights_int8, matrix_size)); + ReportCUDAErrors( + cudaMemcpy(data.weights_int8, w, matrix_size, cudaMemcpyHostToDevice)); + + // go to output scaling factors + w += matrix_size; + data.output_scaling_factors = (float*)w; + + + // go to next entry + w += outBatch * sizeof(float); + } else { + // Just save the pointers to CPU weights (we will over-write here during calibration) + data.input_scaling_factors = (float*)w; + w += InputCols * sizeof(float); + data.weights_int8 = w; + w += matrix_size; + data.output_scaling_factors = (float*)w; + w += outBatch * sizeof(float); + + // to keep track of max values in input activation matrix + int InputMatrixSizeForBatch1 = 64 * InputCols * sizeof(float); + data.input_matrix_max_values = (float*)malloc(InputMatrixSizeForBatch1); + memset(data.input_matrix_max_values, 0, InputMatrixSizeForBatch1); + } + + // return pointer to next item + return w; +} + template EncoderBlock::EncoderBlock( const LegacyWeights::EncoderLayer& cpu_weights, void* scratch, int heads, @@ -1568,53 +1613,33 @@ EncoderBlock::EncoderBlock( // GPU memory already allocated in AttentionBody. smol_global = smolgen_global_scratch; - - // int8 stuff + } + // int8 stuff + if (int8_inference || int8_calibrate) { int per_encoder_size = embedding_op_size_ * sizeof(float) + - 3 * embedding_op_size_ * mha_q_size_ * sizeof(int8_t) + - 3 * sizeof(float); + 3 * embedding_op_size_ * mha_q_size_ + + 3 * sizeof(float) + mha_q_size_ * sizeof(float) + + embedding_op_size_ * mha_q_size_ + sizeof(float) + + embedding_op_size_ * sizeof(float) + + ffn_dense1_size_ * mha_q_size_ + sizeof(float) + + ffn_dense1_size_ * sizeof(float) + + embedding_op_size_ * ffn_dense1_size_ + sizeof(float); + auto w = (int8_t*)int8_weights; // go to current encoder block - w += per_encoder_size * blockIndex; - - if (int8_inference) { - ReportCUDAErrors(cudaMalloc(&input_scaling_factors_, - sizeof(float) * embedding_op_size_)); - ReportCUDAErrors(cudaMemcpy(input_scaling_factors_, w, - sizeof(float) * embedding_op_size_, - cudaMemcpyHostToDevice)); - - //printf("\nCopied input scaling factors for index: %d\n", blockIndex); - - // go to int8 kqv weights - w += embedding_op_size_ * sizeof(float); - size_t elements = cpu_weights.mha.q_w.size(); - size_t size = elements * sizeof(int8_t) * 3; - ReportCUDAErrors(cudaMalloc(&kqv_int8_, size)); - ReportCUDAErrors(cudaMemcpy(kqv_int8_, w, size, cudaMemcpyHostToDevice)); - - //printf("\nCopied int8 weights for index: %d, size: %d\n", blockIndex, - // size); - - // go to output scaling factors - w += size; - output_scaling_factors_ = (float*)w; - } else if (int8_calibrate) { - // just save the pointers (we will over-write here during calibration) - input_scaling_factors_ = (float*)w; - w += embedding_op_size_ * sizeof(float); - kqv_int8_ = w; - w += 3 * cpu_weights.mha.q_w.size() * sizeof(int8_t); - output_scaling_factors_ = (float*)w; - - // to keep track of max values in input activation matrix - input_matrix_max_values_ = - (float*)malloc(64 * embedding_op_size_ * sizeof(float)); - memset(input_matrix_max_values_, 0, - 64 * embedding_op_size_ * sizeof(float)); - } - } + w += per_encoder_size * blockIndex; + w = SetQuantizationData(kqv_, w, embedding_op_size_, mha_q_size_, 3, int8_calibrate); + w = SetQuantizationData(mha_dense_, w, mha_q_size_, embedding_op_size_, 1, int8_calibrate); + w = SetQuantizationData(ffn1_, w, embedding_op_size_, ffn_dense1_size_, 1, int8_calibrate); + SetQuantizationData(ffn2_, w, ffn_dense1_size_, embedding_op_size_, 1, int8_calibrate); + // print some weights + /* + printf("\noutput scale first factor: %f, %f, %f, %f\n", + *kqv_.output_scaling_factors, *mha_dense_.output_scaling_factors, + *ffn1_.output_scaling_factors, *ffn2_.output_scaling_factors); + */ + } } template @@ -1700,7 +1725,7 @@ void quantizeActivationMatrix(int8_t* output, const half* input, int height, void deQuantizeOutputMatrixBiasAdd(half* output, const int8_t* input, int height, int width, int batchSize, float* scale, const half* bias, - cudaStream_t stream); + cudaStream_t stream, ActivationFunction act = NONE); // input/output tensor is scratch1, others are used as scratch. // TODO: fix naming of scratch buffers @@ -1796,31 +1821,30 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, mha_v = mha_k + num_outputs * batch_to_use; if (int8_cali_) { - calibrateGemmForInt8(kqv_int8_, input_scaling_factors_, output_scaling_factors_, - input_matrix_max_values_, scratch1, mha_qkv_w, 64, + calibrateGemmForInt8(kqv_.weights_int8, kqv_.input_scaling_factors, + kqv_.output_scaling_factors, kqv_.input_matrix_max_values, scratch1, mha_qkv_w, 64, d_model, embedding_op_size_, 3, N); } - - + if (int8_inf_) { // printf("\nAttempting int8_inf\n"); // 1. quantize the inputs (scratch1 -> scratch0) // TODO: Fuse this step with layer-norm of previous block quantizeActivationMatrix((int8_t*)scratch0, (const half*)scratch1, batch, - embedding_op_size_, input_scaling_factors_, + embedding_op_size_, kqv_.input_scaling_factors, stream); // 2. perform int8 GEMM (scratch0 -> scratch2) - cutlassMatrixMulBTransposed((const int8_t*)scratch0, kqv_int8_, - (int8_t*)scratch2, batch, num_outputs, - num_inputs, 3, 0, num_inputs * num_outputs, - num_outputs * batch_to_use, 1.0 / 127.0, 0.0f); + cutlassMatrixMulBTransposed( + (const int8_t*)scratch0, kqv_.weights_int8, (int8_t*)scratch2, batch, + num_outputs, num_inputs, 3, 0, num_inputs * num_outputs, + num_outputs * batch_to_use, 1.0 / 127.0, 0.0f); ReportCUDAErrors(cudaGetLastError()); // 3. de-quantize outputs - fused with bias add (scratch2 -> scratch0) // TODO: fuse the entire thing with the above GEMM. deQuantizeOutputMatrixBiasAdd((half*)scratch0, (const int8_t*)scratch2, batch, num_outputs, 3, - output_scaling_factors_, (const half*)mha_qkv_b, stream); + kqv_.output_scaling_factors, (const half*)mha_qkv_b, stream); } else { cublasXGemmStridedBatched( cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, @@ -1932,12 +1956,40 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, // #final dense layer (mha_dense), scratch3 -> scratch2 { + if (int8_cali_) { + calibrateGemmForInt8( + mha_dense_.weights_int8, mha_dense_.input_scaling_factors, + mha_dense_.output_scaling_factors, mha_dense_.input_matrix_max_values, + scratch3, mha_dense_w, 64, embedding_op_size_, d_model, 1, N); + } + const int num_inputs = d_model; const int num_outputs = embedding_op_size_; const int batch = N * 64; - cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, (const DataType*)mha_dense_w, num_inputs, - scratch3, num_inputs, 0.0f, scratch2, num_outputs); + + if (int8_inf_) { + // 1. quantize the inputs (scratch3 -> scratch0) + // TODO: Fuse this step with the previous fused MHA + quantizeActivationMatrix((int8_t*)scratch0, (const half*)scratch3, batch, + num_inputs, mha_dense_.input_scaling_factors, + stream); + + // 2. perform int8 GEMM (scratch0 -> scratch3) + cutlassMatrixMulBTransposed( + (const int8_t*)scratch0, mha_dense_.weights_int8, (int8_t*)scratch3, batch, + num_outputs, num_inputs, 1, 0, 0, 0, 1.0 / 127.0, 0.0f); + ReportCUDAErrors(cudaGetLastError()); + + // 3. de-quantize outputs (scratch3 -> scratch2) + // TODO: Fuse this with LN1 (should be easy!) + deQuantizeOutputMatrixBiasAdd( + (half*)scratch2, (const int8_t*)scratch3, batch, num_outputs, 1, + mha_dense_.output_scaling_factors, nullptr, stream); + } else { + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, mha_dense_w, num_inputs, scratch3, + num_inputs, 0.0f, scratch2, num_outputs); + } } // LN1: skip connection and layer normalization (also bias add of prev gemm) @@ -1949,24 +2001,80 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, // #FFN dense 1, scratch0 -> scratch1 const int encoder_dff = ffn_dense1_size_; { + if (int8_cali_) { + calibrateGemmForInt8( + ffn1_.weights_int8, ffn1_.input_scaling_factors, + ffn1_.output_scaling_factors, ffn1_.input_matrix_max_values, scratch1, + ffn_dense1_w, 64, encoder_dff, embedding_op_size_, 1, N); + } + const int num_inputs = embedding_op_size_; const int num_outputs = encoder_dff; const int batch = N * 64; - cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, (const DataType*)ffn_dense1_w, num_inputs, - scratch0, num_inputs, 0.0f, scratch1, num_outputs); - addBiasBatched(scratch1, scratch1, ffn_dense1_b, 1, batch, num_outputs, - has_smolgen_ ? RELU_2 : act, stream); // @todo sqr relu to have its own flag + + if (false && int8_inf_) { // Ankan - test! .. enabling this one kills accuracy :-/ + // 1. quantize the inputs (scratch0 -> scratch1) + // TODO: Fuse this step with LN1 (should be easy) + quantizeActivationMatrix((int8_t*)scratch1, (const half*)scratch0, batch, + num_inputs, ffn1_.input_scaling_factors, stream); + + // 2. perform int8 GEMM (scratch1 -> scratch2) + cutlassMatrixMulBTransposed((const int8_t*)scratch1, ffn1_.weights_int8, + (int8_t*)scratch2, batch, num_outputs, + num_inputs, 1, 0, 0, 0, 1.0 / 127.0, 0.0f); + ReportCUDAErrors(cudaGetLastError()); + + // 3. de-quantize outputs - fused with bias add (scratch2 -> scratch1) + // TODO: Fuse this with the above GEMM + deQuantizeOutputMatrixBiasAdd( + (half*)scratch1, (const int8_t*)scratch2, batch, num_outputs, 1, + ffn1_.output_scaling_factors, (const half*)ffn_dense1_b, stream, + has_smolgen_ ? RELU_2 : act); + } else { + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, ffn_dense1_w, num_inputs, scratch0, + num_inputs, 0.0f, scratch1, num_outputs); + addBiasBatched(scratch1, scratch1, ffn_dense1_b, 1, batch, num_outputs, + has_smolgen_ ? RELU_2 : act, + stream); // @todo sqr relu to have its own flag + } } // #FFN dense 2, scratch1 -> scratch2 { + if (int8_cali_) { + calibrateGemmForInt8( + ffn2_.weights_int8, ffn2_.input_scaling_factors, + ffn2_.output_scaling_factors, ffn2_.input_matrix_max_values, scratch1, + ffn_dense2_w, 64, embedding_op_size_, encoder_dff, 1, N); + } + const int num_inputs = encoder_dff; const int num_outputs = embedding_op_size_; const int batch = N * 64; - cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + if (int8_inf_) { + // 1. quantize the inputs (scratch1 -> scratch2) + // TODO: Fuse this step with above bias add at least (or ideally with the + // above GEMM) + quantizeActivationMatrix((int8_t*)scratch2, (const half*)scratch1, batch, + num_inputs, ffn2_.input_scaling_factors, stream); + + // 2. perform int8 GEMM (scratch2 -> scratch1) + cutlassMatrixMulBTransposed((const int8_t*)scratch2, ffn2_.weights_int8, + (int8_t*)scratch1, batch, num_outputs, + num_inputs, 1, 0, 0, 0, 1.0 / 127.0, 0.0f); + ReportCUDAErrors(cudaGetLastError()); + + // 3. de-quantize outputs (scratch1 -> scratch2) + // TODO: Fuse this with LN2 (should be easy) + deQuantizeOutputMatrixBiasAdd( + (half*)scratch2, (const int8_t*)scratch1, batch, num_outputs, 1, + ffn2_.output_scaling_factors, nullptr, stream); + } else { + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)ffn_dense2_w, num_inputs, scratch1, num_inputs, 0.0f, scratch2, num_outputs); + } } // LN2: skip connection and layer normilization (also bias add of prev gemm) @@ -2102,10 +2210,19 @@ EncoderBlock::~EncoderBlock() { ReportCUDAErrors(cudaFree(smol_ln2_betas)); } if (int8_inf_) { - ReportCUDAErrors(cudaFree(kqv_int8_)); - ReportCUDAErrors(cudaFree(input_scaling_factors_)); + ReportCUDAErrors(cudaFree(kqv_.weights_int8)); + ReportCUDAErrors(cudaFree(kqv_.input_scaling_factors)); + ReportCUDAErrors(cudaFree(mha_dense_.weights_int8)); + ReportCUDAErrors(cudaFree(mha_dense_.input_scaling_factors)); + ReportCUDAErrors(cudaFree(ffn1_.weights_int8)); + ReportCUDAErrors(cudaFree(ffn1_.input_scaling_factors)); + ReportCUDAErrors(cudaFree(ffn2_.weights_int8)); + ReportCUDAErrors(cudaFree(ffn2_.input_scaling_factors)); } else if (int8_cali_) { - free(input_matrix_max_values_); + free(kqv_.input_matrix_max_values); + free(mha_dense_.input_matrix_max_values); + free(ffn1_.input_matrix_max_values); + free(ffn2_.input_matrix_max_values); } } diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index 0ec58ba849..12b9f3f0c2 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -332,6 +332,19 @@ class ResidualBlock : public BaseLayer { DataType* b2_; }; +// calibration factors and weights for INT8 +// (in GPU memory when int8_inf_ is set, otherwise in CPU memory) +struct MatMulQuantizationData { + int8_t* weights_int8; // int8 quantized weights + float* input_scaling_factors; // per-column scaling factors for input + // quantization + float* output_scaling_factors; // per-tensor scaling factors for output + // dequantization (always in CPU memory) + float* input_matrix_max_values; // max values of input matrix (always in CPU + // memory) +}; + + template class EncoderBlock { public: @@ -367,14 +380,12 @@ class EncoderBlock { DataType *smol_global; + // int 8 stuff bool int8_inf_, int8_cali_; - - // calibration factors and weights for INT8 (in GPU memory when int8_inf_ is set, otherwise in CPU memory) - int8_t* kqv_int8_; // int8 quantized weights for the KQV matrix multiplication - float* input_scaling_factors_; // scaling factors needed to quantize the inputs - float* output_scaling_factors_; // scaling factors needed to dequantize the outputs (just 3 floats: always in CPU memory) - float* input_matrix_max_values_; // max values of input matrix to KQV GEMM - + MatMulQuantizationData kqv_; + MatMulQuantizationData mha_dense_; + MatMulQuantizationData ffn1_; + MatMulQuantizationData ffn2_; int mha_q_size_; int mha_k_size_; diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 9822eb6650..3c01f53053 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -370,18 +370,37 @@ class CudaNetwork : public Network { // Structure of the weights file: // For each encoder block - - // * per-channel scaling factors for Input Matrix to QKV GEMM + // * per-channel scaling factors for Input Matrix to QKV GEMM (embedding_op_size floats) // (to use for quantization of the input) - // * qunatized (int8) weights for QKV GEMMs + // * qunatized (int8) weights for QKV GEMMs (3 * encoder_d_model * embedding_op_size int8_ts) // * float factorQ, factorK, factorV - // (basically factors needed to de-quantize the output) TODO! - // (will add more as we try int8 for more layers) + // (basically factors needed to de-quantize the output) + // + // * per-channel scaling factors for the MHA dense layer's input (encoder_d_model floats) + // * Qunatized (int8) weights for MHA dense (embedding_op_size * encoder_d_model int8_ts) + // * per-tensor output scaling factor for MHA dense (single float) + // + // * per-channel scaling factors for input to FFN1 (embedding_op_size_ floats) + // * Qunatized (int8) weights for FFN1 (encoder_dff * encoder_d_model int8_ts) + // * single output scaling factor for FFN1 (single float) + // + // * per-channel scaling factors for input to FFN2 (encoder_dff floats) + // * Qunatized (int8) weights for FFN1 (encoder_d_model * encoder_dff int8_ts) + // * single output scaling factor for FFN1 (single float) int embedding_op_size = weights.ip_emb_b.size(); int encoder_d_model = weights.encoder[0].mha.q_b.size(); + int encoder_dff = weights.encoder[0].ffn.dense1_b.size(); int num_encoders = weights.encoder.size(); int8_weights_size_ = - num_encoders * ((embedding_op_size + 3) * sizeof(float) + - 3 * embedding_op_size * encoder_d_model * sizeof(int8_t)); + num_encoders * + (embedding_op_size * sizeof(float) + + 3 * embedding_op_size * encoder_d_model + 3 * sizeof(float) + + encoder_d_model * sizeof(float) + + embedding_op_size * encoder_d_model + sizeof(float) + + embedding_op_size * sizeof(float) + encoder_dff * encoder_d_model + + sizeof(float) + encoder_dff * sizeof(float) + + embedding_op_size * encoder_dff + sizeof(float)); + int8_weights_ = malloc(int8_weights_size_); memset(int8_weights_, 0, int8_weights_size_); From 451cc30763f5fab755f0b326f20d8476ae456cfc Mon Sep 17 00:00:00 2001 From: Ankan Banerjee Date: Thu, 25 May 2023 19:23:56 +0530 Subject: [PATCH 28/70] per-column quantization for outputs - somewhat working, but ~25 elo weaker than baseline :-/ --- src/neural/cuda/cutlass_kernels.cu | 441 +++++++++++++++++++++++++---- src/neural/cuda/layers.cc | 252 ++++++++++++++--- src/neural/cuda/layers.h | 10 +- src/neural/cuda/network_cuda.cc | 32 +-- 4 files changed, 625 insertions(+), 110 deletions(-) diff --git a/src/neural/cuda/cutlass_kernels.cu b/src/neural/cuda/cutlass_kernels.cu index e826c13f34..3c904054f7 100644 --- a/src/neural/cuda/cutlass_kernels.cu +++ b/src/neural/cuda/cutlass_kernels.cu @@ -29,12 +29,18 @@ #include #include "winograd_helper.inc" +#include +#include + + #ifdef USE_CUTLASS #include "cutlass/gemm/device/gemm_array.h" #include "cutlass/gemm/device/gemm_batched.h" - +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" +#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h" +#include "cutlass/gemm/device/gemm_universal_with_broadcast.h" // Fused MHA implementation from cutlass example #41 #include "fused_multi_head_attention/kernel_forward.h" @@ -142,7 +148,6 @@ static float mean(float arr[], int n) { return sum / n; } - // function to calculate standard deviation static float stdDev(float arr[], int n) { float m = mean(arr, n); // get the mean @@ -165,7 +170,7 @@ __device__ __forceinline__ void copyAs(void* dst, const void* src) { // debug code to dump allocation in GPU memory template void dumpTensor(const T* memory, int elements, const char* message, - bool only_summary = false, bool cpu_tensor = false) { + bool only_summary = false, bool cpu_tensor = false) { const bool fp16 = std::is_same::value; const bool int8 = std::is_same::value; printf("\n%s\n", message); @@ -188,8 +193,7 @@ void dumpTensor(const T* memory, int elements, const char* message, if (int8) { int8_t* arr = (int8_t*)temp; val = (float)arr[i]; - } - else if (fp16) { + } else if (fp16) { half* arr = (half*)temp; val = (float)arr[i]; } else { @@ -229,6 +233,68 @@ void dumpTensor(const T* memory, int elements, const char* message, printf("\n"); } +// int8 GEMM using CUTLASS (with per-column output quantization) +void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, + const float* scaleVector, int8_t* Out, int M, + int N, int K, int batchSize, int AStride, + int BStride, int OutStride, int VecStride, + float alphaf, float betaf) { + using ElementAccumulator = int32_t; + using ElementComputeEpilogue = float; + using ElementIO = int8_t; + using ElementScale = float; + using ThreadBlockSwizzle = + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; + constexpr int elementsPerAccess = + 128 / cutlass::sizeof_bits::value; + + using EpilogueOutputOp = + cutlass::epilogue::thread::LinearCombinationBiasElementwise< + ElementIO, ElementAccumulator, ElementComputeEpilogue, ElementIO, + ElementIO, + ElementScale, // element Vector + elementsPerAccess, false, + cutlass::multiplies>; + + using Gemm = cutlass::gemm::device::GemmUniversalWithBroadcast< + ElementIO, cutlass::layout::RowMajor, ElementIO, + cutlass::layout::ColumnMajor, ElementIO, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<16, 8, 32>, EpilogueOutputOp, ThreadBlockSwizzle, + 2>; + + typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kBatched, + {M, N, K}, + batchSize, + {alphaf, betaf}, + A, + B, + nullptr, + Out, + (float*)scaleVector, + nullptr, + AStride, + BStride, + 0, + OutStride, + VecStride, // batch_stride_Vector + 0, + K, + K, + 0, + N, + 0, + 0}; + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm_op; + + // Initialize CUTLASS kernel with arguments and workspace pointer + cutlass::Status status = gemm_op.initialize(arguments, nullptr); + status = gemm_op(); +} // int8 GEMM using CUTLASS void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int8_t* Out, @@ -236,8 +302,8 @@ void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int8_t* Out, int AStride, int BStride, int OutStride, float alphaf, float betaf) { // Ankan - For testing! - //dumpTensor(A, 512, "A after scaling", false); - //dumpTensor(B, 512, "B after scaling", false); + // dumpTensor(A, 512, "A after scaling", false); + // dumpTensor(B, 512, "B after scaling", false); using ElementAccumulator = int32_t; // <- data type of accumulator using ElementComputeEpilogue = float; // <- data type of epilogue operations @@ -292,9 +358,207 @@ void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int8_t* Out, status = gemm_op(); } +void cutlassMatrixMulBTransposed_Emulate_INT8(const half* A, const half* B, + half* Out, int M, int N, int K, + int batchSize, int AStride, + int BStride, int OutStride, + bool useInt8) { + // emulate int8 quantization: + // * Copy inputs to CPU + // * Use smooth quant to figure out per-channel scaling factor for the input + // * compute quantized weights and scaling factors for A and B matrices + // * quantize the A and B matrices + // * multiply them on CPU + // * de-quantize the output back to fp16 + + int ASize = M * K; + int BSize = K * N * batchSize; + int OutSize = M * N * batchSize; + half* cpuA = (half*)malloc(ASize * sizeof(half)); + half* cpuB = (half*)malloc(BSize * sizeof(half)); + half* cpuOut = (half*)malloc(OutSize * sizeof(half)); + + int8_t* AInt8 = (int8_t*)malloc(ASize); + int8_t* BInt8 = (int8_t*)malloc(BSize); + int8_t* OutInt8 = (int8_t*)malloc(OutSize); + + cudaMemcpy(cpuA, A, ASize * sizeof(half), cudaMemcpyDeviceToHost); + cudaMemcpy(cpuB, B, BSize * sizeof(half), cudaMemcpyDeviceToHost); + + std::vector scaling_factors(K); + std::vector input_scaling_factors(K); // Not used here, but just for testing. + std::vector output_scaling_factors(N * batchSize); + + // apply smooth-quant (basically adjust A and B matrices to make + // quantization easier) + for (int k = 0; k < K; k++) { + float absMaxA = 0; + float absMaxB = 0; + // scan a column of Matrix A to find the abs max. + for (int y = 0; y < M; y++) { + float val = (float) cpuA[y * K + k]; + absMaxA = std::max(absMaxA, abs(val)); + } + + // scan a column of Matrix B (from each batch dimension) + for (int b = 0; b < batchSize; b++) + for (int x = 0; x < N; x++) { + float val = (float) cpuB[b * N * K + x * K + k]; + absMaxB = std::max(absMaxB, abs(val)); + } + + // compute scaling factor: + float s = sqrt(absMaxA / (absMaxB)); + + // sanity check, don't use too small, or too big scaling factors + if (s < 1) + s = 1.0f; // don't try to squeeze activations for improving range of + // weights! + if (s > 10) s = 10.0f; + + scaling_factors[k] = s; + + // printf("\nMaxA: %f, MaxB: %f, scale: %f ", absMaxA, absMaxB, s); + + // scale A and B matrices using the scaling factor + for (int y = 0; y < M; y++) { + float val = (float) cpuA[y * K + k]; + val /= s; + cpuA[y * K + k] = (half) val; + } + + for (int b = 0; b < batchSize; b++) + for (int x = 0; x < N; x++) { + float val = (float) cpuB[b * N * K + x * K + k]; + val *= s; + cpuB[b * N * K + x * K + k] = (half) val; + } + } + + // figure out scaling factors for A and B matrices + float absMaxA = 0; + for (int i = 0; i < M * K; i++) { + float val = (float) cpuA[i]; + absMaxA = std::max(absMaxA, abs(val)); + } + + float AFactor = 127.0 / absMaxA; + + // update the scaling factors based on global max for Activation matrix + for (int i = 0; i < K; i++) { + input_scaling_factors[i] = 127.0f / (scaling_factors[i] * absMaxA); + } + + std::vector BFactor(batchSize); + for (int b = 0; b < batchSize; b++) { + float absMaxB = 0; + for (int i = 0; i < K * N; i++) { + float val = cpuB[i + b * K * N]; + absMaxB = std::max(absMaxB, abs(val)); + } + + // quantize the weights + float scaleB = 127.0f / absMaxB; + BFactor[b] = scaleB; + for (int i = 0; i < K * N; i++) { + float val = (float) cpuB[i + b * K * N]; + // quantize and clamp + val = (val * scaleB); + if (val > 127) val = 127; + if (val < -128) val = -128; + BInt8[i + b * K * N] = (int8_t)roundf(val); + } + } + + // quantize input activation matrix (A) + for (int i = 0; i < M * K; i++) { + float val = (float)cpuA[i]; + val = (val * AFactor); + if (val > 127) val = 127; + if (val < -128) val = -128; + AInt8[i] = (int8_t)roundf(val); + } + + + // output scaling factors + // multiply the matrices to figure out range of values in output matrix for + // per-channel quantization + for (int b = 0; b < batchSize; b++) { + for (int x = 0; x < N; x++) { + float colAbsMax = 0; + for (int y = 0; y < M; y++) { + int s = 0; + for (int k = 0; k < K; k++) { + int v1 = AInt8[y * K + k]; + int v2 = BInt8[b * K * N + x * K + k]; + s += v1 * v2; + } + colAbsMax = std::max(colAbsMax, (float)abs(s)); + } + float outFactor = colAbsMax ? (127.0f / colAbsMax) : 1; + output_scaling_factors[b * N + x] = outFactor; + } + } + + std::vector output_deq_factors(batchSize); + for (int i = 0; i < batchSize; i++) + output_deq_factors[i] = 1.0f / (AFactor * BFactor[i]); + + // Actually multiply the int8 matrices and apply per-column scaling to store int8 result + for (int b = 0; b < batchSize; b++) { + for (int x = 0; x < N; x++) { + for (int y = 0; y < M; y++) { + int s = 0; + for (int k = 0; k < K; k++) { + int v1 = AInt8[y * K + k]; + int v2 = BInt8[b * K * N + x * K + k]; + s += v1 * v2; + } + OutInt8[b * M * N + N * y + x] = + (int8_t) roundf(s * output_scaling_factors[b * N + x]); + } + } + } + + + // dequantize the output matrix + for (int b = 0; b < batchSize; b++) + for (int x = 0; x < N; x++) + for (int y = 0; y < M; y++) { + float val = (float) OutInt8[b * M * N + N * y + x]; + val /= output_scaling_factors[b * N + x]; + val *= output_deq_factors[b]; + cpuOut[b * M * N + N * y + x] = (half)val; + } + + + // dump the inputs and outputs for debugging - Ankan + dumpTensor(&input_scaling_factors[0], 768, "input_scaling_factors", false, true); + dumpTensor(AInt8, 768, "input quantized", false, true); + dumpTensor(BInt8, 768, "weight quantized", false, true); + dumpTensor(OutInt8, 768, "output quantized", false, true); + dumpTensor(&output_scaling_factors[0], 768, "output_scaling_factors", + false, true); + //dumpTensor(cpuOut, 768, "dequantized output", false, true); + //exit(0); + + cudaMemcpy(Out, cpuOut, OutSize * sizeof(half), cudaMemcpyHostToDevice); + + free(cpuA); + free(cpuB); + free(cpuOut); + + free(AInt8); + free(BInt8); + free(OutInt8); +} + // FP16 GEMM using cutlass void cutlassMatrixMulBTransposed(const half* A, const half* B, half* Out, int M, - int N, int K, int batchSize, int AStride, int BStride, int OutStride, bool useInt8) { + int N, int K, int batchSize, int AStride, + int BStride, int OutStride, bool useInt8) { + if (useInt8) + return cutlassMatrixMulBTransposed_Emulate_INT8(A, B, Out, M, N, K, batchSize, AStride, BStride, OutStride, useInt8); half halfOne = (half)1.0f; half halfZero = (half)0.0f; @@ -342,7 +606,7 @@ void cutlassMatrixMulBTransposed(const half* A, const half* B, half* Out, int M, // precision, it's 8 elements. This // becomes the vector width of math // instructions in epilogue too - ElementAccumulator, // <- data type of accumulator + ElementAccumulator, // <- data type of accumulator float>; // <- data type for alpha/beta in linear combination function constexpr int NumStages = 3; // stages == 2/4 is also good sometimes @@ -368,11 +632,15 @@ void cutlassMatrixMulBTransposed(const half* A, const half* B, half* Out, int M, } +int8_t values[1024 * 64]; static void calibrateGemm(int8_t* weights_int8, float* input_scaling_factors, - float* output_scaling_factors, float* cpuA, - float* cpuB, int M, int N, int K, int batchSize) { + float* output_scaling_factors, float *output_deq_factors, float* cpuA, + float* cpuB, float *maxValuesA, float *maxValuesOut, int M, int N, int K, int batchSize) { std::vector scaling_factors(K); + std::vector A_Max(M * K); // this is another matrix we use to track calculations using max values + for (int i = 0; i < M * K; i++) A_Max[i] = maxValuesA[i]; + // apply smooth-quant (basically adjust A and B matrices to make quantization // easier) for (int k = 0; k < K; k++) { @@ -380,7 +648,7 @@ static void calibrateGemm(int8_t* weights_int8, float* input_scaling_factors, float absMaxB = 0; // scan a column of Matrix A to find the abs max. for (int y = 0; y < M; y++) { - float val = cpuA[y * K + k]; + float val = A_Max[y * K + k]; absMaxA = std::max(absMaxA, abs(val)); } @@ -406,23 +674,23 @@ static void calibrateGemm(int8_t* weights_int8, float* input_scaling_factors, // scale A and B matrices using the scaling factor for (int y = 0; y < M; y++) { - float val = cpuA[y * K + k]; + float val = A_Max[y * K + k]; val /= s; - cpuA[y * K + k] = (half)val; + A_Max[y * K + k] = val; } for (int b = 0; b < batchSize; b++) for (int x = 0; x < N; x++) { float val = cpuB[b * N * K + x * K + k]; val *= s; - cpuB[b * N * K + x * K + k] = (half)val; + cpuB[b * N * K + x * K + k] = val; } } // figure out scaling factors for A and B matrices float absMaxA = 0; for (int i = 0; i < M * K; i++) { - float val = cpuA[i]; + float val = A_Max[i]; absMaxA = std::max(absMaxA, abs(val)); } @@ -455,11 +723,66 @@ static void calibrateGemm(int8_t* weights_int8, float* input_scaling_factors, } // output scaling factors + // multiply the matrices to figure out range of values in output matrix for per-channel quantization + for (int b = 0; b < batchSize; b++) { + for (int x = 0; x < N; x++) { + float colAbsMax = 0; + for (int y = 0; y < M; y++) { + int s = 0; + for (int k = 0; k < K; k++) { + int v1 = (int)roundf(input_scaling_factors[k] * cpuA[y * K + k]); + int v2 = weights_int8[b * K * N + x * K + k]; + s += v1 * v2; + } + float m = maxValuesOut[b * M * N + y * N + x]; + m = std::max(m, abs((float)s)); + maxValuesOut[b * M * N + y * N + x] = m; // also update the abs max value found in output matrix + colAbsMax = std::max(colAbsMax, m); + } + float outFactor = colAbsMax ? (127.0f / colAbsMax) : 1; + output_scaling_factors[b * N + x] = outFactor; + } + } + for (int i = 0; i < batchSize; i++) - output_scaling_factors[i] = 127.0 / (AFactor * BFactor[i]); + output_deq_factors[i] = 1.0f / (AFactor * BFactor[i]); + +#if 0 + // Ankan - For debug - print the quantized expected values of output matrix + for (int b = 0; b < batchSize; b++) + for (int x = 0; x < N; x++) + for (int y = 0; y < M; y++) { + float s = 0; + for (int k = 0; k < K; k++) { + float v1 = input_scaling_factors[k] * cpuA[y * K + k]; + float v2 = (float)(weights_int8[b * K * N + x * K + k]); + s += v1 * v2; + } + values[b * M * N + y * N + x] = + (int8_t)roundf(s * output_scaling_factors[b * N + x]); + } + + for (int y = 0; y < M; y++) + for (int k = 0; k < K; k++) + cpuA[i] *= input_scaling_factors[k]; + + + dumpTensor(cpuA, 768, "input matrix during calibration", + false, true); + + dumpTensor(weights_int8, 768, "weights - during calibration", false, + true); + + + dumpTensor(output_scaling_factors, 768, "output_scaling_factors", + false, true); + + dumpTensor(values, 768, "output during quantization", false, true); + exit(0); +#endif // Ankan - for debug/test - //printf("\nScaling factors - A: %g, B_Q: %g, B_K: %g, B_V: %g \n", + // printf("\nScaling factors - A: %g, B_Q: %g, B_K: %g, B_V: %g \n", // 127.0 / absMaxA, BFactor[0], BFactor[1], BFactor[2]); } @@ -474,7 +797,8 @@ static void calibrateGemm(int8_t* weights_int8, float* input_scaling_factors, // maxValuesA contains the max values in activation matrix found so far template void calibrateGemmForInt8(int8_t* weights_int8, float* input_scaling_factors, - float* output_scaling_factors, float* maxValuesA, + float* output_scaling_factors, + float* output_deq_factors, float* maxValuesA, float *maxValuesOut, const DataType* A, const DataType* B, int M, int N, int K, int batchSize, int M_Batch) { auto cpuA = (DataType*)malloc(M_Batch * M * K * sizeof(DataType)); @@ -494,15 +818,16 @@ void calibrateGemmForInt8(int8_t* weights_int8, float* input_scaling_factors, for (int b = 0; b < M_Batch; b++) { for (int i = 0; i < M * K; i++) { - float val = abs((float)cpuA[b * M * K + i]); - val = std::max(val, maxValuesA[i]); + float val = (float)cpuA[b * M * K + i]; fpA[i] = val; + val = std::max(abs(val), maxValuesA[i]); maxValuesA[i] = val; // update the max activation matrix } // calibrate a single sample calibrateGemm(weights_int8, input_scaling_factors, output_scaling_factors, - fpA, fpB, M, N, K, batchSize); + output_deq_factors, fpA, fpB, maxValuesA, maxValuesOut, M, N, K, + batchSize); } free(fpA); @@ -511,7 +836,6 @@ void calibrateGemmForInt8(int8_t* weights_int8, float* input_scaling_factors, free(cpuB); } - // process 8 elements per thread (in x dimension) __global__ void quantizeMatrix(int8_t* output, const half* input, int height, int width, const float* scale) { @@ -526,7 +850,7 @@ __global__ void quantizeMatrix(int8_t* output, const half* input, int height, copyAs(&ip[0], &input[y * width + x]); copyAs(&factor[0], &scale[x]); - copyAs(&factor[4], &scale[x+4]); + copyAs(&factor[4], &scale[x + 4]); for (int i = 0; i < 8; i++) { float val = roundf((float)ip[i] * factor[i]); @@ -538,11 +862,10 @@ __global__ void quantizeMatrix(int8_t* output, const half* input, int height, copyAs(&output[y * width + x], &op[0]); } - // The scale is per column void quantizeActivationMatrix(int8_t* output, const half* input, int height, - int width, const float* scale, cudaStream_t stream) { - + int width, const float* scale, + cudaStream_t stream) { dim3 blockDim(16, 16); dim3 gridDim(lczero::cudnn_backend::DivUp(width, 16 * 8), lczero::cudnn_backend::DivUp(height, 16)); @@ -551,7 +874,6 @@ void quantizeActivationMatrix(int8_t* output, const half* input, int height, ReportCUDAErrors(cudaGetLastError()); } - #define MAX_BATCH_DEQUANT 16 struct ScaleParam { @@ -559,73 +881,82 @@ struct ScaleParam { }; // process 8 elements per thread (in x dimension) -__global__ void deQuantizeMatrix(half* output, const int8_t* input, const half *bias, int height, int width, int stride, ScaleParam s, ActivationFunction act) { +__global__ void deQuantizeMatrix(half* output, const int8_t* input, + const half* bias, int height, int width, + int stride, const float *invScale, ScaleParam deq, + ActivationFunction act) { int x = (blockIdx.x * blockDim.x + threadIdx.x) * 8; int y = blockIdx.y * blockDim.y + threadIdx.y; int b = blockIdx.z; if (x >= width || y >= height) return; - float factor = s.scale[b]; - int8_t ip[8] = {}; half op[8] = {}; half bi[8] = {}; + float inv_scale[8]; + float deq_scale = deq.scale[b]; copyAs(&ip[0], &input[b * stride + y * width + x]); - if (bias) - copyAs(&bi[0], &bias[b * width + x]); + if (bias) copyAs(&bi[0], &bias[b * width + x]); + + if (invScale) { + copyAs(&inv_scale[0], &invScale[b * width + x]); + copyAs(&inv_scale[4], &invScale[b * width + x + 4]); + } else { + for (int i = 0; i < 8; i++) inv_scale[i] = 1 / 127.0f; + } for (int i = 0; i < 8; i++) { float val = (float)ip[i]; - val *= factor; + val *= (deq_scale / inv_scale[i]); if (bias) val += (float)bi[i]; - op[i] = (half) activate(val, act); + op[i] = (half)activate(val, act); } copyAs(&output[b * stride + y * width + x], &op[0]); } - - // the scale (in CPU memory) is per "batch" // the bias is per column, per batch void deQuantizeOutputMatrixBiasAdd(half* output, const int8_t* input, int height, int width, int batchSize, - float* scale, const half* bias, + float* invScale, float *deq, const half* bias, cudaStream_t stream, ActivationFunction act = NONE) { dim3 blockDim(16, 16); dim3 gridDim(lczero::cudnn_backend::DivUp(width, 16 * 8), lczero::cudnn_backend::DivUp(height, 16), batchSize); - assert(batchSize < MAX_BATCH_DEQUANT); // otherwise we will need to put them in GPU memory + // otherwise we will need to put them in GPU memory + assert(batchSize < MAX_BATCH_DEQUANT); int stride = width * height; ScaleParam s = {}; - for (int i = 0; i < batchSize; i++) s.scale[i] = scale[i]; + for (int i = 0; i < batchSize; i++) s.scale[i] = deq[i]; - deQuantizeMatrix<<>>(output, input, bias, height, width, stride, s, act); + deQuantizeMatrix<<>>( + output, input, bias, height, width, stride, invScale, s, act); ReportCUDAErrors(cudaGetLastError()); - } +void fillGpuArray(float* arr, float val, int count) { + thrust::device_ptr dev_ptr(arr); + thrust::fill(dev_ptr, dev_ptr + count, val); +} - -template void calibrateGemmForInt8(int8_t* weights_int8, - float* input_scaling_factors, - float* output_scaling_factors, - float* maxValuesA, const float* A, - const float* B, int M, int N, int K, - int batchSize, int M_Batch); -template void calibrateGemmForInt8(int8_t* weights_int8, - float* input_scaling_factors, - float* output_scaling_factors, - float* maxValuesA, const half* A, - const half* B, int M, int N, int K, - int batchSize, int M_Batch); +template void calibrateGemmForInt8( + int8_t* weights_int8, float* input_scaling_factors, + float* output_scaling_factors, float* output_deq_factors, float* maxValuesA, + float* maxValuesOut, const float* A, const float* B, int M, int N, int K, + int batchSize, int M_Batch); +template void calibrateGemmForInt8( + int8_t* weights_int8, float* input_scaling_factors, + float* output_scaling_factors, float* output_deq_factors, float* maxValuesA, + float* maxValuesOut, const half* A, const half* B, int M, int N, int K, + int batchSize, int M_Batch); template void dumpTensor(const float* memory, int elements, const char* message, bool only_summary, diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index c1db9205d7..81323769a6 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -38,7 +38,14 @@ namespace lczero { -#if 1 +namespace cudnn_backend { +template +void dumpTensor(const T* memory, int elements, const char* message, + bool only_summary = false, bool cpu_tensor = false); +} + + +#if 0 #include using namespace std; @@ -1480,6 +1487,7 @@ AttentionPolicyHead::AttentionPolicyHead(BaseLayer* ip, } } +void fillGpuArray(float *arr, float val, int count); static int8_t* SetQuantizationData(MatMulQuantizationData& data, int8_t* w, int InputCols, int OutputCols, int outBatch, bool cali) { size_t matrix_size = InputCols * OutputCols * sizeof(int8_t) * outBatch; @@ -1501,10 +1509,16 @@ static int8_t* SetQuantizationData(MatMulQuantizationData& data, int8_t* w, int // go to output scaling factors w += matrix_size; - data.output_scaling_factors = (float*)w; - + ReportCUDAErrors(cudaMalloc(&data.output_scaling_factors, + sizeof(float) * OutputCols * outBatch)); + ReportCUDAErrors(cudaMemcpy(data.output_scaling_factors, w, + sizeof(float) * OutputCols * outBatch, + cudaMemcpyHostToDevice)); + // go to output dequantization factors + w += outBatch * OutputCols * sizeof(float); + data.output_deq_factors = (float*) w; - // go to next entry + // go to next item w += outBatch * sizeof(float); } else { // Just save the pointers to CPU weights (we will over-write here during calibration) @@ -1513,12 +1527,18 @@ static int8_t* SetQuantizationData(MatMulQuantizationData& data, int8_t* w, int data.weights_int8 = w; w += matrix_size; data.output_scaling_factors = (float*)w; + w += outBatch * OutputCols * sizeof(float); + data.output_deq_factors = (float*) w; w += outBatch * sizeof(float); - // to keep track of max values in input activation matrix + // to keep track of max values in activation matrices int InputMatrixSizeForBatch1 = 64 * InputCols * sizeof(float); data.input_matrix_max_values = (float*)malloc(InputMatrixSizeForBatch1); memset(data.input_matrix_max_values, 0, InputMatrixSizeForBatch1); + + int OutputMatrixSizeForBatch1 = 64 * OutputCols * sizeof(float) * outBatch; + data.output_matrix_max_values = (float*)malloc(OutputMatrixSizeForBatch1); + memset(data.output_matrix_max_values, 0, OutputMatrixSizeForBatch1); } // return pointer to next item @@ -1615,7 +1635,9 @@ EncoderBlock::EncoderBlock( smol_global = smolgen_global_scratch; } // int8 stuff + blockIndex_ = blockIndex; if (int8_inference || int8_calibrate) { + /* int per_encoder_size = embedding_op_size_ * sizeof(float) + 3 * embedding_op_size_ * mha_q_size_ + 3 * sizeof(float) + mha_q_size_ * sizeof(float) + @@ -1624,6 +1646,15 @@ EncoderBlock::EncoderBlock( ffn_dense1_size_ * mha_q_size_ + sizeof(float) + ffn_dense1_size_ * sizeof(float) + embedding_op_size_ * ffn_dense1_size_ + sizeof(float); + */ + int embedding_op_size = embedding_op_size_; + int encoder_d_model = mha_q_size_; + int encoder_dff = ffn_dense1_size_; + int per_encoder_size = + (embedding_op_size * sizeof(float) + 3 * embedding_op_size * encoder_d_model + 3 * (encoder_d_model + 1) * sizeof(float) + + encoder_d_model * sizeof(float) + encoder_d_model * embedding_op_size + (embedding_op_size + 1) * sizeof(float) + + embedding_op_size * sizeof(float) + embedding_op_size * encoder_dff + (encoder_dff + 1) * sizeof(float) + + encoder_dff * sizeof(float) + encoder_dff * embedding_op_size + (embedding_op_size + 1) * sizeof(float)); auto w = (int8_t*)int8_weights; // go to current encoder block @@ -1631,7 +1662,8 @@ EncoderBlock::EncoderBlock( w = SetQuantizationData(kqv_, w, embedding_op_size_, mha_q_size_, 3, int8_calibrate); w = SetQuantizationData(mha_dense_, w, mha_q_size_, embedding_op_size_, 1, int8_calibrate); w = SetQuantizationData(ffn1_, w, embedding_op_size_, ffn_dense1_size_, 1, int8_calibrate); - SetQuantizationData(ffn2_, w, ffn_dense1_size_, embedding_op_size_, 1, int8_calibrate); + w = SetQuantizationData(ffn2_, w, ffn_dense1_size_, embedding_op_size_, 1, int8_calibrate); + // printf("\nSize of weights: %d\n", (w - (int8_t*)int8_weights)); // print some weights /* @@ -1709,8 +1741,9 @@ static void cublasXGemmBatched( template void calibrateGemmForInt8(int8_t* weights_int8, float* input_scaling_factors, - float* output_scaling_factors, float* maxValuesA, - const DataType* A, const DataType* B, int M, int N, + float* output_scaling_factors, + float* output_deq_factors, float* maxValuesA, + float* maxValuesOut, const DataType* A, const DataType* B, int M, int N, int K, int batchSize, int M_Batch); void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int8_t* Out, @@ -1718,13 +1751,19 @@ void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int8_t* Out, int AStride, int BStride, int OutStride, float alphaf, float betaf); +void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, + const float* scaleVector, int8_t* Out, int M, + int N, int K, int batchSize, int AStride, + int BStride, int OutStride, int VecStride, + float alphaf, float betaf); + void quantizeActivationMatrix(int8_t* output, const half* input, int height, int width, const float* scale, cudaStream_t stream); void deQuantizeOutputMatrixBiasAdd(half* output, const int8_t* input, int height, int width, int batchSize, - float* scale, const half* bias, + float* scale, float *deq, const half* bias, cudaStream_t stream, ActivationFunction act = NONE); // input/output tensor is scratch1, others are used as scratch. @@ -1822,11 +1861,13 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, if (int8_cali_) { calibrateGemmForInt8(kqv_.weights_int8, kqv_.input_scaling_factors, - kqv_.output_scaling_factors, kqv_.input_matrix_max_values, scratch1, mha_qkv_w, 64, - d_model, embedding_op_size_, 3, N); + kqv_.output_scaling_factors, kqv_.output_deq_factors, + kqv_.input_matrix_max_values, + kqv_.output_matrix_max_values, scratch1, + mha_qkv_w, 64, d_model, embedding_op_size_, 3, N); } - if (int8_inf_) { + if (true && int8_inf_) { // printf("\nAttempting int8_inf\n"); // 1. quantize the inputs (scratch1 -> scratch0) // TODO: Fuse this step with layer-norm of previous block @@ -1835,35 +1876,68 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, stream); // 2. perform int8 GEMM (scratch0 -> scratch2) + /* cutlassMatrixMulBTransposed( (const int8_t*)scratch0, kqv_.weights_int8, (int8_t*)scratch2, batch, num_outputs, num_inputs, 3, 0, num_inputs * num_outputs, num_outputs * batch_to_use, 1.0 / 127.0, 0.0f); + */ + // per-layer output scaling + + cutlassMatrixMulBTransposed( + (const int8_t*)scratch0, kqv_.weights_int8, + kqv_.output_scaling_factors, (int8_t*)scratch2, batch, + num_outputs, num_inputs, 3, 0, num_inputs * num_outputs, + num_outputs * batch_to_use, num_outputs, 1.0f, 0.0f); + ReportCUDAErrors(cudaGetLastError()); + /* + dumpTensor((const int8_t*)scratch0, 768, "quantized input matrix", + false, false); + + dumpTensor(kqv_.weights_int8, 768, + "weights - during run", false, false); + + dumpTensor((const int8_t*)scratch2, 768, + "some quantized output values", false, false); + dumpTensor(kqv_.output_scaling_factors, 768, + "output_scaling_factors - during run", + false, false); + */ // 3. de-quantize outputs - fused with bias add (scratch2 -> scratch0) // TODO: fuse the entire thing with the above GEMM. deQuantizeOutputMatrixBiasAdd((half*)scratch0, (const int8_t*)scratch2, batch, num_outputs, 3, - kqv_.output_scaling_factors, (const half*)mha_qkv_b, stream); - } else { + kqv_.output_scaling_factors, + kqv_.output_deq_factors, (const half*) mha_qkv_b, stream); + + /* + dumpTensor((const half*)scratch0, 768, + "dequantized output values after bias add", false, false); + exit(0); + */ + } else { + cublasXGemmStridedBatched( cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, mha_qkv_w, num_inputs, num_inputs * num_outputs, scratch1, num_inputs, 0, 0.0f, mha_q, num_outputs, num_outputs * batch_to_use, 3); - /* +#if 0 cutlassMatrixMulBTransposed((const half*)scratch1, (const half*)mha_qkv_w, - (half*) mha_q, batch, - num_outputs, num_inputs, 3, - 0, num_inputs * num_outputs, - num_outputs * batch_to_use); - */ - // dumpTensor(mha_q, num_outputs * N, "output of kqv gemm", false); - // exit(0); - + (half*)mha_q, batch_to_use, num_outputs, + num_inputs, 3, 0, num_inputs * num_outputs, + num_outputs * batch_to_use, true); + addBiasBatched(mha_q, mha_q, mha_qkv_b, 3, batch, num_outputs, batch_to_use, NONE, stream); - } + + dumpTensor((const DataType*)mha_q, + /*num_outputs * batch_to_use*/ 768, + "ref output values after bias add", false, false); + exit(0); +#endif + } } // Apply split_heads() to q, k and v @@ -1959,7 +2033,9 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, if (int8_cali_) { calibrateGemmForInt8( mha_dense_.weights_int8, mha_dense_.input_scaling_factors, - mha_dense_.output_scaling_factors, mha_dense_.input_matrix_max_values, + mha_dense_.output_scaling_factors, mha_dense_.output_deq_factors, + mha_dense_.input_matrix_max_values, + mha_dense_.output_matrix_max_values, scratch3, mha_dense_w, 64, embedding_op_size_, d_model, 1, N); } @@ -1967,7 +2043,7 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, const int num_outputs = embedding_op_size_; const int batch = N * 64; - if (int8_inf_) { + if (true && int8_inf_) { // 1. quantize the inputs (scratch3 -> scratch0) // TODO: Fuse this step with the previous fused MHA quantizeActivationMatrix((int8_t*)scratch0, (const half*)scratch3, batch, @@ -1975,20 +2051,35 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, stream); // 2. perform int8 GEMM (scratch0 -> scratch3) + /* cutlassMatrixMulBTransposed( (const int8_t*)scratch0, mha_dense_.weights_int8, (int8_t*)scratch3, batch, num_outputs, num_inputs, 1, 0, 0, 0, 1.0 / 127.0, 0.0f); + */ + + cutlassMatrixMulBTransposed( + (const int8_t*)scratch0, mha_dense_.weights_int8, + mha_dense_.output_scaling_factors, (int8_t*)scratch3, + batch, num_outputs, num_inputs, 1, 0, 0, 0, 0, 1.0f, 0.0f); + + ReportCUDAErrors(cudaGetLastError()); // 3. de-quantize outputs (scratch3 -> scratch2) // TODO: Fuse this with LN1 (should be easy!) deQuantizeOutputMatrixBiasAdd( (half*)scratch2, (const int8_t*)scratch3, batch, num_outputs, 1, - mha_dense_.output_scaling_factors, nullptr, stream); + mha_dense_.output_scaling_factors, mha_dense_.output_deq_factors, + nullptr, stream); } else { cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, mha_dense_w, num_inputs, scratch3, num_inputs, 0.0f, scratch2, num_outputs); + /* + cutlassMatrixMulBTransposed( + (const half*)scratch3, (const half*)mha_dense_w, (half*)scratch2, + batch, num_outputs, num_inputs, 1, 0, 0, 0, true); + */ } } @@ -2004,39 +2095,82 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, if (int8_cali_) { calibrateGemmForInt8( ffn1_.weights_int8, ffn1_.input_scaling_factors, - ffn1_.output_scaling_factors, ffn1_.input_matrix_max_values, scratch1, - ffn_dense1_w, 64, encoder_dff, embedding_op_size_, 1, N); + ffn1_.output_scaling_factors, ffn1_.output_deq_factors, + ffn1_.input_matrix_max_values, ffn1_.output_matrix_max_values, scratch0, + ffn_dense1_w, 64, + encoder_dff, embedding_op_size_, 1, N); } const int num_inputs = embedding_op_size_; const int num_outputs = encoder_dff; const int batch = N * 64; - if (false && int8_inf_) { // Ankan - test! .. enabling this one kills accuracy :-/ + if (true && int8_inf_) { // 1. quantize the inputs (scratch0 -> scratch1) // TODO: Fuse this step with LN1 (should be easy) quantizeActivationMatrix((int8_t*)scratch1, (const half*)scratch0, batch, num_inputs, ffn1_.input_scaling_factors, stream); // 2. perform int8 GEMM (scratch1 -> scratch2) + /* cutlassMatrixMulBTransposed((const int8_t*)scratch1, ffn1_.weights_int8, (int8_t*)scratch2, batch, num_outputs, num_inputs, 1, 0, 0, 0, 1.0 / 127.0, 0.0f); + */ + cutlassMatrixMulBTransposed((const int8_t*)scratch1, ffn1_.weights_int8, + ffn1_.output_scaling_factors, + (int8_t*)scratch2, batch, num_outputs, + num_inputs, 1, 0, 0, 0, 0, 1.0f, 0.0f); + ReportCUDAErrors(cudaGetLastError()); + /* + dumpTensor(ffn1_.input_scaling_factors, 768, + "input_scaling_factors - during run", false, false); + + dumpTensor((const int8_t*)scratch1, 768, "quantized input matrix", + false, false); + + dumpTensor(ffn1_.weights_int8, 768, + "weights - during run", false, false); + + dumpTensor((const int8_t*)scratch2, 768, + "some quantized output values", false, false); + dumpTensor(ffn1_.output_scaling_factors, 768, + "output_scaling_factors - during run", + false, false); + */ // 3. de-quantize outputs - fused with bias add (scratch2 -> scratch1) // TODO: Fuse this with the above GEMM deQuantizeOutputMatrixBiasAdd( (half*)scratch1, (const int8_t*)scratch2, batch, num_outputs, 1, - ffn1_.output_scaling_factors, (const half*)ffn_dense1_b, stream, + ffn1_.output_scaling_factors, ffn1_.output_deq_factors, + (const half*)ffn_dense1_b, stream, has_smolgen_ ? RELU_2 : act); + + // Ankan - test! + //dumpTensor((const DataType*)scratch1, 768, + // "runtime output values after bias and RELU2", false, false); + //exit(0); + } else { cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, ffn_dense1_w, num_inputs, scratch0, num_inputs, 0.0f, scratch1, num_outputs); + /* + cutlassMatrixMulBTransposed( + (const half*)scratch0, (const half*)ffn_dense1_w, (half*)scratch1, + batch, num_outputs, num_inputs, 1, 0, 0, 0, true); + */ addBiasBatched(scratch1, scratch1, ffn_dense1_b, 1, batch, num_outputs, has_smolgen_ ? RELU_2 : act, stream); // @todo sqr relu to have its own flag + + // Ankan - test! + //dumpTensor((const DataType*)scratch1, 768, + // "Ref output values after bias and RELU2", false, + // false); + //exit(0); } } @@ -2045,35 +2179,75 @@ void EncoderBlock::Eval(int N, DataType* scratch1, DataType* scratch0, if (int8_cali_) { calibrateGemmForInt8( ffn2_.weights_int8, ffn2_.input_scaling_factors, - ffn2_.output_scaling_factors, ffn2_.input_matrix_max_values, scratch1, + ffn2_.output_scaling_factors, ffn2_.output_deq_factors, + ffn2_.input_matrix_max_values, ffn2_.output_matrix_max_values, scratch1, ffn_dense2_w, 64, embedding_op_size_, encoder_dff, 1, N); } const int num_inputs = encoder_dff; const int num_outputs = embedding_op_size_; const int batch = N * 64; - if (int8_inf_) { + if (true && int8_inf_) { // 1. quantize the inputs (scratch1 -> scratch2) // TODO: Fuse this step with above bias add at least (or ideally with the // above GEMM) quantizeActivationMatrix((int8_t*)scratch2, (const half*)scratch1, batch, num_inputs, ffn2_.input_scaling_factors, stream); + /* + dumpTensor((const float*)ffn2_.input_scaling_factors, 768, + "input scaling factors during run", + false, false); + + dumpTensor((const int8_t*)scratch2, 768, "quantized input matrix", + false, false); + dumpTensor(ffn2_.weights_int8, 768, "weights - during run", false, + false); + */ // 2. perform int8 GEMM (scratch2 -> scratch1) + /* cutlassMatrixMulBTransposed((const int8_t*)scratch2, ffn2_.weights_int8, (int8_t*)scratch1, batch, num_outputs, num_inputs, 1, 0, 0, 0, 1.0 / 127.0, 0.0f); + */ + + cutlassMatrixMulBTransposed((const int8_t*)scratch2, ffn2_.weights_int8, ffn2_.output_scaling_factors, + (int8_t*)scratch1, batch, num_outputs, + num_inputs, 1, 0, 0, 0, 0, 1.0f, 0.0f); + ReportCUDAErrors(cudaGetLastError()); + /* + dumpTensor((const int8_t*)scratch1, 768, + "some quantized output values", false, false); + dumpTensor(ffn1_.output_scaling_factors, 768, + "output_scaling_factors - during run", false, false); + */ // 3. de-quantize outputs (scratch1 -> scratch2) // TODO: Fuse this with LN2 (should be easy) deQuantizeOutputMatrixBiasAdd( (half*)scratch2, (const int8_t*)scratch1, batch, num_outputs, 1, - ffn2_.output_scaling_factors, nullptr, stream); + ffn2_.output_scaling_factors, + ffn2_.output_deq_factors, nullptr, stream); + /* + dumpTensor((const half*)scratch2, 768, "dequantized output values", + false, false); + exit(0);*/ } else { - cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)ffn_dense2_w, num_inputs, scratch1, num_inputs, 0.0f, scratch2, num_outputs); + /* + cutlassMatrixMulBTransposed( + (const half*)scratch1, (const half*)ffn_dense2_w, (half*)scratch2, + batch, num_outputs, num_inputs, 1, 0, 0, 0, true); + */ + /* + dumpTensor((const half*)scratch2, 768, "dequantized output values - ref", + false, false); + + exit(0); + */ } } @@ -2212,17 +2386,25 @@ EncoderBlock::~EncoderBlock() { if (int8_inf_) { ReportCUDAErrors(cudaFree(kqv_.weights_int8)); ReportCUDAErrors(cudaFree(kqv_.input_scaling_factors)); + ReportCUDAErrors(cudaFree(kqv_.output_scaling_factors)); ReportCUDAErrors(cudaFree(mha_dense_.weights_int8)); ReportCUDAErrors(cudaFree(mha_dense_.input_scaling_factors)); + ReportCUDAErrors(cudaFree(mha_dense_.output_scaling_factors)); ReportCUDAErrors(cudaFree(ffn1_.weights_int8)); ReportCUDAErrors(cudaFree(ffn1_.input_scaling_factors)); + ReportCUDAErrors(cudaFree(ffn1_.output_scaling_factors)); ReportCUDAErrors(cudaFree(ffn2_.weights_int8)); ReportCUDAErrors(cudaFree(ffn2_.input_scaling_factors)); + ReportCUDAErrors(cudaFree(ffn2_.output_scaling_factors)); } else if (int8_cali_) { free(kqv_.input_matrix_max_values); + free(kqv_.output_matrix_max_values); free(mha_dense_.input_matrix_max_values); + free(mha_dense_.output_matrix_max_values); free(ffn1_.input_matrix_max_values); + free(ffn1_.output_matrix_max_values); free(ffn2_.input_matrix_max_values); + free(ffn2_.output_matrix_max_values); } } diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index 12b9f3f0c2..3d96d955a1 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -338,10 +338,11 @@ struct MatMulQuantizationData { int8_t* weights_int8; // int8 quantized weights float* input_scaling_factors; // per-column scaling factors for input // quantization - float* output_scaling_factors; // per-tensor scaling factors for output - // dequantization (always in CPU memory) - float* input_matrix_max_values; // max values of input matrix (always in CPU - // memory) + float* output_scaling_factors; // per-column scaling factors for output + // dequantization + float* output_deq_factors; // per-tensor. Always in cpu memory (passed as constants to dequantization kernels) + float* input_matrix_max_values; // max values of input matrix (always in CPU memory) + float* output_matrix_max_values; // max values in output matrix (always in CPU memory) }; @@ -381,6 +382,7 @@ class EncoderBlock { // int 8 stuff + int blockIndex_; bool int8_inf_, int8_cali_; MatMulQuantizationData kqv_; MatMulQuantizationData mha_dense_; diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 3c01f53053..489f35fb6f 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -373,33 +373,33 @@ class CudaNetwork : public Network { // * per-channel scaling factors for Input Matrix to QKV GEMM (embedding_op_size floats) // (to use for quantization of the input) // * qunatized (int8) weights for QKV GEMMs (3 * encoder_d_model * embedding_op_size int8_ts) - // * float factorQ, factorK, factorV - // (basically factors needed to de-quantize the output) + // * per-channel scaling factors for quantizing the Outut matrix (encoder_d_model * 3 floats) + // * per-tensor output dequantization factors (3 floats) // // * per-channel scaling factors for the MHA dense layer's input (encoder_d_model floats) // * Qunatized (int8) weights for MHA dense (embedding_op_size * encoder_d_model int8_ts) - // * per-tensor output scaling factor for MHA dense (single float) + // * per-channel output scaling factors for MHA dense (embedding_op_size floats) + // * per-tensor output dequantization factor (1 float) // // * per-channel scaling factors for input to FFN1 (embedding_op_size_ floats) // * Qunatized (int8) weights for FFN1 (encoder_dff * encoder_d_model int8_ts) - // * single output scaling factor for FFN1 (single float) + // * per-channel output scaling factors for FFN1 (encoder_dff floats) + // * per-tensor output dequantization factor (1 float) // // * per-channel scaling factors for input to FFN2 (encoder_dff floats) - // * Qunatized (int8) weights for FFN1 (encoder_d_model * encoder_dff int8_ts) - // * single output scaling factor for FFN1 (single float) + // * Qunatized (int8) weights for FFN2 (embedding_op_size * encoder_dff int8_ts) + // * per-channel output scaling factors for FFN2 (embedding_op_size floats) + // * per-tensor output dequantization factor (1 float) int embedding_op_size = weights.ip_emb_b.size(); int encoder_d_model = weights.encoder[0].mha.q_b.size(); int encoder_dff = weights.encoder[0].ffn.dense1_b.size(); int num_encoders = weights.encoder.size(); int8_weights_size_ = num_encoders * - (embedding_op_size * sizeof(float) + - 3 * embedding_op_size * encoder_d_model + 3 * sizeof(float) + - encoder_d_model * sizeof(float) + - embedding_op_size * encoder_d_model + sizeof(float) + - embedding_op_size * sizeof(float) + encoder_dff * encoder_d_model + - sizeof(float) + encoder_dff * sizeof(float) + - embedding_op_size * encoder_dff + sizeof(float)); + (embedding_op_size * sizeof(float) + 3 * embedding_op_size * encoder_d_model + 3 * (encoder_d_model + 1) * sizeof(float) + + encoder_d_model * sizeof(float) + encoder_d_model * embedding_op_size + (embedding_op_size + 1) * sizeof(float) + + embedding_op_size * sizeof(float) + embedding_op_size * encoder_dff + (encoder_dff + 1) * sizeof(float) + + encoder_dff * sizeof(float) + encoder_dff * embedding_op_size + (embedding_op_size + 1) * sizeof(float)); int8_weights_ = malloc(int8_weights_size_); memset(int8_weights_, 0, int8_weights_size_); @@ -413,8 +413,8 @@ class CudaNetwork : public Network { FILE* fp = fopen("weights_quant.bin", "rb"); if (!fp) { CERR << "ERROR: weights_quant.bin not found. Please run 'lc0 benchmark " - "-t 1 --nodes=1 -w --backend=cuda-fp16 " - "--backend-opts=int8-calibrate' first"; + "-t 1 --nodes=1 -w --backend=cuda " + "--backend-opts=int8-calibrate=true' first"; throw Exception("Quantized weights not found"); } else { int read = fread(int8_weights_, 1, int8_weights_size_, fp); @@ -437,7 +437,7 @@ class CudaNetwork : public Network { w += 3 * weights.ip_emb_b.size() * weights.encoder[0].mha.q_b.size() * sizeof(int8_t); - dumpTensor((float*)w, 3, "scaling factors for output", false, true); + dumpTensor((float*)w, 768, "scaling factors for output", false, true); exit(0); #endif From b9e671162436fa8f776eb64451de4765425c004f Mon Sep 17 00:00:00 2001 From: Ankan Banerjee Date: Thu, 25 May 2023 19:52:28 +0530 Subject: [PATCH 29/70] integrate changes from master for bigger layer norms --- src/neural/cuda/common_kernels.cu | 166 ++++++++++++++++++++++++++++-- 1 file changed, 155 insertions(+), 11 deletions(-) diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index f1ac91637c..cd49e04d85 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -1038,6 +1038,150 @@ __global__ void layer_norm_kernel(int N, int C, T* output, const T* input, const } } + +template +__global__ void layer_norm_kernel_slow(int N, int C, T* output, const T* input, + const T* bias, const T* skip, const T* gammas, + const T* betas, float ep, float alpha, + ActivationFunction act) { + int n = blockIdx.x * blockDim.z + threadIdx.z; + if (n >= N) return; + int c = (threadIdx.y * 32 + threadIdx.x) * 16; + bool oobThread = c >= C; + + int biasIndex = c; + int tensorIndex = n * C + c; + + float val[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + float oth[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + + const bool fp16 = std::is_same::value; + if (!oobThread) { + // Load from memory (16 elements a time) + if (fp16) { + half inp[8]; + copyAs(&inp[0], &input[tensorIndex]); + for (int i = 0; i < 8; i++) val[i] = (float)inp[i]; + copyAs(&inp[0], &input[tensorIndex + 8]); + for (int i = 0; i < 8; i++) val[i + 8] = (float)inp[i]; + copyAs(&inp[0], &bias[biasIndex]); + for (int i = 0; i < 8; i++) oth[i] = (float)inp[i]; + copyAs(&inp[0], &bias[biasIndex + 8]); + for (int i = 0; i < 8; i++) oth[i + 8] = (float)inp[i]; + for (int i = 0; i < 16; i++) val[i] += oth[i]; + } else { + copyAs(&val[0], &input[tensorIndex]); + copyAs(&val[4], &input[tensorIndex + 4]); + copyAs(&val[8], &input[tensorIndex + 8]); + copyAs(&val[12], &input[tensorIndex + 12]); + copyAs(&oth[0], &bias[biasIndex]); + copyAs(&oth[4], &bias[biasIndex + 4]); + copyAs(&oth[8], &bias[biasIndex + 8]); + copyAs(&oth[12], &bias[biasIndex + 12]); + for (int i = 0; i < 16; i++) val[i] += oth[i]; + } + } + + if (!oobThread) { + // Load from memory (16 elements a time) + if (fp16) { + half inp[8]; + copyAs(&inp[0], &skip[tensorIndex]); + for (int i = 0; i < 8; i++) oth[i] = (float)inp[i]; + copyAs(&inp[0], &skip[tensorIndex + 8]); + for (int i = 0; i < 8; i++) oth[i + 8] = (float)inp[i]; + } else { + copyAs(&oth[0], &skip[tensorIndex]); + copyAs(&oth[4], &skip[tensorIndex + 4]); + copyAs(&oth[8], &skip[tensorIndex + 8]); + copyAs(&oth[12], &skip[tensorIndex + 12]); + } + } + + // 1. Compute mean + float s = 0; + if (!oobThread) + for (int i = 0; i < 16; i++) { + val[i] = activate(val[i], act) + oth[i] * alpha; + s += val[i]; + } + + s = shared_sum_for_layer_norm(s); + float mean = s / C; + + // 2. Compute varience + s = 0; + if (!oobThread) + for (int i = 0; i < 16; i++) { + float d = val[i] - mean; + float d_sq = d * d; + s += d_sq; + } + s = shared_sum_for_layer_norm(s); + float var = s / C; + + if (!oobThread) { + // Load from memory (16 elements a time) + if (fp16) { + half inp[8]; + copyAs(&inp[0], &gammas[biasIndex]); + for (int i = 0; i < 8; i++) oth[i] = (float)inp[i]; + copyAs(&inp[0], &gammas[biasIndex + 8]); + for (int i = 0; i < 8; i++) oth[i + 8] = (float)inp[i]; + } else { + copyAs(&oth[0], &gammas[biasIndex]); + copyAs(&oth[4], &gammas[biasIndex + 4]); + copyAs(&oth[8], &gammas[biasIndex + 8]); + copyAs(&oth[12], &gammas[biasIndex + 12]); + } + } + + // 3. Normalize + for (int i = 0; i < 16; i++) { + float d = val[i] - mean; + float norm = d / sqrt(var + ep); + float op = norm * oth[i]; + val[i] = op; + } + + if (!oobThread) { + // Load from memory (16 elements a time) + if (fp16) { + half inp[8]; + copyAs(&inp[0], &betas[biasIndex]); + for (int i = 0; i < 8; i++) oth[i] = (float)inp[i]; + copyAs(&inp[0], &betas[biasIndex + 8]); + for (int i = 0; i < 8; i++) oth[i + 8] = (float)inp[i]; + } else { + copyAs(&oth[0], &betas[biasIndex]); + copyAs(&oth[4], &betas[biasIndex + 4]); + copyAs(&oth[8], &betas[biasIndex + 8]); + copyAs(&oth[12], &betas[biasIndex + 12]); + } + } + + for (int i = 0; i < 16; i++) { + val[i] += oth[i]; + } + + if (!oobThread) { + // Write to memory + if (fp16) { + half op[8]; + for (int i = 0; i < 8; i++) op[i] = (half)val[i]; + copyAs(&output[tensorIndex], &op[0]); + for (int i = 0; i < 8; i++) op[i] = (half)val[i + 8]; + copyAs(&output[tensorIndex + 8], &op[0]); + } else { + copyAs(&output[tensorIndex], &val[0]); + copyAs(&output[tensorIndex + 4], &val[4]); + copyAs(&output[tensorIndex + 8], &val[8]); + copyAs(&output[tensorIndex + 12], &val[12]); + } + } +} + + __global__ void layer_norm_kernel_8_el_per_thread( int N, int C, half* output, const half* input, const half* bias, const half* skip, const half* gammas, const half* betas, float ep, @@ -1117,15 +1261,12 @@ void LayerNorm(int N, int C, T* output, const T* input, const T* bias, const T* skip, const T* gammas, const T* betas, float ep, float alpha, ActivationFunction act, cudaStream_t stream) { const bool fp16 = std::is_same::value; - // process 4 elements per thread to achieve close to peak memory bandwidth - if (C % 4 != 0) throw Exception("unsupported filter size"); - if (C > 4096) - { - if (!fp16 || (C % 8 != 0) || C > 8192) - throw Exception("unsupported filter size"); - } - const int EL_PER_THREAD = (C > 4096) ? 8 : 4; + // process 4 or 8 elements per thread to achieve close to peak memory bandwidth + if (C > 16384) throw Exception("unsupported filter size"); + const int EL_PER_THREAD = (C <= 4096) ? 4 : (C <= 8192 && fp16) ? 8 : 16; + + if (C % EL_PER_THREAD != 0) throw Exception("unsupported filter size"); dim3 blockDim, gridDim; blockDim.x = 32; @@ -1136,14 +1277,17 @@ void LayerNorm(int N, int C, T* output, const T* input, const T* bias, gridDim.y = 1; gridDim.z = 1; - if (EL_PER_THREAD == 8 && fp16) + if (EL_PER_THREAD == 8) layer_norm_kernel_8_el_per_thread<<>>( N, C, (half*)output, (const half*)input, (const half*)bias, (const half*)skip, (const half*)gammas, (const half*)betas, ep, alpha, act); - else - layer_norm_kernel<<>>( + else if (EL_PER_THREAD == 16) + layer_norm_kernel_slow<<>>( N, C, output, input, bias, skip, gammas, betas, ep, alpha, act); + else + layer_norm_kernel<<>>( + N, C, output, input, bias, skip, gammas, betas, ep, alpha, act); ReportCUDAErrors(cudaGetLastError()); } From 787090b56fcf6b4d6caadbd4dc2724be4c25e25b Mon Sep 17 00:00:00 2001 From: Aniebiet Udoh Date: Thu, 7 Sep 2023 06:14:49 +0100 Subject: [PATCH 30/70] Common changes for new multiple head architecture. --- libs/lczero-common | 2 +- src/neural/network_legacy.cc | 57 ++++++++++++++++++++++++++++- src/neural/network_legacy.h | 69 ++++++++++++++++++++++++++++++++++++ 3 files changed, 126 insertions(+), 2 deletions(-) diff --git a/libs/lczero-common b/libs/lczero-common index fafda0f59c..66d43cf9d4 160000 --- a/libs/lczero-common +++ b/libs/lczero-common @@ -1 +1 @@ -Subproject commit fafda0f59c8511b5d933ef758c1e4b10a62da1e0 +Subproject commit 66d43cf9d41a8c6987083a1f5f026f8d0e2c0307 diff --git a/src/neural/network_legacy.cc b/src/neural/network_legacy.cc index 387590de6b..be166289df 100644 --- a/src/neural/network_legacy.cc +++ b/src/neural/network_legacy.cc @@ -30,10 +30,17 @@ static constexpr float kEpsilon = 1e-5f; LegacyWeights::LegacyWeights(const pblczero::Weights& weights) : input(weights.input()), + ip_emb_preproc_w(LayerAdapter(weights.ip_emb_preproc_w()).as_vector()), + ip_emb_preproc_b(LayerAdapter(weights.ip_emb_preproc_b()).as_vector()), ip_emb_w(LayerAdapter(weights.ip_emb_w()).as_vector()), ip_emb_b(LayerAdapter(weights.ip_emb_b()).as_vector()), + ip_emb_ln_gammas(LayerAdapter(weights.ip_emb_ln_gammas()).as_vector()), + ip_emb_ln_betas(LayerAdapter(weights.ip_emb_ln_betas()).as_vector()), ip_mult_gate(LayerAdapter(weights.ip_mult_gate()).as_vector()), ip_add_gate(LayerAdapter(weights.ip_add_gate()).as_vector()), + ip_emb_ffn(weights.ip_emb_ffn()), + ip_emb_ffn_ln_gammas(LayerAdapter(weights.ip_emb_ffn_ln_gammas()).as_vector()), + ip_emb_ffn_ln_betas(LayerAdapter(weights.ip_emb_ffn_ln_betas()).as_vector()), policy1(weights.policy1()), policy(weights.policy()), ip_pol_w(LayerAdapter(weights.ip_pol_w()).as_vector()), @@ -58,7 +65,10 @@ LegacyWeights::LegacyWeights(const pblczero::Weights& weights) ip2_mov_w(LayerAdapter(weights.ip2_mov_w()).as_vector()), ip2_mov_b(LayerAdapter(weights.ip2_mov_b()).as_vector()), smolgen_w(LayerAdapter(weights.smolgen_w()).as_vector()), - has_smolgen(weights.has_smolgen_w()) { + has_smolgen(weights.has_smolgen_w()), + policy_heads(weights.policy_heads()), + value_heads(weights.value_heads()), + has_multiheads(weights.has_policy_heads() && weights.has_value_heads()) { for (const auto& res : weights.residual()) { residual.emplace_back(res); } @@ -180,4 +190,49 @@ LegacyWeights::Smolgen::Smolgen( ln2_gammas(LayerAdapter(smolgen.ln2_gammas()).as_vector()), ln2_betas(LayerAdapter(smolgen.ln2_betas()).as_vector()) {} +LegacyWeights::PolicyHead::PolicyHead( + const pblczero::Weights::PolicyHead& policyhead) + : policy1(policyhead.policy1()), + policy(policyhead.policy()), + ip_pol_w(LayerAdapter(policyhead.ip_pol_w()).as_vector()), + ip_pol_b(LayerAdapter(policyhead.ip_pol_b()).as_vector()), + ip2_pol_w(LayerAdapter(policyhead.ip2_pol_w()).as_vector()), + ip2_pol_b(LayerAdapter(policyhead.ip2_pol_b()).as_vector()), + ip3_pol_w(LayerAdapter(policyhead.ip3_pol_w()).as_vector()), + ip3_pol_b(LayerAdapter(policyhead.ip3_pol_b()).as_vector()), + ip4_pol_w(LayerAdapter(policyhead.ip4_pol_w()).as_vector()) { + + pol_encoder_head_count = policyhead.pol_headcount(); + for (const auto& enc : policyhead.pol_encoder()) { + pol_encoder.emplace_back(enc); + } +} + +LegacyWeights::ValueHead::ValueHead( + const pblczero::Weights::ValueHead& valuehead) + : value(valuehead.value()), + ip_val_w(LayerAdapter(valuehead.ip_val_w()).as_vector()), + ip_val_b(LayerAdapter(valuehead.ip_val_b()).as_vector()), + ip1_val_w(LayerAdapter(valuehead.ip1_val_w()).as_vector()), + ip1_val_b(LayerAdapter(valuehead.ip1_val_b()).as_vector()), + ip2_val_w(LayerAdapter(valuehead.ip2_val_w()).as_vector()), + ip2_val_b(LayerAdapter(valuehead.ip2_val_b()).as_vector()), + ip_val_err_w(LayerAdapter(valuehead.ip_val_err_w()).as_vector()), + ip_val_err_b(LayerAdapter(valuehead.ip_val_err_b()).as_vector()) {} + +LegacyWeights::PolicyHeads::PolicyHeads( + const pblczero::Weights::PolicyHeads& policyheads) + : ip_pol_w(LayerAdapter(policyheads.ip_pol_w()).as_vector()), + ip_pol_b(LayerAdapter(policyheads.ip_pol_b()).as_vector()), + vanilla(policyheads.vanilla()), + optimistic_st(policyheads.optimistic_st()), + soft(policyheads.soft()), + opponent(policyheads.opponent()) {} + +LegacyWeights::ValueHeads::ValueHeads( + const pblczero::Weights::ValueHeads& valueheads) + : winner(valueheads.winner()), + q(valueheads.winner()), + st(valueheads.winner()) {} + } // namespace lczero diff --git a/src/neural/network_legacy.h b/src/neural/network_legacy.h index 5715c40fbb..bf09eb3376 100644 --- a/src/neural/network_legacy.h +++ b/src/neural/network_legacy.h @@ -100,17 +100,80 @@ struct LegacyWeights { Vec ln2_betas; }; + struct PolicyHead { + explicit PolicyHead(const pblczero::Weights::PolicyHead& policyhead); + // Policy head + // Extra convolution for AZ-style policy head + ConvBlock policy1; + ConvBlock policy; + Vec ip_pol_w; + Vec ip_pol_b; + // Extra params for attention policy head + Vec ip2_pol_w; + Vec ip2_pol_b; + Vec ip3_pol_w; + Vec ip3_pol_b; + Vec ip4_pol_w; + int pol_encoder_head_count; + std::vector pol_encoder; + }; + + struct ValueHead { + explicit ValueHead(const pblczero::Weights::ValueHead& valuehead); + // Value head + ConvBlock value; + Vec ip_val_w; + Vec ip_val_b; + Vec ip1_val_w; + Vec ip1_val_b; + Vec ip2_val_w; + Vec ip2_val_b; + Vec ip_val_err_w; + Vec ip_val_err_b; + }; + + struct PolicyHeads { + explicit PolicyHeads(const pblczero::Weights::PolicyHeads& policyheads); + Vec ip_pol_w; + Vec ip_pol_b; + PolicyHead vanilla; + PolicyHead optimistic_st; + PolicyHead soft; + PolicyHead opponent; + }; + + struct ValueHeads { + explicit ValueHeads(const pblczero::Weights::ValueHeads& valueheads); + ValueHead winner; + ValueHead q; + ValueHead st; + }; + // Input convnet. ConvBlock input; + // Embedding preprocess layer. + Vec ip_emb_preproc_w; + Vec ip_emb_preproc_b; + // Embedding layer Vec ip_emb_w; Vec ip_emb_b; + // Embedding layernorm + // @todo can this be folded into weights? + Vec ip_emb_ln_gammas; + Vec ip_emb_ln_betas; + // Input gating Vec ip_mult_gate; Vec ip_add_gate; + // Embedding feedforward network + FFN ip_emb_ffn; + Vec ip_emb_ffn_ln_gammas; + Vec ip_emb_ffn_ln_betas; + // Encoder stack. std::vector encoder; int encoder_head_count; @@ -143,6 +206,12 @@ struct LegacyWeights { Vec ip2_val_w; Vec ip2_val_b; + + // Policy and value multiheads + ValueHeads value_heads; + PolicyHeads policy_heads; + bool has_multiheads; + // Moves left head ConvBlock moves_left; Vec ip_mov_w; From 9cbc0720a865d3ae0e91f26ffb99365e0af08c19 Mon Sep 17 00:00:00 2001 From: Aniebiet Udoh Date: Thu, 7 Sep 2023 08:30:59 +0100 Subject: [PATCH 31/70] Fix typo bug. --- src/neural/network_legacy.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/neural/network_legacy.cc b/src/neural/network_legacy.cc index be166289df..be3ea76ffc 100644 --- a/src/neural/network_legacy.cc +++ b/src/neural/network_legacy.cc @@ -232,7 +232,7 @@ LegacyWeights::PolicyHeads::PolicyHeads( LegacyWeights::ValueHeads::ValueHeads( const pblczero::Weights::ValueHeads& valueheads) : winner(valueheads.winner()), - q(valueheads.winner()), - st(valueheads.winner()) {} + q(valueheads.q()), + st(valueheads.st()) {} } // namespace lczero From bf5d82a0a564a88646cd9a0324d620125336e886 Mon Sep 17 00:00:00 2001 From: Aniebiet Udoh Date: Fri, 8 Sep 2023 04:11:22 +0100 Subject: [PATCH 32/70] Fix circleci failures. --- .circleci/config.yml | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 5ef7d003e4..0a5e2aa9e4 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -7,9 +7,7 @@ jobs: - checkout - run: name: "Pull Submodules" - command: | - git submodule init - git submodule update --remote + command: git submodule update --init - run: name: Update Meson command: pip3 install --upgrade meson==0.58.1 @@ -31,9 +29,7 @@ jobs: - checkout - run: name: "Pull Submodules" - command: | - git submodule init - git submodule update --remote + command: git submodule update --init --no-single-branch - run: name: Install build tools command: | From 4157c4f5b71205d60f4efa62b2f073c3a2668d19 Mon Sep 17 00:00:00 2001 From: Aniebiet Udoh Date: Fri, 8 Sep 2023 11:41:17 +0100 Subject: [PATCH 33/70] Remove no-single-branch --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 0a5e2aa9e4..ea86d77c53 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -29,7 +29,7 @@ jobs: - checkout - run: name: "Pull Submodules" - command: git submodule update --init --no-single-branch + command: git submodule update --init - run: name: Install build tools command: | From 9d7b0dbbc4cfbefd967796d0c4bdcd22a6b8833e Mon Sep 17 00:00:00 2001 From: Aniebiet Udoh Date: Tue, 12 Sep 2023 04:56:39 +0200 Subject: [PATCH 34/70] Implement new input encoding architecture. --- src/neural/cuda/common_kernels.cu | 37 ++++-- src/neural/cuda/kernels.h | 4 +- src/neural/cuda/layers.cc | 193 ++++++++++++++++++++++++------ src/neural/cuda/layers.h | 17 ++- src/neural/cuda/network_cuda.cc | 25 ++-- 5 files changed, 215 insertions(+), 61 deletions(-) diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index 2959cfc9ec..044a2343f3 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -1227,20 +1227,26 @@ void ComputePromotionLogits(int N, int C, T* output, const T* keys, template __global__ void preprocess_for_attention_body_kernel(T* output, const T* input, - const float* encoding) { + const T* encoding, + int input_size, int encoding_size, + bool new_encoding) { int n = blockIdx.x; int hw = blockIdx.y; int c = threadIdx.x; T op; - if (c >= kInputPlanes) { - // concatenate from fixed pos encoding array - op = (T)(encoding[64 * hw + (c - kInputPlanes)]); + if (c >= input_size) { + // concatenate from position encoding array + if (new_encoding) { + op = (T)(encoding[n * 64 * encoding_size + hw * encoding_size + (c - input_size)]); + } else { + op = (T)(encoding[64 * hw + (c - input_size)]); + } } else { - op = input[n * kInputPlanes * 64 + c * 64 + hw]; // nchw + op = input[n * input_size * 64 + c * 64 + hw]; // nchw } - constexpr int outputC = kInputPlanes + kNumPosEncodingChannels; + int outputC = input_size + encoding_size; // convert to nhwc output[n * 64 * outputC + hw * outputC + c] = op; @@ -1248,15 +1254,17 @@ __global__ void preprocess_for_attention_body_kernel(T* output, const T* input, template void inputPreprocessForAttentionBody(T* output, const T* input, - const float* encoding, int N, + const T* encoding, int N, + int input_size, int encoding_size, + bool new_encoding, cudaStream_t stream) { // N * 64 blocks // (kInputPlanes + kNumPosEncodingChannels) threads // Each thread computes a single output element dim3 gridSize = dim3(N, 64); - int blockSize = kInputPlanes + kNumPosEncodingChannels; + int blockSize = input_size + encoding_size; preprocess_for_attention_body_kernel - <<>>(output, input, encoding); + <<>>(output, input, encoding, input_size, encoding_size, new_encoding); } template @@ -1565,13 +1573,20 @@ template void convertNCHWtoNHWC(half* output_tensor, template void inputPreprocessForAttentionBody(half* output, const half* input, - const float* encoding, - int N, cudaStream_t stream); + const half* encoding, + int N, + int input_size, + int encoding_size, + bool new_encoding, + cudaStream_t stream); template void inputPreprocessForAttentionBody(float* output, const float* input, const float* encoding, int N, + int input_size, + int encoding_size, + bool new_encoding, cudaStream_t stream); template void applyInputGating(half* output, const half* input, diff --git a/src/neural/cuda/kernels.h b/src/neural/cuda/kernels.h index fa405c1946..f7ea8aa3c2 100644 --- a/src/neural/cuda/kernels.h +++ b/src/neural/cuda/kernels.h @@ -147,7 +147,9 @@ void ComputePromotionLogits(int N, int C, T* output, const T* keys, template void inputPreprocessForAttentionBody(T* output, const T* input, - const float* encoding, int N, + const T* encoding, int N, + int input_size, int encoding_size, + bool new_encoding, cudaStream_t stream); template diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index d4e02b3a46..d232768478 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -39,7 +39,7 @@ namespace lczero { -#if 0 +#if 1 // debug code to dump allocation in GPU memory template void dumpTensor(T* memory, int elements, const char* message, bool only_summary = false) { @@ -2051,15 +2051,38 @@ AttentionBody::AttentionBody(const LegacyWeights& weights, input_c_(input_c), has_gating_(weights.ip_mult_gate.size() > 0 && weights.ip_add_gate.size() > 0), - has_smolgen_(weights.has_smolgen) { + has_smolgen_(weights.has_smolgen), + new_encoding_(weights.has_multiheads) { allocAndUpload(&ip_emb_w_, weights.ip_emb_w, scratch); allocAndUpload(&ip_emb_b_, weights.ip_emb_b, scratch); - { + if (new_encoding_) { + allocAndUpload(&ip_emb_pre_w_, weights.ip_emb_preproc_w, scratch); + allocAndUpload(&ip_emb_pre_b_, weights.ip_emb_preproc_b, scratch); + + allocAndUpload(&ip_emb_ln_g_, weights.ip_emb_ln_gammas, scratch); + allocAndUpload(&ip_emb_ln_b_, weights.ip_emb_ln_betas, scratch); + + allocAndUpload(&ip_emb_ffn_d1_w_, weights.ip_emb_ffn.dense1_w, scratch); + allocAndUpload(&ip_emb_ffn_d1_b_, weights.ip_emb_ffn.dense1_b, scratch); + + allocAndUpload(&ip_emb_ffn_d2_w_, weights.ip_emb_ffn.dense2_w, scratch); + allocAndUpload(&ip_emb_ffn_d2_b_, weights.ip_emb_ffn.dense2_b, scratch); + + allocAndUpload(&ip_emb_ffn_ln_g_, weights.ip_emb_ffn_ln_gammas, scratch); + allocAndUpload(&ip_emb_ffn_ln_b_, weights.ip_emb_ffn_ln_betas, scratch); + + // 12 is the number of input channels used for the input encoding. + embedding_dense_size_ = weights.ip_emb_preproc_b.size() / 64; + embedding_ffn_size_ = weights.ip_emb_ffn.dense2_b.size(); + embedding_ffn_dff_ = weights.ip_emb_ffn.dense1_b.size(); + } + else { size_t size = 64 * kNumPosEncodingChannels * sizeof(float); ReportCUDAErrors(cudaMalloc(&pos_encoding_, size)); ReportCUDAErrors( - cudaMemcpy(pos_encoding_, kPosEncoding, size, cudaMemcpyHostToDevice)); + cudaMemcpy(scratch, kPosEncoding, size, cudaMemcpyHostToDevice)); + copyTypeConverted(pos_encoding_, (float*)scratch, size, 0); } if (has_gating_) { @@ -2087,7 +2110,21 @@ template AttentionBody::~AttentionBody() { ReportCUDAErrors(cudaFree(ip_emb_w_)); ReportCUDAErrors(cudaFree(ip_emb_b_)); - ReportCUDAErrors(cudaFree(pos_encoding_)); + if (new_encoding_) { + ReportCUDAErrors(cudaFree(ip_emb_pre_w_)); + ReportCUDAErrors(cudaFree(ip_emb_pre_b_)); + ReportCUDAErrors(cudaFree(ip_emb_ln_g_)); + ReportCUDAErrors(cudaFree(ip_emb_ln_b_)); + ReportCUDAErrors(cudaFree(ip_emb_ffn_d1_w_)); + ReportCUDAErrors(cudaFree(ip_emb_ffn_d1_b_)); + ReportCUDAErrors(cudaFree(ip_emb_ffn_d2_w_)); + ReportCUDAErrors(cudaFree(ip_emb_ffn_d2_b_)); + ReportCUDAErrors(cudaFree(ip_emb_ffn_ln_g_)); + ReportCUDAErrors(cudaFree(ip_emb_ffn_ln_b_)); + } + else { + ReportCUDAErrors(cudaFree(pos_encoding_)); + } if (has_gating_) { ReportCUDAErrors(cudaFree(ip_mult_gate_)); ReportCUDAErrors(cudaFree(ip_add_gate_)); @@ -2113,19 +2150,49 @@ void AttentionBody::Eval(int N, DataType* output, if (num_resi_blocks_ == 0) { assert(inputC == kInputPlanes); /* - # if there are no residual blocks (pure transformer), do some input - processing - flow = tf.transpose(inputs, perm=[0, 2, 3, 1]) - flow = tf.reshape(flow, [-1, 64, tf.shape(inputs)[1]]) - # add positional encoding for each square to the input - positional_encoding = tf.broadcast_to(tf.convert_to_tensor(self.POS_ENC, - dtype=self.model_dtype), [tf.shape(flow)[0], 64, - tf.shape(self.POS_ENC)[2]]) flow = tf.concat([flow, positional_encoding], - axis=2) + # if there are no residual blocks (pure transformer), do some input + processing */ - inputPreprocessForAttentionBody((DataType*)scratch, input, pos_encoding_, N, - stream); - inputC += kNumPosEncodingChannels; + if (new_encoding_) { + // New encoding is made of dense layer fed with input from a 12-channel slice of the input tensor. + // pos_info = flow[..., :12] + // pos_info_flat = tf.reshape(pos_info, [-1, 64 * 12]) + // pos_info_processed = tf.keras.layers.Dense(64*self.embedding_dense_sz, + // name=name+"embedding/preprocess")(pos_info_flat) + const int num_outputs = 64 * embedding_dense_size_; + const int num_inputs = 64 * 12; + const int batch = N; + + convertNCHWtoNHWC((DataType*)scratch, input, N, inputC, N, 12, 8, 8); + cublasXgemm( + cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, + (const DataType*)ip_emb_pre_w_, num_inputs, + (const DataType*)scratch, num_inputs, + 0.0f, buffer1, num_outputs); + + // addBiasBatched(buffer1, buffer1, ip_emb_pre_b_, batch, N, num_outputs, + // ACTIVATION_NONE, stream); + const int size = num_outputs * N; + // @todo addBiasBatched has a 4096 channel limit, needs refactoring. + addVectors(buffer1, buffer1, ip_emb_pre_b_, size, size, num_outputs, ACTIVATION_NONE, stream); + inputPreprocessForAttentionBody((DataType*)scratch, input, buffer1, N, kInputPlanes, + embedding_dense_size_, true, stream); + inputC += embedding_dense_size_; + } + else { + /* + flow = tf.transpose(inputs, perm=[0, 2, 3, 1]) + flow = tf.reshape(flow, [-1, 64, tf.shape(inputs)[1]]) + # add positional encoding for each square to the input + positional_encoding = tf.broadcast_to(tf.convert_to_tensor(self.POS_ENC, + dtype=self.model_dtype), [tf.shape(flow)[0], 64, + tf.shape(self.POS_ENC)[2]]) flow = tf.concat([flow, positional_encoding], + axis=2) + */ + inputPreprocessForAttentionBody((DataType*)scratch, input, pos_encoding_, N, + kInputPlanes, kNumPosEncodingChannels, false, stream); + inputC += kNumPosEncodingChannels; + } } else { // #redirect flow through encoder blocks // flow = tf.transpose(flow, perm = [ 0, 2, 3, 1 ]) @@ -2133,25 +2200,81 @@ void AttentionBody::Eval(int N, DataType* output, convertNCHWtoNHWC((DataType*)scratch, input, N, inputC, N, inputC, 8, 8); } - // 1. square embedding (fully connected layer) - // Input data in NHWC layout N*(64)*C, output is N*(64)*embedding_op_size_ - DataType* embedding = output_tensor; - { - const int num_outputs = embedding_op_size_; - const int num_inputs = inputC; - const int batch = N * 64; - cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, (const DataType*)ip_emb_w_, - num_inputs, (DataType*)scratch, num_inputs, 0.0f, - embedding, num_outputs); - addBiasBatched(embedding, embedding, ip_emb_b_, 1, batch, num_outputs, - activations_.default_activation, stream); - } + if (new_encoding_) { + // 1. square embedding (fully connected layer) + // Input data in NHWC layout N*(64)*C, output is N*(64)*embedding_op_size_ + DataType* embedding = output_tensor; + DataType* temp = (DataType*)scratch; + { + const int num_outputs = embedding_op_size_; + const int num_inputs = inputC; + const int batch = N * 64; + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, (const DataType*)ip_emb_w_, + num_inputs, temp, num_inputs, 0.0f, + embedding, num_outputs); + // embedding layer norm with fused in bias add of previous gemm. + LayerNorm(N * 64, embedding_op_size_, temp, embedding, ip_emb_b_, + embedding, ip_emb_ln_g_, ip_emb_ln_b_, 1e-3, 0.0, + activations_.default_activation, stream); + } - // Input gating - if (has_gating_) { - applyInputGating(embedding, embedding, ip_mult_gate_, - ip_add_gate_, N, 64, embedding_op_size_, stream); + // Input gating + if (has_gating_) { + applyInputGating(temp, temp, ip_mult_gate_, + ip_add_gate_, N, 64, embedding_op_size_, stream); + } + + // embedding FFN dense 1 + { + const int num_inputs = embedding_ffn_size_; + const int num_outputs = embedding_ffn_dff_; // encoder_dff + const int batch = N * 64; + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, (const DataType*)ip_emb_ffn_d1_w_, num_inputs, + temp, num_inputs, 0.0f, buffer1, num_outputs); + addBiasBatched(buffer1, buffer1, ip_emb_ffn_d1_b_, 1, batch, + num_outputs, activations_.ffn_activation, stream); + } + + // embedding FFN dense 2 + { + const int num_inputs = embedding_ffn_dff_; // encoder_dff + const int num_outputs = embedding_ffn_size_; + const int batch = N * 64; + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, (const DataType*)ip_emb_ffn_d2_w_, num_inputs, + buffer1, num_inputs, 0.0f, buffer2, num_outputs); + // // LN2: skip connection and layer normilization (also bias add of prev gemm) + // // buffer1/scratch -> in_out_tensor + float alpha = (float)pow(2. * encoder_weights_.size(), -0.25); + LayerNorm(N * 64, embedding_ffn_size_, embedding, temp, + ip_emb_ffn_d2_b_, buffer2, ip_emb_ffn_ln_g_, ip_emb_ffn_ln_b_, + 1e-3, alpha, ACTIVATION_NONE, stream); + } + + dumpTensor(embedding, embedding_ffn_size_ * 64 * N, "FFN1", false); + } + else { + // 1. square embedding (fully connected layer) + // Input data in NHWC layout N*(64)*C, output is N*(64)*embedding_op_size_ + DataType* embedding = output_tensor; + { + const int num_outputs = embedding_op_size_; + const int num_inputs = inputC; + const int batch = N * 64; + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, (const DataType*)ip_emb_w_, + num_inputs, (DataType*)scratch, num_inputs, 0.0f, + embedding, num_outputs); + addBiasBatched(embedding, embedding, ip_emb_b_, 1, batch, num_outputs, + activations_.default_activation, stream); + } + // Input gating + if (has_gating_) { + applyInputGating(embedding, embedding, ip_mult_gate_, + ip_add_gate_, N, 64, embedding_op_size_, stream); + } } // 2. Encoder blocks diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index 174b18ec15..5cfeffde63 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -482,12 +482,21 @@ class AttentionBody : public BaseLayer { DataType*** = nullptr) override; private: - // GPU allocations to hold various weights used by the attention policy head - DataType *ip_emb_w_, *ip_emb_b_; // "embedding" layer in net body - DataType *ip_mult_gate_, *ip_add_gate_; // input gating + // GPU allocations to hold various weights used by the attention net body. + DataType *ip_emb_pre_w_, *ip_emb_pre_b_; // input position preprocessing weights. + DataType *ip_emb_w_, *ip_emb_b_; // "embedding" layer in net body + DataType *ip_emb_ln_g_, *ip_emb_ln_b_; // input embedding layernorm gamma and beta + DataType *ip_mult_gate_, *ip_add_gate_; // input gating + DataType *ip_emb_ffn_d1_w_, *ip_emb_ffn_d1_b_; // input embedding FFN dense1 weights + DataType *ip_emb_ffn_d2_w_, *ip_emb_ffn_d2_b_; // input embedding FFN dense2 weights + DataType *ip_emb_ffn_ln_g_, *ip_emb_ffn_ln_b_; // input embedding FFN layernorm gamma and beta DataType *smolgen_global_; // global smolgen weights for all encoder layers - float* pos_encoding_; + bool new_encoding_; // flag for new position encoding + DataType *pos_encoding_; + int embedding_dense_size_; int embedding_op_size_; + int embedding_ffn_size_; + int embedding_ffn_dff_; int encoder_head_count_; std::vector*> encoder_weights_; Activations activations_; diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 275a332e6e..3cf0a527ab 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -50,23 +50,23 @@ template class CudaNetwork; static size_t getMaxAttentionHeadSize(const LegacyWeights& weights, int N) { - const size_t embedding_op_size = weights.ip_pol_b.size(); - const size_t policy_d_model = weights.ip2_pol_b.size(); - assert(policy_d_model == weights.ip3_pol_b.size()); + const size_t embedding_op_size = weights.policy_heads.vanilla.ip_pol_b.size(); + const size_t policy_d_model = weights.policy_heads.vanilla.ip2_pol_b.size(); + assert(policy_d_model == weights.policy_heads.vanilla.ip3_pol_b.size()); size_t encoder_d_model = 0; size_t encoder_dff = 0; - if (weights.pol_encoder.size() > 0) { - encoder_d_model = weights.pol_encoder[0].mha.q_b.size(); - encoder_dff = weights.pol_encoder[0].ffn.dense1_b.size(); + if (weights.policy_heads.vanilla.pol_encoder.size() > 0) { + encoder_d_model = weights.policy_heads.vanilla.pol_encoder[0].mha.q_b.size(); + encoder_dff = weights.policy_heads.vanilla.pol_encoder[0].ffn.dense1_b.size(); - assert(encoder_d_model == weights.pol_encoder[0].mha.k_b.size()); - assert(encoder_d_model == weights.pol_encoder[0].mha.v_b.size()); - assert(embedding_op_size == weights.pol_encoder[0].ffn.dense2_b.size()); + assert(encoder_d_model == weights.policy_heads.vanilla.pol_encoder[0].mha.k_b.size()); + assert(encoder_d_model == weights.policy_heads.vanilla.pol_encoder[0].mha.v_b.size()); + assert(embedding_op_size == weights.policy_heads.vanilla.pol_encoder[0].ffn.dense2_b.size()); } - const size_t encoder_heads = weights.pol_encoder_head_count; + const size_t encoder_heads = weights.policy_heads.vanilla.pol_encoder_head_count; size_t size = N * 64 * @@ -204,6 +204,9 @@ class CudaNetwork : public Network { max_batch_size_ = options.GetOrDefault("max_batch", 1024); + policy_head_ = options.GetOrDefault("policy_head", "vanilla"); + value_head_ = options.GetOrDefault("value_head", "q"); + showInfo(); int total_gpus; @@ -950,6 +953,8 @@ class CudaNetwork : public Network { bool attn_policy_; bool attn_body_; int num_encoder_blocks_; + std::string policy_head_; + std::string value_head_; std::vector>> network_; BaseLayer* getLastLayer() { return network_.back().get(); } From 9746ea4b512c90188ab98b53d62642683676028e Mon Sep 17 00:00:00 2001 From: Aniebiet Udoh Date: Wed, 13 Sep 2023 23:11:50 +0200 Subject: [PATCH 35/70] Fix layer norms and add new multiple heads. --- src/neural/cuda/common_kernels.cu | 39 ++++++++++------ src/neural/cuda/layers.cc | 76 +++++++++++++++++-------------- src/neural/cuda/layers.h | 2 +- src/neural/cuda/network_cuda.cc | 29 +++++++----- src/neural/cuda/network_cudnn.cc | 3 +- 5 files changed, 88 insertions(+), 61 deletions(-) diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index 044a2343f3..b6f8e49eb6 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -1007,27 +1007,36 @@ __global__ void layer_norm_kernel(int N, int C, T* output, const T* input, } if (!oobThread) { - // Load from memory (16 elements a time) - if (fp16) { - half inp[8]; - copyAs(&inp[0], &skip[tensorIndex]); - for (int i = 0; i < 8; i++) oth[i] = (float)inp[i]; - copyAs(&inp[0], &skip[tensorIndex + 8]); - for (int i = 0; i < 8; i++) oth[i + 8] = (float)inp[i]; - } else { - copyAs(&oth[0], &skip[tensorIndex]); - copyAs(&oth[4], &skip[tensorIndex + 4]); - copyAs(&oth[8], &skip[tensorIndex + 8]); - copyAs(&oth[12], &skip[tensorIndex + 12]); + if (skip != nullptr) { + // Load from memory (16 elements a time) + if (fp16) { + half inp[8]; + copyAs(&inp[0], &skip[tensorIndex]); + for (int i = 0; i < 8; i++) oth[i] = (float)inp[i]; + copyAs(&inp[0], &skip[tensorIndex + 8]); + for (int i = 0; i < 8; i++) oth[i + 8] = (float)inp[i]; + } else { + copyAs(&oth[0], &skip[tensorIndex]); + copyAs(&oth[4], &skip[tensorIndex + 4]); + copyAs(&oth[8], &skip[tensorIndex + 8]); + copyAs(&oth[12], &skip[tensorIndex + 12]); + } } } // 1. Compute mean float s = 0; if (!oobThread) - for (int i = 0; i < 16; i++) { - val[i] = activate(val[i], act) + oth[i] * alpha; - s += val[i]; + if (skip != nullptr) { + for (int i = 0; i < 16; i++) { + val[i] = activate(val[i], act) * alpha + oth[i]; + s += val[i]; + } + } else { + for (int i = 0; i < 16; i++) { + val[i] = activate(val[i], act) * alpha; + s += val[i]; + } } s = shared_sum_for_layer_norm(s); diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 2e8595b22a..bdbba2c75e 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -1422,32 +1422,46 @@ void allocAndUpload(DataType** gpu_dest, std::vector cpu_src, template AttentionPolicyHead::AttentionPolicyHead( BaseLayer* ip, const LegacyWeights& weights, void* scratch, - bool attention_body, ActivationFunction act, int max_batch_size) + bool attention_body, ActivationFunction act, std::string policy_head, int max_batch_size) : BaseLayer(64 * 64 + 24 * 8, 1, 1, ip), attention_body_(attention_body), // Old networks without attention body (e.g. T79) use hardcoded SELU // activations. act_(attention_body ? act : ACTIVATION_SELU) { - embedding_op_size_ = weights.ip_pol_b.size(); - wq_op_size_ = weights.ip2_pol_b.size(); - wk_op_size_ = weights.ip3_pol_b.size(); + // Selected head to construct. + // Use vanilla as default head. + /* @todo check that head exists */ + LegacyWeights::PolicyHead head = weights.policy_heads.vanilla; + if (policy_head == "optimistic" /* @todo check that head exists */) { + head = weights.policy_heads.optimistic_st; + } + else if (policy_head == "soft" /* @todo check that head exists */) { + head = weights.policy_heads.soft; + } + else if (policy_head == "opponent" /* @todo check that head exists */) { + head = weights.policy_heads.opponent; + } + + embedding_op_size_ = weights.policy_heads.ip_pol_b.size(); + wq_op_size_ = head.ip2_pol_b.size(); + wk_op_size_ = head.ip3_pol_b.size(); - encoder_heads_ = weights.pol_encoder_head_count; + encoder_heads_ = head.pol_encoder_head_count; policy_d_model_ = wq_op_size_; - allocAndUpload(&ip_pol_w_, weights.ip_pol_w, scratch); - allocAndUpload(&ip_pol_b_, weights.ip_pol_b, scratch); + allocAndUpload(&ip_pol_w_, weights.policy_heads.ip_pol_w, scratch); + allocAndUpload(&ip_pol_b_, weights.policy_heads.ip_pol_b, scratch); - allocAndUpload(&ip2_pol_w_, weights.ip2_pol_w, scratch); - allocAndUpload(&ip2_pol_b_, weights.ip2_pol_b, scratch); + allocAndUpload(&ip2_pol_w_, head.ip2_pol_w, scratch); + allocAndUpload(&ip2_pol_b_, head.ip2_pol_b, scratch); - allocAndUpload(&ip3_pol_w_, weights.ip3_pol_w, scratch); - allocAndUpload(&ip3_pol_b_, weights.ip3_pol_b, scratch); + allocAndUpload(&ip3_pol_w_, head.ip3_pol_w, scratch); + allocAndUpload(&ip3_pol_b_, head.ip3_pol_b, scratch); // big allocation to hold wq and wk weights one after the other { - size_t elements = weights.ip2_pol_w.size(); - assert(elements == weights.ip3_pol_w.size()); + size_t elements = head.ip2_pol_w.size(); + assert(elements == head.ip3_pol_w.size()); size_t size = elements * sizeof(DataType) * 2; ReportCUDAErrors(cudaMalloc(&wqk_w_, size)); @@ -1456,7 +1470,7 @@ AttentionPolicyHead::AttentionPolicyHead( ReportCUDAErrors(cudaMemcpy(wqk_w_ + elements, ip3_pol_w_, size / 2, cudaMemcpyDeviceToDevice)); - elements = weights.ip2_pol_b.size(); + elements = head.ip2_pol_b.size(); size = elements * sizeof(DataType) * 2; ReportCUDAErrors(cudaMalloc(&wqk_b_, size)); ReportCUDAErrors( @@ -1465,9 +1479,9 @@ AttentionPolicyHead::AttentionPolicyHead( cudaMemcpyDeviceToDevice)); } - allocAndUpload(&ip4_pol_w_, weights.ip4_pol_w, scratch); + allocAndUpload(&ip4_pol_w_, head.ip4_pol_w, scratch); - for (const auto& enc : weights.pol_encoder) { + for (const auto& enc : head.pol_encoder) { EncoderBlock* pW = new EncoderBlock( enc, scratch, encoder_heads_, embedding_op_size_, 1.0f, // using alpha = 1 for now (TODO: may change?) @@ -1677,9 +1691,8 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, num_inputs, 0.0f, buffer1, num_outputs); LayerNorm(batch, num_outputs, scratch, buffer1, smol_dense1_b, - buffer1, smol_ln1_gammas, smol_ln1_betas, 1e-3, - 0.0, /* alpha = 0 since we don't need skip */ - smolgen_activation_, stream); + (DataType*)nullptr, smol_ln1_gammas, smol_ln1_betas, 1e-3, + 1.0, smolgen_activation_, stream); } { @@ -1695,9 +1708,8 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, num_inputs, 0.0f, buffer1, num_outputs); LayerNorm(batch, num_outputs, scratch, buffer1, smol_dense2_b, - buffer1, smol_ln2_gammas, smol_ln2_betas, 1e-3, - 0.0, /* alpha = 0 since we don't need skip */ - smolgen_activation_, stream); + (DataType*)nullptr, smol_ln2_gammas, smol_ln2_betas, 1e-3, + 1.0, smolgen_activation_, stream); } { @@ -1848,7 +1860,7 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // LN1: skip connection and layer normalization (also bias add of prev gemm) // buffer1/in_out_tensor -> scratch LayerNorm(N * 64, embedding_op_size_, scratch, buffer1, mha_dense_b, - in_out_tensor, ln1_gammas, ln1_betas, 1e-6, alpha_, + in_out_tensor, ln1_gammas, ln1_betas, 1e-3, alpha_, ACTIVATION_NONE, stream); // #FFN dense 1, scratch -> in_out_tensor @@ -1876,7 +1888,7 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // LN2: skip connection and layer normilization (also bias add of prev gemm) // buffer1/scratch -> in_out_tensor LayerNorm(N * 64, embedding_op_size_, in_out_tensor, buffer1, - ffn_dense2_b, scratch, ln2_gammas, ln2_betas, 1e-6, + ffn_dense2_b, scratch, ln2_gammas, ln2_betas, 1e-3, alpha_, ACTIVATION_NONE, stream); } @@ -2096,7 +2108,7 @@ AttentionBody::AttentionBody(const LegacyWeights& weights, } int num_encoders = weights.encoder.size(); - float alpha = (float)pow(2.0 * num_encoders, 0.25); + float alpha = (float)pow(2.0 * num_encoders, -0.25); for (const auto& enc : weights.encoder) { EncoderBlock* pW = new EncoderBlock( enc, scratch, encoder_head_count_, embedding_op_size_, alpha, @@ -2215,7 +2227,7 @@ void AttentionBody::Eval(int N, DataType* output, embedding, num_outputs); // embedding layer norm with fused in bias add of previous gemm. LayerNorm(N * 64, embedding_op_size_, temp, embedding, ip_emb_b_, - embedding, ip_emb_ln_g_, ip_emb_ln_b_, 1e-3, 0.0, + (DataType*)nullptr, ip_emb_ln_g_, ip_emb_ln_b_, 1e-3, 1.0, activations_.default_activation, stream); } @@ -2245,17 +2257,15 @@ void AttentionBody::Eval(int N, DataType* output, cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)ip_emb_ffn_d2_w_, num_inputs, buffer1, num_inputs, 0.0f, buffer2, num_outputs); - // // LN2: skip connection and layer normilization (also bias add of prev gemm) - // // buffer1/scratch -> in_out_tensor + // Embedding LN: skip connection and layer normilization (also bias add of prev gemm) + // buffer2 -> embedding float alpha = (float)pow(2. * encoder_weights_.size(), -0.25); - LayerNorm(N * 64, embedding_ffn_size_, embedding, temp, - ip_emb_ffn_d2_b_, buffer2, ip_emb_ffn_ln_g_, ip_emb_ffn_ln_b_, + LayerNorm(N * 64, embedding_ffn_size_, embedding, buffer2, + ip_emb_ffn_d2_b_, temp, ip_emb_ffn_ln_g_, ip_emb_ffn_ln_b_, 1e-3, alpha, ACTIVATION_NONE, stream); } - dumpTensor(embedding, embedding_ffn_size_ * 64 * N, "FFN1", false); - } - else { + } else { // 1. square embedding (fully connected layer) // Input data in NHWC layout N*(64)*C, output is N*(64)*embedding_op_size_ DataType* embedding = output_tensor; diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index 5cfeffde63..1524eaf1ed 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -409,7 +409,7 @@ class AttentionPolicyHead : public BaseLayer { public: AttentionPolicyHead(BaseLayer* ip, const LegacyWeights& weights, void* scratch, bool attention_body, - ActivationFunction act, int max_batch_size); + ActivationFunction act, std::string policy_head, int max_batch_size); ~AttentionPolicyHead(); void Eval(int N, DataType* output, const DataType* input, const DataType* input2, void* scratch, size_t scratch_size, diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 3cf0a527ab..d8704cc8ac 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -204,9 +204,6 @@ class CudaNetwork : public Network { max_batch_size_ = options.GetOrDefault("max_batch", 1024); - policy_head_ = options.GetOrDefault("policy_head", "vanilla"); - value_head_ = options.GetOrDefault("value_head", "q"); - showInfo(); int total_gpus; @@ -443,8 +440,9 @@ class CudaNetwork : public Network { // Policy head. if (attn_policy_) { + std::string policy_head = options.GetOrDefault("policy_head", "vanilla"); auto AttentionPolicy = std::make_unique>( - getLastLayer(), weights, scratch_mem_, attn_body_, act, + getLastLayer(), weights, scratch_mem_, attn_body_, act, policy_head, max_batch_size_); network_.emplace_back(std::move(AttentionPolicy)); @@ -495,9 +493,20 @@ class CudaNetwork : public Network { // Value head. { + // Selected head to construct. + // Use value_q as default head. + std::string value_head = options.GetOrDefault("value_head", "q"); + /* @todo check that head exists */ + LegacyWeights::ValueHead head = weights.value_heads.q; + if (value_head == "winner" /* @todo check that head exists */) { + head = weights.value_heads.winner; + } + else if (value_head == "st" /* @todo check that head exists */) { + head = weights.value_heads.st; + } if (attn_body_) { auto embedded_val = std::make_unique>( - encoder_last_, weights.ip_val_w, weights.ip_val_b, scratch_mem_, + encoder_last_, head.ip_val_w, head.ip_val_b, scratch_mem_, act); network_.emplace_back(std::move(embedded_val)); } else { @@ -510,8 +519,8 @@ class CudaNetwork : public Network { } auto FCVal1 = std::make_unique>( - getLastLayer(), weights.ip1_val_b.size(), 1, 1, true, act); - FCVal1->LoadWeights(&weights.ip1_val_w[0], &weights.ip1_val_b[0], + getLastLayer(), head.ip1_val_b.size(), 1, 1, true, act); + FCVal1->LoadWeights(&head.ip1_val_w[0], &head.ip1_val_b[0], scratch_mem_); network_.emplace_back(std::move(FCVal1)); @@ -520,9 +529,9 @@ class CudaNetwork : public Network { auto fc2_tanh = !wdl_; auto FCVal2 = std::make_unique>( - getLastLayer(), weights.ip2_val_b.size(), 1, 1, true, + getLastLayer(), head.ip2_val_b.size(), 1, 1, true, fc2_tanh ? ACTIVATION_TANH : ACTIVATION_NONE); - FCVal2->LoadWeights(&weights.ip2_val_w[0], &weights.ip2_val_b[0], + FCVal2->LoadWeights(&head.ip2_val_w[0], &head.ip2_val_b[0], scratch_mem_); network_.emplace_back(std::move(FCVal2)); } @@ -953,8 +962,6 @@ class CudaNetwork : public Network { bool attn_policy_; bool attn_body_; int num_encoder_blocks_; - std::string policy_head_; - std::string value_head_; std::vector>> network_; BaseLayer* getLastLayer() { return network_.back().get(); } diff --git a/src/neural/cuda/network_cudnn.cc b/src/neural/cuda/network_cudnn.cc index 95278f3590..edfdddc8b6 100644 --- a/src/neural/cuda/network_cudnn.cc +++ b/src/neural/cuda/network_cudnn.cc @@ -512,9 +512,10 @@ class CudnnNetwork : public Network { // Policy head. if (attn_policy_) { + std::string policy_head = options.GetOrDefault("policy_head", "vanilla"); auto AttentionPolicy = std::make_unique>( getLastLayer(), weights, scratch_mem_, false, ACTIVATION_SELU, - max_batch_size_); + policy_head, max_batch_size_); network_.emplace_back(std::move(AttentionPolicy)); auto policymap = std::make_unique>( From f7205734fcb268efbbdd43f3eda4655187679f90 Mon Sep 17 00:00:00 2001 From: Aniebiet Udoh Date: Fri, 22 Sep 2023 18:01:22 +0200 Subject: [PATCH 36/70] Add new method to get value error from neural net. --- src/neural/blas/network_blas.cc | 4 ++++ src/neural/cache.cc | 7 +++++++ src/neural/cache.h | 3 +++ src/neural/cuda/network_cuda.cc | 4 ++++ src/neural/cuda/network_cudnn.cc | 4 ++++ src/neural/dx/network_dx.h | 4 ++++ src/neural/metal/network_metal.h | 4 ++++ src/neural/network.h | 2 ++ src/neural/network_check.cc | 4 ++++ src/neural/network_demux.cc | 6 ++++++ src/neural/network_mux.cc | 4 ++++ src/neural/network_random.cc | 2 ++ src/neural/network_record.cc | 4 ++++ src/neural/network_tf_cc.cc | 3 +++ src/neural/network_trivial.cc | 2 ++ src/neural/onednn/network_onednn.cc | 4 ++++ src/neural/onnx/network_onnx.cc | 6 ++++++ src/neural/opencl/network_opencl.cc | 4 ++++ src/python/weights.h | 3 +++ 19 files changed, 74 insertions(+) diff --git a/src/neural/blas/network_blas.cc b/src/neural/blas/network_blas.cc index a9667206da..2b93d21187 100644 --- a/src/neural/blas/network_blas.cc +++ b/src/neural/blas/network_blas.cc @@ -100,6 +100,10 @@ class BlasComputation : public NetworkComputation { } } + float GetEVal(int sample) const override { + return 0.0f; + } + float GetMVal(int sample) const override { if (moves_left_) { return m_values_[sample]; diff --git a/src/neural/cache.cc b/src/neural/cache.cc index d729a562f0..213466edab 100644 --- a/src/neural/cache.cc +++ b/src/neural/cache.cc @@ -89,6 +89,7 @@ void CachingComputation::ComputeBlocking() { req->q = parent_->GetQVal(item.idx_in_parent); req->d = parent_->GetDVal(item.idx_in_parent); req->m = parent_->GetMVal(item.idx_in_parent); + req->e = parent_->GetEVal(item.idx_in_parent); int idx = 0; for (auto x : item.probabilities_to_cache) { req->p[idx++] = @@ -110,6 +111,12 @@ float CachingComputation::GetDVal(int sample) const { return item.lock->d; } +float CachingComputation::GetEVal(int sample) const { + const auto& item = batch_[sample]; + if (item.idx_in_parent >= 0) return parent_->GetEVal(item.idx_in_parent); + return item.lock->e; +} + float CachingComputation::GetMVal(int sample) const { const auto& item = batch_[sample]; if (item.idx_in_parent >= 0) return parent_->GetMVal(item.idx_in_parent); diff --git a/src/neural/cache.h b/src/neural/cache.h index 207e0fe6e4..97c7c75dda 100644 --- a/src/neural/cache.h +++ b/src/neural/cache.h @@ -38,6 +38,7 @@ struct CachedNNRequest { float q; float d; float m; + float e; // TODO(mooskagh) Don't really need index if using perfect hash. SmallArray p; }; @@ -78,6 +79,8 @@ class CachingComputation { float GetQVal(int sample) const; // Returns probability of draw if NN has WDL value head. float GetDVal(int sample) const; + // Returns E (value error) value for @sample. + float GetEVal(int sample) const; // Returns estimated remaining moves. float GetMVal(int sample) const; // Returns P value @move_id of @sample. diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 275a332e6e..c21222c6ff 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -163,6 +163,10 @@ class CudaNetworkComputation : public NetworkComputation { } } + float GetEVal(int sample) const override { + return 0.0f; + } + float GetPVal(int sample, int move_id) const override { return inputs_outputs_->op_policy_mem_[sample * kNumOutputPolicy + move_id]; } diff --git a/src/neural/cuda/network_cudnn.cc b/src/neural/cuda/network_cudnn.cc index d68d280e72..599a05f293 100644 --- a/src/neural/cuda/network_cudnn.cc +++ b/src/neural/cuda/network_cudnn.cc @@ -132,6 +132,10 @@ class CudnnNetworkComputation : public NetworkComputation { } } + float GetEVal(int sample) const override { + return 0.0f; + } + float GetPVal(int sample, int move_id) const override { return inputs_outputs_->op_policy_mem_[sample * kNumOutputPolicy + move_id]; } diff --git a/src/neural/dx/network_dx.h b/src/neural/dx/network_dx.h index 82deccbebd..3ca839fc90 100644 --- a/src/neural/dx/network_dx.h +++ b/src/neural/dx/network_dx.h @@ -122,6 +122,10 @@ class DxNetworkComputation : public NetworkComputation { } } + float GetEVal(int sample) const override { + return 0.0f; + } + float GetPVal(int sample, int move_id) const override { return inputs_outputs_ ->op_policy_mem_final_[sample * kNumOutputPolicy + move_id]; diff --git a/src/neural/metal/network_metal.h b/src/neural/metal/network_metal.h index b2e2df4b39..7eada32d6f 100644 --- a/src/neural/metal/network_metal.h +++ b/src/neural/metal/network_metal.h @@ -82,6 +82,10 @@ class MetalNetworkComputation : public NetworkComputation { } } + float GetEVal(int sample) const override { + return 0.0f; + } + float GetPVal(int sample, int move_id) const override { return inputs_outputs_->op_policy_mem_[sample * kNumOutputPolicy + move_id]; } diff --git a/src/neural/network.h b/src/neural/network.h index 054b2ebd33..7bc7ad3fc4 100644 --- a/src/neural/network.h +++ b/src/neural/network.h @@ -64,6 +64,8 @@ class NetworkComputation { // Returns Q value of @sample. virtual float GetQVal(int sample) const = 0; virtual float GetDVal(int sample) const = 0; + // Returns E (value error) value for @sample. + virtual float GetEVal(int sample) const = 0; // Returns P value @move_id of @sample. virtual float GetPVal(int sample, int move_id) const = 0; virtual float GetMVal(int sample) const = 0; diff --git a/src/neural/network_check.cc b/src/neural/network_check.cc index 1b67266cff..c115a3da5e 100644 --- a/src/neural/network_check.cc +++ b/src/neural/network_check.cc @@ -105,6 +105,10 @@ class CheckComputation : public NetworkComputation { return work_comp_->GetDVal(sample); } + float GetEVal(int sample) const override { + return work_comp_->GetEVal(sample); + } + float GetMVal(int sample) const override { return work_comp_->GetMVal(sample); } diff --git a/src/neural/network_demux.cc b/src/neural/network_demux.cc index 8935b696c6..32524aef4c 100644 --- a/src/neural/network_demux.cc +++ b/src/neural/network_demux.cc @@ -58,6 +58,12 @@ class DemuxingComputation : public NetworkComputation { return parents_[idx]->GetDVal(offset); } + float GetEVal(int sample) const override { + int idx = sample / partial_size_; + int offset = sample % partial_size_; + return parents_[idx]->GetEVal(offset); + } + float GetMVal(int sample) const override { int idx = sample / partial_size_; int offset = sample % partial_size_; diff --git a/src/neural/network_mux.cc b/src/neural/network_mux.cc index e8d6ec71d3..f7fd17a786 100644 --- a/src/neural/network_mux.cc +++ b/src/neural/network_mux.cc @@ -54,6 +54,10 @@ class MuxingComputation : public NetworkComputation { return parent_->GetDVal(sample + idx_in_parent_); } + float GetEVal(int sample) const override { + return parent_->GetEVal(sample + idx_in_parent_); + } + float GetMVal(int sample) const override { return parent_->GetMVal(sample + idx_in_parent_); } diff --git a/src/neural/network_random.cc b/src/neural/network_random.cc index 5b4a2661bb..a8539c149c 100644 --- a/src/neural/network_random.cc +++ b/src/neural/network_random.cc @@ -78,6 +78,8 @@ class RandomNetworkComputation : public NetworkComputation { return d; } + float GetEVal(int /* sample */) const override { return 0.0f; } + float GetMVal(int /* sample */) const override { return 0.0f; } float GetPVal(int sample, int move_id) const override { diff --git a/src/neural/network_record.cc b/src/neural/network_record.cc index 74a908e184..1defefccbe 100644 --- a/src/neural/network_record.cc +++ b/src/neural/network_record.cc @@ -75,6 +75,9 @@ class RecordComputation : public NetworkComputation { float GetDVal(int sample) const override { return Capture(inner_->GetDVal(sample), sample); } + float GetEVal(int sample) const override { + return Capture(inner_->GetEVal(sample), sample); + } // Returns P value @move_id of @sample. float GetPVal(int sample, int move_id) const override { return Capture(inner_->GetPVal(sample, move_id), sample); @@ -144,6 +147,7 @@ class ReplayComputation : public NetworkComputation { // Returns Q value of @sample. float GetQVal(int sample) const override { return Replay(sample); } float GetDVal(int sample) const override { return Replay(sample); } + float GetEVal(int sample) const override { return Replay(sample); } // Returns P value @move_id of @sample. float GetPVal(int sample, int) const override { return Replay(sample); } float GetMVal(int sample) const override { return Replay(sample); } diff --git a/src/neural/network_tf_cc.cc b/src/neural/network_tf_cc.cc index 548baa6f0f..8fa11bd768 100644 --- a/src/neural/network_tf_cc.cc +++ b/src/neural/network_tf_cc.cc @@ -366,6 +366,9 @@ class TFNetworkComputation : public NetworkComputation { return 0.0f; } } + float GetEVal(int sample) const override { + return 0.0f; + } float GetPVal(int sample, int move_id) const override { return output_[1].template matrix()(sample, move_id); } diff --git a/src/neural/network_trivial.cc b/src/neural/network_trivial.cc index 196c0b14c1..4c6bede980 100644 --- a/src/neural/network_trivial.cc +++ b/src/neural/network_trivial.cc @@ -444,6 +444,8 @@ class TrivialNetworkComputation : public NetworkComputation { float GetDVal(int) const override { return 0.0f; } + float GetEVal(int) const override { return 0.0f; } + float GetMVal(int /* sample */) const override { return 0.0f; } float GetPVal(int /* sample */, int move_id) const override { diff --git a/src/neural/onednn/network_onednn.cc b/src/neural/onednn/network_onednn.cc index 8587e12982..9f906df6bb 100644 --- a/src/neural/onednn/network_onednn.cc +++ b/src/neural/onednn/network_onednn.cc @@ -128,6 +128,10 @@ class OnednnNetworkComputation : public NetworkComputation { } } + float GetEVal(int sample) const override { + return 0.0f; + } + float GetPVal(int sample, int move_id) const override { return inputs_outputs_->op_policy_mem_[sample * kNumOutputPolicy + move_id]; } diff --git a/src/neural/onnx/network_onnx.cc b/src/neural/onnx/network_onnx.cc index d222f745b5..d978bc5fae 100644 --- a/src/neural/onnx/network_onnx.cc +++ b/src/neural/onnx/network_onnx.cc @@ -65,6 +65,7 @@ class OnnxComputation : public NetworkComputation { void ComputeBlocking() override; float GetQVal(int sample) const override; float GetDVal(int sample) const override; + float GetEVal(int sample) const override; float GetPVal(int sample, int move_id) const override; float GetMVal(int sample) const override; @@ -175,6 +176,11 @@ float OnnxComputation::GetDVal(int sample) const { return AsFloat(data[sample * 3 + 1]); } +template +float OnnxComputation::GetEVal(int sample) const { + return 0.0; +} + template float OnnxComputation::GetPVal(int sample, int move_id) const { const auto& data = output_tensors_data_[network_->policy_head_]; diff --git a/src/neural/opencl/network_opencl.cc b/src/neural/opencl/network_opencl.cc index f4a59d0587..781d5ba0d6 100644 --- a/src/neural/opencl/network_opencl.cc +++ b/src/neural/opencl/network_opencl.cc @@ -185,6 +185,10 @@ class OpenCLComputation : public NetworkComputation { } } + float GetEVal(int sample) const override { + return 0.0f; + } + float GetMVal(int sample) const override { if (moves_left_) { auto d = m_values_[sample]; diff --git a/src/python/weights.h b/src/python/weights.h index 18288c5f69..6936ea3e5a 100644 --- a/src/python/weights.h +++ b/src/python/weights.h @@ -134,10 +134,12 @@ class Output { for (int i = 0; i < 1858; ++i) p_[i] = computation.GetPVal(idx, i); q_ = computation.GetQVal(idx); d_ = computation.GetDVal(idx); + e_ = computation.GetEVal(idx); m_ = computation.GetMVal(idx); } float q() const { return q_; } float d() const { return d_; } + float e() const { return e_; } float m() const { return m_; } std::vector p_raw(const std::vector& indicies) { std::vector result(indicies.size()); @@ -173,6 +175,7 @@ class Output { float p_[1858]; float q_; float d_; + float e_; float m_; }; From 916b3fa617abe1d9b705fab6aa55aaea5b989763 Mon Sep 17 00:00:00 2001 From: Aniebiet Udoh Date: Mon, 25 Sep 2023 19:24:26 +0200 Subject: [PATCH 37/70] Refactor value head into separate class. Add short term value error head. --- src/neural/cuda/inputs_outputs.h | 21 +++++- src/neural/cuda/layers.cc | 118 ++++++++++++++++++++++++++++++ src/neural/cuda/layers.h | 40 ++++++++++ src/neural/cuda/network_cuda.cc | 121 +++++++++++++------------------ src/neural/cuda/network_cudnn.cc | 4 +- 5 files changed, 229 insertions(+), 75 deletions(-) diff --git a/src/neural/cuda/inputs_outputs.h b/src/neural/cuda/inputs_outputs.h index da3b5a6b00..ea89ed1272 100644 --- a/src/neural/cuda/inputs_outputs.h +++ b/src/neural/cuda/inputs_outputs.h @@ -31,9 +31,11 @@ namespace lczero { namespace cudnn_backend { struct InputsOutputs { - InputsOutputs(int maxBatchSize, bool wdl, bool moves_left, + InputsOutputs(int maxBatchSize, bool wdl, bool wdl_err, bool moves_left, size_t tensor_mem_size = 0, size_t scratch_size = 0, - bool cublasDisableTensorCores = false) { + bool cublasDisableTensorCores = false) + : has_moves_left_(moves_left), + has_wdl_err_(wdl_err) { ReportCUDAErrors(cudaHostAlloc( &input_masks_mem_, maxBatchSize * kInputPlanes * sizeof(uint64_t), cudaHostAllocMapped)); @@ -68,6 +70,14 @@ struct InputsOutputs { op_moves_left_mem_, 0)); } + if (wdl_err) { + ReportCUDAErrors(cudaHostAlloc(&op_value_err_mem_, + maxBatchSize * sizeof(float), + cudaHostAllocMapped)); + ReportCUDAErrors(cudaHostGetDevicePointer(&op_value_err_mem_gpu_, + op_value_err_mem_, 0)); + } + // memory for network execution managed inside this structure if (tensor_mem_size) { multi_stream_ = true; @@ -92,6 +102,8 @@ struct InputsOutputs { ReportCUDAErrors(cudaFreeHost(op_policy_mem_)); ReportCUDAErrors(cudaFree(op_policy_mem_gpu_)); ReportCUDAErrors(cudaFreeHost(op_value_mem_)); + if (has_moves_left_) ReportCUDAErrors(cudaFreeHost(op_moves_left_mem_)); + if (has_wdl_err_) ReportCUDAErrors(cudaFreeHost(op_value_err_mem_)); if (multi_stream_) { for (auto mem : tensor_mem_) { @@ -110,12 +122,14 @@ struct InputsOutputs { float* input_val_mem_; float* op_policy_mem_; float* op_value_mem_; + float* op_value_err_mem_; float* op_moves_left_mem_; // GPU pointers for the above allocations. uint64_t* input_masks_mem_gpu_; float* input_val_mem_gpu_; float* op_value_mem_gpu_; + float* op_value_err_mem_gpu_; float* op_moves_left_mem_gpu_; // This is a seperate copy. @@ -134,6 +148,9 @@ struct InputsOutputs { // cublas handle used to run the network cublasHandle_t cublas_; + + bool has_moves_left_ = false; + bool has_wdl_err_ = false; }; } // namespace cudnn_backend diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index bdbba2c75e..67b00db897 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -2294,6 +2294,121 @@ void AttentionBody::Eval(int N, DataType* output, } // End of encoder blocks } +template +ValueHead::ValueHead(BaseLayer* ip, + const LegacyWeights::ValueHead& weights, + void* scratch, bool attention_body, + bool wdl, bool wdl_err, + ActivationFunction act, + int max_batch_size) + : BaseLayer(weights.ip_val_b.size(), 8, 8, ip), + attention_body_(attention_body), + embedding_size_(attention_body ? weights.ip_val_b.size() : 0), + convolution_size_(attention_body ? 0 : weights.value.biases.size()), + value_hidden_size_(weights.ip1_val_b.size()), + act_(act), + wdl_(wdl), + wdl_err_(wdl_err) { + // Selected head to construct. + // Use value_q as default head. + /* @todo check that head exists */ + if (attention_body_) { + allocAndUpload(&ip_val_w_, weights.ip_val_w, scratch); + allocAndUpload(&ip_val_b_, weights.ip_val_b, scratch); + } else { + // @todo value convolution here. + } + + allocAndUpload(&ip1_val_w_, weights.ip1_val_w, scratch); + allocAndUpload(&ip1_val_b_, weights.ip1_val_b, scratch); + + allocAndUpload(&ip2_val_w_, weights.ip2_val_w, scratch); + allocAndUpload(&ip2_val_b_, weights.ip2_val_b, scratch); + + if (wdl_err_) { + allocAndUpload(&ip_val_err_w_, weights.ip_val_err_w, scratch); + allocAndUpload(&ip_val_err_b_, weights.ip_val_err_b, scratch); + } +} + +template +ValueHead::~ValueHead() { + if (attention_body_) { + ReportCUDAErrors(cudaFree(ip_val_w_)); + ReportCUDAErrors(cudaFree(ip_val_b_)); + } + ReportCUDAErrors(cudaFree(ip1_val_w_)); + ReportCUDAErrors(cudaFree(ip1_val_b_)); + ReportCUDAErrors(cudaFree(ip2_val_w_)); + ReportCUDAErrors(cudaFree(ip2_val_b_)); + if (wdl_err_) { + ReportCUDAErrors(cudaFree(ip_val_err_w_)); + ReportCUDAErrors(cudaFree(ip_val_err_b_)); + } +} + +template +void ValueHead::Eval(int N, DataType* output, const DataType* input, + const DataType* input2, void* scratch, + size_t scratch_size, cudnnHandle_t /*cudnn*/, + cublasHandle_t cublas, cudaStream_t stream, + DataType***) { + DataType* buffer = (DataType*)input2; + { + const int num_inputs = this->input_->GetC(); + const int num_outputs = embedding_size_; + const int batch = N * 64; + if (attention_body_) { + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, (const DataType*)ip_val_w_, num_inputs, + input, num_inputs, 0.0f, buffer, num_outputs); + addBiasBatched(buffer, buffer, ip_val_b_, N, 64, num_outputs, act_, stream); + + } else { + // Convolution for old conv value head + } + } + + { + // Value dense 1 + const int num_inputs = embedding_size_ * 64; + const int num_outputs = value_hidden_size_; + const int batch = N; + DataType* layer_out = (DataType*)scratch; + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, (const DataType*)ip1_val_w_, num_inputs, + buffer, num_inputs, 0.0f, layer_out, num_outputs); + addBiasBatched(layer_out, layer_out, ip1_val_b_, 1, batch, + num_outputs, act_, stream); + } + + { + // Value dense 2 + const int num_inputs = value_hidden_size_; + const int num_outputs = wdl_ ? 3 : 1; + const int batch = N; + DataType* layer_out = wdl_err_ ? (DataType*)buffer : (DataType*)output; + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, (const DataType*)ip2_val_w_, num_inputs, + (DataType*)scratch, num_inputs, 0.0f, layer_out, num_outputs); + addVectors(layer_out, layer_out, ip2_val_b_, num_outputs * batch, + num_outputs * batch, num_outputs, wdl_ ? ACTIVATION_NONE : ACTIVATION_TANH, + stream); + } + + if (wdl_err_) { + // Value error dense + const int num_inputs = value_hidden_size_; + const int num_outputs = 1; + const int batch = N; + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, (const DataType*)ip_val_err_w_, num_inputs, + (DataType*)scratch, num_inputs, 0.0f, output, num_outputs); + addVectors(output, output, ip_val_err_b_, N, + N, 1, ACTIVATION_SIGMOID, stream); + } +} + // Template instantiation. #ifdef USE_CUDNN template class ConvLayer; @@ -2330,6 +2445,9 @@ template class AttentionBody; template class EmbeddingLayer; template class EmbeddingLayer; +template class ValueHead; +template class ValueHead; + // Misc error handling stuff. #ifdef USE_CUDNN void CudnnError(cudnnStatus_t status, const char* file, const int& line) { diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index 1524eaf1ed..d5f991ea84 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -507,5 +507,45 @@ class AttentionBody : public BaseLayer { const bool has_smolgen_; }; +// The value head implementation +// Responsible for loading weights into GPU memory, and evaluating the value +// head and value error head +template +class ValueHead : public BaseLayer { + using BaseLayer::C; + using BaseLayer::H; + using BaseLayer::W; + using BaseLayer::GetC; + using BaseLayer::GetH; + using BaseLayer::GetW; + + public: + ValueHead(BaseLayer* ip, const LegacyWeights::ValueHead& weights, + void* scratch, bool attention_body, bool wdl, bool wdl_err, + ActivationFunction act, int max_batch_size); + ~ValueHead(); + void Eval(int N, DataType* output, const DataType* input, + const DataType* input2, void* scratch, size_t scratch_size, + cudnnHandle_t cudnn, cublasHandle_t cublas, cudaStream_t stream, + DataType*** = nullptr) override; + + private: + // GPU allocations to hold various weights used by the attention policy head + DataType *value_w_, *value_b_; // "convolution" in value head (legacy) + DataType *ip_val_w_, *ip_val_b_; // "embedding" in value head + DataType *ip1_val_w_, *ip1_val_b_; // "FC1" in value head + DataType *ip2_val_w_, *ip2_val_b_; // "FC2" in value head + DataType *ip_val_err_w_, *ip_val_err_b_; // value error "FC" weights + + int embedding_size_; + int convolution_size_; + int value_hidden_size_; + bool wdl_; + bool wdl_err_; + bool attention_body_; + ActivationFunction act_; +}; + + } // namespace cudnn_backend } // namespace lczero diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 55a4e09559..515f82c14d 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -121,7 +121,7 @@ template class CudaNetworkComputation : public NetworkComputation { public: CudaNetworkComputation(CudaNetwork* network, bool wdl, - bool moves_left); + bool wdl_err, bool moves_left); ~CudaNetworkComputation(); void AddInput(InputPlanes&& input) override { @@ -149,21 +149,21 @@ class CudaNetworkComputation : public NetworkComputation { auto w = inputs_outputs_->op_value_mem_[3 * sample + 0]; auto l = inputs_outputs_->op_value_mem_[3 * sample + 2]; return w - l; - } else { - return inputs_outputs_->op_value_mem_[sample]; } + return inputs_outputs_->op_value_mem_[sample]; } float GetDVal(int sample) const override { if (wdl_) { - auto d = inputs_outputs_->op_value_mem_[3 * sample + 1]; - return d; - } else { - return 0.0f; + return inputs_outputs_->op_value_mem_[3 * sample + 1]; } + return 0.0f; } float GetEVal(int sample) const override { + if (wdl_err_) { + return inputs_outputs_->op_value_err_mem_[sample]; + } return 0.0f; } @@ -183,6 +183,7 @@ class CudaNetworkComputation : public NetworkComputation { std::unique_ptr inputs_outputs_; int batch_size_; bool wdl_; + bool wdl_err_; bool moves_left_; CudaNetwork* network_; @@ -495,49 +496,33 @@ class CudaNetwork : public Network { network_.emplace_back(std::move(FCPol)); } - // Value head. + // Value heads. { - // Selected head to construct. - // Use value_q as default head. std::string value_head = options.GetOrDefault("value_head", "q"); - /* @todo check that head exists */ - LegacyWeights::ValueHead head = weights.value_heads.q; + LegacyWeights::ValueHead& head = weights.value_heads.q; if (value_head == "winner" /* @todo check that head exists */) { head = weights.value_heads.winner; } else if (value_head == "st" /* @todo check that head exists */) { head = weights.value_heads.st; } - if (attn_body_) { - auto embedded_val = std::make_unique>( - encoder_last_, head.ip_val_w, head.ip_val_b, scratch_mem_, - act); - network_.emplace_back(std::move(embedded_val)); - } else { - auto convVal = std::make_unique>( - resi_last_, weights.value.biases.size(), 8, 8, kNumFilters, act, - true, use_gemm_ex); - convVal->LoadWeights(&weights.value.weights[0], - &weights.value.biases[0], scratch_mem_); - network_.emplace_back(std::move(convVal)); - } - - auto FCVal1 = std::make_unique>( - getLastLayer(), head.ip1_val_b.size(), 1, 1, true, act); - FCVal1->LoadWeights(&head.ip1_val_w[0], &head.ip1_val_b[0], - scratch_mem_); - network_.emplace_back(std::move(FCVal1)); - wdl_ = file.format().network_format().value() == pblczero::NetworkFormat::VALUE_WDL; - auto fc2_tanh = !wdl_; - - auto FCVal2 = std::make_unique>( - getLastLayer(), head.ip2_val_b.size(), 1, 1, true, - fc2_tanh ? ACTIVATION_TANH : ACTIVATION_NONE); - FCVal2->LoadWeights(&head.ip2_val_w[0], &head.ip2_val_b[0], - scratch_mem_); - network_.emplace_back(std::move(FCVal2)); + BaseLayer* lastlayer = attn_body_ ? encoder_last_ : resi_last_; + auto value_main = std::make_unique>( + lastlayer, head, scratch_mem_, attn_body_, wdl_, false, + act, max_batch_size_ + ); + network_.emplace_back(std::move(value_main)); + + wdl_err_ = weights.has_multiheads && weights.value_heads.st.ip_val_err_b.size() > 0; + if (wdl_err_) { + auto value_err = std::make_unique>( + lastlayer, weights.value_heads.st, scratch_mem_, attn_body_, + wdl_, true, act, max_batch_size_ + ); + network_.emplace_back(std::move(value_err)); + } } // Moves left head @@ -650,6 +635,7 @@ class CudaNetwork : public Network { float* opPol = io->op_policy_mem_gpu_; float* opVal = io->op_value_mem_gpu_; float* opMov = io->op_moves_left_mem_gpu_; + float* opValErr = io->op_value_err_mem_gpu_; // Figure out if the memory requirment for running the res block would fit // in the L2 cache. @@ -807,38 +793,29 @@ class CudaNetwork : public Network { cudaMemcpyDeviceToHost, stream)); // value head - network_[l++]->Eval(batchSize, spare1, flow, nullptr, scratch_mem, - scratch_size_, nullptr, cublas, - stream); // value conv or embedding - - network_[l++]->Eval(batchSize, spare2, spare1, nullptr, scratch_mem, - scratch_size_, nullptr, cublas, - stream); // value FC1 - - if (wdl_) { - if (fp16) { - network_[l++]->Eval(batchSize, spare1, spare2, nullptr, scratch_mem, - scratch_size_, nullptr, cublas, - stream); // value FC2 // VALUE - copyTypeConverted(opVal, (half*)spare1, 3 * batchSize, - stream); // VALUE - } else { - network_[l++]->Eval(batchSize, (DataType*)opVal, spare2, nullptr, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // value FC2 // VALUE - } + if (fp16) { + network_[l++]->Eval(batchSize, spare1, flow, spare2, scratch_mem, + scratch_size_, nullptr, cublas, + stream); // value head + copyTypeConverted(opVal, (half*)spare1, wdl_ ? 3 * batchSize : batchSize, + stream); } else { + network_[l++]->Eval(batchSize, (DataType*)opVal, flow, spare2, + scratch_mem, scratch_size_, nullptr, cublas, + stream); // value head + } + + if (wdl_err_) { + // value error head if (fp16) { - // TODO: consider fusing the bias-add of FC2 with format conversion. - network_[l++]->Eval(batchSize, spare1, spare2, nullptr, scratch_mem, + network_[l++]->Eval(batchSize, spare1, flow, spare2, scratch_mem, scratch_size_, nullptr, cublas, - stream); // value FC2 - copyTypeConverted(opVal, (half*)(spare1), batchSize, - stream); // VALUE + stream); // value error head + copyTypeConverted(opValErr, (half*)spare1, batchSize, stream); } else { - network_[l++]->Eval(batchSize, (DataType*)opVal, spare2, nullptr, + network_[l++]->Eval(batchSize, (DataType*)opValErr, flow, spare2, scratch_mem, scratch_size_, nullptr, cublas, - stream); // value FC2 // VALUE + stream); // value error head } } @@ -916,6 +893,7 @@ class CudaNetwork : public Network { // from a different thread). ReportCUDAErrors(cudaSetDevice(gpu_id_)); return std::make_unique>(this, wdl_, + wdl_err_, moves_left_); } @@ -923,7 +901,7 @@ class CudaNetwork : public Network { std::lock_guard lock(inputs_outputs_lock_); if (free_inputs_outputs_.empty()) { return std::make_unique( - max_batch_size_, wdl_, moves_left_, tensor_mem_size_, scratch_size_, + max_batch_size_, wdl_, wdl_err_, moves_left_, tensor_mem_size_, scratch_size_, !has_tensor_cores_ && std::is_same::value); } else { std::unique_ptr resource = @@ -941,7 +919,7 @@ class CudaNetwork : public Network { // Apparently nvcc doesn't see constructor invocations through make_unique. // This function invokes constructor just to please complier and silence // warning. Is never called (but compiler thinks that it could). - void UglyFunctionToSilenceNvccWarning() { InputsOutputs io(0, false, false); } + void UglyFunctionToSilenceNvccWarning() { InputsOutputs io(0, false, false, false); } private: const NetworkCapabilities capabilities_; @@ -949,6 +927,7 @@ class CudaNetwork : public Network { int l2_cache_size_; int max_batch_size_; bool wdl_; + bool wdl_err_; bool moves_left_; bool use_res_block_winograd_fuse_opt_; // fuse operations inside the residual // tower @@ -1041,8 +1020,8 @@ class CudaNetwork : public Network { template CudaNetworkComputation::CudaNetworkComputation( - CudaNetwork* network, bool wdl, bool moves_left) - : wdl_(wdl), moves_left_(moves_left), network_(network) { + CudaNetwork* network, bool wdl, bool wdl_err, bool moves_left) + : wdl_(wdl), wdl_err_(wdl_err), moves_left_(moves_left), network_(network) { batch_size_ = 0; inputs_outputs_ = network_->GetInputsOutputs(); } diff --git a/src/neural/cuda/network_cudnn.cc b/src/neural/cuda/network_cudnn.cc index 6cc791967b..c85e56c0d0 100644 --- a/src/neural/cuda/network_cudnn.cc +++ b/src/neural/cuda/network_cudnn.cc @@ -925,7 +925,7 @@ class CudnnNetwork : public Network { std::unique_ptr GetInputsOutputs() { std::lock_guard lock(inputs_outputs_lock_); if (free_inputs_outputs_.empty()) { - return std::make_unique(max_batch_size_, wdl_, + return std::make_unique(max_batch_size_, wdl_, false, moves_left_); } else { std::unique_ptr resource = @@ -943,7 +943,7 @@ class CudnnNetwork : public Network { // Apparently nvcc doesn't see constructor invocations through make_unique. // This function invokes constructor just to please complier and silence // warning. Is never called (but compiler thinks that it could). - void UglyFunctionToSilenceNvccWarning() { InputsOutputs io(0, false, false); } + void UglyFunctionToSilenceNvccWarning() { InputsOutputs io(0, false, false, false); } private: const NetworkCapabilities capabilities_; From ad5390bfe0c75f09b75b98c5de20194c3d5f6d4e Mon Sep 17 00:00:00 2001 From: Aniebiet Udoh Date: Mon, 25 Sep 2023 19:25:20 +0200 Subject: [PATCH 38/70] Add error to NodeToProcess. --- src/mcts/search.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/mcts/search.h b/src/mcts/search.h index ead8ed2682..84631993c1 100644 --- a/src/mcts/search.h +++ b/src/mcts/search.h @@ -302,6 +302,8 @@ class SearchWorker { float v; // Draw probability for NN's with WDL value head. float d; + // Value error from NN's with value error head. + float e; // Estimated remaining plies left. float m; int multivisit = 0; @@ -347,6 +349,8 @@ class SearchWorker { float GetDVal(int) const { return lock->d; } + float GetEVal(int) const { return lock->e; } + float GetMVal(int) const { return lock->m; } float GetPVal(int, int move_id) const { From bd34a6f22952a9698a6eeb1c701dd5e99caf8491 Mon Sep 17 00:00:00 2001 From: Aniebiet Udoh Date: Fri, 29 Sep 2023 23:46:36 +0200 Subject: [PATCH 39/70] Fix bug in value head bias add. --- src/neural/cuda/layers.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 67b00db897..351ad90b93 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -2362,7 +2362,7 @@ void ValueHead::Eval(int N, DataType* output, const DataType* input, cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)ip_val_w_, num_inputs, input, num_inputs, 0.0f, buffer, num_outputs); - addBiasBatched(buffer, buffer, ip_val_b_, N, 64, num_outputs, act_, stream); + addBiasBatched(buffer, buffer, ip_val_b_, 1, batch, num_outputs, act_, stream); } else { // Convolution for old conv value head From bfbe14c7c69d4ab3082755861e092d4d50908926 Mon Sep 17 00:00:00 2001 From: Aniebiet Udoh Date: Thu, 5 Oct 2023 02:11:51 +0200 Subject: [PATCH 40/70] Support for multihead architecture in protobuf. --- libs/lczero-common | 2 +- src/neural/cuda/network_cuda.cc | 12 ++++-------- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/libs/lczero-common b/libs/lczero-common index 66d43cf9d4..372f1195a1 160000 --- a/libs/lczero-common +++ b/libs/lczero-common @@ -1 +1 @@ -Subproject commit 66d43cf9d41a8c6987083a1f5f026f8d0e2c0307 +Subproject commit 372f1195a1fed5516a735dcfddf7e71e59030e1f diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 515f82c14d..68d90684df 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -204,7 +204,8 @@ class CudaNetwork : public Network { attn_policy_ = file.format().network_format().policy() == pblczero::NetworkFormat::POLICY_ATTENTION; - attn_body_ = file.format().network_format().network() == + // Mask out the multihead format bit 7. + attn_body_ = (file.format().network_format().network() & 127) == pblczero::NetworkFormat::NETWORK_ATTENTIONBODY_WITH_HEADFORMAT; max_batch_size_ = options.GetOrDefault("max_batch", 1024); @@ -1046,16 +1047,11 @@ std::unique_ptr MakeCudaNetwork(const std::optional& w, " backend requires a network file."); } const WeightsFile& weights = *w; - if (weights.format().network_format().network() != - pblczero::NetworkFormat::NETWORK_CLASSICAL_WITH_HEADFORMAT && - weights.format().network_format().network() != - pblczero::NetworkFormat::NETWORK_SE_WITH_HEADFORMAT && - weights.format().network_format().network() != - pblczero::NetworkFormat::NETWORK_ATTENTIONBODY_WITH_HEADFORMAT) { + if ((weights.format().network_format().network() & 128) == 0) { throw Exception("Network format " + pblczero::NetworkFormat::NetworkStructure_Name( weights.format().network_format().network()) + - " is not supported by the CUDA backend."); + " is not currently supported by this version of the CUDA backend."); } if (weights.format().network_format().policy() != pblczero::NetworkFormat::POLICY_CLASSICAL && From 4cdfab3a9745cecf3a35388be6cbd27621b47c20 Mon Sep 17 00:00:00 2001 From: Aniebiet Udoh Date: Sun, 8 Oct 2023 01:27:36 +0100 Subject: [PATCH 41/70] Add backward compatibility adjustments to old nets to work in multihead architecture. --- libs/lczero-common | 2 +- src/neural/loader.cc | 100 +++++++++++++++++++++++++++++++++++ src/neural/network_legacy.cc | 5 +- src/neural/network_legacy.h | 6 +++ 4 files changed, 110 insertions(+), 3 deletions(-) diff --git a/libs/lczero-common b/libs/lczero-common index 66d43cf9d4..82f6b0a5fd 160000 --- a/libs/lczero-common +++ b/libs/lczero-common @@ -1 +1 @@ -Subproject commit 66d43cf9d41a8c6987083a1f5f026f8d0e2c0307 +Subproject commit 82f6b0a5fdc8c79f527d3f9accf6b529cc8abac1 diff --git a/src/neural/loader.cc b/src/neural/loader.cc index 4e8e76ec57..deaabc7541 100644 --- a/src/neural/loader.cc +++ b/src/neural/loader.cc @@ -104,6 +104,75 @@ std::string DecompressGzip(const std::string& filename) { return buffer; } +void MovePolicyHead(WeightsFile* file) { + bool attn_policy = file->format().network_format().policy() == + pblczero::NetworkFormat::POLICY_ATTENTION; + auto vanilla = file->mutable_weights()->mutable_policy_heads()->mutable_vanilla(); + auto mutable_weights = file->mutable_weights(); + if (attn_policy && file->weights().has_ip_pol_b()) { + // For attention policy weights, ip_pol_w and ip_pol_b (embedding weights) + // are moved to the main "policy_heads" struct, where all policy heads + // share them. + auto heads = file->mutable_weights()->mutable_policy_heads(); + *heads->mutable_ip_pol_w() = file->weights().ip_pol_w(); + *heads->mutable_ip_pol_b() = file->weights().ip_pol_b(); + mutable_weights->mutable_ip_pol_w()->Clear(); + mutable_weights->mutable_ip_pol_b()->Clear(); + + // Some older attention policy nets have policy encoders. + for (auto enc : file->weights().pol_encoder()) { + *vanilla->add_pol_encoder() = enc; + } + vanilla->set_pol_headcount(file->weights().pol_headcount()); + } + + // Macro to move remaining shared weights around. + #define MOVE(name) \ + if (file->weights().has_##name()) { \ + *vanilla->mutable_##name() = file->weights().name(); \ + mutable_weights->mutable_##name()->Clear(); \ + } + if (!attn_policy) { + // These weights are used by older style policy heads. + MOVE(policy1); + MOVE(policy); + MOVE(ip_pol_w); + MOVE(ip_pol_b); + } + // Weights common to all policy implementations. + MOVE(ip2_pol_w); + MOVE(ip2_pol_b); + MOVE(ip3_pol_w); + MOVE(ip3_pol_b); + MOVE(ip4_pol_w); + + #undef MOVE +} + +void MoveValueHead(WeightsFile* file) { + bool attn_body = file->format().network_format().network() == + pblczero::NetworkFormat::NETWORK_ATTENTIONBODY_WITH_HEADFORMAT; + auto winner = file->mutable_weights()->mutable_value_heads()->mutable_winner(); + auto mutable_weights = file->mutable_weights(); + // Macro to move weights around. + #define MOVE(name) \ + if (file->weights().has_##name()) { \ + *winner->mutable_##name() = file->weights().name(); \ + mutable_weights->mutable_##name()->Clear(); \ + } + if (!attn_body) { + MOVE(value); + } + MOVE(ip2_val_w); + MOVE(ip2_val_b); + MOVE(ip1_val_w); + MOVE(ip1_val_b); + MOVE(ip_val_w); + MOVE(ip_val_b); + + #undef MOVE +} + void FixOlderWeightsFile(WeightsFile* file) { using nf = pblczero::NetworkFormat; auto network_format = file->format().network_format().network(); @@ -140,6 +209,37 @@ void FixOlderWeightsFile(WeightsFile* file) { net->set_smolgen_activation(pblczero::NetworkFormat::ACTIVATION_SWISH); } } + + // Get updated network format. + network_format = file->format().network_format().network(); + auto embedding_type = file->format().network_format().input_embedding(); + bool multihead_format = (network_format & 128) == 128; + if (!multihead_format) { + auto weights = file->weights(); + if (weights.has_policy_heads() && weights.has_value_heads()) { + CERR << "Weights file has multihead format, updating format flag"; + net->set_network(static_cast(network_format | 128)); + net->set_input_embedding(pblczero::NetworkFormat::INPUT_EMBEDDING_PE_DENSE); + } + else { + CERR << "Weights file has single head format, rewriting to multihead format"; + // Move policy and value heads. + MovePolicyHead(file); + MoveValueHead(file); + net->set_network(static_cast(network_format | 128)); + switch (network_format) { + case pblczero::NetworkFormat::NETWORK_ATTENTIONBODY_WITH_HEADFORMAT: + net->set_input_embedding(pblczero::NetworkFormat::INPUT_EMBEDDING_PE_MAP); + break; + default: + net->set_input_embedding(pblczero::NetworkFormat::INPUT_EMBEDDING_NONE); + } + } + } else { + if (embedding_type != pblczero::NetworkFormat::INPUT_EMBEDDING_PE_DENSE) { + net->set_input_embedding(pblczero::NetworkFormat::INPUT_EMBEDDING_PE_DENSE); + } + } } WeightsFile ParseWeightsProto(const std::string& buffer) { diff --git a/src/neural/network_legacy.cc b/src/neural/network_legacy.cc index be3ea76ffc..e487709af6 100644 --- a/src/neural/network_legacy.cc +++ b/src/neural/network_legacy.cc @@ -67,8 +67,9 @@ LegacyWeights::LegacyWeights(const pblczero::Weights& weights) smolgen_w(LayerAdapter(weights.smolgen_w()).as_vector()), has_smolgen(weights.has_smolgen_w()), policy_heads(weights.policy_heads()), - value_heads(weights.value_heads()), - has_multiheads(weights.has_policy_heads() && weights.has_value_heads()) { + value_heads(weights.value_heads()) { + has_multiheads = weights.has_policy_heads() && weights.policy_heads().has_optimistic_st() + && weights.has_value_heads() && weights.value_heads().has_q(); for (const auto& res : weights.residual()) { residual.emplace_back(res); } diff --git a/src/neural/network_legacy.h b/src/neural/network_legacy.h index bf09eb3376..ab3d9706f3 100644 --- a/src/neural/network_legacy.h +++ b/src/neural/network_legacy.h @@ -227,4 +227,10 @@ struct LegacyWeights { bool has_smolgen; }; +enum InputEmbedding { + INPUT_EMBEDDING_NONE = 0, + INPUT_EMBEDDING_PE_MAP = 1, + INPUT_EMBEDDING_PE_DENSE = 2, +}; + } // namespace lczero From 3c0ded9786423046b79fb916b53ef655279ec900 Mon Sep 17 00:00:00 2001 From: Aniebiet Udoh Date: Sun, 8 Oct 2023 23:22:42 +0200 Subject: [PATCH 42/70] Add backward compatibility. --- src/neural/cuda/layers.cc | 31 ++++--- src/neural/cuda/layers.h | 9 +- src/neural/cuda/network_cuda.cc | 146 +++++++++++++++++++++----------- 3 files changed, 118 insertions(+), 68 deletions(-) diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 351ad90b93..038c9b32ba 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -1428,17 +1428,15 @@ AttentionPolicyHead::AttentionPolicyHead( // Old networks without attention body (e.g. T79) use hardcoded SELU // activations. act_(attention_body ? act : ACTIVATION_SELU) { - // Selected head to construct. - // Use vanilla as default head. - /* @todo check that head exists */ + // Selected head to construct, use vanilla as default head. LegacyWeights::PolicyHead head = weights.policy_heads.vanilla; - if (policy_head == "optimistic" /* @todo check that head exists */) { + if (policy_head == "optimistic") { head = weights.policy_heads.optimistic_st; } - else if (policy_head == "soft" /* @todo check that head exists */) { + else if (policy_head == "soft") { head = weights.policy_heads.soft; } - else if (policy_head == "opponent" /* @todo check that head exists */) { + else if (policy_head == "opponent") { head = weights.policy_heads.opponent; } @@ -2054,7 +2052,7 @@ template AttentionBody::AttentionBody(const LegacyWeights& weights, void* scratch, Activations activations, int num_res_blocks, int input_c, - int max_batch_size) + int max_batch_size, bool new_encoding) : BaseLayer(weights.ip_emb_b.size(), 8, 8, nullptr), embedding_op_size_(weights.ip_emb_b.size()), encoder_head_count_(weights.encoder_head_count), @@ -2064,7 +2062,7 @@ AttentionBody::AttentionBody(const LegacyWeights& weights, has_gating_(weights.ip_mult_gate.size() > 0 && weights.ip_add_gate.size() > 0), has_smolgen_(weights.has_smolgen), - new_encoding_(weights.has_multiheads) { + new_encoding_(new_encoding) { allocAndUpload(&ip_emb_w_, weights.ip_emb_w, scratch); allocAndUpload(&ip_emb_b_, weights.ip_emb_b, scratch); @@ -2300,23 +2298,23 @@ ValueHead::ValueHead(BaseLayer* ip, void* scratch, bool attention_body, bool wdl, bool wdl_err, ActivationFunction act, - int max_batch_size) + int max_batch_size, bool use_gemm_ex) : BaseLayer(weights.ip_val_b.size(), 8, 8, ip), attention_body_(attention_body), - embedding_size_(attention_body ? weights.ip_val_b.size() : 0), - convolution_size_(attention_body ? 0 : weights.value.biases.size()), + embedding_size_(attention_body ? weights.ip_val_b.size() : weights.value.biases.size()), value_hidden_size_(weights.ip1_val_b.size()), act_(act), wdl_(wdl), wdl_err_(wdl_err) { - // Selected head to construct. - // Use value_q as default head. - /* @todo check that head exists */ if (attention_body_) { allocAndUpload(&ip_val_w_, weights.ip_val_w, scratch); allocAndUpload(&ip_val_b_, weights.ip_val_b, scratch); } else { - // @todo value convolution here. + conv_ = std::make_unique>( + ip, weights.value.biases.size(), 8, 8, ip->GetC(), act, + true, use_gemm_ex); + conv_->LoadWeights((float*)&weights.value.weights[0], + (float*)&weights.value.biases[0], scratch); } allocAndUpload(&ip1_val_w_, weights.ip1_val_w, scratch); @@ -2365,7 +2363,8 @@ void ValueHead::Eval(int N, DataType* output, const DataType* input, addBiasBatched(buffer, buffer, ip_val_b_, 1, batch, num_outputs, act_, stream); } else { - // Convolution for old conv value head + conv_->Eval(N, buffer, input, nullptr, scratch, + scratch_size, nullptr, cublas, stream); } } diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index d5f991ea84..a876008995 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -474,7 +474,7 @@ class AttentionBody : public BaseLayer { public: AttentionBody(const LegacyWeights& weights, void* scratch, Activations activations, int num_res_blocks, int input_c, - int max_batch_size); + int max_batch_size, bool new_encoding); ~AttentionBody(); void Eval(int N, DataType* output, const DataType* input, const DataType* input2, void* scratch, size_t scratch_size, @@ -522,7 +522,7 @@ class ValueHead : public BaseLayer { public: ValueHead(BaseLayer* ip, const LegacyWeights::ValueHead& weights, void* scratch, bool attention_body, bool wdl, bool wdl_err, - ActivationFunction act, int max_batch_size); + ActivationFunction act, int max_batch_size, bool use_gemm_ex); ~ValueHead(); void Eval(int N, DataType* output, const DataType* input, const DataType* input2, void* scratch, size_t scratch_size, @@ -530,15 +530,16 @@ class ValueHead : public BaseLayer { DataType*** = nullptr) override; private: + // "convolution" in value head (legacy) + std::unique_ptr> conv_; + // GPU allocations to hold various weights used by the attention policy head - DataType *value_w_, *value_b_; // "convolution" in value head (legacy) DataType *ip_val_w_, *ip_val_b_; // "embedding" in value head DataType *ip1_val_w_, *ip1_val_b_; // "FC1" in value head DataType *ip2_val_w_, *ip2_val_b_; // "FC2" in value head DataType *ip_val_err_w_, *ip_val_err_b_; // value error "FC" weights int embedding_size_; - int convolution_size_; int value_hidden_size_; bool wdl_; bool wdl_err_; diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 68d90684df..94cfabf4d9 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -356,6 +356,35 @@ class CudaNetwork : public Network { ActivationFunction act = mish_net ? ACTIVATION_MISH : ACTIVATION_RELU; + std::string policy_head = options.GetOrDefault("policy_head", "vanilla"); + // Check that selected policy head exists. + if (attn_policy_) { + if ((policy_head == "vanilla" && weights.policy_heads.vanilla.ip2_pol_b.size() == 0) + || (policy_head == "optimistic" && weights.policy_heads.optimistic_st.ip2_pol_b.size() == 0) + || (policy_head == "soft" && weights.policy_heads.soft.ip2_pol_b.size() == 0) + || (policy_head != "vanilla" && policy_head != "optimistic" && policy_head != "soft")) { + throw Exception("The policy head you specified '" + policy_head + "'" + + " does not exist in this net."); + } + } else { + if ((policy_head == "vanilla" && weights.policy_heads.vanilla.policy.weights.size() == 0) + || (policy_head == "optimistic" && weights.policy_heads.optimistic_st.policy.weights.size() == 0) + || (policy_head == "soft" && weights.policy_heads.soft.policy.weights.size() == 0) + || (policy_head != "vanilla" && policy_head != "optimistic" && policy_head != "soft")) { + throw Exception("The policy head you specified '" + policy_head + "'" + + " does not exist in this net."); + } + } + + std::string value_head = options.GetOrDefault("value_head", "winner"); + // Check that selected value head exists. + if ((value_head == "winner" && weights.value_heads.winner.ip1_val_b.size() == 0) + || (value_head == "q" && weights.value_heads.q.ip1_val_b.size() == 0) + || (value_head == "st" && weights.value_heads.st.ip1_val_b.size() == 0) + || (value_head != "winner" && value_head != "q" && value_head != "st")) { + throw Exception("The value head you specified '" + value_head + "'" + + " does not exist in this net."); + } // 2. Build the network, and copy the weights to GPU memory. // Input conv only used if there are residual blocks in the network @@ -436,9 +465,13 @@ class CudaNetwork : public Network { : static_cast(ffn_activation); activations.default_activation = act; + auto new_encoding = static_cast( + file.format().network_format().input_embedding()) + == InputEmbedding::INPUT_EMBEDDING_PE_DENSE; auto attention_body = std::make_unique>( weights, scratch_mem_, activations, numBlocks_, - numBlocks_ > 0 ? kNumFilters : kInputPlanes, max_batch_size_); + numBlocks_ > 0 ? kNumFilters : kInputPlanes, max_batch_size_, + new_encoding); network_.emplace_back(std::move(attention_body)); encoder_last_ = getLastLayer(); @@ -446,7 +479,6 @@ class CudaNetwork : public Network { // Policy head. if (attn_policy_) { - std::string policy_head = options.GetOrDefault("policy_head", "vanilla"); auto AttentionPolicy = std::make_unique>( getLastLayer(), weights, scratch_mem_, attn_body_, act, policy_head, max_batch_size_); @@ -457,54 +489,68 @@ class CudaNetwork : public Network { policymap->LoadWeights(kAttnPolicyMap, scratch_mem_); network_.emplace_back(std::move(policymap)); - } else if (conv_policy_) { - assert(!attn_body_); // not supported with attention body - auto conv1 = std::make_unique>( - resi_last_, kNumFilters, 8, 8, kNumFilters, act, true, false, false, - 0, use_gemm_ex); - conv1->LoadWeights(&weights.policy1.weights[0], - &weights.policy1.biases[0], scratch_mem_); - network_.emplace_back(std::move(conv1)); - - auto pol_channels = weights.policy.biases.size(); - - // No relu - auto conv2 = std::make_unique>( - getLastLayer(), pol_channels, 8, 8, kNumFilters, ACTIVATION_NONE, - true, false, false, 0, use_gemm_ex); - conv2->LoadWeights(&weights.policy.weights[0], &weights.policy.biases[0], - scratch_mem_); - network_.emplace_back(std::move(conv2)); + } else { + // Selected head to construct, use vanilla as default head. + lczero::LegacyWeights::PolicyHead head = weights.policy_heads.vanilla; + if (policy_head == "optimistic") { + head = weights.policy_heads.optimistic_st; + } + else if (policy_head == "soft") { + head = weights.policy_heads.soft; + } + else if (policy_head == "opponent") { + head = weights.policy_heads.opponent; + } + if (conv_policy_) { + assert(!attn_body_); // not supported with attention body + auto conv1 = std::make_unique>( + resi_last_, kNumFilters, 8, 8, kNumFilters, act, true, false, false, + 0, use_gemm_ex); + conv1->LoadWeights(&head.policy1.weights[0], + &head.policy1.biases[0], scratch_mem_); + network_.emplace_back(std::move(conv1)); + + auto pol_channels = head.policy.biases.size(); + + // No relu + auto conv2 = std::make_unique>( + getLastLayer(), pol_channels, 8, 8, kNumFilters, ACTIVATION_NONE, + true, false, false, 0, use_gemm_ex); + conv2->LoadWeights(&head.policy.weights[0], &head.policy.biases[0], + scratch_mem_); + network_.emplace_back(std::move(conv2)); - auto policymap = std::make_unique>( - getLastLayer(), kNumOutputPolicy, 1, 1, 73 * 8 * 8, false); - policymap->LoadWeights(kConvPolicyMap, scratch_mem_); + auto policymap = std::make_unique>( + getLastLayer(), kNumOutputPolicy, 1, 1, 73 * 8 * 8, false); + policymap->LoadWeights(kConvPolicyMap, scratch_mem_); - network_.emplace_back(std::move(policymap)); - } else { - assert(!attn_body_); // not supported with attention body - auto convPol = std::make_unique>( - resi_last_, weights.policy.biases.size(), 8, 8, kNumFilters, act, - true, use_gemm_ex); - convPol->LoadWeights(&weights.policy.weights[0], - &weights.policy.biases[0], scratch_mem_); - network_.emplace_back(std::move(convPol)); - - auto FCPol = std::make_unique>( - getLastLayer(), weights.ip_pol_b.size(), 1, 1, true, ACTIVATION_NONE); - FCPol->LoadWeights(&weights.ip_pol_w[0], &weights.ip_pol_b[0], - scratch_mem_); - network_.emplace_back(std::move(FCPol)); + network_.emplace_back(std::move(policymap)); + } else { + assert(!attn_body_); // not supported with attention body + auto convPol = std::make_unique>( + resi_last_, head.policy.biases.size(), 8, 8, kNumFilters, act, + true, use_gemm_ex); + convPol->LoadWeights(&head.policy.weights[0], + &head.policy.biases[0], scratch_mem_); + network_.emplace_back(std::move(convPol)); + + auto FCPol = std::make_unique>( + getLastLayer(), head.ip_pol_b.size(), 1, 1, true, ACTIVATION_NONE); + FCPol->LoadWeights(&head.ip_pol_w[0], &head.ip_pol_b[0], + scratch_mem_); + network_.emplace_back(std::move(FCPol)); + } } // Value heads. { - std::string value_head = options.GetOrDefault("value_head", "q"); - LegacyWeights::ValueHead& head = weights.value_heads.q; - if (value_head == "winner" /* @todo check that head exists */) { - head = weights.value_heads.winner; + // Selected head to construct, use value_winner as default head. + std::string value_head = options.GetOrDefault("value_head", "winner"); + LegacyWeights::ValueHead& head = weights.value_heads.winner; + if (value_head == "q") { + head = weights.value_heads.q; } - else if (value_head == "st" /* @todo check that head exists */) { + else if (value_head == "st") { head = weights.value_heads.st; } wdl_ = file.format().network_format().value() == @@ -512,15 +558,15 @@ class CudaNetwork : public Network { BaseLayer* lastlayer = attn_body_ ? encoder_last_ : resi_last_; auto value_main = std::make_unique>( lastlayer, head, scratch_mem_, attn_body_, wdl_, false, - act, max_batch_size_ + act, max_batch_size_, use_gemm_ex ); network_.emplace_back(std::move(value_main)); - wdl_err_ = weights.has_multiheads && weights.value_heads.st.ip_val_err_b.size() > 0; + wdl_err_ = weights.value_heads.st.ip_val_err_b.size() > 0; if (wdl_err_) { auto value_err = std::make_unique>( lastlayer, weights.value_heads.st, scratch_mem_, attn_body_, - wdl_, true, act, max_batch_size_ + wdl_, true, act, max_batch_size_, use_gemm_ex ); network_.emplace_back(std::move(value_err)); } @@ -1047,11 +1093,15 @@ std::unique_ptr MakeCudaNetwork(const std::optional& w, " backend requires a network file."); } const WeightsFile& weights = *w; - if ((weights.format().network_format().network() & 128) == 0) { + auto format = weights.format().network_format().network() & 127; + if ((weights.format().network_format().network() & 128) == 0 || + (format != pblczero::NetworkFormat::NETWORK_CLASSICAL_WITH_HEADFORMAT && + format != pblczero::NetworkFormat::NETWORK_SE_WITH_HEADFORMAT && + format != pblczero::NetworkFormat::NETWORK_ATTENTIONBODY_WITH_HEADFORMAT)) { throw Exception("Network format " + pblczero::NetworkFormat::NetworkStructure_Name( weights.format().network_format().network()) + - " is not currently supported by this version of the CUDA backend."); + " is not supported by the CUDA backend."); } if (weights.format().network_format().policy() != pblczero::NetworkFormat::POLICY_CLASSICAL && From 4333f154f85a04956411b64ea0cbd5349a4b2bbe Mon Sep 17 00:00:00 2001 From: Aniebiet Udoh Date: Sat, 6 Jan 2024 16:22:27 +0100 Subject: [PATCH 43/70] Fix conflict resolution artifacts. --- src/neural/cuda/common_kernels.cu | 129 ------------------------------ src/neural/cuda/layers.cc | 1 - src/neural/cuda/network_cuda.cc | 41 ++++++++-- 3 files changed, 36 insertions(+), 135 deletions(-) diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index b5b095adbb..1db750b810 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -33,8 +33,6 @@ #include "neural/shared/attention_policy_map.h" #include "winograd_helper.inc" -#include "neural/shared/attention_policy_map.h" - namespace lczero { namespace cudnn_backend { namespace { @@ -104,34 +102,6 @@ void addVectorsHNC_NHC(T* a, T* b, int N, int H, int C, cudaStream_t stream) { ReportCUDAErrors(cudaGetLastError()); } -template -__global__ void addVectorsHNC_NHC_kernel(T* a, T* b, int N, int H, int C) { - int i = threadIdx.x + blockDim.x * blockIdx.x; - if (i < N * H * C) { - int orig_i = i; - int c = i % C; - i /= C; - int n = i % N; - i /= N; - int h = i; - float aVal = (float)a[orig_i]; - float bVal = (float)b[n * H * C + h * C + c]; - - float cVal = aVal + bVal; - - a[orig_i] = (T)cVal; - } -} - -template -void addVectorsHNC_NHC(T* a, T* b, int N, int H, int C, cudaStream_t stream) { - const int kBlockSize = 256; - int blocks = DivUp(N * H * C, kBlockSize); - addVectorsHNC_NHC_kernel<<>>(a, b, N, H, C); - - ReportCUDAErrors(cudaGetLastError()); -} - template __global__ void addBiasBatched_kernel(T* output, const T* input, const T* bias, int N, int C) { @@ -337,105 +307,6 @@ void addBiasBatched(T* output, const T* input, const T* bias, int Batch, int N, ReportCUDAErrors(cudaGetLastError()); } -template -__global__ void addBiasBatched_kernel(T* output, const T* input, const T* bias, - int N, int C, int Nstride) { - int batch = blockIdx.y; - int n = blockIdx.x * blockDim.y + threadIdx.y; - if (n >= N) return; - int c = threadIdx.x * 4; - - int biasIndex = batch * C + c; - int tensorIndex = batch * Nstride * C + n * C + c; - - float val[4]; - float b[4]; - - // Load from memory - const bool fp16 = std::is_same::value; - if (fp16) { - half inp[4]; - copyAs(&inp[0], &input[tensorIndex]); -#pragma unroll - for (int i = 0; i < 4; i++) val[i] = (float)inp[i]; - - copyAs(&inp[0], &bias[biasIndex]); -#pragma unroll - for (int i = 0; i < 4; i++) b[i] = (float)inp[i]; - } else { - copyAs(&val[0], &input[tensorIndex]); - copyAs(&b[0], &bias[biasIndex]); - } - - // Perform bias add and activation -#pragma unroll - for (int i = 0; i < 4; i++) { - float x = val[i] + b[i]; - x = activate(x, act); - val[i] = x; - } - - // write to memory - if (fp16) { - half op[4]; -#pragma unroll - for (int i = 0; i < 4; i++) op[i] = (half)val[i]; - copyAs(&output[tensorIndex], &op[0]); - } else { - copyAs(&output[tensorIndex], &val[0]); - } -} - -// Input/output tensors are Batch * N * C -// bias tensor is N * C (i.e, different bias for each Batch dimension) -template -void addBiasBatched(T* output, const T* input, const T* bias, int Batch, int N, - int C, int Nstride, ActivationFunction activation, cudaStream_t stream) { - // process 4 elements per thread to achieve close to peak memory bandwidth - if (C % 4 != 0) throw Exception("unsupported filter size"); - if (C > 4096) throw Exception("unsupported filter size"); - - dim3 blockDim, gridDim; - blockDim.x = C / 4; - blockDim.y = std::min(std::max(512 / blockDim.x, 1u), (unsigned int) N); - blockDim.z = 1; - gridDim.x = DivUp(N, blockDim.y); - gridDim.y = Batch; - gridDim.z = 1; - - switch (activation) { - case ACTIVATION_NONE: - addBiasBatched_kernel - <<>>(output, input, bias, N, C); - break; - case ACTIVATION_SELU: - addBiasBatched_kernel - <<>>(output, input, bias, N, C); - break; - case ACTIVATION_MISH: - addBiasBatched_kernel - <<>>(output, input, bias, N, C); - break; - case ACTIVATION_RELU: - addBiasBatched_kernel - <<>>(output, input, bias, N, C); - break; - case ACTIVATION_SWISH: - addBiasBatched_kernel - <<>>(output, input, bias, N, C); - break; - case ACTIVATION_RELU_2: // square relu - addBiasBatched_kernel - <<>>(output, input, bias, N, C); - break; - default: - throw Exception( - "unsupported activation in addBiasBatched. Add in switch-case here"); - } - - ReportCUDAErrors(cudaGetLastError()); -} - template __global__ void addBias_NCHW_kernel(T* c, T* a, T* b, int N, int C, int H, int W, ActivationFunction activation) { diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 93f4476c68..59a0c94fe3 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -2000,7 +2000,6 @@ AttentionPolicyHead::~AttentionPolicyHead() { } template -EncoderBlock::~EncoderBlock() { EncoderBlock::~EncoderBlock() { ReportCUDAErrors(cudaFree(mha_q_w)); ReportCUDAErrors(cudaFree(mha_q_b)); diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 21daf32c5d..36c75e799c 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -68,8 +68,42 @@ static size_t getMaxAttentionHeadSize(const LegacyWeights& weights, int N) { const size_t encoder_heads = weights.policy_heads.vanilla.pol_encoder_head_count; - size_t size = N * 64 * std::max(std::max(embedding_op_size, encoder_dff), - policy_d_model); + size_t size = + N * 64 * + std::max(std::max(embedding_op_size, encoder_dff), policy_d_model); + + // size of matmul_qk matrix = encoder_heads_ * Batch * 64 * 64 + const size_t matmul_qk_size = encoder_heads * N * 64 * 64; + const size_t output_size = N * (64 * 64 + 8 * 24); + size = std::max(size, std::max(matmul_qk_size, output_size)); + + size_t qkv_size = N * 64 * encoder_d_model; + // We store qkv in single allocation, and other intermediate tensors are + // sometimes stored by splitting an allocation into two halves. + size = std::max(2 * size, 3 * qkv_size); + return size; +} + +static size_t getMaxAttentionBodySize(const LegacyWeights& weights, int N) { + const size_t embedding_op_size = weights.ip_emb_b.size(); + + size_t encoder_d_model = 0; + size_t encoder_dff = 0; + + if (weights.encoder.size() > 0) { + encoder_d_model = weights.encoder[0].mha.q_b.size(); + encoder_dff = weights.encoder[0].ffn.dense1_b.size(); + + assert(encoder_d_model == weights.encoder[0].mha.k_b.size()); + assert(encoder_d_model == weights.encoder[0].mha.v_b.size()); + assert(embedding_op_size == weights.encoder[0].ffn.dense2_b.size()); + } + + const size_t encoder_heads = weights.encoder_head_count; + + size_t size = + N * 64 * + std::max(std::max(embedding_op_size, encoder_dff), encoder_d_model); // size of matmul_qk matrix = encoder_heads_ * Batch * 64 * 64 const size_t matmul_qk_size = encoder_heads * N * 64 * 64; @@ -781,7 +815,6 @@ class CudaNetwork : public Network { copyTypeConverted(opPol, (half*)(spare1), batchSize * kNumOutputPolicy, stream); // POLICY output } else { - network_[l++]->Eval(batchSize, (DataType*)opPol, spare2, nullptr, network_[l++]->Eval(batchSize, (DataType*)opPol, spare2, nullptr, scratch_mem, scratch_size_, nullptr, cublas, stream); // policy map layer // POLICY output @@ -799,7 +832,6 @@ class CudaNetwork : public Network { copyTypeConverted(opPol, (half*)(spare2), batchSize * kNumOutputPolicy, stream); // POLICY } else { - network_[l++]->Eval(batchSize, (DataType*)opPol, spare1, nullptr, network_[l++]->Eval(batchSize, (DataType*)opPol, spare1, nullptr, scratch_mem, scratch_size_, nullptr, cublas, stream); // pol FC // POLICY @@ -856,7 +888,6 @@ class CudaNetwork : public Network { scratch_size_, nullptr, cublas, stream); copyTypeConverted(opMov, (half*)(spare1), batchSize, stream); } else { - network_[l++]->Eval(batchSize, (DataType*)opMov, spare2, nullptr, network_[l++]->Eval(batchSize, (DataType*)opMov, spare2, nullptr, scratch_mem, scratch_size_, nullptr, cublas, stream); From 288a337aa671baca509503fb9c6e82fa0a78d661 Mon Sep 17 00:00:00 2001 From: Aniebiet Udoh Date: Sat, 6 Jan 2024 18:36:19 +0100 Subject: [PATCH 44/70] Fix omissions. --- src/neural/cuda/kernels.h | 2 +- src/neural/cuda/layers.cc | 4 +--- src/neural/cuda/layers.h | 1 + 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/neural/cuda/kernels.h b/src/neural/cuda/kernels.h index fc515df721..6ac71324d5 100644 --- a/src/neural/cuda/kernels.h +++ b/src/neural/cuda/kernels.h @@ -161,4 +161,4 @@ void applyInputGating(T* output, const T* input, const T* mult, const T* add, // Work around to avoid "nvcc error : 'cudafe++' died with status 0xC0000409" error // For some reason nvcc runs into this random error when trying to compile this function inside the namespaces bool fusedMHA(void* output, void* mha_q, void* mha_k, void* mha_v, void* skip, - int batch_size, int num_heads, int depth, cudaStream_t stream); \ No newline at end of file + int batch_size, int num_heads, int depth, cudaStream_t stream); diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 59a0c94fe3..10460ac452 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -31,13 +31,11 @@ #include #include "cuda_common.h" -#include "neural/network.h" #include "kernels.h" #include "neural/network.h" #include "neural/shared/activation.h" #include "neural/shared/attention_policy_map.h" #include "utils/fp16_utils.h" -#include "neural/shared/attention_policy_map.h" namespace lczero { @@ -1486,7 +1484,7 @@ AttentionPolicyHead::AttentionPolicyHead( enc, scratch, encoder_heads_, embedding_op_size_, 1.0f, // using alpha = 1 for now (TODO: may change?) nullptr, 0, max_batch_size, ACTIVATION_SWISH, - act_); // smolgen weights not implemented in policy encoder heads yet. + act_, false); // smolgen weights not implemented in policy encoder heads yet. encoder_weights_.emplace_back(pW); } } diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index c8b48123b5..1ac7fdfa8c 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -382,6 +382,7 @@ class EncoderBlock { float alpha_; // scale to apply to skip connection add const bool has_smolgen_; + const bool use_fused_mha_; const ActivationFunction smolgen_activation_; const ActivationFunction ffn_activation_; From 316b85c9df9abf67b468dec24fa78ef7857292a5 Mon Sep 17 00:00:00 2001 From: Aniebiet Udoh Date: Sat, 2 Mar 2024 01:17:25 +0100 Subject: [PATCH 45/70] Remove short-term error value accessor. --- src/mcts/search.h | 4 ---- src/neural/blas/network_blas.cc | 4 ---- src/neural/cache.cc | 7 ------- src/neural/cache.h | 3 --- src/neural/cuda/network_cuda.cc | 7 ------- src/neural/cuda/network_cudnn.cc | 4 ---- src/neural/dx/network_dx.h | 4 ---- src/neural/metal/network_metal.h | 4 ---- src/neural/network.h | 2 -- src/neural/network_check.cc | 4 ---- src/neural/network_demux.cc | 6 ------ src/neural/network_mux.cc | 4 ---- src/neural/network_random.cc | 2 -- src/neural/network_record.cc | 4 ---- src/neural/network_tf_cc.cc | 3 --- src/neural/network_trivial.cc | 2 -- src/neural/onednn/network_onednn.cc | 4 ---- src/neural/onnx/network_onnx.cc | 6 ------ src/neural/opencl/network_opencl.cc | 4 ---- src/python/weights.h | 3 --- 20 files changed, 81 deletions(-) diff --git a/src/mcts/search.h b/src/mcts/search.h index 3b86771e5f..c2ff2aa116 100644 --- a/src/mcts/search.h +++ b/src/mcts/search.h @@ -320,8 +320,6 @@ class SearchWorker { float v; // Draw probability for NN's with WDL value head. float d; - // Value error from NN's with value error head. - float e; // Estimated remaining plies left. float m; int multivisit = 0; @@ -367,8 +365,6 @@ class SearchWorker { float GetDVal(int) const { return lock->d; } - float GetEVal(int) const { return lock->e; } - float GetMVal(int) const { return lock->m; } float GetPVal(int, int move_id) const { diff --git a/src/neural/blas/network_blas.cc b/src/neural/blas/network_blas.cc index 8afb30d1e5..a212ce931e 100644 --- a/src/neural/blas/network_blas.cc +++ b/src/neural/blas/network_blas.cc @@ -100,10 +100,6 @@ class BlasComputation : public NetworkComputation { } } - float GetEVal(int sample) const override { - return 0.0f; - } - float GetMVal(int sample) const override { if (moves_left_) { return m_values_[sample]; diff --git a/src/neural/cache.cc b/src/neural/cache.cc index 213466edab..d729a562f0 100644 --- a/src/neural/cache.cc +++ b/src/neural/cache.cc @@ -89,7 +89,6 @@ void CachingComputation::ComputeBlocking() { req->q = parent_->GetQVal(item.idx_in_parent); req->d = parent_->GetDVal(item.idx_in_parent); req->m = parent_->GetMVal(item.idx_in_parent); - req->e = parent_->GetEVal(item.idx_in_parent); int idx = 0; for (auto x : item.probabilities_to_cache) { req->p[idx++] = @@ -111,12 +110,6 @@ float CachingComputation::GetDVal(int sample) const { return item.lock->d; } -float CachingComputation::GetEVal(int sample) const { - const auto& item = batch_[sample]; - if (item.idx_in_parent >= 0) return parent_->GetEVal(item.idx_in_parent); - return item.lock->e; -} - float CachingComputation::GetMVal(int sample) const { const auto& item = batch_[sample]; if (item.idx_in_parent >= 0) return parent_->GetMVal(item.idx_in_parent); diff --git a/src/neural/cache.h b/src/neural/cache.h index 97c7c75dda..207e0fe6e4 100644 --- a/src/neural/cache.h +++ b/src/neural/cache.h @@ -38,7 +38,6 @@ struct CachedNNRequest { float q; float d; float m; - float e; // TODO(mooskagh) Don't really need index if using perfect hash. SmallArray p; }; @@ -79,8 +78,6 @@ class CachingComputation { float GetQVal(int sample) const; // Returns probability of draw if NN has WDL value head. float GetDVal(int sample) const; - // Returns E (value error) value for @sample. - float GetEVal(int sample) const; // Returns estimated remaining moves. float GetMVal(int sample) const; // Returns P value @move_id of @sample. diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 8ac7111af7..d75bab2c38 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -160,13 +160,6 @@ class CudaNetworkComputation : public NetworkComputation { return 0.0f; } - float GetEVal(int sample) const override { - if (wdl_err_) { - return inputs_outputs_->op_value_err_mem_[sample]; - } - return 0.0f; - } - float GetPVal(int sample, int move_id) const override { return inputs_outputs_->op_policy_mem_[sample * kNumOutputPolicy + move_id]; } diff --git a/src/neural/cuda/network_cudnn.cc b/src/neural/cuda/network_cudnn.cc index 16993b1a4a..eaf6a631c0 100644 --- a/src/neural/cuda/network_cudnn.cc +++ b/src/neural/cuda/network_cudnn.cc @@ -132,10 +132,6 @@ class CudnnNetworkComputation : public NetworkComputation { } } - float GetEVal(int sample) const override { - return 0.0f; - } - float GetPVal(int sample, int move_id) const override { return inputs_outputs_->op_policy_mem_[sample * kNumOutputPolicy + move_id]; } diff --git a/src/neural/dx/network_dx.h b/src/neural/dx/network_dx.h index 3ca839fc90..82deccbebd 100644 --- a/src/neural/dx/network_dx.h +++ b/src/neural/dx/network_dx.h @@ -122,10 +122,6 @@ class DxNetworkComputation : public NetworkComputation { } } - float GetEVal(int sample) const override { - return 0.0f; - } - float GetPVal(int sample, int move_id) const override { return inputs_outputs_ ->op_policy_mem_final_[sample * kNumOutputPolicy + move_id]; diff --git a/src/neural/metal/network_metal.h b/src/neural/metal/network_metal.h index 7eada32d6f..b2e2df4b39 100644 --- a/src/neural/metal/network_metal.h +++ b/src/neural/metal/network_metal.h @@ -82,10 +82,6 @@ class MetalNetworkComputation : public NetworkComputation { } } - float GetEVal(int sample) const override { - return 0.0f; - } - float GetPVal(int sample, int move_id) const override { return inputs_outputs_->op_policy_mem_[sample * kNumOutputPolicy + move_id]; } diff --git a/src/neural/network.h b/src/neural/network.h index ed44e20f78..3129ffc518 100644 --- a/src/neural/network.h +++ b/src/neural/network.h @@ -64,8 +64,6 @@ class NetworkComputation { // Returns Q value of @sample. virtual float GetQVal(int sample) const = 0; virtual float GetDVal(int sample) const = 0; - // Returns E (value error) value for @sample. - virtual float GetEVal(int sample) const = 0; // Returns P value @move_id of @sample. virtual float GetPVal(int sample, int move_id) const = 0; virtual float GetMVal(int sample) const = 0; diff --git a/src/neural/network_check.cc b/src/neural/network_check.cc index c115a3da5e..1b67266cff 100644 --- a/src/neural/network_check.cc +++ b/src/neural/network_check.cc @@ -105,10 +105,6 @@ class CheckComputation : public NetworkComputation { return work_comp_->GetDVal(sample); } - float GetEVal(int sample) const override { - return work_comp_->GetEVal(sample); - } - float GetMVal(int sample) const override { return work_comp_->GetMVal(sample); } diff --git a/src/neural/network_demux.cc b/src/neural/network_demux.cc index e280b4def7..a1a28f779f 100644 --- a/src/neural/network_demux.cc +++ b/src/neural/network_demux.cc @@ -58,12 +58,6 @@ class DemuxingComputation : public NetworkComputation { return parents_[idx]->GetDVal(offset); } - float GetEVal(int sample) const override { - int idx = sample / partial_size_; - int offset = sample % partial_size_; - return parents_[idx]->GetEVal(offset); - } - float GetMVal(int sample) const override { int idx = sample / partial_size_; int offset = sample % partial_size_; diff --git a/src/neural/network_mux.cc b/src/neural/network_mux.cc index b04ac7f4cb..e7b6fe6835 100644 --- a/src/neural/network_mux.cc +++ b/src/neural/network_mux.cc @@ -54,10 +54,6 @@ class MuxingComputation : public NetworkComputation { return parent_->GetDVal(sample + idx_in_parent_); } - float GetEVal(int sample) const override { - return parent_->GetEVal(sample + idx_in_parent_); - } - float GetMVal(int sample) const override { return parent_->GetMVal(sample + idx_in_parent_); } diff --git a/src/neural/network_random.cc b/src/neural/network_random.cc index a8539c149c..5b4a2661bb 100644 --- a/src/neural/network_random.cc +++ b/src/neural/network_random.cc @@ -78,8 +78,6 @@ class RandomNetworkComputation : public NetworkComputation { return d; } - float GetEVal(int /* sample */) const override { return 0.0f; } - float GetMVal(int /* sample */) const override { return 0.0f; } float GetPVal(int sample, int move_id) const override { diff --git a/src/neural/network_record.cc b/src/neural/network_record.cc index 1defefccbe..74a908e184 100644 --- a/src/neural/network_record.cc +++ b/src/neural/network_record.cc @@ -75,9 +75,6 @@ class RecordComputation : public NetworkComputation { float GetDVal(int sample) const override { return Capture(inner_->GetDVal(sample), sample); } - float GetEVal(int sample) const override { - return Capture(inner_->GetEVal(sample), sample); - } // Returns P value @move_id of @sample. float GetPVal(int sample, int move_id) const override { return Capture(inner_->GetPVal(sample, move_id), sample); @@ -147,7 +144,6 @@ class ReplayComputation : public NetworkComputation { // Returns Q value of @sample. float GetQVal(int sample) const override { return Replay(sample); } float GetDVal(int sample) const override { return Replay(sample); } - float GetEVal(int sample) const override { return Replay(sample); } // Returns P value @move_id of @sample. float GetPVal(int sample, int) const override { return Replay(sample); } float GetMVal(int sample) const override { return Replay(sample); } diff --git a/src/neural/network_tf_cc.cc b/src/neural/network_tf_cc.cc index 8fa11bd768..548baa6f0f 100644 --- a/src/neural/network_tf_cc.cc +++ b/src/neural/network_tf_cc.cc @@ -366,9 +366,6 @@ class TFNetworkComputation : public NetworkComputation { return 0.0f; } } - float GetEVal(int sample) const override { - return 0.0f; - } float GetPVal(int sample, int move_id) const override { return output_[1].template matrix()(sample, move_id); } diff --git a/src/neural/network_trivial.cc b/src/neural/network_trivial.cc index 4c6bede980..196c0b14c1 100644 --- a/src/neural/network_trivial.cc +++ b/src/neural/network_trivial.cc @@ -444,8 +444,6 @@ class TrivialNetworkComputation : public NetworkComputation { float GetDVal(int) const override { return 0.0f; } - float GetEVal(int) const override { return 0.0f; } - float GetMVal(int /* sample */) const override { return 0.0f; } float GetPVal(int /* sample */, int move_id) const override { diff --git a/src/neural/onednn/network_onednn.cc b/src/neural/onednn/network_onednn.cc index 6b2874fd36..d0aecedbc4 100644 --- a/src/neural/onednn/network_onednn.cc +++ b/src/neural/onednn/network_onednn.cc @@ -128,10 +128,6 @@ class OnednnNetworkComputation : public NetworkComputation { } } - float GetEVal(int sample) const override { - return 0.0f; - } - float GetPVal(int sample, int move_id) const override { return inputs_outputs_->op_policy_mem_[sample * kNumOutputPolicy + move_id]; } diff --git a/src/neural/onnx/network_onnx.cc b/src/neural/onnx/network_onnx.cc index b2788bcdb4..e7dfe98b2d 100644 --- a/src/neural/onnx/network_onnx.cc +++ b/src/neural/onnx/network_onnx.cc @@ -65,7 +65,6 @@ class OnnxComputation : public NetworkComputation { void ComputeBlocking() override; float GetQVal(int sample) const override; float GetDVal(int sample) const override; - float GetEVal(int sample) const override; float GetPVal(int sample, int move_id) const override; float GetMVal(int sample) const override; @@ -185,11 +184,6 @@ float OnnxComputation::GetDVal(int sample) const { return AsFloat(data[sample * 3 + 1]); } -template -float OnnxComputation::GetEVal(int sample) const { - return 0.0; -} - template float OnnxComputation::GetPVal(int sample, int move_id) const { const auto& data = output_tensors_data_[network_->policy_head_]; diff --git a/src/neural/opencl/network_opencl.cc b/src/neural/opencl/network_opencl.cc index 781d5ba0d6..f4a59d0587 100644 --- a/src/neural/opencl/network_opencl.cc +++ b/src/neural/opencl/network_opencl.cc @@ -185,10 +185,6 @@ class OpenCLComputation : public NetworkComputation { } } - float GetEVal(int sample) const override { - return 0.0f; - } - float GetMVal(int sample) const override { if (moves_left_) { auto d = m_values_[sample]; diff --git a/src/python/weights.h b/src/python/weights.h index 6936ea3e5a..18288c5f69 100644 --- a/src/python/weights.h +++ b/src/python/weights.h @@ -134,12 +134,10 @@ class Output { for (int i = 0; i < 1858; ++i) p_[i] = computation.GetPVal(idx, i); q_ = computation.GetQVal(idx); d_ = computation.GetDVal(idx); - e_ = computation.GetEVal(idx); m_ = computation.GetMVal(idx); } float q() const { return q_; } float d() const { return d_; } - float e() const { return e_; } float m() const { return m_; } std::vector p_raw(const std::vector& indicies) { std::vector result(indicies.size()); @@ -175,7 +173,6 @@ class Output { float p_[1858]; float q_; float d_; - float e_; float m_; }; From 284eb27fefeb036cc5f61928120ffb7ad21b10b6 Mon Sep 17 00:00:00 2001 From: Aniebiet Udoh Date: Sat, 2 Mar 2024 01:57:16 +0100 Subject: [PATCH 46/70] Remove old artifacts from network_legacy --- src/neural/cuda/layers.cc | 10 ++-- src/neural/loader.cc | 100 ----------------------------------- src/neural/network_legacy.cc | 6 +-- src/neural/network_legacy.h | 55 ------------------- 4 files changed, 6 insertions(+), 165 deletions(-) diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 038c9b32ba..dc87474403 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -2069,16 +2069,16 @@ AttentionBody::AttentionBody(const LegacyWeights& weights, if (new_encoding_) { allocAndUpload(&ip_emb_pre_w_, weights.ip_emb_preproc_w, scratch); allocAndUpload(&ip_emb_pre_b_, weights.ip_emb_preproc_b, scratch); - + allocAndUpload(&ip_emb_ln_g_, weights.ip_emb_ln_gammas, scratch); allocAndUpload(&ip_emb_ln_b_, weights.ip_emb_ln_betas, scratch); - + allocAndUpload(&ip_emb_ffn_d1_w_, weights.ip_emb_ffn.dense1_w, scratch); allocAndUpload(&ip_emb_ffn_d1_b_, weights.ip_emb_ffn.dense1_b, scratch); - + allocAndUpload(&ip_emb_ffn_d2_w_, weights.ip_emb_ffn.dense2_w, scratch); allocAndUpload(&ip_emb_ffn_d2_b_, weights.ip_emb_ffn.dense2_b, scratch); - + allocAndUpload(&ip_emb_ffn_ln_g_, weights.ip_emb_ffn_ln_gammas, scratch); allocAndUpload(&ip_emb_ffn_ln_b_, weights.ip_emb_ffn_ln_betas, scratch); @@ -2179,7 +2179,7 @@ void AttentionBody::Eval(int N, DataType* output, (const DataType*)ip_emb_pre_w_, num_inputs, (const DataType*)scratch, num_inputs, 0.0f, buffer1, num_outputs); - + // addBiasBatched(buffer1, buffer1, ip_emb_pre_b_, batch, N, num_outputs, // ACTIVATION_NONE, stream); const int size = num_outputs * N; diff --git a/src/neural/loader.cc b/src/neural/loader.cc index aab8c1d32f..c88d985c2b 100644 --- a/src/neural/loader.cc +++ b/src/neural/loader.cc @@ -104,75 +104,6 @@ std::string DecompressGzip(const std::string& filename) { return buffer; } -void MovePolicyHead(WeightsFile* file) { - bool attn_policy = file->format().network_format().policy() == - pblczero::NetworkFormat::POLICY_ATTENTION; - auto vanilla = file->mutable_weights()->mutable_policy_heads()->mutable_vanilla(); - auto mutable_weights = file->mutable_weights(); - if (attn_policy && file->weights().has_ip_pol_b()) { - // For attention policy weights, ip_pol_w and ip_pol_b (embedding weights) - // are moved to the main "policy_heads" struct, where all policy heads - // share them. - auto heads = file->mutable_weights()->mutable_policy_heads(); - *heads->mutable_ip_pol_w() = file->weights().ip_pol_w(); - *heads->mutable_ip_pol_b() = file->weights().ip_pol_b(); - mutable_weights->mutable_ip_pol_w()->Clear(); - mutable_weights->mutable_ip_pol_b()->Clear(); - - // Some older attention policy nets have policy encoders. - for (auto enc : file->weights().pol_encoder()) { - *vanilla->add_pol_encoder() = enc; - } - vanilla->set_pol_headcount(file->weights().pol_headcount()); - } - - // Macro to move remaining shared weights around. - #define MOVE(name) \ - if (file->weights().has_##name()) { \ - *vanilla->mutable_##name() = file->weights().name(); \ - mutable_weights->mutable_##name()->Clear(); \ - } - if (!attn_policy) { - // These weights are used by older style policy heads. - MOVE(policy1); - MOVE(policy); - MOVE(ip_pol_w); - MOVE(ip_pol_b); - } - // Weights common to all policy implementations. - MOVE(ip2_pol_w); - MOVE(ip2_pol_b); - MOVE(ip3_pol_w); - MOVE(ip3_pol_b); - MOVE(ip4_pol_w); - - #undef MOVE -} - -void MoveValueHead(WeightsFile* file) { - bool attn_body = file->format().network_format().network() == - pblczero::NetworkFormat::NETWORK_ATTENTIONBODY_WITH_HEADFORMAT; - auto winner = file->mutable_weights()->mutable_value_heads()->mutable_winner(); - auto mutable_weights = file->mutable_weights(); - // Macro to move weights around. - #define MOVE(name) \ - if (file->weights().has_##name()) { \ - *winner->mutable_##name() = file->weights().name(); \ - mutable_weights->mutable_##name()->Clear(); \ - } - if (!attn_body) { - MOVE(value); - } - MOVE(ip2_val_w); - MOVE(ip2_val_b); - MOVE(ip1_val_w); - MOVE(ip1_val_b); - MOVE(ip_val_w); - MOVE(ip_val_b); - - #undef MOVE -} - void FixOlderWeightsFile(WeightsFile* file) { using nf = pblczero::NetworkFormat; auto network_format = file->format().network_format().network(); @@ -225,37 +156,6 @@ void FixOlderWeightsFile(WeightsFile* file) { net->set_input_embedding(nf::INPUT_EMBEDDING_PE_MAP); } } - - // Get updated network format. - network_format = file->format().network_format().network(); - auto embedding_type = file->format().network_format().input_embedding(); - bool multihead_format = (network_format & 128) == 128; - if (!multihead_format) { - auto weights = file->weights(); - if (weights.has_policy_heads() && weights.has_value_heads()) { - CERR << "Weights file has multihead format, updating format flag"; - net->set_network(static_cast(network_format | 128)); - net->set_input_embedding(pblczero::NetworkFormat::INPUT_EMBEDDING_PE_DENSE); - } - else { - CERR << "Weights file has single head format, rewriting to multihead format"; - // Move policy and value heads. - MovePolicyHead(file); - MoveValueHead(file); - net->set_network(static_cast(network_format | 128)); - switch (network_format) { - case pblczero::NetworkFormat::NETWORK_ATTENTIONBODY_WITH_HEADFORMAT: - net->set_input_embedding(pblczero::NetworkFormat::INPUT_EMBEDDING_PE_MAP); - break; - default: - net->set_input_embedding(pblczero::NetworkFormat::INPUT_EMBEDDING_NONE); - } - } - } else { - if (embedding_type != pblczero::NetworkFormat::INPUT_EMBEDDING_PE_DENSE) { - net->set_input_embedding(pblczero::NetworkFormat::INPUT_EMBEDDING_PE_DENSE); - } - } } WeightsFile ParseWeightsProto(const std::string& buffer) { diff --git a/src/neural/network_legacy.cc b/src/neural/network_legacy.cc index fdc7d228f4..b2f4a1b510 100644 --- a/src/neural/network_legacy.cc +++ b/src/neural/network_legacy.cc @@ -53,11 +53,7 @@ BaseWeights::BaseWeights(const pblczero::Weights& weights) ip2_mov_w(LayerAdapter(weights.ip2_mov_w()).as_vector()), ip2_mov_b(LayerAdapter(weights.ip2_mov_b()).as_vector()), smolgen_w(LayerAdapter(weights.smolgen_w()).as_vector()), - has_smolgen(weights.has_smolgen_w()), - policy_heads(weights.policy_heads()), - value_heads(weights.value_heads()) { - has_multiheads = weights.has_policy_heads() && weights.policy_heads().has_optimistic_st() - && weights.has_value_heads() && weights.value_heads().has_q(); + has_smolgen(weights.has_smolgen_w()) { for (const auto& res : weights.residual()) { residual.emplace_back(res); } diff --git a/src/neural/network_legacy.h b/src/neural/network_legacy.h index 7cc1eac87e..72ce67544f 100644 --- a/src/neural/network_legacy.h +++ b/src/neural/network_legacy.h @@ -101,55 +101,6 @@ struct BaseWeights { Vec ln2_betas; }; - struct PolicyHead { - explicit PolicyHead(const pblczero::Weights::PolicyHead& policyhead); - // Policy head - // Extra convolution for AZ-style policy head - ConvBlock policy1; - ConvBlock policy; - Vec ip_pol_w; - Vec ip_pol_b; - // Extra params for attention policy head - Vec ip2_pol_w; - Vec ip2_pol_b; - Vec ip3_pol_w; - Vec ip3_pol_b; - Vec ip4_pol_w; - int pol_encoder_head_count; - std::vector pol_encoder; - }; - - struct ValueHead { - explicit ValueHead(const pblczero::Weights::ValueHead& valuehead); - // Value head - ConvBlock value; - Vec ip_val_w; - Vec ip_val_b; - Vec ip1_val_w; - Vec ip1_val_b; - Vec ip2_val_w; - Vec ip2_val_b; - Vec ip_val_err_w; - Vec ip_val_err_b; - }; - - struct PolicyHeads { - explicit PolicyHeads(const pblczero::Weights::PolicyHeads& policyheads); - Vec ip_pol_w; - Vec ip_pol_b; - PolicyHead vanilla; - PolicyHead optimistic_st; - PolicyHead soft; - PolicyHead opponent; - }; - - struct ValueHeads { - explicit ValueHeads(const pblczero::Weights::ValueHeads& valueheads); - ValueHead winner; - ValueHead q; - ValueHead st; - }; - // Input convnet. ConvBlock input; @@ -283,10 +234,4 @@ enum InputEmbedding { INPUT_EMBEDDING_PE_DENSE = 2, }; -enum InputEmbedding { - INPUT_EMBEDDING_NONE = 0, - INPUT_EMBEDDING_PE_MAP = 1, - INPUT_EMBEDDING_PE_DENSE = 2, -}; - } // namespace lczero From d446d128568075a5298c414839bcf74379d67b01 Mon Sep 17 00:00:00 2001 From: Aniebiet Udoh Date: Sat, 2 Mar 2024 17:38:27 +0100 Subject: [PATCH 47/70] Fix backend to use MultiHeadWeights struct. --- src/neural/cuda/layers.cc | 63 +++---- src/neural/cuda/layers.h | 15 +- src/neural/cuda/network_cuda.cc | 293 ++++++++++++++----------------- src/neural/cuda/network_cudnn.cc | 127 +++++++------- 4 files changed, 234 insertions(+), 264 deletions(-) diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index dc87474403..c55988ff43 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -39,7 +39,7 @@ namespace lczero { -#if 1 +#if 0 // debug code to dump allocation in GPU memory template void dumpTensor(T* memory, int elements, const char* message, bool only_summary = false) { @@ -1421,45 +1421,34 @@ void allocAndUpload(DataType** gpu_dest, std::vector cpu_src, template AttentionPolicyHead::AttentionPolicyHead( - BaseLayer* ip, const LegacyWeights& weights, void* scratch, - bool attention_body, ActivationFunction act, std::string policy_head, int max_batch_size) + BaseLayer* ip, const MultiHeadWeights::PolicyHead& weights, + void* scratch, bool attention_body, ActivationFunction act, + int max_batch_size) : BaseLayer(64 * 64 + 24 * 8, 1, 1, ip), attention_body_(attention_body), // Old networks without attention body (e.g. T79) use hardcoded SELU // activations. act_(attention_body ? act : ACTIVATION_SELU) { - // Selected head to construct, use vanilla as default head. - LegacyWeights::PolicyHead head = weights.policy_heads.vanilla; - if (policy_head == "optimistic") { - head = weights.policy_heads.optimistic_st; - } - else if (policy_head == "soft") { - head = weights.policy_heads.soft; - } - else if (policy_head == "opponent") { - head = weights.policy_heads.opponent; - } - - embedding_op_size_ = weights.policy_heads.ip_pol_b.size(); - wq_op_size_ = head.ip2_pol_b.size(); - wk_op_size_ = head.ip3_pol_b.size(); + embedding_op_size_ = weights.ip_pol_b.size(); + wq_op_size_ = weights.ip2_pol_b.size(); + wk_op_size_ = weights.ip3_pol_b.size(); - encoder_heads_ = head.pol_encoder_head_count; + encoder_heads_ = weights.pol_encoder_head_count; policy_d_model_ = wq_op_size_; - allocAndUpload(&ip_pol_w_, weights.policy_heads.ip_pol_w, scratch); - allocAndUpload(&ip_pol_b_, weights.policy_heads.ip_pol_b, scratch); + allocAndUpload(&ip_pol_w_, weights.ip_pol_w, scratch); + allocAndUpload(&ip_pol_b_, weights.ip_pol_b, scratch); - allocAndUpload(&ip2_pol_w_, head.ip2_pol_w, scratch); - allocAndUpload(&ip2_pol_b_, head.ip2_pol_b, scratch); + allocAndUpload(&ip2_pol_w_, weights.ip2_pol_w, scratch); + allocAndUpload(&ip2_pol_b_, weights.ip2_pol_b, scratch); - allocAndUpload(&ip3_pol_w_, head.ip3_pol_w, scratch); - allocAndUpload(&ip3_pol_b_, head.ip3_pol_b, scratch); + allocAndUpload(&ip3_pol_w_, weights.ip3_pol_w, scratch); + allocAndUpload(&ip3_pol_b_, weights.ip3_pol_b, scratch); // big allocation to hold wq and wk weights one after the other { - size_t elements = head.ip2_pol_w.size(); - assert(elements == head.ip3_pol_w.size()); + size_t elements = weights.ip2_pol_w.size(); + assert(elements == weights.ip3_pol_w.size()); size_t size = elements * sizeof(DataType) * 2; ReportCUDAErrors(cudaMalloc(&wqk_w_, size)); @@ -1468,7 +1457,7 @@ AttentionPolicyHead::AttentionPolicyHead( ReportCUDAErrors(cudaMemcpy(wqk_w_ + elements, ip3_pol_w_, size / 2, cudaMemcpyDeviceToDevice)); - elements = head.ip2_pol_b.size(); + elements = weights.ip2_pol_b.size(); size = elements * sizeof(DataType) * 2; ReportCUDAErrors(cudaMalloc(&wqk_b_, size)); ReportCUDAErrors( @@ -1477,9 +1466,9 @@ AttentionPolicyHead::AttentionPolicyHead( cudaMemcpyDeviceToDevice)); } - allocAndUpload(&ip4_pol_w_, head.ip4_pol_w, scratch); + allocAndUpload(&ip4_pol_w_, weights.ip4_pol_w, scratch); - for (const auto& enc : head.pol_encoder) { + for (const auto& enc : weights.pol_encoder) { EncoderBlock* pW = new EncoderBlock( enc, scratch, encoder_heads_, embedding_op_size_, 1.0f, // using alpha = 1 for now (TODO: may change?) @@ -1491,7 +1480,7 @@ AttentionPolicyHead::AttentionPolicyHead( template EncoderBlock::EncoderBlock( - const LegacyWeights::EncoderLayer& cpu_weights, void* scratch, int heads, + const MultiHeadWeights::EncoderLayer& cpu_weights, void* scratch, int heads, int size, float alpha, DataType* smolgen_global_scratch, int smolgen_global_size, int max_batch_size, ActivationFunction smolgen_act, ActivationFunction ffn_act) @@ -2049,7 +2038,7 @@ void EmbeddingLayer::Eval( } template -AttentionBody::AttentionBody(const LegacyWeights& weights, +AttentionBody::AttentionBody(const MultiHeadWeights& weights, void* scratch, Activations activations, int num_res_blocks, int input_c, int max_batch_size, bool new_encoding) @@ -2294,14 +2283,14 @@ void AttentionBody::Eval(int N, DataType* output, template ValueHead::ValueHead(BaseLayer* ip, - const LegacyWeights::ValueHead& weights, - void* scratch, bool attention_body, - bool wdl, bool wdl_err, - ActivationFunction act, + const MultiHeadWeights::ValueHead& weights, + void* scratch, bool attention_body, bool wdl, + bool wdl_err, ActivationFunction act, int max_batch_size, bool use_gemm_ex) : BaseLayer(weights.ip_val_b.size(), 8, 8, ip), attention_body_(attention_body), - embedding_size_(attention_body ? weights.ip_val_b.size() : weights.value.biases.size()), + embedding_size_(attention_body ? weights.ip_val_b.size() + : weights.value.biases.size()), value_hidden_size_(weights.ip1_val_b.size()), act_(act), wdl_(wdl), diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index a876008995..f685f2e228 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -336,7 +336,7 @@ class ResidualBlock : public BaseLayer { template class EncoderBlock { public: - EncoderBlock(const LegacyWeights::EncoderLayer& cpu_weights, void* scratch, + EncoderBlock(const MultiHeadWeights::EncoderLayer& cpu_weights, void* scratch, int heads, int size, float alpha, DataType* smolgen_global_scratch, int smolgen_global_size, int max_batch_size, ActivationFunction smolgen_act, @@ -407,9 +407,10 @@ class AttentionPolicyHead : public BaseLayer { using BaseLayer::GetW; public: - AttentionPolicyHead(BaseLayer* ip, const LegacyWeights& weights, + AttentionPolicyHead(BaseLayer* ip, + const MultiHeadWeights::PolicyHead& weights, void* scratch, bool attention_body, - ActivationFunction act, std::string policy_head, int max_batch_size); + ActivationFunction act, int max_batch_size); ~AttentionPolicyHead(); void Eval(int N, DataType* output, const DataType* input, const DataType* input2, void* scratch, size_t scratch_size, @@ -472,7 +473,7 @@ class AttentionBody : public BaseLayer { using BaseLayer::GetW; public: - AttentionBody(const LegacyWeights& weights, void* scratch, + AttentionBody(const MultiHeadWeights& weights, void* scratch, Activations activations, int num_res_blocks, int input_c, int max_batch_size, bool new_encoding); ~AttentionBody(); @@ -520,9 +521,9 @@ class ValueHead : public BaseLayer { using BaseLayer::GetW; public: - ValueHead(BaseLayer* ip, const LegacyWeights::ValueHead& weights, - void* scratch, bool attention_body, bool wdl, bool wdl_err, - ActivationFunction act, int max_batch_size, bool use_gemm_ex); + ValueHead(BaseLayer* ip, const MultiHeadWeights::ValueHead& weights, + void* scratch, bool attention_body, bool wdl, bool wdl_err, + ActivationFunction act, int max_batch_size, bool use_gemm_ex); ~ValueHead(); void Eval(int N, DataType* output, const DataType* input, const DataType* input2, void* scratch, size_t scratch_size, diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index d75bab2c38..bbad91a0a1 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -49,24 +49,25 @@ using namespace cudnn_backend; template class CudaNetwork; -static size_t getMaxAttentionHeadSize(const LegacyWeights& weights, int N) { - const size_t embedding_op_size = weights.policy_heads.vanilla.ip_pol_b.size(); - const size_t policy_d_model = weights.policy_heads.vanilla.ip2_pol_b.size(); - assert(policy_d_model == weights.policy_heads.vanilla.ip3_pol_b.size()); +static size_t getMaxAttentionHeadSize(const MultiHeadWeights& weights, int N) { + const auto vanilla = weights.policy_heads.at("vanilla"); + const size_t embedding_op_size = vanilla.ip_pol_b.size(); + const size_t policy_d_model = vanilla.ip2_pol_b.size(); + assert(policy_d_model == vanilla.ip3_pol_b.size()); size_t encoder_d_model = 0; size_t encoder_dff = 0; - if (weights.policy_heads.vanilla.pol_encoder.size() > 0) { - encoder_d_model = weights.policy_heads.vanilla.pol_encoder[0].mha.q_b.size(); - encoder_dff = weights.policy_heads.vanilla.pol_encoder[0].ffn.dense1_b.size(); + if (vanilla.pol_encoder.size() > 0) { + encoder_d_model = vanilla.pol_encoder[0].mha.q_b.size(); + encoder_dff = vanilla.pol_encoder[0].ffn.dense1_b.size(); - assert(encoder_d_model == weights.policy_heads.vanilla.pol_encoder[0].mha.k_b.size()); - assert(encoder_d_model == weights.policy_heads.vanilla.pol_encoder[0].mha.v_b.size()); - assert(embedding_op_size == weights.policy_heads.vanilla.pol_encoder[0].ffn.dense2_b.size()); + assert(encoder_d_model == vanilla.pol_encoder[0].mha.k_b.size()); + assert(encoder_d_model == vanilla.pol_encoder[0].mha.v_b.size()); + assert(embedding_op_size == vanilla.pol_encoder[0].ffn.dense2_b.size()); } - const size_t encoder_heads = weights.policy_heads.vanilla.pol_encoder_head_count; + const size_t encoder_heads = vanilla.pol_encoder_head_count; size_t size = N * 64 * @@ -84,7 +85,7 @@ static size_t getMaxAttentionHeadSize(const LegacyWeights& weights, int N) { return size; } -static size_t getMaxAttentionBodySize(const LegacyWeights& weights, int N) { +static size_t getMaxAttentionBodySize(const MultiHeadWeights& weights, int N) { const size_t embedding_op_size = weights.ip_emb_b.size(); size_t encoder_d_model = 0; @@ -188,18 +189,15 @@ class CudaNetwork : public Network { CudaNetwork(const WeightsFile& file, const OptionsDict& options) : capabilities_{file.format().network_format().input(), file.format().network_format().moves_left()} { - LegacyWeights weights(file.weights()); + MultiHeadWeights weights(file.weights()); gpu_id_ = options.GetOrDefault("gpu", 0); - conv_policy_ = file.format().network_format().policy() == - pblczero::NetworkFormat::POLICY_CONVOLUTION; - - attn_policy_ = file.format().network_format().policy() == - pblczero::NetworkFormat::POLICY_ATTENTION; - - // Mask out the multihead format bit 7. - attn_body_ = (file.format().network_format().network() & 127) == - pblczero::NetworkFormat::NETWORK_ATTENTIONBODY_WITH_HEADFORMAT; + const auto nf = file.format().network_format(); + using NF = pblczero::NetworkFormat; + conv_policy_ = nf.policy() == NF::POLICY_CONVOLUTION; + attn_policy_ = nf.policy() == NF::POLICY_ATTENTION; + attn_body_ = nf.network() == NF::NETWORK_ATTENTIONBODY_WITH_HEADFORMAT || + nf.network() == NF::NETWORK_ATTENTIONBODY_WITH_MULTIHEADFORMAT; max_batch_size_ = options.GetOrDefault("max_batch", 1024); // min_batch_size_ is chosen as 4 as it is common that for sizes less than @@ -360,32 +358,17 @@ class CudaNetwork : public Network { std::string policy_head = options.GetOrDefault("policy_head", "vanilla"); // Check that selected policy head exists. - if (attn_policy_) { - if ((policy_head == "vanilla" && weights.policy_heads.vanilla.ip2_pol_b.size() == 0) - || (policy_head == "optimistic" && weights.policy_heads.optimistic_st.ip2_pol_b.size() == 0) - || (policy_head == "soft" && weights.policy_heads.soft.ip2_pol_b.size() == 0) - || (policy_head != "vanilla" && policy_head != "optimistic" && policy_head != "soft")) { - throw Exception("The policy head you specified '" + policy_head + "'" - + " does not exist in this net."); - } - } else { - if ((policy_head == "vanilla" && weights.policy_heads.vanilla.policy.weights.size() == 0) - || (policy_head == "optimistic" && weights.policy_heads.optimistic_st.policy.weights.size() == 0) - || (policy_head == "soft" && weights.policy_heads.soft.policy.weights.size() == 0) - || (policy_head != "vanilla" && policy_head != "optimistic" && policy_head != "soft")) { - throw Exception("The policy head you specified '" + policy_head + "'" - + " does not exist in this net."); - } + if (policy_head == "optimistic") policy_head = "optimistic_st"; + if (weights.policy_heads.count(policy_head) == 0) { + throw Exception("The policy head you specified '" + policy_head + + "' does not exist in this net."); } std::string value_head = options.GetOrDefault("value_head", "winner"); // Check that selected value head exists. - if ((value_head == "winner" && weights.value_heads.winner.ip1_val_b.size() == 0) - || (value_head == "q" && weights.value_heads.q.ip1_val_b.size() == 0) - || (value_head == "st" && weights.value_heads.st.ip1_val_b.size() == 0) - || (value_head != "winner" && value_head != "q" && value_head != "st")) { - throw Exception("The value head you specified '" + value_head + "'" - + " does not exist in this net."); + if (weights.value_heads.count(value_head) == 0) { + throw Exception("The value head you specified '" + value_head + + "' does not exist in this net."); } // 2. Build the network, and copy the weights to GPU memory. @@ -480,81 +463,67 @@ class CudaNetwork : public Network { } // Policy head. - if (attn_policy_) { - auto AttentionPolicy = std::make_unique>( - getLastLayer(), weights, scratch_mem_, attn_body_, act, policy_head, - max_batch_size_); - network_.emplace_back(std::move(AttentionPolicy)); - - auto policymap = std::make_unique>( - getLastLayer(), kNumOutputPolicy, 1, 1, 64 * 64 + 8 * 24, true); - policymap->LoadWeights(kAttnPolicyMap, scratch_mem_); - network_.emplace_back(std::move(policymap)); - - } else { - // Selected head to construct, use vanilla as default head. - lczero::LegacyWeights::PolicyHead head = weights.policy_heads.vanilla; - if (policy_head == "optimistic") { - head = weights.policy_heads.optimistic_st; - } - else if (policy_head == "soft") { - head = weights.policy_heads.soft; - } - else if (policy_head == "opponent") { - head = weights.policy_heads.opponent; - } - if (conv_policy_) { - assert(!attn_body_); // not supported with attention body - auto conv1 = std::make_unique>( - resi_last_, kNumFilters, 8, 8, kNumFilters, act, true, false, false, - 0, use_gemm_ex); - conv1->LoadWeights(&head.policy1.weights[0], - &head.policy1.biases[0], scratch_mem_); - network_.emplace_back(std::move(conv1)); - - auto pol_channels = head.policy.biases.size(); - - // No relu - auto conv2 = std::make_unique>( - getLastLayer(), pol_channels, 8, 8, kNumFilters, ACTIVATION_NONE, - true, false, false, 0, use_gemm_ex); - conv2->LoadWeights(&head.policy.weights[0], &head.policy.biases[0], - scratch_mem_); - network_.emplace_back(std::move(conv2)); + { + MultiHeadWeights::PolicyHead& head = weights.policy_heads.at(policy_head); + if (attn_policy_) { + auto AttentionPolicy = std::make_unique>( + getLastLayer(), head, scratch_mem_, attn_body_, act, + max_batch_size_); + network_.emplace_back(std::move(AttentionPolicy)); auto policymap = std::make_unique>( - getLastLayer(), kNumOutputPolicy, 1, 1, 73 * 8 * 8, false); - policymap->LoadWeights(kConvPolicyMap, scratch_mem_); - + getLastLayer(), kNumOutputPolicy, 1, 1, 64 * 64 + 8 * 24, true); + policymap->LoadWeights(kAttnPolicyMap, scratch_mem_); network_.emplace_back(std::move(policymap)); + } else { - assert(!attn_body_); // not supported with attention body - auto convPol = std::make_unique>( - resi_last_, head.policy.biases.size(), 8, 8, kNumFilters, act, - true, use_gemm_ex); - convPol->LoadWeights(&head.policy.weights[0], - &head.policy.biases[0], scratch_mem_); - network_.emplace_back(std::move(convPol)); - - auto FCPol = std::make_unique>( - getLastLayer(), head.ip_pol_b.size(), 1, 1, true, ACTIVATION_NONE); - FCPol->LoadWeights(&head.ip_pol_w[0], &head.ip_pol_b[0], - scratch_mem_); - network_.emplace_back(std::move(FCPol)); + if (conv_policy_) { + assert(!attn_body_); // not supported with attention body + auto conv1 = std::make_unique>( + resi_last_, kNumFilters, 8, 8, kNumFilters, act, true, false, + false, 0, use_gemm_ex); + conv1->LoadWeights(&head.policy1.weights[0], &head.policy1.biases[0], + scratch_mem_); + network_.emplace_back(std::move(conv1)); + + auto pol_channels = head.policy.biases.size(); + + // No relu + auto conv2 = std::make_unique>( + getLastLayer(), pol_channels, 8, 8, kNumFilters, ACTIVATION_NONE, + true, false, false, 0, use_gemm_ex); + conv2->LoadWeights(&head.policy.weights[0], &head.policy.biases[0], + scratch_mem_); + network_.emplace_back(std::move(conv2)); + + auto policymap = std::make_unique>( + getLastLayer(), kNumOutputPolicy, 1, 1, 73 * 8 * 8, false); + policymap->LoadWeights(kConvPolicyMap, scratch_mem_); + + network_.emplace_back(std::move(policymap)); + } else { + assert(!attn_body_); // not supported with attention body + auto convPol = std::make_unique>( + resi_last_, head.policy.biases.size(), 8, 8, kNumFilters, act, + true, use_gemm_ex); + convPol->LoadWeights(&head.policy.weights[0], &head.policy.biases[0], + scratch_mem_); + network_.emplace_back(std::move(convPol)); + + auto FCPol = std::make_unique>( + getLastLayer(), head.ip_pol_b.size(), 1, 1, true, + ACTIVATION_NONE); + FCPol->LoadWeights(&head.ip_pol_w[0], &head.ip_pol_b[0], + scratch_mem_); + network_.emplace_back(std::move(FCPol)); + } } } // Value heads. { - // Selected head to construct, use value_winner as default head. - std::string value_head = options.GetOrDefault("value_head", "winner"); - LegacyWeights::ValueHead& head = weights.value_heads.winner; - if (value_head == "q") { - head = weights.value_heads.q; - } - else if (value_head == "st") { - head = weights.value_heads.st; - } + const MultiHeadWeights::ValueHead& head = + weights.value_heads.at(value_head); wdl_ = file.format().network_format().value() == pblczero::NetworkFormat::VALUE_WDL; BaseLayer* lastlayer = attn_body_ ? encoder_last_ : resi_last_; @@ -564,12 +533,11 @@ class CudaNetwork : public Network { ); network_.emplace_back(std::move(value_main)); - wdl_err_ = weights.value_heads.st.ip_val_err_b.size() > 0; + wdl_err_ = weights.value_heads.count("st") > 0; if (wdl_err_) { auto value_err = std::make_unique>( - lastlayer, weights.value_heads.st, scratch_mem_, attn_body_, - wdl_, true, act, max_batch_size_, use_gemm_ex - ); + lastlayer, weights.value_heads.at("st"), scratch_mem_, attn_body_, + wdl_, true, act, max_batch_size_, use_gemm_ex); network_.emplace_back(std::move(value_err)); } } @@ -1108,54 +1076,63 @@ std::unique_ptr MakeCudaNetwork(const std::optional& w, " backend requires a network file."); } const WeightsFile& weights = *w; - auto format = weights.format().network_format().network() & 127; - if ((weights.format().network_format().network() & 128) == 0 || - (format != pblczero::NetworkFormat::NETWORK_CLASSICAL_WITH_HEADFORMAT && - format != pblczero::NetworkFormat::NETWORK_SE_WITH_HEADFORMAT && - format != pblczero::NetworkFormat::NETWORK_ATTENTIONBODY_WITH_HEADFORMAT)) { - throw Exception("Network format " + - pblczero::NetworkFormat::NetworkStructure_Name( - weights.format().network_format().network()) + - " is not supported by the CUDA backend."); + auto nf = weights.format().network_format(); + using NF = pblczero::NetworkFormat; + switch (nf.network()) { + case NF::NETWORK_CLASSICAL_WITH_HEADFORMAT: + case NF::NETWORK_SE_WITH_HEADFORMAT: + case NF::NETWORK_ATTENTIONBODY_WITH_HEADFORMAT: + case NF::NETWORK_ATTENTIONBODY_WITH_MULTIHEADFORMAT: + break; + default: + throw Exception("Network format " + + NF::NetworkStructure_Name(nf.network()) + + " is not supported by the CUDA backend."); } - if (weights.format().network_format().policy() != - pblczero::NetworkFormat::POLICY_CLASSICAL && - weights.format().network_format().policy() != - pblczero::NetworkFormat::POLICY_CONVOLUTION && - weights.format().network_format().policy() != - pblczero::NetworkFormat::POLICY_ATTENTION) { - throw Exception("Policy format " + - pblczero::NetworkFormat::PolicyFormat_Name( - weights.format().network_format().policy()) + - " is not supported by the CUDA backend."); + switch (nf.policy()) { + case NF::POLICY_CLASSICAL: + case NF::POLICY_CONVOLUTION: + case NF::POLICY_ATTENTION: + break; + default: + throw Exception("Policy format " + NF::PolicyFormat_Name(nf.policy()) + + " is not supported by the CUDA backend."); } - if (weights.format().network_format().value() != - pblczero::NetworkFormat::VALUE_CLASSICAL && - weights.format().network_format().value() != - pblczero::NetworkFormat::VALUE_WDL) { - throw Exception("Value format " + - pblczero::NetworkFormat::ValueFormat_Name( - weights.format().network_format().value()) + - " is not supported by the CUDA backend."); + switch (nf.value()) { + case NF::VALUE_CLASSICAL: + case NF::VALUE_WDL: + break; + default: + throw Exception("Value format " + NF::ValueFormat_Name(nf.value()) + + " is not supported by the CUDA backend."); } - if (weights.format().network_format().moves_left() != - pblczero::NetworkFormat::MOVES_LEFT_NONE && - weights.format().network_format().moves_left() != - pblczero::NetworkFormat::MOVES_LEFT_V1) { - throw Exception("Moves left head format " + - pblczero::NetworkFormat::MovesLeftFormat_Name( - weights.format().network_format().moves_left()) + - " is not supported by the CUDA backend."); + switch (nf.moves_left()) { + case NF::MOVES_LEFT_NONE: + case NF::MOVES_LEFT_V1: + break; + default: + throw Exception("Moves left head format " + + NF::MovesLeftFormat_Name(nf.moves_left()) + + " is not supported by the CUDA backend."); } - if (weights.format().network_format().default_activation() != - pblczero::NetworkFormat::DEFAULT_ACTIVATION_RELU && - weights.format().network_format().default_activation() != - pblczero::NetworkFormat::DEFAULT_ACTIVATION_MISH) { - throw Exception( - "Default activation " + - pblczero::NetworkFormat::DefaultActivation_Name( - weights.format().network_format().default_activation()) + - " is not supported by the CUDA backend."); + switch (nf.default_activation()) { + case NF::DEFAULT_ACTIVATION_RELU: + case NF::DEFAULT_ACTIVATION_MISH: + break; + default: + throw Exception("Default activation " + + NF::DefaultActivation_Name(nf.default_activation()) + + " is not supported by the CUDA backend."); + } + switch (nf.input_embedding()) { + case NF::INPUT_EMBEDDING_NONE: + case NF::INPUT_EMBEDDING_PE_MAP: + case NF::INPUT_EMBEDDING_PE_DENSE: + break; + default: + throw Exception("Input embedding " + + NF::InputEmbeddingFormat_Name(nf.input_embedding()) + + " is not supported by the CUDA backend."); } return std::make_unique>(weights, options); } diff --git a/src/neural/cuda/network_cudnn.cc b/src/neural/cuda/network_cudnn.cc index eaf6a631c0..e0b1b4e21e 100644 --- a/src/neural/cuda/network_cudnn.cc +++ b/src/neural/cuda/network_cudnn.cc @@ -51,24 +51,25 @@ using namespace cudnn_backend; template class CudnnNetwork; -static size_t getMaxAttentionHeadSize(const LegacyWeights& weights, int N) { - const size_t embedding_op_size = weights.ip_pol_b.size(); - const size_t policy_d_model = weights.ip2_pol_b.size(); - assert(policy_d_model == weights.ip3_pol_b.size()); +static size_t getMaxAttentionHeadSize(const MultiHeadWeights& weights, int N) { + const auto vanilla = weights.policy_heads.at("vanilla"); + const size_t embedding_op_size = vanilla.ip_pol_b.size(); + const size_t policy_d_model = vanilla.ip2_pol_b.size(); + assert(policy_d_model == vanilla.ip3_pol_b.size()); size_t encoder_d_model = 0; size_t encoder_dff = 0; - if (weights.pol_encoder.size() > 0) { - encoder_d_model = weights.pol_encoder[0].mha.q_b.size(); - encoder_dff = weights.pol_encoder[0].ffn.dense1_b.size(); + if (vanilla.pol_encoder.size() > 0) { + encoder_d_model = vanilla.pol_encoder[0].mha.q_b.size(); + encoder_dff = vanilla.pol_encoder[0].ffn.dense1_b.size(); - assert(encoder_d_model == weights.pol_encoder[0].mha.k_b.size()); - assert(encoder_d_model == weights.pol_encoder[0].mha.v_b.size()); - assert(embedding_op_size == weights.pol_encoder[0].ffn.dense2_b.size()); + assert(encoder_d_model == vanilla.pol_encoder[0].mha.k_b.size()); + assert(encoder_d_model == vanilla.pol_encoder[0].mha.v_b.size()); + assert(embedding_op_size == vanilla.pol_encoder[0].ffn.dense2_b.size()); } - const size_t encoder_heads = weights.pol_encoder_head_count; + const size_t encoder_heads = vanilla.pol_encoder_head_count; size_t size = N * 64 * @@ -159,7 +160,7 @@ class CudnnNetwork : public Network { CudnnNetwork(const WeightsFile& file, const OptionsDict& options) : capabilities_{file.format().network_format().input(), file.format().network_format().moves_left()} { - LegacyWeights weights(file.weights()); + MultiHeadWeights weights(file.weights()); gpu_id_ = options.GetOrDefault("gpu", 0); conv_policy_ = file.format().network_format().policy() == @@ -519,69 +520,71 @@ class CudnnNetwork : public Network { resi_last_ = getLastLayer(); // Policy head. - if (attn_policy_) { - std::string policy_head = options.GetOrDefault("policy_head", "vanilla"); - auto AttentionPolicy = std::make_unique>( - getLastLayer(), weights, scratch_mem_, false, ACTIVATION_SELU, - policy_head, max_batch_size_); - network_.emplace_back(std::move(AttentionPolicy)); - - auto policymap = std::make_unique>( - getLastLayer(), kNumOutputPolicy, 1, 1, 64 * 64 + 8 * 24, true); - policymap->LoadWeights(kAttnPolicyMap, scratch_mem_); - network_.emplace_back(std::move(policymap)); - } else if (conv_policy_) { - auto conv1 = std::make_unique>( - resi_last_, kNumFilters, 8, 8, 3, kNumFilters, - mish_net ? ACTIVATION_MISH : ACTIVATION_RELU, true); - conv1->LoadWeights(&weights.policy1.weights[0], - &weights.policy1.biases[0], scratch_mem_); - network_.emplace_back(std::move(conv1)); + { + auto head = weights.policy_heads.at("vanilla"); + if (attn_policy_) { + auto AttentionPolicy = std::make_unique>( + getLastLayer(), head, scratch_mem_, false, ACTIVATION_SELU, + max_batch_size_); + network_.emplace_back(std::move(AttentionPolicy)); + + auto policymap = std::make_unique>( + getLastLayer(), kNumOutputPolicy, 1, 1, 64 * 64 + 8 * 24, true); + policymap->LoadWeights(kAttnPolicyMap, scratch_mem_); + network_.emplace_back(std::move(policymap)); + } else if (conv_policy_) { + auto conv1 = std::make_unique>( + resi_last_, kNumFilters, 8, 8, 3, kNumFilters, + mish_net ? ACTIVATION_MISH : ACTIVATION_RELU, true); + conv1->LoadWeights(&head.policy1.weights[0], &head.policy1.biases[0], + scratch_mem_); + network_.emplace_back(std::move(conv1)); + + auto pol_channels = head.policy.biases.size(); - auto pol_channels = weights.policy.biases.size(); + // No relu + auto conv2 = std::make_unique>( + getLastLayer(), pol_channels, 8, 8, 3, kNumFilters, ACTIVATION_NONE, + true); + conv2->LoadWeights(&head.policy.weights[0], &head.policy.biases[0], + scratch_mem_); + network_.emplace_back(std::move(conv2)); - // No relu - auto conv2 = std::make_unique>( - getLastLayer(), pol_channels, 8, 8, 3, kNumFilters, ACTIVATION_NONE, - true); - conv2->LoadWeights(&weights.policy.weights[0], &weights.policy.biases[0], - scratch_mem_); - network_.emplace_back(std::move(conv2)); + auto policymap = std::make_unique>( + getLastLayer(), kNumOutputPolicy, 1, 1, 73 * 8 * 8, false); + policymap->LoadWeights(kConvPolicyMap, scratch_mem_); - auto policymap = std::make_unique>( - getLastLayer(), kNumOutputPolicy, 1, 1, 73 * 8 * 8, false); - policymap->LoadWeights(kConvPolicyMap, scratch_mem_); + network_.emplace_back(std::move(policymap)); + } else { + auto convPol = std::make_unique>( + resi_last_, head.policy.biases.size(), 8, 8, 1, kNumFilters, + mish_net ? ACTIVATION_MISH : ACTIVATION_RELU, true); + convPol->LoadWeights(&head.policy.weights[0], &head.policy.biases[0], + scratch_mem_); + network_.emplace_back(std::move(convPol)); - network_.emplace_back(std::move(policymap)); - } else { - auto convPol = std::make_unique>( - resi_last_, weights.policy.biases.size(), 8, 8, 1, kNumFilters, - mish_net ? ACTIVATION_MISH : ACTIVATION_RELU, true); - convPol->LoadWeights(&weights.policy.weights[0], - &weights.policy.biases[0], scratch_mem_); - network_.emplace_back(std::move(convPol)); - - auto FCPol = std::make_unique>( - getLastLayer(), weights.ip_pol_b.size(), 1, 1, true, ACTIVATION_NONE); - FCPol->LoadWeights(&weights.ip_pol_w[0], &weights.ip_pol_b[0], - scratch_mem_); - network_.emplace_back(std::move(FCPol)); + auto FCPol = std::make_unique>( + getLastLayer(), head.ip_pol_b.size(), 1, 1, true, ACTIVATION_NONE); + FCPol->LoadWeights(&head.ip_pol_w[0], &head.ip_pol_b[0], scratch_mem_); + network_.emplace_back(std::move(FCPol)); + } + policy_out_ = getLastLayer(); } - policy_out_ = getLastLayer(); // Value head. { + auto& head = weights.value_heads.at("winner"); auto convVal = std::make_unique>( - resi_last_, weights.value.biases.size(), 8, 8, 1, kNumFilters, + resi_last_, head.value.biases.size(), 8, 8, 1, kNumFilters, mish_net ? ACTIVATION_MISH : ACTIVATION_RELU, true); - convVal->LoadWeights(&weights.value.weights[0], &weights.value.biases[0], + convVal->LoadWeights(&head.value.weights[0], &head.value.biases[0], scratch_mem_); network_.emplace_back(std::move(convVal)); auto FCVal1 = std::make_unique>( - getLastLayer(), weights.ip1_val_b.size(), 1, 1, true, + getLastLayer(), head.ip1_val_b.size(), 1, 1, true, mish_net ? ACTIVATION_MISH : ACTIVATION_RELU); - FCVal1->LoadWeights(&weights.ip1_val_w[0], &weights.ip1_val_b[0], + FCVal1->LoadWeights(&head.ip1_val_w[0], &head.ip1_val_b[0], scratch_mem_); network_.emplace_back(std::move(FCVal1)); @@ -590,9 +593,9 @@ class CudnnNetwork : public Network { auto fc2_tanh = !wdl_; auto FCVal2 = std::make_unique>( - getLastLayer(), weights.ip2_val_b.size(), 1, 1, true, + getLastLayer(), head.ip2_val_b.size(), 1, 1, true, fc2_tanh ? ACTIVATION_TANH : ACTIVATION_NONE); - FCVal2->LoadWeights(&weights.ip2_val_w[0], &weights.ip2_val_b[0], + FCVal2->LoadWeights(&head.ip2_val_w[0], &head.ip2_val_b[0], scratch_mem_); network_.emplace_back(std::move(FCVal2)); } From c38b568c66c531aaa516a9575fbbcd3c87527ad9 Mon Sep 17 00:00:00 2001 From: Aniebiet Udoh Date: Sat, 2 Mar 2024 17:50:28 +0100 Subject: [PATCH 48/70] File formatting. --- src/neural/cuda/layers.cc | 155 +++++++++++++++++--------------- src/neural/cuda/network_cuda.cc | 34 +++---- 2 files changed, 103 insertions(+), 86 deletions(-) diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index c55988ff43..ea8a12109f 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -1678,8 +1678,8 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, num_inputs, 0.0f, buffer1, num_outputs); LayerNorm(batch, num_outputs, scratch, buffer1, smol_dense1_b, - (DataType*)nullptr, smol_ln1_gammas, smol_ln1_betas, 1e-3, - 1.0, smolgen_activation_, stream); + (DataType*)nullptr, smol_ln1_gammas, smol_ln1_betas, + 1e-3, 1.0, smolgen_activation_, stream); } { @@ -1695,8 +1695,8 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, num_inputs, 0.0f, buffer1, num_outputs); LayerNorm(batch, num_outputs, scratch, buffer1, smol_dense2_b, - (DataType*)nullptr, smol_ln2_gammas, smol_ln2_betas, 1e-3, - 1.0, smolgen_activation_, stream); + (DataType*)nullptr, smol_ln2_gammas, smol_ln2_betas, + 1e-3, 1.0, smolgen_activation_, stream); } { @@ -2062,21 +2062,26 @@ AttentionBody::AttentionBody(const MultiHeadWeights& weights, allocAndUpload(&ip_emb_ln_g_, weights.ip_emb_ln_gammas, scratch); allocAndUpload(&ip_emb_ln_b_, weights.ip_emb_ln_betas, scratch); - allocAndUpload(&ip_emb_ffn_d1_w_, weights.ip_emb_ffn.dense1_w, scratch); - allocAndUpload(&ip_emb_ffn_d1_b_, weights.ip_emb_ffn.dense1_b, scratch); + allocAndUpload(&ip_emb_ffn_d1_w_, weights.ip_emb_ffn.dense1_w, + scratch); + allocAndUpload(&ip_emb_ffn_d1_b_, weights.ip_emb_ffn.dense1_b, + scratch); - allocAndUpload(&ip_emb_ffn_d2_w_, weights.ip_emb_ffn.dense2_w, scratch); - allocAndUpload(&ip_emb_ffn_d2_b_, weights.ip_emb_ffn.dense2_b, scratch); + allocAndUpload(&ip_emb_ffn_d2_w_, weights.ip_emb_ffn.dense2_w, + scratch); + allocAndUpload(&ip_emb_ffn_d2_b_, weights.ip_emb_ffn.dense2_b, + scratch); - allocAndUpload(&ip_emb_ffn_ln_g_, weights.ip_emb_ffn_ln_gammas, scratch); - allocAndUpload(&ip_emb_ffn_ln_b_, weights.ip_emb_ffn_ln_betas, scratch); + allocAndUpload(&ip_emb_ffn_ln_g_, weights.ip_emb_ffn_ln_gammas, + scratch); + allocAndUpload(&ip_emb_ffn_ln_b_, weights.ip_emb_ffn_ln_betas, + scratch); // 12 is the number of input channels used for the input encoding. embedding_dense_size_ = weights.ip_emb_preproc_b.size() / 64; embedding_ffn_size_ = weights.ip_emb_ffn.dense2_b.size(); embedding_ffn_dff_ = weights.ip_emb_ffn.dense1_b.size(); - } - else { + } else { size_t size = 64 * kNumPosEncodingChannels * sizeof(float); ReportCUDAErrors(cudaMalloc(&pos_encoding_, size)); ReportCUDAErrors( @@ -2120,8 +2125,7 @@ AttentionBody::~AttentionBody() { ReportCUDAErrors(cudaFree(ip_emb_ffn_d2_b_)); ReportCUDAErrors(cudaFree(ip_emb_ffn_ln_g_)); ReportCUDAErrors(cudaFree(ip_emb_ffn_ln_b_)); - } - else { + } else { ReportCUDAErrors(cudaFree(pos_encoding_)); } if (has_gating_) { @@ -2153,10 +2157,10 @@ void AttentionBody::Eval(int N, DataType* output, processing */ if (new_encoding_) { - // New encoding is made of dense layer fed with input from a 12-channel slice of the input tensor. - // pos_info = flow[..., :12] - // pos_info_flat = tf.reshape(pos_info, [-1, 64 * 12]) - // pos_info_processed = tf.keras.layers.Dense(64*self.embedding_dense_sz, + // New encoding is made of dense layer fed with input from a 12-channel + // slice of the input tensor. pos_info = flow[..., :12] pos_info_flat = + // tf.reshape(pos_info, [-1, 64 * 12]) pos_info_processed = + // tf.keras.layers.Dense(64*self.embedding_dense_sz, // name=name+"embedding/preprocess")(pos_info_flat) const int num_outputs = 64 * embedding_dense_size_; const int num_inputs = 64 * 12; @@ -2164,21 +2168,21 @@ void AttentionBody::Eval(int N, DataType* output, convertNCHWtoNHWC((DataType*)scratch, input, N, inputC, N, 12, 8, 8); cublasXgemm( - cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, - (const DataType*)ip_emb_pre_w_, num_inputs, - (const DataType*)scratch, num_inputs, - 0.0f, buffer1, num_outputs); + cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, + 1.0f, (const DataType*)ip_emb_pre_w_, num_inputs, + (const DataType*)scratch, num_inputs, 0.0f, buffer1, num_outputs); // addBiasBatched(buffer1, buffer1, ip_emb_pre_b_, batch, N, num_outputs, // ACTIVATION_NONE, stream); const int size = num_outputs * N; // @todo addBiasBatched has a 4096 channel limit, needs refactoring. - addVectors(buffer1, buffer1, ip_emb_pre_b_, size, size, num_outputs, ACTIVATION_NONE, stream); - inputPreprocessForAttentionBody((DataType*)scratch, input, buffer1, N, kInputPlanes, - embedding_dense_size_, true, stream); + addVectors(buffer1, buffer1, ip_emb_pre_b_, size, size, num_outputs, + ACTIVATION_NONE, stream); + inputPreprocessForAttentionBody((DataType*)scratch, input, buffer1, N, + kInputPlanes, embedding_dense_size_, true, + stream); inputC += embedding_dense_size_; - } - else { + } else { /* flow = tf.transpose(inputs, perm=[0, 2, 3, 1]) flow = tf.reshape(flow, [-1, 64, tf.shape(inputs)[1]]) @@ -2188,8 +2192,9 @@ void AttentionBody::Eval(int N, DataType* output, tf.shape(self.POS_ENC)[2]]) flow = tf.concat([flow, positional_encoding], axis=2) */ - inputPreprocessForAttentionBody((DataType*)scratch, input, pos_encoding_, N, - kInputPlanes, kNumPosEncodingChannels, false, stream); + inputPreprocessForAttentionBody((DataType*)scratch, input, pos_encoding_, + N, kInputPlanes, kNumPosEncodingChannels, + false, stream); inputC += kNumPosEncodingChannels; } } else { @@ -2208,20 +2213,21 @@ void AttentionBody::Eval(int N, DataType* output, const int num_outputs = embedding_op_size_; const int num_inputs = inputC; const int batch = N * 64; - cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, (const DataType*)ip_emb_w_, - num_inputs, temp, num_inputs, 0.0f, - embedding, num_outputs); + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, + batch, num_inputs, 1.0f, (const DataType*)ip_emb_w_, + num_inputs, temp, num_inputs, 0.0f, embedding, + num_outputs); // embedding layer norm with fused in bias add of previous gemm. - LayerNorm(N * 64, embedding_op_size_, temp, embedding, ip_emb_b_, - (DataType*)nullptr, ip_emb_ln_g_, ip_emb_ln_b_, 1e-3, 1.0, - activations_.default_activation, stream); + LayerNorm(N * 64, embedding_op_size_, temp, embedding, + ip_emb_b_, (DataType*)nullptr, ip_emb_ln_g_, + ip_emb_ln_b_, 1e-3, 1.0, + activations_.default_activation, stream); } // Input gating if (has_gating_) { - applyInputGating(temp, temp, ip_mult_gate_, - ip_add_gate_, N, 64, embedding_op_size_, stream); + applyInputGating(temp, temp, ip_mult_gate_, ip_add_gate_, N, 64, + embedding_op_size_, stream); } // embedding FFN dense 1 @@ -2230,10 +2236,10 @@ void AttentionBody::Eval(int N, DataType* output, const int num_outputs = embedding_ffn_dff_; // encoder_dff const int batch = N * 64; cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, (const DataType*)ip_emb_ffn_d1_w_, num_inputs, - temp, num_inputs, 0.0f, buffer1, num_outputs); - addBiasBatched(buffer1, buffer1, ip_emb_ffn_d1_b_, 1, batch, - num_outputs, activations_.ffn_activation, stream); + num_inputs, 1.0f, (const DataType*)ip_emb_ffn_d1_w_, + num_inputs, temp, num_inputs, 0.0f, buffer1, num_outputs); + addBiasBatched(buffer1, buffer1, ip_emb_ffn_d1_b_, 1, batch, num_outputs, + activations_.ffn_activation, stream); } // embedding FFN dense 2 @@ -2242,14 +2248,15 @@ void AttentionBody::Eval(int N, DataType* output, const int num_outputs = embedding_ffn_size_; const int batch = N * 64; cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, (const DataType*)ip_emb_ffn_d2_w_, num_inputs, - buffer1, num_inputs, 0.0f, buffer2, num_outputs); - // Embedding LN: skip connection and layer normilization (also bias add of prev gemm) - // buffer2 -> embedding + num_inputs, 1.0f, (const DataType*)ip_emb_ffn_d2_w_, + num_inputs, buffer1, num_inputs, 0.0f, buffer2, num_outputs); + // Embedding LN: skip connection and layer normilization (also bias add of + // prev gemm) buffer2 -> embedding float alpha = (float)pow(2. * encoder_weights_.size(), -0.25); LayerNorm(N * 64, embedding_ffn_size_, embedding, buffer2, - ip_emb_ffn_d2_b_, temp, ip_emb_ffn_ln_g_, ip_emb_ffn_ln_b_, - 1e-3, alpha, ACTIVATION_NONE, stream); + ip_emb_ffn_d2_b_, temp, ip_emb_ffn_ln_g_, + ip_emb_ffn_ln_b_, 1e-3, alpha, ACTIVATION_NONE, + stream); } } else { @@ -2260,17 +2267,18 @@ void AttentionBody::Eval(int N, DataType* output, const int num_outputs = embedding_op_size_; const int num_inputs = inputC; const int batch = N * 64; - cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, (const DataType*)ip_emb_w_, + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, + batch, num_inputs, 1.0f, (const DataType*)ip_emb_w_, num_inputs, (DataType*)scratch, num_inputs, 0.0f, embedding, num_outputs); addBiasBatched(embedding, embedding, ip_emb_b_, 1, batch, num_outputs, - activations_.default_activation, stream); + activations_.default_activation, stream); } // Input gating if (has_gating_) { applyInputGating(embedding, embedding, ip_mult_gate_, - ip_add_gate_, N, 64, embedding_op_size_, stream); + ip_add_gate_, N, 64, embedding_op_size_, + stream); } } @@ -2300,10 +2308,10 @@ ValueHead::ValueHead(BaseLayer* ip, allocAndUpload(&ip_val_b_, weights.ip_val_b, scratch); } else { conv_ = std::make_unique>( - ip, weights.value.biases.size(), 8, 8, ip->GetC(), act, - true, use_gemm_ex); + ip, weights.value.biases.size(), 8, 8, ip->GetC(), act, true, + use_gemm_ex); conv_->LoadWeights((float*)&weights.value.weights[0], - (float*)&weights.value.biases[0], scratch); + (float*)&weights.value.biases[0], scratch); } allocAndUpload(&ip1_val_w_, weights.ip1_val_w, scratch); @@ -2346,14 +2354,16 @@ void ValueHead::Eval(int N, DataType* output, const DataType* input, const int num_outputs = embedding_size_; const int batch = N * 64; if (attention_body_) { - cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, (const DataType*)ip_val_w_, num_inputs, - input, num_inputs, 0.0f, buffer, num_outputs); - addBiasBatched(buffer, buffer, ip_val_b_, 1, batch, num_outputs, act_, stream); + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, + batch, num_inputs, 1.0f, (const DataType*)ip_val_w_, + num_inputs, input, num_inputs, 0.0f, buffer, + num_outputs); + addBiasBatched(buffer, buffer, ip_val_b_, 1, batch, num_outputs, + act_, stream); } else { - conv_->Eval(N, buffer, input, nullptr, scratch, - scratch_size, nullptr, cublas, stream); + conv_->Eval(N, buffer, input, nullptr, scratch, scratch_size, nullptr, + cublas, stream); } } @@ -2364,8 +2374,9 @@ void ValueHead::Eval(int N, DataType* output, const DataType* input, const int batch = N; DataType* layer_out = (DataType*)scratch; cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, (const DataType*)ip1_val_w_, num_inputs, - buffer, num_inputs, 0.0f, layer_out, num_outputs); + num_inputs, 1.0f, (const DataType*)ip1_val_w_, + num_inputs, buffer, num_inputs, 0.0f, layer_out, + num_outputs); addBiasBatched(layer_out, layer_out, ip1_val_b_, 1, batch, num_outputs, act_, stream); } @@ -2377,11 +2388,12 @@ void ValueHead::Eval(int N, DataType* output, const DataType* input, const int batch = N; DataType* layer_out = wdl_err_ ? (DataType*)buffer : (DataType*)output; cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, (const DataType*)ip2_val_w_, num_inputs, - (DataType*)scratch, num_inputs, 0.0f, layer_out, num_outputs); + num_inputs, 1.0f, (const DataType*)ip2_val_w_, + num_inputs, (DataType*)scratch, num_inputs, 0.0f, + layer_out, num_outputs); addVectors(layer_out, layer_out, ip2_val_b_, num_outputs * batch, - num_outputs * batch, num_outputs, wdl_ ? ACTIVATION_NONE : ACTIVATION_TANH, - stream); + num_outputs * batch, num_outputs, + wdl_ ? ACTIVATION_NONE : ACTIVATION_TANH, stream); } if (wdl_err_) { @@ -2390,10 +2402,11 @@ void ValueHead::Eval(int N, DataType* output, const DataType* input, const int num_outputs = 1; const int batch = N; cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, (const DataType*)ip_val_err_w_, num_inputs, - (DataType*)scratch, num_inputs, 0.0f, output, num_outputs); - addVectors(output, output, ip_val_err_b_, N, - N, 1, ACTIVATION_SIGMOID, stream); + num_inputs, 1.0f, (const DataType*)ip_val_err_w_, + num_inputs, (DataType*)scratch, num_inputs, 0.0f, + output, num_outputs); + addVectors(output, output, ip_val_err_b_, N, N, 1, ACTIVATION_SIGMOID, + stream); } } diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index bbad91a0a1..ca15008a27 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -121,8 +121,8 @@ static size_t getMaxAttentionBodySize(const MultiHeadWeights& weights, int N) { template class CudaNetworkComputation : public NetworkComputation { public: - CudaNetworkComputation(CudaNetwork* network, bool wdl, - bool wdl_err, bool moves_left); + CudaNetworkComputation(CudaNetwork* network, bool wdl, bool wdl_err, + bool moves_left); ~CudaNetworkComputation(); void AddInput(InputPlanes&& input) override { @@ -356,7 +356,8 @@ class CudaNetwork : public Network { ActivationFunction act = mish_net ? ACTIVATION_MISH : ACTIVATION_RELU; - std::string policy_head = options.GetOrDefault("policy_head", "vanilla"); + std::string policy_head = + options.GetOrDefault("policy_head", "vanilla"); // Check that selected policy head exists. if (policy_head == "optimistic") policy_head = "optimistic_st"; if (weights.policy_heads.count(policy_head) == 0) { @@ -364,7 +365,8 @@ class CudaNetwork : public Network { "' does not exist in this net."); } - std::string value_head = options.GetOrDefault("value_head", "winner"); + std::string value_head = + options.GetOrDefault("value_head", "winner"); // Check that selected value head exists. if (weights.value_heads.count(value_head) == 0) { throw Exception("The value head you specified '" + value_head + @@ -450,9 +452,10 @@ class CudaNetwork : public Network { : static_cast(ffn_activation); activations.default_activation = act; - auto new_encoding = static_cast( - file.format().network_format().input_embedding()) - == InputEmbedding::INPUT_EMBEDDING_PE_DENSE; + auto new_encoding = + static_cast( + file.format().network_format().input_embedding()) == + InputEmbedding::INPUT_EMBEDDING_PE_DENSE; auto attention_body = std::make_unique>( weights, scratch_mem_, activations, numBlocks_, numBlocks_ > 0 ? kNumFilters : kInputPlanes, max_batch_size_, @@ -528,9 +531,8 @@ class CudaNetwork : public Network { pblczero::NetworkFormat::VALUE_WDL; BaseLayer* lastlayer = attn_body_ ? encoder_last_ : resi_last_; auto value_main = std::make_unique>( - lastlayer, head, scratch_mem_, attn_body_, wdl_, false, - act, max_batch_size_, use_gemm_ex - ); + lastlayer, head, scratch_mem_, attn_body_, wdl_, false, act, + max_batch_size_, use_gemm_ex); network_.emplace_back(std::move(value_main)); wdl_err_ = weights.value_heads.count("st") > 0; @@ -920,16 +922,16 @@ class CudaNetwork : public Network { // Set correct gpu id for this computation (as it might have been called // from a different thread). ReportCUDAErrors(cudaSetDevice(gpu_id_)); - return std::make_unique>(this, wdl_, - wdl_err_, - moves_left_); + return std::make_unique>( + this, wdl_, wdl_err_, moves_left_); } std::unique_ptr GetInputsOutputs() { std::lock_guard lock(inputs_outputs_lock_); if (free_inputs_outputs_.empty()) { return std::make_unique( - max_batch_size_, wdl_, wdl_err_, moves_left_, tensor_mem_size_, scratch_size_, + max_batch_size_, wdl_, wdl_err_, moves_left_, tensor_mem_size_, + scratch_size_, !has_tensor_cores_ && std::is_same::value); } else { std::unique_ptr resource = @@ -947,7 +949,9 @@ class CudaNetwork : public Network { // Apparently nvcc doesn't see constructor invocations through make_unique. // This function invokes constructor just to please complier and silence // warning. Is never called (but compiler thinks that it could). - void UglyFunctionToSilenceNvccWarning() { InputsOutputs io(0, false, false, false); } + void UglyFunctionToSilenceNvccWarning() { + InputsOutputs io(0, false, false, false); + } private: const NetworkCapabilities capabilities_; From eb26621dbb95c056b6010cd6433781bc0f01fe04 Mon Sep 17 00:00:00 2001 From: Aniebiet Udoh Date: Sat, 2 Mar 2024 19:11:32 +0100 Subject: [PATCH 49/70] Fix layernorm epsilon for older attentionbody nets. --- src/neural/cuda/layers.cc | 23 ++++++++++++++--------- src/neural/cuda/layers.h | 3 ++- src/neural/cuda/network_cudnn.cc | 4 ++-- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index ea8a12109f..4c78e5b389 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -1472,8 +1472,11 @@ AttentionPolicyHead::AttentionPolicyHead( EncoderBlock* pW = new EncoderBlock( enc, scratch, encoder_heads_, embedding_op_size_, 1.0f, // using alpha = 1 for now (TODO: may change?) - nullptr, 0, max_batch_size, ACTIVATION_SWISH, - act_); // smolgen weights not implemented in policy encoder heads yet. + nullptr, 0, max_batch_size, + ACTIVATION_SWISH, // smolgen weights not implemented in policy encoder + // heads yet. + act_, 1e-6); // attentionbody nets don't have encoders, so using old + // epsilon for backward compatibility. encoder_weights_.emplace_back(pW); } } @@ -1483,14 +1486,15 @@ EncoderBlock::EncoderBlock( const MultiHeadWeights::EncoderLayer& cpu_weights, void* scratch, int heads, int size, float alpha, DataType* smolgen_global_scratch, int smolgen_global_size, int max_batch_size, ActivationFunction smolgen_act, - ActivationFunction ffn_act) + ActivationFunction ffn_act, float default_eps) : embedding_op_size_(size), encoder_heads_(heads), alpha_(alpha), has_smolgen_(cpu_weights.mha.has_smolgen), smolgen_activation_(smolgen_act), ffn_activation_(ffn_act), - max_batch_size_(max_batch_size) { + max_batch_size_(max_batch_size), + default_eps_(default_eps) { mha_q_size_ = cpu_weights.mha.q_b.size(); mha_k_size_ = cpu_weights.mha.k_b.size(); mha_v_size_ = cpu_weights.mha.v_b.size(); @@ -1847,8 +1851,8 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // LN1: skip connection and layer normalization (also bias add of prev gemm) // buffer1/in_out_tensor -> scratch LayerNorm(N * 64, embedding_op_size_, scratch, buffer1, mha_dense_b, - in_out_tensor, ln1_gammas, ln1_betas, 1e-3, alpha_, - ACTIVATION_NONE, stream); + in_out_tensor, ln1_gammas, ln1_betas, default_eps_, + alpha_, ACTIVATION_NONE, stream); // #FFN dense 1, scratch -> in_out_tensor { @@ -1875,8 +1879,8 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // LN2: skip connection and layer normilization (also bias add of prev gemm) // buffer1/scratch -> in_out_tensor LayerNorm(N * 64, embedding_op_size_, in_out_tensor, buffer1, - ffn_dense2_b, scratch, ln2_gammas, ln2_betas, 1e-3, - alpha_, ACTIVATION_NONE, stream); + ffn_dense2_b, scratch, ln2_gammas, ln2_betas, + default_eps_, alpha_, ACTIVATION_NONE, stream); } template @@ -2105,7 +2109,8 @@ AttentionBody::AttentionBody(const MultiHeadWeights& weights, EncoderBlock* pW = new EncoderBlock( enc, scratch, encoder_head_count_, embedding_op_size_, alpha, smolgen_global_, smolgen_global_size_, max_batch_size, - activations_.smolgen_activation, activations_.ffn_activation); + activations_.smolgen_activation, activations_.ffn_activation, + new_encoding_ ? 1e-3 : 1e-6); encoder_weights_.emplace_back(pW); } } diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index f685f2e228..d88af74901 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -340,7 +340,7 @@ class EncoderBlock { int heads, int size, float alpha, DataType* smolgen_global_scratch, int smolgen_global_size, int max_batch_size, ActivationFunction smolgen_act, - ActivationFunction ffn_act); + ActivationFunction ffn_act, float default_eps); ~EncoderBlock(); void Eval(int N, DataType* inpop, DataType* scratch0, DataType* scratch1, @@ -380,6 +380,7 @@ class EncoderBlock { int encoder_heads_; float alpha_; // scale to apply to skip connection add + float default_eps_; // value of epsilon where it wasn't specified in training const bool has_smolgen_; const ActivationFunction smolgen_activation_; diff --git a/src/neural/cuda/network_cudnn.cc b/src/neural/cuda/network_cudnn.cc index e0b1b4e21e..a5a3f59151 100644 --- a/src/neural/cuda/network_cudnn.cc +++ b/src/neural/cuda/network_cudnn.cc @@ -521,7 +521,7 @@ class CudnnNetwork : public Network { // Policy head. { - auto head = weights.policy_heads.at("vanilla"); + MultiHeadWeights::PolicyHead& head = weights.policy_heads.at("vanilla"); if (attn_policy_) { auto AttentionPolicy = std::make_unique>( getLastLayer(), head, scratch_mem_, false, ACTIVATION_SELU, @@ -573,7 +573,7 @@ class CudnnNetwork : public Network { // Value head. { - auto& head = weights.value_heads.at("winner"); + MultiHeadWeights::ValueHead& head = weights.value_heads.at("winner"); auto convVal = std::make_unique>( resi_last_, head.value.biases.size(), 8, 8, 1, kNumFilters, mish_net ? ACTIVATION_MISH : ACTIVATION_RELU, true); From 8a2009d13594e30d4e2899e711f121d9cdc89304 Mon Sep 17 00:00:00 2001 From: Aniebiet Udoh Date: Sat, 2 Mar 2024 21:57:55 +0100 Subject: [PATCH 50/70] Minor comment fixes. --- src/neural/cuda/layers.cc | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 4c78e5b389..1b60cbb21b 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -1471,12 +1471,12 @@ AttentionPolicyHead::AttentionPolicyHead( for (const auto& enc : weights.pol_encoder) { EncoderBlock* pW = new EncoderBlock( enc, scratch, encoder_heads_, embedding_op_size_, - 1.0f, // using alpha = 1 for now (TODO: may change?) - nullptr, 0, max_batch_size, - ACTIVATION_SWISH, // smolgen weights not implemented in policy encoder - // heads yet. - act_, 1e-6); // attentionbody nets don't have encoders, so using old - // epsilon for backward compatibility. + 1.0f, // using alpha = 1 for now (TODO: may change?) + nullptr, 0, // smolgen weights not implemented in + // policy encoder heads yet. + max_batch_size, ACTIVATION_SWISH, act_, + 1e-6); // attentionbody nets don't have policy encoders, so using old + // epsilon for backward compatibility with T78. encoder_weights_.emplace_back(pW); } } @@ -2163,9 +2163,10 @@ void AttentionBody::Eval(int N, DataType* output, */ if (new_encoding_) { // New encoding is made of dense layer fed with input from a 12-channel - // slice of the input tensor. pos_info = flow[..., :12] pos_info_flat = - // tf.reshape(pos_info, [-1, 64 * 12]) pos_info_processed = - // tf.keras.layers.Dense(64*self.embedding_dense_sz, + // slice of the input tensor. + // pos_info = flow[..., :12] + // pos_info_flat = tf.reshape(pos_info, [-1, 64 * 12]) + // pos_info_processed = tf.keras.layers.Dense(64*self.embedding_dense_sz, // name=name+"embedding/preprocess")(pos_info_flat) const int num_outputs = 64 * embedding_dense_size_; const int num_inputs = 64 * 12; From 4ffee572bec5eecc14856c3b375ddf49f4b25485 Mon Sep 17 00:00:00 2001 From: Aniebiet Udoh Date: Sun, 3 Mar 2024 12:54:42 +0100 Subject: [PATCH 51/70] Change 'optimistic_st' key to 'optimistic' in policy head map. --- src/neural/cuda/network_cuda.cc | 1 - src/neural/network_legacy.cc | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index ca15008a27..e0eb728710 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -359,7 +359,6 @@ class CudaNetwork : public Network { std::string policy_head = options.GetOrDefault("policy_head", "vanilla"); // Check that selected policy head exists. - if (policy_head == "optimistic") policy_head = "optimistic_st"; if (weights.policy_heads.count(policy_head) == 0) { throw Exception("The policy head you specified '" + policy_head + "' does not exist in this net."); diff --git a/src/neural/network_legacy.cc b/src/neural/network_legacy.cc index b2f4a1b510..53846353c6 100644 --- a/src/neural/network_legacy.cc +++ b/src/neural/network_legacy.cc @@ -242,7 +242,7 @@ MultiHeadWeights::MultiHeadWeights(const pblczero::Weights& weights) if (weights.has_policy_heads()) { if (weights.policy_heads().has_optimistic_st()) { policy_heads.emplace( - std::piecewise_construct, std::forward_as_tuple("optimistic_st"), + std::piecewise_construct, std::forward_as_tuple("optimistic"), std::forward_as_tuple(weights.policy_heads().optimistic_st(), ip_pol_w, ip_pol_b)); } From ee8133605a72e481b5e785070f7c2742db518d00 Mon Sep 17 00:00:00 2001 From: Aniebiet Udoh Date: Sun, 3 Mar 2024 13:24:25 +0100 Subject: [PATCH 52/70] Switch cudnn to cuda for multiheadformat. --- src/neural/cuda/network_cudnn.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/neural/cuda/network_cudnn.cc b/src/neural/cuda/network_cudnn.cc index a5a3f59151..89c6d82b74 100644 --- a/src/neural/cuda/network_cudnn.cc +++ b/src/neural/cuda/network_cudnn.cc @@ -1096,6 +1096,7 @@ std::unique_ptr MakeCudnnNetwork(const std::optional& w, case pblczero::NetworkFormat::NETWORK_SE_WITH_HEADFORMAT: break; case pblczero::NetworkFormat::NETWORK_ATTENTIONBODY_WITH_HEADFORMAT: + case pblczero::NetworkFormat::NETWORK_ATTENTIONBODY_WITH_MULTIHEADFORMAT: CERR << "Network format not supported by CuDNN backend, switching to " "CUDA."; return NetworkFactory::Get()->Create( From 58ae0ae4cff7a1387a4aaad709ae1c86c27e8b83 Mon Sep 17 00:00:00 2001 From: Aniebiet Udoh Date: Mon, 11 Mar 2024 13:19:48 +0100 Subject: [PATCH 53/70] Fix buffer naming, fix source to build. --- src/neural/cuda/cutlass_kernels.cu | 17 +-- src/neural/cuda/layers.cc | 161 ++++++++++++----------------- src/neural/cuda/layers.h | 5 +- 3 files changed, 77 insertions(+), 106 deletions(-) diff --git a/src/neural/cuda/cutlass_kernels.cu b/src/neural/cuda/cutlass_kernels.cu index 3c904054f7..1abd7f42d8 100644 --- a/src/neural/cuda/cutlass_kernels.cu +++ b/src/neural/cuda/cutlass_kernels.cu @@ -25,13 +25,14 @@ Program grant you additional permission to convey the resulting work. */ -#include "cuda_common.h" -#include -#include "winograd_helper.inc" - #include #include +#include + +#include "cuda_common.h" +#include "neural/shared/activation.h" +#include "winograd_helper.inc" #ifdef USE_CUTLASS @@ -252,8 +253,10 @@ void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, cutlass::epilogue::thread::LinearCombinationBiasElementwise< ElementIO, ElementAccumulator, ElementComputeEpilogue, ElementIO, ElementIO, - ElementScale, // element Vector - elementsPerAccess, false, + // ElementScale, // element Vector + elementsPerAccess, + // false, + cutlass::epilogue::thread::Identity, cutlass::multiplies>; using Gemm = cutlass::gemm::device::GemmUniversalWithBroadcast< @@ -923,7 +926,7 @@ void deQuantizeOutputMatrixBiasAdd(half* output, const int8_t* input, int height, int width, int batchSize, float* invScale, float *deq, const half* bias, cudaStream_t stream, - ActivationFunction act = NONE) { + ActivationFunction act = ACTIVATION_NONE) { dim3 blockDim(16, 16); dim3 gridDim(lczero::cudnn_backend::DivUp(width, 16 * 8), lczero::cudnn_backend::DivUp(height, 16), batchSize); diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index f6e3d1ac96..04a7103aa0 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -1783,37 +1783,6 @@ static void cublasXGemmBatched(cublasHandle_t handle, cublasOperation_t transa, } } -// input/output tensor is in_out_tensor, others are used as scratch. - -template -void calibrateGemmForInt8(int8_t* weights_int8, float* input_scaling_factors, - float* output_scaling_factors, - float* output_deq_factors, float* maxValuesA, - float* maxValuesOut, const DataType* A, const DataType* B, int M, int N, - int K, int batchSize, int M_Batch); - -void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int8_t* Out, - int M, int N, int K, int batchSize, - int AStride, int BStride, int OutStride, - float alphaf, float betaf); - -void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, - const float* scaleVector, int8_t* Out, int M, - int N, int K, int batchSize, int AStride, - int BStride, int OutStride, int VecStride, - float alphaf, float betaf); - -void quantizeActivationMatrix(int8_t* output, const half* input, int height, - int width, const float* scale, - cudaStream_t stream); - -void deQuantizeOutputMatrixBiasAdd(half* output, const int8_t* input, - int height, int width, int batchSize, - float* scale, float *deq, const half* bias, - cudaStream_t stream, ActivationFunction act = NONE); - -// input/output tensor is scratch1, others are used as scratch. -// TODO: fix naming of scratch buffers template void calibrateGemmForInt8(int8_t* weights_int8, float* input_scaling_factors, float* output_scaling_factors, @@ -1839,7 +1808,7 @@ void quantizeActivationMatrix(int8_t* output, const half* input, int height, void deQuantizeOutputMatrixBiasAdd(half* output, const int8_t* input, int height, int width, int batchSize, float* scale, float *deq, const half* bias, - cudaStream_t stream, ActivationFunction act = NONE); + cudaStream_t stream, ActivationFunction act = ACTIVATION_NONE); // input/output tensor is in_out_tensor, others are used as scratch. template @@ -1922,7 +1891,7 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, DataType* mha_k; DataType* mha_v; - //dumpTensor(scratch1, embedding_op_size_ * 64 * N, "input to mha_kqv gemm", true); + //dumpTensor(in_out_tensor, embedding_op_size_ * 64 * N, "input to mha_kqv gemm", true); //dumpTensor(mha_qkv_w, embedding_op_size_ * d_model * 3, "weights to mha_kqv gemm", // true); //exit(0); @@ -1934,7 +1903,7 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, const int max_batch = max_batch_size_ * 64; const int batch_to_use = use_fused_mha_ ? batch : max_batch; // The array of GPU pointers assume max batch - mha_q = scratch0; + mha_q = scratch; mha_k = mha_q + num_outputs * batch_to_use; mha_v = mha_k + num_outputs * batch_to_use; @@ -1942,56 +1911,56 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, calibrateGemmForInt8(kqv_.weights_int8, kqv_.input_scaling_factors, kqv_.output_scaling_factors, kqv_.output_deq_factors, kqv_.input_matrix_max_values, - kqv_.output_matrix_max_values, scratch1, + kqv_.output_matrix_max_values, in_out_tensor, mha_qkv_w, 64, d_model, embedding_op_size_, 3, N); } if (true && int8_inf_) { // printf("\nAttempting int8_inf\n"); - // 1. quantize the inputs (scratch1 -> scratch0) + // 1. quantize the inputs (in_out_tensor -> scratch) // TODO: Fuse this step with layer-norm of previous block - quantizeActivationMatrix((int8_t*)scratch0, (const half*)scratch1, batch, + quantizeActivationMatrix((int8_t*)scratch, (const half*)in_out_tensor, batch, embedding_op_size_, kqv_.input_scaling_factors, stream); - // 2. perform int8 GEMM (scratch0 -> scratch2) + // 2. perform int8 GEMM (scratch -> buffer1) /* cutlassMatrixMulBTransposed( - (const int8_t*)scratch0, kqv_.weights_int8, (int8_t*)scratch2, batch, + (const int8_t*)scratch, kqv_.weights_int8, (int8_t*)buffer1, batch, num_outputs, num_inputs, 3, 0, num_inputs * num_outputs, num_outputs * batch_to_use, 1.0 / 127.0, 0.0f); */ // per-layer output scaling cutlassMatrixMulBTransposed( - (const int8_t*)scratch0, kqv_.weights_int8, - kqv_.output_scaling_factors, (int8_t*)scratch2, batch, + (const int8_t*)scratch, kqv_.weights_int8, + kqv_.output_scaling_factors, (int8_t*)buffer1, batch, num_outputs, num_inputs, 3, 0, num_inputs * num_outputs, num_outputs * batch_to_use, num_outputs, 1.0f, 0.0f); ReportCUDAErrors(cudaGetLastError()); /* - dumpTensor((const int8_t*)scratch0, 768, "quantized input matrix", + dumpTensor((const int8_t*)scratch, 768, "quantized input matrix", false, false); dumpTensor(kqv_.weights_int8, 768, "weights - during run", false, false); - dumpTensor((const int8_t*)scratch2, 768, + dumpTensor((const int8_t*)buffer1, 768, "some quantized output values", false, false); dumpTensor(kqv_.output_scaling_factors, 768, "output_scaling_factors - during run", false, false); */ - // 3. de-quantize outputs - fused with bias add (scratch2 -> scratch0) + // 3. de-quantize outputs - fused with bias add (buffer1 -> scratch) // TODO: fuse the entire thing with the above GEMM. - deQuantizeOutputMatrixBiasAdd((half*)scratch0, (const int8_t*)scratch2, batch, num_outputs, 3, + deQuantizeOutputMatrixBiasAdd((half*)scratch, (const int8_t*)buffer1, batch, num_outputs, 3, kqv_.output_scaling_factors, kqv_.output_deq_factors, (const half*) mha_qkv_b, stream); /* - dumpTensor((const half*)scratch0, 768, + dumpTensor((const half*)scratch, 768, "dequantized output values after bias add", false, false); exit(0); */ @@ -2002,7 +1971,7 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, num_inputs, 0, 0.0f, mha_q, num_outputs, num_outputs * batch_to_use, 3); #if 0 - cutlassMatrixMulBTransposed((const half*)scratch1, (const half*)mha_qkv_w, + cutlassMatrixMulBTransposed((const half*)in_out_tensor, (const half*)mha_qkv_w, (half*)mha_q, batch_to_use, num_outputs, num_inputs, 3, 0, num_inputs * num_outputs, num_outputs * batch_to_use, true); @@ -2133,7 +2102,7 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, mha_dense_.output_scaling_factors, mha_dense_.output_deq_factors, mha_dense_.input_matrix_max_values, mha_dense_.output_matrix_max_values, - scratch3, mha_dense_w, 64, embedding_op_size_, d_model, 1, N); + buffer2, mha_dense_w, 64, embedding_op_size_, d_model, 1, N); } const int num_inputs = d_model; @@ -2141,40 +2110,40 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, const int batch = N * 64; if (true && int8_inf_) { - // 1. quantize the inputs (scratch3 -> scratch0) + // 1. quantize the inputs (buffer2 -> scratch) // TODO: Fuse this step with the previous fused MHA - quantizeActivationMatrix((int8_t*)scratch0, (const half*)scratch3, batch, + quantizeActivationMatrix((int8_t*)scratch, (const half*)buffer2, batch, num_inputs, mha_dense_.input_scaling_factors, stream); - // 2. perform int8 GEMM (scratch0 -> scratch3) + // 2. perform int8 GEMM (scratch -> buffer2) /* cutlassMatrixMulBTransposed( - (const int8_t*)scratch0, mha_dense_.weights_int8, (int8_t*)scratch3, batch, + (const int8_t*)scratch, mha_dense_.weights_int8, (int8_t*)buffer2, batch, num_outputs, num_inputs, 1, 0, 0, 0, 1.0 / 127.0, 0.0f); */ cutlassMatrixMulBTransposed( - (const int8_t*)scratch0, mha_dense_.weights_int8, - mha_dense_.output_scaling_factors, (int8_t*)scratch3, + (const int8_t*)scratch, mha_dense_.weights_int8, + mha_dense_.output_scaling_factors, (int8_t*)buffer2, batch, num_outputs, num_inputs, 1, 0, 0, 0, 0, 1.0f, 0.0f); ReportCUDAErrors(cudaGetLastError()); - // 3. de-quantize outputs (scratch3 -> scratch2) + // 3. de-quantize outputs (buffer2 -> buffer1) // TODO: Fuse this with LN1 (should be easy!) deQuantizeOutputMatrixBiasAdd( - (half*)scratch2, (const int8_t*)scratch3, batch, num_outputs, 1, + (half*)buffer1, (const int8_t*)buffer2, batch, num_outputs, 1, mha_dense_.output_scaling_factors, mha_dense_.output_deq_factors, nullptr, stream); } else { cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, mha_dense_w, num_inputs, scratch3, - num_inputs, 0.0f, scratch2, num_outputs); + num_inputs, 1.0f, mha_dense_w, num_inputs, buffer2, + num_inputs, 0.0f, buffer1, num_outputs); /* cutlassMatrixMulBTransposed( - (const half*)scratch3, (const half*)mha_dense_w, (half*)scratch2, + (const half*)buffer2, (const half*)mha_dense_w, (half*)buffer1, batch, num_outputs, num_inputs, 1, 0, 0, 0, true); */ } @@ -2186,14 +2155,14 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, in_out_tensor, ln1_gammas, ln1_betas, default_eps_, alpha_, ACTIVATION_NONE, stream); - // #FFN dense 1, scratch0 -> in_out_tensor + // #FFN dense 1, scratch -> in_out_tensor const int encoder_dff = ffn_dense1_size_; { if (int8_cali_) { calibrateGemmForInt8( ffn1_.weights_int8, ffn1_.input_scaling_factors, ffn1_.output_scaling_factors, ffn1_.output_deq_factors, - ffn1_.input_matrix_max_values, ffn1_.output_matrix_max_values, scratch0, + ffn1_.input_matrix_max_values, ffn1_.output_matrix_max_values, scratch, ffn_dense1_w, 64, encoder_dff, embedding_op_size_, 1, N); } @@ -2203,20 +2172,20 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, const int batch = N * 64; if (true && int8_inf_) { - // 1. quantize the inputs (scratch0 -> scratch1) + // 1. quantize the inputs (scratch -> in_out_tensor) // TODO: Fuse this step with LN1 (should be easy) - quantizeActivationMatrix((int8_t*)scratch1, (const half*)scratch0, batch, + quantizeActivationMatrix((int8_t*)in_out_tensor, (const half*)scratch, batch, num_inputs, ffn1_.input_scaling_factors, stream); - // 2. perform int8 GEMM (scratch1 -> scratch2) + // 2. perform int8 GEMM (in_out_tensor -> buffer1) /* - cutlassMatrixMulBTransposed((const int8_t*)scratch1, ffn1_.weights_int8, - (int8_t*)scratch2, batch, num_outputs, + cutlassMatrixMulBTransposed((const int8_t*)in_out_tensor, ffn1_.weights_int8, + (int8_t*)buffer1, batch, num_outputs, num_inputs, 1, 0, 0, 0, 1.0 / 127.0, 0.0f); */ - cutlassMatrixMulBTransposed((const int8_t*)scratch1, ffn1_.weights_int8, + cutlassMatrixMulBTransposed((const int8_t*)in_out_tensor, ffn1_.weights_int8, ffn1_.output_scaling_factors, - (int8_t*)scratch2, batch, num_outputs, + (int8_t*)buffer1, batch, num_outputs, num_inputs, 1, 0, 0, 0, 0, 1.0f, 0.0f); ReportCUDAErrors(cudaGetLastError()); @@ -2224,47 +2193,45 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, dumpTensor(ffn1_.input_scaling_factors, 768, "input_scaling_factors - during run", false, false); - dumpTensor((const int8_t*)scratch1, 768, "quantized input matrix", + dumpTensor((const int8_t*)in_out_tensor, 768, "quantized input matrix", false, false); dumpTensor(ffn1_.weights_int8, 768, "weights - during run", false, false); - dumpTensor((const int8_t*)scratch2, 768, + dumpTensor((const int8_t*)buffer1, 768, "some quantized output values", false, false); dumpTensor(ffn1_.output_scaling_factors, 768, "output_scaling_factors - during run", false, false); */ - // 3. de-quantize outputs - fused with bias add (scratch2 -> scratch1) + // 3. de-quantize outputs - fused with bias add (buffer1 -> in_out_tensor) // TODO: Fuse this with the above GEMM deQuantizeOutputMatrixBiasAdd( - (half*)scratch1, (const int8_t*)scratch2, batch, num_outputs, 1, + (half*)in_out_tensor, (const int8_t*)buffer1, batch, num_outputs, 1, ffn1_.output_scaling_factors, ffn1_.output_deq_factors, - (const half*)ffn_dense1_b, stream, - has_smolgen_ ? RELU_2 : act); + (const half*)ffn_dense1_b, stream, ffn_activation_); // Ankan - test! - //dumpTensor((const DataType*)scratch1, 768, + //dumpTensor((const DataType*)in_out_tensor, 768, // "runtime output values after bias and RELU2", false, false); //exit(0); } else { cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, ffn_dense1_w, num_inputs, scratch0, - num_inputs, 0.0f, scratch1, num_outputs); + num_inputs, 1.0f, (const DataType*)ffn_dense1_w, num_inputs, + scratch, num_inputs, 0.0f, in_out_tensor, num_outputs); /* cutlassMatrixMulBTransposed( - (const half*)scratch0, (const half*)ffn_dense1_w, (half*)scratch1, + (const half*)scratch, (const half*)ffn_dense1_w, (half*)in_out_tensor, batch, num_outputs, num_inputs, 1, 0, 0, 0, true); */ - addBiasBatched(scratch1, scratch1, ffn_dense1_b, 1, batch, num_outputs, - has_smolgen_ ? RELU_2 : act, - stream); // @todo sqr relu to have its own flag + addBiasBatched(in_out_tensor, in_out_tensor, ffn_dense1_b, 1, batch, + num_outputs, ffn_activation_, stream); // Ankan - test! - //dumpTensor((const DataType*)scratch1, 768, + //dumpTensor((const DataType*)in_out_tensor, 768, // "Ref output values after bias and RELU2", false, // false); //exit(0); @@ -2277,7 +2244,7 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, calibrateGemmForInt8( ffn2_.weights_int8, ffn2_.input_scaling_factors, ffn2_.output_scaling_factors, ffn2_.output_deq_factors, - ffn2_.input_matrix_max_values, ffn2_.output_matrix_max_values, scratch1, + ffn2_.input_matrix_max_values, ffn2_.output_matrix_max_values, in_out_tensor, ffn_dense2_w, 64, embedding_op_size_, encoder_dff, 1, N); } @@ -2285,49 +2252,49 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, const int num_outputs = embedding_op_size_; const int batch = N * 64; if (true && int8_inf_) { - // 1. quantize the inputs (scratch1 -> scratch2) + // 1. quantize the inputs (in_out_tensor -> buffer1) // TODO: Fuse this step with above bias add at least (or ideally with the // above GEMM) - quantizeActivationMatrix((int8_t*)scratch2, (const half*)scratch1, batch, + quantizeActivationMatrix((int8_t*)buffer1, (const half*)in_out_tensor, batch, num_inputs, ffn2_.input_scaling_factors, stream); /* dumpTensor((const float*)ffn2_.input_scaling_factors, 768, "input scaling factors during run", false, false); - dumpTensor((const int8_t*)scratch2, 768, "quantized input matrix", + dumpTensor((const int8_t*)buffer1, 768, "quantized input matrix", false, false); dumpTensor(ffn2_.weights_int8, 768, "weights - during run", false, false); */ - // 2. perform int8 GEMM (scratch2 -> scratch1) + // 2. perform int8 GEMM (buffer1 -> in_out_tensor) /* - cutlassMatrixMulBTransposed((const int8_t*)scratch2, ffn2_.weights_int8, - (int8_t*)scratch1, batch, num_outputs, + cutlassMatrixMulBTransposed((const int8_t*)buffer1, ffn2_.weights_int8, + (int8_t*)in_out_tensor, batch, num_outputs, num_inputs, 1, 0, 0, 0, 1.0 / 127.0, 0.0f); */ - cutlassMatrixMulBTransposed((const int8_t*)scratch2, ffn2_.weights_int8, ffn2_.output_scaling_factors, - (int8_t*)scratch1, batch, num_outputs, + cutlassMatrixMulBTransposed((const int8_t*)buffer1, ffn2_.weights_int8, ffn2_.output_scaling_factors, + (int8_t*)in_out_tensor, batch, num_outputs, num_inputs, 1, 0, 0, 0, 0, 1.0f, 0.0f); ReportCUDAErrors(cudaGetLastError()); /* - dumpTensor((const int8_t*)scratch1, 768, + dumpTensor((const int8_t*)in_out_tensor, 768, "some quantized output values", false, false); dumpTensor(ffn1_.output_scaling_factors, 768, "output_scaling_factors - during run", false, false); */ - // 3. de-quantize outputs (scratch1 -> scratch2) + // 3. de-quantize outputs (in_out_tensor -> buffer1) // TODO: Fuse this with LN2 (should be easy) deQuantizeOutputMatrixBiasAdd( - (half*)scratch2, (const int8_t*)scratch1, batch, num_outputs, 1, + (half*)buffer1, (const int8_t*)in_out_tensor, batch, num_outputs, 1, ffn2_.output_scaling_factors, ffn2_.output_deq_factors, nullptr, stream); /* - dumpTensor((const half*)scratch2, 768, "dequantized output values", + dumpTensor((const half*)buffer1, 768, "dequantized output values", false, false); exit(0);*/ } else { @@ -2336,11 +2303,11 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, in_out_tensor, num_inputs, 0.0f, buffer1, num_outputs); /* cutlassMatrixMulBTransposed( - (const half*)scratch1, (const half*)ffn_dense2_w, (half*)scratch2, + (const half*)in_out_tensor, (const half*)ffn_dense2_w, (half*)buffer1, batch, num_outputs, num_inputs, 1, 0, 0, 0, true); */ /* - dumpTensor((const half*)scratch2, 768, "dequantized output values - ref", + dumpTensor((const half*)buffer1, 768, "dequantized output values - ref", false, false); exit(0); diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index 9064ca66f2..428a1b43fd 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -354,8 +354,9 @@ class EncoderBlock { int heads, int size, float alpha, DataType* smolgen_global_scratch, int smolgen_global_size, int max_batch_size, ActivationFunction smolgen_act, - ActivationFunction ffn_act, float default_eps, bool fused_mha, bool int8_calibrate, - bool int8_inference, void* int8_weights, int blockIndex); + ActivationFunction ffn_act, float default_eps, bool fused_mha, + bool int8_calibrate, bool int8_inference, void* int8_weights, + int blockIndex); ~EncoderBlock(); void Eval(int N, DataType* inpop, DataType* scratch0, DataType* scratch1, From d0a8536eac44278fca79a194e88befbbf349830f Mon Sep 17 00:00:00 2001 From: Aniebiet Udoh Date: Thu, 7 Mar 2024 15:35:03 +0100 Subject: [PATCH 54/70] Remove value error head inference. --- src/neural/cuda/inputs_outputs.h | 17 ++----------- src/neural/cuda/layers.cc | 31 +++-------------------- src/neural/cuda/layers.h | 5 ++-- src/neural/cuda/network_cuda.cc | 42 ++++++-------------------------- 4 files changed, 16 insertions(+), 79 deletions(-) diff --git a/src/neural/cuda/inputs_outputs.h b/src/neural/cuda/inputs_outputs.h index ea89ed1272..f7f82bdaf2 100644 --- a/src/neural/cuda/inputs_outputs.h +++ b/src/neural/cuda/inputs_outputs.h @@ -31,11 +31,10 @@ namespace lczero { namespace cudnn_backend { struct InputsOutputs { - InputsOutputs(int maxBatchSize, bool wdl, bool wdl_err, bool moves_left, + InputsOutputs(int maxBatchSize, bool wdl, bool moves_left, size_t tensor_mem_size = 0, size_t scratch_size = 0, bool cublasDisableTensorCores = false) - : has_moves_left_(moves_left), - has_wdl_err_(wdl_err) { + : has_moves_left_(moves_left) { ReportCUDAErrors(cudaHostAlloc( &input_masks_mem_, maxBatchSize * kInputPlanes * sizeof(uint64_t), cudaHostAllocMapped)); @@ -70,14 +69,6 @@ struct InputsOutputs { op_moves_left_mem_, 0)); } - if (wdl_err) { - ReportCUDAErrors(cudaHostAlloc(&op_value_err_mem_, - maxBatchSize * sizeof(float), - cudaHostAllocMapped)); - ReportCUDAErrors(cudaHostGetDevicePointer(&op_value_err_mem_gpu_, - op_value_err_mem_, 0)); - } - // memory for network execution managed inside this structure if (tensor_mem_size) { multi_stream_ = true; @@ -103,7 +94,6 @@ struct InputsOutputs { ReportCUDAErrors(cudaFree(op_policy_mem_gpu_)); ReportCUDAErrors(cudaFreeHost(op_value_mem_)); if (has_moves_left_) ReportCUDAErrors(cudaFreeHost(op_moves_left_mem_)); - if (has_wdl_err_) ReportCUDAErrors(cudaFreeHost(op_value_err_mem_)); if (multi_stream_) { for (auto mem : tensor_mem_) { @@ -122,14 +112,12 @@ struct InputsOutputs { float* input_val_mem_; float* op_policy_mem_; float* op_value_mem_; - float* op_value_err_mem_; float* op_moves_left_mem_; // GPU pointers for the above allocations. uint64_t* input_masks_mem_gpu_; float* input_val_mem_gpu_; float* op_value_mem_gpu_; - float* op_value_err_mem_gpu_; float* op_moves_left_mem_gpu_; // This is a seperate copy. @@ -150,7 +138,6 @@ struct InputsOutputs { cublasHandle_t cublas_; bool has_moves_left_ = false; - bool has_wdl_err_ = false; }; } // namespace cudnn_backend diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 04a7103aa0..31a63250f6 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -2767,16 +2767,15 @@ template ValueHead::ValueHead(BaseLayer* ip, const MultiHeadWeights::ValueHead& weights, void* scratch, bool attention_body, bool wdl, - bool wdl_err, ActivationFunction act, - int max_batch_size, bool use_gemm_ex) + ActivationFunction act, int max_batch_size, + bool use_gemm_ex) : BaseLayer(weights.ip_val_b.size(), 8, 8, ip), attention_body_(attention_body), embedding_size_(attention_body ? weights.ip_val_b.size() : weights.value.biases.size()), value_hidden_size_(weights.ip1_val_b.size()), act_(act), - wdl_(wdl), - wdl_err_(wdl_err) { + wdl_(wdl) { if (attention_body_) { allocAndUpload(&ip_val_w_, weights.ip_val_w, scratch); allocAndUpload(&ip_val_b_, weights.ip_val_b, scratch); @@ -2793,11 +2792,6 @@ ValueHead::ValueHead(BaseLayer* ip, allocAndUpload(&ip2_val_w_, weights.ip2_val_w, scratch); allocAndUpload(&ip2_val_b_, weights.ip2_val_b, scratch); - - if (wdl_err_) { - allocAndUpload(&ip_val_err_w_, weights.ip_val_err_w, scratch); - allocAndUpload(&ip_val_err_b_, weights.ip_val_err_b, scratch); - } } template @@ -2810,10 +2804,6 @@ ValueHead::~ValueHead() { ReportCUDAErrors(cudaFree(ip1_val_b_)); ReportCUDAErrors(cudaFree(ip2_val_w_)); ReportCUDAErrors(cudaFree(ip2_val_b_)); - if (wdl_err_) { - ReportCUDAErrors(cudaFree(ip_val_err_w_)); - ReportCUDAErrors(cudaFree(ip_val_err_b_)); - } } template @@ -2860,7 +2850,7 @@ void ValueHead::Eval(int N, DataType* output, const DataType* input, const int num_inputs = value_hidden_size_; const int num_outputs = wdl_ ? 3 : 1; const int batch = N; - DataType* layer_out = wdl_err_ ? (DataType*)buffer : (DataType*)output; + DataType* layer_out = (DataType*)output; cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)ip2_val_w_, num_inputs, (DataType*)scratch, num_inputs, 0.0f, @@ -2869,19 +2859,6 @@ void ValueHead::Eval(int N, DataType* output, const DataType* input, num_outputs * batch, num_outputs, wdl_ ? ACTIVATION_NONE : ACTIVATION_TANH, stream); } - - if (wdl_err_) { - // Value error dense - const int num_inputs = value_hidden_size_; - const int num_outputs = 1; - const int batch = N; - cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, (const DataType*)ip_val_err_w_, - num_inputs, (DataType*)scratch, num_inputs, 0.0f, - output, num_outputs); - addVectors(output, output, ip_val_err_b_, N, N, 1, ACTIVATION_SIGMOID, - stream); - } } // Template instantiation. diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index 428a1b43fd..995aae3e35 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -551,8 +551,8 @@ class ValueHead : public BaseLayer { public: ValueHead(BaseLayer* ip, const MultiHeadWeights::ValueHead& weights, - void* scratch, bool attention_body, bool wdl, bool wdl_err, - ActivationFunction act, int max_batch_size, bool use_gemm_ex); + void* scratch, bool attention_body, bool wdl, ActivationFunction act, + int max_batch_size, bool use_gemm_ex); ~ValueHead(); void Eval(int N, DataType* output, const DataType* input, const DataType* input2, void* scratch, size_t scratch_size, @@ -572,7 +572,6 @@ class ValueHead : public BaseLayer { int embedding_size_; int value_hidden_size_; bool wdl_; - bool wdl_err_; bool attention_body_; ActivationFunction act_; }; diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 6fc94d36a6..29b0aab85f 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -127,8 +127,8 @@ static size_t getMaxAttentionBodySize(const MultiHeadWeights& weights, int N) { template class CudaNetworkComputation : public NetworkComputation { public: - CudaNetworkComputation(CudaNetwork* network, bool wdl, bool wdl_err, - bool moves_left); + CudaNetworkComputation(CudaNetwork* network, + bool wdl, bool moves_left); ~CudaNetworkComputation(); void AddInput(InputPlanes&& input) override { @@ -183,7 +183,6 @@ class CudaNetworkComputation : public NetworkComputation { std::unique_ptr inputs_outputs_; int batch_size_; bool wdl_; - bool wdl_err_; bool moves_left_; CudaNetwork* network_; @@ -629,17 +628,9 @@ class CudaNetwork : public Network { pblczero::NetworkFormat::VALUE_WDL; BaseLayer* lastlayer = attn_body_ ? encoder_last_ : resi_last_; auto value_main = std::make_unique>( - lastlayer, head, scratch_mem_, attn_body_, wdl_, false, act, + lastlayer, head, scratch_mem_, attn_body_, wdl_, act, max_batch_size_, use_gemm_ex); network_.emplace_back(std::move(value_main)); - - wdl_err_ = weights.value_heads.count("st") > 0; - if (wdl_err_) { - auto value_err = std::make_unique>( - lastlayer, weights.value_heads.at("st"), scratch_mem_, attn_body_, - wdl_, true, act, max_batch_size_, use_gemm_ex); - network_.emplace_back(std::move(value_err)); - } } // Moves left head @@ -756,7 +747,6 @@ class CudaNetwork : public Network { float* opPol = io->op_policy_mem_gpu_; float* opVal = io->op_value_mem_gpu_; float* opMov = io->op_moves_left_mem_gpu_; - float* opValErr = io->op_value_err_mem_gpu_; // Figure out if the memory requirment for running the res block would fit // in the L2 cache. @@ -926,20 +916,6 @@ class CudaNetwork : public Network { stream); // value head } - if (wdl_err_) { - // value error head - if (fp16) { - network_[l++]->Eval(batchSize, spare1, flow, spare2, scratch_mem, - scratch_size_, nullptr, cublas, - stream); // value error head - copyTypeConverted(opValErr, (half*)spare1, batchSize, stream); - } else { - network_[l++]->Eval(batchSize, (DataType*)opValErr, flow, spare2, - scratch_mem, scratch_size_, nullptr, cublas, - stream); // value error head - } - } - if (moves_left_) { // Moves left head network_[l++]->Eval(batchSize, spare1, flow, nullptr, scratch_mem, @@ -1030,16 +1006,15 @@ class CudaNetwork : public Network { // Set correct gpu id for this computation (as it might have been called // from a different thread). ReportCUDAErrors(cudaSetDevice(gpu_id_)); - return std::make_unique>( - this, wdl_, wdl_err_, moves_left_); + return std::make_unique>(this, wdl_, + moves_left_); } std::unique_ptr GetInputsOutputs() { std::lock_guard lock(inputs_outputs_lock_); if (free_inputs_outputs_.empty()) { return std::make_unique( - max_batch_size_, wdl_, wdl_err_, moves_left_, tensor_mem_size_, - scratch_size_, + max_batch_size_, wdl_, moves_left_, tensor_mem_size_, scratch_size_, !has_tensor_cores_ && std::is_same::value); } else { std::unique_ptr resource = @@ -1069,7 +1044,6 @@ class CudaNetwork : public Network { int max_batch_size_; int min_batch_size_; bool wdl_; - bool wdl_err_; bool moves_left_; bool use_res_block_winograd_fuse_opt_; // fuse operations inside the residual // tower @@ -1167,8 +1141,8 @@ class CudaNetwork : public Network { template CudaNetworkComputation::CudaNetworkComputation( - CudaNetwork* network, bool wdl, bool wdl_err, bool moves_left) - : wdl_(wdl), wdl_err_(wdl_err), moves_left_(moves_left), network_(network) { + CudaNetwork* network, bool wdl, bool moves_left) + : wdl_(wdl), moves_left_(moves_left), network_(network) { batch_size_ = 0; inputs_outputs_ = network_->GetInputsOutputs(); } From 405046ac9f9035eb85891083d5fbeb2c68b6ef23 Mon Sep 17 00:00:00 2001 From: almaudoh Date: Sun, 24 Mar 2024 01:13:17 +0100 Subject: [PATCH 55/70] Fix conflict resolution artefacts. --- src/neural/cuda/inputs_outputs.h | 5 +- src/neural/cuda/kernels.h | 10 +- src/neural/cuda/layers.cc | 166 ++++++++++++++++--------------- src/neural/cuda/network_cuda.cc | 37 ++----- src/neural/cuda/network_cudnn.cc | 14 +-- 5 files changed, 110 insertions(+), 122 deletions(-) diff --git a/src/neural/cuda/inputs_outputs.h b/src/neural/cuda/inputs_outputs.h index ed566cc7af..248f4b23b3 100644 --- a/src/neural/cuda/inputs_outputs.h +++ b/src/neural/cuda/inputs_outputs.h @@ -33,8 +33,7 @@ namespace cudnn_backend { struct InputsOutputs { InputsOutputs(int maxBatchSize, bool wdl, bool moves_left, size_t tensor_mem_size = 0, size_t scratch_size = 0, - bool cublasDisableTensorCores = false) - : has_moves_left_(moves_left) { + bool cublasDisableTensorCores = false) { ReportCUDAErrors(cudaHostAlloc( &input_masks_mem_, maxBatchSize * kInputPlanes * sizeof(uint64_t), cudaHostAllocMapped)); @@ -137,8 +136,6 @@ struct InputsOutputs { // cublas handle used to run the network cublasHandle_t cublas_; - - bool has_moves_left_ = false; }; } // namespace cudnn_backend diff --git a/src/neural/cuda/kernels.h b/src/neural/cuda/kernels.h index 83f64066ec..cc49253ab9 100644 --- a/src/neural/cuda/kernels.h +++ b/src/neural/cuda/kernels.h @@ -154,16 +154,18 @@ void inputPreprocessForAttentionBody(T* output, const T* input, template void applyInputGating(T* output, const T* input, const T* mult, const T* add, - int N, int HW, int C, cudaStream_t stream); + int N, int HW, int C, cudaStream_t stream); void cutlassMatrixMulBTransposed(const half* A, const half* B, half* Out, int M, int N, int K, int batchSize, int AStride, - int BStride, int OutStride, bool useInt8 = true); + int BStride, int OutStride, + bool useInt8 = true); } // namespace cudnn_backend } // namespace lczero -// Work around to avoid "nvcc error : 'cudafe++' died with status 0xC0000409" error -// For some reason nvcc runs into this random error when trying to compile this function inside the namespaces +// Work around to avoid "nvcc error : 'cudafe++' died with status 0xC0000409" +// error For some reason nvcc runs into this random error when trying to compile +// this function inside the namespaces bool fusedMHA(void* output, void* mha_q, void* mha_k, void* mha_v, void* skip, int batch_size, int num_heads, int depth, cudaStream_t stream); diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index e64bd47821..1c0962649f 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -45,7 +45,6 @@ void dumpTensor(const T* memory, int elements, const char* message, bool only_summary = false, bool cpu_tensor = false); } - #if 0 #include @@ -1520,16 +1519,18 @@ AttentionPolicyHead::AttentionPolicyHead( } } -void fillGpuArray(float *arr, float val, int count); +void fillGpuArray(float* arr, float val, int count); -static int8_t* SetQuantizationData(MatMulQuantizationData& data, int8_t* w, int InputCols, int OutputCols, int outBatch, bool cali) { +static int8_t* SetQuantizationData(MatMulQuantizationData& data, int8_t* w, + int InputCols, int OutputCols, int outBatch, + bool cali) { size_t matrix_size = InputCols * OutputCols * sizeof(int8_t) * outBatch; if (!cali) { // Load weights for INT8 inference - + // (per-column) scaling factors for the input - ReportCUDAErrors(cudaMalloc(&data.input_scaling_factors, - sizeof(float) * InputCols)); + ReportCUDAErrors( + cudaMalloc(&data.input_scaling_factors, sizeof(float) * InputCols)); ReportCUDAErrors(cudaMemcpy(data.input_scaling_factors, w, sizeof(float) * InputCols, cudaMemcpyHostToDevice)); @@ -1549,19 +1550,20 @@ static int8_t* SetQuantizationData(MatMulQuantizationData& data, int8_t* w, int cudaMemcpyHostToDevice)); // go to output dequantization factors w += outBatch * OutputCols * sizeof(float); - data.output_deq_factors = (float*) w; + data.output_deq_factors = (float*)w; // go to next item w += outBatch * sizeof(float); } else { - // Just save the pointers to CPU weights (we will over-write here during calibration) + // Just save the pointers to CPU weights (we will over-write here during + // calibration) data.input_scaling_factors = (float*)w; w += InputCols * sizeof(float); data.weights_int8 = w; w += matrix_size; data.output_scaling_factors = (float*)w; w += outBatch * OutputCols * sizeof(float); - data.output_deq_factors = (float*) w; + data.output_deq_factors = (float*)w; w += outBatch * sizeof(float); // to keep track of max values in activation matrices @@ -1692,12 +1694,13 @@ EncoderBlock::EncoderBlock( embedding_op_size_ * sizeof(float) + ffn_dense1_size_ * mha_q_size_ + sizeof(float) + ffn_dense1_size_ * sizeof(float) + - embedding_op_size_ * ffn_dense1_size_ + sizeof(float); + embedding_op_size_ * ffn_dense1_size_ + + sizeof(float); */ int embedding_op_size = embedding_op_size_; int encoder_d_model = mha_q_size_; int encoder_dff = ffn_dense1_size_; - int per_encoder_size = + int per_encoder_size = (embedding_op_size * sizeof(float) + 3 * embedding_op_size * encoder_d_model + 3 * (encoder_d_model + 1) * sizeof(float) + encoder_d_model * sizeof(float) + encoder_d_model * embedding_op_size + (embedding_op_size + 1) * sizeof(float) + embedding_op_size * sizeof(float) + embedding_op_size * encoder_dff + (encoder_dff + 1) * sizeof(float) + @@ -1787,8 +1790,9 @@ template void calibrateGemmForInt8(int8_t* weights_int8, float* input_scaling_factors, float* output_scaling_factors, float* output_deq_factors, float* maxValuesA, - float* maxValuesOut, const DataType* A, const DataType* B, int M, int N, - int K, int batchSize, int M_Batch); + float* maxValuesOut, const DataType* A, + const DataType* B, int M, int N, int K, int batchSize, + int M_Batch); void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int8_t* Out, int M, int N, int K, int batchSize, @@ -1807,8 +1811,9 @@ void quantizeActivationMatrix(int8_t* output, const half* input, int height, void deQuantizeOutputMatrixBiasAdd(half* output, const int8_t* input, int height, int width, int batchSize, - float* scale, float *deq, const half* bias, - cudaStream_t stream, ActivationFunction act = ACTIVATION_NONE); + float* scale, float* deq, const half* bias, + cudaStream_t stream, + ActivationFunction act = ACTIVATION_NONE); // input/output tensor is in_out_tensor, others are used as scratch. template @@ -1909,19 +1914,19 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, if (int8_cali_) { calibrateGemmForInt8(kqv_.weights_int8, kqv_.input_scaling_factors, - kqv_.output_scaling_factors, kqv_.output_deq_factors, + kqv_.output_scaling_factors, kqv_.output_deq_factors, kqv_.input_matrix_max_values, kqv_.output_matrix_max_values, in_out_tensor, - mha_qkv_w, 64, d_model, embedding_op_size_, 3, N); + mha_qkv_w, 64, d_model, embedding_op_size_, 3, N); } if (true && int8_inf_) { // printf("\nAttempting int8_inf\n"); // 1. quantize the inputs (in_out_tensor -> scratch) // TODO: Fuse this step with layer-norm of previous block - quantizeActivationMatrix((int8_t*)scratch, (const half*)in_out_tensor, batch, - embedding_op_size_, kqv_.input_scaling_factors, - stream); + quantizeActivationMatrix((int8_t*)scratch, (const half*)in_out_tensor, + batch, embedding_op_size_, + kqv_.input_scaling_factors, stream); // 2. perform int8 GEMM (scratch -> buffer1) /* @@ -1929,15 +1934,15 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, (const int8_t*)scratch, kqv_.weights_int8, (int8_t*)buffer1, batch, num_outputs, num_inputs, 3, 0, num_inputs * num_outputs, num_outputs * batch_to_use, 1.0 / 127.0, 0.0f); - */ + */ // per-layer output scaling - + cutlassMatrixMulBTransposed( (const int8_t*)scratch, kqv_.weights_int8, - kqv_.output_scaling_factors, (int8_t*)buffer1, batch, - num_outputs, num_inputs, 3, 0, num_inputs * num_outputs, + kqv_.output_scaling_factors, (int8_t*)buffer1, batch, num_outputs, + num_inputs, 3, 0, num_inputs * num_outputs, num_outputs * batch_to_use, num_outputs, 1.0f, 0.0f); - + ReportCUDAErrors(cudaGetLastError()); /* dumpTensor((const int8_t*)scratch, 768, "quantized input matrix", @@ -1955,21 +1960,22 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, */ // 3. de-quantize outputs - fused with bias add (buffer1 -> scratch) // TODO: fuse the entire thing with the above GEMM. - deQuantizeOutputMatrixBiasAdd((half*)scratch, (const int8_t*)buffer1, batch, num_outputs, 3, - kqv_.output_scaling_factors, - kqv_.output_deq_factors, (const half*) mha_qkv_b, stream); + deQuantizeOutputMatrixBiasAdd( + (half*)scratch, (const int8_t*)buffer1, batch, num_outputs, 3, + kqv_.output_scaling_factors, kqv_.output_deq_factors, + (const half*)mha_qkv_b, stream); /* dumpTensor((const half*)scratch, 768, "dequantized output values after bias add", false, false); exit(0); */ - } else { - cublasXGemmStridedBatched( - cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, - mha_qkv_w, num_inputs, num_inputs * num_outputs, in_out_tensor, - num_inputs, 0, 0.0f, mha_q, num_outputs, num_outputs * batch_to_use, - 3); + } else { + cublasXGemmStridedBatched( + cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, + 1.0f, mha_qkv_w, num_inputs, num_inputs * num_outputs, in_out_tensor, + num_inputs, 0, 0.0f, mha_q, num_outputs, num_outputs * batch_to_use, + 3); #if 0 cutlassMatrixMulBTransposed((const half*)in_out_tensor, (const half*)mha_qkv_w, (half*)mha_q, batch_to_use, num_outputs, @@ -1984,7 +1990,7 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, "ref output values after bias add", false, false); exit(0); #endif - } + } } // Apply split_heads() to q, k and v @@ -2008,11 +2014,11 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, #ifdef USE_CUTLASS if (use_fused_mha_) { - // TODO: check if we need skip in a different tensor than same tensor as output! + // TODO: check if we need skip in a different tensor than same tensor as + // output! bool success = - fusedMHA(buffer2, mha_q, mha_k, mha_v, - has_smolgen_ ? buffer2 : nullptr, N, - encoder_heads_, depth, stream); + fusedMHA(buffer2, mha_q, mha_k, mha_v, has_smolgen_ ? buffer2 : nullptr, + N, encoder_heads_, depth, stream); ReportCUDAErrors(cudaGetLastError()); if (!success) throw Exception("Some error running fused MHA"); @@ -2101,8 +2107,8 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, mha_dense_.weights_int8, mha_dense_.input_scaling_factors, mha_dense_.output_scaling_factors, mha_dense_.output_deq_factors, mha_dense_.input_matrix_max_values, - mha_dense_.output_matrix_max_values, - buffer2, mha_dense_w, 64, embedding_op_size_, d_model, 1, N); + mha_dense_.output_matrix_max_values, buffer2, mha_dense_w, 64, + embedding_op_size_, d_model, 1, N); } const int num_inputs = d_model; @@ -2119,15 +2125,14 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // 2. perform int8 GEMM (scratch -> buffer2) /* cutlassMatrixMulBTransposed( - (const int8_t*)scratch, mha_dense_.weights_int8, (int8_t*)buffer2, batch, - num_outputs, num_inputs, 1, 0, 0, 0, 1.0 / 127.0, 0.0f); + (const int8_t*)scratch, mha_dense_.weights_int8, (int8_t*)buffer2, + batch, num_outputs, num_inputs, 1, 0, 0, 0, 1.0 / 127.0, 0.0f); */ - + cutlassMatrixMulBTransposed( (const int8_t*)scratch, mha_dense_.weights_int8, - mha_dense_.output_scaling_factors, (int8_t*)buffer2, - batch, num_outputs, num_inputs, 1, 0, 0, 0, 0, 1.0f, 0.0f); - + mha_dense_.output_scaling_factors, (int8_t*)buffer2, batch, + num_outputs, num_inputs, 1, 0, 0, 0, 0, 1.0f, 0.0f); ReportCUDAErrors(cudaGetLastError()); @@ -2162,39 +2167,39 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, calibrateGemmForInt8( ffn1_.weights_int8, ffn1_.input_scaling_factors, ffn1_.output_scaling_factors, ffn1_.output_deq_factors, - ffn1_.input_matrix_max_values, ffn1_.output_matrix_max_values, scratch, - ffn_dense1_w, 64, - encoder_dff, embedding_op_size_, 1, N); + ffn1_.input_matrix_max_values, ffn1_.output_matrix_max_values, + scratch, ffn_dense1_w, 64, encoder_dff, embedding_op_size_, 1, N); } const int num_inputs = embedding_op_size_; const int num_outputs = ffn_dense1_size_; // encoder_dff const int batch = N * 64; - if (true && int8_inf_) { + if (true && int8_inf_) { // 1. quantize the inputs (scratch -> in_out_tensor) // TODO: Fuse this step with LN1 (should be easy) - quantizeActivationMatrix((int8_t*)in_out_tensor, (const half*)scratch, batch, - num_inputs, ffn1_.input_scaling_factors, stream); + quantizeActivationMatrix((int8_t*)in_out_tensor, (const half*)scratch, + batch, num_inputs, ffn1_.input_scaling_factors, + stream); // 2. perform int8 GEMM (in_out_tensor -> buffer1) /* - cutlassMatrixMulBTransposed((const int8_t*)in_out_tensor, ffn1_.weights_int8, - (int8_t*)buffer1, batch, num_outputs, - num_inputs, 1, 0, 0, 0, 1.0 / 127.0, 0.0f); + cutlassMatrixMulBTransposed((const int8_t*)in_out_tensor, + ffn1_.weights_int8, (int8_t*)buffer1, batch, num_outputs, num_inputs, 1, + 0, 0, 0, 1.0 / 127.0, 0.0f); */ - cutlassMatrixMulBTransposed((const int8_t*)in_out_tensor, ffn1_.weights_int8, - ffn1_.output_scaling_factors, - (int8_t*)buffer1, batch, num_outputs, - num_inputs, 1, 0, 0, 0, 0, 1.0f, 0.0f); + cutlassMatrixMulBTransposed( + (const int8_t*)in_out_tensor, ffn1_.weights_int8, + ffn1_.output_scaling_factors, (int8_t*)buffer1, batch, num_outputs, + num_inputs, 1, 0, 0, 0, 0, 1.0f, 0.0f); ReportCUDAErrors(cudaGetLastError()); /* dumpTensor(ffn1_.input_scaling_factors, 768, "input_scaling_factors - during run", false, false); - dumpTensor((const int8_t*)in_out_tensor, 768, "quantized input matrix", - false, false); + dumpTensor((const int8_t*)in_out_tensor, 768, "quantized input + matrix", false, false); dumpTensor(ffn1_.weights_int8, 768, "weights - during run", false, false); @@ -2214,9 +2219,10 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, (const half*)ffn_dense1_b, stream, ffn_activation_); // Ankan - test! - //dumpTensor((const DataType*)in_out_tensor, 768, - // "runtime output values after bias and RELU2", false, false); - //exit(0); + // dumpTensor((const DataType*)in_out_tensor, 768, + // "runtime output values after bias and RELU2", + // false, false); + // exit(0); } else { cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, @@ -2231,10 +2237,10 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, num_outputs, ffn_activation_, stream); // Ankan - test! - //dumpTensor((const DataType*)in_out_tensor, 768, + // dumpTensor((const DataType*)in_out_tensor, 768, // "Ref output values after bias and RELU2", false, // false); - //exit(0); + // exit(0); } } @@ -2244,8 +2250,9 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, calibrateGemmForInt8( ffn2_.weights_int8, ffn2_.input_scaling_factors, ffn2_.output_scaling_factors, ffn2_.output_deq_factors, - ffn2_.input_matrix_max_values, ffn2_.output_matrix_max_values, in_out_tensor, - ffn_dense2_w, 64, embedding_op_size_, encoder_dff, 1, N); + ffn2_.input_matrix_max_values, ffn2_.output_matrix_max_values, + in_out_tensor, ffn_dense2_w, 64, embedding_op_size_, encoder_dff, 1, + N); } const int num_inputs = ffn_dense1_size_; // encoder_dff @@ -2255,8 +2262,9 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // 1. quantize the inputs (in_out_tensor -> buffer1) // TODO: Fuse this step with above bias add at least (or ideally with the // above GEMM) - quantizeActivationMatrix((int8_t*)buffer1, (const half*)in_out_tensor, batch, - num_inputs, ffn2_.input_scaling_factors, stream); + quantizeActivationMatrix((int8_t*)buffer1, (const half*)in_out_tensor, + batch, num_inputs, ffn2_.input_scaling_factors, + stream); /* dumpTensor((const float*)ffn2_.input_scaling_factors, 768, "input scaling factors during run", @@ -2274,11 +2282,12 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, (int8_t*)in_out_tensor, batch, num_outputs, num_inputs, 1, 0, 0, 0, 1.0 / 127.0, 0.0f); */ - - cutlassMatrixMulBTransposed((const int8_t*)buffer1, ffn2_.weights_int8, ffn2_.output_scaling_factors, + + cutlassMatrixMulBTransposed((const int8_t*)buffer1, ffn2_.weights_int8, + ffn2_.output_scaling_factors, (int8_t*)in_out_tensor, batch, num_outputs, num_inputs, 1, 0, 0, 0, 0, 1.0f, 0.0f); - + ReportCUDAErrors(cudaGetLastError()); /* dumpTensor((const int8_t*)in_out_tensor, 768, @@ -2291,16 +2300,16 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // TODO: Fuse this with LN2 (should be easy) deQuantizeOutputMatrixBiasAdd( (half*)buffer1, (const int8_t*)in_out_tensor, batch, num_outputs, 1, - ffn2_.output_scaling_factors, - ffn2_.output_deq_factors, nullptr, stream); + ffn2_.output_scaling_factors, ffn2_.output_deq_factors, nullptr, + stream); /* dumpTensor((const half*)buffer1, 768, "dequantized output values", false, false); exit(0);*/ } else { cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, (const DataType*)ffn_dense2_w, num_inputs, - in_out_tensor, num_inputs, 0.0f, buffer1, num_outputs); + num_inputs, 1.0f, (const DataType*)ffn_dense2_w, num_inputs, + in_out_tensor, num_inputs, 0.0f, buffer1, num_outputs); /* cutlassMatrixMulBTransposed( (const half*)in_out_tensor, (const half*)ffn_dense2_w, (half*)buffer1, @@ -2473,7 +2482,6 @@ EncoderBlock::~EncoderBlock() { } } - template EmbeddingLayer::EmbeddingLayer(BaseLayer* ip, const std::vector& weights, diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 22593ccb80..92c1ff8a4d 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -64,16 +64,16 @@ static size_t getMaxAttentionHeadSize( size_t encoder_d_model = 0; size_t encoder_dff = 0; - if (vanilla.pol_encoder.size() > 0) { - encoder_d_model = vanilla.pol_encoder[0].mha.q_b.size(); - encoder_dff = vanilla.pol_encoder[0].ffn.dense1_b.size(); + if (weights.pol_encoder.size() > 0) { + encoder_d_model = weights.pol_encoder[0].mha.q_b.size(); + encoder_dff = weights.pol_encoder[0].ffn.dense1_b.size(); - assert(encoder_d_model == vanilla.pol_encoder[0].mha.k_b.size()); - assert(encoder_d_model == vanilla.pol_encoder[0].mha.v_b.size()); - assert(embedding_op_size == vanilla.pol_encoder[0].ffn.dense2_b.size()); + assert(encoder_d_model == weights.pol_encoder[0].mha.k_b.size()); + assert(encoder_d_model == weights.pol_encoder[0].mha.v_b.size()); + assert(embedding_op_size == weights.pol_encoder[0].ffn.dense2_b.size()); } - const size_t encoder_heads = vanilla.pol_encoder_head_count; + const size_t encoder_heads = weights.pol_encoder_head_count; size_t size = N * 64 * @@ -383,22 +383,6 @@ class CudaNetwork : public Network { ActivationFunction act = mish_net ? ACTIVATION_MISH : ACTIVATION_RELU; - std::string policy_head = - options.GetOrDefault("policy_head", "vanilla"); - // Check that selected policy head exists. - if (weights.policy_heads.count(policy_head) == 0) { - throw Exception("The policy head you specified '" + policy_head + - "' does not exist in this net."); - } - - std::string value_head = - options.GetOrDefault("value_head", "winner"); - // Check that selected value head exists. - if (weights.value_heads.count(value_head) == 0) { - throw Exception("The value head you specified '" + value_head + - "' does not exist in this net."); - } - use_int8_ = options.GetOrDefault("int8", false); int8_calibration_run_ = options.GetOrDefault("int8-calibrate", false); @@ -565,16 +549,13 @@ class CudaNetwork : public Network { : static_cast(ffn_activation); activations.default_activation = act; - auto new_encoding = - static_cast( - file.format().network_format().input_embedding()) == - InputEmbedding::INPUT_EMBEDDING_PE_DENSE; auto attention_body = std::make_unique>( weights, scratch_mem_, activations, numBlocks_, numBlocks_ > 0 ? kNumFilters : kInputPlanes, max_batch_size_, static_cast( file.format().network_format().input_embedding()) == - InputEmbedding::INPUT_EMBEDDING_PE_DENSE); + InputEmbedding::INPUT_EMBEDDING_PE_DENSE, + use_fused_mha, int8_calibration_run_, use_int8_, int8_weights_); network_.emplace_back(std::move(attention_body)); encoder_last_ = getLastLayer(); diff --git a/src/neural/cuda/network_cudnn.cc b/src/neural/cuda/network_cudnn.cc index e32be87ad2..dc39906fc4 100644 --- a/src/neural/cuda/network_cudnn.cc +++ b/src/neural/cuda/network_cudnn.cc @@ -60,16 +60,16 @@ static size_t getMaxAttentionHeadSize( size_t encoder_d_model = 0; size_t encoder_dff = 0; - if (vanilla.pol_encoder.size() > 0) { - encoder_d_model = vanilla.pol_encoder[0].mha.q_b.size(); - encoder_dff = vanilla.pol_encoder[0].ffn.dense1_b.size(); + if (weights.pol_encoder.size() > 0) { + encoder_d_model = weights.pol_encoder[0].mha.q_b.size(); + encoder_dff = weights.pol_encoder[0].ffn.dense1_b.size(); - assert(encoder_d_model == vanilla.pol_encoder[0].mha.k_b.size()); - assert(encoder_d_model == vanilla.pol_encoder[0].mha.v_b.size()); - assert(embedding_op_size == vanilla.pol_encoder[0].ffn.dense2_b.size()); + assert(encoder_d_model == weights.pol_encoder[0].mha.k_b.size()); + assert(encoder_d_model == weights.pol_encoder[0].mha.v_b.size()); + assert(embedding_op_size == weights.pol_encoder[0].ffn.dense2_b.size()); } - const size_t encoder_heads = vanilla.pol_encoder_head_count; + const size_t encoder_heads = weights.pol_encoder_head_count; size_t size = N * 64 * From 451fbf34b6738805ec611073975d3992cb34d9cc Mon Sep 17 00:00:00 2001 From: almaudoh Date: Mon, 22 Apr 2024 14:28:58 +0200 Subject: [PATCH 56/70] WIP. Reworked int8 to use scaling factors stored in weights. Added kernels for clipping of inputs for non-int8 inference. --- libs/lczero-common | 2 +- src/neural/cuda/cutlass_kernels.cu | 198 +++++++++++- src/neural/cuda/layers.cc | 496 ++++++++++++++++++++--------- src/neural/cuda/layers.h | 7 +- src/neural/cuda/network_cuda.cc | 80 +---- src/neural/network_legacy.cc | 15 +- src/neural/network_legacy.h | 12 + 7 files changed, 564 insertions(+), 246 deletions(-) diff --git a/libs/lczero-common b/libs/lczero-common index 55e1b382ef..250a498c49 160000 --- a/libs/lczero-common +++ b/libs/lczero-common @@ -1 +1 @@ -Subproject commit 55e1b382efadd57903e37f2a2e29caef3ea85799 +Subproject commit 250a498c49018354cef95fcc81a42158d9cd38d1 diff --git a/src/neural/cuda/cutlass_kernels.cu b/src/neural/cuda/cutlass_kernels.cu index 1abd7f42d8..9049b861b0 100644 --- a/src/neural/cuda/cutlass_kernels.cu +++ b/src/neural/cuda/cutlass_kernels.cu @@ -185,8 +185,12 @@ void dumpTensor(const T* memory, int elements, const char* message, } float maxval = -std::numeric_limits::max(); float minval = std::numeric_limits::max(); - int nans = 0; - int nanss[10]{}; + int cnans = 0; + int cnlims = 0; + int cplims = 0; + int nans[10]{}; + int nlims[10]{}; + int plims[10]{}; std::vector fpArr(elements); for (int i = 0; i < elements; i++) { @@ -206,14 +210,28 @@ void dumpTensor(const T* memory, int elements, const char* message, minval = std::min(minval, val); if (std::isnan(val)) { - if (nans < 10) nanss[nans] = i; - nans++; + if (cnans < 10) nans[cnans] = i; + cnans++; + } + if (int8) { + if (val >= 127) { + if (cplims < 10) plims[cplims] = i; + cplims++; + } else if (val <= -128) { + if (cnlims < 10) nlims[cnlims] = i; + cnlims++; + } } if (!only_summary || i < 2 || i == elements - 1) { - printf("%8.4f ", val); - if ((i % 8) == 7) printf("\n"); - // printf("%i;%.6f\n", i, val); + if (int8) { + printf("%6i ", (int8_t)val); + // printf("%i;%6i\n", i, (int8_t)val); + } else { + printf("%8.6f ", val); + // printf("%i;%8.6f\n", i, val); + } + if ((i % 8) == 7 || i == elements - 1) printf("\n"); } } if (!cpu_tensor) free(temp); @@ -224,12 +242,33 @@ void dumpTensor(const T* memory, int elements, const char* message, float avg = mean(&fpArr[0], elements); float stddev = stdDev(&fpArr[0], elements); - printf("Max: %.6f, Min: %.6f, Mean: %.6f, StdDev: %.6f, NaNs: %i of %i", - maxval, minval, avg, stddev, nans, elements); - if (nans > 0) { + if (int8) { + printf( + "Max: %i, Min: %i, Mean: %i, StdDev: %i\n" + "NaNs: %i, HiQuantLimit: %i, LoQuantLimit: %i, Total: %i", + (int8_t)maxval, (int8_t)minval, (int8_t)avg, (int8_t)stddev, cnans, + cplims, cnlims, elements); + + } else { + printf( + "Max: %.6f, Min: %.6f, Mean: %.6f, StdDev: %.6f\n" + "NaNs: %i of %i", + maxval, minval, avg, stddev, cnans, elements); + } + if (cnans > 0) { printf("\nNaN indices: "); - for (int i = 0; i < nans && i < 10; i++) printf("%i ", nanss[i]); - if (nans > 10) printf("......"); + for (int i = 0; i < cnans && i < 10; i++) printf("%i ", nans[i]); + if (cnans > 10) printf("......"); + } + if (cplims > 0) { + printf("\n127 indices: "); + for (int i = 0; i < cplims && i < 10; i++) printf("%i ", plims[i]); + if (cplims > 10) printf("......"); + } + if (cnlims > 0) { + printf("\n-128 indices: "); + for (int i = 0; i < cnlims && i < 10; i++) printf("%i ", nlims[i]); + if (cnlims > 10) printf("......"); } printf("\n"); } @@ -856,7 +895,7 @@ __global__ void quantizeMatrix(int8_t* output, const half* input, int height, copyAs(&factor[4], &scale[x + 4]); for (int i = 0; i < 8; i++) { - float val = roundf((float)ip[i] * factor[i]); + float val = roundf((float)ip[i] / factor[i]); if (val > 127) val = 127; if (val < -128) val = -128; op[i] = (int8_t)(val); @@ -877,6 +916,70 @@ void quantizeActivationMatrix(int8_t* output, const half* input, int height, ReportCUDAErrors(cudaGetLastError()); } +// Quantize matrix with single scale value +// process 8 elements per thread (in x dimension) +__global__ void quantizeMatrix(int8_t* output, const half* input, int height, + int width, const float scale) { + int x = (blockIdx.x * blockDim.x + threadIdx.x) * 8; + int y = blockIdx.y * blockDim.y + threadIdx.y; + + if (x >= width || y >= height) return; + + half ip[8]; + int8_t op[8]; + + copyAs(&ip[0], &input[y * width + x]); + + for (int i = 0; i < 8; i++) { + float val = roundf((float)ip[i] / scale); + if (val > 127) val = 127; + if (val < -128) val = -128; + op[i] = (int8_t)(val); + } + + copyAs(&output[y * width + x], &op[0]); +} + +// The scale is for all columns. +void quantizeActivationMatrix(int8_t* output, const half* input, int height, + int width, const float scale, + cudaStream_t stream) { + dim3 blockDim(16, 16); + dim3 gridDim(lczero::cudnn_backend::DivUp(width, 16 * 8), + lczero::cudnn_backend::DivUp(height, 16)); + quantizeMatrix<<>>(output, input, height, width, + scale); + ReportCUDAErrors(cudaGetLastError()); +} + +// Quantize matrix with single scale value +template +__global__ void clipMatrix(T* output, const T* input, const float* factors, + int height, int width) { + int x = (blockIdx.x * blockDim.x + threadIdx.x); + int y = blockIdx.y * blockDim.y + threadIdx.y; + + if (x >= width || y >= height) return; + + float limit = (float)(127 * factors[x]); + float val = (float)input[y * width + x]; + if (val > limit) val = limit; + if (val < -limit) val = -limit; + output[y * width + x] = (T)val; +} + +template +void clipActivationMatrix(DataType* output, const DataType* input, + const float* factors, int height, int width, + cudaStream_t stream) { + dim3 blockDim(16, 16); + dim3 gridDim(lczero::cudnn_backend::DivUp(width, 16), + lczero::cudnn_backend::DivUp(height, 16)); + clipMatrix + <<>>(output, input, factors, height, width); + ReportCUDAErrors(cudaGetLastError()); +} + #define MAX_BATCH_DEQUANT 16 struct ScaleParam { @@ -924,9 +1027,9 @@ __global__ void deQuantizeMatrix(half* output, const int8_t* input, // the bias is per column, per batch void deQuantizeOutputMatrixBiasAdd(half* output, const int8_t* input, int height, int width, int batchSize, - float* invScale, float *deq, const half* bias, - cudaStream_t stream, - ActivationFunction act = ACTIVATION_NONE) { + float* invScale, float* deq, + const half* bias, ActivationFunction act, + cudaStream_t stream) { dim3 blockDim(16, 16); dim3 gridDim(lczero::cudnn_backend::DivUp(width, 16 * 8), lczero::cudnn_backend::DivUp(height, 16), batchSize); @@ -944,6 +1047,63 @@ void deQuantizeOutputMatrixBiasAdd(half* output, const int8_t* input, ReportCUDAErrors(cudaGetLastError()); } +// process 8 elements per thread (in x dimension) +__global__ void deQuantizeMatrix(half* output, const int8_t* input, + const half* bias, int height, int width, + int stride, const float* invScale, + ActivationFunction act) { + int x = (blockIdx.x * blockDim.x + threadIdx.x) * 8; + int y = blockIdx.y * blockDim.y + threadIdx.y; + int b = blockIdx.z; + + if (x >= width || y >= height) return; + + int8_t ip[8] = {}; + half op[8] = {}; + half bi[8] = {}; + float inv_scale[8]; + + copyAs(&ip[0], &input[b * stride + y * width + x]); + if (bias) copyAs(&bi[0], &bias[b * width + x]); + + if (invScale) { + copyAs(&inv_scale[0], &invScale[b * width + x]); + copyAs(&inv_scale[4], &invScale[b * width + x + 4]); + } else { + for (int i = 0; i < 8; i++) inv_scale[i] = 1 / 127.0f; + } + + for (int i = 0; i < 8; i++) { + float val = (float)ip[i]; + val *= inv_scale[i]; + if (bias) val += (float)bi[i]; + op[i] = (half)activate(val, act); + } + + copyAs(&output[b * stride + y * width + x], &op[0]); +} + +// the scale (in CPU memory) is per "batch" +// the bias is per column, per batch +void deQuantizeOutputMatrixBiasAdd(half* output, const int8_t* input, + int height, int width, int batchSize, + float* invScale, const half* bias, + ActivationFunction act, + cudaStream_t stream) { + dim3 blockDim(16, 16); + dim3 gridDim(lczero::cudnn_backend::DivUp(width, 16 * 8), + lczero::cudnn_backend::DivUp(height, 16), batchSize); + + // otherwise we will need to put them in GPU memory + assert(batchSize < MAX_BATCH_DEQUANT); + + int stride = width * height; + + deQuantizeMatrix<<>>( + output, input, bias, height, width, stride, invScale, act); + ReportCUDAErrors(cudaGetLastError()); +} + void fillGpuArray(float* arr, float val, int count) { thrust::device_ptr dev_ptr(arr); thrust::fill(dev_ptr, dev_ptr + count, val); @@ -960,6 +1120,12 @@ template void calibrateGemmForInt8( float* output_scaling_factors, float* output_deq_factors, float* maxValuesA, float* maxValuesOut, const half* A, const half* B, int M, int N, int K, int batchSize, int M_Batch); +template void clipActivationMatrix(float* output, const float* input, + const float* factors, int height, + int width, cudaStream_t stream); +template void clipActivationMatrix(half* output, const half* input, + const float* factors, int height, + int width, cudaStream_t stream); template void dumpTensor(const float* memory, int elements, const char* message, bool only_summary, diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 1c0962649f..dd5de94c05 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -74,8 +74,10 @@ float stdDev(float arr[], int n) { template void dumpTensor(T* memory, int elements, const char* message, bool only_summary = false) { const bool fp16 = std::is_same::value; + const bool int8 = std::is_same::value; printf("\n%s\n", message); int elementSize = (int) (fp16 ? sizeof(half) : sizeof(float)); + if (int8) elementSize = sizeof(int8_t); int bytes = elements * elementSize; void *temp = malloc(bytes); cudaMemcpy(temp, memory, bytes, cudaMemcpyDeviceToHost); @@ -88,15 +90,15 @@ void dumpTensor(T* memory, int elements, const char* message, bool only_summary for (int i = 0; i < elements; i++) { float val; - if (fp16) - { - half *arr = (half*)temp; - val = (float)arr[i]; - } - else - { - float *arr = (float *)temp; - val = arr[i]; + if (int8) { + int8_t* arr = (int8_t*)temp; + val = (float)arr[i]; + } else if (fp16) { + half* arr = (half*)temp; + val = (float)arr[i]; + } else { + float* arr = (float*)temp; + val = arr[i]; } fpArr[i] = val; maxval = std::max(maxval, val); @@ -108,9 +110,13 @@ void dumpTensor(T* memory, int elements, const char* message, bool only_summary } if (!only_summary || i < 2 || i == elements - 1) { - printf("%8.4f ", val); - if ((i % 8) == 7) printf("\n"); - //printf("%i;%.6f\n", i, val); + if (int8) { + printf("%6i ", (int8_t)val); + } else { + printf("%8.6f ", val); + } + if ((i % 8) == 7 || i == elements - 1) printf("\n"); + // printf("%i;%.6f\n", i, val); } } free(temp); @@ -122,7 +128,15 @@ void dumpTensor(T* memory, int elements, const char* message, bool only_summary float avg = mean(&fpArr[0], elements); float stddev = stdDev(&fpArr[0], elements); - printf("Max: %.6f, Min: %.6f, Mean: %.6f, StdDev: %.6f, NaNs: %i of %i", maxval, minval, avg, stddev, nans, elements); + if (int8) { + printf("Max: %i, Min: %i, Mean: %i, StdDev: %i, NaNs: %i of %i", + (int8_t)maxval, (int8_t)minval, (int8_t)avg, (int8_t)stddev, nans, + elements); + + } else { + printf("Max: %.6f, Min: %.6f, Mean: %.6f, StdDev: %.6f, NaNs: %i of %i", + maxval, minval, avg, stddev, nans, elements); + } if (nans > 0) { printf("\nNaN indices: "); for (int i = 0; i < nans && i < 10; i++) printf("%i ", nanss[i]); @@ -1514,7 +1528,7 @@ AttentionPolicyHead::AttentionPolicyHead( max_batch_size, ACTIVATION_SWISH, act_, 1e-6, // attentionbody nets don't have policy encoders, so using old // epsilon for backward compatibility with T78. - false, false, false, nullptr, 0); + false, false, 0); encoder_weights_.emplace_back(pW); } } @@ -1580,28 +1594,159 @@ static int8_t* SetQuantizationData(MatMulQuantizationData& data, int8_t* w, return w; } +template +void calibrateGemmForInt8(int8_t* weights_int8, float* input_scaling_factors, + float* output_scaling_factors, + float* output_deq_factors, float* maxValuesA, + float* maxValuesOut, const DataType* A, + const DataType* B, int M, int N, int K, int batchSize, + int M_Batch); + +void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int8_t* Out, + int M, int N, int K, int batchSize, + int AStride, int BStride, int OutStride, + float alphaf, float betaf); + +void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, + const float* scaleVector, int8_t* Out, int M, + int N, int K, int batchSize, int AStride, + int BStride, int OutStride, int VecStride, + float alphaf, float betaf); + +void quantizeActivationMatrix(int8_t* output, const half* input, int height, + int width, const float* scale, + cudaStream_t stream); + +void quantizeActivationMatrix(int8_t* output, const half* input, int height, + int width, const float scale, + cudaStream_t stream); + +template +void clipActivationMatrix(DataType* output, const DataType* input, + const float* factors, int height, int weight, + cudaStream_t stream); + +void deQuantizeOutputMatrixBiasAdd(half* output, const int8_t* input, + int height, int width, int batchSize, + float* scale, float* deq, const half* bias, + ActivationFunction act, cudaStream_t stream); + +void deQuantizeOutputMatrixBiasAdd(half* output, const int8_t* input, + int height, int width, int batchSize, + float* scale, const half* bias, + ActivationFunction act, cudaStream_t stream); + +static void LoadQuantizationData(MatMulQuantizationData& data, + const half* weights, int input_len, + int output_len, + const std::vector& weightsFactors, + const std::vector& inputFactors, + cudaStream_t stream) { + // Load weights for INT8 inference + + // (per-column) scaling factors for the input and output. + ReportCUDAErrors( + cudaMalloc(&data.input_scaling_factors, input_len * sizeof(float))); + ReportCUDAErrors( + cudaMalloc(&data.output_scaling_factors, output_len * sizeof(float))); + + if (inputFactors.size() > 1) { + // ReportCUDAErrors(cudaMemcpy( + // data.output_scaling_factors, inputFactors.data(), + // inputFactors.size() * sizeof(float), cudaMemcpyHostToDevice)); + throw Exception("Channelwise quantization not yet supported."); + } else { + // Repeatedly fill values into the input factors buffer. + fillGpuArray(data.input_scaling_factors, inputFactors[0], input_len); + + // Repeatedly fill values into the output factors buffer. + fillGpuArray(data.output_scaling_factors, + inputFactors[0] * weightsFactors[0], output_len); + } + + // Load weights and run a GPU kernel to scale it. + int weights_len = input_len * output_len; + ReportCUDAErrors( + cudaMalloc(&data.weights_int8, weights_len * sizeof(int8_t))); + quantizeActivationMatrix(data.weights_int8, weights, 1, weights_len, + weightsFactors[0], stream); +} + +static void LoadKQVQuantizationData(MatMulQuantizationData& data, + const half* kqv_weights, int input_len, + int output_len, + const std::vector& kWeightsFactors, + const std::vector& qWeightsFactors, + const std::vector& vWeightsFactors, + const std::vector& inputFactors, + cudaStream_t stream) { + // Load weights for INT8 inference. + + // (per-column) scaling factors for the input and output. + ReportCUDAErrors( + cudaMalloc(&data.input_scaling_factors, input_len * sizeof(float))); + ReportCUDAErrors( + cudaMalloc(&data.output_scaling_factors, output_len * 3 * sizeof(float))); + + if (inputFactors.size() > 1) { + // ReportCUDAErrors(cudaMemcpy( + // data.output_scaling_factors, inputFactors.data(), + // inputFactors.size() * sizeof(float), cudaMemcpyHostToDevice)); + throw Exception("Channelwise quantization not yet supported."); + } else { + // Repeatedly fill values into the input factors buffer. + fillGpuArray(data.input_scaling_factors, inputFactors[0], input_len); + + // Repeatedly fill values into the output factors buffer. + fillGpuArray(data.output_scaling_factors, + inputFactors[0] * kWeightsFactors[0], output_len); + fillGpuArray(data.output_scaling_factors + output_len, + inputFactors[0] * qWeightsFactors[0], output_len); + fillGpuArray(data.output_scaling_factors + output_len * 2, + inputFactors[0] * vWeightsFactors[0], output_len); + } + + // Load KQV weights and run a GPU kernel to scale them. + int weights_len = input_len * output_len; + ReportCUDAErrors( + cudaMalloc(&data.weights_int8, weights_len * 3 * sizeof(int8_t))); + quantizeActivationMatrix(data.weights_int8, kqv_weights, 1, weights_len, + kWeightsFactors[0], stream); + quantizeActivationMatrix(data.weights_int8 + weights_len, + kqv_weights + weights_len, 1, weights_len, + qWeightsFactors[0], stream); + quantizeActivationMatrix(data.weights_int8 + weights_len * 2, + kqv_weights + weights_len * 2, 1, weights_len, + vWeightsFactors[0], stream); +} + template EncoderBlock::EncoderBlock( const MultiHeadWeights::EncoderLayer& cpu_weights, void* scratch, int heads, int size, float alpha, DataType* smolgen_global_scratch, int smolgen_global_size, int max_batch_size, ActivationFunction smolgen_act, ActivationFunction ffn_act, float default_eps, bool fused_mha, - bool int8_calibrate, bool int8_inference, void* int8_weights, - int blockIndex) + bool int8_inference, int blockIndex) : embedding_op_size_(size), encoder_heads_(heads), alpha_(alpha), use_fused_mha_(fused_mha), int8_inf_(int8_inference), - int8_cali_(int8_calibrate), has_smolgen_(cpu_weights.mha.has_smolgen), smolgen_activation_(smolgen_act), ffn_activation_(ffn_act), max_batch_size_(max_batch_size), - default_eps_(default_eps) { - mha_q_size_ = cpu_weights.mha.q_b.size(); - mha_k_size_ = cpu_weights.mha.k_b.size(); - mha_v_size_ = cpu_weights.mha.v_b.size(); + default_eps_(default_eps), + is_quantized_(cpu_weights.mha.s1.size() > 0) { + mha_q_size_ = cpu_weights.mha.q_b.size() > 0 + ? cpu_weights.mha.q_b.size() + : cpu_weights.mha.q_w.size() / size; + mha_k_size_ = cpu_weights.mha.k_b.size() > 0 + ? cpu_weights.mha.k_b.size() + : cpu_weights.mha.k_w.size() / size; + mha_v_size_ = cpu_weights.mha.v_b.size() > 0 + ? cpu_weights.mha.v_b.size() + : cpu_weights.mha.v_w.size() / size; mha_dense_size_ = cpu_weights.mha.dense_b.size(); ffn_dense1_size_ = cpu_weights.ffn.dense1_b.size(); ffn_dense2_size_ = cpu_weights.ffn.dense2_b.size(); @@ -1652,6 +1797,8 @@ EncoderBlock::EncoderBlock( allocAndUpload(&ln2_gammas, cpu_weights.ln2_gammas, scratch); allocAndUpload(&ln2_betas, cpu_weights.ln2_betas, scratch); + // printf("mhaqb: %i, mhaqw: %i, embedsize: %i\n", cpu_weights.mha.q_b.size(), + // cpu_weights.mha.q_w.size(), size); // Smolgen weights. if (has_smolgen_) { @@ -1685,7 +1832,7 @@ EncoderBlock::EncoderBlock( } // int8 stuff blockIndex_ = blockIndex; - if (int8_inference || int8_calibrate) { + if (int8_inf_ || is_quantized_) { /* int per_encoder_size = embedding_op_size_ * sizeof(float) + 3 * embedding_op_size_ * mha_q_size_ + @@ -1697,23 +1844,46 @@ EncoderBlock::EncoderBlock( embedding_op_size_ * ffn_dense1_size_ + sizeof(float); */ - int embedding_op_size = embedding_op_size_; - int encoder_d_model = mha_q_size_; - int encoder_dff = ffn_dense1_size_; - int per_encoder_size = - (embedding_op_size * sizeof(float) + 3 * embedding_op_size * encoder_d_model + 3 * (encoder_d_model + 1) * sizeof(float) + - encoder_d_model * sizeof(float) + encoder_d_model * embedding_op_size + (embedding_op_size + 1) * sizeof(float) + - embedding_op_size * sizeof(float) + embedding_op_size * encoder_dff + (encoder_dff + 1) * sizeof(float) + - encoder_dff * sizeof(float) + encoder_dff * embedding_op_size + (embedding_op_size + 1) * sizeof(float)); - - auto w = (int8_t*)int8_weights; - // go to current encoder block - w += per_encoder_size * blockIndex; - w = SetQuantizationData(kqv_, w, embedding_op_size_, mha_q_size_, 3, int8_calibrate); - w = SetQuantizationData(mha_dense_, w, mha_q_size_, embedding_op_size_, 1, int8_calibrate); - w = SetQuantizationData(ffn1_, w, embedding_op_size_, ffn_dense1_size_, 1, int8_calibrate); - w = SetQuantizationData(ffn2_, w, ffn_dense1_size_, embedding_op_size_, 1, int8_calibrate); - // printf("\nSize of weights: %d\n", (w - (int8_t*)int8_weights)); + /** + int embedding_op_size = embedding_op_size_; + int encoder_d_model = mha_q_size_; + int encoder_dff = ffn_dense1_size_; + int per_encoder_size = + (embedding_op_size * sizeof(float) + 3 * embedding_op_size * + encoder_d_model + 3 * (encoder_d_model + 1) * sizeof(float) + + encoder_d_model * sizeof(float) + encoder_d_model * + embedding_op_size + (embedding_op_size + 1) * sizeof(float) + + embedding_op_size * sizeof(float) + embedding_op_size * + encoder_dff + (encoder_dff + 1) * sizeof(float) + + encoder_dff * sizeof(float) + encoder_dff * + embedding_op_size + (embedding_op_size + 1) * sizeof(float)); + + auto w = (int8_t*)int8_weights; + // go to current encoder block + w += per_encoder_size * blockIndex; + w = SetQuantizationData(kqv_, w, embedding_op_size_, mha_q_size_, 3, + int8_calibrate); w = SetQuantizationData(mha_dense_, w, mha_q_size_, + embedding_op_size_, 1, int8_calibrate); w = SetQuantizationData(ffn1_, w, + embedding_op_size_, ffn_dense1_size_, 1, int8_calibrate); w = + SetQuantizationData(ffn2_, w, ffn_dense1_size_, embedding_op_size_, 1, + int8_calibrate); + // printf("\nSize of weights: %d\n", (w - (int8_t*)int8_weights)); + */ + + CERR << "QKV input factor: " << cpu_weights.mha.s1[0]; + LoadKQVQuantizationData(kqv_, (half*)mha_qkv_w, embedding_op_size_, + mha_q_size_, cpu_weights.mha.k_s, + cpu_weights.mha.q_s, cpu_weights.mha.v_s, + cpu_weights.mha.s1, 0); + LoadQuantizationData(mha_dense_, (half*)mha_dense_w, embedding_op_size_, + mha_dense_size_, cpu_weights.mha.dense_s, + cpu_weights.mha.s2, 0); + LoadQuantizationData(ffn1_, (half*)ffn_dense1_w, embedding_op_size_, + ffn_dense1_size_, cpu_weights.ffn.dense1_s, + cpu_weights.ffn.s1, 0); + LoadQuantizationData(ffn2_, (half*)ffn_dense2_w, ffn_dense1_size_, + ffn_dense2_size_, cpu_weights.ffn.dense2_s, + cpu_weights.ffn.s2, 0); // print some weights /* @@ -1786,35 +1956,6 @@ static void cublasXGemmBatched(cublasHandle_t handle, cublasOperation_t transa, } } -template -void calibrateGemmForInt8(int8_t* weights_int8, float* input_scaling_factors, - float* output_scaling_factors, - float* output_deq_factors, float* maxValuesA, - float* maxValuesOut, const DataType* A, - const DataType* B, int M, int N, int K, int batchSize, - int M_Batch); - -void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int8_t* Out, - int M, int N, int K, int batchSize, - int AStride, int BStride, int OutStride, - float alphaf, float betaf); - -void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, - const float* scaleVector, int8_t* Out, int M, - int N, int K, int batchSize, int AStride, - int BStride, int OutStride, int VecStride, - float alphaf, float betaf); - -void quantizeActivationMatrix(int8_t* output, const half* input, int height, - int width, const float* scale, - cudaStream_t stream); - -void deQuantizeOutputMatrixBiasAdd(half* output, const int8_t* input, - int height, int width, int batchSize, - float* scale, float* deq, const half* bias, - cudaStream_t stream, - ActivationFunction act = ACTIVATION_NONE); - // input/output tensor is in_out_tensor, others are used as scratch. template void EncoderBlock::Eval(int N, DataType* in_out_tensor, @@ -1912,29 +2053,35 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, mha_k = mha_q + num_outputs * batch_to_use; mha_v = mha_k + num_outputs * batch_to_use; - if (int8_cali_) { - calibrateGemmForInt8(kqv_.weights_int8, kqv_.input_scaling_factors, - kqv_.output_scaling_factors, kqv_.output_deq_factors, - kqv_.input_matrix_max_values, - kqv_.output_matrix_max_values, in_out_tensor, - mha_qkv_w, 64, d_model, embedding_op_size_, 3, N); - } + // if (int8_cali_) { + // calibrateGemmForInt8(kqv_.weights_int8, kqv_.input_scaling_factors, + // kqv_.output_scaling_factors, + // kqv_.output_deq_factors, + // kqv_.input_matrix_max_values, + // kqv_.output_matrix_max_values, in_out_tensor, + // mha_qkv_w, 64, d_model, embedding_op_size_, 3, N); + // } - if (true && int8_inf_) { + if (is_quantized_ && int8_inf_) { // printf("\nAttempting int8_inf\n"); // 1. quantize the inputs (in_out_tensor -> scratch) // TODO: Fuse this step with layer-norm of previous block quantizeActivationMatrix((int8_t*)scratch, (const half*)in_out_tensor, batch, embedding_op_size_, - kqv_.input_scaling_factors, stream); + kqv_.input_scaling_factors, 0); + dumpTensor((int8_t*)scratch, num_inputs * batch / 4, + "encoder 1 qkv input quantized", true); + dumpTensor(kqv_.weights_int8, num_inputs * num_outputs * 3, + "encoder 1 qkv weights quantized", true); // 2. perform int8 GEMM (scratch -> buffer1) - /* - cutlassMatrixMulBTransposed( - (const int8_t*)scratch, kqv_.weights_int8, (int8_t*)buffer1, batch, - num_outputs, num_inputs, 3, 0, num_inputs * num_outputs, - num_outputs * batch_to_use, 1.0 / 127.0, 0.0f); - */ + + cutlassMatrixMulBTransposed((const int8_t*)scratch, kqv_.weights_int8, + (int8_t*)buffer1, batch, num_outputs, + num_inputs, 3, 0, num_inputs * num_outputs, + num_outputs * batch_to_use, 1.0, 0.0f); + dumpTensor((int8_t*)buffer1, num_outputs * batch / 4, + "encoder 1 qkv output quantized", true); // per-layer output scaling cutlassMatrixMulBTransposed( @@ -1943,6 +2090,10 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, num_inputs, 3, 0, num_inputs * num_outputs, num_outputs * batch_to_use, num_outputs, 1.0f, 0.0f); + dumpTensor((int8_t*)buffer1, num_outputs * batch / 4, + "encoder 1 qkv output quantized-descaled", true); + exit(0); + ReportCUDAErrors(cudaGetLastError()); /* dumpTensor((const int8_t*)scratch, 768, "quantized input matrix", @@ -1960,17 +2111,41 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, */ // 3. de-quantize outputs - fused with bias add (buffer1 -> scratch) // TODO: fuse the entire thing with the above GEMM. + // deQuantizeOutputMatrixBiasAdd( + // (half*)scratch, (const int8_t*)buffer1, batch, num_outputs, 3, + // kqv_.output_scaling_factors, kqv_.output_deq_factors, + // (const half*)mha_qkv_b, ACTIVATION_NONE, stream); + deQuantizeOutputMatrixBiasAdd( (half*)scratch, (const int8_t*)buffer1, batch, num_outputs, 3, - kqv_.output_scaling_factors, kqv_.output_deq_factors, - (const half*)mha_qkv_b, stream); + kqv_.output_scaling_factors, (const half*)/*mha_qkv_b*/ nullptr, + ACTIVATION_NONE, stream); + + // dumpTensor((int8_t*)scratch, num_outputs * batch / 4, + // "encoder 1 qkv output no bias", true); + + dumpTensor((half*)scratch, num_outputs * batch / 4, + "encoder 1 qkv output dequant"); + + dumpTensor((float*)kqv_.input_scaling_factors, num_inputs, + "encoder 1 qkv input factors"); + + dumpTensor((float*)kqv_.output_scaling_factors, num_outputs * 3, + "encoder 1 qkv output factors"); + exit(0); /* dumpTensor((const half*)scratch, 768, - "dequantized output values after bias add", false, false); - exit(0); + "dequantized output values after bias add", false, + false); exit(0); */ } else { + if (is_quantized_) { + clipActivationMatrix( + (DataType*)in_out_tensor, (const DataType*)in_out_tensor, + kqv_.input_scaling_factors, batch, num_inputs, stream); + } + cublasXGemmStridedBatched( cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, mha_qkv_w, num_inputs, num_inputs * num_outputs, in_out_tensor, @@ -1981,15 +2156,23 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, (half*)mha_q, batch_to_use, num_outputs, num_inputs, 3, 0, num_inputs * num_outputs, num_outputs * batch_to_use, true); - +#endif + dumpTensor((DataType*)mha_q, num_outputs * batch / 4, + "encoder 1 qkv output no bias", false); + exit(0); + addBiasBatched(mha_q, mha_q, mha_qkv_b, 3, batch, num_outputs, - batch_to_use, NONE, stream); + batch_to_use, ACTIVATION_NONE, stream); - dumpTensor((const DataType*)mha_q, - /*num_outputs * batch_to_use*/ 768, - "ref output values after bias add", false, false); + dumpTensor((DataType*)mha_q, num_outputs * batch / 4, + "encoder 1 qkv output", true); + // exit(0); + + // dumpTensor((const DataType*)mha_q, + // /*num_outputs * batch_to_use*/ + // 768, "ref output values after + // bias add", false, false); exit(0); -#endif } } @@ -2102,20 +2285,20 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // #final dense layer (mha_dense), buffer2 -> buffer1 { - if (int8_cali_) { - calibrateGemmForInt8( - mha_dense_.weights_int8, mha_dense_.input_scaling_factors, - mha_dense_.output_scaling_factors, mha_dense_.output_deq_factors, - mha_dense_.input_matrix_max_values, - mha_dense_.output_matrix_max_values, buffer2, mha_dense_w, 64, - embedding_op_size_, d_model, 1, N); - } + // if (int8_cali_) { + // calibrateGemmForInt8( + // mha_dense_.weights_int8, mha_dense_.input_scaling_factors, + // mha_dense_.output_scaling_factors, mha_dense_.output_deq_factors, + // mha_dense_.input_matrix_max_values, + // mha_dense_.output_matrix_max_values, buffer2, mha_dense_w, 64, + // embedding_op_size_, d_model, 1, N); + // } const int num_inputs = d_model; const int num_outputs = embedding_op_size_; const int batch = N * 64; - if (true && int8_inf_) { + if (is_quantized_ && int8_inf_) { // 1. quantize the inputs (buffer2 -> scratch) // TODO: Fuse this step with the previous fused MHA quantizeActivationMatrix((int8_t*)scratch, (const half*)buffer2, batch, @@ -2123,25 +2306,35 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, stream); // 2. perform int8 GEMM (scratch -> buffer2) - /* cutlassMatrixMulBTransposed( (const int8_t*)scratch, mha_dense_.weights_int8, (int8_t*)buffer2, - batch, num_outputs, num_inputs, 1, 0, 0, 0, 1.0 / 127.0, 0.0f); - */ + batch, num_outputs, num_inputs, 1, 0, 0, 0, 1.0, 0.0f); + /* cutlassMatrixMulBTransposed( (const int8_t*)scratch, mha_dense_.weights_int8, mha_dense_.output_scaling_factors, (int8_t*)buffer2, batch, num_outputs, num_inputs, 1, 0, 0, 0, 0, 1.0f, 0.0f); + */ ReportCUDAErrors(cudaGetLastError()); // 3. de-quantize outputs (buffer2 -> buffer1) // TODO: Fuse this with LN1 (should be easy!) + // deQuantizeOutputMatrixBiasAdd( + // (half*)buffer1, (const int8_t*)buffer2, batch, num_outputs, 1, + // mha_dense_.output_scaling_factors, mha_dense_.output_deq_factors, + // nullptr, ACTIVATION_NONE, stream); + deQuantizeOutputMatrixBiasAdd( (half*)buffer1, (const int8_t*)buffer2, batch, num_outputs, 1, - mha_dense_.output_scaling_factors, mha_dense_.output_deq_factors, - nullptr, stream); + mha_dense_.output_scaling_factors, nullptr, ACTIVATION_NONE, stream); + + // dequantizeWithLayerNorm(N * 64, embedding_op_size_, scratch, + // buffer1, mha_dense_b, + // in_out_tensor, ln1_gammas, ln1_betas, default_eps_, + // alpha_, ACTIVATION_NONE, stream); + } else { cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, mha_dense_w, num_inputs, buffer2, @@ -2163,19 +2356,19 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // #FFN dense 1, scratch -> in_out_tensor const int encoder_dff = ffn_dense1_size_; { - if (int8_cali_) { - calibrateGemmForInt8( - ffn1_.weights_int8, ffn1_.input_scaling_factors, - ffn1_.output_scaling_factors, ffn1_.output_deq_factors, - ffn1_.input_matrix_max_values, ffn1_.output_matrix_max_values, - scratch, ffn_dense1_w, 64, encoder_dff, embedding_op_size_, 1, N); - } + // if (int8_cali_) { + // calibrateGemmForInt8( + // ffn1_.weights_int8, ffn1_.input_scaling_factors, + // ffn1_.output_scaling_factors, ffn1_.output_deq_factors, + // ffn1_.input_matrix_max_values, ffn1_.output_matrix_max_values, + // scratch, ffn_dense1_w, 64, encoder_dff, embedding_op_size_, 1, N); + // } const int num_inputs = embedding_op_size_; const int num_outputs = ffn_dense1_size_; // encoder_dff const int batch = N * 64; - if (true && int8_inf_) { + if (is_quantized_ && int8_inf_) { // 1. quantize the inputs (scratch -> in_out_tensor) // TODO: Fuse this step with LN1 (should be easy) quantizeActivationMatrix((int8_t*)in_out_tensor, (const half*)scratch, @@ -2189,9 +2382,8 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, 0, 0, 0, 1.0 / 127.0, 0.0f); */ cutlassMatrixMulBTransposed( - (const int8_t*)in_out_tensor, ffn1_.weights_int8, - ffn1_.output_scaling_factors, (int8_t*)buffer1, batch, num_outputs, - num_inputs, 1, 0, 0, 0, 0, 1.0f, 0.0f); + (const int8_t*)in_out_tensor, ffn1_.weights_int8, (int8_t*)buffer1, + batch, num_outputs, num_inputs, 1, 0, 0, 0, 1.0f, 0.0f); ReportCUDAErrors(cudaGetLastError()); /* @@ -2215,8 +2407,8 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // TODO: Fuse this with the above GEMM deQuantizeOutputMatrixBiasAdd( (half*)in_out_tensor, (const int8_t*)buffer1, batch, num_outputs, 1, - ffn1_.output_scaling_factors, ffn1_.output_deq_factors, - (const half*)ffn_dense1_b, stream, ffn_activation_); + ffn1_.output_scaling_factors, (const half*)ffn_dense1_b, + ffn_activation_, stream); // Ankan - test! // dumpTensor((const DataType*)in_out_tensor, 768, @@ -2246,14 +2438,14 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // #FFN dense 2, in_out_tensor -> buffer1 { - if (int8_cali_) { - calibrateGemmForInt8( - ffn2_.weights_int8, ffn2_.input_scaling_factors, - ffn2_.output_scaling_factors, ffn2_.output_deq_factors, - ffn2_.input_matrix_max_values, ffn2_.output_matrix_max_values, - in_out_tensor, ffn_dense2_w, 64, embedding_op_size_, encoder_dff, 1, - N); - } + // if (int8_cali_) { + // calibrateGemmForInt8( + // ffn2_.weights_int8, ffn2_.input_scaling_factors, + // ffn2_.output_scaling_factors, ffn2_.output_deq_factors, + // ffn2_.input_matrix_max_values, ffn2_.output_matrix_max_values, + // in_out_tensor, ffn_dense2_w, 64, embedding_op_size_, encoder_dff, + // 1, N); + // } const int num_inputs = ffn_dense1_size_; // encoder_dff const int num_outputs = embedding_op_size_; @@ -2284,9 +2476,8 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, */ cutlassMatrixMulBTransposed((const int8_t*)buffer1, ffn2_.weights_int8, - ffn2_.output_scaling_factors, (int8_t*)in_out_tensor, batch, num_outputs, - num_inputs, 1, 0, 0, 0, 0, 1.0f, 0.0f); + num_inputs, 1, 0, 0, 0, 1.0f, 0.0f); ReportCUDAErrors(cudaGetLastError()); /* @@ -2300,8 +2491,7 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // TODO: Fuse this with LN2 (should be easy) deQuantizeOutputMatrixBiasAdd( (half*)buffer1, (const int8_t*)in_out_tensor, batch, num_outputs, 1, - ffn2_.output_scaling_factors, ffn2_.output_deq_factors, nullptr, - stream); + ffn2_.output_scaling_factors, nullptr, ACTIVATION_NONE, stream); /* dumpTensor((const half*)buffer1, 768, "dequantized output values", false, false); @@ -2470,16 +2660,17 @@ EncoderBlock::~EncoderBlock() { ReportCUDAErrors(cudaFree(ffn2_.weights_int8)); ReportCUDAErrors(cudaFree(ffn2_.input_scaling_factors)); ReportCUDAErrors(cudaFree(ffn2_.output_scaling_factors)); - } else if (int8_cali_) { - free(kqv_.input_matrix_max_values); - free(kqv_.output_matrix_max_values); - free(mha_dense_.input_matrix_max_values); - free(mha_dense_.output_matrix_max_values); - free(ffn1_.input_matrix_max_values); - free(ffn1_.output_matrix_max_values); - free(ffn2_.input_matrix_max_values); - free(ffn2_.output_matrix_max_values); } + // else if (int8_cali_) { + // free(kqv_.input_matrix_max_values); + // free(kqv_.output_matrix_max_values); + // free(mha_dense_.input_matrix_max_values); + // free(mha_dense_.output_matrix_max_values); + // free(ffn1_.input_matrix_max_values); + // free(ffn1_.output_matrix_max_values); + // free(ffn2_.input_matrix_max_values); + // free(ffn2_.output_matrix_max_values); + // } } template @@ -2518,8 +2709,7 @@ AttentionBody::AttentionBody(const MultiHeadWeights& weights, int num_res_blocks, int input_c, int max_batch_size, bool is_pe_dense_embedding, - bool fused_mha, bool int8_calibrate, - bool int8_inference, void* int8_weights) + bool fused_mha, bool int8_inference) : BaseLayer(weights.ip_emb_b.size(), 8, 8, nullptr), embedding_op_size_(weights.ip_emb_b.size()), encoder_head_count_(weights.encoder_head_count), @@ -2567,6 +2757,13 @@ AttentionBody::AttentionBody(const MultiHeadWeights& weights, cudaMemcpy(scratch, kPosEncoding, size, cudaMemcpyHostToDevice)); copyTypeConverted(pos_encoding_, (float*)scratch, size, 0); } + printf("Is PE Dense Embedding %i\n", is_pe_dense_embedding); + + printf("has_gating: %i, has_smolgen: %i\n", has_gating_, has_smolgen_); + printf("ip_mult_gate: %i\n", weights.ip_mult_gate.size()); + // for (auto i=0; i(&ip_mult_gate_, weights.ip_mult_gate, scratch); @@ -2581,13 +2778,14 @@ AttentionBody::AttentionBody(const MultiHeadWeights& weights, int num_encoders = weights.encoder.size(); float alpha = (float)pow(2.0 * num_encoders, -0.25); int index = 0; + for (const auto& enc : weights.encoder) { EncoderBlock* pW = new EncoderBlock( enc, scratch, encoder_head_count_, embedding_op_size_, alpha, smolgen_global_, smolgen_global_size_, max_batch_size, activations_.smolgen_activation, activations_.ffn_activation, - is_pe_dense_embedding_ ? 1e-3 : 1e-6, use_fused_mha_, int8_calibrate, - int8_inference, int8_weights, index++); + is_pe_dense_embedding_ ? 1e-3 : 1e-6, use_fused_mha_, int8_inference, + index++); encoder_weights_.emplace_back(pW); } } @@ -2766,9 +2964,11 @@ void AttentionBody::Eval(int N, DataType* output, } // 2. Encoder blocks - for (const auto pEnc : encoder_weights_) { - pEnc->Eval(N, output_tensor, (DataType*)scratch, buffer1, buffer2, cublas, - stream, offset_pointers); + for (const auto enc : encoder_weights_) { + enc->Eval(N, output_tensor, (DataType*)scratch, buffer1, buffer2, cublas, + stream, offset_pointers); + dumpTensor(output_tensor, embedding_op_size_, "encoder 1 output"); + break; } // End of encoder blocks } diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index 50f5312c6c..27d33f8faa 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -355,8 +355,7 @@ class EncoderBlock { DataType* smolgen_global_scratch, int smolgen_global_size, int max_batch_size, ActivationFunction smolgen_act, ActivationFunction ffn_act, float default_eps, bool fused_mha, - bool int8_calibrate, bool int8_inference, void* int8_weights, - int blockIndex); + bool int8_inference, int blockIndex); ~EncoderBlock(); void Eval(int N, DataType* inpop, DataType* scratch0, DataType* scratch1, @@ -387,7 +386,7 @@ class EncoderBlock { // int 8 stuff int blockIndex_; - bool int8_inf_, int8_cali_; + bool int8_inf_, is_quantized_; MatMulQuantizationData kqv_; MatMulQuantizationData mha_dense_; MatMulQuantizationData ffn1_; @@ -503,7 +502,7 @@ class AttentionBody : public BaseLayer { AttentionBody(const MultiHeadWeights& weights, void* scratch, Activations activations, int num_res_blocks, int input_c, int max_batch_size, bool is_pe_dense_embedding, bool fused_mha, - bool int8_calibrate, bool int8_inference, void* int8_weights); + bool int8_inference); ~AttentionBody(); void Eval(int N, DataType* output, const DataType* input, const DataType* input2, void* scratch, size_t scratch_size, diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 92c1ff8a4d..7cfca11d9d 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -383,70 +383,17 @@ class CudaNetwork : public Network { ActivationFunction act = mish_net ? ACTIVATION_MISH : ACTIVATION_RELU; + // @todo can we auto-detect by checking for scaling factors? + // @todo or otherwise assert that weights file isn't quantized. use_int8_ = options.GetOrDefault("int8", false); - int8_calibration_run_ = options.GetOrDefault("int8-calibrate", false); - if (int8_calibration_run_ || use_int8_) { - if (!fp16 && use_int8_) + if (use_int8_) { + if (!fp16) throw Exception("INT8 is supported only with cuda-fp16 backend."); if (!attn_body_) throw Exception("INT8 only supported for attention body networks"); - - // Structure of the weights file: - // For each encoder block - - // * per-channel scaling factors for Input Matrix to QKV GEMM (embedding_op_size floats) - // (to use for quantization of the input) - // * qunatized (int8) weights for QKV GEMMs (3 * encoder_d_model * embedding_op_size int8_ts) - // * per-channel scaling factors for quantizing the Outut matrix (encoder_d_model * 3 floats) - // * per-tensor output dequantization factors (3 floats) - // - // * per-channel scaling factors for the MHA dense layer's input (encoder_d_model floats) - // * Qunatized (int8) weights for MHA dense (embedding_op_size * encoder_d_model int8_ts) - // * per-channel output scaling factors for MHA dense (embedding_op_size floats) - // * per-tensor output dequantization factor (1 float) - // - // * per-channel scaling factors for input to FFN1 (embedding_op_size_ floats) - // * Qunatized (int8) weights for FFN1 (encoder_dff * encoder_d_model int8_ts) - // * per-channel output scaling factors for FFN1 (encoder_dff floats) - // * per-tensor output dequantization factor (1 float) - // - // * per-channel scaling factors for input to FFN2 (encoder_dff floats) - // * Qunatized (int8) weights for FFN2 (embedding_op_size * encoder_dff int8_ts) - // * per-channel output scaling factors for FFN2 (embedding_op_size floats) - // * per-tensor output dequantization factor (1 float) - int embedding_op_size = weights.ip_emb_b.size(); - int encoder_d_model = weights.encoder[0].mha.q_b.size(); - int encoder_dff = weights.encoder[0].ffn.dense1_b.size(); - int num_encoders = weights.encoder.size(); - int8_weights_size_ = - num_encoders * - (embedding_op_size * sizeof(float) + 3 * embedding_op_size * encoder_d_model + 3 * (encoder_d_model + 1) * sizeof(float) + - encoder_d_model * sizeof(float) + encoder_d_model * embedding_op_size + (embedding_op_size + 1) * sizeof(float) + - embedding_op_size * sizeof(float) + embedding_op_size * encoder_dff + (encoder_dff + 1) * sizeof(float) + - encoder_dff * sizeof(float) + encoder_dff * embedding_op_size + (embedding_op_size + 1) * sizeof(float)); - - int8_weights_ = malloc(int8_weights_size_); - memset(int8_weights_, 0, int8_weights_size_); - - printf("\nint8_weights_size: %d\n", int8_weights_size_); } - if (int8_calibration_run_) { - // we will write the file at the time of exit. - } else if (use_int8_) { - FILE* fp = fopen("weights_quant.bin", "rb"); - if (!fp) { - CERR << "ERROR: weights_quant.bin not found. Please run 'lc0 benchmark " - "-t 1 --nodes=1 -w --backend=cuda " - "--backend-opts=int8-calibrate=true' first"; - throw Exception("Quantized weights not found"); - } else { - int read = fread(int8_weights_, 1, int8_weights_size_, fp); - fclose(fp); - if (read != int8_weights_size_) - throw Exception( - "Quantized weights likely corrupted or of different network"); - #if 0 // Ankan - test: dump some weights here float* data = (float*)int8_weights_; @@ -466,9 +413,6 @@ class CudaNetwork : public Network { exit(0); #endif - } - } - // 2. Build the network, and copy the weights to GPU memory. // Input conv only used if there are residual blocks in the network @@ -555,7 +499,7 @@ class CudaNetwork : public Network { static_cast( file.format().network_format().input_embedding()) == InputEmbedding::INPUT_EMBEDDING_PE_DENSE, - use_fused_mha, int8_calibration_run_, use_int8_, int8_weights_); + use_fused_mha, use_int8_); network_.emplace_back(std::move(attention_body)); encoder_last_ = getLastLayer(); @@ -978,16 +922,6 @@ class CudaNetwork : public Network { ReportCUDAErrors(cudaFree(head_offset_pointers_)); cublasDestroy(cublas_); } - - if (int8_calibration_run_) { - // write the calibration data/weights to file - FILE* fp = fopen("weights_quant.bin", "wb+"); - fwrite(int8_weights_, 1, int8_weights_size_, fp); - fclose(fp); - } - if (int8_calibration_run_ || use_int8_) - free(int8_weights_); - } const NetworkCapabilities& GetCapabilities() const override { @@ -1049,7 +983,6 @@ class CudaNetwork : public Network { bool multi_stream_; // run multiple parallel network evals bool allow_cache_opt_; // try to fit residual block activations in L2 cache bool use_int8_; // try to use INT8 (works only with cuda-fp16 backend) - bool int8_calibration_run_; // this is a calibration run to figure out quantization factors // Currently only one NN Eval can happen a time (we can fix this if needed // by allocating more memory). @@ -1086,9 +1019,6 @@ class CudaNetwork : public Network { mutable std::mutex inputs_outputs_lock_; std::list> free_inputs_outputs_; - void* int8_weights_; // loaded from disk / to be stored to disk - int int8_weights_size_; - void showInfo() const { int version; int ret = cudaRuntimeGetVersion(&version); diff --git a/src/neural/network_legacy.cc b/src/neural/network_legacy.cc index 53846353c6..61f126f862 100644 --- a/src/neural/network_legacy.cc +++ b/src/neural/network_legacy.cc @@ -142,13 +142,24 @@ BaseWeights::MHA::MHA(const pblczero::Weights::MHA& mha) dense_w(LayerAdapter(mha.dense_w()).as_vector()), dense_b(LayerAdapter(mha.dense_b()).as_vector()), smolgen(Smolgen(mha.smolgen())), - has_smolgen(mha.has_smolgen()) {} + has_smolgen(mha.has_smolgen()), + q_s(LayerAdapter(mha.q_s()).as_vector()), + k_s(LayerAdapter(mha.k_s()).as_vector()), + v_s(LayerAdapter(mha.v_s()).as_vector()), + s1(LayerAdapter(mha.s1()).as_vector()), + s2(LayerAdapter(mha.s2()).as_vector()), + dense_s(LayerAdapter(mha.dense_s()).as_vector()), + has_int8(mha.has_s1()) {} BaseWeights::FFN::FFN(const pblczero::Weights::FFN& ffn) : dense1_w(LayerAdapter(ffn.dense1_w()).as_vector()), dense1_b(LayerAdapter(ffn.dense1_b()).as_vector()), dense2_w(LayerAdapter(ffn.dense2_w()).as_vector()), - dense2_b(LayerAdapter(ffn.dense2_b()).as_vector()) {} + dense2_b(LayerAdapter(ffn.dense2_b()).as_vector()), + s1(LayerAdapter(ffn.s1()).as_vector()), + s2(LayerAdapter(ffn.s2()).as_vector()), + dense1_s(LayerAdapter(ffn.dense1_s()).as_vector()), + dense2_s(LayerAdapter(ffn.dense2_s()).as_vector()) {} BaseWeights::EncoderLayer::EncoderLayer( const pblczero::Weights::EncoderLayer& encoder) diff --git a/src/neural/network_legacy.h b/src/neural/network_legacy.h index 72ce67544f..4c443a95ef 100644 --- a/src/neural/network_legacy.h +++ b/src/neural/network_legacy.h @@ -81,6 +81,13 @@ struct BaseWeights { Vec dense_b; Smolgen smolgen; bool has_smolgen; + bool has_int8; + Vec q_s; + Vec k_s; + Vec v_s; + Vec s1; + Vec s2; + Vec dense_s; }; struct FFN { @@ -89,6 +96,11 @@ struct BaseWeights { Vec dense1_b; Vec dense2_w; Vec dense2_b; + bool has_int8; + Vec s1; + Vec s2; + Vec dense1_s; + Vec dense2_s; }; struct EncoderLayer { From 2682eefe77674f6dd7bff8bb90db6a4ccad45f33 Mon Sep 17 00:00:00 2001 From: almaudoh Date: Wed, 1 May 2024 22:37:31 +0200 Subject: [PATCH 57/70] Add additional scaling factor for matmul accumulator. Rename variables. Fix bugs. --- src/neural/cuda/cutlass_kernels.cu | 12 +- src/neural/cuda/layers.cc | 291 +++++++++++++++++------------ src/neural/cuda/layers.h | 3 +- 3 files changed, 180 insertions(+), 126 deletions(-) diff --git a/src/neural/cuda/cutlass_kernels.cu b/src/neural/cuda/cutlass_kernels.cu index 9049b861b0..3d35abb75c 100644 --- a/src/neural/cuda/cutlass_kernels.cu +++ b/src/neural/cuda/cutlass_kernels.cu @@ -223,15 +223,15 @@ void dumpTensor(const T* memory, int elements, const char* message, } } - if (!only_summary || i < 2 || i == elements - 1) { + if (!only_summary || i < 3 || i == elements - 1) { if (int8) { - printf("%6i ", (int8_t)val); - // printf("%i;%6i\n", i, (int8_t)val); + // printf("%6i ", (int8_t)val); + printf("%i;%6i\n", i, (int8_t)val); } else { - printf("%8.6f ", val); - // printf("%i;%8.6f\n", i, val); + // printf("%8.6f ", val); + printf("%i;%8.6f\n", i, val); } - if ((i % 8) == 7 || i == elements - 1) printf("\n"); + // if ((i % 8) == 7 || i == elements - 1) printf("\n"); } } if (!cpu_tensor) free(temp); diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index dd5de94c05..f4c76515fa 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -153,6 +153,8 @@ namespace cudnn_backend { // than using multiple passes. The flag can be set to false for debugging. static constexpr bool kUseFusedSELayer = true; +static constexpr bool clipInputActivations = true; + template BaseLayer::BaseLayer(int c, int h, int w, BaseLayer* ip, bool nhwc) : input_(ip), C(c), H(h), W(w), nhwc_(nhwc), use_gemm_ex_(false) {} @@ -1639,8 +1641,9 @@ void deQuantizeOutputMatrixBiasAdd(half* output, const int8_t* input, static void LoadQuantizationData(MatMulQuantizationData& data, const half* weights, int input_len, int output_len, - const std::vector& weightsFactors, - const std::vector& inputFactors, + const std::vector& weight_factors, + const std::vector& input_factors, + const float rescale_factor, cudaStream_t stream) { // Load weights for INT8 inference @@ -1650,18 +1653,20 @@ static void LoadQuantizationData(MatMulQuantizationData& data, ReportCUDAErrors( cudaMalloc(&data.output_scaling_factors, output_len * sizeof(float))); - if (inputFactors.size() > 1) { + if (input_factors.size() > 1) { // ReportCUDAErrors(cudaMemcpy( - // data.output_scaling_factors, inputFactors.data(), - // inputFactors.size() * sizeof(float), cudaMemcpyHostToDevice)); + // data.output_scaling_factors, input_factors.data(), + // input_factors.size() * sizeof(float), cudaMemcpyHostToDevice)); throw Exception("Channelwise quantization not yet supported."); } else { // Repeatedly fill values into the input factors buffer. - fillGpuArray(data.input_scaling_factors, inputFactors[0], input_len); + fillGpuArray(data.input_scaling_factors, input_factors[0], input_len); // Repeatedly fill values into the output factors buffer. + data.accum_rescale_factor = 1.0 / rescale_factor; fillGpuArray(data.output_scaling_factors, - inputFactors[0] * weightsFactors[0], output_len); + input_factors[0] * weight_factors[0] * rescale_factor, + output_len); } // Load weights and run a GPU kernel to scale it. @@ -1669,16 +1674,17 @@ static void LoadQuantizationData(MatMulQuantizationData& data, ReportCUDAErrors( cudaMalloc(&data.weights_int8, weights_len * sizeof(int8_t))); quantizeActivationMatrix(data.weights_int8, weights, 1, weights_len, - weightsFactors[0], stream); + weight_factors[0], stream); } -static void LoadKQVQuantizationData(MatMulQuantizationData& data, - const half* kqv_weights, int input_len, +static void LoadQKVQuantizationData(MatMulQuantizationData& data, + const half* qkv_weights, int input_len, int output_len, - const std::vector& kWeightsFactors, - const std::vector& qWeightsFactors, - const std::vector& vWeightsFactors, - const std::vector& inputFactors, + const std::vector& q_weight_factors, + const std::vector& k_weight_factors, + const std::vector& v_weight_factors, + const std::vector& input_factors, + const float rescale_factor, cudaStream_t stream) { // Load weights for INT8 inference. @@ -1688,36 +1694,40 @@ static void LoadKQVQuantizationData(MatMulQuantizationData& data, ReportCUDAErrors( cudaMalloc(&data.output_scaling_factors, output_len * 3 * sizeof(float))); - if (inputFactors.size() > 1) { + if (input_factors.size() > 1) { // ReportCUDAErrors(cudaMemcpy( - // data.output_scaling_factors, inputFactors.data(), - // inputFactors.size() * sizeof(float), cudaMemcpyHostToDevice)); + // data.output_scaling_factors, input_factors.data(), + // input_factors.size() * sizeof(float), cudaMemcpyHostToDevice)); throw Exception("Channelwise quantization not yet supported."); } else { // Repeatedly fill values into the input factors buffer. - fillGpuArray(data.input_scaling_factors, inputFactors[0], input_len); + fillGpuArray(data.input_scaling_factors, input_factors[0], input_len); // Repeatedly fill values into the output factors buffer. + data.accum_rescale_factor = 1.0 / rescale_factor; fillGpuArray(data.output_scaling_factors, - inputFactors[0] * kWeightsFactors[0], output_len); + input_factors[0] * q_weight_factors[0] * rescale_factor, + output_len); fillGpuArray(data.output_scaling_factors + output_len, - inputFactors[0] * qWeightsFactors[0], output_len); + input_factors[0] * k_weight_factors[0] * rescale_factor, + output_len); fillGpuArray(data.output_scaling_factors + output_len * 2, - inputFactors[0] * vWeightsFactors[0], output_len); + input_factors[0] * v_weight_factors[0] * rescale_factor, + output_len); } - // Load KQV weights and run a GPU kernel to scale them. + // Load QKV weights and run a GPU kernel to scale them. int weights_len = input_len * output_len; ReportCUDAErrors( cudaMalloc(&data.weights_int8, weights_len * 3 * sizeof(int8_t))); - quantizeActivationMatrix(data.weights_int8, kqv_weights, 1, weights_len, - kWeightsFactors[0], stream); + quantizeActivationMatrix(data.weights_int8, qkv_weights, 1, weights_len, + q_weight_factors[0], stream); quantizeActivationMatrix(data.weights_int8 + weights_len, - kqv_weights + weights_len, 1, weights_len, - qWeightsFactors[0], stream); + qkv_weights + weights_len, 1, weights_len, + k_weight_factors[0], stream); quantizeActivationMatrix(data.weights_int8 + weights_len * 2, - kqv_weights + weights_len * 2, 1, weights_len, - vWeightsFactors[0], stream); + qkv_weights + weights_len * 2, 1, weights_len, + v_weight_factors[0], stream); } template @@ -1861,7 +1871,7 @@ EncoderBlock::EncoderBlock( auto w = (int8_t*)int8_weights; // go to current encoder block w += per_encoder_size * blockIndex; - w = SetQuantizationData(kqv_, w, embedding_op_size_, mha_q_size_, 3, + w = SetQuantizationData(qkv_, w, embedding_op_size_, mha_q_size_, 3, int8_calibrate); w = SetQuantizationData(mha_dense_, w, mha_q_size_, embedding_op_size_, 1, int8_calibrate); w = SetQuantizationData(ffn1_, w, embedding_op_size_, ffn_dense1_size_, 1, int8_calibrate); w = @@ -1870,25 +1880,25 @@ EncoderBlock::EncoderBlock( // printf("\nSize of weights: %d\n", (w - (int8_t*)int8_weights)); */ - CERR << "QKV input factor: " << cpu_weights.mha.s1[0]; - LoadKQVQuantizationData(kqv_, (half*)mha_qkv_w, embedding_op_size_, - mha_q_size_, cpu_weights.mha.k_s, - cpu_weights.mha.q_s, cpu_weights.mha.v_s, - cpu_weights.mha.s1, 0); + // CERR << "QKV input factor: " << cpu_weights.mha.s1[0]; + LoadQKVQuantizationData(qkv_, (half*)mha_qkv_w, embedding_op_size_, + mha_q_size_, cpu_weights.mha.q_s, + cpu_weights.mha.k_s, cpu_weights.mha.v_s, + cpu_weights.mha.s1, 127.0, 0); LoadQuantizationData(mha_dense_, (half*)mha_dense_w, embedding_op_size_, mha_dense_size_, cpu_weights.mha.dense_s, - cpu_weights.mha.s2, 0); + cpu_weights.mha.s2, 127.0, 0); LoadQuantizationData(ffn1_, (half*)ffn_dense1_w, embedding_op_size_, ffn_dense1_size_, cpu_weights.ffn.dense1_s, - cpu_weights.ffn.s1, 0); + cpu_weights.ffn.s1, 511.0, 0); LoadQuantizationData(ffn2_, (half*)ffn_dense2_w, ffn_dense1_size_, ffn_dense2_size_, cpu_weights.ffn.dense2_s, - cpu_weights.ffn.s2, 0); + cpu_weights.ffn.s2, 255.0, 0); // print some weights /* printf("\noutput scale first factor: %f, %f, %f, %f\n", - *kqv_.output_scaling_factors, *mha_dense_.output_scaling_factors, + *qkv_.output_scaling_factors, *mha_dense_.output_scaling_factors, *ffn1_.output_scaling_factors, *ffn2_.output_scaling_factors); */ } @@ -2037,28 +2047,32 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, DataType* mha_k; DataType* mha_v; - //dumpTensor(in_out_tensor, embedding_op_size_ * 64 * N, "input to mha_kqv gemm", true); - //dumpTensor(mha_qkv_w, embedding_op_size_ * d_model * 3, "weights to mha_kqv gemm", - // true); - //exit(0); + // dumpTensor(in_out_tensor, embedding_op_size_ * 64 * N, "input to mha_kqv + // gemm", true); dumpTensor(mha_qkv_w, embedding_op_size_ * d_model * 3, + // "weights to mha_kqv gemm", + // true); + // exit(0); { const int num_inputs = embedding_op_size_; const int num_outputs = d_model; const int batch = N * 64; const int max_batch = max_batch_size_ * 64; - const int batch_to_use = use_fused_mha_ ? batch : max_batch; // The array of GPU pointers assume max batch + const int batch_to_use = + use_fused_mha_ + ? batch + : max_batch; // The array of GPU pointers assume max batch mha_q = scratch; mha_k = mha_q + num_outputs * batch_to_use; mha_v = mha_k + num_outputs * batch_to_use; // if (int8_cali_) { - // calibrateGemmForInt8(kqv_.weights_int8, kqv_.input_scaling_factors, - // kqv_.output_scaling_factors, - // kqv_.output_deq_factors, - // kqv_.input_matrix_max_values, - // kqv_.output_matrix_max_values, in_out_tensor, + // calibrateGemmForInt8(qkv_.weights_int8, qkv_.input_scaling_factors, + // qkv_.output_scaling_factors, + // qkv_.output_deq_factors, + // qkv_.input_matrix_max_values, + // qkv_.output_matrix_max_values, in_out_tensor, // mha_qkv_w, 64, d_model, embedding_op_size_, 3, N); // } @@ -2068,44 +2082,47 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // TODO: Fuse this step with layer-norm of previous block quantizeActivationMatrix((int8_t*)scratch, (const half*)in_out_tensor, batch, embedding_op_size_, - kqv_.input_scaling_factors, 0); - dumpTensor((int8_t*)scratch, num_inputs * batch / 4, - "encoder 1 qkv input quantized", true); - dumpTensor(kqv_.weights_int8, num_inputs * num_outputs * 3, - "encoder 1 qkv weights quantized", true); + qkv_.input_scaling_factors, stream); + // dumpTensor((int8_t*)scratch, num_inputs * batch / 4, + // "encoder 1 qkv input quantized", true); + // dumpTensor(qkv_.weights_int8, num_inputs * num_outputs * 3, + // "encoder 1 qkv weights quantized", true); // 2. perform int8 GEMM (scratch -> buffer1) - cutlassMatrixMulBTransposed((const int8_t*)scratch, kqv_.weights_int8, - (int8_t*)buffer1, batch, num_outputs, - num_inputs, 3, 0, num_inputs * num_outputs, - num_outputs * batch_to_use, 1.0, 0.0f); - dumpTensor((int8_t*)buffer1, num_outputs * batch / 4, - "encoder 1 qkv output quantized", true); - // per-layer output scaling + cutlassMatrixMulBTransposed( + (const int8_t*)scratch, qkv_.weights_int8, (int8_t*)buffer1, batch, + num_outputs, num_inputs, 3, 0, num_inputs * num_outputs, + num_outputs * batch_to_use, qkv_.accum_rescale_factor, 0.0f); + // dumpTensor((int8_t*)buffer1, num_outputs * batch / 4, + // "encoder 1 qkv output quantized", true); + // dumpTensor(qkv_.output_scaling_factors, num_outputs, + // "Output scaling factors", true); + // // per-layer output scaling + /* cutlassMatrixMulBTransposed( - (const int8_t*)scratch, kqv_.weights_int8, - kqv_.output_scaling_factors, (int8_t*)buffer1, batch, num_outputs, + (const int8_t*)scratch, qkv_.weights_int8, + qkv_.output_scaling_factors, (int8_t*)buffer1, batch, num_outputs, num_inputs, 3, 0, num_inputs * num_outputs, - num_outputs * batch_to_use, num_outputs, 1.0f, 0.0f); - - dumpTensor((int8_t*)buffer1, num_outputs * batch / 4, - "encoder 1 qkv output quantized-descaled", true); - exit(0); + num_outputs * batch_to_use, num_outputs, 1.0, 0.0f); + */ + // dumpTensor((int8_t*)buffer1, num_outputs * batch / 4, + // "encoder 1 qkv output quantized-descaled", true); + // exit(0); ReportCUDAErrors(cudaGetLastError()); /* dumpTensor((const int8_t*)scratch, 768, "quantized input matrix", false, false); - dumpTensor(kqv_.weights_int8, 768, + dumpTensor(qkv_.weights_int8, 768, "weights - during run", false, false); dumpTensor((const int8_t*)buffer1, 768, "some quantized output values", false, false); - dumpTensor(kqv_.output_scaling_factors, 768, + dumpTensor(qkv_.output_scaling_factors, 768, "output_scaling_factors - during run", false, false); */ @@ -2113,26 +2130,21 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // TODO: fuse the entire thing with the above GEMM. // deQuantizeOutputMatrixBiasAdd( // (half*)scratch, (const int8_t*)buffer1, batch, num_outputs, 3, - // kqv_.output_scaling_factors, kqv_.output_deq_factors, + // qkv_.output_scaling_factors, qkv_.output_deq_factors, // (const half*)mha_qkv_b, ACTIVATION_NONE, stream); deQuantizeOutputMatrixBiasAdd( (half*)scratch, (const int8_t*)buffer1, batch, num_outputs, 3, - kqv_.output_scaling_factors, (const half*)/*mha_qkv_b*/ nullptr, - ACTIVATION_NONE, stream); + qkv_.output_scaling_factors, (const half*)mha_qkv_b, ACTIVATION_NONE, + stream); // dumpTensor((int8_t*)scratch, num_outputs * batch / 4, // "encoder 1 qkv output no bias", true); - dumpTensor((half*)scratch, num_outputs * batch / 4, - "encoder 1 qkv output dequant"); - - dumpTensor((float*)kqv_.input_scaling_factors, num_inputs, - "encoder 1 qkv input factors"); + // dumpTensor((half*)scratch, num_outputs * batch / 4, + // "encoder 1 qkv output dequant", true); - dumpTensor((float*)kqv_.output_scaling_factors, num_outputs * 3, - "encoder 1 qkv output factors"); - exit(0); + // exit(0); /* dumpTensor((const half*)scratch, 768, @@ -2140,10 +2152,10 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, false); exit(0); */ } else { - if (is_quantized_) { + if (is_quantized_ && clipInputActivations) { clipActivationMatrix( (DataType*)in_out_tensor, (const DataType*)in_out_tensor, - kqv_.input_scaling_factors, batch, num_inputs, stream); + qkv_.input_scaling_factors, batch, num_inputs, stream); } cublasXGemmStridedBatched( @@ -2157,22 +2169,22 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, num_inputs, 3, 0, num_inputs * num_outputs, num_outputs * batch_to_use, true); #endif - dumpTensor((DataType*)mha_q, num_outputs * batch / 4, - "encoder 1 qkv output no bias", false); - exit(0); + // dumpTensor((DataType*)mha_q, num_outputs * batch / 4, + // "encoder 1 qkv output no bias", true); + // exit(0); addBiasBatched(mha_q, mha_q, mha_qkv_b, 3, batch, num_outputs, batch_to_use, ACTIVATION_NONE, stream); - dumpTensor((DataType*)mha_q, num_outputs * batch / 4, - "encoder 1 qkv output", true); + // dumpTensor((DataType*)mha_q, num_outputs * batch / 4, + // "encoder 1 qkv output", true); // exit(0); // dumpTensor((const DataType*)mha_q, // /*num_outputs * batch_to_use*/ // 768, "ref output values after // bias add", false, false); - exit(0); + // exit(0); } } @@ -2306,9 +2318,10 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, stream); // 2. perform int8 GEMM (scratch -> buffer2) - cutlassMatrixMulBTransposed( - (const int8_t*)scratch, mha_dense_.weights_int8, (int8_t*)buffer2, - batch, num_outputs, num_inputs, 1, 0, 0, 0, 1.0, 0.0f); + cutlassMatrixMulBTransposed((const int8_t*)scratch, + mha_dense_.weights_int8, (int8_t*)buffer2, + batch, num_outputs, num_inputs, 1, 0, 0, 0, + mha_dense_.accum_rescale_factor, 0.0f); /* cutlassMatrixMulBTransposed( @@ -2329,6 +2342,9 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, deQuantizeOutputMatrixBiasAdd( (half*)buffer1, (const int8_t*)buffer2, batch, num_outputs, 1, mha_dense_.output_scaling_factors, nullptr, ACTIVATION_NONE, stream); + // dumpTensor((DataType*)buffer1, num_outputs * batch / 4, + // "encoder 1 mha dense output dequant", true); + // exit(0); // dequantizeWithLayerNorm(N * 64, embedding_op_size_, scratch, // buffer1, mha_dense_b, @@ -2336,6 +2352,12 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // alpha_, ACTIVATION_NONE, stream); } else { + if (is_quantized_ && clipInputActivations) { + clipActivationMatrix( + (DataType*)buffer2, (const DataType*)buffer2, + mha_dense_.input_scaling_factors, batch, num_inputs, stream); + } + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, mha_dense_w, num_inputs, buffer2, num_inputs, 0.0f, buffer1, num_outputs); @@ -2344,6 +2366,9 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, (const half*)buffer2, (const half*)mha_dense_w, (half*)buffer1, batch, num_outputs, num_inputs, 1, 0, 0, 0, true); */ + // dumpTensor((DataType*)buffer1, num_outputs * batch / 4, + // "encoder 1 mha dense output", true); + // exit(0); } } @@ -2379,12 +2404,15 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, /* cutlassMatrixMulBTransposed((const int8_t*)in_out_tensor, ffn1_.weights_int8, (int8_t*)buffer1, batch, num_outputs, num_inputs, 1, - 0, 0, 0, 1.0 / 127.0, 0.0f); + 0, 0, 0, ffn1_.accum_rescale_factor, 0.0f); */ - cutlassMatrixMulBTransposed( - (const int8_t*)in_out_tensor, ffn1_.weights_int8, (int8_t*)buffer1, - batch, num_outputs, num_inputs, 1, 0, 0, 0, 1.0f, 0.0f); - + cutlassMatrixMulBTransposed((const int8_t*)in_out_tensor, + ffn1_.weights_int8, (int8_t*)buffer1, batch, + num_outputs, num_inputs, 1, 0, 0, 0, + ffn1_.accum_rescale_factor, 0.0f); + // dumpTensor((const int8_t*)buffer1, num_outputs * batch / 4, + // "encoder 1 ffn1 output quant", true); + // ReportCUDAErrors(cudaGetLastError()); ReportCUDAErrors(cudaGetLastError()); /* dumpTensor(ffn1_.input_scaling_factors, 768, @@ -2410,6 +2438,10 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, ffn1_.output_scaling_factors, (const half*)ffn_dense1_b, ffn_activation_, stream); + // dumpTensor((DataType*)in_out_tensor, num_outputs * batch / 4, + // "encoder 1 ffn1 dense output dequant"); + // exit(0); + // Ankan - test! // dumpTensor((const DataType*)in_out_tensor, 768, // "runtime output values after bias and RELU2", @@ -2417,6 +2449,12 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // exit(0); } else { + if (is_quantized_ && clipInputActivations) { + clipActivationMatrix( + (DataType*)scratch, (const DataType*)scratch, + ffn1_.input_scaling_factors, batch, num_inputs, stream); + } + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)ffn_dense1_w, num_inputs, scratch, num_inputs, 0.0f, in_out_tensor, num_outputs); @@ -2428,6 +2466,10 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, addBiasBatched(in_out_tensor, in_out_tensor, ffn_dense1_b, 1, batch, num_outputs, ffn_activation_, stream); + // dumpTensor((DataType*)in_out_tensor, num_outputs * batch / 4, + // "encoder 1 ffn1 dense output"); + // exit(0); + // Ankan - test! // dumpTensor((const DataType*)in_out_tensor, 768, // "Ref output values after bias and RELU2", false, @@ -2450,7 +2492,7 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, const int num_inputs = ffn_dense1_size_; // encoder_dff const int num_outputs = embedding_op_size_; const int batch = N * 64; - if (true && int8_inf_) { + if (is_quantized_ && int8_inf_) { // 1. quantize the inputs (in_out_tensor -> buffer1) // TODO: Fuse this step with above bias add at least (or ideally with the // above GEMM) @@ -2469,22 +2511,19 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, false); */ // 2. perform int8 GEMM (buffer1 -> in_out_tensor) - /* - cutlassMatrixMulBTransposed((const int8_t*)buffer1, ffn2_.weights_int8, - (int8_t*)in_out_tensor, batch, num_outputs, - num_inputs, 1, 0, 0, 0, 1.0 / 127.0, 0.0f); - */ - cutlassMatrixMulBTransposed((const int8_t*)buffer1, ffn2_.weights_int8, (int8_t*)in_out_tensor, batch, num_outputs, - num_inputs, 1, 0, 0, 0, 1.0f, 0.0f); - + num_inputs, 1, 0, 0, 0, + ffn2_.accum_rescale_factor, 0.0f); + // dumpTensor((const int8_t*)in_out_tensor, num_outputs * batch / + // 4, + // "encoder 1 ffn2 output quant", true); ReportCUDAErrors(cudaGetLastError()); /* dumpTensor((const int8_t*)in_out_tensor, 768, "some quantized output values", false, false); - dumpTensor(ffn1_.output_scaling_factors, 768, + dumpTensor(ffn2_.output_scaling_factors, 768, "output_scaling_factors - during run", false, false); */ // 3. de-quantize outputs (in_out_tensor -> buffer1) @@ -2492,11 +2531,16 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, deQuantizeOutputMatrixBiasAdd( (half*)buffer1, (const int8_t*)in_out_tensor, batch, num_outputs, 1, ffn2_.output_scaling_factors, nullptr, ACTIVATION_NONE, stream); - /* - dumpTensor((const half*)buffer1, 768, "dequantized output values", - false, false); - exit(0);*/ + + dumpTensor((const DataType*)buffer1, num_outputs * batch / 4, + "encoder 1 ffn2 output dequant"); + exit(0); } else { + if (is_quantized_ && clipInputActivations) { + clipActivationMatrix( + (DataType*)in_out_tensor, (const DataType*)in_out_tensor, + ffn2_.input_scaling_factors, batch, num_inputs, stream); + } cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)ffn_dense2_w, num_inputs, in_out_tensor, num_inputs, 0.0f, buffer1, num_outputs); @@ -2505,9 +2549,14 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, (const half*)in_out_tensor, (const half*)ffn_dense2_w, (half*)buffer1, batch, num_outputs, num_inputs, 1, 0, 0, 0, true); */ + + dumpTensor((DataType*)buffer1, num_outputs * batch / 4, + "encoder 1 ffn2 output"); + exit(0); + /* - dumpTensor((const half*)buffer1, 768, "dequantized output values - ref", - false, false); + dumpTensor((const half*)buffer1, 768, "dequantized output values - + ref", false, false); exit(0); */ @@ -2519,6 +2568,10 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, LayerNorm(N * 64, embedding_op_size_, in_out_tensor, buffer1, ffn_dense2_b, scratch, ln2_gammas, ln2_betas, default_eps_, alpha_, ACTIVATION_NONE, stream); + + // printf("\nencoder %i", blockIndex_); + // dumpTensor((const half*)in_out_tensor, embedding_op_size_ * 64, "ln2 + // output", true); exit(0); } template @@ -2648,9 +2701,9 @@ EncoderBlock::~EncoderBlock() { ReportCUDAErrors(cudaFree(smol_ln2_betas)); } if (int8_inf_) { - ReportCUDAErrors(cudaFree(kqv_.weights_int8)); - ReportCUDAErrors(cudaFree(kqv_.input_scaling_factors)); - ReportCUDAErrors(cudaFree(kqv_.output_scaling_factors)); + ReportCUDAErrors(cudaFree(qkv_.weights_int8)); + ReportCUDAErrors(cudaFree(qkv_.input_scaling_factors)); + ReportCUDAErrors(cudaFree(qkv_.output_scaling_factors)); ReportCUDAErrors(cudaFree(mha_dense_.weights_int8)); ReportCUDAErrors(cudaFree(mha_dense_.input_scaling_factors)); ReportCUDAErrors(cudaFree(mha_dense_.output_scaling_factors)); @@ -2662,8 +2715,8 @@ EncoderBlock::~EncoderBlock() { ReportCUDAErrors(cudaFree(ffn2_.output_scaling_factors)); } // else if (int8_cali_) { - // free(kqv_.input_matrix_max_values); - // free(kqv_.output_matrix_max_values); + // free(qkv_.input_matrix_max_values); + // free(qkv_.output_matrix_max_values); // free(mha_dense_.input_matrix_max_values); // free(mha_dense_.output_matrix_max_values); // free(ffn1_.input_matrix_max_values); @@ -2967,8 +3020,8 @@ void AttentionBody::Eval(int N, DataType* output, for (const auto enc : encoder_weights_) { enc->Eval(N, output_tensor, (DataType*)scratch, buffer1, buffer2, cublas, stream, offset_pointers); - dumpTensor(output_tensor, embedding_op_size_, "encoder 1 output"); - break; + // dumpTensor(output_tensor, embedding_op_size_, "encoder 1 output"); + // break; } // End of encoder blocks } diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index 27d33f8faa..ee9d969e51 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -344,6 +344,7 @@ struct MatMulQuantizationData { float* output_deq_factors; // per-tensor. Always in cpu memory (passed as constants to dequantization kernels) float* input_matrix_max_values; // max values of input matrix (always in CPU memory) float* output_matrix_max_values; // max values in output matrix (always in CPU memory) + float accum_rescale_factor; // accumulator rescale factor for matmuls to prevent overflow }; @@ -387,7 +388,7 @@ class EncoderBlock { // int 8 stuff int blockIndex_; bool int8_inf_, is_quantized_; - MatMulQuantizationData kqv_; + MatMulQuantizationData qkv_; MatMulQuantizationData mha_dense_; MatMulQuantizationData ffn1_; MatMulQuantizationData ffn2_; From 0f61c2f0f5400b3fc56ac8b02b03e8835228af10 Mon Sep 17 00:00:00 2001 From: almaudoh Date: Fri, 3 May 2024 07:53:14 +0200 Subject: [PATCH 58/70] Remove debug outputs. --- src/neural/cuda/layers.cc | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index f4c76515fa..0f53935979 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -153,7 +153,7 @@ namespace cudnn_backend { // than using multiple passes. The flag can be set to false for debugging. static constexpr bool kUseFusedSELayer = true; -static constexpr bool clipInputActivations = true; +static constexpr bool clipInputActivations = false; template BaseLayer::BaseLayer(int c, int h, int w, BaseLayer* ip, bool nhwc) @@ -2532,9 +2532,9 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, (half*)buffer1, (const int8_t*)in_out_tensor, batch, num_outputs, 1, ffn2_.output_scaling_factors, nullptr, ACTIVATION_NONE, stream); - dumpTensor((const DataType*)buffer1, num_outputs * batch / 4, - "encoder 1 ffn2 output dequant"); - exit(0); + // dumpTensor((const DataType*)buffer1, num_outputs * batch / 4, + // "encoder 1 ffn2 output dequant"); + // exit(0); } else { if (is_quantized_ && clipInputActivations) { clipActivationMatrix( @@ -2550,9 +2550,9 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, batch, num_outputs, num_inputs, 1, 0, 0, 0, true); */ - dumpTensor((DataType*)buffer1, num_outputs * batch / 4, - "encoder 1 ffn2 output"); - exit(0); + // dumpTensor((DataType*)buffer1, num_outputs * batch / 4, + // "encoder 1 ffn2 output"); + // exit(0); /* dumpTensor((const half*)buffer1, 768, "dequantized output values - @@ -2569,9 +2569,10 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, ffn_dense2_b, scratch, ln2_gammas, ln2_betas, default_eps_, alpha_, ACTIVATION_NONE, stream); - // printf("\nencoder %i", blockIndex_); - // dumpTensor((const half*)in_out_tensor, embedding_op_size_ * 64, "ln2 - // output", true); exit(0); + // printf("\nencoder %i: batchsize", blockIndex_, N); + // dumpTensor((const half*)in_out_tensor, embedding_op_size_ * 64, "ln2 output", + // true); + // exit(0); } template @@ -2587,6 +2588,8 @@ void AttentionPolicyHead::Eval( if (!attention_body_) convertNCHWtoNHWC((DataType*)scratch, input, N, inputC, N, inputC, 8, 8); + // dumpTensor(input, inputC * 64, "policy input tensor", true); + // 1. Policy embedding (fully connected layer) // Input data in NHWC layout N*(64)*C, output is N*(64)*embedding_op_size_ DataType* pol_embedding = input2_tensor; @@ -2601,6 +2604,8 @@ void AttentionPolicyHead::Eval( num_inputs, 0.0f, pol_embedding, num_outputs); addBiasBatched(pol_embedding, pol_embedding, ip_pol_b_, 1, batch, num_outputs, act_, stream); + // dumpTensor(pol_embedding, embedding_op_size_ * 64, + // "policy embedding tensor", true); } // 2. Encoder layers @@ -2608,6 +2613,8 @@ void AttentionPolicyHead::Eval( pEnc->Eval(N, input2_tensor, (DataType*)scratch, buffer1, buffer2, cublas, stream, offset_pointers); } // End of encoder blocks + // dumpTensor(input2_tensor, embedding_op_size_ * 64, "policy encoder output", + // true); DataType* wq; DataType* wk; @@ -2625,6 +2632,8 @@ void AttentionPolicyHead::Eval( addBiasBatched(wq, wq, wqk_b_, 2, batch, num_outputs, ACTIVATION_NONE, stream); + // dumpTensor(wq, num_outputs * 64, "policy attn wq output", true); + // dumpTensor(wk, num_outputs * 64, "policy attn wk output", true); } // dk = tf.math.sqrt(tf.cast(tf.shape(keys)[-1], self.model_dtype)) @@ -2645,6 +2654,7 @@ void AttentionPolicyHead::Eval( wq /*B*/, policy_d_model_ /*LDB*/, 64 * policy_d_model_, /*strideB*/ 0.0f, output /*C*/, // output (policy_attn_logits) 64 /*LDC*/, 64 * 64 + 8 * 24 /*strideC*/, N); + // dumpTensor(output, 64 * 64, "policy attn output", true); } // Compute promotion_logits in a single kernel (and put the result just after From 71ec58ba4948107bae0ef4b366291e6843aee252 Mon Sep 17 00:00:00 2001 From: almaudoh Date: Sun, 5 May 2024 03:32:52 +0200 Subject: [PATCH 59/70] Add quantization to embedding layer FFN. Add weights clipping for normal fp16 weights. --- src/neural/cuda/cutlass_kernels.cu | 7 +- src/neural/cuda/layers.cc | 180 +++++++++++++++++++++-------- src/neural/cuda/layers.h | 5 + src/neural/cuda/network_cuda.cc | 2 +- 4 files changed, 145 insertions(+), 49 deletions(-) diff --git a/src/neural/cuda/cutlass_kernels.cu b/src/neural/cuda/cutlass_kernels.cu index 3d35abb75c..3445c52ba7 100644 --- a/src/neural/cuda/cutlass_kernels.cu +++ b/src/neural/cuda/cutlass_kernels.cu @@ -961,10 +961,11 @@ __global__ void clipMatrix(T* output, const T* input, const float* factors, if (x >= width || y >= height) return; - float limit = (float)(127 * factors[x]); + float ulimit = 127.0 * factors[x]; + float llimit = -128.0 * factors[x]; float val = (float)input[y * width + x]; - if (val > limit) val = limit; - if (val < -limit) val = -limit; + if (val > ulimit) val = ulimit; + if (val < llimit) val = llimit; output[y * width + x] = (T)val; } diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 0f53935979..ed8df5370a 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -153,7 +153,7 @@ namespace cudnn_backend { // than using multiple passes. The flag can be set to false for debugging. static constexpr bool kUseFusedSELayer = true; -static constexpr bool clipInputActivations = false; +static constexpr bool clipInputActivations = true; template BaseLayer::BaseLayer(int c, int h, int w, BaseLayer* ip, bool nhwc) @@ -1638,12 +1638,11 @@ void deQuantizeOutputMatrixBiasAdd(half* output, const int8_t* input, float* scale, const half* bias, ActivationFunction act, cudaStream_t stream); -static void LoadQuantizationData(MatMulQuantizationData& data, - const half* weights, int input_len, - int output_len, +static void LoadQuantizationData(MatMulQuantizationData& data, half* weights, + int input_len, int output_len, const std::vector& weight_factors, const std::vector& input_factors, - const float rescale_factor, + const float rescale_factor, void* scratch, cudaStream_t stream) { // Load weights for INT8 inference @@ -1675,16 +1674,21 @@ static void LoadQuantizationData(MatMulQuantizationData& data, cudaMalloc(&data.weights_int8, weights_len * sizeof(int8_t))); quantizeActivationMatrix(data.weights_int8, weights, 1, weights_len, weight_factors[0], stream); + + // The original weights also need to be clipped for fp16 inference. + fillGpuArray((float*)scratch, weight_factors[0], input_len); + clipActivationMatrix(weights, weights, (const float*)scratch, + output_len, input_len, stream); } static void LoadQKVQuantizationData(MatMulQuantizationData& data, - const half* qkv_weights, int input_len, + half* qkv_weights, int input_len, int output_len, const std::vector& q_weight_factors, const std::vector& k_weight_factors, const std::vector& v_weight_factors, const std::vector& input_factors, - const float rescale_factor, + const float rescale_factor, void* scratch, cudaStream_t stream) { // Load weights for INT8 inference. @@ -1728,6 +1732,24 @@ static void LoadQKVQuantizationData(MatMulQuantizationData& data, quantizeActivationMatrix(data.weights_int8 + weights_len * 2, qkv_weights + weights_len * 2, 1, weights_len, v_weight_factors[0], stream); + + // The original weights also need to be clipped for fp16 inference + // q weights. + fillGpuArray((float*)scratch, q_weight_factors[0], input_len); + clipActivationMatrix(qkv_weights, qkv_weights, (const float*)scratch, + output_len, input_len, stream); + + // k weights. + fillGpuArray((float*)scratch, k_weight_factors[0], input_len); + clipActivationMatrix(qkv_weights + weights_len, + qkv_weights + weights_len, (const float*)scratch, + output_len, input_len, stream); + + // v weights. + fillGpuArray((float*)scratch, v_weight_factors[0], input_len); + clipActivationMatrix( + qkv_weights + weights_len * 2, qkv_weights + weights_len * 2, + (const float*)scratch, output_len, input_len, stream); } template @@ -1884,16 +1906,16 @@ EncoderBlock::EncoderBlock( LoadQKVQuantizationData(qkv_, (half*)mha_qkv_w, embedding_op_size_, mha_q_size_, cpu_weights.mha.q_s, cpu_weights.mha.k_s, cpu_weights.mha.v_s, - cpu_weights.mha.s1, 127.0, 0); + cpu_weights.mha.s1, 127.0, scratch, 0); LoadQuantizationData(mha_dense_, (half*)mha_dense_w, embedding_op_size_, mha_dense_size_, cpu_weights.mha.dense_s, - cpu_weights.mha.s2, 127.0, 0); + cpu_weights.mha.s2, 127.0, scratch, 0); LoadQuantizationData(ffn1_, (half*)ffn_dense1_w, embedding_op_size_, ffn_dense1_size_, cpu_weights.ffn.dense1_s, - cpu_weights.ffn.s1, 511.0, 0); + cpu_weights.ffn.s1, 511.0, scratch, 0); LoadQuantizationData(ffn2_, (half*)ffn_dense2_w, ffn_dense1_size_, ffn_dense2_size_, cpu_weights.ffn.dense2_s, - cpu_weights.ffn.s2, 255.0, 0); + cpu_weights.ffn.s2, 255.0, scratch, 0); // print some weights /* @@ -2783,7 +2805,9 @@ AttentionBody::AttentionBody(const MultiHeadWeights& weights, weights.ip_add_gate.size() > 0), has_smolgen_(weights.has_smolgen), is_pe_dense_embedding_(is_pe_dense_embedding), - use_fused_mha_(fused_mha) { + use_fused_mha_(fused_mha), + int8_inf_(int8_inference), + is_quantized_(weights.ip_emb_ffn.s1.size() > 0) { allocAndUpload(&ip_emb_w_, weights.ip_emb_w, scratch); allocAndUpload(&ip_emb_b_, weights.ip_emb_b, scratch); @@ -2813,6 +2837,16 @@ AttentionBody::AttentionBody(const MultiHeadWeights& weights, embedding_dense_size_ = weights.ip_emb_preproc_b.size() / 64; embedding_ffn_size_ = weights.ip_emb_ffn.dense2_b.size(); embedding_ffn_dff_ = weights.ip_emb_ffn.dense1_b.size(); + + // Quantization data for input embedding FFN layers. + LoadQuantizationData(emb_ffn1_, (half*)ip_emb_ffn_d1_w_, + embedding_dense_size_, embedding_ffn_dff_, + weights.ip_emb_ffn.dense1_s, weights.ip_emb_ffn.s1, + 127.0, scratch, 0); + LoadQuantizationData(emb_ffn2_, (half*)ip_emb_ffn_d2_w_, embedding_ffn_dff_, + embedding_ffn_size_, weights.ip_emb_ffn.dense2_s, + weights.ip_emb_ffn.s2, 127.0, scratch, 0); + } else { size_t size = 64 * kNumPosEncodingChannels * sizeof(float); ReportCUDAErrors(cudaMalloc(&pos_encoding_, size)); @@ -2951,18 +2985,18 @@ void AttentionBody::Eval(int N, DataType* output, if (is_pe_dense_embedding_) { // 1. square embedding (fully connected layer) // Input data in NHWC layout N*(64)*C, output is N*(64)*embedding_op_size_ - DataType* embedding = output_tensor; - DataType* temp = (DataType*)scratch; + DataType* embedding = (DataType*)scratch; + DataType* temp = output_tensor; { const int num_outputs = embedding_op_size_; const int num_inputs = inputC; const int batch = N * 64; cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)ip_emb_w_, - num_inputs, temp, num_inputs, 0.0f, embedding, + num_inputs, embedding, num_inputs, 0.0f, temp, num_outputs); // embedding layer norm with fused in bias add of previous gemm. - LayerNorm(N * 64, embedding_op_size_, temp, embedding, + LayerNorm(N * 64, embedding_op_size_, embedding, temp, ip_emb_b_, (DataType*)nullptr, ip_emb_ln_g_, ip_emb_ln_b_, 1e-3, 1.0, activations_.default_activation, stream); @@ -2970,8 +3004,9 @@ void AttentionBody::Eval(int N, DataType* output, // Input gating if (has_gating_) { - applyInputGating(temp, temp, ip_mult_gate_, ip_add_gate_, N, 64, - embedding_op_size_, stream); + applyInputGating(embedding, embedding, ip_mult_gate_, + ip_add_gate_, N, 64, embedding_op_size_, + stream); } // embedding FFN dense 1 @@ -2979,11 +3014,37 @@ void AttentionBody::Eval(int N, DataType* output, const int num_inputs = embedding_ffn_size_; const int num_outputs = embedding_ffn_dff_; // encoder_dff const int batch = N * 64; - cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, (const DataType*)ip_emb_ffn_d1_w_, - num_inputs, temp, num_inputs, 0.0f, buffer1, num_outputs); - addBiasBatched(buffer1, buffer1, ip_emb_ffn_d1_b_, 1, batch, num_outputs, - activations_.ffn_activation, stream); + + if (is_quantized_ && int8_inf_) { + // 1. quantize (embedding -> temp) + quantizeActivationMatrix((int8_t*)temp, (const half*)embedding, batch, + num_inputs, emb_ffn1_.input_scaling_factors, + stream); + + // 2. int8 matmul (temp -> buffer1) + cutlassMatrixMulBTransposed((const int8_t*)temp, emb_ffn1_.weights_int8, + (int8_t*)buffer1, batch, num_outputs, + num_inputs, 1, 0, 0, 0, + emb_ffn1_.accum_rescale_factor, 0.0f); + ReportCUDAErrors(cudaGetLastError()); + + // 3. dequantize + bias add (buffer1 -> buffer1) + deQuantizeOutputMatrixBiasAdd( + (half*)buffer1, (const int8_t*)buffer1, batch, num_outputs, 1, + emb_ffn1_.output_scaling_factors, (const half*)ip_emb_ffn_d1_b_, + activations_.ffn_activation, stream); + } else { + if (is_quantized_ && clipInputActivations) { + clipActivationMatrix( + (DataType*)temp, (const DataType*)embedding, + emb_ffn1_.input_scaling_factors, batch, num_inputs, stream); + } + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, (const DataType*)ip_emb_ffn_d1_w_, + num_inputs, temp, num_inputs, 0.0f, buffer1, num_outputs); + addBiasBatched(buffer1, buffer1, ip_emb_ffn_d1_b_, 1, batch, + num_outputs, activations_.ffn_activation, stream); + } } // embedding FFN dense 2 @@ -2991,22 +3052,51 @@ void AttentionBody::Eval(int N, DataType* output, const int num_inputs = embedding_ffn_dff_; // encoder_dff const int num_outputs = embedding_ffn_size_; const int batch = N * 64; - cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, (const DataType*)ip_emb_ffn_d2_w_, - num_inputs, buffer1, num_inputs, 0.0f, buffer2, num_outputs); - // Embedding LN: skip connection and layer normilization (also bias add of - // prev gemm) buffer2 -> embedding - float alpha = (float)pow(2. * encoder_weights_.size(), -0.25); - LayerNorm(N * 64, embedding_ffn_size_, embedding, buffer2, - ip_emb_ffn_d2_b_, temp, ip_emb_ffn_ln_g_, - ip_emb_ffn_ln_b_, 1e-3, alpha, ACTIVATION_NONE, - stream); - } + if (is_quantized_ && int8_inf_) { + // 1. quantize (buffer1 -> temp) + quantizeActivationMatrix((int8_t*)temp, (const half*)buffer1, batch, + num_inputs, emb_ffn2_.input_scaling_factors, + stream); + // 2. int8 matmul (temp -> buffer1) + cutlassMatrixMulBTransposed((const int8_t*)temp, emb_ffn2_.weights_int8, + (int8_t*)buffer1, batch, num_outputs, + num_inputs, 1, 0, 0, 0, + emb_ffn2_.accum_rescale_factor, 0.0f); + + // 3. dequantize + bias add (buffer1 -> buffer2) + deQuantizeOutputMatrixBiasAdd( + (half*)buffer2, (const int8_t*)buffer1, batch, num_outputs, 1, + emb_ffn2_.output_scaling_factors, nullptr, ACTIVATION_NONE, stream); + + float alpha = (float)pow(2. * encoder_weights_.size(), -0.25); + LayerNorm(N * 64, embedding_ffn_size_, output_tensor, buffer2, + ip_emb_ffn_d2_b_, embedding, ip_emb_ffn_ln_g_, + ip_emb_ffn_ln_b_, 1e-3, alpha, ACTIVATION_NONE, + stream); + } else { + if (is_quantized_ && clipInputActivations) { + clipActivationMatrix( + (DataType*)buffer1, (const DataType*)buffer1, + emb_ffn2_.input_scaling_factors, batch, num_inputs, stream); + } + + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, (const DataType*)ip_emb_ffn_d2_w_, + num_inputs, buffer1, num_inputs, 0.0f, buffer2, + num_outputs); + // Embedding LN: skip connection and layer normilization (also bias add + // of prev gemm) buffer2 -> embedding + float alpha = (float)pow(2. * encoder_weights_.size(), -0.25); + LayerNorm(N * 64, embedding_ffn_size_, output_tensor, buffer2, + ip_emb_ffn_d2_b_, embedding, ip_emb_ffn_ln_g_, + ip_emb_ffn_ln_b_, 1e-3, alpha, ACTIVATION_NONE, + stream); + } + } } else { // 1. square embedding (fully connected layer) // Input data in NHWC layout N*(64)*C, output is N*(64)*embedding_op_size_ - DataType* embedding = output_tensor; { const int num_outputs = embedding_op_size_; const int num_inputs = inputC; @@ -3014,25 +3104,25 @@ void AttentionBody::Eval(int N, DataType* output, cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)ip_emb_w_, num_inputs, (DataType*)scratch, num_inputs, 0.0f, - embedding, num_outputs); - addBiasBatched(embedding, embedding, ip_emb_b_, 1, batch, num_outputs, - activations_.default_activation, stream); + output_tensor, num_outputs); + addBiasBatched(output_tensor, output_tensor, ip_emb_b_, 1, batch, + num_outputs, activations_.default_activation, stream); } // Input gating if (has_gating_) { - applyInputGating(embedding, embedding, ip_mult_gate_, + applyInputGating(output_tensor, output_tensor, ip_mult_gate_, ip_add_gate_, N, 64, embedding_op_size_, stream); } } // 2. Encoder blocks - for (const auto enc : encoder_weights_) { - enc->Eval(N, output_tensor, (DataType*)scratch, buffer1, buffer2, cublas, - stream, offset_pointers); - // dumpTensor(output_tensor, embedding_op_size_, "encoder 1 output"); - // break; - } // End of encoder blocks + // for (const auto enc : encoder_weights_) { + // enc->Eval(N, output_tensor, (DataType*)scratch, buffer1, buffer2, cublas, + // stream, offset_pointers); + // // dumpTensor(output_tensor, embedding_op_size_, "encoder 1 output"); + // // if (i++ == 10) break; + // } // End of encoder blocks } template diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index ee9d969e51..372407f766 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -535,6 +535,11 @@ class AttentionBody : public BaseLayer { const bool has_gating_; const bool has_smolgen_; const bool use_fused_mha_; + const bool int8_inf_; + const bool is_quantized_; + + MatMulQuantizationData emb_ffn1_; + MatMulQuantizationData emb_ffn2_; }; // The value head implementation diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 7cfca11d9d..057ea93ea3 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -800,7 +800,7 @@ class CudaNetwork : public Network { scratch_mem, scratch_size_, nullptr, cublas, stream); // policy map layer // POLICY output } - + dumpTensor(opPol, 1858, "Output policy", false); } else if (conv_policy_) { network_[l++]->Eval(batchSize, spare1, flow, nullptr, scratch_mem, scratch_size_, nullptr, cublas, From 01c24e5cf41789500127274910c2b08218a3417e Mon Sep 17 00:00:00 2001 From: almaudoh Date: Wed, 8 May 2024 05:33:23 +0200 Subject: [PATCH 60/70] Update gemms to provide int8->fp32 for correct results. Remove old and unused code. --- src/neural/cuda/common_kernels.cu | 120 +++--- src/neural/cuda/cutlass_kernels.cu | 109 ++--- src/neural/cuda/kernels.h | 12 +- src/neural/cuda/layers.cc | 621 +++++++---------------------- src/neural/cuda/layers.h | 1 + src/neural/cuda/network_cuda.cc | 22 +- 6 files changed, 266 insertions(+), 619 deletions(-) diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index b6a7ce492f..bcbd85b30b 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -102,9 +102,9 @@ void addVectorsHNC_NHC(T* a, T* b, int N, int H, int C, cudaStream_t stream) { ReportCUDAErrors(cudaGetLastError()); } -template -__global__ void addBiasBatched_kernel(T* output, const T* input, const T* bias, - int N, int C) { +template +__global__ void addBiasBatched_kernel(T* output, const IT* input, + const BT* bias, int N, int C) { int batch = blockIdx.y; int n = blockIdx.x * blockDim.y + threadIdx.y; if (n >= N) return; @@ -117,18 +117,21 @@ __global__ void addBiasBatched_kernel(T* output, const T* input, const T* bias, float b[4]; // Load from memory - const bool fp16 = std::is_same::value; - if (fp16) { + if (std::is_same::value) { half inp[4]; copyAs(&inp[0], &input[tensorIndex]); #pragma unroll for (int i = 0; i < 4; i++) val[i] = (float)inp[i]; + } else { + copyAs(&val[0], &input[tensorIndex]); + } + if (std::is_same::value) { + half inp[4]; copyAs(&inp[0], &bias[biasIndex]); #pragma unroll for (int i = 0; i < 4; i++) b[i] = (float)inp[i]; } else { - copyAs(&val[0], &input[tensorIndex]); copyAs(&b[0], &bias[biasIndex]); } @@ -141,7 +144,7 @@ __global__ void addBiasBatched_kernel(T* output, const T* input, const T* bias, } // write to memory - if (fp16) { + if (std::is_same::value) { half op[4]; #pragma unroll for (int i = 0; i < 4; i++) op[i] = (half)val[i]; @@ -153,9 +156,10 @@ __global__ void addBiasBatched_kernel(T* output, const T* input, const T* bias, // Input/output tensors are Batch * N * C // bias tensor is N * C (i.e, different bias for each Batch dimension) -template -void addBiasBatched(T* output, const T* input, const T* bias, int Batch, int N, - int C, ActivationFunction activation, cudaStream_t stream) { +template +void addBiasBatched(T* output, const IT* input, const BT* bias, int Batch, + int N, int C, ActivationFunction activation, + cudaStream_t stream) { // process 4 elements per thread to achieve close to peak memory bandwidth if (C % 4 != 0) throw Exception("unsupported filter size"); if (C > 4096) throw Exception("unsupported filter size"); @@ -170,27 +174,27 @@ void addBiasBatched(T* output, const T* input, const T* bias, int Batch, int N, switch (activation) { case ACTIVATION_NONE: - addBiasBatched_kernel + addBiasBatched_kernel <<>>(output, input, bias, N, C); break; case ACTIVATION_SELU: - addBiasBatched_kernel + addBiasBatched_kernel <<>>(output, input, bias, N, C); break; case ACTIVATION_MISH: - addBiasBatched_kernel + addBiasBatched_kernel <<>>(output, input, bias, N, C); break; case ACTIVATION_RELU: - addBiasBatched_kernel + addBiasBatched_kernel <<>>(output, input, bias, N, C); break; case ACTIVATION_SWISH: - addBiasBatched_kernel + addBiasBatched_kernel <<>>(output, input, bias, N, C); break; case ACTIVATION_RELU_2: // square relu - addBiasBatched_kernel + addBiasBatched_kernel <<>>(output, input, bias, N, C); break; default: @@ -201,8 +205,8 @@ void addBiasBatched(T* output, const T* input, const T* bias, int Batch, int N, ReportCUDAErrors(cudaGetLastError()); } -template -__global__ void addBiasBatched_kernel(T* output, const T* input, const T* bias, +template +__global__ void addBiasBatched_kernel(T* output, const IT* input, const BT* bias, int N, int C, int Nstride) { int batch = blockIdx.y; int n = blockIdx.x * blockDim.y + threadIdx.y; @@ -216,18 +220,21 @@ __global__ void addBiasBatched_kernel(T* output, const T* input, const T* bias, float b[4]; // Load from memory - const bool fp16 = std::is_same::value; - if (fp16) { + if (std::is_same::value) { half inp[4]; copyAs(&inp[0], &input[tensorIndex]); #pragma unroll for (int i = 0; i < 4; i++) val[i] = (float)inp[i]; + } else { + copyAs(&val[0], &input[tensorIndex]); + } + if (std::is_same::value) { + half inp[4]; copyAs(&inp[0], &bias[biasIndex]); #pragma unroll for (int i = 0; i < 4; i++) b[i] = (float)inp[i]; } else { - copyAs(&val[0], &input[tensorIndex]); copyAs(&b[0], &bias[biasIndex]); } @@ -240,7 +247,7 @@ __global__ void addBiasBatched_kernel(T* output, const T* input, const T* bias, } // write to memory - if (fp16) { + if (std::is_same::value) { half op[4]; #pragma unroll for (int i = 0; i < 4; i++) op[i] = (half)val[i]; @@ -252,9 +259,9 @@ __global__ void addBiasBatched_kernel(T* output, const T* input, const T* bias, // Input/output tensors are Batch * N * C // bias tensor is N * C (i.e, different bias for each Batch dimension) -template -void addBiasBatched(T* output, const T* input, const T* bias, int Batch, int N, - int C, int Nstride, ActivationFunction activation, +template +void addBiasBatched(T* output, const IT* input, const BT* bias, int Batch, + int N, int C, int Nstride, ActivationFunction activation, cudaStream_t stream) { // process 4 elements per thread to achieve close to peak memory bandwidth if (C % 4 != 0) throw Exception("unsupported filter size"); @@ -270,32 +277,32 @@ void addBiasBatched(T* output, const T* input, const T* bias, int Batch, int N, switch (activation) { case ACTIVATION_NONE: - addBiasBatched_kernel + addBiasBatched_kernel <<>>(output, input, bias, N, C, Nstride); break; case ACTIVATION_SELU: - addBiasBatched_kernel + addBiasBatched_kernel <<>>(output, input, bias, N, C, Nstride); break; case ACTIVATION_MISH: - addBiasBatched_kernel + addBiasBatched_kernel <<>>(output, input, bias, N, C, Nstride); break; case ACTIVATION_RELU: - addBiasBatched_kernel + addBiasBatched_kernel <<>>(output, input, bias, N, C, Nstride); break; case ACTIVATION_SWISH: - addBiasBatched_kernel + addBiasBatched_kernel <<>>(output, input, bias, N, C, Nstride); break; case ACTIVATION_RELU_2: // square relu - addBiasBatched_kernel + addBiasBatched_kernel <<>>(output, input, bias, N, C, Nstride); break; @@ -959,8 +966,8 @@ __device__ __forceinline__ float shared_sum_for_layer_norm(float x) { // Each thread processes 4 elements // 1. Perform Bias add, and skip add // 2. Perform layer norm (normalize across C dimension) -template -__global__ void layer_norm_kernel(int N, int C, T* output, const T* input, +template +__global__ void layer_norm_kernel(int N, int C, T* output, const IT* input, const T* bias, const T* skip, const T* gammas, const T* betas, float ep, float alpha, ActivationFunction act) { @@ -978,22 +985,26 @@ __global__ void layer_norm_kernel(int N, int C, T* output, const T* input, const bool fp16 = std::is_same::value; if (!oobThread) { // Load from memory (16 elements a time) - if (fp16) { + if (std::is_same::value) { half inp[8]; copyAs(&inp[0], &input[tensorIndex]); for (int i = 0; i < 8; i++) val[i] = (float)inp[i]; copyAs(&inp[0], &input[tensorIndex + 8]); for (int i = 0; i < 8; i++) val[i + 8] = (float)inp[i]; + } else { + copyAs(&val[0], &input[tensorIndex]); + copyAs(&val[4], &input[tensorIndex + 4]); + copyAs(&val[8], &input[tensorIndex + 8]); + copyAs(&val[12], &input[tensorIndex + 12]); + } + if (fp16) { + half inp[8]; copyAs(&inp[0], &bias[biasIndex]); for (int i = 0; i < 8; i++) oth[i] = (float)inp[i]; copyAs(&inp[0], &bias[biasIndex + 8]); for (int i = 0; i < 8; i++) oth[i + 8] = (float)inp[i]; for (int i = 0; i < 16; i++) val[i] += oth[i]; } else { - copyAs(&val[0], &input[tensorIndex]); - copyAs(&val[4], &input[tensorIndex + 4]); - copyAs(&val[8], &input[tensorIndex + 8]); - copyAs(&val[12], &input[tensorIndex + 12]); copyAs(&oth[0], &bias[biasIndex]); copyAs(&oth[4], &bias[biasIndex + 4]); copyAs(&oth[8], &bias[biasIndex + 8]); @@ -1110,12 +1121,12 @@ __global__ void layer_norm_kernel(int N, int C, T* output, const T* input, } } - template __global__ void layer_norm_kernel_slow(int N, int C, T* output, const T* input, - const T* bias, const T* skip, const T* gammas, - const T* betas, float ep, float alpha, - ActivationFunction act) { + const T* bias, const T* skip, + const T* gammas, const T* betas, + float ep, float alpha, + ActivationFunction act) { int n = blockIdx.x * blockDim.z + threadIdx.z; if (n >= N) return; int c = (threadIdx.y * 32 + threadIdx.x) * 16; @@ -1260,7 +1271,6 @@ __global__ void layer_norm_kernel_slow(int N, int C, T* output, const T* input, } } - __global__ void layer_norm_kernel_8_el_per_thread( int N, int C, half* output, const half* input, const half* bias, const half* skip, const half* gammas, const half* betas, float ep, @@ -1332,12 +1342,11 @@ __global__ void layer_norm_kernel_8_el_per_thread( } } - // add (optional) skip connection to input, and then perform Layer normalization // normalization is done across C dimension (i.e, sums and std deviations taken // over elements in C dim) -template -void LayerNorm(int N, int C, T* output, const T* input, const T* bias, +template +void LayerNorm(int N, int C, T* output, const IT* input, const T* bias, const T* skip, const T* gammas, const T* betas, float ep, float alpha, ActivationFunction act, cudaStream_t stream) { // process 4 elements per thread to achieve close to peak memory bandwidth @@ -1353,7 +1362,7 @@ void LayerNorm(int N, int C, T* output, const T* input, const T* bias, gridDim.y = 1; gridDim.z = 1; - layer_norm_kernel<<>>( + layer_norm_kernel<<>>( N, C, output, input, bias, skip, gammas, betas, ep, alpha, act); ReportCUDAErrors(cudaGetLastError()); @@ -1465,7 +1474,8 @@ __global__ void preprocess_for_attention_body_kernel( if (c >= input_size) { // concatenate from position encoding array if (is_pe_dense_embedding) { - op = (T)(encoding[n * 64 * encoding_size + hw * encoding_size + (c - input_size)]); + op = (T)(encoding[n * 64 * encoding_size + hw * encoding_size + + (c - input_size)]); } else { op = (T)(encoding[64 * hw + (c - input_size)]); } @@ -1572,6 +1582,10 @@ template void addBiasBatched(half* output, const half* input, ActivationFunction activation, cudaStream_t stream); +template void addBiasBatched( + half* output, const float* input, const half* bias, int Batch, int N, int C, + ActivationFunction activation, cudaStream_t stream); + template void addBiasBatched(float* output, const float* input, const float* bias, int Batch, int N, int C, int Nstride, ActivationFunction activation, @@ -1581,6 +1595,10 @@ template void addBiasBatched(half* output, const half* input, int Nstride, ActivationFunction activation, cudaStream_t stream); +template void addBiasBatched( + half* output, const float* input, const half* bias, int Batch, int N, int C, + int Nstride, ActivationFunction activation, cudaStream_t stream); + template void addBias_NCHW(float* c, float* a, float* b, int N, int C, int H, int W, ActivationFunction activation, cudaStream_t stream); @@ -1766,6 +1784,12 @@ template void Softmax(int N, int C, half* output, const half* input, template void Softmax(int N, int C, float* output, const float* input, const float* input2, cudaStream_t stream); +template void LayerNorm(int N, int C, half* output, + const float* input, const half* bias, + const half* skip, const half* gammas, + const half* betas, float ep, float alpha, + ActivationFunction act, + cudaStream_t stream); template void LayerNorm(int N, int C, half* output, const half* input, const half* bias, const half* skip, const half* gammas, const half* betas, float ep, diff --git a/src/neural/cuda/cutlass_kernels.cu b/src/neural/cuda/cutlass_kernels.cu index 3445c52ba7..4cec5b01d5 100644 --- a/src/neural/cuda/cutlass_kernels.cu +++ b/src/neural/cuda/cutlass_kernels.cu @@ -174,9 +174,11 @@ void dumpTensor(const T* memory, int elements, const char* message, bool only_summary = false, bool cpu_tensor = false) { const bool fp16 = std::is_same::value; const bool int8 = std::is_same::value; + const bool int32 = std::is_same::value; printf("\n%s\n", message); int elementSize = (int)(fp16 ? sizeof(half) : sizeof(float)); if (int8) elementSize = sizeof(int8_t); + if (int32) elementSize = sizeof(int32_t); int bytes = elements * elementSize; void* temp = (void*)memory; if (!cpu_tensor) { @@ -195,7 +197,10 @@ void dumpTensor(const T* memory, int elements, const char* message, std::vector fpArr(elements); for (int i = 0; i < elements; i++) { float val; - if (int8) { + if (int32) { + int32_t* arr = (int32_t*)temp; + val = (float)arr[i]; + } else if (int8) { int8_t* arr = (int8_t*)temp; val = (float)arr[i]; } else if (fp16) { @@ -242,12 +247,12 @@ void dumpTensor(const T* memory, int elements, const char* message, float avg = mean(&fpArr[0], elements); float stddev = stdDev(&fpArr[0], elements); - if (int8) { + if (int8 || int32) { printf( "Max: %i, Min: %i, Mean: %i, StdDev: %i\n" "NaNs: %i, HiQuantLimit: %i, LoQuantLimit: %i, Total: %i", - (int8_t)maxval, (int8_t)minval, (int8_t)avg, (int8_t)stddev, cnans, - cplims, cnlims, elements); + (int)maxval, (int)minval, (int)avg, (int)stddev, cnans, cplims, cnlims, + elements); } else { printf( @@ -275,23 +280,24 @@ void dumpTensor(const T* memory, int elements, const char* message, // int8 GEMM using CUTLASS (with per-column output quantization) void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, - const float* scaleVector, int8_t* Out, int M, + const float* scaleVector, float* Out, int M, int N, int K, int batchSize, int AStride, int BStride, int OutStride, int VecStride, float alphaf, float betaf) { using ElementAccumulator = int32_t; using ElementComputeEpilogue = float; - using ElementIO = int8_t; + using ElementInput = int8_t; + using ElementOutput = float; using ElementScale = float; using ThreadBlockSwizzle = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; constexpr int elementsPerAccess = - 128 / cutlass::sizeof_bits::value; + 128 / cutlass::sizeof_bits::value; using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise< - ElementIO, ElementAccumulator, ElementComputeEpilogue, ElementIO, - ElementIO, + ElementOutput, ElementAccumulator, ElementComputeEpilogue, ElementOutput, + ElementOutput, // ElementScale, // element Vector elementsPerAccess, // false, @@ -299,8 +305,8 @@ void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, cutlass::multiplies>; using Gemm = cutlass::gemm::device::GemmUniversalWithBroadcast< - ElementIO, cutlass::layout::RowMajor, ElementIO, - cutlass::layout::ColumnMajor, ElementIO, cutlass::layout::RowMajor, + ElementInput, cutlass::layout::RowMajor, ElementInput, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, cutlass::gemm::GemmShape<256, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, @@ -339,7 +345,7 @@ void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, } // int8 GEMM using CUTLASS -void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int8_t* Out, +void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, float* Out, int M, int N, int K, int batchSize, int AStride, int BStride, int OutStride, float alphaf, float betaf) { @@ -351,8 +357,7 @@ void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int8_t* Out, using ElementComputeEpilogue = float; // <- data type of epilogue operations using ElementInputA = int8_t; // <- data type of elements in input matrix A using ElementInputB = int8_t; // <- data type of elements in input matrix B - using ElementOutput = - int8_t; // <- data type of elements in output matrix Out + using ElementOutput = float; // <- data type of elements in output matrix Out // TODO: figure out why row major for matrix B doesn't work?!!! using LayoutInputA = cutlass::layout::RowMajor; @@ -954,15 +959,15 @@ void quantizeActivationMatrix(int8_t* output, const half* input, int height, // Quantize matrix with single scale value template -__global__ void clipMatrix(T* output, const T* input, const float* factors, +__global__ void clipMatrix(T* output, const T* input, const float scale_factor, int height, int width) { int x = (blockIdx.x * blockDim.x + threadIdx.x); int y = blockIdx.y * blockDim.y + threadIdx.y; if (x >= width || y >= height) return; - float ulimit = 127.0 * factors[x]; - float llimit = -128.0 * factors[x]; + float ulimit = 127.0 * scale_factor; + float llimit = -128.0 * scale_factor; float val = (float)input[y * width + x]; if (val > ulimit) val = ulimit; if (val < llimit) val = llimit; @@ -971,13 +976,13 @@ __global__ void clipMatrix(T* output, const T* input, const float* factors, template void clipActivationMatrix(DataType* output, const DataType* input, - const float* factors, int height, int width, + const float scale_factor, int height, int width, cudaStream_t stream) { dim3 blockDim(16, 16); dim3 gridDim(lczero::cudnn_backend::DivUp(width, 16), lczero::cudnn_backend::DivUp(height, 16)); - clipMatrix - <<>>(output, input, factors, height, width); + clipMatrix<<>>( + output, input, scale_factor, height, width); ReportCUDAErrors(cudaGetLastError()); } @@ -1048,63 +1053,6 @@ void deQuantizeOutputMatrixBiasAdd(half* output, const int8_t* input, ReportCUDAErrors(cudaGetLastError()); } -// process 8 elements per thread (in x dimension) -__global__ void deQuantizeMatrix(half* output, const int8_t* input, - const half* bias, int height, int width, - int stride, const float* invScale, - ActivationFunction act) { - int x = (blockIdx.x * blockDim.x + threadIdx.x) * 8; - int y = blockIdx.y * blockDim.y + threadIdx.y; - int b = blockIdx.z; - - if (x >= width || y >= height) return; - - int8_t ip[8] = {}; - half op[8] = {}; - half bi[8] = {}; - float inv_scale[8]; - - copyAs(&ip[0], &input[b * stride + y * width + x]); - if (bias) copyAs(&bi[0], &bias[b * width + x]); - - if (invScale) { - copyAs(&inv_scale[0], &invScale[b * width + x]); - copyAs(&inv_scale[4], &invScale[b * width + x + 4]); - } else { - for (int i = 0; i < 8; i++) inv_scale[i] = 1 / 127.0f; - } - - for (int i = 0; i < 8; i++) { - float val = (float)ip[i]; - val *= inv_scale[i]; - if (bias) val += (float)bi[i]; - op[i] = (half)activate(val, act); - } - - copyAs(&output[b * stride + y * width + x], &op[0]); -} - -// the scale (in CPU memory) is per "batch" -// the bias is per column, per batch -void deQuantizeOutputMatrixBiasAdd(half* output, const int8_t* input, - int height, int width, int batchSize, - float* invScale, const half* bias, - ActivationFunction act, - cudaStream_t stream) { - dim3 blockDim(16, 16); - dim3 gridDim(lczero::cudnn_backend::DivUp(width, 16 * 8), - lczero::cudnn_backend::DivUp(height, 16), batchSize); - - // otherwise we will need to put them in GPU memory - assert(batchSize < MAX_BATCH_DEQUANT); - - int stride = width * height; - - deQuantizeMatrix<<>>( - output, input, bias, height, width, stride, invScale, act); - ReportCUDAErrors(cudaGetLastError()); -} - void fillGpuArray(float* arr, float val, int count) { thrust::device_ptr dev_ptr(arr); thrust::fill(dev_ptr, dev_ptr + count, val); @@ -1122,10 +1070,10 @@ template void calibrateGemmForInt8( float* maxValuesOut, const half* A, const half* B, int M, int N, int K, int batchSize, int M_Batch); template void clipActivationMatrix(float* output, const float* input, - const float* factors, int height, + const float scale_factor, int height, int width, cudaStream_t stream); template void clipActivationMatrix(half* output, const half* input, - const float* factors, int height, + const float scale_factor, int height, int width, cudaStream_t stream); template void dumpTensor(const float* memory, int elements, @@ -1140,6 +1088,9 @@ template void dumpTensor(const int8_t* memory, int elements, const char* message, bool only_summary, bool cpu_tensor); +template void dumpTensor(const int32_t* memory, int elements, + const char* message, bool only_summary, + bool cpu_tensor); }; // namespace cudnn_backend }; // namespace lczero diff --git a/src/neural/cuda/kernels.h b/src/neural/cuda/kernels.h index cc49253ab9..ff2cc793d1 100644 --- a/src/neural/cuda/kernels.h +++ b/src/neural/cuda/kernels.h @@ -45,14 +45,14 @@ void addVectorsHNC_NHC(T* a, T* b, int N, int H, int C, cudaStream_t stream); // Optimized kernel to add bias to innermost dimension // and perform optional activation (to be used with GEMMs/fully connected) -template -void addBiasBatched(T* output, const T* input, const T* bias, int Batch, int N, +template +void addBiasBatched(T* output, const IT* input, const BT* bias, int Batch, int N, int C, ActivationFunction activation, cudaStream_t stream); // Optimized kernel to add bias to innermost dimension // and perform optional activation (to be used with GEMMs/fully connected) -template -void addBiasBatched(T* output, const T* input, const T* bias, int Batch, int N, +template +void addBiasBatched(T* output, const IT* input, const BT* bias, int Batch, int N, int C, int Nstride, ActivationFunction activation, cudaStream_t stream); @@ -135,8 +135,8 @@ template void Softmax(int N, int C, T* output, const T* input, const T* input2, cudaStream_t stream); -template -void LayerNorm(int N, int C, T* output, const T* input, const T* bias, +template +void LayerNorm(int N, int C, T* output, const IT* input, const T* bias, const T* skip, const T* gammas, const T* betas, float ep, float alpha, ActivationFunction act, cudaStream_t stream); diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index ed8df5370a..f13365aa32 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -1537,80 +1537,13 @@ AttentionPolicyHead::AttentionPolicyHead( void fillGpuArray(float* arr, float val, int count); -static int8_t* SetQuantizationData(MatMulQuantizationData& data, int8_t* w, - int InputCols, int OutputCols, int outBatch, - bool cali) { - size_t matrix_size = InputCols * OutputCols * sizeof(int8_t) * outBatch; - if (!cali) { - // Load weights for INT8 inference - - // (per-column) scaling factors for the input - ReportCUDAErrors( - cudaMalloc(&data.input_scaling_factors, sizeof(float) * InputCols)); - ReportCUDAErrors(cudaMemcpy(data.input_scaling_factors, w, - sizeof(float) * InputCols, - cudaMemcpyHostToDevice)); - - // go to weights - w += InputCols * sizeof(float); - ReportCUDAErrors(cudaMalloc(&data.weights_int8, matrix_size)); - ReportCUDAErrors( - cudaMemcpy(data.weights_int8, w, matrix_size, cudaMemcpyHostToDevice)); - - // go to output scaling factors - w += matrix_size; - ReportCUDAErrors(cudaMalloc(&data.output_scaling_factors, - sizeof(float) * OutputCols * outBatch)); - ReportCUDAErrors(cudaMemcpy(data.output_scaling_factors, w, - sizeof(float) * OutputCols * outBatch, - cudaMemcpyHostToDevice)); - // go to output dequantization factors - w += outBatch * OutputCols * sizeof(float); - data.output_deq_factors = (float*)w; - - // go to next item - w += outBatch * sizeof(float); - } else { - // Just save the pointers to CPU weights (we will over-write here during - // calibration) - data.input_scaling_factors = (float*)w; - w += InputCols * sizeof(float); - data.weights_int8 = w; - w += matrix_size; - data.output_scaling_factors = (float*)w; - w += outBatch * OutputCols * sizeof(float); - data.output_deq_factors = (float*)w; - w += outBatch * sizeof(float); - - // to keep track of max values in activation matrices - int InputMatrixSizeForBatch1 = 64 * InputCols * sizeof(float); - data.input_matrix_max_values = (float*)malloc(InputMatrixSizeForBatch1); - memset(data.input_matrix_max_values, 0, InputMatrixSizeForBatch1); - - int OutputMatrixSizeForBatch1 = 64 * OutputCols * sizeof(float) * outBatch; - data.output_matrix_max_values = (float*)malloc(OutputMatrixSizeForBatch1); - memset(data.output_matrix_max_values, 0, OutputMatrixSizeForBatch1); - } - - // return pointer to next item - return w; -} - -template -void calibrateGemmForInt8(int8_t* weights_int8, float* input_scaling_factors, - float* output_scaling_factors, - float* output_deq_factors, float* maxValuesA, - float* maxValuesOut, const DataType* A, - const DataType* B, int M, int N, int K, int batchSize, - int M_Batch); - -void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int8_t* Out, +void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, float* Out, int M, int N, int K, int batchSize, int AStride, int BStride, int OutStride, float alphaf, float betaf); void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, - const float* scaleVector, int8_t* Out, int M, + const float* scaleVector, float* Out, int M, int N, int K, int batchSize, int AStride, int BStride, int OutStride, int VecStride, float alphaf, float betaf); @@ -1625,7 +1558,7 @@ void quantizeActivationMatrix(int8_t* output, const half* input, int height, template void clipActivationMatrix(DataType* output, const DataType* input, - const float* factors, int height, int weight, + const float scale_factor, int height, int width, cudaStream_t stream); void deQuantizeOutputMatrixBiasAdd(half* output, const int8_t* input, @@ -1642,7 +1575,6 @@ static void LoadQuantizationData(MatMulQuantizationData& data, half* weights, int input_len, int output_len, const std::vector& weight_factors, const std::vector& input_factors, - const float rescale_factor, void* scratch, cudaStream_t stream) { // Load weights for INT8 inference @@ -1661,11 +1593,13 @@ static void LoadQuantizationData(MatMulQuantizationData& data, half* weights, // Repeatedly fill values into the input factors buffer. fillGpuArray(data.input_scaling_factors, input_factors[0], input_len); + // Same factor will be used for clipping inputs in fp16 mode. + data.fp16_clip_scale_factor = input_factors[0]; + // Repeatedly fill values into the output factors buffer. - data.accum_rescale_factor = 1.0 / rescale_factor; + data.accum_rescale_factor = input_factors[0] * weight_factors[0]; fillGpuArray(data.output_scaling_factors, - input_factors[0] * weight_factors[0] * rescale_factor, - output_len); + input_factors[0] * weight_factors[0], output_len); } // Load weights and run a GPU kernel to scale it. @@ -1676,9 +1610,8 @@ static void LoadQuantizationData(MatMulQuantizationData& data, half* weights, weight_factors[0], stream); // The original weights also need to be clipped for fp16 inference. - fillGpuArray((float*)scratch, weight_factors[0], input_len); - clipActivationMatrix(weights, weights, (const float*)scratch, - output_len, input_len, stream); + clipActivationMatrix(weights, weights, weight_factors[0], 1, + weights_len, stream); } static void LoadQKVQuantizationData(MatMulQuantizationData& data, @@ -1688,7 +1621,6 @@ static void LoadQKVQuantizationData(MatMulQuantizationData& data, const std::vector& k_weight_factors, const std::vector& v_weight_factors, const std::vector& input_factors, - const float rescale_factor, void* scratch, cudaStream_t stream) { // Load weights for INT8 inference. @@ -1707,17 +1639,16 @@ static void LoadQKVQuantizationData(MatMulQuantizationData& data, // Repeatedly fill values into the input factors buffer. fillGpuArray(data.input_scaling_factors, input_factors[0], input_len); + // Same factor will be used for clipping inputs in fp16 mode. + data.fp16_clip_scale_factor = input_factors[0]; + // Repeatedly fill values into the output factors buffer. - data.accum_rescale_factor = 1.0 / rescale_factor; fillGpuArray(data.output_scaling_factors, - input_factors[0] * q_weight_factors[0] * rescale_factor, - output_len); + input_factors[0] * q_weight_factors[0], output_len); fillGpuArray(data.output_scaling_factors + output_len, - input_factors[0] * k_weight_factors[0] * rescale_factor, - output_len); + input_factors[0] * k_weight_factors[0], output_len); fillGpuArray(data.output_scaling_factors + output_len * 2, - input_factors[0] * v_weight_factors[0] * rescale_factor, - output_len); + input_factors[0] * v_weight_factors[0], output_len); } // Load QKV weights and run a GPU kernel to scale them. @@ -1735,21 +1666,18 @@ static void LoadQKVQuantizationData(MatMulQuantizationData& data, // The original weights also need to be clipped for fp16 inference // q weights. - fillGpuArray((float*)scratch, q_weight_factors[0], input_len); - clipActivationMatrix(qkv_weights, qkv_weights, (const float*)scratch, + clipActivationMatrix(qkv_weights, qkv_weights, q_weight_factors[0], output_len, input_len, stream); // k weights. - fillGpuArray((float*)scratch, k_weight_factors[0], input_len); clipActivationMatrix(qkv_weights + weights_len, - qkv_weights + weights_len, (const float*)scratch, + qkv_weights + weights_len, k_weight_factors[0], output_len, input_len, stream); // v weights. - fillGpuArray((float*)scratch, v_weight_factors[0], input_len); - clipActivationMatrix( - qkv_weights + weights_len * 2, qkv_weights + weights_len * 2, - (const float*)scratch, output_len, input_len, stream); + clipActivationMatrix(qkv_weights + weights_len * 2, + qkv_weights + weights_len * 2, v_weight_factors[0], + output_len, input_len, stream); } template @@ -1829,8 +1757,6 @@ EncoderBlock::EncoderBlock( allocAndUpload(&ln2_gammas, cpu_weights.ln2_gammas, scratch); allocAndUpload(&ln2_betas, cpu_weights.ln2_betas, scratch); - // printf("mhaqb: %i, mhaqw: %i, embedsize: %i\n", cpu_weights.mha.q_b.size(), - // cpu_weights.mha.q_w.size(), size); // Smolgen weights. if (has_smolgen_) { @@ -1865,67 +1791,36 @@ EncoderBlock::EncoderBlock( // int8 stuff blockIndex_ = blockIndex; if (int8_inf_ || is_quantized_) { - /* - int per_encoder_size = embedding_op_size_ * sizeof(float) + - 3 * embedding_op_size_ * mha_q_size_ + - 3 * sizeof(float) + mha_q_size_ * sizeof(float) + - embedding_op_size_ * mha_q_size_ + sizeof(float) + - embedding_op_size_ * sizeof(float) + - ffn_dense1_size_ * mha_q_size_ + sizeof(float) + - ffn_dense1_size_ * sizeof(float) + - embedding_op_size_ * ffn_dense1_size_ + - sizeof(float); - */ - /** - int embedding_op_size = embedding_op_size_; - int encoder_d_model = mha_q_size_; - int encoder_dff = ffn_dense1_size_; - int per_encoder_size = - (embedding_op_size * sizeof(float) + 3 * embedding_op_size * - encoder_d_model + 3 * (encoder_d_model + 1) * sizeof(float) + - encoder_d_model * sizeof(float) + encoder_d_model * - embedding_op_size + (embedding_op_size + 1) * sizeof(float) + - embedding_op_size * sizeof(float) + embedding_op_size * - encoder_dff + (encoder_dff + 1) * sizeof(float) + - encoder_dff * sizeof(float) + encoder_dff * - embedding_op_size + (embedding_op_size + 1) * sizeof(float)); - - auto w = (int8_t*)int8_weights; - // go to current encoder block - w += per_encoder_size * blockIndex; - w = SetQuantizationData(qkv_, w, embedding_op_size_, mha_q_size_, 3, - int8_calibrate); w = SetQuantizationData(mha_dense_, w, mha_q_size_, - embedding_op_size_, 1, int8_calibrate); w = SetQuantizationData(ffn1_, w, - embedding_op_size_, ffn_dense1_size_, 1, int8_calibrate); w = - SetQuantizationData(ffn2_, w, ffn_dense1_size_, embedding_op_size_, 1, - int8_calibrate); - // printf("\nSize of weights: %d\n", (w - (int8_t*)int8_weights)); - */ - - // CERR << "QKV input factor: " << cpu_weights.mha.s1[0]; LoadQKVQuantizationData(qkv_, (half*)mha_qkv_w, embedding_op_size_, mha_q_size_, cpu_weights.mha.q_s, cpu_weights.mha.k_s, cpu_weights.mha.v_s, - cpu_weights.mha.s1, 127.0, scratch, 0); + cpu_weights.mha.s1, 0); LoadQuantizationData(mha_dense_, (half*)mha_dense_w, embedding_op_size_, mha_dense_size_, cpu_weights.mha.dense_s, - cpu_weights.mha.s2, 127.0, scratch, 0); + cpu_weights.mha.s2, 0); LoadQuantizationData(ffn1_, (half*)ffn_dense1_w, embedding_op_size_, ffn_dense1_size_, cpu_weights.ffn.dense1_s, - cpu_weights.ffn.s1, 511.0, scratch, 0); + cpu_weights.ffn.s1, 0); LoadQuantizationData(ffn2_, (half*)ffn_dense2_w, ffn_dense1_size_, ffn_dense2_size_, cpu_weights.ffn.dense2_s, - cpu_weights.ffn.s2, 255.0, scratch, 0); - - // print some weights - /* - printf("\noutput scale first factor: %f, %f, %f, %f\n", - *qkv_.output_scaling_factors, *mha_dense_.output_scaling_factors, - *ffn1_.output_scaling_factors, *ffn2_.output_scaling_factors); - */ + cpu_weights.ffn.s2, 0); } } +// Int8 x Int8 -> Float32 +static void cublasXgemm(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + float alpha, const int8_t* A, int lda, const int8_t* B, + int ldb, float beta, float* C, int ldc) { + ReportCUBLASErrors(cublasGemmEx(handle, transa, transb, m, n, k, &alpha, A, + CUDA_R_8I /* int8 data type */, lda, B, + CUDA_R_8I /* Data type of B */, ldb, &beta, C, + CUDA_R_32F /* Data type of C */, ldc, + CUDA_R_32F, // Compute type FP32 + CUBLAS_GEMM_DEFAULT // Algorithm type + )); +} + template static void cublasXgemm(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, @@ -1946,14 +1841,24 @@ static void cublasXgemm(cublasHandle_t handle, cublasOperation_t transa, } } -template +template static void cublasXGemmStridedBatched( cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, float alpha, const void* A, int lda, long long int strideA, const void* B, int ldb, long long int strideB, float beta, void* C, int ldc, long long int strideC, int batchCount) { + const bool int8 = std::is_same::value; const bool fp16 = std::is_same::value; - if (fp16) { + const bool out_float = std::is_same::value; + if (int8 && out_float) { + // @TODO Gemm is failing. All zeros. + int8_t alpha_i = (int8_t)alpha; + OutDataType beta_i = (OutDataType)beta; + ReportCUBLASErrors(cublasGemmStridedBatchedEx( + handle, transa, transb, m, n, k, &alpha_i, A, CUDA_R_8I, lda, strideA, + B, CUDA_R_8I, ldb, strideB, &beta_i, C, CUDA_R_32F, ldc, strideC, + batchCount, CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); + } else if (fp16) { unsigned short alpha_h = FP32toFP16(alpha); unsigned short beta_h = FP32toFP16(beta); ReportCUBLASErrors(cublasGemmStridedBatchedEx( @@ -2068,13 +1973,6 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, DataType* mha_q; DataType* mha_k; DataType* mha_v; - - // dumpTensor(in_out_tensor, embedding_op_size_ * 64 * N, "input to mha_kqv - // gemm", true); dumpTensor(mha_qkv_w, embedding_op_size_ * d_model * 3, - // "weights to mha_kqv gemm", - // true); - // exit(0); - { const int num_inputs = embedding_op_size_; const int num_outputs = d_model; @@ -2089,38 +1987,30 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, mha_k = mha_q + num_outputs * batch_to_use; mha_v = mha_k + num_outputs * batch_to_use; - // if (int8_cali_) { - // calibrateGemmForInt8(qkv_.weights_int8, qkv_.input_scaling_factors, - // qkv_.output_scaling_factors, - // qkv_.output_deq_factors, - // qkv_.input_matrix_max_values, - // qkv_.output_matrix_max_values, in_out_tensor, - // mha_qkv_w, 64, d_model, embedding_op_size_, 3, N); - // } - if (is_quantized_ && int8_inf_) { - // printf("\nAttempting int8_inf\n"); // 1. quantize the inputs (in_out_tensor -> scratch) // TODO: Fuse this step with layer-norm of previous block quantizeActivationMatrix((int8_t*)scratch, (const half*)in_out_tensor, batch, embedding_op_size_, qkv_.input_scaling_factors, stream); - // dumpTensor((int8_t*)scratch, num_inputs * batch / 4, - // "encoder 1 qkv input quantized", true); - // dumpTensor(qkv_.weights_int8, num_inputs * num_outputs * 3, - // "encoder 1 qkv weights quantized", true); // 2. perform int8 GEMM (scratch -> buffer1) - cutlassMatrixMulBTransposed( - (const int8_t*)scratch, qkv_.weights_int8, (int8_t*)buffer1, batch, - num_outputs, num_inputs, 3, 0, num_inputs * num_outputs, - num_outputs * batch_to_use, qkv_.accum_rescale_factor, 0.0f); - // dumpTensor((int8_t*)buffer1, num_outputs * batch / 4, - // "encoder 1 qkv output quantized", true); - // dumpTensor(qkv_.output_scaling_factors, num_outputs, - // "Output scaling factors", true); - // // per-layer output scaling + (const int8_t*)scratch, qkv_.weights_int8, + qkv_.output_scaling_factors, (float*)buffer1, batch, num_outputs, + num_inputs, 3, 0, num_inputs * num_outputs, + num_outputs * batch_to_use, num_outputs, 1.0f, 0.0f); + + // cublasXGemmStridedBatched( + // cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, + // 1.0f, qkv_.weights_int8, num_inputs, num_inputs * num_outputs, + // (const int8_t*)scratch, num_inputs, 0, 0.0f, (float*)buffer1, + // num_outputs, num_outputs * batch_to_use, 3); + + // 3. Bias add - mixed precision (buffer1 -> mha_q) + addBiasBatched(mha_q, (float*)buffer1, mha_qkv_b, 3, + batch, num_outputs, batch_to_use, + ACTIVATION_NONE, stream); /* cutlassMatrixMulBTransposed( @@ -2129,55 +2019,14 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, num_inputs, 3, 0, num_inputs * num_outputs, num_outputs * batch_to_use, num_outputs, 1.0, 0.0f); */ - // dumpTensor((int8_t*)buffer1, num_outputs * batch / 4, - // "encoder 1 qkv output quantized-descaled", true); - // exit(0); ReportCUDAErrors(cudaGetLastError()); - /* - dumpTensor((const int8_t*)scratch, 768, "quantized input matrix", - false, false); - - dumpTensor(qkv_.weights_int8, 768, - "weights - during run", false, false); - - dumpTensor((const int8_t*)buffer1, 768, - "some quantized output values", false, false); - - dumpTensor(qkv_.output_scaling_factors, 768, - "output_scaling_factors - during run", - false, false); - */ - // 3. de-quantize outputs - fused with bias add (buffer1 -> scratch) - // TODO: fuse the entire thing with the above GEMM. - // deQuantizeOutputMatrixBiasAdd( - // (half*)scratch, (const int8_t*)buffer1, batch, num_outputs, 3, - // qkv_.output_scaling_factors, qkv_.output_deq_factors, - // (const half*)mha_qkv_b, ACTIVATION_NONE, stream); - deQuantizeOutputMatrixBiasAdd( - (half*)scratch, (const int8_t*)buffer1, batch, num_outputs, 3, - qkv_.output_scaling_factors, (const half*)mha_qkv_b, ACTIVATION_NONE, - stream); - - // dumpTensor((int8_t*)scratch, num_outputs * batch / 4, - // "encoder 1 qkv output no bias", true); - - // dumpTensor((half*)scratch, num_outputs * batch / 4, - // "encoder 1 qkv output dequant", true); - - // exit(0); - - /* - dumpTensor((const half*)scratch, 768, - "dequantized output values after bias add", false, - false); exit(0); - */ } else { if (is_quantized_ && clipInputActivations) { clipActivationMatrix( (DataType*)in_out_tensor, (const DataType*)in_out_tensor, - qkv_.input_scaling_factors, batch, num_inputs, stream); + qkv_.fp16_clip_scale_factor, batch, num_inputs, stream); } cublasXGemmStridedBatched( @@ -2185,28 +2034,10 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, 1.0f, mha_qkv_w, num_inputs, num_inputs * num_outputs, in_out_tensor, num_inputs, 0, 0.0f, mha_q, num_outputs, num_outputs * batch_to_use, 3); -#if 0 - cutlassMatrixMulBTransposed((const half*)in_out_tensor, (const half*)mha_qkv_w, - (half*)mha_q, batch_to_use, num_outputs, - num_inputs, 3, 0, num_inputs * num_outputs, - num_outputs * batch_to_use, true); -#endif - // dumpTensor((DataType*)mha_q, num_outputs * batch / 4, - // "encoder 1 qkv output no bias", true); - // exit(0); addBiasBatched(mha_q, mha_q, mha_qkv_b, 3, batch, num_outputs, batch_to_use, ACTIVATION_NONE, stream); - // dumpTensor((DataType*)mha_q, num_outputs * batch / 4, - // "encoder 1 qkv output", true); - // exit(0); - - // dumpTensor((const DataType*)mha_q, - // /*num_outputs * batch_to_use*/ - // 768, "ref output values after - // bias add", false, false); - // exit(0); } } @@ -2319,15 +2150,6 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // #final dense layer (mha_dense), buffer2 -> buffer1 { - // if (int8_cali_) { - // calibrateGemmForInt8( - // mha_dense_.weights_int8, mha_dense_.input_scaling_factors, - // mha_dense_.output_scaling_factors, mha_dense_.output_deq_factors, - // mha_dense_.input_matrix_max_values, - // mha_dense_.output_matrix_max_values, buffer2, mha_dense_w, 64, - // embedding_op_size_, d_model, 1, N); - // } - const int num_inputs = d_model; const int num_outputs = embedding_op_size_; const int batch = N * 64; @@ -2341,76 +2163,39 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // 2. perform int8 GEMM (scratch -> buffer2) cutlassMatrixMulBTransposed((const int8_t*)scratch, - mha_dense_.weights_int8, (int8_t*)buffer2, + mha_dense_.weights_int8, (float*)buffer2, batch, num_outputs, num_inputs, 1, 0, 0, 0, mha_dense_.accum_rescale_factor, 0.0f); - /* - cutlassMatrixMulBTransposed( - (const int8_t*)scratch, mha_dense_.weights_int8, - mha_dense_.output_scaling_factors, (int8_t*)buffer2, batch, - num_outputs, num_inputs, 1, 0, 0, 0, 0, 1.0f, 0.0f); - */ - ReportCUDAErrors(cudaGetLastError()); - // 3. de-quantize outputs (buffer2 -> buffer1) - // TODO: Fuse this with LN1 (should be easy!) - // deQuantizeOutputMatrixBiasAdd( - // (half*)buffer1, (const int8_t*)buffer2, batch, num_outputs, 1, - // mha_dense_.output_scaling_factors, mha_dense_.output_deq_factors, - // nullptr, ACTIVATION_NONE, stream); - - deQuantizeOutputMatrixBiasAdd( - (half*)buffer1, (const int8_t*)buffer2, batch, num_outputs, 1, - mha_dense_.output_scaling_factors, nullptr, ACTIVATION_NONE, stream); - // dumpTensor((DataType*)buffer1, num_outputs * batch / 4, - // "encoder 1 mha dense output dequant", true); - // exit(0); - - // dequantizeWithLayerNorm(N * 64, embedding_op_size_, scratch, - // buffer1, mha_dense_b, - // in_out_tensor, ln1_gammas, ln1_betas, default_eps_, - // alpha_, ACTIVATION_NONE, stream); - + // LN1: skip connection and layer normalization (also bias add of prev + // gemm) buffer2 -> scratch + LayerNorm(N * 64, embedding_op_size_, scratch, (float*)buffer2, + mha_dense_b, in_out_tensor, ln1_gammas, ln1_betas, + default_eps_, alpha_, ACTIVATION_NONE, stream); } else { if (is_quantized_ && clipInputActivations) { clipActivationMatrix( (DataType*)buffer2, (const DataType*)buffer2, - mha_dense_.input_scaling_factors, batch, num_inputs, stream); + mha_dense_.fp16_clip_scale_factor, batch, num_inputs, stream); } cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, mha_dense_w, num_inputs, buffer2, num_inputs, 0.0f, buffer1, num_outputs); - /* - cutlassMatrixMulBTransposed( - (const half*)buffer2, (const half*)mha_dense_w, (half*)buffer1, - batch, num_outputs, num_inputs, 1, 0, 0, 0, true); - */ - // dumpTensor((DataType*)buffer1, num_outputs * batch / 4, - // "encoder 1 mha dense output", true); - // exit(0); + + // LN1: skip connection and layer normalization (also bias add of prev + // gemm) buffer1/in_out_tensor -> scratch + LayerNorm(N * 64, embedding_op_size_, scratch, buffer1, + mha_dense_b, in_out_tensor, ln1_gammas, ln1_betas, + default_eps_, alpha_, ACTIVATION_NONE, stream); } } - // LN1: skip connection and layer normalization (also bias add of prev gemm) - // buffer1/in_out_tensor -> scratch - LayerNorm(N * 64, embedding_op_size_, scratch, buffer1, mha_dense_b, - in_out_tensor, ln1_gammas, ln1_betas, default_eps_, - alpha_, ACTIVATION_NONE, stream); - // #FFN dense 1, scratch -> in_out_tensor const int encoder_dff = ffn_dense1_size_; { - // if (int8_cali_) { - // calibrateGemmForInt8( - // ffn1_.weights_int8, ffn1_.input_scaling_factors, - // ffn1_.output_scaling_factors, ffn1_.output_deq_factors, - // ffn1_.input_matrix_max_values, ffn1_.output_matrix_max_values, - // scratch, ffn_dense1_w, 64, encoder_dff, embedding_op_size_, 1, N); - // } - const int num_inputs = embedding_op_size_; const int num_outputs = ffn_dense1_size_; // encoder_dff const int batch = N * 64; @@ -2423,94 +2208,40 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, stream); // 2. perform int8 GEMM (in_out_tensor -> buffer1) - /* - cutlassMatrixMulBTransposed((const int8_t*)in_out_tensor, - ffn1_.weights_int8, (int8_t*)buffer1, batch, num_outputs, num_inputs, 1, - 0, 0, 0, ffn1_.accum_rescale_factor, 0.0f); - */ cutlassMatrixMulBTransposed((const int8_t*)in_out_tensor, - ffn1_.weights_int8, (int8_t*)buffer1, batch, - num_outputs, num_inputs, 1, 0, 0, 0, + ffn1_.weights_int8, (float*)buffer1, + batch, num_outputs, num_inputs, 1, 0, 0, 0, ffn1_.accum_rescale_factor, 0.0f); - // dumpTensor((const int8_t*)buffer1, num_outputs * batch / 4, - // "encoder 1 ffn1 output quant", true); - // ReportCUDAErrors(cudaGetLastError()); - ReportCUDAErrors(cudaGetLastError()); - /* - dumpTensor(ffn1_.input_scaling_factors, 768, - "input_scaling_factors - during run", false, false); - - dumpTensor((const int8_t*)in_out_tensor, 768, "quantized input - matrix", false, false); - - dumpTensor(ffn1_.weights_int8, 768, - "weights - during run", false, false); - - dumpTensor((const int8_t*)buffer1, 768, - "some quantized output values", false, false); - - dumpTensor(ffn1_.output_scaling_factors, 768, - "output_scaling_factors - during run", - false, false); - */ - // 3. de-quantize outputs - fused with bias add (buffer1 -> in_out_tensor) - // TODO: Fuse this with the above GEMM - deQuantizeOutputMatrixBiasAdd( - (half*)in_out_tensor, (const int8_t*)buffer1, batch, num_outputs, 1, - ffn1_.output_scaling_factors, (const half*)ffn_dense1_b, - ffn_activation_, stream); - - // dumpTensor((DataType*)in_out_tensor, num_outputs * batch / 4, - // "encoder 1 ffn1 dense output dequant"); - // exit(0); - - // Ankan - test! - // dumpTensor((const DataType*)in_out_tensor, 768, - // "runtime output values after bias and RELU2", - // false, false); - // exit(0); + + // 3. Bias add - mixed precision (buffer1 -> in_out_tensor) + addBiasBatched(in_out_tensor, (float*)buffer1, ffn_dense1_b, 1, batch, + num_outputs, ffn_activation_, stream); } else { if (is_quantized_ && clipInputActivations) { + // Note `scratch` should not be changed as it is the FFN input to be used + // as skip connection later at the layer norm. clipActivationMatrix( - (DataType*)scratch, (const DataType*)scratch, - ffn1_.input_scaling_factors, batch, num_inputs, stream); + (DataType*)in_out_tensor, (const DataType*)scratch, + ffn1_.fp16_clip_scale_factor, batch, num_inputs, stream); + + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, (const DataType*)ffn_dense1_w, num_inputs, + in_out_tensor, num_inputs, 0.0f, buffer1, num_outputs); + } else { + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, (const DataType*)ffn_dense1_w, num_inputs, + scratch, num_inputs, 0.0f, buffer1, num_outputs); } - cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, (const DataType*)ffn_dense1_w, num_inputs, - scratch, num_inputs, 0.0f, in_out_tensor, num_outputs); - /* - cutlassMatrixMulBTransposed( - (const half*)scratch, (const half*)ffn_dense1_w, (half*)in_out_tensor, - batch, num_outputs, num_inputs, 1, 0, 0, 0, true); - */ - addBiasBatched(in_out_tensor, in_out_tensor, ffn_dense1_b, 1, batch, + addBiasBatched(in_out_tensor, buffer1, ffn_dense1_b, 1, batch, num_outputs, ffn_activation_, stream); - // dumpTensor((DataType*)in_out_tensor, num_outputs * batch / 4, - // "encoder 1 ffn1 dense output"); - // exit(0); - - // Ankan - test! - // dumpTensor((const DataType*)in_out_tensor, 768, - // "Ref output values after bias and RELU2", false, - // false); - // exit(0); } } // #FFN dense 2, in_out_tensor -> buffer1 { - // if (int8_cali_) { - // calibrateGemmForInt8( - // ffn2_.weights_int8, ffn2_.input_scaling_factors, - // ffn2_.output_scaling_factors, ffn2_.output_deq_factors, - // ffn2_.input_matrix_max_values, ffn2_.output_matrix_max_values, - // in_out_tensor, ffn_dense2_w, 64, embedding_op_size_, encoder_dff, - // 1, N); - // } - const int num_inputs = ffn_dense1_size_; // encoder_dff const int num_outputs = embedding_op_size_; const int batch = N * 64; @@ -2521,80 +2252,42 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, quantizeActivationMatrix((int8_t*)buffer1, (const half*)in_out_tensor, batch, num_inputs, ffn2_.input_scaling_factors, stream); - /* - dumpTensor((const float*)ffn2_.input_scaling_factors, 768, - "input scaling factors during run", - false, false); - - dumpTensor((const int8_t*)buffer1, 768, "quantized input matrix", - false, false); - dumpTensor(ffn2_.weights_int8, 768, "weights - during run", false, - false); - */ - // 2. perform int8 GEMM (buffer1 -> in_out_tensor) + // 2. perform int8 GEMM (buffer1 -> buffer2) cutlassMatrixMulBTransposed((const int8_t*)buffer1, ffn2_.weights_int8, - (int8_t*)in_out_tensor, batch, num_outputs, + (float*)buffer2, batch, num_outputs, num_inputs, 1, 0, 0, 0, ffn2_.accum_rescale_factor, 0.0f); - // dumpTensor((const int8_t*)in_out_tensor, num_outputs * batch / - // 4, - // "encoder 1 ffn2 output quant", true); ReportCUDAErrors(cudaGetLastError()); - /* - dumpTensor((const int8_t*)in_out_tensor, 768, - "some quantized output values", false, false); - - dumpTensor(ffn2_.output_scaling_factors, 768, - "output_scaling_factors - during run", false, false); - */ - // 3. de-quantize outputs (in_out_tensor -> buffer1) - // TODO: Fuse this with LN2 (should be easy) - deQuantizeOutputMatrixBiasAdd( - (half*)buffer1, (const int8_t*)in_out_tensor, batch, num_outputs, 1, - ffn2_.output_scaling_factors, nullptr, ACTIVATION_NONE, stream); - - // dumpTensor((const DataType*)buffer1, num_outputs * batch / 4, - // "encoder 1 ffn2 output dequant"); - // exit(0); + + // LN2: skip connection and layer normilization (also bias add of prev + // gemm) buffer2/scratch -> in_out_tensor + LayerNorm(N * 64, embedding_op_size_, in_out_tensor, + (float*)buffer2, ffn_dense2_b, scratch, ln2_gammas, + ln2_betas, default_eps_, alpha_, ACTIVATION_NONE, + stream); + } else { if (is_quantized_ && clipInputActivations) { clipActivationMatrix( (DataType*)in_out_tensor, (const DataType*)in_out_tensor, - ffn2_.input_scaling_factors, batch, num_inputs, stream); + ffn2_.fp16_clip_scale_factor, batch, num_inputs, stream); } cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)ffn_dense2_w, num_inputs, in_out_tensor, num_inputs, 0.0f, buffer1, num_outputs); - /* - cutlassMatrixMulBTransposed( - (const half*)in_out_tensor, (const half*)ffn_dense2_w, (half*)buffer1, - batch, num_outputs, num_inputs, 1, 0, 0, 0, true); - */ - // dumpTensor((DataType*)buffer1, num_outputs * batch / 4, - // "encoder 1 ffn2 output"); - // exit(0); - - /* - dumpTensor((const half*)buffer1, 768, "dequantized output values - - ref", false, false); + // LN2: skip connection and layer normilization (also bias add of prev + // gemm) buffer1/scratch -> in_out_tensor + LayerNorm(N * 64, embedding_op_size_, in_out_tensor, buffer1, + ffn_dense2_b, scratch, ln2_gammas, ln2_betas, + default_eps_, alpha_, ACTIVATION_NONE, stream); - exit(0); - */ + // dumpTensor((const DataType*)in_out_tensor, num_outputs * 64, + // "encoder 1 ffn2 LN output", false); + // exit(0); } } - - // LN2: skip connection and layer normilization (also bias add of prev gemm) - // buffer1/scratch -> in_out_tensor - LayerNorm(N * 64, embedding_op_size_, in_out_tensor, buffer1, - ffn_dense2_b, scratch, ln2_gammas, ln2_betas, - default_eps_, alpha_, ACTIVATION_NONE, stream); - - // printf("\nencoder %i: batchsize", blockIndex_, N); - // dumpTensor((const half*)in_out_tensor, embedding_op_size_ * 64, "ln2 output", - // true); - // exit(0); } template @@ -2840,13 +2533,11 @@ AttentionBody::AttentionBody(const MultiHeadWeights& weights, // Quantization data for input embedding FFN layers. LoadQuantizationData(emb_ffn1_, (half*)ip_emb_ffn_d1_w_, - embedding_dense_size_, embedding_ffn_dff_, - weights.ip_emb_ffn.dense1_s, weights.ip_emb_ffn.s1, - 127.0, scratch, 0); + embedding_ffn_size_, embedding_ffn_dff_, + weights.ip_emb_ffn.dense1_s, weights.ip_emb_ffn.s1, 0); LoadQuantizationData(emb_ffn2_, (half*)ip_emb_ffn_d2_w_, embedding_ffn_dff_, embedding_ffn_size_, weights.ip_emb_ffn.dense2_s, - weights.ip_emb_ffn.s2, 127.0, scratch, 0); - + weights.ip_emb_ffn.s2, 0); } else { size_t size = 64 * kNumPosEncodingChannels * sizeof(float); ReportCUDAErrors(cudaMalloc(&pos_encoding_, size)); @@ -2854,13 +2545,6 @@ AttentionBody::AttentionBody(const MultiHeadWeights& weights, cudaMemcpy(scratch, kPosEncoding, size, cudaMemcpyHostToDevice)); copyTypeConverted(pos_encoding_, (float*)scratch, size, 0); } - printf("Is PE Dense Embedding %i\n", is_pe_dense_embedding); - - printf("has_gating: %i, has_smolgen: %i\n", has_gating_, has_smolgen_); - printf("ip_mult_gate: %i\n", weights.ip_mult_gate.size()); - // for (auto i=0; i(&ip_mult_gate_, weights.ip_mult_gate, scratch); @@ -3021,23 +2705,27 @@ void AttentionBody::Eval(int N, DataType* output, num_inputs, emb_ffn1_.input_scaling_factors, stream); - // 2. int8 matmul (temp -> buffer1) + // 2. int8 matmul (temp -> buffer2) cutlassMatrixMulBTransposed((const int8_t*)temp, emb_ffn1_.weights_int8, - (int8_t*)buffer1, batch, num_outputs, + (float*)buffer2, batch, num_outputs, num_inputs, 1, 0, 0, 0, emb_ffn1_.accum_rescale_factor, 0.0f); + + // cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + // num_inputs, emb_ffn1_.accum_rescale_factor, + // emb_ffn1_.weights_int8, num_inputs, (int8_t*)temp, + // num_inputs, 0.0f, (float*)buffer2, num_outputs); ReportCUDAErrors(cudaGetLastError()); - // 3. dequantize + bias add (buffer1 -> buffer1) - deQuantizeOutputMatrixBiasAdd( - (half*)buffer1, (const int8_t*)buffer1, batch, num_outputs, 1, - emb_ffn1_.output_scaling_factors, (const half*)ip_emb_ffn_d1_b_, - activations_.ffn_activation, stream); + // 3. Bias add (mixed precision) (buffer2 -> buffer1) + addBiasBatched(buffer1, (float*)buffer2, ip_emb_ffn_d1_b_, 1, batch, + num_outputs, activations_.ffn_activation, stream); + } else { if (is_quantized_ && clipInputActivations) { clipActivationMatrix( (DataType*)temp, (const DataType*)embedding, - emb_ffn1_.input_scaling_factors, batch, num_inputs, stream); + emb_ffn1_.fp16_clip_scale_factor, batch, num_inputs, stream); } cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)ip_emb_ffn_d1_w_, @@ -3057,34 +2745,36 @@ void AttentionBody::Eval(int N, DataType* output, quantizeActivationMatrix((int8_t*)temp, (const half*)buffer1, batch, num_inputs, emb_ffn2_.input_scaling_factors, stream); - - // 2. int8 matmul (temp -> buffer1) + // 2. int8 matmul (temp -> buffer2) cutlassMatrixMulBTransposed((const int8_t*)temp, emb_ffn2_.weights_int8, - (int8_t*)buffer1, batch, num_outputs, + (float*)buffer2, batch, num_outputs, num_inputs, 1, 0, 0, 0, emb_ffn2_.accum_rescale_factor, 0.0f); - // 3. dequantize + bias add (buffer1 -> buffer2) - deQuantizeOutputMatrixBiasAdd( - (half*)buffer2, (const int8_t*)buffer1, batch, num_outputs, 1, - emb_ffn2_.output_scaling_factors, nullptr, ACTIVATION_NONE, stream); + // cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + // num_inputs, emb_ffn2_.accum_rescale_factor, + // emb_ffn2_.weights_int8, num_inputs, (const int8_t*)temp, + // num_inputs, 0.0f, (float*)buffer2, num_outputs); + // Embedding LN: skip connection and layer normalization (also bias add + // of prev gemm) (buffer2 -> embedding/output_tensor) float alpha = (float)pow(2. * encoder_weights_.size(), -0.25); - LayerNorm(N * 64, embedding_ffn_size_, output_tensor, buffer2, - ip_emb_ffn_d2_b_, embedding, ip_emb_ffn_ln_g_, - ip_emb_ffn_ln_b_, 1e-3, alpha, ACTIVATION_NONE, - stream); + LayerNorm(N * 64, embedding_ffn_size_, output_tensor, + (float*)buffer2, ip_emb_ffn_d2_b_, embedding, + ip_emb_ffn_ln_g_, ip_emb_ffn_ln_b_, 1e-3, alpha, + ACTIVATION_NONE, stream); + } else { if (is_quantized_ && clipInputActivations) { clipActivationMatrix( (DataType*)buffer1, (const DataType*)buffer1, - emb_ffn2_.input_scaling_factors, batch, num_inputs, stream); + emb_ffn2_.fp16_clip_scale_factor, batch, num_inputs, stream); } cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, (const DataType*)ip_emb_ffn_d2_w_, - num_inputs, buffer1, num_inputs, 0.0f, buffer2, - num_outputs); + num_inputs, 1.0f, ip_emb_ffn_d2_w_, num_inputs, buffer1, + num_inputs, 0.0f, buffer2, num_outputs); + // Embedding LN: skip connection and layer normilization (also bias add // of prev gemm) buffer2 -> embedding float alpha = (float)pow(2. * encoder_weights_.size(), -0.25); @@ -3117,12 +2807,11 @@ void AttentionBody::Eval(int N, DataType* output, } // 2. Encoder blocks - // for (const auto enc : encoder_weights_) { - // enc->Eval(N, output_tensor, (DataType*)scratch, buffer1, buffer2, cublas, - // stream, offset_pointers); - // // dumpTensor(output_tensor, embedding_op_size_, "encoder 1 output"); - // // if (i++ == 10) break; - // } // End of encoder blocks + int i = 0; + for (const auto enc : encoder_weights_) { + enc->Eval(N, output_tensor, (DataType*)scratch, buffer1, buffer2, cublas, + stream, offset_pointers); + } // End of encoder blocks } template diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index 372407f766..dcb10bb8e5 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -345,6 +345,7 @@ struct MatMulQuantizationData { float* input_matrix_max_values; // max values of input matrix (always in CPU memory) float* output_matrix_max_values; // max values in output matrix (always in CPU memory) float accum_rescale_factor; // accumulator rescale factor for matmuls to prevent overflow + float fp16_clip_scale_factor; // scale factor for clipping inputs in fp16 or fp32 mode }; diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 057ea93ea3..003bbb8357 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -46,11 +46,13 @@ namespace lczero { using namespace cudnn_backend; +#if 0 namespace cudnn_backend { template void dumpTensor(const T* memory, int elements, const char* message, bool only_summary = false, bool cpu_tensor = false); } +#endif template class CudaNetwork; @@ -394,25 +396,6 @@ class CudaNetwork : public Network { throw Exception("INT8 only supported for attention body networks"); } -#if 0 - // Ankan - test: dump some weights here - float* data = (float*)int8_weights_; - dumpTensor(data, weights.ip_emb_b.size(), - "per-channel scaling factors for input", - false, true); - - int8_t* w = (int8_t*)int8_weights_; - w += weights.ip_emb_b.size() * sizeof(float); - dumpTensor(w, 512, "quantized weights", - false, true); - - w += 3 * weights.ip_emb_b.size() * weights.encoder[0].mha.q_b.size() * - sizeof(int8_t); - dumpTensor((float*)w, 768, "scaling factors for output", false, true); - - exit(0); -#endif - // 2. Build the network, and copy the weights to GPU memory. // Input conv only used if there are residual blocks in the network @@ -800,7 +783,6 @@ class CudaNetwork : public Network { scratch_mem, scratch_size_, nullptr, cublas, stream); // policy map layer // POLICY output } - dumpTensor(opPol, 1858, "Output policy", false); } else if (conv_policy_) { network_[l++]->Eval(batchSize, spare1, flow, nullptr, scratch_mem, scratch_size_, nullptr, cublas, From 09fcd568d797b2249be8e2643ebcab17381c9dc7 Mon Sep 17 00:00:00 2001 From: almaudoh Date: Sun, 12 May 2024 04:45:51 +0200 Subject: [PATCH 61/70] Change gemms to int32 - wip --- src/neural/cuda/common_kernels.cu | 43 +++++-- src/neural/cuda/cutlass_kernels.cu | 48 ++++---- src/neural/cuda/kernels.h | 3 +- src/neural/cuda/layers.cc | 190 ++++++++++++++++------------- src/neural/cuda/layers.h | 3 +- 5 files changed, 161 insertions(+), 126 deletions(-) diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index bcbd85b30b..cc8175e461 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -970,7 +970,7 @@ template __global__ void layer_norm_kernel(int N, int C, T* output, const IT* input, const T* bias, const T* skip, const T* gammas, const T* betas, float ep, float alpha, - ActivationFunction act) { + ActivationFunction act, float dequant_scale) { int n = blockIdx.x * blockDim.z + threadIdx.z; if (n >= N) return; int c = (threadIdx.y * 32 + threadIdx.x) * 16; @@ -985,7 +985,16 @@ __global__ void layer_norm_kernel(int N, int C, T* output, const IT* input, const bool fp16 = std::is_same::value; if (!oobThread) { // Load from memory (16 elements a time) - if (std::is_same::value) { + if (std::is_same::value) { + int32_t inp[8]; + copyAs(&inp[0], &input[tensorIndex]); + copyAs(&inp[4], &input[tensorIndex + 4]); + for (int i = 0; i < 8; i++) val[i] = (float)inp[i] * dequant_scale; + + copyAs(&inp[0], &input[tensorIndex + 8]); + copyAs(&inp[4], &input[tensorIndex + 16]); + for (int i = 0; i < 8; i++) val[i + 8] = (float)inp[i] * dequant_scale; + } else if (std::is_same::value) { half inp[8]; copyAs(&inp[0], &input[tensorIndex]); for (int i = 0; i < 8; i++) val[i] = (float)inp[i]; @@ -1348,7 +1357,8 @@ __global__ void layer_norm_kernel_8_el_per_thread( template void LayerNorm(int N, int C, T* output, const IT* input, const T* bias, const T* skip, const T* gammas, const T* betas, float ep, - float alpha, ActivationFunction act, cudaStream_t stream) { + float alpha, ActivationFunction act, cudaStream_t stream, + float dequant_scale) { // process 4 elements per thread to achieve close to peak memory bandwidth if (C % 16 != 0) throw Exception("unsupported filter size"); if (C > 16384) throw Exception("unsupported filter size"); @@ -1363,7 +1373,8 @@ void LayerNorm(int N, int C, T* output, const IT* input, const T* bias, gridDim.z = 1; layer_norm_kernel<<>>( - N, C, output, input, bias, skip, gammas, betas, ep, alpha, act); + N, C, output, input, bias, skip, gammas, betas, ep, alpha, act, + dequant_scale); ReportCUDAErrors(cudaGetLastError()); } @@ -1784,22 +1795,30 @@ template void Softmax(int N, int C, half* output, const half* input, template void Softmax(int N, int C, float* output, const float* input, const float* input2, cudaStream_t stream); -template void LayerNorm(int N, int C, half* output, - const float* input, const half* bias, - const half* skip, const half* gammas, - const half* betas, float ep, float alpha, - ActivationFunction act, - cudaStream_t stream); +template void LayerNorm(int N, int C, half* output, + const int32_t* input, const half* bias, + const half* skip, const half* gammas, + const half* betas, float ep, float alpha, + ActivationFunction act, + cudaStream_t stream, + float dequant_scale); +template void LayerNorm(int N, int C, float* output, + const int32_t* input, const float* bias, + const float* skip, const float* gammas, + const float* betas, float ep, + float alpha, ActivationFunction act, + cudaStream_t stream, + float dequant_scale); template void LayerNorm(int N, int C, half* output, const half* input, const half* bias, const half* skip, const half* gammas, const half* betas, float ep, float alpha, ActivationFunction act, - cudaStream_t stream); + cudaStream_t stream, float dequant_scale); template void LayerNorm(int N, int C, float* output, const float* input, const float* bias, const float* skip, const float* gammas, const float* betas, float ep, float alpha, ActivationFunction act, - cudaStream_t stream); + cudaStream_t stream, float dequant_scale); template void ComputePromotionLogits(int N, int C, half* output, const half* keys, const half* ppo, diff --git a/src/neural/cuda/cutlass_kernels.cu b/src/neural/cuda/cutlass_kernels.cu index 4cec5b01d5..3b16431307 100644 --- a/src/neural/cuda/cutlass_kernels.cu +++ b/src/neural/cuda/cutlass_kernels.cu @@ -229,9 +229,9 @@ void dumpTensor(const T* memory, int elements, const char* message, } if (!only_summary || i < 3 || i == elements - 1) { - if (int8) { + if (int8 || int32) { // printf("%6i ", (int8_t)val); - printf("%i;%6i\n", i, (int8_t)val); + printf("%i;%8i\n", i, (int)val); } else { // printf("%8.6f ", val); printf("%i;%8.6f\n", i, val); @@ -280,14 +280,14 @@ void dumpTensor(const T* memory, int elements, const char* message, // int8 GEMM using CUTLASS (with per-column output quantization) void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, - const float* scaleVector, float* Out, int M, + const float* scaleVector, int32_t* Out, int M, int N, int K, int batchSize, int AStride, int BStride, int OutStride, int VecStride, float alphaf, float betaf) { using ElementAccumulator = int32_t; using ElementComputeEpilogue = float; using ElementInput = int8_t; - using ElementOutput = float; + using ElementOutput = int32_t; using ElementScale = float; using ThreadBlockSwizzle = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; @@ -296,8 +296,8 @@ void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise< - ElementOutput, ElementAccumulator, ElementComputeEpilogue, ElementOutput, - ElementOutput, + ElementOutput, ElementAccumulator, ElementComputeEpilogue, + ElementOutput, ElementOutput, // ElementScale, // element Vector elementsPerAccess, // false, @@ -345,7 +345,7 @@ void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, } // int8 GEMM using CUTLASS -void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, float* Out, +void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int32_t* Out, int M, int N, int K, int batchSize, int AStride, int BStride, int OutStride, float alphaf, float betaf) { @@ -357,7 +357,8 @@ void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, float* Out, using ElementComputeEpilogue = float; // <- data type of epilogue operations using ElementInputA = int8_t; // <- data type of elements in input matrix A using ElementInputB = int8_t; // <- data type of elements in input matrix B - using ElementOutput = float; // <- data type of elements in output matrix Out + using ElementOutput = + int32_t; // <- data type of elements in gemm output matrix Out // TODO: figure out why row major for matrix B doesn't work?!!! using LayoutInputA = cutlass::layout::RowMajor; @@ -993,35 +994,35 @@ struct ScaleParam { }; // process 8 elements per thread (in x dimension) -__global__ void deQuantizeMatrix(half* output, const int8_t* input, +__global__ void deQuantizeMatrix(half* output, const int32_t* input, const half* bias, int height, int width, - int stride, const float *invScale, ScaleParam deq, - ActivationFunction act) { + int stride, const float* invScaleArr, + const float invScale, ActivationFunction act) { int x = (blockIdx.x * blockDim.x + threadIdx.x) * 8; int y = blockIdx.y * blockDim.y + threadIdx.y; int b = blockIdx.z; if (x >= width || y >= height) return; - int8_t ip[8] = {}; + int32_t ip[8] = {}; half op[8] = {}; half bi[8] = {}; float inv_scale[8]; - float deq_scale = deq.scale[b]; - copyAs(&ip[0], &input[b * stride + y * width + x]); + copyAs(&ip[0], &input[b * stride + y * width + x]); + copyAs(&ip[4], &input[b * stride + y * width + x + 4]); if (bias) copyAs(&bi[0], &bias[b * width + x]); - if (invScale) { - copyAs(&inv_scale[0], &invScale[b * width + x]); - copyAs(&inv_scale[4], &invScale[b * width + x + 4]); + if (invScaleArr) { + copyAs(&inv_scale[0], &invScaleArr[b * width + x]); + copyAs(&inv_scale[4], &invScaleArr[b * width + x + 4]); } else { - for (int i = 0; i < 8; i++) inv_scale[i] = 1 / 127.0f; + for (int i = 0; i < 8; i++) inv_scale[i] = invScale; } for (int i = 0; i < 8; i++) { float val = (float)ip[i]; - val *= (deq_scale / inv_scale[i]); + val *= inv_scale[i]; if (bias) val += (float)bi[i]; op[i] = (half)activate(val, act); } @@ -1031,9 +1032,9 @@ __global__ void deQuantizeMatrix(half* output, const int8_t* input, // the scale (in CPU memory) is per "batch" // the bias is per column, per batch -void deQuantizeOutputMatrixBiasAdd(half* output, const int8_t* input, +void deQuantizeOutputMatrixBiasAdd(half* output, const int32_t* input, int height, int width, int batchSize, - float* invScale, float* deq, + float* invScaleArr, float invScale, const half* bias, ActivationFunction act, cudaStream_t stream) { dim3 blockDim(16, 16); @@ -1045,11 +1046,8 @@ void deQuantizeOutputMatrixBiasAdd(half* output, const int8_t* input, int stride = width * height; - ScaleParam s = {}; - for (int i = 0; i < batchSize; i++) s.scale[i] = deq[i]; - deQuantizeMatrix<<>>( - output, input, bias, height, width, stride, invScale, s, act); + output, input, bias, height, width, stride, invScaleArr, invScale, act); ReportCUDAErrors(cudaGetLastError()); } diff --git a/src/neural/cuda/kernels.h b/src/neural/cuda/kernels.h index ff2cc793d1..b5fbf8e43a 100644 --- a/src/neural/cuda/kernels.h +++ b/src/neural/cuda/kernels.h @@ -138,7 +138,8 @@ void Softmax(int N, int C, T* output, const T* input, const T* input2, template void LayerNorm(int N, int C, T* output, const IT* input, const T* bias, const T* skip, const T* gammas, const T* betas, float ep, - float alpha, ActivationFunction act, cudaStream_t stream); + float alpha, ActivationFunction act, + cudaStream_t stream, float dequant_scale = 1.0); template void ComputePromotionLogits(int N, int C, T* output, const T* keys, diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index f13365aa32..351843a172 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -1537,13 +1537,13 @@ AttentionPolicyHead::AttentionPolicyHead( void fillGpuArray(float* arr, float val, int count); -void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, float* Out, +void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int32_t* Out, int M, int N, int K, int batchSize, int AStride, int BStride, int OutStride, float alphaf, float betaf); void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, - const float* scaleVector, float* Out, int M, + const float* scaleVector, int32_t* Out, int M, int N, int K, int batchSize, int AStride, int BStride, int OutStride, int VecStride, float alphaf, float betaf); @@ -1561,28 +1561,23 @@ void clipActivationMatrix(DataType* output, const DataType* input, const float scale_factor, int height, int width, cudaStream_t stream); -void deQuantizeOutputMatrixBiasAdd(half* output, const int8_t* input, +void deQuantizeOutputMatrixBiasAdd(half* output, const int32_t* input, int height, int width, int batchSize, - float* scale, float* deq, const half* bias, - ActivationFunction act, cudaStream_t stream); - -void deQuantizeOutputMatrixBiasAdd(half* output, const int8_t* input, - int height, int width, int batchSize, - float* scale, const half* bias, - ActivationFunction act, cudaStream_t stream); + float* scaleArr, float scale, + const half* bias, ActivationFunction act, + cudaStream_t stream); static void LoadQuantizationData(MatMulQuantizationData& data, half* weights, int input_len, int output_len, const std::vector& weight_factors, const std::vector& input_factors, - cudaStream_t stream) { + cudaStream_t stream, + bool vector_output_factor = false) { // Load weights for INT8 inference // (per-column) scaling factors for the input and output. ReportCUDAErrors( cudaMalloc(&data.input_scaling_factors, input_len * sizeof(float))); - ReportCUDAErrors( - cudaMalloc(&data.output_scaling_factors, output_len * sizeof(float))); if (input_factors.size() > 1) { // ReportCUDAErrors(cudaMemcpy( @@ -1597,9 +1592,13 @@ static void LoadQuantizationData(MatMulQuantizationData& data, half* weights, data.fp16_clip_scale_factor = input_factors[0]; // Repeatedly fill values into the output factors buffer. - data.accum_rescale_factor = input_factors[0] * weight_factors[0]; - fillGpuArray(data.output_scaling_factors, - input_factors[0] * weight_factors[0], output_len); + data.output_rescale_factor = input_factors[0] * weight_factors[0]; + if (vector_output_factor) { + ReportCUDAErrors( + cudaMalloc(&data.output_scaling_factors, output_len * sizeof(float))); + fillGpuArray(data.output_scaling_factors, data.output_rescale_factor, + output_len); + } } // Load weights and run a GPU kernel to scale it. @@ -1807,16 +1806,16 @@ EncoderBlock::EncoderBlock( } } -// Int8 x Int8 -> Float32 +// Int8 x Int8 -> Int32 static void cublasXgemm(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, - float alpha, const int8_t* A, int lda, const int8_t* B, - int ldb, float beta, float* C, int ldc) { + int alpha, const int8_t* A, int lda, const int8_t* B, + int ldb, int beta, int32_t* C, int ldc) { ReportCUBLASErrors(cublasGemmEx(handle, transa, transb, m, n, k, &alpha, A, CUDA_R_8I /* int8 data type */, lda, B, CUDA_R_8I /* Data type of B */, ldb, &beta, C, - CUDA_R_32F /* Data type of C */, ldc, - CUDA_R_32F, // Compute type FP32 + CUDA_R_32I /* Data type of C */, ldc, + CUBLAS_COMPUTE_32I, // Compute type Int32 CUBLAS_GEMM_DEFAULT // Algorithm type )); } @@ -1995,30 +1994,31 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, qkv_.input_scaling_factors, stream); // 2. perform int8 GEMM (scratch -> buffer1) - cutlassMatrixMulBTransposed( - (const int8_t*)scratch, qkv_.weights_int8, - qkv_.output_scaling_factors, (float*)buffer1, batch, num_outputs, - num_inputs, 3, 0, num_inputs * num_outputs, - num_outputs * batch_to_use, num_outputs, 1.0f, 0.0f); + cutlassMatrixMulBTransposed((const int8_t*)scratch, qkv_.weights_int8, + (int32_t*)buffer1, batch, num_outputs, + num_inputs, 3, 0, num_inputs * num_outputs, + num_outputs, 1.0f, 0.0f); + // dumpTensor((float*)qkv_.output_scaling_factors, 5, + // "encoder 1 qkv output dequant", true); + // dumpTensor((int32_t*)buffer1, 5, "encoder 1 qkv output dequant", true); + // exit(0); + + // 3. Dequantize and bias add - mixed precision (buffer1 -> mha_q) + deQuantizeOutputMatrixBiasAdd((half*)mha_q, (int32_t*)buffer1, batch, + num_outputs, 3, qkv_.output_scaling_factors, + 1.0f, (half*)mha_qkv_b, ACTIVATION_NONE, + stream); // cublasXGemmStridedBatched( - // cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, - // 1.0f, qkv_.weights_int8, num_inputs, num_inputs * num_outputs, - // (const int8_t*)scratch, num_inputs, 0, 0.0f, (float*)buffer1, - // num_outputs, num_outputs * batch_to_use, 3); + // cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + // num_inputs, 1.0f, qkv_.weights_int8, num_inputs, num_inputs * + // num_outputs, (const int8_t*)scratch, num_inputs, 0, 0.0f, + // (float*)buffer1, num_outputs, num_outputs * batch_to_use, 3); // 3. Bias add - mixed precision (buffer1 -> mha_q) - addBiasBatched(mha_q, (float*)buffer1, mha_qkv_b, 3, - batch, num_outputs, batch_to_use, - ACTIVATION_NONE, stream); - - /* - cutlassMatrixMulBTransposed( - (const int8_t*)scratch, qkv_.weights_int8, - qkv_.output_scaling_factors, (int8_t*)buffer1, batch, num_outputs, - num_inputs, 3, 0, num_inputs * num_outputs, - num_outputs * batch_to_use, num_outputs, 1.0, 0.0f); - */ + // addBiasBatched(mha_q, (float*)buffer1, mha_qkv_b, 3, + // batch, num_outputs, batch_to_use, + // ACTIVATION_NONE, stream); ReportCUDAErrors(cudaGetLastError()); @@ -2037,8 +2037,9 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, addBiasBatched(mha_q, mha_q, mha_qkv_b, 3, batch, num_outputs, batch_to_use, ACTIVATION_NONE, stream); - } + // dumpTensor(mha_q, num_outputs * 64 * 3, "encoder 1 qkv output", true); + // exit(0); } // Apply split_heads() to q, k and v @@ -2162,18 +2163,18 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, stream); // 2. perform int8 GEMM (scratch -> buffer2) - cutlassMatrixMulBTransposed((const int8_t*)scratch, - mha_dense_.weights_int8, (float*)buffer2, - batch, num_outputs, num_inputs, 1, 0, 0, 0, - mha_dense_.accum_rescale_factor, 0.0f); + cutlassMatrixMulBTransposed( + (const int8_t*)scratch, mha_dense_.weights_int8, (int32_t*)buffer2, + batch, num_outputs, num_inputs, 1, 0, 0, 0, 1.0, 0.0f); ReportCUDAErrors(cudaGetLastError()); // LN1: skip connection and layer normalization (also bias add of prev // gemm) buffer2 -> scratch - LayerNorm(N * 64, embedding_op_size_, scratch, (float*)buffer2, - mha_dense_b, in_out_tensor, ln1_gammas, ln1_betas, - default_eps_, alpha_, ACTIVATION_NONE, stream); + LayerNorm( + N * 64, embedding_op_size_, scratch, (int32_t*)buffer2, mha_dense_b, + in_out_tensor, ln1_gammas, ln1_betas, default_eps_, alpha_, + ACTIVATION_NONE, stream, mha_dense_.output_rescale_factor); } else { if (is_quantized_ && clipInputActivations) { clipActivationMatrix( @@ -2208,14 +2209,19 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, stream); // 2. perform int8 GEMM (in_out_tensor -> buffer1) - cutlassMatrixMulBTransposed((const int8_t*)in_out_tensor, - ffn1_.weights_int8, (float*)buffer1, - batch, num_outputs, num_inputs, 1, 0, 0, 0, - ffn1_.accum_rescale_factor, 0.0f); - + cutlassMatrixMulBTransposed( + (const int8_t*)in_out_tensor, ffn1_.weights_int8, (int32_t*)buffer1, + batch, num_outputs, num_inputs, 1, 0, 0, 0, 1.0, 0.0f); + + // 3. Dequantize and bias add - mixed precision (buffer1 -> in_out_tensor) + deQuantizeOutputMatrixBiasAdd( + (half*)in_out_tensor, (int32_t*)buffer1, batch, num_outputs, 1, + nullptr, ffn1_.output_rescale_factor, (half*)ffn_dense1_b, + ffn_activation_, stream); + // 3. Bias add - mixed precision (buffer1 -> in_out_tensor) - addBiasBatched(in_out_tensor, (float*)buffer1, ffn_dense1_b, 1, batch, - num_outputs, ffn_activation_, stream); + // addBiasBatched(in_out_tensor, (float*)buffer1, ffn_dense1_b, 1, batch, + // num_outputs, ffn_activation_, stream); } else { if (is_quantized_ && clipInputActivations) { @@ -2255,17 +2261,16 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // 2. perform int8 GEMM (buffer1 -> buffer2) cutlassMatrixMulBTransposed((const int8_t*)buffer1, ffn2_.weights_int8, - (float*)buffer2, batch, num_outputs, - num_inputs, 1, 0, 0, 0, - ffn2_.accum_rescale_factor, 0.0f); + (int32_t*)buffer2, batch, num_outputs, + num_inputs, 1, 0, 0, 0, 1.0, 0.0f); ReportCUDAErrors(cudaGetLastError()); // LN2: skip connection and layer normilization (also bias add of prev // gemm) buffer2/scratch -> in_out_tensor - LayerNorm(N * 64, embedding_op_size_, in_out_tensor, - (float*)buffer2, ffn_dense2_b, scratch, ln2_gammas, - ln2_betas, default_eps_, alpha_, ACTIVATION_NONE, - stream); + LayerNorm( + N * 64, embedding_op_size_, in_out_tensor, (int32_t*)buffer2, + ffn_dense2_b, scratch, ln2_gammas, ln2_betas, default_eps_, alpha_, + ACTIVATION_NONE, stream, ffn2_.output_rescale_factor); } else { if (is_quantized_ && clipInputActivations) { @@ -2706,20 +2711,27 @@ void AttentionBody::Eval(int N, DataType* output, stream); // 2. int8 matmul (temp -> buffer2) - cutlassMatrixMulBTransposed((const int8_t*)temp, emb_ffn1_.weights_int8, - (float*)buffer2, batch, num_outputs, - num_inputs, 1, 0, 0, 0, - emb_ffn1_.accum_rescale_factor, 0.0f); - - // cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - // num_inputs, emb_ffn1_.accum_rescale_factor, - // emb_ffn1_.weights_int8, num_inputs, (int8_t*)temp, - // num_inputs, 0.0f, (float*)buffer2, num_outputs); + // cutlassMatrixMulBTransposed((const int8_t*)temp, + // emb_ffn1_.weights_int8, + // (int32_t*)buffer2, batch, num_outputs, + // num_inputs, 1, 0, 0, 0, 1.0f, 0.0f); + + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1, emb_ffn1_.weights_int8, num_inputs, + (int8_t*)temp, num_inputs, 0, (int32_t*)buffer2, + num_outputs); + + // 3. Dequantize and bias add (mixed precision) (buffer2 -> buffer1) + deQuantizeOutputMatrixBiasAdd( + (half*)buffer1, (int32_t*)buffer2, batch, num_outputs, 1, nullptr, + emb_ffn1_.output_rescale_factor, (half*)ip_emb_ffn_d1_b_, + activations_.ffn_activation, stream); + ReportCUDAErrors(cudaGetLastError()); // 3. Bias add (mixed precision) (buffer2 -> buffer1) - addBiasBatched(buffer1, (float*)buffer2, ip_emb_ffn_d1_b_, 1, batch, - num_outputs, activations_.ffn_activation, stream); + // addBiasBatched(buffer1, (float*)buffer2, ip_emb_ffn_d1_b_, 1, batch, + // num_outputs, activations_.ffn_activation, stream); } else { if (is_quantized_ && clipInputActivations) { @@ -2746,23 +2758,23 @@ void AttentionBody::Eval(int N, DataType* output, num_inputs, emb_ffn2_.input_scaling_factors, stream); // 2. int8 matmul (temp -> buffer2) - cutlassMatrixMulBTransposed((const int8_t*)temp, emb_ffn2_.weights_int8, - (float*)buffer2, batch, num_outputs, - num_inputs, 1, 0, 0, 0, - emb_ffn2_.accum_rescale_factor, 0.0f); + // cutlassMatrixMulBTransposed((const int8_t*)temp, emb_ffn2_.weights_int8, + // (int32_t*)buffer2, batch, num_outputs, + // num_inputs, 1, 0, 0, 0, 1.0f, 0.0f); - // cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - // num_inputs, emb_ffn2_.accum_rescale_factor, - // emb_ffn2_.weights_int8, num_inputs, (const int8_t*)temp, - // num_inputs, 0.0f, (float*)buffer2, num_outputs); + cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1, emb_ffn2_.weights_int8, num_inputs, + (const int8_t*)temp, num_inputs, 0, (int32_t*)buffer2, + num_outputs); // Embedding LN: skip connection and layer normalization (also bias add // of prev gemm) (buffer2 -> embedding/output_tensor) float alpha = (float)pow(2. * encoder_weights_.size(), -0.25); - LayerNorm(N * 64, embedding_ffn_size_, output_tensor, - (float*)buffer2, ip_emb_ffn_d2_b_, embedding, - ip_emb_ffn_ln_g_, ip_emb_ffn_ln_b_, 1e-3, alpha, - ACTIVATION_NONE, stream); + LayerNorm( + batch, num_outputs, output_tensor, (int32_t*)buffer2, + ip_emb_ffn_d2_b_, embedding, ip_emb_ffn_ln_g_, ip_emb_ffn_ln_b_, + 1e-3, alpha, ACTIVATION_NONE, stream, + emb_ffn2_.output_rescale_factor); } else { if (is_quantized_ && clipInputActivations) { @@ -2778,12 +2790,16 @@ void AttentionBody::Eval(int N, DataType* output, // Embedding LN: skip connection and layer normilization (also bias add // of prev gemm) buffer2 -> embedding float alpha = (float)pow(2. * encoder_weights_.size(), -0.25); - LayerNorm(N * 64, embedding_ffn_size_, output_tensor, buffer2, + LayerNorm(batch, num_outputs, output_tensor, buffer2, ip_emb_ffn_d2_b_, embedding, ip_emb_ffn_ln_g_, ip_emb_ffn_ln_b_, 1e-3, alpha, ACTIVATION_NONE, stream); } } + // dumpTensor((int32_t*)buffer1, num_outputs * 64, "embed ffn2 cublas", true); + // dumpTensor((DataType*)output_tensor, embedding_ffn_size_ * 64, + // "embed ffn2 ln", true); + // exit(0); } else { // 1. square embedding (fully connected layer) // Input data in NHWC layout N*(64)*C, output is N*(64)*embedding_op_size_ diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index dcb10bb8e5..82ecb5d687 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -344,7 +344,8 @@ struct MatMulQuantizationData { float* output_deq_factors; // per-tensor. Always in cpu memory (passed as constants to dequantization kernels) float* input_matrix_max_values; // max values of input matrix (always in CPU memory) float* output_matrix_max_values; // max values in output matrix (always in CPU memory) - float accum_rescale_factor; // accumulator rescale factor for matmuls to prevent overflow + float output_rescale_factor; // accumulator rescale factor for matmuls to + // prevent overflow float fp16_clip_scale_factor; // scale factor for clipping inputs in fp16 or fp32 mode }; From 868b9ac8d0274925605e601d289b55cf037aa6c7 Mon Sep 17 00:00:00 2001 From: almaudoh Date: Mon, 13 May 2024 03:17:46 +0200 Subject: [PATCH 62/70] Fix bugs in int8 implementation - ith extra (ssuper) pair of eyes from @tilps --- src/neural/cuda/common_kernels.cu | 2 +- src/neural/cuda/cutlass_kernels.cu | 59 +++------- src/neural/cuda/layers.cc | 166 ++++++++++++----------------- src/neural/cuda/layers.h | 8 +- src/neural/cuda/network_cuda.cc | 2 +- 5 files changed, 86 insertions(+), 151 deletions(-) diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index cc8175e461..4f8a7e3239 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -992,7 +992,7 @@ __global__ void layer_norm_kernel(int N, int C, T* output, const IT* input, for (int i = 0; i < 8; i++) val[i] = (float)inp[i] * dequant_scale; copyAs(&inp[0], &input[tensorIndex + 8]); - copyAs(&inp[4], &input[tensorIndex + 16]); + copyAs(&inp[4], &input[tensorIndex + 12]); for (int i = 0; i < 8; i++) val[i + 8] = (float)inp[i] * dequant_scale; } else if (std::is_same::value) { half inp[8]; diff --git a/src/neural/cuda/cutlass_kernels.cu b/src/neural/cuda/cutlass_kernels.cu index 3b16431307..674a6977e7 100644 --- a/src/neural/cuda/cutlass_kernels.cu +++ b/src/neural/cuda/cutlass_kernels.cu @@ -886,7 +886,8 @@ void calibrateGemmForInt8(int8_t* weights_int8, float* input_scaling_factors, // process 8 elements per thread (in x dimension) __global__ void quantizeMatrix(int8_t* output, const half* input, int height, - int width, const float* scale) { + int width, const float* scaleArr, + const float scale) { int x = (blockIdx.x * blockDim.x + threadIdx.x) * 8; int y = blockIdx.y * blockDim.y + threadIdx.y; @@ -897,47 +898,16 @@ __global__ void quantizeMatrix(int8_t* output, const half* input, int height, int8_t op[8]; copyAs(&ip[0], &input[y * width + x]); - copyAs(&factor[0], &scale[x]); - copyAs(&factor[4], &scale[x + 4]); - for (int i = 0; i < 8; i++) { - float val = roundf((float)ip[i] / factor[i]); - if (val > 127) val = 127; - if (val < -128) val = -128; - op[i] = (int8_t)(val); + if (scaleArr) { + copyAs(&factor[0], &scaleArr[x]); + copyAs(&factor[4], &scaleArr[x + 4]); + } else { + for (int i = 0; i < 8; i++) factor[i] = scale; } - copyAs(&output[y * width + x], &op[0]); -} - -// The scale is per column -void quantizeActivationMatrix(int8_t* output, const half* input, int height, - int width, const float* scale, - cudaStream_t stream) { - dim3 blockDim(16, 16); - dim3 gridDim(lczero::cudnn_backend::DivUp(width, 16 * 8), - lczero::cudnn_backend::DivUp(height, 16)); - quantizeMatrix<<>>(output, input, height, width, - scale); - ReportCUDAErrors(cudaGetLastError()); -} - -// Quantize matrix with single scale value -// process 8 elements per thread (in x dimension) -__global__ void quantizeMatrix(int8_t* output, const half* input, int height, - int width, const float scale) { - int x = (blockIdx.x * blockDim.x + threadIdx.x) * 8; - int y = blockIdx.y * blockDim.y + threadIdx.y; - - if (x >= width || y >= height) return; - - half ip[8]; - int8_t op[8]; - - copyAs(&ip[0], &input[y * width + x]); - for (int i = 0; i < 8; i++) { - float val = roundf((float)ip[i] / scale); + float val = roundf((float)ip[i] / (factor[i] + 1e-5f)); if (val > 127) val = 127; if (val < -128) val = -128; op[i] = (int8_t)(val); @@ -946,7 +916,7 @@ __global__ void quantizeMatrix(int8_t* output, const half* input, int height, copyAs(&output[y * width + x], &op[0]); } -// The scale is for all columns. +// The scale is per column void quantizeActivationMatrix(int8_t* output, const half* input, int height, int width, const float scale, cudaStream_t stream) { @@ -954,7 +924,7 @@ void quantizeActivationMatrix(int8_t* output, const half* input, int height, dim3 gridDim(lczero::cudnn_backend::DivUp(width, 16 * 8), lczero::cudnn_backend::DivUp(height, 16)); quantizeMatrix<<>>(output, input, height, width, - scale); + nullptr, scale); ReportCUDAErrors(cudaGetLastError()); } @@ -967,11 +937,12 @@ __global__ void clipMatrix(T* output, const T* input, const float scale_factor, if (x >= width || y >= height) return; - float ulimit = 127.0 * scale_factor; - float llimit = -128.0 * scale_factor; float val = (float)input[y * width + x]; - if (val > ulimit) val = ulimit; - if (val < llimit) val = llimit; + val /= (1e-5 + scale_factor); + val = roundf(val); + if (val > 127.0f) val = 127.0f; + if (val < -128.0f) val = -128.0f; + val *= scale_factor; output[y * width + x] = (T)val; } diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 351843a172..9a1d7593ce 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -1548,10 +1548,6 @@ void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int BStride, int OutStride, int VecStride, float alphaf, float betaf); -void quantizeActivationMatrix(int8_t* output, const half* input, int height, - int width, const float* scale, - cudaStream_t stream); - void quantizeActivationMatrix(int8_t* output, const half* input, int height, int width, const float scale, cudaStream_t stream); @@ -1571,34 +1567,20 @@ static void LoadQuantizationData(MatMulQuantizationData& data, half* weights, int input_len, int output_len, const std::vector& weight_factors, const std::vector& input_factors, - cudaStream_t stream, - bool vector_output_factor = false) { + cudaStream_t stream) { // Load weights for INT8 inference - // (per-column) scaling factors for the input and output. - ReportCUDAErrors( - cudaMalloc(&data.input_scaling_factors, input_len * sizeof(float))); - if (input_factors.size() > 1) { // ReportCUDAErrors(cudaMemcpy( // data.output_scaling_factors, input_factors.data(), // input_factors.size() * sizeof(float), cudaMemcpyHostToDevice)); throw Exception("Channelwise quantization not yet supported."); } else { - // Repeatedly fill values into the input factors buffer. - fillGpuArray(data.input_scaling_factors, input_factors[0], input_len); - - // Same factor will be used for clipping inputs in fp16 mode. - data.fp16_clip_scale_factor = input_factors[0]; + // Same factor will be used for clipping inputs in int8 and fp16 mode. + data.input_scaling_factor = input_factors[0]; - // Repeatedly fill values into the output factors buffer. - data.output_rescale_factor = input_factors[0] * weight_factors[0]; - if (vector_output_factor) { - ReportCUDAErrors( - cudaMalloc(&data.output_scaling_factors, output_len * sizeof(float))); - fillGpuArray(data.output_scaling_factors, data.output_rescale_factor, - output_len); - } + // Output scaling factor = input factor x weight factor. + data.output_scaling_factor = input_factors[0] * weight_factors[0]; } // Load weights and run a GPU kernel to scale it. @@ -1624,8 +1606,8 @@ static void LoadQKVQuantizationData(MatMulQuantizationData& data, // Load weights for INT8 inference. // (per-column) scaling factors for the input and output. - ReportCUDAErrors( - cudaMalloc(&data.input_scaling_factors, input_len * sizeof(float))); + // ReportCUDAErrors( + // cudaMalloc(&data.input_scaling_factors, input_len * sizeof(float))); ReportCUDAErrors( cudaMalloc(&data.output_scaling_factors, output_len * 3 * sizeof(float))); @@ -1635,11 +1617,8 @@ static void LoadQKVQuantizationData(MatMulQuantizationData& data, // input_factors.size() * sizeof(float), cudaMemcpyHostToDevice)); throw Exception("Channelwise quantization not yet supported."); } else { - // Repeatedly fill values into the input factors buffer. - fillGpuArray(data.input_scaling_factors, input_factors[0], input_len); - - // Same factor will be used for clipping inputs in fp16 mode. - data.fp16_clip_scale_factor = input_factors[0]; + // Same factor will be used for clipping inputs in int8 and fp16 mode. + data.input_scaling_factor = input_factors[0]; // Repeatedly fill values into the output factors buffer. fillGpuArray(data.output_scaling_factors, @@ -1848,15 +1827,14 @@ static void cublasXGemmStridedBatched( float beta, void* C, int ldc, long long int strideC, int batchCount) { const bool int8 = std::is_same::value; const bool fp16 = std::is_same::value; - const bool out_float = std::is_same::value; - if (int8 && out_float) { - // @TODO Gemm is failing. All zeros. - int8_t alpha_i = (int8_t)alpha; - OutDataType beta_i = (OutDataType)beta; + const bool out_int32 = std::is_same::value; + if (int8 && out_int32) { + int32_t alpha_i = (int32_t)alpha; + int32_t beta_i = (int32_t)beta; ReportCUBLASErrors(cublasGemmStridedBatchedEx( handle, transa, transb, m, n, k, &alpha_i, A, CUDA_R_8I, lda, strideA, - B, CUDA_R_8I, ldb, strideB, &beta_i, C, CUDA_R_32F, ldc, strideC, - batchCount, CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); + B, CUDA_R_8I, ldb, strideB, &beta_i, C, CUDA_R_32I, ldc, strideC, + batchCount, CUBLAS_COMPUTE_32I, CUBLAS_GEMM_DEFAULT)); } else if (fp16) { unsigned short alpha_h = FP32toFP16(alpha); unsigned short beta_h = FP32toFP16(beta); @@ -1991,30 +1969,26 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // TODO: Fuse this step with layer-norm of previous block quantizeActivationMatrix((int8_t*)scratch, (const half*)in_out_tensor, batch, embedding_op_size_, - qkv_.input_scaling_factors, stream); + qkv_.input_scaling_factor, stream); // 2. perform int8 GEMM (scratch -> buffer1) - cutlassMatrixMulBTransposed((const int8_t*)scratch, qkv_.weights_int8, - (int32_t*)buffer1, batch, num_outputs, - num_inputs, 3, 0, num_inputs * num_outputs, - num_outputs, 1.0f, 0.0f); - // dumpTensor((float*)qkv_.output_scaling_factors, 5, - // "encoder 1 qkv output dequant", true); - // dumpTensor((int32_t*)buffer1, 5, "encoder 1 qkv output dequant", true); - // exit(0); + // cutlassMatrixMulBTransposed((const int8_t*)scratch, qkv_.weights_int8, + // (int32_t*)buffer1, batch, num_outputs, + // num_inputs, 3, 0, num_inputs * num_outputs, + // num_outputs * batch_to_use, 1.0f, 0.0f); + + cublasXGemmStridedBatched( + cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, qkv_.weights_int8, num_inputs, num_inputs * + num_outputs, (const int8_t*)scratch, num_inputs, 0, 0.0f, + (int32_t*)buffer1, num_outputs, num_outputs * batch_to_use, 3); // 3. Dequantize and bias add - mixed precision (buffer1 -> mha_q) - deQuantizeOutputMatrixBiasAdd((half*)mha_q, (int32_t*)buffer1, batch, + deQuantizeOutputMatrixBiasAdd((half*)mha_q, (int32_t*)buffer1, batch_to_use, num_outputs, 3, qkv_.output_scaling_factors, 1.0f, (half*)mha_qkv_b, ACTIVATION_NONE, stream); - // cublasXGemmStridedBatched( - // cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - // num_inputs, 1.0f, qkv_.weights_int8, num_inputs, num_inputs * - // num_outputs, (const int8_t*)scratch, num_inputs, 0, 0.0f, - // (float*)buffer1, num_outputs, num_outputs * batch_to_use, 3); - // 3. Bias add - mixed precision (buffer1 -> mha_q) // addBiasBatched(mha_q, (float*)buffer1, mha_qkv_b, 3, // batch, num_outputs, batch_to_use, @@ -2026,7 +2000,7 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, if (is_quantized_ && clipInputActivations) { clipActivationMatrix( (DataType*)in_out_tensor, (const DataType*)in_out_tensor, - qkv_.fp16_clip_scale_factor, batch, num_inputs, stream); + qkv_.input_scaling_factor, batch, num_inputs, stream); } cublasXGemmStridedBatched( @@ -2038,8 +2012,6 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, addBiasBatched(mha_q, mha_q, mha_qkv_b, 3, batch, num_outputs, batch_to_use, ACTIVATION_NONE, stream); } - // dumpTensor(mha_q, num_outputs * 64 * 3, "encoder 1 qkv output", true); - // exit(0); } // Apply split_heads() to q, k and v @@ -2159,7 +2131,7 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // 1. quantize the inputs (buffer2 -> scratch) // TODO: Fuse this step with the previous fused MHA quantizeActivationMatrix((int8_t*)scratch, (const half*)buffer2, batch, - num_inputs, mha_dense_.input_scaling_factors, + num_inputs, mha_dense_.input_scaling_factor, stream); // 2. perform int8 GEMM (scratch -> buffer2) @@ -2174,12 +2146,12 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, LayerNorm( N * 64, embedding_op_size_, scratch, (int32_t*)buffer2, mha_dense_b, in_out_tensor, ln1_gammas, ln1_betas, default_eps_, alpha_, - ACTIVATION_NONE, stream, mha_dense_.output_rescale_factor); + ACTIVATION_NONE, stream, mha_dense_.output_scaling_factor); } else { if (is_quantized_ && clipInputActivations) { clipActivationMatrix( (DataType*)buffer2, (const DataType*)buffer2, - mha_dense_.fp16_clip_scale_factor, batch, num_inputs, stream); + mha_dense_.input_scaling_factor, batch, num_inputs, stream); } cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, @@ -2205,7 +2177,7 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // 1. quantize the inputs (scratch -> in_out_tensor) // TODO: Fuse this step with LN1 (should be easy) quantizeActivationMatrix((int8_t*)in_out_tensor, (const half*)scratch, - batch, num_inputs, ffn1_.input_scaling_factors, + batch, num_inputs, ffn1_.input_scaling_factor, stream); // 2. perform int8 GEMM (in_out_tensor -> buffer1) @@ -2216,7 +2188,7 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // 3. Dequantize and bias add - mixed precision (buffer1 -> in_out_tensor) deQuantizeOutputMatrixBiasAdd( (half*)in_out_tensor, (int32_t*)buffer1, batch, num_outputs, 1, - nullptr, ffn1_.output_rescale_factor, (half*)ffn_dense1_b, + nullptr, ffn1_.output_scaling_factor, (half*)ffn_dense1_b, ffn_activation_, stream); // 3. Bias add - mixed precision (buffer1 -> in_out_tensor) @@ -2229,7 +2201,7 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // as skip connection later at the layer norm. clipActivationMatrix( (DataType*)in_out_tensor, (const DataType*)scratch, - ffn1_.fp16_clip_scale_factor, batch, num_inputs, stream); + ffn1_.input_scaling_factor, batch, num_inputs, stream); cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)ffn_dense1_w, num_inputs, @@ -2256,7 +2228,7 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // TODO: Fuse this step with above bias add at least (or ideally with the // above GEMM) quantizeActivationMatrix((int8_t*)buffer1, (const half*)in_out_tensor, - batch, num_inputs, ffn2_.input_scaling_factors, + batch, num_inputs, ffn2_.input_scaling_factor, stream); // 2. perform int8 GEMM (buffer1 -> buffer2) @@ -2270,13 +2242,13 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, LayerNorm( N * 64, embedding_op_size_, in_out_tensor, (int32_t*)buffer2, ffn_dense2_b, scratch, ln2_gammas, ln2_betas, default_eps_, alpha_, - ACTIVATION_NONE, stream, ffn2_.output_rescale_factor); + ACTIVATION_NONE, stream, ffn2_.output_scaling_factor); } else { if (is_quantized_ && clipInputActivations) { clipActivationMatrix( (DataType*)in_out_tensor, (const DataType*)in_out_tensor, - ffn2_.fp16_clip_scale_factor, batch, num_inputs, stream); + ffn2_.input_scaling_factor, batch, num_inputs, stream); } cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)ffn_dense2_w, num_inputs, @@ -2301,8 +2273,8 @@ void AttentionPolicyHead::Eval( void* scratch, size_t scratch_size, cudnnHandle_t /*cudnn*/, cublasHandle_t cublas, cudaStream_t stream, DataType*** offset_pointers) { DataType* input2_tensor = (DataType*)input2; - DataType* buffer1 = output + scratch_size / (2 * sizeof(DataType)); - DataType* buffer2 = input2_tensor + scratch_size / (2 * sizeof(DataType)); + DataType* buffer1 = output + scratch_size / (2 * sizeof(float)); + DataType* buffer2 = input2_tensor + scratch_size / (2 * sizeof(float)); int inputC = this->input_->GetC(); if (!attention_body_) @@ -2432,17 +2404,17 @@ EncoderBlock::~EncoderBlock() { } if (int8_inf_) { ReportCUDAErrors(cudaFree(qkv_.weights_int8)); - ReportCUDAErrors(cudaFree(qkv_.input_scaling_factors)); + // ReportCUDAErrors(cudaFree(qkv_.input_scaling_factors)); ReportCUDAErrors(cudaFree(qkv_.output_scaling_factors)); ReportCUDAErrors(cudaFree(mha_dense_.weights_int8)); - ReportCUDAErrors(cudaFree(mha_dense_.input_scaling_factors)); - ReportCUDAErrors(cudaFree(mha_dense_.output_scaling_factors)); + // ReportCUDAErrors(cudaFree(mha_dense_.input_scaling_factors)); + // ReportCUDAErrors(cudaFree(mha_dense_.output_scaling_factors)); ReportCUDAErrors(cudaFree(ffn1_.weights_int8)); - ReportCUDAErrors(cudaFree(ffn1_.input_scaling_factors)); - ReportCUDAErrors(cudaFree(ffn1_.output_scaling_factors)); + // ReportCUDAErrors(cudaFree(ffn1_.input_scaling_factors)); + // ReportCUDAErrors(cudaFree(ffn1_.output_scaling_factors)); ReportCUDAErrors(cudaFree(ffn2_.weights_int8)); - ReportCUDAErrors(cudaFree(ffn2_.input_scaling_factors)); - ReportCUDAErrors(cudaFree(ffn2_.output_scaling_factors)); + // ReportCUDAErrors(cudaFree(ffn2_.input_scaling_factors)); + // ReportCUDAErrors(cudaFree(ffn2_.output_scaling_factors)); } // else if (int8_cali_) { // free(qkv_.input_matrix_max_values); @@ -2707,24 +2679,23 @@ void AttentionBody::Eval(int N, DataType* output, if (is_quantized_ && int8_inf_) { // 1. quantize (embedding -> temp) quantizeActivationMatrix((int8_t*)temp, (const half*)embedding, batch, - num_inputs, emb_ffn1_.input_scaling_factors, + num_inputs, emb_ffn1_.input_scaling_factor, stream); // 2. int8 matmul (temp -> buffer2) - // cutlassMatrixMulBTransposed((const int8_t*)temp, - // emb_ffn1_.weights_int8, - // (int32_t*)buffer2, batch, num_outputs, - // num_inputs, 1, 0, 0, 0, 1.0f, 0.0f); + cutlassMatrixMulBTransposed((const int8_t*)temp, emb_ffn1_.weights_int8, + (int32_t*)buffer2, batch, num_outputs, + num_inputs, 1, 0, 0, 0, 1.0f, 0.0f); - cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1, emb_ffn1_.weights_int8, num_inputs, - (int8_t*)temp, num_inputs, 0, (int32_t*)buffer2, - num_outputs); + // cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + // num_inputs, 1, emb_ffn1_.weights_int8, num_inputs, + // (int8_t*)temp, num_inputs, 0, (int32_t*)buffer2, + // num_outputs); // 3. Dequantize and bias add (mixed precision) (buffer2 -> buffer1) deQuantizeOutputMatrixBiasAdd( (half*)buffer1, (int32_t*)buffer2, batch, num_outputs, 1, nullptr, - emb_ffn1_.output_rescale_factor, (half*)ip_emb_ffn_d1_b_, + emb_ffn1_.output_scaling_factor, (half*)ip_emb_ffn_d1_b_, activations_.ffn_activation, stream); ReportCUDAErrors(cudaGetLastError()); @@ -2737,7 +2708,7 @@ void AttentionBody::Eval(int N, DataType* output, if (is_quantized_ && clipInputActivations) { clipActivationMatrix( (DataType*)temp, (const DataType*)embedding, - emb_ffn1_.fp16_clip_scale_factor, batch, num_inputs, stream); + emb_ffn1_.input_scaling_factor, batch, num_inputs, stream); } cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)ip_emb_ffn_d1_w_, @@ -2755,17 +2726,18 @@ void AttentionBody::Eval(int N, DataType* output, if (is_quantized_ && int8_inf_) { // 1. quantize (buffer1 -> temp) quantizeActivationMatrix((int8_t*)temp, (const half*)buffer1, batch, - num_inputs, emb_ffn2_.input_scaling_factors, + num_inputs, emb_ffn2_.input_scaling_factor, stream); + // 2. int8 matmul (temp -> buffer2) - // cutlassMatrixMulBTransposed((const int8_t*)temp, emb_ffn2_.weights_int8, - // (int32_t*)buffer2, batch, num_outputs, - // num_inputs, 1, 0, 0, 0, 1.0f, 0.0f); + cutlassMatrixMulBTransposed((const int8_t*)temp, emb_ffn2_.weights_int8, + (int32_t*)buffer2, batch, num_outputs, + num_inputs, 1, 0, 0, 0, 1.0f, 0.0f); - cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1, emb_ffn2_.weights_int8, num_inputs, - (const int8_t*)temp, num_inputs, 0, (int32_t*)buffer2, - num_outputs); + // cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + // num_inputs, 1, emb_ffn2_.weights_int8, num_inputs, + // (const int8_t*)temp, num_inputs, 0, (int32_t*)buffer2, + // num_outputs); // Embedding LN: skip connection and layer normalization (also bias add // of prev gemm) (buffer2 -> embedding/output_tensor) @@ -2774,13 +2746,13 @@ void AttentionBody::Eval(int N, DataType* output, batch, num_outputs, output_tensor, (int32_t*)buffer2, ip_emb_ffn_d2_b_, embedding, ip_emb_ffn_ln_g_, ip_emb_ffn_ln_b_, 1e-3, alpha, ACTIVATION_NONE, stream, - emb_ffn2_.output_rescale_factor); + emb_ffn2_.output_scaling_factor); } else { if (is_quantized_ && clipInputActivations) { clipActivationMatrix( (DataType*)buffer1, (const DataType*)buffer1, - emb_ffn2_.fp16_clip_scale_factor, batch, num_inputs, stream); + emb_ffn2_.input_scaling_factor, batch, num_inputs, stream); } cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, @@ -2796,10 +2768,6 @@ void AttentionBody::Eval(int N, DataType* output, stream); } } - // dumpTensor((int32_t*)buffer1, num_outputs * 64, "embed ffn2 cublas", true); - // dumpTensor((DataType*)output_tensor, embedding_ffn_size_ * 64, - // "embed ffn2 ln", true); - // exit(0); } else { // 1. square embedding (fully connected layer) // Input data in NHWC layout N*(64)*C, output is N*(64)*embedding_op_size_ diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index 82ecb5d687..1fb080a07a 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -341,12 +341,8 @@ struct MatMulQuantizationData { // quantization float* output_scaling_factors; // per-column scaling factors for output // dequantization - float* output_deq_factors; // per-tensor. Always in cpu memory (passed as constants to dequantization kernels) - float* input_matrix_max_values; // max values of input matrix (always in CPU memory) - float* output_matrix_max_values; // max values in output matrix (always in CPU memory) - float output_rescale_factor; // accumulator rescale factor for matmuls to - // prevent overflow - float fp16_clip_scale_factor; // scale factor for clipping inputs in fp16 or fp32 mode + float output_scaling_factor; // single value output dequantization factor + float input_scaling_factor; // single value input quantization factor }; diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 003bbb8357..c12ef69a83 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -374,7 +374,7 @@ class CudaNetwork : public Network { sizeof(DataType); const size_t attentionBodySize = - getMaxAttentionBodySize(weights, max_batch_size_) * sizeof(DataType); + getMaxAttentionBodySize(weights, max_batch_size_) * sizeof(float); scratch_size_ = std::max(scratch_size_, std::max(attentionPolicySize, attentionBodySize)); From 30bb6403d6d023a3e22a4ae0744fb9da49d1a5b6 Mon Sep 17 00:00:00 2001 From: almaudoh Date: Mon, 13 May 2024 03:17:46 +0200 Subject: [PATCH 63/70] Fix bugs in int8 implementation - with extra (super) pair of eyes from @tilps --- src/neural/cuda/common_kernels.cu | 2 +- src/neural/cuda/cutlass_kernels.cu | 59 +++------- src/neural/cuda/layers.cc | 166 ++++++++++++----------------- src/neural/cuda/layers.h | 8 +- src/neural/cuda/network_cuda.cc | 2 +- 5 files changed, 86 insertions(+), 151 deletions(-) diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index cc8175e461..4f8a7e3239 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -992,7 +992,7 @@ __global__ void layer_norm_kernel(int N, int C, T* output, const IT* input, for (int i = 0; i < 8; i++) val[i] = (float)inp[i] * dequant_scale; copyAs(&inp[0], &input[tensorIndex + 8]); - copyAs(&inp[4], &input[tensorIndex + 16]); + copyAs(&inp[4], &input[tensorIndex + 12]); for (int i = 0; i < 8; i++) val[i + 8] = (float)inp[i] * dequant_scale; } else if (std::is_same::value) { half inp[8]; diff --git a/src/neural/cuda/cutlass_kernels.cu b/src/neural/cuda/cutlass_kernels.cu index 3b16431307..674a6977e7 100644 --- a/src/neural/cuda/cutlass_kernels.cu +++ b/src/neural/cuda/cutlass_kernels.cu @@ -886,7 +886,8 @@ void calibrateGemmForInt8(int8_t* weights_int8, float* input_scaling_factors, // process 8 elements per thread (in x dimension) __global__ void quantizeMatrix(int8_t* output, const half* input, int height, - int width, const float* scale) { + int width, const float* scaleArr, + const float scale) { int x = (blockIdx.x * blockDim.x + threadIdx.x) * 8; int y = blockIdx.y * blockDim.y + threadIdx.y; @@ -897,47 +898,16 @@ __global__ void quantizeMatrix(int8_t* output, const half* input, int height, int8_t op[8]; copyAs(&ip[0], &input[y * width + x]); - copyAs(&factor[0], &scale[x]); - copyAs(&factor[4], &scale[x + 4]); - for (int i = 0; i < 8; i++) { - float val = roundf((float)ip[i] / factor[i]); - if (val > 127) val = 127; - if (val < -128) val = -128; - op[i] = (int8_t)(val); + if (scaleArr) { + copyAs(&factor[0], &scaleArr[x]); + copyAs(&factor[4], &scaleArr[x + 4]); + } else { + for (int i = 0; i < 8; i++) factor[i] = scale; } - copyAs(&output[y * width + x], &op[0]); -} - -// The scale is per column -void quantizeActivationMatrix(int8_t* output, const half* input, int height, - int width, const float* scale, - cudaStream_t stream) { - dim3 blockDim(16, 16); - dim3 gridDim(lczero::cudnn_backend::DivUp(width, 16 * 8), - lczero::cudnn_backend::DivUp(height, 16)); - quantizeMatrix<<>>(output, input, height, width, - scale); - ReportCUDAErrors(cudaGetLastError()); -} - -// Quantize matrix with single scale value -// process 8 elements per thread (in x dimension) -__global__ void quantizeMatrix(int8_t* output, const half* input, int height, - int width, const float scale) { - int x = (blockIdx.x * blockDim.x + threadIdx.x) * 8; - int y = blockIdx.y * blockDim.y + threadIdx.y; - - if (x >= width || y >= height) return; - - half ip[8]; - int8_t op[8]; - - copyAs(&ip[0], &input[y * width + x]); - for (int i = 0; i < 8; i++) { - float val = roundf((float)ip[i] / scale); + float val = roundf((float)ip[i] / (factor[i] + 1e-5f)); if (val > 127) val = 127; if (val < -128) val = -128; op[i] = (int8_t)(val); @@ -946,7 +916,7 @@ __global__ void quantizeMatrix(int8_t* output, const half* input, int height, copyAs(&output[y * width + x], &op[0]); } -// The scale is for all columns. +// The scale is per column void quantizeActivationMatrix(int8_t* output, const half* input, int height, int width, const float scale, cudaStream_t stream) { @@ -954,7 +924,7 @@ void quantizeActivationMatrix(int8_t* output, const half* input, int height, dim3 gridDim(lczero::cudnn_backend::DivUp(width, 16 * 8), lczero::cudnn_backend::DivUp(height, 16)); quantizeMatrix<<>>(output, input, height, width, - scale); + nullptr, scale); ReportCUDAErrors(cudaGetLastError()); } @@ -967,11 +937,12 @@ __global__ void clipMatrix(T* output, const T* input, const float scale_factor, if (x >= width || y >= height) return; - float ulimit = 127.0 * scale_factor; - float llimit = -128.0 * scale_factor; float val = (float)input[y * width + x]; - if (val > ulimit) val = ulimit; - if (val < llimit) val = llimit; + val /= (1e-5 + scale_factor); + val = roundf(val); + if (val > 127.0f) val = 127.0f; + if (val < -128.0f) val = -128.0f; + val *= scale_factor; output[y * width + x] = (T)val; } diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 351843a172..9a1d7593ce 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -1548,10 +1548,6 @@ void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int BStride, int OutStride, int VecStride, float alphaf, float betaf); -void quantizeActivationMatrix(int8_t* output, const half* input, int height, - int width, const float* scale, - cudaStream_t stream); - void quantizeActivationMatrix(int8_t* output, const half* input, int height, int width, const float scale, cudaStream_t stream); @@ -1571,34 +1567,20 @@ static void LoadQuantizationData(MatMulQuantizationData& data, half* weights, int input_len, int output_len, const std::vector& weight_factors, const std::vector& input_factors, - cudaStream_t stream, - bool vector_output_factor = false) { + cudaStream_t stream) { // Load weights for INT8 inference - // (per-column) scaling factors for the input and output. - ReportCUDAErrors( - cudaMalloc(&data.input_scaling_factors, input_len * sizeof(float))); - if (input_factors.size() > 1) { // ReportCUDAErrors(cudaMemcpy( // data.output_scaling_factors, input_factors.data(), // input_factors.size() * sizeof(float), cudaMemcpyHostToDevice)); throw Exception("Channelwise quantization not yet supported."); } else { - // Repeatedly fill values into the input factors buffer. - fillGpuArray(data.input_scaling_factors, input_factors[0], input_len); - - // Same factor will be used for clipping inputs in fp16 mode. - data.fp16_clip_scale_factor = input_factors[0]; + // Same factor will be used for clipping inputs in int8 and fp16 mode. + data.input_scaling_factor = input_factors[0]; - // Repeatedly fill values into the output factors buffer. - data.output_rescale_factor = input_factors[0] * weight_factors[0]; - if (vector_output_factor) { - ReportCUDAErrors( - cudaMalloc(&data.output_scaling_factors, output_len * sizeof(float))); - fillGpuArray(data.output_scaling_factors, data.output_rescale_factor, - output_len); - } + // Output scaling factor = input factor x weight factor. + data.output_scaling_factor = input_factors[0] * weight_factors[0]; } // Load weights and run a GPU kernel to scale it. @@ -1624,8 +1606,8 @@ static void LoadQKVQuantizationData(MatMulQuantizationData& data, // Load weights for INT8 inference. // (per-column) scaling factors for the input and output. - ReportCUDAErrors( - cudaMalloc(&data.input_scaling_factors, input_len * sizeof(float))); + // ReportCUDAErrors( + // cudaMalloc(&data.input_scaling_factors, input_len * sizeof(float))); ReportCUDAErrors( cudaMalloc(&data.output_scaling_factors, output_len * 3 * sizeof(float))); @@ -1635,11 +1617,8 @@ static void LoadQKVQuantizationData(MatMulQuantizationData& data, // input_factors.size() * sizeof(float), cudaMemcpyHostToDevice)); throw Exception("Channelwise quantization not yet supported."); } else { - // Repeatedly fill values into the input factors buffer. - fillGpuArray(data.input_scaling_factors, input_factors[0], input_len); - - // Same factor will be used for clipping inputs in fp16 mode. - data.fp16_clip_scale_factor = input_factors[0]; + // Same factor will be used for clipping inputs in int8 and fp16 mode. + data.input_scaling_factor = input_factors[0]; // Repeatedly fill values into the output factors buffer. fillGpuArray(data.output_scaling_factors, @@ -1848,15 +1827,14 @@ static void cublasXGemmStridedBatched( float beta, void* C, int ldc, long long int strideC, int batchCount) { const bool int8 = std::is_same::value; const bool fp16 = std::is_same::value; - const bool out_float = std::is_same::value; - if (int8 && out_float) { - // @TODO Gemm is failing. All zeros. - int8_t alpha_i = (int8_t)alpha; - OutDataType beta_i = (OutDataType)beta; + const bool out_int32 = std::is_same::value; + if (int8 && out_int32) { + int32_t alpha_i = (int32_t)alpha; + int32_t beta_i = (int32_t)beta; ReportCUBLASErrors(cublasGemmStridedBatchedEx( handle, transa, transb, m, n, k, &alpha_i, A, CUDA_R_8I, lda, strideA, - B, CUDA_R_8I, ldb, strideB, &beta_i, C, CUDA_R_32F, ldc, strideC, - batchCount, CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); + B, CUDA_R_8I, ldb, strideB, &beta_i, C, CUDA_R_32I, ldc, strideC, + batchCount, CUBLAS_COMPUTE_32I, CUBLAS_GEMM_DEFAULT)); } else if (fp16) { unsigned short alpha_h = FP32toFP16(alpha); unsigned short beta_h = FP32toFP16(beta); @@ -1991,30 +1969,26 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // TODO: Fuse this step with layer-norm of previous block quantizeActivationMatrix((int8_t*)scratch, (const half*)in_out_tensor, batch, embedding_op_size_, - qkv_.input_scaling_factors, stream); + qkv_.input_scaling_factor, stream); // 2. perform int8 GEMM (scratch -> buffer1) - cutlassMatrixMulBTransposed((const int8_t*)scratch, qkv_.weights_int8, - (int32_t*)buffer1, batch, num_outputs, - num_inputs, 3, 0, num_inputs * num_outputs, - num_outputs, 1.0f, 0.0f); - // dumpTensor((float*)qkv_.output_scaling_factors, 5, - // "encoder 1 qkv output dequant", true); - // dumpTensor((int32_t*)buffer1, 5, "encoder 1 qkv output dequant", true); - // exit(0); + // cutlassMatrixMulBTransposed((const int8_t*)scratch, qkv_.weights_int8, + // (int32_t*)buffer1, batch, num_outputs, + // num_inputs, 3, 0, num_inputs * num_outputs, + // num_outputs * batch_to_use, 1.0f, 0.0f); + + cublasXGemmStridedBatched( + cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + num_inputs, 1.0f, qkv_.weights_int8, num_inputs, num_inputs * + num_outputs, (const int8_t*)scratch, num_inputs, 0, 0.0f, + (int32_t*)buffer1, num_outputs, num_outputs * batch_to_use, 3); // 3. Dequantize and bias add - mixed precision (buffer1 -> mha_q) - deQuantizeOutputMatrixBiasAdd((half*)mha_q, (int32_t*)buffer1, batch, + deQuantizeOutputMatrixBiasAdd((half*)mha_q, (int32_t*)buffer1, batch_to_use, num_outputs, 3, qkv_.output_scaling_factors, 1.0f, (half*)mha_qkv_b, ACTIVATION_NONE, stream); - // cublasXGemmStridedBatched( - // cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - // num_inputs, 1.0f, qkv_.weights_int8, num_inputs, num_inputs * - // num_outputs, (const int8_t*)scratch, num_inputs, 0, 0.0f, - // (float*)buffer1, num_outputs, num_outputs * batch_to_use, 3); - // 3. Bias add - mixed precision (buffer1 -> mha_q) // addBiasBatched(mha_q, (float*)buffer1, mha_qkv_b, 3, // batch, num_outputs, batch_to_use, @@ -2026,7 +2000,7 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, if (is_quantized_ && clipInputActivations) { clipActivationMatrix( (DataType*)in_out_tensor, (const DataType*)in_out_tensor, - qkv_.fp16_clip_scale_factor, batch, num_inputs, stream); + qkv_.input_scaling_factor, batch, num_inputs, stream); } cublasXGemmStridedBatched( @@ -2038,8 +2012,6 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, addBiasBatched(mha_q, mha_q, mha_qkv_b, 3, batch, num_outputs, batch_to_use, ACTIVATION_NONE, stream); } - // dumpTensor(mha_q, num_outputs * 64 * 3, "encoder 1 qkv output", true); - // exit(0); } // Apply split_heads() to q, k and v @@ -2159,7 +2131,7 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // 1. quantize the inputs (buffer2 -> scratch) // TODO: Fuse this step with the previous fused MHA quantizeActivationMatrix((int8_t*)scratch, (const half*)buffer2, batch, - num_inputs, mha_dense_.input_scaling_factors, + num_inputs, mha_dense_.input_scaling_factor, stream); // 2. perform int8 GEMM (scratch -> buffer2) @@ -2174,12 +2146,12 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, LayerNorm( N * 64, embedding_op_size_, scratch, (int32_t*)buffer2, mha_dense_b, in_out_tensor, ln1_gammas, ln1_betas, default_eps_, alpha_, - ACTIVATION_NONE, stream, mha_dense_.output_rescale_factor); + ACTIVATION_NONE, stream, mha_dense_.output_scaling_factor); } else { if (is_quantized_ && clipInputActivations) { clipActivationMatrix( (DataType*)buffer2, (const DataType*)buffer2, - mha_dense_.fp16_clip_scale_factor, batch, num_inputs, stream); + mha_dense_.input_scaling_factor, batch, num_inputs, stream); } cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, @@ -2205,7 +2177,7 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // 1. quantize the inputs (scratch -> in_out_tensor) // TODO: Fuse this step with LN1 (should be easy) quantizeActivationMatrix((int8_t*)in_out_tensor, (const half*)scratch, - batch, num_inputs, ffn1_.input_scaling_factors, + batch, num_inputs, ffn1_.input_scaling_factor, stream); // 2. perform int8 GEMM (in_out_tensor -> buffer1) @@ -2216,7 +2188,7 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // 3. Dequantize and bias add - mixed precision (buffer1 -> in_out_tensor) deQuantizeOutputMatrixBiasAdd( (half*)in_out_tensor, (int32_t*)buffer1, batch, num_outputs, 1, - nullptr, ffn1_.output_rescale_factor, (half*)ffn_dense1_b, + nullptr, ffn1_.output_scaling_factor, (half*)ffn_dense1_b, ffn_activation_, stream); // 3. Bias add - mixed precision (buffer1 -> in_out_tensor) @@ -2229,7 +2201,7 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // as skip connection later at the layer norm. clipActivationMatrix( (DataType*)in_out_tensor, (const DataType*)scratch, - ffn1_.fp16_clip_scale_factor, batch, num_inputs, stream); + ffn1_.input_scaling_factor, batch, num_inputs, stream); cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)ffn_dense1_w, num_inputs, @@ -2256,7 +2228,7 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // TODO: Fuse this step with above bias add at least (or ideally with the // above GEMM) quantizeActivationMatrix((int8_t*)buffer1, (const half*)in_out_tensor, - batch, num_inputs, ffn2_.input_scaling_factors, + batch, num_inputs, ffn2_.input_scaling_factor, stream); // 2. perform int8 GEMM (buffer1 -> buffer2) @@ -2270,13 +2242,13 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, LayerNorm( N * 64, embedding_op_size_, in_out_tensor, (int32_t*)buffer2, ffn_dense2_b, scratch, ln2_gammas, ln2_betas, default_eps_, alpha_, - ACTIVATION_NONE, stream, ffn2_.output_rescale_factor); + ACTIVATION_NONE, stream, ffn2_.output_scaling_factor); } else { if (is_quantized_ && clipInputActivations) { clipActivationMatrix( (DataType*)in_out_tensor, (const DataType*)in_out_tensor, - ffn2_.fp16_clip_scale_factor, batch, num_inputs, stream); + ffn2_.input_scaling_factor, batch, num_inputs, stream); } cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)ffn_dense2_w, num_inputs, @@ -2301,8 +2273,8 @@ void AttentionPolicyHead::Eval( void* scratch, size_t scratch_size, cudnnHandle_t /*cudnn*/, cublasHandle_t cublas, cudaStream_t stream, DataType*** offset_pointers) { DataType* input2_tensor = (DataType*)input2; - DataType* buffer1 = output + scratch_size / (2 * sizeof(DataType)); - DataType* buffer2 = input2_tensor + scratch_size / (2 * sizeof(DataType)); + DataType* buffer1 = output + scratch_size / (2 * sizeof(float)); + DataType* buffer2 = input2_tensor + scratch_size / (2 * sizeof(float)); int inputC = this->input_->GetC(); if (!attention_body_) @@ -2432,17 +2404,17 @@ EncoderBlock::~EncoderBlock() { } if (int8_inf_) { ReportCUDAErrors(cudaFree(qkv_.weights_int8)); - ReportCUDAErrors(cudaFree(qkv_.input_scaling_factors)); + // ReportCUDAErrors(cudaFree(qkv_.input_scaling_factors)); ReportCUDAErrors(cudaFree(qkv_.output_scaling_factors)); ReportCUDAErrors(cudaFree(mha_dense_.weights_int8)); - ReportCUDAErrors(cudaFree(mha_dense_.input_scaling_factors)); - ReportCUDAErrors(cudaFree(mha_dense_.output_scaling_factors)); + // ReportCUDAErrors(cudaFree(mha_dense_.input_scaling_factors)); + // ReportCUDAErrors(cudaFree(mha_dense_.output_scaling_factors)); ReportCUDAErrors(cudaFree(ffn1_.weights_int8)); - ReportCUDAErrors(cudaFree(ffn1_.input_scaling_factors)); - ReportCUDAErrors(cudaFree(ffn1_.output_scaling_factors)); + // ReportCUDAErrors(cudaFree(ffn1_.input_scaling_factors)); + // ReportCUDAErrors(cudaFree(ffn1_.output_scaling_factors)); ReportCUDAErrors(cudaFree(ffn2_.weights_int8)); - ReportCUDAErrors(cudaFree(ffn2_.input_scaling_factors)); - ReportCUDAErrors(cudaFree(ffn2_.output_scaling_factors)); + // ReportCUDAErrors(cudaFree(ffn2_.input_scaling_factors)); + // ReportCUDAErrors(cudaFree(ffn2_.output_scaling_factors)); } // else if (int8_cali_) { // free(qkv_.input_matrix_max_values); @@ -2707,24 +2679,23 @@ void AttentionBody::Eval(int N, DataType* output, if (is_quantized_ && int8_inf_) { // 1. quantize (embedding -> temp) quantizeActivationMatrix((int8_t*)temp, (const half*)embedding, batch, - num_inputs, emb_ffn1_.input_scaling_factors, + num_inputs, emb_ffn1_.input_scaling_factor, stream); // 2. int8 matmul (temp -> buffer2) - // cutlassMatrixMulBTransposed((const int8_t*)temp, - // emb_ffn1_.weights_int8, - // (int32_t*)buffer2, batch, num_outputs, - // num_inputs, 1, 0, 0, 0, 1.0f, 0.0f); + cutlassMatrixMulBTransposed((const int8_t*)temp, emb_ffn1_.weights_int8, + (int32_t*)buffer2, batch, num_outputs, + num_inputs, 1, 0, 0, 0, 1.0f, 0.0f); - cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1, emb_ffn1_.weights_int8, num_inputs, - (int8_t*)temp, num_inputs, 0, (int32_t*)buffer2, - num_outputs); + // cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + // num_inputs, 1, emb_ffn1_.weights_int8, num_inputs, + // (int8_t*)temp, num_inputs, 0, (int32_t*)buffer2, + // num_outputs); // 3. Dequantize and bias add (mixed precision) (buffer2 -> buffer1) deQuantizeOutputMatrixBiasAdd( (half*)buffer1, (int32_t*)buffer2, batch, num_outputs, 1, nullptr, - emb_ffn1_.output_rescale_factor, (half*)ip_emb_ffn_d1_b_, + emb_ffn1_.output_scaling_factor, (half*)ip_emb_ffn_d1_b_, activations_.ffn_activation, stream); ReportCUDAErrors(cudaGetLastError()); @@ -2737,7 +2708,7 @@ void AttentionBody::Eval(int N, DataType* output, if (is_quantized_ && clipInputActivations) { clipActivationMatrix( (DataType*)temp, (const DataType*)embedding, - emb_ffn1_.fp16_clip_scale_factor, batch, num_inputs, stream); + emb_ffn1_.input_scaling_factor, batch, num_inputs, stream); } cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)ip_emb_ffn_d1_w_, @@ -2755,17 +2726,18 @@ void AttentionBody::Eval(int N, DataType* output, if (is_quantized_ && int8_inf_) { // 1. quantize (buffer1 -> temp) quantizeActivationMatrix((int8_t*)temp, (const half*)buffer1, batch, - num_inputs, emb_ffn2_.input_scaling_factors, + num_inputs, emb_ffn2_.input_scaling_factor, stream); + // 2. int8 matmul (temp -> buffer2) - // cutlassMatrixMulBTransposed((const int8_t*)temp, emb_ffn2_.weights_int8, - // (int32_t*)buffer2, batch, num_outputs, - // num_inputs, 1, 0, 0, 0, 1.0f, 0.0f); + cutlassMatrixMulBTransposed((const int8_t*)temp, emb_ffn2_.weights_int8, + (int32_t*)buffer2, batch, num_outputs, + num_inputs, 1, 0, 0, 0, 1.0f, 0.0f); - cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1, emb_ffn2_.weights_int8, num_inputs, - (const int8_t*)temp, num_inputs, 0, (int32_t*)buffer2, - num_outputs); + // cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, + // num_inputs, 1, emb_ffn2_.weights_int8, num_inputs, + // (const int8_t*)temp, num_inputs, 0, (int32_t*)buffer2, + // num_outputs); // Embedding LN: skip connection and layer normalization (also bias add // of prev gemm) (buffer2 -> embedding/output_tensor) @@ -2774,13 +2746,13 @@ void AttentionBody::Eval(int N, DataType* output, batch, num_outputs, output_tensor, (int32_t*)buffer2, ip_emb_ffn_d2_b_, embedding, ip_emb_ffn_ln_g_, ip_emb_ffn_ln_b_, 1e-3, alpha, ACTIVATION_NONE, stream, - emb_ffn2_.output_rescale_factor); + emb_ffn2_.output_scaling_factor); } else { if (is_quantized_ && clipInputActivations) { clipActivationMatrix( (DataType*)buffer1, (const DataType*)buffer1, - emb_ffn2_.fp16_clip_scale_factor, batch, num_inputs, stream); + emb_ffn2_.input_scaling_factor, batch, num_inputs, stream); } cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, @@ -2796,10 +2768,6 @@ void AttentionBody::Eval(int N, DataType* output, stream); } } - // dumpTensor((int32_t*)buffer1, num_outputs * 64, "embed ffn2 cublas", true); - // dumpTensor((DataType*)output_tensor, embedding_ffn_size_ * 64, - // "embed ffn2 ln", true); - // exit(0); } else { // 1. square embedding (fully connected layer) // Input data in NHWC layout N*(64)*C, output is N*(64)*embedding_op_size_ diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index 82ecb5d687..1fb080a07a 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -341,12 +341,8 @@ struct MatMulQuantizationData { // quantization float* output_scaling_factors; // per-column scaling factors for output // dequantization - float* output_deq_factors; // per-tensor. Always in cpu memory (passed as constants to dequantization kernels) - float* input_matrix_max_values; // max values of input matrix (always in CPU memory) - float* output_matrix_max_values; // max values in output matrix (always in CPU memory) - float output_rescale_factor; // accumulator rescale factor for matmuls to - // prevent overflow - float fp16_clip_scale_factor; // scale factor for clipping inputs in fp16 or fp32 mode + float output_scaling_factor; // single value output dequantization factor + float input_scaling_factor; // single value input quantization factor }; diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 003bbb8357..c12ef69a83 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -374,7 +374,7 @@ class CudaNetwork : public Network { sizeof(DataType); const size_t attentionBodySize = - getMaxAttentionBodySize(weights, max_batch_size_) * sizeof(DataType); + getMaxAttentionBodySize(weights, max_batch_size_) * sizeof(float); scratch_size_ = std::max(scratch_size_, std::max(attentionPolicySize, attentionBodySize)); From e82761a5e13ce25607588a935fa48850e137d467 Mon Sep 17 00:00:00 2001 From: almaudoh Date: Mon, 13 May 2024 04:15:42 +0200 Subject: [PATCH 64/70] Fix promotion to double for clipMatrix. --- src/neural/cuda/cutlass_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neural/cuda/cutlass_kernels.cu b/src/neural/cuda/cutlass_kernels.cu index 674a6977e7..1340937dfc 100644 --- a/src/neural/cuda/cutlass_kernels.cu +++ b/src/neural/cuda/cutlass_kernels.cu @@ -938,7 +938,7 @@ __global__ void clipMatrix(T* output, const T* input, const float scale_factor, if (x >= width || y >= height) return; float val = (float)input[y * width + x]; - val /= (1e-5 + scale_factor); + val /= (1e-5f + scale_factor); val = roundf(val); if (val > 127.0f) val = 127.0f; if (val < -128.0f) val = -128.0f; From 4e3a650d0792b4e88f5b88e80c137b890736dd68 Mon Sep 17 00:00:00 2001 From: almaudoh Date: Mon, 13 May 2024 10:13:23 +0200 Subject: [PATCH 65/70] Fix scratch size and change epiloge compute to int32. --- src/neural/cuda/cutlass_kernels.cu | 4 ++-- src/neural/cuda/layers.cc | 4 ++-- src/neural/cuda/network_cuda.cc | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/neural/cuda/cutlass_kernels.cu b/src/neural/cuda/cutlass_kernels.cu index 1340937dfc..c36ab6506a 100644 --- a/src/neural/cuda/cutlass_kernels.cu +++ b/src/neural/cuda/cutlass_kernels.cu @@ -285,7 +285,7 @@ void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int BStride, int OutStride, int VecStride, float alphaf, float betaf) { using ElementAccumulator = int32_t; - using ElementComputeEpilogue = float; + using ElementComputeEpilogue = int32_t; using ElementInput = int8_t; using ElementOutput = int32_t; using ElementScale = float; @@ -354,7 +354,7 @@ void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int32_t* Out, // dumpTensor(B, 512, "B after scaling", false); using ElementAccumulator = int32_t; // <- data type of accumulator - using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementComputeEpilogue = int32_t; // <- data type of epilogue operations using ElementInputA = int8_t; // <- data type of elements in input matrix A using ElementInputB = int8_t; // <- data type of elements in input matrix B using ElementOutput = diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index c456801cd1..0610eea455 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -2273,8 +2273,8 @@ void AttentionPolicyHead::Eval( void* scratch, size_t scratch_size, cudnnHandle_t /*cudnn*/, cublasHandle_t cublas, cudaStream_t stream, DataType*** offset_pointers) { DataType* input2_tensor = (DataType*)input2; - DataType* buffer1 = output + scratch_size / (2 * sizeof(float)); - DataType* buffer2 = input2_tensor + scratch_size / (2 * sizeof(float)); + DataType* buffer1 = output + scratch_size / (2 * sizeof(DataType)); + DataType* buffer2 = input2_tensor + scratch_size / (2 * sizeof(DataType)); int inputC = this->input_->GetC(); if (!attention_body_) diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 4dfd6da6f8..47507fa242 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -374,7 +374,7 @@ class CudaNetwork : public Network { sizeof(DataType); const size_t attentionBodySize = - getMaxAttentionBodySize(weights, max_batch_size_) * sizeof(float); + getMaxAttentionBodySize(weights, max_batch_size_) * sizeof(DataType) * 2; scratch_size_ = std::max(scratch_size_, std::max(attentionPolicySize, attentionBodySize)); From dddc978162dff8de6620b4c2dc902c5d83ba3922 Mon Sep 17 00:00:00 2001 From: almaudoh Date: Mon, 13 May 2024 14:20:33 +0200 Subject: [PATCH 66/70] Fuse FFN2 quantize to FFN2 dequantize+bias-add. 2% speedup. --- src/neural/cuda/cutlass_kernels.cu | 55 ++++++++++++++++++++++-------- src/neural/cuda/layers.cc | 33 ++++++++++-------- 2 files changed, 59 insertions(+), 29 deletions(-) diff --git a/src/neural/cuda/cutlass_kernels.cu b/src/neural/cuda/cutlass_kernels.cu index c36ab6506a..e8a580322a 100644 --- a/src/neural/cuda/cutlass_kernels.cu +++ b/src/neural/cuda/cutlass_kernels.cu @@ -965,10 +965,12 @@ struct ScaleParam { }; // process 8 elements per thread (in x dimension) -__global__ void deQuantizeMatrix(half* output, const int32_t* input, +template +__global__ void deQuantizeMatrix(OT* output, const int32_t* input, const half* bias, int height, int width, int stride, const float* invScaleArr, - const float invScale, ActivationFunction act) { + const float invScale, ActivationFunction act, + const float nextInputScale) { int x = (blockIdx.x * blockDim.x + threadIdx.x) * 8; int y = blockIdx.y * blockDim.y + threadIdx.y; int b = blockIdx.z; @@ -976,7 +978,7 @@ __global__ void deQuantizeMatrix(half* output, const int32_t* input, if (x >= width || y >= height) return; int32_t ip[8] = {}; - half op[8] = {}; + OT op[8] = {}; half bi[8] = {}; float inv_scale[8]; @@ -991,23 +993,37 @@ __global__ void deQuantizeMatrix(half* output, const int32_t* input, for (int i = 0; i < 8; i++) inv_scale[i] = invScale; } - for (int i = 0; i < 8; i++) { - float val = (float)ip[i]; - val *= inv_scale[i]; - if (bias) val += (float)bi[i]; - op[i] = (half)activate(val, act); + if (std::is_same::value) { + for (int i = 0; i < 8; i++) { + float val = (float)ip[i]; + val *= inv_scale[i]; + if (bias) val += (float)bi[i]; + op[i] = (OT)activate(val, act); + } + copyAs(&output[b * stride + y * width + x], &op[0]); + } else if (std::is_same::value) { + for (int i = 0; i < 8; i++) { + float val = (float)ip[i]; + val *= inv_scale[i]; + if (bias) val += (float)bi[i]; + val = activate(val, act); + val = roundf(val / (nextInputScale + 1e-5f)); + if (val > 127) val = 127; + if (val < -128) val = -128; + op[i] = (OT)(val); + } + copyAs(&output[b * stride + y * width + x], &op[0]); } - - copyAs(&output[b * stride + y * width + x], &op[0]); } // the scale (in CPU memory) is per "batch" // the bias is per column, per batch -void deQuantizeOutputMatrixBiasAdd(half* output, const int32_t* input, +template +void deQuantizeOutputMatrixBiasAdd(OutputType* output, const int32_t* input, int height, int width, int batchSize, float* invScaleArr, float invScale, const half* bias, ActivationFunction act, - cudaStream_t stream) { + float nextInputScale, cudaStream_t stream) { dim3 blockDim(16, 16); dim3 gridDim(lczero::cudnn_backend::DivUp(width, 16 * 8), lczero::cudnn_backend::DivUp(height, 16), batchSize); @@ -1017,8 +1033,9 @@ void deQuantizeOutputMatrixBiasAdd(half* output, const int32_t* input, int stride = width * height; - deQuantizeMatrix<<>>( - output, input, bias, height, width, stride, invScaleArr, invScale, act); + deQuantizeMatrix<<>>( + output, input, bias, height, width, stride, invScaleArr, invScale, act, + nextInputScale); ReportCUDAErrors(cudaGetLastError()); } @@ -1061,6 +1078,16 @@ template void dumpTensor(const int32_t* memory, int elements, const char* message, bool only_summary, bool cpu_tensor); +template void deQuantizeOutputMatrixBiasAdd( + half* output, const int32_t* input, int height, int width, int batchSize, + float* invScaleArr, float invScale, const half* bias, + ActivationFunction act, float nextInputScale, cudaStream_t stream); + +template void deQuantizeOutputMatrixBiasAdd( + int8_t* output, const int32_t* input, int height, int width, int batchSize, + float* invScaleArr, float invScale, const half* bias, + ActivationFunction act, float nextInputScale, cudaStream_t stream); + }; // namespace cudnn_backend }; // namespace lczero #endif diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 0610eea455..c91d233586 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -1557,10 +1557,12 @@ void clipActivationMatrix(DataType* output, const DataType* input, const float scale_factor, int height, int width, cudaStream_t stream); -void deQuantizeOutputMatrixBiasAdd(half* output, const int32_t* input, +template +void deQuantizeOutputMatrixBiasAdd(OutputType* output, const int32_t* input, int height, int width, int batchSize, float* scaleArr, float scale, const half* bias, ActivationFunction act, + float nextInputScale, cudaStream_t stream); static void LoadQuantizationData(MatMulQuantizationData& data, half* weights, @@ -1986,7 +1988,7 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // 3. Dequantize and bias add - mixed precision (buffer1 -> mha_q) deQuantizeOutputMatrixBiasAdd((half*)mha_q, (int32_t*)buffer1, batch_to_use, num_outputs, 3, qkv_.output_scaling_factors, - 1.0f, (half*)mha_qkv_b, ACTIVATION_NONE, + 1.0f, (half*)mha_qkv_b, ACTIVATION_NONE, 1.0f, stream); // 3. Bias add - mixed precision (buffer1 -> mha_q) @@ -2187,9 +2189,9 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // 3. Dequantize and bias add - mixed precision (buffer1 -> in_out_tensor) deQuantizeOutputMatrixBiasAdd( - (half*)in_out_tensor, (int32_t*)buffer1, batch, num_outputs, 1, + (int8_t*)in_out_tensor, (int32_t*)buffer1, batch, num_outputs, 1, nullptr, ffn1_.output_scaling_factor, (half*)ffn_dense1_b, - ffn_activation_, stream); + ffn_activation_, ffn2_.input_scaling_factor, stream); // 3. Bias add - mixed precision (buffer1 -> in_out_tensor) // addBiasBatched(in_out_tensor, (float*)buffer1, ffn_dense1_b, 1, batch, @@ -2227,12 +2229,12 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // 1. quantize the inputs (in_out_tensor -> buffer1) // TODO: Fuse this step with above bias add at least (or ideally with the // above GEMM) - quantizeActivationMatrix((int8_t*)buffer1, (const half*)in_out_tensor, - batch, num_inputs, ffn2_.input_scaling_factor, - stream); + // quantizeActivationMatrix((int8_t*)buffer1, (const half*)in_out_tensor, + // batch, num_inputs, ffn2_.input_scaling_factor, + // stream); - // 2. perform int8 GEMM (buffer1 -> buffer2) - cutlassMatrixMulBTransposed((const int8_t*)buffer1, ffn2_.weights_int8, + // 2. perform int8 GEMM (in_out_tensor -> buffer2) + cutlassMatrixMulBTransposed((const int8_t*)in_out_tensor, ffn2_.weights_int8, (int32_t*)buffer2, batch, num_outputs, num_inputs, 1, 0, 0, 0, 1.0, 0.0f); ReportCUDAErrors(cudaGetLastError()); @@ -2692,11 +2694,12 @@ void AttentionBody::Eval(int N, DataType* output, // (int8_t*)temp, num_inputs, 0, (int32_t*)buffer2, // num_outputs); - // 3. Dequantize and bias add (mixed precision) (buffer2 -> buffer1) + // 3. Dequantize and bias add (mixed precision) (buffer2 -> temp) deQuantizeOutputMatrixBiasAdd( - (half*)buffer1, (int32_t*)buffer2, batch, num_outputs, 1, nullptr, + (int8_t*)temp, (int32_t*)buffer2, batch, num_outputs, 1, nullptr, emb_ffn1_.output_scaling_factor, (half*)ip_emb_ffn_d1_b_, - activations_.ffn_activation, stream); + activations_.ffn_activation, emb_ffn2_.input_scaling_factor, + stream); ReportCUDAErrors(cudaGetLastError()); @@ -2725,9 +2728,9 @@ void AttentionBody::Eval(int N, DataType* output, const int batch = N * 64; if (is_quantized_ && int8_inf_) { // 1. quantize (buffer1 -> temp) - quantizeActivationMatrix((int8_t*)temp, (const half*)buffer1, batch, - num_inputs, emb_ffn2_.input_scaling_factor, - stream); + // quantizeActivationMatrix((int8_t*)temp, (const half*)buffer1, batch, + // num_inputs, emb_ffn2_.input_scaling_factor, + // stream); // 2. int8 matmul (temp -> buffer2) cutlassMatrixMulBTransposed((const int8_t*)temp, emb_ffn2_.weights_int8, From adad5452f91cc7ea649a3f535d61595eaaf21802 Mon Sep 17 00:00:00 2001 From: almaudoh Date: Thu, 16 May 2024 12:51:38 +0200 Subject: [PATCH 67/70] Implement int8 in all gemms except QKV. Fuse dequant-bias + add+quantize next layer --- libs/lczero-common | 2 +- src/neural/cuda/common_kernels.cu | 25 ++- src/neural/cuda/cutlass_kernels.cu | 174 +++++++++++++------- src/neural/cuda/layers.cc | 256 ++++++++++++++++++----------- src/neural/cuda/layers.h | 6 +- src/neural/network_legacy.cc | 10 +- src/neural/network_legacy.h | 9 +- 7 files changed, 322 insertions(+), 160 deletions(-) diff --git a/libs/lczero-common b/libs/lczero-common index 250a498c49..e05fb7a505 160000 --- a/libs/lczero-common +++ b/libs/lczero-common @@ -1 +1 @@ -Subproject commit 250a498c49018354cef95fcc81a42158d9cd38d1 +Subproject commit e05fb7a505554682acc8a197eb797c26b6db161d diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index 4f8a7e3239..b9ec0b1564 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -985,7 +985,11 @@ __global__ void layer_norm_kernel(int N, int C, T* output, const IT* input, const bool fp16 = std::is_same::value; if (!oobThread) { // Load from memory (16 elements a time) - if (std::is_same::value) { + if (std::is_same::value) { + int8_t inp[16]; + copyAs(&inp[0], &input[tensorIndex]); + for (int i = 0; i < 16; i++) val[i] = (float)inp[i] * dequant_scale; + } else if (std::is_same::value) { int32_t inp[8]; copyAs(&inp[0], &input[tensorIndex]); copyAs(&inp[4], &input[tensorIndex + 4]); @@ -1000,12 +1004,15 @@ __global__ void layer_norm_kernel(int N, int C, T* output, const IT* input, for (int i = 0; i < 8; i++) val[i] = (float)inp[i]; copyAs(&inp[0], &input[tensorIndex + 8]); for (int i = 0; i < 8; i++) val[i + 8] = (float)inp[i]; - } else { + } else if (std::is_same::value) { copyAs(&val[0], &input[tensorIndex]); copyAs(&val[4], &input[tensorIndex + 4]); copyAs(&val[8], &input[tensorIndex + 8]); copyAs(&val[12], &input[tensorIndex + 12]); + } else { + return; // @todo } + if (fp16) { half inp[8]; copyAs(&inp[0], &bias[biasIndex]); @@ -1795,6 +1802,20 @@ template void Softmax(int N, int C, half* output, const half* input, template void Softmax(int N, int C, float* output, const float* input, const float* input2, cudaStream_t stream); +template void LayerNorm(int N, int C, half* output, + const int8_t* input, const half* bias, + const half* skip, const half* gammas, + const half* betas, float ep, float alpha, + ActivationFunction act, + cudaStream_t stream, + float dequant_scale); +template void LayerNorm(int N, int C, float* output, + const int8_t* input, const float* bias, + const float* skip, const float* gammas, + const float* betas, float ep, + float alpha, ActivationFunction act, + cudaStream_t stream, + float dequant_scale); template void LayerNorm(int N, int C, half* output, const int32_t* input, const half* bias, const half* skip, const half* gammas, diff --git a/src/neural/cuda/cutlass_kernels.cu b/src/neural/cuda/cutlass_kernels.cu index e8a580322a..3e6444369a 100644 --- a/src/neural/cuda/cutlass_kernels.cu +++ b/src/neural/cuda/cutlass_kernels.cu @@ -228,7 +228,7 @@ void dumpTensor(const T* memory, int elements, const char* message, } } - if (!only_summary || i < 3 || i == elements - 1) { + if (!only_summary || i < 20 || i == elements - 1) { if (int8 || int32) { // printf("%6i ", (int8_t)val); printf("%i;%8i\n", i, (int)val); @@ -279,15 +279,16 @@ void dumpTensor(const T* memory, int elements, const char* message, } // int8 GEMM using CUTLASS (with per-column output quantization) +template void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, - const float* scaleVector, int32_t* Out, int M, + const float* scaleVector, OutType* Out, int M, int N, int K, int batchSize, int AStride, int BStride, int OutStride, int VecStride, float alphaf, float betaf) { using ElementAccumulator = int32_t; using ElementComputeEpilogue = int32_t; using ElementInput = int8_t; - using ElementOutput = int32_t; + using ElementOutput = OutType; using ElementScale = float; using ThreadBlockSwizzle = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; @@ -345,7 +346,8 @@ void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, } // int8 GEMM using CUTLASS -void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int32_t* Out, +template +void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, OutType* Out, int M, int N, int K, int batchSize, int AStride, int BStride, int OutStride, float alphaf, float betaf) { @@ -354,11 +356,11 @@ void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int32_t* Out, // dumpTensor(B, 512, "B after scaling", false); using ElementAccumulator = int32_t; // <- data type of accumulator - using ElementComputeEpilogue = int32_t; // <- data type of epilogue operations + using ElementComputeEpilogue = float; // <- data type of epilogue operations using ElementInputA = int8_t; // <- data type of elements in input matrix A using ElementInputB = int8_t; // <- data type of elements in input matrix B using ElementOutput = - int32_t; // <- data type of elements in gemm output matrix Out + OutType; // <- data type of elements in gemm output matrix Out // TODO: figure out why row major for matrix B doesn't work?!!! using LayoutInputA = cutlass::layout::RowMajor; @@ -434,7 +436,8 @@ void cutlassMatrixMulBTransposed_Emulate_INT8(const half* A, const half* B, cudaMemcpy(cpuB, B, BSize * sizeof(half), cudaMemcpyDeviceToHost); std::vector scaling_factors(K); - std::vector input_scaling_factors(K); // Not used here, but just for testing. + std::vector input_quantize_factors( + K); // Not used here, but just for testing. std::vector output_scaling_factors(N * batchSize); // apply smooth-quant (basically adjust A and B matrices to make @@ -494,7 +497,7 @@ void cutlassMatrixMulBTransposed_Emulate_INT8(const half* A, const half* B, // update the scaling factors based on global max for Activation matrix for (int i = 0; i < K; i++) { - input_scaling_factors[i] = 127.0f / (scaling_factors[i] * absMaxA); + input_quantize_factors[i] = 127.0f / (scaling_factors[i] * absMaxA); } std::vector BFactor(batchSize); @@ -581,7 +584,8 @@ void cutlassMatrixMulBTransposed_Emulate_INT8(const half* A, const half* B, // dump the inputs and outputs for debugging - Ankan - dumpTensor(&input_scaling_factors[0], 768, "input_scaling_factors", false, true); + dumpTensor(&input_quantize_factors[0], 768, "input_quantize_factors", + false, true); dumpTensor(AInt8, 768, "input quantized", false, true); dumpTensor(BInt8, 768, "weight quantized", false, true); dumpTensor(OutInt8, 768, "output quantized", false, true); @@ -681,9 +685,11 @@ void cutlassMatrixMulBTransposed(const half* A, const half* B, half* Out, int M, int8_t values[1024 * 64]; -static void calibrateGemm(int8_t* weights_int8, float* input_scaling_factors, - float* output_scaling_factors, float *output_deq_factors, float* cpuA, - float* cpuB, float *maxValuesA, float *maxValuesOut, int M, int N, int K, int batchSize) { +static void calibrateGemm(int8_t* weights_int8, float* input_quantize_factors, + float* output_scaling_factors, + float* output_deq_factors, float* cpuA, float* cpuB, + float* maxValuesA, float* maxValuesOut, int M, int N, + int K, int batchSize) { std::vector scaling_factors(K); std::vector A_Max(M * K); // this is another matrix we use to track calculations using max values @@ -746,7 +752,7 @@ static void calibrateGemm(int8_t* weights_int8, float* input_scaling_factors, // update the scaling factors based on global max for Activation matrix for (int i = 0; i < K; i++) { - input_scaling_factors[i] = 127.0f / (scaling_factors[i] * absMaxA); + input_quantize_factors[i] = 127.0f / (scaling_factors[i] * absMaxA); } std::vector BFactor(batchSize); @@ -778,7 +784,7 @@ static void calibrateGemm(int8_t* weights_int8, float* input_scaling_factors, for (int y = 0; y < M; y++) { int s = 0; for (int k = 0; k < K; k++) { - int v1 = (int)roundf(input_scaling_factors[k] * cpuA[y * K + k]); + int v1 = (int)roundf(input_quantize_factors[k] * cpuA[y * K + k]); int v2 = weights_int8[b * K * N + x * K + k]; s += v1 * v2; } @@ -802,7 +808,7 @@ static void calibrateGemm(int8_t* weights_int8, float* input_scaling_factors, for (int y = 0; y < M; y++) { float s = 0; for (int k = 0; k < K; k++) { - float v1 = input_scaling_factors[k] * cpuA[y * K + k]; + float v1 = input_quantize_factors[k] * cpuA[y * K + k]; float v2 = (float)(weights_int8[b * K * N + x * K + k]); s += v1 * v2; } @@ -812,7 +818,7 @@ static void calibrateGemm(int8_t* weights_int8, float* input_scaling_factors, for (int y = 0; y < M; y++) for (int k = 0; k < K; k++) - cpuA[i] *= input_scaling_factors[k]; + cpuA[i] *= input_quantize_factors[k]; dumpTensor(cpuA, 768, "input matrix during calibration", @@ -837,18 +843,19 @@ static void calibrateGemm(int8_t* weights_int8, float* input_scaling_factors, // Same Activation (A) matrix (M x K) is multiplied by batchSize x B matrices / // weights (K x N transposed) The outputs are: // 1. quantized weight matrices (weights_int8) -// 2. "per-column" scaling factors (input_scaling_factors) needed to quantize +// 2. "per-column" scaling factors (input_quantize_factors) needed to quantize // matrix A // 3. Scaling factors to dequantize the output matrix (just 3 values: factorQ, // factorK, factorV) // M_Batch is the batch size component in "M" dimension // maxValuesA contains the max values in activation matrix found so far template -void calibrateGemmForInt8(int8_t* weights_int8, float* input_scaling_factors, +void calibrateGemmForInt8(int8_t* weights_int8, float* input_quantize_factors, float* output_scaling_factors, - float* output_deq_factors, float* maxValuesA, float *maxValuesOut, - const DataType* A, const DataType* B, int M, int N, - int K, int batchSize, int M_Batch) { + float* output_deq_factors, float* maxValuesA, + float* maxValuesOut, const DataType* A, + const DataType* B, int M, int N, int K, int batchSize, + int M_Batch) { auto cpuA = (DataType*)malloc(M_Batch * M * K * sizeof(DataType)); auto cpuB = (DataType*)malloc(batchSize * K * N * sizeof(DataType)); @@ -873,9 +880,9 @@ void calibrateGemmForInt8(int8_t* weights_int8, float* input_scaling_factors, } // calibrate a single sample - calibrateGemm(weights_int8, input_scaling_factors, output_scaling_factors, - output_deq_factors, fpA, fpB, maxValuesA, maxValuesOut, M, N, K, - batchSize); + calibrateGemm(weights_int8, input_quantize_factors, output_scaling_factors, + output_deq_factors, fpA, fpB, maxValuesA, maxValuesOut, M, N, + K, batchSize); } free(fpA); @@ -930,31 +937,37 @@ void quantizeActivationMatrix(int8_t* output, const half* input, int height, // Quantize matrix with single scale value template -__global__ void clipMatrix(T* output, const T* input, const float scale_factor, - int height, int width) { - int x = (blockIdx.x * blockDim.x + threadIdx.x); +__global__ void clipMatrix(T* output, const T* input, const float* scale_factors, + const float scale_factor, int height, int width) { + int x = blockIdx.x * blockDim.x + threadIdx.x; int y = blockIdx.y * blockDim.y + threadIdx.y; if (x >= width || y >= height) return; + float factor = scale_factor; + + if (scale_factors) { + factor = scale_factors[x]; + } + float val = (float)input[y * width + x]; - val /= (1e-5f + scale_factor); + val /= (1e-5f + factor); val = roundf(val); if (val > 127.0f) val = 127.0f; if (val < -128.0f) val = -128.0f; - val *= scale_factor; + val *= factor; output[y * width + x] = (T)val; } template void clipActivationMatrix(DataType* output, const DataType* input, - const float scale_factor, int height, int width, - cudaStream_t stream) { + const float* scale_factors, const float scale_factor, + int height, int width, cudaStream_t stream) { dim3 blockDim(16, 16); dim3 gridDim(lczero::cudnn_backend::DivUp(width, 16), lczero::cudnn_backend::DivUp(height, 16)); clipMatrix<<>>( - output, input, scale_factor, height, width); + output, input, scale_factors, scale_factor, height, width); ReportCUDAErrors(cudaGetLastError()); } @@ -965,11 +978,11 @@ struct ScaleParam { }; // process 8 elements per thread (in x dimension) -template -__global__ void deQuantizeMatrix(OT* output, const int32_t* input, - const half* bias, int height, int width, - int stride, const float* invScaleArr, - const float invScale, ActivationFunction act, +template +__global__ void deQuantizeMatrix(OT* output, const IT* input, const half* bias, + int height, int width, int stride, + const float* invScaleArr, const float invScale, + ActivationFunction act, const float nextInputScale) { int x = (blockIdx.x * blockDim.x + threadIdx.x) * 8; int y = blockIdx.y * blockDim.y + threadIdx.y; @@ -977,30 +990,42 @@ __global__ void deQuantizeMatrix(OT* output, const int32_t* input, if (x >= width || y >= height) return; - int32_t ip[8] = {}; + IT ip[8] = {}; OT op[8] = {}; half bi[8] = {}; float inv_scale[8]; - copyAs(&ip[0], &input[b * stride + y * width + x]); - copyAs(&ip[4], &input[b * stride + y * width + x + 4]); + if (std::is_same::value) { + copyAs(&ip[0], &input[b * stride + y * width + x]); + } else if (std::is_same::value) { + copyAs(&ip[0], &input[b * stride + y * width + x]); + copyAs(&ip[4], &input[b * stride + y * width + x + 4]); + } else { + return; + } if (bias) copyAs(&bi[0], &bias[b * width + x]); + // // Int8 is already scaled by invScale / nextInputScale, so only + // // multiply by nextInputScale to get it to match. + // if (std::is_same::value) { + // for (int i = 0; i < 8; i++) inv_scale[i] = nextInputScale; + // } else { if (invScaleArr) { copyAs(&inv_scale[0], &invScaleArr[b * width + x]); copyAs(&inv_scale[4], &invScaleArr[b * width + x + 4]); } else { for (int i = 0; i < 8; i++) inv_scale[i] = invScale; } - - if (std::is_same::value) { - for (int i = 0; i < 8; i++) { - float val = (float)ip[i]; - val *= inv_scale[i]; - if (bias) val += (float)bi[i]; - op[i] = (OT)activate(val, act); - } - copyAs(&output[b * stride + y * width + x], &op[0]); + // } + + if (std::is_same::value) { + for (int i = 0; i < 8; i++) { + float val = (float)ip[i]; + val *= inv_scale[i]; + if (bias) val += (float)bi[i]; + op[i] = (OT)activate(val, act); + } + copyAs(&output[b * stride + y * width + x], &op[0]); } else if (std::is_same::value) { for (int i = 0; i < 8; i++) { float val = (float)ip[i]; @@ -1018,8 +1043,8 @@ __global__ void deQuantizeMatrix(OT* output, const int32_t* input, // the scale (in CPU memory) is per "batch" // the bias is per column, per batch -template -void deQuantizeOutputMatrixBiasAdd(OutputType* output, const int32_t* input, +template +void deQuantizeOutputMatrixBiasAdd(OutputType* output, const InputType* input, int height, int width, int batchSize, float* invScaleArr, float invScale, const half* bias, ActivationFunction act, @@ -1033,7 +1058,7 @@ void deQuantizeOutputMatrixBiasAdd(OutputType* output, const int32_t* input, int stride = width * height; - deQuantizeMatrix<<>>( + deQuantizeMatrix<<>>( output, input, bias, height, width, stride, invScaleArr, invScale, act, nextInputScale); ReportCUDAErrors(cudaGetLastError()); @@ -1044,21 +1069,22 @@ void fillGpuArray(float* arr, float val, int count) { thrust::fill(dev_ptr, dev_ptr + count, val); } - template void calibrateGemmForInt8( - int8_t* weights_int8, float* input_scaling_factors, + int8_t* weights_int8, float* input_quantize_factors, float* output_scaling_factors, float* output_deq_factors, float* maxValuesA, float* maxValuesOut, const float* A, const float* B, int M, int N, int K, int batchSize, int M_Batch); template void calibrateGemmForInt8( - int8_t* weights_int8, float* input_scaling_factors, + int8_t* weights_int8, float* input_quantize_factors, float* output_scaling_factors, float* output_deq_factors, float* maxValuesA, float* maxValuesOut, const half* A, const half* B, int M, int N, int K, int batchSize, int M_Batch); template void clipActivationMatrix(float* output, const float* input, + const float* scale_factors, const float scale_factor, int height, int width, cudaStream_t stream); template void clipActivationMatrix(half* output, const half* input, + const float* scale_factors, const float scale_factor, int height, int width, cudaStream_t stream); @@ -1067,8 +1093,8 @@ template void dumpTensor(const float* memory, int elements, bool cpu_tensor); template void dumpTensor(const half* memory, int elements, - const char* message, bool only_summary, - bool cpu_tensor); + const char* message, bool only_summary, + bool cpu_tensor); template void dumpTensor(const int8_t* memory, int elements, const char* message, bool only_summary, @@ -1088,6 +1114,40 @@ template void deQuantizeOutputMatrixBiasAdd( float* invScaleArr, float invScale, const half* bias, ActivationFunction act, float nextInputScale, cudaStream_t stream); +template void deQuantizeOutputMatrixBiasAdd( + half* output, const int8_t* input, int height, int width, int batchSize, + float* invScaleArr, float invScale, const half* bias, + ActivationFunction act, float nextInputScale, cudaStream_t stream); + +template void deQuantizeOutputMatrixBiasAdd( + int8_t* output, const int8_t* input, int height, int width, int batchSize, + float* invScaleArr, float invScale, const half* bias, + ActivationFunction act, float nextInputScale, cudaStream_t stream); + +template void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, + int32_t* Out, int M, int N, int K, + int batchSize, int AStride, + int BStride, int OutStride, + float alphaf, float betaf); + +template void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, + int8_t* Out, int M, int N, int K, + int batchSize, int AStride, + int BStride, int OutStride, + float alphaf, float betaf); + +template void cutlassMatrixMulBTransposed( + const int8_t* A, const int8_t* B, const float* scaleVector, int32_t* Out, + int M, int N, int K, int batchSize, int AStride, int BStride, int OutStride, + int VecStride, float alphaf, float betaf); + +template void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, + const float* scaleVector, int8_t* Out, + int M, int N, int K, int batchSize, + int AStride, int BStride, + int OutStride, int VecStride, + float alphaf, float betaf); + }; // namespace cudnn_backend }; // namespace lczero #endif diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index c91d233586..1434719fcc 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -153,7 +153,7 @@ namespace cudnn_backend { // than using multiple passes. The flag can be set to false for debugging. static constexpr bool kUseFusedSELayer = true; -static constexpr bool clipInputActivations = true; +static constexpr bool clipQuantizedActivations = true; template BaseLayer::BaseLayer(int c, int h, int w, BaseLayer* ip, bool nhwc) @@ -1537,13 +1537,15 @@ AttentionPolicyHead::AttentionPolicyHead( void fillGpuArray(float* arr, float val, int count); -void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, int32_t* Out, +template +void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, OutType* Out, int M, int N, int K, int batchSize, int AStride, int BStride, int OutStride, float alphaf, float betaf); +template void cutlassMatrixMulBTransposed(const int8_t* A, const int8_t* B, - const float* scaleVector, int32_t* Out, int M, + const float* scaleVector, OutType* Out, int M, int N, int K, int batchSize, int AStride, int BStride, int OutStride, int VecStride, float alphaf, float betaf); @@ -1554,21 +1556,21 @@ void quantizeActivationMatrix(int8_t* output, const half* input, int height, template void clipActivationMatrix(DataType* output, const DataType* input, - const float scale_factor, int height, int width, - cudaStream_t stream); + const float* scale_factors, const float scale_factor, + int height, int width, cudaStream_t stream); -template -void deQuantizeOutputMatrixBiasAdd(OutputType* output, const int32_t* input, +template +void deQuantizeOutputMatrixBiasAdd(OutputType* output, const InputType* input, int height, int width, int batchSize, float* scaleArr, float scale, const half* bias, ActivationFunction act, - float nextInputScale, - cudaStream_t stream); + float nextInputScale, cudaStream_t stream); static void LoadQuantizationData(MatMulQuantizationData& data, half* weights, int input_len, int output_len, const std::vector& weight_factors, const std::vector& input_factors, + const std::vector& output_factors, cudaStream_t stream) { // Load weights for INT8 inference @@ -1579,10 +1581,14 @@ static void LoadQuantizationData(MatMulQuantizationData& data, half* weights, throw Exception("Channelwise quantization not yet supported."); } else { // Same factor will be used for clipping inputs in int8 and fp16 mode. - data.input_scaling_factor = input_factors[0]; + data.input_quantize_factor = input_factors[0]; // Output scaling factor = input factor x weight factor. - data.output_scaling_factor = input_factors[0] * weight_factors[0]; + data.output_scaling_factor = + input_factors[0] * weight_factors[0] / output_factors[0]; + + // Output scale. + data.output_dequant_factor = output_factors[0]; } // Load weights and run a GPU kernel to scale it. @@ -1593,7 +1599,7 @@ static void LoadQuantizationData(MatMulQuantizationData& data, half* weights, weight_factors[0], stream); // The original weights also need to be clipped for fp16 inference. - clipActivationMatrix(weights, weights, weight_factors[0], 1, + clipActivationMatrix(weights, weights, nullptr, weight_factors[0], 1, weights_len, stream); } @@ -1604,15 +1610,12 @@ static void LoadQKVQuantizationData(MatMulQuantizationData& data, const std::vector& k_weight_factors, const std::vector& v_weight_factors, const std::vector& input_factors, + const std::vector& q_output_factors, + const std::vector& k_output_factors, + const std::vector& v_output_factors, cudaStream_t stream) { // Load weights for INT8 inference. - // (per-column) scaling factors for the input and output. - // ReportCUDAErrors( - // cudaMalloc(&data.input_scaling_factors, input_len * sizeof(float))); - ReportCUDAErrors( - cudaMalloc(&data.output_scaling_factors, output_len * 3 * sizeof(float))); - if (input_factors.size() > 1) { // ReportCUDAErrors(cudaMemcpy( // data.output_scaling_factors, input_factors.data(), @@ -1620,15 +1623,28 @@ static void LoadQKVQuantizationData(MatMulQuantizationData& data, throw Exception("Channelwise quantization not yet supported."); } else { // Same factor will be used for clipping inputs in int8 and fp16 mode. - data.input_scaling_factor = input_factors[0]; + data.input_quantize_factor = input_factors[0]; // Repeatedly fill values into the output factors buffer. + ReportCUDAErrors(cudaMalloc(&data.output_scaling_factors, + output_len * 3 * sizeof(float))); fillGpuArray(data.output_scaling_factors, - input_factors[0] * q_weight_factors[0], output_len); + input_factors[0] * q_weight_factors[0] / q_output_factors[0], + output_len); fillGpuArray(data.output_scaling_factors + output_len, - input_factors[0] * k_weight_factors[0], output_len); + input_factors[0] * k_weight_factors[0] / k_output_factors[0], + output_len); fillGpuArray(data.output_scaling_factors + output_len * 2, - input_factors[0] * v_weight_factors[0], output_len); + input_factors[0] * v_weight_factors[0] / v_output_factors[0], + output_len); + + ReportCUDAErrors(cudaMalloc(&data.output_dequant_factors, + output_len * 3 * sizeof(float))); + fillGpuArray(data.output_dequant_factors, q_output_factors[0], output_len); + fillGpuArray(data.output_dequant_factors + output_len, k_output_factors[0], + output_len); + fillGpuArray(data.output_dequant_factors + output_len * 2, + v_output_factors[0], output_len); } // Load QKV weights and run a GPU kernel to scale them. @@ -1646,18 +1662,19 @@ static void LoadQKVQuantizationData(MatMulQuantizationData& data, // The original weights also need to be clipped for fp16 inference // q weights. - clipActivationMatrix(qkv_weights, qkv_weights, q_weight_factors[0], - output_len, input_len, stream); + clipActivationMatrix(qkv_weights, qkv_weights, nullptr, + q_weight_factors[0], output_len, input_len, + stream); // k weights. - clipActivationMatrix(qkv_weights + weights_len, - qkv_weights + weights_len, k_weight_factors[0], - output_len, input_len, stream); + clipActivationMatrix( + qkv_weights + weights_len, qkv_weights + weights_len, nullptr, + k_weight_factors[0], output_len, input_len, stream); // v weights. - clipActivationMatrix(qkv_weights + weights_len * 2, - qkv_weights + weights_len * 2, v_weight_factors[0], - output_len, input_len, stream); + clipActivationMatrix( + qkv_weights + weights_len * 2, qkv_weights + weights_len * 2, nullptr, + v_weight_factors[0], output_len, input_len, stream); } template @@ -1771,19 +1788,20 @@ EncoderBlock::EncoderBlock( // int8 stuff blockIndex_ = blockIndex; if (int8_inf_ || is_quantized_) { - LoadQKVQuantizationData(qkv_, (half*)mha_qkv_w, embedding_op_size_, - mha_q_size_, cpu_weights.mha.q_s, - cpu_weights.mha.k_s, cpu_weights.mha.v_s, - cpu_weights.mha.s1, 0); + LoadQKVQuantizationData( + qkv_, (half*)mha_qkv_w, embedding_op_size_, mha_q_size_, + cpu_weights.mha.q_s, cpu_weights.mha.k_s, cpu_weights.mha.v_s, + cpu_weights.mha.s1, cpu_weights.mha.q_out_s, cpu_weights.mha.k_out_s, + cpu_weights.mha.v_out_s, 0); LoadQuantizationData(mha_dense_, (half*)mha_dense_w, embedding_op_size_, mha_dense_size_, cpu_weights.mha.dense_s, - cpu_weights.mha.s2, 0); + cpu_weights.mha.s2, cpu_weights.mha.dense_out_s, 0); LoadQuantizationData(ffn1_, (half*)ffn_dense1_w, embedding_op_size_, ffn_dense1_size_, cpu_weights.ffn.dense1_s, - cpu_weights.ffn.s1, 0); + cpu_weights.ffn.s1, cpu_weights.ffn.dense1_out_s, 0); LoadQuantizationData(ffn2_, (half*)ffn_dense2_w, ffn_dense1_size_, ffn_dense2_size_, cpu_weights.ffn.dense2_s, - cpu_weights.ffn.s2, 0); + cpu_weights.ffn.s2, cpu_weights.ffn.dense2_out_s, 0); } } @@ -1966,12 +1984,12 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, mha_k = mha_q + num_outputs * batch_to_use; mha_v = mha_k + num_outputs * batch_to_use; - if (is_quantized_ && int8_inf_) { + if (false && is_quantized_ && int8_inf_) { // 1. quantize the inputs (in_out_tensor -> scratch) // TODO: Fuse this step with layer-norm of previous block quantizeActivationMatrix((int8_t*)scratch, (const half*)in_out_tensor, batch, embedding_op_size_, - qkv_.input_scaling_factor, stream); + qkv_.input_quantize_factor, stream); // 2. perform int8 GEMM (scratch -> buffer1) // cutlassMatrixMulBTransposed((const int8_t*)scratch, qkv_.weights_int8, @@ -1999,10 +2017,10 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, ReportCUDAErrors(cudaGetLastError()); } else { - if (is_quantized_ && clipInputActivations) { + if (is_quantized_ && clipQuantizedActivations) { clipActivationMatrix( - (DataType*)in_out_tensor, (const DataType*)in_out_tensor, - qkv_.input_scaling_factor, batch, num_inputs, stream); + (DataType*)in_out_tensor, (const DataType*)in_out_tensor, nullptr, + qkv_.input_quantize_factor, batch, num_inputs, stream); } cublasXGemmStridedBatched( @@ -2011,6 +2029,12 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, num_inputs, 0, 0.0f, mha_q, num_outputs, num_outputs * batch_to_use, 3); + if (is_quantized_ && clipQuantizedActivations) { + clipActivationMatrix((DataType*)mha_q, (const DataType*)mha_q, + qkv_.output_dequant_factors, 0.0f, batch, + num_outputs, stream); + } + addBiasBatched(mha_q, mha_q, mha_qkv_b, 3, batch, num_outputs, batch_to_use, ACTIVATION_NONE, stream); } @@ -2133,33 +2157,40 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // 1. quantize the inputs (buffer2 -> scratch) // TODO: Fuse this step with the previous fused MHA quantizeActivationMatrix((int8_t*)scratch, (const half*)buffer2, batch, - num_inputs, mha_dense_.input_scaling_factor, + num_inputs, mha_dense_.input_quantize_factor, stream); // 2. perform int8 GEMM (scratch -> buffer2) cutlassMatrixMulBTransposed( - (const int8_t*)scratch, mha_dense_.weights_int8, (int32_t*)buffer2, - batch, num_outputs, num_inputs, 1, 0, 0, 0, 1.0, 0.0f); + (const int8_t*)scratch, (const int8_t*)mha_dense_.weights_int8, + (int8_t*)buffer2, batch, num_outputs, num_inputs, 1, 0, 0, 0, + mha_dense_.output_scaling_factor, 0.0f); ReportCUDAErrors(cudaGetLastError()); // LN1: skip connection and layer normalization (also bias add of prev // gemm) buffer2 -> scratch - LayerNorm( - N * 64, embedding_op_size_, scratch, (int32_t*)buffer2, mha_dense_b, + LayerNorm( + N * 64, embedding_op_size_, scratch, (int8_t*)buffer2, mha_dense_b, in_out_tensor, ln1_gammas, ln1_betas, default_eps_, alpha_, - ACTIVATION_NONE, stream, mha_dense_.output_scaling_factor); + ACTIVATION_NONE, stream, mha_dense_.output_dequant_factor); } else { - if (is_quantized_ && clipInputActivations) { + if (is_quantized_ && clipQuantizedActivations) { clipActivationMatrix( - (DataType*)buffer2, (const DataType*)buffer2, - mha_dense_.input_scaling_factor, batch, num_inputs, stream); + (DataType*)buffer2, (const DataType*)buffer2, nullptr, + mha_dense_.input_quantize_factor, batch, num_inputs, stream); } cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, mha_dense_w, num_inputs, buffer2, num_inputs, 0.0f, buffer1, num_outputs); + if (is_quantized_ && clipQuantizedActivations) { + clipActivationMatrix( + (DataType*)buffer1, (const DataType*)buffer1, nullptr, + mha_dense_.output_dequant_factor, batch, num_outputs, stream); + } + // LN1: skip connection and layer normalization (also bias add of prev // gemm) buffer1/in_out_tensor -> scratch LayerNorm(N * 64, embedding_op_size_, scratch, buffer1, @@ -2167,6 +2198,8 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, default_eps_, alpha_, ACTIVATION_NONE, stream); } } + // dumpTensor(scratch, embedding_op_size_ * N * 64, "ffn1 input", true); + // exit(0); // #FFN dense 1, scratch -> in_out_tensor const int encoder_dff = ffn_dense1_size_; @@ -2179,41 +2212,47 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // 1. quantize the inputs (scratch -> in_out_tensor) // TODO: Fuse this step with LN1 (should be easy) quantizeActivationMatrix((int8_t*)in_out_tensor, (const half*)scratch, - batch, num_inputs, ffn1_.input_scaling_factor, + batch, num_inputs, ffn1_.input_quantize_factor, stream); // 2. perform int8 GEMM (in_out_tensor -> buffer1) cutlassMatrixMulBTransposed( - (const int8_t*)in_out_tensor, ffn1_.weights_int8, (int32_t*)buffer1, - batch, num_outputs, num_inputs, 1, 0, 0, 0, 1.0, 0.0f); + (const int8_t*)in_out_tensor, (const int8_t*)ffn1_.weights_int8, + (int8_t*)buffer1, batch, num_outputs, num_inputs, 1, 0, 0, 0, + ffn1_.output_scaling_factor, 0.0f); - // 3. Dequantize and bias add - mixed precision (buffer1 -> in_out_tensor) + // 3. Dequantize and bias add - mixed precision (buffer1 -> + // in_out_tensor) deQuantizeOutputMatrixBiasAdd( - (int8_t*)in_out_tensor, (int32_t*)buffer1, batch, num_outputs, 1, - nullptr, ffn1_.output_scaling_factor, (half*)ffn_dense1_b, - ffn_activation_, ffn2_.input_scaling_factor, stream); + (int8_t*)in_out_tensor, (int8_t*)buffer1, batch, num_outputs, 1, + nullptr, ffn1_.output_dequant_factor, (half*)ffn_dense1_b, + ffn_activation_, ffn2_.input_quantize_factor, stream); // 3. Bias add - mixed precision (buffer1 -> in_out_tensor) // addBiasBatched(in_out_tensor, (float*)buffer1, ffn_dense1_b, 1, batch, // num_outputs, ffn_activation_, stream); } else { - if (is_quantized_ && clipInputActivations) { + if (is_quantized_ && clipQuantizedActivations) { // Note `scratch` should not be changed as it is the FFN input to be used // as skip connection later at the layer norm. clipActivationMatrix( - (DataType*)in_out_tensor, (const DataType*)scratch, - ffn1_.input_scaling_factor, batch, num_inputs, stream); + (DataType*)in_out_tensor, (const DataType*)scratch, nullptr, + ffn1_.input_quantize_factor, batch, num_inputs, stream); cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)ffn_dense1_w, num_inputs, in_out_tensor, num_inputs, 0.0f, buffer1, num_outputs); + + clipActivationMatrix( + (DataType*)buffer1, (const DataType*)buffer1, nullptr, + ffn1_.output_dequant_factor, batch, num_outputs, stream); + } else { cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)ffn_dense1_w, num_inputs, scratch, num_inputs, 0.0f, buffer1, num_outputs); } - addBiasBatched(in_out_tensor, buffer1, ffn_dense1_b, 1, batch, num_outputs, ffn_activation_, stream); @@ -2230,32 +2269,39 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // TODO: Fuse this step with above bias add at least (or ideally with the // above GEMM) // quantizeActivationMatrix((int8_t*)buffer1, (const half*)in_out_tensor, - // batch, num_inputs, ffn2_.input_scaling_factor, - // stream); + // batch, num_inputs, + // ffn2_.input_quantize_factor, stream); // 2. perform int8 GEMM (in_out_tensor -> buffer2) - cutlassMatrixMulBTransposed((const int8_t*)in_out_tensor, ffn2_.weights_int8, - (int32_t*)buffer2, batch, num_outputs, - num_inputs, 1, 0, 0, 0, 1.0, 0.0f); + cutlassMatrixMulBTransposed( + (const int8_t*)in_out_tensor, (const int8_t*)ffn2_.weights_int8, + (int8_t*)buffer2, batch, num_outputs, num_inputs, 1, 0, 0, 0, + ffn2_.output_scaling_factor, 0.0f); ReportCUDAErrors(cudaGetLastError()); // LN2: skip connection and layer normilization (also bias add of prev // gemm) buffer2/scratch -> in_out_tensor - LayerNorm( - N * 64, embedding_op_size_, in_out_tensor, (int32_t*)buffer2, + LayerNorm( + N * 64, embedding_op_size_, in_out_tensor, (int8_t*)buffer2, ffn_dense2_b, scratch, ln2_gammas, ln2_betas, default_eps_, alpha_, - ACTIVATION_NONE, stream, ffn2_.output_scaling_factor); + ACTIVATION_NONE, stream, ffn2_.output_dequant_factor); } else { - if (is_quantized_ && clipInputActivations) { + if (is_quantized_ && clipQuantizedActivations) { clipActivationMatrix( - (DataType*)in_out_tensor, (const DataType*)in_out_tensor, - ffn2_.input_scaling_factor, batch, num_inputs, stream); + (DataType*)in_out_tensor, (const DataType*)in_out_tensor, nullptr, + ffn2_.input_quantize_factor, batch, num_inputs, stream); } cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)ffn_dense2_w, num_inputs, in_out_tensor, num_inputs, 0.0f, buffer1, num_outputs); + if (is_quantized_ && clipQuantizedActivations) { + clipActivationMatrix( + (DataType*)buffer1, (const DataType*)buffer1, nullptr, + ffn2_.output_dequant_factor, batch, num_outputs, stream); + } + // LN2: skip connection and layer normilization (also bias add of prev // gemm) buffer1/scratch -> in_out_tensor LayerNorm(N * 64, embedding_op_size_, in_out_tensor, buffer1, @@ -2266,6 +2312,8 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // "encoder 1 ffn2 LN output", false); // exit(0); } + // dumpTensor(in_out_tensor, num_outputs * N * 64, "Biases", true); + // exit(0); } } @@ -2408,6 +2456,7 @@ EncoderBlock::~EncoderBlock() { ReportCUDAErrors(cudaFree(qkv_.weights_int8)); // ReportCUDAErrors(cudaFree(qkv_.input_scaling_factors)); ReportCUDAErrors(cudaFree(qkv_.output_scaling_factors)); + ReportCUDAErrors(cudaFree(qkv_.output_dequant_factors)); ReportCUDAErrors(cudaFree(mha_dense_.weights_int8)); // ReportCUDAErrors(cudaFree(mha_dense_.input_scaling_factors)); // ReportCUDAErrors(cudaFree(mha_dense_.output_scaling_factors)); @@ -2513,10 +2562,12 @@ AttentionBody::AttentionBody(const MultiHeadWeights& weights, // Quantization data for input embedding FFN layers. LoadQuantizationData(emb_ffn1_, (half*)ip_emb_ffn_d1_w_, embedding_ffn_size_, embedding_ffn_dff_, - weights.ip_emb_ffn.dense1_s, weights.ip_emb_ffn.s1, 0); + weights.ip_emb_ffn.dense1_s, weights.ip_emb_ffn.s1, + weights.ip_emb_ffn.dense1_out_s, 0); LoadQuantizationData(emb_ffn2_, (half*)ip_emb_ffn_d2_w_, embedding_ffn_dff_, embedding_ffn_size_, weights.ip_emb_ffn.dense2_s, - weights.ip_emb_ffn.s2, 0); + weights.ip_emb_ffn.s2, weights.ip_emb_ffn.dense2_out_s, + 0); } else { size_t size = 64 * kNumPosEncodingChannels * sizeof(float); ReportCUDAErrors(cudaMalloc(&pos_encoding_, size)); @@ -2681,14 +2732,17 @@ void AttentionBody::Eval(int N, DataType* output, if (is_quantized_ && int8_inf_) { // 1. quantize (embedding -> temp) quantizeActivationMatrix((int8_t*)temp, (const half*)embedding, batch, - num_inputs, emb_ffn1_.input_scaling_factor, + num_inputs, emb_ffn1_.input_quantize_factor, stream); // 2. int8 matmul (temp -> buffer2) - cutlassMatrixMulBTransposed((const int8_t*)temp, emb_ffn1_.weights_int8, - (int32_t*)buffer2, batch, num_outputs, - num_inputs, 1, 0, 0, 0, 1.0f, 0.0f); + cutlassMatrixMulBTransposed( + (const int8_t*)temp, (const int8_t*)emb_ffn1_.weights_int8, + (int8_t*)buffer2, batch, num_outputs, num_inputs, 1, 0, 0, 0, + emb_ffn1_.output_scaling_factor, 0.0f); + // dumpTensor((int8_t*)buffer2, 64 * num_outputs, "output", true); + // exit(0); // cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, // num_inputs, 1, emb_ffn1_.weights_int8, num_inputs, // (int8_t*)temp, num_inputs, 0, (int32_t*)buffer2, @@ -2696,9 +2750,9 @@ void AttentionBody::Eval(int N, DataType* output, // 3. Dequantize and bias add (mixed precision) (buffer2 -> temp) deQuantizeOutputMatrixBiasAdd( - (int8_t*)temp, (int32_t*)buffer2, batch, num_outputs, 1, nullptr, - emb_ffn1_.output_scaling_factor, (half*)ip_emb_ffn_d1_b_, - activations_.ffn_activation, emb_ffn2_.input_scaling_factor, + (int8_t*)temp, (int8_t*)buffer2, batch, num_outputs, 1, nullptr, + emb_ffn1_.output_dequant_factor, (half*)ip_emb_ffn_d1_b_, + activations_.ffn_activation, emb_ffn2_.input_quantize_factor, stream); ReportCUDAErrors(cudaGetLastError()); @@ -2708,14 +2762,19 @@ void AttentionBody::Eval(int N, DataType* output, // num_outputs, activations_.ffn_activation, stream); } else { - if (is_quantized_ && clipInputActivations) { + if (is_quantized_ && clipQuantizedActivations) { clipActivationMatrix( - (DataType*)temp, (const DataType*)embedding, - emb_ffn1_.input_scaling_factor, batch, num_inputs, stream); + (DataType*)temp, (const DataType*)embedding, nullptr, + emb_ffn1_.input_quantize_factor, batch, num_inputs, stream); } cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, (const DataType*)ip_emb_ffn_d1_w_, num_inputs, temp, num_inputs, 0.0f, buffer1, num_outputs); + if (is_quantized_ && clipQuantizedActivations) { + clipActivationMatrix( + (DataType*)buffer1, (const DataType*)buffer1, nullptr, + emb_ffn1_.output_dequant_factor, batch, num_outputs, stream); + } addBiasBatched(buffer1, buffer1, ip_emb_ffn_d1_b_, 1, batch, num_outputs, activations_.ffn_activation, stream); } @@ -2729,13 +2788,14 @@ void AttentionBody::Eval(int N, DataType* output, if (is_quantized_ && int8_inf_) { // 1. quantize (buffer1 -> temp) // quantizeActivationMatrix((int8_t*)temp, (const half*)buffer1, batch, - // num_inputs, emb_ffn2_.input_scaling_factor, + // num_inputs, emb_ffn2_.input_quantize_factor, // stream); // 2. int8 matmul (temp -> buffer2) - cutlassMatrixMulBTransposed((const int8_t*)temp, emb_ffn2_.weights_int8, - (int32_t*)buffer2, batch, num_outputs, - num_inputs, 1, 0, 0, 0, 1.0f, 0.0f); + cutlassMatrixMulBTransposed( + (const int8_t*)temp, (const int8_t*)emb_ffn2_.weights_int8, + (int8_t*)buffer2, batch, num_outputs, num_inputs, 1, 0, 0, 0, + emb_ffn2_.output_scaling_factor, 0.0f); // cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, // num_inputs, 1, emb_ffn2_.weights_int8, num_inputs, @@ -2745,23 +2805,29 @@ void AttentionBody::Eval(int N, DataType* output, // Embedding LN: skip connection and layer normalization (also bias add // of prev gemm) (buffer2 -> embedding/output_tensor) float alpha = (float)pow(2. * encoder_weights_.size(), -0.25); - LayerNorm( - batch, num_outputs, output_tensor, (int32_t*)buffer2, + LayerNorm( + batch, num_outputs, output_tensor, (int8_t*)buffer2, ip_emb_ffn_d2_b_, embedding, ip_emb_ffn_ln_g_, ip_emb_ffn_ln_b_, 1e-3, alpha, ACTIVATION_NONE, stream, - emb_ffn2_.output_scaling_factor); + emb_ffn2_.output_dequant_factor); } else { - if (is_quantized_ && clipInputActivations) { + if (is_quantized_ && clipQuantizedActivations) { clipActivationMatrix( - (DataType*)buffer1, (const DataType*)buffer1, - emb_ffn2_.input_scaling_factor, batch, num_inputs, stream); + (DataType*)buffer1, (const DataType*)buffer1, nullptr, + emb_ffn2_.input_quantize_factor, batch, num_inputs, stream); } cublasXgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, num_inputs, 1.0f, ip_emb_ffn_d2_w_, num_inputs, buffer1, num_inputs, 0.0f, buffer2, num_outputs); + if (is_quantized_ && clipQuantizedActivations) { + clipActivationMatrix( + (DataType*)buffer2, (const DataType*)buffer2, nullptr, + emb_ffn2_.output_dequant_factor, batch, num_outputs, stream); + } + // Embedding LN: skip connection and layer normilization (also bias add // of prev gemm) buffer2 -> embedding float alpha = (float)pow(2. * encoder_weights_.size(), -0.25); diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index 789c5db819..8ec2f72f3f 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -337,12 +337,14 @@ class ResidualBlock : public BaseLayer { // (in GPU memory when int8_inf_ is set, otherwise in CPU memory) struct MatMulQuantizationData { int8_t* weights_int8; // int8 quantized weights - float* input_scaling_factors; // per-column scaling factors for input + float* input_quantize_factors; // per-column scaling factors for input // quantization float* output_scaling_factors; // per-column scaling factors for output // dequantization float output_scaling_factor; // single value output dequantization factor - float input_scaling_factor; // single value input quantization factor + float input_quantize_factor; // single value input quantization factor + float output_dequant_factor; + float* output_dequant_factors; }; diff --git a/src/neural/network_legacy.cc b/src/neural/network_legacy.cc index 61f126f862..1a72172f74 100644 --- a/src/neural/network_legacy.cc +++ b/src/neural/network_legacy.cc @@ -146,9 +146,13 @@ BaseWeights::MHA::MHA(const pblczero::Weights::MHA& mha) q_s(LayerAdapter(mha.q_s()).as_vector()), k_s(LayerAdapter(mha.k_s()).as_vector()), v_s(LayerAdapter(mha.v_s()).as_vector()), + q_out_s(LayerAdapter(mha.q_out_s()).as_vector()), + k_out_s(LayerAdapter(mha.k_out_s()).as_vector()), + v_out_s(LayerAdapter(mha.v_out_s()).as_vector()), s1(LayerAdapter(mha.s1()).as_vector()), s2(LayerAdapter(mha.s2()).as_vector()), dense_s(LayerAdapter(mha.dense_s()).as_vector()), + dense_out_s(LayerAdapter(mha.dense_out_s()).as_vector()), has_int8(mha.has_s1()) {} BaseWeights::FFN::FFN(const pblczero::Weights::FFN& ffn) @@ -157,9 +161,11 @@ BaseWeights::FFN::FFN(const pblczero::Weights::FFN& ffn) dense2_w(LayerAdapter(ffn.dense2_w()).as_vector()), dense2_b(LayerAdapter(ffn.dense2_b()).as_vector()), s1(LayerAdapter(ffn.s1()).as_vector()), - s2(LayerAdapter(ffn.s2()).as_vector()), dense1_s(LayerAdapter(ffn.dense1_s()).as_vector()), - dense2_s(LayerAdapter(ffn.dense2_s()).as_vector()) {} + dense1_out_s(LayerAdapter(ffn.dense1_out_s()).as_vector()), + s2(LayerAdapter(ffn.s2()).as_vector()), + dense2_s(LayerAdapter(ffn.dense2_s()).as_vector()), + dense2_out_s(LayerAdapter(ffn.dense2_out_s()).as_vector()) {} BaseWeights::EncoderLayer::EncoderLayer( const pblczero::Weights::EncoderLayer& encoder) diff --git a/src/neural/network_legacy.h b/src/neural/network_legacy.h index 4c443a95ef..5f72f28237 100644 --- a/src/neural/network_legacy.h +++ b/src/neural/network_legacy.h @@ -85,9 +85,13 @@ struct BaseWeights { Vec q_s; Vec k_s; Vec v_s; + Vec q_out_s; + Vec k_out_s; + Vec v_out_s; Vec s1; Vec s2; Vec dense_s; + Vec dense_out_s; }; struct FFN { @@ -98,9 +102,12 @@ struct BaseWeights { Vec dense2_b; bool has_int8; Vec s1; - Vec s2; Vec dense1_s; + Vec dense1_out_s; + + Vec s2; Vec dense2_s; + Vec dense2_out_s; }; struct EncoderLayer { From 870ebba19979013d005416ed9a2e5ed243667ecd Mon Sep 17 00:00:00 2001 From: almaudoh Date: Thu, 16 May 2024 13:14:59 +0200 Subject: [PATCH 68/70] Remove epsilon from quantize --- src/neural/cuda/cutlass_kernels.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/neural/cuda/cutlass_kernels.cu b/src/neural/cuda/cutlass_kernels.cu index 3e6444369a..777a8ad4cb 100644 --- a/src/neural/cuda/cutlass_kernels.cu +++ b/src/neural/cuda/cutlass_kernels.cu @@ -914,7 +914,7 @@ __global__ void quantizeMatrix(int8_t* output, const half* input, int height, } for (int i = 0; i < 8; i++) { - float val = roundf((float)ip[i] / (factor[i] + 1e-5f)); + float val = roundf((float)ip[i] / factor[i]); if (val > 127) val = 127; if (val < -128) val = -128; op[i] = (int8_t)(val); @@ -951,7 +951,7 @@ __global__ void clipMatrix(T* output, const T* input, const float* scale_factors } float val = (float)input[y * width + x]; - val /= (1e-5f + factor); + val /= factor; val = roundf(val); if (val > 127.0f) val = 127.0f; if (val < -128.0f) val = -128.0f; @@ -1032,7 +1032,7 @@ __global__ void deQuantizeMatrix(OT* output, const IT* input, const half* bias, val *= inv_scale[i]; if (bias) val += (float)bi[i]; val = activate(val, act); - val = roundf(val / (nextInputScale + 1e-5f)); + val = roundf(val / nextInputScale); if (val > 127) val = 127; if (val < -128) val = -128; op[i] = (OT)(val); From 80cc6737112447b65756d966e84637fbdcf73efb Mon Sep 17 00:00:00 2001 From: almaudoh Date: Sun, 2 Jun 2024 04:22:24 +0200 Subject: [PATCH 69/70] Split QKV to allow use of int8->int8 cutlass matmul. --- src/neural/cuda/cutlass_kernels.cu | 12 +++--- src/neural/cuda/layers.cc | 67 +++++++++++++++++++++++------- src/neural/cuda/layers.h | 3 ++ 3 files changed, 62 insertions(+), 20 deletions(-) diff --git a/src/neural/cuda/cutlass_kernels.cu b/src/neural/cuda/cutlass_kernels.cu index 777a8ad4cb..772b4c88bf 100644 --- a/src/neural/cuda/cutlass_kernels.cu +++ b/src/neural/cuda/cutlass_kernels.cu @@ -228,15 +228,15 @@ void dumpTensor(const T* memory, int elements, const char* message, } } - if (!only_summary || i < 20 || i == elements - 1) { + if (!only_summary || i < 7 || i == elements - 1) { if (int8 || int32) { - // printf("%6i ", (int8_t)val); - printf("%i;%8i\n", i, (int)val); + printf("%6i ", (int)val); + // printf("%i;%8i\n", i, (int)val); } else { - // printf("%8.6f ", val); - printf("%i;%8.6f\n", i, val); + printf("%8.6f ", val); + // printf("%i;%8.6f\n", i, val); } - // if ((i % 8) == 7 || i == elements - 1) printf("\n"); + if ((i % 8) == 7 || i == elements - 1) printf("\n"); } } if (!cpu_tensor) free(temp); diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 1434719fcc..8afea8e829 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -1616,6 +1616,10 @@ static void LoadQKVQuantizationData(MatMulQuantizationData& data, cudaStream_t stream) { // Load weights for INT8 inference. + float avwfactor = + (q_weight_factors[0] + k_weight_factors[0] + v_weight_factors[0]) / 3; + float avofactor = + (q_output_factors[0] + k_output_factors[0] + v_output_factors[0]) / 3; if (input_factors.size() > 1) { // ReportCUDAErrors(cudaMemcpy( // data.output_scaling_factors, input_factors.data(), @@ -1645,6 +1649,12 @@ static void LoadQKVQuantizationData(MatMulQuantizationData& data, output_len); fillGpuArray(data.output_dequant_factors + output_len * 2, v_output_factors[0], output_len); + + // Output scaling factor = input factor x weight factor. + data.output_scaling_factor = input_factors[0] * avwfactor / avofactor; + + // Output scale. + data.output_dequant_factor = avofactor; } // Load QKV weights and run a GPU kernel to scale them. @@ -1659,6 +1669,9 @@ static void LoadQKVQuantizationData(MatMulQuantizationData& data, quantizeActivationMatrix(data.weights_int8 + weights_len * 2, qkv_weights + weights_len * 2, 1, weights_len, v_weight_factors[0], stream); + // quantizeActivationMatrix(data.weights_int8, qkv_weights, 1, weights_len * + // 3, + // avwfactor, stream); // The original weights also need to be clipped for fp16 inference // q weights. @@ -1793,6 +1806,15 @@ EncoderBlock::EncoderBlock( cpu_weights.mha.q_s, cpu_weights.mha.k_s, cpu_weights.mha.v_s, cpu_weights.mha.s1, cpu_weights.mha.q_out_s, cpu_weights.mha.k_out_s, cpu_weights.mha.v_out_s, 0); + LoadQuantizationData(q_, (half*)mha_q_w, embedding_op_size_, mha_q_size_, + cpu_weights.mha.q_s, cpu_weights.mha.s1, + cpu_weights.mha.q_out_s, 0); + LoadQuantizationData(k_, (half*)mha_k_w, embedding_op_size_, mha_k_size_, + cpu_weights.mha.k_s, cpu_weights.mha.s1, + cpu_weights.mha.k_out_s, 0); + LoadQuantizationData(v_, (half*)mha_v_w, embedding_op_size_, mha_v_size_, + cpu_weights.mha.v_s, cpu_weights.mha.s1, + cpu_weights.mha.v_out_s, 0); LoadQuantizationData(mha_dense_, (half*)mha_dense_w, embedding_op_size_, mha_dense_size_, cpu_weights.mha.dense_s, cpu_weights.mha.s2, cpu_weights.mha.dense_out_s, 0); @@ -1848,6 +1870,7 @@ static void cublasXGemmStridedBatched( const bool int8 = std::is_same::value; const bool fp16 = std::is_same::value; const bool out_int32 = std::is_same::value; + const bool out_int8 = std::is_same::value; if (int8 && out_int32) { int32_t alpha_i = (int32_t)alpha; int32_t beta_i = (int32_t)beta; @@ -1855,6 +1878,11 @@ static void cublasXGemmStridedBatched( handle, transa, transb, m, n, k, &alpha_i, A, CUDA_R_8I, lda, strideA, B, CUDA_R_8I, ldb, strideB, &beta_i, C, CUDA_R_32I, ldc, strideC, batchCount, CUBLAS_COMPUTE_32I, CUBLAS_GEMM_DEFAULT)); + } else if (int8 && out_int8) { + ReportCUBLASErrors(cublasGemmStridedBatchedEx( + handle, transa, transb, m, n, k, &alpha, A, CUDA_R_8I, lda, strideA, B, + CUDA_R_8I, ldb, strideB, &beta, C, CUDA_R_8I, ldc, strideC, batchCount, + CUBLAS_COMPUTE_32I, CUBLAS_GEMM_DEFAULT)); } else if (fp16) { unsigned short alpha_h = FP32toFP16(alpha); unsigned short beta_h = FP32toFP16(beta); @@ -1984,12 +2012,12 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, mha_k = mha_q + num_outputs * batch_to_use; mha_v = mha_k + num_outputs * batch_to_use; - if (false && is_quantized_ && int8_inf_) { + if (is_quantized_ && int8_inf_) { // 1. quantize the inputs (in_out_tensor -> scratch) // TODO: Fuse this step with layer-norm of previous block quantizeActivationMatrix((int8_t*)scratch, (const half*)in_out_tensor, - batch, embedding_op_size_, - qkv_.input_quantize_factor, stream); + batch, num_inputs, qkv_.input_quantize_factor, + stream); // 2. perform int8 GEMM (scratch -> buffer1) // cutlassMatrixMulBTransposed((const int8_t*)scratch, qkv_.weights_int8, @@ -1997,17 +2025,28 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, // num_inputs, 3, 0, num_inputs * num_outputs, // num_outputs * batch_to_use, 1.0f, 0.0f); - cublasXGemmStridedBatched( - cublas, CUBLAS_OP_T, CUBLAS_OP_N, num_outputs, batch, - num_inputs, 1.0f, qkv_.weights_int8, num_inputs, num_inputs * - num_outputs, (const int8_t*)scratch, num_inputs, 0, 0.0f, - (int32_t*)buffer1, num_outputs, num_outputs * batch_to_use, 3); - - // 3. Dequantize and bias add - mixed precision (buffer1 -> mha_q) - deQuantizeOutputMatrixBiasAdd((half*)mha_q, (int32_t*)buffer1, batch_to_use, - num_outputs, 3, qkv_.output_scaling_factors, - 1.0f, (half*)mha_qkv_b, ACTIVATION_NONE, 1.0f, - stream); + int ostride = num_outputs * batch_to_use; + cutlassMatrixMulBTransposed( + (const int8_t*)scratch, (const int8_t*)q_.weights_int8, + (int8_t*)buffer1, batch, num_outputs, num_inputs, 1, 0, 0, 0, + q_.output_scaling_factor, 0.0f); + + cutlassMatrixMulBTransposed( + (const int8_t*)scratch, (const int8_t*)k_.weights_int8, + (int8_t*)buffer1 + ostride, batch, num_outputs, num_inputs, 1, 0, 0, + 0, k_.output_scaling_factor, 0.0f); + + cutlassMatrixMulBTransposed( + (const int8_t*)scratch, (const int8_t*)v_.weights_int8, + (int8_t*)buffer1 + 2 * ostride, batch, num_outputs, num_inputs, 1, 0, + 0, 0, v_.output_scaling_factor, 0.0f); + + // 3. Dequantize and bias add - mixed precision (buffer1 -> + // scratch/mha_q/mha_k/mha_v) + deQuantizeOutputMatrixBiasAdd( + (half*)scratch, (int8_t*)buffer1, batch_to_use, num_outputs, 3, + qkv_.output_dequant_factors, 1.0f, (half*)mha_qkv_b, ACTIVATION_NONE, + 1.0f, stream); // 3. Bias add - mixed precision (buffer1 -> mha_q) // addBiasBatched(mha_q, (float*)buffer1, mha_qkv_b, 3, diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index 8ec2f72f3f..4132db1fbf 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -389,6 +389,9 @@ class EncoderBlock { int blockIndex_; bool int8_inf_, is_quantized_; MatMulQuantizationData qkv_; + MatMulQuantizationData q_; + MatMulQuantizationData k_; + MatMulQuantizationData v_; MatMulQuantizationData mha_dense_; MatMulQuantizationData ffn1_; MatMulQuantizationData ffn2_; From 7fceeebbb91091a98af331a878679689735daf5f Mon Sep 17 00:00:00 2001 From: almaudoh Date: Sun, 2 Jun 2024 09:11:15 +0200 Subject: [PATCH 70/70] Implement fused QKV with averaged scaling factors. --- src/neural/cuda/layers.cc | 42 +++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 8afea8e829..1698922ad5 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -2020,32 +2020,32 @@ void EncoderBlock::Eval(int N, DataType* in_out_tensor, stream); // 2. perform int8 GEMM (scratch -> buffer1) - // cutlassMatrixMulBTransposed((const int8_t*)scratch, qkv_.weights_int8, - // (int32_t*)buffer1, batch, num_outputs, - // num_inputs, 3, 0, num_inputs * num_outputs, - // num_outputs * batch_to_use, 1.0f, 0.0f); - - int ostride = num_outputs * batch_to_use; - cutlassMatrixMulBTransposed( - (const int8_t*)scratch, (const int8_t*)q_.weights_int8, - (int8_t*)buffer1, batch, num_outputs, num_inputs, 1, 0, 0, 0, - q_.output_scaling_factor, 0.0f); - - cutlassMatrixMulBTransposed( - (const int8_t*)scratch, (const int8_t*)k_.weights_int8, - (int8_t*)buffer1 + ostride, batch, num_outputs, num_inputs, 1, 0, 0, - 0, k_.output_scaling_factor, 0.0f); - - cutlassMatrixMulBTransposed( - (const int8_t*)scratch, (const int8_t*)v_.weights_int8, - (int8_t*)buffer1 + 2 * ostride, batch, num_outputs, num_inputs, 1, 0, - 0, 0, v_.output_scaling_factor, 0.0f); + cutlassMatrixMulBTransposed((const int8_t*)scratch, qkv_.weights_int8, + (int8_t*)buffer1, batch, num_outputs, + num_inputs, 3, 0, num_inputs * num_outputs, + num_outputs * batch_to_use, qkv_.output_scaling_factor, 0.0f); + + // int ostride = num_outputs * batch_to_use; + // cutlassMatrixMulBTransposed( + // (const int8_t*)scratch, (const int8_t*)q_.weights_int8, + // (int8_t*)buffer1, batch, num_outputs, num_inputs, 1, 0, 0, 0, + // q_.output_scaling_factor, 0.0f); + + // cutlassMatrixMulBTransposed( + // (const int8_t*)scratch, (const int8_t*)k_.weights_int8, + // (int8_t*)buffer1 + ostride, batch, num_outputs, num_inputs, 1, 0, 0, + // 0, k_.output_scaling_factor, 0.0f); + + // cutlassMatrixMulBTransposed( + // (const int8_t*)scratch, (const int8_t*)v_.weights_int8, + // (int8_t*)buffer1 + 2 * ostride, batch, num_outputs, num_inputs, 1, 0, + // 0, 0, v_.output_scaling_factor, 0.0f); // 3. Dequantize and bias add - mixed precision (buffer1 -> // scratch/mha_q/mha_k/mha_v) deQuantizeOutputMatrixBiasAdd( (half*)scratch, (int8_t*)buffer1, batch_to_use, num_outputs, 3, - qkv_.output_dequant_factors, 1.0f, (half*)mha_qkv_b, ACTIVATION_NONE, + nullptr, qkv_.output_dequant_factor, (half*)mha_qkv_b, ACTIVATION_NONE, 1.0f, stream); // 3. Bias add - mixed precision (buffer1 -> mha_q)