مكتبة لفحص واستخراج الطبقات المتوسطة لنماذج PyTorch.
غالبًا ما نريد فحص الطبقات المتوسطة لنماذج PyTorch دون تعديل الكود. يمكن أن يكون هذا مفيدًا لجذب انتباه مصفوفات نماذج اللغة، أو تصور تضمينات الطبقة، أو تطبيق دالة الخسارة على الطبقات المتوسطة. في بعض الأحيان نرغب في استخراج أجزاء فرعية من النموذج وتشغيلها بشكل مستقل، إما لتصحيح الأخطاء أو لتدريبها بشكل منفصل. كل هذا يمكن القيام به مع Surgeon دون تغيير سطر واحد من النموذج الأصلي.
$ pip install surgeon-pytorch
بالنظر إلى نموذج PyTorch يمكننا عرض جميع الطبقات باستخدام get_layers
:
import torch
import torch . nn as nn
from surgeon_pytorch import Inspect , get_layers
class SomeModel ( nn . Module ):
def __init__ ( self ):
super (). __init__ ()
self . layer1 = nn . Linear ( 5 , 3 )
self . layer2 = nn . Linear ( 3 , 2 )
self . layer3 = nn . Linear ( 2 , 1 )
def forward ( self , x ):
x1 = self . layer1 ( x )
x2 = self . layer2 ( x1 )
y = self . layer3 ( x2 )
return y
model = SomeModel ()
print ( get_layers ( model )) # ['layer1', 'layer2', 'layer3']
بعد ذلك يمكننا تغليف model
ليتم فحصه باستخدام Inspect
وفي كل استدعاء أمامي للنموذج الجديد، سنقوم أيضًا بإخراج مخرجات الطبقة المقدمة (في قيمة الإرجاع الثانية):
model_wrapped = Inspect ( model , layer = 'layer2' )
x = torch . rand ( 1 , 5 )
y , x2 = model_wrapped ( x )
print ( x2 ) # tensor([[-0.2726, 0.0910]], grad_fn=<AddmmBackward0>)
يمكننا تقديم قائمة الطبقات:
model_wrapped = Inspect ( model , layer = [ 'layer1' , 'layer2' ])
x = torch . rand ( 1 , 5 )
y , [ x1 , x2 ] = model_wrapped ( x )
print ( x1 ) # tensor([[ 0.1739, 0.3844, -0.4724]], grad_fn=<AddmmBackward0>)
print ( x2 ) # tensor([[-0.2238, 0.0107]], grad_fn=<AddmmBackward0>)
يمكننا توفير قاموس للحصول على المخرجات المسماة:
model_wrapped = Inspect ( model , layer = { 'layer1' : 'x1' , 'layer2' : 'x2' })
x = torch . rand ( 1 , 5 )
y , layers = model_wrapped ( x )
print ( layers )
"""
{
'x1': tensor([[ 0.3707, 0.6584, -0.2970]], grad_fn=<AddmmBackward0>),
'x2': tensor([[-0.1953, -0.3408]], grad_fn=<AddmmBackward0>)
}
"""
model = Inspect (
model : nn . Module ,
layer : Union [ str , Sequence [ str ], Dict [ str , str ]],
keep_output : bool = True ,
)
بالنظر إلى نموذج PyTorch يمكننا عرض جميع العقد الوسيطة للرسم البياني باستخدام get_nodes
:
import torch
import torch . nn as nn
from surgeon_pytorch import Extract , get_nodes
class SomeModel ( nn . Module ):
def __init__ ( self ):
super (). __init__ ()
self . layer1 = nn . Linear ( 5 , 3 )
self . layer2 = nn . Linear ( 3 , 2 )
self . layer3 = nn . Linear ( 1 , 1 )
def forward ( self , x ):
x1 = torch . relu ( self . layer1 ( x ))
x2 = torch . sigmoid ( self . layer2 ( x1 ))
y = self . layer3 ( x2 ). tanh ()
return y
model = SomeModel ()
print ( get_nodes ( model )) # ['x', 'layer1', 'relu', 'layer2', 'sigmoid', 'layer3', 'tanh']
بعد ذلك يمكننا استخراج المخرجات باستخدام Extract
، والذي سيُنشئ نموذجًا جديدًا يُرجع عقدة المخرجات المطلوبة:
model_ext = Extract ( model , node_out = 'sigmoid' )
x = torch . rand ( 1 , 5 )
sigmoid = model_ext ( x )
print ( sigmoid ) # tensor([[0.5570, 0.3652]], grad_fn=<SigmoidBackward0>)
يمكننا أيضًا استخراج نموذج بعقد إدخال جديدة:
model_ext = Extract ( model , node_in = 'layer1' , node_out = 'sigmoid' )
layer1 = torch . rand ( 1 , 3 )
sigmoid = model_ext ( layer1 )
print ( sigmoid ) # tensor([[0.5444, 0.3965]], grad_fn=<SigmoidBackward0>)
يمكننا أيضًا توفير مدخلات ومخرجات متعددة وتسميتها:
model_ext = Extract ( model , node_in = { 'layer1' : 'x' }, node_out = { 'sigmoid' : 'y1' , 'relu' : 'y2' })
out = model_ext ( x = torch . rand ( 1 , 3 ))
print ( out )
"""
{
'y1': tensor([[0.4437, 0.7152]], grad_fn=<SigmoidBackward0>),
'y2': tensor([[0.0555, 0.9014, 0.8297]]),
}
"""
لاحظ أن تغيير عقدة الإدخال قد لا يكون كافيًا لقص الرسم البياني (قد تكون هناك تبعيات أخرى مرتبطة بالمدخلات السابقة). لعرض جميع مدخلات الرسم البياني الجديد يمكننا استدعاء model_ext.summary
والذي سيعطينا نظرة عامة على جميع المدخلات المطلوبة والمخرجات التي تم إرجاعها:
import torch
import torch . nn as nn
from surgeon_pytorch import Extract , get_nodes
class SomeModel ( nn . Module ):
def __init__ ( self ):
super (). __init__ ()
self . layer1a = nn . Linear ( 2 , 2 )
self . layer1b = nn . Linear ( 2 , 2 )
self . layer2 = nn . Linear ( 2 , 1 )
def forward ( self , x ):
a = self . layer1a ( x )
b = self . layer1b ( x )
c = torch . add ( a , b )
y = self . layer2 ( c )
return y
model = SomeModel ()
print ( get_nodes ( model )) # ['x', 'layer1a', 'layer1b', 'add', 'layer2']
model_ext = Extract ( model , node_in = { 'layer1a' : 'my_input' }, node_out = { 'add' : 'my_add' })
print ( model_ext . summary ) # {'input': ('x', 'my_input'), 'output': {'my_add': add}}
out = model_ext ( x = torch . rand ( 1 , 2 ), my_input = torch . rand ( 1 , 2 ))
print ( out ) # {'my_add': tensor([[ 0.3722, -0.6843]], grad_fn=<AddBackward0>)}
model = Extract (
model : nn . Module ,
node_in : Optional [ Union [ str , Sequence [ str ], Dict [ str , str ]]] = None ,
node_out : Optional [ Union [ str , Sequence [ str ], Dict [ str , str ]]] = None ,
tracer : Optional [ Type [ Tracer ]] = None , # Tracer class used, default: torch.fx.Tracer
concrete_args : Optional [ Dict [ str , Any ]] = None , # Tracer concrete_args, default: None
keep_output : bool = None , # Set to `True` to return original outputs as first argument, default: True except if node_out are provided
share_modules : bool = False , # Set to true if you want to share module weights with original model
)
تقوم فئة Inspect
دائمًا بتنفيذ النموذج بأكمله المقدم كمدخل، وتستخدم خطافات خاصة لتسجيل قيم الموتر أثناء تدفقها. يتميز هذا الأسلوب بأنه (1) لا نقوم بإنشاء وحدة نمطية جديدة (2) فهو يسمح برسم بياني للتنفيذ الديناميكي (أي for
وعبارات if
التي تعتمد على المدخلات). تتمثل سلبيات Inspect
في أنه (1) إذا كنا نحتاج فقط إلى تنفيذ جزء من النموذج، فسيتم إهدار بعض العمليات الحسابية، و(2) يمكننا فقط إخراج القيم من طبقات nn.Module
- بدون قيم دالة وسيطة.
تقوم فئة Extract
ببناء نموذج جديد تمامًا باستخدام التتبع الرمزي. مزايا هذا النهج هي (1) يمكننا قص الرسم البياني في أي مكان والحصول على نموذج جديد يحسب ذلك الجزء فقط، (2) يمكننا استخراج القيم من الوظائف الوسيطة (وليس الطبقات فقط)، و (3) يمكننا أيضًا تغيير موترات الإدخال الجانب السلبي Extract
هو أنه يُسمح فقط بالرسوم البيانية الثابتة (لاحظ أن معظم النماذج تحتوي على رسوم بيانية ثابتة).