Skip to content

Commit 80cc76e

Browse files
authored
Initial commit
0 parents  commit 80cc76e

File tree

8 files changed

+308
-0
lines changed

8 files changed

+308
-0
lines changed

.github/workflows/tests.yml

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
name: Tests
2+
3+
on:
4+
push:
5+
branches: [ master ]
6+
pull_request:
7+
branches: [ master ]
8+
9+
jobs:
10+
tests:
11+
name: tests
12+
runs-on: ${{ matrix.os }}
13+
strategy:
14+
matrix:
15+
os: [ ubuntu-latest ]
16+
python-version: [3.9, 3.11]
17+
steps:
18+
- uses: actions/checkout@v2
19+
- name: Set up Python ${{ matrix.python-version }}
20+
uses: actions/setup-python@v2
21+
with:
22+
python-version: ${{ matrix.python-version }}
23+
- name: Install dependencies
24+
run: pip install nox
25+
- name: Test with pytest
26+
run:
27+
nox -s test

.gitignore

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# Compiled object files
7+
*.o
8+
*.obj
9+
10+
# Compiled dynamic libraries
11+
*.so
12+
*.dylib
13+
14+
# Compiled static libraries
15+
*.a
16+
*.lib
17+
18+
# Executables
19+
*.exe
20+
*.out
21+
22+
# Build directories
23+
build/
24+
bin/
25+
dist/
26+
27+
# IDE specific files
28+
.vscode/
29+
.idea/
30+
31+
# Dependency directories
32+
vendor/
33+
34+
# test cache files
35+
.pytest_cache/
36+
.nox
37+
38+
# egg-infos
39+
sympad.egg-info/

LICENSE

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
NO LIMIT PUBLIC LICENSE
2+
Terms and conditions for copying, distribution, modification
3+
or anything else.
4+
5+
0. No limit to do anything with this work and this license.

README.md

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Symmetric padding for Pytorch
2+
3+
Welcome to the `sympad_pytorch` repository!
4+
5+
## Description
6+
7+
This repository implements a `symmetric` padding extension for PyTorch. Symmetric padding, for example, is the default in `pywt` (https://pywavelets.readthedocs.io). Providing this functionality as a C++ module in PyTorch will allow us to speed up Wavelet computations in PyTorch.
8+
9+
## Testing and Verification
10+
11+
Follow these steps:
12+
13+
1. Clone the repository: `git clone https://github.com/your-username/cpp_pad.git`.
14+
2. Navigate to the project directory: `cd sympad_pytorch `.
15+
3. Run the tests with `nox -s test`.
16+
17+

noxfile.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""This file contains the configuration for Nox sessions.
2+
3+
This file defines the following Nox sessions:
4+
- test: Installs dependencies, runs setup.py, and runs pytest.
5+
- format: Fixes common convention problems automatically using black and isort.
6+
- lint: Checks code conventions using flake8.
7+
- typing: Checks type hints using mypy.
8+
"""
9+
10+
import nox
11+
12+
13+
@nox.session(name="test")
14+
def test_pad(session):
15+
session.install("pytest")
16+
session.install("numpy")
17+
session.install("torch")
18+
session.install("build")
19+
session.run("python", "-m", "build", "--no-isolation")
20+
session.install("dist/sympad-0.0.1.tar.gz")
21+
session.run("pytest")
22+
23+
24+
@nox.session(name="format")
25+
def format(session):
26+
"""Fix common convention problems automatically."""
27+
session.install("black")
28+
session.install("isort")
29+
session.run("isort", ".")
30+
session.run("black", ".")
31+
32+
33+
@nox.session(name="lint")
34+
def lint(session):
35+
"""Check code conventions."""
36+
session.install("flake8")
37+
session.install(
38+
"flake8-black",
39+
"flake8-docstrings",
40+
"flake8-bugbear",
41+
"flake8-broken-line",
42+
"pep8-naming",
43+
"pydocstyle",
44+
"darglint",
45+
)
46+
session.run("flake8", "test", "noxfile.py")
47+
48+
session.install("sphinx", "doc8")
49+
session.run("doc8", "--max-line-length", "120", "docs/")
50+
51+
52+
@nox.session(name="typing")
53+
def mypy(session):
54+
"""Check type hints."""
55+
session.install("torch")
56+
session.install("mypy")
57+
58+
session.run("mypy", "--ignore-missing-imports", "test")

setup.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from setuptools import Extension, setup
2+
from torch.utils import cpp_extension
3+
4+
setup(
5+
name="sympad",
6+
version="0.0.1",
7+
ext_modules=[cpp_extension.CppExtension("sympad", ["src/sympad.cpp"])],
8+
cmdclass={"build_ext": cpp_extension.BuildExtension},
9+
install_requires=["torch"]
10+
)

src/sympad.cpp

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#include <torch/extension.h>
2+
3+
#include <iostream>
4+
using namespace torch::indexing;
5+
using namespace std;
6+
7+
/**
8+
* Pads a 1-dimensional tensor symmetrically.
9+
* This is a helper function for _pad_symmetric.
10+
*
11+
* @param signal The input tensor to be padded.
12+
* @param padl The number of zeros to pad on the left side of the tensor.
13+
* @param padr The number of zeros to pad on the right side of the tensor.
14+
* @param dim The dimension along which to pad the tensor.
15+
*
16+
* @return The padded tensor.
17+
*/
18+
torch::Tensor _pad_symmetric_1d(torch::Tensor signal, pair<int, int> pad_tuple, int dim)
19+
{ int padl = pad_tuple.first;
20+
int padr = pad_tuple.second;
21+
int dimlen = signal.size(dim);
22+
// If the padding is greater than the dimension length,
23+
// pad recursively until we have enough values.
24+
if (padl > dimlen || padr > dimlen)
25+
{
26+
if (padl > dimlen)
27+
{
28+
signal = _pad_symmetric_1d(signal, make_pair(dimlen, 0), dim);
29+
padl = padl - dimlen;
30+
}
31+
else
32+
{
33+
signal = _pad_symmetric_1d(signal, make_pair(0, dimlen), dim);
34+
padr = padr - dimlen;
35+
}
36+
return _pad_symmetric_1d(signal, make_pair(padl, padr), dim);
37+
}
38+
else
39+
{
40+
vector<torch::Tensor> cat_list = {signal};
41+
if (padl > 0)
42+
{
43+
cat_list.insert(cat_list.begin(), signal.slice(dim, 0, padl).flip(dim));
44+
}
45+
if (padr > 0)
46+
{
47+
cat_list.push_back(signal.slice(dim, dimlen-padr, dimlen).flip(dim));
48+
}
49+
return torch::cat(cat_list, dim);
50+
}
51+
}
52+
53+
54+
/**
55+
* Pads a given signal symmetrically along multiple dimensions.
56+
*
57+
* @param signal The input signal to be padded.
58+
* @param pad_lists A vector of pairs representing the padding amounts for each dimension.
59+
* Each pair contains the left and right padding amounts for a dimension.
60+
* @return The padded signal.
61+
* @throws std::invalid_argument if the input signal has fewer dimensions than the specified padding dimensions.
62+
*/
63+
torch::Tensor pad_symmetric(torch::Tensor signal, vector<pair<int, int>> pad_lists)
64+
{
65+
int pad_dims = pad_lists.size();
66+
if (signal.dim() < pad_dims)
67+
{
68+
throw std::invalid_argument("not enough dimensions to pad.");
69+
}
70+
71+
int dims = signal.dim() - 1;
72+
reverse(pad_lists.begin(), pad_lists.end());
73+
for (int pos = 0; pos < pad_dims; pos++)
74+
{
75+
int current_axis = dims - pos;
76+
signal = _pad_symmetric_1d(signal, pad_lists[pos], current_axis);
77+
}
78+
return signal;
79+
}
80+
81+
PYBIND11_MODULE(sympad, m) {
82+
m.def("pad_symmetric", &pad_symmetric, "A function that pads a tensor symmetrically");
83+
m.def("_pad_symmetric_1d", &_pad_symmetric_1d, "A function that pads a tensor symmetrically in 1D.");
84+
}

test/test_sympad.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""Test the sympad modules 1D, 2D and 3D padding."""
2+
3+
import numpy as np
4+
import pytest
5+
import torch
6+
from sympad import _pad_symmetric_1d, pad_symmetric
7+
8+
9+
@pytest.mark.parametrize("size", [[5], [6], [9], [3]])
10+
@pytest.mark.parametrize(
11+
"pad_list",
12+
[(1, 4), (2, 2), (3, 3), (4, 1), (5, 0), (0, 5), (0, 0), (1, 1), (3, 1), (1, 3)],
13+
)
14+
def test_pad_symmetric_1d(size: list[int], pad_list: tuple[int, int]) -> None:
15+
"""Test high-dimensional symetric padding."""
16+
array = np.random.randint(0, 9, size=size)
17+
my_pad = _pad_symmetric_1d(torch.from_numpy(array), pad_list, 0)
18+
np_pad = np.pad(array, pad_list, mode="symmetric")
19+
assert np.allclose(my_pad.numpy(), np_pad)
20+
21+
22+
@pytest.mark.parametrize("size", [[6, 5], [5, 6], [5, 5], [9, 9], [3, 3], [4, 4]])
23+
@pytest.mark.parametrize("pad_list", [[(1, 4), (4, 1)], [(2, 2), (3, 3)]])
24+
def test_pad_symmetric_2d(size: list[int], pad_list: list[tuple[int, int]]) -> None:
25+
"""Test high-dimensional symetric padding."""
26+
array = np.random.randint(0, 9, size=size)
27+
my_pad = pad_symmetric(torch.from_numpy(array), pad_list)
28+
np_pad = np.pad(array, pad_list, mode="symmetric")
29+
assert np.allclose(my_pad.numpy(), np_pad)
30+
31+
32+
@pytest.mark.parametrize("size", [[3, 6, 5], [1, 6, 7]])
33+
@pytest.mark.parametrize(
34+
"pad_list", [[(0, 0), (1, 4), (4, 1)], [(1, 1), (2, 2), (3, 3)]]
35+
)
36+
def test_pad_symmetric_3d(size: list[int], pad_list: list[tuple[int, int]]) -> None:
37+
"""Test high-dimensional symetric padding."""
38+
array = np.random.randint(0, 9, size=size)
39+
my_pad = pad_symmetric(torch.from_numpy(array), pad_list)
40+
np_pad = np.pad(array, pad_list, mode="symmetric")
41+
assert np.allclose(my_pad.numpy(), np_pad)
42+
43+
44+
def test_pad_symmetric_small() -> None:
45+
"""Test high-dimensional symetric padding."""
46+
array = np.random.randint(0, 9, size=(2, 2))
47+
my_pad = pad_symmetric(torch.from_numpy(array), ((1, 1), (1, 1)))
48+
np_pad = np.pad(array, ((1, 1), (1, 1)), mode="symmetric")
49+
assert np.allclose(my_pad.numpy(), np_pad)
50+
51+
52+
@pytest.mark.parametrize(
53+
"pad_list",
54+
[
55+
((6, 6), (6, 6)),
56+
((5, 6), (6, 5)),
57+
((6, 5), (5, 6)),
58+
((5, 5), (5, 5)),
59+
((7, 7), (7, 7)),
60+
],
61+
)
62+
@pytest.mark.parametrize("size", [(3, 3), (4, 4), (2, 2), (1, 1), (2, 1), (2, 1)])
63+
def test_pad_symmetric_wrap(pad_list, size: tuple[int, int]) -> None:
64+
"""Test high-dimensional symetric padding."""
65+
array = np.random.randint(0, 9, size=size)
66+
my_pad = pad_symmetric(torch.from_numpy(array), pad_list)
67+
np_pad = np.pad(array, pad_list, mode="symmetric")
68+
assert np.allclose(my_pad.numpy(), np_pad)

0 commit comments

Comments
 (0)