from __future__ import print_function
from numpy import *
from matplotlib.pyplot import *

# Define the functions
def f(x): return 1./(1.+cos(x)**2)
def g(x): return x**-12 - x**-6

def verify_fourier(forward, back):
    # Create random noise
    N = 1024
    vs = random.random(N)  

    # Fourier transform and back
    vshat = forward(vs)
    vsrec = back(vshat)

    # Check that the number of coefficients is correct
    Nhat, = vshat.shape
    if Nhat != N:
        print("ERROR: Fourier transform is buggy.\n  Bad number of coefficients, Expected {} or {}! (Nhat={})".format(N/2+1, N, Nhat))
        return

    vshatref = fft.fft(vs)

    # the maximal deviation of the Fourier coefficients should be small
    errs = abs(vshatref - vshat)
    err = max(errs)
    argerr = argmax(errs)
    print("Maximal error of transformed function:\n  err={0}\n  fhat[{1}]={2} fhatref[{1}]={3}"
          .format(err, argerr, vshat[argerr], vshatref[argerr]))
    if err > 1.e-10:
        figure()
        plot(errs)
        error("ERROR: Fourier transform is buggy. Transformed function shows too large deviations!")
        return 

    # the maximal deviation of the reconstructed sequence should be small
    errs = abs(vsrec - vs)
    err = max(errs)
    argerr = argmax(errs)
    print("Maximal error of reconstructed function:\n  err={0}\n  f[{1}]={2} fref[{1}]={3}"
          .format(err, argerr, vsrec[argerr], vs[argerr]))
    if err > 1.e-10:
        figure()
        plot(errs)
        error("ERROR: Fourier transform is buggy.\n  Reconstructed function shows too large deviations!")
        return

    print("\nCONGRATULATIONS! Fourier transform works!")

verify_fourier(fft.fft, fft.ifft)
