著者: Henry Ndubuaku (Discord と Docs のバッジはクリック可能です)
N/B: コードは繰り返しを犠牲にして教育的に実装されています。各モデルは、ファイル間の依存関係を持たずに意図的にファイルに含まれています。
トランスフォーマーベースのモデルの開発とトレーニングは通常、リソースを大量に消費し、時間がかかるため、AI/ML の専門家は特定の問題に対応するためにこれらのモデルの小規模バージョンを構築する必要があることがよくあります。 Jax は低リソースながら強力なフレームワークであり、ニューラル ネットワークの開発を加速し、分散トレーニングを抽象化しますが、Jax でのトランスフォーマー開発のための既存のリソースは限られています。 NanoDL は、次の機能でこの課題に対処します。
幅広いブロックとレイヤーにより、カスタマイズされた変圧器モデルを最初から作成することが容易になります。
Gemma、LlaMa3、Mistral、GPT3、GPT4 (推定)、T5、Whisper、ViT、Mixers、CLIP などのモデルの幅広いセレクション。
データ並列分散トレーナーは、手動トレーニング ループを必要とせずに、複数の GPU または TPU 上でモデルを作成します。
データローダー。Jax/Flax のデータ処理プロセスをより簡単かつ効果的にします。
RoPE、GQA、MQA、SW などの Flax/Jax にはないレイヤーに注目し、より柔軟なモデル開発を可能にします。
PCA、KMeans、回帰、ガウス過程などの GPU/TPU アクセラレーションの古典的な ML モデル。
冗長なコードを必要としない、Jax の真の乱数ジェネレーター。
ガウスぼかし、BLEU、トークナイザーなど、NLP およびコンピューター ビジョン タスク用の一連の高度なアルゴリズム。
各モデルは外部依存関係のない単一のファイルに含まれているため、ソース コードも簡単に使用できます。
冗長なコードを必要としない Jax の真の乱数ジェネレーター (次のセクションで例を示します)。
リポジトリには実験的な機能や未完成の機能 (MAMBA、KAN、BitNet、GAT、RLHF など) があり、これらはパッケージからはまだ利用できませんが、このリポジトリからコピーできます。ディスカッション、問題、プル リクエストのスレッドに関するフィードバックは大歓迎です。機能のリクエスト、問題、質問、懸念事項があれば Discord で報告していただくか、現在取り組んでいることをお知らせください。
Python 3.9 以降、および動作する JAX インストール、FLAX インストール、OPTAX インストール (トレーニングを実行するための GPU サポートあり、なしでは作成のサポートのみ可能) が必要です。モデルは CPU で設計およびテストできますが、トレーナーはすべて分散データ並列であるため、1 ~ N の GPUS/TPUS を備えた GPU が必要になります。 CPU のみのバージョンの JAX の場合:
pip install --upgrade pip # To support manylinux2010 wheels. pip install jax flax optax
次に、PyPi から nanodl をインストールします。
pip install nanodl
nanodl API のさまざまな使用例を提供します。
import jaximport nanodlimport jax.numpy as jnpfrom nanodl import ArrayDataset, DataLoaderfrom nanodl import GPT4, GPTDataParallelTrainer# データセットの準備batch_size = 8max_length = 50vocab_size = 1000# ランダムなデータを作成data = nanodl.uniform(shape=(batch_size, max_length), minval=0, maxval=vocab_size-1).astype(jnp.int32)# 次のトークン予測データセットを作成するためにシフトdummy_inputs, dummy_targets = data[:, :-1], data[:, 1:]# データセットとデータローダーを作成dataset = ArrayDataset(dummy_inputs) 、dummy_targets)dataloader = DataLoader(データセット、batch_size=batch_size、 shuffle=True、drop_last=False)# モデルパラメータhyperparams = {'num_layers': 1,'hidden_dim': 256,'num_heads': 2,'feedforward_dim': 256,'dropout': 0.1,'vocab_size': vocab_size,' embed_dim': 256,'max_length': max_length、'start_token': 0、'end_token': 50、 }# 推論された GPT4 モデル model = GPT4(**hyperparams)trainer = GPTDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl')trainer.train(train_loader=dataloader, num_epochs=100, val_loader=dataloader) # 実際の val データを使用# 開始から生成 tokenstart_tokens = jnp.array([[123, 456]])# トレーニング済みパラメーターを忘れずにロードしてください params = Trainer.load_params('params.pkl')outputs = model.apply( {'params': params}、start_tokens,rngs={'dropout': nanodl.time_rng_key()}、method=model.generate)
ビジョンの例
import nanodlimport jax.numpy as jnpfrom nanodl import ArrayDataset, DataLoaderfrom nanodl import DiffusionModel, DiffusionDataParallelTrainerimage_size = 32block_ Depth = 2batch_size = 8widths = [32, 64, 128]input_shape = (101, image_size, image_size, 3)images = nanodl.normal(shape=input_shape)# 独自の画像を使用しますdataset = ArrayDataset(images) dataloader = DataLoader(データセット、batch_size=batch_size、shuffle=True、drop_last=False) # 拡散モデルを作成しますdiffusion_model = DiffusionModel(image_size, width, block_ Depth)# データトレーナーでトレーニング = DiffusionDataParallelTrainer(diffusion_model, input_shape=画像.形状、 weights_filename='params.pkl', learning_rate=1e-4)trainer.train(dataloader, 10)# いくつかのサンプルを生成します: 各モデルは Flax.linen モジュールです# 通常どおりに使用しますparams = Trainer.load_params('params.pkl')generated_images = diffusion_model.apply( {'params': パラメータ}, num_images=5、 拡散ステップ=5、 メソッド=diffusion_model.generate)
オーディオの例
import jaximport jax.numpy as jnpfrom nanodl import ArrayDataset, DataLoaderfrom nanodl import Whisper, WhisperDataParallelTrainer# ダミー データ パラメーターbatch_size = 8max_length = 50embed_dim = 256 vocab_size = 1000 # データの生成: 実際のトークン化/量子化データに置き換えるdummy_targets = jnp.ones((101, max_length)、dtype=jnp.int32)dummy_inputs = jnp.ones((101, max_length, embed_dim))dataset = ArrayDataset(dummy_inputs, dummy_targets)dataloader = DataLoader(dataset,batch_size=batch_size、shuffle=True、drop_last=False)#モデルパラメータハイパーパラメータ = {'num_layers': 1,'hidden_dim': 256,'num_heads': 2,'feedforward_dim': 256,'dropout': 0.1,'vocab_size': 1000,'embed_dim': embed_dim,'max_length': max_length,'開始トークン': 0、'終了トークン': 50、 }#modelmodel = Whisper(**hyperparams)# datatrainer でのトレーニング = WhisperDataParallelTrainer(model, ダミー_入力.形状、 dummy_targets.shape, 'params.pkl')trainer.train(dataloader, 2, dataloader)# サンプル推論params = Trainer.load_params('params.pkl')# 複数のサンプルの場合、多くの場合、model.generate_batchtranscripts = model.apply({'params ':パラメータ}, dummy_inputs[:1]、メソッド=model.generate)
RLHF の報酬モデルの例
import nanodlimport jax.numpy as jnpfrom nanodl import ArrayDataset, DataLoaderfrom nanodl import Mistral, RewardModel, RewardDataParallelTrainer# ダミーの databatch_size = 8max_length = 10# を生成します 実際のトークン化された datadummy_chosen = jnp.ones((101, max_length), dtype=jnp.int32) を生成しますダミー_拒否 = jnp.zeros((101, max_length), dtype=jnp.int32)# データセットとデータローダーを作成しますdataset = ArrayDataset(dummy_chosen, dummy_rejected)dataloader = DataLoader(dataset,batch_size=batch_size, shuffle=True,drop_last=False) #モデルパラメータhyperparams = {'層数': 1、'hidden_dim': 256、'num_heads': 2、'feedforward_dim': 256、'dropout': 0.1、'vocab_size': 1000、'embed_dim': 256、'max_length': max_length、'start_token': 0、 'end_token': 50、'num_groups': 2,'window_size': 5,'shift_size': 2}# Mistralmodel = Mistral(**hyperparams)reward_model = RewardModel(model, dim=hyperparams['hidden_dim'], dropout=0.1)# から報酬モデルを初期化します# 報酬をトレーニングしますmodeltrainer = RewardDataParallelTrainer(reward_model, dummy_chosen.shape, 'reward_model_weights.pkl')trainer.train(dataloader, 5, dataloader)params =trainer.load_params('reward_model_weights.pkl')# 通常の Flax モデルと同じように呼び出しますrewards = award_model.apply({'params': params}, dummy_chosen、rngs={'ドロップアウト': nanodl.time_rng_key()})
PCAの例
import nanodlfrom nanodl import PCA# 実際のデータを使用data = nanodl.normal(shape=(1000, 10))# PCA モデルを初期化してトレーニングするpca = PCA(n_components=2)pca.fit(data)# PCA 変換を取得transformed_data = pca.transform( data)# 逆変換を取得original_data = pca.inverse_transform(transformed_data)# からのサンプルdistributionX_sampled = pca.sample(n_samples=1000、key=なし)
これはまだ開発中であり、うまく機能しますが、粗さが予想されるため、貢献を強くお勧めします。
デザインパターンを変更せずに変更を加えます。
必要に応じて、変更に対するテストを作成します。
pip3 install -e .
。
python3 -m unittest discover -s tests
を使用してテストを実行します。
次に、プル リクエストを送信します。
貢献はさまざまな形で行うことができます。
ドキュメントの作成。
バグの修正。
論文の実施。
高カバレッジのテストを作成します。
既存のコードの最適化。
実際の例を実験し、例セクションに提出します。
バグの報告。
報告された問題への対応。
さらに詳しく知りたい場合は、Discord サーバーに参加してください。
「NanoDL」という名前は、Nano Deep Learning の略です。モデルのサイズは爆発的に増大しているため、ゲートキーピングの専門家やリソースが限られている企業は、法外なコストをかけずに柔軟なモデルを構築することができません。ファイ モデルの成功に続き、長期的な目標は、パラメータの合計数が 1B を超えないように、パフォーマンスにおいて元のモデルと確実に競合できるようにしながら、利用可能なすべてのモデルのナノ バージョンを構築してトレーニングすることです。トレーニングされたウェイトは、このライブラリを通じて利用できるようになります。あらゆる形式のスポンサーシップや資金提供がトレーニング リソースの提供に役立ちます。こちらから GitHub 経由でスポンサーすることも、[email protected] 経由で連絡することもできます。
このリポジトリを引用するには:
@software{nanodl2024github, author = {Henry Ndubuaku}, title = {NanoDL: A Jax-based library for designing and training transformer models from scratch.}, url = {http://github.com/hmunachi/nanodl}, year = {2024}, }