Abstract:$k$-subset sampling is ubiquitous in machine learning, enabling regularization and interpretability through sparsity. The challenge lies in rendering $k$-subset sampling amenable to end-to-end learning. This has typically involved relaxing the reparameterized samples to allow for backpropagation, with the risk of introducing high bias and high variance. In this work, we fall back to discrete $k$-subset sampling on the forward pass. This is coupled with using the gradient with respect to the exact marginals, computed efficiently, as a proxy for the true gradient. We show that our gradient estimator, SIMPLE, exhibits lower bias and variance compared to state-of-the-art estimators, including the straight-through Gumbel estimator when $k = 1$. Empirical results show improved performance on learning to explain and sparse linear regression. We provide an algorithm for computing the exact ELBO for the $k$-subset distribution, obtaining significantly lower loss compared to SOTA.
What problem does this paper attempt to address?
The paper attempts to address the problem of achieving end-to-end learning in k-subset sampling within machine learning. Specifically, k-subset sampling is very common in many machine learning tasks, such as sparse feature representation, parametric k-nearest neighbors, model prediction interpretation, discrete variational autoencoders, and sparse regression. However, traditional k-subset sampling methods often require relaxation of the samples to allow for backpropagation, which can lead to high bias and high variance.
To solve this problem, the authors propose a new gradient estimator—SIMPLE (Subset Implicit Likelihood Estimation). SIMPLE uses discrete k-subset sampling during forward propagation and employs exact conditional marginal probabilities as proxies for the true gradients during backpropagation. In this way, SIMPLE can achieve end-to-end learning while maintaining low bias and low variance.
The main contributions of the paper include:
1. **Proposing a new gradient estimator**: SIMPLE uses discrete sampling during forward propagation and exact conditional marginal probabilities during backpropagation, thus avoiding the bias introduced by relaxation.
2. **Providing an efficient algorithm**: The authors provide an efficient algorithm for computing k-subset probabilities and conditional marginal probabilities, making the entire process differentiable and computationally efficient.
3. **Experimental validation**: Through multiple experimental tasks, including synthetic experiments, discrete variational autoencoders, learning interpretation, and sparse linear regression, the effectiveness and superiority of SIMPLE are validated.
Overall, the paper aims to address the bias and variance issues in k-subset sampling for end-to-end learning through SIMPLE, thereby improving the performance and interpretability of models.