Quelques idées d'optimisation d'un code numérique écrit en Python

 

Loïc Gouarin
25 octobre 2014

Les différents outils pour le calcul

 

Pourquoi vouloir optimiser ?

Différents facteurs font qu'un code numérique écrit en Python peut souffrir de certaines lenteurs

  • l'aspect dynamique du langage
  • la création de tableaux temporaires
  • la non vectorisation des opérations sur les tableaux
  • la présence du GIL,
  • ...

Mon problème jouet

On s'intéresse ici à la résolution des équations de Saint-Venant par une méthode de Lattice Boltzmann appelée $D_2Q_9$.

 

Principe de la méthode

Elle est constituée de 3 étapes simples

  • une phase de mise à jour des points du bord
  • une phase de transport
  • une phase de collision

La grille de calcul

 

La phase de transport

 

In []:
def transport(f):
    f[:, 1:, 1] = f[:, :-1, 1]
    f[1:, :, 2] = f[:-1, :, 2]
    ...

La phase de collision

 

In []:
def collision(f, m):
    f2m(f, m)
    relaxation(m)
    m2f(m, f)
    

La phase de collision

 

In []:
def relaxation(m):
    m[:, :, 3] += c*(-2*m[:, :, 0] + 3.0*m[:, :, 1]**2 + 3.0*m[:, :, 2]**2 - m[:, :, 3])
    m[:, :, 4] += c*(m[:, :, 0] + 1.5*m[:, :, 1]**2 + 1.5*m[:, :, 2]**2 - m[:, :, 4])
    ...

Un pas de temps

In []:
def one_time_step(f, m):
    periodic_bc(f)
    transport(f)
    f2m(f, m)
    relaxation(m)
    m2f(m, f)    

Cython

Principe général

 

Définir des types statiques permettant à Cython de comprendre que nous ne sommes plus dans une partie Python mais dans une partie pouvant être écrite facilement en C et donc optimisée.

Les étapes

  • Faire un copier-coller de la fonction Python à optimiser
  • Typer les variables de la fonction
  • Dérouler les boucles (si possible dans le bon sens)
  • Ajouter des directives permettant d'optimiser les accès aux tableaux NumPy

Exemple sur le transport

In []:
def transport(f):
    f[:, 1:, 1] = f[:, :-1, 1]
    ...
Copie de la fonction à optimiser dans un fichier pyx.
In []:
def transport(double[:, :, ::1] f):
    f[:, 1:, 1] = f[:, :-1, 1]
    ...
Typer les variables de la fonction.
In []:
def transport(double[:, :, ::1] f):
    cdef: 
        int i, j
        int nx = f.shape[0]
        int ny = f.shape[1]

    for i in xrange(nx-1, 0, -1):
        for j in xrange(ny):
            f[i, j, 1] = f[i-1, j, 1]
    ...
Dérouler les boucles.
In []:
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def transport(double[:, :, ::1] f):
    cdef: 
        int i, j
        int nx = f.shape[0]
        int ny = f.shape[1]

    for i in xrange(nx-1, 0, -1):
        for j in xrange(ny):
            f[i, j, 1] = f[i-1, j, 1]
    ...
Ajout des directives.

Just In Time

  • numba
  • parakeet
  • pythran ("Ahead Of Time")
  • ...

Principe général

  • se branche sur le module écrit en Python
  • ajout d'un décorateur ou d'un commentaire sur la fonction à optimiser
  • construction d'un graphe de flot de contrôle
  • recherche des types des variables par inférence
  • optimisation de la fonction

numba et parakeet

In []:
from parakeet import jit

@jit
def transport(f):
    ...
In []:
from numba import jit

@jit("void(f8[:,:,:])")
def transport(f):
    ...

pythran

In []:
#pythran export transport(float[][][])
def transport(f):
    ...

Benchmarks

On prend notre schéma $D_2Q_9$ sur une grille de taille $1024\times 1024$ et on réalise à chaque fois 100 essais sur les fonctions testées. On calcule le temps moyen.

Versions utilisées

  • NumPy 1.9.0
  • Cython 0.20.1
  • Numba 0.15.1
  • Parakeet 0.23.2
  • pythran 0.6

Attention: le stockage de f et de m est (ns, nx, ny) pour NumPy et (nx, ny, ns) pour Numba, parakeet, pythran et Cython.

En conservant l'écriture vectorielle de NumPy

 

En utilisant des boucles

 

x4 par rapport à NumPy

Peut-on encore optimiser ?

On reprend la phase de transport.  

In []:
def one_time_step(f1, f2):
    nx, ny, ns = f1.shape
    floc = np.zeros(ns)    
    mloc = np.zeros(ns)    
    
    periodic_bc(f1)
    for i in range(1, nx-1):
        for j in range(1, ny-1):
            getf(f1, floc, i, j)
            f2m_loc(floc, mloc)
            relaxation_loc(mloc)
            m2f_loc(mloc, floc)
            setf(f2, floc, i, j)
        
x10 par rapport à NumPy pour Pythran
x40 pour la version Cython avec openMP

Conclusion

  • L'optimisation n'est plus le travail du développeur Python.
  • Les JIT bien que jeunes sont performants.
  • Ils peuvent faire plus que Cython (?).
  • Support pour openMP et GPU encore léger.

Questions ?