diff --git a/.github/workflows/ci_gpu.yml b/.github/workflows/ci_gpu.yml index d1515066e..ac2e1f48e 100644 --- a/.github/workflows/ci_gpu.yml +++ b/.github/workflows/ci_gpu.yml @@ -31,19 +31,19 @@ jobs: run: python train_gpt2.py - name: Compile training and testing program - run: make test_gpt2cu train_gpt2cu test_gpt2fp32cu train_gpt2fp32cu + run: make test_gpt2cu train_gpt2cu test_gpt2fp32cu train_gpt2fp32cu - name: Train model (With OpenMP) run: OMP_NUM_THREADS=8 ./train_gpt2cu - name: Train model (FP32) with gpt2_124M.bin - run: | + run: | PRECISION=FP32 make train_gpt2cu ./train_gpt2cu -b 4 -t 64 -l 1e-4 -v 200 -s 200 -a 1 -x 10 -e gpt2_124M.bin - + - name: Build FP32 precision run: PRECISION=FP32 make test_gpt2cu profile_gpt2cu - + - name: Run default run: ./test_gpt2cu @@ -52,7 +52,7 @@ jobs: - name: Run recompute LN run: ./test_gpt2cu -r 2 - + - name: Build BF16 precision run: PRECISION=BF16 make train_gpt2cu test_gpt2cu profile_gpt2cu @@ -67,18 +67,18 @@ jobs: - name: Run recompute LN run: ./test_gpt2cu -r 2 - + - name: Train model fp32 (With OpenMP) run: OMP_NUM_THREADS=8 ./train_gpt2fp32cu - name: Execute testing program (With OpenMP) run: OMP_NUM_THREADS=8 ./test_gpt2cu - + - name: Execute testing program fp32 (With OpenMP) run: OMP_NUM_THREADS=8 ./test_gpt2fp32cu - name: Compile training and testing program without OpenMP - run: NO_OMP=1 make test_gpt2cu train_gpt2cu test_gpt2fp32cu train_gpt2fp32cu + run: NO_OMP=1 make test_gpt2cu train_gpt2cu test_gpt2fp32cu train_gpt2fp32cu - name: Train model (No OpenMP) run: NO_OMP=1 ./train_gpt2cu @@ -88,14 +88,14 @@ jobs: - name: Execute testing program (No OpenMP) run: ./test_gpt2cu -b 32 - + - name: Execute testing program fp32 (No OpenMP) run: ./test_gpt2fp32cu - name: Install cuDNN-frontend - run: + run: git clone https://github.com/NVIDIA/cudnn-frontend.git - + - name: Build with cuDNN run: USE_CUDNN=1 make test_gpt2cu train_gpt2cu test_gpt2fp32cu train_gpt2fp32cu @@ -110,3 +110,13 @@ jobs: - name: Execute testing program fp32 with cuDNN run: ./test_gpt2fp32cu + + unit-tests-gpu: + runs-on: ubicloud-gpu-standard-1-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Test Device<->File IO + run: cd dev/test && nvcc -o device_file_io device_file_io.cu && ./device_file_io diff --git a/.github/workflows/ci_tests.yml b/.github/workflows/ci_tests.yml new file mode 100644 index 000000000..81aaace1c --- /dev/null +++ b/.github/workflows/ci_tests.yml @@ -0,0 +1,100 @@ +name: Unit, Static and other Tests + +on: + create: + workflow_dispatch: + push: + branches: + - master + pull_request: + branches: + - master + +jobs: + dataloader_test: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: test the dataloader without / with sanitize address + run: | + cd dev/test + make PRECISION=BF16 test_dataloader + ./test_dataloader + make clean + make PRECISION=BF16 TEST_CFLAGS="-fsanitize=address -fno-omit-frame-pointer" test_dataloader + ./test_dataloader + + ptx_and_sass_files: + runs-on: ubuntu-latest + container: + image: nvidia/cuda:12.4.1-devel-ubuntu22.04 + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install OpenMP and OpenMPI + run: apt-get update && apt-get install -y libomp-dev libopenmpi-dev + + - name: Generate ptx/sass files and upload them to persistent storage + run: | + mkdir -p dev/cuda/ptx_sass_logs + make train_gpt2cu + cuobjdump --dump-ptx train_gpt2cu > dev/cuda/train_gpt2cu.ptx + cuobjdump --dump-sass train_gpt2cu > dev/cuda/train_gpt2cu.sass + cd dev/cuda + make -j all_ptx + make -j all_sass + cp *.ptx ptx_sass_logs/ + cp *.sass ptx_sass_logs/ + ls ptx_sass_logs/ + + - name: Generate ptx/sass files for A100 and upload them to persistent storage + run: | + mkdir -p dev/cuda/ptx_sass_logs_A100 + make train_gpt2cu GPU_COMPUTE_CAPABILITY=80 + cuobjdump --dump-ptx train_gpt2cu > dev/cuda/train_gpt2cu.ptx + cuobjdump --dump-sass train_gpt2cu > dev/cuda/train_gpt2cu.sass + cd dev/cuda + make -j GPU_COMPUTE_CAPABILITY=80 all_ptx + make -j GPU_COMPUTE_CAPABILITY=80 all_sass + cp *.ptx ptx_sass_logs_A100/ + cp *.sass ptx_sass_logs_A100/ + ls ptx_sass_logs_A100/ + + - name: Generate ptx/sass files for H100 and upload them to persistent storage + run: | + mkdir -p dev/cuda/ptx_sass_logs_H100 + make train_gpt2cu GPU_COMPUTE_CAPABILITY=90 + cuobjdump --dump-ptx train_gpt2cu > dev/cuda/train_gpt2cu.ptx + cuobjdump --dump-sass train_gpt2cu > dev/cuda/train_gpt2cu.sass + cd dev/cuda + make -j GPU_COMPUTE_CAPABILITY=90 all_ptx + make -j GPU_COMPUTE_CAPABILITY=90 all_sass + cp *.ptx ptx_sass_logs_H100/ + cp *.sass ptx_sass_logs_H100/ + ls ptx_sass_logs_H100/ + + - name: Upload ptx/sass files + uses: actions/upload-artifact@v4 + with: + name: ptx_sass_files + path: dev/cuda/ptx_sass_logs/ + retention-days: 30 # days to retain + + - name: Upload ptx/sass files for A100 + uses: actions/upload-artifact@v4 + with: + name: ptx_sass_files_A100 + path: dev/cuda/ptx_sass_logs_A100/ + retention-days: 30 # days to retain + + - name: Upload ptx/sass files for H100 + uses: actions/upload-artifact@v4 + with: + name: ptx_sass_files_H100 + path: dev/cuda/ptx_sass_logs_H100/ + retention-days: 30 # days to retain \ No newline at end of file diff --git a/Makefile b/Makefile index 9ea6866df..40721edba 100644 --- a/Makefile +++ b/Makefile @@ -188,27 +188,41 @@ else endif endif -# Check if OpenMPI and NCCL are available, include them if so, for multi-GPU training +# Check if NCCL is available, include if so, for multi-GPU training ifeq ($(NO_MULTI_GPU), 1) - $(info → Multi-GPU (OpenMPI + NCCL) is manually disabled) + $(info → Multi-GPU (NCCL) is manually disabled) else ifneq ($(OS), Windows_NT) # Detect if running on macOS or Linux ifeq ($(SHELL_UNAME), Darwin) - $(info ✗ Multi-GPU on CUDA on Darwin is not supported, skipping OpenMPI + NCCL support) - else ifeq ($(shell [ -d /usr/lib/x86_64-linux-gnu/openmpi/lib/ ] && [ -d /usr/lib/x86_64-linux-gnu/openmpi/include/ ] && echo "exists"), exists) - $(info ✓ OpenMPI found, OK to train with multiple GPUs) - NVCC_INCLUDES += -I/usr/lib/x86_64-linux-gnu/openmpi/include - NVCC_LDFLAGS += -L/usr/lib/x86_64-linux-gnu/openmpi/lib/ - NVCC_LDLIBS += -lmpi -lnccl + $(info ✗ Multi-GPU on CUDA on Darwin is not supported, skipping NCCL support) + else ifeq ($(shell dpkg -l | grep -q nccl && echo "exists"), exists) + $(info ✓ NCCL found, OK to train with multiple GPUs) NVCC_FLAGS += -DMULTI_GPU + NVCC_LDLIBS += -lnccl else - $(info ✗ OpenMPI is not found, disabling multi-GPU support) - $(info ---> On Linux you can try install OpenMPI with `sudo apt install openmpi-bin openmpi-doc libopenmpi-dev`) + $(info ✗ NCCL is not found, disabling multi-GPU support) + $(info ---> On Linux you can try install NCCL with `sudo apt install libnccl2 libnccl-dev`) endif endif endif +# Attempt to find and include OpenMPI on the system +OPENMPI_DIR ?= /usr/lib/x86_64-linux-gnu/openmpi +OPENMPI_LIB_PATH = $(OPENMPI_DIR)/lib/ +OPENMPI_INCLUDE_PATH = $(OPENMPI_DIR)/include/ +ifeq ($(NO_USE_MPI), 1) + $(info → MPI is manually disabled) +else ifeq ($(shell [ -d $(OPENMPI_LIB_PATH) ] && [ -d $(OPENMPI_INCLUDE_PATH) ] && echo "exists"), exists) + $(info ✓ MPI enabled) + NVCC_INCLUDES += -I$(OPENMPI_INCLUDE_PATH) + NVCC_LDFLAGS += -L$(OPENMPI_LIB_PATH) + NVCC_LDLIBS += -lmpi + NVCC_FLAGS += -DUSE_MPI +else + $(info ✗ MPI not found) +endif + # Precision settings, default to bf16 but ability to override PRECISION ?= BF16 VALID_PRECISIONS := FP32 FP16 BF16 @@ -266,5 +280,5 @@ profile_gpt2cu: profile_gpt2.cu $(NVCC_CUDNN) $(NVCC) $(NVCC_FLAGS) $(PFLAGS) -lineinfo $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) clean: - $(REMOVE_FILES) $(TARGETS) + $(REMOVE_FILES) $(TARGETS) $(REMOVE_BUILD_OBJECT_FILES) diff --git a/README.md b/README.md index d3dea874f..4ba884345 100644 --- a/README.md +++ b/README.md @@ -13,28 +13,34 @@ debugging tip: when you run the `make` command to build the binary, modify it by If you won't be training on multiple nodes, aren't interested in mixed precision, and are interested in learning CUDA, the fp32 (legacy) files might be of interest to you. These are files that were "checkpointed" early in the history of llm.c and frozen in time. They are simpler, more portable, and possibly easier to understand. Run the 1 GPU, fp32 code like this: ```bash -pip install -r requirements.txt -python dev/data/tinyshakespeare.py -python train_gpt2.py +chmod u+x ./dev/download_starter_pack.sh +./dev/download_starter_pack.sh make train_gpt2fp32cu ./train_gpt2fp32cu ``` -The above lines (1) download the [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) dataset, tokenize it with the GPT-2 Tokenizer, (2) download and save the GPT-2 (124M) weights, (3) init from them in C/CUDA and train for one epoch on tineshakespeare with AdamW (using batch size 4, context length 1024, total of 74 steps), evaluate validation loss, and sample some text. +The download_starter_pack.sh script is a quick & easy way to get started and it downloads a bunch of .bin files that help get you off the ground. These contain: 1) the GPT-2 124M model saved in fp32, in bfloat16, 2) a "debug state" used in unit testing (a small batch of data, and target activations and gradients), 3) the GPT-2 tokenizer, and 3) the tokenized [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) dataset. Alternatively, instead of running the .sh script, you can re-create these artifacts manually as follows: + +```bash +pip install -r requirements.txt +python dev/data/tinyshakespeare.py +python train_gpt2.py +``` ## quick start (CPU) The "I am so GPU poor that I don't even have one GPU" section. You can still enjoy seeing llm.c train! But you won't go too far. Just like the fp32 version above, the CPU version is an even earlier checkpoint in the history of llm.c, back when it was just a simple reference implementation in C. For example, instead of training from scratch, you can finetune a GPT-2 small (124M) to output Shakespeare-like text, as an example: ```bash -pip install -r requirements.txt -python dev/data/tinyshakespeare.py -python train_gpt2.py +chmod u+x ./dev/download_starter_pack.sh +./dev/download_starter_pack.sh make train_gpt2 OMP_NUM_THREADS=8 ./train_gpt2 ``` -The above lines (1) download the [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) dataset, tokenize it with the GPT-2 Tokenizer, (2) download and save the GPT-2 (124M) weights, (3) init from them in C and train for 40 steps on tineshakespeare with AdamW (using batch size 4, context length only 64), evaluate validation loss, and sample some text. Honestly, unless you have a beefy CPU (and can crank up the number of OMP threads in the launch command), you're not going to get that far on CPU training LLMs, but it might be a good demo/reference. The output looks like this on my MacBook Pro (Apple Silicon M3 Max): +If you'd prefer to avoid running the starter pack script, then as mentioned in the previous section you can reproduce the exact same .bin files and artifacts by running `python dev/data/tinyshakespeare.py` and then `python train_gpt2.py`. + +The above lines (1) download an already tokenized [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) dataset and download the GPT-2 (124M) weights, (3) init from them in C and train for 40 steps on tineshakespeare with AdamW (using batch size 4, context length only 64), evaluate validation loss, and sample some text. Honestly, unless you have a beefy CPU (and can crank up the number of OMP threads in the launch command), you're not going to get that far on CPU training LLMs, but it might be a good demo/reference. The output looks like this on my MacBook Pro (Apple Silicon M3 Max): ``` [GPT-2] @@ -128,12 +134,16 @@ sudo apt-get -y install libcudnn9-dev-cuda-12 On top of this you need the [cuDNN frontend](https://github.com/NVIDIA/cudnn-frontend/tree/main), but this is just header files. Simply clone the repo to your disk. The Makefile currently looks for it in either your home directory or the current directory. If you have put it elsewhere, add `CUDNN_FRONTEND_PATH=/path/to/your/cudnn-frontend/include` to the `make` command-line. -**multi-GPU training**. As of April 26, 2024 there is now also support for multi-GPU training using MPI and NCCL. Make sure you install MPI, e.g. on Linux: +## multi-GPU training + +Make sure you install MPI and NCCL, e.g. on Linux: ```bash sudo apt install openmpi-bin openmpi-doc libopenmpi-dev ``` +For NCCL follow the instructions from the [official website](https://developer.nvidia.com/nccl/nccl-download) (e.g. network installer) + and then: ```bash @@ -141,6 +151,23 @@ make train_gpt2cu mpirun -np ./train_gpt2cu ``` +or simply run one of our scripts under `./scripts/`. + +## multi-node training + +Make sure you've installed `NCCL` following instructions from [multi-GPU](#multi-gpu-training) section. + +There are 3 ways we currently support that allow you to run multi-node training: +1) Use OpenMPI to exchange nccl id and initialize NCCL. See e.g. `./scripts/multi_node/run_gpt2_124M_mpi.sh` script for details. +2) Use shared file system to init NCCL. See `./scripts/multi_node/run_gpt2_124M_fs.sbatch` script for details. +3) Use TCP sockets to init NCCL. See `./scripts/multi_node/run_gpt2_124M_tcp.sbatch` script for details. + +Note: +* If you're running in a slurm environment and your slurm doesn't support PMIx (which we assume will be a common situation given that `slurm-wlm` dropped PMIx support) you will have to use FS (2) or TCP (3) approach. To test whether your slurm supports PMIx run: `srun --mpi=list` and see whether you get `pmix` in the output. +* If you don't have slurm set up, you can kick off a multi-node run using `mpirun` - MPI (1). + +None of these 3 methods is superior, we just offer you options so that you can run in your specific environment. + ## experiments / sweeps Just as an example process to sweep learning rates on a machine with 4 GPUs on TinyStories. Run a shell script `sweep.sh` (after you of course `chmod u+x sweep.sh`): @@ -198,6 +225,9 @@ Lastly, I will be a lot more sensitive to complexity in the root folder of the p - Mojo - [llm.🔥](https://github.com/dorjeduck/llm.mojo) by @[dorjeduck](https://github.com/dorjeduck): a Mojo port of this project +- OpenCL + - [llm.c](https://github.com/krrishnarraj/llm.c) by @[krrishnarraj](https://github.com/krrishnarraj): an OpenCL port of this project + - Rust - [llm.rs](https://github.com/yijunyu/llm.rs) by @[Yijun Yu](https://github.com/yijunyu): a Rust rewrite with the aim to have same performance - [llm.rs](https://github.com/ToJen/llm.rs) by @[ToJen](https://github.com/ToJen): a Rust port of this project diff --git a/dev/cuda/Makefile b/dev/cuda/Makefile index 4a14ac49f..68bb74635 100644 --- a/dev/cuda/Makefile +++ b/dev/cuda/Makefile @@ -8,8 +8,20 @@ ifeq ($(NVCC),) $(error nvcc not found.) endif +ifneq ($(CI),true) # if not in CI, then use the GPU query + ifndef GPU_COMPUTE_CAPABILITY # set to defaults if: make GPU_COMPUTE_CAPABILITY= + GPU_COMPUTE_CAPABILITY = $(shell __nvcc_device_query) # assume if NVCC is present, then this likely is too + GPU_COMPUTE_CAPABILITY := $(strip $(GPU_COMPUTE_CAPABILITY)) + endif +endif + # Compiler flags -CFLAGS = -O3 --use_fast_math +ifeq ($(GPU_COMPUTE_CAPABILITY),) # set to defaults if: make GPU_COMPUTE_CAPABILITY= + CFLAGS = -O3 --use_fast_math +else + CFLAGS = -O3 --use_fast_math --generate-code arch=compute_$(GPU_COMPUTE_CAPABILITY),code=[compute_$(GPU_COMPUTE_CAPABILITY),sm_$(GPU_COMPUTE_CAPABILITY)] +endif + NVCCFLAGS = -lcublas -lcublasLt -std=c++17 MPI_PATHS = -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux-gnu/openmpi/lib/ @@ -20,6 +32,8 @@ MPI_PATHS = -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux- # Build all targets TARGETS = adamw attention_backward attention_forward classifier_fused crossentropy_forward crossentropy_softmax_backward encoder_backward encoder_forward gelu_backward gelu_forward layernorm_backward layernorm_forward matmul_backward matmul_backward_bias matmul_forward nccl_all_reduce residual_forward softmax_forward trimat_forward fused_residual_forward global_norm all: $(TARGETS) +all_ptx: $(TARGETS:%=%.ptx) +all_sass: $(TARGETS:%=%.sass) # Individual targets: forward pass attention_forward: attention_forward.cu @@ -54,6 +68,14 @@ global_norm: global_norm.cu nccl_all_reduce: nccl_all_reduce.cu $(NVCC) -lmpi -lnccl $(NVCCFLAGS) $(MPI_PATHS) nccl_all_reduce.cu -o nccl_all_reduce +# Generate PTX using cuobjdump +%.ptx: % + cuobjdump --dump-ptx $< > $@ + +# Generate SASS using cuobjdump +%.sass: % + cuobjdump --dump-sass $< > $@ + # Run all targets run_all: all @for target in $(TARGETS); do \ @@ -65,4 +87,4 @@ run_all: all # Clean up clean: - rm -f $(TARGETS) + rm -f $(TARGETS) *.ptx *.sass diff --git a/dev/cuda/adamw.cu b/dev/cuda/adamw.cu index 20a6560dd..74dfc2ee2 100644 --- a/dev/cuda/adamw.cu +++ b/dev/cuda/adamw.cu @@ -159,7 +159,7 @@ int main(int argc, char **argv) { // create random data on host (to be used for the CPU reference implementation) float* params_memory = make_random_float(num_parameters); float* grads_memory = make_random_float(num_parameters); - float* m_memory = make_random_float_01(num_parameters); + float* m_memory = make_random_float(num_parameters); float* v_memory = make_random_float_01(num_parameters); // move to GPU diff --git a/dev/cuda/attention_backward.cu b/dev/cuda/attention_backward.cu index c97dbeee8..f6d258dc9 100644 --- a/dev/cuda/attention_backward.cu +++ b/dev/cuda/attention_backward.cu @@ -68,7 +68,7 @@ void attention_forward_cpu(float* out, float* preatt, float* att, float* att_bth = att + b*NH*T*T + h*T*T + t*T; // pass 1: calculate query dot key and maxval - float maxval = -10000.0f; // TODO something better + float maxval = -FLT_MAX; for (int t2 = 0; t2 < T; t2++) { // used to be t2 <= t float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key diff --git a/dev/cuda/attention_forward.cu b/dev/cuda/attention_forward.cu index b632b4a66..66b6a1b3e 100644 --- a/dev/cuda/attention_forward.cu +++ b/dev/cuda/attention_forward.cu @@ -98,7 +98,7 @@ void attention_forward_cpu(float* out, float* preatt, float* att, float* att_bth = att + b*NH*T*T + h*T*T + t*T; // pass 1: calculate query dot key and maxval - float maxval = -10000.0f; // TODO something better + float maxval = -FLT_MAX; for (int t2 = 0; t2 <= t; t2++) { const float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key @@ -203,7 +203,7 @@ __global__ void attention_softmax_kernel1(float* att, const float* preatt, float* att_bth = att + b*NH*T*T + h*T*T + t*T; // find maxval - float maxval = -10000.0f; // TODO something better + float maxval = -FLT_MAX; for (int t2 = 0; t2 <= t; t2++) { if (preatt_bth[t2] > maxval) { maxval = preatt_bth[t2]; diff --git a/dev/cuda/classifier_fused.cu b/dev/cuda/classifier_fused.cu index 5b6c986e2..9d93f9d64 100644 --- a/dev/cuda/classifier_fused.cu +++ b/dev/cuda/classifier_fused.cu @@ -114,7 +114,7 @@ __device__ SoftmaxParams prepare_softmax(cg::thread_block_tile<32>& warp, int64_t idx, const float* inp, int V, int P) { // this warp (of 32) threads processes one row of inp, i.e. inp[idx, :] of shape (V,) // note that inp is actually (B * T, P) but we only use the first V elements - // this function tehen calculates: + // this function then calculates: // 1) the max value to subtract for numerical stability and // 2) the sum normalization factor const float* x = inp + idx * P; @@ -481,33 +481,6 @@ __global__ void fused_classifier_kernel4(floatX* dlogits, floatX* losses, floatX } } -// todo - move to common.h - or ideally somewhere it's not duplicated between train & common? -// requires all 32 threads in the warp to be active, but should work for any block size -// uses non-dynamic shared memory so every call increases shared memory requirements by 128 bytes -// the fact it's unique shared memory allows us to avoid an extra __syncthreads() call at the end -// but if called inside a loop, the shared memory will be implicitly reused, so set final_sync to 1 -using reduction_func_t = float (*) (float); -template -__device__ float blockReduce(float val, bool final_sync=false, float out_of_bounds=0.0f) { - // two reductions of up to 1024 threads: - // 1) inside warp (shuffle), 2) cross-warp (shared memory), 3) inside warp (shuffle) - __shared__ float shared_val[32]; - const int lane_id = threadIdx.x % 32; - const int warp_id = threadIdx.x / 32; - const int num_warps = blockDim.x / 32; - - float warp_val = warp_reduction(val); - if (lane_id == 0) { shared_val[warp_id] = warp_val; } - __syncthreads(); - warp_val = (lane_id < num_warps) ? shared_val[lane_id] : out_of_bounds; - float block_val = warp_reduction(warp_val); - - if (final_sync) { - __syncthreads(); // only needed in loops when effectively reusing shared memory etc. - } - return block_val; -} - __device__ SoftmaxParams prepare_softmax_blockwide3(int64_t idx, const floatX* inp, int V, int P) { // same but not float4 // one row of inp, i.e. inp[idx, :] of shape (V,) @@ -707,8 +680,8 @@ int main(int argc, char **argv) { cudaCheck(cudaSetDevice(deviceIdx)); // create host memory of random numbers - float* logits = make_random_float_01(B * T * V); - float* probs = (float*)malloc(B * T * V * sizeof(float)); + float* logits = make_random_float(B * T * V); + float* probs = make_random_float_01(B * T * V); float* dlogits = (float*)malloc(B * T * V * sizeof(float)); float* losses = (float*)malloc(B * T * sizeof(float)); float* dlosses = make_random_float(B * T); @@ -787,6 +760,7 @@ int main(int argc, char **argv) { free(losses); free(dlosses); free(targets); + free(outliers); cudaCheck(cudaFree(d_dlogits)); cudaCheck(cudaFree(d_losses)); cudaCheck(cudaFree(d_logits)); diff --git a/dev/cuda/common.h b/dev/cuda/common.h index 6502baa20..61a783a60 100644 --- a/dev/cuda/common.h +++ b/dev/cuda/common.h @@ -5,6 +5,8 @@ #include #include +#define WARP_SIZE 32U +extern cudaDeviceProp deviceProp; template __host__ __device__ T ceil_div(T dividend, T divisor) { @@ -18,6 +20,39 @@ __device__ float warpReduceSum(float val) { return val; } +// requires all 32 threads in the warp to be active, but should work for any block size +// uses non-dynamic shared memory so every call increases shared memory requirements by 128 bytes +// the fact it's unique shared memory allows us to avoid an extra __syncthreads() call at the end +// but if called inside a loop, the shared memory will be implicitly reused, so set final_sync to 1 +using reduction_func_t = float (*) (float); + +template +__device__ inline float blockReduce(float val, bool final_sync, float out_of_bounds) { + // two reductions of up to 1024 threads: + // 1) inside warp (shuffle), 2) cross-warp (shared memory), 3) inside warp (shuffle) + __shared__ float shared_val[WARP_SIZE]; + const int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = threadIdx.x / WARP_SIZE; + const int num_warps = blockDim.x / WARP_SIZE; + + float warp_val = warp_reduction(val); + if (lane_id == 0) { shared_val[warp_id] = warp_val; } + __syncthreads(); + warp_val = (lane_id < num_warps) ? shared_val[lane_id] : out_of_bounds; + float block_val = warp_reduction(warp_val); + + if (final_sync) { + __syncthreads(); // only needed in loops when effectively reusing shared memory etc. + } + return block_val; +} + +// Helper function to call blockReduce with default arguments +template +__device__ inline float blockReduce(float val) { + return blockReduce(val, false, 0.0f); +} + // ---------------------------------------------------------------------------- // checking utils diff --git a/dev/cuda/crossentropy_softmax_backward.cu b/dev/cuda/crossentropy_softmax_backward.cu index 27521bf60..65c72a2f1 100644 --- a/dev/cuda/crossentropy_softmax_backward.cu +++ b/dev/cuda/crossentropy_softmax_backward.cu @@ -99,7 +99,7 @@ int main(int argc, char **argv) { cudaCheck(cudaSetDevice(deviceIdx)); // create host memory of random numbers - float* probs = make_random_float(B * T * V); + float* probs = make_random_float_01(B * T * V); int* targets = make_random_int(B * T, V); float* dlosses = make_random_float(B * T); float* dlogits = make_zeros_float(B * T * V); diff --git a/dev/cuda/fused_residual_forward.cu b/dev/cuda/fused_residual_forward.cu index b98a67c4b..9752873db 100644 --- a/dev/cuda/fused_residual_forward.cu +++ b/dev/cuda/fused_residual_forward.cu @@ -133,7 +133,7 @@ __global__ void fused_residual_forward2(floatX* residual, floatX* normed, floatX for(int c = 0; c < C; ++c) { float out = (float)inp1[c] + (float)inp2[c]; m += out; - residual[c] = out; + residual[c] = (floatX)out; } m = m / C; @@ -149,11 +149,11 @@ __global__ void fused_residual_forward2(floatX* residual, floatX* normed, floatX for (int c = 0; c < C; c++) { float n = (s * ((float)residual[c] - m)); // normalized output float o = n * (float)weight[c] + (float)bias[c]; // scale and shift it - normed[c] = o; // write + normed[c] = (floatX)o; // write } // cache the mean and rstd for the backward pass later - mean[idx] = m; - rstd[idx] = s; + mean[idx] = (floatX)m; + rstd[idx] = (floatX)s; } // handle one token per warp for coalesced access @@ -232,7 +232,7 @@ __global__ void fused_residual_forward_kernel4(floatX* residual, floatX* normed, const x128 in2 = load128cs(inp2 + c); x128 out; for(int k = 0; k < x128::size; ++k) { - out[k] = (float)in1[k] + (float)in2[k]; + out[k] = (floatX)((float)in1[k] + (float)in2[k]); sum += (float)out[k]; sum_sq += (float)out[k] * (float)out[k]; } @@ -309,7 +309,7 @@ __global__ void fused_residual_forward_kernel5(floatX* residual, floatX* normed, const x128 in2 = load128cs(inp2 + c); x128 out; for(int k = 0; k < x128::size; ++k) { - out[k] = (float)in1[k] + (float)in2[k]; + out[k] = (floatX)((float)in1[k] + (float)in2[k]); sum += (float)out[k]; } store128cs(residual + c, out); @@ -372,8 +372,8 @@ __global__ void fused_residual_forward_kernel6(floatX* residual, floatX* normed, // weights and biases are shared among all tokens x128* s_weight = reinterpret_cast(params); x128* s_bias = reinterpret_cast(params + C * sizeof(floatX)); - // residual output (input to layernorm) is indpendent for each sub-block indicates by threadIdx.z - x128* s_res = reinterpret_cast(params + (2 + threadIdx.z) * C * sizeof(floatX) ); + // residual output (input to layernorm) is independent for each sub-block indicates by threadIdx.z + x128* s_res = reinterpret_cast(params + (2 + threadIdx.z) * C * sizeof(floatX)); // similarly, each sub-block needs its own reduction buffers float* s_mean = reinterpret_cast(params + (2 + blockDim.z) * C * sizeof(floatX) + threadIdx.z * 32 * sizeof(float)); float* s_var = reinterpret_cast(params + (2 + blockDim.z) * C * sizeof(floatX) + 32 * sizeof(float) * (blockDim.z + threadIdx.z)); @@ -385,10 +385,10 @@ __global__ void fused_residual_forward_kernel6(floatX* residual, floatX* normed, s_weight[c / x128::size] = load128(weight + c); s_bias[c / x128::size] = load128(bias + c); } + // the block-level reductions will cause sync before the first time we read these // => no syncthreads needed here - // loop over all tokens for(int tidx = blockIdx.x * blockDim.z + threadIdx.z; tidx < N; tidx += gridDim.x * blockDim.z) { // adjust pointers to current token diff --git a/dev/cuda/global_norm.cu b/dev/cuda/global_norm.cu index 6c2ed0389..f54a35a42 100644 --- a/dev/cuda/global_norm.cu +++ b/dev/cuda/global_norm.cu @@ -16,6 +16,7 @@ nvcc -O3 --use_fast_math global_norm.cu -o global_norm #define ENABLE_BF16 #include "common.h" +cudaDeviceProp deviceProp; float global_norm_cpu(const float* data, size_t count) { // accumulate in double so we have an accurate numerical reference @@ -89,6 +90,54 @@ __global__ void norm_kernel2(float* out, const T* data, size_t count) { } } +template +__global__ void norm_kernel3(float* out, const T* data, size_t count) { + size_t index = blockIdx.x * blockDim.x + threadIdx.x; + size_t grid_width = blockDim.x * gridDim.x; + float accumulator = 0.f; + for(size_t i = index; i < count; i += grid_width) { + accumulator += (float)data[i] * (float)data[i]; + } + // block-level reduce + float block_sum = blockReduce(accumulator); + if(threadIdx.x == 0) { + atomicAdd(out, block_sum); + } +} + +// Same as kernel3 but without atomic adds -> this allows us to have determinism due to the +// non associativity of floating point operations. Roughly same performance as kernel3. +template +__global__ void norm_kernel4(float* out, const T* data, size_t count) { + size_t index = blockIdx.x * blockDim.x + threadIdx.x; + size_t grid_width = blockDim.x * gridDim.x; + float accumulator = 0.f; + for(size_t i = index; i < count; i += grid_width) { + accumulator += (float)data[i] * (float)data[i]; + } + // block-level reduce + float block_sum = blockReduce(accumulator); + // each block accumulates its partial sum to out[blockIdx.x] + // we want to avoid using atomic add here so we combine this kernel with the aggregate kernel call + // that sums up the partial block sums + if(threadIdx.x == 0) { + out[blockIdx.x] = block_sum; + } +} + +__global__ void global_norm_aggregate_kernel(float* out, size_t count) { + size_t index = threadIdx.x; + // grab block sums from the previous kernel, use 0. as the neutral sum element + float block_sum = (index < count) ? out[index] : 0.f; + float sum = blockReduce(block_sum); + if(threadIdx.x == 0) { + out[0] = sum; // out[0] ends up with the final norm squared + } +} + +// ---------------------------------------------------------------------------- +// kernel launchers + template void global_norm1(float* out, const T* values, size_t count, int block_size) { // launch just enough blocks to fill the grid. deliberately no DIV_CEIL. @@ -111,17 +160,54 @@ void global_norm2(float* out, const T* values, size_t count, int block_size) { cudaCheck(cudaGetLastError()); } +template +void global_norm3(float* out, const T* values, size_t count, int block_size) { + // launch just enough blocks to fill the grid. deliberately no DIV_CEIL. + // having one block less than possible is a tiny performance hit, having + // one block too many is catastrophic, since it only can start once all the other + // blocks finish. anyway, I think cuda_threads_per_SM should be a multiple of 512 + // on all gpus, so the division really is going to be exact. + const int grid_size = deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / block_size; + assert(grid_size > 0); // gives a better error than letting the call below fail + norm_kernel3<<>>(out, values, count); + cudaCheck(cudaGetLastError()); +} + +template +void global_norm4(float* out, const T* values, size_t count, int block_size) { + if (block_size <= 64) { + block_size = 128; // to avoid triggering the assert below + } + // launch just enough blocks to fill the grid. deliberately no DIV_CEIL. + // having one block less than possible is a tiny performance hit, having + // one block too many is catastrophic, since it only can start once all the other + // blocks finish. anyway, I think cuda_threads_per_SM should be a multiple of 512 + // on all gpus, so the division really is going to be exact. + const int grid_size = deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / block_size; + assert(grid_size > 0); // gives a better error than letting the call below fail + assert(grid_size < 1024); // we want to later accumulate the block sums in a single block + norm_kernel4<<>>(out, values, count); + cudaCheck(cudaGetLastError()); + global_norm_aggregate_kernel<<<1, 1024>>>(out, grid_size); + cudaCheck(cudaGetLastError()); +} + void global_norm(int kernel_num, float* out, const floatX* values, size_t count, int block_size) { switch (kernel_num) { case 1: return global_norm1(out, values, count, block_size); case 2: return global_norm2(out, values, count, block_size); + case 3: + return global_norm3(out, values, count, block_size); + case 4: + return global_norm4(out, values, count, block_size); } } int main(int argc, const char **argv) { setup_main(); + cudaGetDeviceProperties(&deviceProp, 0); int C = 768; int L = 12; @@ -148,7 +234,7 @@ int main(int argc, const char **argv) { // move to GPU float* d_out; floatX* d_inp; - cudaCheck(cudaMalloc(&d_out, sizeof(float))); + cudaCheck(cudaMalloc(&d_out, 1024 * sizeof(float))); // 1024 needed for kernel 4 cudaCheck(cudaMalloc(&d_inp, num_params * sizeof(floatX))); cudaCheck(memcpy_convert(d_inp, inp, num_params)); diff --git a/dev/cuda/layernorm_backward.cu b/dev/cuda/layernorm_backward.cu index dc9d7e982..3930cecdd 100644 --- a/dev/cuda/layernorm_backward.cu +++ b/dev/cuda/layernorm_backward.cu @@ -874,7 +874,6 @@ __global__ void layernorm_backward_kernel9(floatX* dinp, floatX* dweight, floatX } __trap(); // prefer to crash here than run into a deadlock later on } - constexpr int WARP_SIZE = 32; int BLOCK_SIZE = blockDim.x; int warpsInBlock = BLOCK_SIZE / WARP_SIZE; //number of warps in block extern __shared__ float shared[]; // size = 2 * C + 1 @@ -1059,7 +1058,6 @@ layernorm_backward_kernel10(floatX* dinp, floatX* dweight, floatX* dbias, float* const floatX* dout, const floatX* inp, const floatX* weight, const floatX* mean, const floatX* rstd, int B, int T, int C) { - constexpr int WARP_SIZE = 32; int BLOCK_SIZE = blockDim.x; int warpsInBlock = BLOCK_SIZE / WARP_SIZE; //number of warps in block extern __shared__ float shared[]; // size = 2 * C + 1 diff --git a/dev/cuda/layernorm_forward.cu b/dev/cuda/layernorm_forward.cu index 3e948289a..0c4675162 100644 --- a/dev/cuda/layernorm_forward.cu +++ b/dev/cuda/layernorm_forward.cu @@ -28,7 +28,6 @@ verstion 5 allocates blocks per row instead of warps per row, same alg as 4 othe #include #include #include "common.h" - // ---------------------------------------------------------------------------- // CPU code reference @@ -290,7 +289,7 @@ __global__ void layernorm_forward_kernel5(float* __restrict__ out, float* __rest int num_warps = blockDim.x / 32; int warp_id = threadIdx.x / 32; int lane_id = threadIdx.x % 32; - int idx = blockIdx.x; // simpoy one block per row + int idx = blockIdx.x; // simply one block per row // the row of input that this group of threads is responsible for const float* x = inp + idx * C; // thread coarsening through the row, reduce the sum in series @@ -337,6 +336,82 @@ __global__ void layernorm_forward_kernel5(float* __restrict__ out, float* __rest } } +// Inspired by `fused_residual_forward_kernel5` in fused_residual_forward.cu +__global__ void layernorm_forward_kernel6(float* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd, + const float* __restrict__ inp, const float* __restrict__ weight, + const float* __restrict__ bias, int N, int C) { + assert(blockDim.x == WARP_SIZE); + + // load weights and biases into shared memory + // do this before we allow any threads to exit! + extern __shared__ char params[]; + // load128/store128 sometimes generated multiple instructions when the types here were floatX*, so + // let's keep everything as x128 + x128* s_weight = reinterpret_cast(params); + x128* s_bias = reinterpret_cast(params) + (C / x128::size); + x128* s_in = reinterpret_cast(params) + ((2 + threadIdx.y) * C / x128::size); + + int sidx = (threadIdx.x + WARP_SIZE * threadIdx.y) * x128::size; + for(int i = sidx; i < C; i += blockDim.y * WARP_SIZE * x128::size) { + s_weight[i/x128::size] = load128(weight + i); + s_bias[i/x128::size] = load128(bias + i); + } + __syncthreads(); + + int idx = blockIdx.x * blockDim.y + threadIdx.y; + if(idx >= N) { return; } // guard + + // adjust pointers to current token + inp += idx * C; + out += idx * C; + + const float eps = 1e-5f; + float sum = 0.0f; + for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { + const x128 in_data = load128cs(inp + c); + for(int k = 0; k < x128::size; ++k) { + sum += (float)in_data[k]; + } + s_in[c / x128::size] = in_data; + } + + sum = warpReduceSum(sum); + float m = sum / C; + float v = 0.f; + + for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { + const x128 in_data = s_in[c / x128::size]; + for(int k = 0; k < x128::size; ++k) { + v += ((float)in_data[k] - m) * ((float)in_data[k] - m); + } + } + + v = warpReduceSum(v) / C; + float s = rsqrtf(v + eps); + + for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { + const x128 in_data = s_in[c / x128::size]; + const x128 w = s_weight[c / x128::size]; + const x128 b = s_bias[c / x128::size]; + x128 out_data; + for(int k = 0; k < x128::size; ++k) { + float n = s * ((float)in_data[k] - m); // normalized output + float o = n * (float)w[k] + (float)b[k]; // scale and shift it + out_data[k] = o; + } + + store128cs(out + c, out_data); + } + // cache the mean and rstd for the backward pass later + if(threadIdx.x == 0 && mean != nullptr) { + __stcs(mean + idx, m); + } + // store the rstd, no need to cache it + if(threadIdx.x == 0 && rstd != nullptr) { + __stcs(rstd + idx, s); + } +} + // ---------------------------------------------------------------------------- // kernel launcher @@ -356,9 +431,9 @@ void layernorm_forward2(float* out, float* mean, float* rstd, const int block_size) { int N = B * T; // in mean and rstd, threads cooperate within blocks via reductions - mean_kernel<<>>(mean, inp, N, C, block_size); + mean_kernel<<>>(mean, inp, N, C, block_size); cudaCheck(cudaGetLastError()); - rstd_kernel<<>>(rstd, inp, mean, N, C, block_size); + rstd_kernel<<>>(rstd, inp, mean, N, C, block_size); cudaCheck(cudaGetLastError()); // in the normalization, everything just gets flattened out const int block_size2 = 256; @@ -394,12 +469,38 @@ void layernorm_forward5(float* out, float* mean, float* rstd, int B, int T, int C, const int block_size) { assert(block_size % 32 == 0); + assert(block_size <= 1024); const int N = B * T; const int grid_size = N; layernorm_forward_kernel5<<>>(out, mean, rstd, inp, weight, bias, N, C); cudaCheck(cudaGetLastError()); } +void layernorm_forward6(float* out, float* mean, float* rstd, + const float* inp, const float* weight, const float* bias, + int B, int T, int C, + int block_size) { + assert(block_size % 32 == 0); + const int N = B * T; + int block_y = block_size / WARP_SIZE; + const int grid_size = ceil_div(N, block_y); + size_t smem = (2 + block_y) * C * sizeof(float); + + // in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute + // this may fail, in which case we fall back to the smem free implementation. + cudaCheck(cudaGetLastError()); + auto status = cudaFuncSetAttribute(layernorm_forward_kernel6, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); + cudaGetLastError(); + if (status == cudaSuccess) { + layernorm_forward_kernel6<<>>(out, mean, rstd, inp, weight, bias, N, C); + } else { + const int grid_size = N; + // fall back to the version without shared memory + layernorm_forward_kernel5<<>>(out, mean, rstd, inp, weight, bias, N, C); + } + cudaCheck(cudaGetLastError()); +} + // kernel version dispatch void layernorm_forward(int kernel_num, float* out, float* mean, float* rstd, @@ -422,6 +523,9 @@ void layernorm_forward(int kernel_num, case 5: layernorm_forward5(out, mean, rstd, inp, weight, bias, B, T, C, block_size); break; + case 6: + layernorm_forward6(out, mean, rstd, inp, weight, bias, B, T, C, block_size); + break; default: printf("Invalid kernel number\n"); exit(1); @@ -473,9 +577,6 @@ int main(int argc, char **argv) { printf("Using kernel %d\n", kernel_num); int block_sizes[] = {32, 64, 128, 256, 512, 1024}; - float* out_gpu = (float*)malloc(B * T * C * sizeof(float)); - float* mean_gpu = (float*)malloc(B * T * sizeof(float)); - float* rstd_gpu = (float*)malloc(B * T * sizeof(float)); layernorm_forward_cpu(out, mean, rstd, inp, weight, bias, B, T, C); diff --git a/dev/cuda/matmul_backward_bias.cu b/dev/cuda/matmul_backward_bias.cu index ad56a8bc3..86fd37379 100644 --- a/dev/cuda/matmul_backward_bias.cu +++ b/dev/cuda/matmul_backward_bias.cu @@ -2,7 +2,7 @@ Kernels for matmul backward pass bias only. Compile example: -nvcc -O3 -lcublas -lcublasLt matmul_backward_bias.cu -lineinfo -o matmul_backward_bias +nvcc -O3 -lcublas -lcublasLt -std=c++17 matmul_backward_bias.cu -lineinfo -o matmul_backward_bias ./matmul_backward_bias 1 ./matmul_backward_bias 2 @@ -116,7 +116,7 @@ __global__ void matmul_backward_bias_kernel2(floatX* dbias, const floatX* dout, sum = cg::reduce(warp, sum, cg::plus{}); // write the result to output (global memory) if(warp.thread_rank() == 0) { - dbias[idx] += (floatX)sum; + dbias[idx] = (float)dbias[idx] + sum; } } @@ -149,7 +149,7 @@ __global__ void matmul_backward_bias_kernel3(floatX* dbias, const floatX* dout, float block_sum = cg::reduce(warp, warp_sum, cg::plus{}); // sum(x) // write the result to output (global memory) if(threadIdx.x == 0) { - dbias[idx] += block_sum; + dbias[idx] = (float)dbias[idx] + block_sum; } } @@ -189,7 +189,7 @@ __global__ void matmul_backward_bias_kernel4(floatX* dbias, const floatX* dout, for (int j = 0; j < vstep; j++) { dout_sum += smem[lane_id + j * warpSize]; } - dbias[tl + lane_id] += dout_sum; + dbias[tl + lane_id] = (float)dbias[tl + lane_id] + dout_sum; } } diff --git a/dev/data/fineweb.py b/dev/data/fineweb.py index 1d4184c2a..72c312966 100644 --- a/dev/data/fineweb.py +++ b/dev/data/fineweb.py @@ -14,13 +14,16 @@ "language_score": 0.9185474514961243, "token_count": 594 } + +Example of downloading the 100B dataset of FineWebEDU, from root directory: +python dev/data/fineweb.py -t edu -v 100B +100B runs for small few hours, depending on your internet and computer. """ import os import argparse import multiprocessing as mp import numpy as np import tiktoken -# from huggingface_hub import snapshot_download from datasets import load_dataset from tqdm import tqdm import argparse @@ -28,26 +31,34 @@ from data_common import write_datafile # ------------------------------------------ -parser = argparse.ArgumentParser(description="FineWeb dataset preprocessing") -parser.add_argument("-v", "--version", type=str, default="10B", help="Which version of fineweb to use 10B|100B") -parser.add_argument("-s", "--shard_size", type=int, default=10**8, help="Size of each shard in tokens") +parser = argparse.ArgumentParser(description="FineWeb and Edu-FineWeb dataset preprocessing") +parser.add_argument("-t", "--type", type=str, default="classic", help="Fineweb type, edu|classic") +parser.add_argument("-v", "--version", type=str, default="10B", help="Fineweb data sample size, 10B|100B") +parser.add_argument("-s", "--shard_size", type=int, default=10**8, help="Size of each data shard in the output .bin files, in tokens") args = parser.parse_args() # FineWeb has a few possible subsamples available -assert args.version in ["10B", "100B"], "version must be one of 10B, 100B" -if args.version == "10B": - local_dir = "fineweb10B" - remote_name = "sample-10BT" -elif args.version == "100B": - local_dir = "fineweb100B" - remote_name = "sample-100BT" +assert args.version in {"10B", "100B"}, "version must be one of: 10B, 100B" +assert args.type in {"edu", "classic"}, "type must be one of: edu, classic" +directories = { + ("classic", "10B"): ("fineweb10B", "sample-10BT"), + ("classic", "100B"): ("fineweb100B", "sample-100BT"), + ("edu", "10B"): ("edu_fineweb10B", "sample-10BT"), + ("edu", "100B"): ("edu_fineweb100B", "sample-100BT") +} +local_dir, remote_name = directories[(args.type, args.version)] # create the cache the local directory if it doesn't exist yet DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), local_dir) os.makedirs(DATA_CACHE_DIR, exist_ok=True) # download the dataset -fw = load_dataset("HuggingFaceFW/fineweb", name=remote_name, split="train") +if args.type == "classic": + fw = load_dataset("HuggingFaceFW/fineweb", name=remote_name, split="train") + name = "fineweb" +elif args.type =="edu": + fw = load_dataset("HuggingFaceFW/fineweb-edu", name=remote_name, split="train") + name = "edu_fineweb" # init the tokenizer enc = tiktoken.get_encoding("gpt2") @@ -83,7 +94,7 @@ def tokenize(doc): else: # write the current shard and start a new one split = "val" if shard_index == 0 else "train" - filename = os.path.join(DATA_CACHE_DIR, f"fineweb_{split}_{shard_index:06d}.bin") + filename = os.path.join(DATA_CACHE_DIR, f"{name}_{split}_{shard_index:06d}.bin") # split the document into whatever fits in this shard; the remainder goes to next one remainder = args.shard_size - token_count progress_bar.update(remainder) @@ -98,5 +109,5 @@ def tokenize(doc): # write any remaining tokens as the last shard if token_count != 0: split = "val" if shard_index == 0 else "train" - filename = os.path.join(DATA_CACHE_DIR, f"fineweb_{split}_{shard_index:06d}.bin") + filename = os.path.join(DATA_CACHE_DIR, f"{name}_{split}_{shard_index:06d}.bin") write_datafile(filename, all_tokens_np[:token_count]) diff --git a/dev/data/fineweb.sh b/dev/data/fineweb.sh new file mode 100755 index 000000000..33e94792f --- /dev/null +++ b/dev/data/fineweb.sh @@ -0,0 +1,77 @@ +#!/bin/bash + +# Downloads the FineWeb100B dataset, but in an already tokenized format in .bin files +# Example: ./fineweb.sh 100 +# would download 100 shards +# Default is all shards + +# Check if MAX_SHARDS is provided as positional first arg, otherwise default to 1024 +if [ $# -eq 0 ]; then + MAX_SHARDS=1028 +else + MAX_SHARDS=$1 +fi + +# Ensure MAX_SHARDS is not greater than 1028 +if [ $MAX_SHARDS -gt 1028 ]; then + MAX_SHARDS=1028 +fi + +# Base URLs +TRAIN_BASE_URL="https://huggingface.co/datasets/chrisdryden/FineWebTokenizedGPT2/resolve/main/fineweb_train_" +VAL_URL="https://huggingface.co/datasets/chrisdryden/FineWebTokenizedGPT2/resolve/main/fineweb_val_000000.bin?download=true" + +# Directory to save files +SAVE_DIR="fineweb100B" + +# Create the directory if it doesn't exist +mkdir -p "$SAVE_DIR" + +# Function to download, decompress, and delete files +download() { + local FILE_URL=$1 + local FILE_NAME=$(basename $FILE_URL | cut -d'?' -f1) + local FILE_PATH="${SAVE_DIR}/${FILE_NAME}" + + # Download the file + curl -s -L -o "$FILE_PATH" "$FILE_URL" + echo "Downloaded $FILE_NAME to $SAVE_DIR" +} + +# Function to manage parallel jobs +run_in_parallel() { + local max_jobs=$1 + shift + local commands=("$@") + local job_count=0 + + for cmd in "${commands[@]}"; do + eval "$cmd" & + ((job_count++)) + if (( job_count >= max_jobs )); then + wait -n + ((job_count--)) + fi + done + + # Wait for any remaining jobs to finish + wait +} + +# Export the function so it's available in subshells +export -f download + +# Download +download "$VAL_URL" & + +# Generate train file commands +train_commands=() +for i in $(seq -f "%06g" 1 $MAX_SHARDS); do + FILE_URL="${TRAIN_BASE_URL}${i}.bin?download=true" + train_commands+=("download \"$FILE_URL\"") +done + +# Run the train file commands in parallel +run_in_parallel 40 "${train_commands[@]}" + +echo "The val shard and first $MAX_SHARDS train shards of FineWeb100B files downloaded in $SAVE_DIR" diff --git a/dev/download_starter_pack.sh b/dev/download_starter_pack.sh new file mode 100755 index 000000000..4034a1c81 --- /dev/null +++ b/dev/download_starter_pack.sh @@ -0,0 +1,75 @@ +#!/bin/bash + +# Get the directory of the script +SCRIPT_DIR=$(dirname "$(realpath "$0")") + +# Base URL +BASE_URL="https://huggingface.co/datasets/chrisdryden/llmcDatasets/resolve/main/" + +# Directory paths based on script location +SAVE_DIR_PARENT="$SCRIPT_DIR/.." +SAVE_DIR_TINY="$SCRIPT_DIR/data/tinyshakespeare" + +# Create the directories if they don't exist +mkdir -p "$SAVE_DIR_TINY" + +# Files to download +FILES=( + "gpt2_124M.bin" + "gpt2_124M_bf16.bin" + "gpt2_124M_debug_state.bin" + "gpt2_tokenizer.bin" + "tiny_shakespeare_train.bin" + "tiny_shakespeare_val.bin" +) + +# Function to download files to the appropriate directory +download_file() { + local FILE_NAME=$1 + local FILE_URL="${BASE_URL}${FILE_NAME}?download=true" + local FILE_PATH + + # Determine the save directory based on the file name + if [[ "$FILE_NAME" == tiny_shakespeare* ]]; then + FILE_PATH="${SAVE_DIR_TINY}/${FILE_NAME}" + else + FILE_PATH="${SAVE_DIR_PARENT}/${FILE_NAME}" + fi + + # Download the file + curl -s -L -o "$FILE_PATH" "$FILE_URL" + echo "Downloaded $FILE_NAME to $FILE_PATH" +} + +# Export the function so it's available in subshells +export -f download_file + +# Generate download commands +download_commands=() +for FILE in "${FILES[@]}"; do + download_commands+=("download_file \"$FILE\"") +done + +# Function to manage parallel jobs in increments of a given size +run_in_parallel() { + local batch_size=$1 + shift + local i=0 + local command + + for command; do + eval "$command" & + ((i = (i + 1) % batch_size)) + if [ "$i" -eq 0 ]; then + wait + fi + done + + # Wait for any remaining jobs to finish + wait +} + +# Run the download commands in parallel in batches of 2 +run_in_parallel 6 "${download_commands[@]}" + +echo "All files downloaded and saved in their respective directories" \ No newline at end of file diff --git a/dev/eval/README.md b/dev/eval/README.md new file mode 100644 index 000000000..f44c36327 --- /dev/null +++ b/dev/eval/README.md @@ -0,0 +1,59 @@ +# eleuther eval readme + +The goal here is to run the Eleuther Eval harness exactly in the same way as that used in the [huggingface LLM Leaderboard](https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard). + +The starting point is a `.bin` file trained by llm.c. We now have to export it to a huggingface model and then evaluate it. + +To export the model, use [export_hf.py](export_hf.py). See its documentation up top. Eample usage, from this directory: + +```bash +cd dev/eval +python export_hf.py --input model.bin --output output_dir +``` + +Where you point to your model .bin file, and huggingface files get written to output_dir. The script can optionally also upload to huggingface hub. One more post-processing that is advisable is to go into the `output_dir`, open up the `config.json` there and add one more entry into the json object: + +``` +"_attn_implementation": "flash_attention_2" +``` + +To use FlashAttention 2. We had trouble evaluating in bfloat16 without using FlashAttention 2 (the scores are much lower, and this was never fully resolved). This is a temporary hack/workaround. + +Now that we have the model in huggingface format, we download the Eleuther Eval Harness repo and run it. Head over to the parent/root directory of the llm.c repo and: + +```bash +git clone https://github.com/EleutherAI/lm-evaluation-harness/ +cd lm-evaluation-harness +git checkout b281b0921b636bc36ad05c0b0b0763bd6dd43463 +pip install -e . +``` + +And then run the run_eval.sh script: + +```bash +./dev/eval/run_eval.sh output_dir result_dir +``` + +Where output_dir can either be local output dir (above), or a huggingface repo name.This will write eval json objects to `./lm-evaluation-harness/results/results_dir`. It will print the results into console, e.g. for a 774M model we see: + +``` +---------------------------------------- +arc_challenge_25shot.json : 30.4608 +gsm8k_5shot.json : 0.1516 +hellaswag_10shot.json : 57.8072 +mmlu_5shot.json : 25.8682 +truthfulqa_0shot.json : 35.7830 +winogrande_5shot.json : 59.3528 +---------------------------------------- +Average Score : 34.9039 +``` + +But you can additionally get these results later by running `summarize_eval.py`: + +```bash +python dev/eval/summarize_eval.py lm-evaluation-harness/results/results_dir +``` + +The same information will be printed again. + +For some reason, the evaluation is quite expensive and runs for somewhere around 1-3 hours, even though it should be a few minutes at most. This has not been satisfyingly resolved so far. \ No newline at end of file diff --git a/dev/eval/export_hf.py b/dev/eval/export_hf.py new file mode 100644 index 000000000..b52cc28ea --- /dev/null +++ b/dev/eval/export_hf.py @@ -0,0 +1,173 @@ +""" +Script to convert GPT2 models from llm.c binary format to Hugging Face + +It can optinally upload to your account on Hugging Face if you have the CLI: + pip install -U "huggingface_hub[cli]" + huggingface-cli login + +Export to a local HF model: + python export_hf.py --input input_file.bin --output output_dir + +Export to a local HF model and also push to your account on Hugging Face: + python export_hf.py --input input_file.bin --output output_dir --push true +""" + +import numpy as np +import torch +import argparse, sys +from transformers import GPT2Config, GPT2Tokenizer, GPT2LMHeadModel + +# ----------------------------------------------------------------------------- +# Tensor functions for both bfloat16 (from int16) and normal float32 +# Both return float32 tensors + +def tensor_bf16(data_int16, transpose=False): + if transpose: + data_int16 = data_int16.transpose(1,0) + return torch.tensor(data_int16).view(torch.bfloat16).to(torch.float32) + +def tensor_fp32(data_float32, transpose=False): + if transpose: + data_float32 = data_float32.transpose(1,0) + return torch.tensor(data_float32).view(torch.float32) + +# ----------------------------------------------------------------------------- +# Main conversion function + +def convert(filepath, output, push_to_hub=False, out_dtype="bfloat16"): + print(f"Converting model {filepath} to {output} in {out_dtype} format and pushing to Hugging Face: {push_to_hub}") + + f = open(filepath, 'rb') + # Read in our header, checking the magic number and version + # version 3 = fp32, padded vocab + # version 5 = bf16, padded vocab + model_header = np.frombuffer(f.read(256*4), dtype=np.int32) + if model_header[0] != 20240326: + print("ERROR: magic number mismatch in the data .bin file!") + exit(1) + version = model_header[1] + if not version in [3, 5]: + print("Bad version in model file") + exit(1) + + # Load in our model parameters + maxT = model_header[2].item() # max sequence length + V = model_header[3].item() # vocab size + L = model_header[4].item() # num layers + H = model_header[5].item() # num heads + C = model_header[6].item() # channels + Vp = model_header[7].item() # padded vocab size + + print(f"{version=}, {maxT=}, {V=}, {Vp=}, {L=}, {H=}, {C=}") + + # Define the shapes of our parameters + shapes = { + 'wte': (Vp, C), + 'wpe': (maxT, C), + 'ln1w': (L, C), + 'ln1b': (L, C), + 'qkvw': (L, 3 * C, C), + 'qkvb': (L, 3 * C), + 'attprojw': (L, C, C), + 'attprojb': (L, C), + 'ln2w': (L, C), + 'ln2b': (L, C), + 'fcw': (L, 4 * C, C), + 'fcb': (L, 4 * C), + 'fcprojw': (L, C, 4 * C), + 'fcprojb': (L, C), + 'lnfw': (C,), + 'lnfb': (C,), + } + + # Load in our weights given our parameter shapes + dtype = np.float32 if version == 3 else np.int16 + w = {} + for key, shape in shapes.items(): + num_elements = np.prod(shape) + data = np.frombuffer(f.read(num_elements * np.dtype(dtype).itemsize), dtype=dtype) + w[key] = data.reshape(shape) + # The binary file saves the padded vocab - drop the padding back to GPT2 size + if shape[0] == Vp: + w[key] = w[key].reshape(shape)[:(V-Vp), :] + # Ensure the file is fully read and then close + assert f.read() == b'' + f.close() + + # Map to our model dict, the tensors at this stage are always fp32 + mk_tensor = { + 3 : tensor_fp32, + 5 : tensor_bf16, + }[version] + model_dict = {} + model_dict['transformer.wte.weight'] = mk_tensor(w['wte']) + model_dict['transformer.wpe.weight'] = mk_tensor(w['wpe']) + model_dict['lm_head.weight'] = model_dict['transformer.wte.weight'] # Tie weights + for i in range(L): + model_dict[f'transformer.h.{i}.ln_1.weight'] = mk_tensor(w['ln1w'][i]) + model_dict[f'transformer.h.{i}.ln_1.bias'] = mk_tensor(w['ln1b'][i]) + model_dict[f'transformer.h.{i}.attn.c_attn.weight'] = mk_tensor(w['qkvw'][i], True) + model_dict[f'transformer.h.{i}.attn.c_attn.bias'] = mk_tensor(w['qkvb'][i]) + model_dict[f'transformer.h.{i}.attn.c_proj.weight'] = mk_tensor(w['attprojw'][i], True) + model_dict[f'transformer.h.{i}.attn.c_proj.bias'] = mk_tensor(w['attprojb'][i]) + model_dict[f'transformer.h.{i}.ln_2.weight'] = mk_tensor(w['ln2w'][i]) + model_dict[f'transformer.h.{i}.ln_2.bias'] = mk_tensor(w['ln2b'][i]) + model_dict[f'transformer.h.{i}.mlp.c_fc.weight'] = mk_tensor(w['fcw'][i], True) + model_dict[f'transformer.h.{i}.mlp.c_fc.bias'] = mk_tensor(w['fcb'][i]) + model_dict[f'transformer.h.{i}.mlp.c_proj.weight'] = mk_tensor(w['fcprojw'][i], True) + model_dict[f'transformer.h.{i}.mlp.c_proj.bias'] = mk_tensor(w['fcprojb'][i]) + model_dict['transformer.ln_f.weight'] = mk_tensor(w['lnfw']) + model_dict['transformer.ln_f.bias'] = mk_tensor(w['lnfb']) + + # Create a GPT-2 model instance, in the requested dtype + config = GPT2Config(vocab_size = V, + n_positions = maxT, + n_ctx = maxT, + n_embd = C, + n_layer = L, + n_head = H) + model = GPT2LMHeadModel(config) + if out_dtype == "bfloat16": + model = model.to(torch.bfloat16) + + # Set the model dict and save + model.load_state_dict(model_dict) + model.save_pretrained(output, max_shard_size="5GB", safe_serialization=True) + + # Copy over a standard gpt2 tokenizer + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + tokenizer.save_pretrained(output) + + if push_to_hub: + print(f"Uploading {output} to Hugging Face") + model.push_to_hub(output) + tokenizer.push_to_hub(output) + +def spin(output): + print("Taking the exported model for a spin...") + print('-'*80) + from transformers import AutoModelForCausalLM, AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(output) + model = AutoModelForCausalLM.from_pretrained(output, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, device_map='cuda') + model.eval() + tokens = tokenizer.encode("During photosynthesis in green plants", return_tensors="pt") + tokens = tokens.to('cuda') + output = model.generate(tokens, max_new_tokens=64, repetition_penalty=1.3) + samples = tokenizer.batch_decode(output) + for sample in samples: + print('-'*30) + print(sample) + +# ----------------------------------------------------------------------------- + +if __name__== '__main__': + parser=argparse.ArgumentParser() + parser.add_argument("--input", "-i", help="The name of the llm.c model.bin file", type=str, required=True) + parser.add_argument("--output","-o", help="The Hugging Face output model directory", type=str, required=True) + parser.add_argument("--dtype", "-d", help="Output as either float32 or bfloat16 (default)", type=str, default="bfloat16") + parser.add_argument("--push", "-p", help="Push the model to your Hugging Face account", type=bool, default=False) + parser.add_argument("--spin", "-s", help="Take the model for a spin at the end?", type=bool, default=True) + args = parser.parse_args() + convert(args.input, args.output, args.push, args.dtype) + if args.spin: + spin(args.output) diff --git a/dev/eval/run_eval.sh b/dev/eval/run_eval.sh new file mode 100755 index 000000000..d1e28f612 --- /dev/null +++ b/dev/eval/run_eval.sh @@ -0,0 +1,52 @@ +# https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard +# (See About tab -> REPRODUCIBILITY) + +# This script is intended to be run from the parent/root directory of llm.c repo. + +# Clone the evaluation harness: + +# git clone https://github.com/EleutherAI/lm-evaluation-harness/ +# cd lm-evaluation-harness +# git checkout b281b0921b636bc36ad05c0b0b0763bd6dd43463 +# pip install -e . + +# Then return to the parent directory and run this script + +# cd .. +# ./dev/eval/run_eval.sh [model_name] [result_name] + +# where model_name is either a HF model such as openai-community/gpt2 or a local path such as ./gpt2-124M-run1 +# and result_name is the name of the folder under lm-evaluation-harness/results to store the evaluations + +# Since the evals can take a couple of hours to run, depending on the model size, you may wish to +# run within a "screen" session or by using nohup to run the script: + +# nohup ./dev/eval/run_eval.sh [model_name] [result_name] > run.txt 2> err.txt & + +if [ -z "$1" ]; then + echo "Error: missing HuggingFace model name or path to local model" + echo "./run_eval.sh hf_account/model_name my_result" + exit 1 +fi +if [ -z "$2" ]; then + echo "Error: missing output name for results" + echo "./run_eval.sh hf_account/model_name my_result" + exit 1 +fi + +export MODEL="$(realpath -s "$1")" +export RESULT="$2" +echo "Evaluating model $MODEL" +echo "Saving results to ./lm-evaluation-harness/results/$RESULT" + +cd lm-evaluation-harness + +python main.py --model hf-causal-experimental --model_args pretrained=$MODEL,use_accelerate=True,trust_remote_code=True --tasks truthfulqa_mc --batch_size 1 --no_cache --write_out --output_path results/$RESULT/truthfulqa_0shot.json --device cuda +python main.py --model hf-causal-experimental --model_args pretrained=$MODEL,use_accelerate=True,trust_remote_code=True --tasks winogrande --batch_size 1 --no_cache --write_out --output_path results/$RESULT/winogrande_5shot.json --device cuda --num_fewshot 5 +python main.py --model hf-causal-experimental --model_args pretrained=$MODEL,use_accelerate=True,trust_remote_code=True --tasks arc_challenge --batch_size 1 --no_cache --write_out --output_path results/$RESULT/arc_challenge_25shot.json --device cuda --num_fewshot 25 +python main.py --model hf-causal-experimental --model_args pretrained=$MODEL,use_accelerate=True,trust_remote_code=True --tasks hellaswag --batch_size 1 --no_cache --write_out --output_path results/$RESULT/hellaswag_10shot.json --device cuda --num_fewshot 10 +python main.py --model hf-causal-experimental --model_args pretrained=$MODEL,use_accelerate=True,trust_remote_code=True --tasks gsm8k --batch_size 1 --no_cache --write_out --output_path results/$RESULT/gsm8k_5shot.json --device cuda --num_fewshot 5 +python main.py --model hf-causal-experimental --model_args pretrained=$MODEL,use_accelerate=True,trust_remote_code=True --tasks hendrycksTest-abstract_algebra,hendrycksTest-anatomy,hendrycksTest-astronomy,hendrycksTest-business_ethics,hendrycksTest-clinical_knowledge,hendrycksTest-college_biology,hendrycksTest-college_chemistry,hendrycksTest-college_computer_science,hendrycksTest-college_mathematics,hendrycksTest-college_medicine,hendrycksTest-college_physics,hendrycksTest-computer_security,hendrycksTest-conceptual_physics,hendrycksTest-econometrics,hendrycksTest-electrical_engineering,hendrycksTest-elementary_mathematics,hendrycksTest-formal_logic,hendrycksTest-global_facts,hendrycksTest-high_school_biology,hendrycksTest-high_school_chemistry,hendrycksTest-high_school_computer_science,hendrycksTest-high_school_european_history,hendrycksTest-high_school_geography,hendrycksTest-high_school_government_and_politics,hendrycksTest-high_school_macroeconomics,hendrycksTest-high_school_mathematics,hendrycksTest-high_school_microeconomics,hendrycksTest-high_school_physics,hendrycksTest-high_school_psychology,hendrycksTest-high_school_statistics,hendrycksTest-high_school_us_history,hendrycksTest-high_school_world_history,hendrycksTest-human_aging,hendrycksTest-human_sexuality,hendrycksTest-international_law,hendrycksTest-jurisprudence,hendrycksTest-logical_fallacies,hendrycksTest-machine_learning,hendrycksTest-management,hendrycksTest-marketing,hendrycksTest-medical_genetics,hendrycksTest-miscellaneous,hendrycksTest-moral_disputes,hendrycksTest-moral_scenarios,hendrycksTest-nutrition,hendrycksTest-philosophy,hendrycksTest-prehistory,hendrycksTest-professional_accounting,hendrycksTest-professional_law,hendrycksTest-professional_medicine,hendrycksTest-professional_psychology,hendrycksTest-public_relations,hendrycksTest-security_studies,hendrycksTest-sociology,hendrycksTest-us_foreign_policy,hendrycksTest-virology,hendrycksTest-world_religions --batch_size 1 --no_cache --write_out --output_path results/$RESULT/mmlu_5shot.json --device cuda --num_fewshot 5 + +cd .. +python dev/eval/summarize_eval.py lm-evaluation-harness/results/$RESULT diff --git a/dev/eval/summarize_eval.py b/dev/eval/summarize_eval.py new file mode 100644 index 000000000..82425264e --- /dev/null +++ b/dev/eval/summarize_eval.py @@ -0,0 +1,32 @@ +# example run command +# python dev/eval/summarize_eval.py lm-evaluation-harness/results/result774M +# this script is optional, the run_eval.sh should already print these +# but this script can be used to re-print them + +import json, sys + +RESULT = sys.argv[1] +print("-"*40) + +key = {"arc_challenge_25shot.json": "acc_norm", + "gsm8k_5shot.json": "acc", + "hellaswag_10shot.json": "acc_norm", + "mmlu_5shot.json": "acc", + "truthfulqa_0shot.json": "mc2", + "winogrande_5shot.json": "acc" + } + +total = 0 +for test in ["arc_challenge_25shot.json", "gsm8k_5shot.json", "hellaswag_10shot.json", "mmlu_5shot.json", "truthfulqa_0shot.json", "winogrande_5shot.json"]: + data = json.loads(open("./%s/%s"%(RESULT, test)).read()) + r_count = 0 + r_total = 0 + for test_name in data['results']: + r_count += 1 + r_total += data['results'][test_name][key[test]] + score = (r_total*100)/r_count + print(f"{test:<30} : {score:.4f}") + total += score +average = total / 6.0 +print("-"*40) +print(f"Average Score : {average:.4f}") diff --git a/dev/test/Makefile b/dev/test/Makefile new file mode 100644 index 000000000..dfc1d250f --- /dev/null +++ b/dev/test/Makefile @@ -0,0 +1,166 @@ +CC ?= gcc +# example: make test_dataloader TEST_CFLAGS=-fsanitize=address -fno-omit-frame-pointer +CFLAGS = -Ofast -Wno-unused-result -Wno-ignored-pragmas -Wno-unknown-attributes -g +CFLAGS += $(TEST_CFLAGS) +LDFLAGS = +LDLIBS = -lm +INCLUDES = +CFLAGS_COND = -march=native + +# Find nvcc +SHELL_UNAME = $(shell uname) +REMOVE_FILES = rm -f +OUTPUT_FILE = -o $@ +CUDA_OUTPUT_FILE = -o $@ + +# NVCC flags +# -t=0 is short for --threads, 0 = number of CPUs on the machine +NVCC_FLAGS = -O3 -t=0 --use_fast_math -std=c++17 +NVCC_LDFLAGS = -lcublas -lcublasLt +NVCC_INCLUDES = +NVCC_LDLIBS = +NVCC_CUDNN = +# By default we don't build with cudnn because it blows up compile time from a few seconds to ~minute +USE_CUDNN ?= 0 + +# We will place .o files in the `build` directory (create it if it doesn't exist) +BUILD_DIR = build +$(shell mkdir -p $(BUILD_DIR)) +REMOVE_BUILD_OBJECT_FILES := rm -f $(BUILD_DIR)/*.o + +# Function to check if a file exists in the PATH +define file_exists_in_path + $(which $(1) 2>/dev/null) +endef + +ifneq ($(CI),true) # if not in CI, then use the GPU query + ifndef GPU_COMPUTE_CAPABILITY # set to defaults if: make GPU_COMPUTE_CAPABILITY= + ifneq ($(call file_exists_in_path, __nvcc_device_query),) + GPU_COMPUTE_CAPABILITY = $(shell __nvcc_device_query) + GPU_COMPUTE_CAPABILITY := $(strip $(GPU_COMPUTE_CAPABILITY)) + endif + endif +endif + +# set to defaults if - make GPU_COMPUTE_CAPABILITY= otherwise use the compute capability detected above +ifneq ($(GPU_COMPUTE_CAPABILITY),) + NVCC_FLAGS += --generate-code arch=compute_$(GPU_COMPUTE_CAPABILITY),code=[compute_$(GPU_COMPUTE_CAPABILITY),sm_$(GPU_COMPUTE_CAPABILITY)] +endif + +# autodect a lot of various supports on current platform +$(info ---------------------------------------------) + +NVCC := $(shell which nvcc 2>/dev/null) + +# Check and include cudnn if available +# You can override the path to cudnn frontend by setting CUDNN_FRONTEND_PATH on the make command line +# By default, we look for it in HOME/cudnn-frontend/include and ./cudnn-frontend/include +# Refer to the README for cuDNN install instructions +ifeq ($(USE_CUDNN), 1) + ifeq ($(shell [ -d $(HOME)/cudnn-frontend/include ] && echo "exists"), exists) + $(info ✓ cuDNN found, will run with flash-attention) + CUDNN_FRONTEND_PATH ?= $(HOME)/cudnn-frontend/include + else ifeq ($(shell [ -d cudnn-frontend/include ] && echo "exists"), exists) + $(info ✓ cuDNN found, will run with flash-attention) + CUDNN_FRONTEND_PATH ?= cudnn-frontend/include + else + $(error ✗ cuDNN not found. See the README for install instructions and the Makefile for hard-coded paths) + endif + NVCC_INCLUDES += -I$(CUDNN_FRONTEND_PATH) + NVCC_LDFLAGS += -lcudnn + NVCC_FLAGS += -DENABLE_CUDNN + NVCC_CUDNN = $(BUILD_DIR)/cudnn_att.o +else + $(info → cuDNN is manually disabled by default, run make with `USE_CUDNN=1` to try to enable) +endif + +# Check if OpenMP is available +# This is done by attempting to compile an empty file with OpenMP flags +# OpenMP makes the code a lot faster so I advise installing it +# e.g. on MacOS: brew install libomp +# e.g. on Ubuntu: sudo apt-get install libomp-dev +# later, run the program by prepending the number of threads, e.g.: OMP_NUM_THREADS=8 ./gpt2 +# First, check if NO_OMP is set to 1, if not, proceed with the OpenMP checks +ifeq ($(NO_OMP), 1) + $(info OpenMP is manually disabled) +else + ifneq ($(OS), Windows_NT) + # Check for OpenMP support in GCC or Clang on Linux + ifeq ($(shell echo | $(CC) -fopenmp -x c -E - > /dev/null 2>&1; echo $$?), 0) + CFLAGS += -fopenmp -DOMP + LDLIBS += -lgomp + $(info ✓ OpenMP found) + else + $(info ✗ OpenMP not found) + endif + endif +endif + +# Check if OpenMPI and NCCL are available, include them if so, for multi-GPU training +ifeq ($(NO_MULTI_GPU), 1) + $(info → Multi-GPU (OpenMPI + NCCL) is manually disabled) +else + ifeq ($(shell [ -d /usr/lib/x86_64-linux-gnu/openmpi/lib/ ] && [ -d /usr/lib/x86_64-linux-gnu/openmpi/include/ ] && echo "exists"), exists) + $(info ✓ OpenMPI found, OK to train with multiple GPUs) + NVCC_INCLUDES += -I/usr/lib/x86_64-linux-gnu/openmpi/include + NVCC_LDFLAGS += -L/usr/lib/x86_64-linux-gnu/openmpi/lib/ + NVCC_LDLIBS += -lmpi -lnccl + NVCC_FLAGS += -DMULTI_GPU + else + $(info ✗ OpenMPI is not found, disabling multi-GPU support) + $(info ---> On Linux you can try install OpenMPI with `sudo apt install openmpi-bin openmpi-doc libopenmpi-dev`) + endif +endif + +# Precision settings, default to bf16 but ability to override +ifeq ($(MAKECMDGOALS), clean) + PRECISION=BF16 +endif + +VALID_PRECISIONS := FP32 FP16 BF16 +ifeq ($(filter $(PRECISION),$(VALID_PRECISIONS)),) + $(error Invalid precision $(PRECISION), valid precisions are $(VALID_PRECISIONS)) +endif +ifeq ($(PRECISION), FP32) + PFLAGS = -DENABLE_FP32 +else ifeq ($(PRECISION), FP16) + PFLAGS = -DENABLE_FP16 +else + PFLAGS = -DENABLE_BF16 +endif + +# PHONY means these targets will always be executed +.PHONY: all clean + +# Add targets +TARGETS = test_dataloader + +# Dependency files +test_dataloader_dependencies = test_dataloader.d +HEADER_DEPENDENCIES = $(test_dataloader_dependencies) + +# Conditional inclusion of CUDA targets +ifeq ($(NVCC),) + $(info ✗ nvcc not found, skipping GPU/CUDA builds) +else + $(info ✓ nvcc found, including GPU/CUDA support) + TARGETS += +endif + +$(info ---------Build Configuration Complete - Build Targets -------------------------) + +all: $(TARGETS) + +# Generate dependency files +%.d: %.c + $(CC) $(CFLAGS) -MMD -MP -MF $@ -c $< + +# Include the dependency files +-include test_dataloader.d + +test_dataloader: test_dataloader.c + $(CC) $(CFLAGS) $(INCLUDES) $(LDFLAGS) -MMD -MP $^ $(LDLIBS) $(OUTPUT_FILE) + +clean: + $(REMOVE_FILES) $(TARGETS) *.d *.o + $(REMOVE_BUILD_OBJECT_FILES) diff --git a/dev/test/device_file_io.cu b/dev/test/device_file_io.cu new file mode 100644 index 000000000..71fb1ce7e --- /dev/null +++ b/dev/test/device_file_io.cu @@ -0,0 +1,64 @@ +/* +Tests device <-> file IO functions + +compile and run as (from dev/test directory) +nvcc -o device_file_io device_file_io.cu && ./device_file_io +*/ + + +#include "../../llmc/cuda_common.h" +#include +#include +#include +#include + +void test(size_t nelem, size_t wt_buf_size, size_t rd_buf_size) { + + float* data; + cudaCheck(cudaMalloc(&data, nelem*sizeof(float))); + + // generate random array + std::vector random_data(nelem); + std::mt19937 rng(42); + std::uniform_real_distribution dist(-100.f, 100.f); + std::generate(random_data.begin(), random_data.end(), [&](){ return dist(rng); }); + + cudaCheck(cudaMemcpy(data, random_data.data(), random_data.size()*sizeof(float), cudaMemcpyHostToDevice)); + + cudaStream_t stream; + cudaStreamCreate(&stream); + + FILE* tmp = fopenCheck("tmp.bin", "w"); + device_to_file(tmp, data, nelem * sizeof(float), wt_buf_size, stream); + fcloseCheck(tmp); + + + float* reload; + cudaCheck(cudaMalloc(&reload, nelem*sizeof(float))); + + tmp = fopenCheck("tmp.bin", "r"); + file_to_device(reload, tmp, nelem * sizeof(float), rd_buf_size, stream); + fcloseCheck(tmp); + + std::vector cmp(nelem); + cudaCheck(cudaMemcpy(cmp.data(), reload, nelem * sizeof(float), cudaMemcpyDeviceToHost)); + for(int i = 0; i < nelem; ++i) { + if(random_data[i] != cmp[i]) { + fprintf(stderr, "FAIL: Mismatch at position %d: %f vs %f\n", i, random_data[i], cmp[i]); + remove("tmp.bin"); + exit(EXIT_FAILURE); + } + } + + cudaCheck(cudaFree(reload)); + cudaCheck(cudaFree(data)); + remove("tmp.bin"); +} + +int main() { + test(1025, 10000, 10000); // buffers larger than data + test(1025, 1024, 513); // different and smaller + test(500, 500*sizeof(float), + 500*sizeof(float)); // exact match + test(125'000, 10000, 10000); // large array +} \ No newline at end of file diff --git a/dev/test/test_dataloader.c b/dev/test/test_dataloader.c new file mode 100644 index 000000000..2803da022 --- /dev/null +++ b/dev/test/test_dataloader.c @@ -0,0 +1,304 @@ +/* +Tests our DataLoader + +compile and run as (from dev/test directory) +gcc -O3 -I../../llmc -o test_dataloader test_dataloader.c -lm && ./test_dataloader + +TODOs: +- test load/save state of DataLoader +*/ +#include +#include "../../llmc/dataloader.h" + +#define SHARD_NAME_LEN 64 +char shard_name[SHARD_NAME_LEN]; +const int num_tokens = 140; +int num_shards = 4; + +void check_range(const int *tokens, const int start, const int end, const char *file, int line) { + // checks that the tokens[0, ... end-start] are the range [start, end) + int n = end - start; + for (int i = 0; i < n; i++) { + int token = tokens[i]; + if (token != start + i) { + fprintf(stderr, "Error: tokens[%d] = %d, expected %d\n", i, token, start + i); + fprintf(stderr, "Error details:\n"); + fprintf(stderr, " File: %s\n", file); + fprintf(stderr, " Line: %d\n", line); + exit(EXIT_FAILURE); + } + } + // printf("tokens in range [%d, %d) OK\n", start, end); +} +#define checkRange(tokens, start, end) check_range(tokens, start, end, __FILE__, __LINE__) + +void check_equals(const int *tokens, const int n, const int expected, const char *file, int line) { + // checks that the tokens[0, ... n] are all equal to expected + for (int i = 0; i < n; i++) { + int token = tokens[i]; + if (token != expected) { + fprintf(stderr, "Error: tokens[%d] = %d, expected %d\n", i, token, expected); + fprintf(stderr, "Error details:\n"); + fprintf(stderr, " File: %s\n", file); + fprintf(stderr, " Line: %d\n", line); + exit(EXIT_FAILURE); + } + } + // printf("tokens all equal to %d OK\n", expected); +} +#define checkEquals(tokens, n, expected) check_equals(tokens, n, expected, __FILE__, __LINE__) + +void test_simple(void) { + /* + Tests the simplest DataLoader functionality: + - multi-shard + - single-process + - not shuffled + DataLoader should just return all the tokens in order + */ + printf("test_simple... "); + int B = 4; + int T = 8; + int process_rank = 0; + int num_processes = 1; + int should_shuffle = 0; + snprintf(shard_name, SHARD_NAME_LEN, "shard_????.bin"); + DataLoader loader; + dataloader_init(&loader, shard_name, B, T, process_rank, num_processes, should_shuffle); + + int batches_fit = num_tokens / (B * T); // number of batches that fit per shard + int BT = B * T; + int num_epochs = 4; + for (int e = 0; e < num_epochs; e++) { // epoch + for (int s = 0; s < num_shards; s++) { // shard + int start = s * num_tokens; + for (int b = 0; b < batches_fit; b++) { // batch + dataloader_next_batch(&loader); + checkRange(loader.inputs, start, start + BT); + checkRange(loader.targets, start + 1, start + BT + 1); + start += BT; + } + } + } + dataloader_free(&loader); + printf("OK\n"); +} + +void test_multiprocess_simple(void) { + /* + Same as simple above, but using 2 processes. + (which we of course use in a serial, single process way here) + The DataLoaders simply pull chunks of consecutive tokens, so + we expect them to alternate in the "token space". + */ + printf("test_multiprocess_simple... "); + int B = 4; + int T = 8; + int num_processes = 2; + int should_shuffle = 0; + snprintf(shard_name, SHARD_NAME_LEN, "shard_????.bin"); + DataLoader loader0, loader1; + dataloader_init(&loader0, shard_name, B, T, 0, num_processes, should_shuffle); + dataloader_init(&loader1, shard_name, B, T, 1, num_processes, should_shuffle); + + int batches_fit = num_tokens / (B * T * num_processes); // number of batches that fit per shard + int BT = B * T; + int num_epochs = 4; + for (int e = 0; e < num_epochs; e++) { // epoch + for (int s = 0; s < num_shards; s++) { // shard + int start = s * num_tokens; + for (int b = 0; b < batches_fit; b++) { // batch + dataloader_next_batch(&loader0); + dataloader_next_batch(&loader1); + checkRange(loader0.inputs, start, start + BT); + checkRange(loader1.inputs, start + BT, start + 2*BT); + checkRange(loader0.targets, start + 1, start + BT + 1); + checkRange(loader1.targets, start + BT + 1, start + 2*BT + 1); + start += 2*BT; + } + } + } + + dataloader_free(&loader0); + dataloader_free(&loader1); + printf("OK\n"); +} + +void test_shuffled(void) { + /* + Tests the DataLoader when using shuffled: + - multi-shard + - single-process + - shuffled! + DataLoader should return all the tokens, but in randperm order. + So all we check is that we see all the tokens we expect to see, + the correct number of times. + */ + printf("test_shuffled... "); + int B = 4; + int T = 8; + int process_rank = 0; + int num_processes = 1; + int should_shuffle = 1; // should shuffle bit turn on + snprintf(shard_name, 64, "shard_????.bin"); + DataLoader loader; + dataloader_init(&loader, shard_name, B, T, process_rank, num_processes, should_shuffle); + + // get batches from the dataloader and keep stats on what tokens we see + int total_tokens = num_shards * num_tokens; + int *num_seen_inputs = (int *)calloc(total_tokens, sizeof(int)); + int *num_seen_targets = (int *)calloc(total_tokens, sizeof(int)); + int batches_fit = num_tokens / (B * T); // number of batches that fit per shard + int BT = B * T; + int num_epochs = 4; + for (int e = 0; e < num_epochs; e ++) { // epoch + for (int s = 0; s < num_shards; s++) { // shard + int start = s * num_tokens; + for (int b = 0; b < batches_fit; b++) { // batch + dataloader_next_batch(&loader); + // count up the tokens we see + for (int i = 0; i < BT; i++) { + int input_token = loader.inputs[i]; + int target_token = loader.targets[i]; + assert(input_token >= 0 && input_token < total_tokens); + assert(target_token >= 0 && target_token < total_tokens); + num_seen_inputs[input_token]++; + num_seen_targets[target_token]++; + } + start += BT; + } + } + } + + // verify that we saw all the tokens the correct number of times + int tokens_fit = batches_fit * BT; // number of tokens that fit per shard + for (int s = 0; s < num_shards; s++) { + int start = s * num_tokens; + // verify the inputs counts for this shard: + // - the first tokens_fit should have been seen num_epochs times + // - the rest of the tokens in that should should have been seen zero times + checkEquals(num_seen_inputs + start, tokens_fit, num_epochs); + checkEquals(num_seen_inputs + start + tokens_fit, num_tokens - tokens_fit, 0); + // verify the target counts. same thing but offset by 1 + checkEquals(num_seen_targets + start + 1, tokens_fit, num_epochs); + checkEquals(num_seen_targets + start + 1 + tokens_fit, + (s == (num_shards - 1)) ? num_tokens - tokens_fit - 1 : num_tokens - tokens_fit,0); + } + + dataloader_free(&loader); + free(num_seen_inputs); + free(num_seen_targets); + printf("OK\n"); +} + +void test_multiprocess_shuffled(void) { + /* + Tests the DataLoader when using both multiprocess and shuffled: + - multi-shard + - multi-process + - shuffled! + DataLoaders should return all the tokens, but in randperm order. + So all we check is that we see all the tokens we expect to see, + the correct number of times, over multiple epochs. + */ + + printf("test_multiprocess_shuffled... "); + int B = 4; + int T = 8; + const int num_processes = 2; + int should_shuffle = 0; + snprintf(shard_name, SHARD_NAME_LEN, "shard_????.bin"); + DataLoader loaders[num_processes]; + for (int i = 0; i < num_processes; i++) { + dataloader_init(&loaders[i], shard_name, B, T, i, num_processes, should_shuffle); + } + + // get batches from the dataloader and keep stats on what tokens we see + int total_tokens = num_shards * num_tokens; + int *num_seen_inputs = (int *)calloc(total_tokens, sizeof(int)); + int *num_seen_targets = (int *)calloc(total_tokens, sizeof(int)); + int batches_fit = num_tokens / (B * T * num_processes); // number of batches that fit per shard + int BT = B * T; + int num_epochs = 4; + for (int e = 0; e < num_epochs; e ++) { // epoch + for (int s = 0; s < num_shards; s++) { // shard + int start = s * num_tokens; + for (int b = 0; b < batches_fit; b++) { // batch + for (int n = 0; n < num_processes; n++) { // dataloader + DataLoader *loader = &loaders[n]; + dataloader_next_batch(loader); + // count up the tokens we see + for (int i = 0; i < BT; i++) { + int input_token = loader->inputs[i]; + int target_token = loader->targets[i]; + assert(input_token >= 0 && input_token < total_tokens); + assert(target_token >= 0 && target_token < total_tokens); + num_seen_inputs[input_token]++; + num_seen_targets[target_token]++; + } + start += BT; + } + } + } + } + + // verify that we saw all the tokens the correct number of times + int tokens_fit = batches_fit * (B * T * num_processes); // number of tokens that fit per shard + for (int s = 0; s < num_shards; s++) { + int start = s * num_tokens; // token id that starts this shard + // verify the inputs counts for this shard: + // - the first tokens_fit should have been seen num_epochs times + // - the rest of the tokens in that should should have been seen zero times + checkEquals(num_seen_inputs + start, tokens_fit, num_epochs); + checkEquals(num_seen_inputs + start + tokens_fit, num_tokens - tokens_fit, 0); + // verify the target counts. same thing but offset by 1 + checkEquals(num_seen_targets + start + 1, tokens_fit, num_epochs); + checkEquals(num_seen_targets + start + 1 + tokens_fit, + (s == (num_shards - 1)) ? num_tokens - tokens_fit - 1 : num_tokens - tokens_fit,0); + } + + // cleanup + for (int i = 0; i < num_processes; i++) { + dataloader_free(&loaders[i]); + } + free(num_seen_inputs); + free(num_seen_targets); + printf("OK\n"); +} + +int main(void) { + + // generate a few dummy shards of data with incrementing tokens + int header[HEADER_SIZE]; + uint16_t tokens[num_tokens]; + for (int shard_id = 0; shard_id < num_shards; shard_id++) { + // ensure unique tokens across the shards for ez accounting below + int token_offset = shard_id * num_tokens; + for (int i = 0; i < num_tokens; i++) { + tokens[i] = token_offset + i; + } + // write the shard + snprintf(shard_name, SHARD_NAME_LEN, "shard_%04d.bin", shard_id); + header[0] = 20240520; // magic + header[1] = 1; // version + header[2] = num_tokens; // number of tokens within + FILE* shard_file = fopenCheck(shard_name, "wb"); + fwrite(header, sizeof(int), HEADER_SIZE, shard_file); + fwrite(tokens, sizeof(uint16_t), num_tokens, shard_file); + fcloseCheck(shard_file); + printf("Wrote shard %s\n", shard_name); + } + + test_simple(); + test_multiprocess_simple(); + test_shuffled(); + test_multiprocess_shuffled(); + + // clean up the shards + for (int shard_id = 0; shard_id < num_shards; shard_id++) { + snprintf(shard_name, SHARD_NAME_LEN, "shard_%04d.bin", shard_id); + remove(shard_name); + } + + return EXIT_SUCCESS; +} \ No newline at end of file diff --git a/dev/test/test_outlier_detector.c b/dev/test/test_outlier_detector.c new file mode 100644 index 000000000..75b9ca354 --- /dev/null +++ b/dev/test/test_outlier_detector.c @@ -0,0 +1,52 @@ +/* +Tests our OutlierDetector + +compile and run as (from dev/test directory) +gcc -O3 -I../../llmc -o test_outlier_detector test_outlier_detector.c -lm && ./test_outlier_detector +*/ + +#include +#include "../../llmc/outlier_detector.h" + +int main(void) { + OutlierDetector detector; + init_detector(&detector); + + srand(1337); // init rng + + // generate OUTLIER_DETECTOR_WINDOW_SIZE * 2 random numbers between -1 and 1 + for (int i = 0; i < OUTLIER_DETECTOR_WINDOW_SIZE * 2; i++) { + double val = (double)rand() / RAND_MAX * 2 - 1; // Random number between -1 and 1 + double zscore = update_detector(&detector, val); + + printf("Step %d: Value = %.4f, zscore = %.4f\n", i, val, zscore); + + // check that the first OUTLIER_DETECTOR_WINDOW_SIZE values return nan + if (i < OUTLIER_DETECTOR_WINDOW_SIZE) { + if (!isnan(zscore)) { + printf("Error: Expected nan, got %.4f\n", zscore); + return EXIT_FAILURE; + } + } else { + // check that the zscore is within reasonable bounds + if (zscore < -3.0 || zscore > 3.0) { + printf("Error: Z-score %.4f is outside of expected range\n", zscore); + return EXIT_FAILURE; + } + } + } + + // simulate an outlier + double outlier = 10.0; // <--- loss spike + double zscore = update_detector(&detector, outlier); + printf("Outlier Step: Value = %.4f, zscore = %.4f\n", outlier, zscore); + + // check that the z-score here is large + if (zscore < 5.0) { + printf("Error: Z-score %.4f is not large enough for an outlier\n", zscore); + return EXIT_FAILURE; + } + + printf("OK\n"); + return EXIT_SUCCESS; +} diff --git a/dev/unistd.h b/dev/unistd.h index 337f29ad2..6ff68be17 100644 --- a/dev/unistd.h +++ b/dev/unistd.h @@ -4,11 +4,15 @@ #define _CRT_SECURE_NO_WARNINGS #define _USE_MATH_DEFINES +#define WIN32_LEAN_AND_MEAN #include #include -//#define gen_max_length 64 // compile as C++ to skip this VLA issue #include +#include // for malloc and free +#include +#include // for _mkdir and _stat +#include // needed for _access below and _findfirst, _findnext, _findclose #define CLOCK_MONOTONIC 0 static inline int clock_gettime(int ignore_variable, struct timespec* tv) @@ -17,14 +21,12 @@ static inline int clock_gettime(int ignore_variable, struct timespec* tv) } #define OMP /* turn it on */ -#include /* needed for access below */ #define F_OK 0 #define access _access #define TURN_OFF_FP_FAST __pragma(float_control( precise, on, push )) // Save current setting and turn on /fp:precise #define TURN_ON_FP_FAST __pragma(float_control(pop)) // Restore file's default settings -#include /* for _mkdir and _stat */ #define mkdir(path, mode) _mkdir(path) /* sketchy way to get mkdir to work on windows */ #define stat _stat @@ -59,7 +61,7 @@ static inline int glob(const char* pattern, int ignored_flags, int (*ignored_err replace_forward_slashes (pattern_copy); // Replace forward slashes with backslashes - if (strchr(pattern_copy, '\\') != NULL) { + if (strchr(pattern_copy, '\\') != (void*) NULL) { strncpy_s(directory_path, sizeof(directory_path) - 1, pattern_copy, strrchr(pattern_copy, '\\') - pattern_copy + 1); directory_path[strrchr(pattern_copy, '\\') - pattern_copy + 1] = '\0'; } diff --git a/dev/vislog.ipynb b/dev/vislog.ipynb index a6c602ba0..96dbe4877 100644 --- a/dev/vislog.ipynb +++ b/dev/vislog.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -61,63 +61,86 @@ "metadata": {}, "outputs": [], "source": [ - "sz = \"350M\"\n", + "import numpy as np\n", + "\n", + "sz = \"124M\"\n", "loss_baseline = {\n", " \"124M\": 3.424958,\n", " \"350M\": 3.083089,\n", " \"774M\": 3.000580,\n", " \"1558M\": 2.831273,\n", "}[sz]\n", - "hella_baseline = {\n", + "hella2_baseline = { # for GPT-2\n", " \"124M\": 0.294463,\n", " \"350M\": 0.375224,\n", " \"774M\": 0.431986,\n", " \"1558M\": 0.488946,\n", "}[sz]\n", - "\n", + "hella3_baseline = { # for GPT-3\n", + " \"124M\": 0.337,\n", + " \"350M\": 0.436,\n", + " \"774M\": 0.510,\n", + " \"1558M\": 0.547,\n", + "}[sz]\n", "# assumes each model run is stored in this way\n", - "logfile = f\"../log{sz}/main.log\"\n", + "logfile = f\"../log_gpt2_{sz}/main.log\"\n", "streams = parse_logfile(logfile)\n", "\n", + "# optional function that smooths out the loss some\n", + "def smooth_moving_average(signal, window_size):\n", + " if signal.ndim != 1:\n", + " raise ValueError(\"smooth_moving_average only accepts 1D arrays.\")\n", + " if signal.size < window_size:\n", + " raise ValueError(\"Input vector needs to be bigger than window size.\")\n", + " if window_size < 3:\n", + " return signal\n", + "\n", + " s = np.pad(signal, (window_size//2, window_size-1-window_size//2), mode='edge')\n", + " w = np.ones(window_size) / window_size\n", + " smoothed_signal = np.convolve(s, w, mode='valid')\n", + " return smoothed_signal\n", + "\n", "plt.figure(figsize=(16, 6))\n", "\n", "# Panel 1: losses: both train and val\n", "plt.subplot(121)\n", "xs, ys = streams[\"trl\"] # training loss\n", + "ys = np.array(ys)\n", + "# smooth out ys using a rolling window\n", + "# ys = smooth_moving_average(ys, 21) # optional\n", "plt.plot(xs, ys, label=f'llm.c ({sz}) train loss')\n", "print(\"Min Train Loss:\", min(ys))\n", "xs, ys = streams[\"tel\"] # validation loss\n", "plt.plot(xs, ys, label=f'llm.c ({sz}) val loss')\n", "# horizontal line at GPT-2 baseline\n", + "# we don't have GPT-3 loss on this dataset because the weights were never released\n", "if loss_baseline is not None:\n", " plt.axhline(y=loss_baseline, color='r', linestyle='--', label=f\"OpenAI GPT-2 ({sz}) checkpoint val loss\")\n", "plt.xlabel(\"steps\")\n", "plt.ylabel(\"loss\")\n", "plt.yscale('log')\n", + "plt.ylim(top=4.0)\n", "plt.legend()\n", "plt.title(\"Loss\")\n", "print(\"Min Validation Loss:\", min(ys))\n", "\n", "# Panel 2: HellaSwag eval\n", "plt.subplot(122)\n", - "xs, ys = streams[\"eval\"] # HellaSwag eval\n", - "plt.plot(xs, ys, label=f\"llm.c ({sz})\")\n", - "# horizontal line at GPT-2 baseline\n", - "if hella_baseline:\n", - " plt.axhline(y=hella_baseline, color='r', linestyle='--', label=f\"OpenAI GPT-2 ({sz}) checkpoint\")\n", - "plt.xlabel(\"steps\")\n", - "plt.ylabel(\"accuracy\")\n", - "plt.legend()\n", - "plt.title(\"HellaSwag eval\")\n", - "print(\"Max Hellaswag eval:\", max(ys))" + "if \"eval\" in streams:\n", + " xs, ys = streams[\"eval\"] # HellaSwag eval\n", + " ys = np.array(ys)\n", + " plt.plot(xs, ys, label=f\"llm.c ({sz})\")\n", + " # horizontal line at GPT-2/3 baselines\n", + " if hella2_baseline:\n", + " plt.axhline(y=hella2_baseline, color='r', linestyle='--', label=f\"OpenAI GPT-2 ({sz}) checkpoint\")\n", + " if hella3_baseline:\n", + " plt.axhline(y=hella3_baseline, color='g', linestyle='--', label=f\"OpenAI GPT-3 ({sz}) checkpoint\")\n", + " plt.xlabel(\"steps\")\n", + " plt.ylabel(\"accuracy\")\n", + " plt.legend()\n", + " plt.title(\"HellaSwag eval\")\n", + " print(\"Max Hellaswag eval:\", max(ys))\n" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/llmc/adamw.cuh b/llmc/adamw.cuh index 94d464250..f806eaa35 100644 --- a/llmc/adamw.cuh +++ b/llmc/adamw.cuh @@ -40,9 +40,7 @@ __device__ void adamw_update(Tp* params_memory, float* master_params_memory, Tg* float param = old_param - (learning_rate * (m / (sqrtf(v) + eps) + weight_decay * old_param)); // update our low precision version of the parameters using stochastic rounding // this will be used in the next forward pass - // TODO: simply doing `params_memory[i] = (floatX)param;` breaks everything (why?) - unsigned int random = Get2dNoiseUint(threadIdx.x, blockIdx.x + blockDim.y * gridDim.y, seed); - stochastic_rounding(param, ¶ms_memory[idx], random); + stochastic_rounding(param, ¶ms_memory[idx], seed); // write the full, float version of the param into our master copy, if we maintain one // this will be used in the next update if (master_params_memory != NULL) { master_params_memory[idx] = param; } diff --git a/llmc/cuda_common.h b/llmc/cuda_common.h index 8d165de6e..006ad3010 100644 --- a/llmc/cuda_common.h +++ b/llmc/cuda_common.h @@ -8,6 +8,7 @@ Common utilities for CUDA code. #include #include #include +#include // std::bool_constant #include #include #include @@ -15,6 +16,8 @@ Common utilities for CUDA code. #include #include +#include "utils.h" + // ---------------------------------------------------------------------------- // Global defines and settings @@ -38,11 +41,15 @@ extern cudaDeviceProp deviceProp; // convenience macro for calculating grid/block dimensions for kernels #define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) +// short-cuts for compile-time boolean values that can be used as function arguments +constexpr std::bool_constant True; +constexpr std::bool_constant False; + // ---------------------------------------------------------------------------- // Error checking // CUDA error checking -void inline cudaCheck(cudaError_t error, const char *file, int line) { +inline void cudaCheck(cudaError_t error, const char *file, int line) { if (error != cudaSuccess) { printf("[CUDA ERROR] at file %s:%d:\n%s\n", file, line, cudaGetErrorString(error)); exit(EXIT_FAILURE); @@ -50,6 +57,18 @@ void inline cudaCheck(cudaError_t error, const char *file, int line) { }; #define cudaCheck(err) (cudaCheck(err, __FILE__, __LINE__)) +// like cudaFree, but checks for errors _and_ resets the pointer. +template +inline void cudaFreeCheck(T** ptr, const char *file, int line) { + cudaError_t error = cudaFree(*ptr); + if (error != cudaSuccess) { + printf("[CUDA ERROR] at file %s:%d:\n%s\n", file, line, cudaGetErrorString(error)); + exit(EXIT_FAILURE); + } + *ptr = nullptr; +} +#define cudaFreeCheck(ptr) (cudaFreeCheck(ptr, __FILE__, __LINE__)) + // ---------------------------------------------------------------------------- // CUDA Precision settings and defines @@ -104,4 +123,87 @@ class NvtxRange { }; #define NVTX_RANGE_FN() NvtxRange nvtx_range(__FUNCTION__) +// ---------------------------------------------------------------------------- +// Utilities to Read & Write between CUDA memory <-> files + +// copy num_bytes from device pointer src into file dest, using double buffering running on the given stream. +inline void device_to_file(FILE* dest, void* src, size_t num_bytes, size_t buffer_size, cudaStream_t stream) { + // allocate pinned buffer for faster, async transfer + char* buffer_space; + cudaCheck(cudaMallocHost(&buffer_space, 2*buffer_size)); + // split allocation in two + void* read_buffer = buffer_space; + void* write_buffer = buffer_space + buffer_size; + + // prime the read buffer; first copy means we have to wait + char* gpu_read_ptr = (char*)src; + size_t copy_amount = std::min(buffer_size, num_bytes); + cudaCheck(cudaMemcpyAsync(read_buffer, gpu_read_ptr, copy_amount, cudaMemcpyDeviceToHost, stream)); + cudaCheck(cudaStreamSynchronize(stream)); + size_t rest_bytes = num_bytes - copy_amount; + size_t write_buffer_size = copy_amount; + gpu_read_ptr += copy_amount; + + std::swap(read_buffer, write_buffer); + // now the main loop; as long as there are bytes left + while(rest_bytes > 0) { + // initiate next read + copy_amount = std::min(buffer_size, rest_bytes); + cudaCheck(cudaMemcpyAsync(read_buffer, gpu_read_ptr, copy_amount, cudaMemcpyDeviceToHost, stream)); + // while this is going on, transfer the write buffer to disk + fwriteCheck(write_buffer, 1, write_buffer_size, dest); + cudaCheck(cudaStreamSynchronize(stream)); // wait for both buffers to be ready. + + std::swap(read_buffer, write_buffer); + rest_bytes -= copy_amount; + write_buffer_size = copy_amount; + gpu_read_ptr += copy_amount; + } + + // make sure to write the last remaining write buffer + fwriteCheck(write_buffer, 1, write_buffer_size, dest); + cudaCheck(cudaFreeHost(buffer_space)); +} + +// copy num_bytes from file src into device pointer dest, using double buffering running on the given stream. +inline void file_to_device(void* dest, FILE* src, size_t num_bytes, size_t buffer_size, cudaStream_t stream) { + // allocate pinned buffer for faster, async transfer + // from the docs (https://developer.download.nvidia.com/compute/DevZone/docs/html/C/doc/html/group__CUDART__HIGHLEVEL_ge439496de696b166ba457dab5dd4f356.html) + // WC memory is a good option for buffers that will be written by the CPU and read by the device via mapped pinned memory or host->device transfers. + char* buffer_space; + cudaCheck(cudaMallocHost(&buffer_space, 2*buffer_size, cudaHostAllocWriteCombined)); + // split allocation in two + void* read_buffer = buffer_space; + void* write_buffer = buffer_space + buffer_size; + + // prime the read buffer; + char* gpu_write_ptr = (char*)dest; + size_t copy_amount = std::min(buffer_size, num_bytes); + freadCheck(read_buffer, 1, copy_amount, src); + + size_t rest_bytes = num_bytes - copy_amount; + size_t write_buffer_size = copy_amount; + std::swap(read_buffer, write_buffer); + + // now the main loop; as long as there are bytes left + while(rest_bytes > 0) { + // initiate next read + copy_amount = std::min(buffer_size, rest_bytes); + cudaCheck(cudaMemcpyAsync(gpu_write_ptr, write_buffer, write_buffer_size, cudaMemcpyHostToDevice, stream)); + gpu_write_ptr += write_buffer_size; + // while this is going on, read from disk + freadCheck(read_buffer, 1, copy_amount, src); + cudaCheck(cudaStreamSynchronize(stream)); // wait for both buffers to be ready. + + std::swap(read_buffer, write_buffer); + rest_bytes -= copy_amount; + write_buffer_size = copy_amount; + } + + // copy the last remaining write buffer to gpu + cudaCheck(cudaMemcpyAsync(gpu_write_ptr, write_buffer, write_buffer_size, cudaMemcpyHostToDevice, stream)); + cudaCheck(cudaStreamSynchronize(stream)); + cudaCheck(cudaFreeHost(buffer_space)); +} + #endif // CUDA_COMMON_H \ No newline at end of file diff --git a/llmc/cuda_utils.cuh b/llmc/cuda_utils.cuh index 1bca60a8a..4204c3173 100644 --- a/llmc/cuda_utils.cuh +++ b/llmc/cuda_utils.cuh @@ -153,6 +153,28 @@ __device__ inline float blockReduce(float val, bool final_sync=false, float out_ return block_val; } +// Performs a _deterministic_ sum reduction. determinism is achieved by requiring that only +// a single block be used. +template +__global__ void global_sum_single_block_kernel(float* result, const Float* values, size_t count) { + assert(gridDim.x == 1); // only a single block! + float thread_sum = 0; + for(size_t index = threadIdx.x; index < count; index += blockDim.x) { + thread_sum += (float)values[index]; + } + + float reduction = blockReduce(thread_sum, true); + if(threadIdx.x == 0) { + *result = reduction; + } +} + +template +void global_sum_deterministic(float* result, const Float* values, int count, cudaStream_t stream) { + global_sum_single_block_kernel<<<1, 1024, 0, stream>>>(result, values, count); + cudaCheck(cudaGetLastError()); +} + // ---------------------------------------------------------------------------- // Random Number Generation used in Stochastic Rounding @@ -160,14 +182,14 @@ __device__ inline float blockReduce(float val, bool final_sync=false, float out_ // This gives us a random number from threadIdx/blockIdx + a single seed for the entire GPU // todo - possibly overkill and we don't need such high quality random numbers? (tbd) // http://eiserloh.net/noise/SquirrelNoise5.hpp -__device__ __host__ constexpr unsigned int SquirrelNoise5(int positionX, unsigned int seed) +__device__ __host__ constexpr unsigned int SquirrelNoise5(unsigned int positionX, unsigned int seed) { constexpr unsigned int SQ5_BIT_NOISE1 = 0xd2a80a3f; // 11010010101010000000101000111111 constexpr unsigned int SQ5_BIT_NOISE2 = 0xa884f197; // 10101000100001001111000110010111 constexpr unsigned int SQ5_BIT_NOISE3 = 0x6C736F4B; // 01101100011100110110111101001011 constexpr unsigned int SQ5_BIT_NOISE4 = 0xB79F3ABB; // 10110111100111110011101010111011 constexpr unsigned int SQ5_BIT_NOISE5 = 0x1b56c4f5; // 00011011010101101100010011110101 - unsigned int mangledBits = (unsigned int) positionX; + unsigned int mangledBits = positionX; mangledBits *= SQ5_BIT_NOISE1; mangledBits += seed; mangledBits ^= (mangledBits >> 9); @@ -183,14 +205,18 @@ __device__ __host__ constexpr unsigned int SquirrelNoise5(int positionX, unsigne } __device__ __host__ constexpr unsigned int Get2dNoiseUint(int indexX, int indexY, unsigned int seed) { - constexpr int PRIME_NUMBER = 198491317; // Large prime number with non-boring bits - return SquirrelNoise5(indexX + (PRIME_NUMBER * indexY), seed); + constexpr unsigned int PRIME_NUMBER = 198491317u; // Large prime number with non-boring bits + unsigned int x = static_cast(indexX); + unsigned int y = static_cast(indexY); + + return SquirrelNoise5(x + (PRIME_NUMBER * y), seed); } // stochastic rounding built on top of Squirel Noise above (with seed updated per step via xorshift) __device__ __forceinline__ void stochastic_rounding(float in, __nv_bfloat16 *out, unsigned int seed) { // todo - is this stochastic rounding *too good*? can we cut any corners? - unsigned int random = Get2dNoiseUint(threadIdx.x, blockIdx.x, seed); + // makes sure each thread gets a different random number + unsigned int random = Get2dNoiseUint(threadIdx.x, blockIdx.x * blockDim.x + blockIdx.y, seed); unsigned int threshold = random & 0xFFFF; unsigned int float_bits = __float_as_uint(in); unsigned int rounded_bits = float_bits & 0x0000FFFF; diff --git a/llmc/cudnn_att.cpp b/llmc/cudnn_att.cpp index 721786243..8d0ad53d1 100644 --- a/llmc/cudnn_att.cpp +++ b/llmc/cudnn_att.cpp @@ -183,6 +183,9 @@ auto lookup_cache_or_build_graph_bwd(int B, int NH, int T, int HS) { .set_uid(Attn_scale_UID) .set_data_type(fe::DataType_t::FLOAT)); auto sdpa_backward_options = fe::graph::SDPA_backward_attributes().set_name("flash_attention_backward") +#if CUDNN_FRONTEND_MAJOR_VERSION > 1 || CUDNN_FRONTEND_MINOR_VERSION >= 5 + .set_deterministic_algorithm(true) // 1.5+ needs this for determinism +#endif .set_causal_mask(true) .set_attn_scale(attn_scale); diff --git a/llmc/dataloader.h b/llmc/dataloader.h index 73073872e..c96296f9f 100644 --- a/llmc/dataloader.h +++ b/llmc/dataloader.h @@ -15,6 +15,7 @@ // defines: fopenCheck, freadCheck, fcloseCheck, fseekCheck // defines: mallocCheck #include "utils.h" +#include "rand.h" // ---------------------------------------------------------------------------- // implementation of glob for Windows is in dev/unistd.h @@ -30,23 +31,37 @@ typedef struct { // each process/worker has to access different parts of the data int process_rank; int num_processes; - // hyperparameters. use size_t to prevent overflow + // batch and token information size_t B; size_t T; - // input handling and its state + size_t num_tokens; // total number of tokens + size_t shard_num_samples; // total number of samples in the current shard per process + // shards and current position glob_t glob_result; // stores the result of glob, for all shards we want to iterate - int current_shard; // the current shard we are reading from + size_t current_shard_idx; // the current shard we are reading from + size_t current_sample_idx; // the current sample we are reading from + // file handle FILE* tokens_file; - int64_t file_size; - int64_t current_position; + // data buffers uint16_t* buffer; // we fread data from file into this buffer - // public variables that could be accessed from outside - size_t num_tokens; // total number of tokens int* inputs; // input tokens into transformer int* targets; // target tokens for the transformer + // random shuffle related variables + mt19937_state shuffle_rng; + int should_shuffle; + int* shard_indices; + int* intra_shard_indices; + // sizes in bytes + size_t total_batch_size_bytes; // total across all processes + size_t local_batch_offset_bytes; // inner-sample offset for this process + size_t header_bytes; // header size in bytes + int64_t file_size_bytes; } DataLoader; int64_t dataloader_load_shard_(DataLoader *loader, int shard_index) { + if (loader->should_shuffle) { + shard_index = loader->shard_indices[shard_index]; + } // use the first glob match as the filename for now const char* filename = loader->glob_result.gl_pathv[shard_index]; // open the input file for reading. also only a single file can be opened at a time @@ -68,44 +83,60 @@ int64_t dataloader_load_shard_(DataLoader *loader, int shard_index) { assert(ntok > 0); // we expect some tokens in the file. this should never trip, right? // determine the file size and make sure it is consistent with the number of tokens fseekCheck(loader->tokens_file, 0, SEEK_END); // seek to end of file - loader->file_size = ftell(loader->tokens_file); // read the offset, i.e. file size + loader->file_size_bytes = ftell(loader->tokens_file); // read the offset, i.e. file size fseekCheck(loader->tokens_file, 0, SEEK_SET); // seek back to the beginning // we expect ntok in the file to be consistent with filesize, assert that is the case int64_t expected_file_size = HEADER_SIZE * sizeof(int) + ntok * sizeof(uint16_t); - if (loader->file_size != expected_file_size) { + if (loader->file_size_bytes != expected_file_size) { printf("Error: file size is not as expected\n"); exit(EXIT_FAILURE); } + // -1 uint16_t due to us taking B*T+1 tokens but moving by B*T tokens + loader->shard_num_samples = (ntok * sizeof(uint16_t) - sizeof(uint16_t)) / loader->total_batch_size_bytes; return ntok; } -void dataloader_resume(DataLoader *loader, int current_shard, int64_t current_position) { - // used during model resumption (-y 1) flag - loader->current_shard = current_shard; - loader->current_position = current_position; - dataloader_load_shard_(loader, loader->current_shard); +void prepare_intra_shard_indices_(DataLoader *loader) { + // shuffle the examples inside the shards + if (loader->intra_shard_indices != NULL) { + // in case shards have different number of samples / sizes + free(loader->intra_shard_indices); + } + loader->intra_shard_indices = (int*)mallocCheck(loader->shard_num_samples * sizeof(int)); + init_identity_permutation(loader->intra_shard_indices, (int) loader->shard_num_samples); + random_permutation(loader->intra_shard_indices, (int) loader->shard_num_samples, &loader->shuffle_rng); } void dataloader_reset(DataLoader *loader) { - // fully resets the DataLoader object to init configuration - // each process starts at a different offset in the file - int64_t header_bytes = HEADER_SIZE * sizeof(int); - int64_t token_bytes_offset = loader->process_rank * loader->B * loader->T * sizeof(uint16_t); - loader->current_shard = 0; - loader->current_position = header_bytes + token_bytes_offset; - dataloader_load_shard_(loader, loader->current_shard); + loader->current_shard_idx = 0; + loader->current_sample_idx = 0; + + if (loader->should_shuffle) { // shuffle the shards + random_permutation(loader->shard_indices, (int) loader->glob_result.gl_pathc, &loader->shuffle_rng); + } + + dataloader_load_shard_(loader, (int) loader->current_shard_idx); + + if (loader->should_shuffle) { + prepare_intra_shard_indices_(loader); + } } void dataloader_advance_(DataLoader *loader) { + if (loader->current_shard_idx == loader->glob_result.gl_pathc - 1) { + // if we are at the last shard, we reset the loader and start a new epoch + dataloader_reset(loader); + return; + } + // advance the loader by loading the next data shard and resetting the position - if (loader->glob_result.gl_pathc > 1) { - // if we have more than one shard, advance to the next one - loader->current_shard = (loader->current_shard + 1) % loader->glob_result.gl_pathc; - dataloader_load_shard_(loader, loader->current_shard); + loader->current_shard_idx = (loader->current_shard_idx + 1) % loader->glob_result.gl_pathc; + loader->current_sample_idx = 0; + dataloader_load_shard_(loader, (int) loader->current_shard_idx); + + if (loader->should_shuffle) { + prepare_intra_shard_indices_(loader); } - int64_t header_bytes = HEADER_SIZE * sizeof(int); - int64_t token_bytes_offset = loader->process_rank * loader->B * loader->T * sizeof(uint16_t); - loader->current_position = header_bytes + token_bytes_offset; } void dataloader_init(DataLoader *loader, @@ -113,12 +144,17 @@ void dataloader_init(DataLoader *loader, size_t B, size_t T, int process_rank, - int num_processes) { + int num_processes, + int should_shuffle) { loader->process_rank = process_rank; loader->num_processes = num_processes; loader->B = B; loader->T = T; loader->tokens_file = NULL; + loader->should_shuffle = should_shuffle; + loader->header_bytes = HEADER_SIZE * sizeof(int); + loader->total_batch_size_bytes = ((loader->num_processes * (loader->B * loader->T)) * sizeof(uint16_t)); + loader->local_batch_offset_bytes = loader->process_rank * loader->B * loader->T * sizeof(uint16_t); // glob to get the list of files matching the pattern, these are our data shards int glob_status = glob(filename_pattern, 0, NULL, &loader->glob_result); @@ -131,6 +167,15 @@ void dataloader_init(DataLoader *loader, exit(EXIT_FAILURE); } + if (should_shuffle) { + mt19937_state shuffle_rng; + manual_seed(&shuffle_rng, 42 + process_rank); + loader->shuffle_rng = shuffle_rng; + loader->shard_indices = (int*)mallocCheck(loader->glob_result.gl_pathc * sizeof(int)); + init_identity_permutation(loader->shard_indices, (int) loader->glob_result.gl_pathc); + loader->intra_shard_indices = NULL; // dynamically allocated allowing different shard sizes + } + // inspect and validate all shards so we don't get any runtime errors later // if too slow / too many shards, may wish to revisit later int64_t ntok_total = 0; @@ -138,7 +183,7 @@ void dataloader_init(DataLoader *loader, int64_t shard_ntok = dataloader_load_shard_(loader, shard_index); // we need at least one batch/shard, the way things are written right now. // can be relaxed a lot later. - assert(shard_ntok >= num_processes * B * T + 1); + assert(shard_ntok >= (int64_t) (num_processes * B * T + 1)); ntok_total += shard_ntok; } // debugging prints @@ -146,40 +191,59 @@ void dataloader_init(DataLoader *loader, // printf("DataLoader: Found %ld tokens across %zu shards\n", ntok_total, loader->glob_result.gl_pathc); // allocate all the space we'll need - loader->buffer = (uint16_t*)malloc((B * T + 1) * sizeof(uint16_t)); - loader->inputs = (int*)malloc(B * T * sizeof(int)); - loader->targets = (int*)malloc(B * T * sizeof(int)); + loader->buffer = (uint16_t*)mallocCheck((B * T + 1) * sizeof(uint16_t)); + loader->inputs = (int*)mallocCheck(B * T * sizeof(int)); + loader->targets = (int*)mallocCheck(B * T * sizeof(int)); loader->num_tokens = ntok_total; // reset the loader, to initialize it dataloader_reset(loader); } -void dataloader_next_batch(DataLoader *loader) { +void dataloader_load_batch(DataLoader* loader) { + assert(!loader->should_shuffle || (loader->should_shuffle && loader->intra_shard_indices != NULL)); + assert(loader->current_sample_idx < loader->shard_num_samples); + size_t idx = loader->should_shuffle ? loader->intra_shard_indices[loader->current_sample_idx] : loader->current_sample_idx; + size_t global_batch_offset_bytes = idx * loader->total_batch_size_bytes; + int64_t current_offset = loader->header_bytes + global_batch_offset_bytes + loader->local_batch_offset_bytes; + size_t B = loader->B; size_t T = loader->T; // read B*T+1 uint16_t tokens from the file into buffer - fseekCheck(loader->tokens_file, loader->current_position, SEEK_SET); + fseekCheck(loader->tokens_file, (int) current_offset, SEEK_SET); freadCheck(loader->buffer, sizeof(uint16_t), B*T+1, loader->tokens_file); // decode the buffer into inputs and targets (cast to int) for (int i = 0; i < B*T; i++) { loader->inputs[i] = (int)loader->buffer[i]; loader->targets[i] = (int)loader->buffer[i+1]; } - // advance the current position by B*T*num_processes integers - // note: the "stride" of tokens by which we move each time is definitely B * T - // we only load B * T + 1 tokens at each iteration because the targets are offset by 1 - loader->current_position += loader->num_processes * B * T * sizeof(uint16_t); +} + +void dataloader_next_batch(DataLoader *loader) { // if the next batch would go past the end of the file, advance the loader - if (loader->current_position + (loader->num_processes * B * T + 1) * sizeof(uint16_t) > loader->file_size) { + if (loader->current_sample_idx >= loader->shard_num_samples) { dataloader_advance_(loader); } + dataloader_load_batch(loader); + loader->current_sample_idx += 1; +} + + +void dataloader_resume(DataLoader *loader, size_t current_shard_idx, size_t current_sample_idx) { + // used during model resumption (-y 1) flag + loader->current_shard_idx = current_shard_idx; + loader->current_sample_idx = current_sample_idx; + dataloader_load_shard_(loader, (int) loader->current_shard_idx); } void dataloader_free(DataLoader *loader) { free(loader->buffer); free(loader->inputs); free(loader->targets); + if (loader->should_shuffle) { + free(loader->shard_indices); + free(loader->intra_shard_indices); + } fcloseCheck(loader->tokens_file); globfree(&loader->glob_result); } @@ -237,7 +301,7 @@ void evalloader_reset(EvalLoader *loader) { // then process 0 should start at 0, process 1 at N/4, process 2 at N/2, etc. // determine how much work there is for all processes int examples_per_process = CEIL_DIV(loader->num_examples, loader->num_processes); - int can_fit_examples = loader->B / ASSUMED_NUM_COMPLETIONS; + int can_fit_examples = (int) (loader->B / ASSUMED_NUM_COMPLETIONS); loader->num_batches = CEIL_DIV(examples_per_process, can_fit_examples); // determine the start and end example indices for this process loader->start_example_index = examples_per_process * loader->process_rank; @@ -249,7 +313,7 @@ void evalloader_reset(EvalLoader *loader) { // now seek through the file to the start of that example // utilize for efficiency int64_t header_bytes = HEADER_SIZE * sizeof(int); - fseekCheck(loader->eval_file, header_bytes, SEEK_SET); + fseekCheck(loader->eval_file, (int) header_bytes, SEEK_SET); for (int i = 0; i < loader->start_example_index; i++) { uint16_t example_header[3]; // read 3 uint16_t values: , , @@ -261,7 +325,7 @@ void evalloader_reset(EvalLoader *loader) { // skip to the next example, keeping in mind that we already read the header size_t remaining_bytes = example_header[1] - sizeof(uint16_t) * 3; assert(remaining_bytes > 0); // we expect some bytes in the example - fseekCheck(loader->eval_file, remaining_bytes, SEEK_CUR); + fseekCheck(loader->eval_file, (int) remaining_bytes, SEEK_CUR); } // now we are at the start of the example we want to start at, pointing at loader->current_example_index = loader->start_example_index; @@ -296,12 +360,12 @@ void evalloader_init(EvalLoader *loader, assert(longest_example_bytes > 0 && longest_example_bytes < (1+ASSUMED_NUM_COMPLETIONS)*T*2); // allocate all the space we'll need - int can_fit_examples = B / ASSUMED_NUM_COMPLETIONS; - loader->buffer = (uint16_t*)malloc(longest_example_bytes); + int can_fit_examples = (int) (B / ASSUMED_NUM_COMPLETIONS); + loader->buffer = (uint16_t*)mallocCheck(longest_example_bytes); loader->inputs = (int*)calloc(B * T, sizeof(int)); loader->targets = (int*)calloc(B * T, sizeof(int)); - loader->mask = (char*)malloc(B * T * sizeof(char)); - loader->label = (int*)malloc(can_fit_examples * sizeof(int)); + loader->mask = (char*)mallocCheck(B * T * sizeof(char)); + loader->label = (int*)mallocCheck(can_fit_examples * sizeof(int)); // reset the loader, to initialize it evalloader_reset(loader); @@ -329,7 +393,7 @@ void evalloader_next_example_(EvalLoader *loader, int example_batch_index) { freadCheck(loader->buffer, sizeof(char), example_bytes, loader->eval_file); // process the example label int label = (int)loader->buffer[0]; - int can_fit_examples = loader->B / ASSUMED_NUM_COMPLETIONS; + int can_fit_examples = (int) (loader->B / ASSUMED_NUM_COMPLETIONS); assert(label >= 0 && label < ASSUMED_NUM_COMPLETIONS); // we expect the label to be in [0, 4) for right now assert(example_batch_index >= 0 && example_batch_index < can_fit_examples); loader->label[example_batch_index] = label; // store for output @@ -386,7 +450,7 @@ void evalloader_next_batch(EvalLoader *loader) { // we have a batch dimension of B, which we want to take full advantage of // each example has some number of completions (usually 4) // so we want to pack as many examples into rows of B as we can fit - int can_fit_examples = B / ASSUMED_NUM_COMPLETIONS; // how many examples can we fit in the batch? + int can_fit_examples = (int) (B / ASSUMED_NUM_COMPLETIONS); // how many examples can we fit in the batch? for (int i = 0; i < can_fit_examples; i++) { if (loader->current_example_index >= loader->end_example_index) { break; // this process has exhausted its work, noop from here on @@ -405,7 +469,7 @@ int evalloader_stat_losses(EvalLoader *loader, float* losses) { size_t B = loader->B; size_t T = loader->T; // iterate the examples in this batch - int can_fit_examples = B / ASSUMED_NUM_COMPLETIONS; + int can_fit_examples = (int) (B / ASSUMED_NUM_COMPLETIONS); for (int i = 0; i < can_fit_examples; i++) { float min_loss = 0.0f; int min_loss_index = -1; diff --git a/llmc/encoder.cuh b/llmc/encoder.cuh index 5e5715b18..3aa63e175 100644 --- a/llmc/encoder.cuh +++ b/llmc/encoder.cuh @@ -107,8 +107,11 @@ __global__ void wte_backward_kernel(floatX* dwte, // Add the result to dwte and write back to global memory (read-modify-write) for (unsigned int k = 0; k < x128::size; k++) { - // We use stochastic rounding to go from FP32 to BF16 but the seed should be deterministic - stochastic_rounding(accum[k] + (float)packed_in_out[k], &packed_in_out[k], seed + k); + // We use stochastic rounding to go from FP32 to BF16 + // The seed is deterministic and unique for each parameter to guarantee we have determinism AND + // to avoid **potential** issues with positionX int SquirrelNoise5 argument overflowing which is UB + // and that somehow messing the quality of random numbers + stochastic_rounding(accum[k] + (float)packed_in_out[k], &packed_in_out[k], seed + bucket * WARP_SIZE + threadIdx.x + k); } store128(dwte_ix, packed_in_out); } @@ -139,8 +142,11 @@ __global__ void wpe_backward_kernel(floatX* dwpe, floatX* dwpe_tc = dwpe + (t * C) + c; x128 packed_dwpe = load128(dwpe_tc); for (unsigned int k = 0; k < x128::size; k++) { - // We use stochastic rounding to go from FP32 to BF16 but the seed should be deterministic - stochastic_rounding(accum[k] + (float)packed_dwpe[k], &packed_dwpe[k], seed + k); + // We use stochastic rounding to go from FP32 to BF16 + // The seed is deterministic and unique for each parameter to guarantee we have determinism AND + // to avoid **potential** issues with positionX int SquirrelNoise5 argument overflowing which is UB + // and that somehow messing the quality of random numbers + stochastic_rounding(accum[k] + (float)packed_dwpe[k], &packed_dwpe[k], seed + idx + k); } store128(dwpe_tc, packed_dwpe); } diff --git a/llmc/fused_classifier.cuh b/llmc/fused_classifier.cuh index b13fa35dd..4837d4cb0 100644 --- a/llmc/fused_classifier.cuh +++ b/llmc/fused_classifier.cuh @@ -65,11 +65,11 @@ __device__ SoftmaxParams prepare_softmax_blockwide3(int64_t idx, const floatX* i // will _update_ logits to logit gradients // uses template to decide whether to write logits and probs // split both loops in "multiple-of-x128-size" and "bounds-checked remainder" parts -template +template __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) - fused_classifier_kernel5(floatX* logits, floatX* losses, floatX* probs, + fused_classifier_kernel5(floatX* logits, float* losses, floatX* probs, const float dloss, const int* targets, - int B, int T, int V, int P) { + int B, int T, int V, int P, std::bool_constant) { // note: idx is small enough that it easily fits into 32 bit; // by making it a long here, we ensure that any offsets calculated with it (e.g., idx * P) // are done is 64 bit @@ -82,9 +82,15 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) // calculate the probability needed for the loss and update (single-threaded) if(threadIdx.x == 0) { float prob = expf((float)logits[idx * P + ix] - sp.Offset) * sp.Scale; - losses[idx] = (floatX)(-logf(prob)); + losses[idx] -= logf(prob); } + // without this synchronization point we have a race condition: + // the logits used above to compute the loss are concurrently (race) modified to carry backward pass grads. + // since the "logits" are overwritten to be in the [-1, 1] range and sp.Offset is sometimes smaller than -90 + // we errouneously end up computing exp^(90+) which gives us infinities in the loss! this is the fix. + __syncthreads(); + // calculate the gradients directly, saves bandwidth from probs during training // but also supports writing probs for inference-only and debugging const floatX* logits_vec = logits + idx * P; @@ -100,7 +106,7 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) float indicator = (element == ix) ? 1.0f : 0.0f; packed_logits_vec[k] = (floatX)((prob - indicator) * dloss); } - if (WriteLogits){ + if (WriteDLogits){ // reduce cache persistence for the overwritten logits // to maximise probability that logits remain in cache between prepare_softmax and here store128cs(logits + idx * P + i * x128::size, packed_logits_vec); @@ -117,7 +123,7 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) float prob = expf((float)logits_vec[i] - sp.Offset) * sp.Scale; float indicator = (i == ix) ? 1.0f : 0.0f; float dlogit = (prob - indicator) * dloss; - if (WriteLogits){ + if (WriteDLogits){ __stcs(logits + idx * P + i, (floatX)dlogit); } if (WriteProbs) { @@ -130,14 +136,14 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) // kernel launchers // replaces logits with logit gradients -template -void fused_classifier(Type* logits, Type* losses, +template +void fused_classifier(Type* logits, float* losses, const float dloss, const int* targets, - int B, int T, int V, int P, cudaStream_t stream) { + int B, int T, int V, int P, std::bool_constant write_dlogits, cudaStream_t stream) { NVTX_RANGE_FN(); const int block_size = 1024; const int N = B * T; const int grid_size = N; - fused_classifier_kernel5<<>>(logits, losses, (floatX*)NULL, dloss, targets, B, T, V, P); + fused_classifier_kernel5<<>>(logits, losses, (floatX*)NULL, dloss, targets, B, T, V, P, write_dlogits); cudaCheck(cudaGetLastError()); } diff --git a/llmc/gelu.cuh b/llmc/gelu.cuh index ce9aa1251..cd5c297b6 100644 --- a/llmc/gelu.cuh +++ b/llmc/gelu.cuh @@ -50,7 +50,7 @@ __global__ void gelu_backward_inplace_kernel(floatX* d_in_out, const floatX* inp void gelu_forward(floatX* out, const floatX* inp, int N, cudaStream_t stream) { NVTX_RANGE_FN(); const int block_size = 512; - assert(N % block_size == 0); + assert(N % (block_size * x128::size) == 0); const int grid_size = CEIL_DIV(N, block_size * x128::size); gelu_forward_kernel2<<>>(out, inp); cudaCheck(cudaGetLastError()); @@ -59,7 +59,7 @@ void gelu_forward(floatX* out, const floatX* inp, int N, cudaStream_t stream) { void gelu_backward_inplace(floatX* d_in_out, const floatX* inp, const int N, cudaStream_t stream) { NVTX_RANGE_FN(); const int block_size = 128; - assert(N % block_size == 0); + assert(N % (block_size * x128::size) == 0); const int grid_size = CEIL_DIV(N, block_size * x128::size); gelu_backward_inplace_kernel<<>>(d_in_out, inp); cudaCheck(cudaGetLastError()); diff --git a/llmc/global_norm.cuh b/llmc/global_norm.cuh index 92866e81d..e0e23b08a 100644 --- a/llmc/global_norm.cuh +++ b/llmc/global_norm.cuh @@ -13,13 +13,7 @@ Global norm, used in gradient clipping template __device__ float global_norm_squared_for_range(const T* data, size_t count) { - // we want as few atomics as possible, so each block tries to do - // the maximum amount of work (so no fixed chunk, but instead iterating - // until we run out of data), and then we reduce inside the block - // and finally have just one atomic per block. - // out will be updated atomically from all thread blocks. It is a float, so the - // atomic op is unproblematic - size_t index = threadIdx.x + blockDim.x * blockIdx.x; + size_t index = blockIdx.x * blockDim.x + threadIdx.x; size_t grid_width = blockDim.x * gridDim.x; float accumulator = 0.f; for(size_t i = index; i < count; i += grid_width) { @@ -32,16 +26,47 @@ __device__ float global_norm_squared_for_range(const T* data, size_t count) { template __global__ void global_norm_squared_kernel(float* out, const T* data, size_t count, ptrdiff_t stride) { float block_sum = global_norm_squared_for_range(data + blockIdx.y * stride, count); + // each block accumulates its partial sum to out[out_index] + // we want to avoid using atomic add here so we combine this kernel with another kernel call + // that sums up the partial block sums if(threadIdx.x == 0) { - atomicAdd(out, block_sum); + size_t out_index = blockIdx.y * gridDim.x + blockIdx.x; + out[out_index] = out[out_index] + block_sum; + } +} + +__global__ void global_norm_aggregate_kernel(float* out, size_t grid_size) { + size_t index = threadIdx.x; + // grab block sums from the previous kernel, use 0. as the neutral sum element + float block_sum = (index < grid_size) ? out[index] : 0.f; + float sum = blockReduce(block_sum); + if(threadIdx.x == 0) { + out[0] = sum; // out[0] ends up with the final norm squared } } // ---------------------------------------------------------------------------- // kernel launcher +// Helper function determines the maximum number of block sums +int get_max_num_block_sums(int* num_slices_all, int numel) { + // NOTE: this needs to be kept in sync with `global_norm_squared` below. + const int block_size = 512; + const int grid_size = deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / block_size; + assert(grid_size > 0); + int max_num_block_sums = 0; + for (int i = 0; i < numel; i++) { + int num_slices = num_slices_all[i]; + const int gx = CEIL_DIV(grid_size, num_slices); + const int gy = num_slices; + max_num_block_sums = max(max_num_block_sums, gx * gy); + } + + return max_num_block_sums; +} + template -void global_norm_squared(float* out, const T* values, size_t count, ptrdiff_t stride, int num_slices, bool reset, cudaStream_t stream) { +void global_norm_squared(float* out, const T* values, size_t count, ptrdiff_t stride, int num_slices, int max_num_block_sums, bool reset, cudaStream_t stream) { const int block_size = 512; // launch just enough blocks to fill the grid. deliberately no DIV_CEIL. // having one block less than possible is a tiny performance hit, having @@ -50,13 +75,15 @@ void global_norm_squared(float* out, const T* values, size_t count, ptrdiff_t st // on all gpus, so the division really is going to be exact. const int grid_size = deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / block_size; assert(grid_size > 0); // gives a better error than letting the call below fail - // initialize out with zero - if(reset) { - cudaCheck(cudaMemsetAsync(out, 0, sizeof(float), stream)); - } + const int gx = CEIL_DIV(grid_size, num_slices); const int gy = num_slices; + + assert(gx * gy < 1024); // we want to later accumulate the block sums in a single block + + if (reset) { + cudaCheck(cudaMemsetAsync(out, 0, max_num_block_sums * sizeof(float), stream)); + } global_norm_squared_kernel<<>>(out, values, count, stride); cudaCheck(cudaGetLastError()); } - diff --git a/llmc/layernorm.cuh b/llmc/layernorm.cuh index a49f0bffb..045993a26 100644 --- a/llmc/layernorm.cuh +++ b/llmc/layernorm.cuh @@ -17,7 +17,7 @@ E.g., the layernorms are connected to the residuals so we += in layernorm backwa // ---------------------------------------------------------------------------- // CUDA kernels -__global__ void layernorm_forward_kernel3(floatX* __restrict__ out, floatX* __restrict__ mean, floatX* __restrict__ rstd, +__global__ void layernorm_forward_kernel3(floatX* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd, const floatX* __restrict__ inp, const floatX* __restrict__ weight, const floatX* __restrict__ bias, int N, int C) { int lane_id = threadIdx.x % WARP_SIZE; @@ -38,7 +38,7 @@ __global__ void layernorm_forward_kernel3(floatX* __restrict__ out, floatX* __re sum = warpReduceSum(sum); float m = sum / C; if(lane_id == 0 && mean != nullptr) { - __stcs(mean + idx, (floatX)m); + __stcs(mean + idx, m); } // rstd @@ -50,7 +50,7 @@ __global__ void layernorm_forward_kernel3(floatX* __restrict__ out, floatX* __re sum = warpReduceSum(sum); float s = rsqrtf(sum / C + 1e-5f); if(lane_id == 0 && rstd != nullptr) { - __stcs(rstd + idx, (floatX)s); + __stcs(rstd + idx, s); } // final normalization and scaling by weight/bias @@ -64,7 +64,82 @@ __global__ void layernorm_forward_kernel3(floatX* __restrict__ out, floatX* __re } } -__global__ void fused_residual_forward_kernel5(floatX* residual, floatX* normed, floatX* mean, floatX* rstd, +__global__ void layernorm_forward_kernel6(floatX* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd, + const floatX* __restrict__ inp, const floatX* __restrict__ weight, + const floatX* __restrict__ bias, int N, int C) { + assert(blockDim.x == WARP_SIZE); + + // load weights and biases into shared memory + // do this before we allow any threads to exit! + extern __shared__ char* params[]; + // load128/store128 sometimes generated multiple instructions when the types here were floatX*, so + // let's keep everything as x128 + x128* s_weight = reinterpret_cast(params); + x128* s_bias = reinterpret_cast(params) + (C / x128::size); + x128* s_in = reinterpret_cast(params) + ((2 + threadIdx.y) * C / x128::size); + + int sidx = (threadIdx.x + WARP_SIZE * threadIdx.y) * x128::size; + for(int i = sidx; i < C; i += blockDim.y * WARP_SIZE * x128::size) { + s_weight[i/x128::size] = load128(weight + i); + s_bias[i/x128::size] = load128(bias + i); + } + __syncthreads(); + + int idx = blockIdx.x * blockDim.y + threadIdx.y; + if(idx >= N) { return; } // guard + + // adjust pointers to current token + inp += idx * C; + out += idx * C; + + const float eps = 1e-5f; + float sum = 0.0f; + for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { + const x128 in_data = load128cs(inp + c); + for(int k = 0; k < x128::size; ++k) { + sum += (float)in_data[k]; + } + s_in[c / x128::size] = in_data; + } + + sum = warpReduceSum(sum); + float m = sum / C; + float v = 0.f; + + for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { + const x128 in_data = s_in[c / x128::size]; + for(int k = 0; k < x128::size; ++k) { + v += ((float)in_data[k] - m) * ((float)in_data[k] - m); + } + } + + v = warpReduceSum(v) / C; + float s = rsqrtf(v + eps); + + for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { + const x128 in_data = s_in[c / x128::size]; + const x128 w = s_weight[c / x128::size]; + const x128 b = s_bias[c / x128::size]; + x128 out_data; + for(int k = 0; k < x128::size; ++k) { + float n = s * ((float)in_data[k] - m); // normalized output + float o = n * (float)w[k] + (float)b[k]; // scale and shift it + out_data[k] = (floatX)o; + } + + store128cs(out + c, out_data); + } + // cache the mean and rstd for the backward pass later + if(threadIdx.x == 0 && mean != nullptr) { + __stcs(mean + idx, m); + } + // store the rstd, no need to cache it + if(threadIdx.x == 0 && rstd != nullptr) { + __stcs(rstd + idx, s); + } +} + +__global__ void fused_residual_forward_kernel5(floatX* residual, floatX* normed, float* mean, float* rstd, const floatX* inp1, const floatX* inp2, const floatX* weight, const floatX* bias, int N, int C) { @@ -158,7 +233,7 @@ __global__ void residual_forward_kernel(floatX* out, const floatX* inp1, const f __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with only 1024 threads? layernorm_backward_kernel10(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, const floatX* dout, const floatX* inp, const floatX* weight, - const floatX* mean, const floatX* rstd, + const float* mean, const float* rstd, int B, int T, int C) { int BLOCK_SIZE = blockDim.x; int warpsInBlock = BLOCK_SIZE / WARP_SIZE; //number of warps in block @@ -207,8 +282,8 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with } } - const float mean_bt = (float)mean[bt]; - const float rstd_bt = (float)rstd[bt]; + const float mean_bt = mean[bt]; + const float rstd_bt = rstd[bt]; dnorm_mean = warpReduceSum(dnorm_mean) / C; dnorm_norm_mean = warpReduceSum(dnorm_norm_mean) / C * rstd_bt - dnorm_mean * mean_bt * rstd_bt; @@ -354,27 +429,42 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with // ---------------------------------------------------------------------------- // kernel launchers -void layernorm_forward(floatX* out, floatX* mean, floatX* rstd, +// similar to `fused_residual_forward5` +void layernorm_forward(floatX* out, float* mean, float* rstd, floatX* inp, const floatX* weight, const floatX* bias, int B, int T, int C, cudaStream_t stream) { NVTX_RANGE_FN(); - const int block_size = 512; + const int block_size = 256; + int block_y = block_size / WARP_SIZE; const int N = B * T; - const int grid_size = CEIL_DIV(N * WARP_SIZE, block_size); - layernorm_forward_kernel3<<>>(out, mean, rstd, inp, weight, bias, N, C); + const int grid_size = CEIL_DIV(N, block_y); + size_t smem = (2 + block_y) * C * sizeof(floatX); + + // in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute + // this may fail, in which case we fall back to the smem free implementation. + cudaCheck(cudaGetLastError()); + auto status = cudaFuncSetAttribute(layernorm_forward_kernel6, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); + cudaGetLastError(); + if (status == cudaSuccess) { + layernorm_forward_kernel6<<>>(out, mean, rstd, inp, weight, bias, N, C); + } else { + // fall back to the version without shared memory + const int grid_size_fb = CEIL_DIV(N * WARP_SIZE, block_size); + layernorm_forward_kernel3<<>>(out, mean, rstd, inp, weight, bias, N, C); + } cudaCheck(cudaGetLastError()); } void residual_forward(floatX* out, const floatX* inp1, const floatX* inp2, int N, cudaStream_t stream) { NVTX_RANGE_FN(); const int block_size = 256; - assert(N % block_size == 0); + assert(N % (block_size * x128::size) == 0); const int grid_size = CEIL_DIV(N, block_size * x128::size); residual_forward_kernel<<>>(out, inp1, inp2); cudaCheck(cudaGetLastError()); } -void fused_residual_forward5(floatX* residual, floatX* normed, floatX* mean, floatX* rstd, +void fused_residual_forward5(floatX* residual, floatX* normed, float* mean, float* rstd, const floatX* inp1, const floatX* inp2, const floatX* weight, const floatX* bias, int N, int C, cudaStream_t stream) { @@ -400,7 +490,7 @@ void fused_residual_forward5(floatX* residual, floatX* normed, floatX* mean, flo } void layernorm_backward(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, - const floatX* dout, const floatX* inp, const floatX* weight, const floatX* mean, const floatX* rstd, + const floatX* dout, const floatX* inp, const floatX* weight, const float* mean, const float* rstd, int B, int T, int C, cudaStream_t stream) { NVTX_RANGE_FN(); const int block_size = 512; diff --git a/llmc/matmul.cuh b/llmc/matmul.cuh index d217da5d9..91fe9d5cd 100644 --- a/llmc/matmul.cuh +++ b/llmc/matmul.cuh @@ -195,11 +195,11 @@ void matmul_backward(floatX* dinp, floatX* dweight, floatX* dbias, // If we have enough OC that we don't need cross-block reductions, we can skip the bias_buffer accumulation // and write results directly to the output. if(grid_size_y == 1) { - matmul_backward_bias_kernel9<<>>(dbias, dout, B, T, OC, std::bool_constant{}); + matmul_backward_bias_kernel9<<>>(dbias, dout, B, T, OC, False); cudaCheck(cudaGetLastError()); } else { // kernel 9 overwrites temp buffer, so no need to memset - matmul_backward_bias_kernel9<<>>(dbias_buffer, dout, B, T, OC, std::bool_constant{}); + matmul_backward_bias_kernel9<<>>(dbias_buffer, dout, B, T, OC, True); cudaCheck(cudaGetLastError()); reduce_add_sum_kernel<<>>(dbias, dbias_buffer, OC, grid_size_y); cudaCheck(cudaGetLastError()); diff --git a/llmc/mfu.h b/llmc/mfu.h index ec26537f1..7753305dc 100644 --- a/llmc/mfu.h +++ b/llmc/mfu.h @@ -73,6 +73,7 @@ static GPUEntry gpu_db[] = { {"NVIDIA GeForce RTX 4070", &ADA, 184, 2475}, {"NVIDIA GeForce RTX 4060 Ti", &ADA, 136, 2535}, {"NVIDIA GeForce RTX 4060", &ADA, 96, 2460}, + {"NVIDIA H100 PCIe", &HOPPER, 456, 1620}, {"NVIDIA H100 80GB HBM3", &HOPPER, 528, 1830}, // HBM3 = SXM5 }; diff --git a/llmc/outlier_detector.h b/llmc/outlier_detector.h new file mode 100644 index 000000000..fb4ded23e --- /dev/null +++ b/llmc/outlier_detector.h @@ -0,0 +1,70 @@ +/* +Simple OutlierDetector that we can use to monitor the loss and grad norm +Internally, it keeps track of a window of measurements and each time we +add a measurement, it returns the z-score of the new value with respect to +the window of measurements. This can be used to detect outliers in the data. + +We use double so that the detector doesn't drift too much, because we +update the mean and variance with += on each step for efficiency. We could +reconsider this choice in the future, as the compute cost here is minimal. +*/ + +#include +#include + +// use compile-time constant for window size to avoid dynamic memory allocations +#define OUTLIER_DETECTOR_WINDOW_SIZE 128 + +typedef struct { + double buffer[OUTLIER_DETECTOR_WINDOW_SIZE]; + int count; + int index; + double sum; + double sum_sq; +} OutlierDetector; + +void init_detector(OutlierDetector *detector) { + for (int i = 0; i < OUTLIER_DETECTOR_WINDOW_SIZE; i++) { + detector->buffer[i] = 0.0; + } + detector->count = 0; + detector->index = 0; + detector->sum = 0.0; + detector->sum_sq = 0.0; +} + +double update_detector(OutlierDetector *detector, double new_value) { + + if (detector->count < OUTLIER_DETECTOR_WINDOW_SIZE) { + // here we are still building up a window of observations + detector->buffer[detector->count] = new_value; + detector->sum += new_value; + detector->sum_sq += new_value * new_value; + detector->count++; + return nan(""); // not enough data yet + + } else { + // we've filled the window, so now we can start detecting outliers + + // pop the oldest value from the window + double old_value = detector->buffer[detector->index]; + detector->sum -= old_value; + detector->sum_sq -= old_value * old_value; + // push the new value into the window + detector->buffer[detector->index] = new_value; + detector->sum += new_value; + detector->sum_sq += new_value * new_value; + // move the index to the next position + detector->index = (detector->index + 1) % OUTLIER_DETECTOR_WINDOW_SIZE; + // calculate the z-score of the new value + double mean = detector->sum / OUTLIER_DETECTOR_WINDOW_SIZE; + double variance = (detector->sum_sq / OUTLIER_DETECTOR_WINDOW_SIZE) - (mean * mean); + double std_dev = sqrt(variance); + if (std_dev == 0.0) { + return 0.0; + } + double z = (new_value - mean) / std_dev; + + return z; + } +} diff --git a/llmc/rand.h b/llmc/rand.h index 60ed393bf..b66aa04b7 100644 --- a/llmc/rand.h +++ b/llmc/rand.h @@ -165,13 +165,13 @@ void uniform_(float* data, unsigned int numel, float from, float to, mt19937_sta // Box-Muller transform: maps uniform random numbers to Gaussian distributed numbers // https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform -void normal_fill_16(float* data, float mean, float std, mt19937_state* state) { - #define EPSILONE 1e-12 +void normal_fill_16(float* data, float mean, float std) { + #define EPSILONE 1e-12f for (unsigned int t = 0; t < 8; t++) { float u1 = 1 - data[t]; float u2 = data[t + 8]; float radius = sqrtf(-2 * logf(u1 + EPSILONE)); - float theta = 2.0 * M_PI * u2; + float theta = (float) (2.0 * M_PI * u2); data[t] = (radius * cosf(theta) * std + mean); data[t + 8] = (radius * sinf(theta) * std + mean); } @@ -182,7 +182,7 @@ void normal_fill(float* data, unsigned int numel, float mean, float std, mt19937 data[t] = randfloat32(state); } for (unsigned int i = 0; i < numel - 15; i += 16) { - normal_fill_16(data + i, mean, std, state); + normal_fill_16(data + i, mean, std); } if (numel % 16 != 0) { // recompute the last 16 values @@ -190,12 +190,12 @@ void normal_fill(float* data, unsigned int numel, float mean, float std, mt19937 for (unsigned int i = 0; i < 16; i++) { data[i] = randfloat32(state); } - normal_fill_16(data, mean, std, state); + normal_fill_16(data, mean, std); } } void normal_(float* data, unsigned int numel, float mean, float std, mt19937_state* state) { - #define EPSILONE 1e-12 + #define EPSILONE 1e-12f if (numel >= 16) { normal_fill(data, numel, mean, std, state); } @@ -209,10 +209,10 @@ void normal_(float* data, unsigned int numel, float mean, float std, mt19937_sta continue; } // for numel < 16 we draw a double (float64) - float u1 = randfloat64(state); - float u2 = randfloat64(state); + float u1 = (float) randfloat64(state); + float u2 = (float) randfloat64(state); float radius = sqrtf(-2 * logf(1 - u2 + EPSILONE)); - float theta = 2.0 * M_PI * u1; + float theta = (float) (2.0 * M_PI * u1); next_double_normal_sample = radius * sinf(theta); has_next_double_normal_sample = 1; data[t] = (radius * cosf(theta) * std + mean); @@ -220,4 +220,21 @@ void normal_(float* data, unsigned int numel, float mean, float std, mt19937_sta } } +void init_identity_permutation(int *data, int numel) { + for (int i = 0; i < numel; i++) { + data[i] = i; + } +} + +void random_permutation(int* data, int numel, mt19937_state* state) { + for (int i = numel - 1; i > 0; i--) { + // pick an index j in [0, i] with equal probability + int j = randint32(state) % (i + 1); + // swap i <-> j + int tmp = data[i]; + data[i] = data[j]; + data[j] = tmp; + } +} + #endif \ No newline at end of file diff --git a/llmc/schedulers.h b/llmc/schedulers.h new file mode 100644 index 000000000..9ddc570d1 --- /dev/null +++ b/llmc/schedulers.h @@ -0,0 +1,100 @@ +/* +Implements various learning rate schedulers. +*/ +#ifndef SCHEDULERS_H +#define SCHEDULERS_H + +#include +#include +#include + +typedef struct { + const char* type; + float learning_rate; + int warmup_iterations; + int train_num_batches; + float final_learning_rate_frac; +} LearningRateScheduler; + +void lr_scheduler_init(LearningRateScheduler *scheduler, const char* scheduler_type, float learning_rate, int warmup_iterations, int train_num_batches, float final_learning_rate_frac) { + scheduler->type = scheduler_type; + scheduler->learning_rate = learning_rate; + scheduler->warmup_iterations = warmup_iterations; + scheduler->train_num_batches = train_num_batches; + scheduler->final_learning_rate_frac = final_learning_rate_frac; +} + +// cosine: warmup linearly to max LR, then cosine decay to LR * final_learning_rate_frac +float get_learning_rate_cosine(LearningRateScheduler *scheduler, int step) { + float lr = scheduler->learning_rate; + if (step < scheduler->warmup_iterations) { + lr = scheduler->learning_rate * ((float)(step + 1)) / scheduler->warmup_iterations; + } else { + float decay_ratio = ((float)(step - scheduler->warmup_iterations)) / (scheduler->train_num_batches - scheduler->warmup_iterations); + assert(0.0f <= decay_ratio && decay_ratio <= 1.0f); + float coeff = 0.5f * (1.0f + cosf(M_PI * decay_ratio)); // coeff starts at 1 and goes to 0 + assert(0.0f <= coeff && coeff <= 1.0f); + float min_lr = scheduler->learning_rate * scheduler->final_learning_rate_frac; + lr = min_lr + coeff * (scheduler->learning_rate - min_lr); + } + return lr; +} + +// linear: warmup linearly to max LR, then decay linearly to LR * final_learning_rate_frac +float get_learning_rate_linear(LearningRateScheduler *scheduler, int step) { + float lr = scheduler->learning_rate; + if (step < scheduler->warmup_iterations) { + lr = scheduler->learning_rate * ((float)(step + 1)) / scheduler->warmup_iterations; + } else { + float decay_ratio = ((float)(step - scheduler->warmup_iterations)) / (scheduler->train_num_batches - scheduler->warmup_iterations); + assert(0.0f <= decay_ratio && decay_ratio <= 1.0f); + float min_lr = scheduler->learning_rate * scheduler->final_learning_rate_frac; + lr = scheduler->learning_rate - decay_ratio * (scheduler->learning_rate - min_lr); + } + return lr; +} + +// constant +float get_learning_rate_constant(LearningRateScheduler *scheduler, int step) { + return scheduler->learning_rate; +} + +// wsd schedule: warmup linearly, keep constant, last 20% decay using 1 - sqrt decay to final_frac (should be 0.0) +// https://arxiv.org/abs/2405.18392 +float get_learning_rate_wsd(LearningRateScheduler *scheduler, int step) { + int decay_point = (int)(0.8f * scheduler->train_num_batches); + float max_lr = scheduler->learning_rate; + float lr = max_lr; + if (step < scheduler->warmup_iterations) { + float decay_ratio = ((float)(step + 1)) / scheduler->warmup_iterations; + lr = max_lr * decay_ratio; + } else if (step < decay_point) { + // noop, keep lr constant + } else { + float decay_ratio = ((float)(step - decay_point)) / (scheduler->train_num_batches - decay_point); + assert(0.0f <= decay_ratio && decay_ratio <= 1.0f); + float min_lr = max_lr * scheduler->final_learning_rate_frac; + return min_lr + (1.0f - sqrtf(decay_ratio)) * (max_lr - min_lr); + } + return lr; +} + +// return the learning rate at a given step +float get_learning_rate(LearningRateScheduler *scheduler, int step) { + float step_learning_rate; + if (strcmp(scheduler->type, "cosine") == 0) { + step_learning_rate = get_learning_rate_cosine(scheduler, step); + } else if (strcmp(scheduler->type, "linear") == 0) { + step_learning_rate = get_learning_rate_linear(scheduler, step); + } else if (strcmp(scheduler->type, "constant") == 0) { + step_learning_rate = get_learning_rate_constant(scheduler, step); + } else if (strcmp(scheduler->type, "wsd") == 0) { + step_learning_rate = get_learning_rate_wsd(scheduler, step); + } else { + fprintf(stderr, "Unknown learning rate scheduler type: %s\n", scheduler->type); + exit(EXIT_FAILURE); + } + return step_learning_rate; +} + +#endif // SCHEDULERS_H diff --git a/llmc/utils.h b/llmc/utils.h index be8acdb46..fece0a7cf 100644 --- a/llmc/utils.h +++ b/llmc/utils.h @@ -7,6 +7,7 @@ #ifndef UTILS_H #define UTILS_H +#include #include #include #include @@ -21,7 +22,7 @@ // simple replace fopen, fread, fclose, fseek // with fopenCheck, freadCheck, fcloseCheck, fseekCheck -FILE *fopen_check(const char *path, const char *mode, const char *file, int line) { +extern inline FILE *fopen_check(const char *path, const char *mode, const char *file, int line) { FILE *fp = fopen(path, mode); if (fp == NULL) { fprintf(stderr, "Error: Failed to open file '%s' at %s:%d\n", path, file, line); @@ -39,7 +40,7 @@ FILE *fopen_check(const char *path, const char *mode, const char *file, int line #define fopenCheck(path, mode) fopen_check(path, mode, __FILE__, __LINE__) -void fread_check(void *ptr, size_t size, size_t nmemb, FILE *stream, const char *file, int line) { +extern inline void fread_check(void *ptr, size_t size, size_t nmemb, FILE *stream, const char *file, int line) { size_t result = fread(ptr, size, nmemb, stream); if (result != nmemb) { if (feof(stream)) { @@ -61,7 +62,7 @@ void fread_check(void *ptr, size_t size, size_t nmemb, FILE *stream, const char #define freadCheck(ptr, size, nmemb, stream) fread_check(ptr, size, nmemb, stream, __FILE__, __LINE__) -void fclose_check(FILE *fp, const char *file, int line) { +extern inline void fclose_check(FILE *fp, const char *file, int line) { if (fclose(fp) != 0) { fprintf(stderr, "Error: Failed to close file at %s:%d\n", file, line); fprintf(stderr, "Error details:\n"); @@ -73,7 +74,7 @@ void fclose_check(FILE *fp, const char *file, int line) { #define fcloseCheck(fp) fclose_check(fp, __FILE__, __LINE__) -void fseek_check(FILE *fp, long off, int whence, const char *file, int line) { +extern inline void fseek_check(FILE *fp, long off, int whence, const char *file, int line) { if (fseek(fp, off, whence) != 0) { fprintf(stderr, "Error: Failed to seek in file at %s:%d\n", file, line); fprintf(stderr, "Error details:\n"); @@ -87,10 +88,32 @@ void fseek_check(FILE *fp, long off, int whence, const char *file, int line) { #define fseekCheck(fp, off, whence) fseek_check(fp, off, whence, __FILE__, __LINE__) +extern inline void fwrite_check(void *ptr, size_t size, size_t nmemb, FILE *stream, const char *file, int line) { + size_t result = fwrite(ptr, size, nmemb, stream); + if (result != nmemb) { + if (feof(stream)) { + fprintf(stderr, "Error: Unexpected end of file at %s:%d\n", file, line); + } else if (ferror(stream)) { + fprintf(stderr, "Error: File write error at %s:%d\n", file, line); + } else { + fprintf(stderr, "Error: Partial write at %s:%d. Expected %zu elements, wrote %zu\n", + file, line, nmemb, result); + } + fprintf(stderr, "Error details:\n"); + fprintf(stderr, " File: %s\n", file); + fprintf(stderr, " Line: %d\n", line); + fprintf(stderr, " Expected elements: %zu\n", nmemb); + fprintf(stderr, " Written elements: %zu\n", result); + exit(EXIT_FAILURE); + } +} + +#define fwriteCheck(ptr, size, nmemb, stream) fwrite_check(ptr, size, nmemb, stream, __FILE__, __LINE__) + // ---------------------------------------------------------------------------- // malloc error-handling wrapper util -void *malloc_check(size_t size, const char *file, int line) { +extern inline void *malloc_check(size_t size, const char *file, int line) { void *ptr = malloc(size); if (ptr == NULL) { fprintf(stderr, "Error: Memory allocation failed at %s:%d\n", file, line); @@ -105,10 +128,29 @@ void *malloc_check(size_t size, const char *file, int line) { #define mallocCheck(size) malloc_check(size, __FILE__, __LINE__) + +// ---------------------------------------------------------------------------- +// check that all tokens are within range +extern inline void token_check(const int* tokens, int token_count, int vocab_size, const char *file, int line) { + for(int i = 0; i < token_count; i++) { + if(!(0 <= tokens[i] && tokens[i] < vocab_size)) { + fprintf(stderr, "Error: Token out of vocabulary at %s:%d\n", file, line); + fprintf(stderr, "Error details:\n"); + fprintf(stderr, " File: %s\n", file); + fprintf(stderr, " Line: %d\n", line); + fprintf(stderr, " Token: %d\n", tokens[i]); + fprintf(stderr, " Position: %d\n", i); + fprintf(stderr, " Vocab: %d\n", vocab_size); + exit(EXIT_FAILURE); + } + } +} +#define tokenCheck(tokens, count, vocab) token_check(tokens, count, vocab, __FILE__, __LINE__) + // ---------------------------------------------------------------------------- // I/O ops -void create_dir_if_not_exists(const char *dir) { +extern inline void create_dir_if_not_exists(const char *dir) { if (dir == NULL) { return; } struct stat st = {0}; if (stat(dir, &st) == -1) { @@ -120,7 +162,7 @@ void create_dir_if_not_exists(const char *dir) { } } -int find_max_step(const char* output_log_dir) { +extern inline int find_max_step(const char* output_log_dir) { // find the DONE file in the log dir with highest step count if (output_log_dir == NULL) { return -1; } DIR* dir; diff --git a/llmc/zero.cuh b/llmc/zero.cuh index 160dae7ac..cb0faf8ec 100644 --- a/llmc/zero.cuh +++ b/llmc/zero.cuh @@ -5,6 +5,12 @@ Utilities for ZeRO sharding #ifndef LLMC_ZERO_CUH #define LLMC_ZERO_CUH +#ifdef _WIN32 +#include +#else +#include +#endif + #include #include #include @@ -12,8 +18,10 @@ Utilities for ZeRO sharding #include #ifdef MULTI_GPU -#include #include +#ifdef USE_MPI +#include +#endif #endif // ---------------------------------------------------------------------------- @@ -36,6 +44,7 @@ void nccl_check(ncclResult_t status, const char *file, int line) { } #define ncclCheck(err) (nccl_check(err, __FILE__, __LINE__)) +#ifdef USE_MPI void mpi_check(int status, const char *file, int line) { if (status != MPI_SUCCESS) { char mpi_error[4096]; @@ -46,15 +55,14 @@ void mpi_check(int status, const char *file, int line) { } } #define mpiCheck(err) (mpi_check(err, __FILE__, __LINE__)) +#endif #endif // MULTI_GPU // ---------------------------------------------------------------------------- -// MPI / multi-processing setup - // Parameters specific to training on multiple GPUs. typedef struct { - int process_rank; // Rank of this process among all MPI processes. 0 if no multi-GPU. + int process_rank; // Rank of this process among all processes. 0 if no multi-GPU. int num_processes; // Total number of processes. 1 if no multi-GPU. int local_device_idx; // This process GPU index on current machine. 0 if no multi-GPU. @@ -69,70 +77,381 @@ typedef struct { ncclComm_t nccl_comm; // NCCL communication primitive, used for collective multi-GPU work. cudaStream_t nccl_stream; // CUDA Stream to perform NCCL operations. cudaEvent_t compute_nccl_sync; // Event used to synchronize NCCL with the compute + float* unified_buffer; #endif } MultiGpuConfig; #ifdef MULTI_GPU + +#ifdef _WIN32 +void send_nccl_id_to_clients_windows(ncclUniqueId *nccl_id, SOCKET client_sockets[], int num_clients) { + for (int i = 0; i < num_clients; ++i) { + if (send(client_sockets[i], (const char *)nccl_id, sizeof(*nccl_id), 0) == SOCKET_ERROR) { + printf("Failed to send nccl_id"); + WSACleanup(); + exit(EXIT_FAILURE); + } + closesocket(client_sockets[i]); + } +} +#else +void send_nccl_id_to_clients(ncclUniqueId *nccl_id, int client_sockets[], int num_clients) { + for (int i = 0; i < num_clients; ++i) { + if (send(client_sockets[i], nccl_id, sizeof(*nccl_id), 0) == -1) { + printf("Failed to send nccl_id"); + exit(EXIT_FAILURE); + } + close(client_sockets[i]); + } +} +#endif + +#ifdef _WIN32 +// Same as get_nccl_id_via_tcp but for Windows +ncclUniqueId get_nccl_id_via_tcp_windows(MultiGpuConfig* result, const char* server_ip) { + ncclUniqueId nccl_id; + + int SERVER_PORT = 12345; // hardcoded an arbitrary port number between 1024 and 49151 (registered ports) + WSADATA wsaData; + if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) { + printf("WSAStartup failed"); + exit(EXIT_FAILURE); + } + + if (result->process_rank == 0) { + ncclCheck(ncclGetUniqueId(&nccl_id)); + + int MAX_CLIENTS = result->num_processes - 1; + SOCKET client_sockets[MAX_CLIENTS]; + int num_clients = 0; + SOCKET server_socket, new_socket; + struct sockaddr_in address; + int addrlen = sizeof(address); + + // Step 1) create a server TCP socket + if ((server_socket = socket(AF_INET, SOCK_STREAM, 0)) == INVALID_SOCKET) { + printf("Socket failed"); + WSACleanup(); + exit(EXIT_FAILURE); + } + + // Step 2) set the server address and port + address.sin_family = AF_INET; // IPv4 + address.sin_addr.s_addr = inet_addr(server_ip); + address.sin_port = htons(SERVER_PORT); + + // Step 3) bind the socket to the address and port + if (bind(server_socket, (struct sockaddr *)&address, sizeof(address)) == SOCKET_ERROR) { + printf("Bind failed"); + closesocket(server_socket); + WSACleanup(); + exit(EXIT_FAILURE); + } + + // Step 4) MAX_CLIENTS specifies the maximum number of clients that can be queued for this server + if (listen(server_socket, MAX_CLIENTS) == SOCKET_ERROR) { + printf("Listen failed"); + closesocket(server_socket); + WSACleanup(); + exit(EXIT_FAILURE); + } + + // Step 5) accept connections from clients + printf("Waiting for clients to connect...\n"); + while (num_clients < MAX_CLIENTS) { + if ((new_socket = accept(server_socket, (struct sockaddr *)&address, &addrlen)) == INVALID_SOCKET) { + printf("Accept failed"); + closesocket(server_socket); + WSACleanup(); + exit(EXIT_FAILURE); + } + client_sockets[num_clients++] = new_socket; + printf("Client %d connected\n", num_clients); + } + + // Step 6) send the NCCL ID to all clients + send_nccl_id_to_clients_windows(&nccl_id, client_sockets, num_clients); + printf("NCCL ID sent to all clients\n"); + + closesocket(server_socket); + } else { + int num_connection_attempts = 5; + int time_to_sleep = 2; + SOCKET client_socket; + struct sockaddr_in serv_addr; + + // Step 1) create a client TCP socket + if ((client_socket = socket(AF_INET, SOCK_STREAM, 0)) == INVALID_SOCKET) { + printf("Socket creation error"); + WSACleanup(); + exit(EXIT_FAILURE); + } + + // Step 2) set the server address and port + serv_addr.sin_family = AF_INET; + serv_addr.sin_port = htons(SERVER_PORT); + if (inet_pton(AF_INET, server_ip, &serv_addr.sin_addr) <= 0) { + printf("Invalid address or address not supported"); + closesocket(client_socket); + WSACleanup(); + exit(EXIT_FAILURE); + } + + // Step 3) Try to connect to the server - retry up to `num_connection_attempts` times if the connection fails + while (connect(client_socket, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) == SOCKET_ERROR) { + printf("%d Connection failed, retrying in %d seconds\n", result->process_rank, time_to_sleep); + if (--num_connection_attempts == 0) { + printf("Failed to connect to the server\n"); + closesocket(client_socket); + WSACleanup(); + exit(EXIT_FAILURE); + } + Sleep(time_to_sleep * 1000); + } + + // Step 4) receive the NCCL ID from the server + if (recv(client_socket, (char *)&nccl_id, sizeof(nccl_id), 0) <= 0) { + printf("Failed to receive nccl_id"); + closesocket(client_socket); + WSACleanup(); + exit(EXIT_FAILURE); + } + + printf("Received NCCL ID\n"); + closesocket(client_socket); + } + + WSACleanup(); + return nccl_id; +} +#else +ncclUniqueId get_nccl_id_via_tcp(MultiGpuConfig* result, const char* server_ip) { + ncclUniqueId nccl_id; + + int SERVER_PORT = 12345; // hardcoded an arbitrary port number between 1024 and 49151 (registered ports) + if (result->process_rank == 0) { + ncclCheck(ncclGetUniqueId(&nccl_id)); + + int MAX_CLIENTS = result->num_processes - 1; + int client_sockets[MAX_CLIENTS]; + int num_clients = 0; + int server_socket, new_socket; + struct sockaddr_in address; + int addrlen = sizeof(address); + int opt = 1; + + // Step 1) create a server TCP socket + if ((server_socket = socket(AF_INET, SOCK_STREAM, 0)) < 0) { + printf("Socket failed"); + exit(EXIT_FAILURE); + } + + // Step 2) set socket options + // SOL_SOCKET - means that option is configured at socket level + // SO_REUSEADDR - allows to bind to an address which is in a TIME_WAIT state (already used by another socket) - useful when restarting the server + // SO_REUSEPORT - allows to bind to the same port multiple times + if (setsockopt(server_socket, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt)) < 0) { + printf("Setsockopt failed"); + exit(EXIT_FAILURE); + } + + // Step 3) set the server address and port + address.sin_family = AF_INET; // IPv4 + address.sin_addr.s_addr = inet_addr(server_ip); // alternatively use INADDR_ANY to bind to all interfaces, currently we only allow ethernet + address.sin_port = htons(SERVER_PORT); + + // Step 4) bind the socket to the address and port + if (bind(server_socket, (struct sockaddr *)&address, sizeof(address)) < 0) { + printf("Bind failed"); + exit(EXIT_FAILURE); + } + + // Step 5) MAX_CLIENTS specifies the maximum number of clients that can be queued for this server + if (listen(server_socket, MAX_CLIENTS) < 0) { + printf("Listen failed"); + exit(EXIT_FAILURE); + } + + // Step 6) accept connections from clients + printf("Waiting for clients to connect...\n"); + while (num_clients < MAX_CLIENTS) { + if ((new_socket = accept(server_socket, (struct sockaddr *)&address, (socklen_t*)&addrlen)) < 0) { + printf("Accept failed"); + exit(EXIT_FAILURE); + } + client_sockets[num_clients++] = new_socket; + printf("Client %d connected\n", num_clients); + } + + // Step 7) send the NCCL ID to all clients + send_nccl_id_to_clients(&nccl_id, client_sockets, num_clients); + printf("NCCL ID sent to all clients\n"); + + close(server_socket); + } else { + int num_connection_attempts = 5; + int time_to_sleep = 2; + int client_socket; + struct sockaddr_in serv_addr; + + // Step 1) create a client TCP socket + if ((client_socket = socket(AF_INET, SOCK_STREAM, 0)) < 0) { + printf("Socket creation error"); + exit(EXIT_FAILURE); + } + + // Step 2) set the server address and port + serv_addr.sin_family = AF_INET; + serv_addr.sin_port = htons(SERVER_PORT); + if (inet_pton(AF_INET, server_ip, &serv_addr.sin_addr) <= 0) { + printf("Invalid address or address not supported"); + exit(EXIT_FAILURE); + } + + // Step 3) Try to connect to the server - retry up to `num_connection_attempts` times if the connection fails + while (connect(client_socket, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) < 0) { + printf("%d Connection failed, retrying in %d seconds\n", result->process_rank, time_to_sleep); + if (--num_connection_attempts == 0) { + printf("Failed to connect to the server\n"); + exit(EXIT_FAILURE); + } + sleep(time_to_sleep); + } + + // Step 4) receive the NCCL ID from the server + if (recv(client_socket, &nccl_id, sizeof(nccl_id), 0) <= 0) { + printf("Failed to receive nccl_id"); + exit(EXIT_FAILURE); + } + + printf("Received NCCL ID\n"); + close(client_socket); + } + + return nccl_id; +} +#endif + +ncclUniqueId get_nccl_id_via_fs(MultiGpuConfig* result, char* fs_path) { + // Works assuming that the filesystem is shared among all processes + ncclUniqueId nccl_id; + FILE* idFile; + static char filename[1024]; + snprintf(filename, sizeof(filename), "%s/ncclUniqueId.sync", fs_path); + + if (result->process_rank != 0) { // client processse should wait for the server to write to the file + // This is a naive and not 100% robust way to synchronize the processes but it should work almost always + sleep(2); + } + + if (result->process_rank == 0) { + ncclCheck(ncclGetUniqueId(&nccl_id)); + idFile = fopen(filename, "wb"); + assert(idFile != NULL); + fwrite(&nccl_id, sizeof(nccl_id), 1, idFile); + fclose(idFile); + } else { + // Other ranks wait until the file is available and read the unique ID + do { + sleep(1); // 1 second + idFile = fopen(filename, "rb"); + if (idFile != NULL) break; + } while (idFile == NULL); + freadCheck(&nccl_id, sizeof(nccl_id), 1, idFile); + fclose(idFile); + } + + return nccl_id; +} + +#ifdef USE_MPI // Determine which GPU this process should use. // Processes on the same machines use different GPU indicies. Processes on other machines don't. // Copied from NCCL examples: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/examples.html#example-2-one-device-per-process-or-thread int multi_gpu_get_local_device_idx(int process_rank, int num_processes) { - char hostname[1024]; - hostname[1023] = '\0'; - // All processes on the same machine will share the same hostname. - gethostname(hostname, 1023); - for (int i=0; i < 1024; i++) { - if (hostname[i] == '.') { - hostname[i] = '\0'; + char hostname[1024]; + hostname[1023] = '\0'; + // All processes on the same machine will share the same hostname. + gethostname(hostname, 1023); + for (int i=0; i < 1024; i++) { + if (hostname[i] == '.') { + hostname[i] = '\0'; + break; + } + } + uint64_t hostname_hash = 5381u; + for (int c = 0; hostname[c] != '\0'; c++){ hostname_hash = ((hostname_hash << 5u) + hostname_hash) ^ hostname[c]; } + + // Distribute all hostname hashes to all processes. + uint64_t* all_hostsname_hashes = (uint64_t*)malloc(num_processes * sizeof(uint64_t)); + all_hostsname_hashes[process_rank] = hostname_hash; + mpiCheck(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, all_hostsname_hashes, sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD)); + + // Identify which GPU we need to use. + int local_device_idx = 0; + for (int current_process = 0; current_process < num_processes; ++current_process) { + if (current_process == process_rank) { + // Found my gpu, local_device_idx now has my target GPU index. break; + } + if (all_hostsname_hashes[current_process] == all_hostsname_hashes[process_rank]) { + // This process ID runs on the same machine, but it's not me, skip this GPU + local_device_idx++; + } } - } - uint64_t hostname_hash = 5381u; - for (int c = 0; hostname[c] != '\0'; c++){ hostname_hash = ((hostname_hash << 5u) + hostname_hash) ^ hostname[c]; } - - // Distribute all hostname hashes to all processes. - uint64_t* all_hostsname_hashes = (uint64_t*)malloc(num_processes * sizeof(uint64_t)); - all_hostsname_hashes[process_rank] = hostname_hash; - mpiCheck(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, all_hostsname_hashes, sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD)); - - // Identify which GPU we need to use. - int local_device_idx = 0; - for (int current_process = 0; current_process < num_processes; ++current_process) { - if (current_process == process_rank) { - // Found my gpu, local_device_idx now has my target GPU index. - break; - } - if (all_hostsname_hashes[current_process] == all_hostsname_hashes[process_rank]) { - // This process ID runs on the same machine, but it's not me, skip this GPU - local_device_idx++; - } - } - - free(all_hostsname_hashes); - return local_device_idx; + + free(all_hostsname_hashes); + return local_device_idx; } #endif -MultiGpuConfig multi_gpu_config_init(int *argc, char ***argv) { +#endif + +MultiGpuConfig multi_gpu_config_init(int num_processes, int process_rank, int gpus_per_node, char* server_ip, char* fs_path, char* init_method) { #ifdef MULTI_GPU - // Initialize MPI. MultiGpuConfig result; - mpiCheck(MPI_Init(argc, argv)); - mpiCheck(MPI_Comm_rank(MPI_COMM_WORLD, &result.process_rank)); - mpiCheck(MPI_Comm_size(MPI_COMM_WORLD, &result.num_processes)); - result.local_device_idx = multi_gpu_get_local_device_idx(result.process_rank, result.num_processes); - cudaCheck(cudaSetDevice(result.local_device_idx)); ncclUniqueId nccl_id; - if (result.process_rank == 0) { - ncclCheck(ncclGetUniqueId(&nccl_id)); + // Get nccl_id using MPI, TCP, or FS (file system synchronization) methods + // On newer slurm versions (slurm-wlm package) PMIx is disabled so we can not use MPI for NCCL init in multi node setup + if (strcmp(init_method, "mpi") == 0) { + #ifdef USE_MPI + mpiCheck(MPI_Init(NULL, NULL)); + mpiCheck(MPI_Comm_rank(MPI_COMM_WORLD, &result.process_rank)); + mpiCheck(MPI_Comm_size(MPI_COMM_WORLD, &result.num_processes)); + result.local_device_idx = multi_gpu_get_local_device_idx(result.process_rank, result.num_processes); + if (result.process_rank == 0) { + ncclCheck(ncclGetUniqueId(&nccl_id)); + } + mpiCheck(MPI_Bcast(&nccl_id, sizeof(nccl_id), MPI_BYTE, 0, MPI_COMM_WORLD)); + #else + printf("MPI support is disabled. Please enable MPI support to use MPI-based NCCL-init method.\n"); + exit(EXIT_FAILURE); + #endif + } else { + result.process_rank = process_rank; + result.num_processes = num_processes; + result.local_device_idx = process_rank % gpus_per_node; + if (strcmp(init_method, "tcp") == 0) { + #ifdef _WIN32 + nccl_id = get_nccl_id_via_tcp_windows(&result, server_ip); + #else + nccl_id = get_nccl_id_via_tcp(&result, server_ip); + #endif + } else if (strcmp(init_method, "fs") == 0) { + nccl_id = get_nccl_id_via_fs(&result, fs_path); + } else { + printf("Invalid NCCL-init method\n"); + exit(EXIT_FAILURE); + } } - mpiCheck(MPI_Bcast((void *)&nccl_id, sizeof(nccl_id), MPI_BYTE, 0, MPI_COMM_WORLD)); + cudaCheck(cudaSetDevice(result.local_device_idx)); ncclCheck(ncclCommInitRank(&result.nccl_comm, result.num_processes, nccl_id, result.process_rank)); cudaCheck(cudaStreamCreate(&result.nccl_stream)); // event without timing for maximum performance cudaCheck(cudaEventCreate(&result.compute_nccl_sync, cudaEventDisableTiming)); nvtxNameCudaStreamA(result.nccl_stream, "nccl stream"); nvtxNameCudaEventA(result.compute_nccl_sync, "nccl compute sync"); + cudaCheck(cudaMallocManaged(&result.unified_buffer, sizeof(float))); return result; #else printf("Multi-GPU support is disabled. Using a single GPU.\n"); @@ -150,15 +469,19 @@ void multi_gpu_config_free(MultiGpuConfig* multi_gpu_config) { ncclCheck(ncclCommDestroy(multi_gpu_config->nccl_comm)); cudaCheck(cudaStreamDestroy(multi_gpu_config->nccl_stream)); cudaCheck(cudaEventDestroy(multi_gpu_config->compute_nccl_sync)); + cudaCheck(cudaFree(multi_gpu_config->unified_buffer)); + #ifdef USE_MPI mpiCheck(MPI_Finalize()); + #endif #endif } void multi_gpu_barrier(const MultiGpuConfig* multi_gpu_config) { #ifdef MULTI_GPU if (multi_gpu_config->num_processes > 1) { - mpiCheck(MPI_Barrier(MPI_COMM_WORLD)); + ncclCheck(ncclAllReduce(multi_gpu_config->unified_buffer, multi_gpu_config->unified_buffer, sizeof(float), ncclFloat, ncclSum, multi_gpu_config->nccl_comm, multi_gpu_config->nccl_stream)); } + cudaCheck(cudaDeviceSynchronize()); #endif } diff --git a/profile_gpt2.cu b/profile_gpt2.cu index 669e5efe1..29fbdb971 100644 --- a/profile_gpt2.cu +++ b/profile_gpt2.cu @@ -28,11 +28,18 @@ the profile.ncu-rep from a cloud box to local to pretty view. #include "train_gpt2.cu" int main(int argc, char *argv[]) { - multi_gpu_config = multi_gpu_config_init(&argc, &argv); + char nccl_init_method[256] = "mpi"; // "tcp" or "fs" or "mpi" + int num_processes = -1; // doesn't matter when using MPI + int process_rank = -1; // doesn't matter when using MPI + int gpus_per_node = -1; // doesn't matter when using MPI + char server_ip[256] = ""; // doesn't matter when using MPI + char fs_path[256] = ""; // doesn't matter when using MPI + multi_gpu_config = multi_gpu_config_init(num_processes, process_rank, gpus_per_node, server_ip, fs_path, nccl_init_method); common_start(true, true); // build the GPT-2 model from a checkpoint GPT2 model; + gpt2_init_common(&model); gpt2_build_from_checkpoint(&model, "gpt2_124M_bf16.bin"); int B = 24; // if program OOMs decrease this number, e.g. all the way down to 4 or etc @@ -52,10 +59,12 @@ int main(int argc, char *argv[]) { set_zero_configs(&multi_gpu_config, 0, model.num_parameters); // do a training step - gpt2_forward(&model, x, y, B, T); + gpt2_forward(&model, x, B, T); gpt2_zero_grad(&model); - gpt2_backward(&model, x, true); - gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, 1.f, 1, &multi_gpu_config); + gpt2_backward_and_reduce(&model, x, y, 1, true); + float grad_norm = gpt2_calculate_grad_norm(&model, &multi_gpu_config); + float grad_scale = (grad_norm > 1.0f) ? 1.0f / grad_norm : 1.0f; + gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, grad_scale, 1, &multi_gpu_config); cudaCheck(cudaDeviceSynchronize()); // finish all CUDA work to get correct precise timings // free diff --git a/requirements.txt b/requirements.txt index 80471a8be..ea4bc768d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ tqdm -numpy +numpy<2 torch tiktoken transformers diff --git a/scripts/multi_node/run_gpt2_124M_fs.sbatch b/scripts/multi_node/run_gpt2_124M_fs.sbatch new file mode 100755 index 000000000..9bef9aaca --- /dev/null +++ b/scripts/multi_node/run_gpt2_124M_fs.sbatch @@ -0,0 +1,85 @@ +#!/bin/bash +#SBATCH --job-name=llmc-multinode # job name +#SBATCH --output=/home/ubuntu/llm.c/scripts/multi_node/%x_%j_%t.log # output file +#SBATCH --error=/home/ubuntu/llm.c/scripts/multi_node/%x_%j_%t.err # error file +#SBATCH --partition=llmc # Specify the GPU partition +#SBATCH --ntasks=16 # total number of processes to launch on all nodes +#SBATCH --nodes=2 # total number of nodes +#SBATCH --ntasks-per-node=8 # assuming each node has 8 gpus +#SBATCH --gres=gpu:8 # request 8 gpus from each node + +# NOTE: change the above slurm arguments to match your system! +# Run with `sbatch ` + +make train_gpt2cu USE_CUDNN=1 NO_USE_MPI=1 + +# NOTE: change the following to match your system +binary_path="/home/ubuntu/llm.c/train_gpt2cu" +out_dir="/ephemeral/data/fineweb/log_gpt2_124M_multi" +train_data_path='/ephemeral/data/fineweb/bin_10B/fineweb_train_*.bin' +val_data_path='/ephemeral/data/fineweb/bin_10B/fineweb_val_*.bin' +sync_fs_path=$out_dir # needs to be a shared filesystem path that all nodes can access + +# In case the file system is shared this is a no-op. +# Otherwise, we need to copy the binary to all nodes. +current_user=$USER +hosts=$(scontrol show hostnames $SLURM_JOB_NODELIST) # get the hostnames of the allocated nodes +current_host=$(hostname) +for host in $hosts; do + if [ $host == $current_host ]; then + continue + fi + echo "copying $binary_path to $current_user@$host" + scp -r $binary_path $current_user@$host:$binary_path +done + +# Use this for NCCL debugging if you run into issues +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=ALL +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +# Optimization flags +export NCCL_NET_GDR_LEVEL=2 # use GPUDirect RDMA - allows for direct memory access between GPUs across different nodes by bypassing the CPU +export NCCL_IB_DISABLE=0 # use InfiniBand if available + +# NOTE: change the following environment variables to match your system - or comment them out if you don't need them +export NCCL_SOCKET_IFNAME=ens17 +export OMPI_MCA_btl_tcp_if_include=ens17 +export NCCL_P2P_LEVEL=PXB + +if [ -z "$SLURM_JOB_ID" ]; then + echo "Make sure you're running in a SLURM environment. Did you forget to run with sbatch? Aborting." + exit 1 +else + DATESTRING=`date "+%Y-%m-%dT%H:%M:%S"` + echo "Running in a SLURM environment (job ID: $SLURM_JOB_ID, user: $current_user)" + echo "Running on hosts: $(echo $(scontrol show hostname))" + echo "$DATESTRING" +fi + +srun -l -u bash -c " + $binary_path \ + -i '$train_data_path' \ + -j '$val_data_path' \ + -o $out_dir \ + -v 250 -s 20000 -g 144 \ + -h 1 \ + -b 64 -t 1024 \ + -d 2097152 \ + -r 0 \ + -z 1 \ + -c 0.1 \ + -l 0.0006 \ + -q 0.0 \ + -u 700 \ + -n 5000 \ + -y 1 \ + -e d12 \ + -pn \$SLURM_NTASKS \ + -pr \$SLURM_PROCID \ + -pg \$SLURM_NTASKS_PER_NODE \ + -pf $sync_fs_path \ + -pi "fs" \ +" + +echo "$DATESTRING" \ No newline at end of file diff --git a/scripts/multi_node/run_gpt2_124M_mpi.sh b/scripts/multi_node/run_gpt2_124M_mpi.sh new file mode 100755 index 000000000..e09b027ce --- /dev/null +++ b/scripts/multi_node/run_gpt2_124M_mpi.sh @@ -0,0 +1,49 @@ + +make train_gpt2cu USE_CUDNN=1 + +# NOTE: change the following to match your system +binary_path="/home/ubuntu/llm.c/train_gpt2cu" +out_dir="/ephemeral/data/fineweb/log_gpt2_124M_multi" +train_data_path='/ephemeral/data/fineweb/bin_10B/fineweb_train_*.bin' +val_data_path='/ephemeral/data/fineweb/bin_10B/fineweb_val_*.bin' +# You can find these names either in `/etc/hosts`` file or in the terminal (user@host:~$). +host1="h100-node-1-0" # master and worker node +host2="h100-node-1-1" # worker node + +# In case the file system is shared this is a no-op. +# Otherwise, we need to copy the binary to all nodes. +scp -r $binary_path $USER@$host2:$binary_path + +# Use this for NCCL debugging if you run into issues +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=ALL +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +# Optimization flags +export NCCL_NET_GDR_LEVEL=2 # use GPUDirect RDMA - allows for direct memory access between GPUs across different nodes by bypassing the CPU +export NCCL_IB_DISABLE=0 # use InfiniBand if available + +# NOTE: change the following environment variables to match your system - or comment them out if you don't need them +export NCCL_SOCKET_IFNAME=ens17 +export OMPI_MCA_btl_tcp_if_include=ens17 +export NCCL_P2P_LEVEL=PXB + +mpirun -np 16 --host $host1:8,$host2:8 \ + $binary_path \ + -i "$train_data_path" \ + -j "$val_data_path" \ + -o $out_dir \ + -v 250 -s 20000 -g 144 \ + -h 1 \ + -b 64 -t 1024 \ + -d 2097152 \ + -r 0 \ + -z 1 \ + -c 0.1 \ + -l 0.0006 \ + -q 0.1 \ + -u 700 \ + -n 1000 \ + -y 0 \ + -e d12 \ + -pi "mpi" \ diff --git a/scripts/multi_node/run_gpt2_124M_tcp.sbatch b/scripts/multi_node/run_gpt2_124M_tcp.sbatch new file mode 100755 index 000000000..f6cd3a7fa --- /dev/null +++ b/scripts/multi_node/run_gpt2_124M_tcp.sbatch @@ -0,0 +1,86 @@ +#!/bin/bash +#SBATCH --job-name=llmc-multinode # job name +#SBATCH --output=/home/ubuntu/llm.c/scripts/multi_node/%x_%j_%t.log # output file +#SBATCH --error=/home/ubuntu/llm.c/scripts/multi_node/%x_%j_%t.err # error file +#SBATCH --partition=llmc # Specify the GPU partition +#SBATCH --ntasks=16 # total number of processes to launch on all nodes +#SBATCH --nodes=2 # total number of nodes +#SBATCH --ntasks-per-node=8 # assuming each node has 8 gpus +#SBATCH --gres=gpu:8 # request 8 gpus from each node + +# NOTE: change the above slurm arguments to match your system! +# Run with `sbatch ` + +make train_gpt2cu USE_CUDNN=1 NO_USE_MPI=1 + +# NOTE: change the following to match your system +binary_path="/home/ubuntu/llm.c/train_gpt2cu" +out_dir="/ephemeral/data/fineweb/log_gpt2_124M_multi" +train_data_path='/ephemeral/data/fineweb/bin_10B/fineweb_train_*.bin' +val_data_path='/ephemeral/data/fineweb/bin_10B/fineweb_val_*.bin' +# NOTE: change the server_ip to the IP address of the machine that is running process zero +server_ip="10.0.1.220" + +# In case the file system is shared this is a no-op. +# Otherwise, we need to copy the binary to all nodes. +current_user=$USER +hosts=$(scontrol show hostnames $SLURM_JOB_NODELIST) # get the hostnames of the allocated nodes +current_host=$(hostname) +for host in $hosts; do + if [ $host == $current_host ]; then + continue + fi + echo "copying $binary_path to $current_user@$host" + scp -r $binary_path $current_user@$host:$binary_path +done + +# Use this for NCCL debugging if you run into issues +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=ALL +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +# Optimization flags +export NCCL_NET_GDR_LEVEL=2 # use GPUDirect RDMA - allows for direct memory access between GPUs across different nodes by bypassing the CPU +export NCCL_IB_DISABLE=0 # use InfiniBand if available + +# NOTE: change the following environment variables to match your system - or comment them out if you don't need them +export NCCL_SOCKET_IFNAME=ens17 +export OMPI_MCA_btl_tcp_if_include=ens17 +export NCCL_P2P_LEVEL=PXB + +if [ -z "$SLURM_JOB_ID" ]; then + echo "Make sure you're running in a SLURM environment. Did you forget to run with sbatch? Aborting." + exit 1 +else + DATESTRING=`date "+%Y-%m-%dT%H:%M:%S"` + echo "Running in a SLURM environment (job ID: $SLURM_JOB_ID, user: $current_user)" + echo "Running on hosts: $(echo $(scontrol show hostname))" + echo "$DATESTRING" +fi + +srun -l -u bash -c " + $binary_path \ + -i '$train_data_path' \ + -j '$val_data_path' \ + -o $out_dir \ + -v 250 -s 20000 -g 144 \ + -h 1 \ + -b 64 -t 1024 \ + -d 2097152 \ + -r 0 \ + -z 1 \ + -c 0.1 \ + -l 0.0006 \ + -q 0.0 \ + -u 700 \ + -n 5000 \ + -y 1 \ + -e d12 \ + -pn \$SLURM_NTASKS \ + -pr \$SLURM_PROCID \ + -pg \$SLURM_NTASKS_PER_NODE \ + -ps $server_ip \ + -pi "tcp" \ +" + +echo "$DATESTRING" diff --git a/scripts/run_gpt2_1558M.sh b/scripts/run_gpt2_1558M.sh new file mode 100644 index 000000000..28095a621 --- /dev/null +++ b/scripts/run_gpt2_1558M.sh @@ -0,0 +1,43 @@ +# GPT-2 (1558M) repro on FineWeb +# 1558M parameter model on 100B tokens +# => 6 * 1558Me6 * 100e9 = 6.966e20 ~= 1e21 capability model +# => 100,000 steps on 1M tokens/step (1,048,576 to be precise) +# on 8X A100 80GB SXM ($14/hr) steps in ~7s/iter +# => training time 100,000 steps * 7s = 194.4 hours ~= 8.1 days ~= $2721.6 + +make train_gpt2cu USE_CUDNN=1 +out_dir="log_gpt2_1558M" +done_file="$out_dir/DONE_00100000" + +# in case the training stalls or crashes, loop to resume (-y 1) +while true; do + + # exit condition is that optimization has finished + if [ -f "$done_file" ]; then + echo "File $done_file exists. Exiting the loop." + break + fi + + # run python dev/data/fineweb.py --version 100B to prepro data + # run python dev/data/hellaswag.py to prepro hellaswag eval + mpirun -np 8 ./train_gpt2cu \ + -i "dev/data/fineweb100B/fineweb_train_*.bin" \ + -j "dev/data/fineweb100B/fineweb_val_*.bin" \ + -o $out_dir \ + -v 250 -s 300000 -g 144 \ + -h 1 \ + -b 16 -t 1024 \ + -d 1048576 \ + -r 0 \ + -z 1 \ + -c 0.1 \ + -l 0.0002 \ + -q 0.1 \ + -u 700 \ + -n 2000 \ + -x 100000 \ + -y 1 \ + -e "d48" + + sleep 1 +done diff --git a/test_gpt2.cu b/test_gpt2.cu index c24f4c432..800df69e5 100644 --- a/test_gpt2.cu +++ b/test_gpt2.cu @@ -89,7 +89,13 @@ float* float_cpu_malloc_and_point_parameters(FloatParameterTensors* params, size } int main(int argc, char *argv[]) { - multi_gpu_config = multi_gpu_config_init(&argc, &argv); + char nccl_init_method[256] = "mpi"; // "tcp" or "fs" or "mpi" + int num_processes = -1; // doesn't matter when using MPI + int process_rank = -1; // doesn't matter when using MPI + int gpus_per_node = -1; // doesn't matter when using MPI + char server_ip[256] = ""; // doesn't matter when using MPI + char fs_path[256] = ""; // doesn't matter when using MPI + multi_gpu_config = multi_gpu_config_init(num_processes, process_rank, gpus_per_node, server_ip, fs_path, nccl_init_method); common_start(false, true); // set the right paths @@ -101,6 +107,7 @@ int main(int argc, char *argv[]) { // build the GPT-2 model from a checkpoint GPT2 model; + gpt2_init_common(&model); gpt2_build_from_checkpoint(&model, load_filename); size_t V = model.config.vocab_size; @@ -160,21 +167,22 @@ int main(int argc, char *argv[]) { int allok = 1; // First, do target-free forward pass to validate logits - gpt2_forward(&model, x, NULL, B, T); + gpt2_forward(&model, x, B, T); // at this point, target should be equal to expected_logits, let's compare // copy logits to CPU so we can compare them floatX* logits_cpu_raw = (floatX*)mallocCheck(B * T * Vp * sizeof(floatX)); float* logits_cpu = (float*)mallocCheck(B * T * Vp * sizeof(float)); - cudaMemcpy(logits_cpu_raw, model.acts.output, B * T * Vp * sizeof(floatX), cudaMemcpyDeviceToHost); + cudaCheck(cudaMemcpy(logits_cpu_raw, model.acts.output, B * T * Vp * sizeof(floatX), cudaMemcpyDeviceToHost)); for (int i = 0; i < B * T * Vp; i++) { logits_cpu[i] = (float)logits_cpu_raw[i]; } + float logit_accuracy_threshold = 1e-3f; + float loss_diff_threshold = 1e-5f; // FP16 and lower require very high tolerances unfortunately. TODO look into more - float logit_accuracy_threshold = 1e-2f; - float loss_diff_threshold = 0.05f; #if defined(ENABLE_BF16) || defined(ENABLE_F16) logit_accuracy_threshold = 25.0f; // 15.0f was too low even without cuDNN?! :( + loss_diff_threshold = 0.05f; #endif // compare the output logits from the forward pass @@ -208,25 +216,17 @@ int main(int argc, char *argv[]) { for (int step = 0; step < 10; step++) { struct timespec start, end; clock_gettime(CLOCK_MONOTONIC, &start); - gpt2_forward(&model, x, y, B, T); + gpt2_forward(&model, x, B, T); gpt2_zero_grad(&model); - gpt2_backward(&model, x, true); + gpt2_backward_and_reduce(&model, x, y, 1, true); clock_gettime(CLOCK_MONOTONIC, &end); double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9; if (step == 0) { // error checking at step 0 for reference activations - // compare the achieved loss - if (fabsf(model.mean_loss - *expected_loss) >= loss_diff_threshold) { - printf("LOSS MISMATCH: %f %f\n", model.mean_loss, *expected_loss); - allok = 0; - } else { - printf("LOSS OK: %f %f\n", model.mean_loss, *expected_loss); - } - // move the (mixed precision) grads from GPU to CPU - cudaMemcpy(grads_memory_cpu, model.grads_memory, model.num_parameters_bytes, cudaMemcpyDeviceToHost); + cudaCheck(cudaMemcpy(grads_memory_cpu, model.grads_memory, model.num_parameters_bytes, cudaMemcpyDeviceToHost)); // convert all gradients to float on the CPU char* src_iterator = (char*)grads_memory_cpu; // can be lower precision, so we use char* @@ -263,43 +263,56 @@ int main(int argc, char *argv[]) { // In that case it's ok to extend the tolerance by a bit, after a manual review. // Also, different GPUs may use different matrix multiplication algorithms, so the // actual errors can be hardware specific. - allok = allok & check_tensor(tensors1[0], tensors2[0], V * C, "wte", 6e-1f); // hmm a bit high - allok = allok & check_tensor(tensors1[1], tensors2[1], maxT * C, "wpe", 4e-3f); - allok = allok & check_tensor(tensors1[2], tensors2[2], L * 3*C * C, "qkvw", 1e-1); // hmm a bit high - allok = allok & check_tensor(tensors1[3], tensors2[3], L * 3*C, "qkvb", 3.5e-2f); - allok = allok & check_tensor(tensors1[4], tensors2[4], L * C * C, "attprojw", 2e-2f); - allok = allok & check_tensor(tensors1[5], tensors2[5], L * C, "attprojb", 3e-2f); - allok = allok & check_tensor(tensors1[6], tensors2[6], L * 4*C * C, "fcw", 5e-2f); // hmm a bit high - allok = allok & check_tensor(tensors1[7], tensors2[7], L * 4*C, "fcb", 5e-2f); // hmm a bit high - allok = allok & check_tensor(tensors1[8], tensors2[8], L * C * 4*C, "fcprojw", 5e-2f); // hmm a bit high - allok = allok & check_tensor(tensors1[9], tensors2[9], L * C, "fcprojb", 1.5e-2f); - allok = allok & check_tensor(tensors1[10], tensors2[10], L * C, "ln1w", 6e-4f); - allok = allok & check_tensor(tensors1[11], tensors2[11], L * C, "ln1b", 9e-3f); - allok = allok & check_tensor(tensors1[12], tensors2[12], L * C, "ln2w", 2e-3f); - allok = allok & check_tensor(tensors1[13], tensors2[13], L * C, "ln2b", 2.5e-3f); - allok = allok & check_tensor(tensors1[14], tensors2[14], C, "lnfw", 0.12f); // hmm bit higher - allok = allok & check_tensor(tensors1[15], tensors2[15], C, "lnfb", 2e-2f); + + float grad_thresholds[NUM_PARAMETER_TENSORS] = {5e-1f, 4e-3f, 1e-1f, 3.5e-2f, 2e-2f, 3e-2f, 5e-2f, 5e-2f, 5e-2f, 1.5e-2f, 5e-4f, 8e-3f, 1.5e-3f, 2.5e-3f, 1e-1f, 2e-2f}; + #if defined(ENABLE_FP32) + for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { + grad_thresholds[i] = 1e-6f; // we can be much more precise in FP32 + } + #endif + + allok = allok & check_tensor(tensors1[0], tensors2[0], V * C, "wte", grad_thresholds[0]); + allok = allok & check_tensor(tensors1[1], tensors2[1], maxT * C, "wpe", grad_thresholds[1]); + allok = allok & check_tensor(tensors1[2], tensors2[2], L * 3*C * C, "qkvw", grad_thresholds[2]); + allok = allok & check_tensor(tensors1[3], tensors2[3], L * 3*C, "qkvb", grad_thresholds[3]); + allok = allok & check_tensor(tensors1[4], tensors2[4], L * C * C, "attprojw", grad_thresholds[4]); + allok = allok & check_tensor(tensors1[5], tensors2[5], L * C, "attprojb", grad_thresholds[5]); + allok = allok & check_tensor(tensors1[6], tensors2[6], L * 4*C * C, "fcw", grad_thresholds[6]); + allok = allok & check_tensor(tensors1[7], tensors2[7], L * 4*C, "fcb", grad_thresholds[7]); + allok = allok & check_tensor(tensors1[8], tensors2[8], L * C * 4*C, "fcprojw", grad_thresholds[8]); + allok = allok & check_tensor(tensors1[9], tensors2[9], L * C, "fcprojb", grad_thresholds[9]); + allok = allok & check_tensor(tensors1[10], tensors2[10], L * C, "ln1w", grad_thresholds[10]); + allok = allok & check_tensor(tensors1[11], tensors2[11], L * C, "ln1b", grad_thresholds[11]); + allok = allok & check_tensor(tensors1[12], tensors2[12], L * C, "ln2w", grad_thresholds[12]); + allok = allok & check_tensor(tensors1[13], tensors2[13], L * C, "ln2b", grad_thresholds[13]); + allok = allok & check_tensor(tensors1[14], tensors2[14], C, "lnfw", grad_thresholds[14]); + allok = allok & check_tensor(tensors1[15], tensors2[15], C, "lnfb", grad_thresholds[15]); } - gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step+1, &multi_gpu_config); + float grad_norm = gpt2_calculate_grad_norm(&model, &multi_gpu_config); + float grad_scale = (grad_norm > 1.0f) ? 1.0f / grad_norm : 1.0f; + gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, grad_scale, step+1, &multi_gpu_config); // print the timing information at the end printf("step %d: loss %f (took %f ms)\n", step+1, model.mean_loss, time_elapsed_s * 1000); - losses[step] = model.mean_loss; + // the expected losses from PyTorch were copied over after the print formatting rounded + // them to 6 decimal places, so we do the same here + float rounded_loss = roundf(model.mean_loss * 1000000) / 1000000; + losses[step] = rounded_loss; } // expected losses are as follows, from Python float expected_losses[10] = { - 5.2700, - 4.0607, - 3.3202, - 2.7176, - 2.1811, - 1.6538, - 1.1680, - 0.7367, - 0.4008, - 0.1874 + 5.270009, + 4.060681, + 3.320085, + 2.717550, + 2.181066, + 1.653923, + 1.168050, + 0.736873, + 0.401021, + 0.187493 }; // compare @@ -312,10 +325,59 @@ int main(int argc, char *argv[]) { } } + // Finally, let's check determinism + gpt2_write_to_checkpoint(&model, "test_gpt2cu_model.ckpt"); + + DataLoader loader; + dataloader_init(&loader, "dev/data/tinyshakespeare/tiny_shakespeare_val.bin", B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes, 1); + save_state("test_gpt2cu_state.ckpt", 10, &model, &loader); + int tokens[10]; + for (int step = 0; step < 10; step++) { + dataloader_next_batch(&loader); + gpt2_forward(&model, loader.inputs, B, T); + gpt2_zero_grad(&model); + gpt2_backward_and_reduce(&model, loader.inputs, loader.targets, 1, true); + gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step+11, &multi_gpu_config); + losses[step] = model.mean_loss; + tokens[step] = loader.inputs[0]; + } + + // reload + gpt2_free(&model); + gpt2_build_from_checkpoint(&model, "test_gpt2cu_model.ckpt"); + int ld_step; + load_state(&ld_step, &model, &loader, "test_gpt2cu_state.ckpt"); + for (int step = 0; step < 10; step++) { + dataloader_next_batch(&loader); + gpt2_forward(&model, loader.inputs, B, T); + gpt2_zero_grad(&model); + gpt2_backward_and_reduce(&model, loader.inputs, loader.targets, 1, true); + gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step+11, &multi_gpu_config); + + if(loader.inputs[0] != tokens[step]) { + printf("Nondeterminism! Token mismatch at step %d: %d vs %d\n", step, tokens[step], loader.inputs[0]); + allok = false; + break; + } + + if(losses[step] != model.mean_loss) { + printf("Nondeterminism! Loss mismatch at step %d: %.15f vs %.15f\n", step, losses[step], model.mean_loss); + allok = false; + break; + } else { + printf("loss ok at step %d: %f %f\n", step, losses[step], model.mean_loss); + } + } + // final approval printf("overall okay: %d\n", allok); + // delete intermediate test files + remove("test_gpt2cu_model.ckpt"); + remove("test_gpt2cu_state.ckpt"); + // free everything + dataloader_free(&loader); gpt2_free(&model); common_free(model); free(x); diff --git a/train_gpt2.c b/train_gpt2.c index 6240f67f9..799a7a854 100644 --- a/train_gpt2.c +++ b/train_gpt2.c @@ -1041,14 +1041,14 @@ void gpt2_free(GPT2 *model) { // ---------------------------------------------------------------------------- // sampler -unsigned int random_u32(unsigned long long *state) { +unsigned int random_u32(uint64_t *state) { // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A *state ^= *state >> 12; *state ^= *state << 25; *state ^= *state >> 27; return (*state * 0x2545F4914F6CDD1Dull) >> 32; } -float random_f32(unsigned long long *state) { // random float32 in [0,1) +float random_f32(uint64_t *state) { // random float32 in [0,1) return (random_u32(state) >> 8) / 16777216.0f; } @@ -1083,8 +1083,8 @@ int main() { int B = 4; // batch size 4 (i.e. 4 independent token sequences will be trained on) int T = 64; // sequence length 64 (i.e. each sequence is 64 tokens long). must be <= maxT, which is 1024 for GPT-2 DataLoader train_loader, val_loader; - dataloader_init(&train_loader, train_tokens, B, T, 0, 1); - dataloader_init(&val_loader, val_tokens, B, T, 0, 1); + dataloader_init(&train_loader, train_tokens, B, T, 0, 1, 1); + dataloader_init(&val_loader, val_tokens, B, T, 0, 1, 0); printf("train dataset num_batches: %zu\n", train_loader.num_tokens / (B*T)); printf("val dataset num_batches: %zu\n", val_loader.num_tokens / (B*T)); int val_num_batches = 5; @@ -1094,7 +1094,7 @@ int main() { tokenizer_init(&tokenizer, "gpt2_tokenizer.bin"); // some memory for generating samples from the model - unsigned long long rng_state = 1337; + uint64_t rng_state = 1337; int* gen_tokens = (int*)mallocCheck(B * T * sizeof(int)); const int genT = 64; // number of steps of inference we will do diff --git a/train_gpt2.cu b/train_gpt2.cu index 49e94c3f3..60e24501f 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -1,7 +1,6 @@ /* GPT-2 Transformer Neural Net training loop. See README.md for usage. */ - #include #include #include @@ -21,12 +20,16 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. #include "llmc/dataloader.h" // defines: manual_seed, normal_ (same as torch.manual_seed and torch.normal) #include "llmc/rand.h" +// defines: lr_scheduler_init, get_learning_rate +#include "llmc/schedulers.h" // defines: sample_softmax, random_f32 #include "llmc/sampler.h" // defines: logger_init, logger_log_eval, logger_log_val, logger_log_train #include "llmc/logger.h" // defines: get_flops_promised #include "llmc/mfu.h" +// defines: OutlierDetector, init_detector, update_detector +#include "llmc/outlier_detector.h" // ----------- GPU utilities ----------- // defines: // WARP_SIZE, MAX_1024_THREADS_BLOCKS, CEIL_DIV, cudaCheck, PRECISION_MODE @@ -64,6 +67,9 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. // ----------- Multi-GPU support ----------- #include "llmc/zero.cuh" +// ---------------------------------------------------------------------------- +// global vars for I/O +char filename_buffer[512]; // ---------------------------------------------------------------------------- // global vars containing information about the GPU this process is running on @@ -71,6 +77,8 @@ cudaDeviceProp deviceProp; // fills in common_start() cudaStream_t main_stream; // one global variable to hold the multi-GPU configuration for this process MultiGpuConfig multi_gpu_config; +// buffer size to use for device <-> disk io +constexpr const size_t IO_BUF_SIZE = 32 * 1024 * 1024; // convenience function that only prints if the rank of process is zero void printf0(const char *format, ...) { @@ -95,7 +103,6 @@ void set_zero_configs(MultiGpuConfig* multi_gpu_config, int zero_stage, size_t t multi_gpu_config->zero_stage = 0; } else { - printf0("| Zero Stage1 is enabled |\n"); multi_gpu_config->zero_stage = 1; multi_gpu_config->shard_num_parameters = total_parameters / multi_gpu_config->num_processes; } @@ -196,23 +203,29 @@ constexpr int NUM_ACTIVATION_TENSORS = 23; typedef struct { floatX* encoded; // (B, T, C) floatX* ln1; // (L, B, T, C) - floatX* ln1_mean; // (L, B, T) - floatX* ln1_rstd; // (L, B, T) + float* ln1_mean; // (L, B, T) + float* ln1_rstd; // (L, B, T) floatX* atty; // (L, B, T, C) - floatX* att; // (L, B, NH, T, T) (smaller with cuDNN) + // cuDNN saves only some statistics information +#if ENABLE_CUDNN + float* att; // (L, B, NH, T) +#else + floatX* att; // (L, B, NH, T, T) +#endif + floatX* attproj; // (L, B, T, C) floatX* residual2; // (L, B, T, C) floatX* ln2; // (L, B, T, C) - floatX* ln2_mean; // (L, B, T) - floatX* ln2_rstd; // (L, B, T) + float* ln2_mean; // (L, B, T) + float* ln2_rstd; // (L, B, T) floatX* fch; // (L, B, T, 4*C) floatX* fch_gelu; // (L, B, T, 4*C) floatX* fcproj; // (L, B, T, C) floatX* residual3; // (L, B, T, C) floatX* lnf; // (B, T, C); if LN recomputation is enabled (-r 2 and above), will be used for _all_ layernorms - floatX* lnf_mean; // (B, T) - floatX* lnf_rstd; // (B, T) - floatX* losses; // (B, T) + float* lnf_mean; // (B, T) + float* lnf_rstd; // (B, T) + float* losses; // (B, T), will be accumulated in micro-steps // adding these two compared to the CPU .c code, needed for attention kernel as buffers floatX* qkvr; // (L, B, T, 3*C) // in inference mode, this buffer will store the logits @@ -227,76 +240,102 @@ typedef struct { floatX* scratch_btc; // (B, T, C) } ActivationTensors; -void fill_in_activation_sizes(size_t* act_sizes, size_t B, size_t T, GPT2Config config, int recompute) { +// enumerator to indentify the datatype of a tensor. +enum class DType : uint8_t { + FP32, FP16, BF16 +}; + +// Given a datatype enum, returns the underlying number of bytes +// for a scalar of that type +size_t sizeof_dtype(DType type) { + switch (type) { + case DType::FP32: + return sizeof(float); + case DType::FP16: + return sizeof(half); + case DType::BF16: + return sizeof(nv_bfloat16); + default: // handle or get compiler warning + fprintf(stderr, "Unknown datatype\n"); + exit(EXIT_FAILURE); + } +} + +DType dtype_of(float* f) { return DType::FP32; } +DType dtype_of(nv_bfloat16 * f) { return DType::BF16; } +DType dtype_of(half * f) { return DType::FP16; } + +struct TensorSpec { + void** ptr; + size_t size; + DType type; +}; + + +#define TENSOR_SPEC(pointer, size) TensorSpec{(void**)(&pointer), (size), dtype_of(pointer)}; + +void fill_in_activation_sizes(const ActivationTensors* data, TensorSpec (&tensors)[NUM_ACTIVATION_TENSORS], size_t B, size_t T, GPT2Config config, int recompute) { size_t Vp = config.padded_vocab_size; size_t L = config.num_layers; size_t NH = config.num_heads; size_t C = config.channels; - act_sizes[0] = B * T * C; // encoded + tensors[0] = TENSOR_SPEC(data->encoded, B * T * C); // if recompute >= 1 then we will recompute the layernorm forward activation during backward pass - act_sizes[1] = (recompute < 2) ? L * B * T * C : 0; // ln1 - act_sizes[2] = L * B * T; // ln1_mean - act_sizes[3] = L * B * T; // ln1_rstd - act_sizes[4] = L * B * T * C; // atty + tensors[1] = TENSOR_SPEC(data->ln1, (recompute < 2) ? L * B * T * C : 0); + tensors[2] = TENSOR_SPEC(data->ln1_mean, L * B * T); + tensors[3] = TENSOR_SPEC(data->ln1_rstd, L * B * T); + tensors[4] = TENSOR_SPEC(data->atty, L * B * T * C); #ifdef ENABLE_CUDNN // FP32 stats tensor for cuDNN to be passed to backward pass - act_sizes[5] = L * B * NH * T * (sizeof(float) / sizeof(floatX)); + tensors[5] = TENSOR_SPEC(data->att, L * B * NH * T); #else - act_sizes[5] = L * B * NH * T * T; // att + tensors[5] = TENSOR_SPEC(data->att, L * B * NH * T * T); #endif - act_sizes[6] = L * B * T * C; // attproj - act_sizes[7] = L * B * T * C; // residual2 + tensors[6] = TENSOR_SPEC(data->attproj, L * B * T * C); + tensors[7] = TENSOR_SPEC(data->residual2, L * B * T * C); // if recompute >= 1 then we will recompute the layernorm forward activation during backward pass - act_sizes[8] = (recompute < 2) ? L * B * T * C : 0; // ln2 - act_sizes[9] = L * B * T; // ln2_mean - act_sizes[10] = L * B * T; // ln2_rstd - act_sizes[11] = L * B * T * 4*C; // fch + tensors[8] = TENSOR_SPEC(data->ln2, (recompute < 2) ? L * B * T * C : 0); + tensors[9] = TENSOR_SPEC(data->ln2_mean, L * B * T); + tensors[10] = TENSOR_SPEC(data->ln2_rstd, L * B * T); + tensors[11] = TENSOR_SPEC(data->fch, L * B * T * 4*C); // if recompute >= 1 then we will recompute gelu_forward during backward and use this as scratch buffer - act_sizes[12] = (recompute < 1) ? L * B * T * 4*C : B * T * 4*C; - act_sizes[13] = L * B * T * C; // fcproj - act_sizes[14] = L * B * T * C; // residual3 - act_sizes[15] = B * T * C; // lnf - act_sizes[16] = B * T; // lnf_mean - act_sizes[17] = B * T; // lnf_rstd - act_sizes[18] = B * T; // losses - act_sizes[19] = L * B * T * 3*C; // qkvr - act_sizes[20] = B * T * max(3*C, max(NH*T, Vp)); // output / scratch - - act_sizes[21] = B * T * 4 * C; // scratch_bt4c - act_sizes[22] = B * T * C; // scratch_btc + tensors[12] = TENSOR_SPEC(data->fch_gelu, (recompute < 1) ? L * B * T * 4*C : B * T * 4*C); + tensors[13] = TENSOR_SPEC(data->fcproj, L * B * T * C); + tensors[14] = TENSOR_SPEC(data->residual3, L * B * T * C); + tensors[15] = TENSOR_SPEC(data->lnf, B * T * C); + tensors[16] = TENSOR_SPEC(data->lnf_mean, B * T); + tensors[17] = TENSOR_SPEC(data->lnf_rstd, B * T); + tensors[18] = TENSOR_SPEC(data->losses, B * T); + tensors[19] = TENSOR_SPEC(data->qkvr, L * B * T * 3*C); + tensors[20] = TENSOR_SPEC(data->output, B * T * max(3*C, max(NH*T, Vp))); + + tensors[21] = TENSOR_SPEC(data->scratch_bt4c, B * T * 4 * C); + tensors[22] = TENSOR_SPEC(data->scratch_btc, B * T * C); } -void* malloc_and_point(floatX** targets[], const size_t* act_sizes, size_t n) { - size_t num_activations = 0; - for (size_t i = 0; i < n; i++) { - num_activations += act_sizes[i]; +void* malloc_and_point_activations(TensorSpec (&tensors)[NUM_ACTIVATION_TENSORS]) { + size_t bytes = 0; + for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) { + bytes += tensors[i].size * sizeof_dtype(tensors[i].type); } + + printf0("allocating %d MiB for activations\n", (int)round(bytes / (1024 * 1024))); + void* acts_memory; - cudaCheck(cudaMalloc((void**)&acts_memory, num_activations * sizeof(floatX))); + cudaCheck(cudaMalloc((void**)&acts_memory, bytes)); char* acts_memory_iterator = (char*)acts_memory; - for (size_t i = 0; i < n; i++) { + for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) { // extra protection so we don't accidentally use an empty buffer - if(act_sizes[i] == 0) { - *(targets[i]) = NULL; + if(tensors[i].size == 0) { + *(tensors[i].ptr) = NULL; }else { - *(targets[i]) = (floatX*) acts_memory_iterator; - acts_memory_iterator += act_sizes[i] * sizeof(floatX); + *(tensors[i].ptr) = acts_memory_iterator; + acts_memory_iterator += tensors[i].size * sizeof_dtype(tensors[i].type); } } return acts_memory; } -void* malloc_and_point_activations(ActivationTensors* acts, const size_t* act_sizes) { - floatX** ptrs[] = { - &acts->encoded, &acts->ln1, &acts->ln1_mean, &acts->ln1_rstd, &acts->atty, - &acts->att, &acts->attproj, &acts->residual2, &acts->ln2, &acts->ln2_mean, - &acts->ln2_rstd, &acts->fch, &acts->fch_gelu, &acts->fcproj, &acts->residual3, &acts->lnf, - &acts->lnf_mean, &acts->lnf_rstd, &acts->losses, &acts->qkvr, &acts->output, - &acts->scratch_bt4c, &acts->scratch_btc - }; - return malloc_and_point(ptrs, act_sizes, NUM_ACTIVATION_TENSORS); -} - typedef struct { GPT2Config config; // the weights of the model, and their sizes @@ -315,18 +354,16 @@ typedef struct { float* master_weights; // is NULL unless fp32 weights is enabled. // the activations of the model, and their sizes ActivationTensors acts; - size_t act_sizes[NUM_ACTIVATION_TENSORS]; + TensorSpec acts_specs[NUM_ACTIVATION_TENSORS]; void* acts_memory; - size_t num_activations; // other run state configuration int batch_size; // the batch size (B) of current forward pass int seq_len; // the sequence length (T) of current forward pass int* inputs; // the input tokens for the current forward pass int* targets; // the target tokens for the current forward pass - float mean_loss; // after a forward pass with targets, will be populated with the mean loss - float accumulated_mean_loss; // Mean loss after aggregating it on all GPUs - floatX* cpu_losses; // CPU buffer to copy the losses to, allocated with cudaMallocHost - float* cpu_losses_fp32; // same but fp32 + float mean_loss; // after the last backward micro-batch, will be populated with mean loss across all GPUs and micro-steps + float* accumulated_mean_loss; // GPU buffer used to accumulate loss across micro-steps + float* cpu_losses; // CPU buffer to copy the losses to, allocated with cudaMallocHost unsigned long long rng_state; // the RNG state for seeding stochastic rounding etc. int use_master_weights; // keep master weights copy in float for optim update? 0|1 int recompute; // recompute gelu | layernorm forward during model backward? 0|1|2 @@ -344,12 +381,13 @@ void gpt2_init_common(GPT2 *model) { model->acts_memory = NULL; model->inputs = NULL; model->targets = NULL; + model->accumulated_mean_loss = NULL; model->cpu_losses = NULL; - model->cpu_losses_fp32 = NULL; // the B,T params are determined and set, fixed on first batch in forward() model->batch_size = 0; model->seq_len = 0; model->mean_loss = -1.0f; // -1.0f designates no loss, set at end of forward() + model->params_memory = NULL; // memory lazily initialized in backward() model->grads_memory = NULL; model->workload_indices = NULL; // on cpu, for encoder_backward @@ -359,7 +397,7 @@ void gpt2_init_common(GPT2 *model) { model->v_memory = NULL; model->master_weights = NULL; // other default settings - model->rng_state = 13371337; // used in stochastic rounding + model->rng_state = 13371337 + multi_gpu_config.process_rank; // used in stochastic rounding model->use_master_weights = 1; // safe default: do keep master weights in fp32 model->recompute = 1; // good default: recompute gelu but not layernorm } @@ -380,12 +418,10 @@ void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) { model_header[5] = model->config.num_heads; model_header[6] = model->config.channels; model_header[7] = model->config.padded_vocab_size; - fwrite(model_header, sizeof(int), 256, model_file); + fwriteCheck(model_header, sizeof(int), 256, model_file); // write the parameters - void* params_memory_cpu = (void*)mallocCheck(model->num_parameters_bytes); - cudaCheck(cudaMemcpy(params_memory_cpu, model->params_memory, model->num_parameters_bytes, cudaMemcpyDeviceToHost)); - fwrite(params_memory_cpu, 1, model->num_parameters_bytes, model_file); - free(params_memory_cpu); + device_to_file(model_file, model->params_memory, model->num_parameters_bytes, + IO_BUF_SIZE, main_stream); // close file, we're done fcloseCheck(model_file); } @@ -444,16 +480,14 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { } // create memory for model parameters on the device + assert(model->params_memory == nullptr && "Old model needs to be freed before loading from checkpoint again"); model->params_memory = malloc_and_point_parameters(&model->params, model->param_elements, model->param_sizeof); // read in all the parameters from file and copy them to device - void* params_memory_cpu = (void*)mallocCheck(model->num_parameters_bytes); - freadCheck(params_memory_cpu, 1, model->num_parameters_bytes, model_file); - cudaCheck(cudaMemcpy(model->params_memory, params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice)); - free(params_memory_cpu); + file_to_device(model->params_memory, model_file, model->num_parameters_bytes, + IO_BUF_SIZE, main_stream); fcloseCheck(model_file); - gpt2_init_common(model); // only return from this function once we are certain the params are ready on the GPU cudaCheck(cudaDeviceSynchronize()); } @@ -543,15 +577,13 @@ void gpt2_build_from_random(GPT2 *model, int depth) { // copy them to GPU cudaCheck(cudaMemcpy(model->params_memory, params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice)); free(params_memory_cpu); - - gpt2_init_common(model); } -void gpt2_forward(GPT2 *model, const int* inputs, const int* targets, size_t B, size_t T, int grad_accum_steps=1) { - // right now, this function is fully synchronous with the host +// propagate inputs through the network to produce logits. +// right now, this function is fully synchronous with the host +void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { NVTX_RANGE_FN(); - // targets are optional and could be NULL - // in this function we must be careful and use size_t instead of int, otherwise + // we must be careful and use size_t instead of int, otherwise // we could overflow int. E.g. l * B * NH * T * T overflows int at B 16. // ensure the model was initialized or error out @@ -574,19 +606,13 @@ void gpt2_forward(GPT2 *model, const int* inputs, const int* targets, size_t B, model->batch_size = B; model->seq_len = T; // allocate the space - fill_in_activation_sizes(model->act_sizes, B, T, model->config, model->recompute); - size_t num_activations = 0; - for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) { - num_activations += model->act_sizes[i]; - } - model->num_activations = num_activations; - printf0("allocating %d MiB for activations\n", (int)round(num_activations * sizeof(floatX) / (1024 * 1024))); - model->acts_memory = malloc_and_point_activations(&model->acts, model->act_sizes); + fill_in_activation_sizes(&model->acts, model->acts_specs, B, T, model->config, model->recompute); + model->acts_memory = malloc_and_point_activations(model->acts_specs); // also create memory for caching inputs and targets cudaCheck(cudaMalloc((void**)&model->inputs, B * T * sizeof(int))); cudaCheck(cudaMalloc((void**)&model->targets, B * T * sizeof(int))); - cudaCheck(cudaMallocHost((void**)&model->cpu_losses, B * T * sizeof(floatX))); - cudaCheck(cudaMallocHost((void**)&model->cpu_losses_fp32, B * T * sizeof(float))); + cudaCheck(cudaMalloc(((void**)&model->accumulated_mean_loss), sizeof(float))); + cudaCheck(cudaMallocHost((void**)&model->cpu_losses, B * T * sizeof(float))); } else { // validate B,T is consistent with how we've allocated the memory before // in principle we could get more clever here in the future, for now this is safest @@ -598,18 +624,9 @@ void gpt2_forward(GPT2 *model, const int* inputs, const int* targets, size_t B, // copy inputs/targets to the model cudaCheck(cudaMemcpy(model->inputs, inputs, B * T * sizeof(int), cudaMemcpyHostToDevice)); - if (targets != NULL) { - cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); - } - // validate inputs, all indices must be in the range [0, V) // we can do this while the copies are already underway - for(int i = 0; i < B * T; i++) { - assert(0 <= inputs[i] && inputs[i] < V); - if (targets != NULL) { - assert(0 <= targets[i] && targets[i] < V); - } - } + tokenCheck(inputs, B*T, V); // forward pass ParameterTensors params = model->params; // for brevity @@ -643,8 +660,8 @@ void gpt2_forward(GPT2 *model, const int* inputs, const int* targets, size_t B, floatX* l_attproj = acts.attproj + l * B * T * C; floatX* l_residual2 = acts.residual2 + l * B * T * C; floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf; - floatX* l_ln2_mean = acts.ln2_mean + l * B * T; - floatX* l_ln2_rstd = acts.ln2_rstd + l * B * T; + float* l_ln2_mean = acts.ln2_mean + l * B * T; + float* l_ln2_rstd = acts.ln2_rstd + l * B * T; floatX* l_fch = acts.fch + l * B * T * 4*C; // reuse the same activation buffer at each layer, as we'll re-compute the gelu during backward // very useful because we dramatically reduce VRAM usage, and may be able to fit larger batch size @@ -675,8 +692,8 @@ void gpt2_forward(GPT2 *model, const int* inputs, const int* targets, size_t B, // OK, fusion across blocks. if(l+1 != L) { floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + (l + 1) * B * T * C : acts.lnf; - floatX* l_ln1_mean = acts.ln1_mean + (l + 1) * B * T; - floatX* l_ln1_rstd = acts.ln1_rstd + (l + 1) * B * T; + float* l_ln1_mean = acts.ln1_mean + (l + 1) * B * T; + float* l_ln1_rstd = acts.ln1_rstd + (l + 1) * B * T; const floatX* l_ln1w = params.ln1w + (l + 1) * C; const floatX* l_ln1b = params.ln1b + (l + 1) * C; fused_residual_forward5(l_residual3, l_ln1, l_ln1_mean, l_ln1_rstd, l_residual2, l_fcproj, l_ln1w, l_ln1b, @@ -689,45 +706,53 @@ void gpt2_forward(GPT2 *model, const int* inputs, const int* targets, size_t B, } matmul_forward_cublaslt(acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp, main_stream); + cudaCheck(cudaDeviceSynchronize()); +} - // also forward the cross-entropy loss function if we have the targets - if (targets != NULL) { - NvtxRange classifier_and_loss_range("classifier_and_loss"); - // fused classifier: does the forward pass and first part of the backward pass - const float dloss = 1.0f / (B * T * grad_accum_steps); // results in the uniform average loss over all elements - fused_classifier(acts.output, acts.losses, dloss, model->targets, B, T, V, Vp, main_stream); - // for convenience also evaluate the mean loss (TODO re-think this compute+sync point) - cudaCheck(cudaMemcpy(model->cpu_losses, acts.losses, B * T * sizeof(floatX), cudaMemcpyDeviceToHost)); - float mean_loss = 0.0f; - for (int i = 0; i < B*T; i++) { - float loss = (float)(model->cpu_losses[i]); - model->cpu_losses_fp32[i] = loss; - mean_loss += loss; - } - mean_loss /= B*T*grad_accum_steps; - model->mean_loss = mean_loss; - } else { - // if we don't have targets, we don't have loss - model->mean_loss = -1.0f; + +// Forwards both the model and the loss and is used for validation splits and evals. +// In particular it populates cpu_losses with loss at each token. +// Some of the evals (e.g. HellaSwag) require the per-token losses, which are produced here. +float gpt2_validate(GPT2 *model, const int* inputs, const int* targets, size_t B, size_t T) { + assert(targets != NULL); + // forward the model itself + gpt2_forward(model, inputs, B, T); + // convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow + const size_t V = model->config.vocab_size; + const size_t Vp = model->config.padded_vocab_size; + + NvtxRange classifier_and_loss_range("classifier_and_loss"); + ActivationTensors acts = model->acts; + float mean_loss = 0.0f; + // fused classifier: does the forward pass and first part of the backward pass + const float dloss = 1.0f / (B * T); // results in the uniform average loss over all elements + // note: we don't need to generate dlogits here + cudaCheck(cudaMemset(acts.losses, 0, B*T*sizeof(float))); + cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); + tokenCheck(targets, B*T, V); // while the memcpy is underway, validate the targets + fused_classifier(acts.output, acts.losses, dloss, model->targets, B, T, V, Vp, False, main_stream); + cudaCheck(cudaMemcpy(model->cpu_losses, acts.losses, B * T * sizeof(float), cudaMemcpyDeviceToHost)); + for (int i = 0; i < B*T; i++) { + mean_loss += model->cpu_losses[i]; } + mean_loss /= B*T; cudaCheck(cudaDeviceSynchronize()); + return mean_loss; } + void gpt2_zero_grad(GPT2 *model) { NVTX_RANGE_FN(); + // the losses accumulate over the duration of gradient accumulation micro steps, also reset here + cudaCheck(cudaMemset(model->acts.losses, 0, model->batch_size * model->seq_len * sizeof(float))); if (model->grads_memory != NULL) { cudaCheck(cudaMemset(model->grads_memory, 0, model->num_parameters * sizeof(floatX))); } cudaCheck(cudaDeviceSynchronize()); } -void gpt2_backward(GPT2 *model, int* inputs, bool last_step) { +void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int grad_accum_steps, bool last_step) { NVTX_RANGE_FN(); - // double check we forwarded previously, with targets - if (model->mean_loss == -1.0f) { - printf("Error: must forward with targets before backward\n"); - exit(EXIT_FAILURE); - } // lazily allocate the memory for gradients of the weights and activations, if needed if (model->grads_memory == NULL) { @@ -747,16 +772,25 @@ void gpt2_backward(GPT2 *model, int* inputs, bool last_step) { // convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow const size_t B = model->batch_size; const size_t T = model->seq_len; + const size_t V = model->config.vocab_size; const size_t Vp = model->config.padded_vocab_size; const size_t L = model->config.num_layers; const size_t NH = model->config.num_heads; const size_t C = model->config.channels; - // backward pass: go in the reverse order of the forward pass, and call backward() functions ParameterTensors params = model->params; // for brevity ParameterTensors grads = model->grads; ActivationTensors acts = model->acts; + // accumulate the losses inside acts.losses, and kick off the backward pass inside the fused classifier + NvtxRange classifier_and_loss_range("classifier_and_loss"); + const float dloss = 1.0f / (float)(B * T * grad_accum_steps); // results in the uniform average loss over all elements + cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); + tokenCheck(targets, B*T, V); + fused_classifier(acts.output, acts.losses, dloss, model->targets, B, T, V, Vp, True, main_stream); + + // backward pass: go in the reverse order of the forward pass, and call backward() functions + // reset residual stream gradients (put here to work with gradient accumulation) floatX* dresidual = (floatX*)model->acts.scratch_btc; // the main buffer holding the gradient in the backward pass cudaCheck(cudaMemset(dresidual, 0, B * T * C * sizeof(floatX))); @@ -809,14 +843,14 @@ void gpt2_backward(GPT2 *model, int* inputs, bool last_step) { floatX* dl_fcprojb = grads.fcprojb + l * C; // get the pointers of the activations for this layer floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf; - floatX* l_ln1_mean = acts.ln1_mean + l * B * T; - floatX* l_ln1_rstd = acts.ln1_rstd + l * B * T; + float* l_ln1_mean = acts.ln1_mean + l * B * T; + float* l_ln1_rstd = acts.ln1_rstd + l * B * T; floatX* l_qkvr = acts.qkvr + l * B * T * 3*C; floatX* l_atty = acts.atty + l * B * T * C; floatX* l_residual2 = acts.residual2 + l * B * T * C; floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf; - floatX* l_ln2_mean = acts.ln2_mean + l * B * T; - floatX* l_ln2_rstd = acts.ln2_rstd + l * B * T; + float* l_ln2_mean = acts.ln2_mean + l * B * T; + float* l_ln2_rstd = acts.ln2_rstd + l * B * T; floatX* l_fch = acts.fch + l * B * T * 4*C; floatX* l_fch_gelu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * 4*C : acts.fch_gelu; // get the pointers of the gradients of the activations for this layer @@ -886,39 +920,42 @@ void gpt2_backward(GPT2 *model, int* inputs, bool last_step) { // Aggregate all gradients that are not part of the transformer blocks if(last_step) { + // reduce all the losses within the current GPU (across all microsteps) + global_sum_deterministic(model->accumulated_mean_loss, acts.losses, B*T, main_stream); + // reduce loss across GPUs to a single, final float across all microsteps and GPUs + #if MULTI_GPU + ncclCheck(ncclAllReduce(model->accumulated_mean_loss, model->accumulated_mean_loss, sizeof(float), ncclFloat, ncclAvg, multi_gpu_config.nccl_comm, main_stream)); + #endif + cudaCheck(cudaMemcpyAsync(&model->mean_loss, model->accumulated_mean_loss, sizeof(float), cudaMemcpyDeviceToHost, main_stream)); + // reduce the gradients for non-transformer block parameters floatX* const pointers[] = {grads.wte, grads.wpe, grads.lnfw, grads.lnfb}; const size_t nelem[] = {Vp * C, T * C, C, C}; multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config, main_stream); } cudaCheck(cudaDeviceSynchronize()); + if(last_step) { + model->mean_loss /= B*T*grad_accum_steps; + } else { + model->mean_loss = -1.f; // no loss available yet + } } // Compute sum of a single CPU value across all GPU processes. No-op when multi-GPU is disabled. -float multi_gpu_cpu_float_sum(float value) { +float multi_gpu_cpu_float_sum(float value, MultiGpuConfig* multi_gpu_config) { #ifdef MULTI_GPU - // note MPI doesn't support all reduce with mean, only sum - float result; - mpiCheck(MPI_Allreduce(&value, &result, 1, MPI_FLOAT, MPI_SUM, MPI_COMM_WORLD)); - return result; + if (multi_gpu_config->num_processes == 1) return value; + + float* unified_buffer = multi_gpu_config->unified_buffer; + *unified_buffer = value; + ncclCheck(ncclAllReduce(unified_buffer, unified_buffer, sizeof(float), ncclFloat, ncclSum, multi_gpu_config->nccl_comm, multi_gpu_config->nccl_stream)); + cudaCheck(cudaDeviceSynchronize()); + return *unified_buffer; #else return value; #endif } -// Averages out the loss and gradients across all GPUs. No-op when multi-GPU is disabled. -// todo - this version only works if all the parameters are the same size (floatX) -void gpt2_multi_gpu_loss_reduce(GPT2* model, MultiGpuConfig* multi_gpu_config) { -#ifdef MULTI_GPU - NVTX_RANGE_FN(); - // If there's only one process, there is nothing to do - if (multi_gpu_config->num_processes == 1) { return; } - // Average all losses. - model->accumulated_mean_loss = multi_gpu_cpu_float_sum(model->mean_loss) / multi_gpu_config->num_processes; -#endif - cudaCheck(cudaDeviceSynchronize()); -} - // Gets the offset of a specific tensor for a specific layer in the GPT2 model // layer_id is ignored for weights that are not part of a transformer block ShardInfo gpt2_get_tensor_at_layer(const GPT2 *model, int layer_id, int param_tensor_id) { @@ -936,7 +973,50 @@ ShardInfo gpt2_get_tensor_at_layer(const GPT2 *model, int layer_id, int param_te return {offset, size}; } -float gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, float grad_clip, int t, MultiGpuConfig* multi_gpu_config) { +float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) { + NVTX_RANGE_FN(); + floatX* grads_memory = (floatX*)model->grads_memory; + + // repurposing this buffer (which isn't needed now) to write grad norm into it + float* grad_norm_squared = (float*)model->acts.output; + float grad_norm_squared_cpu = 0.0f; + + int num_slices[2] = {1, model->config.num_layers}; + int max_num_block_sums = get_max_num_block_sums(num_slices, 2); + if (multi_gpu_config->zero_stage == 1) { + // because of the ncclReduceScatter() in backward, + // grads_memory only contains the averaged gradients at the local shards, + // so we only calculate the grad norm at the grads_memory belonging to the local shards + for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { + ShardInfo tensor = gpt2_get_tensor_at_layer(model, 0, i); + ShardInfo shard = multi_gpu_get_shard_offset(tensor.size, multi_gpu_config, 1); + ptrdiff_t offset = tensor.offset + shard.offset; + bool is_first_pass = (i == 0); + if((i < 2 || i > 13)) { + global_norm_squared(grad_norm_squared, grads_memory + offset, shard.size, 0, 1, + max_num_block_sums, is_first_pass, main_stream); + } else { + global_norm_squared(grad_norm_squared, grads_memory + offset, shard.size, tensor.size, model->config.num_layers, + max_num_block_sums, is_first_pass, main_stream); + } + } + global_sum_deterministic(grad_norm_squared, grad_norm_squared, max_num_block_sums, main_stream); +#if MULTI_GPU + // further sum the (partial) squared norm across all GPUs + ncclCheck(ncclAllReduce(grad_norm_squared, grad_norm_squared, sizeof(float), ncclFloat, ncclSum, multi_gpu_config->nccl_comm, main_stream)); +#endif + } else { + // in regular DDP, backward has averaged the gradients across all GPUs + // so each GPU can compute the squared norm over the whole grad vector, with no added comms needed + global_norm_squared(grad_norm_squared, grads_memory, model->num_parameters, 0, 1, max_num_block_sums, true, main_stream); + global_sum_deterministic(grad_norm_squared, grad_norm_squared, max_num_block_sums, main_stream); + } + cudaCheck(cudaMemcpy(&grad_norm_squared_cpu, grad_norm_squared, sizeof(float), cudaMemcpyDeviceToHost)); + float grad_norm_cpu = sqrtf(grad_norm_squared_cpu); + return grad_norm_cpu; +} + +void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, float grad_scale, int t, MultiGpuConfig* multi_gpu_config) { // update the model parameters using the AdamW optimizer // keep in mind that optimizer sharding (ZeRO-1) assigns different parameters to different GPUs // so we may not be responsible for the entire parameter tensor @@ -945,7 +1025,6 @@ float gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, fl // TODO: revisit and probably refactor this entire function NVTX_RANGE_FN(); size_t shard_num_parameters = multi_gpu_config->shard_num_parameters; // num parameters we are responsible for - floatX* grads_memory = (floatX*)model->grads_memory; // lazily allocate m,v memory and master weights (usually on the first iteration) if (model->m_memory == NULL) { @@ -965,54 +1044,12 @@ float gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, fl init_master_weights = true; } - // gradient clipping - // repurposing this buffer (which isn't needed now) to write grad norm into it - float* grad_norm_squared = (float*)model->acts.output; - float grad_norm_squared_cpu = 0.0f; - - if (multi_gpu_config->zero_stage == 1) { - // because of the ncclReduceScatter() in backward, - // grads_memory only contains the averaged gradients at the local shards, - // so we only calculate the grad norm at the grads_memory belonging to the local shards - cudaCheck(cudaMemsetAsync(grad_norm_squared, 0, sizeof(float), main_stream)); - for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { - if((i < 2 || i > 13)) { - ShardInfo tensor = gpt2_get_tensor_at_layer(model, 0, i); - ShardInfo shard = multi_gpu_get_shard_offset(tensor.size, multi_gpu_config, 1); - ptrdiff_t offset = tensor.offset + shard.offset; - global_norm_squared(grad_norm_squared, grads_memory + offset, shard.size, 0, 1, false, main_stream); - } else { - ShardInfo tensor = gpt2_get_tensor_at_layer(model, 0, i); - ShardInfo shard = multi_gpu_get_shard_offset(tensor.size, multi_gpu_config, 1); - ptrdiff_t offset = tensor.offset + shard.offset; - global_norm_squared(grad_norm_squared, grads_memory + offset, shard.size, tensor.size, model->config.num_layers, - false, main_stream); - } - } - cudaCheck(cudaMemcpy(&grad_norm_squared_cpu, grad_norm_squared, sizeof(float), cudaMemcpyDeviceToHost)); - // further sum the (partial) squared norm across all GPUs (see comment ^1 above) - grad_norm_squared_cpu = multi_gpu_cpu_float_sum(grad_norm_squared_cpu); - } else { - // in regular DDP, backward has averaged the gradients across all GPUs - // so each GPU can compute the squared norm over the whole grad vector, with no added comms needed - global_norm_squared(grad_norm_squared, grads_memory, model->num_parameters, 0, 1, true, main_stream); - cudaCheck(cudaMemcpy(&grad_norm_squared_cpu, grad_norm_squared, sizeof(float), cudaMemcpyDeviceToHost)); - } - - if(!isfinite(grad_norm_squared_cpu)) { - // may happen due to some issue (e.g. overflow?) - // TODO: later may want to keep a global counter of instabilities like this - printf0("[WARNING]: grad norm is not finite, skipping AdamW update\n"); - return -1.0f; - } - float grad_norm_cpu = sqrtf(grad_norm_squared_cpu); - float grad_scale = (grad_norm_cpu > grad_clip) ? grad_clip / grad_norm_cpu : 1.0f; - // AdamW update - unsigned int seed = random_u32(&model->rng_state); - // handle adamw for all the transformer blocks for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { + // generate a unique seed for each tensor + unsigned int seed = random_u32(&model->rng_state); + int num_layers = model->config.num_layers; if((i < 2 || i > 13)) { num_layers = 1; @@ -1037,7 +1074,7 @@ float gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, fl float* master_ptr = NULL; if (model->master_weights != NULL) { master_ptr = model->master_weights + opt_state_offset; } if(init_master_weights) { - size_t grid_size = CEIL_DIV(shard_num_parameters, 512); + size_t grid_size = CEIL_DIV(shard.size, 512); copy_and_cast_kernel<<>>(master_ptr, param_ptr, shard.size, shard.size, tensor.size); cudaCheck(cudaGetLastError()); @@ -1067,7 +1104,6 @@ float gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, fl } cudaCheck(cudaDeviceSynchronize()); - return grad_norm_cpu; } float gpt2_estimate_mfu(GPT2 *model, int num_tokens, float dt) { @@ -1100,16 +1136,16 @@ float gpt2_estimate_mfu(GPT2 *model, int num_tokens, float dt) { } void gpt2_free(GPT2 *model) { - cudaCheck(cudaFree(model->params_memory)); - cudaCheck(cudaFree(model->grads_memory)); - cudaCheck(cudaFree(model->m_memory)); - cudaCheck(cudaFree(model->v_memory)); - cudaCheck(cudaFree(model->master_weights)); - cudaCheck(cudaFree(model->acts_memory)); - cudaCheck(cudaFree(model->inputs)); - cudaCheck(cudaFree(model->targets)); + cudaFreeCheck(&model->params_memory); + cudaFreeCheck(&model->grads_memory); + cudaFreeCheck(&model->m_memory); + cudaFreeCheck(&model->v_memory); + cudaFreeCheck(&model->master_weights); + cudaFreeCheck(&model->acts_memory); + cudaFreeCheck(&model->inputs); + cudaFreeCheck(&model->targets); + cudaFreeCheck(&model->accumulated_mean_loss); cudaCheck(cudaFreeHost(model->cpu_losses)); - cudaCheck(cudaFreeHost(model->cpu_losses_fp32)); free(model->workload_indices); free(model->bucket_info); } @@ -1155,13 +1191,6 @@ void common_free(GPT2 &model) { #endif } -#ifndef TESTING -// if we are TESTING (see test_gpt2.cu), we'll skip everything below this point - -// ---------------------------------------------------------------------------- -// training resumption logic, very useful when jobs crash once in a while -// the goal is that we can resume optimization from any checkpoint, bit-perfect -// note that "state" refers to things not already saved in the model checkpoint file void save_state(const char* filename, int step, GPT2* model, DataLoader* loader) { printf("Writing state to %s\n", filename); @@ -1173,23 +1202,34 @@ void save_state(const char* filename, int step, GPT2* model, DataLoader* loader) state_header[1] = 1; // version number state_header[2] = multi_gpu_config.num_processes; // number of processes state_header[3] = multi_gpu_config.process_rank; // rank of this process + state_header[4] = model->use_master_weights; // whether we're using fp32 master weights + state_header[5] = loader->should_shuffle; // shuffle state of the dataloader // int main state, start at 10 to leave some padding state_header[10] = step; // step of the optimization - // model state, state, start at 20 to leave some padding + // model rng state, start at 20 to leave some padding *((unsigned long long*)&state_header[20]) = model->rng_state; // random number generator state // dataloader state, start at 30 to leave some padding - state_header[30] = loader->current_shard; // shard of the dataset - *((int64_t*)&state_header[31]) = loader->current_position; // position in shard - fwrite(state_header, sizeof(int), 256, state_file); + *((size_t*)&state_header[30]) = loader->current_shard_idx; // shard of the dataset + *((size_t*)&state_header[32]) = loader->current_sample_idx; // position in shard + fwriteCheck(state_header, sizeof(int), 256, state_file); + // write AdamW m, v, and master_weights here (they are all float) size_t shard_num_parameters = multi_gpu_config.shard_num_parameters; - float* cpu_buffer = (float*)mallocCheck(shard_num_parameters * sizeof(float)); - cudaCheck(cudaMemcpy(cpu_buffer, model->m_memory, shard_num_parameters * sizeof(float), cudaMemcpyDeviceToHost)); - fwrite(cpu_buffer, sizeof(float), shard_num_parameters, state_file); - cudaCheck(cudaMemcpy(cpu_buffer, model->v_memory, shard_num_parameters * sizeof(float), cudaMemcpyDeviceToHost)); - fwrite(cpu_buffer, sizeof(float), shard_num_parameters, state_file); - free(cpu_buffer); - fclose(state_file); + device_to_file(state_file, model->m_memory, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + device_to_file(state_file, model->v_memory, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + if(model->use_master_weights) { + device_to_file(state_file, model->master_weights, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + } + + // write dataloader state if we are using the Permuted version of it + if (loader->should_shuffle) { + fwriteCheck(&loader->glob_result.gl_pathc, sizeof(size_t), 1, state_file); // number of shards + fwriteCheck(loader->shard_indices, sizeof(int), loader->glob_result.gl_pathc, state_file); + fwriteCheck(&loader->shard_num_samples, sizeof(size_t), 1, state_file); + fwriteCheck(loader->intra_shard_indices, sizeof(int), loader->shard_num_samples, state_file); + fwriteCheck(&loader->shuffle_rng, sizeof(mt19937_state), 1, state_file); + } + fcloseCheck(state_file); } void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename) { @@ -1200,32 +1240,114 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename assert(state_header[1] == 1); // version number assert(state_header[2] == multi_gpu_config.num_processes); // number of processes assert(state_header[3] == multi_gpu_config.process_rank); // rank of this process + int use_master_weights = state_header[4]; // whether we're using fp32 master weights + int should_shuffle = state_header[5]; // shuffle state of the dataloader *step = state_header[10]; // step of the optimization model->rng_state = *((unsigned long long*)&state_header[20]); // random number generator state - int current_shard = state_header[30]; // shard of the dataset - int64_t current_position = *((int64_t*)&state_header[31]); // position in shard - dataloader_resume(loader, current_shard, current_position); - // read AdamW m, v (they are all float) - // also allocate the m, v memory in the model, if it does not yet exist + size_t current_shard_idx = *((size_t*)&state_header[30]); // shard index + size_t current_sample_idx = *((size_t*)&state_header[32]); // position in shard + + // read AdamW m, v, master_weights (they are all float) + // allocate all the needed memory as necessary size_t shard_num_parameters = multi_gpu_config.shard_num_parameters; if (model->m_memory == NULL) { printf0("allocating %zu MiB for AdamW optimizer state m\n", (shard_num_parameters * sizeof(float)) >> 20); - printf0("allocating %zu MiB for AdamW optimizer state v\n", (shard_num_parameters * sizeof(float)) >> 20); cudaCheck(cudaMalloc((void**)&model->m_memory, shard_num_parameters * sizeof(float))); + } + if (model->v_memory == NULL) { + printf0("allocating %zu MiB for AdamW optimizer state v\n", (shard_num_parameters * sizeof(float)) >> 20); cudaCheck(cudaMalloc((void**)&model->v_memory, shard_num_parameters * sizeof(float))); } - float* cpu_buffer = (float*)mallocCheck(shard_num_parameters * sizeof(float)); - freadCheck(cpu_buffer, sizeof(float), shard_num_parameters, state_file); - cudaCheck(cudaMemcpy(model->m_memory, cpu_buffer, shard_num_parameters * sizeof(float), cudaMemcpyHostToDevice)); - freadCheck(cpu_buffer, sizeof(float), shard_num_parameters, state_file); - cudaCheck(cudaMemcpy(model->v_memory, cpu_buffer, shard_num_parameters * sizeof(float), cudaMemcpyHostToDevice)); - free(cpu_buffer); - fclose(state_file); + if(use_master_weights == 1 && !model->use_master_weights) { + printf0("Warning: Master weights are present in state, but not enabled for current run."); + } else if (use_master_weights == 0 && model->use_master_weights) { + printf0("Error: Master weights requested, but not present in state file."); + exit(EXIT_FAILURE); + } + if (model->master_weights == NULL && use_master_weights == 1) { + printf0("allocating %zu MiB for master copy of params\n", (shard_num_parameters * sizeof(float)) >> 20); + cudaCheck(cudaMalloc((void**)&model->master_weights, shard_num_parameters * sizeof(float))); + } + file_to_device(model->m_memory, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + file_to_device(model->v_memory, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + if(model->use_master_weights) { + file_to_device(model->master_weights, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + } + + // revive the DataLoader object and its state + loader->should_shuffle = should_shuffle; + if (should_shuffle == 1) { + // ensure the number of shards matches + size_t glob_result_gl_pathc; + freadCheck(&glob_result_gl_pathc, sizeof(size_t), 1, state_file); + assert(glob_result_gl_pathc == loader->glob_result.gl_pathc); + // read the shard indices + loader->shard_indices = (int*)mallocCheck(loader->glob_result.gl_pathc * sizeof(int)); + freadCheck(loader->shard_indices, sizeof(int), loader->glob_result.gl_pathc, state_file); + // ensure the number of samples matches + size_t shard_num_samples; + freadCheck(&shard_num_samples, sizeof(size_t), 1, state_file); + assert(shard_num_samples == loader->shard_num_samples); + // read the intra-shard indices + loader->intra_shard_indices = (int*)mallocCheck(loader->shard_num_samples * sizeof(int)); + freadCheck(loader->intra_shard_indices, sizeof(int), loader->shard_num_samples, state_file); + // read the shuffle rng state + freadCheck(&loader->shuffle_rng, sizeof(mt19937_state), 1, state_file); + } + dataloader_resume(loader, current_shard_idx, current_sample_idx); + + // all done, close state file + fcloseCheck(state_file); +} + +void write_checkpoint(const char* output_log_dir, int step, GPT2* model, DataLoader* train_loader, MultiGpuConfig* multi_gpu_config) { + // a checkpoint contains: model weights, optimizer/dataloader state, and a DONE file + printf0("Writing checkpoint at step %d\n", step); + int rank = multi_gpu_config->process_rank; + // only rank 0 writes the model file because it is the same across all ranks + if (rank == 0) { + snprintf(filename_buffer, sizeof(filename_buffer), "%s/model_%08d.bin", output_log_dir, step); + gpt2_write_to_checkpoint(model, filename_buffer); + } + // all ranks write their state file + snprintf(filename_buffer, sizeof(filename_buffer), "%s/state_%08d_%05d.bin", output_log_dir, step, rank); + save_state(filename_buffer, step, model, train_loader); + // DONE file is a signal that this checkpoint as a whole is complete + multi_gpu_barrier(multi_gpu_config); + if (rank == 0) { + snprintf(filename_buffer, sizeof(filename_buffer), "%s/DONE_%08d", output_log_dir, step); + FILE* done_file = fopenCheck(filename_buffer, "w"); + fcloseCheck(done_file); + } +} + +void delete_checkpoint(const char* output_log_dir, int step, MultiGpuConfig* multi_gpu_config) { + // mirrors write_checkpoint function, cleans up checkpoint from disk + printf0("Deleting checkpoint at step %d\n", step); + int rank = multi_gpu_config->process_rank; + if (rank == 0) { + snprintf(filename_buffer, sizeof(filename_buffer), "%s/model_%08d.bin", output_log_dir, step); + remove(filename_buffer); + } + snprintf(filename_buffer, sizeof(filename_buffer), "%s/state_%08d_%05d.bin", output_log_dir, step, rank); + remove(filename_buffer); + if (rank == 0) { + snprintf(filename_buffer, sizeof(filename_buffer), "%s/DONE_%08d", output_log_dir, step); + remove(filename_buffer); + } } +#ifndef TESTING +// if we are TESTING (see test_gpt2.cu), we'll skip everything below this point + +// ---------------------------------------------------------------------------- +// training resumption logic, very useful when jobs crash once in a while +// the goal is that we can resume optimization from any checkpoint, bit-perfect +// note that "state" refers to things not already saved in the model checkpoint file + // ---------------------------------------------------------------------------- // CLI, poor man's argparse -// unclaimed flags lol: k,p +// (all single letters have been claimed now) void error_usage() { fprintf(stderr, "Usage: ./train_gpt2cu [options]\n"); @@ -1236,6 +1358,8 @@ void error_usage() { fprintf(stderr, " -e input from model at this filename (default = gpt2_124M_bf16.bin)\n"); fprintf(stderr, " -o output log dir (default = NULL, no logging)\n"); fprintf(stderr, " -n write optimization checkpoints every how many steps? (default 0, don't)\n"); + fprintf(stderr, " -nk max number of checkpoints to keep in the directory, removing old ones (0 = disable, default)\n"); + fprintf(stderr, " -nm every how many step checkpoints are considered major? major checkpoints never get deleted.\n"); fprintf(stderr, " -y resume optimization found inside output log dir? (0=restart/overwrite, 1=resume/append)\n"); // token layout for each step of the optimization fprintf(stderr, " -b (per-GPU, micro) batch size B (default = 4)\n"); @@ -1244,10 +1368,13 @@ void error_usage() { // workload (number of steps) fprintf(stderr, " -x max_steps of optimization to run (-1 (default) = disable, run 1 epoch)\n"); // optimization + fprintf(stderr, " -k learning rate scheduler (default = cosine)\n"); fprintf(stderr, " -l learning rate (default = 3e-4f)\n"); fprintf(stderr, " -u learning rate warmup iterations (default = 0, no warmup)\n"); fprintf(stderr, " -q learning rate decay: final fraction, at end of training (default = 1.0 (no decay))\n"); fprintf(stderr, " -c weight decay (default = 0.0f)\n"); + fprintf(stderr, " -sl outlier stability: skip update if loss goes above this in zscore (0.0f=off)\n"); + fprintf(stderr, " -sg outlier stability: skip update if grad_norm goes above this in zscore (0.0f=off)\n"); // evaluation fprintf(stderr, " -v val_loss_every, how often we evaluate val loss (default = 20)\n"); fprintf(stderr, " -m val_max_steps, up to how many val batches to estimate val loss? (default = 20)\n"); @@ -1262,20 +1389,28 @@ void error_usage() { // memory management fprintf(stderr, " -z zero_stage, Zero Optimization Stage, 0,1,2,3 (default = 0)\n"); fprintf(stderr, " -r recompute: less memory but less speed. (default = 1), 0|1|2 = none,gelu,gelu+ln\n"); + // multi-node settings + fprintf(stderr, " -pn num_processes (default = 1)\n"); + fprintf(stderr, " -pr process_rank (default = 0)\n"); + fprintf(stderr, " -pg gpus_per_node (default = 8)\n"); + fprintf(stderr, " -pm nccl_init_method: tcp,fs,mpi (default = mpi)\n"); + fprintf(stderr, " -ps server_ip - used only when nccl_init_method is tcp (default = -1)\n"); + fprintf(stderr, " -pp fs_path - used only when nccl_init_method is fs (default = /tmp)\n"); exit(EXIT_FAILURE); } // ---------------------------------------------------------------------------- // main training loop int main(int argc, char *argv[]) { - multi_gpu_config = multi_gpu_config_init(&argc, &argv); - // read in the (optional) command line arguments const char* train_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_train.bin"; const char* val_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_val.bin"; const char* load_filename = "gpt2_124M_bf16.bin"; // bf16 weights of the model + const char* lr_scheduler_type = "cosine"; const char* output_log_dir = NULL; - int checkpoint_every = 0; // write optimization checkpoints every how many steps? + int checkpoint_every = 0; // write checkpoints every how many steps? + int checkpoints_keep = 0; // how long checkpoint history do we keep? (in units of checkpoints) + int major_checkpoint_every = 0; // major checkpoints never get deleted when maintaining history int resume = 0; // resume the optimization, if one is found inside output_log_dir? int B = 4; // batch size int T = 1024; // sequence length max @@ -1284,6 +1419,8 @@ int main(int argc, char *argv[]) { int warmup_iterations = 0; float final_learning_rate_frac = 1.0f; // final fraction of learning rate, at end of training float weight_decay = 0.0f; + float skip_update_lossz = 0.0f; // skip update if loss goes above this in zscore + float skip_update_gradz = 0.0f; // skip update if grad_norm goes above this in zscore int val_loss_every = 20; // every how many steps do we eval validation loss? int val_max_steps = 20; // how many batches max do we eval for validation loss? int sample_every = 20; // every how many steps to do inference? @@ -1295,16 +1432,23 @@ int main(int argc, char *argv[]) { int recompute = 1; // recompute during backward setting, 0 = none, 1 = recompute gelu int zero_stage = 0; // Zero Optimization Stage for Multi-GPU training int hellaswag_eval = 0; + // multi-node settings + int num_processes = 1; // this should be set by the slurm environment + int process_rank = 0; // this should be set by the slurm environment + int gpus_per_node = 8; // this should be set by the slurm environment + char nccl_init_method[256] = "mpi"; // "tcp" or "fs" or "mpi" + char server_ip[256] = ""; // used if init_method set to "tcp" -> set to your server ip address + char fs_path[256] = ""; // used if init_method set to "fs" -> set to a shared filesystem path for (int i = 1; i < argc; i+=2) { if (i + 1 >= argc) { error_usage(); } // must have arg after flag if (argv[i][0] != '-') { error_usage(); } // must start with dash - if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter) + if (!(strlen(argv[i]) == 2 || strlen(argv[i]) == 3)) { error_usage(); } // must be -x[y] (one dash, one or two letters) // read in the args if (argv[i][1] == 'i') { train_data_pattern = argv[i+1]; } else if (argv[i][1] == 'j') { val_data_pattern = argv[i+1]; } else if (argv[i][1] == 'e') { load_filename = argv[i+1]; } else if (argv[i][1] == 'o') { output_log_dir = argv[i+1]; } - else if (argv[i][1] == 'n') { checkpoint_every = atoi(argv[i+1]); } + else if (argv[i][1] == 'n' && argv[i][2] == '\0') { checkpoint_every = atoi(argv[i+1]); } else if (argv[i][1] == 'y') { resume = atoi(argv[i+1]); } else if (argv[i][1] == 'b') { B = atoi(argv[i+1]); } // Per-GPU (micro) batch size else if (argv[i][1] == 't') { T = atoi(argv[i+1]); } @@ -1316,7 +1460,7 @@ int main(int argc, char *argv[]) { else if (argv[i][1] == 'x') { max_steps = atoi(argv[i+1]); } else if (argv[i][1] == 'v') { val_loss_every = atoi(argv[i+1]); } else if (argv[i][1] == 'm') { val_max_steps = atoi(argv[i+1]); } - else if (argv[i][1] == 's') { sample_every = atoi(argv[i+1]); } + else if (argv[i][1] == 's' && argv[i][2] == '\0') { sample_every = atoi(argv[i+1]); } else if (argv[i][1] == 'g') { genT = atoi(argv[i+1]); } else if (argv[i][1] == 'a') { overfit_single_batch = atoi(argv[i+1]); } else if (argv[i][1] == 'f') { override_enable_tf32 = atoi(argv[i+1]); } @@ -1324,19 +1468,26 @@ int main(int argc, char *argv[]) { else if (argv[i][1] == 'z') { zero_stage = atoi(argv[i+1]); } else if (argv[i][1] == 'r') { recompute = atoi(argv[i+1]); } else if (argv[i][1] == 'h') { hellaswag_eval = atoi(argv[i+1]); } + else if (argv[i][1] == 'k') { lr_scheduler_type = argv[i+1]; } + else if (argv[i][1] == 'p' && argv[i][2] == 'i') { strcpy(nccl_init_method, argv[i+1]); } + else if (argv[i][1] == 'p' && argv[i][2] == 'f') { strcpy(fs_path, argv[i+1]); } + else if (argv[i][1] == 'p' && argv[i][2] == 's') { strcpy(server_ip, argv[i+1]); } + else if (argv[i][1] == 'p' && argv[i][2] == 'n') { num_processes = atoi(argv[i+1]); } + else if (argv[i][1] == 'p' && argv[i][2] == 'r') { process_rank = atoi(argv[i+1]); } + else if (argv[i][1] == 'p' && argv[i][2] == 'g') { gpus_per_node = atoi(argv[i+1]); } + else if (argv[i][1] == 's' && argv[i][2] == 'l') { skip_update_lossz = atof(argv[i+1]); } + else if (argv[i][1] == 's' && argv[i][2] == 'g') { skip_update_gradz = atof(argv[i+1]); } + else if (argv[i][1] == 'n' && argv[i][2] == 'k') { checkpoints_keep = atoi(argv[i+1]); } + else if (argv[i][1] == 'n' && argv[i][2] == 'm') { major_checkpoint_every = atoi(argv[i+1]); } else { error_usage(); } } + multi_gpu_config = multi_gpu_config_init(num_processes, process_rank, gpus_per_node, server_ip, fs_path, nccl_init_method); + // should do a bit more error checking here assert(warmup_iterations >= 0); if (output_log_dir != NULL) { assert(strlen(output_log_dir) < 400); // careful bunch of hardcoded snprintf around this } - // check if output_log_dir does not exist or is a file - struct stat info; - if (output_log_dir != NULL && (stat(output_log_dir, &info ) != 0 || !(info.st_mode & S_IFDIR))) { - fprintf(stderr, "-o \"%s\" does not exist or is a file - are you specifying a file instead of dir?\n", output_log_dir); - exit(EXIT_FAILURE); - } int tokens_per_fwdbwd = B * T * multi_gpu_config.num_processes; // one micro-batch processes this many tokens // calculate sensible default for total batch size as assuming no gradient accumulation if (total_batch_size == -1) { total_batch_size = tokens_per_fwdbwd; } @@ -1357,10 +1508,13 @@ int main(int argc, char *argv[]) { printf0("| micro batch size B | %-50d |\n", B); printf0("| sequence length T | %-50d |\n", T); printf0("| total batch size | %-50d |\n", total_batch_size); + printf0("| LR scheduler | %-50s |\n", lr_scheduler_type); printf0("| learning rate (LR) | %-50e |\n", learning_rate); printf0("| warmup iterations | %-50d |\n", warmup_iterations); printf0("| final LR fraction | %-50e |\n", final_learning_rate_frac); printf0("| weight decay | %-50e |\n", weight_decay); + printf0("| skip update lossz | %-50f |\n", skip_update_lossz); + printf0("| skip update gradz | %-50f |\n", skip_update_gradz); printf0("| max_steps | %-50d |\n", max_steps); printf0("| val_loss_every | %-50d |\n", val_loss_every); printf0("| val_max_steps | %-50d |\n", val_max_steps); @@ -1381,7 +1535,6 @@ int main(int argc, char *argv[]) { printf0("+-----------------------+----------------------------------------------------+\n"); // figure out if we are going to be resuming the optimization - char filename_buffer[512]; int resuming = 0; int resume_max_step = find_max_step(output_log_dir); if (resume == 1) { @@ -1390,12 +1543,13 @@ int main(int argc, char *argv[]) { if (resume_max_step == -1) { } else { resuming = 1; - snprintf(filename_buffer, 512, "%s/model_%08d.bin", output_log_dir, resume_max_step); + snprintf(filename_buffer, sizeof(filename_buffer), "%s/model_%08d.bin", output_log_dir, resume_max_step); } } // build the GPT-2 model GPT2 model; + gpt2_init_common(&model); // if load_filename is of the form "dX" where X is an integer (e.g. d12), then we build // a random model with the depth of the model specified by X (e.g. 12). otherwise interpret // this variable as a checkpoint filename, and load that checkpoint @@ -1427,8 +1581,8 @@ int main(int argc, char *argv[]) { // build DataLoaders for both train and val DataLoader train_loader, val_loader; - dataloader_init(&train_loader, train_data_pattern, B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes); - dataloader_init(&val_loader, val_data_pattern, B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes); + dataloader_init(&train_loader, train_data_pattern, B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes, 1); + dataloader_init(&val_loader, val_data_pattern, B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes, 0); // figure out the number of training steps we will run for int train_num_batches = max_steps; // passed in from command line if (train_num_batches == -1) { @@ -1480,7 +1634,7 @@ int main(int argc, char *argv[]) { printf0("=> setting grad_accum_steps=%d\n", grad_accum_steps); // set up logging - create_dir_if_not_exists(output_log_dir); + if (multi_gpu_config.process_rank == 0) { create_dir_if_not_exists(output_log_dir); } Logger logger; logger_init(&logger, output_log_dir, multi_gpu_config.process_rank, resume); @@ -1488,6 +1642,11 @@ int main(int argc, char *argv[]) { Tokenizer tokenizer; tokenizer_init(&tokenizer, "gpt2_tokenizer.bin"); + // set up learning rate scheduler + LearningRateScheduler lr_scheduler; + lr_scheduler_init(&lr_scheduler, lr_scheduler_type, learning_rate, + warmup_iterations, train_num_batches, final_learning_rate_frac); + // some memory for generating samples from the model int* gen_tokens = (int*)mallocCheck(B * T * sizeof(int)); floatX* cpu_logits_raw = (floatX*)mallocCheck(model.config.vocab_size * sizeof(floatX)); @@ -1496,10 +1655,15 @@ int main(int argc, char *argv[]) { // if we found a checkpoint to resume from, load the optimization state int step = 0; if (resuming == 1) { - snprintf(filename_buffer, 512, "%s/state_%08d_%05d.bin", output_log_dir, resume_max_step, multi_gpu_config.process_rank); + snprintf(filename_buffer, sizeof(filename_buffer), "%s/state_%08d_%05d.bin", output_log_dir, resume_max_step, multi_gpu_config.process_rank); load_state(&step, &model, &train_loader, filename_buffer); } + // init an OutlierDetector the training loss + OutlierDetector loss_outlier_detector, grad_norm_outlier_detector; + init_detector(&loss_outlier_detector); + init_detector(&grad_norm_outlier_detector); + // train cudaEvent_t start, end; cudaCheck(cudaEventCreate(&start)); @@ -1519,11 +1683,10 @@ int main(int argc, char *argv[]) { dataloader_reset(&val_loader); for (int i = 0; i < val_num_batches; i++) { dataloader_next_batch(&val_loader); - gpt2_forward(&model, val_loader.inputs, val_loader.targets, B, T); - val_loss += model.mean_loss; + val_loss += gpt2_validate(&model, val_loader.inputs, val_loader.targets, B, T); } val_loss /= val_num_batches; - val_loss = multi_gpu_cpu_float_sum(val_loss) / multi_gpu_config.num_processes; + val_loss = multi_gpu_cpu_float_sum(val_loss, &multi_gpu_config) / multi_gpu_config.num_processes; printf0("val loss %f\n", val_loss); logger_log_val(&logger, step, val_loss); } @@ -1537,12 +1700,12 @@ int main(int argc, char *argv[]) { for (int i = 0; i < eval_loader.num_batches; i++) { if (i % 10 == 0) { printf("evaluating HellaSwag: %d/%d\r", i, eval_loader.num_batches); } evalloader_next_batch(&eval_loader); - gpt2_forward(&model, eval_loader.inputs, eval_loader.targets, B, T); - int correct = evalloader_stat_losses(&eval_loader, model.cpu_losses_fp32); + gpt2_validate(&model, eval_loader.inputs, eval_loader.targets, B, T); + int correct = evalloader_stat_losses(&eval_loader, model.cpu_losses); eval_acc_norm += (float)correct; } // careful because not all ranks may have the exact same allocation of number of examples - eval_acc_norm = multi_gpu_cpu_float_sum(eval_acc_norm); + eval_acc_norm = multi_gpu_cpu_float_sum(eval_acc_norm, &multi_gpu_config); printf0("HellaSwag: %d/%d = %f\n", (int)eval_acc_norm, eval_loader.num_examples, eval_acc_norm / eval_loader.num_examples); logger_log_eval(&logger, step, eval_acc_norm / eval_loader.num_examples); } @@ -1565,7 +1728,7 @@ int main(int argc, char *argv[]) { // we re-calculate the forward pass for all of (B,T) positions from scratch // but the inference here is just for sanity checking anyway // and we can maybe optimize a bit more later, with careful tests - gpt2_forward(&model, gen_tokens, NULL, B, T); + gpt2_forward(&model, gen_tokens, B, T); // furthermore, below we're only using b=0 (i.e. the first row) of all B rows // we're in principle running B "inference streams" in parallel here // only using position 0 because it's a bit faster (copy less probs from GPU -> CPU) @@ -1597,23 +1760,17 @@ int main(int argc, char *argv[]) { // once in a while checkpoint the optimization state (all ranks) if ((checkpoint_every > 0 && output_log_dir != NULL && resuming == 0) && ((step > 0 && step % checkpoint_every == 0) || last_step)) { - assert(strlen(output_log_dir) < 400); // being a bit lazy here - // only rank 0 writes the model file because it is the same across all ranks - if (multi_gpu_config.process_rank == 0) { - snprintf(filename_buffer, 512, "%s/model_%08d.bin", output_log_dir, step); - gpt2_write_to_checkpoint(&model, filename_buffer); - } - // all ranks write their state file - snprintf(filename_buffer, 512, "%s/state_%08d_%05d.bin", output_log_dir, step, multi_gpu_config.process_rank); - save_state(filename_buffer, step, &model, &train_loader); - // DONE file is a signal that this checkpoint as a whole is complete - multi_gpu_barrier(&multi_gpu_config); - if (multi_gpu_config.process_rank == 0) { - snprintf(filename_buffer, 512, "%s/DONE_%08d", output_log_dir, step); - FILE* done_file = fopenCheck(filename_buffer, "w"); - fclose(done_file); + // writes model .bin file, state .bin files, and DONE file for step + write_checkpoint(output_log_dir, step, &model, &train_loader, &multi_gpu_config); + // we only keep checkpoints_keep checkpoints on disk to save space + // so now that we wrote a new checkpoint, delete one old one (unless it is a "major" checkpoint) + // we only do this is checkpoint keeping is turned on (checkpoints_keep > 0) + int step_delete = step - checkpoints_keep * checkpoint_every; + if (checkpoints_keep > 0 && step_delete > 0 && + (major_checkpoint_every == 0 || step_delete % major_checkpoint_every != 0) + ) { + delete_checkpoint(output_log_dir, step_delete, &multi_gpu_config); } - multi_gpu_barrier(&multi_gpu_config); } resuming = 0; @@ -1626,8 +1783,7 @@ int main(int argc, char *argv[]) { // --------------- TRAINING SECTION BEGIN ----------------- // do one training step, doing forward/backward/update on total_batch_size tokens cudaEventRecord(start); - // gradient accumulation loop over micro-batches - float lossf = 0.0f; // for getting the mean loss over the accumulation steps + // gradient and loss accumulation loop over micro-batches for (int micro_step = 0; micro_step < grad_accum_steps; micro_step++) { // fetch the next data batch // and if we're overfitting a single batch, we'll only call this a single time @@ -1636,30 +1792,27 @@ int main(int argc, char *argv[]) { dataloader_next_batch(&train_loader); } // forward pass. note that we pass in grad_accum_steps, which scales down the loss - gpt2_forward(&model, train_loader.inputs, train_loader.targets, B, T, grad_accum_steps); - lossf += model.mean_loss; // the mean_loss was normalized by grad_accum_steps inside gpt2_forward + gpt2_forward(&model, train_loader.inputs, B, T); // backward pass. all model params accumulate gradients with += inside this inner loop - gpt2_backward(&model, train_loader.inputs, micro_step == grad_accum_steps - 1); + gpt2_backward_and_reduce(&model, train_loader.inputs, train_loader.targets, grad_accum_steps, micro_step == grad_accum_steps - 1); } - // override the mean loss, accounting for the gradient accumulation loop - // this is esp important to do here in multigpu update below, where model.mean_loss gets allreduced - model.mean_loss = lossf; - // average the loss and the gradients between all processes - gpt2_multi_gpu_loss_reduce(&model, &multi_gpu_config); - // learning rate schedule: warmup linearly to max LR, then cosine decay to LR * final_learning_rate_frac - float step_learning_rate = learning_rate; - if (step < warmup_iterations) { - step_learning_rate = learning_rate * ((float)(step + 1)) / warmup_iterations; + float zloss = (float)(update_detector(&loss_outlier_detector, (double)model.mean_loss)); // loss z-score + // fetch the next learning rate + float step_learning_rate = get_learning_rate(&lr_scheduler, step); + // calculate the gradient norm and how much we wish to scale the gradient + float grad_norm = gpt2_calculate_grad_norm(&model, &multi_gpu_config); + float zgrad = (float)(update_detector(&grad_norm_outlier_detector, (double)grad_norm)); // grad z-score + // update the model parameters + if (isfinite(zloss) && skip_update_lossz != 0.0f && zloss > skip_update_lossz) { + printf0("skipping update due to loss z-score of %f\n", zloss); + } else if (isfinite(zgrad) && skip_update_gradz != 0.0f && zgrad > skip_update_gradz) { + printf0("skipping update due to grad z-score of %f\n", zgrad); } else { - float decay_ratio = ((float)(step - warmup_iterations)) / (train_num_batches - warmup_iterations); - assert(0.0f <= decay_ratio && decay_ratio <= 1.0f); - float coeff = 0.5f * (1.0f + cosf(M_PI * decay_ratio)); // coeff starts at 1 and goes to 0 - assert(0.0f <= coeff && coeff <= 1.0f); - float min_lr = learning_rate * final_learning_rate_frac; - step_learning_rate = min_lr + coeff * (learning_rate - min_lr); + // clip the gradient norm to a maximum value + float grad_clip = 1.0f; + float grad_scale = (grad_norm > grad_clip) ? grad_clip / grad_norm : 1.0f; + gpt2_update(&model, step_learning_rate, 0.9f, 0.95f, 1e-8f, weight_decay, grad_scale, step+1, &multi_gpu_config); } - // update the model parameters - float grad_norm = gpt2_update(&model, step_learning_rate, 0.9f, 0.95f, 1e-8f, weight_decay, 1.0f, step+1, &multi_gpu_config); // zero out the gradients for the next iteration gpt2_zero_grad(&model); cudaCheck(cudaEventRecord(end)); @@ -1679,10 +1832,9 @@ int main(int argc, char *argv[]) { ema_tokens_per_second = 0.95f * ema_tokens_per_second + 0.05f * tokens_per_second; bias_corrected_ema_tokens_per_second = ema_tokens_per_second / (1.0f - powf(0.95f, step)); } - float accumulated_loss = multi_gpu_config.num_processes == 1 ? model.mean_loss : model.accumulated_mean_loss; float mfu = gpt2_estimate_mfu(&model, B * T * grad_accum_steps, time_elapsed_ms / 1000.0f); - printf0("step %4d/%d | train loss %7.6f | norm %6.4f | lr %.2e | %.2f ms | %.1f%% bf16 MFU | %.0f tok/s\n", - step + 1, train_num_batches, accumulated_loss, grad_norm, step_learning_rate, + printf0("step %4d/%d | loss %7.6f (%+.2fz)| norm %6.4f (%+.2fz)| lr %.2e | %.2f ms | %.1f%% bf16 MFU | %.0f tok/s\n", + step + 1, train_num_batches, model.mean_loss, zloss, grad_norm, zgrad, step_learning_rate, time_elapsed_ms, 100*mfu, bias_corrected_ema_tokens_per_second); logger_log_train(&logger, step, model.mean_loss, step_learning_rate, grad_norm); diff --git a/train_gpt2.py b/train_gpt2.py index 3e924487b..f71de9f59 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -802,6 +802,11 @@ def get_lr(it): or (args.overfit_single_batch and step == 0 and micro_step == 0): x, y = train_loader.next_batch() x, y = x.to(device), y.to(device) + if ddp: + # we want only the last micro-step to sync grads in a DDP model + # the official way to do this is with model.no_sync(), but that is a + # context manager that bloats the code, so we just toggle this variable + model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1) # forward pass with ctx: _, loss = model(x, y, return_logits=False) @@ -812,11 +817,6 @@ def get_lr(it): loss = loss / grad_accum_steps lossf += loss.detach() # keep track of the mean loss # backward pass - if ddp: - # we want only the last micro-step to sync grads in a DDP model - # the official way to do this is with model.no_sync(), but that is a - # context manager that bloats the code, so we just toggle this variable - model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1) if not args.inference_only: loss.backward() if ddp: diff --git a/train_gpt2_fp32.cu b/train_gpt2_fp32.cu index ebd7c9257..df412ea5e 100644 --- a/train_gpt2_fp32.cu +++ b/train_gpt2_fp32.cu @@ -1634,8 +1634,8 @@ int main(int argc, char *argv[]) { // build DataLoaders for both train and val DataLoader train_loader, val_loader; - dataloader_init(&train_loader, train_data_pattern, B, T, 0, 1); - dataloader_init(&val_loader, val_data_pattern, B, T, 0, 1); + dataloader_init(&train_loader, train_data_pattern, B, T, 0, 1, 1); + dataloader_init(&val_loader, val_data_pattern, B, T, 0, 1, 0); int train_num_batches = train_loader.num_tokens / (B*T); // let's do 1 epoch by default for now int val_num_batches = val_loader.num_tokens / (B*T); if (val_num_batches > val_max_steps) { val_num_batches = val_max_steps; }