diff --git a/step/step_arch/graphwavenet/model.py b/step/step_arch/graphwavenet/model.py index 4ca335c..f664d59 100644 --- a/step/step_arch/graphwavenet/model.py +++ b/step/step_arch/graphwavenet/model.py @@ -99,13 +99,13 @@ def __init__(self, num_nodes, support_len, dropout=0.3, gcn_bool=True, addaptadj # dilated convolutions self.filter_convs.append(nn.Conv2d(in_channels=residual_channels, out_channels=dilation_channels, kernel_size=(1,kernel_size),dilation=new_dilation)) - self.gate_convs.append(nn.Conv1d(in_channels=residual_channels, out_channels=dilation_channels, kernel_size=(1, kernel_size), dilation=new_dilation)) + self.gate_convs.append(nn.Conv2d(in_channels=residual_channels, out_channels=dilation_channels, kernel_size=(1, kernel_size), dilation=new_dilation)) # 1x1 convolution for residual connection - self.residual_convs.append(nn.Conv1d(in_channels=dilation_channels, out_channels=residual_channels, kernel_size=(1, 1))) + self.residual_convs.append(nn.Conv2d(in_channels=dilation_channels, out_channels=residual_channels, kernel_size=(1, 1))) # 1x1 convolution for skip connection - self.skip_convs.append(nn.Conv1d(in_channels=dilation_channels, out_channels=skip_channels, kernel_size=(1, 1))) + self.skip_convs.append(nn.Conv2d(in_channels=dilation_channels, out_channels=skip_channels, kernel_size=(1, 1))) self.bn.append(nn.BatchNorm2d(residual_channels)) new_dilation *= 2 receptive_field += additional_scope