diff --git a/optimum_benchmark/backends/base.py b/optimum_benchmark/backends/base.py index ddaeae6a..79e6ce6f 100644 --- a/optimum_benchmark/backends/base.py +++ b/optimum_benchmark/backends/base.py @@ -63,8 +63,8 @@ def __init__(self, config: BackendConfigT): elif self.config.library == "timm": self.logger.info("\t+ Benchmarking a Timm model") - self.model_shapes = extract_timm_shapes_from_config(self.pretrained_config) self.pretrained_config = get_timm_pretrained_config(self.config.model) + self.model_shapes = extract_timm_shapes_from_config(self.pretrained_config) self.automodel_loader = get_timm_automodel_loader() self.pretrained_processor = None self.generation_config = None