import matplotlib.pyplot as plt
import numpy as np

δ = 0.001  # s
N = 100
In = np.identity(N)
h = 1 / 15
α = δ / (h ** 2)


def operator_matrix() -> np.ndarray:
    B = np.zeros((N, N))
    B[0, 0] = 2
    B[0, 1] = -1

    for i in range(1, N - 1):
        B[i, i - 1] = -1
        B[i, i] = 2
        B[i, i + 1] = -1

    B[N - 1, N - 2] = -1
    B[N - 1, N - 1] = 2
    return B


Op = operator_matrix()


def step(current: np.ndarray, time: int) -> np.ndarray:
    current[time + 1] = (In - α * Op).dot(current[time])
    return current

# x = np.linspace(0, 1, N)
# current = x*(1-x)
# print(current)
# for t in range(1, N-1):
#     print(t)
#     if t % 100 == 0:
#         plotlabel = f"t = {t*δ}"
#         plt.plot(x, current[t], label=plotlabel, color=plt.get_cmap("hot")(t / δ))
#     current = step(current, t)
# 
# plt.xlabel("$x$", fontsize=26)
# plt.ylabel("$T$", fontsize=26, rotation=90)
# plt.title("test pok gorejĥrtor fhdrighoerhergorjgrejg lkughfiuse")
# plt.show()
