Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support OHD-SJTU dataset and object heading detection #704

Open
wants to merge 1 commit into
base: dev-1.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmrotate/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from .dota import DOTAv2Dataset # noqa: F401, F403
from .dota import DOTADataset, DOTAv15Dataset
from .hrsc import HRSCDataset # noqa: F401, F403
from .ohd_sjtu import OHD_SJTUDataset_L, OHD_SJTUDataset_S
from .transforms import * # noqa: F401, F403

__all__ = [
'DOTADataset', 'DOTAv15Dataset', 'DOTAv2Dataset', 'HRSCDataset',
'DIORDataset'
'DIORDataset', 'OHD_SJTUDataset_S', 'OHD_SJTUDataset_L'
]
148 changes: 148 additions & 0 deletions mmrotate/datasets/ohd_sjtu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright (c) OpenMMLab. All rights reserved.
import glob
import os.path as osp
from typing import List, Tuple

from mmengine.dataset import BaseDataset

from mmrotate.registry import DATASETS


@DATASETS.register_module()
class OHD_SJTUDataset_S(BaseDataset):
"""OHD-SJTU-S dataset for detection.

Note: 'ann_file' in OHD_SJTUDataset is different from the BaseDataset.
In BaseDataset, it is the path of an annotation file. In OHD_SJTUDataset,
it is the path of a folder containing txt files.

Args:
img_shape (tuple[int]):
diff_thr (int):
"""

METAINFO = {
'classes': ('ship', 'plane'),
# palette is a list of color tuples, which is used for visualization.
'palette': [(165, 42, 42), (189, 183, 107)]
}

def __init__(self,
img_shape: Tuple[int, int] = (1024, 1024),
diff_thr: int = 100,
**kwargs) -> None:
self.img_shape = img_shape
self.diff_thr = diff_thr
super().__init__(**kwargs)

def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as 'self.ann_file'.

Returns:
List[dict]: A list of annotation.
"""
assert self._metainfo.get('classes', None) is not None, \
"classes in 'OHD-SJTUDataset' can not be None"

cls_map = {c: i for i, c in enumerate(self.metainfo['classes'])}
data_list = []
if self.ann_file == '':
img_files = glob.glob(
osp.join(self.data_prefix['img_path'], '*.png'))
for img_path in img_files:
data_info = {'img_path': img_path}
img_name = osp.split(img_path)[1]
data_info['file_name'] = img_name
img_id = img_name[:-4]
data_info['img_id'] = img_id
data_info['height'] = self.img_shape[0]
data_info['width'] = self.img_shape[1]

instance = dict(
bbox=[], bbox_head=[], bbox_label=[], ignore_flag=0)
data_info['instances'] = [instance]
data_list.append(data_info)
return data_list
else:
txt_files = glob.glob(osp.join(self.ann_file, '*.txt'))
if len(txt_files) == 0:
raise ValueError('There is no txt file in '
f'{self.ann_file}')
for txt_file in txt_files:
img_id = osp.split(txt_file)[1][:-4]
data_info = {'img_id': img_id}
img_name = img_id + '.png'
data_info['file_name'] = img_name
data_info['img_path'] = osp.join(self.data_prefix['img_path'],
img_name)
data_info['height'] = self.img_shape[0]
data_info['width'] = self.img_shape[1]

instances = []
with open(txt_file) as f:
contents = f.readlines()
for content in contents:
bbox_info = content.split(' ')
instance = {'bbox': [float(i) for i in bbox_info[:10]]}
cls_name = bbox_info[-2]
instance['bbox_label'] = cls_map[cls_name]
difficulty = int(bbox_info[-1])
if difficulty > self.diff_thr:
instance['ignore_flag'] = 1
else:
instance['ignore_flag'] = 0
instances.append(instance)
data_info['instances'] = instances
data_list.append(data_info)
print(len(data_list))
return data_list

def filter_data(self) -> List[dict]:
"""Filter annotations according to filter_cfg.

Returns:
List[dict]: Filtered results.
"""
if self.test_mode:
return self.data_list

filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) \
if self.filter_cfg is not None else False

valid_data_infos = []
for i, data_info in enumerate(self.data_list):
if filter_empty_gt and len(data_info['instances']) == 0:
continue
valid_data_infos.append(data_info)

return valid_data_infos

def get_cat_ids(self, idx: int) -> List[int]:
"""Get OHD-SJTU category ids by index.

Args:
idx(int): Index of data.
Returns:
List[int]: All categories in the image of specified index.
"""

instances = self.get_data_info(idx)['instances']
return [instance['bbox_label'] for instance in instances]


@DATASETS.register_module()
class OHD_SJTUDataset_L(OHD_SJTUDataset_S):
"""OHD-SJTU-L dataset for detection.

Note: 'ann_file' in OHD_SJTUDataset is different from the BaseDataset.
In BaseDataset, it is the path of an annotation file. In OHD_SJTUDataset,
it is the path of a folder containing txt files.
"""

METAINFO = {
'classes': ('ship', 'plane', 'small-vehicle', 'large-vehicle',
'harbor', 'helicopter'),
# palette is a list of color tuples, which is used for visualization.
'palette': [(165, 42, 42), (189, 183, 107), (0, 255, 0), (255, 0, 0),
(138, 43, 226), (255, 128, 0)]
}
Loading