Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Sparsity] When sparsifying using Wanda on only Linear layers, PerChannelNormObserver() being added to embedding layers, leading to RuntimeError: linalg.vector_norm: Expected a floating point or complex tensor as input. Got Long #1133

Open
agrawal-aka opened this issue Oct 22, 2024 · 5 comments
Assignees

Comments

@agrawal-aka
Copy link

agrawal-aka commented Oct 22, 2024

Hello, I created a test script which I was testing on Aarch64 platform, for distilbert inference and using wanda sparsifier:

import torch
from transformers import BertForSequenceClassification, BertTokenizer, pipeline
from torch.ao.pruning import WeightNormSparsifier
from torch.profiler import profile, record_function, ProfilerActivity
import torch.profiler
from torchao.sparsity.wanda import WandaSparsifier
from torchao.quantization.quant_api import _is_linear
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

torch.manual_seed(100)

sparsifier = WandaSparsifier(
    sparsity_level=0.6
)

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")

sparse_config = []
for name, mod in model.named_modules():
    if _is_linear(mod, name):
        sparse_config.append({"tensor_fqn": f"{name}.weight"})

print("sparse config:",sparse_config)
sparsifier.prepare(model, sparse_config)
#print(model.distilbert.embeddings)

#Calibration samples - for wanda
calibration_texts = [
    "I love using CPUs for inference.",
    "This is a sample text for calibration.",
    "Calibration is important for pruning accuracy."]


# Tokenize and pass the calibration samples through the model
for text in calibration_texts:
    inputs = tokenizer(text, return_tensors="pt")
    #print(inputs)
    with torch.no_grad():
        model(**inputs)  # Forward pass to collect activation statistics

# Now that activation statistics have been collected, you can proceed with pruning
sparsifier.step()
sparsifier.squash_mask()

# Apply sparsity to linear layers and convert to CSR format
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear) and "layer" in name:
        # Convert dense weights to CSR format
        module.weight = torch.nn.Parameter(module.weight.to_sparse_csr())


# Set the model to evaluation mode
model.eval()

# Initialize Hugging Face sentiment analysis pipeline with the sparsified model
sentiment_analysis_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)

# Run inference using the sparsified model
input_text = "I really love using PyTorch"
inputs = tokenizer(input_text, return_tensors="pt")
with torch.no_grad():
    with profile(with_stack=True,
    profile_memory=True, record_shapes=True) as prof:
        outputs = model(**inputs)
print(prof.key_averages(group_by_input_shape=False).table(sort_by="self_cpu_time_total", row_limit=-1))
print(outputs)
prediction = model.config.id2label[outputs.logits.argmax().item()]

print(f"Predicted sentiment: {prediction}")

Which is raising the below error:
image

The issue is coming only when I use wanda, (as using weightnorm sparsifier doesnt create this issue, and everything seems to run fine, with the same sparse config). I understood that the problem is coming due to PerChannelNormObserver() being attached to embedding layers after sparsifier.prepare(), which internally triggers linalg.vector_norm :
image

so, as a workaround after sparsifier.prepare() is called, I reinitialised the embedding layers from pretrained model again, before passing in the calibration texts and was able to run the script successfully.

sparsifier.prepare(model, sparse_config)
#added workaround to remove observers from embedding layer to avoid error
model.distilbert.embeddings = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english").distilbert.embeddings

# Print the model to confirm that the observer has been removed from the embeddings block
print(model.distilbert.embeddings)

# Tokenize and pass the calibration samples through the model
#.

Am I missing something out here in my script? This shouldn't be the expected behaviour i suppose, because the sparse config is the same when I use either the weight norm or wanda, but it seems to work in one case and not in the other.

@jcaip jcaip self-assigned this Oct 22, 2024
@jcaip
Copy link
Contributor

jcaip commented Oct 22, 2024

Hi @agrawal-aka yeah, this looks like unexpected behavior to me. It looks to me like we need to add a check to make sure we apply the observers only to the layers in our config, I think we're applying observers to the whole model

@agrawal-aka
Copy link
Author

Yes @jcaip, I think that's correct. I verified FakeSparsity() is being added to only the layers specified in config, but the observers are being added everywhere regardless of the config, which needs to be corrected. And when observers are added to embedding layers, the runtime error pops up in forward pass.

@jcaip
Copy link
Contributor

jcaip commented Oct 23, 2024

@agrawal-aka Thanks for verifying - are you interested in opening a PR? No worries if not, I can make fix then

@agrawal-aka
Copy link
Author

Hi, @jcaip, Yes, I'm currently working on the fix and will raise a PR shortly. Thanks for checking!

@jcaip
Copy link
Contributor

jcaip commented Oct 25, 2024

Thanks! feel free to comment here if you need any help

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants