diff --git a/sahi/models/mmdet.py b/sahi/models/mmdet.py index f7673a12..0495ae73 100644 --- a/sahi/models/mmdet.py +++ b/sahi/models/mmdet.py @@ -216,15 +216,31 @@ def num_categories(self): @property def has_mask(self): """ - Returns if model output contains segmentation mask + Returns if model output contains segmentation mask. + Considers both single dataset and ConcatDataset scenarios. """ - # has_mask = self.model.model.with_mask - train_pipeline = self.model.cfg["train_dataloader"]["dataset"]["pipeline"] - has_mask = any( - isinstance(item, dict) and any("mask" in key and value is True for key, value in item.items()) - for item in train_pipeline - ) - return has_mask + + def check_pipeline_for_mask(pipeline): + return any( + isinstance(item, dict) and any("mask" in key and value is True for key, value in item.items()) + for item in pipeline + ) + + # Access the dataset from the configuration + dataset_config = self.model.cfg["train_dataloader"]["dataset"] + + if dataset_config["type"] == "ConcatDataset": + # If using ConcatDataset, check each dataset individually + datasets = dataset_config["datasets"] + for dataset in datasets: + if check_pipeline_for_mask(dataset["pipeline"]): + return True + else: + # Otherwise, assume a single dataset with its own pipeline + if check_pipeline_for_mask(dataset_config["pipeline"]): + return True + + return False @property def category_names(self):