From 18f0503ef4a8a973d246bfbc06c77a544fb0c2e1 Mon Sep 17 00:00:00 2001 From: ZhangYiqin <312065559@qq.com> Date: Mon, 31 Jul 2023 10:43:39 +0800 Subject: [PATCH 1/3] Add channel argments to mae_head When trying iTPN pretrain, it only supports images with 3 channels. One of the restrictions is from MAEHead. --- mmpretrain/models/heads/mae_head.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/mmpretrain/models/heads/mae_head.py b/mmpretrain/models/heads/mae_head.py index 1a5366d13b5..b76ecedd96d 100644 --- a/mmpretrain/models/heads/mae_head.py +++ b/mmpretrain/models/heads/mae_head.py @@ -14,15 +14,18 @@ class MAEPretrainHead(BaseModule): norm_pix_loss (bool): Whether or not normalize target. Defaults to False. patch_size (int): Patch size. Defaults to 16. + in_channels (int): Number of input channels. Defaults to 3. """ def __init__(self, loss: dict, norm_pix: bool = False, - patch_size: int = 16) -> None: + patch_size: int = 16, + in_channels: int = 3) -> None: super().__init__() self.norm_pix = norm_pix self.patch_size = patch_size + self.in_channels = in_channels self.loss_module = MODELS.build(loss) def patchify(self, imgs: torch.Tensor) -> torch.Tensor: @@ -30,19 +33,19 @@ def patchify(self, imgs: torch.Tensor) -> torch.Tensor: Args: imgs (torch.Tensor): A batch of images. The shape should - be :math:`(B, 3, H, W)`. + be :math:`(B, C, H, W)`. Returns: torch.Tensor: Patchified images. The shape is - :math:`(B, L, \text{patch_size}^2 \times 3)`. + :math:`(B, L, \text{patch_size}^2 \times C)`. """ p = self.patch_size assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 h = w = imgs.shape[2] // p - x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = imgs.reshape(shape=(imgs.shape[0], self.in_channels, h, p, w, p)) x = torch.einsum('nchpwq->nhwpqc', x) - x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * self.in_channels)) return x def unpatchify(self, x: torch.Tensor) -> torch.Tensor: @@ -50,18 +53,18 @@ def unpatchify(self, x: torch.Tensor) -> torch.Tensor: Args: x (torch.Tensor): The shape is - :math:`(B, L, \text{patch_size}^2 \times 3)`. + :math:`(B, L, \text{patch_size}^2 \times C)`. Returns: - torch.Tensor: The shape is :math:`(B, 3, H, W)`. + torch.Tensor: The shape is :math:`(B, C, H, W)`. """ p = self.patch_size h = w = int(x.shape[1]**.5) assert h * w == x.shape[1] - x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) + x = x.reshape(shape=(x.shape[0], h, w, p, p, self.in_channels)) x = torch.einsum('nhwpqc->nchpwq', x) - imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) + imgs = x.reshape(shape=(x.shape[0], self.in_channels, h * p, h * p)) return imgs def construct_target(self, target: torch.Tensor) -> torch.Tensor: @@ -71,7 +74,7 @@ def construct_target(self, target: torch.Tensor) -> torch.Tensor: normalize the image according to ``norm_pix``. Args: - target (torch.Tensor): Image with the shape of B x 3 x H x W + target (torch.Tensor): Image with the shape of B x C x H x W Returns: torch.Tensor: Tokenized images with the shape of B x L x C From b49deecebe6751fe8d70d3f85c94a084311089ac Mon Sep 17 00:00:00 2001 From: ZhangYiqin <312065559@qq.com> Date: Mon, 31 Jul 2023 10:47:34 +0800 Subject: [PATCH 2/3] Transfer other argments from iTPNHiViT to HiViT The HiViT supports specifying channels, but the iTPNHiViT class can't pass channel argments to it. This is one of the reasons that iTPNHiViT implementation only support images with 3 channels. --- mmpretrain/models/selfsup/itpn.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mmpretrain/models/selfsup/itpn.py b/mmpretrain/models/selfsup/itpn.py index 85efd254053..271a93e9c19 100644 --- a/mmpretrain/models/selfsup/itpn.py +++ b/mmpretrain/models/selfsup/itpn.py @@ -64,6 +64,7 @@ def __init__( layer_scale_init_value: float = 0.0, mask_ratio: float = 0.75, reconstruction_type: str = 'pixel', + **kwargs, ): super().__init__( arch=arch, @@ -80,7 +81,8 @@ def __init__( norm_cfg=norm_cfg, ape=ape, rpe=rpe, - layer_scale_init_value=layer_scale_init_value) + layer_scale_init_value=layer_scale_init_value, + **kwargs,) self.pos_embed.requires_grad = False self.mask_ratio = mask_ratio From c8af8e07cf2d9826c9a8ccbaed969a5019186dbc Mon Sep 17 00:00:00 2001 From: ZhangYiqin <312065559@qq.com> Date: Sun, 3 Sep 2023 18:48:52 +0800 Subject: [PATCH 3/3] Update itpn.py Fix hint problem --- mmpretrain/models/selfsup/itpn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmpretrain/models/selfsup/itpn.py b/mmpretrain/models/selfsup/itpn.py index 271a93e9c19..488a9963182 100644 --- a/mmpretrain/models/selfsup/itpn.py +++ b/mmpretrain/models/selfsup/itpn.py @@ -82,7 +82,8 @@ def __init__( ape=ape, rpe=rpe, layer_scale_init_value=layer_scale_init_value, - **kwargs,) + **kwargs, + ) self.pos_embed.requires_grad = False self.mask_ratio = mask_ratio