作者:Henry Ndubuaku(Discord 和文档徽章可点击)
注意:代码在教学上的实施是以重复为代价的。每个模型都有目的地包含在一个文件中,没有文件间依赖关系。
开发和训练基于 Transformer 的模型通常是资源密集型且耗时的,AI/ML 专家经常需要针对特定问题构建这些模型的较小规模版本。 Jax 是一个资源匮乏但功能强大的框架,它加速了神经网络的开发并抽象了分布式训练,但 Jax 中用于 Transformer 开发的现有资源有限。 NanoDL 通过以下功能应对这一挑战:
各种各样的块和层,有助于从头开始创建定制的变压器模型。
广泛的模型选择,如 Gemma、LlaMa3、Mistral、GPT3、GPT4(推断)、T5、Whisper、ViT、Mixers、CLIP 等。
数据并行分布式训练器在多个 GPU 或 TPU 上建模,无需手动训练循环。
数据加载器,使 Jax/Flax 的数据处理过程更加简单和有效。
Flax/Jax 中没有的层,例如 RoPE、GQA、MQA 和 SWin 注意力,允许更灵活的模型开发。
GPU/TPU 加速的经典 ML 模型,如 PCA、KMeans、回归、高斯过程等。
Jax 中的真正随机数生成器不需要详细的代码。
一系列用于 NLP 和计算机视觉任务的高级算法,例如高斯模糊、BLEU、Tokenizer 等。
每个模型都包含在一个文件中,没有外部依赖关系,因此源代码也可以轻松使用。
Jax 中的真正随机数生成器不需要详细代码(下一节中显示的示例)。
存储库中存在实验性和/或未完成的功能(例如 MAMBA、KAN、BitNet、GAT 和 RLHF),这些功能尚未通过软件包提供,但可以从此存储库中复制。欢迎对我们的任何讨论、问题和拉取请求线程提供反馈!请在 Discord 中报告任何功能请求、问题、问题或疑虑,或者只是让我们知道您正在做什么!
您将需要Python 3.9或更高版本,以及可运行的JAX安装、FLAX安装、OPTAX安装(有GPU支持运行训练,没有则只能支持创作)。模型可以在 CPU 上设计和测试,但训练器都是分布式数据并行的,这需要具有 1 到 N 个 GPUS/TPUS 的 GPU。对于仅 CPU 版本的 JAX:
pip install --upgrade pip # To support manylinux2010 wheels. pip install jax flax optax
然后,从 PyPi 安装 nanodl:
pip install nanodl
我们提供 nanodl API 的各种示例用法。
import jaximport nanodlimport jax.numpy as jnpfrom nanodl import ArrayDataset, DataLoaderfrom nanodl import GPT4, GPTDataParallelTrainer# 准备数据集batch_size = 8max_length = 50vocab_size = 1000# 创建随机数据data = nanodl.uniform(shape=(batch_size, max_length), minval=0, maxval=vocab_size-1).astype(jnp.int32)# 转移以创建下一个标记预测数据集dummy_inputs, dummy_targets = data[:, :-1], data[:, 1:]# 创建数据集和数据加载器dataset = ArrayDataset(dummy_inputs , dummy_targets)dataloader = DataLoader(数据集,batch_size=batch_size,shuffle=True, drop_last=False)# 模型参数hyperparams = {'num_layers': 1,'hidden_dim': 256,'num_heads': 2,'feedforward_dim': 256,'dropout': 0.1,'vocab_size': vocab_size,'embed_dim': 256 ,'max_length': max_length,'start_token': 0,'结束令牌': 50, }# 推断的 GPT4 模型 model = GPT4(**hyperparams)trainer = GPTDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl')trainer.train(train_loader=dataloader, num_epochs=100, val_loader=dataloader) # 使用实际的 val 数据# 从起始令牌生成start_tokens = jnp.array([[123, 456]])# 记得加载训练好的参数 params = trainer.load_params('params.pkl')outputs = model.apply( {'params': params}, start_tokens,rngs={'dropout': nanodl.time_rng_key()}, method=model.generate)
愿景示例
导入nanodlimport jax.numpy为jnpfrom nanodl导入ArrayDataset,DataLoaderfrom nanodl导入DiffusionModel,DiffusionDataParallelTrainerimage_size = 32block_深度= 2batch_size = 8widths = [32,64,128]input_shape =(101,image_size,image_size,3)images = nanodl.normal(shape=input_shape)# 使用自己的图像dataset = ArrayDataset(images) dataloader = DataLoader(数据集,batch_size=batch_size,shuffle=True,drop_last=False) # 创建扩散模型diffusion_model = DiffusionModel(image_size, widths, block_depth)# 在数据训练器上进行训练 = DiffusionDataParallelTrainer(diffusion_model, 输入形状=图像.形状, weights_filename='params.pkl', Learning_rate=1e-4)trainer.train(dataloader, 10)# 生成一些样本:每个模型都是一个 Flax.linen 模块# 像平常一样使用params = trainer.load_params('params.pkl') generated_images = iteration_model.apply( {'参数':参数}, 图片数量=5, 扩散步数=5, 方法=diffusion_model.generate)
音频示例
import jaximport jax.numpy as jnpfrom nanodl import ArrayDataset, DataLoaderfrom nanodl import Whisper, WhisperDataParallelTrainer# 虚拟数据参数batch_size = 8max_length = 50embed_dim = 256 vocab_size = 1000 # 生成数据:替换为实际标记化/量化数据dummy_targets = jnp.ones((101,max_length),dtype=jnp.int32)dummy_inputs = jnp.ones((101,max_length,embed_dim))数据集=ArrayDataset(dummy_inputs,dummy_targets)dataloader = DataLoader(数据集,batch_size=batch_size,shuffle= True, drop_last=False)# 模型参数hyperparams = {'num_layers':1,'hidden_dim':256,'num_heads':2,'feedforward_dim':256,'dropout':0.1,'vocab_size':1000,'embed_dim':embed_dim,'max_length':max_length ,'start_token': 0,'end_token': 50、 }# 初始化模型 model = Whisper(**hyperparams)# 在 datatrainer 上进行训练 = WhisperDataParallelTrainer(model, 虚拟输入.形状, 虚拟目标.形状, 'params.pkl')trainer.train(dataloader, 2, dataloader)# 样本推理params = trainer.load_params('params.pkl')# 对于多个样本,经常使用 model.generate_batchtranscripts = model.apply({'params ': 参数}, dummy_inputs[:1], method=model.generate)
RLHF 奖励模型示例
import nanodlimport jax.numpy as jnpfrom nanodl import ArrayDataset, DataLoaderfrom nanodl import Mistral, RewardModel, RewardDataParallelTrainer# 生成虚拟数据batch_size = 8max_length = 10# 替换为实际标记化数据dummy_chosen = jnp.ones((101, max_length), dtype=jnp.int32)dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32)# 创建数据集和数据加载器dataset = ArrayDataset(dummy_chosen, dummy_rejected)dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) # 模型参数hyperparams = {'num_layers':1,'hidden_dim':256,'num_heads':2,'feedforward_dim':256,'dropout':0.1,'vocab_size':1000,'embed_dim':256,'max_length':max_length,'开始令牌':0,'结束令牌': 50,'num_groups': 2,'window_size': 5,'shift_size': 2}# 从 Mistralmodel = Mistral(**hyperparams)reward_model = RewardModel(model, dim=hyperparams['hidden_dim'], dropout= 初始化奖励模型0.1)# 训练奖励模型trainer = RewardDataParallelTrainer(reward_model, dummy_chosen.shape, 'reward_model_weights.pkl')trainer.train(dataloader, 5, dataloader)params = trainer.load_params('reward_model_weights.pkl')# 像常规 Flax 模型一样调用rewards =reward_model.apply({'params' : params}, dummy_chosen, rngs={'dropout': nanodl.time_rng_key()})
主成分分析示例
import nanodlfrom nanodl import PCA# 使用实际数据data = nanodl.normal(shape=(1000, 10))# 初始化并训练PCA模型pca = PCA(n_components=2)pca.fit(data)# 获取PCA变换transformed_data = pca.transform( data)# 获取反向变换soriginal_data = pca.inverse_transform(transformed_data)# 样本distributionX_sampled = pca.sample(n_samples=1000, key=None)
这仍然处于开发阶段,效果很好,但预计会很粗糙,因此强烈鼓励贡献!
在不改变设计模式的情况下进行更改。
如有必要,为您的更改编写测试。
使用pip3 install -e .
。
使用python3 -m unittest discover -s tests
运行测试。
然后提交拉取请求。
贡献可以通过多种形式进行:
编写文档。
修复错误。
实施文件。
编写高覆盖率测试。
优化现有代码。
实验并向示例部分提交真实示例。
报告错误。
回应报告的问题。
加入 Discord 服务器了解更多信息。
“NanoDL”这个名字代表纳米深度学习。模型的规模呈爆炸式增长,因此资源有限的专家和公司无法在无需高昂成本的情况下构建灵活的模型。随着 Phi 模型的成功,长期目标是构建和训练所有可用模型的纳米版本,同时确保它们在性能上与原始模型竞争,参数总数不超过 1B。经过训练的权重将通过该库提供。任何形式的赞助、资助都将有助于提供培训资源。您可以通过 GitHub 进行赞助,也可以通过 [email protected] 联系。
引用这个存储库:
@software{nanodl2024github, author = {Henry Ndubuaku}, title = {NanoDL: A Jax-based library for designing and training transformer models from scratch.}, url = {http://github.com/hmunachi/nanodl}, year = {2024}, }