Keras 3 &
das Multi-Backend
Paradigma.
Zwei Fragen: Was hat Keras 3 im KI-Ökosystem verändert? Und welchen Vorteil bringt das für physikalische Simulationen: Wetter, Strömungen, alles wo Physik mitrechnen muss.
Drei Frameworks. Acht Jahre nebeneinander her.
Zwischen 2015 und 2023 entstehen drei große KI-Frameworks: TensorFlow, PyTorch und JAX. Jedes spricht seine eigene Sprache. Wer zwischen ihnen wechseln will, schreibt Code neu. Bis Keras 3 kommt und alle drei unter einen Hut bringt.
tf.keras ein. Vier Jahre lang ist Keras quasi nur noch mit TensorFlow nutzbar. Die ursprüngliche Idee, mehrere Deep-Learning-Frameworks zuzulassen, verschwindet.
Ein API. Drei Backends.
Derselbe Code.
Keras 3 wurde 2023 komplett neu gebaut. Es ist nur noch die Bedien-Schicht. Darunter arbeitet der Deep-Learning-Framework, den du aussuchst: TensorFlow, PyTorch oder JAX. Dein Modell-Code bleibt dabei immer gleich.
Was du bedienst
Motor deiner Wahl
Hardware
Eine Umgebungsvariable entscheidet.
Der Trick heißt keras.ops: jede Rechen-Operation läuft über
eine einheitliche Schicht, die im Hintergrund das ausgewählte Deep-Learning-Framework anspricht.
KERAS_BACKEND ist eine Zeile Code. Der Rest bleibt identisch.
import tensorflow as tf class Net(tf.Module): def __init__(self): self.W1 = tf.Variable(tf.random.normal([2,64])) self.W2 = tf.Variable(tf.random.normal([64,1])) def __call__(self, x): h = tf.nn.tanh(x @ self.W1) return h @ self.W2
import torch import torch.nn as nn class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(2, 64) self.fc2 = nn.Linear(64, 1) def forward(self, x): return self.fc2(torch.tanh(self.fc1(x)))
import jax.numpy as jnp from flax import linen as nn class Net(nn.Module): @nn.compact def __call__(self, x): x = nn.tanh(nn.Dense(64)(x)) return nn.Dense(1)(x)
import keras from keras import layers model = keras.Sequential([ layers.Dense(64, activation="tanh"), layers.Dense(1), ]) model.compile(optimizer="adamw", loss=physics_loss) model.fit(X, y, epochs=500)
KERAS_BACKEND auf
tensorflow, torch oder jax steht. Das Modell bleibt identisch, nur der Deep-Learning-Framework dahinter wechselt.
Warum JAX für Physik so passt.
Physik muss Naturgesetze einhalten, und dafür braucht man Ableitungen. Hier kommen zwei Werkzeuge ins Spiel: JAX übernimmt die Mathe, XLA übernimmt die Geschwindigkeit.
Zusammen: Du schreibst ganz normales Python, am Ende läuft es auf spezieller Hardware und hält gleichzeitig die Physik ein.
Echte Physik durchs Netz.
Physics-Informed Neural Networks lernen so, dass sie eine physikalische Gleichung erfüllen, zum Beispiel die Strömungsgleichung rechts. Vorteil: Du löst komplexe Physik-Probleme ohne klassisches Rechengitter. Gerade bei schwierigen Strömungen wie dem Wetter spart das massiv Aufwand.
grad liefert die Ableitungen durchs Netz, automatisch mit einer Zeile.
GPU oder TPU?
Zwei Chips, zwei Geschichten.
GPUs wurden ursprünglich für Computerspiele gebaut und sind eher zufällig gut in KI geworden. Heute werden sie natürlich explizit dafür entwickelt. TPUs hingegen hat Google von Anfang an für neuronale Netze gebaut. Das Bindeglied zwischen beiden ist XLA.
XLA fusioniert. TPUs ziehen vorbei.
Wie sich das in Zahlen auswirkt: Durchsatz relativ zu PyTorch auf einer NVIDIA H100 (= 1.0). Mittelwert über Transformer-Training und PINN-Szenarien. Konkrete Zahlen schwanken je nach Modell. Das Muster bleibt.
Was bleibt hängen, für Physik und alle.
Literatur, Referenzen und euer Input.
Fragen?