Automated Multi-Task Learning for Joint Disease Prediction on Electronic Health Records

Suhan Cui,Prasenjit Mitra
2024-10-09
Abstract:In the realm of big data and digital healthcare, Electronic Health Records (EHR) have become a rich source of information with the potential to improve patient care and medical research. In recent years, machine learning models have proliferated for analyzing EHR data to predict patients future health conditions. Among them, some studies advocate for multi-task learning (MTL) to jointly predict multiple target diseases for improving the prediction performance over single task learning. Nevertheless, current MTL frameworks for EHR data have significant limitations due to their heavy reliance on human experts to identify task groups for joint training and design model architectures. To reduce human intervention and improve the framework design, we propose an automated approach named AutoDP, which can search for the optimal configuration of task grouping and architectures simultaneously. To tackle the vast joint search space encompassing task combinations and architectures, we employ surrogate model-based optimization, enabling us to efficiently discover the optimal solution. Experimental results on real-world EHR data demonstrate the efficacy of the proposed AutoDP framework. It achieves significant performance improvements over both hand-crafted and automated state-of-the-art methods, also maintains a feasible search cost at the same time. Source code can be found via the link: \url{<a class="link-external link-https" href="https://github.com/SH-Src/AutoDP" rel="external noopener nofollow">this https URL</a>}.
Machine Learning
What problem does this paper attempt to address?
The problems that this paper attempts to solve are the two major challenges faced when performing multi - task learning (MTL) on electronic health record (EHR) data: 1. **How to determine which tasks should be trained together**: The task grouping problem involves finding groups of tasks that can be trained jointly. Multi - task learning can only provide advantages when there is a synergy between tasks, that is, jointly training these tasks can help the model learn general knowledge that is helpful for improving the task performance on the test set and prevent overfitting. Therefore, given a large number of related tasks in a domain, it may be necessary to group the tasks (allowing tasks to belong to multiple groups) in order to train a model for each task group. However, existing work usually relies on the experience of human experts to pre - select multiple tasks and create a shared model for these tasks. This method is not only time - consuming and labor - intensive (trying every possible task combination), but may also introduce task interference (putting unrelated diseases together), resulting in a performance degradation. Therefore, how to design appropriate task grouping for multi - task learning on EHR data is a key challenge. 2. **How to design the model architecture for multi - task learning**: Existing multi - task learning research usually relies on manually - designed architectures, which consist of a shared EHR encoder and several task - specific classifiers. However, due to the large number of possible operations and network topologies, it is impossible to manually adjust the optimal architecture. Moreover, the optimal architectures for different task groups may also be different. Therefore, when the number of tasks increases and different task combinations are involved, the problem becomes more complex. This requires a more efficient and effective method to design the multi - task learning architecture on EHR data. To address the above challenges, the authors propose an automated multi - task learning framework - AutoDP, which aims to simultaneously search for the optimal task grouping and the corresponding neural network architecture, thereby maximizing the performance gain of multi - task learning. By using a surrogate - model - based optimization method, AutoDP can efficiently discover the optimal solution. Experimental results show that the framework has significantly better classification performance on real - world EHR data than existing manually - designed and automated methods, while maintaining a feasible search cost.