Adaptive Nearest Neighbor: A General Framework for Distance Metric Learning

Kun Song
DOI: https://doi.org/10.48550/arXiv.1911.10674
2019-11-22
Abstract:$K$-NN classifier is one of the most famous classification algorithms, whose performance is crucially dependent on the distance metric. When we consider the distance metric as a parameter of $K$-NN, learning an appropriate distance metric for $K$-NN can be seen as minimizing the empirical risk of $K$-NN. In this paper, we design a new type of continuous decision function of the $K$-NN classification rule which can be used to construct the continuous empirical risk function of $K$-NN. By minimizing this continuous empirical risk function, we obtain a novel distance metric learning algorithm named as adaptive nearest neighbor (ANN). We have proved that the current algorithms such as the large margin nearest neighbor (LMNN), neighbourhood components analysis (NCA) and the pairwise constraint methods are special cases of the proposed ANN by setting the parameter different values. Compared with the LMNN, NCA, and pairwise constraint methods, our method has a broader searching space which may contain better solutions. At last, extensive experiments on various data sets are conducted to demonstrate the effectiveness and efficiency of the proposed method.
Machine Learning
What problem does this paper attempt to address?
The problem that this paper attempts to solve is how to find a suitable distance measurement method for the K - Nearest Neighbor (K - NN) classifier to improve its performance. Specifically, the author proposes a new continuous decision function to construct the continuous empirical risk function of the K - NN classification rule, and by minimizing this continuous empirical risk function, designs a new distance metric learning algorithm, called Adaptive Nearest Neighbor (ANN). ### Main contributions of the paper: 1. **Theoretical connection**: - It is proved that several existing distance metric learning algorithms (such as Large Margin Nearest Neighbor (LMNN), Neighborhood Component Analysis (NCA) and the method based on pairwise constraints) are special cases of the proposed ANN algorithm under specific parameter settings. This establishes the connection between the convex distance metric learning model LMNN, the convex pairwise method and the non - convex metric learning model NCA. 2. **Wider search space**: - Since the objective function of ANN can be regarded as the empirical risk of the K - NN classification rule, the search space of ANN is more accurate than that of LMNN, NCA and the pairwise constraint method, which means that ANN may achieve better performance. 3. **Higher computational efficiency**: - In ANN, the amount of distance calculation involved in the objective function gradient calculation is \( N(N - 1) \), which is equivalent to the pairwise - constraint metric learning algorithm. Therefore, ANN runs faster than LMNN. 4. **Experimental verification**: - The effectiveness and superiority of the proposed method are verified through extensive classification task experiments on multiple datasets. ### Key problems solved: - **Discontinuity problem**: The decision function of the traditional K - NN classifier is discontinuous, which makes it difficult to directly minimize the empirical risk. The paper solves this problem by designing a continuous decision function. - **Strict constraint problem**: The existing triple - based methods (such as LMNN) have too strict constraints on the K - NN classification rule, which limits the search space. ANN relaxes these constraints by using a continuous decision function. - **Prior information dependence problem**: Many triple - based methods need to provide prior information to construct triples, while ANN does not need such prior information, thus reducing the dependence on prior information. ### Summary: The paper proposes the Adaptive Nearest Neighbor (ANN) algorithm by designing a new continuous decision function, solves the discontinuity and strict constraint problems of the K - NN classifier in distance metric learning, improves the classification performance, and is superior to the existing methods in computational efficiency.