Enhancing Neural Network Interpretability with Feature-Aligned Sparse Autoencoders

Luke Marks,Alasdair Paren,David Krueger,Fazl Barez
2024-11-06
Abstract:Sparse Autoencoders (SAEs) have shown promise in improving the interpretability of neural network activations, but can learn features that are not features of the input, limiting their effectiveness. We propose \textsc{Mutual Feature Regularization} \textbf{(MFR)}, a regularization technique for improving feature learning by encouraging SAEs trained in parallel to learn similar features. We motivate \textsc{MFR} by showing that features learned by multiple SAEs are more likely to correlate with features of the input. By training on synthetic data with known features of the input, we show that \textsc{MFR} can help SAEs learn those features, as we can directly compare the features learned by the SAE with the input features for the synthetic data. We then scale \textsc{MFR} to SAEs that are trained to denoise electroencephalography (EEG) data and SAEs that are trained to reconstruct GPT-2 Small activations. We show that \textsc{MFR} can improve the reconstruction loss of SAEs by up to 21.21\% on GPT-2 Small, and 6.67\% on EEG data. Our results suggest that the similarity between features learned by different SAEs can be leveraged to improve SAE training, thereby enhancing performance and the usefulness of SAEs for model interpretability.
Machine Learning
What problem does this paper attempt to address?
The problem that this paper attempts to solve is to improve the effectiveness of Sparse Autoencoders (SAEs) in explaining the internal activations of neural networks. Specifically, the authors point out that although SAEs show potential in improving the interpretability of neural network activations, they may learn features that are not features of the input data, thus limiting their effectiveness. To solve this problem, the authors propose the **Mutual Feature Regularization (MFR)** technique. MFR improves feature learning by encouraging multiple SAEs trained in parallel to learn similar features. Its main goal is to ensure that the features learned by SAEs are closer to the real features of the input data, thereby enhancing the performance and usefulness of SAEs in model interpretation. ### Main contributions of the paper: 1. **Proposing the MFR technique**: By introducing MFR, the authors hope that multiple SAEs can learn more consistent and input - feature - related features. 2. **Verifying the hypothesis**: The authors verify their hypothesis through synthetic data experiments, that is, the features jointly learned by multiple SAEs are more likely to be the real features of the input data. 3. **Applying to real - data**: The authors apply MFR to the activation reconstruction of GPT - 2 Small and the EEG data denoising task, demonstrating the effectiveness and scalability of MFR on real - world data. 4. **Performance improvement**: The experimental results show that using MFR can significantly reduce the reconstruction loss of SAEs and perform better on key evaluation metrics, thus enhancing the practicality of SAEs in explaining neural networks. ### Specific problem description: - **Existing problems**: SAEs may learn features that are not features of the input data, resulting in a decrease in interpretability. - **Solutions**: Through MFR, multiple SAEs are encouraged to learn similar features, thereby increasing the correlation between the learned features and the input features. - **Experimental verification**: Through experiments on synthetic data and real data (such as GPT - 2 Small and EEG data), the effectiveness of MFR is proved. ### Key formulas: - **TopK activation function**: \[ \sigma_k(h)_i = \begin{cases} h_i & \text{if } h_i \geq \tau_k(h) \\ 0 & \text{otherwise} \end{cases} \] where \(\tau_k(h)\) is the \(k\) - th largest activation value in \(h\). - **Auxiliary penalty term**: \[ \alpha \binom{N}{2} \sum_{i = 1}^{N - 1} \sum_{j = i + 1}^N (1-\text{MMCS}(W^{(i)}, W^{(j)})) \] where \(\alpha\) is the weight coefficient, \(N\) is the number of SAEs, and \(\text{MMCS}\) is the mean of the maximum cosine similarity pairs. Through these methods, the authors have successfully improved the effectiveness and accuracy of SAEs in explaining neural network activations.