Auteur : Henry Ndubuaku (les badges Discord & Docs sont cliquables)
N/B : Les codes sont mis en œuvre de manière pédagogique au détriment de la répétition. Chaque modèle est volontairement contenu dans un fichier sans dépendances inter-fichiers.
Le développement et la formation de modèles basés sur des transformateurs nécessitent généralement beaucoup de ressources et de temps, et les experts en IA/ML doivent souvent créer des versions à plus petite échelle de ces modèles pour des problèmes spécifiques. Jax, un framework puissant mais à faibles ressources, accélère le développement de réseaux de neurones et résume la formation distribuée, mais les ressources existantes pour le développement de transformateurs dans Jax sont limitées. NanoDL relève ce défi avec les fonctionnalités suivantes :
Une large gamme de blocs et de couches, facilitant la création de modèles de transformateurs personnalisés à partir de zéro.
Une vaste sélection de modèles comme Gemma, LlaMa3, Mistral, GPT3, GPT4 (inféré), T5, Whisper, ViT, Mixers, CLIP etc.
Modèles d'entraînement distribués parallèles aux données sur plusieurs GPU ou TPU, sans avoir besoin de boucles de formation manuelles.
Chargeurs de données, rendant le processus de traitement des données pour Jax/Flax plus simple et efficace.
Couches introuvables dans Flax/Jax, telles que RoPE, GQA, MQA et SWin, permettant un développement de modèles plus flexible.
Modèles ML classiques accélérés par GPU/TPU comme PCA, KMeans, régression, processus gaussiens, etc.
De vrais générateurs de nombres aléatoires dans Jax qui n'ont pas besoin du code détaillé.
Une gamme d'algorithmes avancés pour les tâches de PNL et de vision par ordinateur, tels que le flou gaussien, BLEU, Tokenizer, etc.
Chaque modèle est contenu dans un seul fichier sans dépendances externes, de sorte que le code source peut également être facilement utilisé.
De vrais générateurs de nombres aléatoires dans Jax qui n'ont pas besoin du code détaillé (exemples présentés dans les sections suivantes).
Il existe des fonctionnalités expérimentales et/ou inachevées (comme MAMBA, KAN, BitNet, GAT et RLHF) dans le dépôt qui ne sont pas encore disponibles via le package, mais peuvent être copiées à partir de ce dépôt. Les commentaires sur l’un de nos fils de discussion, de problèmes et de demandes d’extraction sont les bienvenus ! Veuillez signaler toute demande de fonctionnalité, problème, question ou préoccupation dans Discord, ou faites-nous simplement savoir sur quoi vous travaillez !
Vous aurez besoin de Python 3.9 ou version ultérieure et d'une installation JAX fonctionnelle, d'une installation FLAX, d'une installation OPTAX (avec prise en charge GPU pour exécuter la formation, sans pouvoir prendre en charge uniquement les créations). Les modèles peuvent être conçus et testés sur des processeurs, mais les formateurs sont tous en parallèle de données distribuées, ce qui nécessiterait un GPU avec 1 à N GPUS/TPUS. Pour la version CPU uniquement de JAX :
pip install --upgrade pip # To support manylinux2010 wheels. pip install jax flax optax
Ensuite, installez nanodl depuis PyPi :
pip install nanodl
Nous fournissons divers exemples d'utilisation de l'API nanodl.
import jaximport nanodlimport jax.numpy as jnpfrom nanodl import ArrayDataset, DataLoaderfrom nanodl import GPT4, GPTDataParallelTrainer# Préparation de votre ensemble de donnéesbatch_size = 8max_length = 50vocab_size = 1000# Créer des données aléatoires = nanodl.uniform(shape=(batch_size, max_length), minval=0, maxval=vocab_size-1).astype(jnp.int32)# Shift pour créer l'ensemble de données de prédiction du jeton suivantdummy_inputs, dummy_targets = data[:, :-1], data[:, 1:]# Créer un ensemble de données et un dataloaderdataset = ArrayDataset(dummy_inputs , dummy_targets)dataloader = DataLoader(ensemble de données, batch_size=batch_size, shuffle=True, drop_last=False)# paramètres du modèlehyperparams = {'num_layers' : 1,'hidden_dim' : 256,'num_heads' : 2,'feedforward_dim' : 256,'dropout' : 0.1,'vocab_size' : vocab_size,' embed_dim' : 256,'max_length' : max_length, 'start_token' : 0, 'end_token' : 50, }# Modèle GPT4 déduit model = GPT4(**hyperparams)trainer = GPTDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl')trainer.train(train_loader=dataloader, num_epochs=100, val_loader=dataloader) # utiliser les données val réelles # Génération à partir d'un tokenstartstart_tokens = jnp.array([[123, 456]])# N'oubliez pas de charger les paramètres entraînés params = trainer.load_params('params.pkl')outputs = model.apply( {'params' : params}, start_tokens,rngs={'dropout' : nanodl.time_rng_key()}, method=model.generate)
Exemple de vision
importer nanodlimport jax.numpy en tant que jnpfrom nanodl import ArrayDataset, DataLoaderfrom nanodl import DiffusionModel, DiffusionDataParallelTrainerimage_size = 32block_degree = 2batch_size = 8widths = [32, 64, 128]input_shape = (101, image_size, image_size, 3)images = nanodl.normal(shape=input_shape)# Utilisez votre propre imagesdataset = ArrayDataset(images) dataloader = DataLoader (ensemble de données, batch_size=batch_size, shuffle=True, drop_last=False) # Créer un modèle de diffusiondiffusion_model = DiffusionModel(image_size, widths, block_degree)# Formation sur votre datatrainer = DiffusionDataParallelTrainer(diffusion_model, input_shape=images.shape, poids_fichier='params.pkl', learning_rate=1e-4)trainer.train(dataloader, 10)# Générez quelques échantillons : chaque modèle est un module Flax.linen# Utilisez comme vous le feriez normalementparams = trainer.load_params('params.pkl')generated_images = diffusion_model.apply( {'params': paramètres}, num_images=5, diffusion_steps=5, méthode=diffusion_model.generate)
Exemple audio
import jaximport jax.numpy as jnpfrom nanodl import ArrayDataset, DataLoaderfrom nanodl import Whisper, WhisperDataParallelTrainer# Paramètres de données facticesbatch_size = 8max_length = 50embed_dim = 256 vocab_size = 1000 # Générer des données : remplacer par des données réelles tokenisées/quantifiéesdummy_targets = jnp.ones((101, max_length), dtype=jnp.int32)dummy_inputs = jnp.ones((101, max_length, embed_dim))dataset = ArrayDataset(dummy_inputs, dummy_targets)dataloader = DataLoader(dataset, batch_size=batch_size, shuffle= Vrai, drop_last=False)# modèle paramètreshyperparams = {'num_layers' : 1,'hidden_dim' : 256,'num_heads' : 2,'feedforward_dim' : 256,'dropout' : 0.1,'vocab_size' : 1000,'embed_dim' : embed_dim,'max_length' : max_length ,'start_token' : 0,'end_token' : 50, }# Initialiser modelmodel = Whisper(**hyperparams)# Formation sur votre datatrainer = WhisperDataParallelTrainer(model, dummy_inputs.shape, dummy_targets.shape, 'params.pkl')trainer.train(dataloader, 2, dataloader)# Exemple d'inferenceparams = trainer.load_params('params.pkl')# pour plus d'un échantillon, utilisez souvent model.generate_batchtranscripts = model.apply({'params ': paramètres}, dummy_inputs[:1], méthode=model.generate)
Exemple de modèle de récompense pour RLHF
importer nanodlimport jax.numpy as jnpfrom nanodl import ArrayDataset, DataLoaderfrom nanodl import Mistral, RewardModel, RewardDataParallelTrainer# Générer un databatch_size factice = 8max_length = 10# Remplacer par datadummy_chosen = jnp.ones ((101, max_length), dtype=jnp.int32) factice_rejecté = jnp.zeros((101, max_length), dtype=jnp.int32)# Créer un ensemble de données et un dataloaderdataset = ArrayDataset(dummy_chosen, dummy_rejected)dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) # paramètres du modèlehyperparams = {'num_layers' : 1,'hidden_dim' : 256,'num_heads' : 2,'feedforward_dim' : 256,'dropout' : 0.1,'vocab_size' : 1000,'embed_dim' : 256,'max_length' : max_length,'start_token' : 0, 'end_token' : 50, "num_groups" : 2,'window_size': 5,'shift_size': 2}# Initialiser le modèle de récompense à partir de Mistralmodel = Mistral(**hyperparams)reward_model = RewardModel(model, dim=hyperparams['hidden_dim'], dropout=0.1)# Entraîner la récompense modeltrainer = RewardDataParallelTrainer (reward_model, dummy_chosen.shape, 'reward_model_weights.pkl')trainer.train(dataloader, 5, dataloader)params = trainer.load_params('reward_model_weights.pkl')# Appelez comme vous le feriez pour un modèle Flax normalrewards = récompense_model.apply({'params': params}, dummy_chosen, rngs={'dropout' : nanodl.time_rng_key()})
Exemple d'ACP
import nanodlfrom nanodl import PCA# Utiliser les données réellesdata = nanodl.normal(shape=(1000, 10))# Initialiser et entraîner le modèle PCApca = PCA(n_components=2)pca.fit(data)# Obtenir PCA transformstransformed_data = pca.transform( data)# Obtenir des transformations inversesoriginal_data = pca.inverse_transform(transformed_data)# Échantillon du distributionX_sampled = pca.sample(n_samples=1000, key=Aucun)
Ceci est encore en développement, fonctionne très bien mais des difficultés sont attendues, et les contributions sont donc fortement encouragées !
Apportez vos modifications sans modifier les modèles de conception.
Écrivez des tests pour vos modifications si nécessaire.
Installez localement avec pip3 install -e .
.
Exécutez des tests avec python3 -m unittest discover -s tests
.
Soumettez ensuite une pull request.
Les cotisations peuvent être versées sous diverses formes :
Rédaction de documentation.
Correction de bugs.
Documents de mise en œuvre.
Rédaction de tests à haute couverture.
Optimisation des codes existants.
Expérimenter et soumettre des exemples concrets à la section exemples.
Signaler des bugs.
Répondre aux problèmes signalés.
Rejoignez le serveur Discord pour en savoir plus.
Le nom « NanoDL » signifie Nano Deep Learning. Les modèles explosent en taille, ce qui empêche les experts et les entreprises disposant de ressources limitées de créer des modèles flexibles sans coûts prohibitifs. Suite au succès des modèles Phi, l'objectif à long terme est de créer et de former des versions nano de tous les modèles disponibles, tout en garantissant qu'ils rivalisent avec les modèles originaux en termes de performances, avec un nombre total de paramètres ne dépassant pas 1 milliard. Les poids entraînés seront mis à disposition via cette bibliothèque. Toute forme de parrainage, de financement contribuera aux ressources de formation. Vous pouvez soit parrainer via GitHub ici, soit contacter via [email protected].
Pour citer ce référentiel :
@software{nanodl2024github, author = {Henry Ndubuaku}, title = {NanoDL: A Jax-based library for designing and training transformer models from scratch.}, url = {http://github.com/hmunachi/nanodl}, year = {2024}, }