Bohrium
robot
新建

空间站广场

论文
Notebooks
比赛
课程
Apps
我的主页
我的Notebooks
我的论文库
我的足迹

我的工作空间

任务
节点
文件
数据集
镜像
项目
数据库
公开
Use Con-CDVAE to generate the crystals you need
python
python
Con-CDVAE
更新于 2024-08-06
推荐镜像 :Basic Image:ubuntu22.04-py3.10-intel2022
推荐机型 :c2_m4_cpu
赞 2
1
1
Introduction
Training
Generating

Introduction

Con-CDVAE is a diffusion based model which can generates crystals according to the physical properties you need. It was developed based on CDVAE and inspired by DALL·E2.

代码
文本

alt Figure 1: Training and generation flow chart of Con-CDVAE.

Figure 1 shows the model framework of Con-CDVAE and the process of training the model and generating the crystals. We use the VAE structure as CDVAE did. The Encoder is a graph model that can convert crystals into latent variables. The Decoder consists of some MLPs and another graph model. MLP is used to generate lattice constant, number of atoms. Given these constants, the graph model is used to generate atomic position and atomic species in diffusion way. At the same time, we use the properties of the crystal as input to the Decoder as well, enabling it to generate crystal with the target properties we set.

Note that we use Predictor, which consists of MLPs, to predict crystal properties. This block use latent variables as input to make crystals with similar properties close in the latent space which may be helpful when use Prior to generate new latent variable with properties.

Prior block is inspired by DALL·E2. It is another diffusion model composed of MLPs. Prior using properties as input samples the crystal latent variables from the latent variable space. Because Prior requires crystal latent variables as labels, the training of Con-CDVAE needs to be done in two steps.

After training, Con-CDVAE can generate crystals based on target properties. First, Prior will sample latent variables based on the properties. Then the latent variables and properties will be input to the Decoder to generate crystals.

代码
文本

More details can be found in the paper and github:
Paper: Con-CDVAE: A method for the conditional generation of crystal structures
Code: Con-CDVAE

代码
文本

Training

Let's start to train a Con-CDVAE which can generate crystals based on formation energy. In this notebook, I just use the toy datasets, mptest, and only trained for a few epochs to show how to run the code. If you want to train a useful Con-CDVAE you should train the model with enough epochs and a large enough dataset, such as download from CDVAE, or Materials Project, and so on.

代码
文本
  1. First, after downloading the code from github you need to build the environment. We recommend using conda to do it.
代码
文本
[1]
!git clone https://github.com/cyye001/Con-CDVAE
Cloning into 'Con-CDVAE'...
remote: Enumerating objects: 265, done.
remote: Counting objects: 100% (178/178), done.
remote: Compressing objects: 100% (148/148), done.
remote: Total 265 (delta 42), reused 86 (delta 18), pack-reused 87
Receiving objects: 100% (265/265), 52.48 MiB | 4.87 MiB/s, done.
Resolving deltas: 100% (49/49), done.
代码
文本
[3]
!conda env create -f Con-CDVAE/environment.yml
Collecting package metadata (repodata.json): \ Killed
代码
文本
  1. And modify the following environment variables in .env.
  • PROJECT_ROOT: path to the folder that contains this repo
  • HYDRA_JOBS: path to a folder to store hydra outputs
  • WABDB: path to a folder to store wabdb outputs
代码
文本
  1. Step-one training

    To train a Con-CDVAE, run the following command first.

    After training, model checkpoints can be found in$HYDRA_JOBS/singlerun/YYYY-MM-DD/model_expname.pth.

代码
文本
[23]
%cd /Con-CDVAE/
!python concdvae/run.py train=new data=mptest expname=test model=vae_mp_format
已隐藏输出
代码
文本
  1. Step-two training

    After finishing step-one training, you can train the Prior block with the following command.

    Then you can get the default condition Prior in /your_path_to_model_checkpoints/conz_model_your_label_diffu.pth.

代码
文本
[25]
!python scripts/condition_diff_z.py --model_path /Con-CDVAE/output/hydra/singlerun/2024-08-05/test --model_file model_test.pth --fullfea 0 --label conztest
已隐藏输出
代码
文本

Generating

To generate materials, you should prepare condition file. You can see the example in /output/hydra/singlerun/2024-01-25/test/, where "general_full.csv" is for default strategy or full strategy, and "general_less.csv" is for less strategy.

For simplicity, we copy .../2024-01-25/test/general_full.csv directly into the new model_path, then run the following command:

代码
文本
[26]
!python scripts/evaluate_diff.py --model_path /Con-CDVAE/output/hydra/singlerun/2024-08-05/test --model_file model_test.pth --conz_file conz_model_conztest_diffu.pth --label test --prop_path general_full.csv
/opt/mamba/lib/python3.10/site-packages/hydra/experimental/initialize.py:116: UserWarning: hydra.experimental.initialize_config_dir() is no longer experimental. Use hydra.initialize_config_dir().
  deprecation_warning(message=message)
/opt/mamba/lib/python3.10/site-packages/hydra/experimental/initialize.py:118: UserWarning: 
The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  self.delegate = real_initialize_config_dir(
/opt/mamba/lib/python3.10/site-packages/hydra/experimental/compose.py:25: UserWarning: hydra.experimental.compose() is no longer experimental. Use hydra.compose()
  deprecation_warning(message=message)
/Con-CDVAE/scripts/eval_utils.py:33: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  checkpoint = torch.load(model_root, map_location=torch.device('cpu'))
/Con-CDVAE/scripts/eval_utils.py:36: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  lattice_scaler = torch.load(Path(model_path) / 'lattice_scaler.pt')
/Con-CDVAE/scripts/evaluate_diff.py:162: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  checkpoint = torch.load(conz_model_root, map_location=torch.device('cpu'))
use default feature
condition_diff_z(
  (time_mlp): Sequential(
    (0): SinusoidalPositionEmbeddings()
    (1): Linear(in_features=64, out_features=64, bias=True)
    (2): ReLU()
  )
  (condition_model): ConditioningModule(
    (condition_embModel): ModuleList(
      (0): ScalarConditionEmbedding(
        (gaussian_expansion): GaussianRBF()
        (dense_net): Sequential(
          (0): Linear(in_features=15, out_features=64, bias=True)
          (1): ReLU()
          (2): Linear(in_features=64, out_features=64, bias=True)
          (3): ReLU()
          (4): Linear(in_features=64, out_features=64, bias=True)
          (5): ReLU()
          (6): Linear(in_features=64, out_features=64, bias=True)
        )
      )
    )
    (dense_net): Sequential(
      (0): Linear(in_features=64, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): ReLU()
      (4): Linear(in_features=128, out_features=128, bias=True)
    )
  )
  (decoder): UNet(
    (downmodel): ModuleList(
      (0): Linear(in_features=448, out_features=128, bias=True)
      (1): Linear(in_features=320, out_features=64, bias=True)
      (2): Linear(in_features=256, out_features=32, bias=True)
    )
    (downact): ModuleList(
      (0-2): 3 x ReLU()
    )
    (upmodel): ModuleList(
      (0): Linear(in_features=224, out_features=64, bias=True)
      (1): Linear(in_features=256, out_features=128, bias=True)
      (2): Linear(in_features=320, out_features=256, bias=True)
    )
    (upact): ModuleList(
      (0-2): 3 x ReLU()
    )
    (middle_mpl): Linear(in_features=224, out_features=32, bias=True)
    (middle_act): ReLU()
    (output_mlp): Linear(in_features=448, out_features=256, bias=True)
  )
)
Evaluate model on the generation task.
No. 1  in  7 with label =  CS0
in no_grad!!!!
No. 1  in  2
mae 0.3208746910095215 0.6436481475830078
after down sample torch.Size([2, 256])
/Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  X = torch.tensor(X, dtype=torch.float)
100%|███████████████████████████████████████████| 50/50 [06:12<00:00,  7.45s/it]
No. 2  in  2
mae 0.28677281737327576 0.4612818658351898
after down sample torch.Size([2, 256])
/Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  X = torch.tensor(X, dtype=torch.float)
100%|███████████████████████████████████████████| 50/50 [07:50<00:00,  9.42s/it]
No. 2  in  7 with label =  CS1
in no_grad!!!!
No. 1  in  2
mae 0.06959331035614014 0.46275168657302856
after down sample torch.Size([2, 256])
/Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  X = torch.tensor(X, dtype=torch.float)
100%|███████████████████████████████████████████| 50/50 [05:41<00:00,  6.84s/it]
No. 2  in  2
mae 0.16774910688400269 0.6675240993499756
after down sample torch.Size([2, 256])
/Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  X = torch.tensor(X, dtype=torch.float)
100%|███████████████████████████████████████████| 50/50 [07:26<00:00,  8.93s/it]
No. 3  in  7 with label =  CS2
in no_grad!!!!
No. 1  in  2
mae 0.2971403896808624 0.5597849488258362
after down sample torch.Size([2, 256])
/Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  X = torch.tensor(X, dtype=torch.float)
100%|███████████████████████████████████████████| 50/50 [03:08<00:00,  3.78s/it]
No. 2  in  2
mae 0.3593195378780365 0.8800906538963318
after down sample torch.Size([2, 256])
/Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  X = torch.tensor(X, dtype=torch.float)
100%|███████████████████████████████████████████| 50/50 [03:13<00:00,  3.87s/it]
No. 4  in  7 with label =  CS3
in no_grad!!!!
No. 1  in  2
mae 0.015590429306030273 0.5026229023933411
after down sample torch.Size([2, 256])
/Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  X = torch.tensor(X, dtype=torch.float)
100%|███████████████████████████████████████████| 50/50 [02:19<00:00,  2.79s/it]
No. 2  in  2
mae 0.1418730616569519 0.36814606189727783
after down sample torch.Size([2, 256])
/Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  X = torch.tensor(X, dtype=torch.float)
100%|███████████████████████████████████████████| 50/50 [03:35<00:00,  4.31s/it]
No. 5  in  7 with label =  CS4
in no_grad!!!!
No. 1  in  2
mae 0.5420414805412292 0.8714357018470764
after down sample torch.Size([2, 256])
/Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  X = torch.tensor(X, dtype=torch.float)
100%|███████████████████████████████████████████| 50/50 [02:35<00:00,  3.12s/it]
No. 2  in  2
mae 0.6836747527122498 1.2558685541152954
after down sample torch.Size([2, 256])
/Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  X = torch.tensor(X, dtype=torch.float)
100%|███████████████████████████████████████████| 50/50 [02:57<00:00,  3.56s/it]
No. 6  in  7 with label =  CS5
in no_grad!!!!
No. 1  in  2
mae 0.1297561526298523 1.2116953134536743
after down sample torch.Size([2, 256])
/Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  X = torch.tensor(X, dtype=torch.float)
100%|███████████████████████████████████████████| 50/50 [10:05<00:00, 12.12s/it]
No. 2  in  2
mae 0.21289175748825073 0.9193131327629089
after down sample torch.Size([2, 256])
/Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  X = torch.tensor(X, dtype=torch.float)
100%|███████████████████████████████████████████| 50/50 [08:28<00:00, 10.16s/it]
No. 7  in  7 with label =  CS6
in no_grad!!!!
No. 1  in  2
mae 0.1584688425064087 0.21139904856681824
after down sample torch.Size([2, 256])
/Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  X = torch.tensor(X, dtype=torch.float)
100%|███████████████████████████████████████████| 50/50 [03:14<00:00,  3.89s/it]
No. 2  in  2
mae 0.1455262303352356 1.4346026182174683
after down sample torch.Size([2, 256])
/Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  X = torch.tensor(X, dtype=torch.float)
100%|███████████████████████████████████████████| 50/50 [04:39<00:00,  5.59s/it]
end
代码
文本

If you want to filter latent variables using the Predictor block, set --down_sample 100 which means filtering at a ratio of one hundred to one.

The generated crystals are stored in eval_gen_your_label_xxx.pt. The number of structures under each eval_gen_your_label_xxx.pt is (num_batches_to_samples * batch_size / down_sample). So you can change the setting to control the number of generated structures.

For example, by using --num_batches_to_samples 10 --batch_size 500 --down_sample 1 you will get 5000 structures under each eval_gen_your_label_xxx.pt.

Then you can use Con-CDVAE/scripts/pt2cif.py to get the cif from xxx.pt.

代码
文本
python
python
已赞2
推荐阅读
公开
Introduction to Con-CDVAE: conditional generation of crystalline structures
AI4S
AI4S
Peng
更新于 2024-08-11
1 赞1 转存文件
公开
From DFT to MD: A Comprehensive 「Deep Potential」 Guide to Getting Started with Materials Computation
notebook
notebook
Letian
更新于 2024-07-10
1 转存文件