-
Notifications
You must be signed in to change notification settings - Fork 15
Add Sample-level Logging API #309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…estamp_logging
…estamp_logging
…estamp_logging
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hey, thanks! Did a first pass. Just a heads up: my PR will add a considerable amount of conflicts. Might be better to rebase before continuing.
src/forge/observability/metrics.py
Outdated
record_metric(key, sample, Reduce.SAMPLE) | ||
|
||
|
||
def reduce_metrics_states( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems that the only change is checking if Reduce.SAMPLE.value
, but the core stays the same. I dont think that here is the correct place to put this abstraction, otherwise, what if the user wants to add another "if this", "if that".
Please take a look at this PR. I introduced dataclass Metric
that also holds the reduction type.
On later steps, users can just iterate over metrics and check, if Reduce.SAMPLE.value
, do this, otherwise, do that.
TDLR: lets keep the code here as it was and do the if/else elsewhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. I have rebased this PR to #303 , reverted the changes for reduce_metrics_states
, and added the if-else checks in MetricCollector.flush
src/forge/observability/metrics.py
Outdated
for key, rows in samples.items(): | ||
logger.info(f"[{key}] ({len(rows)} samples)") | ||
for sample in rows: | ||
pretty = json.dumps(sample, indent=2, ensure_ascii=False) | ||
logger.info(pretty) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
whats a key here? Is it a whole table? Perhaps we could rename to:
for table_name, table_rows in samples.items()?
I am fine if you want to push back, just didnt seem 100% obvious to me at first sight. Rows remind me of tables, so maybe we should try to stay close to wandb nomenclature and rename everything SAMPLE -> TABLE?
Also, should we just do
json.dumps(rows)
instead of dumping one sample at a time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The key here is rollout/sample
. And you are right, we can make them table_name, table_rows
to be more readable.
And this indeed can be simplified with json.dumps(rows)
. Done!
src/forge/observability/metrics.py
Outdated
table = wandb.Table(columns=columns) | ||
|
||
for s in rows: | ||
values = [s.get(c) for c in columns] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens if c is not present? does it return None or does it break?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be None
. I added a comment here to clarify.
else: | ||
logger.debug(f"WandbBackend: No run started, skipping log for {self.name}") | ||
|
||
async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
following my previous comment, maybe we should keep wandb nomenclature and have it be log_table
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. W&B uses the term table, but since this method is part of the generic LoggerBackend
interface(and the console backend doesn’t actually log tables):
class LoggerBackend(ABC):
"""Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc."""
...
async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None:
pass
Let's keep the name log_samples()
for consistency across backends?
src/forge/observability/metrics.py
Outdated
""" | ||
|
||
def __init__( | ||
self, reduction: Reduce, filter: TopBottomKFilter | None = TopBottomKFilter() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is not configurable, i think we should hardcode this filter logic in the SampleAccumulator.
We could drop self.samples here (and the if/else filter checks). I also dont think we need 'self._counter' logic.
If the user wants to change the logic, i am thinking that its more convenient for them to just create a new SampleAccumulatorV2. They have no way to pass down from the config the arg filter
to this accumulator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That’s a fair point. But I’d like to ask whether we envision making this configurable in the near future. If the answer is a solid yes, then I’d still prefer to keep the filter logic separate for now. It’s much easier (and more natural) for users to define a lightweight filter class than to implement an entirely new SampleAccumulatorV2
.
Merging the logic into SampleAccumulator
would make it harder to evolve later. We’d eventually need to split them apart again once customization becomes necessary.
Also, from a practical standpoint, the current design doesn’t introduce meaningful overhead: TopBottomKFilter.filter_append()
always returns False
, so we’re not storing samples twice. The filter only materializes the top/bottom-k subset at flush time.
Given that, I’d suggest keeping them separate for now to preserve flexibility for near-term extensions (the top/bottom-k filter may already be too limited for internal debugging needs).
But I agree that since we are hardcoding it, we don't need to make it over complicated. How about we do this and not make filter a parameter:
class SampleAccumulator(MetricAccumulator):
def __init__(self, reduction: Reduce):
super().__init__(reduction)
self.samples: List[Dict[str, Any]] = []
self.filter = TopBottomKFilter()
input_ids[i, :max_req_tokens] = episode.request_tensor | ||
input_ids[i, max_req_tokens:] = episode.response_tensor | ||
episode.reward = await reward_actor.evaluate_response.route( | ||
episode.reward_breakdown = await reward_actor.evaluate_response.route( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i like this, but i think it can be a bit dangerous. We dont have a dataclass that says the fields that it will hold. You would also need to make sure that the other actors are aware of this change. I am thinking that maybe we should keep episode.reward: float
and add an extra optional field episode.reward_breakdown: dict[float]
. Wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly! it comes with two field reward
and reward_breakdown
. If you look at the line below it:
episode.reward = episode.reward_breakdown["reward"]
@felipemello1 do i review #303 before this? |
@allenwang28 I have rebased this on #303. So let's land 303 first? |
yes, lets land 303 first. I am debugging one thing about wandb. I think i found a solution. Then we should be able to merge it |
…estamp_logging
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #309 +/- ##
=======================================
Coverage ? 64.88%
=======================================
Files ? 80
Lines ? 7982
Branches ? 0
=======================================
Hits ? 5179
Misses ? 2803
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
This PR introduces structured sample-level logging to complement existing scalar metrics, allowing users to inspect concrete prompt–response–reward examples during RL training.
More discussion: #301
The current implementation logs 2 samples: one top (highest reward) and one bottom (lowest reward) for each step. Supporting customized sampling strategy is out of scope of this PR. For now, this can be achieved by changing the filter for
SampleAccumulator
. For example:Summary of Changes
SampleAccumulator
to support logging structured dict samples (e.g., per-episode data) viarecord_episode_sample
API withReduce.SAMPLE
.TopBottomKFilter
for selecting top/bottom samples based on reward (heap-based, O(log k) per append).wandb.Table
for live sample inspection.Logged Fields
Each logged sample includes:
Tests:
python apps/toy_rl/toy_metrics/main.py
python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
for modes:logging_mode=global_reduce
logging_mode=per_rank_reduce
withper_rank_share_run=False
. The table is logged in the runController_xxx
logging_mode=per_rank_reduce
withper_rank_share_run=True
logging_mode=per_rank_no_reduce
withper_rank_share_run=False
. Because there's no reduce, so we are logging more than 2 samples each step.logging_mode=per_rank_no_reduce
withper_rank_share_run=True
(not working)Backend: wandb
Backend: Console:
Notes