forked from lab-emi/OpenDPD
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
115 lines (102 loc) · 4.9 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
__author__ = "Yizhuo Wu, Chang Gao"
__license__ = "Apache-2.0 License"
__email__ = "[email protected], [email protected]"
import torch
import torch.nn as nn
from backbones.rvtdcnn import RVTDCNN
class CoreModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, backbone_type):
super(CoreModel, self).__init__()
self.output_size = 2 # PA outputs: I & Q
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.backbone_type = backbone_type
self.batch_first = True # Force batch first
self.bidirectional = False
self.bias = True
if backbone_type == 'gmp':
from backbones.gmp import GMP
self.backbone = GMP()
elif backbone_type == 'gru':
from backbones.gru import GRU
self.backbone = GRU(input_size=self.input_size,
hidden_size=self.hidden_size,
output_size=self.output_size,
num_layers=self.num_layers,
bidirectional=self.bidirectional,
batch_first=self.batch_first,
bias=self.bias)
elif backbone_type == 'dgru':
from backbones.dgru import DGRU
self.backbone = DGRU(hidden_size=self.hidden_size,
output_size=self.output_size,
num_layers=self.num_layers,
bidirectional=self.bidirectional,
batch_first=self.batch_first,
bias=self.bias)
elif backbone_type == 'qgru':
from backbones.qgru import QGRU
self.backbone = QGRU(hidden_size=self.hidden_size,
output_size=self.output_size,
num_layers=self.num_layers,
bidirectional=self.bidirectional,
batch_first=self.batch_first,
bias=self.bias)
elif backbone_type == 'qgru_amp1':
from backbones.qgru_amp1 import QGRU
self.backbone = QGRU(hidden_size=self.hidden_size,
output_size=self.output_size,
num_layers=self.num_layers,
bidirectional=self.bidirectional,
batch_first=self.batch_first,
bias=self.bias)
elif backbone_type == 'lstm':
from backbones.lstm import LSTM
self.backbone = LSTM(input_size=self.input_size,
hidden_size=self.hidden_size,
output_size=self.output_size,
num_layers=self.num_layers,
bidirectional=self.bidirectional,
batch_first=self.batch_first,
bias=self.bias)
elif backbone_type == 'vdlstm':
from backbones.vdlstm import VDLSTM
self.backbone = VDLSTM(input_size=self.input_size,
hidden_size=self.hidden_size,
output_size=self.output_size,
num_layers=self.num_layers,
bidirectional=self.bidirectional,
batch_first=self.batch_first,
bias=self.bias)
elif backbone_type == 'rvtdcnn':
self.backbone = RVTDCNN(fc_hid_size=hidden_size)
else:
raise ValueError(f"The backbone type '{self.backbone_type}' is not supported. Please add your own "
f"backbone under ./backbones and update models.py accordingly.")
# Initialize backbone parameters
try:
self.backbone.reset_parameters()
print("Backbone Initialized...")
except AttributeError:
pass
def forward(self, x, h_0=None):
device = x.device
batch_size = x.size(0) # NOTE: dim of x must be (batch, time, feat)/(N, T, F)
if h_0 is None: # Create initial hidden states if necessary
h_0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
# Forward Propagate through the RNN
out = self.backbone(x, h_0)
return out
class CascadedModel(nn.Module):
def __init__(self, dpd_model, pa_model):
super(CascadedModel, self).__init__()
self.dpd_model = dpd_model
self.pa_model = pa_model
def freeze_pa_model(self):
for param in self.pa_model.parameters():
param.requires_grad = False
def forward(self, x):
x = self.dpd_model(x)
x = self.pa_model(x)
return x