Introduction
Con-CDVAE is a diffusion based model which can generates crystals according to the physical properties you need. It was developed based on CDVAE and inspired by DALL·E2.
Figure 1: Training and generation flow chart of Con-CDVAE.
Figure 1 shows the model framework of Con-CDVAE and the process of training the model and generating the crystals. We use the VAE structure as CDVAE did. The Encoder is a graph model that can convert crystals into latent variables. The Decoder consists of some MLPs and another graph model. MLP is used to generate lattice constant, number of atoms. Given these constants, the graph model is used to generate atomic position and atomic species in diffusion way. At the same time, we use the properties of the crystal as input to the Decoder as well, enabling it to generate crystal with the target properties we set.
Note that we use Predictor, which consists of MLPs, to predict crystal properties. This block use latent variables as input to make crystals with similar properties close in the latent space which may be helpful when use Prior to generate new latent variable with properties.
Prior block is inspired by DALL·E2. It is another diffusion model composed of MLPs. Prior using properties as input samples the crystal latent variables from the latent variable space. Because Prior requires crystal latent variables as labels, the training of Con-CDVAE needs to be done in two steps.
After training, Con-CDVAE can generate crystals based on target properties. First, Prior will sample latent variables based on the properties. Then the latent variables and properties will be input to the Decoder to generate crystals.
More details can be found in the paper and github:
Paper: Con-CDVAE: A method for the conditional generation of crystal structures
Code: Con-CDVAE
Training
Let's start to train a Con-CDVAE which can generate crystals based on formation energy. In this notebook, I just use the toy datasets, mptest, and only trained for a few epochs to show how to run the code. If you want to train a useful Con-CDVAE you should train the model with enough epochs and a large enough dataset, such as download from CDVAE, or Materials Project, and so on.
- First, after downloading the code from github you need to build the environment. We recommend using conda to do it.
Cloning into 'Con-CDVAE'... remote: Enumerating objects: 265, done. remote: Counting objects: 100% (178/178), done. remote: Compressing objects: 100% (148/148), done. remote: Total 265 (delta 42), reused 86 (delta 18), pack-reused 87 Receiving objects: 100% (265/265), 52.48 MiB | 4.87 MiB/s, done. Resolving deltas: 100% (49/49), done.
Collecting package metadata (repodata.json): \ Killed
- And modify the following environment variables in
.env
.
PROJECT_ROOT
: path to the folder that contains this repoHYDRA_JOBS
: path to a folder to store hydra outputsWABDB
: path to a folder to store wabdb outputs
Step-one training
To train a Con-CDVAE, run the following command first.
After training, model checkpoints can be found in
$HYDRA_JOBS/singlerun/YYYY-MM-DD/model_expname.pth
.
Step-two training
After finishing step-one training, you can train the Prior block with the following command.
Then you can get the default condition Prior in
/your_path_to_model_checkpoints/conz_model_your_label_diffu.pth
.
Generating
To generate materials, you should prepare condition file. You can see the example in /output/hydra/singlerun/2024-01-25/test/, where "general_full.csv" is for default strategy or full strategy, and "general_less.csv" is for less strategy.
For simplicity, we copy .../2024-01-25/test/general_full.csv directly into the new model_path, then run the following command:
/opt/mamba/lib/python3.10/site-packages/hydra/experimental/initialize.py:116: UserWarning: hydra.experimental.initialize_config_dir() is no longer experimental. Use hydra.initialize_config_dir(). deprecation_warning(message=message) /opt/mamba/lib/python3.10/site-packages/hydra/experimental/initialize.py:118: UserWarning: The version_base parameter is not specified. Please specify a compatability version level, or None. Will assume defaults for version 1.1 self.delegate = real_initialize_config_dir( /opt/mamba/lib/python3.10/site-packages/hydra/experimental/compose.py:25: UserWarning: hydra.experimental.compose() is no longer experimental. Use hydra.compose() deprecation_warning(message=message) /Con-CDVAE/scripts/eval_utils.py:33: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. checkpoint = torch.load(model_root, map_location=torch.device('cpu')) /Con-CDVAE/scripts/eval_utils.py:36: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. lattice_scaler = torch.load(Path(model_path) / 'lattice_scaler.pt') /Con-CDVAE/scripts/evaluate_diff.py:162: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. checkpoint = torch.load(conz_model_root, map_location=torch.device('cpu')) use default feature condition_diff_z( (time_mlp): Sequential( (0): SinusoidalPositionEmbeddings() (1): Linear(in_features=64, out_features=64, bias=True) (2): ReLU() ) (condition_model): ConditioningModule( (condition_embModel): ModuleList( (0): ScalarConditionEmbedding( (gaussian_expansion): GaussianRBF() (dense_net): Sequential( (0): Linear(in_features=15, out_features=64, bias=True) (1): ReLU() (2): Linear(in_features=64, out_features=64, bias=True) (3): ReLU() (4): Linear(in_features=64, out_features=64, bias=True) (5): ReLU() (6): Linear(in_features=64, out_features=64, bias=True) ) ) ) (dense_net): Sequential( (0): Linear(in_features=64, out_features=128, bias=True) (1): ReLU() (2): Linear(in_features=128, out_features=128, bias=True) (3): ReLU() (4): Linear(in_features=128, out_features=128, bias=True) ) ) (decoder): UNet( (downmodel): ModuleList( (0): Linear(in_features=448, out_features=128, bias=True) (1): Linear(in_features=320, out_features=64, bias=True) (2): Linear(in_features=256, out_features=32, bias=True) ) (downact): ModuleList( (0-2): 3 x ReLU() ) (upmodel): ModuleList( (0): Linear(in_features=224, out_features=64, bias=True) (1): Linear(in_features=256, out_features=128, bias=True) (2): Linear(in_features=320, out_features=256, bias=True) ) (upact): ModuleList( (0-2): 3 x ReLU() ) (middle_mpl): Linear(in_features=224, out_features=32, bias=True) (middle_act): ReLU() (output_mlp): Linear(in_features=448, out_features=256, bias=True) ) ) Evaluate model on the generation task. No. 1 in 7 with label = CS0 in no_grad!!!! No. 1 in 2 mae 0.3208746910095215 0.6436481475830078 after down sample torch.Size([2, 256]) /Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). X = torch.tensor(X, dtype=torch.float) 100%|███████████████████████████████████████████| 50/50 [06:12<00:00, 7.45s/it] No. 2 in 2 mae 0.28677281737327576 0.4612818658351898 after down sample torch.Size([2, 256]) /Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). X = torch.tensor(X, dtype=torch.float) 100%|███████████████████████████████████████████| 50/50 [07:50<00:00, 9.42s/it] No. 2 in 7 with label = CS1 in no_grad!!!! No. 1 in 2 mae 0.06959331035614014 0.46275168657302856 after down sample torch.Size([2, 256]) /Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). X = torch.tensor(X, dtype=torch.float) 100%|███████████████████████████████████████████| 50/50 [05:41<00:00, 6.84s/it] No. 2 in 2 mae 0.16774910688400269 0.6675240993499756 after down sample torch.Size([2, 256]) /Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). X = torch.tensor(X, dtype=torch.float) 100%|███████████████████████████████████████████| 50/50 [07:26<00:00, 8.93s/it] No. 3 in 7 with label = CS2 in no_grad!!!! No. 1 in 2 mae 0.2971403896808624 0.5597849488258362 after down sample torch.Size([2, 256]) /Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). X = torch.tensor(X, dtype=torch.float) 100%|███████████████████████████████████████████| 50/50 [03:08<00:00, 3.78s/it] No. 2 in 2 mae 0.3593195378780365 0.8800906538963318 after down sample torch.Size([2, 256]) /Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). X = torch.tensor(X, dtype=torch.float) 100%|███████████████████████████████████████████| 50/50 [03:13<00:00, 3.87s/it] No. 4 in 7 with label = CS3 in no_grad!!!! No. 1 in 2 mae 0.015590429306030273 0.5026229023933411 after down sample torch.Size([2, 256]) /Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). X = torch.tensor(X, dtype=torch.float) 100%|███████████████████████████████████████████| 50/50 [02:19<00:00, 2.79s/it] No. 2 in 2 mae 0.1418730616569519 0.36814606189727783 after down sample torch.Size([2, 256]) /Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). X = torch.tensor(X, dtype=torch.float) 100%|███████████████████████████████████████████| 50/50 [03:35<00:00, 4.31s/it] No. 5 in 7 with label = CS4 in no_grad!!!! No. 1 in 2 mae 0.5420414805412292 0.8714357018470764 after down sample torch.Size([2, 256]) /Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). X = torch.tensor(X, dtype=torch.float) 100%|███████████████████████████████████████████| 50/50 [02:35<00:00, 3.12s/it] No. 2 in 2 mae 0.6836747527122498 1.2558685541152954 after down sample torch.Size([2, 256]) /Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). X = torch.tensor(X, dtype=torch.float) 100%|███████████████████████████████████████████| 50/50 [02:57<00:00, 3.56s/it] No. 6 in 7 with label = CS5 in no_grad!!!! No. 1 in 2 mae 0.1297561526298523 1.2116953134536743 after down sample torch.Size([2, 256]) /Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). X = torch.tensor(X, dtype=torch.float) 100%|███████████████████████████████████████████| 50/50 [10:05<00:00, 12.12s/it] No. 2 in 2 mae 0.21289175748825073 0.9193131327629089 after down sample torch.Size([2, 256]) /Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). X = torch.tensor(X, dtype=torch.float) 100%|███████████████████████████████████████████| 50/50 [08:28<00:00, 10.16s/it] No. 7 in 7 with label = CS6 in no_grad!!!! No. 1 in 2 mae 0.1584688425064087 0.21139904856681824 after down sample torch.Size([2, 256]) /Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). X = torch.tensor(X, dtype=torch.float) 100%|███████████████████████████████████████████| 50/50 [03:14<00:00, 3.89s/it] No. 2 in 2 mae 0.1455262303352356 1.4346026182174683 after down sample torch.Size([2, 256]) /Con-CDVAE/concdvae/common/data_utils.py:313: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). X = torch.tensor(X, dtype=torch.float) 100%|███████████████████████████████████████████| 50/50 [04:39<00:00, 5.59s/it] end
If you want to filter latent variables using the Predictor block, set --down_sample 100 which means filtering at a ratio of one hundred to one.
The generated crystals are stored in eval_gen_your_label_xxx.pt. The number of structures under each eval_gen_your_label_xxx.pt is (num_batches_to_samples * batch_size / down_sample). So you can change the setting to control the number of generated structures.
For example, by using --num_batches_to_samples 10 --batch_size 500 --down_sample 1
you will get 5000 structures under each eval_gen_your_label_xxx.pt.
Then you can use Con-CDVAE/scripts/pt2cif.py to get the cif from xxx.pt.