Initialising runtime
keras 3.0 · jax · xlaboot sequence
BTU Cottbus · Angewante Modellierung und Systemsimulation
April 2026 01 / 07
Vortrag · 5 Minuten

Keras 3 &
das Multi-Backend
Paradigma.

Zwei Fragen: Was hat Keras 3 im KI-Ökosystem verändert? Und welchen Vorteil bringt das für physikalische Simulationen: Wetter, Strömungen, alles wo Physik mitrechnen muss.

JAX XLA TPU jit() · grad() · vmap() · pmap() PINN
§ 02 — Geschichte & Problem

Drei Frameworks. Acht Jahre nebeneinander her.

Zwischen 2015 und 2023 entstehen drei große KI-Frameworks: TensorFlow, PyTorch und JAX. Jedes spricht seine eigene Sprache. Wer zwischen ihnen wechseln will, schreibt Code neu. Bis Keras 3 kommt und alle drei unter einen Hut bringt.

März 2015
Keras
Die Geburt von Keras
Der französische Entwickler François Chollet veröffentlicht Keras. Wenige Monate später wechselt er zu Google. Ziel von Anfang an: neuronale Netze bauen soll so einfach wie möglich sein, egal welches Deep-Learning-Framework dahinter läuft.
Nov 2015
TensorFlow
Googles Antwort: schnell, aber starr
Google bringt TensorFlow heraus. Man muss das ganze Modell erst fertig bauen, dann läuft es. Schnell und industrie-tauglich, aber schwer zu testen und anzupassen.
2016
PyTorch
Facebooks Antwort: wie normales Python
Facebook (heute Meta) veröffentlicht PyTorch als Antwort auf das sperrige TensorFlow. Code wird Zeile für Zeile ausgeführt, ganz wie normales Python. Die Forschung wechselt nach und nach dorthin.
2018
JAX
Massiv parallel, funktional, skalierbar
Google Research veröffentlicht JAX. Idee: Funktionen werden funktional geschrieben und laufen dann massiv parallel auf Hardware-Clustern. Der Preis: man muss anders programmieren. Der Gewinn: skaliert fast beliebig.
Sep 2019
TF 2.x Keras
Keras wird an TensorFlow gekettet
TensorFlow 2 wird einfacher zu bedienen und bindet Keras fest als tf.keras ein. Vier Jahre lang ist Keras quasi nur noch mit TensorFlow nutzbar. Die ursprüngliche Idee, mehrere Deep-Learning-Frameworks zuzulassen, verschwindet.
Ende 2023
Keras 3.0
Keras 3: neu gedacht
Keras wird komplett neu gebaut. Eine Datei, drei Deep-Learning-Frameworks zur Auswahl: TensorFlow, PyTorch oder JAX, umschaltbar mit einem einzigen Befehl. Die Trennung zwischen den Welten ist vorbei.
Stack A
TensorFlow
Stark in Produktion und großen Systemen, aber sperrig im Alltag.
tf.Tensor tf.GradientTape SavedModel
Stack B
PyTorch
Der Liebling der Forschung, aber auf TPU und beim Kompilieren schwach.
torch.Tensor autograd torchscript
Stack C
JAX
Extrem skalierbar auf TPU, mathematisch elegant, aber anspruchsvoll im Einstieg.
jax.Array grad / vmap pjit
§ 03 — Keras 3 · Multi-Backend

Ein API. Drei Backends.
Derselbe Code.

Keras 3 wurde 2023 komplett neu gebaut. Es ist nur noch die Bedien-Schicht. Darunter arbeitet der Deep-Learning-Framework, den du aussuchst: TensorFlow, PyTorch oder JAX. Dein Modell-Code bleibt dabei immer gleich.

Ebene 1
Was du bedienst
keras
Die Bedien-Schicht. Hier baust du dein Modell, immer gleich, egal welcher Motor unten läuft.
keras.layers.Dense keras.Model keras.ops.matmul model.fit()
Ebene 2
Motor deiner Wahl
TensorFlow
Produktion · Serving
JAX
Wissenschaft · TPU · Skalierung
Empfohlen
PyTorch
Forschung · Community
Ebene 3
Hardware
CPUStandard
GPUNVIDIA · AMD
TPUGoogle
AppleSilicon
§ 03 · b — Backend-Switch in der Praxis

Eine Umgebungs­variable entscheidet.

Der Trick heißt keras.ops: jede Rechen-Operation läuft über eine einheitliche Schicht, die im Hintergrund das ausgewählte Deep-Learning-Framework anspricht. KERAS_BACKEND ist eine Zeile Code. Der Rest bleibt identisch.

export KERAS_BACKEND="jax"  # oder "tensorflow" · "torch"
Ohne Keras 3
Drei Frameworks, drei Dialekte.
TensorFlow · tf.Module
import tensorflow as tf

class Net(tf.Module):
  def __init__(self):
    self.W1 = tf.Variable(tf.random.normal([2,64]))
    self.W2 = tf.Variable(tf.random.normal([64,1]))
  def __call__(self, x):
    h = tf.nn.tanh(x @ self.W1)
    return h @ self.W2
PyTorch · nn.Module
import torch
import torch.nn as nn

class Net(nn.Module):
  def __init__(self):
    super().__init__()
    self.fc1 = nn.Linear(2, 64)
    self.fc2 = nn.Linear(64, 1)
  def forward(self, x):
    return self.fc2(torch.tanh(self.fc1(x)))
JAX · Flax · linen
import jax.numpy as jnp
from flax import linen as nn

class Net(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.tanh(nn.Dense(64)(x))
    return nn.Dense(1)(x)
Mit Keras 3
Ein Code. Drei Backends.
keras · egal welches Backend
import keras
from keras import layers

model = keras.Sequential([
    layers.Dense(64, activation="tanh"),
    layers.Dense(1),
])

model.compile(optimizer="adamw", loss=physics_loss)
model.fit(X, y, epochs=500)
Dieser eine Block läuft gleich, egal ob KERAS_BACKEND auf tensorflow, torch oder jax steht. Das Modell bleibt identisch, nur der Deep-Learning-Framework dahinter wechselt.
Code-Länge
~3× kürzer
8 Zeilen statt 24 · weniger Boilerplate
Portabilität
3 Backends
TensorFlow · PyTorch · JAX, alle aus einer Datei
Beim Backend-Wechsel
0 Zeilen
am Modell zu ändern · nur die Umgebungsvariable
§ 04 — JAX für Simulationen

Warum JAX für Physik so passt.

Physik muss Naturgesetze einhalten, und dafür braucht man Ableitungen. Hier kommen zwei Werkzeuge ins Spiel: JAX übernimmt die Mathe, XLA übernimmt die Geschwindigkeit.

JAX
Die Mathe-Bibliothek
Python-Werkzeug von Google. Der Trick: Sie kann jede Funktion automatisch ableiten, ohne dass du die Ableitung selbst hinschreiben musst. Genau das ist das Herzstück für Physik-Simulationen.
XLA
Der Hardware-Compiler
Ein Übersetzer, auch von Google. Er nimmt den JAX-Code und macht daraus hardware-optimierten Maschinen-Code, der auf TPUs und GPUs richtig Tempo aufnimmt. Das ist der Grund, warum JAX so schnell ist.

Zusammen: Du schreibst ganz normales Python, am Ende läuft es auf spezieller Hardware und hält gleichzeitig die Physik ein.

Use Case 01 · PINNs

Echte Physik durchs Netz.

Physics-Informed Neural Networks lernen so, dass sie eine physikalische Gleichung erfüllen, zum Beispiel die Strömungsgleichung rechts. Vorteil: Du löst komplexe Physik-Probleme ohne klassisches Rechengitter. Gerade bei schwierigen Strömungen wie dem Wetter spart das massiv Aufwand.

01
Ableiten ist das Herzstück. JAX' grad liefert die Ableitungen durchs Netz, automatisch mit einer Zeile.
02
Rechenleistung + Naturgesetze. XLA optimiert für TPU und GPU, das Netz darf riesig werden.
∂u/∂t + (u·∇)u = −∇p/ρ + ν∇²u Navier-Stokes · Echtzeit-Flow
§ 05 — Hardware · XLA · TPU · GPU

GPU oder TPU?
Zwei Chips, zwei Geschichten.

GPUs wurden ursprünglich für Computerspiele gebaut und sind eher zufällig gut in KI geworden. Heute werden sie natürlich explizit dafür entwickelt. TPUs hingegen hat Google von Anfang an für neuronale Netze gebaut. Das Bindeglied zwischen beiden ist XLA.

Der Allrounder
GPU
Kommt aus der Grafik-Welt. NVIDIA, AMD, alle großen. Flexibel für fast alles: Gaming, Rendering, KI. Für Physik-Simulationen solide, aber nicht speziell optimiert.
Universell Überall verfügbar
Der Spezialist
TPU
Von Google speziell für neuronale Netze gebaut. Eng mit JAX & XLA verzahnt. Skaliert fast beliebig, von 1 Chip auf einen Pod mit tausenden. Perfekt für große Physik-Modelle.
KI-native Skaliert massiv
§ 05 · b — Was man am Ende davon merkt

XLA fusioniert. TPUs ziehen vorbei.

Was TPU heißt
Matrix-Spezialist
„Tensor Processing Unit", also ein Google-Chip speziell für Matrix-Multiplikation. Genau das ist die Kern-Operation in neuronalen Netzen.
Preis-Leistung
2–5× pro Dollar
Höherer Durchsatz als vergleichbare NVIDIA GPUs bei gleichen Workloads. Weniger Allzweck-Ballast, spezialisierte Architektur.
Zusammenspiel
Nativ mit JAX
XLA und TPUs wurden bei Google parallel entwickelt. JAX spricht XLA direkt. Keine Umwege, keine Kompromisse.

Wie sich das in Zahlen auswirkt: Durchsatz relativ zu PyTorch auf einer NVIDIA H100 (= 1.0). Mittelwert über Transformer-Training und PINN-Szenarien. Konkrete Zahlen schwanken je nach Modell. Das Muster bleibt.

Geschwindigkeit im Vergleich (Trainings-Schritte pro Sekunde)
0 ———— 4×
PyTorchauf NVIDIA GPU
1.0×
PyTorch 2GPU · mit Compiler
1.4×
TensorFlowauf NVIDIA GPU
1.3×
JAXGPU · mit XLA
2.1×
JAXauf einer TPU
3.1×
JAXauf 8 TPU-Chips
3.8×
// Referenz: PyTorch auf NVIDIA H100 = 1.0. TF & JAX mit XLA-Acceleration. Quellen: Keras 3 Launch-Benchmarks & MLPerf.
0
× so schnell auf 8 TPU-Chips wie PyTorch auf einer NVIDIA H100.
Zeilen Code musst du am Modell ändern, um von GPU auf TPU zu wechseln.
0
Umgebungs­variable entscheidet über die gesamte Hardware-Kette.
§ 06 — Fazit · Nutzen für Physik & alle

Was bleibt hängen, für Physik und alle.

— 01
Einfachheit trifft auf Rechenpower.
Keras macht das Bauen von Modellen einfach. JAX mit TPU macht sie massiv skalierbar. Beides zusammen gab es so vorher nicht.
— 02
Freiheit statt Lock-in.
Der Entwickler entscheidet selbst, welchen Motor er nutzt, je nach Aufgabe, Hardware oder Budget. Kein Umschreiben bei einem Wechsel. Das spart massiv Zeit und Geld.
— 03
Ein großer Schritt für die Physik.
PINNs, Simulationen, Wetter- und Strömungsmodelle werden durch Keras 3 + JAX zugänglich, ohne dass man Hardware-Experte oder ML-Ingenieur sein muss.
§ 07 — Quellen & Fragen

Literatur, Referenzen und euer Input.

[01]
Chollet, F. & Watson, M.: Introduction to TensorFlow, PyTorch, JAX, and Keras Deep Learning with Python · deeplearningwithpython.io
2024
[02]
Zhernovyi, V.: Choosing Your AI Stack: PyTorch, TensorFlow, or JAX? blackthorn-vision.com
2024
[03]
Keras Team: Deep Learning for humans, Introducing Keras 3.0 keras.io
2023
[04]
Iyer, R. & Kilaru, S.: Building production AI on Google Cloud TPUs with JAX developers.googleblog.com
2024
[05]
Toscano, J. D. et al.: From PINNs to PIKANs: Recent Advances in Physics-Informed Machine Learning arXiv · 2410.13228
2024
[06]
Alawi, Z. B.: A Comparative Survey of PyTorch vs TensorFlow for Deep Learning arXiv · 2508.04035
2025
[07]
Bradbury, J. et al.: JAX: Composable transformations of Python+NumPy programs GitHub · jax-ml/jax
2018
Danke.
Fragen?
// Präsentation mit Hilfe von KI erstellt
BTU Cottbus-Senftenberg
April 2026
keras.io · jax.readthedocs.io