A Gentle Introduction to Geometric Graph Neural Networks
📖 Getting Started Guide
This document can be executed directly on the Bohrium Notebook. To begin, click the Connect button located at the top of the interface, then select the bohrium-notebook:2023-05-31 Image and choose a GPU machine configuration to proceed.
Graph Neural Networks (GNNs) are a part of an emerging research paradigm called Geometric Deep Learning -- devising neural network architectures that respect the invariances and symmetries in data. This practical aims to be a gentle introduction into the world of Geometric Deep Learning.
The aims of this practical are as follows:
- Understanding the concepts of invariance and equivariance to symmetries, which are fundamental properties of Graph Neural Networks. We will cover theory and proofs, as well as programming and unit testing.
- Becoming hands-on with PyTorch Geometric (PyG), a popular libary for developing state-of-the-art GNNs and Geometric Deep Learning models. In particular, gaining familiarity with the
MessagePassing
base class for designing novel GNN layers and theData
object for representing graph datasets. - Gaining an appreciation of the fundamental principles behind constructing GNN layers that take advantage of geometric information for graphs embedded in 3D space, such as biomolecules, materials, and other physical systems.
Authors and Acknowledgements
Here are the authors: do not hesitate to reach out to us for any queries and feedback on your solutions!
- Chaitanya K. Joshi (ckj24@cl.cam.ac.uk)
- Charlie Harris (cch57@cam.ac.uk)
- Ramon Viñas Torné (rv340@cam.ac.uk)
This notebook was initially developed for students taking the following courses:
- Representation Learning on Graphs and Networks, at University of Cambridge's Department of Computer Science and Technology (instructors: Prof. Pietro Liò, Dr. Petar Veličković).
- Geometric Deep Learning, at the African Master’s in Machine Intelligence (instructors: Prof. Michael Bronstein, Prof. Joan Bruna, Dr. Taco Cohen, Dr. Petar Veličković).
⚙️ Part 0: Installation and Setup
❗️Note: You will need a GPU to complete this practical. If using Collab, remember to click Runtime -> Change runtime type
, and set the hardware accelerator
to GPU.
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
PyTorch version 1.13.1+cu116 PyG version 2.0.3
Great! We are ready to dive into the practical!
🧪 Part 0: Introduction to Molecular Property Prediction with PyTorch Geometric
This section covers the fundamentals. We will study how Graph Neural Networks (GNNs) can be employed for predicting chemical properties of molecules, an impactful real-world application of Geometric Deep Learning. To achieve this, we will first introduce PyTorch Geometric, a widely-used Python library that facilitates the implementation of GNNs.
PyTorch Geometric
PyTorch Geometric (PyG) is an excellent library for graph representation learning research and development:
PyTorch Geometric (PyG) consists of various methods for deep learning on graphs and other irregular structures, also known as Geometric Deep Learning, from a variety of published papers. In addition, it provides easy-to-use mini-batch loaders for operating on many small and single giant graphs, multi GPU-support, distributed graph learning, a large number of common benchmark datasets, and helpful transforms, both for learning on arbitrary graphs as well as on 3D meshes or point clouds.
In this practical, we will make extensive use of PyG. If you have never worked with PyG before, do not worry, we will provide you with some examples and guide you through all the fundamentals in a detailed manner. We also highly recommend this self-contained official tutorial, which will help you get started. Among other things, you will learn how to implement state-of-the-art GNN layers via the generic PyG Message Passing class (more on this later).
Now, let's turn our attention to the problem of predicting molecular properties.
Molecular Property Prediction
Molecules are a great example of an object from nature that can easily be represented as a graph of atoms (nodes) connected by bonds (edges). A popular application of GNNs in chemistry is the task of Molecular Property Prediction. The goal is to train a GNN model from historical experimental data that can predict useful properties of drug-like molecules. The model's predictions can then be used to guide the drug design process.
One famous example of GNNs being used in molecular property prediction is in the world of antibiotic discovery, an area with a potentially massive impact on humanity and infamously little innovation. A GNN trained to predict how much a molecule would inhibit a bacteria was able identify the previously overlooked compound Halicin (below) during virtual screening. Not only did halicin show powerful results during in vitro (in cell) testing but it also had a completely novel mechanism of action that no bacteria has developed resistance to (yet).
The QM9 Dataset
QM9 (Quantum Mechanics dataset 9) is a dataset consisting of about 130,000 small molecules with 19 regression targets. Since being used by MoleculeNet, it has become a popular dataset to benchmark new architectures for molecular property prediction.
Specifically, we will be predicting the electric dipole moment of drug-like molecules. According to Wikipedia:
"The electric dipole moment is a measure of the separation of positive and negative electrical charges within a system, that is, a measure of the system's overall polarity."
We can visualize this concept via the water molecule H20, which forms a dipole due to its slightly different distribution of negative (blue) and postive (red) charge.
You do not need to worry about the exact physical and chemical principles that underpin dipole moments. As you might imagine, writing the equations from first priciples to predict a property like this, espeically for complex molecules (e.g. proteins), is very difficult. All you need know (for the sake of this practical anyway) is that these molecules can be representated as graphs with node and edge features as well as spatial information that we can use to train a GNN model using the ground truth labels.
Now let us load the QM9 dataset and explore how molecular graphs are represented. PyG makes this extremely convinient.
(The dataset may take a few minutes to download.)
Data Preparation and Splitting
The QM9 dataset has over 130,000 molecular graphs!
Let us create a more tractable sub-set of 3,000 molecular graphs for the purposes of this practical and separate it into training, validation, and test sets. We shall use 1,000 graphs each for training, validation, and testing.
Towards the end of this practical, you will get to experiment with the full/larger sub-sets of the QM9 dataset, too.
Total number of samples: 130831. Created dataset splits with 1000 training, 1000 validation, 1000 test samples.
Visualising Molecular Graphs
To get a better understanding of how the QM9 molecular graphs look like, let's visualise a few samples from the training set along with their corresponding target (their dipole moment).
In the following plot we visualise sparse graphs where edges represent physical connections (i.e. bonds). In this practical, however, we will use fully-connected graphs and encode the graph structure in the attributes of each. Later in this practical, we will study the advantages and downsides of both approaches.
❗️Note: we have implemented some code for you to convert the PyG graph into a Molecule object that can be used by RDKit, a python package for chemistry and visualing molecules. It is not important for you to understand RDKit beyond visualisation purposes.
Understanding PyG Data Objects
Each graph in our dataset is encapsulated in a PyG Data
object, a convient way of representing all structured data for use in Geometric Deep Learning (including graphs, point clouds, and meshes).
Let us print all the attributes (along with their shapes) that our PyG molecular graph contains: Data(x=[5, 11], edge_index=[2, 20], edge_attr=[20, 4], y=[1], pos=[5, 3], z=[5], name='gdb_1', idx=[1])
Within an instance of a Data
object, individual Torch.Tensor
attributes (or any other variable type) can be easily dot accessed within a neural network layer. The graphs from PyG come with a number of pre-computed features which we describe below (do not worry if you are unfamiliar with the chemistry terms here):
Atom features (data.x
) -
- 1st-5th features: Atom type (one-hot: H, C, N, O, F)
- 6th feature (also
data.z
): Atomic number (number of protons). - 7th feature: Aromatic (binary)
- 8th-10th features: Electron orbital hybridization (one-hot: sp, sp2, sp3)
- 11th feature: Number of hydrogens
Edge Index (data.edge_index
) -
- A tensor of dimensions 2 x
num_edges
that describe the edge connectivity of the graph
Edge features (data.edge_attr
) -
- 1st-4th features: bond type (one-hot: single, double, triple, aromatic)
Atom positions (data.pos
) -
- 3D coordinates of each atom . (We will talk about their importance later in the practical.)
Target (data.y
) -
- A scalar value corresponding to the molecules electric dipole moment
❗️Note: We will be using fully-connected graphs (i.e. all atoms in a molecule are connected to each other, except self-loops). The information about the molecule structures will be available to the models through the edge features (data.edge_attr
) as follows:
- When two atoms are physically connected, the edge attributes indicate the bond type (single, double, triple, or aromatic) through a one-hot vector.
- When two atoms are not physically connected, all edge attributes are zero. We will later study the advantages/downsides of fully-connected adjacency matrices versus sparse adjacency matrices (where an edge between two atoms is present only when there exists a physical connection between them).
This molecule has 5 atoms, and 20 edges. For each atom, we are given a feature vector with 11 entries (described above). For each edge, we are given a feature vector with 4 entries (also described above). In the next section, we will learn how to build a GNN in the Message Passing flavor to process the node and edge features of molecular graphs and predict their properties. Each atom also has a 3-dimensional coordinate associated with it. We will talk about their importance later in the practical. Finally, we have 1 regression target for the entire molecule.
Using PyG for batching
As you might remember from the previous practical, batching graphs can be quite a tedious and fiddly process. Thankfully, using PyG makes this super simple! Given a list of Data
objects, we can easily batch this into a PyG Batch
object as well as unbatch back into a list of graphs. Furthermore, in simple cases like ours, the PyG DataLoader
object (different from the vanilla PyTorch one) handles all of the batching under the hood for us!
Lets quicky batch and unbatch some graphs anyway as a demonstration:
Awesome! We have downloaded and prepared the QM9 dataset, visualised some samples, understood the attributes associated with each molecular graph, and reviewed how batching works in PyG. Now, we are ready to understand how we can develop GNNs in PyG for molecular property prediction.
📩 Part 0: Introduction to Message Passing Neural Networks in PyTorch Geometric
As a gentle introduction to PyTorch Geometric, we will walk you through the first steps of developing a GNN in the Message Passing flavor.
Formalism
Firstly, let us formalise our molecular property prediction pipeline. (Our notation will mostly follow what has been introduced in the lectures, but we do make some different choices for variable names.)
Graph
Consider a molecular graph , where is a set of nodes, and is a set of edges associated with the nodes. For each node , we are given a -dimensional initial feature vector . For each edge , we are given a -dimensional initial feature vector . For QM9 graphs, .
Label/target
Associated with each graph is a scalar target or label , which we would like to predict.
We will design a Message Passing Neural Network for graph property prediction to do this. Our MPNN will consist of several layers of message passing, followed by a global pooling and prediction head.
MPNN Layer
The Message Passing operation iteratively updates node features from layer to layer via the following equation: where are Multi-Layer Perceptrons (MLPs), and is a permutation-invariant local neighborhood aggregation function such as summation, maximization, or averaging.
Let us break down the MPNN layer into three pedagogical steps:
- Step (1): Message. For each pair of linked nodes , the network first computes a message . The MLP takes as input the concatenation of the feature vectors from the source node, destination node, and edge.
- Note that for the first layer , , where is a simple linear projection (
torch.nn.Linear
) for the initial node features to hidden dimension .
- Note that for the first layer , , where is a simple linear projection (
- Step (2): Aggregate. At each node , the incoming messages from all its neighbors are then aggregated as , where is a permutation-invariant function. We will use summation, i.e. .
- Step (3): Update. Finally, the network updates the node feature vector , by concatenating the aggregated message and the previous node feature vector , and passing them through an MLP .
Global Pooling and Prediction Head
After layers of message passing, we obtain the final node features . As we have a single target per graph, we must pool all node features into a single graph feature or graph embedding via another permutation-invariant function , sometimes called the 'readout' function, as follows: We will use global average pooling over all node features, i.e.
The graph embedding is passed through a linear prediction head to obtain the overall prediction :
Loss Function
Our MPNN graph property prediction model can be trained end-to-end via minimizing the standard mean-squared error loss for regression:
Coding the basic Message Passing Neural Network Layer
We are now ready to define a basic MPNN layer which implements what we have described above. In particular, we will code up the MPNN Layer first. (We will code up the other parts subsequently.)
To do so, we will inherit from the MessagePassing
base class, which automatically takes care of message propagation and is extremely useful to develop advanced GNN models. To implement a custom MPNN, the user only needs to define the behaviour of the message
(i.e. ), the aggregate
(i.e. ), and update
(i.e. ) functions. You may also refer to the PyG documentation for implementing custom message passing layers.
Below, we provide the implementation of a standard MPNN layer as an example, with extensive inline comments to help you figure out what is going on.
Great! We have defined a Message Passing layer following the equation we had introduced previously. Let us use this layer to code up the full MPNN graph property prediction model. This model will take as input molecular graphs, process them via multiple MPNN layers, and predict a single property for each of them.
Awesome! We are done defining our first MPNN model for graph property prediction.
But wait! Before we dive into training and evaluation this model, let us write some sanity checks for a fundamental property of the model and the layer.
Unit tests for Permutation Invariance and Equivariance
The lectures have repeatedly indicated on certain fundamental properties for machine learning on graphs:
- A GNN layer is equivariant to permutations of the set of nodes in the graph; i.e. as we permute the nodes, the node features produced by the GNN must permute accordingly.
- A GNN model for graph-level property prediction is invariant to the permutations of the set of nodes in the graph; i.e. as we permute the nodes, the graph-level propery remains unchanged.
(But wait...What is a permutation? Essentially, it is an ordering of the nodes in a graph. In general, there is no canonical way of assigning an ordering of the nodes, unlike textual or image data. However, graphs need to be stored and processed on computers in order to perform machine learning on them (which is what this course is about!). Thus, we need to ensure that our models are able to principaly handle this lack of canonical ordering or permutation of graph nodes. This is what the above statements are trying to say.)
Formalism
Let us try to formalise these notions of permutation invariance and equivariance via matrix notation (it is easier that way).
- Let be a matrix of node features for a given molecular graph, where is the number of nodes/atoms and each row is the -dimensional feature for node .
- Let be the adjacency matrix where each entry denotes the presence or absence of an edge between nodes and .
- Let be a GNN layer that takes as input the node features and adjacency matrix, and returns the updated node features.
- Let be a GNN model that takes as input the node features and adjacency matrix, and returns the predicted graph-level property.
- Let be a permutation matrix which has exactly one 1 in every row and column, and 0s elsewhere. Left-multipying with a matrix changes the ordering of the rows of the matrix.
Permuation Equivariance
The GNN layer is permuation equivariant as follows:
Another way to formulate the above could be: (1) Consider the updated node features . (2) Applying any permutation matrix to the input of the GNN layer should produce the same result as applying the same permutation on :
Permuation Invariance
The GNN model for graph-level prediction is permutation invariant as follows:
Another way to formulate the above could be: (1) Consider the predicted molecular property . (2) Applying any permutation matrix to the input of the GNN model should produce the same result as not applying it:
With that formalism out of the way, let us write some unit tests to confirm that our MPNNModel
and MPNNLayer
are indeed permutation invariant and equivariant, respectively.
Now that we have defined the unit tests for permutation invariance (for the full MPNN model) and permutation equivariance (for the MPNN layer), let us perform the sanity check:
Is MPNNModel permutation invariant? --> True! Is MPNNLayer permutation equivariant? --> True!
Training and Evaluating Models
Great! We are finally ready to train and evaluate our model on QM9. We have provided a basic experiment loop which takes as input the model and dataloaders, performs training, and returns the final performance on the validation and test set.
We will be training a MPNNModel
consisting of 4 layers of message passing with a hidden dimension of 64.
Running experiment for MPNNModel, training on 1000 samples for 100 epochs. Model architecture: MPNNModel( (lin_in): Linear(in_features=11, out_features=64, bias=True) (convs): ModuleList( (0): MPNNLayer(emb_dim=64, aggr=add) (1): MPNNLayer(emb_dim=64, aggr=add) (2): MPNNLayer(emb_dim=64, aggr=add) (3): MPNNLayer(emb_dim=64, aggr=add) ) (lin_pred): Linear(in_features=64, out_features=1, bias=True) ) Total parameters: 103233 Start training: Epoch: 010, LR: 0.000900, Loss: 0.4851226, Val MAE: 1.0571234, Test MAE: 0.8492622 Epoch: 020, LR: 0.000810, Loss: 0.4219062, Val MAE: 0.9702439, Test MAE: 0.6901035 Epoch: 030, LR: 0.000729, Loss: 0.3758734, Val MAE: 8.7811255, Test MAE: 0.6909492 Epoch: 040, LR: 0.000590, Loss: 0.3854174, Val MAE: 2.7061053, Test MAE: 0.6557090 Epoch: 050, LR: 0.000531, Loss: 0.2343547, Val MAE: 0.8977432, Test MAE: 0.6557090 Epoch: 060, LR: 0.000430, Loss: 0.2004967, Val MAE: 0.9157710, Test MAE: 0.6557090 Epoch: 070, LR: 0.000387, Loss: 0.1646426, Val MAE: 0.8549278, Test MAE: 0.6379384 Epoch: 080, LR: 0.000314, Loss: 0.1231778, Val MAE: 0.8455599, Test MAE: 0.6379384 Epoch: 090, LR: 0.000282, Loss: 0.1172216, Val MAE: 0.9022236, Test MAE: 0.6573651 Epoch: 100, LR: 0.000282, Loss: 0.0914796, Val MAE: 0.7459125, Test MAE: 0.5835653 Done! Training took 2.08 mins. Best validation MAE: 0.7459125, corresponding test MAE: 0.5835653. /tmp/ipykernel_813/709837178.py:13: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead. DF_RESULTS = DF_RESULTS.append(df_temp, ignore_index=True)
{'MPNNModel': (0.7459124717712402, 0.5835652933120727, 2.082554332415263)}
Super! Everything up to this point has already been covered in the lectures, and we hope that the practical so far has been a useful recap along with the acompanying code.
Now for the fun part, where you will be required to think what you have studied so far!
🧊 Part 1: Geometric Graphs and Message Passing with 3D Coordinates
Remember that we were given 3D coordinates with each atom in our molecular graph?
Molecular graphs, and other structured data occurring in nature, do not simply exist on flat planes. Instead, molecules have an inherent 3D structure that influences their properties and functions.
Let us visualize a molecule from QM9 in all of its 3D glory!
Go ahead and try move this molecule with your mouse cursor!
💻Task 1.1: Develop a Message Passing Neural Network that incorporates the atom coordinates as node features (0.5 Marks).
Our initial and somewhat 'vanilla' MPNN MPNNModel
ignores the atom coordiantes and only uses the node features to perform message passing. This means that the model is not leveraging useful 3D structural information to predict the target property.
Your first task is to modify the original MPNNModel
to incorporate atom coordinates into the node features.
We have defined most of the new CoordMPNNModel
class for you, and you have to fill in the YOUR CODE HERE
sections.
🤔 Hint: As reminder, the 3D atom positions are stored in data.pos
. You don't have to do something very smart right now (that will come later). A simple solution is okay to get started, e.g. concatenation or summation.
💻Task 1.2: Test the permutation invariance and equivariance properties of your new CoordMPNNModel
with node features and coordinates, as well as the constituent MPNNLayer
. (0.5 Marks)
Super! You have successfully implemented an MPNN which utilises both the atom features as well as coordinates to predict molecular properties.
Before we evaluate it, let us once again run the permutation sanity checks again to make sure the model and layer have the desired properties that constitute every basic GNN:
- The
MPNNLayer
should be permutation equivariant (we have already shown this previously, but we want you to repeat the exercise in order to thoroughly understand it). - The
CoordMPNNModel
should be permutation invariant.
Your task is to fill in the YOUR CODE HERE
sections to run the required unit tests. You do not need to write new unit tests yet, the ones we defined previously can be re-used.
💻Task 1.3. Prove that your new CoordMPNNModel
is invariant to permutations of both the node features as well as node coordinates. (0.5 Marks)
🤔 Hint: We are looking for simple statements that follow how we formalised permuation invariance for the vanilla MPNN model. We expect you to be copy-pasting most of the formalism and accounting for how your MPNN incorporates both the node features and coordinates. You can additionally introduce as the matrix of node coordinates for a given molecular graph.
❗️YOUR ANSWER HERE
💻Task 1.4. Train and evaluate your CoordMPNNModel
with node features and coordinates on QM9. (0.5 Marks)
Awesome! You are now ready to train and evaluate our new MPNN with node features and coordinates on QM9.
Re-use the experiment loop we have provided and fill in the YOUR CODE HERE
sections to run the experiment.
You will be training a CoordMPNNModel
consisting of 4 layers of message passing with a hidden dimension of 64, in order to compare your result fairly to the previous vanilla MPNNModel
.
Hmm... If you've implemented the CoordMPNNModel
correctly up till now, you may see a very curious result -- the performance of CoordMPNNModel
is about equal or marginally worse than the vanilla MPNNModel
!
This is because the CoordMPNNModel
is not using 3D structural information in a principled manner.
The next sections will help us formalise and understand why this is happening.
🔄 Part 2: Invariance to 3D Symmetries: Rotation and Translation
We saw that the performance of CoordMPNNModel
is unexpectedly mediocre compared to MPNNModel
despite using both node features and coordinates. (But please do not panic if your results say otherwise.) In order to determine why, we must understand the concept of 3D symmetries.
Geometric Invariance
Recall that molecular graphs have 3D coordinates for each atom. A key detail which we have purposely withheld from you up till this point (😈) is that these 3D coordinates are not inherently fixed or permanent. Instead, they were experimentally determined relative to a frame of reference.
To fully grasp these statements, here is GIF of a drug-like molecules moving around in 3D space...
The atoms' 3D coordinates are constantly rotating and translating. However, the properties of this molecule will always remain the same no matter how we rotate or translate it. In other words, the molecule's properties are invariant to 3D rotations and translations.
In this block we will study how to design GNN layers and models that respect these regularities.
Formalism
Let us try to formalise the notion of invariance to 3D rotations and translations in GNNs via matrix notation.
- Let be a matrix of node features for a given molecular graph, where is the number of nodes/atoms and each row is the -dimensional feature for node .
- Let be a matrix of node coordinates for a given molecular graph, where is the number of nodes/atoms and each row is the 3D coordinate for node .
- Let be the adjacency matrix where each entry denotes the presence or absence of an edge between nodes and .
- Let be a GNN layer that takes as input the node features, node coordinates, and adjacency matrix, and returns the updated node features.
- Let be a GNN model that takes as input the node features, node coordinates, and adjacency matrix, and returns the predicted graph-level property.
(Notice that we have updated the notation for the GNN layer and GNN model to include the matrix of node coordinates as an additional input.)
💻Task 2.1: What does it mean for the GNN model and the GNN layer to be invariant to 3D rotations and translations? Express this mathematically using the definitions above. (0.5 Mark)
🤔 Hint: Revisit the formalisms for permutation invariance and equivariance to get an idea of how to go about this. You should use the matrix notation we have provided above. Similar to the permuatation matrix , you may now define an orthogonal rotation matrix and a translation vector in your answer. These would operate on the matrix of node coordinates .
❗️YOUR ANSWER HERE
Before you start coding up a more principled MPNN model, we would like you to take a moment to think about why invariance to 3D rotations and translations is something desirable for GNNs predicting molecular properties...
💻Task 2.2: Is invariance to 3D rotations and translations a desirable property for GNNs? Explain why. (0.5 Marks)
🤔 Hint: We are not looking for an essay, a few sentences will suffice here.
❗️YOUR ANSWER HERE
💻Task 2.3: Write the unit test to check your CoordMPNNModel
for 3D rotation and translation invariance. (0.5 Mark)
🤔 Hint: Show that the output of the model varies when:
- All the atom coordinates in
data.pos
are multiplied by any random orthogonal rotation matrix . (We have provided a helper function for creating rotation matrices.) - All the atom coordinates in
data.pos
are displaced by any random translation vector .
Now that you have defined the unit tests for rotation and translation invariance, perform the sanity check on your CoordMPNNModel
:
(Spoiler alert: if you have implemented things as expected, the unit test should return False
for the CoordMPNNModel
.)
In this part, you have formalised how a GNN can be 3D rotation and translation invariant, thought about why this is desirable for molecular property prediction, and shown that the CoordMPNNModel
was not rotation and translation invariant.
At this point, you should have a concrete understanding of why the performance of CoordMPNNModel
is equal or worse than the vanilla MPNNModel
, and what we meant by our initial statement before we began this part:
"The
CoordMPNNModel
is not using 3D structural information in a principled manner"
Let us try fixing this in the next part!
✈️ Part 3: Message Passing with Invariance to 3D Rotations and Translations
This section will dive into how we may design GNN models which operate on graphs with 3D coordinates in a more theoretically sound way.
💻Task 3.1: Design a new Message Passing Layer as well as the accompanying MPNN Model that are both invariant to 3D rotations and translations. (2 Marks)
❗️ Note: There is no single correct answer to this question.
Our initial 'vanilla' MPNN MPNNModel
and MPNNLayer
ignored the atom coordiantes and only uses the node features to perform message passing. This means that the model was not leveraging 3D structural information to predict the target property.
Our second 'naive' coordinate MPNN CoordMPNNModel
used the node features along with the atom coordinates in an unprincipled manner, resulting in the model not being invariant to 3D rotations and translations of the coordinates (which was a desirable property, as we saw in the previous part).
Your task is to define a new InvariantMPNNLayer
which utilise both atom coordinates and node features.
We have defined most of the new InvariantMPNNLayer
, and you have to fill in the YOUR CODE HERE
sections. We have also already defined the InvariantMPNNModel
that instantiates your new layer to compose the model. You only need to define the new layer.
🤔 Hint 1: Unlike the previous CoordMPNNModel
, we would suggest using the coodinate information to constuct the messages as opposed to incorporating it into the node features. In particular, we would like you to think about how to use the coordinates in a principled manner to constuct the messages: What is a measurement that we can computer using a pair of coordinates that will be invariant to rotating and translating them?
🤔 Hint 2: tensors passed to propagate()
can be mapped to the respective nodes and by appending _i
or _j
to the variable name, e.g. h_i
and h_j
for the node features h
. Note that we generally refer to _i
as the central nodes that aggregates information, and refer to _j
as the neighboring nodes.
Super! You have now defined a more geometrically principled message passing layer and used it to construct an MPNN model with is invariant to 3D rotations and translations.
💻Task 3.2: Write down the update equation of your new InvariantMPNNLayer
and use that to prove that the layer and model are invariant to 3D rotations and translations. (1 Mark)
❗️YOUR ANSWER HERE
Great! You have successfully written the update equation for your new InvariantMPNNLayer
and shown how it is indeed invariant to 3D rotations and translations.
Let us just perform some sanity checks to verify this.
💻Task 3.3: Perform unit tests for your InvariantMPNNLayer
and InvariantMPNNModel
. Show that the layer and model are both invariant to 3D rotations and translations. (0.5 Mark)
🤔 Hint: Run the unit tests defined previously.
Good job! You have defined the InvariantMPNNLayer
and InvariantMPNNModel
, after which you have proved and experimentally verified their invariance to 3D rotations and translations.
It is finally time to run an experiment with our geometrically principled model!
💻Task 3.4: Train and evaluate your InvariantMPNNModel
. Additionally, provide a few sentences explaining the model's results compared to the basic MPNNModel
and the naive CoordMPNNModel
defined previously. Is the new model better? By a significant margin or only minorly better? (0.5 Mark)
Re-use the experiment loop we have provided and fill in the YOUR CODE HERE
sections to run the experiment.
You will be training an InvariantMPNNModel
consisting of 4 layers of message passing with a hidden dimension of 64, in order to compare your result fairly to the previous vanilla MPNNModel
and naive CoordMPNNModel
.
❗️YOUR ANSWER HERE
Awesome! You have now gone from a vanilla MPNNModel
, to a naive use of coodinate information in CoordMPNNModel
, to a more geometrically principled approach in InvariantMPNN
model.
In the next part, we will try to further push the limits of how much information we can derive from the geometry of molecules!
🚀 Part 4: Message Passing with Equivariance to 3D Rotations and Translations
In the previous part of the practical, we studied the concepts of 3D rotation and translation invariance. Now, we will go one step further. We will consider a GNN for molecular property prediction that is composed of message passing layers that are equivariant to 3D rotations and translations.
But why...you may ask. Let us take a step back.
Why Geometric Equivariance over Invariance?
In order to motivate the need for geometric equivariance and symmetries, we would like to take you back to the notion of permutation symmetries in GNNs for graphs, as well as translation symmetries in ConvNets for 2D images.
Permutation Symmetry in GNNs vs. DeepSets
Earlier in the practical, we reviewed the concept of permutation invariance and equivariance. Fundamentally, a GNN layer must be a permutation equivariant operation on the graph nodes, i.e. changing the node ordering of the graph results in the same permutation applied to the node outputs of the layer. However, the overall GNN model for graph-level property prediction is still a permutation invariant function on the graph nodes, i.e. changing the node ordering does not impact the predicted graph property.
Recall from the lectures that the DeepSets model is yet another permutation invariant architecture over sets of nodes, and is a perfectly reasonable option for predicting graph-level properties (which are also permutation invariant, as we just stated). This raises a critical question: why did we build permutation invariant GNN models composed of permutation equivariant GNN layers?
The answer is that permutation equivariant GNN layers enable the model to better leverage the relational structure of the underlying nodes, as well as construct more powerful node representations by stacking several layers of these permutation equivariant operations. (You can try running a DeepSets model for QM9 yourself and see the performance reduce.)
Now, consider the same analogy for 3D rotation and translation symmetries for your molecular property prediction models. Consider your InvariantMPNNModel
so far -- it is composed of InvariantMPNNLayer
which are merely invariant to 3D rotations and translations.
Analogous to how permutation equivariant layer enabled GNNs to leverage relational structure in a more principled manner, a 3D rotation and translation equivariant layer may enable your model to leverage geometric structure in a more principled manner, too.
Translation Symmetry in ConvNets for 2D Images
Yet another example where invariant models are composed of equivariant layers is the ubiquitous Convolutional Neural Network for 2D images.
The ConvNet model is invariant to translations, in the sense that it will detect a cat in an image, regardless of where the cat is positioned in the image.
Importantly, the ConvNet is composed of convolution filters which are akin to sliding a rectangular window over the input image. Convolution filters are matching low level patterns within the image. Intuitively, one of these filters may be a cat detection filter, in that it will fire whenever it comes across cat-like pixels. Thus, convolution filters are translation equivariant functions since their output translates along with their input.
(Source)
Translation invariant ConvNets are composed of translation equivariant convolution filters in order to build heirarchical features across multiple layers. Stacking deep ConvNets enables the features across layers to interact in a compositional manner and enables the overall network to learn increasingly complex visual concepts.
The following video shows yet another demonstration of the translational equivariance of convolution filters: a shift to the input image directly corresponds to a shift of the output features.
(Source)
Formalism
Hopefully, we have sufficiently motivated the need for 3D rotation and translation equivariant GNN layers. Let us now try to formalise the notion of equivariance to 3D rotations and translations via matrix notation.
- Let be a matrix of node features for a given molecular graph, where is the number of nodes/atoms and each row is the -dimensional feature for node .
- Let be a matrix of node coordinates for a given molecular graph, where is the number of nodes/atoms and each row is the 3D coordinate for node .
- Let be the adjacency matrix where each entry denotes the presence or absence of an edge between nodes and .
- Let be a GNN layer that takes as input the node features, node coordinates, and adjacency matrix, and returns the updated node features as well as updated node coordinates.
- Let be a GNN model that takes as input the node features, node coordinates, and adjacency matrix, and returns the predicted graph-level property.
Our GNN model is composed of multiple rotation and translation equivariant GNN layers .
How is this different from Geometrically Invariant Message Passing?
Importantly, and in contrast to rotation and translation invariant message passing layers, each round of equivariant message passing updates both the node features as well as the node coordinates:
Such a formulation is highly beneficial for GNNs to learn useful node features in settings where we are modelling a dynamical system and have reason to believe that the node coordinates are continuously being updated, e.g. by the action of intermolecular forces.
Do note the following nuances about geometrically equivariant message passing layers :
- The updated node coodinates are equivariant to 3D rotations and translations of the input coordinates .
- The updated node features are still invariant to 3D rotations and translations of the input coordinates (similar to the geometrically invariant message passing layer).
- The overall MPNN model will still be invariant to 3D rotations and translations. This is because we are predicting a single scalar quantity (the electric dipole moment) per molecule, which remains unchanged under any rotations and translations of the atoms' coordinates. Thus, the final node feature vectors after layers of message passing are aggregated into a graph embedding (and the final node coordinates are ignored). The graph embedding is then used to predict the target.
The following figure aims to succinctly capture these nuances about geometrically equivariant message passing layers which are used to compose a geometrically invariant GNN :
What we want you to investigate in this part is how we may improve a GNN model that is invariant to 3D rotations and translations by using message passing layers that are equivariant to these 3D symmetries.
Let us get started!
💻Task 4.1: What does it mean for a GNN layer to be equivariant to 3D rotations and translations? Express this mathematically using the definitions above. (0.5 Marks)
🤔 Hint: Revisit the formalisms introduced previously for permutation invariance and equivariance, as well as 3D rotation and traslation invariance.
❗️YOUR ANSWER HERE
💻Task 4.2: Design a new Message Passing Layer that is equivariant to 3D rotations and translations. (2.5 Marks)
🤔 Hint 1: To ensure equivariance to 3D rotations and translations, your message passing layer should now update both the node features as well as the node coordinates. This means that each of the message()
, aggregate()
, and update()
functions will be passing around a tuple of outputs, consisting of the node features and node coordinates.
🤔 Hint 2: Certain quantities that can be computed among a pair of node coordinates do not change when the coordinates are rotated or translated -- these are invariant quantities. On the other hand, certain quantities may rotate or translate along with the coordinates -- these are equivariant quantities. We want you to think about how you can set up the message passing in a way that messages for the node feature updates are invariant to 3D rotations and translations, while messages for the node coordinates are equivariant to the same.
❗️Note: This task has multiple possible approaches for acheiving. Directly importing or copying implementations from PyG will not be accepted as a valid answer.
❗️Note: The trivial solution will not be accepted as a valid answer. A general intuition about GNNs is that each node learns how to borrow information from its neighbours — here, this holds true for both node feature information as well as node coordinate information. Thus, we want you to use message passing to update the node coordinates by aggregating from the node coordinates of the neighbours. The ‘game’ here is about how to design a coordinate message function such that it is equivariant to 3D symmetries.
Awesome! You have now defined a new message passing layer that is equivariant to 3D rotations and translations, and used it to construct your final MPNN model for molecular property prediction.
💻Task 4.3: Write down the update equation of your new EquivariantMPNNLayer
and use that to prove that the layer is equivariant to 3D rotations and translations. (1 Mark)
❗️YOUR ANSWER HERE
Great! You have successfully written the update equation for your new EquivariantMPNNLayer
and shown how it is indeed equivariant to 3D rotations and translations.
Let us just perform some sanity checks to verify this.
💻Task 4.4: Perform unit tests for your EquivariantMPNNLayer
and FinalMPNNModel
. Firstly, write the unit test for 3D rotation and translation equivariance for the layer. Then, show that the layer is equivariant to 3D rotations and translations, and that the model is invariant to 3D rotations and translations. (1 Mark)
At last! You have defined the EquivariantMPNNLayer
and FinalMPNNModel
, after which you have proved and experimentally verified the new layer is equivariant to 3D rotations and translations.
It is finally time to run an experiment with our final geometrically principled model!
💻Task 4.5: Train and evaluate your FinalMPNNModel
. Additionally, provide a few sentences explaining the model's results compared to the basic MPNNModel
, the naive CoordMPNNModel
, and the InvariantMPNNModel
defined previously. Is the new model better? By a significant margin or only minorly better? (0.5 Mark)
Re-use the experiment loop we have provided and fill in the YOUR CODE HERE
sections to run the experiment.
You will be training an EquivariantMPNNModel
consisting of 4 layers of message passing with a hidden dimension of 64, in order to compare your result fairly to the previous vanilla MPNNModel
, naive CoordMPNNModel
, and InvariantMPNNModel
.
❗️YOUR ANSWER HERE
Congratulations! You have now gone from a vanilla MPNNModel
, to a naive use of coodinate information in CoordMPNNModel
, to a more geometrically principled approach in InvariantMPNNModel
, and finally arrived at FinalMPNNModel
, a GNN that is invariant to 3D rotations and translations while consisting of message passing layers that are equivariant to these 3D symmetries.
In the next parts, we will compare these models under two different settings.
🌯 Part 5: Wrapping up
In this section, we will wrap up the practical by analysing two important aspects of the models that we have studied so far: sample efficiency and choice of graph structure.
❗️Note: Ideally, you do not need to write any new code for the tasks in this part. You are only required to run the cells in the notebook and infer the empirical results that you see. This is an exercise to simulate how you may need to infer tables and figures when reading or writing your own research papers.
Sample Efficiency
We firstly want you to think about sample efficiency -- model A is more sample efficient than model B if it can get the most out of every sample in the sense that it can reach better performance with lesser data.
💻Task 5.1: Study all the models' performance across the number of training epochs. What do you observe? Explain your findings. (1 Mark)
You can consider the number of training epochs as a proxy for the number of training samples, i.e. a model is more sample efficient if it converges to better performance within fewer epochs.
Compare the models' performance across the number of training samples. How do the different modelling assumptions of the standard MPNNModel
, the CoordMPNNModel
, the InvariantMPNNModel
, and the FinalMPNN
influence sample efficiency? Which models perform best in low-sample regimes? What happens as we increase the sample size?
Use the sns.lineplot()
function provided along with the results from DF_RESULTS
to visualise the validation and test set MAE w.r.t. the number of training epochs in order to answer this question.
❗️Note: It is highly encouraged that you attempt this task even if you have not been successful in implementing all of the models in the practical. Just answer based on the models you did understand and implement!
❗️YOUR ANSWER HERE
Dense vs. Sparse Graphs
Now, let's turn our attention to the choice of the underlying graph structure.
In this practical we have been using fully-connected adjacency matrices to represent molecules (i.e. all atoms in a molecule are connected to each other, except self-loops). Note, however, that the information about the molecular graph has always been available to the models through the edge attributes data.edge_attr
:
- When two atoms are physically connected, the edge attributes indicate the bond type (single, double, triple, or aromatic) through a one-hot vector.
- When two atoms are not physically connected, all edge attributes are zero.
In the following task, we will study the advantages/downsides of fully-connected adjacency matrices versus sparse adjacency matrices (where an edge between two atoms is present only when there exists a physical connection between them).
💻Task 5.2: Compare the models' performance in the two scenarios (fully-connected versus sparse graphs). Explain your findings. (1 Mark)
The code to load datasets in the sparse format is provided to you. You may need to wait for some time to let all the models finish training with the sparse format.
Grab a coffee/tea! ☕️
❗️Note: Once again, it is highly encouraged that you attempt this task even if you have not been successful in implementing all of the models in the practical. Just answer based on the models you did understand and implement!
Let's now check that the sparse dataset is actually more sparse than the fully-connected dataset that we have been using throughout the practical:
Let's now compare the models under the two scenarios:
Compare the models' performances under the two scenarios. Which models performed better/worst? Why do you think that is the case? Did you observe any differences between the fully-connected and sparse scenarios? Provide at least two arguments to explain the differences.
❗️YOUR ANSWER HERE
FAQ and General Advice
Unit Testing
Consider using the unit test functions as sanity checks to make sure what you think is going on is actually going on. You can even consider writing your own unit tests beyond what we have asked in order to test specific properties of layers/models or numerical issues during training.
Theory vs. Empirical Results
If you are confident that your proposed layer is theoretically sound, e.g. in Task 4.2. if your solution satisfied 3D equivariance but the results are not impressive or you are unable to achieve stable training, it may be due to numerical instability or engineering issues. If that is the case and you are not able to overcome those issues eventually, you may still submit whatever you have as a solution. You will be awarded partial marks as long as the theory is correct.
GPU rate-limit on Google Colab
TL;DR Don’t panic, start early, save and load your results instead of re-running every time.
We experienced rate-limits several times during the testing of the practical. It seems that there is an upper limit to the amount of GPU computation per user per 12-24 hours. When the limit is hit, Colab disconnects the GPU runtime (you do reconnect back to the GPU by the next day). Thus, we have tried to keep the practical as computationally simple as we can.
We have several suggestions to make life more manageable, which we enumerate in the following bullet points:
- If possible, do not leave things for the last moment. Start early so that you are not struggling with the rate limit on the day of the deadline!
- If you do get rate-limited, you can consider writing and testing your model implementations with very small dataset sizes, e.g. 100 samples each. When you reconnect to the GPU, you can re-run your models with the full dataset.
- If you find yourself hit by regular rate-limiting (e.g. if you also have other Colab projects running, this can happen a lot), you can save the results of each task to your Google Drive/local storage and simply load them each time you re-run the notebook.
- You can use some combination of a new Google account, your cam.ac.uk account, and/or a new IP address to get a fresh GPU runtime if you have been rate-limited on one account.