Knowledge Distillation and Training Balance for Heterogeneous Decentralized Multi-Modal Learning over Wireless Networks

Benshun Yin,Zhiyong Chen,Meixia Tao
2023-11-11
Abstract:Decentralized learning is widely employed for collaboratively training models using distributed data over wireless networks. Existing decentralized learning methods primarily focus on training single-modal networks. For the decentralized multi-modal learning (DMML), the modality heterogeneity and the non-independent and non-identically distributed (non-IID) data across devices make it difficult for the training model to capture the correlated features across different modalities. Moreover, modality competition can result in training imbalance among different modalities, which can significantly impact the performance of DMML. To improve the training performance in the presence of non-IID data and modality heterogeneity, we propose a novel DMML with knowledge distillation (DMML-KD) framework, which decomposes the extracted feature into the modality-common and the modality-specific components. In the proposed DMML-KD, a generator is applied to learn the global conditional distribution of the modality-common features, thereby guiding the modality-common features of different devices towards the same distribution. Meanwhile, we propose to decrease the number of local iterations for the modalities with fast training speed in DMML-KD to address the imbalanced training. We design a balance metric based on the parameter variation to evaluate the training speed of different modalities in DMML-KD. Using this metric, we optimize the number of local iterations for different modalities on each device under the constraint of remaining energy on devices. Experimental results demonstrate that the proposed DMML-KD with training balance can effectively improve the training performance of DMML.
Information Theory
What problem does this paper attempt to address?
### Problems the paper attempts to solve This paper aims to solve several key problems encountered in decentralized multi - modal learning (DMML) in wireless networks: 1. **Non - IID (Non - Independent and Identically Distributed) data**: On different devices, data may not be independently and identically distributed, which makes it difficult for the model to capture the relevant features between different modalities. 2. **Modality Heterogeneity**: Modalities on different devices may be different. For example, some cars may be equipped with visual sensors but lack radar sensors. This difference in modalities increases the difficulty of learning cross - modal features. 3. **Training Imbalance**: The training speeds of different modalities may be different, resulting in insufficient training of some modalities and thus affecting the overall performance. To solve these problems, the authors propose a new framework - Decentralized Multi - Modal Learning with Knowledge Distillation (DMML - KD). This framework improves the training performance of DMML through the following methods: - **Feature Decomposition**: Decompose the extracted features into a modality - common part and a modality - specific part, and learn the global conditional distribution of modality - common features through a generator to guide the modality - common features of different devices to approach the same distribution. - **Training Balance**: Solve the training imbalance problem by reducing the local iteration times of modalities with fast training speeds, and design a balance metric based on parameter changes to optimize the local iteration times of different modalities on each device. ### Formula Presentation 1. **Global Loss Function**: \[ F(w)=\frac{1}{K}\sum_{k\in K}F_{k}(w) \] where \(F_{k}(w)=\frac{1}{D_{k}}\sum_{d = 1}^{D_{k}}f(\{x_{m}^{k}(d)\}_{m\in M_{k}},y_{k}(d);w)\) is the task - specific loss function of device \(k\), and \(f(\cdot)\) is determined by the specific task, such as cross - entropy loss. 2. **Similarity Loss**: \[ F_{\text{sim}}^{k}=\frac{1}{D_{k}|M_{k}|(|M_{k}|- 1)}\sum_{d = 1}^{D_{k}}\sum_{(m,m')\in\hat{M}_{k}}\text{KL}(\sigma(\hat{w}_{n}^{k}\hat{h}_{m}^{k}(d))\|\sigma(\hat{w}_{n}^{k}\hat{h}_{m'}^{k}(d))) \] where \(\hat{M}_{k}\) is the set of permutations formed by selecting any two different elements from \(M_{k}\), and \(\text{KL}(z\|\hat{z})=\sum_{i}z_{i}\ln\frac{z_{i}}{\hat{z}_{i}}\) is the KL divergence. 3. **Auxiliary Classification Loss**: \[ F_{\text{cls}}^{k}=\frac{1}{D_{k}|M_{k}|}\sum_{d = 1}^{D_{k}}\sum_{m\in M_{k}}\text{CE}(\sigma(\hat{w}_{n}^{k}\hat{h}_{m}^{k}(d)),y_{k}(d)) \] where \(\text{CE}(\sigma(\hat{w}_{n}^{k}\hat{h}_{m}^{k}(d)),y_{k}(d))\) is the cross - entropy loss. 4. **Difference Loss**: \[ F_{\text{dif}}^{k}=\sum_{m\in M_{k}}\|(\hat{H}_{m}^{k})^{T}\check{H}_{m}^{k}\|_{2}^{2} \]