Understanding Attention and Generalization in Graph Neural Networks

Boris Knyazev, Graham W. Taylor, and Mohamed R. Amer, NeurIPS 2019

Attention in CNNs is reweighting the feature map , to provide attention to some nodes.

Note: is a scalar and is vector of size C. So, is also a vector of size C.

Pooling in CNNs

Pooling in CNNs divide the grid into local regions uniformly (not neighbors) and aggregate them to reduce the dimension.

src: stackoverflow

So, there is no parallelism between attention and pooling in the CCNs.

But in GNN, pooling also use the neighborhoods.

Top K Pooling

Top K pooling was proposed by Gao and Ji, Graph U-Nets, ICML 2018, it is supposed to be a equivalent of k-max pooling (generalization of max-pooling) in the CNN where each feature map is reduced to size k by picking units with highest values. Since, in GNN the k-highest values can possibly come from different node for each feature-map the straight forward extention of k-max pooling does not work. So, Gao and Ji propose to project all the nodes to 1D and then select top K from that.

Given feature matrix and adjacency matrix , first project the feature matrix to 1D using projection vector and normalize it,

From this normalized 1D representation of each node (), filter top K nodes and use indexes ( ) to retrieve relevant feature matrix and adjacency matrix.

However, since the is discrete valued, authors use gate operation () to convert to real value and make it eligible for back-propagation . The final feature matrix for the next layer is obtained by element wise multiplication of feature vectors of selected nodes and ,

Over all,

In GNNs, there is a parallelism between pooling and attention. Node attention can be thought of as .

where, and . is obtained by finding the indices of top-k values of , which is computed by learning projection vector using back-propagation on input graph.

This paper proposes to combine the attention and pooling to a single computational block, which does not have a fixed . Instead, set is determined by threshold :

Further, they also propose a combination of GIN and ChebyNet called ChebyGIN to be used for convolution after pooling.

ChebyGIN

Graph Convolutional Network (GCN), Graph Isomorphism Network (GIN), ChebyNet have similar formulation with minor changes. The proposed ChebyGIN formulation is an extention of these changes. This section highlights the equivalence and differences in the mathematical forms of these networks. We compare the Convolution layer of these networks, each take input (equivalently ) and outputs (equivalently ). And, indicates neighborhood of the node .

GCN:

GIN:
Replaces with multi-layer perceptron (MLP) and since, the MLP has weights and does rescaling from to , we do not need and the normalization .

Here, when the current node is given more importance, when the current node has same importance.

ChebyNet:
Generalization of GCN to order approximation of Chebyshev polynomial.

ChebyGIN:

Replaces , and of ChebyNet, same as GCN.

is still multiplied at the neighborhood level to obtain different weights for each neighborhood. All the node feature vectors () are multiplied by node degree () for first layer.

Proposed architecture

src: Boris Knyazev's slides

The proposed architecture is as follows: first layer is a attention/pooling of input graph, second layer is GNN which aggregates features from local neighborhoods, and third layer is a fully connected (FC) layer, which can also do global pooling and finally an output layer which will be used for training. A separate fully connected MLP called attention network is trained to obtain attention values on each node.

Attention Network

For supervised learning of the attention network, the ground truth of attention values for each node () in the graph is obtained by heuristic.

For example, in experimental dataset for graph color count, attention on each node is defined as follows:

In experimental dataset of graph triangle count, following heuristic is used;

For MNIST- dataset where each node is a superpixel and edges are formed based on spatial distance between superpixel centers, following heuristic was used:

Training

These networks are trained using back-propagation to minimize the Mean-Squared Error (MSE) loss or the Cross-Entropy loss (CE) of the over all prediction and minimize the Kullback-Leibler (KL) divergence loss between ground truth attention and predicted coefficients . The KL term is weighted by scale and number of nodes .

Since , can be thought of as a probability distribution of attention over all the nodes and so, minimizing the KL-divergence is an obvious first choice. Below equation shows relationship between cross-entropy, entropy and KL Divergence.

Weakly supervised model

For domains where the ground truth of attention is hard to obtain for each node, authors propose a weakly supervised learning setting as follows. Train an attention network (model B), which has same structure as the proposed architecture (model A) except for attention/pooling layer. Model B is trained to reduce the for prediction. Then, the is calculated using the trained model and input graph .

Next, the proposed architecture -- Model A is trained using to optimize both the MSE and KL divergence.

src: Boris Knyazev's slides

For colors domain, authors use 2 layers of GNN. So mathematical form for model B is:

where as mathematical form of model A is:

where is as defined above, obtained from model B.

Analysis

How powerful is attention over nodes in GNNs?

Contrary to what the authors mention in the paper, I feel that the experimental results show that there is not a lot of co-relation between attention and model accuracy. The example result below shows that the even though the proposed model has high co-relation with attention AUC, there are other models which do not show better performance even when the attention AUC is high. This observation is also backed by the paper, Jain et al. NAACL 2019 Attention is not Explanation.

src:Knyazev et. al 2019 [Fig 3a]

So the power of attention over nodes is I think need more study.

What are the factors influencing performance of GNNs with attention?

From experiment results it seems following factors influence the GNNs with attention:

  1. initialization vector -- optimal initialization has better accuracy in Fig-4(c)
  2. quality of the attention -- Supervised attention has better results than weakly supervised attention.
  3. strength of GNN model used -- ChebyGIN model has better results than GIN/GCN

Why is the variance of some results so high?

Variance of some results is high because the model is very sensitive to the initialization parameters. It is only able to recover from bad initialization of hyper-parameters when the attention is good. Bad initialization of attention was not recoverable.

How top-k compares to our threshold-based pooling method?

Experiments show that threshold-based pooling has better results than top-k pooling for larger datasets (with high features).

How results change with increase of attention model input dimensionality or capacity?

With increase in the input dimension for the attention model, the distribution of values become flat (). Experiments show that in such cases, deeper GNN model for attention are useful.

Can we improve initialization of attention?

Authors observe for unsupervised attention models, normal or uniform distribution with high values is preferred for the initialization of parameters of attention model. But for supervised or weakly supervised model smaller values are preferred. There is no intuition on why one is preferred over the other, paper just states the observation based on empirical evaluations.

What is the recipe for more powerful attention GNNs?

Recipe for powerful attention is to get supervision for attention. If supervision is not possible, use the weakly-supervised method for attention.

How results differ depending on to which layer we apply the attention model?

Although it is desirable to use attention model closer to the input layer to reduce graph size and keep the attention weights interpretable, the experiments show that the attention on deeper layer have higher impact on the performance.

Why is initialization of attention important?

Since the final model is trained by considering the – attention weights as final, when the attention those weights have bad initialization, the weights learnt in rest of the model are wrong and hence the model is not able to recover.

However, I feel that the models should be able to recover from the bad initialization with more iterations. Literature of expectation-maximization and bi-level optimization indicates that this is possible.

Doubts

  1. Why use sigmoid in Top-K Pooling? Gate operation -- why is projection discrete ??

Questions

  1. What is the dimensionality of ? Link to Answer

  2. How to decide from input graph? Link to Answer

  3. Provide mathematical form of ChebyGIN and show all the parameters Link to Answer

  4. Why is selected as the loss function, but not cross entropy and squared error? Link to Answer

  5. Relation between Cross entropy and KL Divergence. Link to Answer

  6. Give mathematical forms of model A and B for Colors. Link to Answer

  7. Summarize: How powerful is attention over nodes in GNNs? Link to Answer

  8. Summarize: What are the factors influencing performance of GNNs with attention? Link to Answer

  9. Summarize: Why is the variance of some results so high? Link to Answer

  10. Summarize: How top-k compares to our threshold-based pooling method? Link to Answer

  11. Summarize: How results change with increase of attention model input dimensionality or capacity? Link to Answer

  12. Summarize: Can we improve initialization of attention? Link to Answer

  13. Summarize: What is the recipe for more powerful attention GNNs? Link to Answer

  14. Summarize: How results differ depending on to which layer we apply the attention model? Link to Answer

  15. Summarize: Why is initialization of attention important? Link to Answer

Extra questions to be considered

  1. Find the source code related to Weakly supervised attention component and explain each line in the related source code

  2. Why GIN moves from weighted mean to the sum?

  3. How to do back-propagation with ranking?

  4. Doesn't attention lead to overfitting ?? Higher number of parameters mean high chance of overfitting.

References

  1. Boris Knyazev's slides
  2. Gao and Ji, Graph U-Nets, ICML 2018