Meta-Learned Modality-Weighted Knowledge Distillation for Robust Multi-Modal Learning with Missing Data

Hu Wang,Salma Hassan,Yuyuan Liu,Congbo Ma,Yuanhong Chen,Yutong Xie,Mostafa Salem,Yu Tian,Jodie Avery,Louise Hull,Ian Reid,Mohammad Yaqub,Gustavo Carneiro
2024-11-13
Abstract:In multi-modal learning, some modalities are more influential than others, and their absence can have a significant impact on classification/segmentation accuracy. Addressing this challenge, we propose a novel approach called Meta-learned Modality-weighted Knowledge Distillation (MetaKD), which enables multi-modal models to maintain high accuracy even when key modalities are missing. MetaKD adaptively estimates the importance weight of each modality through a meta-learning process. These learned importance weights guide a pairwise modality-weighted knowledge distillation process, allowing high-importance modalities to transfer knowledge to lower-importance ones, resulting in robust performance despite missing inputs. Unlike previous methods in the field, which are often task-specific and require significant modifications, our approach is designed to work in multiple tasks (e.g., segmentation and classification) with minimal adaptation. Experimental results on five prevalent datasets, including three Brain Tumor Segmentation datasets (BraTS2018, BraTS2019 and BraTS2020), the Alzheimer's Disease Neuroimaging Initiative (ADNI) classification dataset and the Audiovision-MNIST classification dataset, demonstrate the proposed model is able to outperform the compared models by a large margin.
Computer Vision and Pattern Recognition
What problem does this paper attempt to address?
### Problems the paper attempts to solve The paper aims to solve the problem that in multimodal learning, the absence of certain modal data leads to a significant decline in classification or segmentation performance. Specifically, multimodal learning usually depends on multiple data sources (such as text, image, audio, video, sensor data, etc.). When these data sources are used in combination, they can provide complementary information, thereby improving the performance of tasks. However, in practical applications, the data of certain modalities may be unavailable, which will seriously affect the robustness and performance of the model. ### Main contributions 1. **Proposed a new multimodal learning model**: Meta - Learned Modality - Weighted Knowledge Distillation (MetaKD). This model can maintain high performance in the absence of some key modalities by distilling knowledge from high - importance modalities to low - importance modalities. 2. **Flexible design**: MetaKD can easily adapt to multiple tasks, such as classification and segmentation, without the need for a large number of modifications for each task. 3. **Experimental verification**: Experiments were carried out on five popular benchmark datasets, including three brain tumor segmentation datasets (BraTS2018, BraTS2019 and BraTS2020), the Alzheimer's Disease Neuroimaging Initiative (ADNI) classification dataset and the Audiovision - MNIST classification dataset. The results show that MetaKD significantly outperforms other models. ### Method overview 1. **Model structure**: - **Encoder**: Each non - missing modality \(x_i\) is input into an encoder with parameters \( \theta_i\) to extract features \(f_i\): \[ f_i = f_{\theta_i}(x_i) \] - **Feature generation**: For the missing modality \(x_n\), its features can be generated from the features of available modalities: \[ f_n = \frac{1}{N - |Q|} \sum_{x_i \in M \setminus Q} f_{\theta_i}(x_i) \] - **Decoder**: Concatenate all the extracted or generated features and input them into the decoder for prediction: \[ \hat{y} = f_\zeta \left( \bigoplus_{i = 1}^N f_i \right) \] 2. **Knowledge distillation meta - learning**: - **Two - level optimization**: The training of MetaKD adopts the two - level optimization method of meta - learning, where the meta - parameters are represented as the importance weight vector (IWV) \(w = [w_1,\ldots,w_N]^\top \in \mathbb{R}^N\). - **Inner - layer learning**: Use the training set \(D_t\) to optimize the model parameters responsible for the main task (such as classification or segmentation). - **Outer - layer learning**: Use the validation set \(D_v\) to optimize the meta - objective function and estimate the meta - parameters: \[ w^* = \arg \min_w \sum_{(M, y) \in D_v} \ell_{\text{meta}} \left( y, f_{\zeta^*} \left( \bigoplus_{i = 1}^N w_i\times f_i^* \right) \right) \] where: \[ \theta^*, \zeta^* = \arg \min_{\theta, \zeta} \sum_{(M, y) \in D_t} \left[ \ell_{\text{task}} \left( y, f_\zeta \left( \bigoplus_{i = 1}^N f_i \right) \right) + \alpha \sum_{i, j = 1}^N