diff --git a/effects/abstract_dsp/include/fast_convolution.h b/effects/abstract_dsp/include/fast_convolution.h index c5a29a45..39156779 100644 --- a/effects/abstract_dsp/include/fast_convolution.h +++ b/effects/abstract_dsp/include/fast_convolution.h @@ -33,6 +33,8 @@ namespace shrapnel::dsp { template + // The FFT library doesn't seem to work when using fewer than 8 points + requires(N >= 8) class FastConvolution final { public: @@ -63,6 +65,12 @@ class FastConvolution final std::copy(out.begin(), out.end(), a_copy.begin()); } + ~FastConvolution() + { + esp32_fft_destroy(fft_config); + esp32_fft_destroy(ifft_config); + } + /** * Process samples * @@ -83,8 +91,6 @@ class FastConvolution final profiling_mark_stage("convolution b transform"); // multiply A * B - // TODO this probably needs a special case for handling DC and Nyquist - // special encoding in the first element complex_multiply(a_copy.data(), fft_out.data(), ifft_in.data()); profiling_mark_stage("convolution complex_multiply"); @@ -107,23 +113,32 @@ class FastConvolution final // (ra rb + ra im_b j) + // part 1 // (- im_a im_b + im_a r_b j) // part 2 + // The FFT result has a special encoding for DC and Nyquist. DC is + // stored in the first element, and nyquist in the second. + out[0] = a[0] * b[0]; + out[1] = a[1] * b[1]; + + a += 2; + b += 2; + out += 2; + // TODO we multiply then add here, could we use the MADD.S instruction // to speed it up? - + // part 1 - dsps_mul_f32(a, b, out, N / 2, 2, 2, 2); - dsps_mul_f32(a, b + 1, out + 1, N / 2, 2, 2, 2); + dsps_mul_f32(a, b, out, N / 2 - 1, 2, 2, 2); + dsps_mul_f32(a, b + 1, out + 1, N / 2 - 1, 2, 2, 2); // part 2 - std::array out2; + std::array out2; auto out2_ptr = reinterpret_cast(out2.data()); - dsps_mul_f32(a + 1, b + 1, out2_ptr, N / 2, 2, 2, 2); - dsps_mul_f32(a + 1, b, out2_ptr + 1, N / 2, 2, 2, 2); - dsps_mulc_f32(out2_ptr, out2_ptr, N / 2, -1, 2, 2); + dsps_mul_f32(a + 1, b + 1, out2_ptr, N / 2 - 1, 2, 2, 2); + dsps_mul_f32(a + 1, b, out2_ptr + 1, N / 2 - 1, 2, 2, 2); + dsps_mulc_f32(out2_ptr, out2_ptr, N / 2 - 1, -1, 2, 2); - dsps_add_f32(out, out2_ptr, out, N, 1, 1, 1); + dsps_add_f32(out, out2_ptr, out, N - 2, 1, 1, 1); } - + private: std::array a_copy; esp32_fft_config_t *fft_config; diff --git a/effects/abstract_dsp/test/test_fast_convolution.cpp b/effects/abstract_dsp/test/test_fast_convolution.cpp index 9a4236c4..a77a23bb 100644 --- a/effects/abstract_dsp/test/test_fast_convolution.cpp +++ b/effects/abstract_dsp/test/test_fast_convolution.cpp @@ -56,7 +56,7 @@ TEST_F(FastConvolution, ComplexMultiply) using namespace std::complex_literals; std::array, 8> in_a{ - 0.f + 0if, + 2.f + 3if, 1.f + 0if, 0.f + 2if, 1.f + 2if, @@ -66,7 +66,7 @@ TEST_F(FastConvolution, ComplexMultiply) }; std::array, 8> in_b{ - 1.f + 1if, + 7.f + 21if, 1.f + 1if, 1.f + 1if, 1.f + 1if, @@ -84,7 +84,9 @@ TEST_F(FastConvolution, ComplexMultiply) for(std::size_t i = 0; i < 8; i++) { - auto ref = in_a[i] * in_b[i]; + auto ref = i == 0 ? in_a[0].real() * in_b[0].real() + + 1if * (in_a[0].imag() * in_b[0].imag()) + : in_a[i] * in_b[i]; EXPECT_FLOAT_EQ(ref.real(), out_uut[i].real()) << "i: " << i << " in_a: " << in_a[i] << " in_b: " << in_b[i]; @@ -95,8 +97,8 @@ TEST_F(FastConvolution, ComplexMultiply) TEST_F(FastConvolution, ImpulseZeroDelayIsIdentity) { - std::array input_a{}; - std::array input_b{}; + std::array input_a{}; + std::array input_b{}; input_b[0] = 1; @@ -104,21 +106,32 @@ TEST_F(FastConvolution, ImpulseZeroDelayIsIdentity) input_a[1] = 2; input_a[2] = 3; input_a[3] = 4; + input_a[4] = 5; + input_a[5] = 6; + input_a[6] = 7; + input_a[7] = 8; - std::array out{}; + std::array out{}; - shrapnel::dsp::FastConvolution<4, 4> uut(input_a); + shrapnel::dsp::FastConvolution<8, 8> uut(input_a); uut.process(input_b, out); EXPECT_THAT(out, - ElementsAre(FloatEq(1), FloatEq(2), FloatEq(3), FloatEq(4))); + ElementsAre(FloatEq(1), + FloatEq(2), + FloatEq(3), + FloatEq(4), + FloatEq(5), + FloatEq(6), + FloatEq(7), + FloatEq(8))); } TEST_F(FastConvolution, ImpulseNonZeroDelay) { - std::array input_a{}; - std::array input_b{}; + std::array input_a{}; + std::array input_b{}; input_b[1] = 1; @@ -126,42 +139,61 @@ TEST_F(FastConvolution, ImpulseNonZeroDelay) input_a[1] = 2; input_a[2] = 3; input_a[3] = 4; + input_a[4] = 5; + input_a[5] = 6; + input_a[6] = 7; + input_a[7] = 8; - std::array out{}; + std::array out{}; - shrapnel::dsp::FastConvolution<4, 4> uut(input_a); + shrapnel::dsp::FastConvolution<8, 8> uut(input_a); uut.process(input_b, out); EXPECT_THAT(out, - ElementsAre(FloatEq(4), FloatEq(1), FloatEq(2), FloatEq(3))); + ElementsAre(FloatEq(8), + FloatEq(1), + FloatEq(2), + FloatEq(3), + FloatEq(4), + FloatEq(5), + FloatEq(6), + FloatEq(7))); } TEST_F(FastConvolution, IsCommutative) { - std::array input_a{}; - std::array input_b{}; + std::array input_a{}; + std::array input_b{}; input_a[0] = 1; input_a[1] = 2; input_a[2] = 3; input_a[3] = 4; - - input_b[0] = 5; - input_b[1] = 6; - input_b[2] = 7; - input_b[3] = 8; - - std::array out_a{}; - std::array out_b{}; - - shrapnel::dsp::FastConvolution<4, 4> uut_a(input_a); - shrapnel::dsp::FastConvolution<4, 4> uut_b(input_b); + input_a[4] = 5; + input_a[5] = 6; + input_a[6] = 7; + input_a[7] = 8; + + input_b[0] = 9; + input_b[1] = 10; + input_b[2] = 11; + input_b[3] = 12; + input_b[4] = 13; + input_b[5] = 14; + input_b[6] = 15; + input_b[7] = 16; + + std::array out_a{}; + std::array out_b{}; + + shrapnel::dsp::FastConvolution<8, 8> uut_a(input_a); + shrapnel::dsp::FastConvolution<8, 8> uut_b(input_b); uut_a.process(input_b, out_a); uut_b.process(input_a, out_b); - for(std::size_t i = 0; i < 4; i++) + for(std::size_t i = 0; i < 8; i++) { EXPECT_FLOAT_EQ(out_a[i], out_b[i]); } @@ -169,8 +201,8 @@ TEST_F(FastConvolution, IsCommutative) TEST_F(FastConvolution, IsLinear) { - std::array input_a{}; - std::array input_b{}; + std::array input_a{}; + std::array input_b{}; input_b[0] = 1; input_b[1] = 1; @@ -179,15 +211,24 @@ TEST_F(FastConvolution, IsLinear) input_a[1] = 2; input_a[2] = 3; input_a[3] = 4; + input_a[4] = 5; + input_a[5] = 6; + input_a[6] = 7; + input_a[7] = 8; - std::array out{}; + std::array out{}; - shrapnel::dsp::FastConvolution<4, 4> uut(input_a); + shrapnel::dsp::FastConvolution<8, 8> uut(input_a); uut.process(input_b, out); - EXPECT_THAT( - out, - ElementsAre( - FloatEq(1 + 4), FloatEq(2 + 1), FloatEq(3 + 2), FloatEq(4 + 3))); + EXPECT_THAT(out, + ElementsAre(FloatEq(1 + 8), + FloatEq(2 + 1), + FloatEq(3 + 2), + FloatEq(4 + 3), + FloatEq(5 + 4), + FloatEq(6 + 5), + FloatEq(7 + 6), + FloatEq(8 + 7))); }