Masked Completion via Structured Diffusion with White-Box Transformers

Druv Pai,Ziyang Wu,Sam Buchanan,Yaodong Yu,Yi Ma
2024-04-03
Abstract:Modern learning frameworks often train deep neural networks with massive amounts of unlabeled data to learn representations by solving simple pretext tasks, then use the representations as foundations for downstream tasks. These networks are empirically designed; as such, they are usually not interpretable, their representations are not structured, and their designs are potentially redundant. White-box deep networks, in which each layer explicitly identifies and transforms structures in the data, present a promising alternative. However, existing white-box architectures have only been shown to work at scale in supervised settings with labeled data, such as classification. In this work, we provide the first instantiation of the white-box design paradigm that can be applied to large-scale unsupervised representation learning. We do this by exploiting a fundamental connection between diffusion, compression, and (masked) completion, deriving a deep transformer-like masked autoencoder architecture, called CRATE-MAE, in which the role of each layer is mathematically fully interpretable: they transform the data distribution to and from a structured representation. Extensive empirical evaluations confirm our analytical insights. CRATE-MAE demonstrates highly promising performance on large-scale imagery datasets while using only ~30% of the parameters compared to the standard masked autoencoder with the same model configuration. The representations learned by CRATE-MAE have explicit structure and also contain semantic meaning. Code is available at
Machine Learning
What problem does this paper attempt to address?
The problem that this paper attempts to solve is to apply the white - box design paradigm in large - scale unsupervised representation learning. Specifically, the existing white - box architectures mainly work with labeled data in a supervised setting, such as classification tasks. However, for the application of large - scale unsupervised representation learning, the existing methods have not achieved the desired results. This paper proposes a new method. By exploiting the fundamental connections between diffusion, compression, and (masked) completion, a deep Transformer - like masked auto - encoder architecture, named CRATE - MAE, is derived. This method can not only effectively handle large - scale image datasets, but also reduces the number of parameters by about 70% compared to the standard masked auto - encoder, while maintaining the explicit structure and semantic meaning of the representation. ### Main Contributions 1. **Theoretical Connection**: The paper shows that under certain natural conditions, denoising and compression are very similar basic data - processing operations. When the target distribution has a low - dimensional structure, both operations will perform a projection onto this structure. 2. **Architectural Innovation**: Based on the above - mentioned theoretical connection, the paper proposes a new white - box Transformer architecture, CRATE - MAE. The function of each layer in this architecture is fully interpretable, that is, they transform the data distribution into a structured representation form. 3. **Performance Verification**: Through extensive empirical evaluation, the paper verifies the performance of CRATE - MAE on large - scale image datasets, and with a significantly reduced number of parameters, its performance is comparable to or better than that of traditional masked auto - encoders. ### Method Overview - **Signal Model**: The paper assumes that the representation \( Z \) of data follows a low - dimensional Gaussian mixture model, and the marginal distribution of each token \( z_i \) can be expressed as: \[ z_i \sim U_{s_i} \alpha_i, \quad \forall i \in [N] \] where \( s_i \) is the subspace index and \( \alpha_i \) is a zero - mean Gaussian variable. If the noise parameter \( \sigma\geq0 \) is considered, then the marginal distribution of each token \( z_i \) can be expressed as: \[ z_i \sim U_{s_i} \alpha_i+\sigma w_i, \quad \forall i \in [N] \] where \( w_i \) is an independent standard Gaussian variable. - **Optimization Objective**: To learn these representations, the paper proposes a sparse - rate - reduction objective function: \[ \mathbb{E}_Z[\Delta R(Z | U[K])-\lambda \| Z \|_0]=\mathbb{E}_Z[R(Z)-R_c(Z | U[K])-\lambda \| Z \|_0] \] where \( R(Z) \) and \( R_c(Z | U[K]) \) are the unconditional and conditional loss - coding rates or rate - distortion, respectively. - **Encoder and Decoder**: The paper constructs the encoder and decoder by gradually optimizing the sparse - rate - reduction objective function. Each layer \( f_\ell \) of the encoder consists of a multi - head subspace self - attention block (MSSA) and an iterative shrinkage - thresholding - algorithm block (ISTA), which are used for compressing and sparsifying data, respectively. Each layer \( g_\ell \) of the decoder restores the data through reverse operations. ### Experimental Results - **Inter - layer Functional Analysis**: The experimental results show that each layer of the encoder indeed realizes the compression and sparsification of data features, verifying the effectiveness of the theoretical design. - **Auto - encoding Performance**: On the ImageNet - 1K dataset, the reconstruction performance of CRATE - MAE is comparable to that of ViT - MAE, but the number of parameters is only about 30% of the latter. In conclusion, through introducing new theoretical connections and architectural innovations, this paper successfully applies the white - box design paradigm to large - scale unsupervised representation learning, providing new ideas and tools for future research.