Skip to content

Commit

Permalink
make prepare_for_eval backward compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
Niccolo-Ajroldi committed Oct 21, 2024
1 parent d9c4ee9 commit 9caedc5
Showing 1 changed file with 22 additions and 20 deletions.
42 changes: 22 additions & 20 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,25 +378,27 @@ def train_once(
workload.eval_period_time_sec or train_state['training_complete']):

# Prepare for evaluation (timed).
with profiler.profile('Prepare for eval'):
del batch
prepare_for_eval_start_time = get_time()
optimizer_state, model_params, model_state = prepare_for_eval(
workload=workload,
current_param_container=model_params,
current_params_types=workload.model_params_types,
model_state=model_state,
hyperparameters=hyperparameters,
loss_type=workload.loss_type,
optimizer_state=optimizer_state,
eval_results=eval_results,
global_step=global_step,
rng=prep_eval_rng)
prepare_for_eval_end_time = get_time()

# Update sumbission time.
train_state['accumulated_submission_time'] += (
prepare_for_eval_end_time - prepare_for_eval_start_time)
if prepare_for_eval is not None:

with profiler.profile('Prepare for eval'):
del batch
prepare_for_eval_start_time = get_time()
optimizer_state, model_params, model_state = prepare_for_eval(
workload=workload,
current_param_container=model_params,
current_params_types=workload.model_params_types,
model_state=model_state,
hyperparameters=hyperparameters,
loss_type=workload.loss_type,
optimizer_state=optimizer_state,
eval_results=eval_results,
global_step=global_step,
rng=prep_eval_rng)
prepare_for_eval_end_time = get_time()

# Update sumbission time.
train_state['accumulated_submission_time'] += (
prepare_for_eval_end_time - prepare_for_eval_start_time)

# Check if time is remaining,
# use 3x the runtime budget for the self-tuning ruleset.
Expand Down Expand Up @@ -548,7 +550,7 @@ def score_submission_on_workload(workload: spec.Workload,
init_optimizer_state = submission_module.init_optimizer_state
update_params = submission_module.update_params
data_selection = submission_module.data_selection
prepare_for_eval = submission_module.prepare_for_eval
prepare_for_eval = getattr(submission_module, 'prepare_for_eval', None)
try:
global_batch_size = submission_module.get_batch_size(workload_name)
except ValueError:
Expand Down

0 comments on commit 9caedc5

Please sign in to comment.