Esta biblioteca entrena a los autoencoders de K -sparse (SAE) en las activaciones de la corriente residual de los modelos de lenguaje Huggingface, siguiendo aproximadamente la receta detallada en la escala y la evaluación de autoencoders dispersos (Gao et al. 2024).
Esta es una biblioteca delgada y simple con pocas opciones de configuración. A diferencia de la mayoría de las otras bibliotecas SAE (por ejemplo, Saelens), no almacena activaciones en el disco, sino que las calcula sobre la marcha. Esto nos permite escalar a modelos y conjuntos de datos muy grandes con una sobrecarga de almacenamiento cero, pero tiene la desventaja de que probar diferentes hiperparámetros para el mismo modelo y conjunto de datos será más lento que si estuviéramos en caché las activaciones (ya que las activaciones se volverán computadas). Podemos agregar almacenamiento en caché como una opción en el futuro.
Después de Gao et al., Utilizamos una función de activación de TOPK que impone directamente un nivel de dispersión deseado en las activaciones. Esto contrasta con otras bibliotecas que usan una penalización L1 en la función de pérdida. Creemos que Topk es una mejora de Pareto sobre el enfoque L1 y, por lo tanto, no planeamos apoyarlo.
Para cargar un SAE previamente provocado desde el Hub Huggingface, puede usar el método Sae.load_from_hub
de la siguiente manera:
from sae import Sae
sae = Sae . load_from_hub ( "EleutherAI/sae-llama-3-8b-32x" , hookpoint = "layers.10" )
Esto cargará el SAE para la capa de corriente residual 10 de Llama 3 8B, que fue entrenado con un factor de expansión de 32. También puede cargar los SAE para todas las capas a la vez usando Sae.load_many
:
saes = Sae . load_many ( "EleutherAI/sae-llama-3-8b-32x" )
saes [ "layers.10" ]
Se garantiza que el diccionario devuelto por load_many
se clasificará naturalmente con el nombre del punto de gancho. Para el caso común en el que los puntos de gancho se llaman embed_tokens
, layers.0
, ..., layers.n
, esto significa que los SAE se clasificarán por número de capa. Luego podemos reunir las activaciones de SAE para un modelo de avance de la siguiente manera:
from transformers import AutoModelForCausalLM , AutoTokenizer
import torch
tokenizer = AutoTokenizer . from_pretrained ( "meta-llama/Meta-Llama-3-8B" )
inputs = tokenizer ( "Hello, world!" , return_tensors = "pt" )
with torch . inference_mode ():
model = AutoModelForCausalLM . from_pretrained ( "meta-llama/Meta-Llama-3-8B" )
outputs = model ( ** inputs , output_hidden_states = True )
latent_acts = []
for sae , hidden_state in zip ( saes . values (), outputs . hidden_states ):
latent_acts . append ( sae . encode ( hidden_state ))
# Do stuff with the latent activations
Para entrenar a SAES desde la línea de comandos, puede usar el siguiente comando:
python -m sae EleutherAI/pythia-160m togethercomputer/RedPajama-Data-1T-Sample
La CLI admite todas las opciones de configuración proporcionadas por la clase TrainConfig
. Puedes verlos ejecutando python -m sae --help
.
El uso programático es simple. Aquí hay un ejemplo:
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM , AutoTokenizer
from sae import SaeConfig , SaeTrainer , TrainConfig
from sae . data import chunk_and_tokenize
MODEL = "EleutherAI/pythia-160m"
dataset = load_dataset (
"togethercomputer/RedPajama-Data-1T-Sample" ,
split = "train" ,
trust_remote_code = True ,
)
tokenizer = AutoTokenizer . from_pretrained ( MODEL )
tokenized = chunk_and_tokenize ( dataset , tokenizer )
gpt = AutoModelForCausalLM . from_pretrained (
MODEL ,
device_map = { "" : "cuda" },
torch_dtype = torch . bfloat16 ,
)
cfg = TrainConfig (
SaeConfig ( gpt . config . hidden_size ), batch_size = 16
)
trainer = SaeTrainer ( cfg , tokenized , gpt )
trainer . fit ()
Por defecto, los SAE están entrenados en las activaciones de la corriente residual del modelo. Sin embargo, también puede entrenar a SAE en las activaciones de cualquier otro submódulo especificando patrones de punto de manejo personalizados. Estos patrones son como nombres de módulos Pytorch estándar (por ejemplo, h.0.ln_1
) pero también permiten la sintaxis de coincidencia de patrones UNIX, incluidos comodines y conjuntos de caracteres. Por ejemplo, para entrenar SAE en la salida de cada módulo de atención y las activaciones internas de cada MLP en GPT-2, puede usar el siguiente código:
python -m sae gpt2 togethercomputer/RedPajama-Data-1T-Sample --hookpoints " h.*.attn " " h.*.mlp.act "
Para restringir las primeras tres capas:
python -m sae gpt2 togethercomputer/RedPajama-Data-1T-Sample --hookpoints " h.[012].attn " " h.[012].mlp.act "
Actualmente no apoyamos el control manual de grano fino sobre la tasa de aprendizaje, el número de latentes u otros hiperparámetros en forma de punto de manejo por gancho. De manera predeterminada, la opción expansion_ratio
se utiliza para seleccionar el número apropiado de latentes para cada punto de manejo en función del ancho de la salida de ese punto de conexión. La tasa de aprendizaje predeterminada para cada punto de conexión se establece utilizando una ley de escala de raíz cuadrada inversa basada en el número de latentes. Si establece manualmente el número de latentes o la tasa de aprendizaje, se aplicará a todos los puntos de manejo.
Apoyamos la capacitación distribuida a través del comando torchrun
de Pytorch. Por defecto, utilizamos el método paralelo de datos distribuidos, lo que significa que los pesos de cada SAE se replican en cada GPU.
torchrun --nproc_per_node gpu -m sae meta-llama/Meta-Llama-3-8B --batch_size 1 --layers 16 24 --k 192 --grad_acc_steps 8 --ctx_len 2048
Esto es simple, pero muy ineficiente. Si desea entrenar a SAE para muchas capas de un modelo, recomendamos usar el indicador --distribute_modules
, que asigna los SAE para diferentes capas a diferentes GPU. Actualmente, requerimos que el número de GPU divida uniformemente el número de capas para las que está entrenando SAES.
torchrun --nproc_per_node gpu -m sae meta-llama/Meta-Llama-3-8B --distribute_modules --batch_size 1 --layer_stride 2 --grad_acc_steps 8 --ctx_len 2048 --k 192 --load_in_8bit --micro_acc_steps 2
El comando anterior entrena un SAE para cada capa uniforme de Llama 3 8B, utilizando todas las GPU disponibles. Acumula gradientes en más de 8 minibates y divide cada minibatch en 2 microbatches antes de alimentarlos en el codificador SAE, ahorrando así mucha memoria. También carga el modelo en una precisión de 8 bits usando bitsandbytes
. Este comando no requiere más de 48 GB de memoria por GPU en un nodo de 8 GPU.
Hay varias características que nos gustaría agregar en el futuro cercano:
Si desea ayudar con alguno de estos, ¡no dude en abrir un PR! Puede colaborar con nosotros en el canal de escasos autores de la discordia de Eleutherai.