import numpy def backsubstitute(A, b): n,i = b.shape if i != 1: raise ValueError("b should be shape (n, 1), but is %s" % (b.shape,)) i,j = A.shape if i != n or j != n: raise ValueError("A should be shape (%d, %d), but is %s" % (n,n,A.shape)) x = b.copy() for i in range(n-1,-1,-1): x[i] = b[i] for j in range(n-1, i, -1): x[i] -= A[i,j] * x[j] x[i] /= A[i,i] return x def gauss_eliminate(A_in, b_in): A = A_in.copy() b = b_in.copy() n,i = b.shape if i != 1: raise ValueError("b should be shape (n, 1), but is %s" % (b.shape,)) i,j = A.shape if i != n or j != n: raise ValueError("A should be shape (%d, %d), but is %s" % (n,n,A.shape)) # which row is currently handled for r in range(0, n): # which row to eliminate for i in range(r+1, n): l = A[i,r]/A[r,r] A[i,r] = 0 b[i,0] -= l*b[r,0] A[i,r+1:] -= l*A[r,r+1:] return A, b def gauss_eliminate_columnpivot(A_in, b_in): A = A_in.copy() b = b_in.copy() n,i = b.shape if i != 1: raise ValueError("b should be shape (n, 1), but is %s" % (b.shape,)) i,j = A.shape if i != n or j != n: raise ValueError("A should be shape (%d, %d), but is %s" % (n,n,A.shape)) # which row is currently handled for r in range(0, n): i = r + abs(A[r:,r]).argmax() # swap the rows A[[r,i],:]=A[[i,r],:] b[[r,i]] = b[[i,r]] # which row to eliminate for i in range(r+1, n): l = A[i,r]/A[r,r] A[i,r] = 0 b[i,0] -= l*b[r,0] A[i,r+1:] -= l*A[r,r+1:] return A, b def gauss_eliminate_totalpivot(A_in, b_in): A = A_in.copy() b = b_in.copy() n,i = b.shape if i != 1: raise ValueError("b should be shape (n, 1), but is %s" % (b.shape,)) i,j = A.shape if i != n or j != n: raise ValueError("A should be shape (%d, %d), but is %s" % (n,n,A.shape)) original_index = numpy.arange(0,n) # which row is currently handled for r in range(0, n): index = abs(A[r:,r:]).argmax() i,j = numpy.unravel_index(index, (n-r, n-r)) # row to swap i += r # column to swap j += r # swap the rows A[[r,i],:] = A[[i,r],:] b[[r,i]] = b[[i,r]] # swap the columns A[:,[r,j]] = A[:,[j,r]] original_index[[r,j]] = original_index[[j,r]] # which row to eliminate for i in range(r+1, n): l = A[i,r]/A[r,r] A[i,r] = 0 b[i,0] -= l*b[r,0] A[i,r+1:] -= l*A[r,r+1:] return A, b, original_index def solve(A, b): n,i = A.shape # Ares, bres = gauss_eliminate(A, b) # Ares, bres = gauss_eliminate_columnpivot(A, b) Ares, bres, original_index = gauss_eliminate_totalpivot(A, b) x = backsubstitute(Ares, bres) x[original_index,:] = x[range(0,n),:] return x N = 5 A1 = numpy.random.random((N,N)) A2 = A1.copy() A2[0,0] = 0.0 b2 = b1 = numpy.random.random((N,1)) import scipy.linalg print "----------" print "Solve eq 1" print "----------" print "A1 * x1 = b1" print "A1 = %s" % A1 print "b1 = %s" % b1 x1_scipy = scipy.linalg.solve(A1, b1) print "x1_scipy = %s" % x1_scipy x1_own = solve(A1, b1) print "x1_own = %s" % x1_own diff = abs(x1_own - x1_scipy) print "Maximal error = %s" % diff.max() print print "----------" print "Solve eq 2" print "----------" print "A2 * x2 = b2" print "A2 = %s" % A2 print "b2 = %s" % b2 x2_scipy = scipy.linalg.solve(A2,b2) print "x1_scipy = %s" % x2_scipy x2_own = solve(A2,b2) print "x2_own = %s" % x2_own diff = abs(x2_own - x2_scipy) print "Maximal error = %s" % diff.max() print