-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdecoder.py
35 lines (27 loc) · 1.37 KB
/
decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch.nn as nn
from options import HiDDenConfiguration
from conv_bn_relu import ConvBNRelu
class Decoder(nn.Module):
"""
Decoder module. Receives a watermarked image and extracts the watermark.
The input image may have various kinds of noise applied to it,
such as Crop, JpegCompression, and so on. See Noise layers for more.
"""
def __init__(self, config: HiDDenConfiguration):
super(Decoder, self).__init__()
self.channels = config.decoder_channels
layers = [ConvBNRelu(3, self.channels)]
for _ in range(config.decoder_blocks - 2):
layers.append(ConvBNRelu(self.channels, self.channels))
layers.append(ConvBNRelu(self.channels, self.channels, stride=2))
layers.append(ConvBNRelu(self.channels, config.message_length, stride=2))
layers.append(nn.AdaptiveAvgPool2d(output_size=(1, 1)))
self.layers = nn.Sequential(*layers)
self.linear = nn.Linear(config.message_length, config.message_length)
def forward(self, image_with_wm):
x = self.layers(image_with_wm)
# the output is of shape b x c x 1 x 1, and we want to squeeze out the last two dummy dimensions and make
# the tensor of shape b x c. If we just call squeeze_() it will also squeeze the batch dimension when b=1.
x.squeeze_(3).squeeze_(2)
x = self.linear(x)
return x