-
Notifications
You must be signed in to change notification settings - Fork 3.2k
[worker, training_utils] fix: Engine Metric Aggregation #4778
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
cdc2a81
a172c15
b1170a9
3061d12
cb6c02c
d882004
939e940
bff8cb1
1d0e23e
1d1ac3e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -15,12 +15,14 @@ | |||||||||
| Metrics utils. | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| from typing import Any | ||||||||||
| from enum import Enum | ||||||||||
| from typing import Any, Optional, Union | ||||||||||
|
|
||||||||||
| import numpy as np | ||||||||||
| import torch | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]: | ||||||||||
| def reduce_metrics(metrics: dict[str, Union["MetricList", list[Any]]]) -> dict[str, Any]: | ||||||||||
| """ | ||||||||||
| Reduces a dictionary of metric lists by computing the mean, max, or min of each list. | ||||||||||
| The reduce operation is determined by the key name: | ||||||||||
|
|
@@ -45,10 +47,112 @@ def reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]: | |||||||||
| {"loss": 2.0, "accuracy": 0.8, "max_reward": 8.0, "min_error": 0.05} | ||||||||||
| """ | ||||||||||
| for key, val in metrics.items(): | ||||||||||
| if "max" in key: | ||||||||||
| if isinstance(val, MetricList): | ||||||||||
| metrics[key] = val.aggregate() | ||||||||||
| elif "max" in key: | ||||||||||
| metrics[key] = np.max(val) | ||||||||||
| elif "min" in key: | ||||||||||
| metrics[key] = np.min(val) | ||||||||||
| else: | ||||||||||
| metrics[key] = np.mean(val) | ||||||||||
| return metrics | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class AggregationType(Enum): | ||||||||||
| MEAN = "mean" | ||||||||||
| SUM = "sum" | ||||||||||
| MIN = "min" | ||||||||||
| MAX = "max" | ||||||||||
|
|
||||||||||
|
|
||||||||||
| Numeric = Union[int, float, torch.Tensor] | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def tensor_to_float(value: torch.Tensor) -> float: | ||||||||||
| if not isinstance(value, torch.Tensor): | ||||||||||
| raise ValueError(f"Expected torch.Tensor, got {type(value)}") | ||||||||||
| if value.numel() != 1: | ||||||||||
| raise ValueError("Only scalar tensors can be converted to float") | ||||||||||
| return value.detach().item() | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class MetricValue: | ||||||||||
| def __init__(self, value: Numeric, aggregation: str | AggregationType) -> None: | ||||||||||
| if isinstance(value, torch.Tensor): | ||||||||||
| value = tensor_to_float(value) | ||||||||||
| if not isinstance(value, Numeric): | ||||||||||
| raise ValueError(f"Unsupported value type: {type(value)}") | ||||||||||
JacobHelwig marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||
| self.value = value | ||||||||||
| if isinstance(aggregation, str): | ||||||||||
| self.aggregation = AggregationType(aggregation) | ||||||||||
| else: | ||||||||||
| self.aggregation = aggregation | ||||||||||
| if not isinstance(self.aggregation, AggregationType): | ||||||||||
| raise ValueError(f"Unsupported aggregation type: {aggregation}") | ||||||||||
|
|
||||||||||
| @classmethod | ||||||||||
| def from_dict(cls, data: dict[str, Numeric], aggregation: str | AggregationType) -> dict[str, "MetricValue"]: | ||||||||||
| return {key: cls(value, aggregation) for key, value in data.items()} | ||||||||||
|
|
||||||||||
| def init_list(self) -> "MetricList": | ||||||||||
| return MetricList(aggregation=self.aggregation) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class MetricList: | ||||||||||
| def __init__(self, aggregation: str | AggregationType, values: Optional[list[float]] = None) -> None: | ||||||||||
| if isinstance(aggregation, str): | ||||||||||
| self.aggregation = AggregationType(aggregation) | ||||||||||
| else: | ||||||||||
| self.aggregation = aggregation | ||||||||||
| if not isinstance(self.aggregation, AggregationType): | ||||||||||
| raise ValueError(f"Unsupported aggregation type: {aggregation}") | ||||||||||
| self.values = values if values is not None else [] | ||||||||||
|
|
||||||||||
| def append(self, value: float | MetricValue) -> None: | ||||||||||
|
||||||||||
| def append(self, value: float | MetricValue) -> None: | |
| def append(self, value: Union[Numeric, MetricValue]) -> None: |
JacobHelwig marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When chain is called with an empty list of metrics, it returns a new Metric with AggregationType.MEAN by default. This arbitrary default can lead to incorrect behavior. For example, if a list of metrics intended for summation is empty, chaining them should result in a metric that aggregates to 0 (the sum of an empty set), not NaN (the mean of an empty set). It would be safer to raise a ValueError if the list is empty, forcing the caller to handle this case explicitly.
| if len(metric_lists) == 0: | |
| return cls(aggregation=AggregationType.MEAN) | |
| if not metric_lists: | |
| raise ValueError("Cannot chain an empty list of metrics.") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,8 @@ | |
| from types import SimpleNamespace | ||
| from typing import Any, Callable, Iterator, Optional | ||
|
|
||
| from verl.utils.metric import MetricList, MetricValue | ||
|
|
||
|
|
||
| # --- Top-level helper for multiprocessing timeout --- | ||
| # This function MUST be defined at the top level to be pickleable | ||
|
|
@@ -196,8 +198,8 @@ def append_to_dict(data: dict, new_data: dict, prefix: str = ""): | |
| for key, val in new_data.items(): | ||
| new_key = f"{prefix}{key}" if not key.startswith(prefix) else key | ||
| if new_key not in data: | ||
| data[new_key] = [] | ||
| if isinstance(val, list): | ||
| data[new_key] = MetricValue.init_list(val) if isinstance(val, (MetricValue, MetricList)) else [] | ||
| if isinstance(val, (list, MetricList)): | ||
| data[new_key].extend(val) | ||
| else: | ||
| data[new_key].append(val) | ||
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using wildcard imports (
import *) is generally discouraged as it can pollute the namespace and make it unclear which names are being imported. It's better to explicitly import the names you need from theutilsmodule. This improves code readability and maintainability.