Una biblioteca para inspeccionar y extraer capas intermedias de modelos PyTorch.
A menudo ocurre que queremos inspeccionar capas intermedias de modelos de PyTorch sin modificar el código. Esto puede resultar útil para llamar la atención sobre matrices de modelos de lenguaje, visualizar incrustaciones de capas o aplicar una función de pérdida a capas intermedias. A veces queremos extraer subpartes del modelo y ejecutarlas de forma independiente, ya sea para depurarlas o entrenarlas por separado. Todo esto se puede hacer con Surgeon sin cambiar una línea del modelo original.
$ pip install surgeon-pytorch
Dado un modelo de PyTorch, podemos mostrar todas las capas usando get_layers
:
import torch
import torch . nn as nn
from surgeon_pytorch import Inspect , get_layers
class SomeModel ( nn . Module ):
def __init__ ( self ):
super (). __init__ ()
self . layer1 = nn . Linear ( 5 , 3 )
self . layer2 = nn . Linear ( 3 , 2 )
self . layer3 = nn . Linear ( 2 , 1 )
def forward ( self , x ):
x1 = self . layer1 ( x )
x2 = self . layer2 ( x1 )
y = self . layer3 ( x2 )
return y
model = SomeModel ()
print ( get_layers ( model )) # ['layer1', 'layer2', 'layer3']
Luego podemos ajustar nuestro model
para inspeccionarlo usando Inspect
y en cada llamada directa al nuevo modelo también generaremos las salidas de capa proporcionadas (en el segundo valor de retorno):
model_wrapped = Inspect ( model , layer = 'layer2' )
x = torch . rand ( 1 , 5 )
y , x2 = model_wrapped ( x )
print ( x2 ) # tensor([[-0.2726, 0.0910]], grad_fn=<AddmmBackward0>)
Podemos proporcionar una lista de capas:
model_wrapped = Inspect ( model , layer = [ 'layer1' , 'layer2' ])
x = torch . rand ( 1 , 5 )
y , [ x1 , x2 ] = model_wrapped ( x )
print ( x1 ) # tensor([[ 0.1739, 0.3844, -0.4724]], grad_fn=<AddmmBackward0>)
print ( x2 ) # tensor([[-0.2238, 0.0107]], grad_fn=<AddmmBackward0>)
Podemos proporcionar un diccionario para obtener resultados con nombre:
model_wrapped = Inspect ( model , layer = { 'layer1' : 'x1' , 'layer2' : 'x2' })
x = torch . rand ( 1 , 5 )
y , layers = model_wrapped ( x )
print ( layers )
"""
{
'x1': tensor([[ 0.3707, 0.6584, -0.2970]], grad_fn=<AddmmBackward0>),
'x2': tensor([[-0.1953, -0.3408]], grad_fn=<AddmmBackward0>)
}
"""
model = Inspect (
model : nn . Module ,
layer : Union [ str , Sequence [ str ], Dict [ str , str ]],
keep_output : bool = True ,
)
Dado un modelo de PyTorch, podemos mostrar todos los nodos intermedios del gráfico usando get_nodes
:
import torch
import torch . nn as nn
from surgeon_pytorch import Extract , get_nodes
class SomeModel ( nn . Module ):
def __init__ ( self ):
super (). __init__ ()
self . layer1 = nn . Linear ( 5 , 3 )
self . layer2 = nn . Linear ( 3 , 2 )
self . layer3 = nn . Linear ( 1 , 1 )
def forward ( self , x ):
x1 = torch . relu ( self . layer1 ( x ))
x2 = torch . sigmoid ( self . layer2 ( x1 ))
y = self . layer3 ( x2 ). tanh ()
return y
model = SomeModel ()
print ( get_nodes ( model )) # ['x', 'layer1', 'relu', 'layer2', 'sigmoid', 'layer3', 'tanh']
Luego podemos extraer resultados usando Extract
, lo que creará un nuevo modelo que devuelve el nodo de salida solicitado:
model_ext = Extract ( model , node_out = 'sigmoid' )
x = torch . rand ( 1 , 5 )
sigmoid = model_ext ( x )
print ( sigmoid ) # tensor([[0.5570, 0.3652]], grad_fn=<SigmoidBackward0>)
También podemos extraer un modelo con nuevos nodos de entrada:
model_ext = Extract ( model , node_in = 'layer1' , node_out = 'sigmoid' )
layer1 = torch . rand ( 1 , 3 )
sigmoid = model_ext ( layer1 )
print ( sigmoid ) # tensor([[0.5444, 0.3965]], grad_fn=<SigmoidBackward0>)
También podemos proporcionar múltiples entradas y salidas y nombrarlas:
model_ext = Extract ( model , node_in = { 'layer1' : 'x' }, node_out = { 'sigmoid' : 'y1' , 'relu' : 'y2' })
out = model_ext ( x = torch . rand ( 1 , 3 ))
print ( out )
"""
{
'y1': tensor([[0.4437, 0.7152]], grad_fn=<SigmoidBackward0>),
'y2': tensor([[0.0555, 0.9014, 0.8297]]),
}
"""
Tenga en cuenta que cambiar un nodo de entrada puede no ser suficiente para cortar el gráfico (puede haber otras dependencias conectadas a entradas anteriores). Para ver todas las entradas del nuevo gráfico, podemos llamar a model_ext.summary
, que nos brindará una descripción general de todas las entradas requeridas y las salidas devueltas:
import torch
import torch . nn as nn
from surgeon_pytorch import Extract , get_nodes
class SomeModel ( nn . Module ):
def __init__ ( self ):
super (). __init__ ()
self . layer1a = nn . Linear ( 2 , 2 )
self . layer1b = nn . Linear ( 2 , 2 )
self . layer2 = nn . Linear ( 2 , 1 )
def forward ( self , x ):
a = self . layer1a ( x )
b = self . layer1b ( x )
c = torch . add ( a , b )
y = self . layer2 ( c )
return y
model = SomeModel ()
print ( get_nodes ( model )) # ['x', 'layer1a', 'layer1b', 'add', 'layer2']
model_ext = Extract ( model , node_in = { 'layer1a' : 'my_input' }, node_out = { 'add' : 'my_add' })
print ( model_ext . summary ) # {'input': ('x', 'my_input'), 'output': {'my_add': add}}
out = model_ext ( x = torch . rand ( 1 , 2 ), my_input = torch . rand ( 1 , 2 ))
print ( out ) # {'my_add': tensor([[ 0.3722, -0.6843]], grad_fn=<AddBackward0>)}
model = Extract (
model : nn . Module ,
node_in : Optional [ Union [ str , Sequence [ str ], Dict [ str , str ]]] = None ,
node_out : Optional [ Union [ str , Sequence [ str ], Dict [ str , str ]]] = None ,
tracer : Optional [ Type [ Tracer ]] = None , # Tracer class used, default: torch.fx.Tracer
concrete_args : Optional [ Dict [ str , Any ]] = None , # Tracer concrete_args, default: None
keep_output : bool = None , # Set to `True` to return original outputs as first argument, default: True except if node_out are provided
share_modules : bool = False , # Set to true if you want to share module weights with original model
)
La clase Inspect
siempre ejecuta todo el modelo proporcionado como entrada y utiliza ganchos especiales para registrar los valores del tensor a medida que fluyen. Este enfoque tiene las ventajas de que (1) no creamos un nuevo módulo (2) permite un gráfico de ejecución dinámica (es decir, bucles for
y declaraciones if
que dependen de las entradas). Las desventajas de Inspect
son que (1) si solo necesitamos ejecutar parte del modelo, se desperdicia algo de cálculo y (2) solo podemos generar valores de nn.Module
capas, sin valores de función intermedia.
La clase Extract
crea un modelo completamente nuevo mediante el seguimiento simbólico. Las ventajas de este enfoque son (1) podemos recortar el gráfico en cualquier lugar y obtener un nuevo modelo que calcule solo esa parte, (2) podemos extraer valores de funciones intermedias (no solo capas) y (3) también podemos cambiar tensores de entrada. La desventaja de Extract
es que solo se permiten gráficos estáticos (tenga en cuenta que la mayoría de los modelos tienen gráficos estáticos).