"""
Principalement on discute l'exercice 9 de la feuille de TD no 2.
Le code a été développé avec Fredérique Proïa.

La fonction partition1(n, P) construit les partitions de n+1 
plus grades qu'une partition donnee de n.

A la fin, en utilisant la fonction partition1, on construit la 
fonction partitions(n) qui produit la liste de toutes les partitions 
de n.
"""

def estPartition(n, P):
    """
    Prend en entrée un entier n et une liste P
    Renvoie True (si P est une partition de n) ou False (sinon)
    """
    # Test 1 : est-ce que a0 + a1 + ... + ak = n ?
    s = sum(P)
    if s != n:
        return False
    
    # Test 2 : est-ce que a0 >= a1 >= ... >= ak ?
    for i in range(len(P)-1):
        if P[i+1] > P[i]:
            return False
        
    return True



def partition_n1(n, P):
    """
    Prend en entrée un entier n et une liste P
    Teste si P est une partition de n
    Si oui : génère toutes la liste des partitions
         de n+1 qui sont > P (au sens de l'énoncé)
    Si non : affiche une erreur et coupe le programme
    """
    if not estPartition(n, P):
        print(f"{P} n'est pas une partition de {n} !")
        return None
    
    res = []
    for i in range(len(P)):
        P1 = P.copy()
        P1[i] += 1
        if estPartition(n+1, P1):
            res.append(P1)
            
    P1 = P.copy()
    P1.append(1)
    res.append(P1)
            
    return res



def partitions(n):
    """
    Retourne toutes les partitions de n
    Fonction recursive basee sur partition_n1
    """
    if n==1:
        return [[1]]
    else:
        previous_Ps = partitions(n-1)
        res = []
        for P in previous_Ps:
            tmp_Ps = partition_n1(n-1, P)
            for partition in tmp_Ps:
                if partition not in res:
                    res.append(partition)
        return res
    

###############################################################
### Tests
###############################################################

P = [7, 4, 4, 1]
partitions_plusGrandes = partition_n1(sum(P), P)
print(partitions_plusGrandes,
      f"\nsont les partitions de {sum(P)+1} plus grandes que {P}\n")


for n in range(2, 9):
    Ps = partitions(n)
    print(f"pour n = {n} on obtient {len(Ps)} partitions")

print("\nles partitions de 8 :\n", Ps)
