Diagonal Hierarchical Consistency Learning for Semi-supervised Medical Image Segmentation

Heejoon Koo
2024-04-29
Abstract:Medical image segmentation, which is essential for many clinical applications, has achieved almost human-level performance via data-driven deep learning technologies. Nevertheless, its performance is predicated upon the costly process of manually annotating a vast amount of medical images. To this end, we propose a novel framework for robust semi-supervised medical image segmentation using diagonal hierarchical consistency learning (DiHC-Net). First, it is composed of multiple sub-models with identical multi-scale architecture but with distinct sub-layers, such as up-sampling and normalisation layers. Second, with mutual consistency, a novel consistency regularisation is enforced between one model's intermediate and final prediction and soft pseudo labels from other models in a diagonal hierarchical fashion. A series of experiments verifies the efficacy of our simple framework, outperforming all previous approaches on public benchmark dataset covering organ and tumour.
Computer Vision and Pattern Recognition
What problem does this paper attempt to address?
The problem that this paper attempts to solve is how to use limited labeled data and a large amount of unlabeled data to improve the performance of the model in medical image segmentation. Specifically, the paper proposes a new framework - Diagonal Hierarchical Consistency Learning (DiHC - Net), aiming to enhance the diversity among models and reduce prediction differences in uncertain regions, thereby improving the effect of semi - supervised medical image segmentation. ### Main Contributions 1. **Model Design**: Designed a network composed of three sub - models with the same multi - scale architecture but different sub - layers (such as up - sampling and normalization layers) to increase the internal diversity of the model. 2. **Optimization Method**: Optimize on the labeled data through Deep Supervision, and combine the proposed Diagonal Hierarchical Consistency Learning (DiHC) and Mutual Consistency Learning to optimize on the labeled and unlabeled data. 3. **Experimental Verification**: Experiments were carried out on the publicly available Left Atrium (LA) and Brain Tumor Segmentation (BraTS) datasets, and the results show that this framework significantly outperforms the existing baseline methods. ### Method Overview 1. **Task Definition**: - The training data set contains \( N \) labeled data and \( M \) unlabeled data, where \( N \ll M \). - The labeled data set is represented as \( D_L=\{(x_i^l, y_i)\}_{i = 1}^N \), and the unlabeled data set is represented as \( D_U=\{x_i^u\}_{i = 1}^M \). - The goal is to train a model \( f_\theta \) with parameter \( \theta \) so that it can correctly perform pixel - level classification. 2. **Diversified Multi - scale Sub - models and Deep Supervision**: - Use three sub - models, each with the same multi - scale architecture but different sub - layer configurations (such as normalization layers and up - sampling layers). - Perform deep supervision on the labeled data by minimizing the difference between the intermediate prediction of up - sampling and the true label. - The supervised segmentation loss \( L_{\text{sup}} \) is defined as: \[ L_{\text{sup}}=\sum_{m = 1}^M\sum_{s = 1}^S L_{\text{dice}}(f_s^m(x_i^l), y_i) \] where \( L_{\text{dice}} \) represents the Dice loss, and \( f_s^m(\cdot) \) represents the prediction result of the \( m \) - th model at the \( s \) - th scale. 3. **Diagonal Hierarchical Consistency Learning**: - Introduce the mutual consistency loss \( L_{\text{mc}} \), which enforces the consistency between soft pseudo - labels and final predictions on the labeled and unlabeled data. - Introduce the diagonal hierarchical consistency loss \( L_{\text{dihc}} \), which minimizes the difference between the pseudo - label of one model and the intermediate and final representations of other models on the labeled and unlabeled data. - The overall loss \( L_{\text{total}} \) is defined as: \[ L_{\text{total}}=\lambda_{\text{sup}} L_{\text{sup}}+\lambda_{\text{cst}} L_{\text{cst}} \] where \( L_{\text{cst}} = L_{\text{mc}}+L_{\text{dihc}} \), \( \lambda_{\text{sup}} \) and \( \lambda_{\text{cst}} \)