Skip to content

Commit ba0b1be

Browse files
committed
use incremental table
1 parent 9a52da5 commit ba0b1be

File tree

1 file changed

+30
-11
lines changed

1 file changed

+30
-11
lines changed

src/forge/observability/metrics.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,7 @@ def __init__(self, logger_backend_config: Dict[str, Any]):
887887
self.run = None
888888
self.logging_mode = LoggingMode(logger_backend_config["logging_mode"])
889889
self.per_rank_share_run = logger_backend_config.get("per_rank_share_run", False)
890+
self._tables: dict[str, "wandb.Table"] = {}
890891

891892
async def init(
892893
self,
@@ -989,30 +990,36 @@ def log_stream(self, metric: Metric, step: int, *args, **kwargs) -> None:
989990
self.run.log(log_data)
990991

991992
async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None:
992-
"""Log sample-level data to WandB Tables."""
993+
"""Log sample-level data incrementally to persistent WandB Tables."""
993994
import wandb
994995

995-
if not self.run or not samples:
996+
if not self.run:
996997
return
997998

998999
for table_name, table_rows in samples.items():
9991000
if not table_rows:
10001001
continue
10011002

1002-
# Use all keys to avoid dropped fields
1003-
columns = sorted({k for s in table_rows for k in s.keys()})
1004-
table = wandb.Table(columns=columns)
1003+
# If table doesn't exist yet, create it in INCREMENTAL mode
1004+
if table_name not in self._tables:
1005+
columns = list(table_rows[0].keys())
1006+
table = wandb.Table(columns=columns, log_mode="INCREMENTAL")
1007+
self._tables[table_name] = table
1008+
logger.info(
1009+
f"WandbBackend: Created new incremental table: {table_name}"
1010+
)
1011+
else:
1012+
table = self._tables[table_name]
10051013

1014+
# Add rows (fill missing columns with None)
10061015
for s in table_rows:
1007-
values = [s.get(c) for c in columns] # returns None for missing keys
1016+
values = [s.get(c) for c in table.columns]
10081017
table.add_data(*values)
10091018

1010-
# Unique table name avoids overwrite; commit forces sync
1011-
table_name = f"{table_name}_table_step{step}"
1012-
self.run.log({table_name: table, "_num_rows": len(table_rows)}, commit=True)
1013-
1019+
# Log the same table object (INCREMENTAL update)
1020+
self.run.log({f"{table_name}_table": table})
10141021
logger.info(
1015-
f"WandbBackend: Logged {len(table_rows)} samples for {table_name} at step {step}"
1022+
f"WandbBackend: Appended {len(table_rows)} rows to incremental table '{table_name}' at step {step}"
10161023
)
10171024

10181025
def get_metadata_for_secondary_ranks(self) -> Dict[str, Any]:
@@ -1021,7 +1028,19 @@ def get_metadata_for_secondary_ranks(self) -> Dict[str, Any]:
10211028
return {}
10221029

10231030
async def finish(self) -> None:
1031+
import wandb
1032+
10241033
if self.run:
1034+
# Convert each incremental table to immutable before finishing
1035+
for table_name, incr_table in self._tables.items():
1036+
final_table = wandb.Table(
1037+
columns=incr_table.columns,
1038+
data=incr_table.data,
1039+
log_mode="IMMUTABLE",
1040+
)
1041+
self.run.log({table_name: final_table})
1042+
logger.info(f"WandbBackend: Finalized table {table_name}")
1043+
10251044
self.run.finish()
10261045
logger.info(f"WandbBackend {self.name}: Finished run")
10271046

0 commit comments

Comments
 (0)