Skip to content

Commit

Permalink
added ryzenai quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Mar 5, 2024
1 parent ac508c2 commit 873b73b
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 30 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ style:
install:
pip install -e .

## Docker builds
## Docker build

define build_docker
docker build -f docker/$(1).dockerfile --build-arg USER_ID=$(USER_ID) --build-arg GROUP_ID=$(GROUP_ID) -t opt-bench-$(1):local .
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ defaults:
experiment_name: ryzenai_resnet50

backend:
device: ipu
export: false
model: amd/resnet50
device: cpu
export: true
model: timm/mobilenetv3_large_100.ra_in1k
provider: CPUExecutionProvider
auto_quantization: cpu_cnn_config

benchmark:
input_shapes:
Expand Down
132 changes: 110 additions & 22 deletions optimum_benchmark/backends/ryzenai/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@

import torch
from hydra.utils import get_class
from onnxruntime import SessionOptions
from optimum.amd.ryzenai import AutoQuantizationConfig, QuantizationConfig, RyzenAIOnnxQuantizer
from optimum.exporters.onnx import main_export
from safetensors.torch import save_file
from transformers.utils.logging import set_verbosity_error

from ...generators.dataset_generator import DatasetGenerator
from ...task_utils import IMAGE_PROCESSING_TASKS, TEXT_GENERATION_TASKS
from ..base import Backend
from ..transformers_utils import random_init_weights
Expand All @@ -27,27 +31,51 @@ class RyzenAIBackend(Backend[RyzenAIConfig]):

def __init__(self, config: RyzenAIConfig) -> None:
super().__init__(config)
self.validate_task()

self.ryzenaimodel_class = get_class(TASKS_TO_RYZENAIMODEL[self.config.task])
LOGGER.info(f"\t+ Using RyzenAIModel class {self.ryzenaimodel_class.__name__}")

self.session_options = SessionOptions()
if self.config.session_options:
LOGGER.info("\t+ Processing session options")
for key, value in self.config.session_options.items():
setattr(self.session_options, key, value)

LOGGER.info("\t+ Creating backend temporary directory")
self.tmpdir = TemporaryDirectory()

if self.config.no_weights:
LOGGER.info("\t+ Loading no weights RyzenAIModel")
self.load_ryzenaimodel_with_no_weights()
if self.is_quantized:
if self.config.no_weights:
LOGGER.info("\t+ Loading no weights AutoModel")
self.load_automodel_with_no_weights()
else:
LOGGER.info("\t+ Loading pretrained AutoModel")
self.load_automodel_from_pretrained()

original_model, original_export = self.config.model, self.config.export

LOGGER.info("\t+ Exporting model to ONNX")
self.export_onnx_model()
self.config.model = self.exported_model

LOGGER.info("\t+ Applying RyzenAI quantization")
self.quantize_onnx_files()
self.config.model = self.quantized_model

self.config.export = False
LOGGER.info("\t+ Loading quantized RyzenAIModel")
self.load_ryzenaimodel_from_pretrained()

self.config.model, self.config.export = original_model, original_export

elif self.config.no_weights:
raise NotImplementedError("`no_weights` is only supported when RyzenAI model is quantized from scratch")
else:
LOGGER.info("\t+ Loading pretrained RyzenAIModel")
self.load_ryzenaimodel_from_pretrained()

self.tmpdir.cleanup()

def validate_task(self) -> None:
if self.config.task not in TASKS_TO_RYZENAIMODEL:
raise NotImplementedError(f"RyzenAIBackend does not support task {self.config.task}")

self.ryzenaimodel_class = get_class(TASKS_TO_RYZENAIMODEL[self.config.task])
LOGGER.info(f"\t+ Using RyzenAIModel class {self.ryzenaimodel_class.__name__}")

def create_no_weights_model(self) -> None:
self.no_weights_model = os.path.join(self.tmpdir.name, "no_weights_model")
LOGGER.info("\t+ Creating no weights model directory")
Expand All @@ -65,30 +93,44 @@ def create_no_weights_model(self) -> None:
def load_automodel_with_no_weights(self) -> None:
LOGGER.info("\t+ Creating no weights model")
self.create_no_weights_model()

with random_init_weights():
original_model, self.config.model = self.config.model, self.no_weights_model
LOGGER.info("\t+ Loading no weights AutoModel")
self.load_automodel_from_pretrained()
self.config.model = original_model

LOGGER.info("\t+ Tying model weights")
self.pretrained_model.tie_weights()
if self.config.library == "transformers":
LOGGER.info("\t+ Tying weights")
self.pretrained_model.tie_weights()

def export_onnx_model(self) -> None:
self.exported_model = f"{self.tmpdir.name}/exported_model"
main_export(
model_name_or_path=self.config.model,
output=self.exported_model,
task=self.config.task,
no_dynamic_axes=True,
batch_size=1,
opset=13,
)

def load_automodel_from_pretrained(self) -> None:
self.pretrained_model = self.automodel_class.from_pretrained(self.config.model, **self.config.hub_kwargs)
if self.config.library == "timm":
self.pretrained_model = self.automodel_class(model_name=self.config.model)
else:
self.pretrained_model = self.automodel_class.from_pretrained(self.config.model, **self.config.hub_kwargs)

def load_ryzenaimodel_with_no_weights(self) -> None:
LOGGER.info("\t+ Creating no weights model")
self.create_no_weights_model()

with random_init_weights():
original_model, self.config.model = self.config.model, self.no_weights_model
original_export, self.config.export = self.config.export, True
original_model, original_export = self.config.model, self.config.export

self.config.model, self.config.export = self.no_weights_model, False
LOGGER.info("\t+ Loading no weights RyzenAIModel")
self.load_ryzenaimodel_from_pretrained()
self.config.model = original_model
self.config.export = original_export

self.config.model, self.config.export = original_model, original_export

def load_ryzenaimodel_from_pretrained(self) -> None:
self.pretrained_model = self.ryzenaimodel_class.from_pretrained(
Expand All @@ -100,6 +142,52 @@ def load_ryzenaimodel_from_pretrained(self) -> None:
**self.ryzenaimodel_kwargs,
)

def quantize_onnx_files(self) -> None:
LOGGER.info("\t+ Attempting quantization")
self.quantized_model = f"{self.tmpdir.name}/quantized_model"

LOGGER.info("\t+ Processing quantization config")
if self.config.auto_quantization is not None:
auto_quantization_class = getattr(AutoQuantizationConfig, self.config.auto_quantization)
quantization_config = auto_quantization_class(**self.config.auto_quantization_config)
elif self.config.quantization:
quantization_config = QuantizationConfig(**quantization_config)

LOGGER.info("\t+ Generating calibration dataset")
dataset_shapes = {"dataset_size": 1, "sequence_length": 1, **self.model_shapes}
calibration_dataset = DatasetGenerator(
task=self.config.task, dataset_shapes=dataset_shapes, model_shapes=self.model_shapes
)()
calibration_dataset = calibration_dataset.remove_columns(["labels"])

for onnx_file_name in self.onnx_files_names:
LOGGER.info(f"\t+ Creating quantizer for {onnx_file_name}")
quantizer = RyzenAIOnnxQuantizer.from_pretrained(self.config.model, file_name=onnx_file_name)

LOGGER.info("\t+ Quantizing model")
quantizer.quantize(
save_dir=self.quantized_model,
quantization_config=quantization_config,
dataset=calibration_dataset,
# TODO: add support for these (maybe)
batch_size=1,
file_suffix="",
)

if self.pretrained_processor is not None:
self.pretrained_processor.save_pretrained(self.quantized_model)
if self.config.library == "transformers":
self.pretrained_config.save_pretrained(self.quantized_model)

@property
def onnx_files_names(self):
assert os.path.isdir(self.config.model), f"{self.config.model} is not a directory"
return [file for file in os.listdir(self.config.model) if file.endswith(".onnx")]

@property
def is_quantized(self) -> bool:
return self.config.quantization or self.config.auto_quantization

@property
def ryzenaimodel_kwargs(self) -> Dict[str, Any]:
kwargs = {}
Expand All @@ -112,8 +200,8 @@ def ryzenaimodel_kwargs(self) -> Dict[str, Any]:
def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
inputs = super().prepare_inputs(inputs)

if self.config.task in IMAGE_PROCESSING_TASKS:
# channels last
if not self.config.export and self.config.task in IMAGE_PROCESSING_TASKS:
# original amd ryzenai models expects channels first
inputs["pixel_values"] = inputs["pixel_values"].permute(0, 2, 3, 1)

return inputs
Expand Down
19 changes: 17 additions & 2 deletions optimum_benchmark/backends/ryzenai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, Optional

from ..config import BackendConfig
from .utils import TASKS_TO_RYZENAIMODEL


@dataclass
Expand All @@ -17,15 +18,29 @@ class RyzenAIConfig(BackendConfig):
export: bool = True
use_cache: bool = True

# session options
session_options: Dict[str, Any] = field(default_factory=dict)

# provider options
provider: Optional[str] = None
provider_options: Dict[str, Any] = field(default_factory=dict)

# ryzenai config
vaip_config: Optional[str] = None # /usr/bin/vaip_config.json

# auto quantization options
auto_quantization: Optional[str] = None # ipu_cnn_config, cpu_cnn_config
auto_quantization_config: Dict[str, Any] = field(default_factory=dict)

# manual quantization options
quantization: bool = False
quantization_config: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
super().__post_init__()

if self.device not in ["ipu", "npu"]:
raise ValueError(f"RyzenAIBackend only supports IPU/NPU device, got {self.device}")
if self.device not in ["cpu", "ipu", "npu"]:
raise ValueError(f"RyzenAIBackend only supports CPU & IPU/NPU devices, got {self.device}")

if self.task not in TASKS_TO_RYZENAIMODEL:
raise NotImplementedError(f"RyzenAIBackend does not support task {self.task}")
2 changes: 0 additions & 2 deletions optimum_benchmark/backends/timm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
def get_timm_pretrained_config(model_name: str) -> PretrainedConfig:
model_source, model_name = timm.models.parse_model_name(model_name)
if model_source == "hf-hub":
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
# load model weights + pretrained_cfg from Hugging Face hub.
pretrained_cfg, model_name = timm.models.load_model_config_from_hf(model_name)
return pretrained_cfg

Expand Down

0 comments on commit 873b73b

Please sign in to comment.