Counterfactual Explanations for Medical Image Classification and Regression using Diffusion Autoencoder

Matan Atad,David Schinz,Hendrik Moeller,Robert Graf,Benedikt Wiestler,Daniel Rueckert,Nassir Navab,Jan S. Kirschke,Matthias Keicher
DOI: https://doi.org/10.59275/j.melba.2024-4862
2024-10-01
Abstract:Counterfactual explanations (CEs) aim to enhance the interpretability of machine learning models by illustrating how alterations in input features would affect the resulting predictions. Common CE approaches require an additional model and are typically constrained to binary counterfactuals. In contrast, we propose a novel method that operates directly on the latent space of a generative model, specifically a Diffusion Autoencoder (DAE). This approach offers inherent interpretability by enabling the generation of CEs and the continuous visualization of the model's internal representation across decision boundaries. Our method leverages the DAE's ability to encode images into a semantically rich latent space in an unsupervised manner, eliminating the need for labeled data or separate feature extraction models. We show that these latent representations are helpful for medical condition classification and the ordinal regression of severity pathologies, such as vertebral compression fractures (VCF) and diabetic retinopathy (DR). Beyond binary CEs, our method supports the visualization of ordinal CEs using a linear model, providing deeper insights into the model's decision-making process and enhancing interpretability. Experiments across various medical imaging datasets demonstrate the method's advantages in interpretability and versatility. The linear manifold of the DAE's latent space allows for meaningful interpolation and manipulation, making it a powerful tool for exploring medical image properties. Our code is available at <a class="link-external link-https" href="https://doi.org/10.5281/zenodo.13859266" rel="external noopener nofollow">this https URL</a>.
Computer Vision and Pattern Recognition,Machine Learning
What problem does this paper attempt to address?
The problem that this paper attempts to solve is to improve the interpretability of machine - learning models in medical image classification and regression tasks. Specifically, the authors propose a method based on Diffusion Autoencoder (DAE) to generate Counterfactual Explanations (CEs). This method can generate continuous visual explanations by directly operating in the latent space of the generative model, thereby enhancing the transparency and understandability of the model's decision - making process. ### Main problems 1. **Limitations of existing methods**: - Existing counterfactual explanation methods usually require additional models and are usually limited to binary counterfactual explanations. - These methods have semantic gap and complexity problems because they operate independently of the classifier. 2. **Generating continuous counterfactual explanations**: - The method proposed by the authors can generate continuous counterfactual explanations in the latent space, which is suitable for classification and ordered regression tasks. - By interpolating and editing in the latent space, images reflecting different pathological grades can be generated, providing a visualization of the model's decision boundary. 3. **Unsupervised feature extraction**: - Use unlabeled data to train DAE as an unsupervised feature extractor, eliminating the dependence on labeled data. - This allows counterfactual explanations for multiple downstream tasks to be generated on the same latent space without the need for repeated training. 4. **Modeling continuous regression of pathological grading**: - By using continuous regression methods in the latent space, the smooth process of pathological progression, such as the grading of vertebral compression fractures (VCF) and diabetic retinopathy (DR), can be modeled. - This goal can be achieved with only binary labels, reducing the need for detailed annotations. ### Method overview 1. **Unsupervised feature extraction**: - Use the DAE model to train on unlabeled data and learn a semantically rich latent space. - Representations in the latent space can be used for downstream tasks such as classification and regression. 2. **Linear decision boundary**: - Encode the labeled images and train linear classifiers (such as linear regression and SVM) to predict the presence of the target pathology. - Use the distance of the decision hyperplane to estimate the severity of the pathological grading. 3. **Generating counterfactual explanations**: - Generate counterfactual explanation images by editing the latent code in the latent space along the direction of the decision hyperplane. - Counterfactual explanations for specific pathological grades can be generated, and the required change magnitude can be determined by reverse - calibrating the regression process. ### Experimental results - **Image reconstruction**: The DAE has better image reconstruction performance in the vertebral compression fracture (VCF) task than StyleGAN2 E4E, as shown by lower LPIPS scores and higher FID scores. - **Discriminative tasks**: The DAE latent representation performs well in multiple detection and regression tasks, comparable to the fully - supervised DenseNet121 baseline model. - **Counterfactual explanations**: The generated counterfactual explanation images can intuitively show the model's decision boundary and help understand how the model makes decisions based on input features. ### Conclusion The method proposed in this paper significantly improves the interpretability of models in medical image classification and regression tasks by generating continuous counterfactual explanations in the DAE's latent space. This method not only simplifies the process of generating counterfactual explanations but also can support multiple downstream tasks on the same latent space and has broad application prospects.