Inductive bias in GCN: a spectral perspective

author: Marc Lelarge, course: dataflowr

run the code or open it in Colab

date: April 15, 2021

Here, we focus on Graph Convolution Networks (GCN) introduced by Kipf and Welling in their paper Semi-Supervised Classification with Graph Convolutional Networks. The GCN layer is one of the simplest Graph Neural Network layer defined by:

hi(+1)=1di+1hi()W()+jihj()W()(di+1)(dj+1), h_i^{(\ell+1)} = \frac{1}{d_i+1}h_i^{(\ell)}W^{(\ell)} + \sum_{j\sim i} \frac{h_j^{(\ell)}W^{(\ell)}}{\sqrt{(d_i+1)(d_j+1)}},

where iji\sim j means that nodes ii and jj are neighbors in the graph GG, did_i and djd_j are the respective degrees of nodes ii and jj (i.e. their number of neighbors in the graph) and hi()h_i^{(\ell)} is the embedding representation of node ii at layer \ell and W()W^{(\ell)} is a trainable weight matrix of shape [size_input_feature, size_output_feature].

The inductive bias of a learning algorithm is the set of assumptions that the learner uses to predict outputs of given inputs that it has not encountered. For GCN, we argue that the inductive bias can be formulated as a simple spectral property of the algorithm: GCN acts as low-pass filters. This arguments follows from recent works Simplifying Graph Convolutional Networks by Wu, Souza, Zhang, Fifty, Yu, Weinberger and Revisiting Graph Neural Networks: All We Have is Low-Pass Filters by NT and Maehara.

Here we will study a very simple case and relate the inductive bias of GCN to the property of the Fiedler vector of the graph. We'll consider the more general setting in a subsequent post.

Notations

We consider undirected graphs G=(V,E)G=(V,E) with nn vertices denoted by i,j[n]i,j \in [n]. iji\sim j means that nodes ii and jj are neighbors in GG, i.e. {i,j}E\{i,j\}\in E. We denote by AA its adjacency matrix and by DD the diagonal matrix of degrees. The vector of degrees is denoted by dd so that d=A1d= A1. The components of a vector xRnx\in \mathbb{R}^n are denoted xix_i but sometimes it is convenient to see the vector xx as a function from VV to R\mathbb{R} and use the notation x(i)x(i) instead of xix_i.

Community detection in the Karate Club

We'll start with an unsupervised problem: given one graph, find a partition of its node in communities. In this case, we make the hypothesis that individuals tend to associate and bond with similar others, which is known as homophily.

To study this problem, we will focus on the Zachary's karate club and try to recover the split of the club from the graph of connections. The pytorch-geometric library will be very convenient.

Note that GCN are not appropriate in an unsupervised setting as no learning is possible without any label on the vertices. However, this is not a problem here as we will not train the GCN! In more practical settings, GCN are used in a semi-supervised setting where a few labels are revealed for a few nodes (more on this in the section with the Cora dataset).

from torch_geometric.datasets import KarateClub

dataset = KarateClub()
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
Dataset: KarateClub():
    ======================
    Number of graphs: 1
    Number of features: 34
    Number of classes: 4

As shown above, the default number of classes (i.e. subgroups) in pytorch-geometric is 4, for simplicity, we'll focus on a partition in two groups only:

data = dataset[0] 
biclasses = [int(b) for b in ((data.y == data.y[0]) + (data.y==data.y[5]))]

We will use networkx for drawing the graph. On the picture below, the color of each node is given by its "true" class.

from torch_geometric.utils import to_networkx

G = to_networkx(data, to_undirected=True)
visualize(G, color=biclasses)

png

The Kernighan Lin algorithm is a heuristic algorithm for finding partitions of graphs and the results below show that it captures well our homophily assumption. Indeed the algorithm tries to minimize the number of crossing edges between the 2 communities.

c1,c2 = nx.algorithms.community.kernighan_lin_bisection(G)
classes_kl = [0 if i in c1 else 1 for i in range(34)]
visualize(G, color=classes_kl, cmap="Set2")

png

def acc(predicitions, classes):
    n_tot = len(classes)
    acc = np.sum([int(pred)==cla for pred,cla in zip(predicitions,classes)])
    return max(acc, n_tot-acc), n_tot

n_simu = 1000
all_acc = np.zeros(n_simu)
for i in range(n_simu):
    c1,c2 = nx.algorithms.community.kernighan_lin_bisection(G)
    classes_kl = [0 if i in c1 else 1 for i in range(34)]
    all_acc[i],_ = acc(classes_kl, biclasses)

The algorithm is not deterministic but performs poorly only a small fractions of the trials as shown below in the histogram for the number of correct predictions (note there are 3434 nodes in total):

bin_list = range(17,35)
_ = plt.hist(all_acc, bins=bin_list,rwidth=0.8)

png

Inductive bias for GCN

To demonstrate the inductive bias for the GCN architecture, we consider a simple GCN with 3 layers and look at its performance without any training. To be more precise, the GCN takes as input the graph and outputs a vector (xi,yi)R2(x_i,y_i)\in \mathbb{R}^2 for each node ii.

import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv
import torch.nn.functional as F

class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(data.num_nodes, 4)# no feature...
        self.conv2 = GCNConv(4, 4)
        self.conv3 = GCNConv(4, 2)

    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index)
        h = h.tanh()
        h = self.conv2(h, edge_index)
        h = h.tanh()
        h = self.conv3(h, edge_index)
        return h

torch.manual_seed(12345)
model = GCN()
print(model)
GCN(
      (conv1): GCNConv(34, 4)
      (conv2): GCNConv(4, 4)
      (conv3): GCNConv(4, 2)
    )

Below, we draw all the points (xi,yi)(x_i,y_i) for all nodes ii of the graph. The vertical and horizontal lines are the medians of the xix_i's and yiy_i's respectively. The colors are the true classes. We see that without any learning the points are almost separated in the lower-left and upper-right corners according to their community!

h = model(data.x, data.edge_index)
visualize(h, color=biclasses)

png

Note that by drawing the medians above, we enforce a balanced partition of the graph. Below, we draw the original graph where the color for node ii depends if xix_i is larger or smaller than the median.

color_out = color_from_vec(h[:,0])
visualize(G, color=color_out, cmap="Set2")

png

We made only a few errors without any training!

Our result might depend on the particular initialization, so we run a few more experiments below:

_ = plt.hist(all_acc, bins=bin_list,rwidth=0.8)

png

We see that on average, we have an accuracy over 24/3424/34 which is much better than chance!

We now explain why the GCN architecture with random initialization achieves such good results.

Spectral analysis of GCN

We start by rewriting the equation (1) in matrix form:

h(+1)=Sh()W(), h^{(\ell+1)} = S h^{(\ell)}W^{(\ell)} ,

where the scaled adjacency matrix SRn×nS\in\mathbb{R}^{n\times n} is defined by Sij=1(di+1)(dj+1)S_{ij} = \frac{1}{\sqrt{(d_i+1)(d_j+1)}} if iji\sim j or i=ji=j and Sij=0S_{ij}=0 otherwise and h()Rn×f()h^{(\ell)}\in \mathbb{R}^{n\times f^{(\ell)}} is the embedding representation of the nodes at layer \ell and W()W^{(\ell)} is the learnable weight matrix in Rf()×f(+1)\mathbb{R}^{f^{(\ell)}\times f^{(\ell+1)}}.

To simplify, we now ignore the tanhtanh non-linearities in our GCN above so that we get

y=S3W(1)W(2)W(3), y = S^3 W^{(1)}W^{(2)}W^{(3)},

where W(1)Rn,4W^{(1)}\in \mathbb{R}^{n,4}, W(2)R4,4W^{(2)}\in \mathbb{R}^{4,4} and W(3)R4,2W^{(3)}\in \mathbb{R}^{4,2} and yRn×2y\in \mathbb{R}^{n\times 2} is the output of the network (note that data.x is the identity matrix here). The vector W(1)W(2)W(3)Rn×2W^{(1)}W^{(2)}W^{(3)}\in \mathbb{R}^{n\times 2} is a random vector with no particular structure so that to understand the inductive bias of our GCN, we need to understand the action of the matrix S3S^3.

The matrix SS is symmetric with eigenvalues ν1ν2...\nu_1\geq \nu_2\geq ... and associated eigenvectors U1,U2,...U_1,U_2,... We can show that indeed 1=ν1>ν2...νn11=\nu_1>\nu_2\geq ...\geq \nu_n\geq -1 by applying Perron-Frobenius theorem. This is illustrated below.

from numpy import linalg as LA

A = nx.adjacency_matrix(G).todense()
A_l = A + np.eye(A.shape[0],dtype=int)
deg_l = np.dot(A_l,np.ones(A.shape[0]))
scaling = np.dot(np.transpose(1/np.sqrt(deg_l)),(1/np.sqrt(deg_l)))
S = np.multiply(scaling,A_l)
eigen_values, eigen_vectors = LA.eigh(S)

_ = plt.hist(eigen_values, bins = 40)

png

But the most interesting fact for us here concerns the eigenvector U2U_2 associated with the second largest eigenvalue which is also known as the Fiedler vector.

A first result due to Fiedler tells us that the subgraph induced by GG on vertices with U2(i)0U_2(i)\geq 0 is connected. This is known as Fiedler’s Nodal Domain Theorem (see Chapter 24 in Spectral and Algebraic Graph Theory by Daniel Spielman). We check this fact below both on U2U_2 and U2-U_2 so that here we get a partition of our graph in 2 connected graphs (since we do not have any node ii with U2(i)=0U_2(i)=0).

fiedler = np.array(eigen_vectors[:,-2]).squeeze()
H1 = G.subgraph([i for (i,f) in enumerate(fiedler) if f>=0])
H2 = G.subgraph([i for (i,f) in enumerate(fiedler) if -f>=0])
H = nx.union(H1,H2)
plt.figure(figsize=(7,7))
plt.xticks([])
plt.yticks([])
nx.draw_networkx(H, pos=nx.spring_layout(G, seed=42), with_labels=True)

png

There are many possible partitions of our graph in 2 connected graphs and we see here that the Fiedler vector actually gives a very particular partition corresponding almost exactly to the true communities!

visualize(G, color=[fiedler>=0], cmap="Set2")

png

There are actually very few errors made by Fiedler's vector. Another way to see the performance of the Fiedler's vector is to sort its entries and color each dot with its community label as done below:

fiedler_c = np.sort([biclasses,fiedler], axis=1)
fiedler_1 = [v for (c,v) in np.transpose(fiedler_c) if c==1]
l1 = len(fiedler_1)
fiedler_0 = [v for (c,v) in np.transpose(fiedler_c) if c==0]
l0 = len(fiedler_0)
plt.plot(range(l0),fiedler_0,'o',color='red')
plt.plot(range(l0,l1+l0),fiedler_1,'o',color='grey')
plt.plot([0]*35);

png

To understand why the partition of Fiedler's vector is so good requires a bit of calculus. To simplify a bit, we will make a small modification about the matrix SS and define it to be Sij=1didjS_{ij} = \frac{1}{\sqrt{d_i d_j}} if iji\sim j or i=ji=j and Sij=0S_{ij}=0 otherwise. We still denote by νi\nu_i and UiU_i its eigenvalues and eigenvectors.

Define the (normalized) Laplacian L=IdSL=Id-S so that the eigenvalues of LL are λi=1νi\lambda_i=1-\nu_i associated with the same eigenvector UiU_i as for SS. We also define the combinatorial Laplacian L=DAL^* = D-A.

We then have

xTLxxTx=xTD1/2LD1/2xxTx=yTLyyTDy, \frac{x^TLx}{x^Tx} = \frac{x^TD^{-1/2}L^* D^{-1/2}x}{x^Tx}\\ = \frac{y^T L^* y}{y^TDy},

where y=D1/2xy = D^{-1/2}x. In particular, we get:

λ2=1ν2=minxU1xTLxxTx=minydyTLyyTDy, \lambda_2 = 1-\nu_2 = \min_{x\perp U_1}\frac{x^TLx}{x^Tx}\\ = \min_{y\perp d} \frac{y^T L^* y}{y^TDy},

where dd is the vector of degrees.

Rewriting this last equation, we obtain

λ2=minij(y(i)y(j))2idiy(i)2, \lambda_2 = \min \frac{\sum_{i\sim j}\left(y(i)-y(j)\right)^2}{\sum_i d_i y(i)^2},

where the minimum is taken over vector yy such that idiyi=0\sum_i d_i y_i =0.

Now if yy^* is a vector achieving the minimum then we get the Fiedler vector (up to a sign) by U2=D1/2yD1/2yU_2 = \frac{D^{1/2}y^*}{\|D^{1/2}y^*\|}. In particular, we see that the sign of the elements of U2U_2 is the same as the sign of the elements of yy^*.

To get an intuition about (6), consider the same minimization but with the constraint that y(i){1,1}y(i) \in \{-1,1\} with the meaning that if y(i)=1y(i)=1, then node ii is in community 00 and if y(i)=1y(i)=-1 then node ii is in community 11. In this case, we see that the numerator ij(y(i)y(j))2\sum_{i\sim j}\left(y(i)-y(j)\right)^2 is the number of edges between the two communities multiplied by 4 and the denominator idiy(i)2\sum_i d_i y(i)^2 is twice the total number of edges in the graph. Hence the minimization problem is now a combinatorial problem asking for a graph partition (P1,P2)(P_1,P_2) of the graph under the constraint that iP1di=jP2dj\sum_{i\in P_1}d_i= \sum_{j\in P_2} d_j. This last condition is simply saying that the number of edges in the graph induced by GG on P1P_1 should be the same as the number of edges in the graph induced by GG on P2P_2 (note that this condition might not have a solution). Hence the minimization problem defining yy^* in (6) can be seen as a relaxation of this bisection problem. We can then expect the Fiedler vector to be close to this vector of partition (P1,P2)(P_1,P_2) at least the signs of its elements which would explain that the partition obtained thanks to the Fiedler vector is balanced and with a small cut, corresponding exactly to our goal here.

So now that we understand the Fiedler vector, we are ready to go back toi GCN. First, we check that the small simplifications made (removing non-linearities...) are really unimportant:

torch.manual_seed(12345)
model = GCN()
W1 = model.conv1.weight.detach().numpy()
W2 = model.conv2.weight.detach().numpy()
W3 = model.conv3.weight.detach().numpy()

iteration = S**3*W1*W2*W3
visualize(torch.tensor(iteration), color=biclasses)

png

OK, we get (almost) the same embeddings as with the untrained network but we now have a simpler math formula for the output:

[Y1,Y2]=S3[R1,R2], [Y_1,Y_2] = S^3 [R_1, R_2],

where R1,R2R_1,R_2 are random vectors in Rn\mathbb{R}^n and Y1,Y2Y_1, Y_2 are the output vectors in Rn\mathbb{R}^n used to do the scatter plot above.

But we can rewrite the matrix S=iνiUiUiTS = \sum_{i}\nu_i U_i U_i^T so that we get S3=iνi3UiUiTU1U1T+ν23U2U2TS^3 = \sum_{i}\nu_i^3 U_i U_i^T \approx U_1U_1^T + \nu_2^3 U_2U_2^T because all others νi<<ν23\nu_i<< \nu_2^3. Hence, we get

Y1U1TR1U1+ν23U2TR1U2Y2U1TR2U1+ν23U2TR2U2 Y_1 \approx U_1^T R_1 U_1 + \nu_2^3 U_2^T R_1 U_2 \\ Y_2 \approx U_1^T R_2 U_1 + \nu_2^3 U_2^T R_2 U_2

Recall that the signal about the communities is in the U2U_2 vector so that we can rewrite it more explicitly as

Y1(i)a1+b1U2(i)Y2(i)a2+b2U2(i), Y_1(i) \approx a_1 + b_1 U_2(i)\\ Y_2(i) \approx a_2 + b_2 U_2(i),

where a1,a2,b1,b2a_1,a_2,b_1,b_2 are random numbers of the same magnitude. In other words, the points (Y1(i),Y2(i))(Y_1(i), Y_2(i)) should be approximately aligned on a line and the two extremes of the corresponding segment should correspond to the 2 communities U2(i)0U_2(i)\geq 0 or U2(i)0U_2(i)\leq 0.

from sklearn import linear_model
from sklearn.metrics import mean_squared_error
regr = linear_model.LinearRegression()
regr.fit(iteration[:,0].reshape(-1, 1), iteration[:,1])
plt.figure(figsize=(7,7))
plt.xticks([])
plt.yticks([])
h = np.array(iteration)
plt.scatter(h[:, 0], h[:, 1], s=140, c=biclasses, cmap="Set1")
plt.plot(h[:, 0],regr.predict(iteration[:,0].reshape(-1, 1)))

png

Below, we run a few simulations and compute the mean squared error between the points and the best interpolating line for the random input [R1,R2][R_1,R_2] in blue and for the output [Y1,Y2][Y_1, Y_2] in orange (that you can hardly see because the error is much smaller). Our theory seems to be nicely validated ;-)

_ = plt.hist(base, bins = 34)
_ = plt.hist(coef, bins = 34)

png

Here we studied a very simple case but more general statements are possible as we will see in a subsequent post. To generalize the analysis made about Fiedler vector requires a little bit of spectral graph theory as explained in the module on spectral Graph Neural Networks, see Deep Learning on graphs (2)

Follow on twitter!

Thanks for reading!