ผู้แต่ง: Henry Ndubuaku (สามารถคลิกป้าย Discord & Docs ได้)
N/B: รหัสต่างๆ ได้รับการนำไปใช้ในเชิงการสอน โดยต้องเสียค่าใช้จ่ายในการทำซ้ำ แต่ละโมเดลมีจุดประสงค์ในไฟล์โดยไม่มีการพึ่งพาระหว่างไฟล์
โดยทั่วไปการพัฒนาและการฝึกอบรมโมเดลที่ใช้หม้อแปลงไฟฟ้านั้นต้องใช้ทรัพยากรมากและใช้เวลานาน และผู้เชี่ยวชาญด้าน AI/ML มักจำเป็นต้องสร้างโมเดลเหล่านี้ในเวอร์ชันที่มีขนาดเล็กกว่าสำหรับปัญหาเฉพาะ Jax ซึ่งเป็นเฟรมเวิร์กทรัพยากรต่ำแต่ทรงพลัง ช่วยเร่งการพัฒนาโครงข่ายประสาทเทียมและการฝึกอบรมแบบกระจายบทคัดย่อ แต่ทรัพยากรที่มีอยู่สำหรับการพัฒนาหม้อแปลงไฟฟ้าใน Jax นั้นมีจำกัด NanoDL จัดการกับความท้าทายนี้ด้วยคุณสมบัติดังต่อไปนี้:
บล็อกและเลเยอร์ที่หลากหลาย อำนวยความสะดวกในการสร้างสรรค์โมเดลหม้อแปลงแบบกำหนดเองตั้งแต่เริ่มต้น
มีรุ่นให้เลือกมากมาย เช่น Gemma, LlaMa3, Mistral, GPT3, GPT4 (อนุมาน), T5, Whisper, ViT, Mixers, CLIP เป็นต้น
โมเดลเทรนเนอร์แบบกระจายข้อมูลแบบขนานบน GPU หรือ TPU หลายตัว โดยไม่จำเป็นต้องใช้ลูปการฝึกด้วยตนเอง
Dataloaders ทำให้กระบวนการจัดการข้อมูลสำหรับ Jax/Flax ตรงไปตรงมาและมีประสิทธิภาพมากขึ้น
เลเยอร์ที่ไม่พบใน Flax/Jax เช่น RoPE, GQA, MQA และ SWin ช่วยให้การพัฒนาโมเดลมีความยืดหยุ่นมากขึ้น
โมเดล ML คลาสสิกที่เร่งด้วย GPU/TPU เช่น PCA, KMeans, Regression, Gaussian Processes เป็นต้น
เครื่องกำเนิดตัวเลขสุ่มที่แท้จริงใน Jax ซึ่งไม่ต้องการรหัสรายละเอียด
อัลกอริธึมขั้นสูงที่หลากหลายสำหรับงาน NLP และคอมพิวเตอร์วิทัศน์ เช่น Gaussian Blur, BLEU, Tokenizer เป็นต้น
แต่ละโมเดลจะอยู่ในไฟล์เดียวโดยไม่มีการพึ่งพาภายนอก ดังนั้นซอร์สโค้ดจึงสามารถใช้งานได้ง่าย
เครื่องกำเนิดตัวเลขสุ่มที่แท้จริงใน Jax ซึ่งไม่ต้องการรหัสรายละเอียด (ตัวอย่างแสดงในส่วนถัดไป)
มีคุณลักษณะทดลองและ/หรือที่ยังไม่เสร็จ (เช่น MAMBA, KAN, BitNet, GAT และ RLHF) ใน repo ซึ่งยังไม่พร้อมใช้งานผ่านแพ็คเกจ แต่สามารถคัดลอกได้จาก repo นี้ ยินดีรับฟังข้อเสนอแนะเกี่ยวกับการสนทนา ปัญหา และคำขอดึงข้อมูลของเรา! กรุณารายงานคำขอคุณสมบัติ ปัญหา คำถามหรือข้อกังวลใน Discord หรือเพียงแจ้งให้เราทราบว่าคุณกำลังทำอะไรอยู่!
คุณจะต้องใช้ Python 3.9 หรือใหม่กว่า และการติดตั้ง JAX ที่ใช้งานได้ การติดตั้ง FLAX การติดตั้ง OPTAX (พร้อมการรองรับ GPU สำหรับการรันการฝึกอบรม โดยไม่รองรับเฉพาะการสร้างสรรค์เท่านั้น) โมเดลสามารถออกแบบและทดสอบบน CPU ได้ แต่อุปกรณ์ฝึกสอนเป็นแบบ Distributed Data-Parallel ทั้งหมด ซึ่งจะต้องใช้ GPU ที่มี 1 ถึง N GPUS/TPUS สำหรับ JAX เวอร์ชัน CPU เท่านั้น:
pip install --upgrade pip # To support manylinux2010 wheels. pip install jax flax optax
จากนั้นติดตั้ง nanodl จาก PyPi:
pip install nanodl
เรามีตัวอย่างการใช้งาน nanodl API มากมาย
นำเข้า jaximport nanodlimport jax.numpy เป็น jnpfrom nanodl นำเข้า ArrayDataset, DataLoaderfrom nanodl นำเข้า GPT4, GPTDataParallelTrainer# การเตรียมชุดข้อมูลของคุณbatch_size = 8max_length = 50vocab_size = 1,000# สร้าง datadata แบบสุ่ม = nanodl.uniform(shape=(batch_size, max_length), minval=0, maxval=vocab_size-1).astype(jnp.int32)# Shift เพื่อสร้างชุดข้อมูลการทำนายโทเค็นถัดไปdummy_inputs, dummy_targets = data[:, :-1], data[:, 1:]# สร้างชุดข้อมูลและ dataloaderdataset = ArrayDataset(dummy_inputs , dummy_targets)dataloader = DataLoader(ชุดข้อมูล, batch_size=batch_size, shuffle=True, drop_last=False)# พารามิเตอร์โมเดลhyperparams = {'num_layers': 1,'hidden_dim': 256,'num_heads': 2,'feedforward_dim': 256,'dropout': 0.1,'vocab_size': vocab_size,' embed_dim': 256,'max_length': max_length,'start_token': 0,'end_token': 50, }# โมเดล GPT4 ที่อนุมาน model = GPT4(**hyperparams)trainer = GPTDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl')trainer.train(train_loader=dataloader, num_epochs=100, val_loader=dataloader) # ใช้ข้อมูล val จริง # การสร้างจากโทเค็นเริ่มต้นstart_tokens = jnp.array([[123, 456]])# อย่าลืมโหลดพารามิเตอร์ที่ได้รับการฝึกอบรม params = trainer.load_params('params.pkl')outputs = model.apply( {'params': params}, start_tokens,rngs={'dropout': nanodl.time_rng_key()}, method=model.generate)
ตัวอย่างการมองเห็น
นำเข้า nanodlimport jax.numpy เป็น jnpfrom nanodl นำเข้า ArrayDataset, DataLoader จาก nanodl นำเข้า DiffusionModel, DiffusionDataParallelTrainerimage_size = 32block_ledge = 2batch_size = 8widths = [32, 64, 128]input_shape = (101, image_size, image_size, 3)images = nanodl.normal(shape=input_shape)# ใช้ชุดข้อมูลรูปภาพของคุณเอง = ArrayDataset (รูปภาพ) dataloader = DataLoader (ชุดข้อมูล, batch_size=batch_size, shuffle=True, drop_last=False) # สร้าง diffusion modeldiffusion_model = DiffusionModel(image_size, widths, block_ allowance)# การฝึกอบรมเกี่ยวกับ datatrainer ของคุณ = DiffusionDataParallelTrainer(diffusion_model, input_shape=images.รูปร่าง Weights_filename='params.pkl', Learning_rate=1e-4)trainer.train(dataloader, 10)# สร้างตัวอย่าง: แต่ละรุ่นเป็นโมดูล Flax.linen# ใช้ตามปกติที่คุณจะ params = trainer.load_params('params.pkl')generated_images = diffusion_model.apply( {'พารามิเตอร์': พารามิเตอร์}, num_images=5, การแพร่กระจาย_ขั้นตอน=5, วิธีการ=diffusion_model.สร้าง)
ตัวอย่างเสียง
นำเข้า jaximport jax.numpy เป็น jnpfrom nanodl นำเข้า ArrayDataset, DataLoaderfrom nanodl นำเข้า Whisper, WhisperDataParallelTrainer# พารามิเตอร์ข้อมูลจำลอง batch_size = 8max_length = 50embed_dim = 256 vocab_size = 1,000 # สร้างข้อมูล: แทนที่ด้วย datadummy_targets โทเค็น / ปริมาณจริง = jnp.ones((101, max_length), dtype=jnp.int32)dummy_inputs = jnp.ones((101, max_length, embed_dim))ชุดข้อมูล = ArrayDataset(dummy_inputs, dummy_targets)dataloader = DataLoader (ชุดข้อมูล, batch_size=batch_size, shuffle= จริง drop_last=False)# พารามิเตอร์โมเดลhyperparams = {'num_layers': 1,'hidden_dim': 256,'num_heads': 2,'feedforward_dim': 256,'dropout': 0.1,'vocab_size': 1,000,'embed_dim': embed_dim,'max_length': max_length,'start_token': 0,'end_token': 50, }# เริ่มต้น modelmodel = Whisper(**hyperparams)# การฝึกอบรมเกี่ยวกับ datatrainer ของคุณ = WhisperDataParallelTrainer(model, dummy_inputs.รูปร่าง, dummy_targets.รูปร่าง, 'params.pkl')trainer.train(dataloader, 2, dataloader)# ตัวอย่าง inferenceparams = trainer.load_params('params.pkl')# สำหรับตัวอย่างมากกว่าหนึ่งตัวอย่าง มักใช้ model.generate_batchtranscripts = model.apply({'params ': พารามิเตอร์}, dummy_inputs[:1], วิธีการ=model.สร้าง)
ตัวอย่างโมเดลรางวัลสำหรับ RLHF
นำเข้า nanodlimport jax.numpy เป็น jnpfrom nanodl นำเข้า ArrayDataset, DataLoaderfrom nanodl นำเข้า Mistral, RewardModel, RewardDataParallelTrainer# สร้างหุ่นจำลอง databatch_size = 8max_length = 10# แทนที่ด้วยโทเค็นจริง datadummy_chosen = jnp.ones ((101, max_length) dtype=jnp.int32)dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32)# สร้างชุดข้อมูลและ dataloaderdataset = ArrayDataset(dummy_chosen, dummy_rejected)dataloader = DataLoader(ชุดข้อมูล, bat_size=batch_size, shuffle=True, drop_last=False) # รุ่น พารามิเตอร์hyperparams = {'num_layers': 1,'hidden_dim': 256,'num_heads': 2,'feedforward_dim': 256,'dropout': 0.1,'vocab_size': 1,000,'embed_dim': 256,'max_length': max_length ,'start_token': 0,'end_token': 50,'num_groups': 2,'window_size': 5,'shift_size': 2}# เริ่มต้นโมเดลรางวัลจาก Mistralmodel = Mistral(**hyperparams)reward_model = RewardModel(model, dim=hyperparams[' Hidden_dim'], dropout=0.1)# ฝึกเทรนเนอร์โมเดลรางวัล = RewardDataParallelTrainer(reward_model, dummy_chosen.shape, 'reward_model_weights.pkl')trainer.train(dataloader, 5, dataloader)params = trainer.load_params('reward_model_weights.pkl')# โทรเหมือนกับที่คุณทำกับ Flax modelrewards ปกติ = รางวัล_model.apply( {'params': params}, dummy_chosen, rngs={'ออกกลางคัน': nanodl.time_rng_key()})
ตัวอย่าง PCA
นำเข้า nanodlfrom nanodl นำเข้า PCA# ใช้ datadata จริง = nanodl.normal(shape=(1000, 10))# เริ่มต้นและฝึกอบรม PCA modelpca = PCA(n_components=2)pca.fit(data)# รับ PCA Transformerstransformed_data = pca.transform( data)# รับการแปลงย้อนกลับOriginal_data = pca.inverse_transform(transformed_data)# ตัวอย่างจาก distributionX_sampled = pca.sample (n_samples=1,000, คีย์=ไม่มี)
สิ่งนี้ยังอยู่ในช่วงการพัฒนา ใช้งานได้ดี แต่คาดว่าจะมีความหยาบ ดังนั้นจึงสนับสนุนให้มีส่วนร่วมอย่างสูง!
ทำการเปลี่ยนแปลงโดยไม่ต้องเปลี่ยนรูปแบบการออกแบบ
เขียนการทดสอบการเปลี่ยนแปลงของคุณหากจำเป็น
ติดตั้งในเครื่องด้วย pip3 install -e .
-
รันการทดสอบด้วย python3 -m unittest discover -s tests
จากนั้นจึงส่งคำขอดึง
การบริจาคสามารถทำได้หลายรูปแบบ:
การเขียนเอกสาร
แก้ไขข้อบกพร่อง
เอกสารการดำเนินงาน
การเขียนแบบทดสอบความครอบคลุมสูง
การเพิ่มประสิทธิภาพรหัสที่มีอยู่
การทดลองและส่งตัวอย่างจากโลกแห่งความเป็นจริงไปยังส่วนตัวอย่าง
รายงานข้อบกพร่อง
ตอบสนองต่อประเด็นที่ได้รับรายงาน
เข้าร่วมเซิร์ฟเวอร์ Discord เพื่อรับข้อมูลเพิ่มเติม
ชื่อ "NanoDL" ย่อมาจาก Nano Deep Learning โมเดลมีขนาดใหญ่มาก ดังนั้นผู้เชี่ยวชาญด้านการเฝ้าประตูและบริษัทที่มีทรัพยากรจำกัดจึงไม่สามารถสร้างโมเดลที่ยืดหยุ่นได้โดยไม่มีต้นทุนที่ห้ามปราม หลังจากความสำเร็จของโมเดล Phi เป้าหมายระยะยาวคือการสร้างและฝึกฝนเวอร์ชันนาโนของโมเดลที่มีอยู่ทั้งหมด ขณะเดียวกันก็รับประกันว่าโมเดลเหล่านี้จะแข่งขันกับโมเดลดั้งเดิมในด้านประสิทธิภาพ โดยมีจำนวนพารามิเตอร์รวมไม่เกิน 1B ตุ้มน้ำหนักที่ฝึกแล้วจะมีให้บริการผ่านห้องสมุดนี้ การสนับสนุนทุกรูปแบบ เงินทุนจะช่วยในเรื่องทรัพยากรการฝึกอบรม คุณสามารถสนับสนุนผ่าน GitHub ที่นี่ หรือติดต่อผ่าน [email protected]
หากต้องการอ้างอิงที่เก็บนี้:
@software{nanodl2024github, author = {Henry Ndubuaku}, title = {NanoDL: A Jax-based library for designing and training transformer models from scratch.}, url = {http://github.com/hmunachi/nanodl}, year = {2024}, }