Overview
During my undergraduate studies at IIT Bombay, I completed a research internship at the National University of Singapore with Prof. Bryan Hooi’s group. The focus was temporal graph representation learning — specifically, using attention mechanisms to capture how relationships evolve over time in dynamic networks, and using that representation to predict which edges will form next.
This was my first exposure to cutting-edge ML research, and it shaped how I think about representation learning and temporal modeling.
The Problem
Most graph ML work at the time treated graphs as static — you had a fixed set of nodes and edges, and learned embeddings from that structure. But real-world networks evolve. Social connections form and dissolve. Financial transactions create temporal patterns. Information spreads through a network in time-ordered sequences.
For link prediction — predicting which connections will form in the future — the temporal dynamics are not noise, they’re signal. A model that ignores when interactions happened is throwing away critical information.
Why It Mattered
Temporal link prediction has applications in social network recommendation, fraud detection (identifying unusual connection patterns), academic collaboration prediction, and dynamic knowledge graph completion. The research contributed to the broader question of how to represent time-evolving graphs in a way that is useful for downstream prediction tasks.
Data & Inputs
- College Messages dataset: A temporal graph of direct messages exchanged between students at a US college, with timestamps. The task is to predict which node pairs will interact in the future.
- Additional datasets for cross-validation of the approach
- Node features where available; structural features derived from graph topology otherwise
The temporal nature of the data was the key feature — the timestamp on each interaction was the signal we were trying to leverage.
Approach
I implemented and benchmarked multiple temporal graph representation methods:
- Node2Vec: Static random walk embeddings — baseline
- TMF (Temporal Matrix Factorization): Time-aware matrix factorization
- CTDNE (Continuous-Time Dynamic Network Embeddings): Random walks biased by temporal information
- BANE (Binarized Attributed Network Embedding): Attribute-aware embeddings
The key contribution was developing a temporal attention model that could:
- Encode the sequence of interactions a node has participated in
- Weight recent interactions more heavily using an attention mechanism
- Combine temporal interaction history with structural graph features for link prediction
The attention mechanism was the critical design choice — it let the model learn which past interactions were most predictive for future connections, rather than assuming recency was always most important.
Engineering & Implementation
- Implemented in PyTorch from scratch — no existing graph learning libraries had the exact temporal attention architecture I needed
- Negative sampling strategy for training: random negative sampling with hard negative mining near the decision boundary
- Evaluation protocol: strict temporal splitting — training only on edges before a cutoff time, evaluating on edges after
- Hyperparameter tuning via grid search on validation split
- Baseline reimplementations for fair comparison using identical train/test splits
The implementation discipline — same data splits, same negative sampling, same evaluation metric — was essential for meaningful comparisons.
Results & Impact
- 86% AUC on College Messages dataset for link prediction
- Outperformed all benchmarked baselines (node2vec, TMF, CTDNE, BANE)
- Demonstrated that temporal attention captures patterns that static methods miss
- Presented findings to Prof. Hooi’s research group
Limitations & What I’d Do Differently
The attention model was trained end-to-end on the link prediction objective — it might benefit from pre-training on a self-supervised temporal prediction task before fine-tuning on link prediction.
The model assumed that all nodes had sufficient temporal interaction history to learn meaningful representations. Nodes that are new to the network (cold start) needed special handling that wasn’t part of this implementation.
Stack
Python, PyTorch, NetworkX, NumPy, Pandas, Matplotlib