UniFold Multimer Symmetry Renaming Improvement
UniFold is a Deep Learning model for Protein Folding. UniFold Multimer is an improvement over UniFold as it allows for the prediction of Protein Complexes.
Symmetric protein complexes are complexes formed by symmetric monomers, which assemble in 3D space to carry out a particular function. Protein complexes can be formed by only 1 symmetric monomer, and are called homo-(N)-mers (homodimers, homotrimers, homohexamers, etc) or formed by more than 1 symmetric monomer, and are called hetero-(N)-mers (heterodimer, etc). An example of a heterotetramer is hemoglobin, which is an tetramer:
In UniFold Multimer, researchers identified that penalizing the model with a high training loss when proteins were symmetric was not ideal, and proposed a permutation strategy to find a better alignment of the predictions to the labels taking into account this symmetry.
Proving FAPE SO(3) Invariance
(tensor([[1.0000e+00, 0.0000e+00, 2.9802e-08], [0.0000e+00, 1.0000e+00, 0.0000e+00], [2.9802e-08, 0.0000e+00, 1.0000e+00]]), tensor(1.))
This demonstrates that rot
is a rotation matrix, as since for all , the 3D rotations group, and
tensor(0.0032)
This proves the FAPE invariance to 3D roto-translations, as we wanted to prove
Proof that X-fape is not enough
Now we seek a candidate for a loss that allows to compute all labels against all predictions, so that we can later select the best matching. The result should be a matrix expressing the cost of assigning Label to Prediction. The matching can later be done using all permutations (but beware that permutations grow with the number of examples as ). Turns out htis problem is the linear sum assignment problem and efficient algorithms exist that solve it in , much more efficiently.
Lets try a simple point expansion cross-FAPE (from all i to all j)
Testing Cross-Fape renaming
torch.Size([12, 515])
(True, tensor(51.0860))
Rotation and translation of the whole body does not change the result
(True, tensor(51.0860))
Now a difficult case - it fails - we need a better candidate
(False, tensor(601.9054))
Proof XX-FAPE is enough
Function from PR
This function has been later adapted to a Pull Request to the original UniFold codebase, for the improvement of the training methodology. The idea is to perform an expansion both in the frames and the points of the FAPE loss for all symmetric chains, and later find the permutation that minimizes the loss with an efficient hungarian algorithm.
Difficult case, it works
(True, tensor(0.0379))
Try with a heteromer, it works too!
running with overall perm: [3, 1, 0, 2, 6, 7, 5, 4, 8, 10, 9, 11] Following algorithm, FAPE for iter 0 is : 0.012649103999137878 Following algorithm, FAPE for iter 1 is : 0.012649117037653923 Following algorithm, FAPE for iter 2 is : 0.012649094685912132 running with overall perm: [2, 3, 0, 1, 6, 5, 7, 4, 8, 11, 10, 9] Following algorithm, FAPE for iter 0 is : 0.012649103999137878 Following algorithm, FAPE for iter 1 is : 0.012649117037653923 Following algorithm, FAPE for iter 2 is : 0.012649094685912132 running with overall perm: [3, 0, 2, 1, 5, 7, 6, 4, 11, 9, 10, 8] Following algorithm, FAPE for iter 0 is : 0.012649103999137878 Following algorithm, FAPE for iter 1 is : 0.012649117037653923 Following algorithm, FAPE for iter 2 is : 0.012649094685912132 running with overall perm: [1, 0, 2, 3, 7, 5, 4, 6, 11, 10, 8, 9] Following algorithm, FAPE for iter 0 is : 0.012649103999137878 Following algorithm, FAPE for iter 1 is : 0.012649117037653923 Following algorithm, FAPE for iter 2 is : 0.012649094685912132 running with overall perm: [1, 2, 3, 0, 6, 5, 7, 4, 8, 9, 11, 10] Following algorithm, FAPE for iter 0 is : 0.012649103999137878 Following algorithm, FAPE for iter 1 is : 0.012649117037653923 Following algorithm, FAPE for iter 2 is : 0.012649094685912132 running with overall perm: [1, 2, 0, 3, 5, 7, 4, 6, 8, 9, 11, 10] Following algorithm, FAPE for iter 0 is : 0.012649103999137878 Following algorithm, FAPE for iter 1 is : 0.012649117037653923 Following algorithm, FAPE for iter 2 is : 0.012649094685912132