Das Repository enthält eine einfache PyTorch-basierte Demonstration von Diffusionsmodellen zur Rauschunterdrückung. Es zielt lediglich darauf ab, ein erstes Verständnis dieses generativen Modellierungsansatzes zu vermitteln.
Eine kurze theoretische Einführung in Standard-DDPMs finden Sie hier. DDIMs für beschleunigtes Sampling werden im Begleitnotizbuch besprochen. Zwei Beispielanwendungen schaffen einen kleinen Experimentierspielplatz. Sie sind so aufbereitet, dass sie problemlos geändert und erweitert werden können.
Einführung in DDPMs
Einführung in DDIMs
Beispiel für eine Biskuitrolle
Bedingungsloses Modell auf MNIST
Bedingtes Modell auf MNIST
Als erstes Beispiel wird ein generatives DDPM auf eine 2D-Risk-Roll-Verteilung trainiert. Das Haupttrainingsskript kann zu diesem Zweck mit einer Konfigurationsdatei aufgerufen werden, die es einem ermöglicht, die Problemeinrichtung und Modelldefinition anzupassen:
python scripts/main.py fit --config config/swissroll.yaml
Nach Abschluss des Trainings kann das endgültige Modell in diesem Notebook getestet und analysiert werden.
Zur Überwachung des Experiments kann man mit tensorboard --logdir run/swissroll/
lokal einen TensorBoard-Server betreiben. Es ist standardmäßig unter localhost:6006 in Ihrem Browser erreichbar. Alternativ kann MLfLow zur Verwaltung von Experimenten verwendet werden. In diesem Fall kann man das Training mit den entsprechenden Einstellungen starten und einen Tracking-Server per mlflow server --backend-store-uri file:./run/mlruns/
einrichten. Es ist dann unter localhost:5000 erreichbar.