White-Box Transformers via Sparse Rate Reduction: Compression Is All There Is?

Yaodong Yu,Sam Buchanan,Druv Pai,Tianzhe Chu,Ziyang Wu,Shengbang Tong,Hao Bai,Yuexiang Zhai,Benjamin D. Haeffele,Yi Ma
2024-09-06
Abstract:In this paper, we contend that a natural objective of representation learning is to compress and transform the distribution of the data, say sets of tokens, towards a low-dimensional Gaussian mixture supported on incoherent subspaces. The goodness of such a representation can be evaluated by a principled measure, called sparse rate reduction, that simultaneously maximizes the intrinsic information gain and extrinsic sparsity of the learned representation. From this perspective, popular deep network architectures, including transformers, can be viewed as realizing iterative schemes to optimize this measure. Particularly, we derive a transformer block from alternating optimization on parts of this objective: the multi-head self-attention operator compresses the representation by implementing an approximate gradient descent step on the coding rate of the features, and the subsequent multi-layer perceptron sparsifies the features. This leads to a family of white-box transformer-like deep network architectures, named CRATE, which are mathematically fully interpretable. We show, by way of a novel connection between denoising and compression, that the inverse to the aforementioned compressive encoding can be realized by the same class of CRATE architectures. Thus, the so-derived white-box architectures are universal to both encoders and decoders. Experiments show that these networks, despite their simplicity, indeed learn to compress and sparsify representations of large-scale real-world image and text datasets, and achieve performance very close to highly engineered transformer-based models: ViT, MAE, DINO, BERT, and GPT2. We believe the proposed computational framework demonstrates great potential in bridging the gap between theory and practice of deep learning, from a unified perspective of data compression. Code is available at: <a class="link-external link-https" href="https://ma-lab-berkeley.github.io/CRATE" rel="external noopener nofollow">this https URL</a> .
Machine Learning,Computation and Language,Computer Vision and Pattern Recognition
What problem does this paper attempt to address?
The core problem that this paper attempts to solve is: how to optimize the quality of representation learning through compression and sparsification, and provide a unified theoretical framework to explain and design deep network architectures, especially Transformers. Specifically, the authors believe that a natural goal of representation learning is to compress and transform the data distribution into a low - dimensional Gaussian mixture distribution, supported on incoherent subspaces. The quality of such a representation can be evaluated by a principled measure called "sparse rate reduction", which simultaneously maximizes the intrinsic information gain and extrinsic sparsity of the learned representation. From this perspective, popular deep network architectures, including Transformers, can be regarded as iterative schemes for achieving the optimization of this measure. In particular, the paper derives that the Transformer block is achieved by alternately optimizing parts of this objective: the multi - head self - attention mechanism compresses the representation by implementing an approximate gradient - descent step on the feature encoding rate, while the subsequent multi - layer perceptron sparsifies the features. This leads to a class of white - box Transformer - like deep network architectures named CRATE (Compressive Rate - reduction And Transformative Encoding), which are fully mathematically interpretable. Furthermore, by establishing a new connection between denoising and compression, the paper points out that the inverse process of the above - mentioned compression encoding can be achieved by the same class of CRATE architectures. Therefore, the derived white - box architectures are universal for both encoders and decoders. Experiments show that although these network structures are simple, they are indeed able to learn the compressed and sparse representations of large - scale real - world image and text datasets, and exhibit strong performance in different settings, such as ViT, MAE, DINO, BERT and GPT2. Overall, the computational framework proposed in the paper shows great potential in bridging the gap between theory and practice from a unified data - compression perspective.