UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation

Abdelrahman Shaker,Muhammad Maaz,Hanoona Rasheed,Salman Khan,Ming-Hsuan Yang,Fahad Shahbaz Khan
2024-05-04
Abstract:Owing to the success of transformer models, recent works study their applicability in 3D medical segmentation tasks. Within the transformer models, the self-attention mechanism is one of the main building blocks that strives to capture long-range dependencies. However, the self-attention operation has quadratic complexity which proves to be a computational bottleneck, especially in volumetric medical imaging, where the inputs are 3D with numerous slices. In this paper, we propose a 3D medical image segmentation approach, named UNETR++, that offers both high-quality segmentation masks as well as efficiency in terms of parameters, compute cost, and inference speed. The core of our design is the introduction of a novel efficient paired attention (EPA) block that efficiently learns spatial and channel-wise discriminative features using a pair of inter-dependent branches based on spatial and channel attention. Our spatial attention formulation is efficient having linear complexity with respect to the input sequence length. To enable communication between spatial and channel-focused branches, we share the weights of query and key mapping functions that provide a complimentary benefit (paired attention), while also reducing the overall network parameters. Our extensive evaluations on five benchmarks, Synapse, BTCV, ACDC, BRaTs, and Decathlon-Lung, reveal the effectiveness of our contributions in terms of both efficiency and accuracy. On Synapse, our UNETR++ sets a new state-of-the-art with a Dice Score of 87.2%, while being significantly efficient with a reduction of over 71% in terms of both parameters and FLOPs, compared to the best method in the literature. Code: <a class="link-external link-https" href="https://github.com/Amshaker/unetr_plus_plus" rel="external noopener nofollow">this https URL</a>.
Computer Vision and Pattern Recognition
What problem does this paper attempt to address?
The paper attempts to address the issue in 3D medical image segmentation tasks where existing methods, despite significant progress in accuracy, have high model complexity, leading to high computational costs, large parameter sizes, and slow inference speeds. Specifically: 1. **Trade-off between model efficiency and accuracy**: Existing hybrid architectures (such as UNETR and nnFormer) have improved segmentation accuracy but significantly increased the number of model parameters and floating-point operations (FLOPs), resulting in unsatisfactory robustness and efficiency. 2. **Computational bottleneck of self-attention mechanisms**: Traditional self-attention mechanisms have quadratic complexity, which is particularly evident when processing 3D medical images, as the input data is usually large 3D volumes. This high complexity limits the model's efficiency. 3. **Dependency between spatial and channel features**: Existing methods fail to effectively capture the explicit dependency between spatial and channel features, affecting segmentation quality. To address these issues, the authors propose a new method called UNETR++ that aims to improve both segmentation accuracy and model efficiency simultaneously. The specific contributions include: - **Efficient Paired Attention (EPA) block**: By introducing the EPA block, spatial and channel features are effectively learned, improving segmentation accuracy while maintaining efficiency. The spatial attention mechanism in the EPA block has linear complexity, while the channel attention mechanism emphasizes the dependency between channel feature maps. - **Hierarchical architecture design**: UNETR++ adopts a hierarchical encoder-decoder structure, gradually reducing feature resolution to decrease computational costs and parameter sizes. - **Experimental validation**: Extensive experiments on five benchmark datasets (Synapse, BTCV, ACDC, BRaTs, and Decathlon-Lung) demonstrate the superior performance of UNETR++ in terms of segmentation accuracy and model efficiency. In summary, the goal of this paper is to address the trade-off between efficiency and accuracy in existing 3D medical image segmentation methods by introducing an efficient paired attention mechanism and a hierarchical architecture design.