TorchSurv: A Lightweight Package for Deep Survival Analysis

Mélodie Monod,Peter Krusche,Qian Cao,Berkman Sahiner,Nicholas Petrick,David Ohlssen,Thibaud Coroller
2024-04-17
Abstract:TorchSurv is a Python package that serves as a companion tool to perform deep survival modeling within the PyTorch environment. Unlike existing libraries that impose specific parametric forms, TorchSurv enables the use of custom PyTorch-based deep survival models. With its lightweight design, minimal input requirements, full PyTorch backend, and freedom from restrictive survival model parameterizations, TorchSurv facilitates efficient deep survival model implementation and is particularly beneficial for high-dimensional and complex input data scenarios.
Machine Learning
What problem does this paper attempt to address?
The problem that this paper attempts to solve is the limitations of existing survival analysis libraries in deep - learning models, especially the restrictions of these libraries on predefined parameter forms and the integration issues with the PyTorch framework. Specifically: 1. **Restrictions on predefined parameter forms**: Existing survival analysis libraries usually limit users to specific parameter forms (such as linear functions) to define the parameters \(\theta\) of survival models. Such restrictions make it difficult for researchers to build more complex models, especially when dealing with high - dimensional and complex input data. 2. **Integration issues with PyTorch**: Existing libraries lack the ability to integrate seamlessly with PyTorch, resulting in the inability to fully utilize PyTorch's powerful functions for automatic gradient calculation and optimization. In addition, some libraries rely on external libraries (such as NumPy or Pandas), which further hinders automatic gradient calculation. 3. **Numerical stability issues**: Some libraries use likelihood functions instead of log - likelihood functions, which may lead to numerical instability and affect the stability of model training. To solve these problems, the paper introduces **TorchSurv**, which is a lightweight Python package designed to provide flexible tools for deep survival analysis. The main features of TorchSurv include: - **Custom PyTorch models**: Allows users to use custom PyTorch neural networks to define the parameters \(\theta\) of survival models, thereby achieving greater flexibility. - **Seamless integration with PyTorch**: Completely based on the PyTorch backend, it supports automatic gradient calculation and maximum - likelihood estimation. - **Efficient calculation of log - likelihood**: Calculates on the log scale to ensure numerical stability and efficient model training. - **Rich evaluation metrics**: Provides multiple evaluation metrics (such as AUC, C - index, Brier Score, etc.) to comprehensively evaluate the prediction performance of survival models. Through these features, TorchSurv provides researchers with a powerful tool that can build and evaluate deep survival models more flexibly, especially when dealing with complex and high - dimensional data.