Autor: Henry Ndubuaku (Discord- und Docs-Abzeichen sind anklickbar)
Hinweis: Kodizes werden auf Kosten der Wiederholung pädagogisch umgesetzt. Jedes Modell ist gezielt in einer Datei enthalten, ohne Abhängigkeiten zwischen den Dateien.
Die Entwicklung und das Training transformatorbasierter Modelle ist in der Regel ressourcenintensiv und zeitaufwändig, und KI/ML-Experten müssen häufig kleinere Versionen dieser Modelle für bestimmte Probleme erstellen. Jax, ein ressourcenarmes, aber leistungsstarkes Framework, beschleunigt die Entwicklung neuronaler Netze und abstraktes verteiltes Training, aber die vorhandenen Ressourcen für die Transformatorentwicklung in Jax sind begrenzt. NanoDL begegnet dieser Herausforderung mit den folgenden Funktionen:
Eine große Auswahl an Blöcken und Ebenen erleichtert die Erstellung benutzerdefinierter Transformatormodelle von Grund auf.
Eine umfangreiche Auswahl an Modellen wie Gemma, LlaMa3, Mistral, GPT3, GPT4 (abgeleitet), T5, Whisper, ViT, Mixer, CLIP usw.
Datenparallel verteilte Trainermodelle auf mehreren GPUs oder TPUs, ohne dass manuelle Trainingsschleifen erforderlich sind.
Datenlader, die den Prozess der Datenverarbeitung für Jax/Flax einfacher und effektiver machen.
Ebenen, die in Flax/Jax nicht zu finden sind, wie z. B. RoPE, GQA, MQA und SWin, ermöglichen eine flexiblere Modellentwicklung.
GPU/TPU-beschleunigte klassische ML-Modelle wie PCA, KMeans, Regression, Gaußsche Prozesse usw.
Echte Zufallszahlengeneratoren in Jax, die keinen ausführlichen Code benötigen.
Eine Reihe fortschrittlicher Algorithmen für NLP- und Computer-Vision-Aufgaben, wie Gaussian Blur, BLEU, Tokenizer usw.
Jedes Modell ist in einer einzigen Datei ohne externe Abhängigkeiten enthalten, sodass auch der Quellcode problemlos verwendet werden kann.
Echte Zufallszahlengeneratoren in Jax, die keinen ausführlichen Code benötigen (Beispiele finden Sie in den nächsten Abschnitten).
Es gibt experimentelle und/oder unvollendete Funktionen (wie MAMBA, KAN, BitNet, GAT und RLHF) im Repo, die noch nicht über das Paket verfügbar sind, aber aus diesem Repo kopiert werden können. Feedback zu unseren Diskussions-, Issue- und Pull-Request-Threads ist willkommen! Bitte melden Sie alle Funktionswünsche, Probleme, Fragen oder Bedenken im Discord oder lassen Sie uns einfach wissen, woran Sie arbeiten!
Sie benötigen Python 3.9 oder höher und eine funktionierende JAX-Installation, FLAX-Installation, OPTAX-Installation (mit GPU-Unterstützung für die Ausführung von Schulungen, ohne dass nur Kreationen unterstützt werden können). Modelle können auf CPUs entworfen und getestet werden, aber Trainer sind alle verteilte Daten-Parallel-Modelle, was eine GPU mit 1 bis N GPUS/TPUS erfordern würde. Für reine CPU-Version von JAX:
pip install --upgrade pip # To support manylinux2010 wheels. pip install jax flax optax
Dann installieren Sie nanodl von PyPi:
pip install nanodl
Wir stellen verschiedene Beispielverwendungen der nanodl-API zur Verfügung.
import jaximport nanodlimport jax.numpy as jnpfrom nanodl import ArrayDataset, DataLoaderfrom nanodl import GPT4, GPTDataParallelTrainer# Vorbereiten Ihres Datensatzesbatch_size = 8max_length = 50vocab_size = 1000# Erstellen Sie zufällige Datendata = nanodl.uniform(shape=(batch_size, max_length), minval=0, maxval=vocab_size-1).astype(jnp.int32)# Umschalten, um die nächste Token-Vorhersage zu erstellen datasetdummy_inputs, dummy_targets = data[:, :-1], data[:, 1:]# Dataset und Dataloader erstellendataset = ArrayDataset(dummy_inputs , dummy_targets)dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)# Modellparameterhyperparams = {'num_layers': 1,'hidden_dim': 256,'num_heads': 2,'feedforward_dim': 256,'dropout': 0.1,'vocab_size': vocab_size,' embed_dim': 256,'max_length': max_length,'start_token': 0,'end_token': 50, }# Abgeleitetes GPT4-Modell model = GPT4(**hyperparams)trainer = GPTDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl')trainer.train(train_loader=dataloader, num_epochs=100, val_loader=dataloader) # tatsächliche Val-Daten verwenden # Generieren von einem Start-Tokenstart_tokens = jnp.array([[123, 456]])# Denken Sie daran, die trainierten Parameter zu laden params = trainer.load_params('params.pkl')outputs = model.apply( {'params': params}, start_tokens,rngs={'dropout': nanodl.time_rng_key()}, method=model.generate)
Visionsbeispiel
import nanodlimport jax.numpy as jnpfrom nanodl import ArrayDataset, DataLoaderfrom nanodl import DiffusionModel, DiffusionDataParallelTrainerimage_size = 32block_ Depth = 2batch_size = 8widths = [32, 64, 128]input_shape = (101, image_size, image_size, 3)images = nanodl.normal(shape=input_shape)# Verwenden Sie Ihre eigenen Bilderdataset = ArrayDataset(images) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) # Diffusionsmodell erstellendiffusion_model = DiffusionModel(image_size, widths, block_ Depth)# Training auf Ihrem Datentrainer = DiffusionDataParallelTrainer(diffusion_model, input_shape=images.shape, weights_filename='params.pkl', learning_rate=1e-4)trainer.train(dataloader, 10)# Generieren Sie einige Beispiele: Jedes Modell ist ein Flax.linen-Modul# Verwenden Sie es wie gewohnt params = trainer.load_params('params.pkl')generated_images = diffusion_model.apply( {'params': params}, num_images=5, diffusion_steps=5, method=diffusion_model.generate)
Hörbeispiel
import jaximport jax.numpy as jnpfrom nanodl import ArrayDataset, DataLoaderfrom nanodl import Whisper, WhisperDataParallelTrainer# Dummy-Datenparameterbatch_size = 8max_length = 50embed_dim = 256 vocab_size = 1000 # Daten generieren: durch tatsächliche tokenisierte/quantisierte Daten ersetzendummy_targets = jnp.ones((101, max_length), dtype=jnp.int32)dummy_inputs = jnp.ones((101, max_length, embed_dim))dataset = ArrayDataset(dummy_inputs, dummy_targets)dataloader = DataLoader(dataset, batch_size=batch_size, shuffle= True, drop_last=False)# Modell Parameterhyperparams = {'num_layers': 1,'hidden_dim': 256,'num_heads': 2,'feedforward_dim': 256,'dropout': 0.1,'vocab_size': 1000,'embed_dim': embed_dim,'max_length': max_length ,'start_token': 0,'end_token': 50, }# Modell initialisierenmodel = Whisper(**hyperparams)# Training auf Ihrem Datentrainer = WhisperDataParallelTrainer(model, dummy_inputs.shape, dummy_targets.shape, 'params.pkl')trainer.train(dataloader, 2, dataloader)# Sample inferenceparams = trainer.load_params('params.pkl')# Für mehr als ein Beispiel verwenden Sie häufig model.generate_batchtranscripts = model.apply({'params ': params}, dummy_inputs[:1], method=model.generate)
Beispiel für ein Belohnungsmodell für RLHF
import nanodlimport jax.numpy as jnpfrom nanodl import ArrayDataset, DataLoaderfrom nanodl import Mistral, RewardModel, RewardDataParallelTrainer# Dummy-Daten generierenbatch_size = 8max_length = 10# Durch tatsächliche tokenisierte Daten ersetzendummy_chosen = jnp.ones((101, max_length), dtype=jnp.int32) dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32)# Datensatz und Datenlader erstellendataset = ArrayDataset(dummy_chosen, dummy_rejected)dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) # Modellparameterhyperparams = {'num_layers': 1,'hidden_dim': 256,'num_heads': 2,'feedforward_dim': 256,'dropout': 0.1,'vocab_size': 1000,'embed_dim': 256,'max_length': max_length,'start_token': 0, 'end_token': 50,'num_groups': 2,'window_size': 5,'shift_size': 2}# Belohnungsmodell von Mistralmodel initialisieren = Mistral(**hyperparams)reward_model = RewardModel(model, dim=hyperparams['hidden_dim'], dropout=0.1)# Trainiere die Belohnung modeltrainer = RewardDataParallelTrainer(reward_model, dummy_chosen.shape, 'reward_model_weights.pkl')trainer.train(dataloader, 5, dataloader)params = trainer.load_params('reward_model_weights.pkl')# Rufen Sie wie ein normales Flax-Modell auf.rewards = reward_model.apply({'params': params}, dummy_chosen, rngs={'dropout': nanodl.time_rng_key()})
PCA-Beispiel
import nanodlfrom nanodl import PCA# Aktuelle Daten verwendendata = nanodl.normal(shape=(1000, 10))# PCA-Modell initialisieren und trainierenpca = PCA(n_components=2)pca.fit(data)# PCA abrufen transformstransformed_data = pca.transform( data)# Get reverse transformsoriginal_data = pca.inverse_transform(transformed_data)# Beispiel aus der DistributionX_sampled = pca.sample(n_samples=1000, key=None)
Dies befindet sich noch in der Entwicklung, funktioniert großartig, aber es wird erwartet, dass es grob wird, und daher sind Beiträge sehr erwünscht!
Nehmen Sie Ihre Änderungen vor, ohne die Designmuster zu ändern.
Schreiben Sie bei Bedarf Tests für Ihre Änderungen.
Lokal installieren mit pip3 install -e .
.
Führen Sie Tests mit python3 -m unittest discover -s tests
aus.
Senden Sie dann eine Pull-Anfrage.
Beiträge können in verschiedenen Formen erfolgen:
Dokumentation schreiben.
Fehler beheben.
Durchführungspapiere.
Schreiben von Tests mit hoher Abdeckung.
Optimierung bestehender Codes.
Experimentieren Sie und senden Sie Beispiele aus der Praxis an den Abschnitt „Beispiele“.
Fehler melden.
Reagieren auf gemeldete Probleme.
Treten Sie dem Discord-Server bei, um mehr zu erfahren.
Der Name „NanoDL“ steht für Nano Deep Learning. Die Größe der Modelle nimmt explosionsartig zu, weshalb Gate-Keeping-Experten und Unternehmen mit begrenzten Ressourcen davon abgehalten werden, flexible Modelle ohne unerschwingliche Kosten zu entwickeln. Nach dem Erfolg der Phi-Modelle besteht das langfristige Ziel darin, Nanoversionen aller verfügbaren Modelle zu erstellen und zu trainieren und gleichzeitig sicherzustellen, dass sie hinsichtlich der Leistung mit den Originalmodellen konkurrieren und die Gesamtzahl der Parameter 1 B nicht überschreitet. Über diese Bibliothek werden trainierte Gewichte zur Verfügung gestellt. Jede Form von Sponsoring und Finanzierung hilft bei der Bereitstellung von Schulungsressourcen. Sie können entweder hier über GitHub sponsern oder sich über [email protected] an uns wenden.
Um dieses Repository zu zitieren:
@software{nanodl2024github, author = {Henry Ndubuaku}, title = {NanoDL: A Jax-based library for designing and training transformer models from scratch.}, url = {http://github.com/hmunachi/nanodl}, year = {2024}, }