Training Data Attribution via Approximate Unrolled Differentiation

Juhan Bae,Wu Lin,Jonathan Lorraine,Roger Grosse
2024-05-21
Abstract:Many training data attribution (TDA) methods aim to estimate how a model's behavior would change if one or more data points were removed from the training set. Methods based on implicit differentiation, such as influence functions, can be made computationally efficient, but fail to account for underspecification, the implicit bias of the optimization algorithm, or multi-stage training pipelines. By contrast, methods based on unrolling address these issues but face scalability challenges. In this work, we connect the implicit-differentiation-based and unrolling-based approaches and combine their benefits by introducing Source, an approximate unrolling-based TDA method that is computed using an influence-function-like formula. While being computationally efficient compared to unrolling-based approaches, Source is suitable in cases where implicit-differentiation-based approaches struggle, such as in non-converged models and multi-stage training pipelines. Empirically, Source outperforms existing TDA techniques in counterfactual prediction, especially in settings where implicit-differentiation-based approaches fall short.
Machine Learning
What problem does this paper attempt to address?
The problem that this paper attempts to solve is: how to effectively estimate the impact of training data on model behavior, especially how the model behavior will change when certain data points are removed from the training set. Specifically, the author focuses on the limitations of existing Training Data Attribution (TDA) methods and proposes a new method to overcome these limitations. ### Problem Background Existing TDA methods are mainly divided into two categories: 1. **Implicit Differentiation**: - Methods in this category, such as Influence Functions, can be calculated efficiently, but cannot handle problems in unconverged models or multi - stage training processes. - It assumes that the model parameters are unique and converge to the optimal solution during the optimization process, which is not always true in modern neural networks. 2. **Unrolled Differentiation**: - Methods in this category, such as SGD - Influence, can better handle unconverged models and multi - stage training processes, but have high computational complexity and need to store all intermediate variables, which is very expensive for large - scale models. ### The Method Proposed in the Paper To solve the above problems, the author proposes a new method named **Source**, which combines the advantages of implicit differentiation and unrolled differentiation. Specifically: - **Source** is a TDA method based on approximate unrolled differentiation and is calculated using a formula similar to the influence function. - It reduces the computational complexity by dividing the training process into multiple segments and assuming that the gradients and Hessian matrices are stationary within each segment. - **Source** only needs to save a few checkpoints instead of all intermediate variables, so there is a significant improvement in computational efficiency. ### Main Contributions 1. **Wider Applicability**: Compared with implicit differentiation, **Source** can be applied to unconverged models and multi - stage training processes. 2. **Higher Computational Efficiency**: Compared with unrolled differentiation, **Source** only needs to save a small number of checkpoints, greatly reducing memory and computational costs. 3. **Experimental Verification**: Experiments have shown that **Source** outperforms existing TDA methods in counterfactual prediction tasks, especially in cases where implicit differentiation performs poorly. ### Formula Summary - **Influence Function of Implicit Differentiation**: \[ \tau_{\text{IF}}(z_q, z_m, D) := \nabla_\theta f(z_q, \theta^\star)^\top H^{-1} \nabla_\theta L(z_m, \theta^\star) \] - **Core Formula of the Source Method**: \[ E[\frac{d\theta_T}{d\epsilon}] \approx -\sum_{\ell = 1}^L (\prod_{\ell' = L}^{\ell + 1} E[S_{\ell'}]) E[r_\ell] \] where, \[ E[S_\ell] \approx \exp(-\bar{\eta}_\ell K_\ell \bar{H}_\ell) \] and \[ E[r_\ell] \approx \frac{1}{N}(I - \exp(-\bar{\eta}_\ell K_\ell \bar{H}_\ell)) \bar{H}_\ell^{-1} \bar{g}_\ell \] Through this method, **Source** can significantly improve computational efficiency while maintaining high accuracy and is suitable for a wider range of scenarios.