Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,18 @@ PyTorch API

sequential

ONNX API
********

.. currentmodule:: gurobi_ml.onnx

.. autosummary::
:toctree: auto_generated/
:caption: ONNX API
:template: modeling_object.rst

onnx_model

XGBoost API
***********

Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Gurobi Machine Learning
A Python package to help use *trained* regression models in
mathematical optimization models. The package supports a variety of regression models
(linear, logistic, neural networks, decision trees,...) trained by
different machine learning frameworks (scikit-learn, LightGBM, XGBoost, Keras and PyTorch).
different machine learning frameworks (scikit-learn, LightGBM, XGBoost, Keras, PyTorch, and ONNX).


.. only:: html
Expand Down
15 changes: 15 additions & 0 deletions docs/source/user/supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,21 @@ Currently, only two types of layers are supported:
* :external+torch:py:class:`Linear layers <torch.nn.Linear>`,
* :external+torch:py:class:`ReLU layers <torch.nn.ReLU>`.

ONNX
----

`ONNX <https://onnx.ai/>`_ models for sequential multi-layer perceptrons are
supported when composed of `Gemm` (dense) operators and `Relu` activations.

They can be formulated in a Gurobi model with the function
:py:func:`add_onnx_constr <gurobi_ml.onnx.onnx_model.add_onnx_constr>`.

Currently, only the following are supported:

* `Gemm` nodes with default attributes (`alpha=1`, `beta=1`) and optional
`transB` attribute,
* `Relu` activations.

XGBoost
-------

Expand Down
291 changes: 291 additions & 0 deletions notebooks/adversarial/adversarial_onnx.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Adversarial example using ONNX\n",
"\n",
"In this example, we demonstrate finding adversarial examples for a neural network using Gurobi and gurobi-ml's ONNX support.\n",
"\n",
"We load a pre-trained MNIST classifier (stored as an ONNX model) and use optimization to find small perturbations to an input image that cause misclassification.\n",
"\n",
"This example requires:\n",
" - [matplotlib](https://matplotlib.org/)\n",
" - [onnx](https://onnx.ai/)\n",
" - [keras](https://keras.io/) (only for loading MNIST data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from matplotlib import pyplot as plt\n",
"import numpy as np\n",
"import onnx\n",
"from tensorflow import keras\n",
"\n",
"import gurobipy as gp\n",
"from gurobi_ml import add_predictor_constr"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load MNIST data\n",
"\n",
"We use Keras only to load the MNIST dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n",
"\n",
"# Reshape and normalize\n",
"x_test = x_test.astype(\"float32\") / 255.0\n",
"x_test_flat = x_test.reshape(-1, 28 * 28)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load pre-trained ONNX model\n",
"\n",
"We load a pre-trained neural network with 2 hidden layers of 50 neurons each and ReLU activations.\n",
"The model was trained on MNIST and converted to ONNX format."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"onnx_model = onnx.load(\"mnist_model.onnx\")\n",
"print(\"ONNX model loaded successfully\")\n",
"print(f\"Model has {len(onnx_model.graph.node)} operations\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Verify model predictions\n",
"\n",
"Let's verify the model works by making a prediction on a test sample."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Use onnxruntime for inference\n",
"import onnxruntime as ort\n",
"\n",
"session = ort.InferenceSession(onnx_model.SerializeToString())\n",
"input_name = session.get_inputs()[0].name\n",
"\n",
"# Predict on a test sample\n",
"sample_idx = 18\n",
"sample = x_test_flat[sample_idx : sample_idx + 1]\n",
"prediction = session.run(None, {input_name: sample})[0]\n",
"\n",
"print(f\"True label: {y_test[sample_idx]}\")\n",
"print(f\"Predicted: {np.argmax(prediction)}\")\n",
"\n",
"# Display the image\n",
"plt.imshow(x_test[sample_idx], cmap=\"gray\")\n",
"plt.title(f\"True: {y_test[sample_idx]}, Predicted: {np.argmax(prediction)}\")\n",
"plt.axis(\"off\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Select an example for adversarial attack\n",
"\n",
"We choose a test example that is correctly classified and define the target misclassification."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"example = x_test_flat[sample_idx : sample_idx + 1]\n",
"right_label = int(y_test[sample_idx])\n",
"wrong_label = 8\n",
"\n",
"print(f\"Original label: {right_label}\")\n",
"print(f\"Target (wrong) label: {wrong_label}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Build the optimization model\n",
"\n",
"We create a Gurobi model to find an adversarial example.\n",
"The objective is to maximize the score difference between the wrong label and correct label,\n",
"subject to the perturbed image being close to the original (measured by L1 distance)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"m = gp.Model()\n",
"delta = 5 # Maximum L1 distance from original image\n",
"\n",
"# Decision variables\n",
"x = m.addMVar(example.shape, lb=0.0, ub=1.0, name=\"x\")\n",
"y = m.addMVar((1, 10), lb=-gp.GRB.INFINITY, name=\"y\") # Network output logits\n",
"abs_diff = m.addMVar(example.shape, lb=0, ub=1, name=\"abs_diff\")\n",
"\n",
"# Objective: maximize score of wrong label minus score of correct label\n",
"m.setObjective(y[0, wrong_label] - y[0, right_label], gp.GRB.MAXIMIZE)\n",
"\n",
"# Constraints: bound L1 distance from original\n",
"m.addConstr(abs_diff >= x - example)\n",
"m.addConstr(abs_diff >= -x + example)\n",
"m.addConstr(abs_diff.sum() <= delta)\n",
"\n",
"# Add neural network constraints\n",
"pred_constr = add_predictor_constr(m, onnx_model, x, y)\n",
"\n",
"pred_constr.print_stats()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Solve the optimization problem\n",
"\n",
"We solve the model to find an adversarial example."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"m.Params.BestBdStop = 0.0\n",
"m.Params.BestObjStop = 0.0\n",
"m.optimize()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Display the adversarial example\n",
"\n",
"If an adversarial example was found, we display it and verify the misclassification."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pred_constr.get_error()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"adversarial_image = x.X.reshape(28, 28)\n",
"\n",
"# Verify classification\n",
"adv_flat = x.X.reshape(1, -1).astype(np.float32)\n",
"adv_prediction = session.run(None, {input_name: adv_flat})[0]\n",
"predicted_label = np.argmax(adv_prediction)\n",
"\n",
"# Display original and adversarial images\n",
"fig, axes = plt.subplots(1, 3, figsize=(12, 4))\n",
"\n",
"axes[0].imshow(example.reshape(28, 28), cmap=\"gray\")\n",
"axes[0].set_title(f\"Original (label: {right_label})\")\n",
"axes[0].axis(\"off\")\n",
"\n",
"axes[1].imshow(adversarial_image, cmap=\"gray\")\n",
"axes[1].set_title(f\"Adversarial (classified as: {predicted_label})\")\n",
"axes[1].axis(\"off\")\n",
"\n",
"# Show difference\n",
"diff = np.abs(adversarial_image - example.reshape(28, 28))\n",
"axes[2].imshow(diff, cmap=\"hot\")\n",
"axes[2].set_title(f\"Difference (L1: {diff.sum():.2f})\")\n",
"axes[2].axis(\"off\")\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n",
"if m.ObjVal > 0.0:\n",
" print(\"\\nAdversarial example found!\")\n",
" print(f\"Original label: {right_label}\")\n",
" print(f\"Predicted label: {predicted_label}\")\n",
" print(f\"L1 distance: {diff.sum():.2f}\")\n",
"else:\n",
" print(\"No adversarial example exists within the specified distance bound.\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"copyright © 2023 Gurobi Optimization, LLC"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
},
"license": {
"full_text": "# Copyright © 2023-2025 Gurobi Optimization, LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# =============================================================================="
},
"vscode": {
"interpreter": {
"hash": "949777d72b0d2535278d3dc13498b2535136f6dfe0678499012e853ee9abcab1"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Binary file added notebooks/adversarial/mnist_model.onnx
Binary file not shown.
2 changes: 2 additions & 0 deletions requirements.onnx.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
onnx>=1.14.0
onnxruntime>=1.15.0
22 changes: 22 additions & 0 deletions src/gurobi_ml/onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright © 2025 Gurobi Optimization, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""ONNX support for formulating simple feed-forward neural networks.

Currently supports sequential MLPs represented with ONNX `Gemm` layers and
`Relu` activations, matching the capabilities of the Keras and PyTorch
converters (Dense/Linear + ReLU)."""

from .onnx_model import add_onnx_constr as add_onnx_constr # noqa: F401
Loading
Loading