You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm continue-training from an Olmo checkpoint on LUMI. When scaling up to 64 nodes (with global_train_batch_size=512), training works fine until the point where the first checkpoint save should occur, resulting in an OOM error on one of the nodes:
...
nid005143:0 out: 2025-01-20 10:29:10.590 nid005143:0 olmo.train:967 INFO [step=50/65355,epoch=0]
nid005143:0 out: train/masked_instances_local_rank=0
nid005143:0 out: optim/total_grad_norm=0.3484
nid005143:0 out: train/CrossEntropyLoss=2.633
nid005143:0 out: train/Perplexity=13.91
nid005143:0 out: train/ZLoss=0.0035
nid005143:0 out: throughput/total_tokens=104,857,600
nid005143:0 out: throughput/total_training_Gflops=684,340,122
nid005143:0 out: throughput/total_training_log_Gflops=20.34
nid005143:0 out: throughput/device/tokens_per_second=174.8
nid005143:0 out: throughput/device/batches_per_second=0.0427
nid005143:0 out: System/Peak GPU Memory (MB)=42,394
nid005143:0 out: 2025-01-20 10:29:11.183 nid005143:0 olmo.train:1259 INFO Saving checkpoint...
nid005143:0 out: 2025-01-20 10:29:11.786 nid005143:0 olmo.checkpoint:1922 INFO Saving model and optim state...
slurmstepd: error: Detected 1 oom_kill event in StepId=9178265.0. Some of the step tasks have been OOM Killed.
srun: error: nid006190: task 277: Out Of Memory
srun: Terminating StepId=9178265.0
slurmstepd: error: Detected 1 oom_kill event in StepId=9178265.0. Some of the step tasks have been OOM Killed.
slurmstepd: error: *** STEP 9178265.0 ON nid005143 CANCELLED AT 2025-01-20T10:34:05 ***
Have you encountered this issue before? Would you have any suggestions to get around this?
Note that in another setup (32 nodes, global_train_batch_size=256), checkpointing works as expected.
My job script and the config file appended below.
test-mling.sh
#!/bin/bash#SBATCH --job-name=test-mling#SBATCH --account=project_PROJECT_ID#SBATCH --output=/scratch/project_PROJECT_ID/logs/%j.log#SBATCH --nodes=64 # Total number of nodes#SBATCH --ntasks-per-node=8#SBATCH --gpus-per-node=8 # Allocate one gpu per MPI rank#SBATCH --cpus-per-task=7#SBATCH --exclusive=user#SBATCH --hint=nomultithread#SBATCH --mem=480G # max on lumi-g#SBATCH --time=02:00:00#SBATCH --partition=standard-g
module load LUMI/22.08 partition/G
export OLMO_CONTAINER=lumi-flash_latest.sif
export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASKexport MPICH_GPU_SUPPORT_ENABLED=1
export NCCL_SOCKET_IFNAME=hsn
export NCCL_NET_GDR_LEVEL=3
export MIOPEN_USER_DB_PATH=/tmp/${USER}-miopen-cache-${SLURM_JOB_ID}export MIOPEN_CUSTOM_CACHE_DIR=${MIOPEN_USER_DB_PATH}export CXI_FORK_SAFE=1
export CXI_FORK_SAFE_HP=1
export FI_CXI_DISABLE_CQ_HUGETLB=1
# We need to set this to avoid "Cassini Event Queue overflow detected." errors.export FI_CXI_DEFAULT_CQ_SIZE=131072
#export NCCL_DEBUG=INFOexport PYTHONPATH=.:${PYTHONPATH}export ROCM_PATH=/opt/rocm
export SINGULARITYENV_LD_LIBRARY_PATH=/usr/local/lib:/opt/cray/libfabric/1.15.2.0/lib64
# Try playing with max_split_size_mb if you run into OOM errors.export PYTORCH_HIP_ALLOC_CONF='max_split_size_mb:512'#export PYTORCH_HIP_ALLOC_CONF=expandable_segments:True
srun \
--cpus-per-task=$SLURM_CPUS_PER_TASK \
--distribution=block:block \
--kill-on-bad-exit \
scripts/run_with_environment.sh \
singularity exec \
-B"$PROJECT_DIR:$PROJECT_DIR" \
-B"$SCRATCH_DIR:$SCRATCH_DIR" \
-B"$FLASH_DIR:$FLASH_DIR" \
-B"$USER_DIR:$USER_DIR" \
-B /opt/cray:/opt/cray \
-B /usr/lib64/libcxi.so.1:/usr/lib64/libcxi.so.1 \
-B /usr/lib64/libjson-c.so.3:/usr/lib64/libjson-c.so.3 \
$PROJECT_DIR/containers/$OLMO_CONTAINER \
python scripts/train.py configs/test-mling.yaml --run_name=${SLURM_JOB_ID}${@}
This is a little puzzling, but I noticed this line: SBATCH --mem=480G. Do you need that? Does it restrict you? I don't think we set that when we ran on LUMI.
You might already know this, but you're running out of CPU memory, not GPU memory. It probably has something to do with 8 ranks per node all assembling the model, or too much of the model, at the same time. You can't fit 8 7B models in CPU memory at the same time. You could try a different checkpointing scheme. We had good luck with local checkpointing on LUMI before, but this was before a bunch of software updates.
I'm continue-training from an Olmo checkpoint on LUMI. When scaling up to 64 nodes (with global_train_batch_size=512), training works fine until the point where the first checkpoint save should occur, resulting in an OOM error on one of the nodes:
Have you encountered this issue before? Would you have any suggestions to get around this?
Note that in another setup (32 nodes, global_train_batch_size=256), checkpointing works as expected.
My job script and the config file appended below.
test-mling.sh
test-mling.yaml
The text was updated successfully, but these errors were encountered: