Skip to content

Commit

Permalink
Config object kept in trt inferencer
Browse files Browse the repository at this point in the history
  • Loading branch information
enesozi authored Jun 18, 2024
1 parent 909c164 commit b42fce6
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions sahi/models/mmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(self, deploy_cfg: str, model_cfg: str, engine_file: str, device: Op
deploy_cfg,
model_cfg,
)
self.cfg = model_cfg
self.task_processor = build_task_processor(model_cfg, deploy_cfg, device)
self.model = self.task_processor.build_backend_model(
[engine_file], self.task_processor.update_data_preprocessor
Expand Down Expand Up @@ -255,8 +256,6 @@ def has_mask(self):
"""
Returns if model output contains segmentation mask
"""

# has_mask = self.model.model.with_mask
train_pipeline = self.model.cfg["train_dataloader"]["dataset"]["pipeline"]
has_mask = any(isinstance(item, dict) and any("mask" in key for key in item.keys()) for item in train_pipeline)

Expand Down

0 comments on commit b42fce6

Please sign in to comment.