Yunguan Fu,Yiwen Li,Shaheer U. Saeed,Matthew J. Clarkson,Yipeng Hu
Abstract:Recently, denoising diffusion probabilistic models (DDPM) have been applied to image segmentation by generating segmentation masks conditioned on images, while the applications were mainly limited to 2D networks without exploiting potential benefits from the 3D formulation. In this work, we studied the DDPM-based segmentation model for 3D multiclass segmentation on two large multiclass data sets (prostate MR and abdominal CT). We observed that the difference between training and test methods led to inferior performance for existing DDPM methods. To mitigate the inconsistency, we proposed a recycling method which generated corrupted masks based on the model's prediction at a previous time step instead of using ground truth. The proposed method achieved statistically significantly improved performance compared to existing DDPMs, independent of a number of other techniques for reducing train-test discrepancy, including performing mask prediction, using Dice loss, and reducing the number of diffusion time steps during training. The performance of diffusion models was also competitive and visually similar to non-diffusion-based U-net, within the same compute budget. The JAX-based diffusion framework has been released at <a class="link-external link-https" href="https://github.com/mathpluscode/ImgX-DiffSeg" rel="external noopener nofollow">this https URL</a>.
What problem does this paper attempt to address?
The main problem that this paper attempts to solve is the performance degradation caused by the inconsistent training and testing methods of the existing diffusion models (DDPM) in 3D multi - class segmentation tasks. Specifically:
1. **Differences between training and testing methods**: The existing DDPM methods use noise masks generated based on the real labels during training, but cannot access the real labels during testing. This leads to an inconsistency between the training and testing processes, thus affecting the model performance.
2. **Challenges in 3D image processing**: Most of the existing diffusion model applications are based on 2D networks, while the processing of 3D volumetric medical images faces greater problems of memory occupation and computational cost. For example, 3D diffusion models need to process images and noise masks simultaneously, which increases the memory occupation, limits the batch size and may lead to an overly long training time. In addition, most diffusion models assume that hundreds of denoising steps are required for training and inference. In particular, the inference stage may take several days to run on TPUs/GPUs.
To solve these problems, the authors propose the following improvement measures:
- **Recycling**: By using the prediction results of the model in the previous time step to generate new noise masks instead of relying on the real labels. This method makes the training process closer to the actual testing process, reduces the risk of data leakage, and significantly improves the model performance.
- **Direct prediction of real labels**: Different from the traditional prediction of noise, the new method directly predicts the real segmentation masks, which enables the use of Dice loss and cross - entropy loss for training, not just L2 loss for noise prediction.
- **Reducing denoising steps**: A five - step denoising process is proposed to replace the hundreds of denoising steps in the traditional method, combined with the resampling variance schedule, to accelerate the training and inference speed.
Through these improvements, the authors conducted experiments on two large multi - class segmentation datasets (prostate MR and abdominal CT) and showed a statistically significant performance improvement. The specific results are as follows:
- On the prostate MR dataset, the Dice score increased from 0.815 to 0.830 (p < 0.001).
- On the abdominal CT dataset, the Dice score increased from 0.753 to 0.801 (p < 0.001).
In addition, the diffusion model achieved a performance level comparable to that of non - diffusion supervised learning methods under the same computational budget, and the visual effects are similar.
### Formula summary
The formulas involved in the paper include:
- Noise mask generation:
\[
x_{t + 1}=\sqrt{\bar{\alpha}_{t + 1}}x_0+\sqrt{1-\bar{\alpha}_{t + 1}}\epsilon_{t + 1}
\]
\[
x_t=\sqrt{\bar{\alpha}_t}x_{0,\theta}+\sqrt{1-\bar{\alpha}_t}\epsilon_t
\]
- Model prediction:
\[
\mu_\theta(x_t,t)=\frac{\sqrt{\bar{\alpha}_{t - 1}}\beta_t}{1-\bar{\alpha}_t}x_{0,\theta}(x_t,t)+\frac{1-\bar{\alpha}_{t - 1}}{1-\bar{\alpha}_t}\sqrt{\alpha_t}x_t
\]
\[
\mu_\theta(x_t,t)=\frac{1}{\sqrt{\alpha_t}}\left(x_t-\beta_t\sqrt{1-\bar{\alpha}_t}\epsilon_{t,\theta}(x_t,t)\right)
\]
- Loss function:
\[
L_{seg,x_0}(\theta)=E_{t,x_0,\epsilon_t,I}L_{seg}(x_0,x_{0,\theta})
\]
\[
L_{seg,\epsilon_t}(\theta)