From 5f018b221f53aa095c679f47ca6c1943b88d6c48 Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Sun, 18 Jul 2021 14:17:17 +0800 Subject: [PATCH] feat: make it jittable --- kazane/decimate.py | 8 ++++---- kazane/fftconv.py | 10 +++++++--- kazane/upsample.py | 10 +++++----- tests/test_decimate.py | 4 ++-- tests/test_upsample.py | 8 ++++---- 5 files changed, 22 insertions(+), 18 deletions(-) diff --git a/kazane/decimate.py b/kazane/decimate.py index d764cce..1155938 100644 --- a/kazane/decimate.py +++ b/kazane/decimate.py @@ -7,8 +7,6 @@ from .upsample import _pad_to_block_2 from .fftconv import _custom_fft_conv1d -BLOCK_RATIO = 4 - class Decimate(nn.Module): r"""Downsampling by an integer amount. @@ -32,6 +30,8 @@ class Decimate(nn.Module): 250 """ + __constants__ = ['BLOCK_RATIO'] + BLOCK_RATIO: int = 4 def __init__(self, q: int = 2, @@ -54,7 +54,7 @@ def forward(self, x: torch.Tensor): shape = x.shape x = x.view(-1, 1, shape[-1]) - block_length = self.kernel.shape[-1] * self.stride * BLOCK_RATIO + block_length = self.kernel.shape[-1] * self.stride * self.BLOCK_RATIO out_size = shape[-1] // self.stride if shape[-1] < block_length: x = F.pad(x, [self.padding] * 2, mode='reflect') @@ -66,4 +66,4 @@ def forward(self, x: torch.Tensor): y = _custom_fft_conv1d(x, self.kernel, stride=self.stride) y = y.view(-1, num_blocks * y.shape[-1])[..., :out_size] - return y.view(*shape[:-1], -1) + return y.view(shape[:-1] + (-1,)) diff --git a/kazane/fftconv.py b/kazane/fftconv.py index 6a9501a..5ecdb67 100644 --- a/kazane/fftconv.py +++ b/kazane/fftconv.py @@ -22,7 +22,8 @@ def _custom_fft_conv1d(input: Tensor, weight: Tensor, X = rfft(input, n=s) W = rfft(weight, n=s) - W.imag.mul_(-1) + # W.imag.mul_(-1) + torch.view_as_real(W)[..., 1].mul_(-1) Y = X * W # handle stride @@ -32,8 +33,11 @@ def _custom_fft_conv1d(input: Tensor, weight: Tensor, step_size = new_n_fft // 2 strided_Y_size = step_size + 1 - unfolded_Y_real = Y.real.unfold(-1, strided_Y_size, step_size) - unfolded_Y_imag = Y.imag[..., + tmpY = torch.view_as_real(Y) + Y_real, Y_imag = tmpY[..., 0], tmpY[..., 1] + + unfolded_Y_real = Y_real.unfold(-1, strided_Y_size, step_size) + unfolded_Y_imag = Y_imag[..., 1:].unfold(-1, strided_Y_size - 2, step_size) Y_pos_real, Y_pos_imag = unfolded_Y_real[..., ::2, :].sum(-2), unfolded_Y_imag[..., ::2, :].sum(-2) diff --git a/kazane/upsample.py b/kazane/upsample.py index 42515de..0e76034 100644 --- a/kazane/upsample.py +++ b/kazane/upsample.py @@ -5,8 +5,6 @@ from .sinc import sinc_kernel from .fftconv import _custom_fft_conv1d -BLOCK_RATIO = 5 - def _pad_to_block(x, block_size): offset = x.shape[-1] % block_size @@ -15,7 +13,7 @@ def _pad_to_block(x, block_size): return x.view(*x.shape[:-1], -1, block_size) -def _pad_to_block_2(x, block_size, padding): +def _pad_to_block_2(x: torch.Tensor, block_size: int, padding: int): offset = x.shape[-1] % block_size if offset: offset = block_size - offset @@ -45,6 +43,8 @@ class Upsample(nn.Module): 1500 """ + __constants__ = ['BLOCK_RATIO'] + BLOCK_RATIO: int = 5 def __init__(self, q: int = 2, @@ -67,7 +67,7 @@ def forward(self, x: torch.Tensor): shape = x.shape x = x.view(-1, 1, shape[-1]) - block_length = self.kernel.shape[-1] * BLOCK_RATIO + block_length = self.kernel.shape[-1] * self.BLOCK_RATIO if shape[-1] < block_length: x = F.pad(x, [self.padding] * 2, mode='reflect') y = _custom_fft_conv1d(x, self.kernel) @@ -81,4 +81,4 @@ def forward(self, x: torch.Tensor): 2).reshape(-1, q, num_blocks * block_length)[..., :shape[-1]] y = y.transpose(1, 2).contiguous() - return y.view(*shape[:-1], -1) + return y.view(shape[:-1] + (-1,)) diff --git a/tests/test_decimate.py b/tests/test_decimate.py index cda5780..e84781d 100644 --- a/tests/test_decimate.py +++ b/tests/test_decimate.py @@ -29,7 +29,7 @@ def test_quality_sine(q, zeros, rms): DURATION = 2 sr = 44100 sr_new = sr // q - dec = Decimate(q, zeros, roll_off=0.945).to(dtype) + dec = torch.jit.script(Decimate(q, zeros, roll_off=0.945)).to(dtype) x = make_tone(FREQ, sr, DURATION) y = make_tone(FREQ, sr_new, DURATION) @@ -52,7 +52,7 @@ def test_quality_sweep(q, zeros, rms): DURATION = 5 sr = 44100 sr_new = sr // q - dec = Decimate(q, zeros, roll_off=0.945).to(dtype) + dec = torch.jit.script(Decimate(q, zeros, roll_off=0.945)).to(dtype) x = make_sweep(FREQ, sr, DURATION) y = make_sweep(FREQ, sr_new, DURATION) diff --git a/tests/test_upsample.py b/tests/test_upsample.py index b4296a8..ba062e7 100644 --- a/tests/test_upsample.py +++ b/tests/test_upsample.py @@ -19,8 +19,8 @@ def test_quality_sine(q, zeros, rms): DURATION = 2 sr = 22050 sr_new = sr * q - up = Upsample(q, zeros, - roll_off=0.945).to(dtype) + up = torch.jit.script(Upsample(q, zeros, + roll_off=0.945)).to(dtype) x = make_tone(FREQ, sr, DURATION) y = make_tone(FREQ, sr_new, DURATION) @@ -43,8 +43,8 @@ def test_quality_sweep(q, zeros, rms): DURATION = 5 sr = 22050 sr_new = sr * q - up = Upsample(q, zeros, - roll_off=0.945).to(dtype) + up = torch.jit.script(Upsample(q, zeros, + roll_off=0.945)).to(dtype) x = make_sweep(FREQ, sr, DURATION) y = make_sweep(FREQ, sr_new, DURATION)