Tabular Data Contrastive Learning via Class-Conditioned and Feature-Correlation Based Augmentation

Wei Cui,Rasa Hosseinzadeh,Junwei Ma,Tongzi Wu,Yi Sui,Keyvan Golestan
2024-04-30
Abstract:Contrastive learning is a model pre-training technique by first creating similar views of the original data, and then encouraging the data and its corresponding views to be close in the embedding space. Contrastive learning has witnessed success in image and natural language data, thanks to the domain-specific augmentation techniques that are both intuitive and effective. Nonetheless, in tabular domain, the predominant augmentation technique for creating views is through corrupting tabular entries via swapping values, which is not as sound or effective. We propose a simple yet powerful improvement to this augmentation technique: corrupting tabular data conditioned on class identity. Specifically, when corrupting a specific tabular entry from an anchor row, instead of randomly sampling a value in the same feature column from the entire table uniformly, we only sample from rows that are identified to be within the same class as the anchor row. We assume the semi-supervised learning setting, and adopt the pseudo labeling technique for obtaining class identities over all table rows. We also explore the novel idea of selecting features to be corrupted based on feature correlation structures. Extensive experiments show that the proposed approach consistently outperforms the conventional corruption method for tabular data classification tasks. Our code is available at
Machine Learning,Artificial Intelligence
What problem does this paper attempt to address?
The problem that this paper attempts to solve is: in the field of tabular data, the existing data augmentation techniques (such as corrupting tabular entries by swapping values) are not very effective or reasonable. Therefore, the author proposes an improved data augmentation method to improve the performance of contrastive learning in tabular data classification tasks. Specifically, the paper mainly focuses on two aspects of problems: 1. **How to Corrupt**: - When corrupting tabular data, the existing data augmentation methods usually randomly select replacement values from the entire table, which may lead to the generated views being semantically dissimilar to the original data. To improve this, the author proposes **Class - Conditioned Corruption**. That is, when corrupting a specific tabular entry, instead of randomly sampling from the entire table, it only samples from the rows that belong to the same class as the anchor row. This can ensure that the generated views are more likely to be semantically similar to the original anchor, and thus closer to the anchor in the embedding space. 2. **Where to Corrupt**: - In addition to improving the corruption method, the author also explores the problem of which features to select for corruption. They propose **Correlation - Based Feature Masking**. By using the correlation structure between features and selecting a subset of highly correlated features for corruption, the semantic information of the data can be better preserved, thereby enhancing the effect of contrastive learning. ### Formula Representation To ensure the correctness and readability of the formulas, the following are some key formulas involved in the paper: - **Cosine Similarity**: \[ s_{i,j} = \frac{\hat{z}_i^T \hat{z}_j}{\|\hat{z}_i\|_2 \|\hat{z}_j\|_2} \] where \(\hat{z}_i\) and \(\hat{z}_j\) are two embedding vectors. - **Contrastive Loss Function**: \[ L = \frac{1}{2N^2} \sum_{i = 1}^{2N} -\log \left( \frac{e^{s_{i,i'}/\tau}}{\sum_{j = 1}^{2N} 1[j \neq i] e^{s_{i,j}/\tau}} \right) \] where \(i'\) is the index of the paired view (or anchor) embedding, and \(\tau\) is the temperature parameter. Through these improvements, the paper aims to improve the effect of contrastive learning in tabular data classification tasks and has verified the effectiveness of the proposed method in multiple experiments.