# -*- coding: utf-8 -*-

from __future__ import print_function
import numpy as np
import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt


# Define a function to do the backsubstitution:
def backsubstitute(A, b):
    n = len(b)
    i, j = A.shape
    if i != n or j != n:
        raise ValueError("A should be shape (%d, %d), but is %s" % (n,n,A.shape))

    x = b.copy()

    for i in range(n - 1, -1, -1):
        x[i] = b[i]
        for j in range(n - 1, i, -1):
            x[i] -= A[i, j] * x[j]
        x[i] /= A[i, i]

    return x


# With that function, we can already solve Ax=b when A is an upper triangular matrix:
A = np.array([[ 1.0, 2.0, 3.0],
              [ 0.0, 4.0, 1.5],
              [ 0.0, 0.0, 3.0]])
b = np.array([1.0, 2.0, 3.0])
x = backsubstitute(A, b)
print()
print("b   =", b)
print("x   =", x)

# Verify the solution, i.e. compute $A\vec{x}$:
print("A*x =", np.dot(A, x))



# Task 2.1: Gauss Elimination with canonical pivoting

# Define a function "gauss_eliminate" that shall bring an arbitrary matrix into lower triangular form:
def gauss_eliminate(A_in, b_in):
    # Make copies of the matrix and the vector
    A = A_in.copy()
    b = b_in.copy()

    # Gauss eliminate
    n = len(b)
    i, j = A.shape
    if i != n or j != n:
        raise ValueError("A should be shape (%d, %d), but is %s" % (n, n, A.shape))

    # which row is currently handled
    for r in range(0, n):
        # which row to eliminate
        for i in range(r + 1, n):
            l = A[i, r] / A[r, r]
            A[i, r] = 0.0
            b[i] -= l * b[r]
            A[i, r + 1:] -= l * A[r, r + 1:]

    return A, b

def solve(A, b):
    Ares, bres = gauss_eliminate(A, b)
    x = backsubstitute(Ares, bres)
    return x


# When the function is implemented correctly, the following example LES should be solved:
A1 = np.array([[ 1.        ,  0.2       ,  0.04      ,  0.008     ],
               [ 1.        ,  0.53333333,  0.28444444,  0.1517037 ],
               [ 1.        ,  0.86666667,  0.75111111,  0.65096296],
               [ 1.        ,  1.2       ,  1.44      ,  1.728     ]])
b1 = np.array([ 0.95105652, -0.20791169, -0.74314483,  0.95105652])

# solve the LES:
x1 = solve(A1, b1)
print()
print("b1    =", b1)
print("x1    =", x1)
print("A1*x1 =", np.dot(A1, x1))

# compare the result to the reference:
x1_ref = np.array([ 1.27785932, -0.17219435, -8.75422788,  7.22564732])
if np.all(np.abs(x1 - x1_ref) < 1.e-5):
    print("Congratulations, the result is correct!")
else:
    print("This result is wrong!")


# Unfortunately, canonical pivoting doesn't work in all cases.
# The following example will not work with your function ``solve``.
# Therefore, in task 2.3 (below) the partial pivoting scheme is implemented.
A2 = np.array([[ 0.        ,  0.2       ,  0.04      ,  0.008     ],
               [ 1.        ,  0.53333333,  0.28444444,  0.1517037 ],
               [ 1.        ,  0.86666667,  0.75111111,  0.65096296],
               [ 1.        ,  1.2       ,  1.44      ,  1.728     ]])
b2 = np.array([ 0.95105652, -0.20791169, -0.74314483,  0.95105652])

# "solve" the LES:
x2 = solve(A2, b2)
print()
print("b2    =", b2)
print("x2    =", x2)
print("A2*x2 =", np.dot(A2, x2))

# compare the result to the reference:
x2_ref = np.array([ -0.85418404,   8.06213979, -18.74818115,  11.0694755 ])
if np.all(np.abs(x2 - x2_ref) < 1.e-5):
    print("Congratulations, the result is correct!")
else:
    print("This result is wrong!")



# Task 2.2: Interpolating Polynomial


# Task 2.2.1
# Now we use the function "solve" to determine the coefficients of the interpolating polynomial.

# compute the points of the sine to be interpolated:
xsin = np.linspace(0.0, 2.0 * np.pi, 4)
ysin = np.sin(xsin)

# create the matrix:
A = np.zeros((4,4))
A[:, 0] = 1.0
A[:, 1] = xsin
A[:, 2] = xsin ** 2
A[:, 3] = xsin ** 3

# now solve the LES:
coefficients = solve(A, ysin)
print(coefficients)


# Task 2.2.2
# Then we write a function to evaluate the polynomial at different points for the given coefficients.
# Note that you can easily write functions that directly work on arrays of numbers instead of simple numbers when using the right functions.
#def polyeval(x, coeffs):
#    f = np.ones_like(x)
#    _sum = np.zeros_like(x)
#    for a in coeffs:
#        _sum += a * f
#        f *= x
#    return _sum

# Nicer Version:
def polyeval(x, coeffs):
    return (coeffs * x[:, None] ** np.arange(coeffs.shape[0])).sum(axis=-1)

# Task 2.2.3
# With that knowledge, we can plot the interpolating polynomial.

xs = np.linspace(0.0, 2.0 * np.pi, 100)

# plot sine
plt.figure()
plt.subplot(1, 1, 1)
plt.plot(xs, np.sin(xs), label="sin(x)")
plt.plot(xsin, ysin, 'o')
# plot the polynomial
plt.plot(xs, polyeval(xs, coefficients), label="p(x)")
plt.legend()
plt.show()

# # Task 2.3: Partial and Total Pivoting

# ## Task 2.3.1: Partial Pivoting
# Let's return to Gauss elimination. In task 2.1, we were not able to solve the second sample. Therefore we now need partial pivoting.

def gauss_eliminate_partial(A_in, b_in):
    A = A_in.copy()
    b = b_in.copy()

    n = len(b)
    i, j = A.shape
    if i != n or j != n:
        raise ValueError("A should be shape (%d, %d), but is %s" % (n,n,A.shape))

    # which row is currently handled
    for r in range(0, n):
        i = r + np.abs(A[r:, r]).argmax()
        # swap the rows
        A[[r, i], :] = A[[i, r], :]
        b[[r, i]] = b[[i, r]]
        # which row to eliminate
        for i in range(r + 1, n):
            l = A[i, r] / A[r, r]
            A[i, r] = 0.0
            b[i] -= l * b[r]
            A[i, r + 1:] -= l * A[r, r + 1:]
    return A, b

def solve_partial(A, b):
    n, _ = A.shape
    Ares, bres = gauss_eliminate_partial(A, b)
    x = backsubstitute(Ares, bres)
    return x

# Let's see whether this now solves the problem:
x2 = solve_partial(A2, b2)
print()
print("b2    =", b2)
print("x2    =", x2)
print("A2*x2 =", np.dot(A2, x2))

# compare the result to the reference:
x2_ref = np.array([ -0.85418404,   8.06213979, -18.74818115,  11.0694755 ])
if np.all(np.abs(x2 - x2_ref) < 1.e-5):
    print("Congratulations, the result is correct!")
else:
    print("This result is wrong!")


# Task 2.3.3: Total Pivoting

def gauss_eliminate_complete(A_in, b_in):
    A = A_in.copy()
    b = b_in.copy()

    n = len(b)
    i, j = A.shape
    if i != n or j != n:
        raise ValueError("A should be shape (%d, %d), but is %s" % (n,n,A.shape))

    original_index = np.arange(0, n)

    # which row is currently handled
    for r in range(0, n):
        index = np.abs(A[r:, r:]).argmax()
        i, j = np.unravel_index(index, (n - r, n - r))
        # row to swap
        i += r
        # column to swap
        j += r

        # swap the rows
        A[[r, i], :] = A[[i, r], :]
        b[[r, i]] = b[[i, r]]

        # swap the columns
        A[:, [r, j]] = A[:, [j, r]]
        original_index[[r, j]] = original_index[[j, r]]

        # which row to eliminate
        for i in range(r + 1, n):
            l = A[i, r] / A[r, r]
            A[i, r] = 0.0
            b[i] -= l * b[r]
            A[i, r + 1:] -= l * A[r, r + 1:]

    return A, b, original_index

def solve_complete(A, b):
    n, _ = A.shape
    Ares, bres, original_index = gauss_eliminate_complete(A, b)
    x = backsubstitute(Ares, bres)
    x[original_index] = x[range(0, n)]
    return x

# Check the solution:
x2 = solve_complete(A2, b2)
print()
print("b2    =", b2)
print("x2    =", x2)
print("A2*x2 =", np.dot(A2, x2))

# compare the result to the reference:
x2_ref = np.array([ -0.85418404,   8.06213979, -18.74818115,  11.0694755 ])
if np.all(np.abs(x2 - x2_ref) < 1.e-5):
    print("Congratulations, the result is correct!")
else:
    print("This result is wrong!")



