Skip to main content

Temporal GNN Lightweight Framework

Project description

TGLite - A Framework for Temporal GNNs

TGLite is a lightweight framework that provides core abstractions and building blocks for practitioners and researchers to implement efficient TGNN models. TGNNs, or Temporal Graph Neural Networks, learn node embeddings for graphs that dynamically change over time by jointly aggregating structural and temporal information from neighboring nodes. TGLite employs an abstraction called a TBlock to represent the temporal graph dependencies when aggregating from neighbors, with explicit support for capturing temporal details like edge timestamps, as well as composable operators and optimizations. Compared to prior art, TGLite can outperform the TGL framework by up to 3x in terms of training time.

End-to-end training epoch time comparison on an Nvidia A100 GPU.

Installation

See our documentation for instructions on how to install the TGLite binaries, as well as examples and references for supported functionality. To install from source or for local development, go to the Building from source session, it also explains how to run examples.

Getting Started

TGLite is currently designed to be used with PyTorch as a training backend, typically with GPU devices. A TGNN model can be defined and trained in the usual way using PyTorch, with the computations constructed using a mix of PyTorch functions and operators/optimizations from TGLite. Below is a simple example (not a real network architecture, just for demonstration purposes):

import torch
import tglite as tg

class TGNN(torch.nn.Module):
    def __init__(self, ctx: tg.TContext, dim_node=100, dim_time=100):
        super().__init__()
        self.ctx = ctx
        self.linear = torch.nn.Linear(dim_node + dim_time, dim_node)
        self.sampler = tg.TSampler(num_nbrs=10, strategy='recent')
        self.encoder = tg.nn.TimeEncode(dim_time)

    def forward(self, batch: tg.TBatch):
        blk = batch.block(self.ctx)
        blk = tg.op.dedup(blk)
        blk = self.sampler.sample(blk)
        blk.srcdata['h'] = blk.srcfeat()
        return tg.op.aggregate(blk, self.compute, key='h')

    def compute(self, blk: tg.TBlock):
        feats = self.encoder(blk.time_deltas())
        feats = torch.cat([blk.srcdata['h'], feats], dim=1)
        embeds = self.linear(feats)
        embeds = tg.op.edge_reduce(blk, embeds, op='sum')
        return torch.relu(embeds)

graph = tg.from_csv(...)
ctx = tg.TContext(graph)
model = TGNN(ctx)
train(model)

The example model is defined to first construct the graph dependencies for nodes in the current batch of edges. The dedup() optimization is applied before sampling for 10 recent neighbors. Node embeddings are computed by simply combining node and time features, applying a linear layer and summing across neighbors. More complex computations and aggregations, such as temporal self-attention often used with TGNNs, can be defined using the provided building blocks.

Publication

If you find TGLite useful, please consider attributing to the following citation:

@inproceedings{wang2024tglite,
  author = {Wang, Yufeng and Mendis, Charith},
  title = {TGLite: A Lightweight Programming Framework for Continuous-Time Temporal Graph Neural Networks},
  year = {2024},
  booktitle = {Proceedings of the 29th ACM International Conference on Architectural Support for Programming Languages and Operating Systems, Volume 2},
  doi = {10.1145/3620665.3640414}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tglite-0.0.4.tar.gz (35.7 kB view details)

Uploaded Source

File details

Details for the file tglite-0.0.4.tar.gz.

File metadata

  • Download URL: tglite-0.0.4.tar.gz
  • Upload date:
  • Size: 35.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.7.16

File hashes

Hashes for tglite-0.0.4.tar.gz
Algorithm Hash digest
SHA256 6c626f4f26b75e7e882521f743186a65e15a1525aa46852c818f0ab449ee8744
MD5 870406409bd6edba2cc6c68dff3cae43
BLAKE2b-256 ff53bba05766b833b6672a1773c68b1247685d6c4c606a79f6a2c170b46268d7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page