Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions tests/trainer/ppo/test_metric_utils_on_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
from verl.utils.metric import (
reduce_metrics,
)
from verl.utils.metric.utils import (
AggregationType,
Metric,
)


class TestReduceMetrics(unittest.TestCase):
Expand Down Expand Up @@ -67,6 +71,167 @@ def test_reduce_metrics_single_value(self):
self.assertEqual(result["single"], 5.0)


class TestMetric(unittest.TestCase):
"""Tests for the Metric class."""

def test_init_with_string_aggregation(self):
"""Test Metric initialization with string aggregation type."""
metric = Metric(aggregation="mean")
self.assertEqual(metric.aggregation, AggregationType.MEAN)
self.assertEqual(metric.values, [])

def test_init_with_enum_aggregation(self):
"""Test Metric initialization with AggregationType enum."""
metric = Metric(aggregation=AggregationType.SUM)
self.assertEqual(metric.aggregation, AggregationType.SUM)
self.assertEqual(metric.values, [])

def test_init_with_value(self):
"""Test Metric initialization with an initial value."""
metric = Metric(aggregation="mean", value=5.0)
self.assertEqual(metric.values, [5.0])

def test_init_with_invalid_aggregation(self):
"""Test Metric initialization with invalid aggregation type."""
with self.assertRaises(ValueError):
Metric(aggregation="invalid")

def test_append_float(self):
"""Test appending float values."""
metric = Metric(aggregation="mean")
metric.append(1.0)
metric.append(2.0)
self.assertEqual(metric.values, [1.0, 2.0])

def test_append_int(self):
"""Test appending int values."""
metric = Metric(aggregation="mean")
metric.append(1)
metric.append(2)
self.assertEqual(metric.values, [1, 2])

def test_append_tensor(self):
"""Test appending scalar tensor values."""
metric = Metric(aggregation="mean")
metric.append(torch.tensor(3.0))
metric.append(torch.tensor(4.0))
self.assertEqual(metric.values, [3.0, 4.0])

def test_append_non_scalar_tensor_raises(self):
"""Test that appending non-scalar tensor raises ValueError."""
metric = Metric(aggregation="mean")
with self.assertRaises(ValueError):
metric.append(torch.tensor([1.0, 2.0]))

def test_append_metric(self):
"""Test appending another Metric extends values."""
metric1 = Metric(aggregation="mean", value=1.0)
metric1.append(2.0)

metric2 = Metric(aggregation="mean", value=3.0)
metric2.append(metric1)

self.assertEqual(metric2.values, [3.0, 1.0, 2.0])

def test_extend_with_list(self):
"""Test extending with a list of values."""
metric = Metric(aggregation="mean")
metric.extend([1.0, 2.0, 3.0])
self.assertEqual(metric.values, [1.0, 2.0, 3.0])

def test_extend_with_metric(self):
"""Test extending with another Metric."""
metric1 = Metric(aggregation="mean")
metric1.extend([1.0, 2.0])

metric2 = Metric(aggregation="mean")
metric2.extend([3.0, 4.0])
metric2.extend(metric1)

self.assertEqual(metric2.values, [3.0, 4.0, 1.0, 2.0])

def test_extend_aggregation_mismatch_raises(self):
"""Test that extending with mismatched aggregation raises ValueError."""
metric1 = Metric(aggregation="mean")
metric2 = Metric(aggregation="sum")

with self.assertRaises(ValueError):
metric1.extend(metric2)

def test_aggregate_mean(self):
"""Test aggregation with mean."""
metric = Metric(aggregation="mean")
metric.extend([1.0, 2.0, 3.0, 4.0])
self.assertEqual(metric.aggregate(), 2.5)

def test_aggregate_sum(self):
"""Test aggregation with sum."""
metric = Metric(aggregation="sum")
metric.extend([1.0, 2.0, 3.0, 4.0])
self.assertEqual(metric.aggregate(), 10.0)

def test_aggregate_min(self):
"""Test aggregation with min."""
metric = Metric(aggregation="min")
metric.extend([3.0, 1.0, 4.0, 2.0])
self.assertEqual(metric.aggregate(), 1.0)

def test_aggregate_max(self):
"""Test aggregation with max."""
metric = Metric(aggregation="max")
metric.extend([3.0, 1.0, 4.0, 2.0])
self.assertEqual(metric.aggregate(), 4.0)

def test_chain_multiple_metrics(self):
"""Test chain combines multiple Metrics."""
metric1 = Metric(aggregation="sum")
metric1.extend([1.0, 2.0])

metric2 = Metric(aggregation="sum")
metric2.extend([3.0, 4.0])

chained = Metric.chain([metric1, metric2])

self.assertEqual(chained.aggregation, AggregationType.SUM)
self.assertEqual(chained.values, [1.0, 2.0, 3.0, 4.0])
self.assertEqual(chained.aggregate(), 10.0)

def test_from_dict(self):
"""Test from_dict creates Metrics from dictionary."""
data = {"loss": 1.0, "accuracy": 0.9}
metrics = Metric.from_dict(data, aggregation="mean")

self.assertIn("loss", metrics)
self.assertIn("accuracy", metrics)
self.assertEqual(metrics["loss"].values, [1.0])
self.assertEqual(metrics["accuracy"].values, [0.9])
self.assertEqual(metrics["loss"].aggregation, AggregationType.MEAN)

def test_init_list(self):
"""Test init_list creates new empty Metric with same aggregation."""
metric = Metric(aggregation="max")
metric.extend([1.0, 2.0])

new_metric = metric.init_list()

self.assertEqual(new_metric.aggregation, AggregationType.MAX)
self.assertEqual(new_metric.values, [])

def test_reduce_metrics_with_metric(self):
"""Test reduce_metrics correctly handles Metric objects."""
metric = Metric(aggregation="mean")
metric.extend([1.0, 2.0, 3.0])

metrics = {
"custom_metric": metric,
"list_metric": [4.0, 5.0, 6.0],
}
result = reduce_metrics(metrics)

self.assertEqual(result["custom_metric"], 2.0)
self.assertEqual(result["list_metric"], 5.0)


class TestComputeDataMetrics(unittest.TestCase):
"""Tests for the compute_data_metrics function."""

Expand Down
4 changes: 2 additions & 2 deletions verl/utils/metric/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .utils import reduce_metrics
from .utils import AggregationType, Metric, reduce_metrics

__all__ = ["reduce_metrics"]
__all__ = ["reduce_metrics", "AggregationType", "Metric"]
101 changes: 98 additions & 3 deletions verl/utils/metric/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
Metrics utils.
"""

from typing import Any
from enum import Enum
from typing import Any, Optional, Union

import numpy as np
import torch


def reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]:
def reduce_metrics(metrics: dict[str, Union["Metric", list[Any]]]) -> dict[str, Any]:
"""
Reduces a dictionary of metric lists by computing the mean, max, or min of each list.
The reduce operation is determined by the key name:
Expand All @@ -45,10 +47,103 @@ def reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]:
{"loss": 2.0, "accuracy": 0.8, "max_reward": 8.0, "min_error": 0.05}
"""
for key, val in metrics.items():
if "max" in key:
if isinstance(val, Metric):
metrics[key] = val.aggregate()
elif "max" in key:
metrics[key] = np.max(val)
elif "min" in key:
metrics[key] = np.min(val)
else:
metrics[key] = np.mean(val)
return metrics


class AggregationType(Enum):
MEAN = "mean"
SUM = "sum"
MIN = "min"
MAX = "max"


NumericType = int, float, torch.Tensor
Numeric = int | float | torch.Tensor


class Metric:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JacobHelwig Docstring Coverage ci failed, please add doc string for Metric.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added!

"""
A metric aggregator for collecting and aggregating numeric values.

This class accumulates numeric values (int, float, or scalar tensors) and computes
an aggregate statistic based on the specified aggregation type (MEAN, SUM, MIN, or MAX).

Args:
aggregation: The aggregation method to use. Can be a string ("mean", "sum", "min", "max")
or an AggregationType enum value.
value: Optional initial value(s) to add. Can be a single numeric value or a list of values.

Example:
>>> metric = Metric(aggregation="mean", value=1.0)
>>> metric.append(2.0)
>>> metric.append(3.0)
>>> metric.aggregate()
2.0
"""

def __init__(self, aggregation: str | AggregationType, value: Optional[Numeric | list[Numeric]] = None) -> None:
if isinstance(aggregation, str):
self.aggregation = AggregationType(aggregation)
else:
self.aggregation = aggregation
if not isinstance(self.aggregation, AggregationType):
raise ValueError(f"Unsupported aggregation type: {aggregation}")
self.values = []
if value is not None:
self.append(value)
Comment on lines +100 to +101
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The __init__ method does not correctly handle initialization with a list of values. The type hint for value is Optional[Numeric | list[Numeric]], but if a list is passed, self.append(value) is called. This will raise a ValueError because append is designed for single numeric values or Metric objects, not lists. You should check if value is a list and call self.extend(value) in that case.

Suggested change
if value is not None:
self.append(value)
if value is not None:
if isinstance(value, list):
self.extend(value)
else:
self.append(value)


def append(self, value: Union[Numeric, "Metric"]) -> None:
if isinstance(value, Metric):
self.extend(value)
return
if isinstance(value, torch.Tensor):
if value.numel() != 1:
raise ValueError("Only scalar tensors can be converted to float")
value = value.detach().item()
if not isinstance(value, NumericType):
raise ValueError(f"Unsupported value type: {type(value)}")
self.values.append(value)

def extend(self, values: Union["Metric", list[Numeric]]) -> None:
if isinstance(values, Metric):
if values.aggregation != self.aggregation:
raise ValueError(f"Aggregation type mismatch: {self.aggregation} != {values.aggregation}")
values = values.values
for value in values:
self.append(value)

def aggregate(self) -> float:
match self.aggregation:
case AggregationType.MEAN:
return np.mean(self.values)
case AggregationType.SUM:
return np.sum(self.values)
case AggregationType.MIN:
return np.min(self.values)
case AggregationType.MAX:
return np.max(self.values)

@classmethod
def chain(cls, metric_lists: list["Metric"]) -> "Metric":
if len(metric_lists) == 0:
return cls(aggregation=AggregationType.MEAN)
Comment on lines +136 to +137
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When chain is called with an empty list of metrics, it returns a new Metric with AggregationType.MEAN by default. This arbitrary default can lead to incorrect behavior. For example, if a list of metrics intended for summation is empty, chaining them should result in a metric that aggregates to 0 (the sum of an empty set), not NaN (the mean of an empty set). It would be safer to raise a ValueError if the list is empty, forcing the caller to handle this case explicitly.

Suggested change
if len(metric_lists) == 0:
return cls(aggregation=AggregationType.MEAN)
if not metric_lists:
raise ValueError("Cannot chain an empty list of metrics.")

aggregation = metric_lists[0].aggregation
chained = cls(aggregation=aggregation)
for ml in metric_lists:
chained.extend(ml)
return chained

@classmethod
def from_dict(cls, data: dict[str, Numeric], aggregation: str | AggregationType) -> dict[str, "Metric"]:
return {key: cls(value=value, aggregation=aggregation) for key, value in data.items()}

def init_list(self) -> "Metric":
return Metric(aggregation=self.aggregation)
4 changes: 3 additions & 1 deletion verl/utils/py_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from types import SimpleNamespace
from typing import Any, Callable, Iterator, Optional

from verl.utils.metric import Metric


# --- Top-level helper for multiprocessing timeout ---
# This function MUST be defined at the top level to be pickleable
Expand Down Expand Up @@ -196,7 +198,7 @@ def append_to_dict(data: dict, new_data: dict, prefix: str = ""):
for key, val in new_data.items():
new_key = f"{prefix}{key}" if not key.startswith(prefix) else key
if new_key not in data:
data[new_key] = []
data[new_key] = val.init_list() if isinstance(val, Metric) else []
if isinstance(val, list):
data[new_key].extend(val)
else:
Expand Down
5 changes: 4 additions & 1 deletion verl/workers/engine_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from verl.utils.distributed import initialize_global_process_group_ray
from verl.utils.flops_counter import FlopsCounter
from verl.utils.memory_utils import aggressive_empty_cache
from verl.utils.metric.utils import Metric
from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import allgather_dict_into_dict
Expand Down Expand Up @@ -242,7 +243,9 @@ def train_mini_batch(self, data: TensorDict) -> TensorDict:
for key, val in output.items():
# flattn dp and micro batch
if isinstance(val, list):
output[key] = list(chain.from_iterable(val))
output[key] = (
Metric.chain(val) if isinstance(val[0], Metric) else list(chain.from_iterable(val))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The expression isinstance(val[0], Metric) will raise an IndexError if val is an empty list. This will crash the worker. You should first check if the list is not empty before accessing its first element.

Suggested change
Metric.chain(val) if isinstance(val[0], Metric) else list(chain.from_iterable(val))
Metric.chain(val) if val and isinstance(val[0], Metric) else list(chain.from_iterable(val))

)
append_to_dict(metrics, output)

output = tu.get_tensordict(tensor_dict={}, non_tensor_dict={"metrics": metrics}).cpu()
Expand Down
Loading
Loading