From b3f9f5d787f564cc86dbddaf75997f805a9cc90a Mon Sep 17 00:00:00 2001 From: KlemenSkrlj <47853619+klemen1999@users.noreply.github.com> Date: Tue, 17 Sep 2024 09:14:40 +0200 Subject: [PATCH] [Fix] Corrected config valid sequence for predefined models (#72) --- luxonis_train/utils/config.py | 42 +++++++++++++++++------------------ 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/luxonis_train/utils/config.py b/luxonis_train/utils/config.py index bdcd13dc..31e4fe5b 100644 --- a/luxonis_train/utils/config.py +++ b/luxonis_train/utils/config.py @@ -77,27 +77,6 @@ class ModelConfig(BaseModelExtraForbid): visualizers: list[AttachedModuleConfig] = [] outputs: list[str] = [] - @model_validator(mode="after") - def check_main_metric(self) -> Self: - for metric in self.metrics: - if metric.is_main_metric: - logger.info(f"Main metric: `{metric.name}`") - return self - - logger.warning("No main metric specified.") - if self.metrics: - metric = self.metrics[0] - metric.is_main_metric = True - name = metric.alias or metric.name - logger.info(f"Setting '{name}' as main metric.") - else: - logger.error( - "No metrics specified. " - "This is likely unintended unless " - "the configuration is not used for training." - ) - return self - @model_validator(mode="after") def check_predefined_model(self) -> Self: from luxonis_train.utils.registry import MODELS @@ -120,6 +99,27 @@ def check_predefined_model(self) -> Self: return self + @model_validator(mode="after") + def check_main_metric(self) -> Self: + for metric in self.metrics: + if metric.is_main_metric: + logger.info(f"Main metric: `{metric.name}`") + return self + + logger.warning("No main metric specified.") + if self.metrics: + metric = self.metrics[0] + metric.is_main_metric = True + name = metric.alias or metric.name + logger.info(f"Setting '{name}' as main metric.") + else: + logger.error( + "No metrics specified. " + "This is likely unintended unless " + "the configuration is not used for training." + ) + return self + @model_validator(mode="after") def check_graph(self) -> Self: from luxonis_train.utils.general import is_acyclic