Entwicklung & Code
Künstliche Neuronale Netze im Überblick 10: Graphneuronale Netzwerke
Neuronale Netze sind der Motor vieler Anwendungen in KI und GenAI. Diese Artikelserie gibt einen Einblick in die einzelnen Elemente. Der zehnte Teil der Serie stellt graphneuronale Netze vor.
Prof. Dr. Michael Stal arbeitet seit 1991 bei Siemens Technology. Seine Forschungsschwerpunkte umfassen Softwarearchitekturen für große komplexe Systeme (Verteilte Systeme, Cloud Computing, IIoT), Eingebettte Systeme und Künstliche Intelligenz.
Er berät Geschäftsbereiche in Softwarearchitekturfragen und ist für die Architekturausbildung der Senior-Software-Architekten bei Siemens verantwortlich.
Graphneuronale Netzwerke (Graph Neural Networks, GNN) erweitern das Konzept der neuronalen Berechnung von regulären Gitternetzen auf unregelmäßige Graphstrukturen und ermöglichen so Deep Learning für Daten, deren Beziehungen sich am besten durch Knoten und Kanten ausdrücken lassen. Ein Graph G besteht aus einer Menge von Knoten V und einer Menge von Kanten E zwischen diesen Knoten. Jeder Knoten i trägt einen Merkmalsvektor xᵢ, und das Muster der Kanten codiert, wie Informationen zwischen den Knoten fließen sollen.
Im Zentrum vieler GNNs steht ein Paradigma der Nachrichtenübermittlung. In jeder Schicht des Netzwerks sammelt jeder Knoten Informationen von seinen Nachbarn (aggregiert sie), transformiert diese aggregierte Nachricht und aktualisiert dann seine eigene Merkmalsdarstellung. Durch das Stapeln mehrerer Schichten können Knoten Informationen aus immer größeren Nachbarschaften einbeziehen.
Eine der einfachsten und am weitesten verbreiteten Formen der Graphfaltung ist das Graph Convolutional Network (GCN). Angenommen, wir haben N Knoten mit jeweils einem d-dimensionalen Merkmalsvektor, die in einer Matrix X ∈ ℝᴺˣᵈ gesammelt sind. Sei A ∈ ℝᴺˣᴺ die Adjazenzmatrix des Graphen, wobei Aᵢⱼ = 1 ist, wenn eine Kante vom Knoten i zum Knoten j besteht, und sonst Null. Um die eigenen Merkmale jedes Knotens einzubeziehen, addieren wir die Identitätsmatrix I zu A, wodurch à = A + I entsteht. Anschließend berechnen wir die Gradmatrix D̃, wobei D̃ᵢᵢ = Σⱼ Ãᵢⱼ ist. Eine einzelne GCN-Schicht transformiert X nach folgender Regel in neue Merkmale H ∈ ℝᴺˣᵈ′:
H = σ( D̃⁻½ · Ã · D̃⁻½ · X · W )
Hier ist W ∈ ℝᵈˣᵈ′ eine lernbare Gewichtungsmatrix und σ eine elementweise Nichtlinearität wie ReLU. Die symmetrische Normalisierung D̃⁻½ Ã D̃⁻½ stellt sicher, dass Nachrichten von Knoten mit hohem Grad diejenigen von Knoten mit niedrigem Grad nicht überlagern.
Nachfolgend steht eine minimale PyTorch-Implementierung einer einzelnen GCN-Schicht. Ich erkläre jeden Schritt ausführlich.
In diesem Code ist die Adjazenzmatrix ein dichter Tensor der Form (N, N). Zunächst fügen wir Selbstschleifen hinzu, indem wir mit der Identität summieren. Anschließend berechnen wir den Grad jedes Knotens, indem wir die Zeilen von à summieren. Durch Ziehen der inversen Quadratwurzel dieser Grade und Bilden einer Diagonalmatrix erhalten wir D̃⁻½. Multipliziert man D̃⁻½ mit beiden Seiten von Ã, erhält man die normalisierte Adjazenz. Die Knotenmerkmale X werden mit der Gewichtungsmatrix W multipliziert, um sie in einen neuen Merkmalsraum zu transformieren, und schließlich mischt die normalisierte Adjazenzmatrix diese transformierten Merkmale entsprechend der Graphstruktur. Eine ReLU-Aktivierung fügt Nichtlinearität hinzu.
import torch
import torch.nn as nn
class GCNLayer(nn.Module):
def __init__(self, in_features, out_features):
super(GCNLayer, self).__init__()
# Gewichtungsmatrix W der Form (in_features, out_features)
self.weight = nn.Parameter(torch.randn(in_features, out_features))
def forward(self, X, adjacency):
# Selbstschleifen hinzufügen, indem die Identitätsmatrix zur Adjazenz hinzugefügt wird
A_tilde = adjacency + torch.eye(adjacency.size(0), device=adjacency.device)
# Berechne die Gradmatrix von A_tilde
degrees = A_tilde.sum(dim=1)
# D_tilde^(-1/2) berechnen
D_inv_sqrt = torch.diag(degrees.pow(-0.5))
# Symmetrische Normalisierung: D^(-1/2) * A_tilde * D^(-1/2)
A_normalized = D_inv_sqrt @ A_tilde @ D_inv_sqrt
# Lineare Transformation: X * W
support = X @ self.weight
# Nachrichten weiterleiten: A_normalized * support
out = A_normalized @ support
# Nichtlinearität anwenden
return torch.relu(out)
Durch Stapeln mehrerer solcher Schichten verbessern sich die Ausgaben, zum Beispiel:
class SimpleGCN(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(SimpleGCN, self).__init__()
self.gcn1 = GCNLayer(input_dim, hidden_dim)
self.gcn2 = GCNLayer(hidden_dim, output_dim)
def forward(self, X, adjacency):
h1 = self.gcn1(X, adjacency)
# h1 dient als Eingabe für die nächste Schicht
h2 = self.gcn2(h1, adjacency)
return h2
Wir ermöglichen jedem Knoten, Informationen von Knoten zu sammeln, die bis zu zwei Hops entfernt sind. Für eine Klassifizierungsaufgabe, bei der jeder Knoten i ein Label yᵢ in {1,…,C} hat, können wir die endgültigen Ausgaben H ∈ ℝᴺˣᶜ mit einem Kreuzentropieverlust paaren, genau wie bei einer gewöhnlichen Klassifizierung, und durch Gradientenabstieg trainieren.
Über GCNs hinaus berechnen aufmerksamkeitsbasierte Graphennetzwerke kantenspezifische Gewichte, die einem Knoten mitteilen, wie stark er sich auf jeden Nachbarn konzentrieren soll. Das Graph Attention Network (GAT) führt lernbare Aufmerksamkeitskoeffizienten αᵢⱼ ein, die wie folgt definiert sind:
eᵢⱼ = LeakyReLU( aᵀ · [ W·xᵢ ∥ W·xⱼ ] )
αᵢⱼ = softmax_j( eᵢⱼ )
wobei ∥ die Verkettung bezeichnet, a ∈ ℝ²ᵈ′ ein lernbarer Vektor ist und softmax_j über alle Nachbarn von i normalisiert. Die Knotenaktualisierung lautet dann:
hᵢ′ = σ( Σⱼ αᵢⱼ · W·xⱼ ).
Die Implementierung einer GAT-Schicht von Grund auf folgt dem gleichen Muster der Nachrichtenübermittlung, erfordert jedoch die Berechnung von eᵢⱼ für jede Kante und anschließende Normalisierung. Bei großen Graphen verwendet man spärliche Darstellungen oder Bibliotheken wie PyTorch Geometric, um die Effizienz zu gewährleisten.
Graph Neural Networks eröffnen Anwendungsmöglichkeiten in der Chemie, der Analyse sozialer Netzwerke, Empfehlungssystemen und der kombinatorischen Optimierung. Sie bieten eine prinzipielle Möglichkeit, Darstellungen strukturierter Daten zu lernen, bei denen der Kontext jeder Entität durch ihre Beziehungen definiert ist.
Der nächste Teil der Serie beschäftigt sich mit Transformern, einer neuronalen Architektur, die vollständig auf Aufmerksamkeitsmechanismen basiert und ohne Rekursion und Faltung auskommt, um Sequenzen parallel zu verarbeiten.
(rme)