Skip to content

Commit

Permalink
Fix FFT
Browse files Browse the repository at this point in the history
  • Loading branch information
Barabas5532 committed Mar 11, 2024
1 parent 2142507 commit 23b5838
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 46 deletions.
37 changes: 26 additions & 11 deletions effects/abstract_dsp/include/fast_convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
namespace shrapnel::dsp {

template <std::size_t N, std::size_t K>
// The FFT library doesn't seem to work when using fewer than 8 points
requires(N >= 8)
class FastConvolution final
{
public:
Expand Down Expand Up @@ -63,6 +65,12 @@ class FastConvolution final
std::copy(out.begin(), out.end(), a_copy.begin());
}

~FastConvolution()
{
esp32_fft_destroy(fft_config);
esp32_fft_destroy(ifft_config);
}

/**
* Process samples
*
Expand All @@ -83,8 +91,6 @@ class FastConvolution final
profiling_mark_stage("convolution b transform");

// multiply A * B
// TODO this probably needs a special case for handling DC and Nyquist
// special encoding in the first element
complex_multiply(a_copy.data(), fft_out.data(), ifft_in.data());
profiling_mark_stage("convolution complex_multiply");

Expand All @@ -107,23 +113,32 @@ class FastConvolution final
// (ra rb + ra im_b j) + // part 1
// (- im_a im_b + im_a r_b j) // part 2

// The FFT result has a special encoding for DC and Nyquist. DC is
// stored in the first element, and nyquist in the second.
out[0] = a[0] * b[0];
out[1] = a[1] * b[1];

a += 2;
b += 2;
out += 2;

// TODO we multiply then add here, could we use the MADD.S instruction
// to speed it up?

// part 1
dsps_mul_f32(a, b, out, N / 2, 2, 2, 2);
dsps_mul_f32(a, b + 1, out + 1, N / 2, 2, 2, 2);
dsps_mul_f32(a, b, out, N / 2 - 1, 2, 2, 2);
dsps_mul_f32(a, b + 1, out + 1, N / 2 - 1, 2, 2, 2);

// part 2
std::array<float, N> out2;
std::array<float, N - 2> out2;
auto out2_ptr = reinterpret_cast<float *>(out2.data());
dsps_mul_f32(a + 1, b + 1, out2_ptr, N / 2, 2, 2, 2);
dsps_mul_f32(a + 1, b, out2_ptr + 1, N / 2, 2, 2, 2);
dsps_mulc_f32(out2_ptr, out2_ptr, N / 2, -1, 2, 2);
dsps_mul_f32(a + 1, b + 1, out2_ptr, N / 2 - 1, 2, 2, 2);
dsps_mul_f32(a + 1, b, out2_ptr + 1, N / 2 - 1, 2, 2, 2);
dsps_mulc_f32(out2_ptr, out2_ptr, N / 2 - 1, -1, 2, 2);

dsps_add_f32(out, out2_ptr, out, N, 1, 1, 1);
dsps_add_f32(out, out2_ptr, out, N - 2, 1, 1, 1);
}

private:
std::array<float, N> a_copy;
esp32_fft_config_t *fft_config;
Expand Down
111 changes: 76 additions & 35 deletions effects/abstract_dsp/test/test_fast_convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ TEST_F(FastConvolution, ComplexMultiply)
using namespace std::complex_literals;

std::array<std::complex<float>, 8> in_a{
0.f + 0if,
2.f + 3if,
1.f + 0if,
0.f + 2if,
1.f + 2if,
Expand All @@ -66,7 +66,7 @@ TEST_F(FastConvolution, ComplexMultiply)
};

std::array<std::complex<float>, 8> in_b{
1.f + 1if,
7.f + 21if,
1.f + 1if,
1.f + 1if,
1.f + 1if,
Expand All @@ -84,7 +84,9 @@ TEST_F(FastConvolution, ComplexMultiply)

for(std::size_t i = 0; i < 8; i++)
{
auto ref = in_a[i] * in_b[i];
auto ref = i == 0 ? in_a[0].real() * in_b[0].real() +
1if * (in_a[0].imag() * in_b[0].imag())
: in_a[i] * in_b[i];

EXPECT_FLOAT_EQ(ref.real(), out_uut[i].real())
<< "i: " << i << " in_a: " << in_a[i] << " in_b: " << in_b[i];
Expand All @@ -95,82 +97,112 @@ TEST_F(FastConvolution, ComplexMultiply)

TEST_F(FastConvolution, ImpulseZeroDelayIsIdentity)
{
std::array<float, 4> input_a{};
std::array<float, 4> input_b{};
std::array<float, 8> input_a{};
std::array<float, 8> input_b{};

input_b[0] = 1;

input_a[0] = 1;
input_a[1] = 2;
input_a[2] = 3;
input_a[3] = 4;
input_a[4] = 5;
input_a[5] = 6;
input_a[6] = 7;
input_a[7] = 8;

std::array<float, 4> out{};
std::array<float, 8> out{};

shrapnel::dsp::FastConvolution<4, 4> uut(input_a);
shrapnel::dsp::FastConvolution<8, 8> uut(input_a);

uut.process(input_b, out);

EXPECT_THAT(out,
ElementsAre(FloatEq(1), FloatEq(2), FloatEq(3), FloatEq(4)));
ElementsAre(FloatEq(1),
FloatEq(2),
FloatEq(3),
FloatEq(4),
FloatEq(5),
FloatEq(6),
FloatEq(7),
FloatEq(8)));
}

TEST_F(FastConvolution, ImpulseNonZeroDelay)
{
std::array<float, 4> input_a{};
std::array<float, 4> input_b{};
std::array<float, 8> input_a{};
std::array<float, 8> input_b{};

input_b[1] = 1;

input_a[0] = 1;
input_a[1] = 2;
input_a[2] = 3;
input_a[3] = 4;
input_a[4] = 5;
input_a[5] = 6;
input_a[6] = 7;
input_a[7] = 8;

std::array<float, 4> out{};
std::array<float, 8> out{};

shrapnel::dsp::FastConvolution<4, 4> uut(input_a);
shrapnel::dsp::FastConvolution<8, 8> uut(input_a);

uut.process(input_b, out);

EXPECT_THAT(out,
ElementsAre(FloatEq(4), FloatEq(1), FloatEq(2), FloatEq(3)));
ElementsAre(FloatEq(8),
FloatEq(1),
FloatEq(2),
FloatEq(3),
FloatEq(4),
FloatEq(5),
FloatEq(6),
FloatEq(7)));
}

TEST_F(FastConvolution, IsCommutative)
{
std::array<float, 4> input_a{};
std::array<float, 4> input_b{};
std::array<float, 8> input_a{};
std::array<float, 8> input_b{};

input_a[0] = 1;
input_a[1] = 2;
input_a[2] = 3;
input_a[3] = 4;

input_b[0] = 5;
input_b[1] = 6;
input_b[2] = 7;
input_b[3] = 8;

std::array<float, 4> out_a{};
std::array<float, 4> out_b{};

shrapnel::dsp::FastConvolution<4, 4> uut_a(input_a);
shrapnel::dsp::FastConvolution<4, 4> uut_b(input_b);
input_a[4] = 5;
input_a[5] = 6;
input_a[6] = 7;
input_a[7] = 8;

input_b[0] = 9;
input_b[1] = 10;
input_b[2] = 11;
input_b[3] = 12;
input_b[4] = 13;
input_b[5] = 14;
input_b[6] = 15;
input_b[7] = 16;

std::array<float, 8> out_a{};
std::array<float, 8> out_b{};

shrapnel::dsp::FastConvolution<8, 8> uut_a(input_a);
shrapnel::dsp::FastConvolution<8, 8> uut_b(input_b);

uut_a.process(input_b, out_a);
uut_b.process(input_a, out_b);

for(std::size_t i = 0; i < 4; i++)
for(std::size_t i = 0; i < 8; i++)
{
EXPECT_FLOAT_EQ(out_a[i], out_b[i]);
}
}

TEST_F(FastConvolution, IsLinear)
{
std::array<float, 4> input_a{};
std::array<float, 4> input_b{};
std::array<float, 8> input_a{};
std::array<float, 8> input_b{};

input_b[0] = 1;
input_b[1] = 1;
Expand All @@ -179,15 +211,24 @@ TEST_F(FastConvolution, IsLinear)
input_a[1] = 2;
input_a[2] = 3;
input_a[3] = 4;
input_a[4] = 5;
input_a[5] = 6;
input_a[6] = 7;
input_a[7] = 8;

std::array<float, 4> out{};
std::array<float, 8> out{};

shrapnel::dsp::FastConvolution<4, 4> uut(input_a);
shrapnel::dsp::FastConvolution<8, 8> uut(input_a);

uut.process(input_b, out);

EXPECT_THAT(
out,
ElementsAre(
FloatEq(1 + 4), FloatEq(2 + 1), FloatEq(3 + 2), FloatEq(4 + 3)));
EXPECT_THAT(out,
ElementsAre(FloatEq(1 + 8),
FloatEq(2 + 1),
FloatEq(3 + 2),
FloatEq(4 + 3),
FloatEq(5 + 4),
FloatEq(6 + 5),
FloatEq(7 + 6),
FloatEq(8 + 7)));
}

0 comments on commit 23b5838

Please sign in to comment.