From 566e2738da2d83f055718d8edb609ad8dc325204 Mon Sep 17 00:00:00 2001 From: S22 <864453277@qq.com> Date: Mon, 11 Dec 2023 03:29:12 +0800 Subject: [PATCH] Fix conv1d erro. See issues #26 and #51 . --- step/step_arch/graphwavenet/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/step/step_arch/graphwavenet/model.py b/step/step_arch/graphwavenet/model.py index 4ca335c..f664d59 100644 --- a/step/step_arch/graphwavenet/model.py +++ b/step/step_arch/graphwavenet/model.py @@ -99,13 +99,13 @@ def __init__(self, num_nodes, support_len, dropout=0.3, gcn_bool=True, addaptadj # dilated convolutions self.filter_convs.append(nn.Conv2d(in_channels=residual_channels, out_channels=dilation_channels, kernel_size=(1,kernel_size),dilation=new_dilation)) - self.gate_convs.append(nn.Conv1d(in_channels=residual_channels, out_channels=dilation_channels, kernel_size=(1, kernel_size), dilation=new_dilation)) + self.gate_convs.append(nn.Conv2d(in_channels=residual_channels, out_channels=dilation_channels, kernel_size=(1, kernel_size), dilation=new_dilation)) # 1x1 convolution for residual connection - self.residual_convs.append(nn.Conv1d(in_channels=dilation_channels, out_channels=residual_channels, kernel_size=(1, 1))) + self.residual_convs.append(nn.Conv2d(in_channels=dilation_channels, out_channels=residual_channels, kernel_size=(1, 1))) # 1x1 convolution for skip connection - self.skip_convs.append(nn.Conv1d(in_channels=dilation_channels, out_channels=skip_channels, kernel_size=(1, 1))) + self.skip_convs.append(nn.Conv2d(in_channels=dilation_channels, out_channels=skip_channels, kernel_size=(1, 1))) self.bn.append(nn.BatchNorm2d(residual_channels)) new_dilation *= 2 receptive_field += additional_scope