PyTorch モデルの中間層を検査して抽出するためのライブラリ。
コードを変更せずに PyTorch モデルの中間層を検査したい場合がよくあります。これは、言語モデルのアテンション行列を取得したり、層の埋め込みを視覚化したり、中間層に損失関数を適用したりするのに役立ちます。場合によっては、モデルのサブパートを抽出して個別に実行して、デバッグまたは個別にトレーニングしたい場合があります。これらすべては、元のモデルの 1 行も変更することなく、Surgeon を使用して実行できます。
$ pip install surgeon-pytorch
PyTorch モデルを指定すると、 get_layers
使用してすべてのレイヤーを表示できます。
import torch
import torch . nn as nn
from surgeon_pytorch import Inspect , get_layers
class SomeModel ( nn . Module ):
def __init__ ( self ):
super (). __init__ ()
self . layer1 = nn . Linear ( 5 , 3 )
self . layer2 = nn . Linear ( 3 , 2 )
self . layer3 = nn . Linear ( 2 , 1 )
def forward ( self , x ):
x1 = self . layer1 ( x )
x2 = self . layer2 ( x1 )
y = self . layer3 ( x2 )
return y
model = SomeModel ()
print ( get_layers ( model )) # ['layer1', 'layer2', 'layer3']
次に、 Inspect
使用して検査対象のmodel
ラップし、新しいモデルを前方呼び出しするたびに、提供されたレイヤー出力 (2 番目の戻り値) も出力します。
model_wrapped = Inspect ( model , layer = 'layer2' )
x = torch . rand ( 1 , 5 )
y , x2 = model_wrapped ( x )
print ( x2 ) # tensor([[-0.2726, 0.0910]], grad_fn=<AddmmBackward0>)
レイヤーのリストを提供できます。
model_wrapped = Inspect ( model , layer = [ 'layer1' , 'layer2' ])
x = torch . rand ( 1 , 5 )
y , [ x1 , x2 ] = model_wrapped ( x )
print ( x1 ) # tensor([[ 0.1739, 0.3844, -0.4724]], grad_fn=<AddmmBackward0>)
print ( x2 ) # tensor([[-0.2238, 0.0107]], grad_fn=<AddmmBackward0>)
名前付き出力を取得するための辞書を提供できます。
model_wrapped = Inspect ( model , layer = { 'layer1' : 'x1' , 'layer2' : 'x2' })
x = torch . rand ( 1 , 5 )
y , layers = model_wrapped ( x )
print ( layers )
"""
{
'x1': tensor([[ 0.3707, 0.6584, -0.2970]], grad_fn=<AddmmBackward0>),
'x2': tensor([[-0.1953, -0.3408]], grad_fn=<AddmmBackward0>)
}
"""
model = Inspect (
model : nn . Module ,
layer : Union [ str , Sequence [ str ], Dict [ str , str ]],
keep_output : bool = True ,
)
PyTorch モデルを指定すると、 get_nodes
使用してグラフのすべての中間ノードを表示できます。
import torch
import torch . nn as nn
from surgeon_pytorch import Extract , get_nodes
class SomeModel ( nn . Module ):
def __init__ ( self ):
super (). __init__ ()
self . layer1 = nn . Linear ( 5 , 3 )
self . layer2 = nn . Linear ( 3 , 2 )
self . layer3 = nn . Linear ( 1 , 1 )
def forward ( self , x ):
x1 = torch . relu ( self . layer1 ( x ))
x2 = torch . sigmoid ( self . layer2 ( x1 ))
y = self . layer3 ( x2 ). tanh ()
return y
model = SomeModel ()
print ( get_nodes ( model )) # ['x', 'layer1', 'relu', 'layer2', 'sigmoid', 'layer3', 'tanh']
次に、 Extract
使用して出力を抽出できます。これにより、要求された出力ノードを返す新しいモデルが作成されます。
model_ext = Extract ( model , node_out = 'sigmoid' )
x = torch . rand ( 1 , 5 )
sigmoid = model_ext ( x )
print ( sigmoid ) # tensor([[0.5570, 0.3652]], grad_fn=<SigmoidBackward0>)
新しい入力ノードを使用してモデルを抽出することもできます。
model_ext = Extract ( model , node_in = 'layer1' , node_out = 'sigmoid' )
layer1 = torch . rand ( 1 , 3 )
sigmoid = model_ext ( layer1 )
print ( sigmoid ) # tensor([[0.5444, 0.3965]], grad_fn=<SigmoidBackward0>)
複数の入力と出力を提供し、それらに名前を付けることもできます。
model_ext = Extract ( model , node_in = { 'layer1' : 'x' }, node_out = { 'sigmoid' : 'y1' , 'relu' : 'y2' })
out = model_ext ( x = torch . rand ( 1 , 3 ))
print ( out )
"""
{
'y1': tensor([[0.4437, 0.7152]], grad_fn=<SigmoidBackward0>),
'y2': tensor([[0.0555, 0.9014, 0.8297]]),
}
"""
入力ノードを変更するだけではグラフを切り取るのに十分ではない可能性があることに注意してください (前の入力に接続されている他の依存関係がある可能性があります)。新しいグラフのすべての入力を表示するには、 model_ext.summary
を呼び出します。これにより、すべての必要な入力と返された出力の概要が得られます。
import torch
import torch . nn as nn
from surgeon_pytorch import Extract , get_nodes
class SomeModel ( nn . Module ):
def __init__ ( self ):
super (). __init__ ()
self . layer1a = nn . Linear ( 2 , 2 )
self . layer1b = nn . Linear ( 2 , 2 )
self . layer2 = nn . Linear ( 2 , 1 )
def forward ( self , x ):
a = self . layer1a ( x )
b = self . layer1b ( x )
c = torch . add ( a , b )
y = self . layer2 ( c )
return y
model = SomeModel ()
print ( get_nodes ( model )) # ['x', 'layer1a', 'layer1b', 'add', 'layer2']
model_ext = Extract ( model , node_in = { 'layer1a' : 'my_input' }, node_out = { 'add' : 'my_add' })
print ( model_ext . summary ) # {'input': ('x', 'my_input'), 'output': {'my_add': add}}
out = model_ext ( x = torch . rand ( 1 , 2 ), my_input = torch . rand ( 1 , 2 ))
print ( out ) # {'my_add': tensor([[ 0.3722, -0.6843]], grad_fn=<AddBackward0>)}
model = Extract (
model : nn . Module ,
node_in : Optional [ Union [ str , Sequence [ str ], Dict [ str , str ]]] = None ,
node_out : Optional [ Union [ str , Sequence [ str ], Dict [ str , str ]]] = None ,
tracer : Optional [ Type [ Tracer ]] = None , # Tracer class used, default: torch.fx.Tracer
concrete_args : Optional [ Dict [ str , Any ]] = None , # Tracer concrete_args, default: None
keep_output : bool = None , # Set to `True` to return original outputs as first argument, default: True except if node_out are provided
share_modules : bool = False , # Set to true if you want to share module weights with original model
)
Inspect
クラスは、入力として提供されたモデル全体を常に実行し、特別なフックを使用してテンソル値が流れるときにそれを記録します。このアプローチには、(1) 新しいモジュールを作成しない、(2) 動的な実行グラフ (入力に依存するfor
ループやif
ステートメントなど) が可能になるという利点があります。 Inspect
の欠点は、(1) モデルの一部のみを実行する必要がある場合、一部の計算が無駄になること、(2) nn.Module
層からの値のみを出力できるため、中間関数値が存在しないことです。
Extract
クラスは、シンボリック トレースを使用してまったく新しいモデルを構築します。このアプローチの利点は、(1) グラフを任意の場所でトリミングして、その部分のみを計算する新しいモデルを取得できること、(2) 中間関数 (レイヤーだけでなく) から値を抽出できること、(3) 変更もできることです。入力テンソル。 Extract
の欠点は、静的グラフのみが許可されることです (ほとんどのモデルには静的グラフがあることに注意してください)。