Skip to content

Commit

Permalink
fix failing type-check
Browse files Browse the repository at this point in the history
  • Loading branch information
JSabadin committed Jan 23, 2025
1 parent afc1603 commit 3f6c497
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def forward(
assigned_scores: Tensor,
mask_positive: Tensor,
):
max_assigned_scores_sum = max(assigned_scores.sum(), 1)
max_assigned_scores_sum = max(assigned_scores.sum().item(), 1)
loss_cls = (
self.bce(pred_scores, assigned_scores)
).sum() / max_assigned_scores_sum
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,22 @@ def __init__(

def prepare(
self, inputs: Packet[Tensor], labels: Labels
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
) -> tuple[
Tensor,
Tensor,
Tensor,
Tensor,
Tensor,
Tensor,
Tensor,
Tensor,
Tensor,
Tensor,
Tensor,
]:
det_feats = self.get_input_tensors(inputs, "features")
proto = self.get_input_tensors(inputs, "prototypes")
pred_mask = self.get_input_tensors(inputs, "mask_coeficients")
proto = self.get_input_tensors(inputs, "prototypes")[0]
pred_mask = self.get_input_tensors(inputs, "mask_coeficients")[0]
self._init_parameters(det_feats)
batch_size, _, mask_h, mask_w = proto.shape
pred_distri, pred_scores = torch.cat(
Expand Down Expand Up @@ -129,7 +141,7 @@ def forward(
target_masks: Tensor,
img_idx: Tensor,
):
max_assigned_scores_sum = max(assigned_scores.sum(), 1)
max_assigned_scores_sum = max(assigned_scores.sum().item(), 1)
loss_cls = (
self.bce(pred_scores, assigned_scores)
).sum() / max_assigned_scores_sum
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class MeanAveragePrecision(

def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self.is_segmentation = (
self.is_segmentation = (self.node.tasks is not None) and (
TaskType.INSTANCE_SEGMENTATION in self.node.tasks
)

Expand All @@ -38,7 +38,7 @@ def __init__(self, **kwargs: Any):
else:
iou_type = "bbox"

self.metric = detection.MeanAveragePrecision(iou_type=iou_type)
self.metric = detection.MeanAveragePrecision(iou_type=iou_type) # type: ignore

def update(
self,
Expand Down Expand Up @@ -77,7 +77,7 @@ def prepare(
"labels": output_nms_bboxes[i][:, 5].int(),
}
if self.is_segmentation:
pred["masks"] = output_nms_masks[i].to(
pred["masks"] = output_nms_masks[i].to( # type: ignore
dtype=torch.bool
) # Predicted masks (M, H, W)
output_list.append(pred)
Expand All @@ -93,7 +93,7 @@ def prepare(
"labels": curr_label[:, 1].int(),
}
if self.is_segmentation:
gt["masks"] = mask_label[box_label[:, 0] == i].to(
gt["masks"] = mask_label[box_label[:, 0] == i].to( # type: ignore
dtype=torch.bool
)
label_list.append(gt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def __init__(
self.alpha = alpha

def prepare(
self, inputs: Packet[Tensor], labels: Labels | None
) -> tuple[Tensor, Tensor, list[Tensor], Tensor | None, Tensor | None]:
self, inputs: Packet[Tensor], labels: Labels
) -> tuple[Tensor, Tensor, list[Tensor], list[Tensor]]:
# Override the prepare base method
target_bboxes = self.get_label(labels, TaskType.BOUNDINGBOX)
target_masks = self.get_label(labels, TaskType.INSTANCE_SEGMENTATION)
Expand Down Expand Up @@ -211,8 +211,8 @@ def forward(
prediction_canvas: Tensor,
target_bboxes: Tensor | None,
target_masks: Tensor | None,
predicted_bboxes: Tensor,
predicted_masks: Tensor,
predicted_bboxes: list[Tensor],
predicted_masks: list[Tensor],
) -> tuple[Tensor, Tensor] | Tensor:
"""Creates visualizations of the predicted and target bounding
boxes and instance masks.
Expand All @@ -229,12 +229,12 @@ def forward(
@type target_masks: Tensor | None
@param target_masks: Tensor containing the target instance
masks.
@type predicted_bboxes: Tensor
@param predicted_bboxes: Tensor containing the predicted
bounding boxes.
@type predicted_masks: Tensor
@param predicted_masks: Tensor containing the predicted instance
masks.
@type predicted_bboxes: list[Tensor]
@param predicted_bboxes: List of tensors containing the
predicted bounding boxes.
@type predicted_masks: list[Tensor]
@param predicted_masks: List of tensors containing the predicted
instance masks.
"""
predictions_viz = self.draw_predictions(
prediction_canvas,
Expand Down
16 changes: 9 additions & 7 deletions luxonis_train/nodes/heads/precision_bbox_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def forward(self, x: list[Tensor]) -> tuple[list[Tensor], list[Tensor]]:
cls_outputs = []
reg_outputs = []
for i in range(self.n_heads):
reg_output = self.detection_heads[i][0](x[i])
cls_output = self.detection_heads[i][1](x[i])
reg_output = self.detection_heads[i][0](x[i]) # type: ignore
cls_output = self.detection_heads[i][1](x[i]) # type: ignore
reg_outputs.append(reg_output)
cls_outputs.append(cls_output)
return reg_outputs, cls_outputs
Expand All @@ -153,7 +153,9 @@ def wrap(

if self.export:
return {
self.task: self._prepare_bbox_export(reg_outputs, cls_outputs)
"boundingbox": self._prepare_bbox_export(
reg_outputs, cls_outputs
)
}

boxes = non_max_suppression(
Expand Down Expand Up @@ -199,13 +201,13 @@ def _prepare_bbox_and_cls(

def _prepare_bbox_export(
self, reg_outputs: list[Tensor], cls_outputs: list[Tensor]
) -> Tensor:
) -> list[Tensor]:
"""Prepare the output for export."""
return self._prepare_bbox_and_cls(reg_outputs, cls_outputs)

def _prepare_bbox_inference_output(
self, reg_outputs: list[Tensor], cls_outputs: list[Tensor]
):
) -> Tensor:
"""Perform inference on predicted bounding boxes and class
probabilities."""
processed_outputs = self._prepare_bbox_and_cls(
Expand Down Expand Up @@ -254,8 +256,8 @@ def bias_init(self):
classification branches.
"""
for head, stride in zip(self.detection_heads, self.stride):
reg_branch = head[0]
cls_branch = head[1]
reg_branch = head[0] # type: ignore
cls_branch = head[1] # type: ignore

reg_conv = reg_branch[-1]
reg_conv.bias.data[:] = 1.0
Expand Down
19 changes: 10 additions & 9 deletions luxonis_train/nodes/heads/precision_seg_bbox_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(

def forward(
self, inputs: list[Tensor]
) -> tuple[list[Tensor], list[Tensor], list[Tensor], list[Tensor]]:
) -> tuple[tuple[list[Tensor], list[Tensor]], Tensor, list[Tensor]]:
prototypes = self.proto(inputs[0])
mask_coefficients = [
self.mask_layers[i](inputs[i]) for i in range(self.n_heads)
Expand All @@ -89,16 +89,17 @@ def forward(
return det_outs, prototypes, mask_coefficients

def wrap(
self, output: tuple[list[Tensor], Tensor, Tensor]
self,
output: tuple[tuple[list[Tensor], list[Tensor]], Tensor, list[Tensor]],
) -> Packet[Tensor]:
det_feats, prototypes, mask_coefficients = output

if self.export:
pred_bboxes = self._prepare_bbox_export(*det_feats)
pred_bboxes = self._prepare_bbox_export(*det_feats) # type: ignore
return {
"boundingbox": pred_bboxes,
"masks": mask_coefficients,
"prototypes": prototypes,
"prototypes": [prototypes],
}

det_feats_combined = [
Expand All @@ -115,11 +116,11 @@ def wrap(
if self.training:
return {
"features": det_feats_combined,
"prototypes": prototypes,
"mask_coeficients": mask_coefficients,
"prototypes": [prototypes],
"mask_coeficients": [mask_coefficients],
}

pred_bboxes = self._prepare_bbox_inference_output(*det_feats)
pred_bboxes = self._prepare_bbox_inference_output(*det_feats) # type: ignore
preds_combined = torch.cat(
[pred_bboxes, mask_coefficients.permute(0, 2, 1)], dim=-1
)
Expand All @@ -135,8 +136,8 @@ def wrap(

results = {
"features": det_feats_combined,
"prototypes": prototypes,
"mask_coeficients": mask_coefficients,
"prototypes": [prototypes],
"mask_coeficients": [mask_coefficients],
"boundingbox": [],
"instance_segmentation": [],
}
Expand Down

0 comments on commit 3f6c497

Please sign in to comment.