변환기의 모델 병렬 처리를 위해 JAX에서 xmap
/ pjit
연산자를 사용하는 하이쿠 라이브러리입니다.
병렬 처리 방식은 고속 2D 메시 네트워크로 인해 TPU에서 효율적인 원래 Megatron-LM과 유사합니다. ZeRo 스타일 샤딩을 구현하는 실험적인 모델 버전도 있습니다.
이 라이브러리는 TPUv3에서 최대 약 40B 매개변수까지 확장할 수 있도록 설계되었으며, 그 이상에서는 다양한 병렬 처리 전략을 사용해야 합니다. 이에 대해서는 GPT-NeoX 또는 DeepSpeed와 같은 다른 구현을 참조하세요.
향후 연구 방향 중 하나는 이 코드베이스를 swarm-jax와 통합하여 파이프라인 병렬 처리를 통해 추가 확장성을 달성하는 것입니다.
12-07-21 : 미세조정 가이드 추가
The Pile에서 훈련된 60억 개의 매개변수, 자동 회귀 텍스트 생성 모델.
슬림 웨이트 다운로드(bf16 웨이트만, 추론용, 9GB)
전체 가중치 다운로드(최적화 매개변수 포함, 61GB)
부분적으로 훈련된 체크포인트
Colab 데모
웹 데모
Aran의 블로그 게시물
이 프로젝트는 EleutherAI의 지원을 받아 TPU Research Cloud에서 아낌없이 제공하는 컴퓨팅이 없었다면 불가능했을 것입니다.
Cloud TPU VM 알파(이제 공개적으로 사용 가능)에 대한 조기 액세스를 제공한 Google의 Cloud TPU 팀에 감사드립니다.
어떤 식으로든 도움을 주신 모든 분들께 감사드립니다(알파벳순으로 나열).
GPT-J-6B의 가중치는 Apache 라이센스 버전 2.0에 따라 라이센스가 부여됩니다.
초매개변수 | 값 |
---|---|
n_매개변수 | 6,053,381,344 |
n_레이어 | 28* |
d_모델 | 4,096 |
d_ff | 16,384 |
n_heads | 16 |
d_head | 256 |
n_ctx | 2,048 |
n_vocab | 50,257(GPT-2/3과 동일한 토크나이저) |
위치 인코딩 | RoPE(로터리 위치 인코딩) |
RoPE 치수 | 64 |
*
각 레이어는 하나의 피드포워드 블록과 하나의 셀프 어텐션 블록으로 구성됩니다.
모델은 모델 차원이 4096이고 피드포워드 차원이 16384인 28개의 레이어로 구성됩니다. 모델 차원은 16개의 헤드로 분할되며 각각의 차원은 256입니다. RoPE(회전 위치 인코딩)는 각 헤드의 64 차원에 적용되었습니다. . 모델은 GPT-2/GPT-3과 동일한 BPE 세트를 사용하여 50257의 토큰화 어휘로 학습되었습니다.
성능을 기준으로 대략적으로 정렬된 모델 또는 사용할 수 없는 경우 FLOP를 기준으로 모델을 정렬합니다.
모델 | 가중치 | 훈련 FLOP | 람바다 PPL ↓ | 람바다 Acc ↑ | 위노그란데 ↑ | 헬라스와그 ↑ | 피카 ↑ | 데이터 세트 크기(GB) |
---|---|---|---|---|---|---|---|---|
가능성 | ✔ | 0 | ~많이 | ~0% | 50% | 25% | 25% | 0 |
GPT-3-Ada‡ | ✘ | ----- | 9.95 | 51.6% | 52.9% | 43.4% | 70.5% | ----- |
GPT-2-1.5B | ✔ | ----- | 10.63 | 51.21% | 59.4% | 50.9% | 70.8% | 40 |
GPTNeo-1.3B‡ | ✔ | 3.0e21 | 7시 50분 | 57.2% | 55.0% | 48.9% | 71.1% | 825 |
메가트론-2.5B* | ✘ | 2.4e21 | ----- | 61.7% | ----- | ----- | ----- | 174 |
GPTNeo-2.7B‡ | ✔ | 6.8e21 | 5.63 | 62.2% | 56.5% | 55.8% | 73.0% | 825 |
GPT-3-1.3B*‡ | ✘ | 2.4e21 | 5.44 | 63.6% | 58.7% | 54.7% | 75.1% | ~800 |
GPT-3-배비지‡ | ✘ | ----- | 5.58 | 62.4% | 59.0% | 54.5% | 75.5% | ----- |
메가트론-8.3B* | ✘ | 7.8e21 | ----- | 66.5% | ----- | ----- | ----- | 174 |
GPT-3-2.7B*‡ | ✘ | 4.8e21 | 4.60 | 67.1% | 62.3% | 62.8% | 75.6% | ~800 |
메가트론-11B† | ✔ | 1.0e22 | ----- | ----- | ----- | ----- | ----- | 161 |
GPT-J-6B ‡ | ✔ | 1.5e22 | 3.99 | 69.7% | 65.3% | 66.1% | 76.5% | 825 |
GPT-3-6.7B*‡ | ✘ | 1.2e22 | 4.00 | 70.3% | 64.5% | 67.4% | 78.0% | ~800 |
GPT-3-퀴리‡ | ✘ | ----- | 4.00 | 69.3% | 65.6% | 68.5% | 77.9% | ----- |
GPT-3-13B*‡ | ✘ | 2.3e22 | 3.56 | 72.5% | 67.9% | 70.9% | 78.5% | ~800 |
GPT-3-175B*‡ | ✘ | 3.1e23 | 3.00 | 76.2% | 70.2% | 78.9% | 81.0% | ~800 |
GPT-3-다빈치‡ | ✘ | ----- | 3.0 | 75% | 72% | 78% | 80% | ----- |
고퍼 230B* | ✘ | 6.31E+23 | ----- | 74.50% | 70.10% | 79.20% | 81.80% | 1344 |
MT-NLG 530B*‡ | ✘ | ----- | ----- | 76.6% | 73.0% | 80.2% | 82.0% | ----- |
*
해당 작성자가 보고한 평가 번호를 나타내며, 다른 모든 번호는 릴리스된 가중치 또는 API 액세스를 사용하여 lm-evaluation-harness를 실행하여 제공됩니다. 미묘한 구현 차이와 다양한 제로 샷 작업 프레이밍으로 인해 이러한 항목은 직접 비교할 수 없을 수도 있습니다. 자세한 내용은 이 블로그 게시물을 참조하세요.
†
Megatron-11B 모델은 비교 가능한 측정항목을 제공하지 않으며, 발표된 가중치를 사용한 여러 구현에서는 세대 품질 및 평가를 재현하지 않습니다. (1 2 3 참조) 따라서 평가는 시도되지 않았습니다.
‡
이러한 모델은 테스트 세트 오염 가능성이 포함된 데이터로 학습되었습니다. OpenAI GPT-3 모델은 특정 테스트 세트에 대한 교육 데이터 중복을 제거하지 못한 반면, GPT-Neo 모델과 이 모델은 어떤 테스트 세트에 대해서도 중복 제거되지 않은 The Pile에서 교육되었습니다.
이 저장소에 있는 대부분의 스크립트는 TPU-VM 아키텍처에서 임의의 코드를 실행할 수 있는 가상 머신인 TPU에서 실행되도록 설계되었습니다. 대부분의 스크립트는 TPU, SSH를 실행하여 종속성을 설정하고 로컬 디렉터리에서 코드를 복사한 다음 RPC 호출을 수락할 수 있는 Ray 작업자를 시작하도록 설계되었습니다.
TPUVM은 실행 중인 모델 학습 단계 및 평가, 체크포인트 저장 및 로드를 처리하는 반면, Driver Python 프로그램은 데이터 로드 및 일반 조정(예: 체크포인트 저장 시기 등)을 처리합니다.
이는 RPC 지연 시간과 데이터 전송 비용을 최소화하기 위해 대부분의 스크립트( train.py
, eval_harness.py
등)가 TPU와 동일한 지역의 GCE 가상 머신에서 실행될 것으로 예상한다는 의미입니다. 다른 스크립트(일반적으로 device_sample.py
, device_serve.py
또는 device_train.py
등 --tpu
인수를 사용하지 않는 스크립트)는 TPUVM에서 직접 실행될 것으로 예상됩니다. device_* 스크립트는 v3-8에서만 작동하며 더 큰 Pod에서는 작동하지 않습니다.
또한 제공된 체크포인트(GPT-J-6B의 경우 샤드가 8개 있음)를 GPU에서 실행하는 경우와 같이 더 작은 수로 변환하는 방법에 대한 예( resharding_example.py
)가 있습니다.
모델을 미세 조정하려면 TPU VM에서 device_train.py
실행하세요. TPU v3-8을 사용하면 ~5000개 토큰/초의 속도로 미세 조정할 수 있으며, 이는 중소 규모 데이터세트에 충분합니다.
철저한 미세 조정 지침을 보려면 단계별 가이드를 읽어보세요.
이 라이브러리에는 JAX 버전에 대한 몇 가지 특정 요구 사항이 있습니다. 특히 v1 모델(GPT-J 6B 포함)을 사용하려면 jax==0.2.12
가 필요합니다. 이는 차례로 jaxlib==0.1.68
에 따라 달라집니다. 그렇지 않으면 암호 같은 Xmap 오류가 발생합니다.
그러나 v2 모델 코드(공개적으로 릴리스된 가중치 없음)를 사용하려면 최신 JAX 버전을 사용할 수 있습니다.
이 저장소를 인용하려면:
@misc{mesh-transformer-jax,
author = {Wang, Ben},
title = {{Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX}},
howpublished = {url{https://github.com/kingoflolz/mesh-transformer-jax}},
year = 2021,
month = May
}
GPT-J-6B의 가중치를 인용하려면:
@misc{gpt-j,
author = {Wang, Ben and Komatsuzaki, Aran},
title = {{GPT-J-6B: A 6 Billion Parameter Autoregressive Language Model}},
howpublished = {url{https://github.com/kingoflolz/mesh-transformer-jax}},
year = 2021,
month = May
}
이 저장소나 사전 훈련된 가중치를 사용하여 멋진 작업을 수행한다면 이에 대해 듣고 싶습니다. 자유롭게 github 문제를 열거나 이메일(프로필)을 통해 문의하세요.