import numpy as np
from math import *
from copy import deepcopy
from pylab import *
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from mpl_toolkits.mplot3d import Axes3D
from rich import print

mesh_size = 40 + 1
precision = 1e-4


def initialise() -> np.ndarray:
    top_half = list(np.linspace(100, 0, mesh_size // 2 + 1))
    sides = list(np.linspace(100, 0, mesh_size))

    domain = np.zeros((mesh_size, mesh_size))

    domain[0] = top_half + list(reversed(top_half))[1:]

    domain[-1] = list(reversed(top_half))[:-1] + top_half

    for line_index in range(mesh_size):
        domain[line_index][0] = sides[line_index]
        domain[line_index][-1] = sides[line_index]

    return domain


def Δ(value, top, right, bottom, left):
    # return right + left + top + bottom - 4 * value
    return 0.25 * (left + right + top + bottom)


def iteration(domain) -> np.ndarray:
    domain_next = initialise()

    for i in range(1, len(domain_next) - 1):
        for j in range(1, len(domain_next) - 1):
            domain_next[i, j] = Δ(
                value=domain[i, j],
                top=domain[i, j + 1],
                right=domain[i + 1, j],
                left=domain[i - 1, j],
                bottom=domain[i, j - 1],
            )

    return domain_next


def ecart_max(domain1, domain2) -> float:
    # return abs(domain1 - domain2).max()
    ecart = max(
        max(abs(domain1[i, j] - domain2[i, j]) for i in range(len(domain1)))
        for j in range(len(domain1))
    )
    print(ecart)
    return ecart if ecart != 0 else 9000000000000000000000000000


def calcul() -> np.ndarray:
    domain_prev = initialise()
    domain_next = iteration(domain_prev)

    while ecart_max(domain_prev, domain_next) > precision:
        domain_prev = deepcopy(domain_next)
        domain_next = iteration(domain_next)

    return domain_next


def electric_field(domain):
    field_x = np.zeros((mesh_size, mesh_size))
    field_y = np.zeros((mesh_size, mesh_size))

    for i in range(mesh_size-1):
        for j in range(mesh_size-1):
            field_y[i][j] = domain[i + 1][j] - domain[i][j]
            field_x[i][j] = -(domain[i][j + 1] - domain[i][j])

    return field_x, field_y


if __name__ == "__main__":
    potentiels = calcul()
    fig, ax = plt.subplots()
    cax = plt.imshow(potentiels, cmap="gnuplot2")
    plt.quiver(*electric_field(potentiels))
    fig.colorbar(cax)
    fig = plt.figure()
    ax = Axes3D(fig)
    x = y = np.linspace(0, mesh_size - 1, mesh_size)
    X, Y = np.meshgrid(x, y)

    ax.plot_surface(X, Y, potentiels, rstride=1, cstride=1, cmap="hot")
    plt.show()
