Deep reinforced active learning for multi-class image classification

Emma Slade,Kim M. Branson
DOI: https://doi.org/10.48550/arXiv.2206.13391
2022-06-20
Abstract:High accuracy medical image classification can be limited by the costs of acquiring more data as well as the time and expertise needed to label existing images. In this paper, we apply active learning to medical image classification, a method which aims to maximise model performance on a minimal subset from a larger pool of data. We present a new active learning framework, based on deep reinforcement learning, to learn an active learning query strategy to label images based on predictions from a convolutional neural network. Our framework modifies the deep-Q network formulation, allowing us to pick data based additionally on geometric arguments in the latent space of the classifier, allowing for high accuracy multi-class classification in a batch-based active learning setting, enabling the agent to label datapoints that are both diverse and about which it is most uncertain. We apply our framework to two medical imaging datasets and compare with standard query strategies as well as the most recent reinforcement learning based active learning approach for image classification.
Computer Vision and Pattern Recognition,Artificial Intelligence,Machine Learning
What problem does this paper attempt to address?
The problem that this paper attempts to solve is: in medical image classification tasks, how to maximize model performance through active learning (AL) methods when the labeled data is limited and costly. Specifically, the author proposes an active learning framework based on deep reinforcement learning, aiming to select the most valuable data from a large amount of unlabeled data for labeling, thereby improving the accuracy of multi - class image classification. ### Main problem summary: 1. **High cost of data acquisition**: Obtaining and labeling high - quality medical images requires a great deal of time and professional knowledge. 2. **Low efficiency of traditional methods**: Traditional active learning methods usually label only one sample at a time, which is less efficient, especially when dealing with multi - class and high - dimensional data. 3. **Limitations of existing methods**: Existing reinforcement learning methods have limitations in dealing with multi - class classification tasks and cannot efficiently handle large - batch data labeling. ### Solutions proposed in the paper: - **Deep reinforcement active learning framework**: Combining the advantages of deep learning and reinforcement learning, a new active learning framework is proposed, using the deep Q - network (DQN) to learn the optimal query strategy. - **Batch labeling**: Allows multiple data points to be labeled at once, greatly improving training efficiency. - **Geometric distance metric**: Introduces the geometric distance metric in the latent space to ensure that the selected samples are both diverse and contain uncertainty, thereby improving classification accuracy. - **Noise robustness**: By simulating the noise situation in the real world, the robustness of the framework to noise is verified. ### Formula representation: - The goal of reinforcement learning is to maximize the expected discounted reward: \[ Q^{\pi}(s, a)=\mathbb{E}\left[\sum_{i = 0}^{\infty}\gamma^{i}r_{i}\right] \] where \(\gamma\) is the discount factor, set to 0.99. - The reward function is defined as the difference between the accuracy of the classifier on the validation set at the current time step and the previous time step: \[ r_{t + 1}=\text{ACC}(\hat{y}_{t+1}(\tilde{x}_{i}),\tilde{y}_{i})-\text{ACC}(\hat{y}_{t}(\tilde{x}_{i}),\tilde{y}_{i}) \] - The selection strategy for batch labeling: \[ \{a_{t}^{n}\}_{n = 1}^{N}=\arg\max_{a_{t}^{n}\in A,|a_{t}^{n}| = N}\sum_{a_{t}^{n}\in A}Q(s_{t},a_{t}^{n};\phi) \] Through these methods, the paper shows the superior performance of its framework on a variety of medical image data sets and performs well in both binary and multi - class classification tasks.