From 0932f0cea7712313b5f679141a7d83a470a7f192 Mon Sep 17 00:00:00 2001 From: "haowen.han@mthreads.com" Date: Tue, 20 Feb 2024 16:26:49 +0800 Subject: [PATCH] support musa --- mmrotate/evaluation/metrics/dota_metric.py | 7 ++-- mmrotate/models/losses/convex_giou_loss.py | 17 ++++++---- mmrotate/models/losses/rotated_iou_loss.py | 33 +++++++++++++------ .../evaluation/metrics/dota_r360_metric.py | 6 +++- tools/deployment/mmrotate_handler.py | 15 ++++++--- 5 files changed, 54 insertions(+), 24 deletions(-) diff --git a/mmrotate/evaluation/metrics/dota_metric.py b/mmrotate/evaluation/metrics/dota_metric.py index bd4564a4d..9fd055eb1 100644 --- a/mmrotate/evaluation/metrics/dota_metric.py +++ b/mmrotate/evaluation/metrics/dota_metric.py @@ -14,7 +14,7 @@ from mmengine.evaluator import BaseMetric from mmengine.fileio import dump from mmengine.logging import MMLogger - +from mmengine.device import is_musa_available, is_cuda_available from mmrotate.evaluation import eval_rbbox_map from mmrotate.registry import METRICS from mmrotate.structures.bbox import rbox2qbox @@ -157,7 +157,10 @@ def merge_results(self, results: Sequence[dict], big_img_results.append(dets[labels == i]) else: try: - cls_dets = torch.from_numpy(dets[labels == i]).cuda() + if is_musa_available(): + cls_dets = torch.from_numpy(dets[labels == i]).cuda() + elif is_musa_available(): + cls_dets = torch.from_numpy(dets[labels == i]).musa() except: # noqa: E722 cls_dets = torch.from_numpy(dets[labels == i]) if self.predict_box_type == 'rbox': diff --git a/mmrotate/models/losses/convex_giou_loss.py b/mmrotate/models/losses/convex_giou_loss.py index b1e73a584..9e900d9d6 100644 --- a/mmrotate/models/losses/convex_giou_loss.py +++ b/mmrotate/models/losses/convex_giou_loss.py @@ -6,7 +6,7 @@ from torch.autograd.function import once_differentiable from mmrotate.registry import MODELS - +from mmengine.device import is_musa_available, is_cuda_available class ConvexGIoULossFuction(Function): """The function of Convex GIoU loss.""" @@ -227,11 +227,16 @@ def forward(ctx, target_aspect = AspectRatio(target) smooth_loss_weight = torch.exp((-1 / 4) * target_aspect) - loss = \ - smooth_loss_weight * (diff_mean_loss.reshape(-1, 1).cuda() + - diff_corners_loss.reshape(-1, 1).cuda()) + \ - 1 - (1 - 2 * smooth_loss_weight) * convex_gious - + if is_cuda_available(): + loss = \ + smooth_loss_weight * (diff_mean_loss.reshape(-1, 1).cuda() + + diff_corners_loss.reshape(-1, 1).cuda()) + \ + 1 - (1 - 2 * smooth_loss_weight) * convex_gious + if is_musa_available(): + loss = \ + smooth_loss_weight * (diff_mean_loss.reshape(-1, 1).musa() + + diff_corners_loss.reshape(-1, 1).musa()) + \ + 1 - (1 - 2 * smooth_loss_weight) * convex_gious if weight is not None: loss = loss * weight grad = grad * weight.reshape(-1, 1) diff --git a/mmrotate/models/losses/rotated_iou_loss.py b/mmrotate/models/losses/rotated_iou_loss.py index 7944ce0af..50d239d2e 100644 --- a/mmrotate/models/losses/rotated_iou_loss.py +++ b/mmrotate/models/losses/rotated_iou_loss.py @@ -6,6 +6,7 @@ from mmdet.models.losses.utils import weighted_loss from mmrotate.registry import MODELS +from mmengine.device import is_musa_available, is_cuda_available try: from mmcv.ops import diff_iou_rotated_2d @@ -127,14 +128,26 @@ def forward(self, # iou_loss of shape (n,) assert weight.shape == pred.shape weight = weight.mean(-1) - with torch.cuda.amp.autocast(enabled=False): - loss = self.loss_weight * rotated_iou_loss( - pred, - target, - weight, - mode=self.mode, - eps=self.eps, - reduction=reduction, - avg_factor=avg_factor, - **kwargs) + if is_cuda_available(): + with torch.cuda.amp.autocast(enabled=False): + loss = self.loss_weight * rotated_iou_loss( + pred, + target, + weight, + mode=self.mode, + eps=self.eps, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + elif is_musa_available(): + with torch.musa.amp.autocast(enabled=False): + loss = self.loss_weight * rotated_iou_loss( + pred, + target, + weight, + mode=self.mode, + eps=self.eps, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) return loss diff --git a/projects/RR360/evaluation/metrics/dota_r360_metric.py b/projects/RR360/evaluation/metrics/dota_r360_metric.py index 71c5a674a..fe4f82a4b 100644 --- a/projects/RR360/evaluation/metrics/dota_r360_metric.py +++ b/projects/RR360/evaluation/metrics/dota_r360_metric.py @@ -14,6 +14,7 @@ from mmengine.evaluator import BaseMetric from mmengine.fileio import dump from mmengine.logging import MMLogger +from mmengine.device import is_musa_available, is_cuda_available from projects.RR360.evaluation import eval_rbbox_head_map from mmrotate.registry import METRICS @@ -159,7 +160,10 @@ def merge_results(self, results: Sequence[dict], big_img_results.append(dets[labels == i]) else: try: - cls_dets = torch.from_numpy(dets[labels == i]).cuda() + if is_cuda_available(): + cls_dets = torch.from_numpy(dets[labels == i]).cuda() + elif is_musa_available(): + cls_dets = torch.from_numpy(dets[labels == i]).musa() except: # noqa: E722 cls_dets = torch.from_numpy(dets[labels == i]) if self.predict_box_type == 'rbox': diff --git a/tools/deployment/mmrotate_handler.py b/tools/deployment/mmrotate_handler.py index 5991f1766..0320ae1ae 100644 --- a/tools/deployment/mmrotate_handler.py +++ b/tools/deployment/mmrotate_handler.py @@ -8,7 +8,7 @@ from ts.torch_handler.base_handler import BaseHandler import mmrotate # noqa: F401 - +from mmengine.device import is_musa_available, is_cuda_available class MMRotateHandler(BaseHandler): """MMRotate handler to load torchscript or eager mode [state_dict] @@ -23,10 +23,15 @@ def initialize(self, context): pertaining to the model artifacts parameters. """ properties = context.system_properties - self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu' - self.device = torch.device(self.map_location + ':' + - str(properties.get('gpu_id')) if torch.cuda. - is_available() else self.map_location) + if is_cuda_available(): + self.map_location = 'cuda' + self.device = torch.device(self.map_location + ':' + str(properties.get('gpu_id'))) + elif is_musa_available(): + self.map_location = 'musa' + self.device = torch.device(self.map_location + ':' + str(properties.get('gpu_id'))) + else: + self.map_location = 'cpu' + self.device = torch.device(self.map_location) self.manifest = context.manifest model_dir = properties.get('model_dir')