Accurate flood forecasting remains a critical challenge for water‐resource management, as it demands simultaneous modeling of local, time-varying runoff drivers (e.g., rainfall-induced peaks, baseflow trends) and complex spatial interactions across a river network.
Traditional data-driven approaches, such as convolutional networks and sequence-based models, ignore topological information about the region. Graph Neural Networks (GNNs), in contrast, propagate information exactly along the river network, making them ideal for learning hydrological routing. However, state-of-the-art GNN-based flood prediction models still collapse pixels to coarse catchment polygons because the cost of training explodes with graph size and higher resolution. Furthermore, most existing methods treat spatial and temporal dependencies separately, either applying GNNs solely on spatial graphs or transformers purely on temporal sequences, thus failing to simultaneously capture spatiotemporal interactions critical for accurate flood prediction.
To address these limitations, we introduce a heterogenous basin graph to represent every land and river pixel as a node connected by both physical hydrological flow directions as well as inter-catchment relationships. We also propose HydroGAT, a novel spatiotemporal network that adaptively learns both local temporal importance as well as most influential upstream locations.
Evaluated in two Midwestern US basins and across five baseline architectures, our model achieves higher NSE (up to 0.97), improved KGE (up to 0.96), and low bias (PBIAS within ± 5%) in hourly discharge prediction, while offering interpretable attention maps that reveal sparse, structured intercatchment influences. To support high-resolution basin-scale training, we develop a distributed data-parallel pipeline that scales efficiently up to 64 NVIDIA A100 GPUs on NERSC Perlmutter supercomputer, demonstrating up to 15× speedup across nodes.
If you use HydroGAT, please cite:
@inproceedings{sarkar2025hydrogat,
author = {Sarkar, Aishwarya and Hakimi, Autrin and Chen, Xiaoqiong and Huang, Hai and Lu, Chaoqun and Demir, Ibrahim and Jannesari, Ali},
title = {HydroGAT: Distributed Heterogeneous Graph Attention Transformer for Spatiotemporal Flood Prediction},
booktitle = {Proceedings of the ACM SIGSPATIAL International Conference on Advances in Geographic Information Systems (SIGSPATIAL '25)},
year = {2025},
address = {Minneapolis, MN, USA},
publisher = {Association for Computing Machinery},
doi = {10.1145/3748636.3764172},
url = {https://doi.org/10.1145/3748636.3764172},
note = {Licensed under CC BY 4.0}
}- Python 3.10 or higher
- CUDA-enabled GPU (for training and inference)
- CUDA toolkit 11.7 or compatible version (11.8, 12.1, etc.)
-
Install PyTorch (GPU version)
First, install PyTorch with CUDA support. Check PyTorch's official website for installation commands matching your CUDA version. Replace
cu117in the URL with your CUDA version (e.g.,cu118,cu121).Example for CUDA 11.7:
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117
Example for CUDA 11.8:
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
-
Install DGL (Deep Graph Library)
Install DGL compatible with your CUDA version. Adjust the CUDA version in the URL to match your PyTorch installation (e.g.,
cu117,cu118,cu121).Example for CUDA 11.7:
pip install dgl -f https://data.dgl.ai/wheels/cu117/repo.html
Example for CUDA 11.8:
pip install dgl -f https://data.dgl.ai/wheels/cu118/repo.html
-
Install other dependencies
Install the remaining dependencies:
pip install -r requirements.txt
Important Notes:
- Do NOT install
torch-cpuas it will break DGL compatibility - When installing packages via
pip install, verify that PyTorch and DGL are not being downgraded to CPU versions - Some packages from conda-forge may have dependencies that require CPU-only PyTorch - check installation output carefully
- Do NOT install
The following packages are required (see requirements.txt for versions):
- Core DL: torch (GPU), dgl (GPU)
- Scientific: numpy, scipy, scikit-learn
- Data: pandas, xarray, netCDF4
- Geospatial: geopandas, rasterio, rioxarray, pyproj
- Utilities: Pillow, matplotlib, seaborn, networkx, tqdm
Before training, you need to process your raw data (precipitation, discharge, flow direction, DEM, etc.) into a graph format. Use the graph_processor.py script:
python preprocess/graph_processor.py \
--basin_name "cedar-river-basin" \
--precip_raw_dir "data/cedar-river-basin/NCPrecipitation_fixed" \
--flow_dir_file "data/cedar-river-basin/BasinAttribution/flow_dir.nc" \
--target_node_file "data/cedar-river-basin/BasinAttribution/CedarRiverBasin_Guage.shp" \
--discharge_dir "data/cedar-river-basin/Discharge_NA_Fill" \
--discharge_raw_dir "data/discharge/data_time_series" \
--distance_file "data/cedar-river-basin/BasinAttribution/distance.nc" \
--dem_file "data/cedar-river-basin/BasinAttribution/dem.nc" \
--catchment_relationship_file "data/discharge/catchment_relationship.csv" \
--output_dir "data/cedar-river-basin/graph-with-catchment-relationship" \
--months "5,6,7,8,9" \
--years "2012,2013,2014,2015,2016,2017,2018" \
--add_catchment_relationship \
--node_coordinates \
--update_outputRequired Arguments:
--basin_name: Name of the basin--precip_raw_dir: Directory containing raw precipitation NetCDF files--flow_dir_file: Flow direction NetCDF file path--target_node_file: Target node shapefile path--discharge_dir: Directory containing processed discharge CSV files--discharge_raw_dir: Directory containing raw discharge CSV files--distance_file: Distance data NetCDF file path--dem_file: Digital Elevation Model NetCDF file path--catchment_relationship_file: Catchment relationship CSV file path--output_dir: Output directory for the processed graph
Optional Arguments:
--months: Comma-separated months (default: "5,6,7,8,9")--years: Comma-separated years (default: "2012,2013,2014,2015,2016,2017,2018")--debug: Enable debug mode (generates debug plots)--clip: Clip precipitation data using flow direction mask--flip: Flip data for coordinate system--add_catchment_relationship: Add catchment relationship edges--use_distance: Use distance information for edge weights--connect_to_main_component: Connect all nodes to main component--node_coordinates: Include node coordinates (default: True)--update_output: Write output files (if not set, runs in dry-run mode)
For training on a single GPU:
python src/dist_train.py \
--launcher slurm \
--mode train \
--graph_dir "data/cedar-river-basin/graph-with-catchment-relationship" \
--checkpoint_dir "checkpoints/cedar-river-basin" \
--model hydrogat \
--basin_name "cedar-river-basin" \
--epochs 100 \
--batch_size 8 \
--input_steps 72 \
--output_steps 72 \
--tgcn_hidden_size 32 \
--s_heads 2 \
--t_heads 2 \
--catchment_relationshipHydroGAT supports distributed training across multiple GPUs and nodes using PyTorch's DistributedDataParallel (DDP). The code supports SLURM job schedulers.
Create a SLURM job script (train.sh):
#!/bin/bash
#SBATCH --nodes=1 # Number of nodes
#SBATCH --gpus-per-node=4 # GPUs per node
#SBATCH --constraint=gpu
#SBATCH --time=08:00:00
#SBATCH --job-name=hydrogat
# Set environment variables
NNODES=${SLURM_NNODES:-1}
NGPUS_PER_NODE=4
WORLD_SIZE=$((NNODES * NGPUS_PER_NODE))
echo "Nodes=$NNODES GPUs/node=$NGPUS_PER_NODE World size=$WORLD_SIZE"
MASTER_ADDR=$(scontrol show hostname $SLURM_NODELIST | head -n1)
MASTER_PORT=29501
export MASTER_ADDR=$MASTER_ADDR
export MASTER_PORT=$MASTER_PORT
export WORLD_SIZE=$WORLD_SIZE
nvidia-smi
# Activate your environment
# source activate your_env # or conda/micromamba activate
# Run distributed training
srun --ntasks=${WORLD_SIZE} \
--ntasks-per-node=${NGPUS_PER_NODE} \
--gres=gpu:${NGPUS_PER_NODE} \
--cpu-bind=cores \
python src/dist_train.py \
--launcher slurm \
--mode train \
--graph_dir "data/cedar-river-basin/graph-with-catchment-relationship" \
--checkpoint_dir "checkpoints/cedar-river-basin" \
--model hydrogat \
--basin_name "cedar-river-basin" \
--epochs 100 \
--batch_size 8 \
--input_steps 72 \
--output_steps 72 \
--tgcn_hidden_size 32 \
--s_heads 2 \
--t_heads 2 \
--catchment_relationshipSubmit with:
sbatch train.shKey Distributed Training Parameters:
--launcher: Must beslurmto set up distributed environment correctlyWORLD_SIZE: Total number of GPUs (nodes × GPUs per node)MASTER_ADDRandMASTER_PORT: Used for inter-node communication
Note: According to the HydroGAT paper, when using --model hydrogat, the learning rate is automatically set to 0.01 (AdamW optimizer) with weight decay 1e-4 and scheduler patience 5. The default hyperparameters for HydroGAT are: 72 hours input/output windows, 32 hidden features, 2 attention heads per module (spatial and temporal), and batch size 8.
To evaluate a trained model, you need to provide the checkpoint path and ensure all model parameters match those used during training.
For evaluating on a single GPU:
python src/dist_train.py \
--launcher slurm \
--mode evaluate \
--graph_dir "data/cedar-river-basin/graph-with-catchment-relationship" \
--checkpoint "checkpoints/cedar-river-basin/cedar-river-basin_hydrogat_epoch100_lr0.001_patience5_input72_output12_batch8_controlTTSS_y2012_2018_m5-9.pt" \
--model hydrogat \
--basin_name "cedar-river-basin" \
--input_steps 72 \
--output_steps 72 \
--tgcn_hidden_size 32 \
--s_heads 2 \
--t_heads 2 \
--catchment_relationshipRequired Arguments for Evaluation:
--mode evaluate: Set mode to evaluation--checkpoint: Path to the trained model checkpoint file--graph_dir: Directory containing the processed graph file--basin_name: Name of the basin (must match training configuration)--model: Model name (must match training configuration, e.g.,hydrogat)--input_steps: Number of input timesteps (must match training)--output_steps: Number of output timesteps (must match training)--tgcn_hidden_size: Hidden size for temporal GCN (must match training)--s_heads: Number of spatial attention heads (must match training)--t_heads: Number of temporal attention heads (must match training)
Evaluation Flags:
--catchment_relationship: Include if model was trained with catchment relationships--target_only: Include if model was trained with target-only nodes--use_distance: Include if model was trained with distance information
The code supports multiple models:
hydrogat: HydroGAT (proposed model)stgcnwave: STGCN-Wave baselinergcn: Recurrent GCN baselinedcrnn: DCRNN baselinegraphwavenet: GraphWaveNet baselinegcrnn: GCRNN baseline
Use --model <model_name> to select the model.
For a complete list of training and evaluation options:
python src/dist_train.py --helpKey options include:
--input_steps: Number of input timesteps (default: 72)--output_steps: Number of output timesteps to predict (default: 12, but paper uses 72)--tgcn_hidden_size: Hidden size for temporal GCN (default: 32)--s_heads: Number of spatial attention heads for GAT (default: 2)--t_heads: Number of temporal attention heads for Transformer (default: 2)--learning_rate: Learning rate (default: 0.001, but automatically set to 0.01 for hydrogat model)--batch_size: Batch size (default: 8)--epochs: Number of training epochs (default: 30, we use 100)--catchment_relationship: Use catchment relationship edges--target_only: Only use target nodes--use_distance: Use distance information for edges