Cross-Global Attention Graph Kernel Network Prediction of Drug Prescription

Hao-Ren Yao,Der-Chen Chang,Ophir Frieder,Wendy Huang,I-Chia Liang,Chi-Feng Hung
DOI: https://doi.org/10.1145/3388440.3412459
2020-08-05
Abstract:We present an end-to-end, interpretable, deep-learning architecture to learn a graph kernel that predicts the outcome of chronic disease drug prescription. This is achieved through a deep metric learning collaborative with a Support Vector Machine objective using a graphical representation of Electronic Health Records. We formulate the predictive model as a binary graph classification problem with an adaptive learned graph kernel through novel cross-global attention node matching between patient graphs, simultaneously computing on multiple graphs without training pair or triplet generation. Results using the Taiwanese National Health Insurance Research Database demonstrate that our approach outperforms current start-of-the-art models both in terms of accuracy and interpretability.
Machine Learning
What problem does this paper attempt to address?
The problem that this paper attempts to solve is the prediction of chronic disease drug prescription outcomes. Specifically, the authors propose an end - to - end, interpretable deep - learning architecture, aiming to predict the outcomes of chronic disease drug prescriptions through graph kernels. The importance of this problem lies in: 1. **Complexity of chronic disease drug prescriptions**: Chronic diseases such as hypertension, hyperlipidemia, and diabetes require long - term treatment plans, including multiple drug prescriptions to control disease progression. The effectiveness of drug prescriptions depends on the risk levels of severe complications and comorbidities that may occur in patients in the future. 2. **Utilization of electronic health records (EHRs)**: Although EHRs provide detailed medical histories of patients, these data are usually highly noisy and variable, making it difficult for traditional methods to accurately predict the effects of drug prescriptions. 3. **Limitations of existing models**: Existing RNN models based on the attention mechanism and other graph similarity learning methods perform poorly when dealing with complex EHR data and long - term disease progression, and also have over - fitting problems. ### Overview of the solution To solve the above problems, the authors propose a Cross - Global Attention Graph Kernel Network, with the following main features: - **Graph representation learning**: Represent the patient's EHR as a directed acyclic graph (DAG), where each node represents a medical event and the edges represent time intervals. - **Cross - global attention mechanism**: Perform node matching through the cross - global attention mechanism to automatically capture relevant information in long - term disease progression without generating training pairs or triplets. - **Graph kernel learning**: Learn an optimal graph kernel for binary classification tasks, thereby predicting the success or failure of drug prescriptions. - **SVM combination**: Optimize the graph kernel through the support vector machine (SVM) objective function to ensure the classification performance and interpretability of the model. ### Formula summary - **Graph Convolutional Network (GCN)**: \[ H = f(\tilde{D}^{-1}\tilde{A}XW) \] where \(\tilde{D}\) is the diagonal node - degree matrix of the adjacency matrix \(\tilde{A}\) with self - loops, \(X\) is the one - hot encoding matrix of node attributes, \(W\) is a trainable weight matrix, and \(f\) is a non - linear activation function (such as ReLU). - **Multi - layer GCN embedding**: \[ H_{k + 1}=f(\tilde{D}^{-1}\tilde{A}H_kW_k),\quad H_0 = X \] - **Final node embedding**: \[ H_{final}=\text{ReLU}(H_{1:t}W_{concat}) \] - **Global node clustering learning**: \[ A=\text{Sparsemax}(\text{ReLU}(H_{final}M^T)) \] \[ Q = \text{Tanh}(AM) \] - **Reconstruction error auxiliary loss**: \[ L_{recon}=\|H_{final}-Q\|_F \] - **Attention pooling**: \[ \alpha=\text{Softmax}(\text{Sim}(H_{final}, Q)) \] \[ G_{emb}=\sum_{i = 1}^{n}\alpha_iH_{final}^i \] - **Graph kernel definition**: \[ \text{Dist}_C(G_{emb1}, G_{emb2})=1-\frac{\langle G_{emb1}, G_{emb2}\rangle}{\|G_{emb1}\|\cdot\|G_{emb2}\|} \] \[ \text{Dist}_E(G_{emb1}, G_{emb2})=\|G_{emb1}-