Semi-Supervised Learning via Weight-aware Distillation under Class Distribution Mismatch

Pan Du,Suyun Zhao,Zisen Sheng,Cuiping Li,Hong Chen
2023-08-23
Abstract:Semi-Supervised Learning (SSL) under class distribution mismatch aims to tackle a challenging problem wherein unlabeled data contain lots of unknown categories unseen in the labeled ones. In such mismatch scenarios, traditional SSL suffers severe performance damage due to the harmful invasion of the instances with unknown categories into the target classifier. In this study, by strict mathematical reasoning, we reveal that the SSL error under class distribution mismatch is composed of pseudo-labeling error and invasion error, both of which jointly bound the SSL population risk. To alleviate the SSL error, we propose a robust SSL framework called Weight-Aware Distillation (WAD) that, by weights, selectively transfers knowledge beneficial to the target task from unsupervised contrastive representation to the target classifier. Specifically, WAD captures adaptive weights and high-quality pseudo labels to target instances by exploring point mutual information (PMI) in representation space to maximize the role of unlabeled data and filter unknown categories. Theoretically, we prove that WAD has a tight upper bound of population risk under class distribution mismatch. Experimentally, extensive results demonstrate that WAD outperforms five state-of-the-art SSL approaches and one standard baseline on two benchmark datasets, CIFAR10 and CIFAR100, and an artificial cross-dataset. The code is available at <a class="link-external link-https" href="https://github.com/RUC-DWBI-ML/research/tree/main/WAD-master" rel="external noopener nofollow">this https URL</a>.
Computer Vision and Pattern Recognition
What problem does this paper attempt to address?
The problem that this paper attempts to solve is the effectiveness of semi - supervised learning (SSL) in the case of class distribution mismatch. Specifically, when the unlabeled data contains a large number of unknown classes that have not been seen in the labeled data, traditional semi - supervised learning methods will suffer from severe performance losses. This is because instances of unknown classes will wrongly affect the training of the target classifier. Through strict mathematical reasoning, the paper reveals that the SSL error in the case of class distribution mismatch consists of two parts: pseudo - labeling error and invasion error, and proposes a robust SSL framework named Weight - Aware Distillation (WAD) to mitigate these errors. WAD selectively transfers knowledge beneficial to the target task from unsupervised contrastive representations to the target classifier, uses Point Mutual Information (PMI) to capture adaptive weights and high - quality pseudo - labels, in order to maximize the role of unlabeled data and filter out unknown classes. Theoretically, the paper proves that the population risk of WAD has a tight upper bound in the case of class distribution mismatch. Experimental results show that WAD outperforms five state - of - the - art SSL methods and a standard baseline on two benchmark datasets, CIFAR10 and CIFAR100, and a synthetic cross - dataset.