キーロックランクワン編集の実装。プロジェクトページ
この論文のセールスポイントは、追加されたコンセプトごとの追加パラメーターが 100kb までと非常に少ないことです。
彼らは、LLM 用の記憶編集用紙の Rank-1 編集技術をいくつかの改良を加えて適用することに成功したようです。彼らはまた、キーが新しい概念の「場所」を決定し、値が「何を」決定するかを特定し、(値を学習しながら) スーパークラスの概念に対するローカル/グローバル キー ロックを提案しました。
研究者の皆様、もしこの論文が確認できれば、このリポジトリのツールは、クロスアテンションコンディショニングを使用する他のテキストから<insert modality>
ネットワークでも機能するはずです。ちょっとした考え
StabilityAI の寛大なスポンサーシップと、他のスポンサーの皆様
複数のコードレビューと電子メールの明確化のための Yoad Tewel
Brad Vidler は、Stable Diffusion 1.5 で使用される CLIP の共分散行列を事前計算してくれました。
SOTA オープンソース対照学習テキスト画像モデルの OpenClip のメンテナ全員
$ pip install perfusion-pytorch
import torch
from torch import nn
from perfusion_pytorch import Rank1EditModule
to_keys = nn . Linear ( 768 , 320 , bias = False )
to_values = nn . Linear ( 768 , 320 , bias = False )
wrapped_to_keys = Rank1EditModule (
to_keys ,
is_key_proj = True
)
wrapped_to_values = Rank1EditModule (
to_values
)
text_enc = torch . randn ( 4 , 77 , 768 ) # regular input
text_enc_with_superclass = torch . randn ( 4 , 77 , 768 ) # init_input in algorithm 1, for key-locking
concept_indices = torch . randint ( 0 , 77 , ( 4 ,)) # index where the concept or superclass concept token is in the sequence
key_pad_mask = torch . ones ( 4 , 77 ). bool ()
keys = wrapped_to_keys (
text_enc ,
concept_indices = concept_indices ,
text_enc_with_superclass = text_enc_with_superclass ,
)
values = wrapped_to_values (
text_enc ,
concept_indices = concept_indices ,
text_enc_with_superclass = text_enc_with_superclass ,
)
# after much training ...
wrapped_to_keys . eval ()
wrapped_to_values . eval ()
keys = wrapped_to_keys ( text_enc )
values = wrapped_to_values ( text_enc )
このリポジトリには、新しい概念のトレーニング (および複数の概念での最終的な推論) を容易にするEmbeddingWrapper
も含まれています。
import torch
from torch import nn
from perfusion_pytorch import EmbeddingWrapper
embed = nn . Embedding ( 49408 , 512 ) # open clip embedding, somewhere in the module tree of stable diffusion
# wrap it, and will automatically create a new concept for learning, based on the superclass embed string
wrapped_embed = EmbeddingWrapper (
embed ,
superclass_string = 'dog'
)
# now just pass in your prompts with the superclass id
embeds_with_new_concept , embeds_with_superclass , embed_mask , concept_indices = wrapped_embed ([
'a portrait of dog' ,
'dog running through a green field' ,
'a man walking his dog'
]) # (3, 77, 512), (3, 77, 512), (3, 77), (3,)
# now pass both embeds through clip text transformer
# the embed_mask needs to be passed to the cross attention as key padding mask
安定した拡散インスタンス内でCLIP
インスタンスを識別できる場合は、それをOpenClipEmbedWrapper
に直接渡して、クロス アテンション レイヤーに必要なものすべてを取得することもできます。
元。
from perfusion_pytorch import OpenClipEmbedWrapper
texts = [
'a portrait of dog' ,
'dog running through a green field' ,
'a man walking his dog'
]
wrapped_clip_with_new_concept = OpenClipEmbedWrapper (
stable_diffusion . path . to . clip ,
superclass_string = 'dog'
)
text_enc , superclass_enc , mask , indices = wrapped_clip_with_new_concept ( texts )
# (3, 77, 512), (3, 77, 512), (3, 77), (3,)
xiao の dreambooth-sd から始めて、SD 1.5 と接続します。
複数の概念を使用した推論の例を Readme に表示
make_key_value_proj_rank1_edit_modules_
関数に指定されていない場合、キーと値の射影がどこにあるかを自動的に推測します。
埋め込みラッパーは、スーパークラスのトークン ID での置換を処理し、スーパークラスで埋め込みを返す必要があります。
複数のコンセプトを確認 - Yoad のおかげで
クロスアテンションを結びつける機能を提供する
推論時に 1 つのプロンプトで複数の概念を処理します - シグモイド項 + 出力の合計
複数のRank1EditModule
から個別に学習した概念を 1 つに結合して推論する方法を提供します
Rank1EditModule
をマージするための関数を提供します論文で提案されたコンセプトのゼロショット マスキングを追加
データセットとテキスト エンコーダーを受け取り、ランク 1 の更新に必要な共分散行列を事前計算する関数を処理します。
研究者に学習率の違いを心配させる代わりに、(概念の埋め込みを学ぶために)他の論文からの分数勾配トリックを提供します。
@article { Tewel2023KeyLockedRO ,
title = { Key-Locked Rank One Editing for Text-to-Image Personalization } ,
author = { Yoad Tewel and Rinon Gal and Gal Chechik and Yuval Atzmon } ,
journal = { ACM SIGGRAPH 2023 Conference Proceedings } ,
year = { 2023 } ,
url = { https://api.semanticscholar.org/CorpusID:258436985 }
}
@inproceedings { Meng2022LocatingAE ,
title = { Locating and Editing Factual Associations in GPT } ,
author = { Kevin Meng and David Bau and Alex Andonian and Yonatan Belinkov } ,
booktitle = { Neural Information Processing Systems } ,
year = { 2022 } ,
url = { https://api.semanticscholar.org/CorpusID:255825985 }
}