Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jcaip/llm bsr #1601

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft

Jcaip/llm bsr #1601

wants to merge 16 commits into from

Conversation

jcaip
Copy link
Contributor

@jcaip jcaip commented Jan 22, 2025

This PR promotes Supermask and block sparsity from prototype -> torchao.sparsity

It adds a new public API for SupermaskLinear, which users can use to add Supermask to their models with

sparsify_(model, lambda x: SupermaskLinear.to_dense(x, sparsity_level=0.9)

I have also modified all the existing supermask sam testing code to use this new API.

It also ports over the triton addmm kernels from core, to let us modify them as needed. I've added padding support into the triton kernel, which was a 4 tok/s improvement (214 -> 218).

  • Adds padding to BSR

On a H100 benchmarking with the following commands

export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder
export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B

python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt --prefill_size 8192 --profile baseline_prefill
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt --prefill_size 8192 --sparsity bsr --profile bsr_prefill
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt --profile baseline
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt --sparsity bsr --profile bsr

yields a 134 -> 218 tok/s improvemnt on LLM decoding.

Copy link

pytorch-bot bot commented Jan 22, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1601

Note: Links to docs will display an error until the docs builds have been completed.

❌ 10 New Failures

As of commit b414b49 with merge base 11333ba (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 22, 2025
offsets = tl.arange(0, 16)[None, :]
dense_block = tl.load(
dense_block_ptrs + dense_tiled_row_stride * dense_row_idx,
mask=offsets < BLOCKSIZE_COL,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @cpuhrsch masking added in here for the padding

row_block_arange = tl.arange(0, BLOCKSIZE_ROW)
inner_block_arange = tl.arange(0, BLOCKSIZE_INNER)

if BLOCKSIZE_COL < 16 or BLOCKSIZE_COL % 16 != 0:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the padding logic (need to do this properly instead of hardcoding 16)



@implements(aten.sum.dim_IntList)
def block_sparse_sum(func, types, args, kwargs):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @cpuhrsch This computes the sum properly for the fast path reduction, but doesn't work with compile because of L300 temp_sum = bsr.values()[start:stop] which errors out on data dependent flow.

I think we can add a new kernel to the bsr_dense_addmm implementation to handle the fast path there instead, and rewrite this using triton

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, so here it's useful to view crow_indices and values as a NestedTensor and then use sum from there :) This is possible because values + crow_indices is like values + offsets.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants