From 5630d73164feb66f0330f79f2d52203ad462c27b Mon Sep 17 00:00:00 2001 From: Tcc0403 Date: Wed, 16 Oct 2024 06:14:03 +0800 Subject: [PATCH 1/2] Add missing ignore_index tests --- .../test_fused_linear_cross_entropy.py | 93 +++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/test/transformers/test_fused_linear_cross_entropy.py b/test/transformers/test_fused_linear_cross_entropy.py index 57e2cf53..23e128e7 100644 --- a/test/transformers/test_fused_linear_cross_entropy.py +++ b/test/transformers/test_fused_linear_cross_entropy.py @@ -164,6 +164,99 @@ def test_correctness( ) +@pytest.mark.parametrize( + "B, T, H, V", + [ + # (2, 4, 512, 512), # The test does not work on some CI GPUs. Issue #160 + (8, 2048, 4096, 32000), # llama2, mistral + # Comment out to speed up testing + # (4, 2048, 4096, 128256), # llama3 8B + # (4, 1024, 8192, 128256), # llama3 70B + (4, 423, 8192, 32000), # random shape + ], +) +@pytest.mark.parametrize( + "reduction, scalar, dtype, atol, rtol", + [ + ("mean", 1.0, torch.bfloat16, 5e-3, 5e-2), + ("mean", 1.0, torch.float32, 1e-5, 5e-4), + ("sum", 1.0, torch.bfloat16, 5e-0, 5e1), + ("sum", 1.0, torch.float32, 1e-3, 5e-2), + ], +) +@pytest.mark.parametrize("ignore_index", [-100, 42]) +def test_correctness_with_ignore_index( + B, T, H, V, scalar, dtype, bias, ignore_index, reduction, atol, rtol +): + device = "cuda" + torch_lm_head_ce = TorchLMHeadCE( + H=H, + V=V, + bias=bias, + ignore_index=ignore_index, + reduction=reduction, + dtype=dtype, + ).to(device) + liger_lm_head_ce = LigerLMHeadCE( + H=H, + V=V, + bias=bias, + ignore_index=ignore_index, + reduction=reduction, + dtype=dtype, + ).to(device) + + # init the linear in all CEs with the same weights + torch_lm_head_ce.lin.weight.data = liger_lm_head_ce.lin.weight.data = torch.rand( + V, H, device=device, dtype=dtype + ) + + if bias: + torch_lm_head_ce.lin.bias.data = liger_lm_head_ce.lin.bias.data = torch.rand( + V, device=device, dtype=dtype + ) + + _tensor = torch.randn(B * T, H, device=device, dtype=dtype) * scalar + _input1 = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint( + 1, B * T // 2, (1,) + ).item() # Random number of elements to set to ignore_index + indices_to_assign = torch.randperm(B * T)[ + :num_elements_to_assign + ] # Randomly select indices + target[indices_to_assign] = ignore_index + + output1 = torch_lm_head_ce(_input1, target) + output2 = liger_lm_head_ce(_input2, target) + + assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) + + output1.backward() + output2.backward() + + assert_verbose_allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol) + + assert_verbose_allclose( + torch_lm_head_ce.lin.weight.grad, + liger_lm_head_ce.lin.weight.grad, + atol=atol, + rtol=rtol, + ) + + if bias: + assert_verbose_allclose( + torch_lm_head_ce.lin.bias.grad, + liger_lm_head_ce.lin.bias.grad, + atol=atol, + rtol=rtol, + ) + + @pytest.mark.parametrize( "B, T, H, V", [ From 8350b91f6ff787d0d0da7872720b9e799c06c90c Mon Sep 17 00:00:00 2001 From: Tcc0403 Date: Wed, 16 Oct 2024 06:35:30 +0800 Subject: [PATCH 2/2] Fix test script --- .../test_fused_linear_cross_entropy.py | 100 +++--------------- 1 file changed, 14 insertions(+), 86 deletions(-) diff --git a/test/transformers/test_fused_linear_cross_entropy.py b/test/transformers/test_fused_linear_cross_entropy.py index 23e128e7..f048276b 100644 --- a/test/transformers/test_fused_linear_cross_entropy.py +++ b/test/transformers/test_fused_linear_cross_entropy.py @@ -100,9 +100,20 @@ def forward(self, x, y): ], ) @pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize("label_smoothing", [0, 0.1]) +@pytest.mark.parametrize("label_smoothing, ignore_index", [(0.0, -100), (0.1, 42)]) def test_correctness( - B, T, H, V, scalar, dtype, bias, label_smoothing, reduction, atol, rtol + B, + T, + H, + V, + scalar, + dtype, + bias, + label_smoothing, + ignore_index, + reduction, + atol, + rtol, ): device = "cuda" torch_lm_head_ce = TorchLMHeadCE( @@ -110,89 +121,6 @@ def test_correctness( V=V, bias=bias, label_smoothing=label_smoothing, - reduction=reduction, - dtype=dtype, - ).to(device) - liger_lm_head_ce = LigerLMHeadCE( - H=H, - V=V, - bias=bias, - label_smoothing=label_smoothing, - reduction=reduction, - dtype=dtype, - ).to(device) - - # init the linear in all CEs with the same weights - torch_lm_head_ce.lin.weight.data = liger_lm_head_ce.lin.weight.data = torch.rand( - V, H, device=device, dtype=dtype - ) - - if bias: - torch_lm_head_ce.lin.bias.data = liger_lm_head_ce.lin.bias.data = torch.rand( - V, device=device, dtype=dtype - ) - - _tensor = torch.randn(B * T, H, device=device, dtype=dtype) * scalar - _input1 = _tensor.detach().clone().requires_grad_(True) - _input2 = _tensor.detach().clone().requires_grad_(True) - - target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) - - output1 = torch_lm_head_ce(_input1, target) - output2 = liger_lm_head_ce(_input2, target) - - assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) - - output1.backward() - output2.backward() - - assert_verbose_allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol) - - assert_verbose_allclose( - torch_lm_head_ce.lin.weight.grad, - liger_lm_head_ce.lin.weight.grad, - atol=atol, - rtol=rtol, - ) - - if bias: - assert_verbose_allclose( - torch_lm_head_ce.lin.bias.grad, - liger_lm_head_ce.lin.bias.grad, - atol=atol, - rtol=rtol, - ) - - -@pytest.mark.parametrize( - "B, T, H, V", - [ - # (2, 4, 512, 512), # The test does not work on some CI GPUs. Issue #160 - (8, 2048, 4096, 32000), # llama2, mistral - # Comment out to speed up testing - # (4, 2048, 4096, 128256), # llama3 8B - # (4, 1024, 8192, 128256), # llama3 70B - (4, 423, 8192, 32000), # random shape - ], -) -@pytest.mark.parametrize( - "reduction, scalar, dtype, atol, rtol", - [ - ("mean", 1.0, torch.bfloat16, 5e-3, 5e-2), - ("mean", 1.0, torch.float32, 1e-5, 5e-4), - ("sum", 1.0, torch.bfloat16, 5e-0, 5e1), - ("sum", 1.0, torch.float32, 1e-3, 5e-2), - ], -) -@pytest.mark.parametrize("ignore_index", [-100, 42]) -def test_correctness_with_ignore_index( - B, T, H, V, scalar, dtype, bias, ignore_index, reduction, atol, rtol -): - device = "cuda" - torch_lm_head_ce = TorchLMHeadCE( - H=H, - V=V, - bias=bias, ignore_index=ignore_index, reduction=reduction, dtype=dtype, @@ -201,6 +129,7 @@ def test_correctness_with_ignore_index( H=H, V=V, bias=bias, + label_smoothing=label_smoothing, ignore_index=ignore_index, reduction=reduction, dtype=dtype, @@ -221,7 +150,6 @@ def test_correctness_with_ignore_index( _input2 = _tensor.detach().clone().requires_grad_(True) target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) - # Assign some random number of elements as ignore_index num_elements_to_assign = torch.randint( 1, B * T // 2, (1,)