Skip to content

Commit

Permalink
feat: make it jittable
Browse files Browse the repository at this point in the history
  • Loading branch information
yoyolicoris committed Jul 18, 2021
1 parent dd20372 commit 5f018b2
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 18 deletions.
8 changes: 4 additions & 4 deletions kazane/decimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from .upsample import _pad_to_block_2
from .fftconv import _custom_fft_conv1d

BLOCK_RATIO = 4


class Decimate(nn.Module):
r"""Downsampling by an integer amount.
Expand All @@ -32,6 +30,8 @@ class Decimate(nn.Module):
250
"""
__constants__ = ['BLOCK_RATIO']
BLOCK_RATIO: int = 4

def __init__(self,
q: int = 2,
Expand All @@ -54,7 +54,7 @@ def forward(self, x: torch.Tensor):
shape = x.shape
x = x.view(-1, 1, shape[-1])

block_length = self.kernel.shape[-1] * self.stride * BLOCK_RATIO
block_length = self.kernel.shape[-1] * self.stride * self.BLOCK_RATIO
out_size = shape[-1] // self.stride
if shape[-1] < block_length:
x = F.pad(x, [self.padding] * 2, mode='reflect')
Expand All @@ -66,4 +66,4 @@ def forward(self, x: torch.Tensor):
y = _custom_fft_conv1d(x, self.kernel, stride=self.stride)
y = y.view(-1, num_blocks * y.shape[-1])[..., :out_size]

return y.view(*shape[:-1], -1)
return y.view(shape[:-1] + (-1,))
10 changes: 7 additions & 3 deletions kazane/fftconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def _custom_fft_conv1d(input: Tensor, weight: Tensor,

X = rfft(input, n=s)
W = rfft(weight, n=s)
W.imag.mul_(-1)
# W.imag.mul_(-1)
torch.view_as_real(W)[..., 1].mul_(-1)
Y = X * W

# handle stride
Expand All @@ -32,8 +33,11 @@ def _custom_fft_conv1d(input: Tensor, weight: Tensor,
step_size = new_n_fft // 2
strided_Y_size = step_size + 1

unfolded_Y_real = Y.real.unfold(-1, strided_Y_size, step_size)
unfolded_Y_imag = Y.imag[...,
tmpY = torch.view_as_real(Y)
Y_real, Y_imag = tmpY[..., 0], tmpY[..., 1]

unfolded_Y_real = Y_real.unfold(-1, strided_Y_size, step_size)
unfolded_Y_imag = Y_imag[...,
1:].unfold(-1, strided_Y_size - 2, step_size)
Y_pos_real, Y_pos_imag = unfolded_Y_real[..., ::2,
:].sum(-2), unfolded_Y_imag[..., ::2, :].sum(-2)
Expand Down
10 changes: 5 additions & 5 deletions kazane/upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from .sinc import sinc_kernel
from .fftconv import _custom_fft_conv1d

BLOCK_RATIO = 5


def _pad_to_block(x, block_size):
offset = x.shape[-1] % block_size
Expand All @@ -15,7 +13,7 @@ def _pad_to_block(x, block_size):
return x.view(*x.shape[:-1], -1, block_size)


def _pad_to_block_2(x, block_size, padding):
def _pad_to_block_2(x: torch.Tensor, block_size: int, padding: int):
offset = x.shape[-1] % block_size
if offset:
offset = block_size - offset
Expand Down Expand Up @@ -45,6 +43,8 @@ class Upsample(nn.Module):
1500
"""
__constants__ = ['BLOCK_RATIO']
BLOCK_RATIO: int = 5

def __init__(self,
q: int = 2,
Expand All @@ -67,7 +67,7 @@ def forward(self, x: torch.Tensor):
shape = x.shape
x = x.view(-1, 1, shape[-1])

block_length = self.kernel.shape[-1] * BLOCK_RATIO
block_length = self.kernel.shape[-1] * self.BLOCK_RATIO
if shape[-1] < block_length:
x = F.pad(x, [self.padding] * 2, mode='reflect')
y = _custom_fft_conv1d(x, self.kernel)
Expand All @@ -81,4 +81,4 @@ def forward(self, x: torch.Tensor):
2).reshape(-1, q, num_blocks * block_length)[..., :shape[-1]]

y = y.transpose(1, 2).contiguous()
return y.view(*shape[:-1], -1)
return y.view(shape[:-1] + (-1,))
4 changes: 2 additions & 2 deletions tests/test_decimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_quality_sine(q, zeros, rms):
DURATION = 2
sr = 44100
sr_new = sr // q
dec = Decimate(q, zeros, roll_off=0.945).to(dtype)
dec = torch.jit.script(Decimate(q, zeros, roll_off=0.945)).to(dtype)

x = make_tone(FREQ, sr, DURATION)
y = make_tone(FREQ, sr_new, DURATION)
Expand All @@ -52,7 +52,7 @@ def test_quality_sweep(q, zeros, rms):
DURATION = 5
sr = 44100
sr_new = sr // q
dec = Decimate(q, zeros, roll_off=0.945).to(dtype)
dec = torch.jit.script(Decimate(q, zeros, roll_off=0.945)).to(dtype)

x = make_sweep(FREQ, sr, DURATION)
y = make_sweep(FREQ, sr_new, DURATION)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def test_quality_sine(q, zeros, rms):
DURATION = 2
sr = 22050
sr_new = sr * q
up = Upsample(q, zeros,
roll_off=0.945).to(dtype)
up = torch.jit.script(Upsample(q, zeros,
roll_off=0.945)).to(dtype)

x = make_tone(FREQ, sr, DURATION)
y = make_tone(FREQ, sr_new, DURATION)
Expand All @@ -43,8 +43,8 @@ def test_quality_sweep(q, zeros, rms):
DURATION = 5
sr = 22050
sr_new = sr * q
up = Upsample(q, zeros,
roll_off=0.945).to(dtype)
up = torch.jit.script(Upsample(q, zeros,
roll_off=0.945)).to(dtype)

x = make_sweep(FREQ, sr, DURATION)
y = make_sweep(FREQ, sr_new, DURATION)
Expand Down

0 comments on commit 5f018b2

Please sign in to comment.