diff --git a/model/attention/BAM.py b/model/attention/BAM.py index 2f51f0b..9f1d921 100644 --- a/model/attention/BAM.py +++ b/model/attention/BAM.py @@ -39,7 +39,7 @@ def __init__(self,channel,reduction=16,num_layers=3,dia_val=2): self.sa.add_module('bn_reduce1',nn.BatchNorm2d(channel//reduction)) self.sa.add_module('relu_reduce1',nn.ReLU()) for i in range(num_layers): - self.sa.add_module('conv_%d'%i,nn.Conv2d(kernel_size=3,in_channels=channel//reduction,out_channels=channel//reduction,padding=1,dilation=dia_val)) + self.sa.add_module('conv_%d'%i,nn.Conv2d(kernel_size=3,in_channels=channel//reduction,out_channels=channel//reduction,padding="same",dilation=dia_val)) self.sa.add_module('bn_%d'%i,nn.BatchNorm2d(channel//reduction)) self.sa.add_module('relu_%d'%i,nn.ReLU()) self.sa.add_module('last_conv',nn.Conv2d(channel//reduction,1,kernel_size=1)) @@ -76,7 +76,6 @@ def init_weights(self): init.constant_(m.bias, 0) def forward(self, x): - b, c, _, _ = x.size() sa_out=self.sa(x) ca_out=self.ca(x) weight=self.sigmoid(sa_out+ca_out) @@ -85,8 +84,9 @@ def forward(self, x): if __name__ == '__main__': - input=torch.randn(50,512,7,7) - bam = BAMBlock(channel=512,reduction=16,dia_val=2) + input=torch.randn(3,32,512,512) + bam = BAMBlock(channel=32,reduction=16,dia_val=2) + bam.eval() output=bam(input) print(output.shape)