BatchTopK Sparse Autoencoders

Bart Bussmann,Patrick Leask,Neel Nanda
2024-12-09
Abstract:Sparse autoencoders (SAEs) have emerged as a powerful tool for interpreting language model activations by decomposing them into sparse, interpretable features. A popular approach is the TopK SAE, that uses a fixed number of the most active latents per sample to reconstruct the model activations. We introduce BatchTopK SAEs, a training method that improves upon TopK SAEs by relaxing the top-k constraint to the batch-level, allowing for a variable number of latents to be active per sample. As a result, BatchTopK adaptively allocates more or fewer latents depending on the sample, improving reconstruction without sacrificing average sparsity. We show that BatchTopK SAEs consistently outperform TopK SAEs in reconstructing activations from GPT-2 Small and Gemma 2 2B, and achieve comparable performance to state-of-the-art JumpReLU SAEs. However, an advantage of BatchTopK is that the average number of latents can be directly specified, rather than approximately tuned through a costly hyperparameter sweep. We provide code for training and evaluating BatchTopK SAEs at <a class="link-external link-https" href="https://github.com/bartbussmann/BatchTopK" rel="external noopener nofollow">this https URL</a>
Machine Learning,Artificial Intelligence
What problem does this paper attempt to address?
The problem that this paper attempts to solve is how to improve the reconstruction performance while maintaining the average sparsity when existing Sparse Autoencoders (SAEs) process language model activations. Specifically: 1. **Limitations of TopK SAEs**: - TopK SAEs reconstruct model activations by retaining a fixed number of the most active latent variables in each sample. - Although this method is effective, there is a problem: it assumes that all samples have the same number of active latent variables, which limits the flexibility of the model. 2. **Improvements of BatchTopK SAEs**: - To solve the above problems, the paper introduces BatchTopK SAEs, a new training method. - BatchTopK SAEs relax the top - k constraint to the batch level, allowing each sample to use a different number of latent variables as needed. - This method enables the model to allocate latent variables more flexibly, thereby improving the reconstruction performance without sacrificing the average sparsity. 3. **Specific Objectives**: - Improve the reconstruction quality of language model activations. - Achieve better reconstruction results while maintaining the average sparsity. - Provide a more efficient and flexible solution compared to existing TopK SAEs and JumpReLU SAEs. Through experiments, the author shows the superior performance of BatchTopK SAEs on the GPT - 2 Small and Gemma 2 2B models. In particular, under different dictionary sizes and sparsity levels, BatchTopK SAEs consistently outperforms TopK SAEs, and in some cases performs comparably to the state - of - the - art JumpReLU SAEs. In terms of formulas, the training objective function of BatchTopK SAEs can be expressed as: \[ L(X)=\left\| X - \text{BatchTopK}(W_{\text{enc}}X + b_{\text{enc}})W_{\text{dec}}+b_{\text{dec}} \right\|_2^2+\alpha L_{\text{aux}} \] where \( X \) is the input data batch; \( W_{\text{enc}} \) and \( b_{\text{enc}} \) are the weights and biases of the encoder respectively; \( W_{\text{dec}} \) and \( b_{\text{dec}} \) are the weights and biases of the decoder respectively; the \(\text{BatchTopK}\) function selects the top \( n\times k \) largest activation values in the entire batch and sets other activation values to zero; \( L_{\text{aux}} \) is an auxiliary loss term used to prevent "dead" latent variables. Furthermore, in order to eliminate batch - processing dependence in the inference stage, the author introduces a global threshold parameter \( \theta \), which is calculated as: \[ \theta=\mathbb{E}_X\left[ \min \{ z_{i,j}(X)\mid z_{i,j}(X)>0 \} \right] \] where \( z_{i,j}(X) \) represents the \( j \) - th latent activation value of the \( i \) - th sample in batch \( X \).