Mahalanobis-Aware Training for Out-of-Distribution Detection

Connor Mclaughlin,Jason Matterer,Michael Yee
2023-11-02
Abstract:While deep learning models have seen widespread success in controlled environments, there are still barriers to their adoption in open-world settings. One critical task for safe deployment is the detection of anomalous or out-of-distribution samples that may require human intervention. In this work, we present a novel loss function and recipe for training networks with improved density-based out-of-distribution sensitivity. We demonstrate the effectiveness of our method on CIFAR-10, notably reducing the false-positive rate of the relative Mahalanobis distance method on far-OOD tasks by over 50%.
Machine Learning
What problem does this paper attempt to address?
This paper aims to solve a key problem encountered when deploying deep - learning models in an open - world environment - how to effectively detect abnormal or out - of - distribution (OOD) samples. Specifically, the authors propose a new loss function and training method to improve the sensitivity of density - based OOD detection. Through this method, they hope to reduce the false - positive rate when dealing with data far from the distribution, thereby improving the safety and reliability of the model. ### Background of the Paper With the wide application of deep - learning models in controlled environments, the deployment of these models in the open - world still faces challenges. An important task is to be able to detect in a timely manner when the model encounters abnormal or out - of - distribution samples and may require human intervention. Existing OOD detection methods can be roughly divided into two categories: one is the method based on model output, which evaluates the confidence through the logits or probabilities predicted by the model; the other is the method based on model representation, which judges by measuring the similarity between the intermediate - layer representation and the data seen during training. This paper focuses on the latter, especially those methods that assume the data has a Gaussian structure. ### Research Motivation Although some studies assume that the data follows a Gaussian distribution and have achieved empirical success, the theoretical basis is still insufficient. Other studies adopt more complex non - parametric methods to avoid making any assumptions about the data. The authors of this paper propose whether it is possible to improve the performance of Gaussian - based OOD detection methods by explicitly training the model to create Gaussian - like data representations. ### Main Contributions 1. **Propose a new regularization loss**: This loss better aligns the training objective with the OOD detector at test time. It calculates the probability of test samples by online estimating Gaussian parameters and minimizes the cross - entropy loss. 2. **Provide a training scheme**: This scheme includes noise - reduction techniques, enabling the method to be implemented with limited computational resources. ### Method Overview - **OOD Detection Task**: Learn a scoring function \( S(x) \) to capture the similarity between the test data and the training distribution. For density - estimation - based methods, this scoring function is similar to the likelihood function of the probability model representing the Gaussian distribution of the training data. - **Mahalanobis Distance Method**: Assume that the neural network latent representation follows a Gaussian distribution under class - conditional, and use the Mahalanobis distance of the nearest class center point as the scoring function. - **Proposed Training Objective**: Calculate the predicted probability of test samples by online estimating Gaussian parameters and using Bayes' rule, and minimize the cross - entropy loss. The final combined loss is a weighted combination of an initial cross - entropy loss \( L_{\text{base}} \) and a Mahalanobis - distance - based cross - entropy loss \( L_{\text{maha}} \): \[ L_{\text{reg}}=(1 - \alpha) L_{\text{base}}+\alpha L_{\text{maha}} \] where \( \alpha \) is a hyperparameter used to control the balance between the two losses. - **Gaussian Parameter Estimation**: To reduce the noise introduced by mini - batch estimation, use a shrinkage estimator to estimate the covariance matrix and maintain the moving average (EMA) of the mean and covariance. ### Experimental Results The authors used CIFAR - 10 as the in - distribution dataset and conducted multiple experiments. The results show that their method outperforms existing baseline methods in both far - OOD and near - OOD benchmark tests, especially in reducing the false - positive rate in far - OOD tasks, which is more than 50% lower than that of the relative Mahalanobis distance method. ### Conclusion This paper proposes a new training method based on Mahalanobis distance regularization, which significantly improves the performance of OOD detection while maintaining the performance of in - distribution data. Future work will explore the scalability of this method on large - scale datasets.