Skip to content

Commit 488afd6

Browse files
committed
add everything
1 parent 87885b8 commit 488afd6

File tree

9 files changed

+368
-2
lines changed

9 files changed

+368
-2
lines changed

README.md

+58-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,58 @@
1-
# torchplot
2-
Plotting pytorch tensors made easy!
1+
# torchplot - Plotting pytorch tensors made easy!
2+
3+
Ask yourself the following:
4+
* Are you using `matplotlib.pyplot` to plot pytorch tensors?
5+
* Do you forget to call `.cpu().detach().numpy()` everytime you want to plot a tensor
6+
7+
Then `torchplot` may be something for you. `torchplot` is a simple drop-in replacement
8+
for plotting pytorch tensors. We simply override every `matplotlib.pyplot` function such
9+
that pytorch tensors are automatically converted.
10+
11+
Simply just change your default `matplotlib` import statement:
12+
13+
14+
Instead of
15+
```
16+
from matplotlib.pyplot import *
17+
```
18+
use
19+
```
20+
from torchplot import *
21+
```
22+
and instead of
23+
```
24+
import matplotlib.pyplot as plt
25+
```
26+
use
27+
```
28+
import torchplot as plt
29+
```
30+
Herafter, then you can remove every `.cpu().detach().numpy()` (or variations heroff) from
31+
your code and everything should just work. If you do not want to mix implementations,
32+
we recommend importing `torchplot` as seperaly package:
33+
```
34+
import torchplot as tp
35+
```
36+
37+
## Installation
38+
Simple as
39+
```
40+
pip install torchplot
41+
```
42+
43+
## Example
44+
45+
```
46+
# lets make a scatter plot of two pytorch variables
47+
import torch
48+
import torchplot as plt
49+
x = torch.randn(100, requires_grad=True, device='cuda')
50+
y = torch.randn(100, requires_grad=True, device='cuda')
51+
plt.plot(x, y, '.') # easy and simple
52+
```
53+
54+
55+
56+
57+
58+

pyproject.toml

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
[build-system]
2+
requires = [
3+
"setuptools",
4+
"wheel",
5+
]
6+
7+
[tool.autopep8]
8+
max_line_length = 120
9+
ignore = ["W504", "W504", "E402", "E731", "C40", "E741", "F40", "F841"]
10+
11+
[tool.black]
12+
# https://github.com/psf/black
13+
line-length = 120
14+
target-version = ["py38"]
15+
exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|.venv|.svn|_build|buck-out|build|dist|docs)"
16+
17+
[tool.isort]
18+
known_first_party = [
19+
"benchmarks",
20+
"docs",
21+
"pl_examples",
22+
"pytorch_lightning",
23+
"tests",
24+
]
25+
skip_glob = [
26+
"pytorch_lightning/accelerators/*",
27+
"pytorch_lightning/callbacks/*",
28+
"pytorch_lightning/cluster_environments/*",
29+
"pytorch_lightning/core/*",
30+
"pytorch_lightning/distributed/*",
31+
"pytorch_lightning/loggers/*",
32+
"pytorch_lightning/metrics/*",
33+
"pytorch_lightning/overrides/*",
34+
"pytorch_lightning/plugins/*",
35+
"pytorch_lightning/profiler/*",
36+
"pytorch_lightning/trainer/*",
37+
"pytorch_lightning/tuner/*",
38+
"pytorch_lightning/utilities/*",
39+
"tests/backends/*",
40+
"tests/base/*",
41+
"tests/callbacks/*",
42+
"tests/checkpointing/*",
43+
"tests/core/*",
44+
"tests/loggers/*",
45+
"tests/metrics/*",
46+
"tests/models/*",
47+
"tests/plugins/*",
48+
"tests/trainer/*",
49+
"tests/tuner/*",
50+
"tests/utilities/*",
51+
]
52+
profile = "black"
53+
line_length = 120
54+
force_sort_within_sections = "False"
55+
order_by_type = "False"

requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
torch>=1.3
2+
numpy>=1.16.4

setup.cfg

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
[metadata]
2+
description-file = README.md
3+
4+
[flake8]
5+
# TODO: this should be 88 or 100 according PEP8
6+
max-line-length = 120
7+
exclude = .tox,*.egg,build,temp
8+
select = E,W,F
9+
doctests = True
10+
verbose = 2
11+
# https://pep8.readthedocs.io/en/latest/intro.html#error-codes
12+
format = pylint
13+
ignore = E731,W504,F401,F841,E722,W503
14+
15+
[build_sphinx]
16+
source-dir = doc/source
17+
build-dir = doc/build
18+
all_files = 1
19+
20+
[upload_sphinx]
21+
upload-dir = doc/build/html

setup.py

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright The GeoML Team
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
from io import open
16+
17+
from setuptools import setup, find_packages, Command
18+
19+
try:
20+
import builtins
21+
except ImportError:
22+
import __builtin__ as builtins
23+
24+
PATH_ROOT = os.path.dirname(__file__)
25+
builtins.__STOCHMAN_SETUP__ = True
26+
27+
import torchplot
28+
29+
30+
class CleanCommand(Command):
31+
"""Custom clean command to tidy up the project root."""
32+
33+
user_options = []
34+
35+
def initialize_options(self):
36+
pass
37+
38+
def finalize_options(self):
39+
pass
40+
41+
def run(self):
42+
os.system("rm -vrf ./build ./dist ./*.pyc ./*.tgz ./*.egg-info")
43+
44+
45+
PATH_ROOT = os.path.dirname(__file__)
46+
47+
48+
def load_readme(path_dir=PATH_ROOT):
49+
with open(os.path.join(path_dir, "README.md"), encoding="utf-8") as f:
50+
long_description = f.read()
51+
return long_description
52+
53+
54+
setup(
55+
name="torchplot",
56+
version=torchplot.__version__,
57+
description=torchplot.__docs__,
58+
long_description=load_readme(PATH_ROOT),
59+
author=torchplot.__author__,
60+
author_email=torchplot.__author_email__,
61+
license=torchplot.__license__,
62+
packages=find_packages(exclude=["tests", "tests/*"]),
63+
python_requires=">=3.8",
64+
install_requires=['torch>=1.3', 'matplotlib>=3.3.3'],
65+
download_url="https://github.com/CenterBioML/torchplot/archive/0.1.0.zip",
66+
classifiers=[
67+
"Environment :: Console",
68+
"Natural Language :: English",
69+
# How mature is this project? Common values are
70+
# 3 - Alpha, 4 - Beta, 5 - Production/Stable
71+
"Development Status :: 3 - Alpha",
72+
# Indicate who your project is intended for
73+
"Intended Audience :: Developers",
74+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
75+
# Pick your license as you wish
76+
"License :: OSI Approved :: Apache Software License",
77+
"Operating System :: OS Independent",
78+
# Specify the Python versions you support here. In particular, ensure
79+
# that you indicate whether you support Python 2, Python 3 or both.
80+
"Programming Language :: Python :: 3.8",
81+
],
82+
)

tests/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

tests/test_torchplot.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright The GeoML Team
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import pytest
15+
from inspect import getmembers, isfunction
16+
from collections import namedtuple
17+
import string
18+
19+
import torch
20+
import numpy as np
21+
22+
import matplotlib.pyplot as plt
23+
import torchplot as tp
24+
25+
Inputs = namedtuple("case", ["x", "y"])
26+
27+
_cpu_cases = [Inputs(x=torch.randn(100,), y=torch.randn(100,)),
28+
Inputs(x=torch.randn(100, requires_grad=True), y=torch.randn(100,requires_grad=True)),
29+
# test that list/numpy arrays still works
30+
Inputs(x=[1,2,3,4], y=[1,2,3,4]),
31+
Inputs(x=np.random.randn(100,), y=np.random.randn(100,)),
32+
# test that we can mix
33+
Inputs(x=torch.randn(100,), y=torch.randn(100, requires_grad=True)),
34+
Inputs(x=np.random.randn(100,), y=torch.randn(100, requires_grad=True)),
35+
Inputs(x=torch.randn(5,), y=[1,2,3,4,5]),
36+
]
37+
38+
_gpu_cases = [Inputs(x=torch.randn(100, device='cuda'), y=torch.randn(100, device='cuda')),
39+
Inputs(x=torch.randn(100,requires_grad=True, device='cuda'), y=torch.randn(100,requires_grad=True, device='cuda')),
40+
]
41+
42+
43+
44+
_members_to_check = [name for name, member in getmembers(plt)
45+
if isfunction(member) and not name.startswith('_')]
46+
47+
48+
def string_compare(text1, text2):
49+
if text1 is None and text2 is None:
50+
return True
51+
remove = string.punctuation + string.whitespace
52+
return text1.translate(str.maketrans(dict.fromkeys(remove))) == text2.translate(str.maketrans(dict.fromkeys(remove)))
53+
54+
55+
@pytest.mark.parametrize("member", _members_to_check)
56+
def test_members(member):
57+
""" test that all members have been copied """
58+
assert member in dir(plt)
59+
assert member in dir(tp)
60+
61+
62+
@pytest.mark.parametrize('test_case', _cpu_cases)
63+
def test_cpu(test_case):
64+
""" test that it works on cpu """
65+
assert tp.plot(test_case.x, test_case.y, '.')
66+
67+
68+
@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda')
69+
@pytest.mark.parametrize('test_case', _gpu_cases)
70+
def test_gpu(test_case):
71+
""" test that it works on gpu """
72+
assert tp.plot(test_case.x, test_case.y, '.')
73+

torchplot/__init__.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright The GeoML Team
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Root package info."""
15+
import os
16+
import time
17+
18+
_this_year = time.strftime("%Y")
19+
__version__ = "0.1"
20+
__author__ = "Nicki Skafte Detlefsen et al."
21+
__author_email__ = "[email protected]"
22+
__license__ = "Apache-2.0"
23+
__copyright__ = f"Copyright (c) 2018-{_this_year}, {__author__}."
24+
__homepage__ = "https://github.com/CenterBioML/torchplot"
25+
26+
__docs__ = "Plotting pytorch tensors made easy"
27+
28+
PACKAGE_ROOT = os.path.dirname(__file__)
29+
PROJECT_ROOT = os.path.dirname(PACKAGE_ROOT)
30+
31+
from .core import *

torchplot/core.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright The GeoML Team
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import matplotlib.pyplot as plt
15+
import torch
16+
from inspect import getmembers, isfunction, getdoc
17+
18+
# Function to convert a list of arguments containing torch tensors, into
19+
# a corresponding list of arguments containing numpy arrays
20+
def _torch2np(*args, **kwargs):
21+
def convert(arg):
22+
return arg.detach().cpu().numpy() if isinstance(arg, torch.Tensor) else arg
23+
24+
# first unnamed arguments
25+
outargs = [convert(arg) for arg in args]
26+
27+
# then keyword arguments
28+
outkwargs = dict()
29+
for key, value in kwargs.items():
30+
outkwargs[key] = convert(value)
31+
32+
return outargs, kwargs
33+
34+
# Iterate over all members of 'plt' in order to duplicate them
35+
for name, member in getmembers(plt):
36+
if isfunction(member):
37+
doc = getdoc(member)
38+
strdoc = "" if doc is None else doc
39+
exec(('def {name}(*args, **kwargs):\n' +
40+
'\t"""{doc}"""\n' +
41+
'\tnew_args, new_kwargs = _torch2np(*args, **kwargs)\n' +
42+
'\treturn plt.{name}(*new_args, **new_kwargs)').format(name=name, doc=strdoc))
43+
else:
44+
exec('{name} = plt.{name}'.format(name=name))
45+
#break

0 commit comments

Comments
 (0)