Skip to content

pytorch bindings for optimized knn and aggregation kernels

License

Notifications You must be signed in to change notification settings

tklijnsma/pytorch_cmspepr

This branch is 1 commit ahead of, 1 commit behind cms-pepr/pytorch_cmspepr:main.

Folders and files

NameName
Last commit message
Last commit date

Latest commit

4d61f15 · Dec 12, 2023

History

28 Commits
Jul 6, 2021
Sep 29, 2023
Oct 3, 2023
Dec 12, 2023
Dec 12, 2023
Sep 29, 2023
Jul 6, 2021
Jul 6, 2021
Oct 3, 2023
Dec 12, 2023

Repository files navigation

pytorch_cmspepr

pytorch bindings for optimized knn and aggregation kernels

Example

>>> import torch
>>> import torch_cmspepr

# Two events with 5 nodes and 4 nodes, respectively.
# Nodes here are on a diagonal line in 2D, with d^2 = 0.02 between them.
>>> nodes = torch.FloatTensor([
    # Event 0
    [.1, .1],
    [.2, .2],
    [.3, .3],
    [.4, .4],
    [100., 100.],
    # Event 1
    [.1, .1],
    [.2, .2],
    [.3, .3],
    [.4, .4]
    ])
# Designate which nodes belong to which event
>>> batch = torch.LongTensor([0,0,0,0,0,1,1,1,1])

# Generate edges: k=2, max_radius^2 of 0.04
>>> torch_cmspepr.knn_graph(nodes, 2, batch, max_radius=.2)
tensor([[0, 1, 1, 2, 2, 3, 5, 6, 6, 7, 7, 8],
        [1, 0, 2, 1, 3, 2, 6, 5, 7, 6, 8, 7]])

# Generate edges: k=3 with loops allowed
>>> torch_cmspepr.knn_graph(nodes, 3, batch, max_radius=.2, loop=True)
tensor([[0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8],
        [0, 1, 1, 0, 2, 2, 1, 3, 3, 2, 4, 5, 6, 6, 5, 7, 7, 6, 8, 8, 7]])

# If CUDA is available, the CUDA version of the knn_graph is used automatically:
>>> gpu = torch.device('cuda') 
>>> torch_cmspepr.knn_graph(nodes.to(gpu), 2, batch.to(gpu), max_radius=.2)
tensor([[0, 1, 1, 2, 2, 3, 5, 6, 6, 7, 7, 8],
        [1, 0, 2, 1, 3, 2, 6, 5, 7, 6, 8, 7]], device='cuda:0')

Installation and requirements

v1 is tested with CUDA 11.7 and pytorch 2.0. You should verify nvcc is available:

$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Jun__8_16:49:14_PDT_2022
Cuda compilation tools, release 11.7, V11.7.99
Build cuda_11.7.r11.7/compiler.31442593_0

Also a gcc version of 5 or higher is recommended.

The package is not (yet) available on PyPI, so local installation is at the moment the preferred installation method:

git clone [email protected]:cms-pepr/pytorch_cmspepr.git
cd pytorch_cmspepr
pip install -e .

Installing only the CPU or CUDA extensions is supported:

FORCE_CPU_ONLY=1 pip install -e .  # Only compile C++ extensions
FORCE_CUDA_ONLY=1 pip install -e .  # Only compile CUDA extenstions
FORCE_CUDA=1 pip install -e .  # Try to compile CUDA extenstion even if no device found

If you only want to test the compilation of the extensions:

python setup.py develop

Containerization

It is recommended to install and run inside a container. At the time of writing (29 Sep 2023), the pytorch/pytorch:2.0.0-cuda11.7-cudnn8-devel docker container works well.

Example Singularity instructions:

singularity pull docker://pytorch/pytorch:2.0.0-cuda11.7-cudnn8-devel
singularity run --nv pytorch_2.0.0-cuda11.7-cudnn8-devel.sif

And then once in the container:

export PYTHONPATH="/opt/conda/lib/python3.10/site-packages"
python -m venv env
source env/bin/activate
pip install torch_geometric
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cu117.html # Make sure to pick the right torch and CUDA versions here
git clone [email protected]:cms-pepr/pytorch_cmspepr.git
cd pytorch_cmspepr
pip install -e .

Tests

pip install pytest
pytest tests

Performance

The following profiling code can be used:

import time
import torch
import torch_cmspepr
import torch_cluster
gpu = torch.device('cuda')

def gen(cuda=False):
    # 10k nodes with 5 node features
    x = torch.rand((10000, 5))
    # Split nodes over 4 events with 2500 nodes/evt
    batch = torch.repeat_interleave(torch.arange(4), 2500)
    if cuda: x, batch = x.to(gpu), batch.to(gpu)
    return x, batch

def profile(name, unit):
    t0 = time.time()
    for _ in range(100): unit()
    print(f'{name} took {(time.time() - t0)/100.} sec/evt')

def cpu_cmspepr():
    x, batch = gen()
    torch_cmspepr.knn_graph(x, k=10, batch=batch)
profile('CPU (torch_cmspepr)', cpu_cmspepr)

def cpu_cluster():
    x, batch = gen()
    torch_cluster.knn_graph(x, k=10, batch=batch)
profile('CPU (torch_cluster)', cpu_cmspepr)

def cuda_cmspepr():
    x, batch = gen(cuda=True)
    torch_cmspepr.knn_graph(x, k=10, batch=batch)
profile('CUDA (torch_cmspepr)', cuda_cmspepr)

def cuda_cluster():
    x, batch = gen(cuda=True)
    torch_cluster.knn_graph(x, k=10, batch=batch)
profile('CUDA (torch_cluster)', cpu_cmspepr)

On a NVIDIA Tesla P100 with 12GB of RAM, this produces:

CPU (torch_cmspepr) took 0.22623349189758302 sec/evt
CPU (torch_cluster) took 0.2259768319129944 sec/evt
CUDA (torch_cmspepr) took 0.026673252582550048 sec/evt
CUDA (torch_cluster) took 0.22262062072753908 sec/evt

About

pytorch bindings for optimized knn and aggregation kernels

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 55.1%
  • C++ 23.3%
  • Cuda 18.4%
  • CMake 2.9%
  • C 0.3%