Bohrium
robot
新建

空间站广场

论文
Notebooks
比赛
课程
Apps
我的主页
我的Notebooks
我的论文库
我的足迹

我的工作空间

任务
节点
文件
数据集
镜像
项目
数据库
公开
Diffusion models in SO3 space
Diffusion Model
Diffusion Model
fanjh@dp.tech
发布于 2023-09-25
推荐镜像 :ufconf:v4
推荐机型 :c12_m46_1 * NVIDIA GPU B
赞 2
Diffusion models in SO3 space
Score function and noise adding process
Random walk in 3D space
Random walk in SO3 space
Denoising dynamics in general and in SO3 space
Score function in SO3 space
Mean and std of the forward and backward process

Diffusion models in SO3 space

One typical way to represent the structure of proteins is to use the position of alpha carbon and the rotation matrix of each residue. So the diffusion process on the space of alpha carbon is simple Euclidean space, but the diffusion process on the rotation matrix ( space) is non-Euclidean.

In this notebook, I will show the general formulation of diffusion models in space.

代码
文本
[1]
import numpy as np
import torch
from ufconf.diffusion.diffuser import (
RotationDiffuser,
PositionDiffuser
)
from ufconf.diffusion import so3
from ufconf.config import model_config
from ufconf.diffusion.angle import IGSO3

import matplotlib.pyplot as plt
import math

cfg = model_config("ufconf_af2_v3")

cfg.diffusion.rotation.rw_approx_thres = 0.2

n = 100
nt = 100

diff_mask = torch.ones(n)

ts = torch.linspace(0, 1, nt+1)

rd = RotationDiffuser(**cfg.diffusion.rotation)
distrib = rd.igso3
代码
文本
[2]
# defionition of some probability density functions
def f_angle_igso3(omega, t, eps=1e-4, L=1024):
"""Truncated sum of IGSO(3) distribution added with cos theta.
"""
ls = torch.arange(L)[None,:] # of shape [1, L]

a = torch.sin(omega[:, None]*(ls+1/2) + (ls+1/2)*1e-6) / torch.sin(omega[:, None]/2 + 1/2*1e-6)
c = ((1 - torch.cos(omega[:, None])) / math.pi) # (N, *)
f = c*(2*ls + 1) * torch.exp(-1/2*ls*(ls+1)*t**2) * a
f = f.sum(dim=-1) + eps
# print("f shape", f.shape)
return f

def f_igso3(omega, t, eps=1e-4, L=200):
"""Truncated sum of IGSO(3) distribution.
"""
ls = torch.arange(L)[None,:] # of shape [1, L]

a = torch.sin(omega[:, None]*(ls+1/2) + (ls+1/2)*1e-6) / torch.sin(omega[:, None]/2 + 1/2*1e-6)
c = (1 - torch.cos(omega[:, None])) / math.pi # (N, *)
f = (2*ls + 1) * torch.exp(-1/2*ls*(ls+1)*t**2) * a
f = f.sum(dim=-1) + eps
# print("f shape", f.shape)
return f

def f_gaussian(x, std):
y = np.exp(-.5 * x**2 / (std)**2)
y = y
return y

def f_maxwell_gaussian(x,std):
y = x**2*np.exp(-.5 * x**2 / (std)**2)
return y

代码
文本

Score function and noise adding process

Just like the diffusion process in euclidean space, the core part of the diffusion process includes the score function and noise adding process. So to understand the diffusion process in space, I will first show the noise adding process (or random walk) in the 3D space.

代码
文本

Random walk in 3D space

We generate stepsize from a gaussian distribution with std= 0.1 in x,y,z direction, after 200 steps, we found that the distance from origin follows from the Maxwell-Boltzmann distribution with std =

代码
文本
[3]
# test the random walk in 3D position space
import numpy as np
from scipy.stats import maxwell

# Number of points and dimensions
num_points = 1000
dim = 3

# Generate initial points
points_initial = np.random.rand(num_points, dim)

# Standard deviation for Gaussian distribution
std_dev = 0.1

# Number of iterations
num_iterations = 200

# Copy of the initial points to apply translations
points_translated = points_initial.copy()

# Apply translations
for _ in range(num_iterations):
# Generate a translation vector for each point
translations = np.random.normal(scale=std_dev, size=(num_points, dim))
# Apply the translation
points_translated += translations

translation_vec = points_translated - points_initial
print("translation_vec", translation_vec.shape)
# Calculate Euclidean distance
distances = np.sqrt(np.sum((points_translated - points_initial) ** 2, axis=1))

import matplotlib.pyplot as plt

# Values for the x axis (distances)
x = np.linspace(min(distances), max(distances), 1000)

# calculate the standard deviation of the distances
std_all = np.sqrt(std_dev**2 * num_iterations)
print("std_all", std_all)

# Maxwell-Boltzmann distribution
y = x**2*np.exp(-x**2/(2*std_all**2))
y = y / (y.sum()*(max(distances) - min(distances))/1000)

# Plot histogram and PDF
plt.figure(figsize=(12, 6))
plt.hist(distances, bins=20, density=True, alpha=0.5, label='Histogram of distances')
plt.plot(x, y, label='maxwell-boltzmann')
plt.title('Histogram of distances and estimated PDF')
plt.xlabel('Distance')
plt.ylabel('Density')
plt.legend()
plt.grid(True)
loc = 0
scale = std_all
x = np.linspace(maxwell.ppf(0.01,loc=loc, scale=scale),maxwell.ppf(0.99,loc=loc, scale=scale), 100)
plt.plot(x, maxwell.pdf(x,loc=loc, scale=scale),'r-', lw=5, alpha=0.6, label='maxwell pdf')
plt.legend()
plt.show()
# Calculate the translation
translation_vec = points_translated - points_initial

# Values for the x axis (distances)
x_2 = np.linspace(-5, 5, 1000)

# Gaussian distribution
z = np.exp(-x_2**2/(2*std_all**2))
z = z / (z.sum()*10/1000)

# Plot histogram and PDF
plt.figure(figsize=(12, 6))
plt.hist(translation_vec, bins=20, density=True, alpha=0.5, label='Histogram of translation')
plt.plot(x_2, z, label='gaussian')
plt.title('Histogram of distances and estimated PDF')
plt.xlabel('Distance')
plt.ylabel('Density')
plt.legend()
plt.grid(True)
plt.show()
代码
文本

Random walk in SO3 space

Likewise in the 3D cases, we can do the random walk in space by sampling angles from gaussian distributions and compose them with Lie algebra bases. To write it formally,

代码
文本
[4]
# test the random walk and denoising dynamics in SO3 space
from scipy.spatial.transform import Rotation as R

# Function to generate random rotation matrix
def random_rotation_matrix():
return R.from_rotvec(np.random.rand(3)).as_matrix()

from scipy.linalg import expm
from tqdm import tqdm
# lie algebra matrix
G1 = np.array([[0, 0, 0], [0, 0, -1], [0, 1, 0]])
G2 = np.array([[0, 0, 1], [0, 0, 0], [-1, 0, 0]])
G3 = np.array([[0, -1, 0], [1, 0, 0], [0, 0, 0]])

G = np.array([G1,G2,G3])

# Number of matrices generated
matrix_num = 5000

# Standard deviation of Gaussian distribution
std_dev = torch.tensor([0.2]*100)
std_all = torch.tensor([torch.sqrt(torch.sum(std_dev[:index+1]**2)) for index in range(len(std_dev))])

# Initial rotation matrices
initial_rot_mats = np.array([random_rotation_matrix() for _ in range(matrix_num)])

# Store the current rotation matrices
rot_mats = np.copy(initial_rot_mats)

rot_vec_list = []

# initial_vec = so3.Log(initial_rot_mats)
# initial_vec_norm = initial_vec/(so3.theta_and_axis(initial_vec)[0][:,None] + 1e-6)
# Apply rotations iteratively
steps = len(std_dev)

progress_bar = tqdm(range(steps), desc="Progress:", ncols=100)

angles_list = []
for process_index in progress_bar:
# Generate rotation angles from Gaussian distribution
angles = np.clip(np.random.normal(loc=0, scale=std_dev[process_index], size=(matrix_num,3)), -np.pi, np.pi)
# angles = (np.random.normal(loc=0, scale=std_dev[process_index], size=(matrix_num,3)) + np.pi) % np.pi - np.pi
# Generate rotation matrices from lie algebra matrix
rot_matrix = np.einsum('ijk,li->ljk', G, angles)
perturbations = expm(rot_matrix)
# Apply rotations
rot_mats = np.matmul(perturbations, rot_mats)
# Calculate rotation angles from initial positions
angles = np.array([np.arccos((np.trace(np.matmul(initial_rot_mat.T, rot_mat)) - 1) / 2) for initial_rot_mat, rot_mat in zip(initial_rot_mats, rot_mats)])
# angles = np.array([so3.angle(so3.log(np.matmul(initial_rot_mat.T, rot_mat))) for initial_rot_mat, rot_mat in zip(initial_rot_mats, rot_mats)])
angles_list.append(angles)

print("std",std_all)
angles_list = np.array(angles_list)
print("angle list shape", angles_list.shape)
Progress:: 100%|██████████████████████████████████████████████████| 100/100 [00:13<00:00,  7.23it/s]std tensor([0.2000, 0.2828, 0.3464, 0.4000, 0.4472, 0.4899, 0.5292, 0.5657, 0.6000,
        0.6325, 0.6633, 0.6928, 0.7211, 0.7483, 0.7746, 0.8000, 0.8246, 0.8485,
        0.8718, 0.8944, 0.9165, 0.9381, 0.9592, 0.9798, 1.0000, 1.0198, 1.0392,
        1.0583, 1.0770, 1.0954, 1.1136, 1.1314, 1.1489, 1.1662, 1.1832, 1.2000,
        1.2166, 1.2329, 1.2490, 1.2649, 1.2806, 1.2961, 1.3115, 1.3267, 1.3416,
        1.3565, 1.3711, 1.3856, 1.4000, 1.4142, 1.4283, 1.4422, 1.4560, 1.4697,
        1.4832, 1.4967, 1.5100, 1.5232, 1.5362, 1.5492, 1.5620, 1.5748, 1.5875,
        1.6000, 1.6125, 1.6248, 1.6371, 1.6492, 1.6613, 1.6733, 1.6852, 1.6971,
        1.7088, 1.7205, 1.7321, 1.7436, 1.7550, 1.7664, 1.7776, 1.7889, 1.8000,
        1.8111, 1.8221, 1.8330, 1.8439, 1.8547, 1.8655, 1.8762, 1.8868, 1.8974,
        1.9079, 1.9183, 1.9287, 1.9391, 1.9494, 1.9596, 1.9698, 1.9799, 1.9900,
        2.0000])
angle list shape (100, 5000)

代码
文本

We found out that the angle distribution of this noise adding process follows the distribution called IGSO3 distribution, which writes:

代码
文本
[5]
# make plots
fig_index = 0
f, axs = plt.subplots(1, 3, figsize=(14, 4))
axs = axs.ravel()
print("axs",axs)
for index in [int(i) for i in np.linspace(0,len(std_all)-1,3)]:
print("index",index)
print("fig_index",fig_index)
std = std_all[index]
xlist = torch.linspace(0,math.pi,500)

zlist = f_angle_igso3(xlist, std)
zlist = zlist/(zlist.sum()*(math.pi/500))

wlist = f_maxwell_gaussian(xlist, std)
wlist = wlist/(wlist.sum()*(math.pi/500))

# Plot histogram
axs[fig_index].hist(angles_list[index], bins='auto', density=True)

axs[fig_index].plot(xlist[:], zlist[:], linewidth=2, label='angle IGSO3')
axs[fig_index].plot(xlist[:], wlist[:], linewidth=2, label='maxwell gaussian')
plt.xlabel("Rotation Angle")
plt.ylabel("Probability Density")
axs[fig_index].legend()
fig_index += 1
# plt.show()
# plt.title("Histogram of Rotation Angles")
代码
文本

Denoising dynamics in general and in SO3 space

The denoising dynamics in general can be formulated into a reverse SDE corresponding to a forward SDE in continuous form. A remarkable result from Anderson states that, the reverse SDE equation for a forward SDE process: can be modeled as: Note that this formulation is in Euclidean space, in case, we have to use the formulation in Lie algebra space, then exponentiate back to space.

To write it more explicitly, we can write the random walk in Lie algebra space as: where is sampled from three gaussian distribution with standard deviation , is the Three Lie algebra matrix.

Then the reverse process of this SDE in Lie algebra space is: Exponentiate back to the matrix, we have the reverse process in space: where is the probability density of matrix of the marginal distribution in the forward process. Now it is only a matter of calculating the expression , which is also called the score function.

代码
文本

Score function in SO3 space

The score function in SO3 space can be written using the chain rule:

Write the gradient of the rotation angle in compact form: Write in a compact form, using the identity

代码
文本
[6]
# from diffold.data.utils import to_numpy
std_all = torch.tensor(std_all)
# Store the current rotation matrices
rot_back_mats = np.copy(rot_mats)
rot_back_mats = torch.tensor(rot_back_mats).float()
initial_rot_mats = torch.tensor(initial_rot_mats).float()
print("std all",std_all)
print("rot_mats shape",rot_mats.shape)
print("initial rot mats shape",initial_rot_mats.shape)

progress_bar = tqdm(range(steps-1,-1,-1), desc="Progress:", ncols=100)
angles_back_list = []
score_list = []
for process_index in progress_bar:
sigma_t = std_all[process_index]
if process_index == 0:
delta_sigma_sqs = std_all[process_index]**2
else:
delta_sigma_sqs = std_all[process_index]**2 - std_all[process_index-1]**2
sigma_sqs_t = std_all[process_index]**2
if process_index == 0:
sigma_sqs_s = 0
else:
sigma_sqs_s = std_all[process_index-1]**2
# method 1
# rot_back_mats, _ = rd.denoise_sigma(rot_back_mats,initial_rot_mats,torch.ones(matrix_num),sigma_sqs_t,sigma_sqs_s)
# method 2
betas = sigma_sqs_t - sigma_sqs_s
delta_so3 = rd.random_so3(
sigmas=betas.sqrt(),
size=rot_back_mats.shape[-3],
igso3_obj=rd.igso3,
rw_approx_thres=rd.rw_approx_thres
) # * L 3
score = rd.score(rot_back_mats, initial_rot_mats, sigma_sqs_t.sqrt()) # * L 3
# print("score",score.shape)
update_vec = score * betas[..., None, None]
noisy_update_vec = update_vec + delta_so3
update_mat = rd.Exp(noisy_update_vec)
rot_back_mats = rot_back_mats @ update_mat
angles = np.array([so3.theta_and_axis(so3.Log(np.matmul(initial_rot_mat.T, rot_mat)))[0] for initial_rot_mat, rot_mat in zip(initial_rot_mats, rot_back_mats)])
angles_back_list.append(angles)

score_list = np.array(score_list)
angles_back_list = np.array(angles_back_list)
print("angle back list shape", angles_back_list.shape)
print("score_list shape",score_list.shape)
/tmp/ipykernel_69/3614484125.py:2: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  std_all = torch.tensor(std_all)
std all tensor([0.2000, 0.2828, 0.3464, 0.4000, 0.4472, 0.4899, 0.5292, 0.5657, 0.6000,
        0.6325, 0.6633, 0.6928, 0.7211, 0.7483, 0.7746, 0.8000, 0.8246, 0.8485,
        0.8718, 0.8944, 0.9165, 0.9381, 0.9592, 0.9798, 1.0000, 1.0198, 1.0392,
        1.0583, 1.0770, 1.0954, 1.1136, 1.1314, 1.1489, 1.1662, 1.1832, 1.2000,
        1.2166, 1.2329, 1.2490, 1.2649, 1.2806, 1.2961, 1.3115, 1.3267, 1.3416,
        1.3565, 1.3711, 1.3856, 1.4000, 1.4142, 1.4283, 1.4422, 1.4560, 1.4697,
        1.4832, 1.4967, 1.5100, 1.5232, 1.5362, 1.5492, 1.5620, 1.5748, 1.5875,
        1.6000, 1.6125, 1.6248, 1.6371, 1.6492, 1.6613, 1.6733, 1.6852, 1.6971,
        1.7088, 1.7205, 1.7321, 1.7436, 1.7550, 1.7664, 1.7776, 1.7889, 1.8000,
        1.8111, 1.8221, 1.8330, 1.8439, 1.8547, 1.8655, 1.8762, 1.8868, 1.8974,
        1.9079, 1.9183, 1.9287, 1.9391, 1.9494, 1.9596, 1.9698, 1.9799, 1.9900,
        2.0000])
rot_mats shape (5000, 3, 3)
initial rot mats shape torch.Size([5000, 3, 3])
Progress:: 100%|██████████████████████████████████████████████████| 100/100 [00:32<00:00,  3.11it/s]angle back list shape (100, 5000)
score_list shape (0,)

代码
文本

Mean and std of the forward and backward process

We can plot the mean and std of the forward and backward process, and find them to be consistent.

代码
文本
[7]
import pandas as pd
# pd.options.display.max_seq_items = 2000
print("angles_list shape", angles_list.shape)
print("angles_back_list shape", angles_back_list.shape)
plt.figure()
plt.plot(np.mean(angles_list,axis = 1), label = "forward mean")
plt.plot(np.mean(angles_back_list[::-1],axis = 1), label = "backward mean")
plt.legend()

plt.figure()
plt.plot(np.std(angles_list,axis = 1), label = "forward std")
plt.plot(np.std(angles_back_list[::-1],axis = 1), label = "backward std")
plt.legend()
代码
文本

We can also plot the rotation angle distribution along the denoising process.

代码
文本
[8]
# make plots
fig_index = 0
f, axs = plt.subplots(1, 3, figsize=(14, 4))
axs = axs.ravel()
print("axs",axs)
for index in [int(i) for i in np.linspace(0,len(std_all)-1,3)]:
print("index",index)
print("fig_index",fig_index)
std = std_all[-(index+1)]
xlist = torch.linspace(0,math.pi,500)

zlist = f_angle_igso3(xlist, std)
zlist = zlist/(zlist.sum()*(math.pi/500))

wlist = f_maxwell_gaussian(xlist, std)
wlist = wlist/(wlist.sum()*(math.pi/500))

# Plot histogram
axs[fig_index].hist(angles_back_list[index], bins='auto', density=True)

axs[fig_index].plot(xlist[:], zlist[:], linewidth=2, label='angle IGSO3')
axs[fig_index].plot(xlist[:], wlist[:], linewidth=2, label='maxwell gaussian')
plt.xlabel("Rotation Angle")
plt.ylabel("Probability Density")
axs[fig_index].legend()
fig_index += 1
# plt.show()
# plt.title("Histogram of Rotation Angles")
代码
文本
Diffusion Model
Diffusion Model
已赞2
推荐阅读
公开
Denoising diffusion probabilistic models tutorial-diffusion_02_model
notebookTutorial Diffusion Model
notebookTutorial Diffusion Model
喇叭花
发布于 2023-08-25
1 赞1 转存文件
公开
Diffusion probabilistic models -03- Applications to waveforms
notebookEnglish Diffusion Model
notebookEnglish Diffusion Model
喇叭花
发布于 2023-08-25
评论
 # Diffusion models i...

Linfeng Zhang

09-26 02:10
It would be better to give a "teaser example" showing what would be qualitatively wrong if SO3 is not treated properly
评论