ไลบรารีสำหรับตรวจสอบและแยกเลเยอร์กลางของโมเดล PyTorch
บ่อยครั้งเป็นกรณีที่เราต้องการ ตรวจสอบ เลเยอร์กลางของโมเดล PyTorch โดยไม่ต้องแก้ไขโค้ด สิ่งนี้มีประโยชน์ในการรับเมทริกซ์ความสนใจของโมเดลภาษา แสดงภาพการฝังเลเยอร์ หรือใช้ฟังก์ชันการสูญเสียกับเลเยอร์ระดับกลาง บางครั้งเราต้องการ แยก ส่วนย่อยของโมเดลและรันแยกกัน ไม่ว่าจะเพื่อดีบักหรือฝึกแยกกัน ทั้งหมดนี้สามารถทำได้ด้วย Surgeon โดยไม่ต้องเปลี่ยนโมเดลดั้งเดิมแม้แต่บรรทัดเดียว
$ pip install surgeon-pytorch
ด้วยโมเดล PyTorch เราสามารถแสดงเลเยอร์ทั้งหมดโดยใช้ get_layers
:
import torch
import torch . nn as nn
from surgeon_pytorch import Inspect , get_layers
class SomeModel ( nn . Module ):
def __init__ ( self ):
super (). __init__ ()
self . layer1 = nn . Linear ( 5 , 3 )
self . layer2 = nn . Linear ( 3 , 2 )
self . layer3 = nn . Linear ( 2 , 1 )
def forward ( self , x ):
x1 = self . layer1 ( x )
x2 = self . layer2 ( x1 )
y = self . layer3 ( x2 )
return y
model = SomeModel ()
print ( get_layers ( model )) # ['layer1', 'layer2', 'layer3']
จากนั้นเราสามารถรวม model
ของเราเพื่อรับการตรวจสอบโดยใช้ Inspect
และในการส่งต่อโมเดลใหม่ทุกครั้ง เราจะส่งออกเอาต์พุตเลเยอร์ที่ให้มาด้วย (ในค่าส่งคืนที่สอง):
model_wrapped = Inspect ( model , layer = 'layer2' )
x = torch . rand ( 1 , 5 )
y , x2 = model_wrapped ( x )
print ( x2 ) # tensor([[-0.2726, 0.0910]], grad_fn=<AddmmBackward0>)
เราสามารถจัดเตรียมรายการเลเยอร์ได้:
model_wrapped = Inspect ( model , layer = [ 'layer1' , 'layer2' ])
x = torch . rand ( 1 , 5 )
y , [ x1 , x2 ] = model_wrapped ( x )
print ( x1 ) # tensor([[ 0.1739, 0.3844, -0.4724]], grad_fn=<AddmmBackward0>)
print ( x2 ) # tensor([[-0.2238, 0.0107]], grad_fn=<AddmmBackward0>)
เราสามารถจัดเตรียมพจนานุกรมเพื่อรับผลลัพธ์ที่มีชื่อได้:
model_wrapped = Inspect ( model , layer = { 'layer1' : 'x1' , 'layer2' : 'x2' })
x = torch . rand ( 1 , 5 )
y , layers = model_wrapped ( x )
print ( layers )
"""
{
'x1': tensor([[ 0.3707, 0.6584, -0.2970]], grad_fn=<AddmmBackward0>),
'x2': tensor([[-0.1953, -0.3408]], grad_fn=<AddmmBackward0>)
}
"""
model = Inspect (
model : nn . Module ,
layer : Union [ str , Sequence [ str ], Dict [ str , str ]],
keep_output : bool = True ,
)
ด้วยโมเดล PyTorch เราสามารถแสดงโหนดกลางทั้งหมดของกราฟได้โดยใช้ get_nodes
:
import torch
import torch . nn as nn
from surgeon_pytorch import Extract , get_nodes
class SomeModel ( nn . Module ):
def __init__ ( self ):
super (). __init__ ()
self . layer1 = nn . Linear ( 5 , 3 )
self . layer2 = nn . Linear ( 3 , 2 )
self . layer3 = nn . Linear ( 1 , 1 )
def forward ( self , x ):
x1 = torch . relu ( self . layer1 ( x ))
x2 = torch . sigmoid ( self . layer2 ( x1 ))
y = self . layer3 ( x2 ). tanh ()
return y
model = SomeModel ()
print ( get_nodes ( model )) # ['x', 'layer1', 'relu', 'layer2', 'sigmoid', 'layer3', 'tanh']
จากนั้นเราสามารถแยกเอาต์พุตโดยใช้ Extract
ซึ่งจะสร้างโมเดลใหม่ที่ส่งคืนโหนดเอาต์พุตที่ร้องขอ:
model_ext = Extract ( model , node_out = 'sigmoid' )
x = torch . rand ( 1 , 5 )
sigmoid = model_ext ( x )
print ( sigmoid ) # tensor([[0.5570, 0.3652]], grad_fn=<SigmoidBackward0>)
นอกจากนี้เรายังสามารถแยกโมเดลที่มีโหนดอินพุตใหม่ได้:
model_ext = Extract ( model , node_in = 'layer1' , node_out = 'sigmoid' )
layer1 = torch . rand ( 1 , 3 )
sigmoid = model_ext ( layer1 )
print ( sigmoid ) # tensor([[0.5444, 0.3965]], grad_fn=<SigmoidBackward0>)
นอกจากนี้เรายังสามารถจัดเตรียมอินพุตและเอาต์พุตหลายรายการและตั้งชื่อได้:
model_ext = Extract ( model , node_in = { 'layer1' : 'x' }, node_out = { 'sigmoid' : 'y1' , 'relu' : 'y2' })
out = model_ext ( x = torch . rand ( 1 , 3 ))
print ( out )
"""
{
'y1': tensor([[0.4437, 0.7152]], grad_fn=<SigmoidBackward0>),
'y2': tensor([[0.0555, 0.9014, 0.8297]]),
}
"""
โปรดทราบว่าการเปลี่ยนโหนดอินพุตอาจไม่เพียงพอที่จะตัดกราฟ (อาจมีการขึ้นต่อกันอื่นๆ ที่เชื่อมต่อกับอินพุตก่อนหน้า) หากต้องการดูอินพุตทั้งหมดของกราฟใหม่ เราสามารถเรียก model_ext.summary
ซึ่งจะให้ภาพรวมของอินพุตที่จำเป็นและเอาต์พุตที่ส่งคืนทั้งหมด:
import torch
import torch . nn as nn
from surgeon_pytorch import Extract , get_nodes
class SomeModel ( nn . Module ):
def __init__ ( self ):
super (). __init__ ()
self . layer1a = nn . Linear ( 2 , 2 )
self . layer1b = nn . Linear ( 2 , 2 )
self . layer2 = nn . Linear ( 2 , 1 )
def forward ( self , x ):
a = self . layer1a ( x )
b = self . layer1b ( x )
c = torch . add ( a , b )
y = self . layer2 ( c )
return y
model = SomeModel ()
print ( get_nodes ( model )) # ['x', 'layer1a', 'layer1b', 'add', 'layer2']
model_ext = Extract ( model , node_in = { 'layer1a' : 'my_input' }, node_out = { 'add' : 'my_add' })
print ( model_ext . summary ) # {'input': ('x', 'my_input'), 'output': {'my_add': add}}
out = model_ext ( x = torch . rand ( 1 , 2 ), my_input = torch . rand ( 1 , 2 ))
print ( out ) # {'my_add': tensor([[ 0.3722, -0.6843]], grad_fn=<AddBackward0>)}
model = Extract (
model : nn . Module ,
node_in : Optional [ Union [ str , Sequence [ str ], Dict [ str , str ]]] = None ,
node_out : Optional [ Union [ str , Sequence [ str ], Dict [ str , str ]]] = None ,
tracer : Optional [ Type [ Tracer ]] = None , # Tracer class used, default: torch.fx.Tracer
concrete_args : Optional [ Dict [ str , Any ]] = None , # Tracer concrete_args, default: None
keep_output : bool = None , # Set to `True` to return original outputs as first argument, default: True except if node_out are provided
share_modules : bool = False , # Set to true if you want to share module weights with original model
)
คลาส Inspect
จะดำเนินการกับโมเดลทั้งหมดที่มีให้เป็นอินพุตเสมอ และใช้ hooks พิเศษเพื่อบันทึกค่าเทนเซอร์ขณะที่พวกมันไหลผ่าน วิธีการนี้มีข้อดีตรงที่ (1) เราไม่ได้สร้างโมดูลใหม่ (2) อนุญาตให้มีกราฟการดำเนินการแบบไดนามิก (เช่น for
ลูปและคำสั่ง if
ที่ขึ้นอยู่กับอินพุต) ข้อเสียของ Inspect
คือ (1) หากเราต้องการดำเนินการเพียงบางส่วนของโมเดล การคำนวณบางอย่างจะสูญเปล่า และ (2) เราจะส่งออกเฉพาะค่าจากเลเยอร์ nn.Module
เท่านั้น ไม่มีค่าฟังก์ชันระดับกลาง
คลาส Extract
สร้างโมเดลใหม่ทั้งหมดโดยใช้การติดตามเชิงสัญลักษณ์ ข้อดีของแนวทางนี้คือ (1) เราสามารถครอบตัดกราฟได้ทุกที่และรับโมเดลใหม่ที่คำนวณเฉพาะส่วนนั้น (2) เราสามารถแยกค่าจากฟังก์ชันระดับกลาง (ไม่ใช่แค่เลเยอร์) และ (3) เรายังสามารถเปลี่ยนแปลงได้ เทนเซอร์อินพุต ข้อเสียของ Extract
คืออนุญาตให้ใช้เฉพาะกราฟคงที่เท่านั้น (โปรดทราบว่าโมเดลส่วนใหญ่มีกราฟคงที่)