Skip to content

Commit

Permalink
fix: export
Browse files Browse the repository at this point in the history
  • Loading branch information
JSabadin committed Dec 12, 2024
1 parent 95ea9c2 commit a28f0c7
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 75 deletions.
41 changes: 18 additions & 23 deletions luxonis_train/nodes/blocks/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,31 +138,26 @@ def forward(self, x):


class DFL(nn.Module):
def __init__(self, channels: int = 16):
"""
Constructs the module with a convolutional layer using the specified input channels.
Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
@type channels: int
@param channels: Number of input channels. Defaults to 16.
def __init__(self, reg_max: int = 16):
"""The DFL (Distribution Focal Loss) module processes input
tensors by applying softmax over a specified dimension and
projecting the resulting tensor to produce output logits.
@type reg_max: int
@param reg_max: Maximum number of regression outputs. Defaults
to 16.
"""
super().__init__()
self.transform = nn.Conv2d(
channels, 1, kernel_size=1, bias=False
).requires_grad_(False)
weights = torch.arange(channels, dtype=torch.float32)
self.transform.weight.data.copy_(weights.view(1, channels, 1, 1))
self.num_channels = channels

def forward(self, input: Tensor):
"""Transforms the input tensor and returns the processed
output."""
batch_size, _, anchors = input.size()
reshaped = input.view(batch_size, 4, self.num_channels, anchors)
softmaxed = reshaped.transpose(2, 1).softmax(dim=1)
processed = self.transform(softmaxed)
return processed.view(batch_size, 4, anchors)
self.proj_conv = nn.Conv2d(reg_max, 1, kernel_size=1, bias=False)
self.proj_conv.weight.data.copy_(
torch.arange(reg_max, dtype=torch.float32).view(1, reg_max, 1, 1)
)
self.proj_conv.requires_grad_(False)

def forward(self, x: Tensor) -> Tensor:
bs, _, h, w = x.size()
x = F.softmax(x.view(bs, 4, -1, h * w).permute(0, 2, 1, 3), dim=1)
return self.proj_conv(x)[:, 0].view(bs, 4, h, w)


class ConvModule(nn.Sequential):
Expand Down
98 changes: 67 additions & 31 deletions luxonis_train/nodes/heads/precision_bbox_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,24 +126,38 @@ def __init__(
self.bias_init()
self.initialize_weights()

def forward(self, x: list[Tensor]) -> list[Tensor]:
def forward(self, x: list[Tensor]) -> tuple[list[Tensor], list[Tensor]]:
cls_outputs = []
reg_outputs = []
for i in range(self.n_heads):
reg_output = self.detection_heads[i][0](x[i])
cls_output = self.detection_heads[i][1](x[i])
x[i] = torch.cat((reg_output, cls_output), 1)
return x
reg_outputs.append(reg_output)
cls_outputs.append(cls_output)
return reg_outputs, cls_outputs

def wrap(self, output: list[Tensor]) -> Packet[Tensor]:
def wrap(
self, output: tuple[list[Tensor], list[Tensor]]
) -> Packet[Tensor]:
reg_outputs, cls_outputs = (
output # ([bs, 4*reg_max, h_f, w_f]), ([bs, n_classes, h_f, w_f])
)
features = [
torch.cat((reg, cls), dim=1)
for reg, cls in zip(reg_outputs, cls_outputs)
]
if self.training:
return {
"features": output,
"features": features,
}

if self.export:
return {self.task: [self._export_bbox_output(output)]}
return {
self.task: self._prepare_bbox_export(reg_outputs, cls_outputs)
}

boxes = non_max_suppression(
self._inference_bbox_output(output),
self._prepare_bbox_inference_output(reg_outputs, cls_outputs),
n_classes=self.n_classes,
conf_thres=self.conf_thres,
iou_thres=self.iou_thres,
Expand All @@ -153,7 +167,7 @@ def wrap(self, output: list[Tensor]) -> Packet[Tensor]:
)

return {
"features": output,
"features": features,
"boundingbox": boxes,
}

Expand All @@ -169,46 +183,68 @@ def _fit_stride_to_n_heads(self):
)
return stride

def _extract_cls_and_box(self, x: list[Tensor]):
def _prepare_bbox_and_cls(
self, reg_outputs: list[Tensor], cls_outputs: list[Tensor]
) -> list[Tensor]:
"""Extract classification and bounding box tensors."""
shape = x[0].shape
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
box, cls = x_cat.split((self.reg_max * 4, self.n_classes), 1)
return box, cls.sigmoid(), shape # Apply sigmoid to cls
output = []
for i in range(self.n_heads):
box = self.dfl(reg_outputs[i])
cls = cls_outputs[i].sigmoid()
conf = cls.max(1, keepdim=True)[0]
output.append(
torch.cat([box, conf, cls], dim=1)
) # [bs, 4 + 1 + n_classes, h_f, w_f]
return output

def _export_bbox_output(self, x: list[Tensor]):
def _prepare_bbox_export(
self, reg_outputs: list[Tensor], cls_outputs: list[Tensor]
) -> Tensor:
"""Prepare the output for export."""
box, cls, _ = self._extract_cls_and_box(x)
box_dist = self.dfl(box) # Shape: [N, 4, N_anchors]
conf, _ = cls.max(1, keepdim=True) # Shape: [N, 1, N_anchors]
export_output = torch.cat(
[box_dist, conf, cls], dim=1
) # Shape: [N, 4 + 1 + num_classes, N_anchors]
return export_output

def _inference_bbox_output(self, x: list[Tensor]):
return self._prepare_bbox_and_cls(reg_outputs, cls_outputs)

def _prepare_bbox_inference_output(
self, reg_outputs: list[Tensor], cls_outputs: list[Tensor]
):
"""Perform inference on predicted bounding boxes and class
probabilities."""
box, cls, shape = self._extract_cls_and_box(x)
box_dist = self.dfl(box)
processed_outputs = self._prepare_bbox_and_cls(
reg_outputs, cls_outputs
)
box_dists = []
class_probs = []
for feature in processed_outputs:
bs, _, h, w = feature.size()
reshaped = feature.view(bs, -1, h * w)
box_dist = reshaped[:, :4, :]
cls = reshaped[:, 5:, :]
box_dists.append(box_dist)
class_probs.append(cls)

box_dists = torch.cat(box_dists, dim=2)
class_probs = torch.cat(class_probs, dim=2)

_, anchor_points, _, strides = anchors_for_fpn_features(
x, self.stride, 0.5
processed_outputs, self.stride, 0.5
)

pred_bboxes = dist2bbox(
box_dist, anchor_points.transpose(0, 1), out_format="xyxy", dim=1
box_dists, anchor_points.transpose(0, 1), out_format="xyxy", dim=1
) * strides.transpose(0, 1)

base_output = [
pred_bboxes.permute(0, 2, 1),
pred_bboxes.permute(0, 2, 1), # [BS, H*W, 4]
torch.ones(
(shape[0], pred_bboxes.shape[2], 1),
(box_dists.shape[0], pred_bboxes.shape[2], 1),
dtype=pred_bboxes.dtype,
device=pred_bboxes.device,
),
cls.permute(0, 2, 1),
class_probs.permute(0, 2, 1), # [BS, H*W, n_classes]
]

output_merged = torch.cat(base_output, dim=-1)
output_merged = torch.cat(
base_output, dim=-1
) # [BS, H*W, 4 + 1 + n_classes]
return output_merged

def bias_init(self):
Expand Down
47 changes: 26 additions & 21 deletions luxonis_train/nodes/heads/precision_seg_bbox_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,10 @@ def forward(
self, inputs: list[Tensor]
) -> tuple[list[Tensor], list[Tensor], list[Tensor], list[Tensor]]:
prototypes = self.proto(inputs[0])
bs = prototypes.shape[0]
mask_coefficients = torch.cat(
[
self.mask_layers[i](inputs[i]).view(bs, self.n_masks, -1)
for i in range(self.n_heads)
],
dim=2,
)
mask_coefficients = [
self.mask_layers[i](inputs[i]) for i in range(self.n_heads)
]

det_outs = super().forward(inputs)

return det_outs, prototypes, mask_coefficients
Expand All @@ -96,25 +92,34 @@ def wrap(
self, output: tuple[list[Tensor], Tensor, Tensor]
) -> Packet[Tensor]:
det_feats, prototypes, mask_coefficients = output
if self.training:

if self.export:
pred_bboxes = self._prepare_bbox_export(*det_feats)
return {
"features": det_feats,
"boundingbox": pred_bboxes,
"masks": mask_coefficients,
"prototypes": prototypes,
"mask_coeficients": mask_coefficients,
}

if self.export:
pred_bboxes = self._export_bbox_output(det_feats)
det_feats_combined = [
torch.cat((reg, cls), dim=1) for reg, cls in zip(*det_feats)
]
mask_coefficients = torch.cat(
[
coef.view(coef.size(0), self.n_masks, -1)
for coef in mask_coefficients
],
dim=2,
)

if self.training:
return {
TaskType.INSTANCE_SEGMENTATION: [
torch.cat(
[pred_bboxes, mask_coefficients], 1
), # Shape: [N, 4 + 1 + num_classes + n_masks, N_anchors]
],
"prototypes": [prototypes], # Shape: [N, n_masks, H, W]
"features": det_feats_combined,
"prototypes": prototypes,
"mask_coeficients": mask_coefficients,
}

pred_bboxes = self._inference_bbox_output(det_feats)
pred_bboxes = self._prepare_bbox_inference_output(*det_feats)
preds_combined = torch.cat(
[pred_bboxes, mask_coefficients.permute(0, 2, 1)], dim=-1
)
Expand All @@ -129,7 +134,7 @@ def wrap(
)

results = {
"features": det_feats,
"features": det_feats_combined,
"prototypes": prototypes,
"mask_coeficients": mask_coefficients,
"boundingbox": [],
Expand Down

0 comments on commit a28f0c7

Please sign in to comment.