Autor: Henry Ndubuaku (se puede hacer clic en las insignias de Discord y Docs)
N/B: Los códigos se implementan pedagógicamente a expensas de la repetición. Cada modelo está contenido intencionalmente en un archivo sin dependencias entre archivos.
El desarrollo y entrenamiento de modelos basados en transformadores suele requerir mucho tiempo y recursos, y los expertos en IA/ML con frecuencia necesitan crear versiones a menor escala de estos modelos para problemas específicos. Jax, un marco poderoso pero de bajos recursos, acelera el desarrollo de redes neuronales y abstrae la capacitación distribuida, pero los recursos existentes para el desarrollo de transformadores en Jax son limitados. NanoDL aborda este desafío con las siguientes características:
Una amplia gama de bloques y capas, facilitando la creación de modelos de transformadores personalizados desde cero.
Una amplia selección de modelos como Gemma, LlaMa3, Mistral, GPT3, GPT4 (inferido), T5, Whisper, ViT, Mixers, CLIP, etc.
Modelos de entrenadores distribuidos en paralelo con datos en múltiples GPU o TPU, sin necesidad de bucles de entrenamiento manuales.
Cargadores de datos, que hacen que el proceso de manejo de datos para Jax/Flax sea más sencillo y efectivo.
Capas que no se encuentran en Flax/Jax, como RoPE, GQA, MQA y SWin, se presta atención, lo que permite un desarrollo de modelos más flexible.
Modelos de aprendizaje automático clásicos acelerados por GPU/TPU como PCA, KMeans, regresión, procesos gaussianos, etc.
Verdaderos generadores de números aleatorios en Jax que no necesitan el código detallado.
Una gama de algoritmos avanzados para tareas de PNL y visión por computadora, como Gaussian Blur, BLEU, Tokenizer, etc.
Cada modelo está contenido en un único archivo sin dependencias externas, por lo que el código fuente también se puede utilizar fácilmente.
Verdaderos generadores de números aleatorios en Jax que no necesitan el código detallado (los ejemplos se muestran en las siguientes secciones).
Hay características experimentales y/o inacabadas (como MAMBA, KAN, BitNet, GAT y RLHF) en el repositorio que aún no están disponibles a través del paquete, pero se pueden copiar desde este repositorio. ¡Se agradecen los comentarios sobre cualquiera de nuestros hilos de discusión, problemas y solicitudes de extracción! Informe cualquier solicitud de función, problema, pregunta o inquietud en Discord, o simplemente háganos saber en qué está trabajando.
Necesitará Python 3.9 o posterior, y la instalación de JAX, la instalación de FLAX y la instalación de OPTAX en funcionamiento (con soporte de GPU para ejecutar la capacitación, pero solo puede admitir creaciones). Los modelos se pueden diseñar y probar en CPU, pero todos los entrenadores son datos distribuidos en paralelo, lo que requeriría una GPU con 1 a N GPUS/TPUS. Para la versión de JAX solo para CPU:
pip install --upgrade pip # To support manylinux2010 wheels. pip install jax flax optax
Luego, instale nanodl desde PyPi:
pip install nanodl
Proporcionamos varios usos de ejemplo de la API nanodl.
importar jaximport nanodlimport jax.numpy como jnp de nanodl importar ArrayDataset, DataLoader de nanodl importar GPT4, GPTDataParallelTrainer# Preparando su conjunto de datosbatch_size = 8max_length = 50vocab_size = 1000# Crear datos aleatoriosdata = nanodl.uniform(shape=(batch_size, max_length), minval=0, maxval=vocab_size-1).astype(jnp.int32)# Mayús para crear el conjunto de datos de predicción del siguiente tokendummy_inputs, dummy_targets = data[:, :-1], data[:, 1:]# Crear conjunto de datos y cargador de datosdataset = ArrayDataset(dummy_inputs , dummy_targets)cargador de datos = DataLoader(conjunto de datos, tamaño_lote=tamaño_lote, shuffle=True, drop_last=False)# parámetros del modelohyperparams = {'num_layers': 1,'hidden_dim': 256,'num_heads': 2,'feedforward_dim': 256,'dropout': 0.1,'vocab_size': vocab_size,' embed_dim': 256,'max_length': max_length, 'start_token': 0, 'end_token': 50, }# Modelo GPT4 inferido model = GPT4(**hyperparams)trainer = GPTDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl')trainer.train(train_loader=dataloader, num_epochs=100, val_loader=dataloader) # usa datos val reales # Generando desde un inicio tokenstart_tokens = jnp.array([[123, 456]])# Recuerde cargar los parámetros entrenados params = trainer.load_params('params.pkl')outputs = model.apply( {'params': params}, start_tokens,rngs={'dropout': nanodl.time_rng_key()}, método=model.generate)
Ejemplo de visión
importar nanodlimport jax.numpy como jnp de nanodl importar ArrayDataset, DataLoader de nanodl importar DiffusionModel, DiffusionDataParallelTrainerimage_size = 32block_ Depth = 2batch_size = 8widths = [32, 64, 128]input_shape = (101, image_size, image_size, 3)images = nanodl.normal(shape=input_shape)# Utilice su propio conjunto de datos de imágenes = ArrayDataset(imágenes) cargador de datos = Cargador de datos (conjunto de datos, tamaño_lote = tamaño_lote, aleatorio = Verdadero, drop_last = Falso) # Crear modelo de difusióndiffusion_model = DiffusionModel(image_size, widths, block_ Depth)# Entrenamiento en tu datatrainer = DiffusionDataParallelTrainer(diffusion_model, input_shape=imagenes.forma, pesos_filename='params.pkl', learning_rate=1e-4)trainer.train(dataloader, 10)# Genere algunas muestras: cada modelo es un módulo Flax.linen# Úselo como lo haría normalmenteparams = trainer.load_params('params.pkl')generated_images = diffusion_model.apply( {'params': parámetros}, núm_imagenes=5, pasos_difusión=5, método = modelo_difusión.generar)
Ejemplo de audio
import jaximport jax.numpy as jnpfrom nanodl import ArrayDataset, DataLoaderfrom nanodl import Whisper, WhisperDataParallelTrainer# Parámetros de datos ficticiosbatch_size = 8max_length = 50embed_dim = 256 vocab_size = 1000 # Generar datos: reemplazar con datos tokenizados/cuantificados realesdummy_targets = jnp.ones((101, max_length), dtype=jnp.int32)dummy_inputs = jnp.ones((101, max_length, embed_dim))dataset = ArrayDataset(dummy_inputs, dummy_targets)cargador de datos = DataLoader(conjunto de datos, tamaño_por lotes=tamaño_por lotes, aleatorio= Verdadero, drop_last=False)# modelo parámetroshyperparams = {'num_layers': 1,'hidden_dim': 256,'num_heads': 2,'feedforward_dim': 256,'dropout': 0.1,'vocab_size': 1000,'embed_dim': embed_dim,'max_length': max_length ,'start_token': 0,'end_token': 50, }# Inicializar modelmodel = Whisper(**hyperparams)# Entrenamiento en su entrenador de datos = WhisperDataParallelTrainer(model, entradas_ficticias.forma, objetivos_ficticios.forma, 'params.pkl')trainer.train(dataloader, 2, dataloader)# Inferencia de muestraparams = trainer.load_params('params.pkl')# para más de una muestra, a menudo se usa model.generate_batchtranscripts = model.apply({'params ': parámetros}, entradas_ficticias[:1], método=modelo.generar)
Ejemplo de modelo de recompensa para RLHF
importar nanodlimport jax.numpy como jnpdesde nanodl importar ArrayDataset, DataLoaderdesde nanodl importar Mistral, RewardModel, RewardDataParallelTrainer# Generar datos ficticiosbatch_size = 8max_length = 10# Reemplazar con datos tokenizados realesdummy_chosen = jnp.ones((101, max_length), dtype=jnp.int32) tonto_rechazado = jnp.zeros((101, max_length), dtype=jnp.int32)# Crear conjunto de datos y cargador de datosdataset = ArrayDataset(dummy_chosen, dummy_rejected)dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) # parámetros del modelohyperparams = {'num_capas': 1,'hidden_dim': 256,'num_heads': 2,'feedforward_dim': 256,'dropout': 0.1,'vocab_size': 1000,'embed_dim': 256,'max_length': max_length,'start_token': 0, 'end_token': 50, 'num_groups': 2,'window_size': 5,'shift_size': 2}# Inicializar el modelo de recompensa de Mistralmodel = Mistral(**hyperparams)reward_model = RewardModel(model, dim=hyperparams['hidden_dim'], dropout=0.1)# Entrenar la recompensa modeltrainer = RewardDataParallelTrainer(reward_model, dummy_chosen.shape, 'reward_model_weights.pkl')trainer.train(dataloader, 5, dataloader)params = trainer.load_params('reward_model_weights.pkl')# Llame como lo haría con un modelo Flax normalrewards = recompensa_model.apply({'params': params}, dummy_chosen, rngs={'abandono': nanodl.time_rng_key()})
ejemplo de PCA
import nanodlfrom nanodl import PCA# Usar datos realesdata = nanodl.normal(shape=(1000, 10))# Inicializar y entrenar el modelo PCApca = PCA(n_components=2)pca.fit(data)# Obtener transformaciones PCAtransformed_data = pca.transform( data)# Obtener transformaciones inversasoriginal_data = pca.inverse_transform(transformed_data)# Muestra del distribuciónX_sampled = pca.sample(n_samples=1000, clave=Ninguno)
Esto todavía está en desarrollo, funciona muy bien, pero se espera que sea rudo y, por lo tanto, se recomiendan encarecidamente las contribuciones.
Realice sus cambios sin cambiar los patrones de diseño.
Escriba pruebas para sus cambios si es necesario.
Instale localmente con pip3 install -e .
.
Ejecute pruebas con python3 -m unittest discover -s tests
.
Luego envíe una solicitud de extracción.
Las contribuciones se pueden realizar de diversas formas:
Redacción de documentación.
Corrección de errores.
Documentos de implementación.
Redacción de pruebas de alta cobertura.
Optimización de códigos existentes.
Experimentar y enviar ejemplos del mundo real a la sección de ejemplos.
Informar errores.
Responder a los problemas reportados.
Únase al servidor de Discord para obtener más información.
El nombre "NanoDL" significa Nano Deep Learning. Los modelos están aumentando en tamaño, por lo que los expertos y las empresas con recursos limitados impiden que construyan modelos flexibles sin costos prohibitivos. Tras el éxito de los modelos Phi, el objetivo a largo plazo es construir y entrenar versiones nano de todos los modelos disponibles, garantizando al mismo tiempo que compitan con los modelos originales en rendimiento, con un número total de parámetros que no supere los 1.000 millones. Los pesos entrenados estarán disponibles a través de esta biblioteca. Cualquier forma de patrocinio o financiación ayudará con los recursos de formación. Puede patrocinar a través de GitHub aquí o comunicarse con [email protected].
Para citar este repositorio:
@software{nanodl2024github, author = {Henry Ndubuaku}, title = {NanoDL: A Jax-based library for designing and training transformer models from scratch.}, url = {http://github.com/hmunachi/nanodl}, year = {2024}, }