•介紹
•安裝
•層
•數學功能
• Pytorch後衛
•測試
Attorch是Pytorch的nn
模塊的子集,使用Openai的Triton純粹用Python編寫。它的目標是成為一個易於黑客入侵,獨立且可讀性的神經網絡模塊集合,同時維持或提高Pytorch的效率。換句話說,它打算成為一個具有簡單,直觀的設計的可分配項目,對於那些尋求開發自定義深度學習操作但對純Pytorch實施的速度並且沒有技術專業知識或資源不滿意的人來說,它可以作為可訪問的起點。
由Triton提供動力的類似Pytorch的框架已經存在,包括Kernl,Xformers,Unspher和fla
,但大多數主要集中於變形金剛和NLP應用程序上,而Attorch的目標是通過與NLP之外的各種區域相關的區域來更具包容性。此外,Attorch不是僅推理的軟件包,並且可以完全支持向前和向後的通行證,這意味著它可以在訓練和推理過程中使用,儘管後者的性能通常與專用推理引擎不相同。
Attorch的唯一依賴性是torch==2.4.0
和triton==3.0.0
。請安裝這兩個庫的指定版本,然後克隆此存儲庫以開始。
當前具有自動混合精度(AMP)支持的實施圖層是,
attorch.Conv1d
:使用權重的輸入上的1d-convolves,可選地添加偏差。attorch.Conv2d
:使用權重的輸入上的2D-Convolves,可選地添加偏差。attorch.MultiheadAttention
:對輸入應用多頭縮放的點產生關注。attorch.Hardsigmoid
:將硬sigmoid應用於輸入,可選地融合掉落。attorch.Hardswish
:將艱苦的旋轉應用於輸入,可選地融合輟學。attorch.LeakyReLU
:將洩漏的依賴應用於輸入,可選地融合掉落。attorch.GELU
:將gelu應用於輸入,可選地融合掉落。attorch.ReLU
:適用於輸入,可選地融合輟學。attorch.ReLU6
:將relu6應用於輸入,可選地將掉落。attorch.SELU
:將SELU應用於輸入,可選地融合輟學。attorch.SiLU
:將SILU應用於輸入,可選地將掉落掉。attorch.Mish
:將Mish應用於輸入,可選地將掉落。attorch.Sigmoid
:將sigmoid應用於輸入,可選地融合掉落。attorch.Tanh
:將tanh應用於輸入,可選地融合輟學。attorch.GLU
:將帶有任意激活函數的封閉線性單元應用於輸入。attorch.LogSoftmax
:使用SoftMax將輸入歸一化,並進行日誌。attorch.Softmax
:使用SoftMax對輸入進行歸一化。attorch.Softmin
:使用軟敏化將輸入歸一化。attorch.BatchNorm1d
:批量歸一量將2D或3D輸入歸功於構層融合,從而融合激活函數並將殘留物添加到前激活結果中。attorch.BatchNorm2d
:批量歸一量4D輸入,可選地融合激活函數並將殘餘添加到前激活結果中。attorch.LayerNorm
:層歸一化輸入。attorch.RMSNorm
:根平方均衡輸入。attorch.Linear
:線性地使用權重轉換輸入,可選地添加偏差並融合激活函數。attorch.Dropout
:訓練期間輸入中的隨機零元素。attorch.L1Loss
:測量輸入和目標之間的平均絕對誤差。attorch.MSELoss
:測量輸入和目標之間的平方平方誤差。attorch.CrossEntropyLoss
:測量輸入和目標之間的平均交叉熵損失,每個類可選。attorch.NLLLoss
:測量輸入和目標之間的負模可能性損失,每個班級可選。除非在其Docstrings中另有說明,否則上述層的行為與它們的質量等效物相同。
Triton內核通常由兩個部分組成:一個部分處理相關張量的負載和存儲,另一個使用適當的數學函數來轉換數據。例如,一層歸一化的內核從輸入(負載)讀取一排或幾行,標準化功能(數學),並將結果寫入容器(存儲)。 attorch.math
提供了這些純數學功能的選擇,目的是促進定制內核和操作融合的實施。儘管只有triton-autodiff
功能的正向通行證可以在attorch.math
中獲得。可以通過用相應的attorch.math
或它們的衍生品取代其數學位來重構Attorch的內核中的很大一部分,但是這樣做會犧牲單檔和獨立的自動化設計的設計,因此attorch.math
。
為了更輕鬆地集成Attorch和Pytorch層,提供了attorch.nn
,如果不提供所需的圖層,它為帶有Pytorch後備的Attorch模塊提供了接口,如下所示。
from attorch import nn
lin = nn . Linear ( 10 , 20 ) # Uses attorch's linear layer
gap = nn . AdaptiveAvgPool2d ( 1 ) # Uses PyTorch's global pooling since GAP is not available in attorch
每個模塊都可以針對其Pytorch對應物進行測試,以確保正確性。這些測試包括在tests/
並且可以使用pytest
執行。應該注意的是,由於數值精度問題,有些人可能會失敗,但是在大多數實際用例中,這不是問題。