You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[Sparsity] When sparsifying using Wanda on only Linear layers, PerChannelNormObserver() being added to embedding layers, leading to RuntimeError: linalg.vector_norm: Expected a floating point or complex tensor as input. Got Long
#1133
Open
agrawal-aka opened this issue
Oct 22, 2024
· 5 comments
Hello, I created a test script which I was testing on Aarch64 platform, for distilbert inference and using wanda sparsifier:
import torch
from transformers import BertForSequenceClassification, BertTokenizer, pipeline
from torch.ao.pruning import WeightNormSparsifier
from torch.profiler import profile, record_function, ProfilerActivity
import torch.profiler
from torchao.sparsity.wanda import WandaSparsifier
from torchao.quantization.quant_api import _is_linear
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
torch.manual_seed(100)
sparsifier = WandaSparsifier(
sparsity_level=0.6
)
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
sparse_config = []
for name, mod in model.named_modules():
if _is_linear(mod, name):
sparse_config.append({"tensor_fqn": f"{name}.weight"})
print("sparse config:",sparse_config)
sparsifier.prepare(model, sparse_config)
#print(model.distilbert.embeddings)
#Calibration samples - for wanda
calibration_texts = [
"I love using CPUs for inference.",
"This is a sample text for calibration.",
"Calibration is important for pruning accuracy."]
# Tokenize and pass the calibration samples through the model
for text in calibration_texts:
inputs = tokenizer(text, return_tensors="pt")
#print(inputs)
with torch.no_grad():
model(**inputs) # Forward pass to collect activation statistics
# Now that activation statistics have been collected, you can proceed with pruning
sparsifier.step()
sparsifier.squash_mask()
# Apply sparsity to linear layers and convert to CSR format
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and "layer" in name:
# Convert dense weights to CSR format
module.weight = torch.nn.Parameter(module.weight.to_sparse_csr())
# Set the model to evaluation mode
model.eval()
# Initialize Hugging Face sentiment analysis pipeline with the sparsified model
sentiment_analysis_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
# Run inference using the sparsified model
input_text = "I really love using PyTorch"
inputs = tokenizer(input_text, return_tensors="pt")
with torch.no_grad():
with profile(with_stack=True,
profile_memory=True, record_shapes=True) as prof:
outputs = model(**inputs)
print(prof.key_averages(group_by_input_shape=False).table(sort_by="self_cpu_time_total", row_limit=-1))
print(outputs)
prediction = model.config.id2label[outputs.logits.argmax().item()]
print(f"Predicted sentiment: {prediction}")
Which is raising the below error:
The issue is coming only when I use wanda, (as using weightnorm sparsifier doesnt create this issue, and everything seems to run fine, with the same sparse config). I understood that the problem is coming due to PerChannelNormObserver() being attached to embedding layers after sparsifier.prepare(), which internally triggers linalg.vector_norm :
so, as a workaround after sparsifier.prepare() is called, I reinitialised the embedding layers from pretrained model again, before passing in the calibration texts and was able to run the script successfully.
sparsifier.prepare(model, sparse_config)
#added workaround to remove observers from embedding layer to avoid error
model.distilbert.embeddings = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english").distilbert.embeddings
# Print the model to confirm that the observer has been removed from the embeddings block
print(model.distilbert.embeddings)
# Tokenize and pass the calibration samples through the model
#.
Am I missing something out here in my script? This shouldn't be the expected behaviour i suppose, because the sparse config is the same when I use either the weight norm or wanda, but it seems to work in one case and not in the other.
The text was updated successfully, but these errors were encountered:
Hi @agrawal-aka yeah, this looks like unexpected behavior to me. It looks to me like we need to add a check to make sure we apply the observers only to the layers in our config, I think we're applying observers to the whole model
Yes @jcaip, I think that's correct. I verified FakeSparsity() is being added to only the layers specified in config, but the observers are being added everywhere regardless of the config, which needs to be corrected. And when observers are added to embedding layers, the runtime error pops up in forward pass.
Hello, I created a test script which I was testing on Aarch64 platform, for distilbert inference and using wanda sparsifier:
Which is raising the below error:
The issue is coming only when I use wanda, (as using weightnorm sparsifier doesnt create this issue, and everything seems to run fine, with the same sparse config). I understood that the problem is coming due to PerChannelNormObserver() being attached to embedding layers after sparsifier.prepare(), which internally triggers linalg.vector_norm :
so, as a workaround after sparsifier.prepare() is called, I reinitialised the embedding layers from pretrained model again, before passing in the calibration texts and was able to run the script successfully.
Am I missing something out here in my script? This shouldn't be the expected behaviour i suppose, because the sparse config is the same when I use either the weight norm or wanda, but it seems to work in one case and not in the other.
The text was updated successfully, but these errors were encountered: