개요 | 빠른 설치 | Flax는 어떻게 생겼습니까? | 선적 서류 비치
2024 년에 출시 된 Flax NNX는 JAX에서 신경망을보다 쉽게 만들고 검사하고 디버깅하고 분석 할 수 있도록 설계된 새로운 단순화 된 Flax API입니다. Python Reference Semantics에 대한 일등석 지원을 추가하여이를 달성합니다. 이를 통해 사용자는 일반 파이썬 객체를 사용하여 모델을 표현하여 참조 공유 및 돌연변이를 가능하게합니다.
Flax NNX는 2020 년 Google Brain의 엔지니어 및 연구원이 JAX 팀과 긴밀한 협력으로 출시 한 Flax Linen API에서 발전했습니다.
전용 Flax 문서 사이트에서 Flax NNX에 대해 자세히 알아볼 수 있습니다. 체크 아웃을 확인하십시오.
참고 : 아마 리넨의 문서에는 자체 사이트가 있습니다.
Flax 팀의 사명은 알파벳 내 및 더 넓은 커뮤니티 모두에서 성장하는 JAX 신경망 연구 생태계에 서비스를 제공하고 JAX가 빛나는 사용 사례를 탐색하는 것입니다. 우리는 거의 모든 조정 및 계획뿐만 아니라 다가오는 디자인 변경에 대해 논의하는 곳에 Github를 사용합니다. 우리는 토론, 문제 및 요청 스레드에 대한 피드백을 환영합니다.
기능 요청을하고, 작업중인 내용을 알려주고, 문제를보고하고, Flax Github 토론 포럼에서 질문을 할 수 있습니다.
우리는 아마를 향상시킬 것으로 예상하지만, 코어 API의 상당한 파괴 변화를 예상하지는 않습니다. 가능한 경우 변경 사항 항목 및 감가 상각 경고를 사용합니다.
직접 연락하고 싶다면 [email protected]에 있습니다.
Flax는 유연성을 위해 설계된 JAX의 고성능 신경망 라이브러리 및 생태계입니다. 프레임 워크에 기능을 추가하지 않고 예제를 포킹하고 교육 루프를 수정하여 새로운 형태의 교육을 시도하십시오.
Flax는 JAX 팀과 긴밀한 협력을 통해 개발되고 있으며 다음을 포함하여 연구를 시작하는 데 필요한 모든 것을 제공합니다.
신경망 API ( flax.nnx
) : Linear
, Conv
, BatchNorm
, LayerNorm
, GroupNorm
, 관심 ( MultiHeadAttention
), LSTMCell
, GRUCell
, Dropout
포함.
유틸리티 및 패턴 : 복제 된 교육, 직렬화 및 체크 포인팅, 메트릭, 장치의 프리 페치.
교육 사례 : Gemma Language Model (Transformer)을 통한 MNIST, 추론/샘플링, 변압기 LM1B.
Flax는 JAX를 사용하므로 CPU, GPU 및 TPU에서 JAX 설치 지침을 확인하십시오.
Python 3.8 이상이 필요합니다. PYPI에서 아마 설치 :
pip install flax
최신 버전의 Flax로 업그레이드하려면 다음을 사용할 수 있습니다.
pip install --upgrade git+https://github.com/google/flax.git
일부 의존성에 포함되어 있지 않지만 포함되지 않은 일부 추가 종속성 (예 : matplotlib
)을 설치하려면 다음을 사용할 수 있습니다.
pip install " flax[all] "
우리는 아마 API를 사용하여 세 가지 예를 제공합니다 : 간단한 다층 퍼셉트론, CNN 및 자동 인코더.
Module
추상화에 대한 자세한 내용은 모듈 추상화에 대한 광범위한 소개 인 문서를 확인하십시오. 모범 사례에 대한 추가적인 구체적인 시연은 가이드 및 개발자 메모를 참조하십시오.
MLP의 예 :
class MLP ( nnx . Module ):
def __init__ ( self , din : int , dmid : int , dout : int , * , rngs : nnx . Rngs ):
self . linear1 = Linear ( din , dmid , rngs = rngs )
self . dropout = nnx . Dropout ( rate = 0.1 , rngs = rngs )
self . bn = nnx . BatchNorm ( dmid , rngs = rngs )
self . linear2 = Linear ( dmid , dout , rngs = rngs )
def __call__ ( self , x : jax . Array ):
x = nnx . gelu ( self . dropout ( self . bn ( self . linear1 ( x ))))
return self . linear2 ( x )
CNN의 예 :
class CNN ( nnx . Module ):
def __init__ ( self , * , rngs : nnx . Rngs ):
self . conv1 = nnx . Conv ( 1 , 32 , kernel_size = ( 3 , 3 ), rngs = rngs )
self . conv2 = nnx . Conv ( 32 , 64 , kernel_size = ( 3 , 3 ), rngs = rngs )
self . avg_pool = partial ( nnx . avg_pool , window_shape = ( 2 , 2 ), strides = ( 2 , 2 ))
self . linear1 = nnx . Linear ( 3136 , 256 , rngs = rngs )
self . linear2 = nnx . Linear ( 256 , 10 , rngs = rngs )
def __call__ ( self , x ):
x = self . avg_pool ( nnx . relu ( self . conv1 ( x )))
x = self . avg_pool ( nnx . relu ( self . conv2 ( x )))
x = x . reshape ( x . shape [ 0 ], - 1 ) # flatten
x = nnx . relu ( self . linear1 ( x ))
x = self . linear2 ( x )
return x
Autoencoder의 예 :
Encoder = lambda rngs : nnx . Linear ( 2 , 10 , rngs = rngs )
Decoder = lambda rngs : nnx . Linear ( 10 , 2 , rngs = rngs )
class AutoEncoder ( nnx . Module ):
def __init__ ( self , rngs ):
self . encoder = Encoder ( rngs )
self . decoder = Decoder ( rngs )
def __call__ ( self , x ) -> jax . Array :
return self . decoder ( self . encoder ( x ))
def encode ( self , x ) -> jax . Array :
return self . encoder ( x )
이 저장소를 인용하려면 :
@software{flax2020github,
author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
title = {{F}lax: A neural network library and ecosystem for {JAX}},
url = {http://github.com/google/flax},
version = {0.10.2},
year = {2024},
}
위의 Bibtex 항목에서 이름은 알파벳 순서로 표시되며, 버전 번호는 Flax/version.py에서의 것이며, 연도는 프로젝트의 오픈 소스 릴리스에 해당합니다.
Flax는 Google Deepmind의 전용 팀이 관리하는 오픈 소스 프로젝트이지만 공식적인 Google 제품은 아닙니다.