Learning feature relationships in CNN model via relational embedding convolution layer
Shengzhou Xiong,Yihua Tan,Guoyou Wang,Pei Yan,Xuanyu Xiang
DOI: https://doi.org/10.1016/j.neunet.2024.106510
Abstract:Establishing the relationships among hierarchical visual attributes of objects in the visual world is crucial for human cognition. The classic convolution neural network (CNN) can successfully extract hierarchical features but ignore the relationships among features, resulting in shortcomings compared to humans in areas like interpretability and domain generalization. Recently, algorithms have introduced feature relationships by external prior knowledge and special auxiliary modules, which have been proven to bring multiple improvements in many computer vision tasks. However, prior knowledge is often difficult to obtain, and auxiliary modules bring additional consumption of computing and storage resources, which limits the flexibility and practicality of the algorithm. In this paper, we aim to drive the CNN model to learn the relationships among hierarchical deep features without prior knowledge and consumption increasing, while enhancing the fundamental performance of some aspects. Firstly, the task of learning the relationships among hierarchical features in CNN is defined and three key problems related to this task are pointed out, including the quantitative metric of connection intensity, the threshold of useless connections, and the updating strategy of relation graph. Secondly, Relational Embedding Convolution (RE-Conv) layer is proposed for the representation of feature relationships in convolution layer, followed by a scheme called use & disuse strategy which aims to address the three problems of feature relation learning. Finally, the improvements brought by the proposed feature relation learning scheme have been demonstrated through numerous experiments, including interpretability, domain generalization, noise robustness, and inference efficiency. In particular, the proposed scheme outperforms many state-of-the-art methods in the domain generalization community and can be seamlessly integrated with existing methods for further improvement. Meanwhile, it maintains comparable precision to the original CNN model while reducing floating point operations (FLOPs) by approximately 50%.