Skip to content

Commit

Permalink
Merge pull request karpathy#283 from ChrisDryden/configurationRewrite
Browse files Browse the repository at this point in the history
Changed ordering of type configuration to easily see unchanged values
  • Loading branch information
karpathy authored Apr 30, 2024
2 parents 9978649 + cf3e6ef commit 9464f42
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,20 @@ enum PrecisionMode {
PRECISION_BF16
};

// fp32
// Default Properties
typedef float floatN;
#define CUBLAS_LOWP_COMPUTE cublas_compute_type
#ifdef MULTI_GPU
const ncclDataType_t ncclFloatN = ncclFloat;
#endif

// Specific configurations based on the enabled precision
#if defined(ENABLE_FP32)
typedef float floatX;
#define CUBLAS_LOWP CUDA_R_32F
#define CUBLAS_LOWP_COMPUTE cublas_compute_type // auto-select FP32 vs TF32
const char* load_filename = "gpt2_124M.bin"; // fp32 weights
PrecisionMode PRECISION_MODE = PRECISION_FP32;
#define PRECISION_MODE PRECISION_FP32
const char* load_filename = "gpt2_124M.bin";
const char* precision_mode_str = "fp32";

#ifdef MULTI_GPU
const ncclDataType_t ncclFloatX = ncclFloat;
#endif
Expand All @@ -80,24 +85,19 @@ const ncclDataType_t ncclFloatX = ncclFloat;
#elif defined(ENABLE_FP16)
typedef half floatX;
#define CUBLAS_LOWP CUDA_R_16F
#define CUBLAS_LOWP_COMPUTE CUBLAS_COMPUTE_32F
const char* load_filename = "gpt2_124M.bin"; // fp32 weights
PrecisionMode PRECISION_MODE = PRECISION_FP16;
#define PRECISION_MODE PRECISION_FP16
const char* load_filename = "gpt2_124M.bin";
const char* precision_mode_str = "fp16";

#ifdef MULTI_GPU
const ncclDataType_t ncclFloatX = ncclHalf;
#endif

// bfloat16 (default!)
#else
#else // Default to bfloat16
typedef __nv_bfloat16 floatX;
#define CUBLAS_LOWP CUDA_R_16BF
#define CUBLAS_LOWP_COMPUTE CUBLAS_COMPUTE_32F
const char* load_filename = "gpt2_124M_bf16.bin"; // bf16 weights
PrecisionMode PRECISION_MODE = PRECISION_BF16;
#define PRECISION_MODE PRECISION_BF16
const char* load_filename = "gpt2_124M_bf16.bin"; // bf16 weights specific filename
const char* precision_mode_str = "bf16";

#ifdef MULTI_GPU
const ncclDataType_t ncclFloatX = ncclBfloat16;
#endif
Expand Down

0 comments on commit 9464f42

Please sign in to comment.