diff --git a/app/src/main/rs/radix2.rsh b/app/src/main/rs/radix2.rsh index ca1de1a..c85567e 100644 --- a/app/src/main/rs/radix2.rsh +++ b/app/src/main/rs/radix2.rsh @@ -18,7 +18,6 @@ limitations under the License. #define RADIX2_RSH #include "complex.rsh" -#include "radix2_generated.rsh" static inline void dft2(complex_t *out0, complex_t *out1, complex_t in0, complex_t in1) { @@ -38,17 +37,6 @@ static inline void fwd4(complex_t *out0, complex_t *out1, complex_t *out2, compl *out3 = complex(b, in1 - in3); } -static void radix2(complex_t *out, float *in, int N, int S) -{ - // we only need 4 <= N forward FFTs - if (N == 4) { - fwd4(out, out + 1, out + 2, out + 3, in[0], in[S], in[2 * S], in[3 * S]); - return; - } - radix2(out, in, N / 2, 2 * S); - radix2(out + N / 2, in + S, N / 2, 2 * S); - for (int k0 = 0, k1 = N / 2, l1 = 0; k0 < N / 2; ++k0, ++k1, l1 += S) - dft2(out + k0, out + k1, out[k0], cmul(radix2_z[l1], out[k1])); -} +#include "radix2_generated.rsh" #endif \ No newline at end of file diff --git a/app/src/main/rs/radix2_generated.rsh b/app/src/main/rs/radix2_generated.rsh index cd7c41a..19e33a5 100644 --- a/app/src/main/rs/radix2_generated.rsh +++ b/app/src/main/rs/radix2_generated.rsh @@ -258,3 +258,56 @@ static const complex_t radix2_z[256] = { { -0x1.ffd886084cd0dp-1, -0x1.92155f7a36689p-6 }, { -0x1.fff62169b92dbp-1, -0x1.921d1fcdec7b3p-7 } }; +static inline void dit4(complex_t *out, float *in) +{ + fwd4(out, out + 1, out + 2, out + 3, in[0], in[128], in[256], in[384]); +} +static void dit8(complex_t *out, float *in) +{ + dit4(out, in); + dit4(out + 4, in + 64); + for (int k0 = 0, k1 = 4, l1 = 0; k0 < 4; ++k0, ++k1, l1 += 64) + dft2(out + k0, out + k1, out[k0], cmul(radix2_z[l1], out[k1])); +} +static void dit16(complex_t *out, float *in) +{ + dit8(out, in); + dit8(out + 8, in + 32); + for (int k0 = 0, k1 = 8, l1 = 0; k0 < 8; ++k0, ++k1, l1 += 32) + dft2(out + k0, out + k1, out[k0], cmul(radix2_z[l1], out[k1])); +} +static void dit32(complex_t *out, float *in) +{ + dit16(out, in); + dit16(out + 16, in + 16); + for (int k0 = 0, k1 = 16, l1 = 0; k0 < 16; ++k0, ++k1, l1 += 16) + dft2(out + k0, out + k1, out[k0], cmul(radix2_z[l1], out[k1])); +} +static void dit64(complex_t *out, float *in) +{ + dit32(out, in); + dit32(out + 32, in + 8); + for (int k0 = 0, k1 = 32, l1 = 0; k0 < 32; ++k0, ++k1, l1 += 8) + dft2(out + k0, out + k1, out[k0], cmul(radix2_z[l1], out[k1])); +} +static void dit128(complex_t *out, float *in) +{ + dit64(out, in); + dit64(out + 64, in + 4); + for (int k0 = 0, k1 = 64, l1 = 0; k0 < 64; ++k0, ++k1, l1 += 4) + dft2(out + k0, out + k1, out[k0], cmul(radix2_z[l1], out[k1])); +} +static void dit256(complex_t *out, float *in) +{ + dit128(out, in); + dit128(out + 128, in + 2); + for (int k0 = 0, k1 = 128, l1 = 0; k0 < 128; ++k0, ++k1, l1 += 2) + dft2(out + k0, out + k1, out[k0], cmul(radix2_z[l1], out[k1])); +} +static void forward(complex_t *out, float *in) +{ + dit256(out, in); + dit256(out + 256, in + 1); + for (int k0 = 0, k1 = 256, l1 = 0; k0 < 256; ++k0, ++k1, l1 += 1) + dft2(out + k0, out + k1, out[k0], cmul(radix2_z[l1], out[k1])); +} diff --git a/app/src/main/rs/stft.rsh b/app/src/main/rs/stft.rsh index 2d1b01f..520b273 100644 --- a/app/src/main/rs/stft.rsh +++ b/app/src/main/rs/stft.rsh @@ -136,7 +136,7 @@ static void spectrum_analyzer(int amplitude) for (int i = 0; i < stft_N; ++i) input[i&(radix2_N-1)] += stft_w[i] * buffer[(i+n)&(stft_N-1)]; // yep, were wasting 2x performance - radix2(output, input, radix2_N, 1); + forward(output, input); for (int i = 0; i < radix2_N; ++i) input[i] = 0.0f; for (int j = spectrogram_height - 1; 0 < j; --j) diff --git a/utils/radix2.c b/utils/radix2.c index e50d771..be7cbf9 100644 --- a/utils/radix2.c +++ b/utils/radix2.c @@ -30,5 +30,20 @@ int main() printf("\t{ %a, %a }%s\n", creal(z), cimag(z), n < (N/2-1) ? "," : ""); } printf("};\n"); + printf("static inline void dit4(complex_t *out, float *in)\n{\n"); + printf("\tfwd4(out, out + 1, out + 2, out + 3, in[0], in[%i], in[%i], in[%i]);\n}\n", N / 4, N / 2, 3 * N / 4); + for (int n = 4, s = N / 8; n < N; n *= 2, s /= 2) { + printf("static void "); + if (n < N / 2) + printf("dit%i", n * 2); + else + printf("forward"); + printf("(complex_t *out, float *in)\n{\n"); + printf("\tdit%i(out, in);\n", n); + printf("\tdit%i(out + %i, in + %i);\n", n, n, s); + printf("\tfor (int k0 = 0, k1 = %i, l1 = 0; k0 < %i; ++k0, ++k1, l1 += %i)\n", n, n, s); + printf("\t\tdft2(out + k0, out + k1, out[k0], cmul(radix2_z[l1], out[k1]));\n"); + printf("}\n"); + } return 0; }