Bohrium
robot
新建

空间站广场

论文
Notebooks
比赛
课程
Apps
我的主页
我的Notebooks
我的论文库
我的足迹

我的工作空间

任务
节点
文件
数据集
镜像
项目
数据库
公开
UniFold Multimer Symmetry Renaming Improvement
AI4S
D2L.AI
English
python
notebook
OpenScience
AI4SD2L.AIEnglishpythonnotebookOpenScience
ericalcaide1@gmail.com
发布于 2023-09-25
推荐镜像 :Third-party software:unifold-notebook:v2
推荐机型 :c16_m32_cpu
赞 1
3
2
UniFold Multimer Symmetry Renaming Improvement
Proving FAPE SO(3) Invariance
Proof that X-fape is not enough
Proving Cross-Fape renaming
Rotation and translation of the whole body does not change the result
Now a difficult case - it fails
Proof XX-FAPE is enough
Function from PR
Difficult case, it works
Try with a heteromer, it works too!
Conclusion

UniFold Multimer Symmetry Renaming Improvement

UniFold is a Deep Learning model for Protein Folding. UniFold Multimer is an improvement over UniFold as it allows for the prediction of Protein Complexes.

Symmetric protein complexes are complexes formed by symmetric monomers, which assemble in 3D space to carry out a particular function. Protein complexes can be formed by only 1 symmetric monomer, and are called homo-(N)-mers (homodimers, homotrimers, homohexamers, etc) or formed by more than 1 symmetric monomer, and are called hetero-(N)-mers (heterohexamer, etc). An example of a heterotetramer is hemoglobin, which is an tetramer:

In UniFold Multimer, researchers identified that penalizing the model with a high training loss when proteins were symmetric was not ideal, and proposed a permutation strategy to find a better alignment of the predictions to the labels taking into account this symmetry.

代码
文本
[ ]
%%bash
# download coordinates of a sample protein to test the algorithm
!wget https://github.com/dptech-corp/Uni-Fold/files/12474758/coords_saved.txt -O coords_saved.pt
!wget https://github.com/dptech-corp/Uni-Fold/files/12490827/coords_mask_saved.txt -O coords_mask_saved.pt
代码
文本

Proving FAPE SO(3) Invariance

代码
文本
[16]
import torch as th
from scipy.spatial.transform.rotation import Rotation as R
from unifold.modules.frame import Frame, Rotation
from scipy.optimize import linear_sum_assignment
import matplotlib.pyplot as plt


th.manual_seed(17)
# point cloud
a = th.randn(16, 1, 3) * 3.
aux = th.randn(16, 3, 3) * 0.1
a = a+aux
# rot + trans
t = th.randn(1, 1, 3)
rot = th.from_numpy(R.as_matrix(R.random())).float()
代码
文本
[10]
rot @ rot.T, rot.det()
(tensor([[1.0000e+00, 0.0000e+00, 2.9802e-08],
         [0.0000e+00, 1.0000e+00, 0.0000e+00],
         [2.9802e-08, 0.0000e+00, 1.0000e+00]]),
 tensor(1.))
代码
文本

This demonstrates that rot is a rotation matrix, as since for all , the 3D rotations group, and

代码
文本
[11]
b = a@rot + t
代码
文本
[12]
fa = Frame.from_3_points(*a.unbind(dim=-2))
fb = Frame.from_3_points(*b.unbind(dim=-2))
代码
文本
[13]
def compute_fape(
pred_frames: Frame,
target_frames: Frame,
) -> th.Tensor:
""" Computes the Frame Aligned Point Error as introduced in AlphaFold 2.
Inputs:
* pred_frames: (f,). Frames have both a rotation (f, 3, 3) and a translation (f, 3)
* target_frames: (f,). Frames have both a rotation (f, 3, 3) and a translation (f, 3)
Outputs: (,) th.Tensor
"""
local_pred_pos = pred_frames.invert()[..., None].apply(
pred_frames._t[..., None, :, :].float(),
)
local_target_pos = target_frames.invert()[..., None].apply(
target_frames._t[..., None, :, :].float(),
)

d_pt2 = (local_pred_pos - local_target_pos).square().sum(-1)
d_pt = d_pt2.add_(1e-5).sqrt()
fape = d_pt.mean(dim=(-1, -2))
return fape
代码
文本
[ ]
compute_fape(fa, fa) - compute_fape(fa, fb)
代码
文本

The result is ~0 (up to numerical precision of floating point math), and it proves the FAPE invariance to 3D roto-translations, as we wanted to prove

代码
文本

Proof that X-fape is not enough

Now we seek a candidate for a loss that allows to compute all labels against all predictions, so that we can later select the best matching. The result should be a matrix expressing the cost of assigning Label to Prediction. The matching can later be done using all permutations (but beware that permutations grow with the number of examples as ). Turns out this problem is the linear sum assignment problem and efficient algorithms exist that solve it in , much more efficiently.

Lets try a simple point expansion cross-FAPE (from all i to all j)

代码
文本
[17]
import torch as th

def compute_x_fape(
pred_frames: Frame,
target_frames: Frame,
pred_points: th.Tensor,
target_points: th.Tensor
) -> th.Tensor:
""" (..., (n f), 1, 1) @ (..., 1, n, p, d) -> (..., (n f), n, p, d)
(..., (n f), ni, p, d) - (..., (n f), nj, p, d) -> (..., ni, nj)
Inputs are:
* pred_frames, target_frames: (..., (n f))
* pred_points, target_points: (..., n, p=f)
Returns: (..., n, n)
"""
local_pred_pos = pred_frames[..., None, None].invert().apply(
pred_points[..., None, :, :, :].float(),
)
local_target_pos = target_frames[..., None, None].invert().apply(
target_points[..., None, :, :, :].float(),
)
d_pt2 = (local_pred_pos.unsqueeze(dim=-3) - local_target_pos.unsqueeze(dim=-4)).square().sum(-1)
d_pt = d_pt2.add_(1e-5).sqrt()
x_fape = d_pt.mean(dim=(-1, -4))
return x_fape
代码
文本
[18]
# check for file linked here: https://github.com/dptech-corp/Uni-Fold/pull/129
coords = th.load("coords_saved.pt")[0]
代码
文本

Proving Cross-Fape renaming

代码
文本
[19]
th.manual_seed(17)
n = coords.shape[0]
p = th.randperm(n)
p
p = th.tensor([1, 0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
代码
文本
[20]
frames_true = frames = Frame.from_3_points(*coords.unbind(dim=-2))
frames_perm = frames[p]
代码
文本
[21]
frames_perm.shape
torch.Size([12, 515])
代码
文本
[22]
frames_flat_true = Frame.cat([f for f in frames], dim=0)
frames_flat_perm = Frame.cat([f for f in frames_perm], dim=0)
代码
文本
[23]
mat = compute_x_fape(frames_flat_perm, frames_flat_true, frames_perm._t, frames_true._t )
rows, cols = linear_sum_assignment(mat.detach().cpu().numpy())
cols.tolist() == p.tolist(), mat[rows, cols].sum()
(True, tensor(51.0860))
代码
文本

Rotation and translation of the whole body does not change the result

代码
文本
[24]
p = th.tensor([1, 0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
frames_true = frames = Frame.from_3_points(*coords.unbind(dim=-2))
frames_perm = frames[p]

th.manual_seed(42)
t = th.randn(1, 1, 3)
rot = th.from_numpy(R.as_matrix(R.random())).float()
f = Frame(rotation=Rotation(rot[None]), translation=th.ones(1, 3)*15)[None]
frames_perm = f.compose(frames_perm)
代码
文本
[25]
frames_flat_true = Frame.cat([f for f in frames], dim=0)
frames_flat_perm = Frame.cat([f for f in frames_perm], dim=0)
代码
文本
[26]
mat = compute_x_fape(frames_flat_perm, frames_flat_true, frames_perm._t, frames_true._t )
rows, cols = linear_sum_assignment(mat.detach().cpu().numpy())
cols.tolist() == p.tolist(), mat[rows, cols].sum()
(True, tensor(51.0860))
代码
文本
[27]
plt.imshow(mat.numpy(), cmap="viridis_r")
plt.colorbar()
代码
文本

Now a difficult case - it fails

代码
文本
[28]
p = th.tensor([9, 0, 8, 1, 11, 5, 10, 7, 2, 3, 6, 4])
代码
文本
[29]
frames_true = frames = Frame.from_3_points(*coords.unbind(dim=-2))
frames_perm = frames[p]
代码
文本
[30]
frames_flat_true = Frame.cat([f for f in frames], dim=0)
frames_flat_perm = Frame.cat([f for f in frames_perm], dim=0)
代码
文本
[31]
mat = compute_x_fape(frames_flat_perm, frames_flat_true, frames_perm._t, frames_true._t )
rows, cols = linear_sum_assignment(mat.detach().cpu().numpy())
cols.tolist() == p.tolist(), mat[rows, cols].sum()
(False, tensor(601.9054))
代码
文本
[32]
plt.imshow(mat.numpy(), cmap="viridis_r")
plt.colorbar()
代码
文本

Essentially it fails to recover the permutation that would minimize the loss. We can see this is due to the X-FAPE having only the cross dependency in the poins, but FAPE is a function of both Frames and points!

Therefore we should think of something that accounts for both cross dependencies.

代码
文本

Proof XX-FAPE is enough

代码
文本

Function from PR

This function has been later adapted to a Pull Request to the original UniFold codebase, for the improvement of the training methodology. The idea is to perform an expansion both in the frames and the points of the FAPE loss for all symmetric chains, and later find the permutation that minimizes the loss with an efficient hungarian algorithm.

The mathematical formulation can be written as:

代码
文本
[49]
from typing import Optional

def compute_xx_fape(
pred_frames: Frame,
target_frames: Frame,
pred_points: th.Tensor,
target_points: th.Tensor,
frames_mask: Optional[th.Tensor] = None,
points_mask: Optional[th.Tensor] = None,
) -> th.Tensor:
""" FAPE cross-matrix from frames to the cross-matrix of points,
used to find a permutation which gives the optimal loss for a symmetric structure.

Notation for use with n chains of length p, under f=k frames:
- n: number of protein chains
- p: number of points (length of chain)
- d: dimension of points = 3
- f: arbitrary number of frames
- ': frames dimension
(..., n', f, 1, 1) @ (..., 1, 1, n, p, d) -> (..., n', f, n, p, d)
(..., ni', f, ni, p, d) - (..., nj', f, nj, p, d) -> (..., ni', nj', ni, nj)

Args:
pred_frames: (..., n(i'), f)
target_frames: (..., n(j'), f)
pred_points: (..., n(i), p, d)
target_points: (..., n(j), p, d)
frames_mask: (..., n', f) float tensor
points_mask: (..., n, p) float tensor

Returns:
(..., n(i'), n(j'), n(i), n(j)) th.Tensor
"""
# define masks for reduction, mask is (ni', nj', f, ni, nj, p)
mask = 1.
if frames_mask is not None:
mask = mask * (
frames_mask[..., :, None, :, None, None, None] + frames_mask[..., None, :, :, None, None, None]
).bool().float()
if points_mask is not None:
mask = mask * (
points_mask[..., None, None, None, :, None, :] + points_mask[..., None, None, None, None, :, :]
).bool().float()

# (..., n', f) · (..., n, p, d) -> (..., n', f, n, p, d)
local_pred_pos = pred_frames[..., None, None].invert().apply(
pred_points[..., None, None, :, :, :].float(),
)
# (..., n', f) · (..., n, p) -> (..., n', f, n, p)
local_target_pos = target_frames[..., None, None].invert().apply(
target_points[..., None, None, :, :, :].float(),
)
# chunk in ni, nj to avoid memory errors
n_, n = local_pred_pos.shape[-5], local_pred_pos.shape[-3]
xx_fape = local_pred_pos.new_zeros(*local_pred_pos.shape[:-5], n_, n_, n, n)
for i_ in range(n_):
for j_ in range(n_):
# (..., ni, f, ni, p, d) - (..., nj, f, nj, p, d) -> (..., ni, nj, f, ni', nj', p)
d_pt2 = (
local_pred_pos[..., i_:i_ + 1, None, :, :, None, :, :] -
local_target_pos[..., None, j_:j_ + 1, :, None, :, :, :]
).square().sum(-1)
d_pt = d_pt2.add_(1e-5).sqrt()
# (..., ni, nj, f, ni', nj', p) -> (..., ni, nj, ni', nj')
if frames_mask is not None or points_mask is not None:
mask_ = mask[..., i_:i_+1, j_:j_+1, :, :, :, :]
x_fape_ij = (d_pt * mask_).sum(dim=(-1, -4)) / mask_.sum(dim=(-1, -4))
else:
x_fape_ij = d_pt.mean(dim=(-1, -4))
xx_fape[..., i_, j_, :, :] = x_fape_ij
# save memory
del d_pt2, d_pt, x_fape_ij

return xx_fape
代码
文本

Difficult case, it works

代码
文本
[34]
p = th.tensor([9, 0, 8, 1, 11, 5, 10, 7, 2, 3, 6, 4])
代码
文本
[35]
frames_true = frames = Frame.from_3_points(*coords.unbind(dim=-2))
frames_perm = frames[p]
代码
文本
[36]
mat = compute_xx_fape(frames_perm, frames_true, frames_perm._t, frames_true._t ).detach().cpu()
x_mat_frames = mat.sum(dim=(-1, -2)) / (mat.shape[-1] * mat.shape[-2])
x_mat_points = mat.sum(dim=(-3, -4)) / (mat.shape[-3] * mat.shape[-4])
rows, cols = linear_sum_assignment((x_mat_frames + x_mat_points).numpy())
cols.tolist() == p.tolist(), mat[rows, cols, rows, cols].sum()
(True, tensor(0.0379))
代码
文本

Try with a heteromer, it works too!

代码
文本
[43]
import itertools
代码
文本
[37]
# (n, f) -> (12, ~300); 12 is 4a,4b,4c labels, comes from 4F4O PDB code (haemoglobin alpha, beta and haptoglobin)
coords, mask = th.load("coords_mask_saved.pt")
代码
文本
[41]
labels = th.tensor([*[0]*4, *[1]*4, *[2]*4])
uniq_labels, uniq_counts = th.unique(labels, return_counts=True)
代码
文本
[50]
for seed in [17, 42, 23, 10, 19, 34]:
th.manual_seed(seed)
p = [th.randperm(c).add_(uniq_counts[:i].sum()).tolist() for i,c in enumerate(uniq_counts)]
p = list(itertools.chain.from_iterable(p))
p_ = th.tensor(p[:])
print(f"running with overall perm: {p}")
# for all chain types
for i, c in enumerate(uniq_counts):
c = c.item()
offset = uniq_counts[:i].sum().item()
# (n f)
frames_true = frames = Frame.from_3_points(*coords[offset:offset+c].unbind(dim=-2))
frames_perm = frames[p_[offset:offset+c] - offset]
frames_mask = points_mask = mask[p[offset:offset+c]]
# rank 4 tensor to derive assignment
mat = compute_xx_fape(frames_perm, frames_true, frames_perm._t, frames_true._t, frames_mask=frames_mask, points_mask=points_mask)
# run assignment
x_mat_frames = mat.sum(dim=(-1, -2)) / (mat.shape[-1] * mat.shape[-2])
x_mat_points = mat.sum(dim=(-3, -4)) / (mat.shape[-3] * mat.shape[-4])
rows, cols = linear_sum_assignment((x_mat_frames + x_mat_points).numpy())
# reindex permutation
p_[offset:offset+c] = offset + th.tensor(cols.tolist())
print(f"Following algorithm, FAPE for iter {i} is : {mat[rows, cols, rows, cols].sum()}")
assert p_.tolist() == p, f"permutation was not recovered: {p_} vs correct permutation: {p}"
running with overall perm: [3, 1, 0, 2, 6, 7, 5, 4, 8, 10, 9, 11]
Following algorithm, FAPE for iter 0 is : 0.012649103999137878
Following algorithm, FAPE for iter 1 is : 0.012649117037653923
Following algorithm, FAPE for iter 2 is : 0.012649094685912132
running with overall perm: [2, 3, 0, 1, 6, 5, 7, 4, 8, 11, 10, 9]
Following algorithm, FAPE for iter 0 is : 0.012649103999137878
Following algorithm, FAPE for iter 1 is : 0.012649117037653923
Following algorithm, FAPE for iter 2 is : 0.012649094685912132
running with overall perm: [3, 0, 2, 1, 5, 7, 6, 4, 11, 9, 10, 8]
Following algorithm, FAPE for iter 0 is : 0.012649103999137878
Following algorithm, FAPE for iter 1 is : 0.012649117037653923
Following algorithm, FAPE for iter 2 is : 0.012649094685912132
running with overall perm: [1, 0, 2, 3, 7, 5, 4, 6, 11, 10, 8, 9]
Following algorithm, FAPE for iter 0 is : 0.012649103999137878
Following algorithm, FAPE for iter 1 is : 0.012649117037653923
Following algorithm, FAPE for iter 2 is : 0.012649094685912132
running with overall perm: [1, 2, 3, 0, 6, 5, 7, 4, 8, 9, 11, 10]
Following algorithm, FAPE for iter 0 is : 0.012649103999137878
Following algorithm, FAPE for iter 1 is : 0.012649117037653923
Following algorithm, FAPE for iter 2 is : 0.012649094685912132
running with overall perm: [1, 2, 0, 3, 5, 7, 4, 6, 8, 9, 11, 10]
Following algorithm, FAPE for iter 0 is : 0.012649103999137878
Following algorithm, FAPE for iter 1 is : 0.012649117037653923
Following algorithm, FAPE for iter 2 is : 0.012649094685912132
代码
文本

Conclusion

We can see that XX-FAPE can recover the permutation that minimizes the loss in all cases, even for heteromers!

Now it's time to add this improvement to the UniFold codebase so it can be leveraged in the next training run!

代码
文本
AI4S
D2L.AI
English
python
notebook
OpenScience
AI4SD2L.AIEnglishpythonnotebookOpenScience
已赞1
本文被以下合集收录
unifold
1410069299@qq.com
更新于 2024-05-02
1 篇0 人关注
推荐阅读
公开
Untitled
Deep LearningAI4SEnglishnotebookpythonOpenScience
Deep LearningAI4SEnglishnotebookpythonOpenScience
ericalcaide1@gmail.com
发布于 2023-09-26
1 转存文件
公开
3.1_蛋白单体结构预测
药物设计计算工具应用
药物设计计算工具应用
songk@dp.tech
发布于 2024-02-29