ไลบรารีไฮกุที่ใช้ตัวดำเนินการ xmap
/ pjit
ใน JAX สำหรับโมเดลความขนานของหม้อแปลง
รูปแบบความเท่าเทียมนั้นคล้ายคลึงกับ Megatron-LM ดั้งเดิม ซึ่งมีประสิทธิภาพบน TPU เนื่องจากเครือข่ายตาข่าย 2d ความเร็วสูง นอกจากนี้ยังมีเวอร์ชันทดลองซึ่งใช้การแบ่งส่วนสไตล์ ZeRo
ไลบรารีนี้ออกแบบมาเพื่อความสามารถในการปรับขนาดได้สูงสุดประมาณ 40B พารามิเตอร์บน TPUv3 ซึ่งนอกเหนือจากนั้นควรใช้กลยุทธ์การทำงานแบบขนานที่แตกต่างกัน ดูการใช้งานอื่นๆ เช่น GPT-NeoX หรือ DeepSpeed สำหรับสิ่งนั้น
ทิศทางหนึ่งในอนาคตสำหรับการวิจัยคือการรวมโค้ดเบสนี้เข้ากับ swarm-jax เพื่อให้บรรลุถึงความสามารถในการขยายขนาดเพิ่มเติมด้วยความขนานของไปป์ไลน์
12-07-21 : เพิ่มคำแนะนำในการปรับแต่งอย่างละเอียด
โมเดลการสร้างข้อความแบบถอยหลังอัตโนมัติพารามิเตอร์ 6 พันล้านพารามิเตอร์ที่ได้รับการฝึกบน The Pile
ดาวน์โหลดสลิมตุ้มน้ำหนัก (เฉพาะตุ้มน้ำหนัก bf16 สำหรับการอนุมาน 9GB)
ดาวน์โหลดน้ำหนักเต็ม (รวมถึงพารามิเตอร์เครื่องมือเพิ่มประสิทธิภาพ 61GB)
จุดตรวจที่ได้รับการฝึกอบรมบางส่วน
การสาธิต Colab
การสาธิตเว็บ
โพสต์ในบล็อกของอรัญ
โปรเจ็กต์นี้คงเป็นไปไม่ได้หากไม่ได้รับการประมวลผลจาก TPU Research Cloud ที่ได้รับความช่วยเหลือจาก EleutherAI
ขอขอบคุณทีม Cloud TPU ที่ Google ที่ให้สิทธิ์เข้าถึง Cloud TPU VM อัลฟ่าก่อนใคร (ขณะนี้พร้อมใช้งานแบบสาธารณะแล้ว!)
ขอขอบคุณทุกคนที่ช่วยเหลือไม่ทางใดก็ทางหนึ่ง (เรียงตามตัวอักษร):
น้ำหนักของ GPT-J-6B ได้รับอนุญาตภายใต้เวอร์ชัน 2.0 ของ Apache License
ไฮเปอร์พารามิเตอร์ | ค่า |
---|---|
n_พารามิเตอร์ | 6,053,381,344 |
n_ชั้น | 28* |
d_model | 4,096 |
d_ff | 16,384 |
n_heads | 16 |
d_head | 256 |
n_ctx | 2,048 |
n_คำศัพท์ | 50,257 (โทเค็นไนเซอร์เดียวกันกับ GPT-2/3) |
การเข้ารหัสตำแหน่ง | การเข้ารหัสตำแหน่งโรตารี (RoPE) |
ขนาดเชือก | 64 |
*
แต่ละชั้นประกอบด้วยบล็อกฟีดฟอร์เวิร์ดหนึ่งบล็อกและบล็อกความสนใจตนเองหนึ่งบล็อก
แบบจำลองประกอบด้วย 28 เลเยอร์โดยมีมิติแบบจำลอง 4096 และมิติฟีดไปข้างหน้า 16384 มิติแบบจำลองแบ่งออกเป็น 16 หัว โดยแต่ละหัวมีขนาด 256 มีการใช้การเข้ารหัสตำแหน่งแบบหมุน (RoPE) กับ 64 มิติของแต่ละหัว . โมเดลนี้ได้รับการฝึกฝนด้วยคำศัพท์โทเค็น 50257 โดยใช้ชุด BPE ชุดเดียวกันกับ GPT-2/GPT-3
โมเดลจะจัดเรียงคร่าวๆ ตามประสิทธิภาพ หรือตาม FLOP หากไม่มี
แบบอย่าง | ตุ้มน้ำหนัก | การฝึกอบรม FLOP | แลมบาดา พีพีแอล ↓ | แลมบาดาตาม↑ | วิโนแกรนด์ ↑ | เฮลลาสวัก ↑ | PIQA ↑ | ขนาดชุดข้อมูล (GB) |
---|---|---|---|---|---|---|---|---|
โอกาส | 0 | ~มาก | ~0% | 50% | 25% | 25% | 0 | |
GPT-3-อาดา‡ | - | 9.95 | 51.6% | 52.9% | 43.4% | 70.5% | - | |
GPT-2-1.5B | - | 10.63 | 51.21% | 59.4% | 50.9% | 70.8% | 40 | |
GPTNEO-1.3B‡ | 3.0e21 | 7.50 | 57.2% | 55.0% | 48.9% | 71.1% | 825 | |
เมกะตรอน-2.5B* | 2.4e21 | - | 61.7% | - | - | - | 174 | |
GPTNEO-2.7B‡ | 6.8e21 | 5.63 | 62.2% | 56.5% | 55.8% | 73.0% | 825 | |
GPT-3-1.3B*‡ | 2.4e21 | 5.44 | 63.6% | 58.7% | 54.7% | 75.1% | ~800 | |
GPT-3-แบบเบจ‡ | - | 5.58 | 62.4% | 59.0% | 54.5% | 75.5% | - | |
เมกะตรอน-8.3B* | 7.8e21 | - | 66.5% | - | - | - | 174 | |
GPT-3-2.7B*‡ | 4.8e21 | 4.60 | 67.1% | 62.3% | 62.8% | 75.6% | ~800 | |
เมกะตรอน-11B† | 1.0e22 | - | - | - | - | - | 161 | |
GPT-J-6B ‡ | 1.5e22 | 3.99 | 69.7% | 65.3% | 66.1% | 76.5% | 825 | |
GPT-3-6.7B*‡ | 1.2e22 | 4.00 น | 70.3% | 64.5% | 67.4% | 78.0% | ~800 | |
GPT-3-คูรี‡ | - | 4.00 น | 69.3% | 65.6% | 68.5% | 77.9% | - | |
GPT-3-13B*‡ | 2.3e22 | 3.56 | 72.5% | 67.9% | 70.9% | 78.5% | ~800 | |
GPT-3-175B*‡ | 3.1จ23 | 03.00 น | 76.2% | 70.2% | 78.9% | 81.0% | ~800 | |
GPT-3-ดาวินชี่‡ | - | 3.0 | 75% | 72% | 78% | 80% | - | |
โกเฟอร์ 230B* | 6.31E+23 | - | 74.50% | 70.10% | 79.20% | 81.80% | 1344 | |
MT-NLG 530B*‡ | - | - | 76.6% | 73.0% | 80.2% | 82.0% | - |
*
หมายถึงหมายเลขการประเมินที่รายงานโดยผู้เขียนที่เกี่ยวข้อง หมายเลขอื่นๆ ทั้งหมดได้มาจากการดำเนินการ lm-evalue-harness ด้วยน้ำหนักที่ปล่อยออกมาหรือด้วยการเข้าถึง API เนื่องจากความแตกต่างในการใช้งานที่ละเอียดอ่อนและการวางเฟรมงานแบบ Zero Shot ที่แตกต่างกัน สิ่งเหล่านี้จึงอาจเทียบไม่ได้โดยตรง ดูโพสต์บล็อกนี้สำหรับรายละเอียดเพิ่มเติม
†
โมเดล Megatron-11B ไม่มีหน่วยเมตริกที่เทียบเคียงได้ และการใช้งานหลายอย่างโดยใช้ตุ้มน้ำหนักที่ปล่อยออกมาไม่ได้สร้างคุณภาพการสร้างและการประเมินขึ้นมาใหม่ (ดู 1 2 3) จึงไม่พยายามประเมินผล
‡
แบบจำลองเหล่านี้ได้รับการฝึกอบรมเกี่ยวกับข้อมูลที่มีการปนเปื้อนของชุดทดสอบที่เป็นไปได้ โมเดล OpenAI GPT-3 ล้มเหลวในการขจัดข้อมูลซ้ำซ้อนการฝึกสำหรับชุดการทดสอบบางชุด ในขณะที่รุ่น GPT-Neo และรุ่นนี้ได้รับการฝึกบน The Pile ซึ่งไม่ได้รับการขจัดความซ้ำซ้อนกับชุดการทดสอบใด ๆ
สคริปต์ส่วนใหญ่ในพื้นที่เก็บข้อมูลนี้ได้รับการออกแบบมาให้ทำงานบน TPU ซึ่งภายใต้สถาปัตยกรรม TPU-VM นั้นเป็นเครื่องเสมือนที่สามารถเรียกใช้โค้ดที่กำหนดเองได้ สคริปต์ส่วนใหญ่ได้รับการออกแบบให้หมุน TPU, SSH เข้าไปเพื่อตั้งค่าการขึ้นต่อกันและคัดลอกโค้ดจากไดเร็กทอรีในเครื่อง จากนั้นเริ่มการทำงานของ Ray ซึ่งสามารถรับการเรียก RPC ได้
TPUVM จัดการขั้นตอนการฝึกโมเดลที่กำลังรันและการประเมินผล การบันทึกและการโหลดจุดตรวจสอบ ในขณะที่โปรแกรม Python ของไดรเวอร์จะจัดการการโหลดข้อมูลและการจัดการทั่วไป (เช่น เมื่อใดที่ต้องบันทึกจุดตรวจสอบ ฯลฯ)
ซึ่งหมายความว่าสคริปต์ส่วนใหญ่ ( train.py
, eval_harness.py
ฯลฯ ) คาดว่าจะทำงานบนเครื่องเสมือน GCE ในภูมิภาคเดียวกับ TPU เพื่อลดเวลาในการตอบสนองของ RPC และต้นทุนการถ่ายโอนข้อมูลให้เหลือน้อยที่สุด สคริปต์อื่นๆ (โดยปกติแล้วจะไม่ใช้อาร์กิวเมนต์ --tpu
เช่น device_sample.py
, device_serve.py
หรือ device_train.py
) คาดว่าจะทำงานโดยตรงบน TPUVM สคริปต์ device_* ใช้งานได้กับ v3-8 เท่านั้น และไม่สามารถใช้ได้กับพ็อดที่ใหญ่กว่า
นอกจากนี้ยังมีตัวอย่าง ( resharding_example.py
) ของวิธีแปลงจุดตรวจสอบที่ให้ไว้ (ซึ่งมี 8 ชิ้นส่วนในกรณีของ GPT-J-6B) ให้เป็นจำนวนที่น้อยกว่า เช่น เมื่อทำงานบน GPU
หากต้องการปรับแต่งโมเดลอย่างละเอียด ให้เรียกใช้ device_train.py
บน TPU VM เมื่อใช้ TPU v3-8 คุณสามารถปรับแต่งแบบละเอียดได้ในอัตรา ~5,000 โทเค็น/วินาที ซึ่งน่าจะเพียงพอสำหรับชุดข้อมูลขนาดเล็กถึงขนาดกลาง
โปรดอ่านคำแนะนำทีละขั้นตอนเพื่อดูคำแนะนำในการปรับแต่งอย่างละเอียด
โปรดทราบว่าไลบรารีนี้มีข้อกำหนดเฉพาะบางประการสำหรับเวอร์ชัน JAX โดยเฉพาะ หากต้องการใช้รุ่น v1 (รวมถึง GPT-J 6B) จำเป็นต้องมี jax==0.2.12
สิ่งนี้จะขึ้นอยู่กับ jaxlib==0.1.68
หากยังไม่เสร็จสิ้น คุณจะได้รับข้อผิดพลาด xmap ที่เป็นความลับ
อย่างไรก็ตาม หากต้องการใช้โค้ดโมเดล v2 (ไม่มีน้ำหนักที่เปิดเผยต่อสาธารณะ) คุณสามารถใช้ JAX เวอร์ชันใหม่ล่าสุดได้
หากต้องการอ้างอิงที่เก็บนี้:
@misc{mesh-transformer-jax,
author = {Wang, Ben},
title = {{Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX}},
howpublished = {url{https://github.com/kingoflolz/mesh-transformer-jax}},
year = 2021,
month = May
}
หากต้องการอ้างอิงน้ำหนักของ GPT-J-6B:
@misc{gpt-j,
author = {Wang, Ben and Komatsuzaki, Aran},
title = {{GPT-J-6B: A 6 Billion Parameter Autoregressive Language Model}},
howpublished = {url{https://github.com/kingoflolz/mesh-transformer-jax}},
year = 2021,
month = May
}
หากคุณใช้พื้นที่เก็บข้อมูลนี้หรือตุ้มน้ำหนักที่ได้รับการฝึกมาล่วงหน้าเพื่อทำสิ่งเจ๋งๆ เรายินดีอย่างยิ่งที่จะได้ยินเกี่ยวกับเรื่องนี้ อย่าลังเลที่จะเปิดปัญหา GitHub หรือติดต่อทางอีเมล (ในโปรไฟล์)