・・ |
Ignite 是一個進階庫,可協助在 PyTorch 中靈活、透明地訓練和評估神經網路。
點擊圖片看完整程式碼
比純 PyTorch 更少的程式碼,同時確保最大程度的控制和簡單性
庫方法和無程序控制反轉 -在需要的地方和時間使用 ignite
用於指標、實驗管理器和其他元件的可擴展 API
Ignite 是一個提供三個進階功能的函式庫:
不再需要對紀元和迭代進行for/while
循環編碼。使用者實例化引擎並運行它們。
from ignite . engine import Engine , Events , create_supervised_evaluator
from ignite . metrics import Accuracy
# Setup training engine:
def train_step ( engine , batch ):
# Users can do whatever they need on a single iteration
# Eg. forward/backward pass for any number of models, optimizers, etc
# ...
trainer = Engine ( train_step )
# Setup single model evaluation engine
evaluator = create_supervised_evaluator ( model , metrics = { "accuracy" : Accuracy ()})
def validation ():
state = evaluator . run ( validation_data_loader )
# print computed metrics
print ( trainer . state . epoch , state . metrics )
# Run model's validation at the end of each epoch
trainer . add_event_handler ( Events . EPOCH_COMPLETED , validation )
# Start the training
trainer . run ( training_data_loader , max_epochs = 100 )
處理程序的最酷之處在於它們提供了無與倫比的靈活性(例如與回呼相比)。處理程序可以是任何函數:例如 lambda、簡單函數、類別方法等。
trainer . add_event_handler ( Events . STARTED , lambda _ : print ( "Start training" ))
# attach handler with args, kwargs
mydata = [ 1 , 2 , 3 , 4 ]
logger = ...
def on_training_ended ( data ):
print ( f"Training is ended. mydata= { data } " )
# User can use variables from another scope
logger . info ( "Training is ended" )
trainer . add_event_handler ( Events . COMPLETED , on_training_ended , mydata )
# call any number of functions on a single event
trainer . add_event_handler ( Events . COMPLETED , lambda engine : print ( engine . state . times ))
@ trainer . on ( Events . ITERATION_COMPLETED )
def log_something ( engine ):
print ( engine . state . output )
# run the validation every 5 epochs
@ trainer . on ( Events . EPOCH_COMPLETED ( every = 5 ))
def run_validation ():
# run validation
# change some training variable once on 20th epoch
@ trainer . on ( Events . EPOCH_STARTED ( once = 20 ))
def change_training_variable ():
# ...
# Trigger handler with customly defined frequency
@ trainer . on ( Events . ITERATION_COMPLETED ( event_filter = first_x_iters ))
def log_gradients ():
# ...
事件可以堆疊在一起以啟用多個呼叫:
@ trainer . on ( Events . COMPLETED | Events . EPOCH_COMPLETED ( every = 10 ))
def run_validation ():
# ...
與後向和優化器步驟呼叫相關的自訂事件:
from ignite . engine import EventEnum
class BackpropEvents ( EventEnum ):
BACKWARD_STARTED = 'backward_started'
BACKWARD_COMPLETED = 'backward_completed'
OPTIM_STEP_COMPLETED = 'optim_step_completed'
def update ( engine , batch ):
# ...
loss = criterion ( y_pred , y )
engine . fire_event ( BackpropEvents . BACKWARD_STARTED )
loss . backward ()
engine . fire_event ( BackpropEvents . BACKWARD_COMPLETED )
optimizer . step ()
engine . fire_event ( BackpropEvents . OPTIM_STEP_COMPLETED )
# ...
trainer = Engine ( update )
trainer . register_events ( * BackpropEvents )
@ trainer . on ( BackpropEvents . BACKWARD_STARTED )
def function_before_backprop ( engine ):
# ...
各種任務的指標:精確度、回想率、準確度、混淆矩陣、IoU 等,約 20 個迴歸指標。
使用者還可以使用算術運算或火炬方法輕鬆地根據現有指標組成自己的指標。
precision = Precision ( average = False )
recall = Recall ( average = False )
F1_per_class = ( precision * recall * 2 / ( precision + recall ))
F1_mean = F1_per_class . mean () # torch mean method
F1_mean . attach ( engine , "F1" )
來自點:
pip install pytorch-ignite
來自康達:
conda install ignite -c pytorch
來自來源:
pip install git+https://github.com/pytorch/ignite
來自點:
pip install --pre pytorch-ignite
來自 conda (這建議安裝 pytorch nightly 版本而不是穩定版本作為依賴項):
conda install ignite -c pytorch-nightly
從我們的 Docker Hub 中提取預先建置的 docker 映像並使用 docker v19.03+ 運行它。
docker run --gpus all -it -v $PWD :/workspace/project --network=host --shm-size 16G pytorchignite/base:latest /bin/bash
根據
pytorchignite/base:latest
pytorchignite/apex:latest
pytorchignite/hvd-base:latest
pytorchignite/hvd-apex:latest
pytorchignite/msdp-apex:latest
想像:
pytorchignite/vision:latest
pytorchignite/hvd-vision:latest
pytorchignite/apex-vision:latest
pytorchignite/hvd-apex-vision:latest
pytorchignite/msdp-apex-vision:latest
自然語言處理:
pytorchignite/nlp:latest
pytorchignite/hvd-nlp:latest
pytorchignite/apex-nlp:latest
pytorchignite/hvd-apex-nlp:latest
pytorchignite/msdp-apex-nlp:latest
有關更多詳細信息,請參見此處。
一些幫助您入門的建議:
受 torchvision/references 的啟發,我們為視覺任務提供了幾個可重複的基準:
特徵:
使用 PyTorch-Ignite 創建訓練腳本的最簡單方法:
GitHub 問題:問題、錯誤回報、功能請求等。
討論.PyTorch,類別「點燃」。
PyTorch-Ignite Discord Server:與社群聊天
GitHub Discussions:與庫相關的一般討論、想法、問答等。
我們創建了一個「使用者回饋」表格。我們感謝任何類型的回饋,這就是我們希望看到我們的社群的方式:
謝謝你!
請參閱貢獻指南以獲取更多資訊。
一如既往,歡迎 PR :)
請參閱「使用者」中的其他項目
如果您的專案實作了一篇論文,代表了我們的官方教學、Kaggle 競賽程式碼中未涵蓋的其他用例,或者只是您的程式碼提供了有趣的結果並使用了 Ignite。我們希望將您的項目新增到此清單中,因此請發送包含該項目簡短描述的 PR。
如果您在科學出版物中使用 PyTorch-Ignite,我們將不勝感激我們項目的引用。
@misc{pytorch-ignite,
author = {V. Fomin and J. Anmol and S. Desroziers and J. Kriss and A. Tejani},
title = {High-level library to help with training neural networks in PyTorch},
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {url{https://github.com/pytorch/ignite}},
}
PyTorch-Ignite 是一個 NumFOCUS 附屬項目,由 PyTorch 社區的志工以個人身分(而非雇主代表)運作和維護。請參閱「關於我們」頁面以取得核心貢獻者清單。有關使用問題和問題,請參閱此處的各個管道。對於所有其他問題和詢問,請發送電子郵件至 [email protected]。