Bohrium
robot
新建

空间站广场

论文
Notebooks
比赛
课程
Apps
我的主页
我的Notebooks
我的论文库
我的足迹

我的工作空间

任务
节点
文件
数据集
镜像
项目
数据库
公开
使用PiFold进行蛋白质逆折叠
AI4S
Deep Learning
AI4SDeep Learning
ABOWLofFish
发布于 2023-11-07
推荐镜像 :Basic Image:ubuntu18.04-py3.10-mamba-cuda10.2
推荐机型 :c12_m92_1 * NVIDIA V100
赞 1
引言
相关工作
1. 环境配置
2. 加载模型
3. CATH4.2 结果
4.如果你想从头训练PiFold

引言

代码
文本

蛋白质逆折叠(Inverse Folding),是指通过给定蛋白质的折叠结构,预测其对应的氨基酸序列。与蛋白质折叠任务相反,蛋白质逆折叠任务是根据已知的蛋白质结构来推断其组成的氨基酸序列。这个任务对于理解蛋白质的结构与功能之间的关系以及设计新的蛋白质具有重要意义。

代码
文本

相关工作

代码
文本

大多数现有工作沿用了端到端的模型架构,使用图神经网络(GNN)来对蛋白质空间结构进行编码,然后使用自回归的方式进行解码

代码
文本

相对之前的工作,新推出的PiFold 模型的一大特点就是: Effciency & Accuracy Banlancing
PiFold通过多层感知机的方式对蛋白质图的节点和边进行更新,在预测时摒弃了先前自回归的方式直接进行预测,且模型参数量保持在5.5M,大大减少了推理阶段的时间。PiFold相对于先前的SOTA,将三个测试集上的序列恢复率分别提高了4~5%
本文将复现其CATH4.2数据集上的结果

代码
文本

参考colab链接 PiFold Benchmark
github链接 https://github.com/A4Bio/PiFold

代码
文本

1. 环境配置

代码
文本
[1]
! nvidia-smi
Tue Nov  7 05:51:26 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   68C    P8    12W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
代码
文本
[2]
# Add this in a Google Colab cell to install the correct version of Pytorch Geometric.
import torch

def format_pytorch_version(version):
return version.split('+')[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
return 'cu' + version.replace('.', '')

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

!pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-geometric
Looking in links: https://pytorch-geometric.com/whl/torch-2.1.0+cu118.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcu118/torch_scatter-2.1.2%2Bpt21cu118-cp310-cp310-linux_x86_64.whl (10.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 10.2/10.2 MB 99.6 MB/s eta 0:00:00
Installing collected packages: torch-scatter
Successfully installed torch-scatter-2.1.2+pt21cu118
Looking in links: https://pytorch-geometric.com/whl/torch-2.1.0+cu118.html
Collecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcu118/torch_sparse-0.6.18%2Bpt21cu118-cp310-cp310-linux_x86_64.whl (4.9 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.9/4.9 MB 33.0 MB/s eta 0:00:00
Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from torch-sparse) (1.11.3)
Requirement already satisfied: numpy<1.28.0,>=1.21.6 in /usr/local/lib/python3.10/dist-packages (from scipy->torch-sparse) (1.23.5)
Installing collected packages: torch-sparse
Successfully installed torch-sparse-0.6.18+pt21cu118
Looking in links: https://pytorch-geometric.com/whl/torch-2.1.0+cu118.html
Collecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcu118/torch_cluster-1.6.3%2Bpt21cu118-cp310-cp310-linux_x86_64.whl (3.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.3/3.3 MB 30.1 MB/s eta 0:00:00
Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from torch-cluster) (1.11.3)
Requirement already satisfied: numpy<1.28.0,>=1.21.6 in /usr/local/lib/python3.10/dist-packages (from scipy->torch-cluster) (1.23.5)
Installing collected packages: torch-cluster
Successfully installed torch-cluster-1.6.3+pt21cu118
Looking in links: https://pytorch-geometric.com/whl/torch-2.1.0+cu118.html
Collecting torch-spline-conv
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcu118/torch_spline_conv-1.2.2%2Bpt21cu118-cp310-cp310-linux_x86_64.whl (887 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 887.8/887.8 kB 16.8 MB/s eta 0:00:00
Installing collected packages: torch-spline-conv
Successfully installed torch-spline-conv-1.2.2+pt21cu118
Collecting torch-geometric
  Downloading torch_geometric-2.4.0-py3-none-any.whl (1.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/1.0 MB 10.7 MB/s eta 0:00:00
Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from torch-geometric) (4.66.1)
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torch-geometric) (1.23.5)
Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from torch-geometric) (1.11.3)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch-geometric) (3.1.2)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from torch-geometric) (2.31.0)
Requirement already satisfied: pyparsing in /usr/local/lib/python3.10/dist-packages (from torch-geometric) (3.1.1)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (from torch-geometric) (1.2.2)
Requirement already satisfied: psutil>=5.8.0 in /usr/local/lib/python3.10/dist-packages (from torch-geometric) (5.9.5)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch-geometric) (2.1.3)
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->torch-geometric) (3.3.1)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->torch-geometric) (3.4)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->torch-geometric) (2.0.7)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->torch-geometric) (2023.7.22)
Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->torch-geometric) (1.3.2)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->torch-geometric) (3.2.0)
Installing collected packages: torch-geometric
Successfully installed torch-geometric-2.4.0
代码
文本
[3]
import json, time, os, sys, glob
import numpy as np

if not os.path.isdir("ProDesign"):
os.system("git clone -q https://github.com/A4Bio/ProDesign.git")
sys.path.append('/content/ProDesign')
代码
文本
[4]
!mkdir -p ProDesign/data/cath
!mkdir -p ProDesign/data/ts
!wget -O ProDesign/data/cath.zip https://github.com/A4Bio/PiFold/releases/download/Training%26Data/cath4.2.zip
!unzip ProDesign/data/cath.zip -d ProDesign/data/cath
!mv ProDesign/data/cath/cath4.2/* ProDesign/data/cath/

!wget -O ProDesign/data/ts.zip https://github.com/A4Bio/PiFold/releases/download/Training%26Data/ts.zip
!unzip ProDesign/data/ts.zip -d ProDesign/data/

!mkdir -p ProDesign/results/ProDesign
!wget -O ProDesign/results/ProDesign/checkpoint.pth https://github.com/A4Bio/PiFold/releases/download/Training%26Data/checkpoint.pth
--2023-11-07 05:52:01--  https://github.com/A4Bio/PiFold/releases/download/Training%26Data/cath4.2.zip
Resolving github.com (github.com)... 140.82.121.3
Connecting to github.com (github.com)|140.82.121.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/538405150/36fd2f8c-5e3b-4a07-9741-dc881be4a4e5?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20231107%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20231107T055201Z&X-Amz-Expires=300&X-Amz-Signature=5f1938c826837e08ac56930a5c0b4e582912694e95f6a786bee0a557732af9cd&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=538405150&response-content-disposition=attachment%3B%20filename%3Dcath4.2.zip&response-content-type=application%2Foctet-stream [following]
--2023-11-07 05:52:01--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/538405150/36fd2f8c-5e3b-4a07-9741-dc881be4a4e5?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20231107%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20231107T055201Z&X-Amz-Expires=300&X-Amz-Signature=5f1938c826837e08ac56930a5c0b4e582912694e95f6a786bee0a557732af9cd&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=538405150&response-content-disposition=attachment%3B%20filename%3Dcath4.2.zip&response-content-type=application%2Foctet-stream
Resolving objects.githubusercontent.com (objects.githubusercontent.com)... 185.199.108.133, 185.199.110.133, 185.199.109.133, ...
Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 188153738 (179M) [application/octet-stream]
Saving to: ‘ProDesign/data/cath.zip’

ProDesign/data/cath 100%[===================>] 179.44M  48.4MB/s    in 4.0s    

2023-11-07 05:52:05 (44.7 MB/s) - ‘ProDesign/data/cath.zip’ saved [188153738/188153738]

Archive:  ProDesign/data/cath.zip
  inflating: ProDesign/data/cath/cath4.2/chain_set.jsonl  
  inflating: ProDesign/data/cath/cath4.2/chain_set_splits.json  
  inflating: ProDesign/data/cath/cath4.2/download_cath.sh  
  inflating: ProDesign/data/cath/cath4.2/ollikainen_set.jsonl  
  inflating: ProDesign/data/cath/cath4.2/remove.json  
  inflating: ProDesign/data/cath/cath4.2/test_split_L100.json  
  inflating: ProDesign/data/cath/cath4.2/test_split_sc.json  
--2023-11-07 05:52:13--  https://github.com/A4Bio/PiFold/releases/download/Training%26Data/ts.zip
Resolving github.com (github.com)... 140.82.121.3
Connecting to github.com (github.com)|140.82.121.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/538405150/1c1dd67b-49ba-4d5f-a8de-9b4d48ae7fee?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20231107%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20231107T055213Z&X-Amz-Expires=300&X-Amz-Signature=10794e82452afec5455b63dd111e6d898f2cb5fc68c51eefdb831cbb8a3b84d4&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=538405150&response-content-disposition=attachment%3B%20filename%3Dts.zip&response-content-type=application%2Foctet-stream [following]
--2023-11-07 05:52:13--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/538405150/1c1dd67b-49ba-4d5f-a8de-9b4d48ae7fee?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20231107%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20231107T055213Z&X-Amz-Expires=300&X-Amz-Signature=10794e82452afec5455b63dd111e6d898f2cb5fc68c51eefdb831cbb8a3b84d4&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=538405150&response-content-disposition=attachment%3B%20filename%3Dts.zip&response-content-type=application%2Foctet-stream
Resolving objects.githubusercontent.com (objects.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5451231 (5.2M) [application/octet-stream]
Saving to: ‘ProDesign/data/ts.zip’

ProDesign/data/ts.z 100%[===================>]   5.20M  --.-KB/s    in 0.1s    

2023-11-07 05:52:13 (52.7 MB/s) - ‘ProDesign/data/ts.zip’ saved [5451231/5451231]

Archive:  ProDesign/data/ts.zip
  inflating: ProDesign/data/ts/ts50.json  
  inflating: ProDesign/data/ts/ts500.json  
--2023-11-07 05:52:14--  https://github.com/A4Bio/PiFold/releases/download/Training%26Data/checkpoint.pth
Resolving github.com (github.com)... 140.82.121.3
Connecting to github.com (github.com)|140.82.121.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/538405150/9a5583fd-598e-4249-b69a-0684ce1d796a?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20231107%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20231107T055214Z&X-Amz-Expires=300&X-Amz-Signature=a0ff47ef131b8dbd0ec5a32144a20e01cb61e7b6c206fdac5c38347792f3a860&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=538405150&response-content-disposition=attachment%3B%20filename%3Dcheckpoint.pth&response-content-type=application%2Foctet-stream [following]
--2023-11-07 05:52:14--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/538405150/9a5583fd-598e-4249-b69a-0684ce1d796a?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20231107%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20231107T055214Z&X-Amz-Expires=300&X-Amz-Signature=a0ff47ef131b8dbd0ec5a32144a20e01cb61e7b6c206fdac5c38347792f3a860&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=538405150&response-content-disposition=attachment%3B%20filename%3Dcheckpoint.pth&response-content-type=application%2Foctet-stream
Resolving objects.githubusercontent.com (objects.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 26768195 (26M) [application/octet-stream]
Saving to: ‘ProDesign/results/ProDesign/checkpoint.pth’

ProDesign/results/P 100%[===================>]  25.53M  48.0MB/s    in 0.5s    

2023-11-07 05:52:15 (48.0 MB/s) - ‘ProDesign/results/ProDesign/checkpoint.pth’ saved [26768195/26768195]

代码
文本

2. 加载模型

代码
文本
[5]
def get_parser():
import argparse
parser = argparse.ArgumentParser()
# Set-up parameters
parser.add_argument('--device', default='cuda', type=str, help='Name of device to use for tensor computations (cuda/cpu)')
parser.add_argument('--display_step', default=10, type=int, help='Interval in batches between display of training metrics')
parser.add_argument('--res_dir', default='ProDesign/results', type=str)
parser.add_argument('--ex_name', default='ProDesign', type=str)
parser.add_argument('--use_gpu', default=True, type=bool)
parser.add_argument('--gpu', default=0, type=int)
parser.add_argument('--seed', default=111, type=int)

# CATH
# dataset parameters
parser.add_argument('--data_name', default='CATH', choices=['CATH', 'TS50'])
parser.add_argument('--data_root', default='ProDesign/data/')
parser.add_argument('--batch_size', default=8, type=int)
parser.add_argument('--num_workers', default=8, type=int)

# method parameters
parser.add_argument('--method', default='ProDesign', choices=['ProDesign'])
parser.add_argument('--config_file', '-c', default=None, type=str)
parser.add_argument('--hidden_dim', default=128, type=int)
parser.add_argument('--node_features', default=128, type=int)
parser.add_argument('--edge_features', default=128, type=int)
parser.add_argument('--k_neighbors', default=30, type=int)
parser.add_argument('--dropout', default=0.1, type=int)
parser.add_argument('--num_encoder_layers', default=10, type=int)

# Training parameters
parser.add_argument('--epoch', default=100, type=int, help='end epoch')
parser.add_argument('--log_step', default=1, type=int)
parser.add_argument('--lr', default=0.001, type=float, help='Learning rate')
parser.add_argument('--patience', default=100, type=int)

# ProDesign parameters
parser.add_argument('--updating_edges', default=4, type=int)
parser.add_argument('--node_dist', default=1, type=int)
parser.add_argument('--node_angle', default=1, type=int)
parser.add_argument('--node_direct', default=1, type=int)
parser.add_argument('--edge_dist', default=1, type=int)
parser.add_argument('--edge_angle', default=1, type=int)
parser.add_argument('--edge_direct', default=1, type=int)
parser.add_argument('--virtual_num', default=3, type=int)
args = parser.parse_args([])
return args

import torch
from ProDesign.main import Exp
from ProDesign.parser import create_parser
args = get_parser()
exp = Exp(args)
svpath = 'ProDesign/results/ProDesign/'
exp.method.model.load_state_dict(torch.load(svpath+'checkpoint.pth'))
Use GPU: cuda:0
100%|██████████| 21668/21668 [00:33<00:00, 655.05it/s] 

device: 	cuda	
display_step: 	10	
res_dir: 	ProDesign/results	
ex_name: 	ProDesign	
use_gpu: 	True	
gpu: 	0	
seed: 	111	
data_name: 	CATH	
data_root: 	ProDesign/data/	
batch_size: 	8	
num_workers: 	8	
method: 	ProDesign	
config_file: 	None	
hidden_dim: 	128	
node_features: 	128	
edge_features: 	128	
k_neighbors: 	30	
dropout: 	0.1	
num_encoder_layers: 	10	
epoch: 	100	
log_step: 	1	
lr: 	0.001	
patience: 	100	
updating_edges: 	4	
node_dist: 	1	
node_angle: 	1	
node_direct: 	1	
edge_dist: 	1	
edge_angle: 	1	
edge_direct: 	1	
virtual_num: 	3	
<All keys matched successfully>
代码
文本

3. CATH4.2 结果

代码
文本
[6]
from API.dataloader import make_cath_loader
from API.cath_dataset import CATH

with open('ProDesign/data/cath/chain_set_splits.json','r') as f:
test_split = json.load(f)

with open('ProDesign/data/cath/test_split_L100.json','r') as f:
test_short_split = json.load(f)

with open('ProDesign/data/cath/test_split_sc.json','r') as f:
test_SC_split = json.load(f)

alphabet='ACDEFGHIKLMNPQRSTVWY'
alphabet_set = set([a for a in alphabet])
max_length = 500
with open('ProDesign/data/cath/chain_set.jsonl') as f:
lines = f.readlines()
data_list = []
for line in lines:
entry = json.loads(line)
seq = entry['seq']

for key, val in entry['coords'].items():
entry['coords'][key] = np.asarray(val)

bad_chars = set([s for s in seq]).difference(alphabet_set)

if len(bad_chars) == 0:
if len(entry['seq']) <= max_length:
data_list.append({
'title':entry['name'],
'seq':entry['seq'],
'CA':entry['coords']['CA'],
'C':entry['coords']['C'],
'O':entry['coords']['O'],
'N':entry['coords']['N']
})


test_full_list = []
test_SC_list = []
test_Short_list = []
for data in data_list:
if data['title'] in test_split['test']:
test_full_list.append(data)
if data['title'] in test_SC_split['test']:
test_SC_list.append(data)
if data['title'] in test_short_split['test']:
test_Short_list.append(data)
print(">Loading",len(test_full_list),"samples for testAll")
print(">Loading",len(test_SC_list),"samples for testSC")
print(">Loading",len(test_Short_list),"samples for testShort")

exp.test_loader = make_cath_loader(CATH(data=test_full_list), 'SimDesign', 8)
exp.test()
print("median: {:.4f}\t mean: {:.4f}\t std: {:.4f}\t min: {:.4f}\t max: {:.4f}".format(exp.method.median_recovery, exp.method.mean_recovery, exp.method.std_recovery, exp.method.min_recovery, exp.method.max_recovery))
exp.test_loader = make_cath_loader(CATH(data=test_SC_list), 'SimDesign', 8)
exp.test()
print("median: {:.4f}\t mean: {:.4f}\t std: {:.4f}\t min: {:.4f}\t max: {:.4f}".format(exp.method.median_recovery, exp.method.mean_recovery, exp.method.std_recovery, exp.method.min_recovery, exp.method.max_recovery))
exp.test_loader = make_cath_loader(CATH(data=test_Short_list), 'SimDesign', 8)
exp.test()
print("median: {:.4f}\t mean: {:.4f}\t std: {:.4f}\t min: {:.4f}\t max: {:.4f}".format(exp.method.median_recovery, exp.method.mean_recovery, exp.method.std_recovery, exp.method.min_recovery, exp.method.max_recovery))
>Loading 1120 samples for testAll
>Loading 103 samples for testSC
>Loading 94 samples for testShort
test loss: 1.3224: 100%|██████████| 140/140 [00:23<00:00,  6.07it/s]
100%|██████████| 1120/1120 [00:58<00:00, 19.13it/s]
Test Perp: 4.5533, Test Rec: 0.5166

Category Unknown Rec: 0.5166

median: 0.5166	 mean: 0.4939	 std: 0.1113	 min: 0.0896	 max: 0.7698
test loss: 1.5457: 100%|██████████| 13/13 [00:01<00:00,  7.08it/s]
100%|██████████| 103/103 [00:05<00:00, 17.75it/s]
Test Perp: 6.2962, Test Rec: 0.3846

Category Unknown Rec: 0.3846

median: 0.3846	 mean: 0.4050	 std: 0.1338	 min: 0.1522	 max: 0.7556
test loss: 1.7012: 100%|██████████| 12/12 [00:01<00:00,  7.99it/s]
100%|██████████| 94/94 [00:04<00:00, 21.64it/s]Test Perp: 6.0340, Test Rec: 0.3984

Category Unknown Rec: 0.3984

median: 0.3984	 mean: 0.4134	 std: 0.1407	 min: 0.1167	 max: 0.7556

代码
文本

4.如果你想从头训练PiFold

代码
文本

pifold仓库提供了训练脚本,可以直接运行main.py
使用parser中定义的默认参数即可

代码
文本
[ ]
%cd ProDesign
!python main.py
/content/ProDesign
{'device': 'cuda', 'display_step': 10, 'res_dir': './results', 'ex_name': 'debug', 'use_gpu': True, 'gpu': 0, 'seed': 111, 'data_name': 'CATH', 'data_root': './data/', 'batch_size': 8, 'num_workers': 8, 'method': 'ProDesign', 'config_file': None, 'hidden_dim': 128, 'node_features': 128, 'edge_features': 128, 'k_neighbors': 30, 'dropout': 0.1, 'num_encoder_layers': 10, 'epoch': 100, 'log_step': 1, 'lr': 0.001, 'patience': 100, 'updating_edges': 4, 'node_dist': 1, 'node_angle': 1, 'node_direct': 1, 'edge_dist': 1, 'edge_angle': 1, 'edge_direct': 1, 'virtual_num': 3}
Use GPU: cuda:0
100% 21668/21668 [00:21<00:00, 1015.98it/s]

device: 	cuda	
display_step: 	10	
res_dir: 	./results	
ex_name: 	debug	
use_gpu: 	True	
gpu: 	0	
seed: 	111	
data_name: 	CATH	
data_root: 	./data/	
batch_size: 	8	
num_workers: 	8	
method: 	ProDesign	
config_file: 	None	
hidden_dim: 	128	
node_features: 	128	
edge_features: 	128	
k_neighbors: 	30	
dropout: 	0.1	
num_encoder_layers: 	10	
epoch: 	100	
log_step: 	1	
lr: 	0.001	
patience: 	100	
updating_edges: 	4	
node_dist: 	1	
node_angle: 	1	
node_direct: 	1	
edge_dist: 	1	
edge_angle: 	1	
edge_direct: 	1	
virtual_num: 	3	
>>>>>>>>>>>>>>>>>>>>>>>>>> training <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
train loss: 2.6155:   6% 135/2253 [00:59<14:09,  2.49it/s]
代码
文本
AI4S
Deep Learning
AI4SDeep Learning
已赞1
推荐阅读
公开
Reproduction of ESM-IF1
Deep LearningAI4S
Deep LearningAI4S
ABOWLofFish
发布于 2023-09-20
1 转存文件
公开
PTM Prediction
hermite
hermite
mengyue@dp.tech
发布于 2023-10-12
2 转存文件
评论
 ## 1. 环境配置

Zhifeng Gao

11-21 22:22
可以加一个初始镜像哈
评论
 # Add this in a Goog...

Hui_Zhou

11-20 01:09
ModuleNotFoundError: No module named 'torch'
评论