Abstract:In this paper, we consider the problem of disease diagnosis. Unlike the conventional learning paradigm that treats labels independently, we propose a knowledge-enhanced framework, that enables training visual representation with the guidance of medical domain knowledge. In particular, we make the following contributions: First, to explicitly incorporate experts' knowledge, we propose to learn a neural representation for the medical knowledge graph via contrastive learning, implicitly establishing relations between different medical concepts. Second, while training the visual encoder, we keep the parameters of the knowledge encoder frozen and propose to learn a set of prompt vectors for efficient adaptation. Third, we adopt a Transformer-based disease-query module for cross-model fusion, which naturally enables explainable diagnosis results via cross attention. To validate the effectiveness of our proposed framework, we conduct thorough experiments on three x-ray imaging datasets across different anatomy structures, showing our model is able to exploit the implicit relations between diseases/findings, thus is beneficial to the commonly encountered problem in the medical domain, namely, long-tailed and zero-shot recognition, which conventional methods either struggle or completely fail to realize.
What problem does this paper attempt to address?
The problem that this paper attempts to solve is the disease diagnosis in radiological images, especially the limitations encountered by the traditional supervised learning paradigm when dealing with unseen disease categories. Specifically:
1. **Long - tailed Recognition**: Traditional methods have difficulty dealing with the situation of unbalanced data distribution, that is, the sample size of some disease categories is very small.
2. **Zero - shot Recognition**: Traditional models are unable to predict diseases that have not appeared in the training set.
To solve these problems, the author proposes a knowledge - enhanced disease diagnosis framework (K - Diag), which guides the learning of visual features by introducing prior knowledge in the medical field. The following are the main contributions of this framework:
- **Explicit Introduction of Expert Knowledge**: Through contrastive learning, the neural representation of the medical knowledge graph is constructed, implicitly establishing the relationships between different medical concepts.
- **Efficiently Adaptable Prompt Vectors**: When training the visual encoder, the parameters of the knowledge encoder are frozen, and a set of learnable prompt vectors are introduced to achieve efficient adaptation.
- **Cross - modal Fusion Module**: A Transformer - based disease query module is adopted for text - image cross - modal fusion, making the diagnosis results interpretable.
To verify the effectiveness of this framework, the author conducted experiments on three X - ray image datasets, covering different anatomical structures. The experimental results show that this model can utilize the implicit relationships between diseases and findings, thus performing well in long - tailed recognition and zero - shot recognition tasks.
### Formula Summary
The formulas involved in this paper are mainly used to describe the training process and loss function of the model. The following are the key formulas:
1. **Contrastive Learning Loss Function**:
\[
L_{\text{contrastive}} = -\frac{1}{2N} \sum_{k = 1}^{N} \left( \log \frac{e^{\frac{\langle \mathbf{n}_i, \mathbf{d}_i \rangle}{\tau}}}{\sum_{k = 1}^{N} e^{\frac{\langle \mathbf{n}_i, \mathbf{d}_k \rangle}{\tau}}} + \log \frac{e^{\frac{\langle \mathbf{d}_i, \mathbf{n}_i \rangle}{\tau}}}{\sum_{k = 1}^{N} e^{\frac{\langle \mathbf{d}_i, \mathbf{n}_k \rangle}{\tau}}} \right)
\]
where $\tau$ is the temperature parameter, and $\mathbf{n}_i$ and $\mathbf{d}_i$ are the embedding vectors of concepts and definitions respectively.
2. **Prompt Module Output Calculation**:
\[
\mathbf{k} = \Phi_{\text{prompt}}(\mathbf{T}) = (\mathbf{p} \cdot \mathbf{h})
\]
where $\mathbf{p} = \text{SoftMax}(\text{MLP}(\mathbf{T}))$ is the probability distribution, and $\mathbf{h}$ is the learnable prompt vector.
3. **Final Prediction**:
\[
\mathbf{s}_i = \Phi_{\text{query}}(\mathbf{x}_i, \mathbf{k}) \in \mathbb{R}^{Q \times C}
\]
where $\mathbf{x}_i$ is the feature of the input image, $\mathbf{k}$ is the disease embedding, and $C$ is the number of classes (set to 2 in binary classification tasks).
Through these improvements, the K - Diag model can achieve better performance on a variety of medical image datasets.