Bohrium
robot
新建

空间站广场

论文
Notebooks
比赛
课程
Apps
我的主页
我的Notebooks
我的论文库
我的足迹

我的工作空间

任务
节点
文件
数据集
镜像
项目
数据库
公开
Uni-Fold Notebook-Fixed
Uni-Fold
PyTorch
Uni-FoldPyTorch
SkyRui
更新于 2024-10-18
推荐镜像 :Third-party software:unifold-notebook:v2
推荐机型 :c12_m92_1 * NVIDIA V100
Uni-Fold Notebook
1. CONFIGURATION
1.1 Input and Output
1.2 Hyper-parameters
1.3 Data-processing Functions
2. Inference
2.1 Data Input
2.2 Feature Generation
2.3 Model Prediction
3. Visualization

©️ Copyright 2023 @ Authors
作者: 杨舒文 📨 李子尧 📨
日期:2023-07-13
共享协议:本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。
快速开始:点击上方的 开始连接 按钮,选择 unifold-notebook:v2镜像 和任意配置机型即可开始。

代码
文本

Uni-Fold Notebook

This notebook provides protein structure prediction service of Uni-Fold as well as UF-Symmetry. Predictions of both protein monomers and multimers are supported. The homology search process in this notebook is enabled with the MMSeqs2 server provided by ColabFold. For more consistent results with the original AlphaFold(-Multimer), please refer to the open-source repository of Uni-Fold, or our convenient web server at Hermite™.

Please note that this notebook is provided as an early-access prototype, and is NOT an official product of DP Technology. It is provided for theoretical modeling only and caution should be exercised in its use.

Licenses

This Colab uses the Uni-Fold model parameters and its outputs are under the terms of the Creative Commons Attribution 4.0 International (CC BY 4.0) license. You can find details at: https://creativecommons.org/licenses/by/4.0/legalcode. The Colab itself is provided under the Apache 2.0 license.

Citations

Please cite the following papers if you use this notebook:

Acknowledgements

The model architecture of Uni-Fold is largely based on AlphaFold and AlphaFold-Multimer. The design of this notebook refers directly to ColabFold. We specially thank @sokrypton for his helpful suggestions to this notebook.

Copyright © 2022 DP Technology. All rights reserved.

代码
文本

1. CONFIGURATION

1.1 Input and Output

Set up input contents (from file or directly filling input_json) and output path.

  • jobname (str): name of the job, served as prefix of output directories.
  • input_json_path (str): path of input json file, which contains a list or dict of proteins. If it's a list, we take indices as IDs. Each protein is a dict with keys:
    • symmetry: protein's symmetry group. Use "C1" as default.
    • sequence: the sequences of the asymmetric unit (splitted by ";").
  • output_dir_base (str): root directory of output files.

For multimers, it's recommended to specify a cyclic symmetry group (e.g. C4) and the sequences of the asymmetric unit (i.e. do not copy them multiple times) to predict with UF-Symmetry.

代码
文本
[11]
import os
import json

jobname = 'unifold_bohrium' #@param {type:"string"}

# A DEMO CASE (FILE)
# Upload folding task first
# input_json_path = 'input.json'
# with open(input_json_path, encoding="utf-8") as fp:
# input_json = json.load(fp)

# A DEMO CASE (LIST)
# input_json = [
# {'sequence': 'MGSSHHHHHHSSGLVPRGSHMEDRDPTQFEERHLKFLQQLGKGNFGSVEMCRYDPLQDNTGEVVAVKKLQHSTEEHLRDFEREIEILKSLQHDNIVKYKGVCYSAGRRNLKLIMEYLPYGSLRDYLQKHKERIDHIKLLQYTSQICKGMEYLGTKRYIHRDLATRNILVENENRVKIGDFGLTKVLPQDKEFFKVKEPGESPIFWYAPESLTESKFSVASDVWSFGVVLYELFTYIEKSKSPPAEFMRMIGNDKQGQMIVFHLIELLKNNGRLPRPDGCPDEIYMIMTECWNNNVNQRPSFRDLALRVDQIRDNMAG'},
# {'symmetry': 'C2', 'sequence': 'GSHMKNVLIGVQTNLGVNKTGTEFGPDDLIQAYPDTFDEMELISVERQKEDFNDKKLKFKNTVLDTCEKIAKRVNEAVIDGYRPILVGGDHSISLGSVSGVSLEKEIGVLWISAHGDMNTPESTLTGNIHGMPLALLQGLGDRELVNCFYEGAKLDSRNIVIFGAREIEVEERKIIEKTGVKIVYYDDILRKGIDNVLDEVKDYLKIDNLHISIDMNVFDPEIAPGVSVPVRRGMSYDEMFKSLKFAFKNYSVTSADITEFNPLNDINGKTAELVNGIVQYMMNPDY'},
# {'symmetry': 'C2', 'sequence': 'GGSGGSGGSGGSLFCEQVTTVTNLFEKWNDCERTVVMYALLKRLRYPSLKFLQYSIDSNLTQNLGTSQTNLSSVVIDINANNPVYLQNLLNAYKTARKEDILHEVLNMLPLLKPGNEEAKLIYLTLIPVAVKDTMQQIVPTELVQQIFSYLLIHPAITSEDRRSLNIWLRHLEDHIQ;SVPSYGEDELQQAMRLLNAASRQRTEAANEDFGGT'},
# {'symmetry': 'C3', 'sequence': 'LILNLRGGAFVSNTQITMADKQKKFINEIQEGDLVRSYSITDETFQQNAVTSIVKHEADQLCQINFGKQHVVCTVNHRFYDPESKLWKSVCPHPGSGISFLKKYDYLLSEEGEKLQITEIKTFTTKQPVFIYHIQVENNHNFFANGVLAHAMQVSI'},
# ]

# A DEMO CASE (DICT)
input_json = {
'7teu': {'sequence': 'MGSSHHHHHHSSGLVPRGSHMEDRDPTQFEERHLKFLQQLGKGNFGSVEMCRYDPLQDNTGEVVAVKKLQHSTEEHLRDFEREIEILKSLQHDNIVKYKGVCYSAGRRNLKLIMEYLPYGSLRDYLQKHKERIDHIKLLQYTSQICKGMEYLGTKRYIHRDLATRNILVENENRVKIGDFGLTKVLPQDKEFFKVKEPGESPIFWYAPESLTESKFSVASDVWSFGVVLYELFTYIEKSKSPPAEFMRMIGNDKQGQMIVFHLIELLKNNGRLPRPDGCPDEIYMIMTECWNNNVNQRPSFRDLALRVDQIRDNMAG'},
'8d27': {'symmetry': 'C2', 'sequence': 'GSHMKNVLIGVQTNLGVNKTGTEFGPDDLIQAYPDTFDEMELISVERQKEDFNDKKLKFKNTVLDTCEKIAKRVNEAVIDGYRPILVGGDHSISLGSVSGVSLEKEIGVLWISAHGDMNTPESTLTGNIHGMPLALLQGLGDRELVNCFYEGAKLDSRNIVIFGAREIEVEERKIIEKTGVKIVYYDDILRKGIDNVLDEVKDYLKIDNLHISIDMNVFDPEIAPGVSVPVRRGMSYDEMFKSLKFAFKNYSVTSADITEFNPLNDINGKTAELVNGIVQYMMNPDY'},
'8oij': {'symmetry': 'C2', 'sequence': 'GGSGGSGGSGGSLFCEQVTTVTNLFEKWNDCERTVVMYALLKRLRYPSLKFLQYSIDSNLTQNLGTSQTNLSSVVIDINANNPVYLQNLLNAYKTARKEDILHEVLNMLPLLKPGNEEAKLIYLTLIPVAVKDTMQQIVPTELVQQIFSYLLIHPAITSEDRRSLNIWLRHLEDHIQ;SVPSYGEDELQQAMRLLNAASRQRTEAANEDFGGT'},
'c2404': {'symmetry': 'C3', 'sequence': 'LILNLRGGAFVSNTQITMADKQKKFINEIQEGDLVRSYSITDETFQQNAVTSIVKHEADQLCQINFGKQHVVCTVNHRFYDPESKLWKSVCPHPGSGISFLKKYDYLLSEEGEKLQITEIKTFTTKQPVFIYHIQVENNHNFFANGVLAHAMQVSI'},
}
代码
文本

1.2 Hyper-parameters

Setup inference parameters:

  • use_templates (bool): whether to use template features.
  • msa_mode (str): set to "MMseqs2" if requiring MSA features, "single_sequence" if not.
  • max_recycling_iters (int): max recycling iterations.
  • num_ensembles (int): number of ensembles.
  • manual_seed (int): seed.
  • times (int): number of Uni-Fold inference attempts.
  • max_display_cnt: max number of displayed proteins in visualization stage.
代码
文本
[12]
use_templates = True #@param {type:"boolean"}
msa_mode = "MMseqs2" #@param ["MMseqs2","single_sequence"]

max_recycling_iters = 3 #@param {type:"integer"}
num_ensembles = 1 #@param {type:"integer"}
manual_seed = 42 #@param {type:"integer"}
times = 1 #@param {type:"integer"}

max_display_cnt = 3 #@param {type:"integer"}
代码
文本

1.3 Data-processing Functions

The following block is recommended to be folded, if possible.

代码
文本
[13]
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import random
import logging
import time
import tarfile
import requests
from tqdm import tqdm
from pathlib import Path
from typing import Dict, List, Sequence, Tuple, Union, Any, Optional

from unifold.msa import templates
from unifold.msa import pipeline
from unifold.msa.tools import hhsearch
from unifold.dataset import load_and_process, UnifoldDataset
from unifold.symmetry import load_and_process_symmetry

logger = logging.getLogger(__name__)
TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'
DEFAULT_API_SERVER = "https://api.colabfold.com"


def clean_and_validate_sequence(
input_sequence: str, min_length: int, max_length: int) -> str:
"""Checks that the input sequence is ok and returns a clean version of it."""
# Remove all whitespaces, tabs and end lines; upper-case.
clean_sequence = input_sequence.translate(
str.maketrans('', '', ' \n\t')).upper()
aatypes = set(residue_constants.restypes) # 20 standard aatypes.
if not set(clean_sequence).issubset(aatypes):
raise ValueError(
f'Input sequence contains non-amino acid letters: '
f'{set(clean_sequence) - aatypes}. AlphaFold only supports 20 standard '
'amino acids as inputs.')
if len(clean_sequence) < min_length:
raise ValueError(
f'Input sequence is too short: {len(clean_sequence)} amino acids, '
f'while the minimum is {min_length}')
if len(clean_sequence) > max_length:
raise ValueError(
f'Input sequence is too long: {len(clean_sequence)} amino acids, while '
f'the maximum is {max_length}. You may be able to run it with the full '
f'Uni-Fold system depending on your resources (system memory, '
f'GPU memory).')
return clean_sequence


def validate_input(
input_sequences: Sequence[str],
symmetry_group: str,
min_length: int,
max_length: int,
max_multimer_length: int) -> Tuple[Sequence[str], bool, Optional[str]]:
"""Validates and cleans input sequences and determines which model to use."""
sequences = []

for input_sequence in input_sequences:
if input_sequence.strip():
input_sequence = clean_and_validate_sequence(
input_sequence=input_sequence,
min_length=min_length,
max_length=max_length)
sequences.append(input_sequence)

if symmetry_group != 'C1':
if symmetry_group.startswith('C') and symmetry_group[1:].isnumeric():
print(f'Using UF-Symmetry with group {symmetry_group}. If you do not '
f'want to use UF-Symmetry, please use `C1` and copy the AU '
f'sequences to the count in the assembly.')
is_multimer = (len(sequences) > 1)
return sequences, is_multimer, symmetry_group
else:
raise ValueError(f"UF-Symmetry does not support symmetry group "
f"{symmetry_group} currently. Cyclic groups (Cx) are "
f"supported only.")

elif len(sequences) == 1:
print('Using the single-chain model.')
return sequences, False, None

elif len(sequences) > 1:
total_multimer_length = sum([len(seq) for seq in sequences])
if total_multimer_length > max_multimer_length:
raise ValueError(f'The total length of multimer sequences is too long: '
f'{total_multimer_length}, while the maximum is '
f'{max_multimer_length}. Please use the full AlphaFold '
f'system for long multimers.')
print(f'Using the multimer model with {len(sequences)} sequences.')
return sequences, True, None

else:
raise ValueError('No input amino acid sequence provided, please provide at '
'least one sequence.')


def run_mmseqs2(x, prefix, use_env=True,
use_templates=False, use_pairing=False,
host_url=DEFAULT_API_SERVER) -> Tuple[List[str], List[str]]:
submission_endpoint = "ticket/pair" if use_pairing else "ticket/msa"

def submit(seqs, mode, N=101):
n, query = N, ""
for seq in seqs:
query += f">{n}\n{seq}\n"
n += 1

res = requests.post(f'{host_url}/{submission_endpoint}', data={'q': query, 'mode': mode})
try:
out = res.json()
except ValueError:
logger.error(f"Server didn't reply with json: {res.text}")
out = {"status": "ERROR"}
return out

def status(ID):
res = requests.get(f'{host_url}/ticket/{ID}')
try:
out = res.json()
except ValueError:
logger.error(f"Server didn't reply with json: {res.text}")
out = {"status": "ERROR"}
return out

def download(ID, path):
res = requests.get(f'{host_url}/result/download/{ID}')
with open(path, "wb") as out: out.write(res.content)

# process input x
seqs = [x] if isinstance(x, str) else x

mode = "env"
if use_pairing:
mode = ""
use_templates = False
use_env = False

# define path
path = f"{prefix}"
if not os.path.isdir(path): os.mkdir(path)

# call mmseqs2 api
tar_gz_file = f'{path}/out_{mode}.tar.gz'
N, REDO = 101, True

# deduplicate and keep track of order
seqs_unique = []
#TODO this might be slow for large sets
[seqs_unique.append(x) for x in seqs if x not in seqs_unique]
Ms = [N + seqs_unique.index(seq) for seq in seqs]
# lets do it!
if not os.path.isfile(tar_gz_file):
TIME_ESTIMATE = 150 * len(seqs_unique)
with tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar:
while REDO:
pbar.set_description("SUBMIT")

# Resubmit job until it goes through
out = submit(seqs_unique, mode, N)
while out["status"] in ["UNKNOWN", "RATELIMIT"]:
sleep_time = 5 + random.randint(0, 5)
logger.error(f"Sleeping for {sleep_time}s. Reason: {out['status']}")
# resubmit
time.sleep(sleep_time)
out = submit(seqs_unique, mode, N)

if out["status"] == "ERROR":
raise Exception(
f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.')

if out["status"] == "MAINTENANCE":
raise Exception(f'MMseqs2 API is undergoing maintenance. Please try again in a few minutes.')

# wait for job to finish
ID, TIME = out["id"], 0
pbar.set_description(out["status"])
while out["status"] in ["UNKNOWN", "RUNNING", "PENDING"]:
t = 5 + random.randint(0, 5)
logger.error(f"Sleeping for {t}s. Reason: {out['status']}")
time.sleep(t)
out = status(ID)
pbar.set_description(out["status"])
if out["status"] == "RUNNING":
TIME += t
pbar.update(n=t)

if out["status"] == "COMPLETE":
if TIME < TIME_ESTIMATE:
pbar.update(n=(TIME_ESTIMATE - TIME))
REDO = False

if out["status"] == "ERROR":
REDO = False
raise Exception(
f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.')

# Download results
download(ID, tar_gz_file)

# prep list of a3m files
if use_pairing:
a3m_files = [f"{path}/pair.a3m"]
else:
a3m_files = [f"{path}/uniref.a3m"]
if use_env: a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m")

# extract a3m files
if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files):
with tarfile.open(tar_gz_file) as tar_gz:
tar_gz.extractall(path)

# templates
if use_templates:
templates = {}

for line in open(f"{path}/pdb70.m8", "r"):
p = line.rstrip().split()
M, pdb, qid, e_value = p[0], p[1], p[2], p[10]
M = int(M)
if M not in templates: templates[M] = []
templates[M].append(pdb)

template_paths = {}
for k, TMPL in templates.items():
TMPL_PATH = f"{prefix}/templates_{k}"
if not os.path.isdir(TMPL_PATH):
os.mkdir(TMPL_PATH)
TMPL_LINE = ",".join(TMPL[:20])
os.system(f"curl -s -L {host_url}/template/{TMPL_LINE} | tar xzf - -C {TMPL_PATH}/")
os.system(f"cp {TMPL_PATH}/pdb70_a3m.ffindex {TMPL_PATH}/pdb70_cs219.ffindex")
os.system(f"touch {TMPL_PATH}/pdb70_cs219.ffdata")
template_paths[k] = TMPL_PATH
else:
template_paths = {}

# gather a3m lines
a3m_lines = {}
for a3m_file in a3m_files:
update_M, M = True, None
for line in open(a3m_file, "r"):
if len(line) > 0:
if "\x00" in line:
line = line.replace("\x00", "")
update_M = True
if line.startswith(">") and update_M:
M = int(line[1:].rstrip())
update_M = False
if M not in a3m_lines: a3m_lines[M] = []
a3m_lines[M].append(line)

# return results

a3m_lines = ["".join(a3m_lines[n]) for n in Ms]

if use_templates:
template_paths_ = []
for n in Ms:
if n not in template_paths:
template_paths_.append(None)
#print(f"{n-N}\tno_templates_found")
else:
template_paths_.append(template_paths[n])
template_paths = template_paths_
else:
template_paths = []

return (a3m_lines, template_paths) if use_templates else a3m_lines


def get_null_template(
query_sequence: Union[List[str], str], num_temp: int = 1
) -> Dict[str, Any]:
ln = (
len(query_sequence)
if isinstance(query_sequence, str)
else sum(len(s) for s in query_sequence)
)
output_templates_sequence = "A" * ln
output_confidence_scores = np.full(ln, 1.0)

templates_all_atom_positions = np.zeros(
(ln, templates.residue_constants.atom_type_num, 3)
)
templates_all_atom_masks = np.zeros((ln, templates.residue_constants.atom_type_num))
templates_aatype = templates.residue_constants.sequence_to_onehot(
output_templates_sequence, templates.residue_constants.HHBLITS_AA_TO_ID
)
template_features = {
"template_all_atom_positions": np.tile(
templates_all_atom_positions[None], [num_temp, 1, 1, 1]
),
"template_all_atom_masks": np.tile(
templates_all_atom_masks[None], [num_temp, 1, 1]
),
"template_sequence": [f"none".encode()] * num_temp,
"template_aatype": np.tile(np.array(templates_aatype)[None], [num_temp, 1, 1]),
"template_domain_names": [f"none".encode()] * num_temp,
"template_sum_probs": np.zeros([num_temp], dtype=np.float32),
}
return template_features


def get_template(
a3m_lines: str, template_path: str, query_sequence: str
) -> Dict[str, Any]:
template_featurizer = templates.HhsearchHitFeaturizer(
mmcif_dir=template_path,
max_template_date="2100-01-01",
max_hits=20,
kalign_binary_path="kalign",
release_dates_path=None,
obsolete_pdbs_path=None,
)

hhsearch_pdb70_runner = hhsearch.HHSearch(
binary_path="hhsearch", databases=[f"{template_path}/pdb70"]
)

hhsearch_result = hhsearch_pdb70_runner.query(a3m_lines)
hhsearch_hits = pipeline.parsers.parse_hhr(hhsearch_result)
templates_result = template_featurizer.get_templates(
query_sequence=query_sequence, hits=hhsearch_hits
)
return dict(templates_result.features)


def get_msa_and_templates(
jobname: str,
query_seqs_unique: Union[str, List[str]],
result_dir: Path,
msa_mode: str,
use_templates: bool,
homooligomers_num: int = 1,
host_url: str = DEFAULT_API_SERVER,
) -> Tuple[
Optional[List[str]], Optional[List[str]], List[Dict[str, Any]]
]:
use_env = msa_mode == "MMseqs2"

template_features = []
if use_templates:
a3m_lines_mmseqs2, template_paths = run_mmseqs2(
query_seqs_unique,
str(result_dir.joinpath(jobname)),
use_env,
use_templates=True,
host_url=host_url,
)
if template_paths is None:
logger.info("No template detected")
for index in range(0, len(query_seqs_unique)):
template_feature = get_null_template(query_seqs_unique[index])
template_features.append(template_feature)
else:
for index in range(0, len(query_seqs_unique)):
if template_paths[index] is not None:
template_feature = get_template(
a3m_lines_mmseqs2[index],
template_paths[index],
query_seqs_unique[index],
)
if len(template_feature["template_domain_names"]) == 0:
template_feature = get_null_template(query_seqs_unique[index])
logger.info(f"Sequence {index} found no templates")
else:
logger.info(
f"Sequence {index} found templates: {template_feature['template_domain_names'].astype(str).tolist()}"
)
else:
template_feature = get_null_template(query_seqs_unique[index])
logger.info(f"Sequence {index} found no templates")

template_features.append(template_feature)
else:
for index in range(0, len(query_seqs_unique)):
template_feature = get_null_template(query_seqs_unique[index])
template_features.append(template_feature)

if msa_mode == "single_sequence":
a3m_lines = []
num = 101
for i, seq in enumerate(query_seqs_unique):
a3m_lines.append(">" + str(num + i) + "\n" + seq)
else:
# find normal a3ms
a3m_lines = run_mmseqs2(
query_seqs_unique,
str(result_dir.joinpath(jobname)),
use_env,
use_pairing=False,
host_url=host_url,
)
if len(query_seqs_unique) > 1:
# find paired a3m if not a homooligomers
paired_a3m_lines = run_mmseqs2(
query_seqs_unique,
str(result_dir.joinpath(jobname)),
use_env,
use_pairing=True,
host_url=host_url,
)
else:
num = 101
paired_a3m_lines = []
for i in range(0, homooligomers_num):
paired_a3m_lines.append(
">" + str(num + i) + "\n" + query_seqs_unique[0] + "\n"
)

return (
a3m_lines,
paired_a3m_lines,
template_features,
)


def load_feature_for_one_target(
config, data_folder, seed=0, is_multimer=False, use_uniprot=False, symmetry_group=None,
):
if not is_multimer:
uniprot_msa_dir = None
sequence_ids = ["A"]
if use_uniprot:
uniprot_msa_dir = data_folder

else:
uniprot_msa_dir = data_folder
sequence_ids = open(os.path.join(data_folder, "chains.txt")).readline().split()

if symmetry_group is None:
batch, _ = load_and_process(
config=config.data,
mode="predict",
seed=seed,
batch_idx=None,
data_idx=0,
is_distillation=False,
sequence_ids=sequence_ids,
monomer_feature_dir=data_folder,
uniprot_msa_dir=uniprot_msa_dir,
)

else:
batch, _ = load_and_process_symmetry(
config=config.data,
mode="predict",
seed=seed,
batch_idx=None,
data_idx=0,
is_distillation=False,
symmetry=symmetry_group,
sequence_ids=sequence_ids,
monomer_feature_dir=data_folder,
uniprot_msa_dir=uniprot_msa_dir,
)
batch = UnifoldDataset.collater([batch])
return batch

代码
文本

2. Inference

2.1 Data Input

Input protein sequence(s).

代码
文本
[16]
import os
import re
import hashlib

from unifold.data import residue_constants, protein
from unifold.msa.utils import divide_multi_chains

MIN_SINGLE_SEQUENCE_LENGTH = 16
MAX_SINGLE_SEQUENCE_LENGTH = 1000
MAX_MULTIMER_LENGTH = 1000

output_dir_base="/root"
os.makedirs(output_dir_base, exist_ok=True)

if isinstance(input_json, list):
for i, input_dict in enumerate(input_json):
input_dict['id'] = str(i)
elif isinstance(input_json, dict):
new_input_json = []
for k, v in input_json.items():
v['id'] = k
new_input_json.append(v)
input_json = new_input_json
else:
assert False, f"Input JSON file type {type(input_json)} is neither list nor dict."
print(f"Number of input proteins: {len(input_json)}")


for input_dict in input_json:
if 'sequence' not in input_dict.keys():
raise KeyError(f"'sequence' not found in dict keys: {input_dict.keys()}")
input_sequences = input_dict['sequence'].strip().split(';')
if 'id' in input_dict.keys():
target_id = f"{jobname}_{input_dict['id']}"
else:
basejobname = "".join(input_sequences)
basejobname = re.sub(r'\W+', '', basejobname)
target_id = add_hash(jobname, basejobname)
input_dict['target_id'] = target_id
if 'symmetry' not in input_dict.keys():
input_dict['symmetry'] = 'C1'
symmetry_group = input_dict['symmetry']

# Validate the input.
sequences, is_multimer, symmetry_group = validate_input(
input_sequences=input_sequences,
symmetry_group=symmetry_group,
min_length=MIN_SINGLE_SEQUENCE_LENGTH,
max_length=MAX_SINGLE_SEQUENCE_LENGTH,
max_multimer_length=MAX_MULTIMER_LENGTH)
input_dict['sequences'] = sequences
input_dict['is_multimer'] = is_multimer
input_dict['symmetry_group'] = symmetry_group

descriptions = ['> ' + target_id + ' seq' + str(ii) for ii in range(len(sequences))]
if is_multimer:
divide_multi_chains(target_id, output_dir_base, sequences, descriptions)

unique_sequences = list(set(sequences))
homooligomers_num = len(sequences) if len(unique_sequences) == 1 else 1
input_dict['unique_sequences'] = unique_sequences
input_dict['homooligomers_num'] = homooligomers_num

with open(f"{output_dir_base}/{target_id}.fasta", "w") as f:
for des, seq in zip(descriptions, sequences):
f.write(f"{des}\n{seq}\n")
Number of input proteins: 4
Using the single-chain model.
Using UF-Symmetry with group C2. If you do not want to use UF-Symmetry, please use `C1` and copy the AU sequences to the count in the assembly.
Using UF-Symmetry with group C2. If you do not want to use UF-Symmetry, please use `C1` and copy the AU sequences to the count in the assembly.
Using UF-Symmetry with group C3. If you do not want to use UF-Symmetry, please use `C1` and copy the AU sequences to the count in the assembly.
代码
文本

2.2 Feature Generation

Process features for Uni-Fold prediction.

代码
文本
[17]
import pickle
import gzip
from unifold.msa import parsers
from unifold.data.utils import compress_features
from unifold.data.protein import PDB_CHAIN_IDS

result_dir = Path(output_dir_base)

for input_dict in input_json:
output_dir = os.path.join(output_dir_base, input_dict['target_id'])
input_dict['output_dir'] = output_dir

(
unpaired_msa,
paired_msa,
template_results,
) = get_msa_and_templates(
input_dict['target_id'],
input_dict['unique_sequences'],
result_dir=result_dir,
msa_mode=msa_mode,
use_templates=use_templates,
homooligomers_num=input_dict['homooligomers_num']
)

for idx, seq in enumerate(input_dict['unique_sequences']):
chain_id = PDB_CHAIN_IDS[idx]
sequence_features = pipeline.make_sequence_features(
sequence=seq, description=f'> {jobname} seq {chain_id}', num_res=len(seq)
)
monomer_msa = parsers.parse_a3m(unpaired_msa[idx])
msa_features = pipeline.make_msa_features([monomer_msa])
template_features = template_results[idx]
feature_dict = {**sequence_features, **msa_features, **template_features}
feature_dict = compress_features(feature_dict)
features_output_path = os.path.join(
output_dir, "{}.feature.pkl.gz".format(chain_id)
)
pickle.dump(
feature_dict,
gzip.GzipFile(features_output_path, "wb"),
protocol=4
)
if input_dict['is_multimer']:
multimer_msa = parsers.parse_a3m(paired_msa[idx])
pair_features = pipeline.make_msa_features([multimer_msa])
pair_feature_dict = compress_features(pair_features)
uniprot_output_path = os.path.join(
output_dir, "{}.uniprot.pkl.gz".format(chain_id)
)
pickle.dump(
pair_feature_dict,
gzip.GzipFile(uniprot_output_path, "wb"),
protocol=4,
)

COMPLETE: 100%|██████████| 150/150 [elapsed: 00:06 remaining: 00:00]
WARNING:absl:The exact sequence DPTQFEERHLKFLQQLGKGNFGSVEMCRYDPLQDGEVVAVKKLQHSTEEHLRDFEREIEILKSLQHDNIVKYKGVCYSAGRRNLKLIMEYLPYGSLRDYLQKHKERIDHIKLLQYTSQICKGMEYLGTKRYIHRDLATRNILVENENRVKIGDFGLTKVLPQDKEYYKVKEPGESPIFWYAPESLTESKFSVASDVWSFGVVLYELFTYIEKSKSPPAEFMRMIGNDKQGQMIVFHLIELLKNNGRLPRPDGCPDEIYMIMTECWNNNVNQRPSFRDLALRVDQIRDNMAG was not found in 7ll5_A. Realigning the template to the actual sequence.
WARNING:absl:The exact sequence GDPTQFEERHLKFLQQLGKGNFGSVEMCRYDPLQDNTGEVVAVKKLQHSTEEHLRDFEREIEILKSLQHDNIVKYKGVCYSGRRNLKLIMEYLPYGSLRDYLQKHKERIDHIKLLQYTSQICKGMEYLGTKRYIHRDLATRNILVENENRVKIGDFGLTKVLPQDKEYYKVKEPGESPIFWYAPESLTESKFSVASDVWSFGVVLYELFTYIEKSKSPPAEFMRMIGNDKQGQMIVFHLIELLKNNGRLPRPDGCPDEIYMIMTECWNNNVNQRPSFRDLALRVDQIRDNMAG was not found in 6vnk_A. Realigning the template to the actual sequence.
WARNING:absl:The exact sequence PTQFEERHLKFLQQLGKGNFGSVEMCRYDPLQDNTGEVVAVKKLQHSTEEHLRDFEREIEILKSLQHDNIVKYKGVCYSAGRNLKLIMEYLPYGSLRDYLQKHKERIDHIKLLQYTSQICKGMEYLGTKRYIHRDLATRNILVENENRVKIGDFGLTKVLPQDKEYYKVKEPGESPIFWYAPESLTESKFSVASDVWSFGVVLYELFTYIEKSKSPPAEFMRMIGNDKQGQMIVFHLIELLKNNGRLPRPDGCPDEIYMIMTECWNNNVNQRPSFRDLALRVDQIRDNMAG was not found in 5cf4_B. Realigning the template to the actual sequence.
WARNING:absl:The exact sequence DPTQFEERHLKFLQQLGKGFGSVEMCRYDPLQDNTGEVVAVKKLQHSTEEHLRDFEREIEILKSLQHDNIVKYKGVCYSAGRRNLKLIMEYLPYGSLRDYLQKHKERIDHIKLLQYTSQICKGMEYLGTKRYIHRDLATRNILVENENRVKIGDFGLTKVLPQDKEYYKVKEPGESPIFWYAPESLTESKFSVASDVWSFGVVLYELFTYIEKSKSPPAEFMRMIGNDKQGQMIVFHLIELLKNNGRLPRPDGCPDEIYMIMTECWNNNVNQRPSFRDLALRVDQIRDN was not found in 4d0w_A. Realigning the template to the actual sequence.
WARNING:absl:The exact sequence FEERHLKFLQQLGKGNFGSVEMCRYDPLQDNTGEVVAVKKLQHSTEEHLRDFEREIEILKSLQHDNIVKYKGVCYNLKLIMEYLPYGSLRDYLQKHKERIDHIKLLQYTSQICKGMEYLGTKRYIHRDLATRNILVENENRVKIGDFGLTKVLPQDKEYYKVKEPGESPIFWYAPESLTESKFSVASDVWSFGVVLYELFTYIEKSKSPPAEFMRMIGNDKQGQMIVFHLIELLKNNGRLPRPDGCPDEIYMIMTECWNNNVNQRPSFRDLALRVDQIRDNMAG was not found in 5cf8_A. Realigning the template to the actual sequence.
WARNING:absl:The exact sequence PTQFEERHLKFLRQLGKGNFGSVEMCRYDPLQDNTGEVVAVKKLQHSTEEHLRDFEREIEILKSLQHDNIVKYKGVCYNLKLIMEFLPYGSLREYLQKHKERIDHIKLLQYTSQICKGMEYLGTKRYIHRDLATRNILVENENRVKIGDFGLTKVLPQDKEYYKVKEPGESPIFWYAPESLTESKFSVASDVWSFGVVLYELFTYIEKSKSPPAEFMRMIGNDKQGQMIVFHLIELLKNNGRLPRPDGCPDEIYMIMTECWNNNVNQRPSFRDLALRVDQIRDNMAG was not found in 4e6d_B. Realigning the template to the actual sequence.
WARNING:absl:The exact sequence EERHLKFLQQLGKGNFGSVEMCRYDPLQDNTGEVVAVKKLQHSTEEHLRDFEREIEILKSLQHDNIVKYKGVCYSNLKLIMEYLPYGSLRDYLQKHKERIDHIKLLQYTSQICKGMEYLGTKRYIHRDLATRNILVENENRVKIGDFGLTKVLPQDKEYYKVKEPGESPIFWYAPESLTESKFSVASDVWSFGVVLYELFTYIEKSKSPPAEFMRMIGNDKQGQMIVFHLIELLKNNGRLPRPDGCPDEIYMIMTECWNNNVNQRPSFRDLALRVDQIRDNMAG was not found in 3tjd_A. Realigning the template to the actual sequence.
WARNING:absl:The exact sequence FEDRDPTQFEERHLKFLQQLGKGSVEMCRYDPLQDNTGEVVAVKKLQHSTEEHLRDFEREIEILKSLQHDNIVKYKGVCYSAGRRNLKLIMEYLPYGSLRDYLQKHKERIDHIKLLQYTSQICKGMEYLGTKRYIHRDLATRNILVENENRVKIGDFGLTKVLPQDKEYYKVKEPGESPIFWYAPESLTESKFSVASDVWSFGVVLYELFTYIEKSKSPPAEFMRMIGNDKQGQMIVFHLIELLKNNGRLPRPDGCPDEIYMIMTECWNNNVNQRPSFRDLALRVDQIRDN was not found in 5wev_A. Realigning the template to the actual sequence.
WARNING:absl:The exact sequence DPTQFEERHLKFLQQLGKGNFGSVEMCRYDPLQDNTGEVVAVKKLQHSTEEHLRDFEREIEILKSLQHDNIVKYKGVCYSAGRRNLKLIMEYLPYGSLRDYLQKHKERIDHIKLLQYTSQICKGMEYLGTKRYIHRDLATRNILVENENRVKIGDFGLTKVLPQDKEYYKVKESPIFWYAPESLTESKFSVASDVWSFGVVLYELFTYIEKSKSPPAEFMRMIGNDKQGQMIVFHLIELLKNNGRLPRPDGCPDEIYMIMTECWNNNVNQRPSFRDLALRVDQIRDNM was not found in 5usy_B. Realigning the template to the actual sequence.
WARNING:absl:The exact sequence EERHLKFLQQLGKGNFGSVEMCRYDPLQDNTGEVVAVKKLQHSTEEHLRDFEREIEILKSLQHDNIVKYKGVCYSNLKLIMEYLPYGSLRDYLQKHKERIDHIKLLQYTSQICKGMEYLGTKRYIHRDLATRNILVENENRVKIGDFGLTKVLPQDKEYYKVKEPGESPIFWYAPESLTESKFSVASDVWSFGVVLYELFTYIEKSKSPPAEFMRMIGNDKQGQMIVFHLIELLKNNGRLPRPDGCPDEIYMIMTECWNNNVNQRPSFRDLALRVDQIRDQMAG was not found in 2b7a_A. Realigning the template to the actual sequence.
WARNING:absl:The exact sequence QFEERHLKFLQQLGKGNFGSVEMCRYDPLQDNTGEVVAVKKLQHSTEEHLRDFEREIEILKSLQHDNIVKYKGVCYSAGRRNLKLIMEYLPYGSLRDYLQKHKERIDHIKLLQYTSQICKGMEYLGTKRYIHRDLATRNILVENENRVKIGDFGLTKVLPQDKEYYKVSPIFWYAPESLTESKFSVASDVWSFGVVLYELFTYIEKSKSPPAEFMRMIGNDKQGQMIVFHLIELLKNNGRLPRPDGCPDEIYMIMTECWNNNVNQRPSFRDLALRVDQIRDNMAG was not found in 3rvg_A. Realigning the template to the actual sequence.
COMPLETE: 100%|██████████| 150/150 [elapsed: 00:07 remaining: 00:00]
WARNING:absl:The exact sequence KEISVIGVPMDLGQMRRGVDMGPSAIRYAGVIERIEEIGYDVKDMGDICIENTKLRNLTQVATVCNELASKVDHIIEEGRFPLVLGGDHSIAIGTLAGVAKHYKNLGVIWYDAHGDLNTEETSPSGNIHGMSLAASLGYGHSSLVDLYGAYPKVKKENVVIIGARALDEGEKDFIRNEGIKVFSMHEIDRMGMTAVMEETIAYLSHTDGVHLSLDLDGLDPHDAPGVGTPVIGGLSYRESHLAMEMLAEADIITSAEFVEVNTILDERNRTATTAVALMGSLFGE was not found in 6nbk_D. Realigning the template to the actual sequence.
WARNING:absl:The exact sequence KEISVIGVPMDLGQMRRGVDMGPSAIRYAGVIERIEEIGYDVKDMGDICINTKLRNLTQVATVCNELASKVDHIIEEGRFPLVLGGDHSIAIGTLAGVAKHYKNLGVIWYDAHGDLNTEETSPSGNIHGMSLAASLGYGHSSLVDLYGAYPKVKKENVVIIGARALDEGEKDFIRNEGIKVFSMHEIDRMGMTAVMEETIAYLSHTDGVHLSLDLDGLDPHDAPGVGTPVIGGLSYRESHLAMEMLAEADIITSAEFVEVNTILDERNRTATTAVALMGSLFGE was not found in 6nbk_C. Realigning the template to the actual sequence.
WARNING:absl:The exact sequence KEISVIGVPMDLGQMRRGVDMGPSAIRYAGVIERIEEIGYDVKDMGDICIEENTKLRNLTQVATVCNELASKVDHIIEEGRFPLVLGGDHSIAIGTLAGVAKHYKNLGVIWYDAHGDLNTEETSPSGNIHGMSLAASLGYGHSSLVDLYGAYPKVKKENVVIIGARALDEGEKDFIRNEGIKVFSMHEIDRMGMTAVMEETIAYLSHTDGVHLSLDLDGLDPHDAPGVGTPVIGGLSYRESHLAMEMLAEADIITSAEFVEVNTILDERNRTATTAVALMGSLFGE was not found in 6nbk_A. Realigning the template to the actual sequence.
WARNING:absl:The exact sequence DKTISVIGMPMDLGQARRGVDMGPSAIRYAHLIERLSDMGYTVEDLGDIPINELKNLNSVLAGNEKLAQKVNKVIEEKKFPLVLGGDHSIAIGTLAGTAKHYDNLGVIWYDAHGDLNTLETSPSGNIHGMPLAVSLGIGHESLVNLEGYAPKIKPENVVIIGARSLDEGERKYIKESGMKVYTMHEIDRLGMTKVIEETLDYLSACDGVHLSLDLDGLDPNDAPGVGTPVVGGISYRESHLAMEMLYDAGIITSAEFVEVNPILDHKNKTGKTAVELVESLLGK was not found in 6nfp_E. Realigning the template to the actual sequence.
WARNING:absl:The exact sequence DKTISVIGMPMDLGQARRGVDMGPSAIRYAHLIERLSDMGYTVEDLGDIPINREELKNLNSVLAGNEKLAQKVNKVIEEKKFPLVLGGDHSIAIGTLAGTAKHYDNLGVIWYDAHGDLNTLETSPSGNIHGMPLAVSLGIGHESLVNLEGYAPKIKPENVVIIGARSLDEGERKYIKESGMKVYTMHEIDRLGMTKVIEETLDYLSACDGVHLSLDLDGLDPNDAPGVGTPVVGGISYRESHLAMEMLYDAGIITSAEFVEVNPILDHKNKTGKTAVELVESLLGK was not found in 6nfp_F. Realigning the template to the actual sequence.
WARNING:absl:The exact sequence KTISVIGMPMDLGQARRGVDMGPSAIRYAHLIERLSDMGYTVEDLGDIPINREDEELKNLNSVLAGNEKLAQKVNKVIEEKKFPLVLGGDHSIAIGTLAGTAKHYDNLGVIWYDAHGDLNTLETSPSGNIHGMPLAVSLGIGHESLVNLEGYAPKIKPENVVIIGARSLDEGERKYIKESGMKVYTMHEIDRLGMTKVIEETLDYLSACDGVHLSLDLDGLDPNDAPGVGTPVVGGISYRESHLAMEMLYDAGIITSAEFVEVNPILDHKNKTGKTAVELVESLLGK was not found in 6nfp_C. Realigning the template to the actual sequence.
WARNING:absl:The exact sequence KTISVIGMPMDLGQARRGVDMGPSAIRYAHLIERLSDMGYTVEDLGDIPINREKIDEELKNLNSVLAGNEKLAQKVNKVIEEKKFPLVLGGDHSIAIGTLAGTAKHYDNLGVIWYDAHGDLNTLETSPSGNIHGMPLAVSLGIGHESLVNLEGYAPKIKPENVVIIGARSLDEGERKYIKESGMKVYTMHEIDRLGMTKVIEETLDYLSACDGVHLSLDLDGLDPNDAPGVGTPVVGGISYRESHLAMEMLYDAGIITSAEFVEVNPILDHKNKTGKTAVELVESLLGKK was not found in 6nfp_A. Realigning the template to the actual sequence.
WARNING:absl:The exact sequence KTISVIGMPMDLGQARRGVDMGPSAIRYAHLIERLSDMGYTVEDLGDIPINNLNSVLAGNEKLAQKVNKVIEEKKFPLVLGGDHSIAIGTLAGTAKHYDNLGVIWYDAHGDLNTLETSPSGNIHGMPLAVSLGIGHESLVNLEGYAPKIKPENVVIIGARSLDEGERKYIKESGMKVYTMHEIDRLGMTKVIEETLDYLSACDGVHLSLDLDGLDPNDAPGVGTPVVGGISYRESHLAMEMLYDAGIITSAEFVEVNPILDHKNKTGKTAVELVESLLGK was not found in 6dkt_D. Realigning the template to the actual sequence.
WARNING:absl:The exact sequence KTISVIGMPMDLGQARRGVDMGPSAIRYAHLIERLSDMGYTVEDLGDIPINNLNSVLAGNEKLAQKVNKVIEEKKFPLVLGGDHSIAIGTLAGTAKHYDNLGVIWYDAHGDLNTLESGNIHGMPLAVSLGIGHESLVNLEGYAPKIKPENVVIIGARSLDEGERKYIKESGMKVYTMHEIDRLGMTKVIEETLDYLSACDGVHLSLDLDGLDPNDAPGVGTPVVGGISYRESHLAMEMLYDAGIITSAEFVEVNPILDHKNKTGKTAVELVESLLGK was not found in 6dkt_F. Realigning the template to the actual sequence.
WARNING:absl:The exact sequence RVAVVGVPMDLGANRRGVDMGPSALRYARLLEQLEDLGYTVEDLGDVPVSLARLAYLEEIRAAALVLKERLAALPEGVFPIVLGGDHSLSMGSVAGAARGRRVGVVWVDAHADFNTPETSPSGNVHGMPLAVLSGLGHPRLTEVFRAVDPKDVVLVGVRSLDPGEKRLLKEAGVRVYTMHEVDRLGVARIAEEVLKHLQGLPLHVSLDADVLDPTLAPGVGTPVPGGLTYREAHLLMEILAESGRVQSLDLVEVNPILDERNRTAEMLVGLALSLLGKR was not found in 2ef4_A. Realigning the template to the actual sequence.
WARNING:absl:The exact sequence RVAVVGVPMDLGVDMGPSALRYARLLEQLEDLGYTVEDLGDVPVSLAYLEEIRAAALVLKERLAALPEGVFPIVLGGDHSLSMGSVAGAARGRRVGVVWVDAHADFNTPETSSGNVHGMPLAVLSGLGHPRLTEVFRAVDPKDVVLVGVRSLDPGEKRLLKEAGVRVYTMHEVDRLGVARIAEEVLKHLQGLPLHVSLDADVLDPTLAPGVGTPVPGGLTYREAHLLMEILAESGRVQSLDLVEVNPILDERNRTAEMLVGLALSLLGKR was not found in 2eiv_M. Realigning the template to the actual sequence.
COMPLETE: 100%|██████████| 300/300 [elapsed: 00:05 remaining: 00:00]
COMPLETE: 100%|██████████| 300/300 [elapsed: 00:05 remaining: 00:00]
COMPLETE: 100%|██████████| 150/150 [elapsed: 00:06 remaining: 00:00]
WARNING:absl:The exact sequence CLAEGTRIFDPVTGTTHRIEDVVDGRKPIHVVAAAKDGTLHARPVVSWFDQGTRDVIGLRIAGGAILWATPDHKVLTEYGWRAAGELRKGDRVAQPRRFDGFMLAEELRYSVIREVLPTRRARTFDLEVEELHTLVAEGVVVH was not found in 2imz_A. Realigning the template to the actual sequence.
WARNING:absl:The exact sequence CLAEGTRIFDPVTGTTHRIEDVVDGRKPIHVVAAAKDGTLHARPVVSWFDQGTRDVIGLRIAGGAILWATPDHKVLTEYGWRAAGELRKGDRVAQPRRFDGFEELRYSVIREVLPTRRARTFDLEVEELHTLVAEGVVVH was not found in 2imz_B. Realigning the template to the actual sequence.
代码
文本

2.3 Model Prediction

Uni-Fold prediction

代码
文本
[18]
import torch
import json
from unifold.config import model_config
from unifold.modules.alphafold import AlphaFold
from unicore.utils import (
tensor_tree_map,
)
from unifold.symmetry import (
UFSymmetry,
uf_symmetry_config,
assembly_from_prediction,
)

def automatic_chunk_size(seq_len):
if seq_len < 512:
chunk_size = 256
elif seq_len < 1024:
chunk_size = 128
elif seq_len < 2048:
chunk_size = 32
elif seq_len < 3072:
chunk_size = 16
else:
chunk_size = 1
return chunk_size

for input_dict in input_json:
output_dir = input_dict['output_dir']
target_id = input_dict['target_id']
is_multimer = input_dict['is_multimer']
symmetry_group = input_dict['symmetry_group']

if symmetry_group is not None:
model_name = "uf_symmetry"
param_path = "/root/params/uf_symmetry.pt"
elif is_multimer:
model_name = "multimer_ft"
param_path = "/root/params/multimer.unifold.pt"
else:
model_name = "model_2_ft"
param_path = "/root/params/monomer.unifold.pt"

if symmetry_group is None:
config = model_config(model_name)
else:
config = uf_symmetry_config()
config.data.common.max_recycling_iters = max_recycling_iters
config.globals.max_recycling_iters = max_recycling_iters
config.data.predict.num_ensembles = num_ensembles

# faster prediction with large chunk
config.globals.chunk_size = 128
model = AlphaFold(config) if symmetry_group is None else UFSymmetry(config)
print("start to load params {}".format(param_path))
state_dict = torch.load(param_path)["ema"]["params"]
state_dict = {".".join(k.split(".")[1:]): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model = model.to("cuda:0")
model.eval()
model.inference_mode()

# data path is based on target_name
cur_param_path_postfix = os.path.split(param_path)[-1]

print("start to predict {}".format(target_id))
plddts = {}
ptms = {}
best_protein = None
best_score = 0
best_plddt = None
best_pae = None

for seed in range(times):
cur_seed = hash((manual_seed, seed)) % 100000
batch = load_feature_for_one_target(
config,
output_dir,
cur_seed,
is_multimer=is_multimer,
use_uniprot=is_multimer,
symmetry_group=symmetry_group,
)
seq_len = batch["aatype"].shape[-1]
model.globals.chunk_size = automatic_chunk_size(seq_len)

with torch.no_grad():
batch = {
k: torch.as_tensor(v, device="cuda:0")
for k, v in batch.items()
}
shapes = {k: v.shape for k, v in batch.items()}
print(shapes)
t = time.perf_counter()
out = model(batch)
print(f"Inference time: {time.perf_counter() - t}")

def to_float(x):
if x.dtype == torch.bfloat16 or x.dtype == torch.half:
return x.float()
else:
return x

# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda t: t[-1, 0, ...], batch)
batch = tensor_tree_map(to_float, batch)
out = tensor_tree_map(lambda t: t[0, ...], out)
out = tensor_tree_map(to_float, out)
batch = tensor_tree_map(lambda x: np.array(x.cpu()), batch)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)

plddt = out["plddt"]
mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat(
plddt[..., None], residue_constants.atom_type_num, axis=-1
)
# TODO: , may need to reorder chains, based on entity_ids
if symmetry_group is None:
cur_protein = protein.from_prediction(
features=batch, result=out, b_factors=plddt_b_factors
)
else:
plddt_b_factors_assembly = np.concatenate(
[plddt_b_factors for _ in range(batch["symmetry_opers"].shape[0])])
cur_protein = assembly_from_prediction(
result=out, b_factors=plddt_b_factors_assembly,
)
cur_save_name = (
f"{cur_param_path_postfix}_{cur_seed}"
)
plddts[cur_save_name] = str(mean_plddt)
if is_multimer and symmetry_group is None:
ptms[cur_save_name] = str(np.mean(out["iptm+ptm"]))
with open(os.path.join(output_dir, cur_save_name + '.pdb'), "w") as f:
f.write(protein.to_pdb(cur_protein))

if is_multimer and symmetry_group is None:
mean_ptm = np.mean(out["iptm+ptm"])
if mean_ptm>best_score:
best_protein = cur_protein
best_pae = out["predicted_aligned_error"]
best_plddt = out["plddt"]
best_score = mean_ptm
else:
if mean_plddt>best_score:
best_protein = cur_protein
best_plddt = out["plddt"]
best_score = mean_plddt

input_dict['best_protein'] = best_protein
input_dict['best_pae'] = best_pae
input_dict['best_plddt'] = best_plddt
print("plddts", plddts)
score_name = f"{model_name}_{cur_param_path_postfix}"
plddt_fname = score_name + "_plddt.json"
json.dump(plddts, open(os.path.join(output_dir, plddt_fname), "w"), indent=4)
if ptms:
print("ptms", ptms)
ptm_fname = score_name + "_ptm.json"
json.dump(ptms, open(os.path.join(output_dir, ptm_fname), "w"), indent=4)
start to load params /root/params/monomer.unifold.pt
start to predict unifold_bohrium_0
{'aatype': torch.Size([1, 1, 317]), 'residue_index': torch.Size([1, 1, 317]), 'seq_length': torch.Size([1, 1]), 'msa_chains': torch.Size([4, 1, 508, 1]), 'template_aatype': torch.Size([1, 1, 4, 317]), 'template_all_atom_mask': torch.Size([1, 1, 4, 317, 37]), 'template_all_atom_positions': torch.Size([1, 1, 4, 317, 37, 3]), 'bert_mask': torch.Size([4, 1, 508, 317]), 'msa_mask': torch.Size([4, 1, 508, 317]), 'num_recycling_iters': torch.Size([1, 1]), 'is_distillation': torch.Size([4, 1]), 'seq_mask': torch.Size([1, 1, 317]), 'msa_row_mask': torch.Size([4, 1, 508]), 'template_mask': torch.Size([1, 1, 4]), 'template_pseudo_beta': torch.Size([1, 1, 4, 317, 3]), 'template_pseudo_beta_mask': torch.Size([1, 1, 4, 317]), 'template_torsion_angles_sin_cos': torch.Size([1, 1, 4, 317, 7, 2]), 'template_alt_torsion_angles_sin_cos': torch.Size([1, 1, 4, 317, 7, 2]), 'template_torsion_angles_mask': torch.Size([1, 1, 4, 317, 7]), 'residx_atom14_to_atom37': torch.Size([1, 1, 317, 14]), 'residx_atom37_to_atom14': torch.Size([1, 1, 317, 37]), 'atom14_atom_exists': torch.Size([1, 1, 317, 14]), 'atom37_atom_exists': torch.Size([1, 1, 317, 37]), 'target_feat': torch.Size([1, 1, 317, 22]), 'extra_msa': torch.Size([4, 1, 1024, 317]), 'extra_msa_mask': torch.Size([4, 1, 1024, 317]), 'extra_msa_row_mask': torch.Size([4, 1, 1024]), 'true_msa': torch.Size([4, 1, 508, 317]), 'extra_msa_has_deletion': torch.Size([4, 1, 1024, 317]), 'extra_msa_deletion_value': torch.Size([4, 1, 1024, 317]), 'msa_feat': torch.Size([4, 1, 508, 317, 49])}
Inference time: 29.34663464399995
plddts {'monomer.unifold.pt_97923': '0.914682'}
start to load params /root/params/uf_symmetry.pt
start to predict unifold_bohrium_1
{'aatype': torch.Size([1, 1, 287]), 'residue_index': torch.Size([1, 1, 287]), 'seq_length': torch.Size([1, 1]), 'msa_chains': torch.Size([4, 1, 252, 1]), 'template_aatype': torch.Size([1, 1, 4, 287]), 'template_all_atom_mask': torch.Size([1, 1, 4, 287, 37]), 'template_all_atom_positions': torch.Size([1, 1, 4, 287, 37, 3]), 'asym_id': torch.Size([1, 1, 287]), 'sym_id': torch.Size([1, 1, 287]), 'entity_id': torch.Size([1, 1, 287]), 'num_sym': torch.Size([1, 1, 287]), 'assembly_num_chains': torch.Size([1, 1, 1]), 'cluster_bias_mask': torch.Size([1, 1, 252]), 'bert_mask': torch.Size([4, 1, 252, 287]), 'msa_mask': torch.Size([4, 1, 252, 287]), 'asym_len': torch.Size([1, 1, 1]), 'num_recycling_iters': torch.Size([1, 1]), 'is_distillation': torch.Size([4, 1]), 'seq_mask': torch.Size([1, 1, 287]), 'msa_row_mask': torch.Size([4, 1, 252]), 'template_mask': torch.Size([1, 1, 4]), 'template_pseudo_beta': torch.Size([1, 1, 4, 287, 3]), 'template_pseudo_beta_mask': torch.Size([1, 1, 4, 287]), 'template_torsion_angles_sin_cos': torch.Size([1, 1, 4, 287, 7, 2]), 'template_alt_torsion_angles_sin_cos': torch.Size([1, 1, 4, 287, 7, 2]), 'template_torsion_angles_mask': torch.Size([1, 1, 4, 287, 7]), 'residx_atom14_to_atom37': torch.Size([1, 1, 287, 14]), 'residx_atom37_to_atom14': torch.Size([1, 1, 287, 37]), 'atom14_atom_exists': torch.Size([1, 1, 287, 14]), 'atom37_atom_exists': torch.Size([1, 1, 287, 37]), 'target_feat': torch.Size([1, 1, 287, 22]), 'extra_msa': torch.Size([4, 1, 1152, 287]), 'extra_msa_mask': torch.Size([4, 1, 1152, 287]), 'extra_msa_row_mask': torch.Size([4, 1, 1152]), 'true_msa': torch.Size([4, 1, 252, 287]), 'msa_feat': torch.Size([4, 1, 252, 287, 49]), 'extra_msa_has_deletion': torch.Size([4, 1, 1152, 287]), 'extra_msa_deletion_value': torch.Size([4, 1, 1152, 287]), 'symmetry_opers': torch.Size([1, 1, 2, 4, 4]), 'pseudo_residue_feat': torch.Size([1, 1, 8]), 'num_asym': torch.Size([1, 1])}
Inference time: 15.05725186899997
plddts {'uf_symmetry.pt_97923': '0.93517303'}
start to load params /root/params/uf_symmetry.pt
start to predict unifold_bohrium_2
{'aatype': torch.Size([1, 1, 212]), 'residue_index': torch.Size([1, 1, 212]), 'seq_length': torch.Size([1, 1]), 'msa_chains': torch.Size([4, 1, 252, 1]), 'template_aatype': torch.Size([1, 1, 4, 212]), 'template_all_atom_mask': torch.Size([1, 1, 4, 212, 37]), 'template_all_atom_positions': torch.Size([1, 1, 4, 212, 37, 3]), 'asym_id': torch.Size([1, 1, 212]), 'sym_id': torch.Size([1, 1, 212]), 'entity_id': torch.Size([1, 1, 212]), 'num_sym': torch.Size([1, 1, 212]), 'assembly_num_chains': torch.Size([1, 1, 1]), 'cluster_bias_mask': torch.Size([1, 1, 252]), 'bert_mask': torch.Size([4, 1, 252, 212]), 'msa_mask': torch.Size([4, 1, 252, 212]), 'asym_len': torch.Size([1, 1, 2]), 'num_recycling_iters': torch.Size([1, 1]), 'is_distillation': torch.Size([4, 1]), 'seq_mask': torch.Size([1, 1, 212]), 'msa_row_mask': torch.Size([4, 1, 252]), 'template_mask': torch.Size([1, 1, 4]), 'template_pseudo_beta': torch.Size([1, 1, 4, 212, 3]), 'template_pseudo_beta_mask': torch.Size([1, 1, 4, 212]), 'template_torsion_angles_sin_cos': torch.Size([1, 1, 4, 212, 7, 2]), 'template_alt_torsion_angles_sin_cos': torch.Size([1, 1, 4, 212, 7, 2]), 'template_torsion_angles_mask': torch.Size([1, 1, 4, 212, 7]), 'residx_atom14_to_atom37': torch.Size([1, 1, 212, 14]), 'residx_atom37_to_atom14': torch.Size([1, 1, 212, 37]), 'atom14_atom_exists': torch.Size([1, 1, 212, 14]), 'atom37_atom_exists': torch.Size([1, 1, 212, 37]), 'target_feat': torch.Size([1, 1, 212, 22]), 'extra_msa': torch.Size([4, 1, 336, 212]), 'extra_msa_mask': torch.Size([4, 1, 336, 212]), 'extra_msa_row_mask': torch.Size([4, 1, 336]), 'true_msa': torch.Size([4, 1, 252, 212]), 'msa_feat': torch.Size([4, 1, 252, 212, 49]), 'extra_msa_has_deletion': torch.Size([4, 1, 336, 212]), 'extra_msa_deletion_value': torch.Size([4, 1, 336, 212]), 'symmetry_opers': torch.Size([1, 1, 2, 4, 4]), 'pseudo_residue_feat': torch.Size([1, 1, 8]), 'num_asym': torch.Size([1, 1])}
Inference time: 8.176409836000005
plddts {'uf_symmetry.pt_97923': '0.83992827'}
start to load params /root/params/uf_symmetry.pt
start to predict unifold_bohrium_3
{'aatype': torch.Size([1, 1, 156]), 'residue_index': torch.Size([1, 1, 156]), 'seq_length': torch.Size([1, 1]), 'msa_chains': torch.Size([4, 1, 252, 1]), 'template_aatype': torch.Size([1, 1, 4, 156]), 'template_all_atom_mask': torch.Size([1, 1, 4, 156, 37]), 'template_all_atom_positions': torch.Size([1, 1, 4, 156, 37, 3]), 'asym_id': torch.Size([1, 1, 156]), 'sym_id': torch.Size([1, 1, 156]), 'entity_id': torch.Size([1, 1, 156]), 'num_sym': torch.Size([1, 1, 156]), 'assembly_num_chains': torch.Size([1, 1, 1]), 'cluster_bias_mask': torch.Size([1, 1, 252]), 'bert_mask': torch.Size([4, 1, 252, 156]), 'msa_mask': torch.Size([4, 1, 252, 156]), 'asym_len': torch.Size([1, 1, 1]), 'num_recycling_iters': torch.Size([1, 1]), 'is_distillation': torch.Size([4, 1]), 'seq_mask': torch.Size([1, 1, 156]), 'msa_row_mask': torch.Size([4, 1, 252]), 'template_mask': torch.Size([1, 1, 4]), 'template_pseudo_beta': torch.Size([1, 1, 4, 156, 3]), 'template_pseudo_beta_mask': torch.Size([1, 1, 4, 156]), 'template_torsion_angles_sin_cos': torch.Size([1, 1, 4, 156, 7, 2]), 'template_alt_torsion_angles_sin_cos': torch.Size([1, 1, 4, 156, 7, 2]), 'template_torsion_angles_mask': torch.Size([1, 1, 4, 156, 7]), 'residx_atom14_to_atom37': torch.Size([1, 1, 156, 14]), 'residx_atom37_to_atom14': torch.Size([1, 1, 156, 37]), 'atom14_atom_exists': torch.Size([1, 1, 156, 14]), 'atom37_atom_exists': torch.Size([1, 1, 156, 37]), 'target_feat': torch.Size([1, 1, 156, 22]), 'extra_msa': torch.Size([4, 1, 1152, 156]), 'extra_msa_mask': torch.Size([4, 1, 1152, 156]), 'extra_msa_row_mask': torch.Size([4, 1, 1152]), 'true_msa': torch.Size([4, 1, 252, 156]), 'msa_feat': torch.Size([4, 1, 252, 156, 49]), 'extra_msa_has_deletion': torch.Size([4, 1, 1152, 156]), 'extra_msa_deletion_value': torch.Size([4, 1, 1152, 156]), 'symmetry_opers': torch.Size([1, 1, 3, 4, 4]), 'pseudo_residue_feat': torch.Size([1, 1, 8]), 'num_asym': torch.Size([1, 1])}
Inference time: 5.369833519999986
plddts {'uf_symmetry.pt_97923': '0.9115123'}
代码
文本

3. Visualization

Visualize the structure and lDDT of Uni-Fold output.

Construct multiclass b-factors to indicate confidence bands

  • 0=very low, 1=low, 2=confident, 3=very high
  • Color bands for visualizing plddt
代码
文本
[19]
import py3Dmol
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from IPython import display
from ipywidgets import GridspecLayout
from ipywidgets import Output

show_sidechains = False #@param {type:"boolean"}
dpi = 100 #@param {type:"integer"}

PLDDT_BANDS = [(0., 0.50, '#FF7D45'),
(0.50, 0.70, '#FFDB13'),
(0.70, 0.90, '#65CBF3'),
(0.90, 1.00, '#0053D6')]


# --- Visualise the prediction & confidence ---
def plot_plddt_legend():
"""Plots the legend for pLDDT."""
thresh = ['Very low (pLDDT < 50)',
'Low (70 > pLDDT > 50)',
'Confident (90 > pLDDT > 70)',
'Very high (pLDDT > 90)']

colors = [x[2] for x in PLDDT_BANDS]

plt.figure(figsize=(2, 2))
for c in colors:
plt.bar(0, 0, color=c)
plt.legend(thresh, frameon=False, loc='center', fontsize=20)
plt.xticks([])
plt.yticks([])
ax = plt.gca()
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['bottom'].set_visible(False)
plt.title('Model Confidence', fontsize=20, pad=20)
return plt

max_display_cnt = min(max_display_cnt, len(input_json))
for input_dict in input_json[:max_display_cnt]:
output_dir = input_dict['output_dir']
is_multimer = input_dict['is_multimer']
symmetry_group = input_dict['symmetry_group']
best_protein = input_dict['best_protein']
best_pae = input_dict['best_pae']
plddt = input_dict['best_plddt']
to_visualize_pdb = protein.to_pdb(best_protein)

if is_multimer and symmetry_group is None:
multichain_view = py3Dmol.view(width=800, height=600)
multichain_view.addModelsAsFrames(to_visualize_pdb)
multichain_style = {'cartoon': {'colorscheme': 'chain'}}
multichain_view.setStyle({'model': -1}, multichain_style)
multichain_view.zoomTo()
multichain_view.show()

# Color the structure by per-residue pLDDT
view = py3Dmol.view(width=800, height=600)
view.addModelsAsFrames(to_visualize_pdb)
style = {'cartoon': {'colorscheme': {'prop': 'b', 'gradient': 'roygb', 'min': 0.5, 'max': 0.9}}}
if show_sidechains:
style['stick'] = {}
view.setStyle({'model': -1}, style)
view.zoomTo()

grid = GridspecLayout(1, 2)
out = Output()
with out:
view.show()
grid[0, 0] = out

out = Output()
with out:
plot_plddt_legend().show()
grid[0, 1] = out

display.display(grid)

# Display pLDDT and predicted aligned error (if output by the model).
if is_multimer and symmetry_group is None:
num_plots = 2
else:
num_plots = 1

plt.figure(figsize=[8 * num_plots, 6])
plt.subplot(1, num_plots, 1)
plt.plot(plddt * 100)
plt.title('Predicted LDDT')
plt.xlabel('Residue')
plt.ylabel('pLDDT')
plt.grid()
plt.show()
plddt_svg_path = os.path.join(output_dir, 'plddt.svg')
plt.savefig(plddt_svg_path, dpi=dpi, bbox_inches='tight')

if num_plots == 2:
plt.subplot(1, 2, 2)
max_pae = np.max(best_pae)
colors = ['#0F006F', '#245AE6', '#55CCFF', '#FFFFFF']

cmap = LinearSegmentedColormap.from_list('mymap', colors)
im = plt.imshow(best_pae, vmin=0., vmax=max_pae, cmap=cmap)
plt.colorbar(im, fraction=0.046, pad=0.04)

# Display lines at chain boundaries.
total_num_res = best_protein.residue_index.shape[-1]
chain_ids = best_protein.chain_index
for chain_boundary in np.nonzero(chain_ids[:-1] - chain_ids[1:]):
if chain_boundary.size:
plt.plot([0, total_num_res], [chain_boundary, chain_boundary], color='red')
plt.plot([chain_boundary, chain_boundary], [0, total_num_res], color='red')

plt.title('Predicted Aligned Error')
plt.xlabel('Scored residue')
plt.ylabel('Aligned residue')
pae_svg_path = os.path.join(output_dir, 'pae.svg')
plt.savefig(pae_svg_path, dpi=dpi, bbox_inches='tight')

<Figure size 432x288 with 0 Axes>
<Figure size 432x288 with 0 Axes>
<Figure size 432x288 with 0 Axes>
代码
文本
Uni-Fold
PyTorch
Uni-FoldPyTorch
点个赞吧
推荐阅读
公开
Uni-Fold Notebook
Uni-FoldPyTorch
Uni-FoldPyTorch
我是地球人
发布于 2023-07-13
5 赞24 转存文件2 评论
公开
Unifold Batch Inference
Uni-Fold
Uni-Fold
csg
发布于 2023-09-24
3 赞5 转存文件
{/**/}