Penulis: Henry Ndubuaku (Lencana Discord & Dokumen dapat diklik)
N/B: Kode diterapkan secara pedagogis dengan mengorbankan pengulangan. Setiap model sengaja dimasukkan ke dalam file tanpa ketergantungan antar file.
Mengembangkan dan melatih model berbasis transformator biasanya membutuhkan banyak sumber daya dan waktu, serta pakar AI/ML sering kali perlu membuat versi skala kecil dari model ini untuk masalah tertentu. Jax, kerangka kerja dengan sumber daya rendah namun kuat, mempercepat pengembangan jaringan saraf dan pelatihan terdistribusi abstrak, tetapi sumber daya yang ada untuk pengembangan transformator di Jax terbatas. NanoDL mengatasi tantangan ini dengan fitur-fitur berikut:
Beragam blok dan lapisan, memfasilitasi pembuatan model transformator yang disesuaikan dari awal.
Berbagai pilihan model seperti Gemma, LlaMa3, Mistral, GPT3, GPT4 (disimpulkan), T5, Whisper, ViT, Mixers, CLIP dll.
Model pelatih yang terdistribusi secara paralel data pada beberapa GPU atau TPU, tanpa memerlukan loop pelatihan manual.
Dataloader, menjadikan proses penanganan data untuk Jax/Flax lebih mudah dan efektif.
Lapisan yang tidak ditemukan di Flax/Jax, seperti RoPE, GQA, MQA, dan SWin perhatian, memungkinkan pengembangan model yang lebih fleksibel.
Model ML klasik dengan akselerasi GPU/TPU seperti PCA, KMeans, Regression, Gaussian Processes, dll.
Generator nomor acak sejati di Jax yang tidak memerlukan kode verbose.
Berbagai algoritma canggih untuk tugas NLP dan visi komputer, seperti Gaussian Blur, BLEU, Tokenizer, dll.
Setiap model terkandung dalam satu file tanpa ketergantungan eksternal, sehingga kode sumbernya juga dapat digunakan dengan mudah.
Generator angka acak sejati di Jax yang tidak memerlukan kode verbose (contoh ditunjukkan di bagian selanjutnya).
Terdapat fitur eksperimental dan/atau belum selesai (seperti MAMBA, KAN, BitNet, GAT, dan RLHF) di repo yang belum tersedia melalui paket, namun dapat disalin dari repo ini. Masukan pada setiap diskusi, masalah, dan rangkaian permintaan penarikan kami disambut baik! Silakan laporkan permintaan fitur, masalah, pertanyaan, atau kekhawatiran apa pun di Discord, atau beri tahu kami apa yang sedang Anda kerjakan!
Anda memerlukan Python 3.9 atau lebih baru, dan instalasi JAX yang berfungsi, instalasi FLAX, instalasi OPTAX (dengan dukungan GPU untuk menjalankan pelatihan, tanpa hanya dapat mendukung kreasi). Model dapat dirancang dan diuji pada CPU tetapi semua model pelatih adalah Data-Paralel Terdistribusi yang memerlukan GPU dengan 1 hingga N GPUS/TPUS. Untuk JAX versi khusus CPU:
pip install --upgrade pip # To support manylinux2010 wheels. pip install jax flax optax
Kemudian, instal nanodl dari PyPi:
pip install nanodl
Kami menyediakan berbagai contoh penggunaan API nanodl.
import jaximport nanodlimport jax.numpy as jnpfrom nanodl import ArrayDataset, DataLoaderfrom nanodl import GPT4, GPTDataParallelTrainer# Mempersiapkan kumpulan data Andabatch_size = 8max_length = 50vocab_size = 1000# Buat data acakdata = nanodl.uniform(shape=(batch_size, max_length), minval=0, maxval=vocab_size-1).astype(jnp.int32)# Pergeseran untuk membuat kumpulan data prediksi token berikutnyadummy_inputs, dummy_targets = data[:, :-1], data[:, 1:]# Membuat kumpulan data dan dataloaderdataset = ArrayDataset(dummy_inputs , dummy_targets)dataloader = DataLoader(kumpulan data, batch_size=batch_size, shuffle=Benar, drop_last=False)# parameter modelhyperparams = {'num_layers': 1,'hidden_dim': 256,'num_heads': 2,'feedforward_dim': 256,'dropout': 0.1,'vocab_size': vocab_size,' semat_dim': 256,'panjang_maks': max_length,'start_token': 0,'end_token': 50, }# Model model GPT4 yang disimpulkan = GPT4(**hyperparams)trainer = GPTDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl')trainer.train(train_loader=dataloader, num_epochs=100, val_loader=dataloader) # gunakan data val aktual # Menghasilkan dari awal tokenstart_tokens = jnp.array([[123, 456]])# Ingatlah untuk memuat parameter yang dilatih params = trainer.load_params('params.pkl')outputs = model.apply( {'params': params}, start_tokens,rngs={'dropout': nanodl.time_rng_key()}, metode=model.generate)
Contoh visi
import nanodlimport jax.numpy as jnpfrom nanodl import ArrayDataset, DataLoaderfrom nanodl import DiffusionModel, DiffusionDataParallelTrainerimage_size = 32block_ depth = 2batch_size = 8widths = [32, 64, 128]input_shape = (101, image_size, image_size, 3)images = nanodl.normal(shape=input_shape)# Gunakan imagesdataset Anda sendiri = ArrayDataset(images) dataloader = DataLoader(kumpulan data, batch_size=batch_size, shuffle=Benar, drop_last=False) # Buat model difusidiffusion_model = DiffusionModel(ukuran_gambar, lebar, kedalaman_blok)# Pelatihan pada datatrainer Anda = DiffusionDataParallelTrainer(diffusion_model, input_shape=gambar.bentuk, bobot_namafile='params.pkl', learning_rate=1e-4)trainer.train(dataloader, 10)# Hasilkan beberapa sampel: Setiap model adalah modul Flax.linen# Gunakan seperti biasaparams = trainer.load_params('params.pkl')generated_images = diffusion_model.apply( {'param': param}, jumlah_gambar=5, langkah_difusi=5, metode=diffusion_model.generate)
Contoh audio
import jaximport jax.numpy as jnpfrom nanodl import ArrayDataset, DataLoaderfrom nanodl import Whisper, WhisperDataParallelTrainer# Parameter data tiruanbatch_size = 8max_length = 50embed_dim = 256 vocab_size = 1000 # Hasilkan data: ganti dengan data yang diberi token/terkuantisasidummy_targets = jnp.ones((101, max_length), dtype=jnp.int32)dummy_inputs = jnp.ones((101, max_length, embed_dim))dataset = ArrayDataset(dummy_inputs, dummy_targets)dataloader = DataLoader(dataset, batch_size=batch_size, shuffle= Benar, drop_last=False)# model parameterhyperparams = {'num_layers': 1,'hidden_dim': 256,'num_heads': 2,'feedforward_dim': 256,'dropout': 0.1,'vocab_size': 1000,'embed_dim': embed_dim,'max_length': max_length ,'token_mulai': 0,'token_akhir': 50, }# Inisialisasi modelmodel = Whisper(**hyperparams)# Pelatihan pada datatrainer Anda = WhisperDataParallelTrainer(model, dummy_inputs.bentuk, dummy_targets.bentuk, 'params.pkl')trainer.train(dataloader, 2, dataloader)# Contoh inferenceparams = trainer.load_params('params.pkl')# untuk lebih dari satu sampel, sering menggunakan model.generate_batchtranscripts = model.apply({'params ': param}, dummy_inputs[:1], metode=model.hasilkan)
Contoh Model Hadiah untuk RLHF
import nanodlimport jax.numpy as jnpfrom nanodl import ArrayDataset, DataLoaderfrom nanodl import Mistral, RewardModel, RewardDataParallelTrainer# Hasilkan dummy databatch_size = 8max_length = 10# Ganti dengan data yang diberi token yang sebenarnyadummy_chosen = jnp.ones((101, max_length), dtype=jnp.int32) dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32)# Membuat dataset dan dataloaderdataset = ArrayDataset(dummy_chosen, dummy_rejected)dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) # parameter modelhyperparams = {'jumlah_lapisan': 1,'hidden_dim': 256,'num_heads': 2,'feedforward_dim': 256,'dropout': 0,1,'vocab_size': 1000,'embed_dim': 256,'max_length': max_length,'start_token': 0, 'end_token': 50,'num_groups': 2,'window_size': 5,'shift_size': 2}# Inisialisasi model reward dari Mistralmodel = Mistral(**hyperparams)reward_model = RewardModel(model, dim=hyperparams['hidden_dim'], dropout=0.1)# Latih rewardnya modeltrainer = RewardDataParallelTrainer(reward_model, dummy_chosen.shape, 'reward_model_weights.pkl')trainer.train(dataloader, 5, dataloader)params = trainer.load_params('reward_model_weights.pkl')# Panggil seperti yang Anda lakukan pada Flax modelrewards = reward_model.apply({'params': params} biasa, dummy_chosen, rngs={'putus sekolah': nanodl.time_rng_key()})
Contoh PCA
import nanodlfrom nanodl import PCA# Gunakan data aktualdata = nanodl.normal(shape=(1000, 10))# Inisialisasi dan latih model PCApca = PCA(n_components=2)pca.fit(data)# Dapatkan PCA transformstransformed_data = pca.transform( data)# Dapatkan reverse transformsoriginal_data = pca.inverse_transform(transformed_data)# Contoh dari distributionX_sampled = pca.sample(n_samples=1000, kunci=Tidak Ada)
Ini masih dalam pengembangan, berfungsi dengan baik tetapi diharapkan ada kekasaran, dan oleh karena itu kontribusi sangat dianjurkan!
Buat perubahan Anda tanpa mengubah pola desain.
Tulis tes untuk perubahan Anda jika perlu.
Instal secara lokal dengan pip3 install -e .
.
Jalankan tes dengan python3 -m unittest discover -s tests
.
Kemudian kirimkan permintaan tarik.
Kontribusi dapat diberikan dalam berbagai bentuk:
Menulis dokumentasi.
Memperbaiki bug.
makalah implementasi.
Menulis tes dengan cakupan tinggi.
Mengoptimalkan kode yang ada.
Bereksperimen dan mengirimkan contoh dunia nyata ke bagian contoh.
Melaporkan bug.
Menanggapi masalah yang dilaporkan.
Bergabunglah dengan Server Perselisihan untuk mengetahui lebih lanjut.
Nama "NanoDL" adalah singkatan dari Nano Deep Learning. Ukuran model semakin meningkat, oleh karena itu para ahli dan perusahaan yang memiliki sumber daya terbatas tidak dapat membuat model yang fleksibel tanpa biaya yang mahal. Menyusul kesuksesan model Phi, tujuan jangka panjangnya adalah membuat dan melatih versi nano dari semua model yang tersedia, sekaligus memastikan model tersebut bersaing dengan model asli dalam hal performa, dengan jumlah total parameter tidak melebihi 1 miliar. Beban terlatih akan tersedia melalui perpustakaan ini. Segala bentuk sponsorship, pendanaan akan membantu sumber daya pelatihan. Anda dapat mensponsori melalui GitHub di sini atau menghubungi melalui [email protected].
Mengutip repositori ini:
@software{nanodl2024github, author = {Henry Ndubuaku}, title = {NanoDL: A Jax-based library for designing and training transformer models from scratch.}, url = {http://github.com/hmunachi/nanodl}, year = {2024}, }