#!/usr/bin/env python

from copy import deepcopy
from numpy import array, identity
from rich import print


def dilatation(A, i, λ):
    """
    >>> mat = array([[1, 2, 3], [3, 4, 6], [1, 0, -1]], dtype=float)
    >>> dilatation(mat, 1, 4)
    array([[ 1.,  2.,  3.],
           [12., 16., 24.],
           [ 1.,  0., -1.]])
    """
    A[i] *= λ
    return A


def permutation(A, i, j):
    """
    >>> mat = array([[1, 2, 3], [3, 4, 6], [1, 0, -1]], dtype=float)
    >>> permutation(mat, 1, 2)
    array([[ 1.,  2.,  3.],
           [ 1.,  0., -1.],
           [ 3.,  4.,  6.]])
    """
    A[[i, j]] = A[[j, i]]
    return A


def transvection(A, i, j, λ):
    """
    >>> mat = array([[1, 2, 3], [3, 4, 6], [1, 0, -1]], dtype=float)
    >>> transvection(mat, 0, 1, 2)
    array([[ 7., 10., 15.],
           [ 3.,  4.,  6.],
           [ 1.,  0., -1.]])
    """
    A[i] += λ * A[j]
    return A


def inversion_mat():
    transvection(mat, 0, -3)
    transvection(mat, 0, -1)
    transvection(mat, 1, -1)

I3 = identity(3)
A = array([[1, 0, 1], [0, 2, 1], [1] * 3], dtype=float)

print(A)
print(I3)

def recherche_pivot(A, b, j):
    pivot = j
    for i in range(j, A.shape[0]):
        if abs(A[pivot, j]) < abs(A[i, j]):
            pivot = i
    if pivot != j:
        permutation(A, pivot, j)
        permutation(b, pivot, j)
    return A, b

print(recherche_pivot(A, I3, 1))

def elimination_bas(A, b, j):
    for i in range(j + 1, A.shape[0]):
        c = -A[i, j]/A[j, j]
        A[i] += A[j] * c
        b[i] += b[j] * c
    return A, b

print(elimination_bas(A, I3, 1))

def descente(A, b):
    for j in range(A.shape[1] - 1):
        recherche_pivot(A, b, j)
        elimination_bas(A, b, j)
    return A, b

I3 = identity(3)
A = array([[1, 0, 1], [0, 2, 1], [1] * 3], dtype=float)

print(descente(A, I3))

def elimination_haut(A, b, j):
    for i in range(j + 1, A.shape[0]):
        c = -A[i, j]/A[j, j]
        A[i] += A[j] * c
        b[i] += b[j] * c
    return A, b

def remontee(A, b):
    for j in range(A.shape[0] - 1, 0, -1):
        elimination_haut(A, b, j)
    return A, b


def solve_diagonal(A, b):
    for j in range(A.shape[0]):
        b[j] /= A[j, j]
        A[j, j] = 1
    return b


def gauss(A, b):
    descente(A, b)
    remontee(A, b)
    return solve_diagonal(A, b)


I3 = identity(3)
A = array([[1, 0, 1], [0, 2, 1], [1] * 3], dtype=float)

print(gauss(A, I3))

def inversion(A):
    """
    >>> mat = array([[1, 2, 3], [3, 4, 6], [1, 0, -1]], dtype=float)
    >>> inversion(mat)
    array([[ 0.        ,  0.33333333,  0.        ],
           [-0.        , -0.        , -0.75      ],
           [-2.        , -0.        , -0.        ]])
    """
    A_1 = deepcopy(A)
    return gauss(A_1, identity(A_1.shape[0]))

