Skip to content

Commit

Permalink
#15732: add matmul block h/w parameter processing (#15938)
Browse files Browse the repository at this point in the history
### Ticket
Link to Github Issue #15732

### Problem description
- CBs don't fit into L1

### What's changed
- adjust block h/w to allow CBs to fit into L1, at least if output is
not sharded

### Checklist
- [x] Post commit CI passes
https://github.com/tenstorrent/tt-metal/actions/runs/12583509804
- [x] Blackhole Post commit (if applicable)
https://github.com/tenstorrent/tt-metal/actions/runs/12576215442
- [x] Model regression CI testing passes (if applicable)
https://github.com/tenstorrent/tt-metal/actions/runs/12598782714
- [x] Device performance regression CI testing passes (if applicable)
https://github.com/tenstorrent/tt-metal/actions/runs/12601098604/job/35121724725
has the same mnist issue as main

https://github.com/tenstorrent/tt-metal/actions/runs/12598829248/job/35114714224
although will probably need update because still different - updated
expected perf and got a clean run:
https://github.com/tenstorrent/tt-metal/actions/runs/12642511513
```
2025-01-03 17:33:56.936 | ERROR    | models.perf.device_perf_utils:check_device_perf_results:135 - ttnn_functional_convnet_mnist1_ AVG DEVICE KERNEL SAMPLES/S is too slow with 2160.8073, min expected 2716.485.
```
vs
```
2025-01-03 15:24:13.851 | ERROR    | models.perf.device_perf_utils:check_device_perf_results:135 - ttnn_functional_convnet_mnist1_ AVG DEVICE KERNEL SAMPLES/S is too slow with 1938.6871, min expected 2716.485.
```
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [x] New/Existing tests provide coverage for changes

Also ran some T3k tests:
Frequent
https://github.com/tenstorrent/tt-metal/actions/runs/12329955947
Unit https://github.com/tenstorrent/tt-metal/actions/runs/12329950882
Demo https://github.com/tenstorrent/tt-metal/actions/runs/12565631818
needs to have expected text updated, besides that same regressions as
main https://github.com/tenstorrent/tt-metal/actions/runs/12565308545 -
new run has same issues as main only:
https://github.com/tenstorrent/tt-metal/actions/runs/12638947085/job/35217119707

And more pipelines:
T3k perplexity
https://github.com/tenstorrent/tt-metal/actions/runs/12601369497
Single card demo - updated text since program configs changed, now
passes https://github.com/tenstorrent/tt-metal/actions/runs/12603042794
T3k model perf -
https://github.com/tenstorrent/tt-metal/actions/runs/12601372071
expected perf regressions that are in main
  • Loading branch information
bbradelTT authored Jan 7, 2025
1 parent bf94433 commit 3787a8a
Show file tree
Hide file tree
Showing 15 changed files with 438 additions and 128 deletions.
2 changes: 1 addition & 1 deletion models/demos/bert_tiny/tests/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_perf_device_bare_metal(batch_size, expected_perf):
if is_wormhole_b0():
expected_perf = 3990.0
else:
expected_perf = 3476.55
expected_perf = 3460.0

command = f"pytest tests/ttnn/integration_tests/bert_tiny/test_bert_tiny.py::test_bert_for_question_answering"
cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"]
Expand Down
2 changes: 1 addition & 1 deletion models/demos/convnet_mnist/tests/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def test_perf_device_bare_metal_convnet_mnist(batch_size, expected_perf):
subdir = "ttnn_convnet_mnist"
num_iterations = 1
margin = 0.03
expected_perf = 1800 if is_grayskull() else 2800.5
expected_perf = 2430 if is_grayskull() else 3358.0

command = f"pytest tests/ttnn/integration_tests/convnet_mnist/test_convnet_mnist.py"
cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"]
Expand Down
2 changes: 1 addition & 1 deletion models/demos/distilbert/tests/test_perf_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_distilbert_perf_device(batch_size, test, reset_seeds):
if is_grayskull():
expected_perf = 57.3
elif is_wormhole_b0():
expected_perf = 103.884
expected_perf = 95.5

command = f"pytest tests/ttnn/integration_tests/distilbert/test_ttnn_distilbert.py::test_distilbert_for_question_answering[sequence_size=768-batch_size=8-model_name=distilbert-base-uncased-distilled-squad]"
cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class DeviceSetup(Enum):
PREFILL_CONFIG_TO_PCC = {
DeviceSetup.GRAYSKULL: {
"BFLOAT16-DRAM": {
128: (0.88, 0.97, 0.88),
128: (0.87, 0.97, 0.88),
256: (0.92, 0.97, 0.88),
},
"BFLOAT16-L1": {
Expand Down
2 changes: 1 addition & 1 deletion models/demos/t3000/falcon7b/expected_greedy_output.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion models/demos/wormhole/bert_tiny/tests/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_perf_bert_tiny(
@pytest.mark.parametrize(
"batch_size, expected_perf",
[
(16, 6946.78),
(16, 6292.78),
],
)
def test_perf_device_bare_metal(batch_size, expected_perf):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_distilbert_perf_device(batch_size, test, reset_seeds):
margin = 0.03
num_iterations = 1

expected_perf = 224
expected_perf = 245
if ttnn.GetNumAvailableDevices() == 2:
batch_size = batch_size * 2

Expand Down
2 changes: 1 addition & 1 deletion models/demos/wormhole/falcon7b/expected_greedy_output.json

Large diffs are not rendered by default.

72 changes: 59 additions & 13 deletions tests/sweep_framework/sweeps/matmul/short/matmul_traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,9 @@
(9, 768, 768, 640),
(920, 256, 256, 256),
],
"core_grid": [False],
"dtype": [ttnn.float32],
"core_grid": [True, False],
"dtype": [ttnn.float32, ttnn.bfloat16],
"test_bias": [True, False],
},
"gpt": {
"params": [
Expand Down Expand Up @@ -352,7 +353,8 @@
(64, 12, 64, 1024, 64, 12, 1024, 64),
],
"core_grid": [True, False],
"dtype": [ttnn.float32],
"dtype": [ttnn.float32, ttnn.bfloat16],
"test_bias": [True, False],
},
"forge": {
"params": [
Expand Down Expand Up @@ -2607,30 +2609,70 @@
256,
),
],
"core_grid": [False],
"core_grid": [True, False],
"dtype": [ttnn.float32, ttnn.bfloat16],
"test_bias": [True, False],
},
}


def run_matmul(device, params, core_grid, dtype):
# Invalidate vector is called during the generation phase where each vector will be passed in.
# If invalidated, the vector will still be stored but will be skipped.
# Returns False, None if the vector is valid, and True, str with a reason for invalidation if it is invalid.
def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]:
# Cannot have bias and batch. If only four params, two input tensors have a dimension of 2 and cannot be batched.
if test_vector["test_bias"] and len(test_vector["params"]) > 4 and test_vector["params"][0] > 1:
return True, "Batched input not supported when bias exists"
return False, None


def run_matmul(device, params, core_grid, dtype, test_bias):
# Cannot have bias and batch. If only four params, two input tensors have a dimension of 2 and cannot be batched.
if test_bias and len(params) > 4 and params[0] > 1:
pytest.skip("Batched input not supported when bias exists")
if core_grid == False:
grid = None
else:
grid = device.core_grid
if dtype == ttnn.bfloat16:
compute_kernel_config = None
else:
compute_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi2,
math_approx_mode=False,
fp32_dest_acc_en=True,
packer_l1_acc=True,
)

count = len(params)
half = int(count / 2)
shape0 = params[0:half]
shape1 = params[half:count]
shape2 = [shape1[-1]]
torch_input_tensor0 = torch.rand(shape0, dtype=torch.float32)
torch_input_tensor1 = torch.rand(shape1, dtype=torch.float32)
torch_input_tensor2 = torch.rand(shape2, dtype=torch.float32)
torch_output_tensor = torch.matmul(torch_input_tensor0, torch_input_tensor1)
if test_bias:
torch_output_tensor += torch_input_tensor2

input_tensor0 = ttnn.from_torch(torch_input_tensor0, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor1 = ttnn.from_torch(torch_input_tensor1, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor2 = ttnn.from_torch(torch_input_tensor2, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device)

start_time = start_measuring_time()
output_tensor = ttnn.matmul(input_tensor0, input_tensor1, core_grid=grid)
if test_bias:
output_tensor = ttnn.linear(
input_tensor0,
input_tensor1,
core_grid=grid,
compute_kernel_config=compute_kernel_config,
bias=input_tensor2,
)
else:
output_tensor = ttnn.matmul(
input_tensor0, input_tensor1, core_grid=grid, compute_kernel_config=compute_kernel_config
)
output_tensor = ttnn.to_torch(output_tensor)
e2e_perf = stop_measuring_time(start_time)
expected_pcc = 0.99
Expand All @@ -2640,29 +2682,33 @@ def run_matmul(device, params, core_grid, dtype):
@pytest.mark.parametrize("params", parameters["pytorch"]["params"])
@pytest.mark.parametrize("core_grid", parameters["pytorch"]["core_grid"])
@pytest.mark.parametrize("dtype", parameters["pytorch"]["dtype"])
def test_pytorch(device, params, core_grid, dtype):
run_matmul(device, params, core_grid, dtype)
@pytest.mark.parametrize("test_bias", parameters["pytorch"]["test_bias"])
def test_pytorch(device, params, core_grid, dtype, test_bias):
run_matmul(device, params, core_grid, dtype, test_bias)


@pytest.mark.parametrize("params", parameters["gpt"]["params"])
@pytest.mark.parametrize("core_grid", parameters["gpt"]["core_grid"])
@pytest.mark.parametrize("dtype", parameters["gpt"]["dtype"])
def test_gpt(device, params, core_grid, dtype):
run_matmul(device, params, core_grid, dtype)
@pytest.mark.parametrize("test_bias", parameters["gpt"]["test_bias"])
def test_gpt(device, params, core_grid, dtype, test_bias):
run_matmul(device, params, core_grid, dtype, test_bias)


@pytest.mark.parametrize("params", parameters["forge"]["params"])
@pytest.mark.parametrize("core_grid", parameters["forge"]["core_grid"])
@pytest.mark.parametrize("dtype", parameters["forge"]["dtype"])
def test_forge(device, params, core_grid, dtype):
run_matmul(device, params, core_grid, dtype)
@pytest.mark.parametrize("test_bias", parameters["forge"]["test_bias"])
def test_forge(device, params, core_grid, dtype, test_bias):
run_matmul(device, params, core_grid, dtype, test_bias)


def run(
params,
core_grid,
dtype,
test_bias,
*,
device,
) -> list:
return run_matmul(device, params, core_grid, dtype)
return run_matmul(device, params, core_grid, dtype, test_bias)
2 changes: 1 addition & 1 deletion tests/ttnn/integration_tests/bert/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def preprocess_inputs(
def get_expected_times(bert):
return {
ttnn_bert: (0.1, 0.1),
ttnn_optimized_bert: (5.55, 0.07),
ttnn_optimized_bert: (5.55, 0.11),
ttnn_optimized_sharded_bert: (5.7, 0.07),
}[bert]

Expand Down
2 changes: 1 addition & 1 deletion tests/ttnn/integration_tests/whisper/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def get_expected_times(functional_whisper):
return {
ttnn_functional_whisper: (11.7, 4.16),
ttnn_optimized_functional_whisper: (1.57, 1.35),
ttnn_optimized_functional_whisper: (1.65, 1.35),
}[functional_whisper]


Expand Down
33 changes: 32 additions & 1 deletion tests/ttnn/unit_tests/operations/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import ttnn

from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import torch_random, is_wormhole_b0
from models.utility_functions import torch_random, is_wormhole_b0, skip_for_grayskull


@pytest.mark.parametrize("batch_sizes", [(1,)])
Expand Down Expand Up @@ -310,3 +310,34 @@ def test_linear_by_passing_in_1D_systolic_array_program_config_and_optional_outo
assert_with_pcc(torch_output_tensor, output_tensor, 0.997)
assert_with_pcc(torch_output_tensor, optional_output_tensor, 0.997)
assert_with_pcc(optional_output_tensor, output_tensor, 0.997)


@skip_for_grayskull()
def test_linear_with_fp32_dest_acc_and_bias(device):
torch.manual_seed(0)
torch_input_tensor_a = torch.rand([64, 1, 256, 384])
torch_input_tensor_b = torch.rand([1, 1, 1152, 384])
torch_input_tensor_c = torch.rand([1, 1, 1, 1152])
compute_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi2,
math_approx_mode=False,
fp32_dest_acc_en=True,
packer_l1_acc=True,
)
torch_output_tensor = torch.matmul(torch_input_tensor_a, torch.transpose(torch_input_tensor_b, -1, -2))
torch_output_tensor += torch_input_tensor_c

input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16)
input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16)
input_tensor_c = ttnn.from_torch(torch_input_tensor_c, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16)

output1 = ttnn.linear(
input_tensor_a,
input_tensor_b,
bias=input_tensor_c,
compute_kernel_config=compute_kernel_config,
core_grid=ttnn.CoreGrid(y=8, x=7),
transpose_b=True,
)
output_tensor = ttnn.to_torch(output1)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)
26 changes: 13 additions & 13 deletions tt-train/tests/model/gpt2s_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,27 +45,27 @@ TEST_F(GPT2SBatch64Test, Matmul) {
{{{64, 12, 1024, 1024}, {64, 12, 1024, 64}, false, false}, ExpectedResult::OK},
{{{768, 65536}, {65536, 96}, false, false}, ExpectedResult::OK},
{{{65536, 768}, {65536, 96}, true, false}, ExpectedResult::OK},
{{{65536, 96}, {1, 1, 96, 768}, false, false}, ExpectedResult::ERROR},
{{{65536, 96}, {1, 1, 768, 96}, false, true}, ExpectedResult::ERROR},
{{{65536, 96}, {1, 1, 96, 768}, false, false}, ExpectedResult::OK},
{{{65536, 96}, {1, 1, 768, 96}, false, true}, ExpectedResult::OK},
{{{3072, 65536}, {65536, 768}, false, false}, ExpectedResult::OK},
{{{65536, 3072}, {65536, 768}, true, false}, ExpectedResult::OK},
{{{65536, 768}, {1, 1, 768, 3072}, false, false}, ExpectedResult::ERROR},
{{{65536, 768}, {1, 1, 3072, 768}, false, true}, ExpectedResult::ERROR},
{{{65536, 768}, {1, 1, 768, 3072}, false, false}, ExpectedResult::OK},
{{{65536, 768}, {1, 1, 3072, 768}, false, true}, ExpectedResult::OK},
{{{768, 65536}, {65536, 3072}, false, false}, ExpectedResult::OK},
{{{65536, 768}, {65536, 3072}, true, false}, ExpectedResult::OK},
{{{65536, 3072}, {1, 1, 3072, 768}, false, false}, ExpectedResult::ERROR},
{{{65536, 3072}, {1, 1, 768, 3072}, false, true}, ExpectedResult::ERROR},
{{{65536, 3072}, {3072, 768}, false, false}, ExpectedResult::ERROR},
{{{65536, 3072}, {768, 3072}, false, true}, ExpectedResult::ERROR},
{{{65536, 3072}, {1, 1, 3072, 768}, false, false}, ExpectedResult::OK},
{{{65536, 3072}, {1, 1, 768, 3072}, false, true}, ExpectedResult::OK},
{{{65536, 3072}, {3072, 768}, false, false}, ExpectedResult::OK},
{{{65536, 3072}, {768, 3072}, false, true}, ExpectedResult::OK},
{{{768, 65536}, {65536, 768}, false, false}, ExpectedResult::OK},
{{{65536, 768}, {65536, 768}, true, false}, ExpectedResult::OK},
{{{65536, 768}, {1, 1, 768, 768}, false, false}, ExpectedResult::ERROR},
{{{768, 65536}, {1, 1, 768, 768}, true, false}, ExpectedResult::ERROR},
{{{65536, 768}, {1, 1, 768, 768}, false, false}, ExpectedResult::OK},
{{{768, 65536}, {1, 1, 768, 768}, true, false}, ExpectedResult::OK},
{{{768, 65536}, {65536, 2304}, false, false}, ExpectedResult::OK},
{{{65536, 768}, {65536, 2304}, true, false}, ExpectedResult::OK},
{{{65536, 768}, {768, 50257}, false, false}, ExpectedResult::ERROR},
{{{65536, 768}, {50304, 768}, false, true}, ExpectedResult::ERROR},
{{{65536, 50304}, {50304, 768}, false, false}, ExpectedResult::ERROR},
{{{65536, 768}, {768, 50257}, false, false}, ExpectedResult::OK},
{{{65536, 768}, {50304, 768}, false, true}, ExpectedResult::OK},
{{{65536, 50304}, {50304, 768}, false, false}, ExpectedResult::OK},
};

auto run_matmul = [](auto& a, auto& b, bool transpose_a, bool transpose_b) {
Expand Down
Loading

0 comments on commit 3787a8a

Please sign in to comment.