主要特点|最新更新|愿景|快速入门|参考文档|执照
EasyDeL 是一个开源框架,旨在增强和简化机器学习模型的训练过程,主要关注 Jax/Flax。它为在 TPU/GPU 上大规模训练和服务 Flax/Jax 模型提供了便捷有效的解决方案。
EasyDeL 因提供无与伦比的灵活性和透明度而脱颖而出:
开放式架构:EasyDeL 的每个组件都是开放的,可供检查、修改和定制。这里没有黑匣子。
可破解性的核心:我们相信给您完全的控制权。无论您是想调整一个小功能还是彻底修改一个训练循环,EasyDeL 都可以让您做到。
自定义代码访问:所有自定义实现都是现成的并且有详细的文档记录,允许您根据需要理解、学习和修改内部结构。
鼓励实验:我们积极鼓励用户尝试、扩展和改进现有的代码库。您的创新可能成为下一个重大功能!
社区驱动的开发:与社区分享您的自定义实现和改进,为推进机器学习研究和开发营造协作环境。
有了 EasyDeL,您就不再受严格框架的限制。相反,您拥有一个灵活、强大的工具包,可以满足您的需求,无论它们有多么独特或专业。无论您是进行前沿研究还是构建可立即投入生产的机器学习系统,EasyDeL 都能提供不受限制的创新自由。
EasyDeL 在定制和优化模型方面提供了无与伦比的灵活性:
分片策略:轻松定制和试验不同的分片策略,以优化多个设备的性能。
算法定制:修改和微调算法以满足您的特定需求和硬件配置。
注意力机制:从针对 GPU/TPU/CPU 优化的 10 多种注意力机制中进行选择,包括:
这种级别的定制允许您充分利用硬件的性能,同时根据您的具体要求定制模型行为。
EasyDeL 不断发展以满足机器学习社区的需求。在即将到来的更新中,我们计划推出:
灵活性:EasyDeL 提供模块化设计,使研究人员和开发人员能够轻松混合和匹配组件,尝试不同的架构(包括 Transformers、Mamba、RWKV 等),并使模型适应特定的用例。
性能:利用 JAX 和 Flax 的强大功能,EasyDeL 提供最先进模型和训练技术的高性能实现,并针对 TPU 和 GPU 进行了优化。
可扩展性:从小型实验到大规模模型训练,EasyDeL 提供工具和优化来有效扩展您的模型和工作流程。
易于使用:尽管 EasyDeL 具有强大的功能,但它仍保留了直观的 API,使初学者和经验丰富的从业者都可以使用它。
前沿研究:快速实施模型架构、训练技术和优化方法方面的最新进展。
pip install easydel
import easydel as ed
ed . FlexibleAttentionModule . run_attention_benchmarks ()
EasyDeL 文档中提供了全面的文档和示例。
以下是您最新更新的改进版本:
jax_flash_attn2
,并且默认注意现在设置为 CPU/GPU/TPU 中的jax_flash_attn2
。inference
性能8bit_cache
的支持DPO
和ORPO
培训师均已升级。 params = model . shard_params ( params )
params = model . gather_params ( params )
do_shard_params
已从TrainingArguments
中删除。要对参数进行分片,您现在必须在训练之前手动执行此操作。ApiEngine
和engine_client
SFT
、 DPO
、 ORPO
、 CLM
培训师vInference
类提供了一个简化的界面,用于使用 JAX 中预先训练的语言模型生成文本。
import easydel as ed
from transformers import AutoTokenizer
model , params = ed . AutoEasyDeLModelForCausalLM . from_pretrained (...)
tokenizer = AutoTokenizer . from_pretrained (...)
inference = ed . vInference (
model = model ,
params = params ,
tokenizer = tokenizer ,
generation_config = ed . vInferenceConfig (
temperature = model . generation_config . temperature ,
top_k = model . generation_config . top_k ,
top_p = model . generation_config . top_p ,
bos_token_id = model . generation_config . bos_token_id ,
eos_token_id = model . generation_config . eos_token_id ,
pad_token_id = model . generation_config . pad_token_id ,
streaming_chunks = 32 ,
max_new_tokens = 1024 ,
),
)
vInferenceApiServer
是一个用于生产或研究目的的 Serve API 引擎,提供稳定、高效、类似 OpenAI API 的 API。
import easydel as ed
api_inference = ed . vInferenceApiServer (
{ inference . inference_name : inference }
) # you can load multi inferences together
api_inference . fire ()
EasyDeLState
充当 EasyDeL 模型的综合容器,包括训练进度、模型参数和优化器信息。
from easydel import EasyDeLState
state = EasyDeLState . from_pretrained (
pretrained_model_name_or_path = "model_name" ,
dtype = jnp . bfloat16 ,
param_dtype = jnp . bfloat16 ,
sharding_axis_dims = ( 1 , - 1 , 1 , 1 )
)
from easydel import SFTTrainer , TrainingArguments
trainer = SFTTrainer (
arguments = train_arguments ,
train_dataset = train_dataset ,
eval_dataset = eval_dataset ,
tokenizer = tokenizer ,
formatting_func = prompter ,
packing = True ,
num_of_sequences = max_length ,
)
output = trainer . train ( flax . core . FrozenDict ({ "params" : params }))
from easydel import DPOTrainer
dpo_trainer = DPOTrainer (
model_state = state ,
ref_model_state = ref_state ,
beta = 0.1 ,
train_dataset = train_dataset ,
eval_dataset = eval_dataset ,
tokenizer = tokenizer ,
arguments = arguments ,
max_length = max_length ,
max_completion_length = max_completion_length ,
max_prompt_length = max_prompt_length ,
)
output = dpo_trainer . train ()
欢迎为 EasyDeL 做出贡献!请分叉存储库,进行更改,然后提交拉取请求。
EasyDeL 是在 Apache v2 许可证下发布的。有关更多详细信息,请参阅许可证文件。
如果您对 EasyDeL 有任何疑问或意见,可以通过[email protected]与我联系。
在您的工作中引用 EasyDeL:
@misc { Zare Chavoshi_2023,
title = { EasyDeL: An open-source library for enhancing and streamlining the training process of machine learning models } ,
url = { https://github.com/erfanzar/EasyDeL } ,
author = { Zare Chavoshi, Erfan } ,
year = { 2023 }
}