Pustaka untuk memeriksa dan mengekstrak lapisan perantara model PyTorch.
Seringkali kita ingin memeriksa lapisan perantara model PyTorch tanpa mengubah kodenya. Hal ini berguna untuk mendapatkan matriks perhatian model bahasa, memvisualisasikan penyematan lapisan, atau menerapkan fungsi kerugian pada lapisan perantara. Terkadang kita ingin mengekstrak subbagian model dan menjalankannya secara independen, baik untuk melakukan debug atau melatihnya secara terpisah. Semua ini dapat dilakukan dengan Surgeon tanpa mengubah satu baris pun model aslinya.
$ pip install surgeon-pytorch
Dengan model PyTorch kita dapat menampilkan semua lapisan menggunakan get_layers
:
import torch
import torch . nn as nn
from surgeon_pytorch import Inspect , get_layers
class SomeModel ( nn . Module ):
def __init__ ( self ):
super (). __init__ ()
self . layer1 = nn . Linear ( 5 , 3 )
self . layer2 = nn . Linear ( 3 , 2 )
self . layer3 = nn . Linear ( 2 , 1 )
def forward ( self , x ):
x1 = self . layer1 ( x )
x2 = self . layer2 ( x1 )
y = self . layer3 ( x2 )
return y
model = SomeModel ()
print ( get_layers ( model )) # ['layer1', 'layer2', 'layer3']
Kemudian kita dapat membungkus model
kita untuk diperiksa menggunakan Inspect
dan dalam setiap panggilan penerusan model baru kita juga akan menampilkan keluaran lapisan yang disediakan (dalam nilai kembalian kedua):
model_wrapped = Inspect ( model , layer = 'layer2' )
x = torch . rand ( 1 , 5 )
y , x2 = model_wrapped ( x )
print ( x2 ) # tensor([[-0.2726, 0.0910]], grad_fn=<AddmmBackward0>)
Kami dapat memberikan daftar lapisan:
model_wrapped = Inspect ( model , layer = [ 'layer1' , 'layer2' ])
x = torch . rand ( 1 , 5 )
y , [ x1 , x2 ] = model_wrapped ( x )
print ( x1 ) # tensor([[ 0.1739, 0.3844, -0.4724]], grad_fn=<AddmmBackward0>)
print ( x2 ) # tensor([[-0.2238, 0.0107]], grad_fn=<AddmmBackward0>)
Kami dapat menyediakan kamus untuk mendapatkan keluaran bernama:
model_wrapped = Inspect ( model , layer = { 'layer1' : 'x1' , 'layer2' : 'x2' })
x = torch . rand ( 1 , 5 )
y , layers = model_wrapped ( x )
print ( layers )
"""
{
'x1': tensor([[ 0.3707, 0.6584, -0.2970]], grad_fn=<AddmmBackward0>),
'x2': tensor([[-0.1953, -0.3408]], grad_fn=<AddmmBackward0>)
}
"""
model = Inspect (
model : nn . Module ,
layer : Union [ str , Sequence [ str ], Dict [ str , str ]],
keep_output : bool = True ,
)
Dengan adanya model PyTorch, kita dapat menampilkan semua node perantara pada grafik menggunakan get_nodes
:
import torch
import torch . nn as nn
from surgeon_pytorch import Extract , get_nodes
class SomeModel ( nn . Module ):
def __init__ ( self ):
super (). __init__ ()
self . layer1 = nn . Linear ( 5 , 3 )
self . layer2 = nn . Linear ( 3 , 2 )
self . layer3 = nn . Linear ( 1 , 1 )
def forward ( self , x ):
x1 = torch . relu ( self . layer1 ( x ))
x2 = torch . sigmoid ( self . layer2 ( x1 ))
y = self . layer3 ( x2 ). tanh ()
return y
model = SomeModel ()
print ( get_nodes ( model )) # ['x', 'layer1', 'relu', 'layer2', 'sigmoid', 'layer3', 'tanh']
Kemudian kita dapat mengekstrak keluaran menggunakan Extract
, yang akan membuat model baru yang mengembalikan simpul keluaran yang diminta:
model_ext = Extract ( model , node_out = 'sigmoid' )
x = torch . rand ( 1 , 5 )
sigmoid = model_ext ( x )
print ( sigmoid ) # tensor([[0.5570, 0.3652]], grad_fn=<SigmoidBackward0>)
Kita juga dapat mengekstrak model dengan node masukan baru:
model_ext = Extract ( model , node_in = 'layer1' , node_out = 'sigmoid' )
layer1 = torch . rand ( 1 , 3 )
sigmoid = model_ext ( layer1 )
print ( sigmoid ) # tensor([[0.5444, 0.3965]], grad_fn=<SigmoidBackward0>)
Kami juga dapat memberikan beberapa masukan dan keluaran dan menamainya:
model_ext = Extract ( model , node_in = { 'layer1' : 'x' }, node_out = { 'sigmoid' : 'y1' , 'relu' : 'y2' })
out = model_ext ( x = torch . rand ( 1 , 3 ))
print ( out )
"""
{
'y1': tensor([[0.4437, 0.7152]], grad_fn=<SigmoidBackward0>),
'y2': tensor([[0.0555, 0.9014, 0.8297]]),
}
"""
Perhatikan bahwa mengubah node masukan mungkin tidak cukup untuk memotong grafik (mungkin ada dependensi lain yang terhubung ke masukan sebelumnya). Untuk melihat semua masukan pada grafik baru, kita dapat memanggil model_ext.summary
yang akan memberi kita gambaran umum tentang semua masukan yang diperlukan dan keluaran yang dikembalikan:
import torch
import torch . nn as nn
from surgeon_pytorch import Extract , get_nodes
class SomeModel ( nn . Module ):
def __init__ ( self ):
super (). __init__ ()
self . layer1a = nn . Linear ( 2 , 2 )
self . layer1b = nn . Linear ( 2 , 2 )
self . layer2 = nn . Linear ( 2 , 1 )
def forward ( self , x ):
a = self . layer1a ( x )
b = self . layer1b ( x )
c = torch . add ( a , b )
y = self . layer2 ( c )
return y
model = SomeModel ()
print ( get_nodes ( model )) # ['x', 'layer1a', 'layer1b', 'add', 'layer2']
model_ext = Extract ( model , node_in = { 'layer1a' : 'my_input' }, node_out = { 'add' : 'my_add' })
print ( model_ext . summary ) # {'input': ('x', 'my_input'), 'output': {'my_add': add}}
out = model_ext ( x = torch . rand ( 1 , 2 ), my_input = torch . rand ( 1 , 2 ))
print ( out ) # {'my_add': tensor([[ 0.3722, -0.6843]], grad_fn=<AddBackward0>)}
model = Extract (
model : nn . Module ,
node_in : Optional [ Union [ str , Sequence [ str ], Dict [ str , str ]]] = None ,
node_out : Optional [ Union [ str , Sequence [ str ], Dict [ str , str ]]] = None ,
tracer : Optional [ Type [ Tracer ]] = None , # Tracer class used, default: torch.fx.Tracer
concrete_args : Optional [ Dict [ str , Any ]] = None , # Tracer concrete_args, default: None
keep_output : bool = None , # Set to `True` to return original outputs as first argument, default: True except if node_out are provided
share_modules : bool = False , # Set to true if you want to share module weights with original model
)
Kelas Inspect
selalu mengeksekusi seluruh model yang diberikan sebagai masukan, dan menggunakan kait khusus untuk mencatat nilai tensor saat nilai tersebut mengalir. Pendekatan ini memiliki kelebihan yaitu (1) kita tidak membuat modul baru (2) memungkinkan grafik eksekusi dinamis (yaitu for
loop dan pernyataan if
yang bergantung pada input). Kelemahan dari Inspect
adalah (1) jika kita hanya perlu mengeksekusi sebagian model, beberapa komputasi akan sia-sia, dan (2) kita hanya dapat mengeluarkan nilai dari lapisan nn.Module
– tidak ada nilai fungsi perantara.
Kelas Extract
membangun model yang sepenuhnya baru menggunakan penelusuran simbolik. Keuntungan dari pendekatan ini adalah (1) kita dapat memotong grafik di mana saja dan mendapatkan model baru yang hanya menghitung bagian tersebut, (2) kita dapat mengekstrak nilai dari fungsi perantara (tidak hanya lapisan), dan (3) kita juga dapat mengubah tensor masukan. Kelemahan dari Extract
adalah hanya grafik statis yang diperbolehkan (perhatikan bahwa sebagian besar model memiliki grafik statis).