diff --git a/app/src/main/rs/radix2.rsh b/app/src/main/rs/radix2.rsh index 1002b66..ca1de1a 100644 --- a/app/src/main/rs/radix2.rsh +++ b/app/src/main/rs/radix2.rsh @@ -20,34 +20,35 @@ limitations under the License. #include "complex.rsh" #include "radix2_generated.rsh" -static void radix2(complex_t *out, float *in, int N, int S, int L) +static inline void dft2(complex_t *out0, complex_t *out1, complex_t in0, complex_t in1) +{ + *out0 = in0 + in1; + *out1 = in0 - in1; +} + +static inline void fwd4(complex_t *out0, complex_t *out1, complex_t *out2, complex_t *out3, + float in0, float in1, float in2, float in3) { - // we only need 4 <= N forward FFTs - if (N == 4) { - float in0 = in[0]; - float in1 = in[S]; - float in2 = in[2 * S]; - float in3 = in[3 * S]; float a = in0 + in2; float b = in0 - in2; float c = in1 + in3; - out[0] = complex(a + c, 0.0f); - out[1] = complex(b, in3 - in1); - out[2] = complex(a - c, 0.0f); - out[3] = complex(b, in1 - in3); + *out0 = complex(a + c, 0.0f); + *out1 = complex(b, in3 - in1); + *out2 = complex(a - c, 0.0f); + *out3 = complex(b, in1 - in3); +} + +static void radix2(complex_t *out, float *in, int N, int S) +{ + // we only need 4 <= N forward FFTs + if (N == 4) { + fwd4(out, out + 1, out + 2, out + 3, in[0], in[S], in[2 * S], in[3 * S]); return; } - radix2(out, in, N / 2, 2 * S, L + 1); - radix2(out + N / 2, in + S, N / 2, 2 * S, L + 1); - for (int k = 0; k < N / 2; ++k) { - int ke = k; - int ko = k + N / 2; - complex_t w = radix2_z[k << L]; - complex_t even = out[ke]; - complex_t odd = cmul(w, out[ko]); - out[ke] = even + odd; - out[ko] = even - odd; - } + radix2(out, in, N / 2, 2 * S); + radix2(out + N / 2, in + S, N / 2, 2 * S); + for (int k0 = 0, k1 = N / 2, l1 = 0; k0 < N / 2; ++k0, ++k1, l1 += S) + dft2(out + k0, out + k1, out[k0], cmul(radix2_z[l1], out[k1])); } #endif \ No newline at end of file diff --git a/app/src/main/rs/stft.rsh b/app/src/main/rs/stft.rsh index a539199..2d1b01f 100644 --- a/app/src/main/rs/stft.rsh +++ b/app/src/main/rs/stft.rsh @@ -136,7 +136,7 @@ static void spectrum_analyzer(int amplitude) for (int i = 0; i < stft_N; ++i) input[i&(radix2_N-1)] += stft_w[i] * buffer[(i+n)&(stft_N-1)]; // yep, were wasting 2x performance - radix2(output, input, radix2_N, 1, 0); + radix2(output, input, radix2_N, 1); for (int i = 0; i < radix2_N; ++i) input[i] = 0.0f; for (int j = spectrogram_height - 1; 0 < j; --j)