|
8 | 8 | import itertools |
9 | 9 | import logging |
10 | 10 | import os |
11 | | -import random |
12 | 11 | from abc import ABC, abstractmethod |
13 | 12 | from dataclasses import dataclass |
14 | 13 | from datetime import datetime |
@@ -133,7 +132,6 @@ def record_episode_sample(key: str, episode): |
133 | 132 | Args: |
134 | 133 | key (str): logging prefix (e.g. "rollout/sample"). |
135 | 134 | episode (Episode): episode object with filled attributes. |
136 | | - reward_breakdown (dict[str, float]): per-function rewards, e.g. {"MathReward": 0.8, "FormatReward": 1.0}. |
137 | 135 | """ |
138 | 136 | sample = { |
139 | 137 | "episode_id": episode.episode_id, |
@@ -246,65 +244,7 @@ def reduce_metrics_states( |
246 | 244 | ################# |
247 | 245 |
|
248 | 246 |
|
249 | | -class SampleFilter(ABC): |
250 | | - """Abstract base class for sample filtering.""" |
251 | | - |
252 | | - @abstractmethod |
253 | | - def filter_append(self, sample: Dict) -> bool: |
254 | | - """ |
255 | | - Decide whether a sample should be kept at append time. |
256 | | - Return True if the sample should be stored, False otherwise. |
257 | | - """ |
258 | | - pass |
259 | | - |
260 | | - def filter_flush(self, samples: List[Dict]) -> List[Dict]: |
261 | | - """ |
262 | | - Optionally filter or transform the collected samples at flush time. |
263 | | - Default: return the samples unchanged. |
264 | | - """ |
265 | | - return samples |
266 | | - |
267 | | - def reset(self) -> None: |
268 | | - """Clears for next accumulation cycle.""" |
269 | | - pass |
270 | | - |
271 | | - |
272 | | -class RandomRatioFilter: |
273 | | - """Randomly keep a fraction of samples.""" |
274 | | - |
275 | | - def __init__(self, ratio=0.05): |
276 | | - self.ratio = ratio |
277 | | - |
278 | | - def filter_append(self, sample): |
279 | | - return random.random() < self.ratio |
280 | | - |
281 | | - |
282 | | -class RewardThresholdFilter: |
283 | | - """ |
284 | | - Keep samples only if their reward is < lt or > gt (depending on which bound is set). |
285 | | - If a bound is None, that side of the filter is disabled. |
286 | | - """ |
287 | | - |
288 | | - def __init__(self, lt=None, gt=None): |
289 | | - self.lt = lt |
290 | | - self.gt = gt |
291 | | - |
292 | | - def filter_append(self, sample): |
293 | | - r = sample.get("reward", 0.0) |
294 | | - |
295 | | - # If lt is set: drop samples with reward >= lt |
296 | | - if self.lt is not None and r >= self.lt: |
297 | | - return False |
298 | | - |
299 | | - # If gt is set: drop samples with reward <= gt |
300 | | - if self.gt is not None and r <= self.gt: |
301 | | - return False |
302 | | - |
303 | | - # Otherwise, keep this sample |
304 | | - return True |
305 | | - |
306 | | - |
307 | | -class TopBottomKFilter(SampleFilter): |
| 247 | +class TopBottomKFilter: |
308 | 248 | """Keep the top-k and bottom-k samples by a given key (e.g., reward).""" |
309 | 249 |
|
310 | 250 | def __init__(self, top_k=1, bottom_k=1, key="reward"): |
@@ -539,18 +479,19 @@ def reset(self) -> None: |
539 | 479 | class SampleAccumulator(MetricAccumulator): |
540 | 480 | """Accumulator for sample-level metrics (e.g., prompt/response/reward dicts). |
541 | 481 |
|
542 | | - Optionally uses a SampleFilter to decide what to keep at append/flush time. |
| 482 | + Optionally uses a sample filter to decide what to keep at append/flush time. |
543 | 483 | """ |
544 | 484 |
|
545 | 485 | def __init__( |
546 | | - self, reduction: Reduce, filter: SampleFilter | None = TopBottomKFilter() |
| 486 | + self, reduction: Reduce, filter: TopBottomKFilter | None = TopBottomKFilter() |
547 | 487 | ): |
548 | 488 | super().__init__(reduction) |
549 | 489 | self.samples: List[Dict[str, Any]] = [] |
550 | 490 | self.filter = filter |
551 | 491 |
|
552 | 492 | def append(self, value: dict) -> None: |
553 | | - assert isinstance(value, dict) |
| 493 | + if not isinstance(value, dict): |
| 494 | + raise ValueError(f"Expected dict, got {type(value)}") |
554 | 495 |
|
555 | 496 | # If filter is provided, only keep the sample if filter_append returns True |
556 | 497 | if self.filter: |
|
0 commit comments