From f60a56aa2437fe6d4482c7b50ceb4b6d510b580a Mon Sep 17 00:00:00 2001 From: Arjun Krishnakumar Date: Tue, 2 Jul 2024 19:48:30 +0200 Subject: [PATCH] fix bug in PredictorEvaluator for ZeroCost proxies --- naslib/defaults/predictor_evaluator.py | 21 ++++++++++++++++----- naslib/predictors/zerocost.py | 14 ++++++++++++-- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/naslib/defaults/predictor_evaluator.py b/naslib/defaults/predictor_evaluator.py index 320e81e395..78287aabda 100644 --- a/naslib/defaults/predictor_evaluator.py +++ b/naslib/defaults/predictor_evaluator.py @@ -10,9 +10,12 @@ from sklearn import metrics import math +from naslib.predictors.zerocost import ZeroCost from naslib.search_spaces.core.query_metrics import Metric from naslib.utils import generate_kfold, cross_validation +from naslib import utils + logger = logging.getLogger(__name__) @@ -47,6 +50,9 @@ def __init__(self, predictor, config=None): self.num_arches_to_mutate = 5 self.max_mutation_rate = 3 + # For ZeroCost proxies + self.dataloader = None + def adapt_search_space( self, search_space, load_labeled, scope=None, dataset_api=None ): @@ -70,6 +76,9 @@ def adapt_search_space( "This search space is not yet implemented in PredictorEvaluator." ) + if isinstance(self.predictor, ZeroCost): + self.dataloader, _, _, _, _ = utils.get_train_val_loaders(self.config) + def get_full_arch_info(self, arch): """ Given an arch, return the accuracy, train_time, @@ -139,10 +148,8 @@ def load_dataset(self, load_labeled=False, data_size=10, arch_hash_map={}): arch.load_labeled_architecture(dataset_api=self.dataset_api) arch_hash = arch.get_hash() - if False: # removing this for consistency, for now - continue - else: - arch_hash_map[arch_hash] = True + + arch_hash_map[arch_hash] = True accuracy, train_time, info_dict = self.get_full_arch_info(arch) xdata.append(arch) @@ -295,7 +302,11 @@ def single_evaluate(self, train_data, test_data, fidelity): hyperparams = self.predictor.get_hyperparams() fit_time_end = time.time() - test_pred = self.predictor.query(xtest, test_info) + if isinstance(self.predictor, ZeroCost): + [g.parse() for g in xtest] # parse the graphs because they will be used + test_pred = self.predictor.query_batch(xtest, self.dataloader) + else: + test_pred = self.predictor.query(xtest, test_info) query_time_end = time.time() # If the predictor is an ensemble, take the mean diff --git a/naslib/predictors/zerocost.py b/naslib/predictors/zerocost.py index fd10a488ae..0faa878705 100644 --- a/naslib/predictors/zerocost.py +++ b/naslib/predictors/zerocost.py @@ -4,6 +4,7 @@ based on https://github.com/mohsaied/zero-cost-nas """ import torch +import numpy as np import logging import math @@ -24,12 +25,21 @@ def __init__(self, method_type="jacov"): self.num_imgs_or_batches = 1 self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - def query(self, graph, dataloader=None, info=None): + def query_batch(self, graphs, dataloader): + scores = [] + + for graph in graphs: + score = self.query(graph, dataloader) + scores.append(score) + + return np.array(scores) + + def query(self, graph, dataloader): loss_fn = graph.get_loss_fn() n_classes = graph.num_classes score = predictive.find_measures( - net_orig=graph, + net_orig=graph.to(self.device), dataloader=dataloader, dataload_info=(self.dataload, self.num_imgs_or_batches, n_classes), device=self.device,