Autor: Henry Ndubuaku (os emblemas do Discord e Docs são clicáveis)
N/B: Os códigos são implementados pedagogicamente em detrimento da repetição. Cada modelo é propositalmente contido em um arquivo sem dependências entre arquivos.
O desenvolvimento e o treinamento de modelos baseados em transformadores normalmente consomem muitos recursos e tempo, e os especialistas em IA/ML frequentemente precisam construir versões em menor escala desses modelos para problemas específicos. Jax, uma estrutura poderosa, mas com poucos recursos, acelera o desenvolvimento de redes neurais e abstrai o treinamento distribuído, mas os recursos existentes para o desenvolvimento de transformadores em Jax são limitados. NanoDL aborda esse desafio com os seguintes recursos:
Uma ampla gama de blocos e camadas, facilitando a criação de modelos de transformadores customizados do zero.
Uma extensa seleção de modelos como Gemma, LlaMa3, Mistral, GPT3, GPT4 (inferido), T5, Whisper, ViT, Mixers, CLIP etc.
Modelos de treinadores distribuídos com dados paralelos em várias GPUs ou TPUs, sem a necessidade de loops de treinamento manuais.
Dataloaders, tornando o processo de manipulação de dados para Jax/Flax mais simples e eficaz.
Camadas não encontradas em Flax/Jax, como RoPE, GQA, MQA e SWin atenção, permitindo um desenvolvimento de modelo mais flexível.
Modelos clássicos de ML acelerados por GPU/TPU, como PCA, KMeans, Regressão, Processos Gaussianos, etc.
Geradores de números aleatórios verdadeiros em Jax que não precisam de código detalhado.
Uma variedade de algoritmos avançados para tarefas de PNL e visão computacional, como Gaussian Blur, BLEU, Tokenizer etc.
Cada modelo está contido em um único arquivo sem dependências externas, portanto o código-fonte também pode ser facilmente utilizado.
Geradores de números aleatórios verdadeiros em Jax que não precisam de código detalhado (exemplos mostrados nas próximas seções).
Existem recursos experimentais e/ou inacabados (como MAMBA, KAN, BitNet, GAT e RLHF) no repositório que ainda não estão disponíveis através do pacote, mas podem ser copiados deste repositório. Comentários sobre qualquer um de nossos tópicos de discussão, problema e solicitação pull são bem-vindos! Por favor, relate quaisquer solicitações de recursos, problemas, dúvidas ou preocupações no Discord ou apenas conte-nos no que você está trabalhando!
Você precisará do Python 3.9 ou posterior e da instalação JAX funcional, instalação FLAX, instalação OPTAX (com suporte de GPU para treinamento em execução, sem suporte apenas para criações). Os modelos podem ser projetados e testados em CPUs, mas os treinadores são todos distribuídos em paralelo com dados, o que exigiria uma GPU com 1 a N GPUS/TPUS. Para versão somente CPU do JAX:
pip install --upgrade pip # To support manylinux2010 wheels. pip install jax flax optax
Em seguida, instale o nanodl do PyPi:
pip install nanodl
Fornecemos vários exemplos de uso da API nanodl.
import jaximport nanodlimport jax.numpy as jnpfrom nanodl import ArrayDataset, DataLoaderfrom nanodl import GPT4, GPTDataParallelTrainer# Preparando seu conjunto de dadosbatch_size = 8max_length = 50vocab_size = 1000# Crie dados aleatórios = nanodl.uniform(shape=(batch_size, max_length), minval=0, maxval=vocab_size-1).astype(jnp.int32)# Shift para criar dataset de previsão do próximo tokendummy_inputs, dummy_targets = data[:, :-1], data[:, 1:]# Criar conjunto de dados e dataloaderdataset = ArrayDataset(dummy_inputs , dummy_targets)dataloader = DataLoader(conjunto de dados, batch_size=batch_size, shuffle=True, drop_last=False)# parâmetros do modelohyperparams = {'num_layers': 1,'hidden_dim': 256,'num_heads': 2,'feedforward_dim': 256,'dropout': 0,1,'vocab_size': vocab_size,' embed_dim': 256,'max_length': max_length,'start_token': 0,'end_token': 50, }# Modelo de modelo GPT4 inferido = GPT4(**hyperparams)trainer = GPTDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl')trainer.train(train_loader=dataloader, num_epochs=100, val_loader=dataloader) # use dados val reais # Gerando a partir de um início tokenstart_tokens = jnp.array([[123, 456]])# Lembre-se de carregar os parâmetros treinados params = trainer.load_params('params.pkl')outputs = model.apply( {'params': params}, start_tokens,rngs={'dropout': nanodl.time_rng_key()}, método=model.generate)
Exemplo de visão
importar nanodlimport jax.numpy como jnpfrom nanodl importar ArrayDataset, DataLoaderfrom nanodl importar DiffusionModel, DiffusionDataParallelTrainerimage_size = 32block_profundidade = 2batch_size = 8widths = [32, 64, 128]input_shape = (101, image_size, image_size, 3)images = nanodl.normal(shape=input_shape)# Use seu próprio imagesdataset = ArrayDataset(images) dataloader = DataLoader(conjunto de dados, batch_size=batch_size, shuffle=True, drop_last=False) # Crie um modelo de difusãodiffusion_model = DiffusionModel(image_size, widths, block_thought)# Treinamento em seu datatrainer = DiffusionDataParallelTrainer(diffusion_model, input_shape=imagens.shape, pesos_filename='params.pkl', learning_rate=1e-4)trainer.train(dataloader, 10)# Gere algumas amostras: Cada modelo é um módulo Flax.linen# Use como faria normalmenteparams = trainer.load_params('params.pkl')generated_images = diffusion_model.apply( {'params': params}, num_imagens=5, difusão_passos = 5, método = difusão_model.generate)
Exemplo de áudio
import jaximport jax.numpy as jnpfrom nanodl import ArrayDataset, DataLoaderfrom nanodl import Whisper, WhisperDataParallelTrainer# Parâmetros de dados fictíciosbatch_size = 8max_length = 50embed_dim = 256 vocab_size = 1000 # Gere dados: substitua por dados reais tokenizados/quantizadosdummy_targets = jnp.ones((101, max_length), dtype=jnp.int32)dummy_inputs = jnp.ones((101, max_length, embed_dim))dataset = ArrayDataset(dummy_inputs, dummy_targets)dataloader = DataLoader(conjunto de dados, batch_size=batch_size, shuffle= Verdadeiro, drop_last=Falso)# parâmetros do modelohyperparams = {'num_layers': 1,'hidden_dim': 256,'num_heads': 2,'feedforward_dim': 256,'dropout': 0,1,'vocab_size': 1000,'embed_dim': embed_dim,'max_length': max_length,'start_token': 0,'end_token': 50, }# Inicialize modelmodel = Whisper(**hyperparams)# Treinamento em seu datatrainer = WhisperDataParallelTrainer(model, dummy_inputs.shape, dummy_targets.shape, 'params.pkl')trainer.train(dataloader, 2, dataloader)# Sample inferenceparams = trainer.load_params('params.pkl')# para mais de uma amostra, use frequentemente model.generate_batchtranscripts = model.apply({'params ': parâmetros}, dummy_inputs[:1], método=modelo.generate)
Exemplo de modelo de recompensa para RLHF
importar nanodlimport jax.numpy como jnpfrom nanodl importar ArrayDataset, DataLoaderfrom nanodl importar Mistral, RewardModel, RewardDataParallelTrainer# Gerar dados fictícios databatch_size = 8max_length = 10# Substituir por datadummy_chosen tokenizado real = jnp.ones((101, max_length), dtype=jnp.int32)dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32)# Criar conjunto de dados e dataloaderdataset = ArrayDataset(dummy_chosen, dummy_rejected)dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=Falso) # modelo parâmetroshyperparams = {'num_layers': 1,'hidden_dim': 256,'num_heads': 2,'feedforward_dim': 256,'dropout': 0,1,'vocab_size': 1000,'embed_dim': 256,'max_length': max_length ,'start_token': 0,'end_token': 50,'num_groups': 2,'window_size': 5,'shift_size': 2}# Inicializar modelo de recompensa de Mistralmodel = Mistral(**hyperparams)reward_model = RewardModel(model, dim=hyperparams['hidden_dim'], dropout= 0.1)# Treine o modelo de recompensatrainer = RewardDataParallelTrainer(reward_model, dummy_chosen.shape, 'reward_model_weights.pkl')trainer.train(dataloader, 5, dataloader)params = trainer.load_params('reward_model_weights.pkl')# Chame como faria com um modelo Flax normalrewards = reward_model.apply({'params' : params}, dummy_chosen, rngs={'dropout': nanodl.time_rng_key()})
Exemplo de PCA
importar nanodlfrom nanodl importar PCA# Usar dados reais = nanodl.normal(shape=(1000, 10))# Inicializar e treinar modelo PCApca = PCA(n_components=2)pca.fit(data)# Obter PCA transformstransformed_data = pca.transform( data)# Obtenha transformsoriginal_data reverso = pca.inverse_transform(transformed_data)# Amostra do distribuiçãoX_sampled = pca.sample(n_samples=1000, key=None)
Isso ainda está em desenvolvimento, funciona muito bem, mas são esperadas irregularidades e, portanto, as contribuições são altamente incentivadas!
Faça suas alterações sem alterar os padrões de design.
Escreva testes para suas alterações, se necessário.
Instale localmente com pip3 install -e .
.
Execute testes com python3 -m unittest discover -s tests
.
Em seguida, envie uma solicitação pull.
As contribuições podem ser feitas de diversas formas:
Escrevendo documentação.
Correção de bugs.
Implementando documentos.
Escrever testes de alta cobertura.
Otimizando códigos existentes.
Experimentar e enviar exemplos do mundo real para a seção de exemplos.
Relatando bugs.
Respondendo a problemas relatados.
Junte-se ao servidor Discord para mais.
O nome "NanoDL" significa Nano Deep Learning. Os modelos estão explodindo em tamanho, portanto, impedindo que especialistas e empresas com recursos limitados construam modelos flexíveis sem custos proibitivos. Seguindo o sucesso dos modelos Phi, o objetivo de longo prazo é construir e treinar versões nano de todos os modelos disponíveis, garantindo ao mesmo tempo que concorram com os modelos originais em desempenho, com um número total de parâmetros não superior a 1B. Pesos treinados serão disponibilizados através desta biblioteca. Qualquer forma de patrocínio, o financiamento ajudará nos recursos de treinamento. Você pode patrocinar via GitHub aqui ou entrar em contato via [email protected].
Para citar este repositório:
@software{nanodl2024github, author = {Henry Ndubuaku}, title = {NanoDL: A Jax-based library for designing and training transformer models from scratch.}, url = {http://github.com/hmunachi/nanodl}, year = {2024}, }