Skip to content

Commit

Permalink
Objective evaluate fix for lecture
Browse files Browse the repository at this point in the history
  • Loading branch information
staeros committed Nov 14, 2023
1 parent fa505cb commit bad34db
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions nas/optimizer/objective/future/nas_objective_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,13 @@ def _graph_fit(self, graph: NasGraph, train_data: InputData, log, fold_id) -> Ne
batch_size = self._requirements.model_requirements.batch_size
opt_epochs = self._requirements.opt_epochs

opt_data, val_data = train_test_data_setup(train_data, stratify=train_data.target)
opt_dataset = DataLoader(self._dataset_builder.build(opt_data), batch_size=batch_size, shuffle=shuffle_flag)
val_dataset = DataLoader(self._dataset_builder.build(val_data), batch_size=batch_size, shuffle=shuffle_flag)
# opt_data, val_data = train_test_data_setup(train_data, stratify=train_data.target)
opt_dataset = DataLoader(self._dataset_builder.build(train_data), batch_size=batch_size, shuffle=shuffle_flag)
# val_dataset = DataLoader(self._dataset_builder.build(val_data), batch_size=batch_size, shuffle=shuffle_flag)

input_shape = self._requirements.model_requirements.input_shape
trainer = self._model_trainer_builder.build(input_shape=input_shape, output_shape=classes, graph=graph)
trainer.fit_model(train_data=opt_dataset, val_data=val_dataset, epochs=opt_epochs)
trainer.fit_model(train_data=opt_dataset, epochs=opt_epochs)
return trainer

def _evaluate_fitted_model(self, fitted_model: NeuralSearchModel, test_data: InputData, graph: NasGraph,
Expand Down

0 comments on commit bad34db

Please sign in to comment.