import numpy
import matplotlib.pyplot

def getRhos(N, Nsystems=25, Ncharges=10):
    "Generate a set of Nsystems NxN-systems of Ncharges random charges."
    rhos = []

    for i in range(Nsystems):
        rho = numpy.zeros((N, N))

        xs = numpy.random.randint(N, size=Ncharges)
        ys = numpy.random.randint(N, size=Ncharges)
        rho[xs, ys] += 1.0

        xs = numpy.random.randint(N, size=Ncharges)
        ys = numpy.random.randint(N, size=Ncharges)
        rho[xs, ys] -= 1.0

        rhos.append(rho)

    return rhos

def getMatrix(N, L=1.0):
    "Get the matrix for a NxN system."
    def linindex(x, y):
        return (x % N) + N * (y % N)

    A = numpy.zeros((N * N, N * N))
    # periodic
    for x in range(N):
        for y in range(N):
            i = linindex(x, y)
            A[i, i] = -4.0
            A[i, linindex(x + 1, y)] = 1.0
            A[i, linindex(x - 1, y)] = 1.0
            A[i, linindex(x, y + 1)] = 1.0
            A[i, linindex(x, y - 1)] = 1.0

    h = float(L) / N
    A *= 1.0 / (h * h)
    A[0, :] = 0.0
    A[0, 0] = 1.0

    return A

def plotImages(phis):
    for num, phi in enumerate(phis):
        p = matplotlib.pyplot.subplot(5, 5, num+1)
        p.axes.get_xaxis().set_visible(False)
        p.axes.get_yaxis().set_visible(False)
        matplotlib.pyplot.imshow(phi)
