・・ |
Ignite 是一个高级库,可帮助在 PyTorch 中灵活、透明地训练和评估神经网络。
点击图片查看完整代码
比纯 PyTorch 更少的代码,同时确保最大程度的控制和简单性
库方法和无程序控制反转 -在需要的地方和时间使用 ignite
用于指标、实验管理器和其他组件的可扩展 API
Ignite 是一个提供三个高级功能的库:
不再需要对纪元和迭代进行for/while
循环编码。用户实例化引擎并运行它们。
from ignite . engine import Engine , Events , create_supervised_evaluator
from ignite . metrics import Accuracy
# Setup training engine:
def train_step ( engine , batch ):
# Users can do whatever they need on a single iteration
# Eg. forward/backward pass for any number of models, optimizers, etc
# ...
trainer = Engine ( train_step )
# Setup single model evaluation engine
evaluator = create_supervised_evaluator ( model , metrics = { "accuracy" : Accuracy ()})
def validation ():
state = evaluator . run ( validation_data_loader )
# print computed metrics
print ( trainer . state . epoch , state . metrics )
# Run model's validation at the end of each epoch
trainer . add_event_handler ( Events . EPOCH_COMPLETED , validation )
# Start the training
trainer . run ( training_data_loader , max_epochs = 100 )
处理程序的最酷之处在于它们提供了无与伦比的灵活性(例如与回调相比)。处理程序可以是任何函数:例如 lambda、简单函数、类方法等。因此,我们不需要从接口继承并重写其抽象方法,这可能会不必要地增加代码及其复杂性。
trainer . add_event_handler ( Events . STARTED , lambda _ : print ( "Start training" ))
# attach handler with args, kwargs
mydata = [ 1 , 2 , 3 , 4 ]
logger = ...
def on_training_ended ( data ):
print ( f"Training is ended. mydata= { data } " )
# User can use variables from another scope
logger . info ( "Training is ended" )
trainer . add_event_handler ( Events . COMPLETED , on_training_ended , mydata )
# call any number of functions on a single event
trainer . add_event_handler ( Events . COMPLETED , lambda engine : print ( engine . state . times ))
@ trainer . on ( Events . ITERATION_COMPLETED )
def log_something ( engine ):
print ( engine . state . output )
# run the validation every 5 epochs
@ trainer . on ( Events . EPOCH_COMPLETED ( every = 5 ))
def run_validation ():
# run validation
# change some training variable once on 20th epoch
@ trainer . on ( Events . EPOCH_STARTED ( once = 20 ))
def change_training_variable ():
# ...
# Trigger handler with customly defined frequency
@ trainer . on ( Events . ITERATION_COMPLETED ( event_filter = first_x_iters ))
def log_gradients ():
# ...
事件可以堆叠在一起以启用多个调用:
@ trainer . on ( Events . COMPLETED | Events . EPOCH_COMPLETED ( every = 10 ))
def run_validation ():
# ...
与后向和优化器步骤调用相关的自定义事件:
from ignite . engine import EventEnum
class BackpropEvents ( EventEnum ):
BACKWARD_STARTED = 'backward_started'
BACKWARD_COMPLETED = 'backward_completed'
OPTIM_STEP_COMPLETED = 'optim_step_completed'
def update ( engine , batch ):
# ...
loss = criterion ( y_pred , y )
engine . fire_event ( BackpropEvents . BACKWARD_STARTED )
loss . backward ()
engine . fire_event ( BackpropEvents . BACKWARD_COMPLETED )
optimizer . step ()
engine . fire_event ( BackpropEvents . OPTIM_STEP_COMPLETED )
# ...
trainer = Engine ( update )
trainer . register_events ( * BackpropEvents )
@ trainer . on ( BackpropEvents . BACKWARD_STARTED )
def function_before_backprop ( engine ):
# ...
各种任务的指标:精确度、召回率、准确度、混淆矩阵、IoU 等,约 20 个回归指标。
用户还可以使用算术运算或火炬方法轻松地根据现有指标组成自己的指标。
precision = Precision ( average = False )
recall = Recall ( average = False )
F1_per_class = ( precision * recall * 2 / ( precision + recall ))
F1_mean = F1_per_class . mean () # torch mean method
F1_mean . attach ( engine , "F1" )
来自点:
pip install pytorch-ignite
来自康达:
conda install ignite -c pytorch
来自来源:
pip install git+https://github.com/pytorch/ignite
来自点:
pip install --pre pytorch-ignite
来自 conda (这建议安装 pytorch nightly 版本而不是稳定版本作为依赖项):
conda install ignite -c pytorch-nightly
从我们的 Docker Hub 中提取预构建的 docker 映像并使用 docker v19.03+ 运行它。
docker run --gpus all -it -v $PWD :/workspace/project --network=host --shm-size 16G pytorchignite/base:latest /bin/bash
根据
pytorchignite/base:latest
pytorchignite/apex:latest
pytorchignite/hvd-base:latest
pytorchignite/hvd-apex:latest
pytorchignite/msdp-apex:latest
想象:
pytorchignite/vision:latest
pytorchignite/hvd-vision:latest
pytorchignite/apex-vision:latest
pytorchignite/hvd-apex-vision:latest
pytorchignite/msdp-apex-vision:latest
自然语言处理:
pytorchignite/nlp:latest
pytorchignite/hvd-nlp:latest
pytorchignite/apex-nlp:latest
pytorchignite/hvd-apex-nlp:latest
pytorchignite/msdp-apex-nlp:latest
有关更多详细信息,请参见此处。
一些帮助您入门的建议:
受 torchvision/references 的启发,我们为视觉任务提供了几个可重复的基线:
特征:
使用 PyTorch-Ignite 创建训练脚本的最简单方法:
GitHub 问题:问题、错误报告、功能请求等。
讨论.PyTorch,类别“点燃”。
PyTorch-Ignite Discord Server:与社区聊天
GitHub Discussions:与库相关的一般讨论、想法、问答等。
我们创建了一个“用户反馈”表格。我们感谢任何类型的反馈,这就是我们希望看到我们的社区的方式:
谢谢你!
请参阅贡献指南以获取更多信息。
一如既往,欢迎 PR :)
请参阅“使用者”中的其他项目
如果您的项目实现了一篇论文,代表了我们的官方教程、Kaggle 竞赛代码中未涵盖的其他用例,或者只是您的代码提供了有趣的结果并使用了 Ignite。我们希望将您的项目添加到此列表中,因此请发送包含该项目简短描述的 PR。
如果您在科学出版物中使用 PyTorch-Ignite,我们将不胜感激对我们项目的引用。
@misc{pytorch-ignite,
author = {V. Fomin and J. Anmol and S. Desroziers and J. Kriss and A. Tejani},
title = {High-level library to help with training neural networks in PyTorch},
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {url{https://github.com/pytorch/ignite}},
}
PyTorch-Ignite 是一个 NumFOCUS 附属项目,由 PyTorch 社区的志愿者以个人身份(而非雇主代表)运营和维护。请参阅“关于我们”页面以获取核心贡献者列表。有关使用问题和问题,请参阅此处的各个渠道。对于所有其他问题和询问,请发送电子邮件至 [email protected]。