Skip to content

adds available_device to test_precision_recall_curve #3335 #3368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
092da88
adds available_device to test_precision_recall_curve #3335
BanzaiTokyo Mar 28, 2025
4970f2c
forces float32 when converting to tensor on mps
BanzaiTokyo Mar 28, 2025
3c27487
Merge branch 'master' into test_precision_recall_curve_available_device
vfdev-5 Mar 29, 2025
fae2950
creates the data directly with torch tensors instead of numpy arrays
BanzaiTokyo Mar 29, 2025
41aa987
ensures compatibility with MPS by converting to float32
BanzaiTokyo Mar 29, 2025
e207a8f
Merge branch 'master' into test_precision_recall_curve_available_device
BanzaiTokyo Mar 29, 2025
429d803
comments on float32 conversion
BanzaiTokyo Mar 29, 2025
51d2dc3
makes sure that sklearn does not convert float32 to float64
BanzaiTokyo Apr 14, 2025
af3ee49
another attempt of avoiding float64
BanzaiTokyo Apr 14, 2025
0d6f930
avoiding float64 for MPS
BanzaiTokyo Apr 14, 2025
622c2d7
Merge branch 'master' into test_precision_recall_curve_available_device
BanzaiTokyo Apr 14, 2025
1c96abe
avoiding float64 for MPS
BanzaiTokyo Apr 15, 2025
52972de
another attempt at avoiding float64 on MPS
BanzaiTokyo Apr 15, 2025
093e13a
moves conversion to float32 before assertions
BanzaiTokyo Apr 15, 2025
74185e5
conversion to float32
BanzaiTokyo Apr 15, 2025
d0be0a9
more conversion to float32
BanzaiTokyo Apr 15, 2025
80574ad
more conversion to float32
BanzaiTokyo Apr 15, 2025
e0fd412
more conversion to float32
BanzaiTokyo Apr 15, 2025
cf57e09
more conversion to float32
BanzaiTokyo Apr 15, 2025
b30cfcd
in precision_recall_curve.py add dtype when creating tensors for prec…
BanzaiTokyo Apr 23, 2025
7488bdf
Merge branch 'master' into test_precision_recall_curve_available_device
BanzaiTokyo Apr 24, 2025
5ef9215
Merge branch 'master' into test_precision_recall_curve_available_device
BanzaiTokyo Apr 24, 2025
949357f
removes unnecessary conversions
BanzaiTokyo Apr 24, 2025
fc4075c
move tensors to CPU before passing them to precision_recall_curve
BanzaiTokyo Apr 28, 2025
9057d5d
move tensors to CPU before passing them to precision_recall_curve
BanzaiTokyo Apr 28, 2025
fe00e65
move tensors to CPU before passing them to precision_recall_curve
BanzaiTokyo Apr 28, 2025
f3e4ae8
replace np.testing.assert_array_almost_equal with pytest.approx
BanzaiTokyo Apr 28, 2025
1adec2b
Merge branch 'master' into test_precision_recall_curve_available_device
BanzaiTokyo Apr 28, 2025
875b15e
removes manual_seed
BanzaiTokyo Apr 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ignite/metrics/precision_recall_curve.py
Original file line number Diff line number Diff line change
@@ -110,11 +110,11 @@ def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # type: i
if idist.get_rank() == 0:
# Run compute_fn on zero rank only
precision, recall, thresholds = cast(Tuple, self.compute_fn(_prediction_tensor, _target_tensor))
precision = torch.tensor(precision, device=_prediction_tensor.device)
recall = torch.tensor(recall, device=_prediction_tensor.device)
precision = torch.tensor(precision, device=_prediction_tensor.device, dtype=self._double_dtype)
recall = torch.tensor(recall, device=_prediction_tensor.device, dtype=self._double_dtype)
# thresholds can have negative strides, not compatible with torch tensors
# https://discuss.pytorch.org/t/negative-strides-in-tensor-error/134287/2
thresholds = torch.tensor(thresholds.copy(), device=_prediction_tensor.device)
thresholds = torch.tensor(thresholds.copy(), device=_prediction_tensor.device, dtype=self._double_dtype)
else:
precision, recall, thresholds = None, None, None

136 changes: 74 additions & 62 deletions tests/ignite/metrics/test_precision_recall_curve.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,6 @@
from typing import Tuple
from unittest.mock import patch

import numpy as np
import pytest
import sklearn
import torch
@@ -28,85 +27,97 @@ def test_no_sklearn(mock_no_sklearn):
pr_curve.compute()


def test_precision_recall_curve():
def test_precision_recall_curve(available_device):
size = 100
np_y_pred = np.random.rand(size, 1)
np_y = np.zeros((size,))
np_y[size // 2 :] = 1
sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y, np_y_pred)
y_pred = torch.rand(size, 1, dtype=torch.float32, device=available_device)
y_true = torch.zeros(size, dtype=torch.float32, device=available_device)
y_true[size // 2 :] = 1.0
expected_precision, expected_recall, expected_thresholds = precision_recall_curve(
y_true.cpu().numpy(), y_pred.cpu().numpy()
)

precision_recall_curve_metric = PrecisionRecallCurve()
y_pred = torch.from_numpy(np_y_pred)
y = torch.from_numpy(np_y)
precision_recall_curve_metric = PrecisionRecallCurve(device=available_device)
assert precision_recall_curve_metric._device == torch.device(available_device)

precision_recall_curve_metric.update((y_pred, y))
precision_recall_curve_metric.update((y_pred, y_true))
precision, recall, thresholds = precision_recall_curve_metric.compute()
precision = precision.numpy()
recall = recall.numpy()
thresholds = thresholds.numpy()

assert pytest.approx(precision) == sk_precision
assert pytest.approx(recall) == sk_recall
# assert thresholds almost equal, due to numpy->torch->numpy conversion
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)
precision = precision.cpu().numpy()
recall = recall.cpu().numpy()
thresholds = thresholds.cpu().numpy()

assert pytest.approx(precision) == expected_precision
assert pytest.approx(recall) == expected_recall
assert thresholds == pytest.approx(expected_thresholds, rel=1e-6)


def test_integration_precision_recall_curve_with_output_transform():
np.random.seed(1)
def test_integration_precision_recall_curve_with_output_transform(available_device):
size = 100
np_y_pred = np.random.rand(size, 1)
np_y = np.zeros((size,))
np_y[size // 2 :] = 1
np.random.shuffle(np_y)
y_pred = torch.rand(size, 1, dtype=torch.float32, device=available_device)
y_true = torch.zeros(size, dtype=torch.float32, device=available_device)
y_true[size // 2 :] = 1.0
perm = torch.randperm(size)
y_pred = y_pred[perm]
y_true = y_true[perm]

sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y, np_y_pred)
expected_precision, expected_recall, expected_thresholds = precision_recall_curve(
y_true.cpu().numpy(), y_pred.cpu().numpy()
)

batch_size = 10

def update_fn(engine, batch):
idx = (engine.state.iteration - 1) * batch_size
y_true_batch = np_y[idx : idx + batch_size]
y_pred_batch = np_y_pred[idx : idx + batch_size]
return idx, torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)
y_true_batch = y_true[idx : idx + batch_size]
y_pred_batch = y_pred[idx : idx + batch_size]
return idx, y_pred_batch, y_true_batch

engine = Engine(update_fn)

precision_recall_curve_metric = PrecisionRecallCurve(output_transform=lambda x: (x[1], x[2]))
precision_recall_curve_metric = PrecisionRecallCurve(
output_transform=lambda x: (x[1], x[2]), device=available_device
)
assert precision_recall_curve_metric._device == torch.device(available_device)
precision_recall_curve_metric.attach(engine, "precision_recall_curve")

data = list(range(size // batch_size))
precision, recall, thresholds = engine.run(data, max_epochs=1).metrics["precision_recall_curve"]
precision = precision.numpy()
recall = recall.numpy()
thresholds = thresholds.numpy()
assert pytest.approx(precision) == sk_precision
assert pytest.approx(recall) == sk_recall
# assert thresholds almost equal, due to numpy->torch->numpy conversion
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)
precision = precision.cpu().numpy()
recall = recall.cpu().numpy()
thresholds = thresholds.cpu().numpy()
assert pytest.approx(precision) == expected_precision
assert pytest.approx(recall) == expected_recall
assert thresholds == pytest.approx(expected_thresholds, rel=1e-6)


def test_integration_precision_recall_curve_with_activated_output_transform():
np.random.seed(1)
def test_integration_precision_recall_curve_with_activated_output_transform(available_device):
size = 100
np_y_pred = np.random.rand(size, 1)
np_y_pred_sigmoid = torch.sigmoid(torch.from_numpy(np_y_pred)).numpy()
np_y = np.zeros((size,))
np_y[size // 2 :] = 1
np.random.shuffle(np_y)

sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y, np_y_pred_sigmoid)
y_pred = torch.rand(size, 1, dtype=torch.float32, device=available_device)
y_true = torch.zeros(size, dtype=torch.float32, device=available_device)
y_true[size // 2 :] = 1.0
perm = torch.randperm(size)
y_pred = y_pred[perm]
y_true = y_true[perm]

sigmoid_y_pred = torch.sigmoid(y_pred).cpu().numpy()
expected_precision, expected_recall, expected_thresholds = precision_recall_curve(
y_true.cpu().numpy(), sigmoid_y_pred
)

batch_size = 10

def update_fn(engine, batch):
idx = (engine.state.iteration - 1) * batch_size
y_true_batch = np_y[idx : idx + batch_size]
y_pred_batch = np_y_pred[idx : idx + batch_size]
return idx, torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)
y_true_batch = y_true[idx : idx + batch_size]
y_pred_batch = y_pred[idx : idx + batch_size]
return idx, y_pred_batch, y_true_batch

engine = Engine(update_fn)

precision_recall_curve_metric = PrecisionRecallCurve(output_transform=lambda x: (torch.sigmoid(x[1]), x[2]))
precision_recall_curve_metric = PrecisionRecallCurve(
output_transform=lambda x: (torch.sigmoid(x[1]), x[2]), device=available_device
)
assert precision_recall_curve_metric._device == torch.device(available_device)
precision_recall_curve_metric.attach(engine, "precision_recall_curve")

data = list(range(size // batch_size))
@@ -115,25 +126,26 @@ def update_fn(engine, batch):
recall = recall.cpu().numpy()
thresholds = thresholds.cpu().numpy()

assert pytest.approx(precision) == sk_precision
assert pytest.approx(recall) == sk_recall
# assert thresholds almost equal, due to numpy->torch->numpy conversion
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)
assert pytest.approx(precision) == expected_precision
assert pytest.approx(recall) == expected_recall
assert thresholds == pytest.approx(expected_thresholds, rel=1e-6)


def test_check_compute_fn():
def test_check_compute_fn(available_device):
y_pred = torch.zeros((8, 13))
y_pred[:, 1] = 1
y_true = torch.zeros_like(y_pred)
output = (y_pred, y_true)

em = PrecisionRecallCurve(check_compute_fn=True)
em = PrecisionRecallCurve(check_compute_fn=True, device=available_device)
assert em._device == torch.device(available_device)

em.reset()
with pytest.warns(EpochMetricWarning, match=r"Probably, there can be a problem with `compute_fn`"):
em.update(output)

em = PrecisionRecallCurve(check_compute_fn=False)
em = PrecisionRecallCurve(check_compute_fn=False, device=available_device)
assert em._device == torch.device(available_device)
em.update(output)


@@ -225,14 +237,14 @@ def update(engine, i):
np_y_true = y_true.cpu().numpy().ravel()
np_y_preds = y_preds.cpu().numpy().ravel()

sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y_true, np_y_preds)
expected_precision, expected_recall, expected_thresholds = precision_recall_curve(np_y_true, np_y_preds)

assert precision.shape == sk_precision.shape
assert recall.shape == sk_recall.shape
assert thresholds.shape == sk_thresholds.shape
assert pytest.approx(precision.cpu().numpy()) == sk_precision
assert pytest.approx(recall.cpu().numpy()) == sk_recall
assert pytest.approx(thresholds.cpu().numpy()) == sk_thresholds
assert precision.shape == expected_precision.shape
assert recall.shape == expected_recall.shape
assert thresholds.shape == expected_thresholds.shape
assert pytest.approx(precision.cpu().numpy()) == expected_precision
assert pytest.approx(recall.cpu().numpy()) == expected_recall
assert pytest.approx(thresholds.cpu().numpy()) == expected_thresholds

metric_devices = ["cpu"]
if device.type != "xla":