Learning multi-cellular representations of single-cell transcriptomics data enables characterization of patient-level disease states

Tianyu Liu,Edward De Brouwer,Tony Kuo,Nathaniel Diamant,Alsu Missarova,Hanchen Wang,Minsheng Hao,Hector Corrada Bravo,Gabriele Scalia,Aviv Regev,Graham Heimberg
DOI: https://doi.org/10.1101/2024.11.18.624166
2024-11-20
Abstract:Single-cell RNA-seq (scRNA-seq) has become a prominent tool for studying human biology and disease. The availability of massive scRNA-seq datasets and advanced machine learning techniques has recently driven the development of single-cell foundation models that provide informative and versatile cell representations based on expression profiles. However, to understand disease states, we need to consider entire tissue ecosystems, simultaneously considering many different interacting cells. Here, we tackle this challenge by generating representations derived from multi-cellular expression context measured with scRNA-seq of tissues. We develop PaSCient, a novel model that employs a multi-level representation learning paradigm and provides importance scores at the individual cell and gene levels for fine-grained analysis across multiple cell types and gene programs characteristic of a given disease. We apply PaSCient to learn a disease model across a large-scale scRNA-seq atlas of 24.3 million cells from over 5,000 patients. Comprehensive and rigorous benchmarking demonstrates the superiority of PaSCient in disease classification and its multiple downstream applications, including dimensionality reduction, gene/cell type prioritization, and patient subgroup discovery.
Biology
What problem does this paper attempt to address?
The problem that this paper attempts to solve is to generate a multi - cell representation that can characterize the patient - level disease state in single - cell transcriptome data. Specifically, the researchers developed a new model named PaSCient. This model utilizes large - scale single - cell RNA sequencing (scRNA - seq) data to generate patient - level representations by considering multiple interacting cells in the entire tissue ecosystem. These representations can be used not only for disease classification, but also for downstream applications such as dimensionality reduction, gene / cell - type prioritization, and patient sub - population discovery. ### Main Problems and Methods 1. **Problem Definition**: - The development of single - cell RNA sequencing (scRNA - seq) technology has made it possible to study human biology and diseases. However, in order to understand the disease state, it is necessary to consider multiple different interacting cells in the tissue ecosystem simultaneously. - Existing single - cell - based models mainly focus on the representation of single cells and lack comprehensive consideration of the multi - cell environment. 2. **Solution**: - Developed the PaSCient model, which generates patient - level representations from single - cell transcriptome data through a multi - layer representation learning paradigm. - The PaSCient model uses an attention mechanism to score the importance of each cell and gene, thereby achieving fine - grained analysis. ### Model Architecture 1. **Input Representation**: - The sample of each patient is represented as a matrix \( X_i\in\mathbb{R}^{M_i\times d_g} \), where \( M_i \) is the number of cells of patient \( i \) and \( d_g \) is the number of measured genes. 2. **Cell Embedding**: - Use a learnable cell embedding function \( f_\theta:\mathbb{R}^{d_g}\to\mathbb{R}^{d_h} \) to encode each cell and obtain the embedding vector \( z_j \) of each cell. - The cell embedding of patient \( i \) is represented as a matrix \( Z_i\in\mathbb{R}^{M_i\times d_h} \). 3. **Patient - Level Embedding**: - Use a softmax - attention pooling layer to aggregate cell embeddings and generate patient - level embeddings \( e_i \): \[ w_i = \text{softmax}(a_\theta(Z_i)) \] \[ e_i = w_i^T Z_i \] - where \( a_\theta:\mathbb{R}^{d_h}\to\mathbb{R} \) is a neural network that acts on each row of \( Z_i \). 4. **Classifier**: - The patient - level embedding \( e_i \) is sent to a neural network classifier \( h_\theta:\mathbb{R}^{d_h}\to\mathbb{R}^{d_c} \), where \( d_c \) is the number of disease categories in the dataset. - The final disease prediction result is: \[ \hat{p}_i=\text{softmax}(h_\theta(e_i)) \] ### Experiments and Results 1. **Dataset**: - The dataset contains 24.3 million scRNA - seq count profiles from more than 5,000 patients, covering 135 different disease - state labels from 413 studies and 189 tissues (organs). 2. **Performance Evaluation**: - PaSCient performs excellently in multi - disease classification tasks, and its weighted F1 - score is significantly better than other baseline models. - Through experiments with different aggregation mechanisms and sampling strategies, it was found that the non - linear attention mechanism combined with disease - and tissue - based oversampling strategies has the best effect. 3. **Interpretability**