MatFormer: Nested Transformer for Elastic Inference
Devvrit,Sneha Kudugunta,Aditya Kusupati,Tim Dettmers,Kaifeng Chen,Inderjit Dhillon,Yulia Tsvetkov,Hannaneh Hajishirzi,Sham Kakade,Ali Farhadi,Prateek Jain
2024-12-15
Abstract:Foundation models are applied in a broad spectrum of settings with different inference constraints, from massive multi-accelerator clusters to resource-constrained standalone mobile devices. However, the substantial costs associated with training these models often limit the number of unique model sizes that can be offered. Consequently, practitioners are compelled to select a model that may not be optimally aligned with their specific latency and cost requirements. We present MatFormer, a novel Transformer architecture designed to provide elastic inference across diverse deployment constraints. MatFormer achieves this by incorporating a nested Feed Forward Network (FFN) block structure within a standard Transformer model. During training, we optimize the parameters of multiple nested FFN blocks with varying sizes, enabling the extraction of hundreds of accurate smaller models without incurring additional computational costs. We empirically validate the efficacy of MatFormer across different model classes (decoders and encoders) and modalities (language and vision), demonstrating its potential for real-world deployment. We show that a 850M decoder-only MatFormer language model (MatLM) allows us to extract multiple smaller models spanning from 582M to 850M parameters, each exhibiting better validation loss and one-shot downstream evaluations than independently trained counterparts. Furthermore, we observe that smaller encoders extracted from a universal MatFormer-based ViT (MatViT) encoder preserve the metric-space structure for adaptive large-scale retrieval. Finally, we showcase that speculative decoding with the accurate and consistent submodels extracted from MatFormer can lead to significant reduction in inference latency. Project website: <a class="link-external link-https" href="https://devvrit.github.io/matformer/" rel="external noopener nofollow">this https URL</a>
Machine Learning,Computation and Language,Computer Vision and Pattern Recognition