Zustandslos skalieren · Boot-Sequenz
jax · flax · pure functionsvortrag · mai 2026
01 / Vortrag · 5 Minuten
// JAX// Flax// pure functions// PRNGKey// init() · apply()// stateless// XLA// jax.grad// jit// vmap// pmap// optax// JAX// Flax// pure functions// PRNGKey// init() · apply()// stateless
Darf ein neuronales Netz seine eigenen Gewichte vergessen?

Zustandslos
skalieren.

Warum erzwingt JAX absolute Zustandslosigkeit und opfert dafür den klassischen Zufall?
Wie ermöglicht Flax Deep Learning, wenn Modelle eigentlich keine internen Gewichte speichern dürfen?

VortragenderMax Lightning TalkJAX · Flax · Pure Functions VortragMai 2026 Dauer≈ 5 Minuten
02 / Grundlage
Definition

Was ist eine Pure Function?

Ganz einfach: gleicher Input, immer gleicher Output. Keine Side Effects, kein Gedächtnis. Das einzige was sie tut: einen Wert zurückgeben.

⚠ Imperativ · OOP-Stil

Keras: Gewichte leben im Objekt

Das Modell besitzt seine Gewichte. Jeder Trainings-Schritt manipuliert sie still im Hintergrund.

! model ├── architecture
├── weights ← mutates
└── optimizer.state ← mutates
✓ Funktional · JAX-Stil

JAX: Input rein, Output raus

Gewichte sind ein Argument: f(params, x) → y. Reine Funktion = vom XLA-Compiler tracebar = JIT-kompilierbar auf 1000 GPUs.

params, x f new_params, y

„Eine Pure Function ist wie ein Rezept: gleiche Zutaten, immer gleicher Kuchen — der Koch hat kein Gedächtnis."

— Funktionale Programmierung, kurz erklärt
03 / Konsequenz
Determinismus

Das PRNG-Problem.

PRNG · Pseudo-Random Number Generator

Wenn alles pure sein muss, gibt es ein Problem: Zufall. Es kann keinen globalen Generator geben. Sonst hängt das Ergebnis davon ab, wer wann zuerst würfelt.

Globaler Seed

Race Condition

1.000 GPUs, alle gleichzeitig auf einer Zufallsquelle. Wer würfelt zuerst? Undefiniert → korrelierte Zufallszahlen, leise kaputte Statistik.

SEED
→ chaos

JAX PRNGKey

Deterministisch

Ein Root-Key, mit split() in N unabhängige Keys aufgeteilt. Jede Operation kriegt ihren eigenen.

key k₁ k₂ k₃ jax.random.split(key, n)
~/code/prng.py
key = jax.random.PRNGKey(42)
key1, key2 = jax.random.split(key)   # deterministisch geteilt

x = jax.random.normal(key1, shape=(100,))
Kein globaler Zustand. Niemals.
04 / Beweis
Determinismus, sichtbar

Selber Key. Selber Pfad.

Zwei identische Loss-Landschaften. Links rollt der Optimizer mit globalem RNG, rechts mit PRNGKey(42).

RNG · Random Number Generator
Globaler RNG Run 0
Math.random() · jeder Lauf wandert woanders hin
PRNGKey(42) Run 0
mulberry32(42) · jeder Lauf legt sich auf den letzten
Reproduzierbarkeit ist kein Feature. Es ist eine Folge.
ƒ 05 / Architektur
Flax · Neural Networks auf JAX

Ein Netz ohne Gedächtnis… das trotzdem Gewichte hat?

Flax ist die Standard-Bibliothek für neuronale Netze auf JAX. Architektur und Gewichte werden strikt getrennt: das Modell ist nur eine Funktion, die Parameter leben extern als PyTree.

01
JAX-nativ

Nutzt direkt jit, grad, vmap. Kein Wrapper, kein Overhead.

02
Zustandslos

Das Modul speichert keine Daten. Nur Topologie.

03
PyTree-Params

Gewichte als verschachteltes Dict. Immutable Tensoren an den Blättern, jax.grad + jax.jit arbeiten direkt darauf.

Phase 1 · Init params = model.init(key, x_dummy) erzeugt den PyTree
PyTreeparams
verschachteltes Dict · immutable · jeder Knoten ist ein Tensor
Dense_0
kernel(784, 256)
bias(256,)
Dense_1
kernel(256, 64)
bias(64,)
Dense_2
kernel(64, 10)
bias(10,)
Phase 2 · Apply y = model.apply(params, x) params + x rein, y raus · zustandslos
06 / Fazit
Drei Säulen

Was bleibt hängen?

01
ƒ

Pure Functions

Reproduzierbarkeit und Parallelität sind keine Optimierung, die man später draufpackt. Sie sind Folge der Reinheit.

f(x) → y · always
02

Explizite PRNGKeys

Kein globaler Seed, keine Race Conditions. Jede Zufallsoperation kriegt ihren eigenen Schlüssel.

jax.random.split
03

init() / apply()

Die Architektur ist eine Funktion. Die Gewichte sind Daten. Diese Trennung ist der ganze Punkt.

Architektur ≠ Daten

Zustandslosigkeit ist kein Verzicht
es ist Architektur.

07 / Referenzen
Literatur

Worauf das basiert.

[01]
Burle, M.-H.: Flax's handling of model states
mint.westdri.ca/ai/jx/fl_state
Research Computing
[02]
CodeSignal: Pure Functions — The Cornerstone of JAX
codesignal.com · jax-fundamentals
CodeSignal Learn
[03]
CodeSignal: Reproducible Randomness with jax.random — Keys, Splitting, and Determinism
codesignal.com · advanced-jax-transformations
CodeSignal Learn
[04]
The Flax Authors: Flax Basics
flax.readthedocs.io · guides/flax_basics
Read the Docs
[05]
The JAX Authors: JAX — The Sharp Bits
docs.jax.dev · Common_Gotchas_in_JAX
JAX Documentation
[06]
N.N.: The Functional Transformation — Deciphering the Stateless Paradigm Shift from Keras to Flax
Markdown-Dokument
Internes Skript
// Präsentation mit Hilfe von KI erstellt