Skip to content

Commit

Permalink
Pass backend name to EnergyTracker in Training scenario (huggingface#279
Browse files Browse the repository at this point in the history
)
  • Loading branch information
asesorov authored and vicoooo26 committed Oct 23, 2024
1 parent 561deca commit c3908cb
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
4 changes: 3 additions & 1 deletion optimum_benchmark/scenarios/training/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport:

if self.config.energy:
self.logger.info("\t+ Creating energy tracking context manager")
energy_tracker = EnergyTracker(device=backend.config.device, device_ids=backend.config.device_ids)
energy_tracker = EnergyTracker(
device=backend.config.device, backend=backend.config.name, device_ids=backend.config.device_ids
)

if self.config.memory:
self.logger.info("\t+ Entering memory tracking context manager")
Expand Down
8 changes: 7 additions & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,13 @@ def test_api_launch(device, scenario, library, task, model):

if scenario == "training":
if library == "transformers":
scenario_config = TrainingConfig(memory=True, latency=True, warmup_steps=2, max_steps=5)
scenario_config = TrainingConfig(
memory=True,
latency=True,
energy=not is_rocm_system(),
warmup_steps=2,
max_steps=5,
)
else:
pytest.skip("Training scenario is only available for Transformers library")

Expand Down

0 comments on commit c3908cb

Please sign in to comment.