Uma biblioteca para inspecionar e extrair camadas intermediárias de modelos PyTorch.
Muitas vezes queremos inspecionar camadas intermediárias de modelos PyTorch sem modificar o código. Isso pode ser útil para obter matrizes de atenção de modelos de linguagem, visualizar incorporações de camadas ou aplicar uma função de perda a camadas intermediárias. Às vezes queremos extrair subpartes do modelo e executá-las de forma independente, seja para depurá-las ou para treiná-las separadamente. Tudo isso pode ser feito com o Surgeon sem alterar uma linha do modelo original.
$ pip install surgeon-pytorch
Dado um modelo PyTorch, podemos exibir todas as camadas usando get_layers
:
import torch
import torch . nn as nn
from surgeon_pytorch import Inspect , get_layers
class SomeModel ( nn . Module ):
def __init__ ( self ):
super (). __init__ ()
self . layer1 = nn . Linear ( 5 , 3 )
self . layer2 = nn . Linear ( 3 , 2 )
self . layer3 = nn . Linear ( 2 , 1 )
def forward ( self , x ):
x1 = self . layer1 ( x )
x2 = self . layer2 ( x1 )
y = self . layer3 ( x2 )
return y
model = SomeModel ()
print ( get_layers ( model )) # ['layer1', 'layer2', 'layer3']
Então podemos agrupar nosso model
para ser inspecionado usando Inspect
e em cada chamada direta do novo modelo também produziremos as saídas da camada fornecidas (no segundo valor de retorno):
model_wrapped = Inspect ( model , layer = 'layer2' )
x = torch . rand ( 1 , 5 )
y , x2 = model_wrapped ( x )
print ( x2 ) # tensor([[-0.2726, 0.0910]], grad_fn=<AddmmBackward0>)
Podemos fornecer uma lista de camadas:
model_wrapped = Inspect ( model , layer = [ 'layer1' , 'layer2' ])
x = torch . rand ( 1 , 5 )
y , [ x1 , x2 ] = model_wrapped ( x )
print ( x1 ) # tensor([[ 0.1739, 0.3844, -0.4724]], grad_fn=<AddmmBackward0>)
print ( x2 ) # tensor([[-0.2238, 0.0107]], grad_fn=<AddmmBackward0>)
Podemos fornecer um dicionário para obter saídas nomeadas:
model_wrapped = Inspect ( model , layer = { 'layer1' : 'x1' , 'layer2' : 'x2' })
x = torch . rand ( 1 , 5 )
y , layers = model_wrapped ( x )
print ( layers )
"""
{
'x1': tensor([[ 0.3707, 0.6584, -0.2970]], grad_fn=<AddmmBackward0>),
'x2': tensor([[-0.1953, -0.3408]], grad_fn=<AddmmBackward0>)
}
"""
model = Inspect (
model : nn . Module ,
layer : Union [ str , Sequence [ str ], Dict [ str , str ]],
keep_output : bool = True ,
)
Dado um modelo PyTorch, podemos exibir todos os nós intermediários do gráfico usando get_nodes
:
import torch
import torch . nn as nn
from surgeon_pytorch import Extract , get_nodes
class SomeModel ( nn . Module ):
def __init__ ( self ):
super (). __init__ ()
self . layer1 = nn . Linear ( 5 , 3 )
self . layer2 = nn . Linear ( 3 , 2 )
self . layer3 = nn . Linear ( 1 , 1 )
def forward ( self , x ):
x1 = torch . relu ( self . layer1 ( x ))
x2 = torch . sigmoid ( self . layer2 ( x1 ))
y = self . layer3 ( x2 ). tanh ()
return y
model = SomeModel ()
print ( get_nodes ( model )) # ['x', 'layer1', 'relu', 'layer2', 'sigmoid', 'layer3', 'tanh']
Então podemos extrair as saídas usando Extract
, que criará um novo modelo que retorna o nó de saída solicitado:
model_ext = Extract ( model , node_out = 'sigmoid' )
x = torch . rand ( 1 , 5 )
sigmoid = model_ext ( x )
print ( sigmoid ) # tensor([[0.5570, 0.3652]], grad_fn=<SigmoidBackward0>)
Também podemos extrair um modelo com novos nós de entrada:
model_ext = Extract ( model , node_in = 'layer1' , node_out = 'sigmoid' )
layer1 = torch . rand ( 1 , 3 )
sigmoid = model_ext ( layer1 )
print ( sigmoid ) # tensor([[0.5444, 0.3965]], grad_fn=<SigmoidBackward0>)
Também podemos fornecer várias entradas e saídas e nomeá-las:
model_ext = Extract ( model , node_in = { 'layer1' : 'x' }, node_out = { 'sigmoid' : 'y1' , 'relu' : 'y2' })
out = model_ext ( x = torch . rand ( 1 , 3 ))
print ( out )
"""
{
'y1': tensor([[0.4437, 0.7152]], grad_fn=<SigmoidBackward0>),
'y2': tensor([[0.0555, 0.9014, 0.8297]]),
}
"""
Observe que alterar um nó de entrada pode não ser suficiente para cortar o gráfico (pode haver outras dependências conectadas às entradas anteriores). Para visualizar todas as entradas do novo gráfico, podemos chamar model_ext.summary
, que nos dará uma visão geral de todas as entradas necessárias e saídas retornadas:
import torch
import torch . nn as nn
from surgeon_pytorch import Extract , get_nodes
class SomeModel ( nn . Module ):
def __init__ ( self ):
super (). __init__ ()
self . layer1a = nn . Linear ( 2 , 2 )
self . layer1b = nn . Linear ( 2 , 2 )
self . layer2 = nn . Linear ( 2 , 1 )
def forward ( self , x ):
a = self . layer1a ( x )
b = self . layer1b ( x )
c = torch . add ( a , b )
y = self . layer2 ( c )
return y
model = SomeModel ()
print ( get_nodes ( model )) # ['x', 'layer1a', 'layer1b', 'add', 'layer2']
model_ext = Extract ( model , node_in = { 'layer1a' : 'my_input' }, node_out = { 'add' : 'my_add' })
print ( model_ext . summary ) # {'input': ('x', 'my_input'), 'output': {'my_add': add}}
out = model_ext ( x = torch . rand ( 1 , 2 ), my_input = torch . rand ( 1 , 2 ))
print ( out ) # {'my_add': tensor([[ 0.3722, -0.6843]], grad_fn=<AddBackward0>)}
model = Extract (
model : nn . Module ,
node_in : Optional [ Union [ str , Sequence [ str ], Dict [ str , str ]]] = None ,
node_out : Optional [ Union [ str , Sequence [ str ], Dict [ str , str ]]] = None ,
tracer : Optional [ Type [ Tracer ]] = None , # Tracer class used, default: torch.fx.Tracer
concrete_args : Optional [ Dict [ str , Any ]] = None , # Tracer concrete_args, default: None
keep_output : bool = None , # Set to `True` to return original outputs as first argument, default: True except if node_out are provided
share_modules : bool = False , # Set to true if you want to share module weights with original model
)
A classe Inspect
sempre executa todo o modelo fornecido como entrada e usa ganchos especiais para registrar os valores do tensor à medida que eles fluem. Esta abordagem tem as vantagens de (1) não criarmos um novo módulo (2) ela permite um gráfico de execução dinâmico (ou seja, loops for
e instruções if
que dependem de entradas). As desvantagens do Inspect
são que (1) se precisarmos executar apenas parte do modelo, algum cálculo será desperdiçado e (2) só poderemos gerar valores de camadas nn.Module
– sem valores de função intermediários.
A classe Extract
cria um modelo totalmente novo usando rastreamento simbólico. As vantagens desta abordagem são (1) podemos cortar o gráfico em qualquer lugar e obter um novo modelo que calcula apenas aquela parte, (2) podemos extrair valores de funções intermediárias (não apenas camadas) e (3) também podemos alterar tensores de entrada. A desvantagem do Extract
é que apenas gráficos estáticos são permitidos (observe que a maioria dos modelos possui gráficos estáticos).