Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions verl/utils/metric/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .utils import reduce_metrics
from .utils import *
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using wildcard imports (import *) is generally discouraged as it can pollute the namespace and make it unclear which names are being imported. It's better to explicitly import the names you need from the utils module. This improves code readability and maintainability.

Suggested change
from .utils import *
from .utils import AggregationType, MetricList, MetricValue, reduce_metrics


__all__ = ["reduce_metrics"]
__all__ = ["reduce_metrics", "AggregationType", "MetricValue", "MetricList"]
110 changes: 107 additions & 3 deletions verl/utils/metric/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
Metrics utils.
"""

from typing import Any
from enum import Enum
from typing import Any, Optional, Union

import numpy as np
import torch


def reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]:
def reduce_metrics(metrics: dict[str, Union["MetricList", list[Any]]]) -> dict[str, Any]:
"""
Reduces a dictionary of metric lists by computing the mean, max, or min of each list.
The reduce operation is determined by the key name:
Expand All @@ -45,10 +47,112 @@ def reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]:
{"loss": 2.0, "accuracy": 0.8, "max_reward": 8.0, "min_error": 0.05}
"""
for key, val in metrics.items():
if "max" in key:
if isinstance(val, MetricList):
metrics[key] = val.aggregate()
elif "max" in key:
metrics[key] = np.max(val)
elif "min" in key:
metrics[key] = np.min(val)
else:
metrics[key] = np.mean(val)
return metrics


class AggregationType(Enum):
MEAN = "mean"
SUM = "sum"
MIN = "min"
MAX = "max"


Numeric = Union[int, float, torch.Tensor]


def tensor_to_float(value: torch.Tensor) -> float:
if not isinstance(value, torch.Tensor):
raise ValueError(f"Expected torch.Tensor, got {type(value)}")
if value.numel() != 1:
raise ValueError("Only scalar tensors can be converted to float")
return value.detach().item()


class MetricValue:
def __init__(self, value: Numeric, aggregation: str | AggregationType) -> None:
if isinstance(value, torch.Tensor):
value = tensor_to_float(value)
if not isinstance(value, Numeric):
raise ValueError(f"Unsupported value type: {type(value)}")
self.value = value
if isinstance(aggregation, str):
self.aggregation = AggregationType(aggregation)
else:
self.aggregation = aggregation
if not isinstance(self.aggregation, AggregationType):
raise ValueError(f"Unsupported aggregation type: {aggregation}")

@classmethod
def from_dict(cls, data: dict[str, Numeric], aggregation: str | AggregationType) -> dict[str, "MetricValue"]:
return {key: cls(value, aggregation) for key, value in data.items()}

def init_list(self) -> "MetricList":
return MetricList(aggregation=self.aggregation)


class MetricList:
def __init__(self, aggregation: str | AggregationType, values: Optional[list[float]] = None) -> None:
if isinstance(aggregation, str):
self.aggregation = AggregationType(aggregation)
else:
self.aggregation = aggregation
if not isinstance(self.aggregation, AggregationType):
raise ValueError(f"Unsupported aggregation type: {aggregation}")
self.values = values if values is not None else []

def append(self, value: float | MetricValue) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The type hint value: float | MetricValue is inconsistent with the method's implementation, which also handles torch.Tensor and int values. To improve type safety and clarity, the hint should be expanded to include all supported types. Using the Numeric alias, which is Union[int, float, torch.Tensor], would be appropriate here.

Suggested change
def append(self, value: float | MetricValue) -> None:
def append(self, value: Union[Numeric, MetricValue]) -> None:

if isinstance(value, MetricValue):
if value.aggregation != self.aggregation:
raise AggregationTypeMismatchError(self.aggregation, value.aggregation)
value = value.value
if isinstance(value, torch.Tensor):
value = tensor_to_float(value)
if not isinstance(value, Numeric):
raise ValueError(f"Unsupported value type: {type(value)}")
self.values.append(value)

def extend(self, values: Union["MetricList", list[float | MetricValue]]) -> None:
if isinstance(values, MetricList):
if values.aggregation != self.aggregation:
raise AggregationTypeMismatchError(self.aggregation, values.aggregation)
values = values.values
for value in values:
self.append(value)

def aggregate(self) -> float:
match self.aggregation:
case AggregationType.MEAN:
return np.mean(self.values)
case AggregationType.SUM:
return np.sum(self.values)
case AggregationType.MIN:
return np.min(self.values)
case AggregationType.MAX:
return np.max(self.values)

@classmethod
def chain(cls, metric_lists: list["MetricList"]) -> "MetricList":
if len(metric_lists) == 0:
return cls(aggregation=AggregationType.MEAN)
Comment on lines +136 to +137
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When chain is called with an empty list of metrics, it returns a new Metric with AggregationType.MEAN by default. This arbitrary default can lead to incorrect behavior. For example, if a list of metrics intended for summation is empty, chaining them should result in a metric that aggregates to 0 (the sum of an empty set), not NaN (the mean of an empty set). It would be safer to raise a ValueError if the list is empty, forcing the caller to handle this case explicitly.

Suggested change
if len(metric_lists) == 0:
return cls(aggregation=AggregationType.MEAN)
if not metric_lists:
raise ValueError("Cannot chain an empty list of metrics.")

aggregation = metric_lists[0].aggregation
chained = cls(aggregation=aggregation)
for ml in metric_lists:
chained.extend(ml)
return chained

def init_list(self) -> "MetricList":
return MetricList(aggregation=self.aggregation)


class AggregationTypeMismatchError(Exception):
def __init__(self, agg1: AggregationType, agg2: AggregationType):
msg = f"Aggregation type mismatch: {agg1.value} != {agg2.value}"
super().__init__(msg)
6 changes: 4 additions & 2 deletions verl/utils/py_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from types import SimpleNamespace
from typing import Any, Callable, Iterator, Optional

from verl.utils.metric import MetricList, MetricValue


# --- Top-level helper for multiprocessing timeout ---
# This function MUST be defined at the top level to be pickleable
Expand Down Expand Up @@ -196,8 +198,8 @@ def append_to_dict(data: dict, new_data: dict, prefix: str = ""):
for key, val in new_data.items():
new_key = f"{prefix}{key}" if not key.startswith(prefix) else key
if new_key not in data:
data[new_key] = []
if isinstance(val, list):
data[new_key] = MetricValue.init_list(val) if isinstance(val, (MetricValue, MetricList)) else []
if isinstance(val, (list, MetricList)):
data[new_key].extend(val)
else:
data[new_key].append(val)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's a potential TypeError here. If data[new_key] is a standard Python list (which can happen if an empty list [] was appended for this key previously) and val is a MetricList, the call data[new_key].extend(val) will fail because MetricList is not iterable by default for list.extend. You should handle this case by 'upgrading' the list to a MetricList before extending to ensure type consistency.

        if isinstance(data[new_key], list) and isinstance(val, (MetricValue, MetricList)):
            # Upgrade list to MetricList if it's not already one
            new_list = val.init_list()
            new_list.extend(data[new_key])
            data[new_key] = new_list

        if isinstance(val, (list, MetricList)):
            data[new_key].extend(val)
        else:
            data[new_key].append(val)

Expand Down
7 changes: 6 additions & 1 deletion verl/workers/engine_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from verl.utils.distributed import initialize_global_process_group_ray
from verl.utils.flops_counter import FlopsCounter
from verl.utils.memory_utils import aggressive_empty_cache
from verl.utils.metric.utils import MetricList
from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import allgather_dict_into_dict
Expand Down Expand Up @@ -242,7 +243,11 @@ def train_mini_batch(self, data: TensorDict) -> TensorDict:
for key, val in output.items():
# flattn dp and micro batch
if isinstance(val, list):
output[key] = list(chain.from_iterable(val))
output[key] = (
MetricList.chain(val)
if isinstance(val[0], MetricList)
else list(chain.from_iterable(val))
)
append_to_dict(metrics, output)

output = tu.get_tensordict(tensor_dict={}, non_tensor_dict={"metrics": metrics}).cpu()
Expand Down
17 changes: 14 additions & 3 deletions verl/workers/utils/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from verl.trainer.ppo.core_algos import agg_loss, compute_value_loss, get_policy_loss_fn, kl_penalty
from verl.utils import tensordict_utils as tu
from verl.utils.dataset.dataset_utils import DatasetPadMode
from verl.utils.metric import AggregationType, MetricValue
from verl.utils.torch_functional import masked_mean, masked_sum
from verl.workers.config import ActorConfig, CriticConfig

Expand Down Expand Up @@ -103,6 +104,15 @@ def ppo_loss(config: ActorConfig, model_output, data: TensorDict, dp_group=None)
config.global_batch_info["batch_num_tokens"] = data["batch_num_tokens"]
config.global_batch_info["global_batch_size"] = data["global_batch_size"]
config.global_batch_info["loss_scale_factor"] = config.loss_scale_factor
if (
data["dp_size"] > 1
or data["batch_num_tokens"] is not None
or data["global_batch_size"] is not None
or config.loss_scale_factor is not None
):
metric_aggregation = AggregationType.SUM
else:
metric_aggregation = AggregationType.MEAN

metrics = {}

Expand All @@ -127,8 +137,8 @@ def ppo_loss(config: ActorConfig, model_output, data: TensorDict, dp_group=None)
rollout_is_weights=rollout_is_weights,
)

metrics.update(pg_metrics)
metrics["actor/pg_loss"] = pg_loss.detach().item()
metrics.update(MetricValue.from_dict(pg_metrics, aggregation=metric_aggregation))
metrics["actor/pg_loss"] = MetricValue(value=pg_loss, aggregation=metric_aggregation)
policy_loss = pg_loss

# add entropy loss
Expand All @@ -138,6 +148,7 @@ def ppo_loss(config: ActorConfig, model_output, data: TensorDict, dp_group=None)
)
entropy_coeff = config.entropy_coeff
policy_loss -= entropy_coeff * entropy_loss
metrics["actor/entropy_loss"] = MetricValue(value=entropy_loss, aggregation=metric_aggregation)

# add kl loss
if config.use_kl_loss:
Expand All @@ -149,7 +160,7 @@ def ppo_loss(config: ActorConfig, model_output, data: TensorDict, dp_group=None)
)

policy_loss += kl_loss * config.kl_loss_coef
metrics["kl_loss"] = kl_loss.detach().item()
metrics["kl_loss"] = MetricValue(value=kl_loss, aggregation=metric_aggregation)
metrics["kl_coef"] = config.kl_loss_coef

return policy_loss, metrics
Expand Down