import numpy import matplotlib.pyplot as plt import timeit def dft_forw(f): N = len(f) fhat = numpy.zeros(N, dtype=numpy.complex) fac = -2.0j*numpy.pi/N for n in range(N): for k in range(N): fhat[k] += f[n]*numpy.exp(fac*k*n) return fhat def dft_back(fhat): N = len(fhat) f = numpy.zeros(N, dtype=numpy.complex) fac = 2.0j*numpy.pi/N for n in range(N): for k in range(N): f[n] += fhat[k]*numpy.exp(fac*k*n) f /= N return f def fft_forw(f): N = len(f) if N == 1: return f else: fhat = numpy.zeros(N, dtype=numpy.complex) fac = -2.0j*numpy.pi/N g = fft_forw(f[0::2]) u = fft_forw(f[1::2]) for k in range(N/2-1): fhat[k] = g[k] + u[k]*numpy.exp(fac*k) fhat[k+N/2] = g[k] - u[k]*numpy.exp(fac*k) return fhat def fft_back(f): N = len(f) if N == 1: return f else: fhat = numpy.zeros(N, dtype=numpy.complex) fac = 2.0j*numpy.pi/N g = fft_forw(f[0::2]) u = fft_forw(f[1::2]) for k in range(N/2-1): fhat[k] = g[k] + u[k]*numpy.exp(fac*k) fhat[k+N/2] = g[k] - u[k]*numpy.exp(fac*k) return fhat def f(x): return 1/(1+x*x) allN = numpy.array([2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]) fv = {} for N in allN: x=numpy.linspace(-5, 5, N) fv[N] = f(x) def test_fft(N): fft_back(fft_forw(fv[N])) def test_dft(N): dft_back(dft_forw(fv[N])) def test_numpy(N): numpy.fft.ifft(numpy.fft.fft(fv[N])) if __name__ == '__main__': import timeit numpy_timing = numpy.empty(len(allN), dtype=float) fft_timing = numpy.empty(len(allN), dtype=float) dft_timing = numpy.empty(len(allN), dtype=float) i = 0 for N in allN: print "N=%d" % N numpy_timing[i] = timeit.timeit('fourier3.test_numpy(%d)' % N, 'import fourier3', number=1) print " numpy_timing=%f" % numpy_timing[i] fft_timing[i] = timeit.timeit('fourier3.test_fft(%d)' % N, 'import fourier3', number=1) print " fft_timing=%f" % fft_timing[i] dft_timing[i] = timeit.timeit('fourier3.test_dft(%d)' % N, 'import fourier3', number=1) print " dft_timing=%f" % dft_timing[i] i += 1 plt.loglog( allN, numpy_timing, 'o-', allN, fft_timing, 'o-', allN, dft_timing, 'o-', ) plt.legend(('fft (numpy)', 'fft (python)', 'dft (python)')) plt.show()