用於檢查和提取 PyTorch 模型中間層的庫。
通常情況下,我們希望在不修改程式碼的情況下檢查PyTorch 模型的中間層。這對於獲取語言模型的注意力矩陣、視覺化層嵌入或將損失函數應用於中間層非常有用。有時我們想要提取模型的子部分並獨立運行它們,以調試它們或單獨訓練它們。所有這些都可以透過 Surgeon 完成,而無需更改原始模型的任何一條線。
$ pip install surgeon-pytorch
給定 PyTorch 模型,我們可以使用get_layers
顯示所有層:
import torch
import torch . nn as nn
from surgeon_pytorch import Inspect , get_layers
class SomeModel ( nn . Module ):
def __init__ ( self ):
super (). __init__ ()
self . layer1 = nn . Linear ( 5 , 3 )
self . layer2 = nn . Linear ( 3 , 2 )
self . layer3 = nn . Linear ( 2 , 1 )
def forward ( self , x ):
x1 = self . layer1 ( x )
x2 = self . layer2 ( x1 )
y = self . layer3 ( x2 )
return y
model = SomeModel ()
print ( get_layers ( model )) # ['layer1', 'layer2', 'layer3']
然後我們可以使用Inspect
包裝要檢查的model
,並且在每次前向呼叫新模型時,我們還將輸出提供的層輸出(在第二個返回值中):
model_wrapped = Inspect ( model , layer = 'layer2' )
x = torch . rand ( 1 , 5 )
y , x2 = model_wrapped ( x )
print ( x2 ) # tensor([[-0.2726, 0.0910]], grad_fn=<AddmmBackward0>)
我們可以提供圖層列表:
model_wrapped = Inspect ( model , layer = [ 'layer1' , 'layer2' ])
x = torch . rand ( 1 , 5 )
y , [ x1 , x2 ] = model_wrapped ( x )
print ( x1 ) # tensor([[ 0.1739, 0.3844, -0.4724]], grad_fn=<AddmmBackward0>)
print ( x2 ) # tensor([[-0.2238, 0.0107]], grad_fn=<AddmmBackward0>)
我們可以提供一個字典來取得命名輸出:
model_wrapped = Inspect ( model , layer = { 'layer1' : 'x1' , 'layer2' : 'x2' })
x = torch . rand ( 1 , 5 )
y , layers = model_wrapped ( x )
print ( layers )
"""
{
'x1': tensor([[ 0.3707, 0.6584, -0.2970]], grad_fn=<AddmmBackward0>),
'x2': tensor([[-0.1953, -0.3408]], grad_fn=<AddmmBackward0>)
}
"""
model = Inspect (
model : nn . Module ,
layer : Union [ str , Sequence [ str ], Dict [ str , str ]],
keep_output : bool = True ,
)
給定 PyTorch 模型,我們可以使用get_nodes
顯示圖形的所有中間節點:
import torch
import torch . nn as nn
from surgeon_pytorch import Extract , get_nodes
class SomeModel ( nn . Module ):
def __init__ ( self ):
super (). __init__ ()
self . layer1 = nn . Linear ( 5 , 3 )
self . layer2 = nn . Linear ( 3 , 2 )
self . layer3 = nn . Linear ( 1 , 1 )
def forward ( self , x ):
x1 = torch . relu ( self . layer1 ( x ))
x2 = torch . sigmoid ( self . layer2 ( x1 ))
y = self . layer3 ( x2 ). tanh ()
return y
model = SomeModel ()
print ( get_nodes ( model )) # ['x', 'layer1', 'relu', 'layer2', 'sigmoid', 'layer3', 'tanh']
然後我們可以使用Extract
提取輸出,這將創建一個返回請求的輸出節點的新模型:
model_ext = Extract ( model , node_out = 'sigmoid' )
x = torch . rand ( 1 , 5 )
sigmoid = model_ext ( x )
print ( sigmoid ) # tensor([[0.5570, 0.3652]], grad_fn=<SigmoidBackward0>)
我們也可以使用新的輸入節點來提取模型:
model_ext = Extract ( model , node_in = 'layer1' , node_out = 'sigmoid' )
layer1 = torch . rand ( 1 , 3 )
sigmoid = model_ext ( layer1 )
print ( sigmoid ) # tensor([[0.5444, 0.3965]], grad_fn=<SigmoidBackward0>)
我們還可以提供多個輸入和輸出並命名它們:
model_ext = Extract ( model , node_in = { 'layer1' : 'x' }, node_out = { 'sigmoid' : 'y1' , 'relu' : 'y2' })
out = model_ext ( x = torch . rand ( 1 , 3 ))
print ( out )
"""
{
'y1': tensor([[0.4437, 0.7152]], grad_fn=<SigmoidBackward0>),
'y2': tensor([[0.0555, 0.9014, 0.8297]]),
}
"""
請注意,更改輸入節點可能不足以切割圖形(可能還有其他依賴項連接到先前的輸入)。要查看新圖的所有輸入,我們可以呼叫model_ext.summary
,它將為我們提供所有所需輸入和返回輸出的概述:
import torch
import torch . nn as nn
from surgeon_pytorch import Extract , get_nodes
class SomeModel ( nn . Module ):
def __init__ ( self ):
super (). __init__ ()
self . layer1a = nn . Linear ( 2 , 2 )
self . layer1b = nn . Linear ( 2 , 2 )
self . layer2 = nn . Linear ( 2 , 1 )
def forward ( self , x ):
a = self . layer1a ( x )
b = self . layer1b ( x )
c = torch . add ( a , b )
y = self . layer2 ( c )
return y
model = SomeModel ()
print ( get_nodes ( model )) # ['x', 'layer1a', 'layer1b', 'add', 'layer2']
model_ext = Extract ( model , node_in = { 'layer1a' : 'my_input' }, node_out = { 'add' : 'my_add' })
print ( model_ext . summary ) # {'input': ('x', 'my_input'), 'output': {'my_add': add}}
out = model_ext ( x = torch . rand ( 1 , 2 ), my_input = torch . rand ( 1 , 2 ))
print ( out ) # {'my_add': tensor([[ 0.3722, -0.6843]], grad_fn=<AddBackward0>)}
model = Extract (
model : nn . Module ,
node_in : Optional [ Union [ str , Sequence [ str ], Dict [ str , str ]]] = None ,
node_out : Optional [ Union [ str , Sequence [ str ], Dict [ str , str ]]] = None ,
tracer : Optional [ Type [ Tracer ]] = None , # Tracer class used, default: torch.fx.Tracer
concrete_args : Optional [ Dict [ str , Any ]] = None , # Tracer concrete_args, default: None
keep_output : bool = None , # Set to `True` to return original outputs as first argument, default: True except if node_out are provided
share_modules : bool = False , # Set to true if you want to share module weights with original model
)
Inspect
類別始終執行作為輸入提供的整個模型,並使用特殊的鉤子來記錄流過的張量值。這種方法的優點是 (1) 我們不建立新模組 (2) 它允許動態執行圖(即依賴輸入的for
迴圈和if
語句)。 Inspect
的缺點是(1)如果我們只需要執行模型的一部分,則會浪費一些計算,(2)我們只能從nn.Module
層輸出值 - 沒有中間函數值。
Extract
類別使用符號追蹤建立一個全新的模型。這種方法的優點是(1)我們可以在任何地方裁剪圖形並獲得僅計算該部分的新模型,(2)我們可以從中間函數(不僅僅是層)中提取值,(3)我們還可以更改輸入張量。 Extract
的缺點是只允許靜態圖(請注意,大多數模型都有靜態圖)。