SegFormer3D: an Efficient Transformer for 3D Medical Image Segmentation

Shehan Perera,Pouyan Navard,Alper Yilmaz
2024-04-24
Abstract:The adoption of Vision Transformers (ViTs) based architectures represents a significant advancement in 3D Medical Image (MI) segmentation, surpassing traditional Convolutional Neural Network (CNN) models by enhancing global contextual understanding. While this paradigm shift has significantly enhanced 3D segmentation performance, state-of-the-art architectures require extremely large and complex architectures with large scale computing resources for training and deployment. Furthermore, in the context of limited datasets, often encountered in medical imaging, larger models can present hurdles in both model generalization and convergence. In response to these challenges and to demonstrate that lightweight models are a valuable area of research in 3D medical imaging, we present SegFormer3D, a hierarchical Transformer that calculates attention across multiscale volumetric features. Additionally, SegFormer3D avoids complex decoders and uses an all-MLP decoder to aggregate local and global attention features to produce highly accurate segmentation masks. The proposed memory efficient Transformer preserves the performance characteristics of a significantly larger model in a compact design. SegFormer3D democratizes deep learning for 3D medical image segmentation by offering a model with 33x less parameters and a 13x reduction in GFLOPS compared to the current state-of-the-art (SOTA). We benchmark SegFormer3D against the current SOTA models on three widely used datasets Synapse, BRaTs, and ACDC, achieving competitive results. Code:
Computer Vision and Pattern Recognition
What problem does this paper attempt to address?
The paper primarily addresses the issues in the field of 3D medical image segmentation by proposing a new solution. Specifically, while existing methods based on Vision Transformers (ViTs) have made significant progress in 3D medical image segmentation tasks, these methods usually require very large model architectures and a substantial amount of computational resources for training and deployment. Additionally, in cases where the dataset size is limited (a common phenomenon in medical imaging), large models may encounter problems with model generalization and convergence. To address the above challenges, the paper proposes SegFormer3D, a hierarchical Transformer model that can compute attention on multi-scale volumetric features and aggregate local and global attention features through a full MLP decoder to generate high-precision segmentation masks. The features of SegFormer3D include: 1. **Lightweight and Efficiency**: SegFormer3D is a lightweight and memory-efficient model, reducing the number of parameters by 34 times and the computational complexity (GFLOPS) by 13 times compared to the current state-of-the-art technology. This significantly reduces the demand for computational resources while maintaining high performance. 2. **Hierarchical Structure**: SegFormer3D adopts a hierarchical Transformer structure that can capture different scale features of the input volume from coarse to fine, which helps improve the model's representation capability. 3. **Efficient Self-Attention Mechanism**: To reduce the computational burden brought by long-sequence 3D volumetric inputs, SegFormer3D introduces an efficient self-attention module that can significantly reduce computational complexity, thereby enhancing the model's scalability and performance. 4. **Position-Independent Encoding**: SegFormer3D employs a mix-ffn module that allows automatic learning of positional cues, eliminating the need for fixed positional encoding and ensuring superior scalability and performance of the model. 5. **Full MLP Decoder**: Unlike traditional convolutional decoders, SegFormer3D uses a full MLP decoder to effectively generate high-quality segmentation masks, simplifying the decoding process and ensuring consistency and efficiency across different datasets. Experimental results show that SegFormer3D achieves competitive results on three widely used benchmark datasets (Synapse, BRaTs, and ACDC), demonstrating its effectiveness in 3D medical image segmentation tasks.