#! /usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function
import matplotlib       # uncomment these two lines if the program doesn't
matplotlib.use("TkAgg") # throw any errors but exits without having shown any plots.
import numpy as np
import matplotlib.pyplot as plt



#----------------- FUNCTION DEFINITIONS -----------------#

# A function to simplify plotting
def plot_interpolating_splines(function, interval, splinesclass, n_support_points=[5, 10, 15], nodes=np.linspace):
    """Create a plot of the function <function> between <interval>[0] and <interval>[1] together with plots
    of the interpolating polynomials with <n_support_points> different interpolation points.
    <nodes> should be a function that generates n supporting points in the interval <interval>.
    """
    
    function_name = function.__name__.replace('_', '-')
    x = np.linspace(min(interval), max(interval), 1000)
    plt.plot(x, function(x), '-', color='black', linewidth=0.5, label="{}($x$)".format(function_name))
    
    for n in n_support_points:
        support_x = nodes(min(interval), max(interval), n)
        support_y = function(support_x)
        ip = splinesclass(support_x, support_y)
        plt.plot(x, ip(x), ':', linewidth=3.0, label="P$_{{{}}}(x)$".format(n))
    
    plt.xlim(interval)
    plt.xlabel("$x$")
    plt.ylabel("$y$")
    plt.title(function_name+" function")
    plt.legend(loc="best")


# A function plotting the Sine, Runge, and Lennard-Jones functions
# together with their interpolating polynomials:
def generate_plots(splinesclass, n_support_points=[5, 10, 15], nodes=np.linspace):
    def Sine(x): return np.sin(x)
    def Runge(x): return 1.0 / (1.0 + x ** 2)
    def Lennard_Jones(x): return x ** (-12) - x ** (-6)
    
    functions = [Sine, Runge, Lennard_Jones]
    intervals = [[ 0.0, 2.0 * np.pi],
                 [-5.0, 5.0        ],
                 [ 1.0, 6.0        ]]
    fig = plt.figure()
    fig.canvas.set_window_title("{}".format(splinesclass.__name__))
    for i in range(len(functions)):
        plt.subplot(1, len(functions), i+1)
        plot_interpolating_splines(functions[i], intervals[i], splinesclass, n_support_points, nodes)



#------------------ CLASS DEFINITIONS -------------------#

# Happy hacking :-)




#------------------------- DEMO -------------------------#

import scipy.interpolate as si
def cubic_spline(x, y): return si.interp1d(x, y, 'cubic', bounds_error=False)
def quadratic_spline(x, y): return si.interp1d(x, y, 'quadratic', bounds_error=False)

generate_plots(cubic_spline)
generate_plots(quadratic_spline)

plt.show()



