主要特點|最新更新|願景|快速入門|參考文件|執照
EasyDeL 是一個開源框架,旨在增強和簡化機器學習模型的訓練過程,主要專注於 Jax/Flax。它為在 TPU/GPU 上大規模訓練和服務 Flax/Jax 車型提供了便捷有效的解決方案。
EasyDeL 因提供無與倫比的靈活性和透明度而脫穎而出:
開放式架構:EasyDeL 的每個元件都是開放式的,可供檢查、修改和自訂。這裡沒有黑盒子。
可破解性的核心:我們相信給您完全的控制權。無論您是想調整一個小功能還是徹底修改一個訓練循環,EasyDeL 都可以讓您做到。
自訂程式碼存取:所有自訂實作都是現成的並且有詳細的文件記錄,可讓您根據需要理解、學習和修改內部結構。
鼓勵實驗:我們積極鼓勵使用者嘗試、擴展和改進現有的程式碼庫。您的創新可能成為下一個重大功能!
社群驅動的開發:與社群分享您的自訂實作和改進,為推進機器學習研究和開發創造協作環境。
有了 EasyDeL,您就不再受嚴格框架的限制。相反,您擁有一個靈活、強大的工具包,可以滿足您的需求,無論它們有多麼獨特或專業。無論您是進行前沿研究還是建立可立即投入生產的機器學習系統,EasyDeL 都能提供不受限制的創新自由。
EasyDeL 在客製化和優化模型方面提供了無與倫比的靈活性:
分片策略:輕鬆自訂和試驗不同的分片策略,以優化多個設備的性能。
演算法客製化:修改和微調演算法以滿足您的特定需求和硬體配置。
注意力機制:從針對 GPU/TPU/CPU 最佳化的 10 多種注意力機制中進行選擇,包括:
這種程度的客製化可讓您充分利用硬體的效能,同時根據您的特定要求自訂模型行為。
EasyDeL 不斷發展以滿足機器學習社群的需求。在即將到來的更新中,我們計劃推出:
靈活性:EasyDeL 提供模組化設計,使研究人員和開發人員能夠輕鬆混合和匹配元件,嘗試不同的架構(包括 Transformers、Mamba、RWKV 等),並使模型適應特定的用例。
性能:利用 JAX 和 Flax 的強大功能,EasyDeL 提供最先進模型和訓練技術的高效能實現,並針對 TPU 和 GPU 進行了最佳化。
可擴展性:從小型實驗到大規模模型訓練,EasyDeL 提供工具和最佳化來有效擴展您的模型和工作流程。
易於使用:儘管 EasyDeL 具有強大的功能,但它仍保留了直覺的 API,使初學者和經驗豐富的從業者都可以使用它。
前沿研究:快速實施模型架構、訓練技術和最佳化方法的最新進展。
pip install easydel
import easydel as ed
ed . FlexibleAttentionModule . run_attention_benchmarks ()
EasyDeL 文件中提供了全面的文檔和範例。
以下是您最新更新的改進版本:
jax_flash_attn2
,並且預設注意現在設定為 CPU/GPU/TPU 中的jax_flash_attn2
。inference
性能8bit_cache
的支持DPO
和ORPO
培訓師均已升級。 params = model . shard_params ( params )
params = model . gather_params ( params )
do_shard_params
已從TrainingArguments
中刪除。若要對參數進行分片,您現在必須在訓練之前手動執行此操作。ApiEngine
和engine_client
SFT
、 DPO
、 ORPO
、 CLM
培訓師vInference
類別提供了一個簡化的介面,用於使用 JAX 中預先訓練的語言模型產生文字。
import easydel as ed
from transformers import AutoTokenizer
model , params = ed . AutoEasyDeLModelForCausalLM . from_pretrained (...)
tokenizer = AutoTokenizer . from_pretrained (...)
inference = ed . vInference (
model = model ,
params = params ,
tokenizer = tokenizer ,
generation_config = ed . vInferenceConfig (
temperature = model . generation_config . temperature ,
top_k = model . generation_config . top_k ,
top_p = model . generation_config . top_p ,
bos_token_id = model . generation_config . bos_token_id ,
eos_token_id = model . generation_config . eos_token_id ,
pad_token_id = model . generation_config . pad_token_id ,
streaming_chunks = 32 ,
max_new_tokens = 1024 ,
),
)
vInferenceApiServer
是一個用於生產或研究目的的 Serve API 引擎,提供穩定、高效、類似 OpenAI API 的 API。
import easydel as ed
api_inference = ed . vInferenceApiServer (
{ inference . inference_name : inference }
) # you can load multi inferences together
api_inference . fire ()
EasyDeLState
充當 EasyDeL 模型的綜合容器,包括訓練進度、模型參數和優化器資訊。
from easydel import EasyDeLState
state = EasyDeLState . from_pretrained (
pretrained_model_name_or_path = "model_name" ,
dtype = jnp . bfloat16 ,
param_dtype = jnp . bfloat16 ,
sharding_axis_dims = ( 1 , - 1 , 1 , 1 )
)
from easydel import SFTTrainer , TrainingArguments
trainer = SFTTrainer (
arguments = train_arguments ,
train_dataset = train_dataset ,
eval_dataset = eval_dataset ,
tokenizer = tokenizer ,
formatting_func = prompter ,
packing = True ,
num_of_sequences = max_length ,
)
output = trainer . train ( flax . core . FrozenDict ({ "params" : params }))
from easydel import DPOTrainer
dpo_trainer = DPOTrainer (
model_state = state ,
ref_model_state = ref_state ,
beta = 0.1 ,
train_dataset = train_dataset ,
eval_dataset = eval_dataset ,
tokenizer = tokenizer ,
arguments = arguments ,
max_length = max_length ,
max_completion_length = max_completion_length ,
max_prompt_length = max_prompt_length ,
)
output = dpo_trainer . train ()
歡迎為 EasyDeL 做出貢獻!請分叉儲存庫,進行更改,然後提交拉取請求。
EasyDeL 是在 Apache v2 許可證下發布的。有關更多詳細信息,請參閱許可證文件。
如果您對 EasyDeL 有任何疑問或意見,可以透過[email protected]與我聯繫。
在您的作品中引用 EasyDeL:
@misc { Zare Chavoshi_2023,
title = { EasyDeL: An open-source library for enhancing and streamlining the training process of machine learning models } ,
url = { https://github.com/erfanzar/EasyDeL } ,
author = { Zare Chavoshi, Erfan } ,
year = { 2023 }
}