概要|クイックインストール|亜麻はどのように見えますか? |ドキュメント
2024年にリリースされたFlax NNXは、JAXのニューラルネットワークの作成、検査、デバッグ、分析を容易にするように設計された新しい簡素化されたFlax APIです。 Pythonリファレンスセマンティクスのファーストクラスのサポートを追加することにより、これを達成します。これにより、ユーザーは通常のPythonオブジェクトを使用してモデルを表現でき、参照共有と変動を可能にします。
Flax NNXは、2020年にGoogle Brainのエンジニアと研究者によってJAXチームと緊密なコラボレーションでリリースされたFlax Linen APIから発展しました。
Flax NNXの詳細については、専用のFlaxドキュメントサイトをご覧ください。必ずチェックアウトしてください:
注: Flax Linenのドキュメントには独自のサイトがあります。
Flaxチームの使命は、アルファベット内とより広範なコミュニティの両方で、成長するJax Neural Network Research Ecosystemにサービスを提供し、Jaxが輝くユースケースを探索することです。 GitHubを使用して、ほぼすべての調整と計画、および今後のデザイン変更について説明します。ディスカッション、発行、リクエストスレッドのいずれかに関するフィードバックを歓迎します。
機能リクエストを作成したり、取り組んでいる内容をお知らせしたり、問題を報告したり、Flax Githubディスカッションフォーラムで質問をしたりできます。
亜麻の改善を期待していますが、コアAPIの大幅な破損の変化は予想していません。可能であれば、Changelogエントリと非推奨警告を使用します。
あなたが私たちに直接連絡したい場合は、[email protected]にいます。
Flaxは、柔軟性のために設計されたJaxの高性能ニューラルネットワークライブラリおよびエコシステムです。例を分岐し、フレームワークに機能を追加するのではなく、トレーニングループを変更することにより、新しい形式のトレーニングを試してください。
FLAXはJAXチームと緊密なコラボレーションで開発されており、次のような研究を開始するために必要なすべてのものが付属しています。
ニューラルネットワークAPI ( flax.nnx
): Linear
、 Conv
、 BatchNorm
、 LayerNorm
、 GroupNorm
、attence( MultiHeadAttention
)、 LSTMCell
、 GRUCell
、 Dropout
を含む。
ユーティリティとパターン:複製されたトレーニング、シリアル化とチェックポイント、メトリック、デバイスでのプリフェッチ。
教育の例:MNIST、Gemma Language Model(Transformer)、Transformer LM1Bによる推論/サンプリング。
FlaxはJAXを使用しているため、CPU、GPU、TPUのJAXインストール手順をチェックしてください。
Python 3.8以降が必要です。 PypiからFlaxをインストールします。
pip install flax
Flaxの最新バージョンにアップグレードするには、以下を使用できます。
pip install --upgrade git+https://github.com/google/flax.git
いくつかの追加の依存関係( matplotlib
など)をインストールするには、一部の依存関係に含まれていないが、使用できます。
pip install " flax[all] "
Flax APIを使用して3つの例を提供します。単純な多層パーセプトロン、CNN、自動エンコーダーです。
Module
抽象化の詳細については、モジュールの抽象化の幅広いイントロであるDocsをご覧ください。ベストプラクティスの追加の具体的なデモンストレーションについては、ガイドと開発者のメモを参照してください。
MLPの例:
class MLP ( nnx . Module ):
def __init__ ( self , din : int , dmid : int , dout : int , * , rngs : nnx . Rngs ):
self . linear1 = Linear ( din , dmid , rngs = rngs )
self . dropout = nnx . Dropout ( rate = 0.1 , rngs = rngs )
self . bn = nnx . BatchNorm ( dmid , rngs = rngs )
self . linear2 = Linear ( dmid , dout , rngs = rngs )
def __call__ ( self , x : jax . Array ):
x = nnx . gelu ( self . dropout ( self . bn ( self . linear1 ( x ))))
return self . linear2 ( x )
CNNの例:
class CNN ( nnx . Module ):
def __init__ ( self , * , rngs : nnx . Rngs ):
self . conv1 = nnx . Conv ( 1 , 32 , kernel_size = ( 3 , 3 ), rngs = rngs )
self . conv2 = nnx . Conv ( 32 , 64 , kernel_size = ( 3 , 3 ), rngs = rngs )
self . avg_pool = partial ( nnx . avg_pool , window_shape = ( 2 , 2 ), strides = ( 2 , 2 ))
self . linear1 = nnx . Linear ( 3136 , 256 , rngs = rngs )
self . linear2 = nnx . Linear ( 256 , 10 , rngs = rngs )
def __call__ ( self , x ):
x = self . avg_pool ( nnx . relu ( self . conv1 ( x )))
x = self . avg_pool ( nnx . relu ( self . conv2 ( x )))
x = x . reshape ( x . shape [ 0 ], - 1 ) # flatten
x = nnx . relu ( self . linear1 ( x ))
x = self . linear2 ( x )
return x
自動エンコーダーの例:
Encoder = lambda rngs : nnx . Linear ( 2 , 10 , rngs = rngs )
Decoder = lambda rngs : nnx . Linear ( 10 , 2 , rngs = rngs )
class AutoEncoder ( nnx . Module ):
def __init__ ( self , rngs ):
self . encoder = Encoder ( rngs )
self . decoder = Decoder ( rngs )
def __call__ ( self , x ) -> jax . Array :
return self . decoder ( self . encoder ( x ))
def encode ( self , x ) -> jax . Array :
return self . encoder ( x )
このリポジトリを引用するには:
@software{flax2020github,
author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
title = {{F}lax: A neural network library and ecosystem for {JAX}},
url = {http://github.com/google/flax},
version = {0.10.2},
year = {2024},
}
上記のbibtexエントリでは、名前はアルファベット順に、バージョン番号はFlax/version.pyからのものであり、年はプロジェクトのオープンソースリリースに対応することを目的としています。
Flaxは、Google Deepmindの専任チームが維持するオープンソースプロジェクトですが、Googleの公式製品ではありません。