Un cosmologue a remplacé un outil classique par une solution plus rapide, tout en gardant la même physique. Résultat : des calculs 7 fois plus rapides et des gradients exacts, sans tout réinventer.
INSTALLER DIFFRAX EN 10 SECONDES
Pour commencer à utiliser Diffrax, il suffit d'installer deux bibliothèques avec pip :
pip install jax diffrax
JAX est une bibliothèque qui permet de faire du calcul numérique accéléré par GPU, et Diffrax est un solveur d'équations différentielles numériques écrit entièrement avec JAX. Pas de réseau de neurones, pas d'approximation : juste des algorithmes classiques, mais optimisés et différentiables.
LE CODE QUI A TOUT CHANGÉ : PASSER DE SCIPY À DIFFRAX
Voici le code original, basé sur SciPy, un outil classique pour résoudre des équations différentielles :
from scipy.integrate import solve_ivp
import numpy as np
C_KMS = 299792.458 # vitesse de la lumière [km/s]
def rhs(z, chi, Om, H0):
return C_KMS / (H0 np.sqrt(Om(1+z)**3 + (1-Om)))
def forwardscipy(Om, H0, zobs):
sol = solveivp(rhs, tspan=(0, z_obs[-1]),
y0=[0.0], teval=zobs,
args=(Om, H0), method="RK45",
rtol=1e-8, atol=1e-10)
chi = sol.y[0]
return 5 np.log10((1 + z_obs) chi * 1e5) # module de distance
Ce code résout une équation différentielle pour calculer la distance comobile dans un univers en expansion. La fonction rhs décrit comment cette distance évolue avec le redshift z, en fonction des paramètres Om (densité de matière) et H0 (constante de Hubble).
Voici maintenant la version avec Diffrax, qui fait exactement la même chose, mais en mieux :
import jax, jax.numpy as jnp
import diffrax as dfx
# Obligatoire : activer les nombres à 64 bits (on verra pourquoi plus tard)
jax.config.update("jaxenablex64", True)
def H_jax(z, Om, H0):
return H0 jnp.sqrt(Om(1+z)**3 + (1-Om))
@jax.jit # compile une fois, exécute vite à chaque appel
@jax.jit # compile une fois, exécute vite à chaque appel
def forwarddiffrax(theta, zobs):
Om, H0 = theta[0], theta[1]
sol = dfx.diffeqsolve(
dfx.ODETerm(lambda z, chi, a: CKMS / Hjax(z, a[0], a[1])),
dfx.Tsit5(),
t0=0.0, t1=float(z_obs[-1]), # valeur initiale et finale
dt0=1e-3, # taille initiale du pas
y0=jnp.array(0.0), # condition initiale
args=(Om, H0),
saveat=dfx.SaveAt(ts=z_obs),
stepsize_controller=dfx.PIDController(rtol=1e-8, atol=1e-10),
maxsteps=10000,
)
chi = sol.ys
return 5 jnp.log10((1 + z_obs) chi * 1e5)
La physique reste identique, et l'algorithme de résolution aussi : Tsit5 est très proche de RK45, les deux étant des méthodes de Runge-Kutta. La différence ? Deux lignes de code en plus : @jax.jit et l'API de Diffrax. Mais ces deux petites modifications changent tout.
CE QUE CES DEUX LIGNES DE CODE RAPPORTENT : 7 FOIS PLUS VITE
Avec SciPy, un appel à forward_scipy prend environ 404 microsecondes. Avec Diffrax, après compilation, le même appel ne prend que 59 microsecondes. Soit 7 fois plus rapide.
Mais d'où vient cette accélération ? Ce n'est pas de la magie. Avec SciPy, à chaque appel, Python doit revenir dans le backend C/Cython. La mémoire est réallouée. La boucle adaptative qui ajuste les pas d'intégration doit passer par l'interpréteur Python à chaque itération : « Est-ce que l'erreur locale est trop grande ? Rejeter le pas ; sinon, augmenter la taille du pas ; répéter. » Pour une résolution en 12 étapes, cela fait 12 allers-retours entre Python et le code compilé, 12 allocations mémoire, 12 calculs d'erreur.
Avec Diffrax, le premier appel avec @jax.jit trace toute la computation, y compris la boucle adaptative. Cette boucle est transformée en un lax.while_loop et compilée en un noyau machine par XLA. Tous les appels suivants exécutent directement ce noyau. Plus de Python, plus d'allocation mémoire, plus de dispatch. Juste du code machine ultra-rapide.
GRADIENTS EXACTS : LE VRAI BOUSTON DE SCIPI
Un autre problème avec SciPy ? Calculer un gradient de la fonction de perte par rapport aux paramètres Om et H0 coûte très cher. Pour deux paramètres, il faut quatre évaluations de la fonction (méthode des différences finies centrales) :
∂ℱ/∂Om ≈ [ℱ(Om+h, H0) – ℱ(Om-h, H0)] / (2h)
∂ℱ/∂H0 ≈ [ℱ(Om, H0+h) – ℱ(Om, H0-h)] / (2h)
Avec 10 paramètres, il faut 20 évaluations. Avec 50 paramètres, 100. Le coût grandit linéairement avec le nombre de paramètres.
Avec Diffrax, grâce à autodiff de JAX, on obtient un gradient exact en un seul passage arrière, quel que soit le nombre de paramètres. Pas besoin d'écrire les équations adjointes soi-même : JAX s'en charge automatiquement. Le gradient coûte à peu près le même temps qu'un appel avant.
def loss(theta):
mupred = forwarddiffrax(theta, z_obs)
return 0.5 * jnp.sum(((mupred - muobs) / sigma_mu)**2)
grad_fn = jax.jit(jax.grad(loss)) # c'est tout ce qu'il faut changer
g = grad_fn(jnp.array([0.3, 70.0])) # gradient exact
TROIS PIÈGES À ÉVITER ABSOLUMENT
Diffrax est puissant, mais il faut faire attention à trois détails qui peuvent tout gâcher.
Par défaut, JAX utilise des nombres à 32 bits. Si on pousse les tolérances d'erreur (rtol < 10⁻⁷), les résultats peuvent devenir très étranges. Sur une équation différentielle, le solveur a besoin de 69 étapes en 32 bits, mais seulement 12 en 64 bits. Si on serre encore plus les tolérances, il peut même échouer complètement. La solution ? Activer les 64 bits dès le début :
jax.config.update("jaxenablex64", True) # doit être fait en premier
Le premier appel à une fonction décorée par @jax.jit inclut une compilation qui prend environ 90 à 100 millisecondes. Si on inclut ce temps dans les mesures, Diffrax semblera plus lent que SciPy pour la mauvaise raison. Il faut donc lancer une fois la fonction et ignorer ce premier appel :
= forwarddiffrax(theta, zobs).blockuntil_ready() # compilation
# MAINTENANT, on peut benchmarker — c'est la vraie vitesse
Et attention : JAX exécute les calculs de manière asynchrone. Toujours appeler .blockuntilready() dans les boucles de timing, sinon on mesure le temps pour soumettre le travail, pas pour le terminer.
Avec odeint de SciPy, la fonction attend f(y, t) : d'abord l'état, puis le temps. Presque tout le reste (solve_ivp, Diffrax) attend f(t, y). Si on porte du code ancien sans inverser les arguments, on résout une équation différente… et on obtient une mauvaise réponse sans même s'en rendre compte.
UN TEST CONCRET : INFÉRER LES PARAMÈTRES DE L'UNIVERS
Prenons un cas réel : inférer les paramètres Om (densité de matière) et H0 (constante de Hubble) à partir de 30 supernovas simulées. On part d'une mauvaise estimation initiale : (Om, H0) = (0.10, 60), loin de la vérité (0.30, 70). On fait 350 pas de gradient avec l'optimiseur Adam.
Avec SciPy, le résultat final est catastrophique : Om = 0.652 (faux) et H0 = 60.10 (bloqué). Pourquoi ? Parce que la descente de gradient avec des différences finies ne supporte pas l'écart d'échelle de 200 fois entre Om (~0.3) et H0 (~70). Adam, avec ses taux d'apprentissage adaptatifs, ne peut pas rattraper ça.
Avec Diffrax, les résultats sont bien meilleurs : Om = 0.270 et H0 = 70.94, proches de la vérité. Et tout ça, grâce aux gradients exacts fournis par JAX.
LES CHIFFRES QUI PARLENT
Voici un résumé des performances :
Le résultat final avec SciPy est faux, car la méthode ne peut pas gérer l'échelle des paramètres. Avec Diffrax, tout fonctionne comme sur des roulettes.
DIFFRAX N'EST PAS UNE IA : C'EST DE LA MATHS CLASSIQUE OPTIMISÉE
Attention : Diffrax n'est pas une solution basée sur l'apprentissage automatique au sens où on l'entend habituellement (avec des réseaux de neurones). C'est de la mathématique classique : des équations différentielles résolues avec des méthodes de Runge-Kutta, mais écrites en JAX. L'accélération vient de la compilation JIT et de l'autodiff, deux Outils issus du monde de l'IA, mais appliqués à un solveur numérique classique.
Une approche vraiment basée sur l'IA serait un surrogate neural : un réseau de neurones qui apprend à prédire directement μ(z) à partir de θ, sans résoudre d'équation. Mais c'est un autre sujet, plus avancé.
LE SCRIPT COMPLET : DE A À Z EN UN SEUL FICHIER
Tout le code nécessaire pour reproduire cette expérience est disponible dans un seul fichier, prêt à l'emploi :
flatlcdminference.py
Infer (Omega_m, H0) from 30 mock supernovae using diffrax + Adam.
pip install jax diffrax optax
import jax, jax.numpy as jnp, numpy as np
import diffrax as dfx, optax
from scipy.integrate import solve_ivp # seulement pour générer les données simulées
jax.config.update("jaxenablex64", True)
# -- Constantes et données -----------------------------------------------
C_KMS = 299792.458
z_obs = jnp.linspace(0.05, 1.5, 30)
SIGMA = 0.10
# Données simulées à la vérité (Om=0.30, H0=70)
def chi_np(Om, H0):
sol = solveivp(lambda z, y: CKMS/(H0np.sqrt(Om(1+z)**3+(1-Om))),
(0, 1.5), [0.], teval=np.array(zobs), rtol=1e-10)
return sol.y[0]
mutrue = 5np.log10((1+np.array(zobs))chi_np(0.3, 70.)*1e5)
muobs = jnp.array(mutrue + 0.10*np.random.defaultrng(42).standardnormal(30))
# -- Modèle avant avec Diffrax --------------------------------------------
@jax.jit
def forward(theta):
Om, H0 = theta[0], theta[1]
sol = dfx.diffeqsolve(
dfx.ODETerm(lambda z, chi, a:
C_KMS/(a[1]jnp.sqrt(a[0](1+z)**3+(1-a[0])))),
dfx.Tsit5(),
t0=0., t1=1.5, dt0=1e-3, y0=jnp.array(0.),
args=(Om, H0),
saveat=dfx.SaveAt(ts=z_obs),
stepsize_controller=dfx.PIDController(rtol=1e-8, atol=1e-10),
maxsteps=10000,
).ys
return 5jnp.log10((1+z_obs)sol*1e5)
# -- Fonction de perte et gradient ----------------------------------------
def loss(th_s): # optimisation dans des coordonnées mises à l'échelle (Om, h=H0/100)
mu = forward(jnp.array([ths[0], 100.*ths[1]]))
return 0.5*jnp.sum(((mu - mu_obs)/SIGMA)**2)
grad_fn = jax.jit(jax.grad(loss))
# Réchauffer le compilateur JIT
theta_init = jnp.array([0.10, 0.60])
= forward(jnp.array([0.3, 0.7])).blockuntil_ready()
= gradfn(thetainit).blockuntil_ready()
# -- Optimiseur Adam avec plan de décroissance du taux d'apprentissage ----
sched = optax.cosinedecayschedule(initvalue=0.05, decaysteps=350, alpha=0.04)
opt = optax.adam(sched)
theta = theta_init
state = opt.init(theta)
print(f"{'Step':>5} {'Om':>7} {'H0':>7} {'Loss':>8}")
for step in range(350):
g = grad_fn(theta)
upd, state = opt.update(g, state)
theta = optax.apply_updates(theta, upd)
if (step + 1) % 70 == 0 or step == 0:
L = float(loss(theta))
print(f"{step+1:5d} {float(theta[0]):7.4f} {100*float(theta[1]):7.3f} {L:8.2f}")
Omfit, H0fit = float(theta[0]), 100*float(theta[1])
print(f"\nFinal: Om = {Omfit:.3f} H0 = {H0fit:.2f}")
print(f"Truth: Om = 0.300 H0 = 70.00")
CE QUE ÇA CHANGE POUR UN COSMOLOGUE (ET POUR TOI)
Passer de SciPy à Diffrax n'a pas changé la physique ni la méthode d'inférence. Ça a changé la faisabilité pratique de faire cette inférence du tout. Une analyse par échantillonnage niché qui devait durer des heures est devenue une analyse de moins d'une minute. Les gradients qui coûtaient 20 évaluations supplémentaires par pas sont devenus quasi gratuits.
Le temps d'apprentissage ? Une après-midi. Le temps de débogage ? Principalement les pièges des 64 bits et de la compilation JIT. Le retour sur investissement ? Immédiat et réel.
Si tu es physicien, astronome, ou ingénieur et que tu utilises SciPy pour des évaluations répétées de fonctions de vraisemblance, et que tu n'as pas encore regardé Diffrax… il est temps de le faire. La courbe d'apprentissage est faible, et les gains sont énormes.
POUR ALLER PLUS LOIN : REPRODUCTIBILITÉ ET RESSOURCES
Tous les exemples de code présentés ici sont reproductibles. Les données simulées sont générées avec une graine fixe pour garantir les mêmes résultats à chaque exécution. Les bibliothèques utilisées (JAX, Diffrax, Optax) sont open source et bien documentées.
Pour en savoir plus, consulte les ressources suivantes :
- Documentation officielle de Diffrax : https://docs.kidger.site/diffrax/
- Tutoriel JAX pour les débutants : https://jax.readthedocs.io/en/latest/notebooks/quickstart.html
- Exemple de cosmologie avec Diffrax : https://github.com/finite-differences/diffrax-examples
- Towards Data Science
L'indépendance de CLODCO est votre garantie.
Pour que l'actualité de l'IA reste sans filtre et sans concession, votre soutien est indispensable. Votre contribution est le seul moteur de notre liberté éditoriale.
Soutenir CLODCO


