@@ -887,6 +887,7 @@ def __init__(self, logger_backend_config: Dict[str, Any]):
887887 self .run = None
888888 self .logging_mode = LoggingMode (logger_backend_config ["logging_mode" ])
889889 self .per_rank_share_run = logger_backend_config .get ("per_rank_share_run" , False )
890+ self ._tables : dict [str , "wandb.Table" ] = {}
890891
891892 async def init (
892893 self ,
@@ -989,30 +990,36 @@ def log_stream(self, metric: Metric, step: int, *args, **kwargs) -> None:
989990 self .run .log (log_data )
990991
991992 async def log_samples (self , samples : Dict [str , List [dict ]], step : int ) -> None :
992- """Log sample-level data to WandB Tables."""
993+ """Log sample-level data incrementally to persistent WandB Tables."""
993994 import wandb
994995
995- if not self .run or not samples :
996+ if not self .run :
996997 return
997998
998999 for table_name , table_rows in samples .items ():
9991000 if not table_rows :
10001001 continue
10011002
1002- # Use all keys to avoid dropped fields
1003- columns = sorted ({k for s in table_rows for k in s .keys ()})
1004- table = wandb .Table (columns = columns )
1003+ # If table doesn't exist yet, create it in INCREMENTAL mode
1004+ if table_name not in self ._tables :
1005+ columns = list (table_rows [0 ].keys ())
1006+ table = wandb .Table (columns = columns , log_mode = "INCREMENTAL" )
1007+ self ._tables [table_name ] = table
1008+ logger .info (
1009+ f"WandbBackend: Created new incremental table: { table_name } "
1010+ )
1011+ else :
1012+ table = self ._tables [table_name ]
10051013
1014+ # Add rows (fill missing columns with None)
10061015 for s in table_rows :
1007- values = [s .get (c ) for c in columns ] # returns None for missing keys
1016+ values = [s .get (c ) for c in table . columns ]
10081017 table .add_data (* values )
10091018
1010- # Unique table name avoids overwrite; commit forces sync
1011- table_name = f"{ table_name } _table_step{ step } "
1012- self .run .log ({table_name : table , "_num_rows" : len (table_rows )}, commit = True )
1013-
1019+ # Log the same table object (INCREMENTAL update)
1020+ self .run .log ({f"{ table_name } _table" : table })
10141021 logger .info (
1015- f"WandbBackend: Logged { len (table_rows )} samples for { table_name } at step { step } "
1022+ f"WandbBackend: Appended { len (table_rows )} rows to incremental table ' { table_name } ' at step { step } "
10161023 )
10171024
10181025 def get_metadata_for_secondary_ranks (self ) -> Dict [str , Any ]:
@@ -1021,7 +1028,19 @@ def get_metadata_for_secondary_ranks(self) -> Dict[str, Any]:
10211028 return {}
10221029
10231030 async def finish (self ) -> None :
1031+ import wandb
1032+
10241033 if self .run :
1034+ # Convert each incremental table to immutable before finishing
1035+ for table_name , incr_table in self ._tables .items ():
1036+ final_table = wandb .Table (
1037+ columns = incr_table .columns ,
1038+ data = incr_table .data ,
1039+ log_mode = "IMMUTABLE" ,
1040+ )
1041+ self .run .log ({table_name : final_table })
1042+ logger .info (f"WandbBackend: Finalized table { table_name } " )
1043+
10251044 self .run .finish ()
10261045 logger .info (f"WandbBackend { self .name } : Finished run" )
10271046
0 commit comments