BrainNPT: Pre-training of Transformer networks for brain network classification

Jinlong Hu,Yangmin Huang,Nan Wang,Shoubin Dong
2023-08-02
Abstract:Deep learning methods have advanced quickly in brain imaging analysis over the past few years, but they are usually restricted by the limited labeled data. Pre-trained model on unlabeled data has presented promising improvement in feature learning in many domains, including natural language processing and computer vision. However, this technique is under-explored in brain network analysis. In this paper, we focused on pre-training methods with Transformer networks to leverage existing unlabeled data for brain functional network classification. First, we proposed a Transformer-based neural network, named as BrainNPT, for brain functional network classification. The proposed method leveraged <cls> token as a classification embedding vector for the Transformer model to effectively capture the representation of brain network. Second, we proposed a pre-training framework for BrainNPT model to leverage unlabeled brain network data to learn the structure information of brain networks. The results of classification experiments demonstrated the BrainNPT model without pre-training achieved the best performance with the state-of-the-art models, and the BrainNPT model with pre-training strongly outperformed the state-of-the-art models. The pre-training BrainNPT model improved 8.75% of accuracy compared with the model without pre-training. We further compared the pre-training strategies, analyzed the influence of the parameters of the model, and interpreted the trained model.
Neurons and Cognition,Machine Learning,Neural and Evolutionary Computing
What problem does this paper attempt to address?
### Main Problems Addressed by the Paper This paper primarily addresses the issue of limited labeled data constraining the performance of deep learning methods in brain functional network classification tasks. Specifically, the authors propose a new model based on Transformer—BrainNPT, and a pre-training framework to leverage unlabeled brain functional network data. ### Summary of Main Contributions 1. **Proposed BrainNPT Model**: This is a Transformer-based neural network specifically designed for brain functional network classification. The model introduces a learnable `<cls>` token as a classification embedding vector to capture the representation information of the entire brain network. Additionally, the model can be interpreted through Layer-wise Relevance Propagation (LRP). 2. **Proposed Pre-training Framework**: This framework utilizes a large amount of unlabeled brain functional network data for pre-training, adopting an effective pre-training strategy—Replaced Region Prediction (RRP). This strategy enables the model to capture the intrinsic structural information of brain functional networks without labels, significantly improving the model's accuracy. 3. **Data Augmentation and Pre-training**: To further expand the pre-training dataset, the authors use a sliding window technique to extract multiple functional connectivity matrices from Resting-State Functional Magnetic Resonance Imaging (rs-fMRI) data and pre-train the model with these data. ### Method Overview - **BrainNPT Model Architecture**: The model is based on the Transformer structure, including the `<cls>` token, Transformer blocks, and Multilayer Perceptron (MLP) modules. The `<cls>` token is treated as a virtual brain region to gather information from the entire brain network. - **Pre-training Framework**: A pre-training task is constructed using a random replaced region strategy. The model is trained through a binary classification task to identify whether a brain region has been replaced, allowing the model to learn the structural features of brain networks. ### Experimental Results and Discussion - **Experimental Datasets**: The study uses four public datasets for experiments: ABIDE II, REST-meta-MDD, HCP, and ABIDE I. The first two are used for downstream classification tasks, while the latter two are used for pre-training. - **Comparison with Other Models**: The BrainNPT model performs excellently even without pre-training, achieving the best or near-best performance on two datasets. Compared to other types of models, the Transformer-based approach demonstrates strong brain functional network classification capabilities. - **Effect of Pre-training**: With pre-training, the accuracy of the BrainNPT model improves by 8.75% compared to the non-pre-trained version, indicating that the pre-training strategy effectively enhances model performance. ### Conclusion This paper proposes a new Transformer-based brain functional network classification model, BrainNPT, and its pre-training framework, aiming to address the issue of data scarcity limiting the performance of deep learning methods. Experimental results demonstrate the effectiveness and superiority of this approach.