一个俳句库,使用 JAX 中的xmap
/ pjit
运算符来实现变压器的模型并行性。
并行方案类似于原始的 Megatron-LM,由于高速 2d 网格网络,它在 TPU 上非常高效。还有一个实现 ZeRo 风格分片的实验模型版本。
该库设计用于在 TPUv3 上扩展至约 40B 参数,超出此范围应使用不同的并行策略。请参阅其他实现,例如 GPT-NeoX 或 DeepSpeed。
未来的研究方向之一是将这个代码库与 swarm-jax 集成,以通过管道并行性实现进一步的可扩展性。
12-07-21 :添加了微调指南
在 The Pile 上训练的 60 亿个参数的自回归文本生成模型。
下载 slim 权重(仅限 BF16 权重,用于推理,9GB)
下载完整权重(包括优化器参数,61GB)
部分训练的检查站
Colab 演示
网络演示
阿兰的博客文章
如果没有 TPU 研究云在 EleutherAI 的协助下慷慨提供的计算能力,这个项目就不可能实现。
感谢 Google Cloud TPU 团队提供对 Cloud TPU VM alpha 的早期访问(现已公开发布!)
感谢所有以这种或那种方式提供帮助的人(按字母顺序列出):
GPT-J-6B 的权重根据 Apache 许可证 2.0 版本获得许可。
超参数 | 价值 |
---|---|
n_参数 | 6,053,381,344 |
n_层数 | 28* |
d_模型 | 4,096 |
d_ff | 16,384 |
n_头 | 16 |
d_头 | 256 |
n_ctx | 2,048 |
n_词汇 | 50,257(与 GPT-2/3 相同的标记器) |
位置编码 | 旋转位置编码 (RoPE) |
绳索尺寸 | 64 |
*
每层由一个前馈块和一个自注意力块组成
该模型由 28 层组成,模型维度为 4096,前馈维度为 16384。模型维度分为 16 个头,每个头的维度为 256。每个头的 64 个维度应用了旋转位置编码(RoPE) 。该模型使用与 GPT-2/GPT-3 相同的 BPE 集,使用 50257 的标记化词汇进行训练。
模型大致按性能排序,如果不可用,则按失败次数排序。
模型 | 重量 | 训练失败次数 | 兰巴达 PPL ↓ | 兰巴达加速器 ↑ | 温诺格兰德 ↑ | 海拉斯瓦格 ↑ | PIQA↑ | 数据集大小 (GB) |
---|---|---|---|---|---|---|---|---|
机会 | ✔ | 0 | 〜很多 | ~0% | 50% | 25% | 25% | 0 |
GPT-3-Ada‡ | ✘ | ----- | 9.95 | 51.6% | 52.9% | 43.4% | 70.5% | ----- |
GPT-2-1.5B | ✔ | ----- | 10.63 | 51.21% | 59.4% | 50.9% | 70.8% | 40 |
GPTNeo-1.3B‡ | ✔ | 3.0e21 | 7.50 | 57.2% | 55.0% | 48.9% | 71.1% | 第825章 |
威震天-2.5B* | ✘ | 2.4e21 | ----- | 61.7% | ----- | ----- | ----- | 174 |
GPTNeo-2.7B‡ | ✔ | 6.8e21 | 5.63 | 62.2% | 56.5% | 55.8% | 73.0% | 第825章 |
GPT-3-1.3B*‡ | ✘ | 2.4e21 | 5.44 | 63.6% | 58.7% | 54.7% | 75.1% | ~800 |
GPT-3-巴贝奇‡ | ✘ | ----- | 5.58 | 62.4% | 59.0% | 54.5% | 75.5% | ----- |
威震天-8.3B* | ✘ | 7.8e21 | ----- | 66.5% | ----- | ----- | ----- | 174 |
GPT-3-2.7B*‡ | ✘ | 4.8e21 | 4.60 | 67.1% | 62.3% | 62.8% | 75.6% | ~800 |
威震天 11B† | ✔ | 1.0e22 | ----- | ----- | ----- | ----- | ----- | 161 |
GPT-J-6B ‡ | ✔ | 1.5e22 | 3.99 | 69.7% | 65.3% | 66.1% | 76.5% | 第825章 |
GPT-3-6.7B*‡ | ✘ | 1.2e22 | 4.00 | 70.3% | 64.5% | 67.4% | 78.0% | ~800 |
GPT-3-居里‡ | ✘ | ----- | 4.00 | 69.3% | 65.6% | 68.5% | 77.9% | ----- |
GPT-3-13B*‡ | ✘ | 2.3e22 | 3.56 | 72.5% | 67.9% | 70.9% | 78.5% | ~800 |
GPT-3-175B*‡ | ✘ | 3.1e23 | 3.00 | 76.2% | 70.2% | 78.9% | 81.0% | ~800 |
GPT-3-达芬奇‡ | ✘ | ----- | 3.0 | 75% | 72% | 78% | 80% | ----- |
地鼠230B* | ✘ | 6.31E+23 | ----- | 74.50% | 70.10% | 79.20% | 81.80% | 第1344章 |
MT-NLG 530B*‡ | ✘ | ----- | ----- | 76.6% | 73.0% | 80.2% | 82.0% | ----- |
*
代表各自作者报告的评估数字,所有其他数字均通过使用发布的权重或 API 访问运行 lm-evaluation-harness 来提供。由于细微的实现差异以及不同的零样本任务框架,这些可能无法直接比较。请参阅此博客文章了解更多详细信息。
†
Megatron-11B 模型没有提供可比较的指标,并且使用发布的权重的几种实现无法重现生成质量和评估。 (参见 1 2 3)因此,未尝试进行评估。
‡
这些模型已经使用包含可能的测试集污染的数据进行了训练。 OpenAI GPT-3 模型未能对某些测试集的训练数据进行重复数据删除,而 GPT-Neo 模型以及该模型是在 The Pile 上进行训练的,而 The Pile 尚未针对任何测试集进行重复数据删除。
该存储库中的大多数脚本都设计为在 TPU 上运行,TPU-VM 架构下的 TPU 是可以运行任意代码的虚拟机。大多数脚本旨在启动 TPU、通过 SSH 连接到其中以设置依赖项并从本地目录复制代码,然后启动可以接受 RPC 调用的 Ray 工作线程。
TPUVM 处理运行模型训练步骤和评估、检查点保存和加载,而驱动程序 python 程序处理数据加载和一般编排(例如何时保存检查点等)。
这意味着大多数脚本( train.py
、 eval_harness.py
等)希望在与 TPU 相同区域的 GCE 虚拟机上运行,以最大限度地减少 RPC 延迟和数据传输成本。其他脚本(通常是不带--tpu
参数的脚本,例如device_sample.py
、 device_serve.py
或device_train.py
)期望直接在 TPUVM 上运行。 device_* 脚本仅适用于 v3-8 ,不适用于较大的 pod。
此外,还有一个示例 ( resharding_example.py
),说明如何将提供的检查点(在 GPT-J-6B 的情况下有 8 个分片)转换为较小的数量,例如在 GPU 上运行时。
要微调模型,请在 TPU VM 上运行device_train.py
。使用 TPU v3-8,您可以以约 5000 个令牌/秒的速率进行微调,这对于中小型数据集来说应该足够了。
请阅读分步指南以获取完整的微调说明。
请注意,该库对 JAX 版本有一些特定要求。具体来说,要使用 v1 模型(包括 GPT-J 6B),需要jax==0.2.12
。这又取决于jaxlib==0.1.68
。如果不这样做,您将收到神秘的 xmap 错误
但是,要使用 v2 模型代码(没有公开发布的权重),可以使用最新的 JAX 版本。
引用这个存储库:
@misc{mesh-transformer-jax,
author = {Wang, Ben},
title = {{Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX}},
howpublished = {url{https://github.com/kingoflolz/mesh-transformer-jax}},
year = 2021,
month = May
}
引用GPT-J-6B的重量:
@misc{gpt-j,
author = {Wang, Ben and Komatsuzaki, Aran},
title = {{GPT-J-6B: A 6 Billion Parameter Autoregressive Language Model}},
howpublished = {url{https://github.com/kingoflolz/mesh-transformer-jax}},
year = 2021,
month = May
}
如果您使用此存储库或任何预训练的权重来做一些很酷的事情,我们很乐意听到它。请随意打开 github 问题或通过电子邮件联系(在个人资料中)。