Acte 2

L'entraînement

On dispose de 20 000 transitions. Il faut maintenant entraîner un modèle à prédire l'avenir — non pas en pixels, mais dans un espace de représentation abstrait qu'il construit lui-même.

L'architecture

Trois composants qui apprennent ensemble

Le world model est une architecture JEPA composée de trois réseaux de neurones. Chacun a un rôle précis. Ils ne s'entraînent pas séparément : leurs paramètres sont tous optimisés simultanément par rétropropagation.

obst — image 10×10
actiont — entier 0 à 3
Encodeur CNN
Conv2d(1→16) + ReLU
Conv2d(16→32) + ReLU
Linear(3200→32) + BatchNorm
Embedding d'action
Embedding(4, 32)
 
(4 vecteurs apprenables)
zt — 32 dimensions
ea — 32 dimensions
concaténation → 64 dimensions
Prédicteur MLP
Linear(64→64) + ReLU
Linear(64→64) + ReLU
Linear(64→32)
t+1 — prédiction latente 32D

Pourquoi ces choix ?

Le CNN pour l'encodeur — une image 10×10, c'est 100 pixels. Un réseau dense les traiterait tous de façon indépendante. Le CNN, lui, utilise des filtres locaux qui se déplacent sur l'image : il apprend à détecter des motifs spatiaux (la position d'un objet, un coin de mur) avant de produire un résumé global de 32 dimensions. C'est la bonne inductive bias pour des données structurées spatialement.

L'embedding pour l'action — l'action est un entier (0, 1, 2 ou 3). Passer cet entier brut au prédicteur n'aurait pas de sens : 3 n'est pas "plus grand" que 0 dans le sens physique. L'embedding apprend un vecteur de 32 nombres pour chaque action, que le prédicteur peut combiner avec l'état latent de façon souple.

Le MLP pour le prédicteur — une fois l'état et l'action fusionnés en 64 dimensions, le prédicteur est un réseau entièrement connecté classique. Sa tâche est de modéliser des transformations non-linéaires dans l'espace latent : "si l'état est celui-ci et que l'action est celle-là, le prochain état latent ressemble à ça."

Au total : ~15 000 paramètres apprenables. C'est minuscule comparé aux milliards des grands modèles — mais suffisant pour un monde à 10×10 cases. Le modèle tient entièrement en mémoire vive et s'entraîne en quelques secondes sur CPU.

La fonction de loss

Deux objectifs en tension

L'entraînement consiste à minimiser une fonction de perte qui combine deux termes aux rôles opposés : l'un pousse le modèle à bien prédire, l'autre l'empêche de tricher.

ℒ =   ‖ẑt+1 − zt+1‖²   +   λ · CovReg(Z)    (λ = 0.1)
Terme 1 — Erreur de prédiction (MSE)

On mesure la distance au carré entre le vecteur latent prédit (ẑt+1, produit par le prédicteur) et le vecteur latent réel (zt+1, produit en encodant l'observation suivante). Ce terme force le modèle à prédire juste.

Terme 2 — Régularisation de covariance

Pénalise les corrélations entre dimensions du vecteur latent. Si deux dimensions encodent toujours la même chose, la pénalité augmente. Ce terme force chacune des 32 dimensions à porter une information différente.

Pourquoi le deuxième terme est indispensable

Sans régularisation, le modèle découvre rapidement une solution triviale : projeter toutes les images sur le même vecteur constant. La loss de prédiction tombe alors à zéro — ẑt+1 = zt+1 si les deux valent toujours (0, 0, …, 0) — sans que le modèle ait appris quoi que ce soit sur le monde. Ce phénomène s'appelle l'effondrement de représentation (representation collapse).

La régularisation par covariance s'inspire de VICReg (2022) et du papier LeWM. Elle pénalise les termes hors-diagonaux de la matrice de covariance du batch : si dimi et dimj sont systématiquement corrélées, elles encodent la même chose — du gaspillage de capacité. La BatchNorm sur la sortie de l'encodeur joue un rôle complémentaire : elle normalise chaque dimension, empêchant certaines de saturer ou s'effondrer.

Visualisation

Regardez le modèle apprendre

50 époques, 250 batches de 64 transitions par époque, optimiseur Adam. À chaque étape, les poids sont ajustés pour réduire la loss sur les 16 000 transitions d'entraînement — et on mesure au passage si le modèle généralise sur les 4 000 transitions de validation.

Courbes de loss — échelle logarithmique
Les vrais chiffres du notebook
Train
Validation
Époque
Loss train
Loss val
Pred loss
Cov reg × 0.1
✓ Convergence atteinte — loss finale ≈ 0.0001

Analyse

Ce que la convergence nous apprend

14
Époques pour converger
×2880
Réduction de loss (époque 1 → 50)
0
Surapprentissage (train ≈ val)

La convergence est rapide et propre. En 14 époques, la loss passe de 0.29 à 0.001 — une chute de deux ordres de grandeur. Puis elle s'affine progressivement jusqu'à ~0.0001 à l'époque 30, et reste stable jusqu'à la 50e.

L'époque 13 est intéressante : la loss d'entraînement remonte brièvement de 0.0044 à 0.0135 avant de plonger à 0.0016. C'est une instabilité transitoire — le modèle "réorganise" ses représentations internes. La loss de validation, elle, continue de descendre sans discontinuité, ce qui confirme qu'il ne s'agit pas d'un problème réel.

L'absence de surapprentissage est notable : les courbes train et validation restent collées tout au long de l'entraînement. C'est cohérent avec un problème bien posé — le dataset est suffisamment varié et la tâche (prédire dans l'espace latent) est naturellement régularisante.

Une loss finale de ~0.0001 signifie que le prédicteur est capable d'anticiper la représentation latente du prochain état avec une erreur quadratique de 0.01% de l'amplitude typique du vecteur latent. Mais une faible loss de prédiction ne garantit pas que l'espace latent encode ce dont on a besoin pour planifier. C'est précisément ce qu'on explore à l'acte 3.

← Acte 1 — Environnement Acte 3 — Représentation →