Warum erzwingt JAX absolute Zustandslosigkeit und opfert dafür den klassischen Zufall?
Wie ermöglicht Flax Deep Learning, wenn Modelle eigentlich keine internen Gewichte speichern dürfen?
Ganz einfach: gleicher Input, immer gleicher Output. Keine Side Effects, kein Gedächtnis. Das einzige was sie tut: einen Wert zurückgeben.
Das Modell besitzt seine Gewichte. Jeder Trainings-Schritt manipuliert sie still im Hintergrund.
Gewichte sind ein Argument: f(params, x) → y. Reine Funktion = vom XLA-Compiler tracebar = JIT-kompilierbar auf 1000 GPUs.
„Eine Pure Function ist wie ein Rezept: gleiche Zutaten, immer gleicher Kuchen — der Koch hat kein Gedächtnis."
— Funktionale Programmierung, kurz erklärt
Wenn alles pure sein muss, gibt es ein Problem: Zufall. Es kann keinen globalen Generator geben. Sonst hängt das Ergebnis davon ab, wer wann zuerst würfelt.
1.000 GPUs, alle gleichzeitig auf einer Zufallsquelle. Wer würfelt zuerst? Undefiniert → korrelierte Zufallszahlen, leise kaputte Statistik.
PRNGKeyEin Root-Key, mit split() in N unabhängige Keys aufgeteilt. Jede Operation kriegt ihren eigenen.
key = jax.random.PRNGKey(42) key1, key2 = jax.random.split(key) # deterministisch geteilt x = jax.random.normal(key1, shape=(100,))
Zwei identische Loss-Landschaften. Links rollt der Optimizer mit globalem RNG, rechts mit PRNGKey(42).
Flax ist die Standard-Bibliothek für neuronale Netze auf JAX. Architektur und Gewichte werden strikt getrennt: das Modell ist nur eine Funktion, die Parameter leben extern als PyTree.
Nutzt direkt jit, grad, vmap. Kein Wrapper, kein Overhead.
Das Modul speichert keine Daten. Nur Topologie.
Gewichte als verschachteltes Dict. Immutable Tensoren an den Blättern, jax.grad + jax.jit arbeiten direkt darauf.
params = model.init(key, x_dummy)
erzeugt den PyTree
y = model.apply(params, x)
params + x rein, y raus · zustandslos
Reproduzierbarkeit und Parallelität sind keine Optimierung, die man später draufpackt. Sie sind Folge der Reinheit.
Kein globaler Seed, keine Race Conditions. Jede Zufallsoperation kriegt ihren eigenen Schlüssel.
Die Architektur ist eine Funktion. Die Gewichte sind Daten. Diese Trennung ist der ganze Punkt.
Zustandslosigkeit ist kein Verzicht —
es ist Architektur.