Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fused_linear_cross_entropy: Move float32 cast into kernel #238

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def liger_cross_entropy_kernel(
# 3. [Online softmax] first pass: find max + sum
m = float("-inf") # m is the max value. use the notation from the paper
d = 0.0 # d is the sum. use the notation from the paper
ori_X_y = tl.load(
X_ptr + y
ori_X_y = tl.load(X_ptr + y).cast(
tl.float32
) # we need to store the original value of X_y for the loss calculation

# Label smoothing is a general case of normal cross entropy
Expand All @@ -73,8 +73,11 @@ def liger_cross_entropy_kernel(
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
)
X_ptr + X_offsets,
mask=X_offsets < n_cols,
other=float("-inf"),
# Ensure float32 precision for softmax calculation
).cast(tl.float32)
block_max = tl.max(X_block)
if label_smoothing > 0:
# scale X beforehand to avoid overflow
Expand All @@ -94,8 +97,11 @@ def liger_cross_entropy_kernel(
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
)
X_ptr + X_offsets,
mask=X_offsets < n_cols,
other=float("-inf"),
# Ensure float32 precision for softmax calculation
).cast(tl.float32)
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)

Expand Down Expand Up @@ -124,7 +130,7 @@ def liger_cross_entropy_kernel(
loss = loss * (1 - label_smoothing) + smooth_loss

# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
X_y = tl.load(X_ptr + y)
X_y = tl.load(X_ptr + y).cast(tl.float32)
X_y += -(1 - label_smoothing) / (n_non_ignore)

tl.store(loss_ptr, loss)
Expand Down
13 changes: 0 additions & 13 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
def fused_linear_cross_entropy_forward(
_input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0
):
dtype = (
torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else _input.dtype
)
device = _input.device

# inputs have shape: BT x H
Expand Down Expand Up @@ -65,9 +62,6 @@ def fused_linear_cross_entropy_forward(
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
n_non_ignore = (target_chunk != ignore_index).sum().item()

# when doing CE, use the upcasted precision
logits_chunk = logits_chunk.float()

# ensure _input and target are contiguous
logits_chunk = logits_chunk.contiguous()
target_chunk = target_chunk.contiguous()
Expand All @@ -88,13 +82,6 @@ def fused_linear_cross_entropy_forward(
num_warps=32,
)

# gradient of logits_chunk is computed in-place by the above triton kernel.
# Following HuggingFace model source code, we do the forward and backward
# w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) os huge.
# (reference: https://github.com/huggingface/transformers/blob/v4.42.4/src/transformers/models/llama/modeling_llama.py#L1194)
# Propagating to lm_head's backward, we'll switch back to the original dtype.
logits_chunk = logits_chunk.to(dtype)

# gradient of logits_chunk is computed in-place by the above triton kernel and is of shape: chunk_size x V
# thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H
# additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
Expand Down
63 changes: 62 additions & 1 deletion test/transformers/test_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import torch
from torch.nn import CrossEntropyLoss

from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
from liger_kernel.ops.cross_entropy import (
LigerCrossEntropyFunction,
liger_cross_entropy_kernel,
)
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy

Expand Down Expand Up @@ -493,3 +496,61 @@ def test_large_no_exception(B, T, V):
# The large inputs were hitting cuda illegal memory access because of
# https://github.com/triton-lang/triton/issues/1058
_full_pass_once(B, T, V)


def test_float32_internal():
"""
This test validates that the internal softmax calculations occur in float32,
even if the input dtype is bfloat16.
"""
# Set up test parameters
batch_size = 4
n_cols = 128256
n_non_ignore = batch_size
ignore_index = -100
label_smoothing = 0.0
BLOCK_SIZE = 32768

# Initialize input tensors
X_init = torch.randn(batch_size, n_cols, dtype=torch.bfloat16, device="cuda")
Y = torch.randint(0, n_cols, (batch_size,), device="cuda")

# Run kernel for bfloat16
X_bf16 = X_init.clone()
loss_bf16 = torch.zeros(batch_size, dtype=torch.float32, device="cuda")
liger_cross_entropy_kernel[(batch_size,)](
X_ptr=X_bf16,
X_stride=X_bf16.stride(-2),
Y_ptr=Y,
Y_stride=Y.stride(-1),
loss_ptr=loss_bf16,
loss_stride=loss_bf16.stride(-1),
n_cols=n_cols,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
label_smoothing=label_smoothing,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32,
)

# Run kernel for float32
X_fp32 = X_init.float()
loss_fp32 = torch.zeros(batch_size, dtype=torch.float32, device="cuda")
liger_cross_entropy_kernel[(batch_size,)](
X_ptr=X_fp32,
X_stride=X_fp32.stride(-2),
Y_ptr=Y,
Y_stride=Y.stride(-1),
loss_ptr=loss_fp32,
loss_stride=loss_fp32.stride(-1),
n_cols=n_cols,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
label_smoothing=label_smoothing,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32,
)

# The results should be **exactly** equal after downcasting
assert (X_bf16 == X_fp32.bfloat16()).all()
assert (loss_bf16 == loss_fp32).all()
Loading