Bohrium
robot
新建

空间站广场

论文
Notebooks
比赛
课程
Apps
我的主页
我的Notebooks
我的论文库
我的足迹

我的工作空间

任务
节点
文件
数据集
镜像
项目
数据库
公开
TIGON——基于最优传输 (OT) 和神经网络 (NN) 构建细胞动力学轨迹
转录组学动力学推断
PyTorch
python
AI4S
Machine Learning
转录组学动力学推断PyTorchpythonAI4SMachine Learning
佩奇
更新于 2024-10-29
推荐镜像 :Basic Image:ubuntu:22.04-py3.10-pytorch2.0
推荐机型 :c40_m155_1 * NVIDIA T4
1
Processed Data in TIGON(v2)

TIGON——基于最优传输 (OT) 和神经网络 (NN) 构建细胞动力学轨迹

在这篇notebook中,你将了解到TIGON作为一种轨迹推断方法的基本原理以及代码实现。

参考文献: Sha, Y., Qiu, Y., Zhou, P. et al. Reconstructing growth and dynamic trajectories from single-cell transcriptomics data. Nat Mach Intell 6, 25–39 (2024).

代码
文本

TIGON基本介绍

TIGON的全称是Trajectory Inference with Growth via Optimal transport and Neural network,是一个基于最优传输和神经网络的细胞动力学推断方法。

  • 问题的引入

研究者希望通过组学数据(如基因组,转录组等等),揭示细胞的基因表达如何随着时间变化。这个过程被称为组学动力学推断,它有助于我们进一步解释疾病进程、胚胎发育、细胞分化等一系列重要生命过程的原理。但是组学数据推断过程中存在的一个普遍困难是,组学检测的过程中会杀死细胞,也就是说,我们无法在不同时间跟踪同一细胞,只能得到一系列非配对的快照(unpaired snapshots)。

具体而言,我们的输入是一系列scRNA-seq快照: 这里我们有个时刻,每个时刻抽取了个细胞,每个细胞有个基因的转录量信息(即,每个时刻可以得到维基因表达空间中的个点0)。

我们希望利用上面这些输入数据,通过一些建模方法,推断每个细胞在基因表达空间中是如何运动的。

  • (非平衡 & 动态)最优传输:(unbalanced & dynamic) Optimal Transport

最优传输问题可以给出两个分布之间使得某种代价最小的映射。具体的数学形式可以参考 (Kantorovich, 1942). 细胞动力学推断中的Waddington-OT (2019)就是基于此。

动态最优传输于2000年被提出 (Benamou, 2000), 可以看作是最优传输问题的一个动力学推广。现在给定初分布和末分布,最优传输问题等价于构造如下由速度场驱动的动力学系统: 并优化Wasserstein-2距离使之最小。Wasserstein-2距离定义为 上述优化问题得到的速度场驱动的动力学系统即可给出问题的最优传输方案。转录组动力学中被广泛使用的TrajectoryNet方法就是基于动态最优传输的。

上面问题的都代表概率分布,意味着它在全空间上积分为1保持不变。细胞数量并不是一成不变的,不仅如此,细胞数量的变化恰恰是我们希望关注的,尤其是在胚胎发育、癌症研究等议题中。总质量随时间变化的过程对应着非平衡动态最优传输。具体来说,动力学将由控制移动的速度场和使得质量改变的生灭率共同驱动: 而优化的对象变为Wasserstein-Fisher-Rao距离 这里的表示细胞在基因表达空间中的密度,它是一个随基因表达向量和时间变化的量。是自选的,实践中选取为1,TIGON的原文献中对取值的鲁棒性做了检验。

  • 重建误差 (Reconstruction Error)

上面的非平衡最优传输问题是TIGON的核心思想。问题中有两条限制,即的初末状态。神经网络并不擅长处理这种限制,所以TIGON将之转化为了“重建误差”并加入目标函数(loss)中。

这里提供一个对“重建误差”的直观理解:当我们构建的细胞动力学系统,可以完全的fit输入数据时,重建误差为0;当系统较好的fit输入数据时,重建误差较小;而当系统不能够的fit输入数据时,重建误差很大。的推导过程需要不少偏微分方程 (PDE) 知识且比较复杂,有兴趣的同学可以去看原文献。这里只列出计算公式。 其中

  • 神经网络 (Neural Networks)

该方法用两个MLP来拟合复杂函数:速度和增长。目标函数选取为: 其中是一个自选参数。总体来说,TIGON的介绍图如下:

  • 推断细胞轨迹

得到速度场后,我们可以像“描电场线”一样画出细胞轨迹。

  • 推断基因之间的调控关系以及基因对生长的贡献

速度场的Jacobian矩阵 列的元素可以描述源基因对目标基因的调控能力。

生长项的梯度 可以描述基因对细胞数量变化的调控作用。

这两个量我们可以以热图的形式画出。

下面我们进行代码实现。

代码
文本

准备

代码
文本

安装、导入相关的Python包

代码
文本
[1]
! pip install TorchDiffEqPack==1.0.1
! pip install torchdiffeq==0.2.3
! pip install matplotlib
! pip install seaborn
! pip install matplotlib==3.5.3 # 版本降级以适配代码
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting TorchDiffEqPack==1.0.1
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/fa/bb/31721405fabfffdf5ff2a3f713bcea78d245e081160490f4780fc7959c3b/TorchDiffEqPack-1.0.1-py3-none-any.whl (33 kB)
Requirement already satisfied: torch in /opt/mamba/lib/python3.10/site-packages (from TorchDiffEqPack==1.0.1) (2.0.0+cu118)
Requirement already satisfied: sympy in /opt/mamba/lib/python3.10/site-packages (from torch->TorchDiffEqPack==1.0.1) (1.11.1)
Requirement already satisfied: filelock in /opt/mamba/lib/python3.10/site-packages (from torch->TorchDiffEqPack==1.0.1) (3.10.0)
Requirement already satisfied: networkx in /opt/mamba/lib/python3.10/site-packages (from torch->TorchDiffEqPack==1.0.1) (3.0)
Requirement already satisfied: typing-extensions in /opt/mamba/lib/python3.10/site-packages (from torch->TorchDiffEqPack==1.0.1) (4.5.0)
Requirement already satisfied: triton==2.0.0 in /opt/mamba/lib/python3.10/site-packages (from torch->TorchDiffEqPack==1.0.1) (2.0.0)
Requirement already satisfied: jinja2 in /opt/mamba/lib/python3.10/site-packages (from torch->TorchDiffEqPack==1.0.1) (3.1.2)
Requirement already satisfied: cmake in /opt/mamba/lib/python3.10/site-packages (from triton==2.0.0->torch->TorchDiffEqPack==1.0.1) (3.26.0)
Requirement already satisfied: lit in /opt/mamba/lib/python3.10/site-packages (from triton==2.0.0->torch->TorchDiffEqPack==1.0.1) (15.0.7)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/mamba/lib/python3.10/site-packages (from jinja2->torch->TorchDiffEqPack==1.0.1) (2.1.2)
Requirement already satisfied: mpmath>=0.19 in /opt/mamba/lib/python3.10/site-packages (from sympy->torch->TorchDiffEqPack==1.0.1) (1.3.0)
Installing collected packages: TorchDiffEqPack
Successfully installed TorchDiffEqPack-1.0.1
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting torchdiffeq==0.2.3
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2c/9b/b9c3e17f261e30f630511390e0dd33fc529073f1f2db222a1f09dc49a1ae/torchdiffeq-0.2.3-py3-none-any.whl (31 kB)
Requirement already satisfied: torch>=1.3.0 in /opt/mamba/lib/python3.10/site-packages (from torchdiffeq==0.2.3) (2.0.0+cu118)
Requirement already satisfied: scipy>=1.4.0 in /opt/mamba/lib/python3.10/site-packages (from torchdiffeq==0.2.3) (1.10.1)
Requirement already satisfied: numpy<1.27.0,>=1.19.5 in /opt/mamba/lib/python3.10/site-packages (from scipy>=1.4.0->torchdiffeq==0.2.3) (1.24.2)
Requirement already satisfied: filelock in /opt/mamba/lib/python3.10/site-packages (from torch>=1.3.0->torchdiffeq==0.2.3) (3.10.0)
Requirement already satisfied: networkx in /opt/mamba/lib/python3.10/site-packages (from torch>=1.3.0->torchdiffeq==0.2.3) (3.0)
Requirement already satisfied: typing-extensions in /opt/mamba/lib/python3.10/site-packages (from torch>=1.3.0->torchdiffeq==0.2.3) (4.5.0)
Requirement already satisfied: triton==2.0.0 in /opt/mamba/lib/python3.10/site-packages (from torch>=1.3.0->torchdiffeq==0.2.3) (2.0.0)
Requirement already satisfied: jinja2 in /opt/mamba/lib/python3.10/site-packages (from torch>=1.3.0->torchdiffeq==0.2.3) (3.1.2)
Requirement already satisfied: sympy in /opt/mamba/lib/python3.10/site-packages (from torch>=1.3.0->torchdiffeq==0.2.3) (1.11.1)
Requirement already satisfied: lit in /opt/mamba/lib/python3.10/site-packages (from triton==2.0.0->torch>=1.3.0->torchdiffeq==0.2.3) (15.0.7)
Requirement already satisfied: cmake in /opt/mamba/lib/python3.10/site-packages (from triton==2.0.0->torch>=1.3.0->torchdiffeq==0.2.3) (3.26.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/mamba/lib/python3.10/site-packages (from jinja2->torch>=1.3.0->torchdiffeq==0.2.3) (2.1.2)
Requirement already satisfied: mpmath>=0.19 in /opt/mamba/lib/python3.10/site-packages (from sympy->torch>=1.3.0->torchdiffeq==0.2.3) (1.3.0)
Installing collected packages: torchdiffeq
Successfully installed torchdiffeq-0.2.3
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting matplotlib
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/8d/9d/d06860390f9d154fa884f1740a5456378fb153ff57443c91a4a32bab7092/matplotlib-3.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 8.3/8.3 MB 40.6 MB/s eta 0:00:0000:0100:01
Requirement already satisfied: python-dateutil>=2.7 in /opt/mamba/lib/python3.10/site-packages (from matplotlib) (2.8.2)
Collecting kiwisolver>=1.3.1
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/55/91/0a57ce324caf2ff5403edab71c508dd8f648094b18cfbb4c8cc0fde4a6ac/kiwisolver-1.4.7-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.6/1.6 MB 67.2 MB/s eta 0:00:00
Collecting contourpy>=1.0.1
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/99/e6/d11966962b1aa515f5586d3907ad019f4b812c04e4546cc19ebf62b5178e/contourpy-1.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (322 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 322.0/322.0 kB 55.3 MB/s eta 0:00:00
Collecting cycler>=0.10
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl (8.3 kB)
Requirement already satisfied: numpy>=1.23 in /opt/mamba/lib/python3.10/site-packages (from matplotlib) (1.24.2)
Collecting pyparsing>=2.3.1
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/be/ec/2eb3cd785efd67806c46c13a17339708ddc346cbb684eade7a6e6f79536a/pyparsing-3.2.0-py3-none-any.whl (106 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 106.9/106.9 kB 36.9 MB/s eta 0:00:00
Collecting pillow>=8
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/41/c3/94f33af0762ed76b5a237c5797e088aa57f2b7fa8ee7932d399087be66a8/pillow-11.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (4.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.4/4.4 MB 74.5 MB/s eta 0:00:00ta 0:00:01
Collecting fonttools>=4.22.0
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/e5/12/9a45294a7c4520cc32936edd15df1d5c24af701d2f5f51070a9a43d7664b/fonttools-4.54.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.6/4.6 MB 94.8 MB/s eta 0:00:00:00:01
Requirement already satisfied: packaging>=20.0 in /opt/mamba/lib/python3.10/site-packages (from matplotlib) (23.0)
Requirement already satisfied: six>=1.5 in /opt/mamba/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)
Installing collected packages: pyparsing, pillow, kiwisolver, fonttools, cycler, contourpy, matplotlib
Successfully installed contourpy-1.3.0 cycler-0.12.1 fonttools-4.54.1 kiwisolver-1.4.7 matplotlib-3.9.2 pillow-11.0.0 pyparsing-3.2.0
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting seaborn
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/83/11/00d3c3dfc25ad54e731d91449895a79e4bf2384dc3ac01809010ba88f6d5/seaborn-0.13.2-py3-none-any.whl (294 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 294.9/294.9 kB 8.6 MB/s eta 0:00:00
Requirement already satisfied: numpy!=1.24.0,>=1.20 in /opt/mamba/lib/python3.10/site-packages (from seaborn) (1.24.2)
Requirement already satisfied: pandas>=1.2 in /opt/mamba/lib/python3.10/site-packages (from seaborn) (1.5.3)
Requirement already satisfied: matplotlib!=3.6.1,>=3.4 in /opt/mamba/lib/python3.10/site-packages (from seaborn) (3.9.2)
Requirement already satisfied: kiwisolver>=1.3.1 in /opt/mamba/lib/python3.10/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (1.4.7)
Requirement already satisfied: pillow>=8 in /opt/mamba/lib/python3.10/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (11.0.0)
Requirement already satisfied: contourpy>=1.0.1 in /opt/mamba/lib/python3.10/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (1.3.0)
Requirement already satisfied: fonttools>=4.22.0 in /opt/mamba/lib/python3.10/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (4.54.1)
Requirement already satisfied: cycler>=0.10 in /opt/mamba/lib/python3.10/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (0.12.1)
Requirement already satisfied: pyparsing>=2.3.1 in /opt/mamba/lib/python3.10/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (3.2.0)
Requirement already satisfied: packaging>=20.0 in /opt/mamba/lib/python3.10/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (23.0)
Requirement already satisfied: python-dateutil>=2.7 in /opt/mamba/lib/python3.10/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /opt/mamba/lib/python3.10/site-packages (from pandas>=1.2->seaborn) (2022.7.1)
Requirement already satisfied: six>=1.5 in /opt/mamba/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib!=3.6.1,>=3.4->seaborn) (1.16.0)
Installing collected packages: seaborn
Successfully installed seaborn-0.13.2
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting matplotlib==3.5.3
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/e5/ca/3ed0e1de9df496392a4c9d75b0c78f82fe5758c66bb875903cf7a9402f0b/matplotlib-3.5.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.9 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 11.9/11.9 MB 64.6 MB/s eta 0:00:0000:0100:01
Requirement already satisfied: numpy>=1.17 in /opt/mamba/lib/python3.10/site-packages (from matplotlib==3.5.3) (1.24.2)
Requirement already satisfied: pyparsing>=2.2.1 in /opt/mamba/lib/python3.10/site-packages (from matplotlib==3.5.3) (3.2.0)
Requirement already satisfied: cycler>=0.10 in /opt/mamba/lib/python3.10/site-packages (from matplotlib==3.5.3) (0.12.1)
Requirement already satisfied: python-dateutil>=2.7 in /opt/mamba/lib/python3.10/site-packages (from matplotlib==3.5.3) (2.8.2)
Requirement already satisfied: packaging>=20.0 in /opt/mamba/lib/python3.10/site-packages (from matplotlib==3.5.3) (23.0)
Requirement already satisfied: pillow>=6.2.0 in /opt/mamba/lib/python3.10/site-packages (from matplotlib==3.5.3) (11.0.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /opt/mamba/lib/python3.10/site-packages (from matplotlib==3.5.3) (1.4.7)
Requirement already satisfied: fonttools>=4.22.0 in /opt/mamba/lib/python3.10/site-packages (from matplotlib==3.5.3) (4.54.1)
Requirement already satisfied: six>=1.5 in /opt/mamba/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib==3.5.3) (1.16.0)
Installing collected packages: matplotlib
  Attempting uninstall: matplotlib
    Found existing installation: matplotlib 3.9.2
    Uninstalling matplotlib-3.9.2:
      Successfully uninstalled matplotlib-3.9.2
Successfully installed matplotlib-3.5.3
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
代码
文本
[2]
import torch
torch.cuda.empty_cache()
import torch.nn as nn
import torch.optim as optim
import numpy as np
from TorchDiffEqPack import odesolve
import sys
import os
import matplotlib.pyplot as plt
import scipy.io as sio
import random
from torchdiffeq import odeint
from functools import partial
import getpass
from mpl_toolkits import mplot3d
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d
import seaborn as sns
import warnings
warnings.filterwarnings("ignore")
代码
文本

定义输入

代码
文本
[3]
class Args:
pass

def create_args():
args = Args()
args.dataset = input("Name of the data set. Options: EMT; Lineage; Bifurcation; Simulation (default: EMT): ") or 'EMT'
timepoints = input("Time points of data (default: 0, 0.1, 0.3, 0.9, 2.1): ")
args.timepoints = [float(tp.strip()) for tp in timepoints.split(",")] if timepoints else [0, 0.1, 0.3, 0.9, 2.1]
args.niters = int(input("Number of training iterations (default: 5000): ") or 5000)
args.lr = float(input("Learning rate (default: 3e-3): ") or 3e-3)
args.num_samples = int(input("Number of sampling points per epoch (default: 100): ") or 100)
args.hidden_dim = int(input("Dimension of the hidden layer (default: 16): ") or 16)
args.n_hiddens = int(input("Number of hidden layers (default: 4): ") or 4)
args.activation = input("Activation function (default: Tanh): ") or 'Tanh'
args.gpu = int(input("GPU device index (default: 0): ") or 0)
args.input_dir = input("Input Files Directory (default: bohr/ProcessedDataInTIGON-9q96/v2/Input/): ") or 'bohr/ProcessedDataInTIGON-9q96/v2/Input/'
args.save_dir = input("Output Files Directory (default: Output/): ") or 'Output/'
args.seed = int(input("Random seed (default: 1): ") or 1)
return args
代码
文本

构建神经网络 (MLP)

代码
文本
[4]
class UOT(nn.Module):
def __init__(self, in_out_dim, hidden_dim, n_hiddens, activation):
super().__init__()
self.in_out_dim = in_out_dim
self.hidden_dim = hidden_dim
self.hyper_net1 = HyperNetwork1(in_out_dim, hidden_dim, n_hiddens,activation) #v= dx/dt
self.hyper_net2 = HyperNetwork2(in_out_dim, hidden_dim, activation) #g

def forward(self, t, states):
z = states[0]
g_z = states[1]
logp_z = states[2]

batchsize = z.shape[0]

with torch.set_grad_enabled(True):
z.requires_grad_(True)

dz_dt = self.hyper_net1(t, z)
g = self.hyper_net2(t, z)

dlogp_z_dt = g - trace_df_dz(dz_dt, z).view(batchsize, 1)

return (dz_dt, g, dlogp_z_dt)


def trace_df_dz(f, z):
"""Calculates the trace of the Jacobian df/dz.
Stolen from: https://github.com/rtqichen/ffjord/blob/master/lib/layers/odefunc.py#L13
"""
sum_diag = 0.
for i in range(z.shape[1]):
sum_diag += torch.autograd.grad(f[:, i].sum(), z, create_graph=True)[0].contiguous()[:, i].contiguous()

return sum_diag.contiguous()


class HyperNetwork1(nn.Module):
# input x, t to get v= dx/dt
def __init__(self, in_out_dim, hidden_dim, n_hiddens, activation='Tanh'):
super().__init__()
Layers = [in_out_dim+1]
for i in range(n_hiddens):
Layers.append(hidden_dim)
Layers.append(in_out_dim)
if activation == 'Tanh':
self.activation = nn.Tanh()
elif activation == 'relu':
self.activation = nn.ReLU()
elif activation == 'elu':
self.activation = nn.ELU()
elif activation == 'leakyrelu':
self.activation = nn.LeakyReLU()

self.net = nn.ModuleList(
[nn.Sequential(
nn.Linear(Layers[i], Layers[i + 1]),
self.activation,
)
for i in range(len(Layers) - 2)
]
)
self.out = nn.Linear(Layers[-2], Layers[-1])

def forward(self, t, x):
# x is N*2
batchsize = x.shape[0]
t = torch.tensor(t).repeat(batchsize).reshape(batchsize, 1)
t.requires_grad=True
state = torch.cat((t,x),dim=1)
ii = 0
for layer in self.net:
if ii == 0:
x = layer(state)
else:
x = layer(x)
ii =ii+1
x = self.out(x)
return x

class HyperNetwork2(nn.Module):
# input x, t to get g
def __init__(self, in_out_dim, hidden_dim, activation='Tanh'):
super().__init__()
if activation == 'Tanh':
self.activation = nn.Tanh()
elif activation == 'relu':
self.activation = nn.ReLU()
elif activation == 'elu':
self.activation = nn.ELU()
elif activation == 'leakyrelu':
self.activation = nn.LeakyReLU()

self.net = nn.Sequential(
nn.Linear(in_out_dim+1, hidden_dim),
self.activation,
nn.Linear(hidden_dim,hidden_dim),
self.activation,
nn.Linear(hidden_dim,hidden_dim),
self.activation,
nn.Linear(hidden_dim,1))
def forward(self, t, x):
# x is N*2
batchsize = x.shape[0]
t = torch.tensor(t).repeat(batchsize).reshape(batchsize, 1)
t.requires_grad=True
state = torch.cat((t,x),dim=1)
return self.net(state)
def initialize_weights(m):
if hasattr(m, 'weight') and m.weight.dim() > 1:
nn.init.xavier_uniform_(m.weight.data)

代码
文本

编写训练函数(实际上就是算Loss,比较琐碎)

代码
文本

算Loss需要用到TorchDiffEqPack和torchdiffeq这两个包的ODE求解器。我这里对这两个ODE求解器分别写了一个示例供大家理解参考。下面两个代码框不是TIGON源代码。

代码
文本
[5]
from TorchDiffEqPack import odesolve

# 定义微分方程 dy/dt = f(y, t)
def model(t, y):
# y 是状态,t 是时间
return -y # 比如我们要解决 dy/dt = -y 的问题

# 初始状态
y0 = torch.tensor([1.0])

# 时间点
t = torch.linspace(0, 5, 100)
options = {}
options.update({'t0': 0})
options.update({'t1': 1})
options.update({'method': 'Dopri5'})
options.update({'h': None})
options.update({'rtol': 1e-3})
options.update({'atol': 1e-5})
options.update({'print_neval': False})
options.update({'neval_max': 1000000})
options.update({'safety': None})

# 使用 odesolve 进行求解
solution = odesolve(model, y0, options)

solution
tensor([0.3752])
代码
文本
[6]
from torchdiffeq import odeint

# 定义微分方程(这是一个函数,表示 dy/dt = f(y, t))
def model(t, y):
# 这里 y 是状态向量,t 是时间
dydt = -y # 例如 dy/dt = -y, 一个简单的指数衰减方程
return dydt

# 初始状态 y0
y0 = torch.tensor([1.0])

# 时间点
t = torch.linspace(0, 5, steps=6)

# 使用 odeint 求解
result = odeint(model, y0, t)

# result 中保存了每个时间点对应的解
result

tensor([[1.0000],
        [0.3679],
        [0.1353],
        [0.0498],
        [0.0183],
        [0.0067]])
代码
文本

这里开始是TIGON训练函数的源代码

代码
文本
[7]
class RunningAverageMeter(object):
"""Computes and stores the average and current value"""

def __init__(self, momentum=0.99):
self.momentum = momentum
self.reset()

def reset(self):
self.val = None
self.avg = 0

def update(self, val):
if self.val is None:
self.avg = val
else:
self.avg = self.avg * self.momentum + val * (1 - self.momentum)
self.val = val


# 混合Gauss概率密度函数 vector: <x.shape[0]>
def MultimodalGaussian_density(x,time_all,time_pt,data_train,sigma,device):
"""density function for MultimodalGaussian
"""
mu = data_train[time_all[time_pt]]
num_gaussian = mu.shape[0] # mu is number_sample * dimension
dim = mu.shape[1]
sigma_matrix = sigma * torch.eye(dim).type(torch.float32).to(device)
p_unn = torch.zeros([x.shape[0]]).type(torch.float32).to(device)
for i in range(num_gaussian):
m = torch.distributions.multivariate_normal.MultivariateNormal(mu[i,:], sigma_matrix)
p_unn = p_unn + torch.exp(m.log_prob(x)).type(torch.float32).to(device)
p_n = p_unn/num_gaussian
return p_n


# 生成混合Gauss的一组采样 matrix: <num_samples * dim>
def Sampling(num_samples,time_all,time_pt,data_train,sigma,device):
#perturb the coordinate x with Gaussian noise N (0, sigma*I )
mu = data_train[time_all[time_pt]]
num_gaussian = mu.shape[0] # mu is number_sample * dimension
dim = mu.shape[1]
sigma_matrix = sigma * torch.eye(dim)
m = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(dim), sigma_matrix)
noise_add = m.rsample(torch.Size([num_samples])).type(torch.float32).to(device)
# check if number of points is <num_samples
if num_gaussian < num_samples:
samples = mu[random.choices(range(0,num_gaussian), k=num_samples)] + noise_add
else:
samples = mu[random.sample(range(0,num_gaussian), num_samples)] + noise_add
return samples


# .npy -> tensor list: data_train
def loaddata(args,device):
data=np.load(os.path.join(args.input_dir,(args.dataset+'.npy')),allow_pickle=True)
data_train=[]
for i in range(data.shape[1]):
data_train.append(torch.from_numpy(data[0,i]).type(torch.float32).to(device))
return data_train


def ggrowth(t,y,func,device):
y_0 = torch.zeros(y[0].shape).type(torch.float32).to(device)
y_00 = torch.zeros(y[1].shape).type(torch.float32).to(device)
gg = func.forward(t, y)[1]
return (y_0,y_00,gg)
def trans_loss(t,y,func,device,odeint_setp):
outputs= func.forward(t, y)
v = outputs[0]
g = outputs[1]
y_0 = torch.zeros(g.shape).type(torch.float32).to(device)
y_00 = torch.zeros(v.shape).type(torch.float32).to(device)
g_growth = partial(ggrowth,func=func,device=device)
if torch.is_nonzero(t):
_,_, exp_g = odeint(g_growth, (y_00,y_0,y_0), torch.tensor([0,t]).type(torch.float32).to(device),atol=1e-5,rtol=1e-5,method='midpoint',options = {'step_size': odeint_setp})
f_int = (torch.norm(v,dim=1)**2+torch.norm(g,dim=1)**2).unsqueeze(1)*torch.exp(exp_g[-1])
return (y_00,y_0,f_int)
else:
return (y_00,y_0,y_0)

# 对一个列表求最大公约数
def gcd_list(numbers):
def _gcd(a, b):
while b:
a, b = b, a % b
return a

gcd_value = numbers[0]
for i in range(1, len(numbers)):
gcd_value = _gcd(gcd_value, numbers[i])

return gcd_value


def train_model(mse,func,args,data_train,train_time,integral_time,sigma_now,options,device,itr):
warnings.filterwarnings("ignore")
loss = 0
L2_value1 = torch.zeros(1,len(data_train)-1).type(torch.float32).to(device)
L2_value2 = torch.zeros(1,len(data_train)-1).type(torch.float32).to(device)
odeint_setp = gcd_list([num * 100 for num in integral_time])/100
for i in range(len(train_time)-1):
x = Sampling(args.num_samples, train_time,i+1,data_train,0.02,device)
x.requires_grad=True
logp_diff_t1 = torch.zeros(x.shape[0], 1).type(torch.float32).to(device)
g_t1 = logp_diff_t1
options.update({'t0': integral_time[i+1]})
options.update({'t1': integral_time[0]})
z_t0, g_t0, logp_diff_t0 = odesolve(func,y0=(x, g_t1, logp_diff_t1),options=options)
aa = MultimodalGaussian_density(z_t0, train_time, 0, data_train,sigma_now,device) #normalized density
zero_den = (aa < 1e-16).nonzero(as_tuple=True)[0]
aa[zero_den] = torch.tensor(1e-16).type(torch.float32).to(device)
logp_x = torch.log(aa)-logp_diff_t0.view(-1)
aaa = MultimodalGaussian_density(x, train_time, i+1, data_train,sigma_now,device) * torch.tensor(data_train[i+1].shape[0]/data_train[0].shape[0]) # mass
L2_value1[0][i] = mse(aaa,torch.exp(logp_x.view(-1)))
loss = loss + L2_value1[0][i]*1e4
# loss between each two time points
options.update({'t0': integral_time[i+1]})
options.update({'t1': integral_time[i]})
z_t0, g_t0, logp_diff_t0= odesolve(func,y0=(x, g_t1, logp_diff_t1),options=options)
aa = MultimodalGaussian_density(z_t0, train_time, i, data_train,sigma_now,device)* torch.tensor(data_train[i].shape[0]/data_train[0].shape[0])
#find zero density
zero_den = (aa < 1e-16).nonzero(as_tuple=True)[0]
aa[zero_den] = torch.tensor(1e-16).type(torch.float32).to(device)
logp_x = torch.log(aa)-logp_diff_t0.view(-1)
L2_value2[0][i] = mse(aaa,torch.exp(logp_x.view(-1)))
loss = loss + L2_value2[0][i]*1e4
# compute transport cost efficiency
transport_cost = partial(trans_loss,func=func,device=device,odeint_setp=odeint_setp)
x0 = Sampling(args.num_samples,train_time,0,data_train,0.02,device)
logp_diff_t00 = torch.zeros(x0.shape[0], 1).type(torch.float32).to(device)
g_t00 = logp_diff_t00
_,_,loss1 = odeint(transport_cost,y0=(x0, g_t00, logp_diff_t00),t = torch.tensor([0, integral_time[-1]]).type(torch.float32).to(device),atol=1e-5,rtol=1e-5,method='midpoint',options = {'step_size': odeint_setp})
loss = loss + integral_time[-1]*loss1[-1].mean(0)


if (itr >1):
if ((itr % 100 == 0) and (itr<=args.niters-400) and (sigma_now>0.02) and (L2_value1.mean()<=0.0003)):
sigma_now = sigma_now/2

return loss, loss1, sigma_now, L2_value1, L2_value2
代码
文本

编写画图函数

(包括画速度场、速度场的Jacobian、增长项的梯度)

代码
文本
[8]
# plot 3d of inferred trajectory of 20 cells
def plot_3d(func,data_train,train_time,integral_time,args,device):
viz_samples = 20
sigma_a = 0.001

t_list = []#list(reversed(integral_time))#integral_time #np.linspace(5, 0, viz_timesteps)
#options.update({'t_eval':t_list})
z_t_samples = []
z_t_data = []
v = []
g = []
t_list2 = []
odeint_setp = gcd_list([num * 100 for num in integral_time])/100
integral_time2 = np.arange(integral_time[0], integral_time[-1]+odeint_setp, odeint_setp)
integral_time2 = np.round_(integral_time2, decimals = 2)
plot_time = list(reversed(integral_time2))
sample_time = np.where(np.isin(np.array(plot_time),integral_time))[0]
sample_time = list(reversed(sample_time))

with torch.no_grad():
for i in range(len(integral_time)):

z_t0 = data_train[i]

z_t_data.append(z_t0.cpu().detach().numpy())
t_list2.append(integral_time[i])
# traj backward
z_t0 = Sampling(viz_samples, train_time, len(train_time)-1,data_train,sigma_a,device)
#z_t0 = z_t0[z_t0[:,2]>1]
logp_diff_t0 = torch.zeros(z_t0.shape[0], 1).type(torch.float32).to(device)
g0 = torch.zeros(z_t0.shape[0], 1).type(torch.float32).to(device)
v_t = func(torch.tensor(integral_time[-1]).type(torch.float32).to(device),(z_t0,g0, logp_diff_t0))[0] #True_v(z_t0)
g_t = func(torch.tensor(integral_time[-1]).type(torch.float32).to(device),(z_t0,g0, logp_diff_t0))[1]
v.append(v_t.cpu().detach().numpy())
g.append(g_t.cpu().detach().numpy())
z_t_samples.append(z_t0.cpu().detach().numpy())
t_list.append(plot_time[0])
options = {}
options.update({'method': 'Dopri5'})
options.update({'h': None})
options.update({'rtol': 1e-3})
options.update({'atol': 1e-5})
options.update({'print_neval': False})
options.update({'neval_max': 1000000})
options.update({'safety': None})

options.update({'t0': integral_time[-1]})
options.update({'t1': 0})
options.update({'t_eval':plot_time})
z_t1,_, logp_diff_t1= odesolve(func,y0=(z_t0,g0, logp_diff_t0),options=options)
for i in range(len(plot_time)-1):
v_t = func(torch.tensor(plot_time[i+1]).type(torch.float32).to(device),(z_t1[i+1], g0, logp_diff_t1))[0] #True_v(z_t0)
g_t = func(torch.tensor(plot_time[i+1]).type(torch.float32).to(device),(z_t1[i+1], g0, logp_diff_t1))[1]
z_t_samples.append(z_t1[i+1].cpu().detach().numpy())
g.append(g_t.cpu().detach().numpy())
v.append(v_t.cpu().detach().numpy())
t_list.append(plot_time[i+1])

aa=5#3
angle1 = 10#30
angle2 = 75#30
trans = 0.8
trans2 = 0.4
widths = 0.2 #arrow width
ratio1 = 0.4
fig = plt.figure(figsize=(4*2,3*2), dpi=200)
plt.tight_layout()
plt.margins(0, 0)
v_scale = 5


plt.tight_layout()
plt.axis('off')
plt.margins(0, 0)
#fig.suptitle(f'{t:.1f}day')
ax1 = plt.axes(projection ='3d')
ax1.grid(False)
ax1.set_xlabel('UMAP1')
ax1.set_ylabel('UMAP2')
ax1.set_zlabel('UMAP3')
ax1.set_xlim(-2,2)
ax1.set_ylim(-2,2)
ax1.set_zlim(-2,2)
ax1.set_xticks([-2,2])
ax1.set_yticks([-2,2])
ax1.set_zticks([-2,2])
ax1.view_init(elev=angle1, azim=angle2)
# ax1.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
# ax1.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
# ax1.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax1.invert_xaxis()
ax1.get_proj = lambda: np.dot(Axes3D.get_proj(ax1), np.diag([1, 1, 0.7, 1]))
line_width = 0.3

color_wanted = [np.array([250,187,110])/255,
np.array([173,219,136])/255,
np.array([250,199,179])/255,
np.array([238,68,49])/255,
np.array([206,223,239])/255,
np.array([3,149,198])/255,
np.array([180,180,213])/255,
np.array([178,143,237])/255]
for j in range(viz_samples): #individual traj
for i in range(len(plot_time)-1):
ax1.plot([z_t_samples[i][j,0],z_t_samples[i+1][j,0]],
[z_t_samples[i][j,1],z_t_samples[i+1][j,1]],
[z_t_samples[i][j,2],z_t_samples[i+1][j,2]],
linewidth=0.5,color ='grey',zorder=2)

# add inferrred trajecotry
for i in range(len(sample_time)):
ax1.scatter(z_t_samples[sample_time[i]][:,0],z_t_samples[sample_time[i]][:,1],z_t_samples[sample_time[i]][:,2],s=aa*10,linewidth=0, color=color_wanted[i],zorder=3)
ax1.quiver(z_t_samples[sample_time[i]][:,0],z_t_samples[sample_time[i]][:,1],z_t_samples[sample_time[i]][:,2],
v[sample_time[i]][:,0]/v_scale,v[sample_time[i]][:,1]/v_scale,v[sample_time[i]][:,2]/v_scale, color='k',alpha=1,linewidths=widths*2,arrow_length_ratio=0.3,zorder=4)

for i in range(len(integral_time)):
ax1.scatter(z_t_data[i][:,0],z_t_data[i][:,1],z_t_data[i][:,2],s=aa,linewidth=line_width,alpha = 0.7, facecolors='none', edgecolors=color_wanted[i],zorder=1)

#plt.savefig(os.path.join(args.save_dir, f"traj_3d.pdf"),format="pdf",pad_inches=0.1, bbox_inches='tight')
plt.show()
def Jacobian(f, z):
"""Calculates Jacobian df/dz.
"""
jac = []
for i in range(z.shape[1]):
df_dz = torch.autograd.grad(f[:, i], z, torch.ones_like(f[:, i]),retain_graph=True, create_graph=True)[0].view(z.shape[0], -1)
jac.append(torch.unsqueeze(df_dz, 1))
jac = torch.cat(jac, 1)
return jac

# plot avergae jac of v of cells (z_t) at time (time_pt)
def plot_jac_v(func,z_t,time_pt,title,gene_list,args,device):
g_xt0 = torch.zeros(1, 1).type(torch.float32).to(device)
logp_diff_xt0 = g_xt0
# compute the mean of jacobian of v within cells z_t at time (time_pt)
dim = z_t.shape[1]
jac = np.zeros((dim,dim))
for i in range(z_t.shape[0]):
x_t = z_t[i,:].reshape([1,dim])
v_xt = func(torch.tensor(time_pt).type(torch.float32).to(device),(x_t,g_xt0, logp_diff_xt0))[0]
jac = jac+Jacobian(v_xt, x_t).reshape(dim,dim).detach().cpu().numpy()
jac = jac/z_t.shape[0]
fig = plt.figure(figsize=(5, 4), dpi=200)
ax = fig.add_subplot(111)
plt.tight_layout()
plt.axis('off')
plt.margins(0, 0)
ax.set_title('Jacobian of velocity')
sns.heatmap(jac,cmap="coolwarm",xticklabels=gene_list,yticklabels=gene_list)
ax.set_xticks([]) # Remove x-axis tick marks
ax.set_yticks([]) # Remove y-axis tick marks
ax.axis('off')
#plt.savefig(os.path.join(args.save_dir, title),format="pdf",
# pad_inches=0.2, bbox_inches='tight')
plt.show()

# plot avergae gradients of g of cells (z_t) at time (time_pt)
def plot_grad_g(func,z_t,time_pt,title,gene_list,args,device):
g_xt0 = torch.zeros(1, 1).type(torch.float32).to(device)
logp_diff_xt0 = g_xt0
dim = z_t.shape[1]
gg = np.zeros((dim,dim))
for i in range(z_t.shape[0]):
x_t = z_t[i,:].reshape([1,dim])
g_xt = func(torch.tensor(time_pt).type(torch.float32).to(device),(x_t,g_xt0, logp_diff_xt0))[1]
gg = gg+torch.autograd.grad(g_xt, x_t, torch.ones_like(g_xt),retain_graph=True, create_graph=True)[0].view(x_t.shape[0], -1).reshape(dim,1).detach().cpu().numpy()
gg = gg/z_t.shape[0]
fig= plt.figure(figsize=(1, 4), dpi=200)
ax = fig.add_subplot(111)
plt.tight_layout()
plt.axis('off')
plt.margins(0, 0)
ax.set_title('Gradient of growth')
sns.heatmap(gg,cmap="coolwarm",xticklabels=[],yticklabels=gene_list)
ax.set_xticks([]) # Remove x-axis tick marks
ax.set_yticks([]) # Remove y-axis tick marks
ax.axis('off')
#plt.savefig(os.path.join(args.save_dir, title),format="pdf",
# pad_inches=0.2, bbox_inches='tight')
plt.show()
代码
文本

输入

一路回车即可

代码
文本
[9]
args=create_args()
Name of the data set. Options: EMT; Lineage; Bifurcation; Simulation (default: EMT):
Time points of data (default: 0, 0.1, 0.3, 0.9, 2.1):
Number of training iterations (default: 5000):
Learning rate (default: 3e-3):
Number of sampling points per epoch (default: 100):
Dimension of the hidden layer (default: 16):
Number of hidden layers (default: 4):
Activation function (default: Tanh):
GPU device index (default: 0):
Input Files Directory (default: bohr/ProcessedDataInTIGON-9q96/v2/Input/):
Output Files Directory (default: Output/):
Random seed (default: 1):
代码
文本

训练神经网络

提示: 这里程序运行开销很大,尽量开个比较好的GPU。如果没时间运行也没关系,运行好的结果已经放在了挂载的数据集文件中。不想运行的只需跳过即可,不影响后面代码的运行。

代码
文本
[9]
torch.enable_grad()
random.seed(args.seed)
torch.manual_seed(args.seed)

device = torch.device('cuda:' + str(args.gpu)
if torch.cuda.is_available() else 'cpu')
# load dataset
data_train = loaddata(args,device)
integral_time = args.timepoints

time_pts = range(len(data_train))
leave_1_out = []
train_time = [x for i,x in enumerate(time_pts) if i!=leave_1_out]


# model
func = UOT(in_out_dim=data_train[0].shape[1], hidden_dim=args.hidden_dim,n_hiddens=args.n_hiddens,activation=args.activation).to(device)
func.apply(initialize_weights)


# configure training options
options = {}
options.update({'method': 'Dopri5'})
options.update({'h': None})
options.update({'rtol': 1e-3})
options.update({'atol': 1e-5})
options.update({'print_neval': False})
options.update({'neval_max': 1000000})
options.update({'safety': None})

optimizer = optim.Adam(func.parameters(), lr=args.lr, weight_decay= 0.01)
lr_adjust = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[args.niters-400,args.niters-200], gamma=0.5, last_epoch=-1)
mse = nn.MSELoss()

LOSS = []
L2_1 = []
L2_2 = []
Trans = []
Sigma = []

if args.save_dir is not None:
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
ckpt_path = os.path.join(args.save_dir, 'ckpt.pth')
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path)
func.load_state_dict(checkpoint['func_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
print('Loaded ckpt from {}'.format(ckpt_path))

try:
sigma_now = 1
for itr in range(1, args.niters + 1):
optimizer.zero_grad()
loss, loss1, sigma_now, L2_value1, L2_value2 = train_model(mse,func,args,data_train,train_time,integral_time,sigma_now,options,device,itr)

loss.backward()
optimizer.step()
lr_adjust.step()

LOSS.append(loss.item())
Trans.append(loss1[-1].mean(0).item())
Sigma.append(sigma_now)
L2_1.append(L2_value1.tolist())
L2_2.append(L2_value2.tolist())
print('Iter: {}, loss: {:.4f}'.format(itr, loss.item()))
if itr % 500 == 0:
ckpt_path = os.path.join(args.save_dir, 'ckpt_itr{}.pth'.format(itr))
torch.save({'func_state_dict': func.state_dict()}, ckpt_path)
print('Iter {}, Stored ckpt at {}'.format(itr, ckpt_path))

except KeyboardInterrupt:
if args.save_dir is not None:
ckpt_path = os.path.join(args.save_dir, 'ckpt.pth')
torch.save({
'func_state_dict': func.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, ckpt_path)
print('Stored ckpt at {}'.format(ckpt_path))
print('Training complete after {} iters.'.format(itr))


ckpt_path = os.path.join(args.save_dir, 'ckpt.pth')
torch.save({
'func_state_dict': func.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'LOSS':LOSS,
'TRANS':Trans,
'L2_1': L2_1,
'L2_2': L2_2,
'Sigma': Sigma
}, ckpt_path)
print('Stored ckpt at {}'.format(ckpt_path))



代码
文本

可视化

代码
文本

载入训练结果

上面代码的运行结果已经放在数据集文件夹中。无论你是否真的进行了训练,都可以继续运行下面代码:

代码
文本
[10]
torch.enable_grad()
random.seed(args.seed)
torch.manual_seed(args.seed)

device = torch.device('cuda:' + str(args.gpu)
if torch.cuda.is_available() else 'cpu')
# load dataset
args.save_dir = 'bohr/ProcessedDataInTIGON-9q96/v2/Output/'
data_train = loaddata(args,device)
integral_time = args.timepoints

time_pts = range(len(data_train))
leave_1_out = []
train_time = [x for i,x in enumerate(time_pts) if i!=leave_1_out]


# model
func = UOT(in_out_dim=data_train[0].shape[1], hidden_dim=args.hidden_dim,n_hiddens=args.n_hiddens,activation=args.activation).to(device)

# load trained networks
if args.save_dir is not None:
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
ckpt_path = os.path.join(args.save_dir, 'ckpt_EMT.pth')
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path,map_location=torch.device('cpu'))
func.load_state_dict(checkpoint['func_state_dict'])
print('Loaded ckpt from {}'.format(ckpt_path))

Loaded ckpt from bohr/ProcessedDataInTIGON-9q96/v2/Output/ckpt_EMT.pth
代码
文本

绘制热图

它代表着基因调控关系

代码
文本
[11]
time_pt = 0
z_t = data_train[time_pt]
plot_jac_v(func,z_t,time_pt,'Average_jac_d0.pdf',['UMAP1','UMAP1','UMAP1'],args,device)
代码
文本

绘制热图

它代表着基因对生长的影响

代码
文本
[12]
time_pt = 0
z_t = data_train[time_pt]
plot_grad_g(func,z_t,time_pt,'Average_grad_d0.pdf',['UMAP1','UMAP1','UMAP1'],args,device)
代码
文本

绘制三维向量场

代码
文本
[13]
plot_3d(func,data_train,train_time,integral_time,args,device)
代码
文本
转录组学动力学推断
PyTorch
python
AI4S
Machine Learning
转录组学动力学推断PyTorchpythonAI4SMachine Learning
点个赞吧
推荐阅读
公开
AI4SCUP-DFODE: rank3 方案分享
AI4SAI4SCUP-DFODE
AI4SAI4SCUP-DFODE
bohr12ydrw
发布于 2024-04-27
3 赞4 转存文件
公开
细胞动力学模型探索之 MuTrans
生物信息学动力学模型
生物信息学动力学模型
孙楠
发布于 2023-11-21
1 赞
{/**/}