Graph neural network

From testwiki
Jump to navigation Jump to search

Template:Short description Template:Machine learning

Graph neural networks (GNN) are specialized artificial neural networks that are designed for tasks whose inputs are graphs.[1][2][3][4][5]

One prominent example is molecular drug design.[6][7][8] Each input sample is a graph representation of a molecule, where atoms form the nodes and chemical bonds between atoms form the edges. In addition to the graph representation, the input also includes known chemical properties for each of the atoms. Dataset samples may thus differ in length, reflecting the varying numbers of atoms in molecules, and the varying number of bonds between them. The task is to predict the efficacy of a given molecule for a specific medical application, like eliminating E. coli bacteria.

The key design element of GNNs is the use of pairwise message passing, such that graph nodes iteratively update their representations by exchanging information with their neighbors. Several GNN architectures have been proposed,[2][3][9][10][11] which implement different flavors of message passing,[12][13] started by recursive[2] or convolutional constructive[3] approaches. Template:As of, it is an open question whether it is possible to define GNN architectures "going beyond" message passing, or instead every GNN can be built on message passing over suitably defined graphs.[14]

Basic building blocks of a graph neural network (GNN). (1) Permutation equivariant layer. (2) Local pooling layer. (3) Global pooling (or readout) layer. Colors indicate features.

In the more general subject of "geometric deep learning", certain existing neural network architectures can be interpreted as GNNs operating on suitably defined graphs.[12] A convolutional neural network layer, in the context of computer vision, can be considered a GNN applied to graphs whose nodes are pixels and only adjacent pixels are connected by edges in the graph. A transformer layer, in natural language processing, can be considered a GNN applied to complete graphs whose nodes are words or tokens in a passage of natural language text.

Relevant application domains for GNNs include natural language processing,[15] social networks,[16] citation networks,[17] molecular biology,[18] chemistry,[19][20] physics[21] and NP-hard combinatorial optimization problems.[22]

Open source libraries implementing GNNs include PyTorch Geometric[23] (PyTorch), TensorFlow GNN[24] (TensorFlow), Deep Graph Library[25] (framework agnostic), jraph[26] (Google JAX), and GraphNeuralNetworks.jl[27]/GeometricFlux.jl[28] (Julia, Flux).

Architecture

The architecture of a generic GNN implements the following fundamental layers:[12]

  1. Permutation equivariant: a permutation equivariant layer maps a representation of a graph into an updated representation of the same graph. In the literature, permutation equivariant layers are implemented via pairwise message passing between graph nodes.[12][14] Intuitively, in a message passing layer, nodes update their representations by aggregating the messages received from their immediate neighbours. As such, each message passing layer increases the receptive field of the GNN by one hop.
  2. Local pooling: a local pooling layer coarsens the graph via downsampling. Local pooling is used to increase the receptive field of a GNN, in a similar fashion to pooling layers in convolutional neural networks. Examples include k-nearest neighbours pooling, top-k pooling,[29] and self-attention pooling.[30]
  3. Global pooling: a global pooling layer, also known as readout layer, provides fixed-size representation of the whole graph. The global pooling layer must be permutation invariant, such that permutations in the ordering of graph nodes and edges do not alter the final output.[31] Examples include element-wise sum, mean or maximum.

It has been demonstrated that GNNs cannot be more expressive than the Weisfeiler–Leman Graph Isomorphism Test.[32][33] In practice, this means that there exist different graph structures (e.g., molecules with the same atoms but different bonds) that cannot be distinguished by GNNs. More powerful GNNs operating on higher-dimension geometries such as simplicial complexes can be designed.[34][35][13] Template:As of, whether or not future architectures will overcome the message passing primitive is an open research question.[14]

Non-isomorphic graphs that cannot be distinguished by a GNN due to the limitations of the Weisfeiler-Lehman Graph Isomorphism Test. Colors indicate node features.

Message passing layers

Node representation update in a Message Passing Neural Network (MPNN) layer. Node 𝐱0 receives messages sent by all of its immediate neighbours 𝐱1 to 𝐱4. Messages are computing via the message function ψ, which accounts for the features of both senders and receiver.

Message passing layers are permutation-equivariant layers mapping a graph into an updated representation of the same graph. Formally, they can be expressed as message passing neural networks (MPNNs).[12]

Let G=(V,E) be a graph, where V is the node set and E is the edge set. Let Nu be the neighbourhood of some node uV. Additionally, let 𝐱u be the features of node uV, and 𝐞uv be the features of edge (u,v)E. An MPNN layer can be expressed as follows:[12]

𝐑u=ϕ(𝐱u,vNuψ(𝐱u,𝐱v,𝐞uv))

where ϕ and ψ are differentiable functions (e.g., artificial neural networks), and is a permutation invariant aggregation operator that can accept an arbitrary number of inputs (e.g., element-wise sum, mean, or max). In particular, ϕ and ψ are referred to as update and message functions, respectively. Intuitively, in an MPNN computational block, graph nodes update their representations by aggregating the messages received from their neighbours.

The outputs of one or more MPNN layers are node representations 𝐑u for each node uV in the graph. Node representations can be employed for any downstream task, such as node/graph classification or edge prediction.

Graph nodes in an MPNN update their representation aggregating information from their immediate neighbours. As such, stacking n MPNN layers means that one node will be able to communicate with nodes that are at most n "hops" away. In principle, to ensure that every node receives information from every other node, one would need to stack a number of MPNN layers equal to the graph diameter. However, stacking many MPNN layers may cause issues such as oversmoothing[36] and oversquashing.[37] Oversmoothing refers to the issue of node representations becoming indistinguishable. Oversquashing refers to the bottleneck that is created by squeezing long-range dependencies into fixed-size representations. Countermeasures such as skip connections[10][38] (as in residual neural networks), gated update rules[39] and jumping knowledge[40] can mitigate oversmoothing. Modifying the final layer to be a fully-adjacent layer, i.e., by considering the graph as a complete graph, can mitigate oversquashing in problems where long-range dependencies are required.[37]

Other "flavours" of MPNN have been developed in the literature,[12] such as graph convolutional networks[9] and graph attention networks,[11] whose definitions can be expressed in terms of the MPNN formalism.

Graph convolutional network

The graph convolutional network (GCN) was first introduced by Thomas Kipf and Max Welling in 2017.[9]

A GCN layer defines a first-order approximation of a localized spectral filter on graphs. GCNs can be understood as a generalization of convolutional neural networks to graph-structured data.

The formal expression of a GCN layer reads as follows:

𝐇=σ(𝐃~12𝐀~𝐃~12𝐗Θ)

where 𝐇 is the matrix of node representations 𝐑u, 𝐗 is the matrix of node features 𝐱u, σ() is an activation function (e.g., ReLU), 𝐀~ is the graph adjacency matrix with the addition of self-loops, 𝐃~ is the graph degree matrix with the addition of self-loops, and Θ is a matrix of trainable parameters.

In particular, let 𝐀 be the graph adjacency matrix: then, one can define 𝐀~=𝐀+𝐈 and 𝐃~ii=jVA~ij, where 𝐈 denotes the identity matrix. This normalization ensures that the eigenvalues of 𝐃~12𝐀~𝐃~12 are bounded in the range [0,1], avoiding numerical instabilities and exploding/vanishing gradients.

A limitation of GCNs is that they do not allow multidimensional edge features 𝐞uv.[9] It is however possible to associate scalar weights wuv to each edge by imposing Auv=wuv, i.e., by setting each nonzero entry in the adjacency matrix equal to the weight of the corresponding edge.

Graph attention network

The graph attention network (GAT) was introduced by Petar VeličkoviΔ‡ et al. in 2018.[11]

Graph attention network is a combination of a GNN and an attention layer. The implementation of attention layer in graphical neural networks helps provide attention or focus to the important information from the data instead of focusing on the whole data.

A multi-head GAT layer can be expressed as follows:

𝐑u=k=1Kσ(vNuαuv𝐖k𝐱v)

where K is the number of attention heads, denotes vector concatenation, σ() is an activation function (e.g., ReLU), αij are attention coefficients, and Wk is a matrix of trainable parameters for the k-th attention head.

For the final GAT layer, the outputs from each attention head are averaged before the application of the activation function. Formally, the final GAT layer can be written as:

𝐑u=σ(1Kk=1KvNuαuv𝐖k𝐱v)

Attention in Machine Learning is a technique that mimics cognitive attention. In the context of learning on graphs, the attention coefficient αuv measures how important is node uV to node vV.

Normalized attention coefficients are computed as follows:

αuv=exp(LeakyReLU(𝐚T[𝐖𝐱u𝐖𝐱v𝐞uv]))zNuexp(LeakyReLU(𝐚T[𝐖𝐱u𝐖𝐱z𝐞uz]))

where 𝐚 is a vector of learnable weights, T indicates transposition, 𝐞uv are the edge features (if present), and LeakyReLU is a modified ReLU activation function. Attention coefficients are normalized to make them easily comparable across different nodes.[11]

A GCN can be seen as a special case of a GAT where attention coefficients are not learnable, but fixed and equal to the edge weights wuv.

Crypto

The gated graph sequence neural network (GGS-NN) was introduced by Yujia Li et al. in 2015.[39] The GGS-NN extends the GNN formulation by Scarselli et al.[2] to output sequences. The message passing framework is implemented as an update rule to a gated recurrent unit (GRU) cell.

A GGS-NN can be expressed as follows:

𝐑u(0)=𝐱u𝟎
𝐦u(l+1)=vNuΘ𝐑v
𝐑u(l+1)=GRU(𝐦u(l+1),𝐑u(l))

where denotes vector concatenation, 𝟎 is a vector of zeros, Θ is a matrix of learnable parameters, GRU is a GRU cell, and l denotes the sequence index. In a GGS-NN, the node representations are regarded as the hidden states of a GRU cell. The initial node features 𝐱u(0) are zero-padded up to the hidden state dimension of the GRU cell. The same GRU cell is used for updating representations for each node.

Local pooling layers

Local pooling layers coarsen the graph via downsampling. We present here several learnable local pooling strategies that have been proposed.[31] For each case, the input is the initial graph is represented by a matrix 𝐗 of node features, and the graph adjacency matrix 𝐀. The output is the new matrix 𝐗of node features, and the new graph adjacency matrix 𝐀.

Top-k pooling

We first set

𝐲=𝐗𝐩𝐩

where 𝐩 is a learnable projection vector. The projection vector 𝐩 computes a scalar projection value for each graph node.

The top-k pooling layer [29] can then be formalised as follows:

𝐗=(𝐗sigmoid(𝐲))𝐒
𝐀=𝐀𝐒,𝐒

where 𝐒=topk(𝐲) is the subset of nodes with the top-k highest projection scores, denotes element-wise matrix multiplication, and sigmoid() is the sigmoid function. In other words, the nodes with the top-k highest projection scores are retained in the new adjacency matrix 𝐀. The sigmoid() operation makes the projection vector 𝐩 trainable by backpropagation, which otherwise would produce discrete outputs.[29]

Self-attention pooling

We first set

𝐲=GNN(𝐗,𝐀)

where GNN is a generic permutation equivariant GNN layer (e.g., GCN, GAT, MPNN).

The Self-attention pooling layer[30] can then be formalised as follows:

𝐗=(𝐗𝐲)𝐒
𝐀=𝐀𝐒,𝐒

where 𝐒=topk(𝐲) is the subset of nodes with the top-k highest projection scores, denotes element-wise matrix multiplication.

The self-attention pooling layer can be seen as an extension of the top-k pooling layer. Differently from top-k pooling, the self-attention scores computed in self-attention pooling account both for the graph features and the graph topology.

Heterophilic Graph Learning

Homophily principle, i.e., nodes with the same labels or similar attributes are more likely to be connected, has been commonly believed to be the main reason for the superiority of Graph Neural Networks (GNNs) over traditional Neural Networks (NNs) on graph-structured data, especially on node-level tasks.[41] However, recent work has identified a non-trivial set of datasets where GNN’s performance compared to the NN’s is not satisfactory.[42] Heterophily, i.e., low homophily, has been considered the main cause of this empirical observation.[43] People have begun to revisit and re-evaluate most existing graph models in the heterophily scenario across various kinds of graphs, e.g., heterogeneous graphs, temporal graphs and hypergraphs. Moreover, numerous graph-related applications are found to be closely related to the heterophily problem, e.g. graph fraud/anomaly detection, graph adversarial attacks and robustness, privacy, federated learning and point cloud segmentation, graph clustering, recommender systems, generative models, link prediction, graph classification and coloring, etc. In the past few years, considerable effort has been devoted to studying and addressing the heterophily issue in graph learning.[41][43][44]

Applications

Protein folding

Template:See also

Graph neural networks are one of the main building blocks of AlphaFold, an artificial intelligence program developed by Google's DeepMind for solving the protein folding problem in biology. AlphaFold achieved first place in several CASP competitions.[45][46][40]

Social networks

Template:See also Social networks are a major application domain for GNNs due to their natural representation as social graphs. GNNs are used to develop recommender systems based on both social relations and item relations.[47][16]

Combinatorial optimization

Template:See also GNNs are used as fundamental building blocks for several combinatorial optimization algorithms.[48] Examples include computing shortest paths or Eulerian circuits for a given graph,[39] deriving chip placements superior or competitive to handcrafted human solutions,[49] and improving expert-designed branching rules in branch and bound.[50]

Cyber security

Template:See also When viewed as a graph, a network of computers can be analyzed with GNNs for anomaly detection. Anomalies within provenance graphs often correlate to malicious activity within the network. GNNs have been used to identify these anomalies on individual nodes[51] and within paths[52] to detect malicious processes, or on the edge level[53] to detect lateral movement.

Water distribution networks

Template:See also

Water distribution systems can be modelled as graphs, being then a straightforward application of GNN. This kind of algorithm has been applied to water demand forecasting,[54] interconnecting District Measuring Areas to improve the forecasting capacity. Other application of this algorithm on water distribution modelling is the development of metamodels.[55]

References

Template:Reflist

Template:Artificial intelligence navbox

  1. ↑ Cite error: Invalid <ref> tag; no text was provided for refs named wucuipeizhao2022
  2. ↑ 2.0 2.1 2.2 2.3 Cite error: Invalid <ref> tag; no text was provided for refs named scarselli2009
  3. ↑ 3.0 3.1 3.2 Cite error: Invalid <ref> tag; no text was provided for refs named micheli2009
  4. ↑ Cite error: Invalid <ref> tag; no text was provided for refs named sanchez2021
  5. ↑ Cite error: Invalid <ref> tag; no text was provided for refs named daigavane2021
  6. ↑ Template:Cite journal
  7. ↑ Template:Citation
  8. ↑ Template:Cite journal
  9. ↑ 9.0 9.1 9.2 9.3 Cite error: Invalid <ref> tag; no text was provided for refs named kipf2016
  10. ↑ 10.0 10.1 Cite error: Invalid <ref> tag; no text was provided for refs named hamilton2017
  11. ↑ 11.0 11.1 11.2 11.3 Cite error: Invalid <ref> tag; no text was provided for refs named velickovic2018
  12. ↑ 12.0 12.1 12.2 12.3 12.4 12.5 12.6 Cite error: Invalid <ref> tag; no text was provided for refs named bronstein2021
  13. ↑ 13.0 13.1 Cite error: Invalid <ref> tag; no text was provided for refs named hajij2022
  14. ↑ 14.0 14.1 14.2 Cite error: Invalid <ref> tag; no text was provided for refs named velickovic2022
  15. ↑ Cite error: Invalid <ref> tag; no text was provided for refs named wuchen2023
  16. ↑ 16.0 16.1 Cite error: Invalid <ref> tag; no text was provided for refs named ying2018
  17. ↑ Cite error: Invalid <ref> tag; no text was provided for refs named stanforddata
  18. ↑ Template:Cite journal
  19. ↑ Cite error: Invalid <ref> tag; no text was provided for refs named gilmer2017
  20. ↑ Template:Cite journal
  21. ↑ Cite error: Invalid <ref> tag; no text was provided for refs named qasim2019
  22. ↑ Cite error: Invalid <ref> tag; no text was provided for refs named li2018
  23. ↑ Cite error: Invalid <ref> tag; no text was provided for refs named fey2019
  24. ↑ Cite error: Invalid <ref> tag; no text was provided for refs named tfgnn2022
  25. ↑ Template:Cite web
  26. ↑ Cite error: Invalid <ref> tag; no text was provided for refs named jraph2022
  27. ↑ Cite error: Invalid <ref> tag; no text was provided for refs named Lucibello2021GNN
  28. ↑ Template:Citation
  29. ↑ 29.0 29.1 29.2 Cite error: Invalid <ref> tag; no text was provided for refs named gao2019
  30. ↑ 30.0 30.1 Cite error: Invalid <ref> tag; no text was provided for refs named lee2019
  31. ↑ 31.0 31.1 Cite error: Invalid <ref> tag; no text was provided for refs named lui2022
  32. ↑ Cite error: Invalid <ref> tag; no text was provided for refs named douglas2011
  33. ↑ Cite error: Invalid <ref> tag; no text was provided for refs named xu2019
  34. ↑ Cite error: Invalid <ref> tag; no text was provided for refs named bronstein2021-2
  35. ↑ Cite error: Invalid <ref> tag; no text was provided for refs named grady2011discrete
  36. ↑ Cite error: Invalid <ref> tag; no text was provided for refs named chen2021
  37. ↑ 37.0 37.1 Cite error: Invalid <ref> tag; no text was provided for refs named alon2021
  38. ↑ Cite error: Invalid <ref> tag; no text was provided for refs named xu2021
  39. ↑ 39.0 39.1 39.2 Cite error: Invalid <ref> tag; no text was provided for refs named li2016
  40. ↑ 40.0 40.1 Cite error: Invalid <ref> tag; no text was provided for refs named xu2018
  41. ↑ 41.0 41.1 Template:Citation
  42. ↑ Template:Cite journal
  43. ↑ 43.0 43.1 Template:Cite journal
  44. ↑ Template:Cite journal
  45. ↑ Cite error: Invalid <ref> tag; no text was provided for refs named guardian2018
  46. ↑ Cite error: Invalid <ref> tag; no text was provided for refs named mit2020
  47. ↑ Cite error: Invalid <ref> tag; no text was provided for refs named fan2019
  48. ↑ Cite error: Invalid <ref> tag; no text was provided for refs named cappart2021
  49. ↑ Cite error: Invalid <ref> tag; no text was provided for refs named mirhoseini2021
  50. ↑ Cite error: Invalid <ref> tag; no text was provided for refs named gasse2019
  51. ↑ Template:Cite journal
  52. ↑ Template:Cite journal
  53. ↑ Template:Cite journal
  54. ↑ Template:Cite journal
  55. ↑ Template:Cite journal