EMC$^2$: Efficient MCMC Negative Sampling for Contrastive Learning with Global Convergence

Chung-Yiu Yau,Hoi-To Wai,Parameswaran Raman,Soumajyoti Sarkar,Mingyi Hong
2024-04-16
Abstract:A key challenge in contrastive learning is to generate negative samples from a large sample set to contrast with positive samples, for learning better encoding of the data. These negative samples often follow a softmax distribution which are dynamically updated during the training process. However, sampling from this distribution is non-trivial due to the high computational costs in computing the partition function. In this paper, we propose an Efficient Markov Chain Monte Carlo negative sampling method for Contrastive learning (EMC$^2$). We follow the global contrastive learning loss as introduced in SogCLR, and propose EMC$^2$ which utilizes an adaptive Metropolis-Hastings subroutine to generate hardness-aware negative samples in an online fashion during the optimization. We prove that EMC$^2$ finds an $\mathcal{O}(1/\sqrt{T})$-stationary point of the global contrastive loss in $T$ iterations. Compared to prior works, EMC$^2$ is the first algorithm that exhibits global convergence (to stationarity) regardless of the choice of batch size while exhibiting low computation and memory cost. Numerical experiments validate that EMC$^2$ is effective with small batch training and achieves comparable or better performance than baseline algorithms. We report the results for pre-training image encoders on STL-10 and Imagenet-100.
Machine Learning,Artificial Intelligence,Computer Vision and Pattern Recognition,Optimization and Control
What problem does this paper attempt to address?
This paper focuses on the problem of effectively generating negative samples in contrastive learning. Contrastive learning is a self-supervised learning method used to learn better encoding representations from a large amount of data. The key challenge lies in generating negative samples from a large sample set for comparison with positive samples. Current methods typically rely on large batch sizes to maintain the quality of the negative sample distribution, but this approach comes with high computational costs. The paper proposes a new algorithm called EMC2 (Efficient Markov Chain Monte Carlo negative sampling for Contrastive Learning), which uses an adaptive Metropolis-Hastings subroutine to generate "difficulty-aware" negative samples online. EMC2 can find the O(1/√T)-stationary point of the global contrastive loss in the optimization process, and it has low computational and memory costs. Compared to previous algorithms, EMC2 demonstrates for the first time the achievement of global convergence (to the stationary point) regardless of the batch size. EMC2 solves the efficiency problem in optimizing the global contrastive loss in contrastive learning. This loss function aims to reduce the distance between positive sample pairs while pushing negative sample pairs apart. Due to the complexity of computing the normalization term for negative samples, existing methods like SimCLR approximate the gradient by randomly selecting negative sample batches, but this requires large batches, which is not feasible in resource-limited scenarios. EMC2 directly tracks the negative sample distribution using the Markov chain Monte Carlo method and dynamically adjusts sampling with the Metropolis-Hastings algorithm to adapt to SGD iterations, thereby reducing the memory and computational complexity per iteration. The paper also analyzes the convergence of EMC2 and verifies its effectiveness in small-batch training through numerical experiments, showing good performance compared to baseline algorithms. In summary, the main contribution of the paper is the proposal of a new and efficient negative sample sampling method called EMC2, which addresses the computational bottleneck in optimizing contrastive learning. It can achieve efficient convergence in small-batch training.