NIPS2018 《DiffPool:Hierarchical Graph Representation Learning with Differentiable Pooling》 Reading Notes
论文地址: DiffPool Introduction 传统的GNN算法在Node-level的任务如节点分类、链路预测上有着较好的效果。但是,现有的GNN方法由于其存在平面化的局限性,因此无法学习图的层级表示(意味着无法预测整个图的标签),顾无法实现图分类任务。举个栗子,一个Graph,可以分成600个subgraph,每个节点都存在于其中的某个subgraph(一个节点只存在于一个subgraph中),每个subgraph拥有一个标签,如何预测subgraph的标签是这篇文章主要想解决的问题。传统的GNN的图分类方法都是为Graph中的所有节点生成Embedding,然后将对这些Embedding做全局聚合(池化),如简单的把属于同一个subgraph的节点求和或者输入到MLP中生成一个标签向量来表示整个subgraph,但是这样可能忽略的图的层级结构信息。 本文提出了一种端到端的可微可微图池化模块DiffPool,原理如下图所示: 在深度GNN中的每层中为节点学习可微的软簇分配,将节点映射到簇中,这些簇作为新的节点作为下一层GNN的输入。上图的Original Network部分是一个Subgraph,传统的方法是直接求出这个Subgraph中每个节点的Embedding,然后相加或输入到一个神经网络中,得到一个预测向量,这种方法可以称为“全局池化”。DiffPool中,假设第$l$层的输入是$1000$个簇(如果是第一层输入就是1000个节点),我们先设置第$l+1$层需要输入的簇的个数(假设为$100$),也就是第$l$层输出的簇个数,然后在$l$层中通过一个分配矩阵将$1000$个簇做合并,合并成100个“节点”,然后将这100个节点输入到$l+1$层中,最后图中的节点数逐渐减少,最后,图中的节点只有一个,这个节点的embedding就是整个图的表示,然后将图输入到一个多层感知机MLP中,得到预测向量,在于真值的one-hot向量做cross-entropy,得到Loss。 Model:DiffPool 一个Graph表示为$\mathcal{G} = (A,F)$,其中$A \in {0,1}^{n \times n}$是Graph的邻接矩阵,$F \in \mathbb{R}^{n \times d}$表示节点特征矩阵,每个节点有$d$维的特征。给定一个带标签的子图集$\mathcal{D}=\left\{\left(G_{1}, y_{1}\right),\left(G_{2}, y_{2}\right), \ldots\right\}$, 其中 $y_{i} \in \mathcal{Y}$表示每个子图$G_i \in \mathcal{G}$的标签,任务目标是寻找映射$f: \mathcal{G} \rightarrow \mathcal{Y}$,将图映射到标签集。我们需要一个过程来将每个子图转化为一个有限维度的向量$\mathbb{R}^D$。 Graph Neural Networks 一般,GNN可以表示成"Message Passing"框架: $$ H^{(k)}=M\left(A, H^{(k-1)} ; \theta^{(k)}\right) $$ 其中$H^{(k)} \in \mathbb{R}^{n \times d}$表示GNN迭代$k$次后的node embedding,$M$是一个Message扩散函数,由邻接矩阵$A$和一个可训练的参数$\theta^{(k)}$决定。$H^{(k-1)}$是由前一个message passing过程生成的node embedding。当$k = 1$时,第一个GNN的输入为$H^{(0)}$是原始的节点特征$H^{(0)} = F$。 GNN的一个主要目标是设计一个Message Passage函数$M$,GCN(kipf.2016)是一种流行的GNN,$M$的实现方式是将线性变换和ReLU非线性激活结合起来: $$ H^{(k)}=M\left(A, H^{(k-1)} ; W^{(k)}\right)=\operatorname{ReLU}\left(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} H^{(k-1)} W^{(k-1)}\right) $$ 其中,$\tilde{A} = A+I$是一个加上自环的邻接矩阵,$\tilde{D}=\sum_{j} \tilde{A}_{i j}$是$\tilde{A}$的度矩阵,$W^{(k)} \in \mathbb{R}^{d \times d}$是一个可训练的权重矩阵,$W$与节点个数以及每个节点的度无关,可以看做一个特征增强矩阵,用来规定GCN的输出维度。 ...