-
Notifications
You must be signed in to change notification settings - Fork 3.2k
[worker, training_utils] fix: Engine Metric Aggregation #4778
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cdc2a81
a172c15
b1170a9
3061d12
cb6c02c
d882004
939e940
bff8cb1
1d0e23e
1d1ac3e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -15,12 +15,14 @@ | |||||||||||||||
| Metrics utils. | ||||||||||||||||
| """ | ||||||||||||||||
|
|
||||||||||||||||
| from typing import Any | ||||||||||||||||
| from enum import Enum | ||||||||||||||||
| from typing import Any, Optional, Union | ||||||||||||||||
|
|
||||||||||||||||
| import numpy as np | ||||||||||||||||
| import torch | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| def reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]: | ||||||||||||||||
| def reduce_metrics(metrics: dict[str, Union["Metric", list[Any]]]) -> dict[str, Any]: | ||||||||||||||||
| """ | ||||||||||||||||
| Reduces a dictionary of metric lists by computing the mean, max, or min of each list. | ||||||||||||||||
| The reduce operation is determined by the key name: | ||||||||||||||||
|
|
@@ -45,10 +47,103 @@ def reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]: | |||||||||||||||
| {"loss": 2.0, "accuracy": 0.8, "max_reward": 8.0, "min_error": 0.05} | ||||||||||||||||
| """ | ||||||||||||||||
| for key, val in metrics.items(): | ||||||||||||||||
| if "max" in key: | ||||||||||||||||
| if isinstance(val, Metric): | ||||||||||||||||
| metrics[key] = val.aggregate() | ||||||||||||||||
| elif "max" in key: | ||||||||||||||||
| metrics[key] = np.max(val) | ||||||||||||||||
| elif "min" in key: | ||||||||||||||||
| metrics[key] = np.min(val) | ||||||||||||||||
| else: | ||||||||||||||||
| metrics[key] = np.mean(val) | ||||||||||||||||
| return metrics | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| class AggregationType(Enum): | ||||||||||||||||
| MEAN = "mean" | ||||||||||||||||
| SUM = "sum" | ||||||||||||||||
| MIN = "min" | ||||||||||||||||
| MAX = "max" | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| NumericType = int, float, torch.Tensor | ||||||||||||||||
| Numeric = int | float | torch.Tensor | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| class Metric: | ||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @JacobHelwig Docstring Coverage ci failed, please add doc string for Metric.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added! |
||||||||||||||||
| """ | ||||||||||||||||
| A metric aggregator for collecting and aggregating numeric values. | ||||||||||||||||
|
|
||||||||||||||||
| This class accumulates numeric values (int, float, or scalar tensors) and computes | ||||||||||||||||
| an aggregate statistic based on the specified aggregation type (MEAN, SUM, MIN, or MAX). | ||||||||||||||||
|
|
||||||||||||||||
| Args: | ||||||||||||||||
| aggregation: The aggregation method to use. Can be a string ("mean", "sum", "min", "max") | ||||||||||||||||
| or an AggregationType enum value. | ||||||||||||||||
| value: Optional initial value(s) to add. Can be a single numeric value or a list of values. | ||||||||||||||||
|
|
||||||||||||||||
| Example: | ||||||||||||||||
| >>> metric = Metric(aggregation="mean", value=1.0) | ||||||||||||||||
| >>> metric.append(2.0) | ||||||||||||||||
| >>> metric.append(3.0) | ||||||||||||||||
| >>> metric.aggregate() | ||||||||||||||||
| 2.0 | ||||||||||||||||
| """ | ||||||||||||||||
|
|
||||||||||||||||
| def __init__(self, aggregation: str | AggregationType, value: Optional[Numeric | list[Numeric]] = None) -> None: | ||||||||||||||||
| if isinstance(aggregation, str): | ||||||||||||||||
| self.aggregation = AggregationType(aggregation) | ||||||||||||||||
| else: | ||||||||||||||||
| self.aggregation = aggregation | ||||||||||||||||
| if not isinstance(self.aggregation, AggregationType): | ||||||||||||||||
| raise ValueError(f"Unsupported aggregation type: {aggregation}") | ||||||||||||||||
| self.values = [] | ||||||||||||||||
| if value is not None: | ||||||||||||||||
| self.append(value) | ||||||||||||||||
|
Comment on lines
+100
to
+101
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||
|
|
||||||||||||||||
| def append(self, value: Union[Numeric, "Metric"]) -> None: | ||||||||||||||||
| if isinstance(value, Metric): | ||||||||||||||||
| self.extend(value) | ||||||||||||||||
| return | ||||||||||||||||
| if isinstance(value, torch.Tensor): | ||||||||||||||||
| if value.numel() != 1: | ||||||||||||||||
| raise ValueError("Only scalar tensors can be converted to float") | ||||||||||||||||
| value = value.detach().item() | ||||||||||||||||
| if not isinstance(value, NumericType): | ||||||||||||||||
| raise ValueError(f"Unsupported value type: {type(value)}") | ||||||||||||||||
| self.values.append(value) | ||||||||||||||||
|
|
||||||||||||||||
| def extend(self, values: Union["Metric", list[Numeric]]) -> None: | ||||||||||||||||
| if isinstance(values, Metric): | ||||||||||||||||
| if values.aggregation != self.aggregation: | ||||||||||||||||
| raise ValueError(f"Aggregation type mismatch: {self.aggregation} != {values.aggregation}") | ||||||||||||||||
| values = values.values | ||||||||||||||||
| for value in values: | ||||||||||||||||
| self.append(value) | ||||||||||||||||
|
|
||||||||||||||||
| def aggregate(self) -> float: | ||||||||||||||||
| match self.aggregation: | ||||||||||||||||
| case AggregationType.MEAN: | ||||||||||||||||
| return np.mean(self.values) | ||||||||||||||||
| case AggregationType.SUM: | ||||||||||||||||
| return np.sum(self.values) | ||||||||||||||||
| case AggregationType.MIN: | ||||||||||||||||
| return np.min(self.values) | ||||||||||||||||
| case AggregationType.MAX: | ||||||||||||||||
| return np.max(self.values) | ||||||||||||||||
JacobHelwig marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||
|
|
||||||||||||||||
| @classmethod | ||||||||||||||||
| def chain(cls, metric_lists: list["Metric"]) -> "Metric": | ||||||||||||||||
| if len(metric_lists) == 0: | ||||||||||||||||
| return cls(aggregation=AggregationType.MEAN) | ||||||||||||||||
|
Comment on lines
+136
to
+137
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When
Suggested change
|
||||||||||||||||
| aggregation = metric_lists[0].aggregation | ||||||||||||||||
| chained = cls(aggregation=aggregation) | ||||||||||||||||
| for ml in metric_lists: | ||||||||||||||||
| chained.extend(ml) | ||||||||||||||||
| return chained | ||||||||||||||||
|
|
||||||||||||||||
| @classmethod | ||||||||||||||||
| def from_dict(cls, data: dict[str, Numeric], aggregation: str | AggregationType) -> dict[str, "Metric"]: | ||||||||||||||||
| return {key: cls(value=value, aggregation=aggregation) for key, value in data.items()} | ||||||||||||||||
|
|
||||||||||||||||
| def init_list(self) -> "Metric": | ||||||||||||||||
| return Metric(aggregation=self.aggregation) | ||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -35,6 +35,7 @@ | |||||
| from verl.utils.distributed import initialize_global_process_group_ray | ||||||
| from verl.utils.flops_counter import FlopsCounter | ||||||
| from verl.utils.memory_utils import aggressive_empty_cache | ||||||
| from verl.utils.metric.utils import Metric | ||||||
| from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage | ||||||
| from verl.utils.py_functional import append_to_dict | ||||||
| from verl.utils.torch_functional import allgather_dict_into_dict | ||||||
|
|
@@ -242,7 +243,9 @@ def train_mini_batch(self, data: TensorDict) -> TensorDict: | |||||
| for key, val in output.items(): | ||||||
| # flattn dp and micro batch | ||||||
| if isinstance(val, list): | ||||||
| output[key] = list(chain.from_iterable(val)) | ||||||
| output[key] = ( | ||||||
| Metric.chain(val) if isinstance(val[0], Metric) else list(chain.from_iterable(val)) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The expression
Suggested change
|
||||||
| ) | ||||||
JacobHelwig marked this conversation as resolved.
Show resolved
Hide resolved
JacobHelwig marked this conversation as resolved.
Show resolved
Hide resolved
JacobHelwig marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| append_to_dict(metrics, output) | ||||||
|
|
||||||
| output = tu.get_tensordict(tensor_dict={}, non_tensor_dict={"metrics": metrics}).cpu() | ||||||
|
|
||||||
Uh oh!
There was an error while loading. Please reload this page.