Skip to content

Commit f995ff5

Browse files
committedMar 4, 2022
vector field example
1 parent 9e84bdd commit f995ff5

File tree

3 files changed

+49
-0
lines changed

3 files changed

+49
-0
lines changed
 

‎README.md

+6
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ Problems solved by this project:
99

1010
See [Usage](#usage) for a high-level example of using the library. See [How it works](#how-it-works) to see which modules are supported.
1111

12+
For fun, here's a vector field produced by differentiating the probability predictions of a two-class SVM (produced by [this script](examples/svm_vector_field.py)):
13+
14+
<img src="examples/svm_vector_field.png" width="300" alt="A vector field quiver plot with two modes">
15+
1216
# Usage
1317

1418
First, train a model with scikit-learn as usual:
@@ -47,6 +51,8 @@ torch.jit.script(torch_model).save("path.pt")
4751
loaded_model = torch.jit.load("path.pt")
4852
```
4953

54+
For a full example of training a model and using its PyTorch translation, see [examples/svm_vector_field.py](examples/svm_vector_field.py).
55+
5056
# How it works
5157

5258
**sk2torch** contains PyTorch re-implementations of supported scikit-learn models. For a supported estimator `X`, a class `TorchX` in sk2torch will be able to read the attributes of `X` and convert them to `torch.Tensor` or simple Python types. `TorchX` subclasses `torch.nn.Module` and has a method for each inference API of `X` (e.g. `predict`, `decision_function`, etc.).

‎examples/svm_vector_field.png

22.6 KB
Loading

‎examples/svm_vector_field.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""
2+
Train an SVM on a 2D classification problem, then plot a gradient vector field
3+
for the predicted class probabilty.
4+
"""
5+
6+
import matplotlib.pyplot as plt
7+
import numpy as np
8+
import sk2torch
9+
import torch
10+
from sklearn.svm import SVC
11+
12+
# Create a dataset of two Gaussians. There will be some overlap
13+
# between the two classes, which adds some uncertainty to the model.
14+
xs = np.concatenate(
15+
[
16+
np.random.random(size=(256, 2)) + [1, 0],
17+
np.random.random(size=(256, 2)) + [-1, 0],
18+
],
19+
axis=0,
20+
)
21+
ys = np.array([False] * 256 + [True] * 256)
22+
23+
# Train an SVM on the data and wrap it in PyTorch.
24+
sk_model = SVC(probability=True)
25+
sk_model.fit(xs, ys)
26+
model = sk2torch.wrap(sk_model)
27+
28+
# Create a coordinate grid to compute a vector field on.
29+
spaced = np.linspace(-2, 2, num=25)
30+
grid_xs = torch.tensor([[x, y] for x in spaced for y in spaced], requires_grad=True)
31+
32+
# Compute the gradients of the SVM output.
33+
outputs = model.predict_proba(grid_xs)[:, 1]
34+
(input_grads,) = torch.autograd.grad(outputs.sum(), (grid_xs,))
35+
36+
# Create a quiver plot of the vector field.
37+
plt.quiver(
38+
grid_xs[:, 0].detach().numpy(),
39+
grid_xs[:, 1].detach().numpy(),
40+
input_grads[:, 0].detach().numpy(),
41+
input_grads[:, 1].detach().numpy(),
42+
)
43+
plt.savefig("svm_vector_field.png")

0 commit comments

Comments
 (0)