From 995b23e48c26bd92e66403d2199b3206006d7efd Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Fri, 5 Jun 2026 07:10:27 +0000 Subject: [PATCH 01/46] [feat] SID: add SidRqkmeans model (FAISS-trained residual K-Means) Second of three PRs splitting the Semantic-ID models onto the shared base from #538. Adds the concrete RQ-KMeans backend on top of ResidualQuantizer / BaseSidModel; RQ-VAE follows in PR3. - tzrec/modules/sid/kmeans.py: KMeansLayer centroid container + recon_diagnostics. - tzrec/modules/sid/residual_kmeans_quantizer.py: ResidualKMeansQuantizer (FAISS-trained, FX-traceable forward, non-uniform per-layer codebooks). - tzrec/models/sid_rqkmeans.py: SidRqkmeans(BaseSidModel) - gradient -free; reservoir-samples embeddings during the train loop and fits FAISS once in on_train_end. - tzrec/models/model.py: BaseModel.on_train_end() no-op lifecycle hook. - tzrec/main.py: invoke on_train_end after the train loop and force the tail checkpoint so post-hook state is persisted. - protos: SidRqkmeans message + ModelConfig registration (601; 600 is reserved for SidRqvae in PR3). - tests: kmeans_test, ResidualKMeansQuantizerTest, sid_rqkmeans_test. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/main.py | 9 + tzrec/models/model.py | 9 + tzrec/models/sid_rqkmeans.py | 343 ++++++++++++++++++ tzrec/models/sid_rqkmeans_test.py | 306 ++++++++++++++++ tzrec/modules/sid/kmeans.py | 222 ++++++++++++ tzrec/modules/sid/kmeans_test.py | 100 +++++ .../modules/sid/residual_kmeans_quantizer.py | 248 +++++++++++++ tzrec/modules/sid/residual_quantizer_test.py | 92 +++++ tzrec/protos/model.proto | 5 + tzrec/protos/models/sid_model.proto | 31 ++ 10 files changed, 1365 insertions(+) create mode 100644 tzrec/models/sid_rqkmeans.py create mode 100644 tzrec/models/sid_rqkmeans_test.py create mode 100644 tzrec/modules/sid/kmeans.py create mode 100644 tzrec/modules/sid/kmeans_test.py create mode 100644 tzrec/modules/sid/residual_kmeans_quantizer.py create mode 100644 tzrec/protos/models/sid_model.proto diff --git a/tzrec/main.py b/tzrec/main.py index 87f2984fb..8824e8373 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -500,6 +500,15 @@ def _train_and_evaluate( if lr.by_epoch: lr.step() + # One-shot end-of-loop hook (default no-op). Some models do real work + # here — e.g. SidRqkmeans fits its FAISS codebook from the embeddings + # collected during training. Since that mutates model state, force the + # tail-save below to fire so the post-hook state is persisted even when + # the last in-loop checkpoint coincided with the final step. + _model.on_train_end() + if last_ckpt_step == i_step: + last_ckpt_step = -1 + _log_train( i_step, losses, diff --git a/tzrec/models/model.py b/tzrec/models/model.py index 40da5335a..10fa8aae5 100644 --- a/tzrec/models/model.py +++ b/tzrec/models/model.py @@ -150,6 +150,15 @@ def compute_train_metric(self) -> Dict[str, torch.Tensor]: metric_results[metric_name] = metric.compute() return metric_results + def on_train_end(self) -> None: + """Hook fired once after the train_eval loop exits. + + Default: no-op. Override in models that need one-shot end-of-loop + work — e.g. :class:`SidRqkmeans` uses this hook to fit the FAISS + codebook from the embedding sample it collected during training. + """ + pass + def sparse_parameters( self, ) -> Tuple[Iterable[nn.Parameter], Iterable[nn.Parameter]]: diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py new file mode 100644 index 000000000..b9c3c8800 --- /dev/null +++ b/tzrec/models/sid_rqkmeans.py @@ -0,0 +1,343 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SidRqkmeans: SID generation model using residual K-Means. + +Training is FAISS-only: ``predict`` collects embeddings into a CPU +buffer; the actual FAISS fit is triggered ONCE after the train_eval +loop ends, via the :meth:`BaseModel.on_train_end` lifecycle hook +(``tzrec.main`` calls ``_model.on_train_end()`` unconditionally). +""" + +from typing import Any, Dict, List, Optional + +import torch +import torch.distributed as dist +import torchmetrics +from torch import nn + +from tzrec.datasets.utils import Batch +from tzrec.features.feature import BaseFeature +from tzrec.models.sid_model import BaseSidModel +from tzrec.modules.sid.kmeans import recon_diagnostics +from tzrec.modules.sid.residual_kmeans_quantizer import ( + ResidualKMeansQuantizer, +) +from tzrec.protos.model_pb2 import ModelConfig +from tzrec.utils import config_util +from tzrec.utils.logging_util import logger + + +def _coerce_proto_numbers(d: Dict) -> Dict: + """Coerce float-typed integers back to int. + + ``google.protobuf.Struct.number_value`` is always float, but most + ``faiss.Kmeans`` kwargs (``niter``, ``seed``, ``nredo``, ...) require + Python ``int``. This helper converts any float that is an exact + integer to ``int`` for downstream consumption. + """ + out: Dict = {} + for k, v in d.items(): + if isinstance(v, float) and v.is_integer(): + out[k] = int(v) + else: + out[k] = v + return out + + +class SidRqkmeans(BaseSidModel): + """SID generation model using residual K-Means (FAISS-only). + + No gradient-based training. The codebook is built once at the end + of the train_eval loop via a single FAISS K-Means pass over the + embeddings collected during training. + + Args: + model_config (ModelConfig): an instance of ModelConfig. + features (list): list of features. + labels (list): list of label names. + sample_weights (list): sample weight names. + """ + + def __init__( + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + super().__init__(model_config, features, labels, sample_weights, **kwargs) + + cfg = self._model_config # SidRqkmeans proto message + + # config_to_kwargs returns Struct numbers as floats (it is + # MessageToDict under the hood), so _coerce_proto_numbers restores + # the ints faiss.Kmeans expects (niter, seed, nredo, ...). + self._faiss_kwargs = ( + _coerce_proto_numbers(config_util.config_to_kwargs(cfg.faiss_kmeans_kwargs)) + if cfg.HasField("faiss_kmeans_kwargs") + else {} + ) + + self._quantizer = ResidualKMeansQuantizer( + embed_dim=self._input_dim, + n_layers=self._n_layers, + n_embed=self._n_embed_list, + normalize_residuals=self._normalize_residuals, + faiss_kmeans_kwargs=self._faiss_kwargs, + ) + + # Per-rank reservoir cap. FAISS K-Means only ever consumes + # K * max_points_per_centroid points (it subsamples internally), so + # buffering the full corpus is wasted memory. We reservoir-sample to + # that target instead, split across ranks so the gathered set on + # rank0 is ~train_sample_size and FAISS does no further subsampling. + # Use the LARGEST per-layer K so non-uniform codebooks (e.g. + # [256, 512, 1024]) still feed their biggest layer enough points. + k = max(self._n_embed_list) + max_ppc = int(self._faiss_kwargs.get("max_points_per_centroid", 256)) + global_target = ( + cfg.train_sample_size if cfg.train_sample_size > 0 else k * max_ppc + ) + world_size = dist.get_world_size() if dist.is_initialized() else 1 + self._sample_cap = max(1, -(-global_target // world_size)) # ceil div + + # Bounded host-resident reservoir (allocated lazily on first batch, + # once the embedding dim/device is known). ``_n_filled`` slots hold + # data; ``_n_seen`` is the running count for the sampling probability. + self._reservoir: Optional[torch.Tensor] = None + self._n_filled = 0 + self._n_seen = 0 + + # KMeans has no learnable parameters (centroids use register_buffer). + # Add dummy param to keep optimizer/DDP happy. + self._dummy_param = nn.Parameter(torch.zeros(1), requires_grad=True) + + @torch.no_grad() + def _reservoir_add(self, x: torch.Tensor) -> None: + """Add a batch to the bounded reservoir (Vitter's Algorithm R). + + Keeps a uniform random ``self._sample_cap`` subset of every embedding + seen so far in O(cap) host memory, in a single streaming pass. + + Args: + x (Tensor): a batch of embeddings, shape (B, D); copied to host. + """ + x = x.detach().to("cpu", dtype=torch.float32) + cap = self._sample_cap + if self._reservoir is None: + self._reservoir = torch.empty(cap, x.shape[1], dtype=torch.float32) + + # Phase 1: fill empty slots first. + if self._n_filled < cap: + take = min(x.shape[0], cap - self._n_filled) + self._reservoir[self._n_filled : self._n_filled + take] = x[:take] + self._n_filled += take + self._n_seen += take + x = x[take:] + if x.shape[0] == 0: + return + + # Phase 2: replacement. Row j (0-indexed in x) is the + # (n_seen + j)-th item seen; it enters the reservoir with prob + # cap / (n_seen + j + 1), displacing a uniformly-random slot. + r = x.shape[0] + pos = self._n_seen + torch.arange(r) + accept = torch.rand(r) < (cap / (pos + 1).to(torch.float64)) + idx = accept.nonzero(as_tuple=True)[0] + if idx.numel() > 0: + slots = torch.randint(0, cap, (idx.numel(),)) + # Intra-batch slot collisions resolve last-write-wins; the bias is + # O(B/cap) per step and negligible for codebook fitting. + self._reservoir[slots] = x[idx] + self._n_seen += r + + def _reservoir_sample(self) -> torch.Tensor: + """Return the filled portion of the reservoir, shape (n_filled, D).""" + if self._reservoir is None or self._n_filled == 0: + return torch.empty(0, self._input_dim, dtype=torch.float32) + return self._reservoir[: self._n_filled] + + def _reset_reservoir(self) -> None: + """Drop the reservoir after the FAISS fit to free host memory.""" + self._reservoir = None + self._n_filled = 0 + self._n_seen = 0 + + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: + """Predict the model. + + Training: buffer embeddings only (codes are dummy until FAISS fits). + Eval/inference (after ``on_train_end``): real predict + lookup. + + Args: + batch (Batch): input batch data. + + Return: + predictions (dict): a dict of predicted result. + """ + embedding = self._extract_feature(batch) + + # Training: reservoir-sample into a bounded host buffer for the + # end-of-loop FAISS fit, and return dummy codes — the codebook does + # not exist yet. The reservoir caps memory at _sample_cap rows + # regardless of corpus size (FAISS only consumes a subset anyway). + if self.is_train: + self._reservoir_add(embedding) + B = embedding.shape[0] + return { + "codes": torch.zeros( + B, self._n_layers, dtype=torch.long, device=embedding.device + ) + } + + codes, quantized = self._quantizer(embedding) + + predictions: Dict[str, torch.Tensor] = { + "codes": codes, + } + + if self.is_eval: + predictions["quantized"] = quantized + predictions["input_embedding"] = embedding + + return predictions + + def loss( + self, predictions: Dict[str, torch.Tensor], batch: Batch + ) -> Dict[str, torch.Tensor]: + """Compute loss of the model. + + Returns zero loss to keep TrainWrapper backward happy. + _dummy_param * 0.0 ensures a compute graph exists so DDP + does not complain about unused parameters. + + Args: + predictions (dict): a dict of predicted result. + batch (Batch): input batch data. + + Return: + losses (dict): a dict of loss tensor. + """ + return {"dummy_loss": self._dummy_param.sum() * 0.0} + + def init_metric(self) -> None: + """Initialize metric modules (shared eval metrics + rel_loss). + + Only eval metrics are registered. During training ``predict`` + returns dummy zero codes (the codebook does not exist yet), so + any train-time metric would be either NaN or trivially constant; + the inherited no-op ``update_train_metric`` keeps the train path + empty (``compute_train_metric`` then returns an empty dict, which + the framework already tolerates). + """ + super().init_metric() + self._metric_modules["rel_loss"] = torchmetrics.MeanMetric() + + def update_metric( + self, + predictions: Dict[str, torch.Tensor], + batch: Batch, + losses: Optional[Dict[str, torch.Tensor]] = None, + ) -> None: + """Update metric state. + + Args: + predictions (dict): a dict of predicted result. + batch (Batch): input batch data. + losses (dict, optional): a dict of loss. + """ + if "input_embedding" in predictions: + _, rel = recon_diagnostics( + predictions["input_embedding"], + predictions["quantized"], + ) + # MeanSquaredError aggregates (preds, target) itself; rel_loss has + # no torchmetrics equivalent so it stays a MeanMetric. + self._metric_modules["mse"].update( + predictions["quantized"], predictions["input_embedding"] + ) + self._metric_modules["rel_loss"].update(rel) + + self._metric_modules["unique_sid_ratio"].update(predictions["codes"]) + + @torch.no_grad() + def on_train_end(self) -> None: + """Trigger one-shot FAISS fit after the train_eval loop ends. + + Overrides :meth:`BaseModel.on_train_end`. Called unconditionally + by ``tzrec.main.train_and_evaluate`` after the training loop exits. + + DDP behavior: + - rank0: receive each rank's reservoir sample via gather_object, + concat, run FAISS fit, then broadcast centroids to all ranks. + - other ranks: ship their reservoir sample via gather_object + (dst=0) and wait for the broadcast. + + No cross-rank empty-buffer handshake is needed: the dataset layer + enforces ``num_files >= world_size`` (``tzrec.datasets.dataset`` + raises otherwise), so in synchronized training every rank receives + at least one shard and reaches the gather with a non-empty sample. + """ + is_ddp = ( + dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1 + ) + + local = self._reservoir_sample() + self._reset_reservoir() + + if is_ddp: + # DDP path: every rank ships its reservoir sample to rank 0 via + # gather_object. Each sample is bounded by _sample_cap, so the + # gathered set on rank0 is ~train_sample_size and FAISS does no + # further subsampling. + rank = dist.get_rank() + gathered: Optional[List[Optional[torch.Tensor]]] = ( + [None] * dist.get_world_size() if rank == 0 else None + ) + dist.gather_object(local, gathered, dst=0) + del local + if rank == 0: + assert gathered is not None + full = torch.cat([g for g in gathered if g is not None], dim=0) + del gathered + logger.info( + "[SidRqkmeans.on_train_end] rank0 fitting FAISS " + "on %d samples (D=%d)." % (full.shape[0], full.shape[1]) + ) + self._quantizer.train_offline(full, verbose=True) + del full + # Broadcast centroids and set the init flag locally on every + # rank. ``_is_initialized`` is a bool buffer and NCCL's bool + # dtype support is inconsistent across versions, so we avoid + # a separate broadcast for it — all ranks enter this block in + # lockstep, so a local fill_() keeps state consistent. + for layer in self._quantizer.layers: + dist.broadcast(layer.centroids, src=0) + layer._is_initialized.fill_(True) + dist.barrier() + return + + # Single-process path. Guard an empty sample with a plain local check + # (no collective): on_train_end may be invoked without a training pass. + if local.shape[0] == 0: + logger.warning( + "[SidRqkmeans.on_train_end] empty reservoir; skipping FAISS " + "fit. Did the train_eval loop run?" + ) + return + + logger.info( + "[SidRqkmeans.on_train_end] fitting FAISS on %d samples (D=%d)." + % (local.shape[0], local.shape[1]) + ) + self._quantizer.train_offline(local, verbose=True) diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py new file mode 100644 index 000000000..8b224afac --- /dev/null +++ b/tzrec/models/sid_rqkmeans_test.py @@ -0,0 +1,306 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torchrec import KeyedTensor + +from tzrec.datasets.utils import BASE_DATA_GROUP, Batch +from tzrec.models.sid_rqkmeans import SidRqkmeans +from tzrec.protos import model_pb2 +from tzrec.protos.models import sid_model_pb2 +from tzrec.utils import misc_util +from tzrec.utils.state_dict_util import init_parameters + +WORLD_SIZE = 2 + + +def _make_batch(batch_size: int, input_dim: int, device: str = "cpu") -> Batch: + """Create a minimal Batch with dense embedding features.""" + dense_feature = KeyedTensor.from_tensor_list( + keys=["item_emb"], + tensors=[torch.randn(batch_size, input_dim, device=device)], + ) + return Batch( + dense_features={BASE_DATA_GROUP: dense_feature}, + sparse_features={}, + labels={}, + ) + + +def _build_model(input_dim=32, n_layers=2, niter=5, codebook=None) -> SidRqkmeans: + """Build a SidRqkmeans configured for offline FAISS fit. + + Module-level (not a method) so the spawned DDP workers below can build + the same model; callers move it to a device / init params as needed. + SID models read the item-embedding dense feature directly from the batch + and do not consume feature_groups, so none is set. + """ + from google.protobuf.struct_pb2 import Struct + + n_embed_list = codebook if codebook is not None else [16] * n_layers + faiss_kwargs = Struct() + faiss_kwargs.update({"niter": niter, "verbose": False, "seed": 1234}) + cfg = sid_model_pb2.SidRqkmeans( + input_dim=input_dim, + codebook=n_embed_list, + normalize_residuals=False, + faiss_kmeans_kwargs=faiss_kwargs, + embedding_feature_name="item_emb", + ) + return SidRqkmeans( + model_config=model_pb2.ModelConfig(sid_rqkmeans=cfg), + features=[], + labels=[], + ) + + +class SidRqkmeansOfflineTest(unittest.TestCase): + """Single-process tests for SidRqkmeans (FAISS-only).""" + + def _create_model(self, input_dim=32, n_layers=2, niter=5, codebook=None): + """Create a SidRqkmeans on CPU with params initialized.""" + model = _build_model(input_dim, n_layers, niter, codebook) + init_parameters(model, device=torch.device("cpu")) + return model + + def test_proto_parse(self) -> None: + """Verify faiss_kmeans_kwargs are parsed correctly.""" + model = self._create_model() + self.assertEqual(model._faiss_kwargs.get("niter"), 5) + self.assertEqual(model._faiss_kwargs.get("seed"), 1234) + self.assertFalse(model._faiss_kwargs.get("verbose")) + self.assertEqual(model._n_seen, 0) + self.assertIsNone(model._reservoir) + + def test_predict_collects_buffer(self) -> None: + """In train mode, predict reservoir-samples; never fits.""" + B, input_dim = 8, 32 + model = self._create_model(input_dim=input_dim) + model.train() + + for _ in range(4): + batch = _make_batch(B, input_dim) + preds = model.predict(batch) + self.assertIn("codes", preds) + + # Reservoir holds all 4*B samples (well under the cap) and tracks + # the running count. + self.assertEqual(model._n_seen, 4 * B) + self.assertEqual(model._n_filled, 4 * B) + # FAISS not yet triggered: layers should be uninitialized + for layer in model._quantizer.layers: + self.assertFalse(layer.is_initialized) + + def test_reservoir_caps_memory(self) -> None: + """Reservoir bounds the buffer at _sample_cap regardless of corpus.""" + B, input_dim = 16, 8 + model = self._create_model(input_dim=input_dim) + model._sample_cap = 10 # force a tiny cap + model._reset_reservoir() + model.train() + for _ in range(20): # 320 rows >> cap + model.predict(_make_batch(B, input_dim)) + self.assertEqual(model._n_seen, 20 * B) + self.assertEqual(model._n_filled, 10) + self.assertEqual(model._reservoir.shape, (10, input_dim)) + + def test_on_train_end_runs_faiss(self) -> None: + """on_train_end triggers FAISS fit and clears buffer.""" + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + + B, input_dim = 64, 32 + model = self._create_model(input_dim=input_dim) + model.train() + + # Accumulate enough samples (FAISS K-Means needs at least K points) + for _ in range(8): + model.predict(_make_batch(B, input_dim)) + self.assertGreater(model._n_seen, 0) + + # Trigger one-shot FAISS fit + model.on_train_end() + + # Reservoir should be released after the fit + self.assertEqual(model._n_seen, 0) + self.assertIsNone(model._reservoir) + # All layers should be initialized + centroids non-zero + for layer in model._quantizer.layers: + self.assertTrue(bool(layer._is_initialized.item())) + self.assertGreater(layer.centroids.abs().sum().item(), 0.0) + + # After fit, predict on eval should produce valid codes + model.eval() + preds = model.predict(_make_batch(B, input_dim)) + codes = preds["codes"] + self.assertEqual(codes.shape, (B, 2)) + self.assertTrue((codes >= 0).all() and (codes < 16).all()) + + def test_non_uniform_codebook_end_to_end(self) -> None: + """Non-uniform codebook [8, 4, 16]: fit then emit per-layer codes.""" + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + + B, input_dim = 64, 32 + codebook = [8, 4, 16] + model = self._create_model(input_dim=input_dim, codebook=codebook) + # Reservoir cap derives from the LARGEST K (16), not the first (8). + self.assertEqual( + model._sample_cap, + 16 * int(model._faiss_kwargs.get("max_points_per_centroid", 256)), + ) + + model.train() + for _ in range(8): + model.predict(_make_batch(B, input_dim)) + model.on_train_end() + + for k, layer in zip(codebook, model._quantizer.layers): + self.assertTrue(bool(layer._is_initialized.item())) + self.assertEqual(layer.centroids.shape[0], k) + + model.eval() + codes = model.predict(_make_batch(B, input_dim))["codes"] + self.assertEqual(codes.shape, (B, 3)) + for i, k in enumerate(codebook): + self.assertTrue((codes[:, i] >= 0).all() and (codes[:, i] < k).all()) + + def test_on_train_end_noop_on_empty_buffer(self) -> None: + """on_train_end on an empty buffer is a warned no-op.""" + model = self._create_model() + model.on_train_end() # should not raise + + def test_post_fit_checkpoint_round_trips(self) -> None: + """Fit → save state_dict → load into fresh instance → predict. + + After loading, ``predict`` must return real (non-zero) codes — + the centroids and the ``_is_initialized`` flag both need to come + through the state_dict. + """ + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + + B, input_dim = 64, 32 + src = self._create_model(input_dim=input_dim) + src.train() + for _ in range(8): + src.predict(_make_batch(B, input_dim)) + src.on_train_end() + sd = src.state_dict() + + dst = self._create_model(input_dim=input_dim) + dst.load_state_dict(sd) + dst.eval() + codes = dst.predict(_make_batch(B, input_dim))["codes"] + self.assertGreater( + codes.abs().sum().item(), + 0, + "post-fit checkpoint resume produced all-zero codes", + ) + + def test_mid_fit_checkpoint_rejected_on_load(self) -> None: + """Tampered state (_is_initialized=True + zero centroids) raises.""" + model = self._create_model() + sd = model.state_dict() + # Simulate a checkpoint that captured the flag mid-fit (before + # load_centroids_ ran): True flag, zero centroids. + layer0_prefix = next( + k.rsplit("._is_initialized", 1)[0] + for k in sd + if k.endswith("._is_initialized") + ) + sd[f"{layer0_prefix}._is_initialized"] = torch.tensor(True) + + fresh = self._create_model() + with self.assertRaisesRegex(RuntimeError, "mid-FAISS-fit"): + fresh.load_state_dict(sd) + + +# -------------------------------------------------------------------------- +# Distributed (multi-process) test for the DDP on_train_end path: the +# cross-rank gather_object -> FAISS fit -> broadcast sequence the in-process +# tests above cannot reach. NCCL on GPU when >=2 devices, else gloo/CPU. +# -------------------------------------------------------------------------- +def _init_dist(rank: int, world_size: int, port: int) -> torch.device: + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + use_cuda = torch.cuda.is_available() and torch.cuda.device_count() >= world_size + if use_cuda: + torch.cuda.set_device(rank) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + return torch.device(f"cuda:{rank}") + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + return torch.device("cpu") + + +def _on_train_end_worker(rank: int, world_size: int, port: int) -> None: + device = _init_dist(rank, world_size, port) + input_dim, n_layers, k = 16, 2, 16 + model = _build_model(input_dim, n_layers, codebook=[k] * n_layers).to(device) + model.train() + + torch.manual_seed(100 + rank) + for _ in range(6): + model.predict(_make_batch(32, input_dim, device)) + assert model._n_seen == 6 * 32, f"rank{rank}: reservoir not filled" + + # gather_object -> rank0 FAISS fit -> broadcast centroids + fill flag. + model.on_train_end() + + for layer in model._quantizer.layers: + assert bool(layer._is_initialized.item()), f"rank{rank}: layer uninit" + assert layer.centroids.abs().sum().item() > 0.0, f"rank{rank}: zero centroids" + # Centroids were broadcast from rank0 -> must be bit-identical across ranks. + for layer in model._quantizer.layers: + cmin, cmax = layer.centroids.clone(), layer.centroids.clone() + dist.all_reduce(cmin, op=dist.ReduceOp.MIN) + dist.all_reduce(cmax, op=dist.ReduceOp.MAX) + assert torch.allclose(cmin, cmax), f"rank{rank}: centroids differ across ranks" + + model.eval() + codes = model.predict(_make_batch(8, input_dim, device))["codes"] + assert codes.shape == (8, n_layers), f"rank{rank}: bad codes shape {codes.shape}" + assert (codes >= 0).all() and (codes < k).all(), f"rank{rank}: codes out of range" + dist.destroy_process_group() + + +class SidRqkmeansDistTest(unittest.TestCase): + """2-rank test for SidRqkmeans.on_train_end (gather -> fit -> broadcast).""" + + def test_on_train_end_ddp(self) -> None: + port = misc_util.get_free_port() + ctx = mp.get_context("spawn") + procs = [] + for rank in range(WORLD_SIZE): + p = ctx.Process(target=_on_train_end_worker, args=(rank, WORLD_SIZE, port)) + p.start() + procs.append(p) + for i, p in enumerate(procs): + p.join() + if p.exitcode != 0: + raise RuntimeError(f"worker-{i} failed (exitcode={p.exitcode}).") + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans.py new file mode 100644 index 000000000..0b6fe4255 --- /dev/null +++ b/tzrec/modules/sid/kmeans.py @@ -0,0 +1,222 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""K-Means utilities for the SID-generation stack. + +This module is the single home for torch-native K-Means code used by +SID models: + +* :class:`KMeansLayer` — per-layer centroid container used by + :class:`ResidualKMeansQuantizer`. Centroids are injected + by the FAISS backend via ``load_centroids_``; the only forward path + is ``predict``. +* :func:`faiss_residual_kmeans` — FAISS residual K-Means used by + :class:`ResidualVectorQuantizer` to warm-start the RQ-VAE codebook on the + first training batch (same FAISS backend as the offline RQ-KMeans fit). +""" + +from typing import Dict, List, Optional, Tuple + +import torch +from torch import nn + + +def recon_diagnostics( + x: torch.Tensor, + out: torch.Tensor, + epsilon: float = 1e-4, +) -> Tuple[torch.Tensor, torch.Tensor]: + """MSE + relative-L1 reconstruction diagnostics. + + Shared by :meth:`SidRqkmeans.update_metric` (which wants tensors for + ``torchmetrics.MeanMetric``) and :meth:`ResidualKMeansQuantizer.train_offline`'s + per-layer log line (which converts to Python floats via ``.item()``). + + Args: + x: ground-truth embedding, shape (B, D). + out: quantized reconstruction, shape (B, D). + epsilon: numerical stabilizer for the relative-L1 denominator. + + Returns: + mse: scalar ``((out - x) ** 2).mean()``. + rel: scalar relative-L1 ``mean(|x - out| / (max(|x|, |out|) + eps))``. + """ + mse = ((out - x) ** 2).mean() + rel = ( + torch.abs(x - out) / (torch.maximum(torch.abs(x), torch.abs(out)) + epsilon) + ).mean() + return mse, rel + + +@torch.no_grad() +def _squared_euclidean_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Squared L2 distance between rows of ``x`` and ``y``. + + Args: + x (Tensor): data points, shape (N, D). + y (Tensor): centroids, shape (K, D). + + Returns: + Tensor: squared distances, shape (N, K). + + Called per-batch from :meth:`KMeansLayer.predict`, so ``N`` is the batch + size and the full (N, K) product is small. Kept branch-free (no + data-dependent chunking on ``N``) so the predict forward stays + FX-traceable: torchrec's inference pipeline symbolically traces the + model, and a ``if N <= chunk_size`` on the traced batch dim raises a + ``torch.fx`` TraceError. + """ + x_sq = x.pow(2).sum(dim=1, keepdim=True) # (N, 1) + y_sq = y.pow(2).sum(dim=1, keepdim=True).t() # (1, K) + return (x_sq + y_sq - 2.0 * x @ y.t()).clamp_(min=0.0) + + +@torch.no_grad() +def faiss_residual_kmeans( + samples: torch.Tensor, + n_clusters_list: List[int], + faiss_kmeans_kwargs: Optional[Dict] = None, +) -> List[torch.Tensor]: + """Residual K-Means warm-start via FAISS, one pass per layer. + + Clusters ``samples`` with FAISS K-Means, subtracts each point's assigned + centroid, and repeats on the residual for every layer. Used by + :meth:`ResidualVectorQuantizer.init_embed_` to seed the RQ-VAE codebook + from the first training batch — the same FAISS backend the offline + RQ-KMeans model uses, instead of a separate torch-native Lloyd's loop. + + Args: + samples (Tensor): data points, shape (N, D). + n_clusters_list (List[int]): per-layer cluster counts. + faiss_kmeans_kwargs (Dict|None): extra kwargs for ``faiss.Kmeans`` + (e.g. ``{'niter': 10, 'seed': 123}``). + + Returns: + List[Tensor]: per-layer centroids ``[(K0, D), ...]`` on samples.device. + + Raises: + ImportError: if ``faiss`` is not installed. + """ + try: + import faiss + except ImportError as e: + raise ImportError( + "faiss is required for RQ-VAE kmeans_init. Install via " + "`pip install faiss-cpu` or `pip install faiss-gpu`." + ) from e + + kwargs = dict(faiss_kmeans_kwargs or {}) + device = samples.device + _, D = samples.shape + # Own a contiguous fp32 numpy copy we mutate in place to form residuals. + x = samples.detach().cpu().float().numpy().copy() + + res_centers: List[torch.Tensor] = [] + for n_clusters in n_clusters_list: + kmeans = faiss.Kmeans(D, n_clusters, **kwargs) + kmeans.train(x) + centroids = kmeans.centroids.copy() # (K, D) + res_centers.append(torch.from_numpy(centroids).to(device)) + _, idx = kmeans.index.search(x, 1) + x -= centroids[idx.ravel()] # residual, in place + return res_centers + + +class KMeansLayer(nn.Module): + """Single layer of a residual K-Means stack. + + Centroids are populated externally by ``load_centroids_`` (called per + layer by the FAISS backend in :class:`ResidualKMeansQuantizer`); ``predict`` + is the only forward path. PyTorch state-dict keys are scoped by + attribute path (``layers..centroids``), so renaming the class + does not break existing checkpoints. + + Args: + n_clusters (int): number of clusters (codebook size). + n_features (int): feature dimension. + """ + + def __init__( + self, + n_clusters: int, + n_features: int, + ) -> None: + super().__init__() + self.n_clusters = n_clusters + self.n_features = n_features + + self.register_buffer("centroids", torch.zeros(n_clusters, n_features)) + # Flipped by ``load_centroids_`` after the FAISS fit. Persistent + # so a normal post-fit checkpoint round-trips; mid-fit poisoning + # (True flag + still-zero centroids) is caught in _load_from_state_dict. + self.register_buffer("_is_initialized", torch.tensor(False)) + + @property + def is_initialized(self) -> bool: + """Whether centroids have been injected via ``load_centroids_``.""" + return self._is_initialized.item() + + @torch.no_grad() + def load_centroids_(self, centroids: torch.Tensor) -> None: + """Inject offline-trained centroids. + + Args: + centroids (Tensor): externally trained centroids, + shape (n_clusters, n_features). + """ + assert centroids.shape == self.centroids.shape, ( + f"centroids shape mismatch: expected {tuple(self.centroids.shape)}, " + f"got {tuple(centroids.shape)}" + ) + self.centroids.copy_( + centroids.to(dtype=self.centroids.dtype, device=self.centroids.device) + ) + self._is_initialized.fill_(True) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) -> None: + """Reject mid-fit-checkpoint state dicts (True flag + zero centroids).""" + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + if bool(self._is_initialized.item()) and self.centroids.abs().sum() == 0: + error_msgs.append( + f"KMeansLayer at '{prefix}': _is_initialized=True but centroids " + "are all zero — checkpoint was likely taken mid-FAISS-fit. " + "Re-run on_train_end to produce a valid checkpoint." + ) + + @torch.no_grad() + def predict(self, batch: torch.Tensor) -> torch.Tensor: + """Assign points to nearest centroid. + + Args: + batch (Tensor): data points, shape (B, D). + + Returns: + Tensor: cluster indices, shape (B,). + """ + dists = _squared_euclidean_distance(batch, self.centroids) + return torch.argmin(dists, dim=-1) diff --git a/tzrec/modules/sid/kmeans_test.py b/tzrec/modules/sid/kmeans_test.py new file mode 100644 index 000000000..8fed1f83a --- /dev/null +++ b/tzrec/modules/sid/kmeans_test.py @@ -0,0 +1,100 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from tzrec.modules.sid.kmeans import ( + KMeansLayer, + _squared_euclidean_distance, + faiss_residual_kmeans, + recon_diagnostics, +) + + +class KmeansHelpersTest(unittest.TestCase): + """Tests for the K-Means helper functions.""" + + def test_recon_diagnostics_zero_on_identity(self) -> None: + x = torch.randn(8, 4) + mse, rel = recon_diagnostics(x, x.clone()) + self.assertAlmostEqual(mse.item(), 0.0, places=6) + self.assertAlmostEqual(rel.item(), 0.0, places=6) + + def test_squared_euclidean_distance(self) -> None: + x = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) + y = torch.tensor([[0.0, 0.0], [0.0, 1.0]]) + d = _squared_euclidean_distance(x, y) + self.assertEqual(d.shape, (2, 2)) + # row0: dist to (0,0)=0, to (0,1)=1; row1: to (0,0)=1, to (0,1)=2 + torch.testing.assert_close(d, torch.tensor([[0.0, 1.0], [1.0, 2.0]])) + + def test_faiss_residual_kmeans_per_layer_centers(self) -> None: + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + torch.manual_seed(0) + samples = torch.randn(512, 6) + centers = faiss_residual_kmeans( + samples, [8, 4], {"niter": 5, "verbose": False, "seed": 1} + ) + self.assertEqual(len(centers), 2) + self.assertEqual(centers[0].shape, (8, 6)) + self.assertEqual(centers[1].shape, (4, 6)) + self.assertTrue(torch.isfinite(centers[0]).all()) + self.assertEqual(centers[0].device, samples.device) + + +class KMeansLayerTest(unittest.TestCase): + """Tests for the single KMeansLayer.""" + + def test_uninitialized_by_default(self) -> None: + layer = KMeansLayer(n_clusters=4, n_features=3) + self.assertFalse(layer.is_initialized) + self.assertEqual(layer.centroids.abs().sum().item(), 0.0) + + def test_load_centroids_and_predict(self) -> None: + layer = KMeansLayer(n_clusters=2, n_features=2) + centroids = torch.tensor([[0.0, 0.0], [10.0, 10.0]]) + layer.load_centroids_(centroids) + self.assertTrue(layer.is_initialized) + + batch = torch.tensor([[0.1, 0.0], [9.0, 11.0]]) + codes = layer.predict(batch) + torch.testing.assert_close(codes, torch.tensor([0, 1])) + + def test_load_centroids_shape_mismatch_raises(self) -> None: + layer = KMeansLayer(n_clusters=2, n_features=2) + with self.assertRaises(AssertionError): + layer.load_centroids_(torch.zeros(3, 2)) + + def test_mid_fit_checkpoint_rejected(self) -> None: + layer = KMeansLayer(n_clusters=2, n_features=2) + sd = layer.state_dict() + # Simulate a mid-fit checkpoint: flag True but centroids still zero. + sd["_is_initialized"] = torch.tensor(True) + fresh = KMeansLayer(n_clusters=2, n_features=2) + with self.assertRaisesRegex(RuntimeError, "mid-FAISS-fit"): + fresh.load_state_dict(sd) + + def test_post_fit_checkpoint_round_trips(self) -> None: + layer = KMeansLayer(n_clusters=2, n_features=2) + layer.load_centroids_(torch.tensor([[1.0, 2.0], [3.0, 4.0]])) + fresh = KMeansLayer(n_clusters=2, n_features=2) + fresh.load_state_dict(layer.state_dict()) + self.assertTrue(fresh.is_initialized) + torch.testing.assert_close(fresh.centroids, layer.centroids) + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py new file mode 100644 index 000000000..505a1b1dc --- /dev/null +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -0,0 +1,248 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multi-layer residual K-Means: ResidualKMeansQuantizer. + +Training is FAISS-only: the codebook is built once via ``train_offline`` +over the full embedding matrix; ``forward`` is read-only (predict + lookup). +""" + +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from tzrec.modules.sid.kmeans import KMeansLayer, recon_diagnostics +from tzrec.modules.sid.residual_quantizer import ResidualQuantizer +from tzrec.utils.logging_util import logger + + +class ResidualKMeansQuantizer(ResidualQuantizer): + """Multi-layer residual K-Means with offline FAISS training. + + Each layer quantizes the residual from the previous layer: + residual_0 = input + for each layer i: + (optionally) residual_i = L2_normalize(residual_i) + code_i = layer_i.predict(residual_i) + quantized_i = layer_i.centroids[code_i] + residual_{i+1} = residual_i - quantized_i + output = sum of all quantized_i + + Semantic ID = (code_0, code_1, ..., code_{n_layers-1}) + + Args: + embed_dim (int): feature dimension. + n_layers (int): number of residual quantization layers. + n_embed (int|List[int]): number of clusters per layer. Default: 256. + May differ per layer (non-uniform codebooks such as + ``[256, 512, 1024]`` are supported) — ``train_offline`` builds a + separate ``faiss.Kmeans`` per layer. + normalize_residuals (bool): whether to L2-normalize residuals + before each layer. Default: False. + faiss_kmeans_kwargs (Dict|None): extra kwargs forwarded to + ``faiss.Kmeans(D, K, **kwargs)`` (e.g. {'niter': 20, + 'gpu': True, 'verbose': True, 'spherical': False}). + """ + + def __init__( + self, + embed_dim: int, + n_layers: int, + n_embed: Union[int, List[int]] = 256, + normalize_residuals: bool = False, + faiss_kmeans_kwargs: Optional[Dict] = None, + ) -> None: + super().__init__(embed_dim, n_layers, n_embed, normalize_residuals) + self.faiss_kmeans_kwargs = dict(faiss_kmeans_kwargs or {}) + + self.layers = nn.ModuleList( + [ + KMeansLayer( + n_clusters=self.n_embed_list[i], + n_features=embed_dim, + ) + for i in range(n_layers) + ] + ) + + def _quantize_layer( + self, + layer_idx: int, + residual: torch.Tensor, + temperature: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Nearest-centroid assignment for one layer. + + Uninitialized layers (before ``train_offline``) return zeros, so the + residual walk is a no-op and the model stays callable. ``temperature`` + is unused (no soft assignment). + + Args: + layer_idx (int): quantization layer index. + residual (Tensor): current residual, shape (B, D). + temperature (float): unused. + + Returns: + codes (Tensor): cluster indices, shape (B,). + quantized (Tensor): selected centroids, shape (B, D). + """ + layer = self.layers[layer_idx] + if not layer.is_initialized: + codes = torch.zeros( + residual.shape[0], dtype=torch.long, device=residual.device + ) + return codes, torch.zeros_like(residual) + codes = layer.predict(residual) + return codes, layer.centroids[codes] + + def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Assign codes per layer and sum the centroids. + + Codebook is read-only here; training happens in ``train_offline``. + Uninitialized layers contribute zeros (see :meth:`_quantize_layer`) so + the model is callable before the one-shot FAISS fit completes. + + Args: + input (Tensor): input embeddings, shape (B, D). + + Returns: + codes (Tensor): cluster indices per layer, shape (B, n_layers). + quantized (Tensor): sum of quantized embeddings, shape (B, D). + """ + cluster_ids, quantized_sum, _ = self._residual_pass(input) + return cluster_ids, quantized_sum + + @torch.no_grad() + def get_codebook_embeddings(self, layer_idx: int) -> torch.Tensor: + """Get centroid weights for a specific layer. + + Args: + layer_idx (int): index of the quantization layer. + + Returns: + Tensor: centroids, shape (n_embed, embed_dim). + """ + return self.layers[layer_idx].centroids + + def _lookup_code(self, layer_idx: int, code_idx: torch.Tensor) -> torch.Tensor: + """Look up codebook vectors via the layer's centroid table.""" + return self.layers[layer_idx].centroids[code_idx] + + @torch.no_grad() + def train_offline( + self, + inputs: Union[torch.Tensor, "np.ndarray"], + verbose: bool = True, + ) -> None: + """Train the multi-layer codebook via offline FAISS K-Means. + + FAISS consumes torch tensors directly (via ``faiss.contrib. + torch_utils``) — no numpy round-trips. The residual matrix stays a + host (CPU) tensor; when a faiss-gpu build is present, ``gpu=`` + moves only FAISS's internal, subsampled working set to the GPU, so we + never hold (N, D) in VRAM. On a faiss-cpu build it runs on CPU + unchanged. Either way the code path is identical. + + Args: + inputs: full embedding matrix, shape (N, D), ``torch.Tensor`` or + ``np.ndarray``. Copied once to an owned CPU float32 tensor; + the caller's input is not mutated. + verbose (bool): whether to print per-layer reconstruction + loss. Default: True. + + Raises: + ImportError: if ``faiss`` is not installed. + """ + try: + import faiss + import faiss.contrib.torch_utils # noqa: F401 (torch tensor I/O) + except ImportError as e: + raise ImportError( + "faiss is required for ResidualKMeansQuantizer training. Install via " + "`pip install faiss-cpu` or `pip install faiss-gpu`." + ) from e + + # Own a contiguous CPU float32 tensor we can update in place for + # residuals, without mutating the caller's input. + if isinstance(inputs, torch.Tensor): + assert inputs.dim() == 2 and inputs.shape[1] == self.embed_dim, ( + f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" + ) + x = inputs.detach().to("cpu", torch.float32).contiguous().clone() + else: + assert inputs.ndim == 2 and inputs.shape[1] == self.embed_dim, ( + f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" + ) + x = torch.from_numpy(np.ascontiguousarray(inputs, dtype=np.float32)).clone() + N = x.shape[0] + out = torch.zeros_like(x) + + # Use FAISS GPU compute when a GPU build is available (data stays on + # host; FAISS streams only its subsampled training set to the device). + # An explicit ``gpu`` in faiss_kmeans_kwargs always wins. + kwargs = dict(self.faiss_kmeans_kwargs) + if "gpu" not in kwargs: + kwargs["gpu"] = ( + torch.cuda.current_device() + if faiss.get_num_gpus() > 0 and torch.cuda.is_available() + else False + ) + + # Chunk size for index.search to limit peak memory. + # 500K × 512 × 4B ≈ 1 GB per chunk. + SEARCH_CHUNK = 500_000 + + for layer_idx in range(self.n_layers): + if self.normalize_residuals: + x = F.normalize(x, dim=-1) + + # Fresh Kmeans per layer so each layer can use its own K + # (non-uniform codebooks supported). Index construction is a cheap + # O(K*D) allocation next to train(), so this is effectively free. + kmeans = faiss.Kmeans( + self.embed_dim, self.n_embed_list[layer_idx], **kwargs + ) + kmeans.train(x) + centroids = torch.as_tensor(kmeans.centroids, dtype=torch.float32).cpu() + + for start in range(0, N, SEARCH_CHUNK): + end = min(start + SEARCH_CHUNK, N) + _, idx = kmeans.index.search(x[start:end], 1) + idx = torch.as_tensor(idx, device="cpu").reshape(-1).long() + q = centroids[idx] # (chunk, D) + out[start:end] += q + x[start:end] -= q # residual + del idx, q + + if verbose: + logger.info( + "[ResidualKMeansQuantizer][offline_faiss][layer %d] %s", + layer_idx, + self._calc_loss(out + x, out), # x_in = out + residual + ) + + self.layers[layer_idx].load_centroids_(centroids) + if verbose: + logger.info( + "[ResidualKMeansQuantizer][offline_faiss] layer %d finished", + layer_idx, + ) + + @staticmethod + def _calc_loss( + x: torch.Tensor, out: torch.Tensor, epsilon: float = 1e-4 + ) -> Dict[str, float]: + """Reconstruction loss diagnostics (MSE + relative L1).""" + loss, rel_loss = recon_diagnostics(x, out, epsilon=epsilon) + return {"loss": float(loss.item()), "rel_loss": float(rel_loss.item())} diff --git a/tzrec/modules/sid/residual_quantizer_test.py b/tzrec/modules/sid/residual_quantizer_test.py index c94cc545d..d23ef1cf5 100644 --- a/tzrec/modules/sid/residual_quantizer_test.py +++ b/tzrec/modules/sid/residual_quantizer_test.py @@ -14,6 +14,9 @@ import torch from torch import nn +from tzrec.modules.sid.residual_kmeans_quantizer import ( + ResidualKMeansQuantizer, +) from tzrec.modules.sid.residual_quantizer import ( ResidualQuantizer, normalize_n_embed, @@ -142,5 +145,94 @@ def test_decode_codes_sum_and_dtype(self) -> None: self.assertEqual(recon16.dtype, torch.bfloat16) +class ResidualKMeansQuantizerTest(unittest.TestCase): + def test_is_subclass(self) -> None: + rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) + self.assertIsInstance(rkq, ResidualQuantizer) + + def test_non_uniform_codebook_supported(self) -> None: + rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=3, n_embed=[8, 4, 16]) + self.assertEqual(rkq.n_embed_list, [8, 4, 16]) + self.assertEqual([layer.centroids.shape[0] for layer in rkq.layers], [8, 4, 16]) + + def test_forward_returns_zeros_before_fit(self) -> None: + rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) + self.assertFalse(all(layer.is_initialized for layer in rkq.layers)) + codes, quantized = rkq(torch.randn(5, 4)) + self.assertEqual(codes.shape, (5, 2)) + self.assertEqual(quantized.shape, (5, 4)) + + def test_forward_is_fx_traceable(self) -> None: + """Predict forward must FX-trace. + + torchrec's inference pipeline symbolically traces the model, so the + per-batch distance path must be free of data-dependent control flow. + """ + import torch.fx as fx + + torch.manual_seed(0) + rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) + for layer in rkq.layers: # populate centroids -> is_initialized=True + layer.load_centroids_(torch.randn(8, 4)) + traced = fx.symbolic_trace(rkq) + x = torch.randn(5, 4) + c_eager, q_eager = rkq(x) + c_traced, q_traced = traced(x) + torch.testing.assert_close(c_traced, c_eager) + torch.testing.assert_close(q_traced, q_eager) + + def test_train_offline_non_uniform(self) -> None: + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + torch.manual_seed(0) + n_embed = [8, 4, 16] + rkq = ResidualKMeansQuantizer( + embed_dim=4, n_layers=3, n_embed=n_embed, faiss_kmeans_kwargs={"niter": 5} + ) + rkq.train_offline(torch.randn(512, 4), verbose=False) + self.assertTrue(all(layer.is_initialized for layer in rkq.layers)) + # Each layer fit its own K centroids; codes stay in per-layer range. + codes, _ = rkq(torch.randn(7, 4)) + self.assertEqual(codes.shape, (7, 3)) + for i, k in enumerate(n_embed): + self.assertTrue((codes[:, i] >= 0).all() and (codes[:, i] < k).all()) + + def test_train_offline_then_decode(self) -> None: + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + torch.manual_seed(0) + rkq = ResidualKMeansQuantizer( + embed_dim=4, n_layers=2, n_embed=8, faiss_kmeans_kwargs={"niter": 5} + ) + rkq.train_offline(torch.randn(256, 4), verbose=False) + self.assertTrue(all(layer.is_initialized for layer in rkq.layers)) + + codes, _ = rkq(torch.randn(5, 4)) + self.assertTrue((codes >= 0).all() and (codes < 8).all()) + recon = rkq.decode_codes(codes) # inherited from the base + self.assertEqual(recon.shape, (5, 4)) + + def test_forward_get_codes_consistent(self) -> None: + """Forward ids and get_codes both route through the shared walk.""" + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + torch.manual_seed(0) + rkq = ResidualKMeansQuantizer( + embed_dim=4, n_layers=3, n_embed=8, faiss_kmeans_kwargs={"niter": 5} + ) + rkq.train_offline(torch.randn(256, 4), verbose=False) + x = torch.randn(9, 4) + fwd_ids, fwd_quant = rkq(x) + torch.testing.assert_close(rkq.get_codes(x), fwd_ids) + # forward's residual-sum equals the centroid-sum reconstruction. + torch.testing.assert_close(fwd_quant, rkq.decode_codes(fwd_ids)) + + if __name__ == "__main__": unittest.main() diff --git a/tzrec/protos/model.proto b/tzrec/protos/model.proto index bef2062ea..58b719a7a 100644 --- a/tzrec/protos/model.proto +++ b/tzrec/protos/model.proto @@ -5,6 +5,7 @@ import "tzrec/protos/models/rank_model.proto"; import "tzrec/protos/models/multi_task_rank.proto"; import "tzrec/protos/models/match_model.proto"; import "tzrec/protos/models/general_rank_model.proto"; +import "tzrec/protos/models/sid_model.proto"; import "tzrec/protos/loss.proto"; import "tzrec/protos/metric.proto"; import "tzrec/protos/seq_encoder.proto"; @@ -76,6 +77,10 @@ message ModelConfig { TDM tdm = 400; RocketLaunching rocket_launching = 500; + + // SID generation models + // (600 is reserved for SidRqvae, arriving in the follow-up PR) + SidRqkmeans sid_rqkmeans = 601; } optional uint32 num_class = 2 [default = 1]; diff --git a/tzrec/protos/models/sid_model.proto b/tzrec/protos/models/sid_model.proto new file mode 100644 index 000000000..065013614 --- /dev/null +++ b/tzrec/protos/models/sid_model.proto @@ -0,0 +1,31 @@ +syntax = "proto2"; +package tzrec.protos; + +import "google/protobuf/struct.proto"; + +message SidRqkmeans { + // Input embedding dimension (K-Means runs directly on raw embeddings, + // no encoder). + optional uint32 input_dim = 1 [default = 512]; + // Per-layer cluster counts, e.g. [256, 256, 256]. + // List length is the number of residual quantization layers. Entries + // may differ per layer (non-uniform codebooks such as [256, 512, 1024] + // are supported — the FAISS backend fits a separate ``faiss.Kmeans`` + // per layer). + repeated uint32 codebook = 3; + // L2-normalize residuals before each layer. + optional bool normalize_residuals = 4 [default = true]; + // Extra kwargs forwarded to faiss.Kmeans(D, K, **kwargs) as a + // loosely-typed dict, e.g. {niter: 20, gpu: true, verbose: true, + // spherical: false, seed: 1234}. + optional google.protobuf.Struct faiss_kmeans_kwargs = 5; + // Target number of embeddings to reservoir-sample for the FAISS fit + // (global, across all ranks). Bounds host memory regardless of corpus + // size. 0 (the default) auto-derives it as K * max_points_per_centroid + // — exactly what FAISS subsamples to internally (default 256), so no + // training points are wasted. + optional uint32 train_sample_size = 6 [default = 0]; + + // Name of the item embedding feature inside the input Batch. + optional string embedding_feature_name = 40 [default = "item_emb"]; +} From c7f3a091fa3d02e4d3ed5824f6312759612bc30a Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 8 Jun 2026 02:59:22 +0000 Subject: [PATCH 02/46] [review] SID: drop forced tail-checkpoint after on_train_end Remove the `last_ckpt_step == i_step -> -1` override (and its stale comment) in the train loop's end-of-loop hook. The normal checkpoint cadence already persists the post-hook state. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/main.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tzrec/main.py b/tzrec/main.py index 8824e8373..9efb8d8e4 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -502,12 +502,8 @@ def _train_and_evaluate( # One-shot end-of-loop hook (default no-op). Some models do real work # here — e.g. SidRqkmeans fits its FAISS codebook from the embeddings - # collected during training. Since that mutates model state, force the - # tail-save below to fire so the post-hook state is persisted even when - # the last in-loop checkpoint coincided with the final step. + # collected during training. _model.on_train_end() - if last_ckpt_step == i_step: - last_ckpt_step = -1 _log_train( i_step, From 61ec842c89f91b75eb45f87e2af20cce09e4bb69 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 8 Jun 2026 05:48:00 +0000 Subject: [PATCH 03/46] [review] SID: address code-review findings on PR #539 - on_train_end() now returns is_ckpt_after_train; the tail save fires on `last_ckpt_step != i_step or is_ckpt_after_train`, so the fitted FAISS codebook is always persisted even when the last periodic checkpoint landed on the final step (main.py, model.py, sid_rqkmeans.py). (#1) - DDP on_train_end: wrap the rank0 FAISS fit in try/except and broadcast a fit-status flag so a rank0-only failure (or an empty reservoir) makes all ranks raise together instead of deadlocking on the centroid broadcast; correct the empty-reservoir docstring. (#2, #3) - KMeansLayer: cache is_initialized as a plain Python bool to drop a per-layer per-batch GPU->CPU .item() sync on the eval/predict path, kept in lockstep with the _is_initialized buffer. (#6) - _reservoir_add: copy only the kept rows to host instead of the whole batch every training step (keep float64 for n_seen exactness). (#7) - train_offline: per-layer fit-loss log now reports cumulative reconstruction of the original input (correct under normalize_residuals); align the module normalize_residuals default to True to match the proto. (#8, #10) - Drop dead faiss_residual_kmeans (RQ-VAE-only, lands in PR3) and its test; tidy _coerce_proto_numbers into a comprehension. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/main.py | 8 +- tzrec/models/model.py | 10 ++- tzrec/models/sid_rqkmeans.py | 90 +++++++++++++------ tzrec/models/sid_rqkmeans_test.py | 10 ++- tzrec/modules/sid/kmeans.py | 70 +++------------ tzrec/modules/sid/kmeans_test.py | 17 ---- .../modules/sid/residual_kmeans_quantizer.py | 13 ++- 7 files changed, 106 insertions(+), 112 deletions(-) diff --git a/tzrec/main.py b/tzrec/main.py index 9efb8d8e4..b71df9fe1 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -502,8 +502,10 @@ def _train_and_evaluate( # One-shot end-of-loop hook (default no-op). Some models do real work # here — e.g. SidRqkmeans fits its FAISS codebook from the embeddings - # collected during training. - _model.on_train_end() + # collected during training. When the hook mutated state that must be + # persisted, it returns True so the tail save below fires even if the + # last in-loop checkpoint already landed on the final step. + is_ckpt_after_train = _model.on_train_end() _log_train( i_step, @@ -518,7 +520,7 @@ def _train_and_evaluate( summary_writer.close() if train_config.is_profiling: prof.stop() - if last_ckpt_step != i_step: + if last_ckpt_step != i_step or is_ckpt_after_train: ckpt_manager.save(i_step, model, optimizer, dataloader_state) if eval_dataloader is not None: _evaluate( diff --git a/tzrec/models/model.py b/tzrec/models/model.py index 10fa8aae5..09ffa1f58 100644 --- a/tzrec/models/model.py +++ b/tzrec/models/model.py @@ -150,14 +150,20 @@ def compute_train_metric(self) -> Dict[str, torch.Tensor]: metric_results[metric_name] = metric.compute() return metric_results - def on_train_end(self) -> None: + def on_train_end(self) -> bool: """Hook fired once after the train_eval loop exits. Default: no-op. Override in models that need one-shot end-of-loop work — e.g. :class:`SidRqkmeans` uses this hook to fit the FAISS codebook from the embedding sample it collected during training. + + Returns: + is_ckpt_after_train (bool): whether the hook mutated model state + that must be persisted, so the train loop should force a final + checkpoint even when one was already saved at the last step. + Default ``False`` (no-op hooks change nothing). """ - pass + return False def sparse_parameters( self, diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index b9c3c8800..00aefa12b 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -44,13 +44,10 @@ def _coerce_proto_numbers(d: Dict) -> Dict: Python ``int``. This helper converts any float that is an exact integer to ``int`` for downstream consumption. """ - out: Dict = {} - for k, v in d.items(): - if isinstance(v, float) and v.is_integer(): - out[k] = int(v) - else: - out[k] = v - return out + return { + k: int(v) if isinstance(v, float) and v.is_integer() else v + for k, v in d.items() + } class SidRqkmeans(BaseSidModel): @@ -132,15 +129,17 @@ def _reservoir_add(self, x: torch.Tensor) -> None: Args: x (Tensor): a batch of embeddings, shape (B, D); copied to host. """ - x = x.detach().to("cpu", dtype=torch.float32) + x = x.detach() cap = self._sample_cap if self._reservoir is None: self._reservoir = torch.empty(cap, x.shape[1], dtype=torch.float32) - # Phase 1: fill empty slots first. + # Phase 1: fill empty slots first. Copy only the rows we keep to host. if self._n_filled < cap: take = min(x.shape[0], cap - self._n_filled) - self._reservoir[self._n_filled : self._n_filled + take] = x[:take] + self._reservoir[self._n_filled : self._n_filled + take] = x[:take].to( + "cpu", dtype=torch.float32 + ) self._n_filled += take self._n_seen += take x = x[take:] @@ -149,7 +148,12 @@ def _reservoir_add(self, x: torch.Tensor) -> None: # Phase 2: replacement. Row j (0-indexed in x) is the # (n_seen + j)-th item seen; it enters the reservoir with prob - # cap / (n_seen + j + 1), displacing a uniformly-random slot. + # cap / (n_seen + j + 1), displacing a uniformly-random slot. The + # accept decision needs only counts (not embedding values), so we + # compute it on small host index tensors and copy ONLY the accepted + # rows to host — in steady state (reservoir full, n_seen >> cap) + # almost none are accepted, so the whole-batch GPU->CPU copy is + # avoided. float64 keeps (n_seen + j + 1) exact past 2**24. r = x.shape[0] pos = self._n_seen + torch.arange(r) accept = torch.rand(r) < (cap / (pos + 1).to(torch.float64)) @@ -158,7 +162,7 @@ def _reservoir_add(self, x: torch.Tensor) -> None: slots = torch.randint(0, cap, (idx.numel(),)) # Intra-batch slot collisions resolve last-write-wins; the bias is # O(B/cap) per step and negligible for codebook fitting. - self._reservoir[slots] = x[idx] + self._reservoir[slots] = x[idx.to(x.device)].to("cpu", dtype=torch.float32) self._n_seen += r def _reservoir_sample(self) -> torch.Tensor: @@ -271,7 +275,7 @@ def update_metric( self._metric_modules["unique_sid_ratio"].update(predictions["codes"]) @torch.no_grad() - def on_train_end(self) -> None: + def on_train_end(self) -> bool: """Trigger one-shot FAISS fit after the train_eval loop ends. Overrides :meth:`BaseModel.on_train_end`. Called unconditionally @@ -283,10 +287,19 @@ def on_train_end(self) -> None: - other ranks: ship their reservoir sample via gather_object (dst=0) and wait for the broadcast. - No cross-rank empty-buffer handshake is needed: the dataset layer - enforces ``num_files >= world_size`` (``tzrec.datasets.dataset`` - raises otherwise), so in synchronized training every rank receives - at least one shard and reaches the gather with a non-empty sample. + Empty-reservoir handling: for any real-scale dataset every rank gets + a non-empty reservoir — the default ParquetDataset (``rebalance=True``) + splits rows across ``num_workers * world_size`` workers, so a rank only + ends up empty for a pathologically tiny corpus (``total_rows`` smaller + than that worker count). That degenerate case does not hang: rank0's + FAISS fit raises on too-few points and the fit-status broadcast below + makes every rank raise a coordinated ``RuntimeError`` instead. + + Returns: + is_ckpt_after_train (bool): ``True`` once the codebook has been + fitted here (the centroid buffers changed and must be persisted, + so the train loop forces a final checkpoint); ``False`` when the + fit was skipped (empty reservoir — nothing to persist). """ is_ddp = ( dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1 @@ -306,16 +319,39 @@ def on_train_end(self) -> None: ) dist.gather_object(local, gathered, dst=0) del local + fit_ok = True if rank == 0: assert gathered is not None - full = torch.cat([g for g in gathered if g is not None], dim=0) - del gathered - logger.info( - "[SidRqkmeans.on_train_end] rank0 fitting FAISS " - "on %d samples (D=%d)." % (full.shape[0], full.shape[1]) + try: + full = torch.cat([g for g in gathered if g is not None], dim=0) + del gathered + logger.info( + "[SidRqkmeans.on_train_end] rank0 fitting FAISS " + "on %d samples (D=%d)." % (full.shape[0], full.shape[1]) + ) + self._quantizer.train_offline(full, verbose=True) + del full + except Exception as e: # noqa: BLE001 + # Swallow on rank0 only long enough to tell the peers — if + # we let it propagate here, ranks 1..N-1 would block forever + # on the centroid broadcast below with no sender. + fit_ok = False + logger.error( + "[SidRqkmeans.on_train_end] rank0 FAISS fit failed: %s", e + ) + # Sync rank0's status to every rank (int flag, not bool — see the + # NCCL note below) so a rank0-only failure makes all ranks raise + # together instead of deadlocking on the centroid broadcast. + status = torch.tensor( + [1 if fit_ok else 0], + device=self._quantizer.layers[0].centroids.device, + ) + dist.broadcast(status, src=0) + if int(status.item()) == 0: + raise RuntimeError( + "[SidRqkmeans.on_train_end] FAISS fit failed on rank0; " + "see rank0 logs for the underlying error." ) - self._quantizer.train_offline(full, verbose=True) - del full # Broadcast centroids and set the init flag locally on every # rank. ``_is_initialized`` is a bool buffer and NCCL's bool # dtype support is inconsistent across versions, so we avoid @@ -324,8 +360,9 @@ def on_train_end(self) -> None: for layer in self._quantizer.layers: dist.broadcast(layer.centroids, src=0) layer._is_initialized.fill_(True) + layer._initialized = True dist.barrier() - return + return True # Single-process path. Guard an empty sample with a plain local check # (no collective): on_train_end may be invoked without a training pass. @@ -334,10 +371,11 @@ def on_train_end(self) -> None: "[SidRqkmeans.on_train_end] empty reservoir; skipping FAISS " "fit. Did the train_eval loop run?" ) - return + return False logger.info( "[SidRqkmeans.on_train_end] fitting FAISS on %d samples (D=%d)." % (local.shape[0], local.shape[1]) ) self._quantizer.train_offline(local, verbose=True) + return True diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index 8b224afac..30e204116 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -133,8 +133,8 @@ def test_on_train_end_runs_faiss(self) -> None: model.predict(_make_batch(B, input_dim)) self.assertGreater(model._n_seen, 0) - # Trigger one-shot FAISS fit - model.on_train_end() + # Trigger one-shot FAISS fit; a real fit must request a tail checkpoint + self.assertTrue(model.on_train_end()) # Reservoir should be released after the fit self.assertEqual(model._n_seen, 0) @@ -185,7 +185,8 @@ def test_non_uniform_codebook_end_to_end(self) -> None: def test_on_train_end_noop_on_empty_buffer(self) -> None: """on_train_end on an empty buffer is a warned no-op.""" model = self._create_model() - model.on_train_end() # should not raise + # No fit happened, so no tail checkpoint is requested. + self.assertFalse(model.on_train_end()) # should not raise def test_post_fit_checkpoint_round_trips(self) -> None: """Fit → save state_dict → load into fresh instance → predict. @@ -266,7 +267,8 @@ def _on_train_end_worker(rank: int, world_size: int, port: int) -> None: assert model._n_seen == 6 * 32, f"rank{rank}: reservoir not filled" # gather_object -> rank0 FAISS fit -> broadcast centroids + fill flag. - model.on_train_end() + # Every rank fitted/received the codebook, so each requests a tail ckpt. + assert model.on_train_end(), f"rank{rank}: on_train_end should request ckpt" for layer in model._quantizer.layers: assert bool(layer._is_initialized.item()), f"rank{rank}: layer uninit" diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans.py index 0b6fe4255..ecc554aa5 100644 --- a/tzrec/modules/sid/kmeans.py +++ b/tzrec/modules/sid/kmeans.py @@ -18,12 +18,9 @@ :class:`ResidualKMeansQuantizer`. Centroids are injected by the FAISS backend via ``load_centroids_``; the only forward path is ``predict``. -* :func:`faiss_residual_kmeans` — FAISS residual K-Means used by - :class:`ResidualVectorQuantizer` to warm-start the RQ-VAE codebook on the - first training batch (same FAISS backend as the offline RQ-KMeans fit). """ -from typing import Dict, List, Optional, Tuple +from typing import Tuple import torch from torch import nn @@ -79,57 +76,6 @@ def _squared_euclidean_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tenso return (x_sq + y_sq - 2.0 * x @ y.t()).clamp_(min=0.0) -@torch.no_grad() -def faiss_residual_kmeans( - samples: torch.Tensor, - n_clusters_list: List[int], - faiss_kmeans_kwargs: Optional[Dict] = None, -) -> List[torch.Tensor]: - """Residual K-Means warm-start via FAISS, one pass per layer. - - Clusters ``samples`` with FAISS K-Means, subtracts each point's assigned - centroid, and repeats on the residual for every layer. Used by - :meth:`ResidualVectorQuantizer.init_embed_` to seed the RQ-VAE codebook - from the first training batch — the same FAISS backend the offline - RQ-KMeans model uses, instead of a separate torch-native Lloyd's loop. - - Args: - samples (Tensor): data points, shape (N, D). - n_clusters_list (List[int]): per-layer cluster counts. - faiss_kmeans_kwargs (Dict|None): extra kwargs for ``faiss.Kmeans`` - (e.g. ``{'niter': 10, 'seed': 123}``). - - Returns: - List[Tensor]: per-layer centroids ``[(K0, D), ...]`` on samples.device. - - Raises: - ImportError: if ``faiss`` is not installed. - """ - try: - import faiss - except ImportError as e: - raise ImportError( - "faiss is required for RQ-VAE kmeans_init. Install via " - "`pip install faiss-cpu` or `pip install faiss-gpu`." - ) from e - - kwargs = dict(faiss_kmeans_kwargs or {}) - device = samples.device - _, D = samples.shape - # Own a contiguous fp32 numpy copy we mutate in place to form residuals. - x = samples.detach().cpu().float().numpy().copy() - - res_centers: List[torch.Tensor] = [] - for n_clusters in n_clusters_list: - kmeans = faiss.Kmeans(D, n_clusters, **kwargs) - kmeans.train(x) - centroids = kmeans.centroids.copy() # (K, D) - res_centers.append(torch.from_numpy(centroids).to(device)) - _, idx = kmeans.index.search(x, 1) - x -= centroids[idx.ravel()] # residual, in place - return res_centers - - class KMeansLayer(nn.Module): """Single layer of a residual K-Means stack. @@ -158,11 +104,17 @@ def __init__( # so a normal post-fit checkpoint round-trips; mid-fit poisoning # (True flag + still-zero centroids) is caught in _load_from_state_dict. self.register_buffer("_is_initialized", torch.tensor(False)) + # Plain-Python mirror of ``_is_initialized``, read on the per-batch + # forward path (``_quantize_layer``) so the hot path never pays a + # ``.item()`` GPU->CPU sync. Kept in lockstep with the buffer wherever + # the buffer changes: ``load_centroids_``, ``_load_from_state_dict``, + # and the DDP broadcast in ``SidRqkmeans.on_train_end``. + self._initialized: bool = False @property def is_initialized(self) -> bool: """Whether centroids have been injected via ``load_centroids_``.""" - return self._is_initialized.item() + return self._initialized @torch.no_grad() def load_centroids_(self, centroids: torch.Tensor) -> None: @@ -180,6 +132,7 @@ def load_centroids_(self, centroids: torch.Tensor) -> None: centroids.to(dtype=self.centroids.dtype, device=self.centroids.device) ) self._is_initialized.fill_(True) + self._initialized = True def _load_from_state_dict( self, @@ -201,7 +154,10 @@ def _load_from_state_dict( unexpected_keys, error_msgs, ) - if bool(self._is_initialized.item()) and self.centroids.abs().sum() == 0: + # Mirror the restored buffer into the cached Python flag (one sync at + # load time, off the hot path). + self._initialized = bool(self._is_initialized.item()) + if self._initialized and self.centroids.abs().sum() == 0: error_msgs.append( f"KMeansLayer at '{prefix}': _is_initialized=True but centroids " "are all zero — checkpoint was likely taken mid-FAISS-fit. " diff --git a/tzrec/modules/sid/kmeans_test.py b/tzrec/modules/sid/kmeans_test.py index 8fed1f83a..cb86a39d8 100644 --- a/tzrec/modules/sid/kmeans_test.py +++ b/tzrec/modules/sid/kmeans_test.py @@ -16,7 +16,6 @@ from tzrec.modules.sid.kmeans import ( KMeansLayer, _squared_euclidean_distance, - faiss_residual_kmeans, recon_diagnostics, ) @@ -38,22 +37,6 @@ def test_squared_euclidean_distance(self) -> None: # row0: dist to (0,0)=0, to (0,1)=1; row1: to (0,0)=1, to (0,1)=2 torch.testing.assert_close(d, torch.tensor([[0.0, 1.0], [1.0, 2.0]])) - def test_faiss_residual_kmeans_per_layer_centers(self) -> None: - try: - import faiss # noqa: F401 - except ImportError: - self.skipTest("faiss not installed") - torch.manual_seed(0) - samples = torch.randn(512, 6) - centers = faiss_residual_kmeans( - samples, [8, 4], {"niter": 5, "verbose": False, "seed": 1} - ) - self.assertEqual(len(centers), 2) - self.assertEqual(centers[0].shape, (8, 6)) - self.assertEqual(centers[1].shape, (4, 6)) - self.assertTrue(torch.isfinite(centers[0]).all()) - self.assertEqual(centers[0].device, samples.device) - class KMeansLayerTest(unittest.TestCase): """Tests for the single KMeansLayer.""" diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 505a1b1dc..72a539654 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -49,7 +49,8 @@ class ResidualKMeansQuantizer(ResidualQuantizer): ``[256, 512, 1024]`` are supported) — ``train_offline`` builds a separate ``faiss.Kmeans`` per layer. normalize_residuals (bool): whether to L2-normalize residuals - before each layer. Default: False. + before each layer. Default: True, matching the ``SidRqkmeans`` + proto default so direct instantiation agrees with the config path. faiss_kmeans_kwargs (Dict|None): extra kwargs forwarded to ``faiss.Kmeans(D, K, **kwargs)`` (e.g. {'niter': 20, 'gpu': True, 'verbose': True, 'spherical': False}). @@ -60,7 +61,7 @@ def __init__( embed_dim: int, n_layers: int, n_embed: Union[int, List[int]] = 256, - normalize_residuals: bool = False, + normalize_residuals: bool = True, faiss_kmeans_kwargs: Optional[Dict] = None, ) -> None: super().__init__(embed_dim, n_layers, n_embed, normalize_residuals) @@ -187,6 +188,12 @@ def train_offline( x = torch.from_numpy(np.ascontiguousarray(inputs, dtype=np.float32)).clone() N = x.shape[0] out = torch.zeros_like(x) + # Keep the original input only when we log: the per-layer diagnostic + # is the cumulative reconstruction error of the *original* input by + # the centroid sum so far (the same quantity update_metric reports). + # ``out + x`` would equal it only when normalize_residuals is off; with + # normalization the residual is rescaled each layer, so track x0. + x0 = x.clone() if verbose else None # Use FAISS GPU compute when a GPU build is available (data stays on # host; FAISS streams only its subsampled training set to the device). @@ -229,7 +236,7 @@ def train_offline( logger.info( "[ResidualKMeansQuantizer][offline_faiss][layer %d] %s", layer_idx, - self._calc_loss(out + x, out), # x_in = out + residual + self._calc_loss(x0, out), # cumulative recon of original input ) self.layers[layer_idx].load_centroids_(centroids) From 753f3fe94cd94ce05d819b8425fd76baafcf6959 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 8 Jun 2026 06:04:30 +0000 Subject: [PATCH 04/46] [review] SID: default normalize_residuals to False Flip the default for RQ-KMeans residual normalization to False, in both the SidRqkmeans proto field and the ResidualKMeansQuantizer constructor (kept consistent to avoid the proto/module mismatch). This matches OpenOneRec's residual k-means, which fits raw residuals with no per-layer L2 normalization. Configs that set normalize_residuals explicitly are unaffected. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/modules/sid/residual_kmeans_quantizer.py | 7 ++++--- tzrec/protos/models/sid_model.proto | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 72a539654..a3bfb1dae 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -49,8 +49,9 @@ class ResidualKMeansQuantizer(ResidualQuantizer): ``[256, 512, 1024]`` are supported) — ``train_offline`` builds a separate ``faiss.Kmeans`` per layer. normalize_residuals (bool): whether to L2-normalize residuals - before each layer. Default: True, matching the ``SidRqkmeans`` - proto default so direct instantiation agrees with the config path. + before each layer. Default: False, matching the ``SidRqkmeans`` + proto default (and OpenOneRec's residual k-means, which fits raw + residuals with no per-layer normalization). faiss_kmeans_kwargs (Dict|None): extra kwargs forwarded to ``faiss.Kmeans(D, K, **kwargs)`` (e.g. {'niter': 20, 'gpu': True, 'verbose': True, 'spherical': False}). @@ -61,7 +62,7 @@ def __init__( embed_dim: int, n_layers: int, n_embed: Union[int, List[int]] = 256, - normalize_residuals: bool = True, + normalize_residuals: bool = False, faiss_kmeans_kwargs: Optional[Dict] = None, ) -> None: super().__init__(embed_dim, n_layers, n_embed, normalize_residuals) diff --git a/tzrec/protos/models/sid_model.proto b/tzrec/protos/models/sid_model.proto index 065013614..6c3d1b297 100644 --- a/tzrec/protos/models/sid_model.proto +++ b/tzrec/protos/models/sid_model.proto @@ -14,7 +14,7 @@ message SidRqkmeans { // per layer). repeated uint32 codebook = 3; // L2-normalize residuals before each layer. - optional bool normalize_residuals = 4 [default = true]; + optional bool normalize_residuals = 4 [default = false]; // Extra kwargs forwarded to faiss.Kmeans(D, K, **kwargs) as a // loosely-typed dict, e.g. {niter: 20, gpu: true, verbose: true, // spherical: false, seed: 1234}. From 52c745224431ac5e32969a68db4bdc64028b6467 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 8 Jun 2026 07:47:17 +0000 Subject: [PATCH 05/46] [review] SID: encapsulation, comment, and import cleanups - KMeansLayer: add mark_initialized_() so the buffer + cached-bool init flag is owned by the layer; the DDP broadcast in SidRqkmeans uses it instead of poking the private fields. - SidRqkmeans: extract the reservoir-cap setup into _init_reservoir(). - residual_kmeans_quantizer: import faiss at module level (it's a pinned requirement) instead of a lazy in-function import; narrow train_offline(inputs) to torch.Tensor (all callers pass tensors) and drop the dead numpy branch. - Tighten the verbose comments/docstrings across the SID files. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/main.py | 8 +- tzrec/models/model.py | 12 +- tzrec/models/sid_rqkmeans.py | 156 +++++++----------- tzrec/modules/sid/kmeans.py | 49 +++--- .../modules/sid/residual_kmeans_quantizer.py | 72 +++----- 5 files changed, 114 insertions(+), 183 deletions(-) diff --git a/tzrec/main.py b/tzrec/main.py index b71df9fe1..8b4b5357b 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -500,11 +500,9 @@ def _train_and_evaluate( if lr.by_epoch: lr.step() - # One-shot end-of-loop hook (default no-op). Some models do real work - # here — e.g. SidRqkmeans fits its FAISS codebook from the embeddings - # collected during training. When the hook mutated state that must be - # persisted, it returns True so the tail save below fires even if the - # last in-loop checkpoint already landed on the final step. + # One-shot end-of-loop hook (default no-op; e.g. SidRqkmeans fits its FAISS + # codebook here). Returns True if it mutated persistable state, forcing the + # tail save below even when the last in-loop checkpoint hit the final step. is_ckpt_after_train = _model.on_train_end() _log_train( diff --git a/tzrec/models/model.py b/tzrec/models/model.py index 09ffa1f58..c6b2b952c 100644 --- a/tzrec/models/model.py +++ b/tzrec/models/model.py @@ -153,15 +153,13 @@ def compute_train_metric(self) -> Dict[str, torch.Tensor]: def on_train_end(self) -> bool: """Hook fired once after the train_eval loop exits. - Default: no-op. Override in models that need one-shot end-of-loop - work — e.g. :class:`SidRqkmeans` uses this hook to fit the FAISS - codebook from the embedding sample it collected during training. + Default no-op; override for one-shot end-of-loop work (e.g. + :class:`SidRqkmeans` fits its FAISS codebook here). Returns: - is_ckpt_after_train (bool): whether the hook mutated model state - that must be persisted, so the train loop should force a final - checkpoint even when one was already saved at the last step. - Default ``False`` (no-op hooks change nothing). + is_ckpt_after_train (bool): whether the hook mutated state that must + be persisted, so the loop forces a final checkpoint even if one was + already saved at the last step. Default ``False``. """ return False diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 00aefa12b..3859a3ef0 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -37,12 +37,10 @@ def _coerce_proto_numbers(d: Dict) -> Dict: - """Coerce float-typed integers back to int. + """Coerce whole-valued floats back to int. - ``google.protobuf.Struct.number_value`` is always float, but most - ``faiss.Kmeans`` kwargs (``niter``, ``seed``, ``nredo``, ...) require - Python ``int``. This helper converts any float that is an exact - integer to ``int`` for downstream consumption. + ``Struct.number_value`` is always float, but faiss.Kmeans kwargs + (``niter``, ``seed``, ...) need ``int``. """ return { k: int(v) if isinstance(v, float) and v.is_integer() else v @@ -76,9 +74,7 @@ def __init__( cfg = self._model_config # SidRqkmeans proto message - # config_to_kwargs returns Struct numbers as floats (it is - # MessageToDict under the hood), so _coerce_proto_numbers restores - # the ints faiss.Kmeans expects (niter, seed, nredo, ...). + # config_to_kwargs yields Struct numbers as floats; coerce back to int. self._faiss_kwargs = ( _coerce_proto_numbers(config_util.config_to_kwargs(cfg.faiss_kmeans_kwargs)) if cfg.HasField("faiss_kmeans_kwargs") @@ -93,41 +89,41 @@ def __init__( faiss_kmeans_kwargs=self._faiss_kwargs, ) - # Per-rank reservoir cap. FAISS K-Means only ever consumes - # K * max_points_per_centroid points (it subsamples internally), so - # buffering the full corpus is wasted memory. We reservoir-sample to - # that target instead, split across ranks so the gathered set on - # rank0 is ~train_sample_size and FAISS does no further subsampling. - # Use the LARGEST per-layer K so non-uniform codebooks (e.g. - # [256, 512, 1024]) still feed their biggest layer enough points. + self._init_reservoir() + + # KMeans has no learnable params; a dummy keeps the optimizer/DDP happy. + self._dummy_param = nn.Parameter(torch.zeros(1), requires_grad=True) + + def _init_reservoir(self) -> None: + """Set up the bounded host reservoir for the end-of-loop FAISS fit. + + Per-rank cap: FAISS subsamples to K*max_points_per_centroid internally, + so reservoir-sample to that target (split across ranks) rather than + buffer the whole corpus. Use the largest per-layer K so non-uniform + codebooks still feed their biggest layer enough points. + """ k = max(self._n_embed_list) max_ppc = int(self._faiss_kwargs.get("max_points_per_centroid", 256)) - global_target = ( - cfg.train_sample_size if cfg.train_sample_size > 0 else k * max_ppc - ) + target = self._model_config.train_sample_size + global_target = target if target > 0 else k * max_ppc world_size = dist.get_world_size() if dist.is_initialized() else 1 self._sample_cap = max(1, -(-global_target // world_size)) # ceil div - # Bounded host-resident reservoir (allocated lazily on first batch, - # once the embedding dim/device is known). ``_n_filled`` slots hold - # data; ``_n_seen`` is the running count for the sampling probability. + # Allocated lazily on the first batch. _n_filled = used slots; + # _n_seen = running count for the accept prob. self._reservoir: Optional[torch.Tensor] = None self._n_filled = 0 self._n_seen = 0 - # KMeans has no learnable parameters (centroids use register_buffer). - # Add dummy param to keep optimizer/DDP happy. - self._dummy_param = nn.Parameter(torch.zeros(1), requires_grad=True) - @torch.no_grad() def _reservoir_add(self, x: torch.Tensor) -> None: - """Add a batch to the bounded reservoir (Vitter's Algorithm R). + """Stream a batch into the reservoir (Vitter Algorithm R). - Keeps a uniform random ``self._sample_cap`` subset of every embedding - seen so far in O(cap) host memory, in a single streaming pass. + Keeps a uniform ``_sample_cap`` sample of all embeddings seen, in + O(cap) host memory. Args: - x (Tensor): a batch of embeddings, shape (B, D); copied to host. + x (Tensor): batch of embeddings, shape (B, D). """ x = x.detach() cap = self._sample_cap @@ -146,22 +142,17 @@ def _reservoir_add(self, x: torch.Tensor) -> None: if x.shape[0] == 0: return - # Phase 2: replacement. Row j (0-indexed in x) is the - # (n_seen + j)-th item seen; it enters the reservoir with prob - # cap / (n_seen + j + 1), displacing a uniformly-random slot. The - # accept decision needs only counts (not embedding values), so we - # compute it on small host index tensors and copy ONLY the accepted - # rows to host — in steady state (reservoir full, n_seen >> cap) - # almost none are accepted, so the whole-batch GPU->CPU copy is - # avoided. float64 keeps (n_seen + j + 1) exact past 2**24. + # Phase 2: row j enters with prob cap/(n_seen+j+1), displacing a random + # slot. The accept decision needs only counts, so compute it on host and + # copy ONLY accepted rows (in steady state, almost none) — avoiding the + # whole-batch GPU->CPU copy. float64 keeps n_seen+j+1 exact past 2**24. r = x.shape[0] pos = self._n_seen + torch.arange(r) accept = torch.rand(r) < (cap / (pos + 1).to(torch.float64)) idx = accept.nonzero(as_tuple=True)[0] if idx.numel() > 0: slots = torch.randint(0, cap, (idx.numel(),)) - # Intra-batch slot collisions resolve last-write-wins; the bias is - # O(B/cap) per step and negligible for codebook fitting. + # Slot collisions are last-write-wins; O(B/cap) bias, negligible here. self._reservoir[slots] = x[idx.to(x.device)].to("cpu", dtype=torch.float32) self._n_seen += r @@ -191,10 +182,8 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """ embedding = self._extract_feature(batch) - # Training: reservoir-sample into a bounded host buffer for the - # end-of-loop FAISS fit, and return dummy codes — the codebook does - # not exist yet. The reservoir caps memory at _sample_cap rows - # regardless of corpus size (FAISS only consumes a subset anyway). + # Training: just reservoir-sample for the end-of-loop FAISS fit and + # return dummy codes — the codebook does not exist yet. if self.is_train: self._reservoir_add(embedding) B = embedding.shape[0] @@ -221,9 +210,8 @@ def loss( ) -> Dict[str, torch.Tensor]: """Compute loss of the model. - Returns zero loss to keep TrainWrapper backward happy. - _dummy_param * 0.0 ensures a compute graph exists so DDP - does not complain about unused parameters. + Zero loss via ``_dummy_param * 0`` — gives TrainWrapper/DDP a compute + graph despite there being no real trainable params. Args: predictions (dict): a dict of predicted result. @@ -235,14 +223,11 @@ def loss( return {"dummy_loss": self._dummy_param.sum() * 0.0} def init_metric(self) -> None: - """Initialize metric modules (shared eval metrics + rel_loss). - - Only eval metrics are registered. During training ``predict`` - returns dummy zero codes (the codebook does not exist yet), so - any train-time metric would be either NaN or trivially constant; - the inherited no-op ``update_train_metric`` keeps the train path - empty (``compute_train_metric`` then returns an empty dict, which - the framework already tolerates). + """Register eval metrics (shared ``mse`` + ``rel_loss``). + + Train-time metrics are intentionally absent: ``predict`` returns dummy + codes pre-fit, so the inherited no-op ``update_train_metric`` keeps the + train path empty. """ super().init_metric() self._metric_modules["rel_loss"] = torchmetrics.MeanMetric() @@ -265,8 +250,8 @@ def update_metric( predictions["input_embedding"], predictions["quantized"], ) - # MeanSquaredError aggregates (preds, target) itself; rel_loss has - # no torchmetrics equivalent so it stays a MeanMetric. + # mse aggregates (preds, target) itself; rel_loss has no + # torchmetrics equivalent, so it stays a MeanMetric. self._metric_modules["mse"].update( predictions["quantized"], predictions["input_embedding"] ) @@ -276,30 +261,20 @@ def update_metric( @torch.no_grad() def on_train_end(self) -> bool: - """Trigger one-shot FAISS fit after the train_eval loop ends. - - Overrides :meth:`BaseModel.on_train_end`. Called unconditionally - by ``tzrec.main.train_and_evaluate`` after the training loop exits. + """Fit the FAISS codebook once, after the train_eval loop exits. - DDP behavior: - - rank0: receive each rank's reservoir sample via gather_object, - concat, run FAISS fit, then broadcast centroids to all ranks. - - other ranks: ship their reservoir sample via gather_object - (dst=0) and wait for the broadcast. + Overrides :meth:`BaseModel.on_train_end` (called unconditionally by + ``tzrec.main``). DDP: every rank gather_objects its reservoir to rank0, + which fits and broadcasts the centroids back. - Empty-reservoir handling: for any real-scale dataset every rank gets - a non-empty reservoir — the default ParquetDataset (``rebalance=True``) - splits rows across ``num_workers * world_size`` workers, so a rank only - ends up empty for a pathologically tiny corpus (``total_rows`` smaller - than that worker count). That degenerate case does not hang: rank0's - FAISS fit raises on too-few points and the fit-status broadcast below - makes every rank raise a coordinated ``RuntimeError`` instead. + An empty reservoir only happens for a pathologically tiny corpus + (rebalance splits rows across ``num_workers * world_size``); it then + fails fast via the fit-status broadcast rather than hanging. Returns: - is_ckpt_after_train (bool): ``True`` once the codebook has been - fitted here (the centroid buffers changed and must be persisted, - so the train loop forces a final checkpoint); ``False`` when the - fit was skipped (empty reservoir — nothing to persist). + is_ckpt_after_train (bool): ``True`` if the codebook was fitted + (centroids changed → force a final checkpoint), ``False`` if the + fit was skipped (empty reservoir). """ is_ddp = ( dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1 @@ -309,10 +284,7 @@ def on_train_end(self) -> bool: self._reset_reservoir() if is_ddp: - # DDP path: every rank ships its reservoir sample to rank 0 via - # gather_object. Each sample is bounded by _sample_cap, so the - # gathered set on rank0 is ~train_sample_size and FAISS does no - # further subsampling. + # Each rank ships its (capped) reservoir to rank0, which fits. rank = dist.get_rank() gathered: Optional[List[Optional[torch.Tensor]]] = ( [None] * dist.get_world_size() if rank == 0 else None @@ -332,16 +304,14 @@ def on_train_end(self) -> bool: self._quantizer.train_offline(full, verbose=True) del full except Exception as e: # noqa: BLE001 - # Swallow on rank0 only long enough to tell the peers — if - # we let it propagate here, ranks 1..N-1 would block forever - # on the centroid broadcast below with no sender. + # Don't raise yet — peers would hang on the broadcast below. + # Signal failure via the status flag so all ranks raise. fit_ok = False logger.error( "[SidRqkmeans.on_train_end] rank0 FAISS fit failed: %s", e ) - # Sync rank0's status to every rank (int flag, not bool — see the - # NCCL note below) so a rank0-only failure makes all ranks raise - # together instead of deadlocking on the centroid broadcast. + # Broadcast rank0's status (int, not bool — see NCCL note below) so + # a rank0-only failure makes all ranks raise instead of deadlocking. status = torch.tensor( [1 if fit_ok else 0], device=self._quantizer.layers[0].centroids.device, @@ -352,20 +322,16 @@ def on_train_end(self) -> bool: "[SidRqkmeans.on_train_end] FAISS fit failed on rank0; " "see rank0 logs for the underlying error." ) - # Broadcast centroids and set the init flag locally on every - # rank. ``_is_initialized`` is a bool buffer and NCCL's bool - # dtype support is inconsistent across versions, so we avoid - # a separate broadcast for it — all ranks enter this block in - # lockstep, so a local fill_() keeps state consistent. + # Broadcast centroids; set the init flag locally (avoids + # broadcasting a bool buffer — NCCL bool support is inconsistent). + # All ranks are in lockstep, so a local mark_initialized_() agrees. for layer in self._quantizer.layers: dist.broadcast(layer.centroids, src=0) - layer._is_initialized.fill_(True) - layer._initialized = True + layer.mark_initialized_() dist.barrier() return True - # Single-process path. Guard an empty sample with a plain local check - # (no collective): on_train_end may be invoked without a training pass. + # Single-process: guard an empty reservoir with a plain local check. if local.shape[0] == 0: logger.warning( "[SidRqkmeans.on_train_end] empty reservoir; skipping FAISS " diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans.py index ecc554aa5..d6e34acd9 100644 --- a/tzrec/modules/sid/kmeans.py +++ b/tzrec/modules/sid/kmeans.py @@ -33,9 +33,8 @@ def recon_diagnostics( ) -> Tuple[torch.Tensor, torch.Tensor]: """MSE + relative-L1 reconstruction diagnostics. - Shared by :meth:`SidRqkmeans.update_metric` (which wants tensors for - ``torchmetrics.MeanMetric``) and :meth:`ResidualKMeansQuantizer.train_offline`'s - per-layer log line (which converts to Python floats via ``.item()``). + Shared by :meth:`SidRqkmeans.update_metric` and + :meth:`ResidualKMeansQuantizer.train_offline`'s per-layer log. Args: x: ground-truth embedding, shape (B, D). @@ -64,12 +63,8 @@ def _squared_euclidean_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tenso Returns: Tensor: squared distances, shape (N, K). - Called per-batch from :meth:`KMeansLayer.predict`, so ``N`` is the batch - size and the full (N, K) product is small. Kept branch-free (no - data-dependent chunking on ``N``) so the predict forward stays - FX-traceable: torchrec's inference pipeline symbolically traces the - model, and a ``if N <= chunk_size`` on the traced batch dim raises a - ``torch.fx`` TraceError. + Kept branch-free (no data-dependent control flow on ``N``) so the + per-batch predict forward stays FX-traceable for torchrec inference. """ x_sq = x.pow(2).sum(dim=1, keepdim=True) # (N, 1) y_sq = y.pow(2).sum(dim=1, keepdim=True).t() # (1, K) @@ -79,11 +74,9 @@ def _squared_euclidean_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tenso class KMeansLayer(nn.Module): """Single layer of a residual K-Means stack. - Centroids are populated externally by ``load_centroids_`` (called per - layer by the FAISS backend in :class:`ResidualKMeansQuantizer`); ``predict`` - is the only forward path. PyTorch state-dict keys are scoped by - attribute path (``layers..centroids``), so renaming the class - does not break existing checkpoints. + Centroids are populated externally by ``load_centroids_`` (the FAISS + backend in :class:`ResidualKMeansQuantizer`); ``predict`` is the only + forward path. Args: n_clusters (int): number of clusters (codebook size). @@ -100,15 +93,12 @@ def __init__( self.n_features = n_features self.register_buffer("centroids", torch.zeros(n_clusters, n_features)) - # Flipped by ``load_centroids_`` after the FAISS fit. Persistent - # so a normal post-fit checkpoint round-trips; mid-fit poisoning - # (True flag + still-zero centroids) is caught in _load_from_state_dict. + # Persistent so a post-fit checkpoint round-trips; a mid-fit poison + # (True flag + zero centroids) is caught in _load_from_state_dict. self.register_buffer("_is_initialized", torch.tensor(False)) - # Plain-Python mirror of ``_is_initialized``, read on the per-batch - # forward path (``_quantize_layer``) so the hot path never pays a - # ``.item()`` GPU->CPU sync. Kept in lockstep with the buffer wherever - # the buffer changes: ``load_centroids_``, ``_load_from_state_dict``, - # and the DDP broadcast in ``SidRqkmeans.on_train_end``. + # Plain-Python mirror of the buffer, read on the per-batch forward + # path to avoid a .item() GPU->CPU sync. Synced only via + # mark_initialized_ and _load_from_state_dict. self._initialized: bool = False @property @@ -116,6 +106,15 @@ def is_initialized(self) -> bool: """Whether centroids have been injected via ``load_centroids_``.""" return self._initialized + def mark_initialized_(self) -> None: + """Flag centroids populated, syncing buffer + cached mirror. + + For callers that fill ``centroids`` in place (e.g. the DDP broadcast + in :meth:`SidRqkmeans.on_train_end`) rather than via ``load_centroids_``. + """ + self._is_initialized.fill_(True) + self._initialized = True + @torch.no_grad() def load_centroids_(self, centroids: torch.Tensor) -> None: """Inject offline-trained centroids. @@ -131,8 +130,7 @@ def load_centroids_(self, centroids: torch.Tensor) -> None: self.centroids.copy_( centroids.to(dtype=self.centroids.dtype, device=self.centroids.device) ) - self._is_initialized.fill_(True) - self._initialized = True + self.mark_initialized_() def _load_from_state_dict( self, @@ -154,8 +152,7 @@ def _load_from_state_dict( unexpected_keys, error_msgs, ) - # Mirror the restored buffer into the cached Python flag (one sync at - # load time, off the hot path). + # Mirror the restored buffer into the cached flag (one load-time sync). self._initialized = bool(self._is_initialized.item()) if self._initialized and self.centroids.abs().sum() == 0: error_msgs.append( diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index a3bfb1dae..e28891f9c 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -17,7 +17,8 @@ from typing import Dict, List, Optional, Tuple, Union -import numpy as np +import faiss +import faiss.contrib.torch_utils # noqa: F401 (registers torch tensor I/O) import torch from torch import nn from torch.nn import functional as F @@ -144,61 +145,34 @@ def _lookup_code(self, layer_idx: int, code_idx: torch.Tensor) -> torch.Tensor: @torch.no_grad() def train_offline( self, - inputs: Union[torch.Tensor, "np.ndarray"], + inputs: torch.Tensor, verbose: bool = True, ) -> None: """Train the multi-layer codebook via offline FAISS K-Means. - FAISS consumes torch tensors directly (via ``faiss.contrib. - torch_utils``) — no numpy round-trips. The residual matrix stays a - host (CPU) tensor; when a faiss-gpu build is present, ``gpu=`` - moves only FAISS's internal, subsampled working set to the GPU, so we - never hold (N, D) in VRAM. On a faiss-cpu build it runs on CPU - unchanged. Either way the code path is identical. + The residual matrix stays a host (CPU) tensor; with a faiss-gpu build, + ``gpu=`` moves only FAISS's subsampled working set to the GPU, so + we never hold (N, D) in VRAM. faiss-cpu runs the same path on CPU. Args: - inputs: full embedding matrix, shape (N, D), ``torch.Tensor`` or - ``np.ndarray``. Copied once to an owned CPU float32 tensor; - the caller's input is not mutated. - verbose (bool): whether to print per-layer reconstruction - loss. Default: True. - - Raises: - ImportError: if ``faiss`` is not installed. + inputs (Tensor): embedding matrix (N, D). Copied once to an owned + CPU float32 tensor; not mutated. + verbose (bool): print per-layer reconstruction loss. Default: True. """ - try: - import faiss - import faiss.contrib.torch_utils # noqa: F401 (torch tensor I/O) - except ImportError as e: - raise ImportError( - "faiss is required for ResidualKMeansQuantizer training. Install via " - "`pip install faiss-cpu` or `pip install faiss-gpu`." - ) from e - - # Own a contiguous CPU float32 tensor we can update in place for - # residuals, without mutating the caller's input. - if isinstance(inputs, torch.Tensor): - assert inputs.dim() == 2 and inputs.shape[1] == self.embed_dim, ( - f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" - ) - x = inputs.detach().to("cpu", torch.float32).contiguous().clone() - else: - assert inputs.ndim == 2 and inputs.shape[1] == self.embed_dim, ( - f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" - ) - x = torch.from_numpy(np.ascontiguousarray(inputs, dtype=np.float32)).clone() + # Own a contiguous CPU float32 copy to update in place as the residual. + assert inputs.dim() == 2 and inputs.shape[1] == self.embed_dim, ( + f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" + ) + x = inputs.detach().to("cpu", torch.float32).contiguous().clone() N = x.shape[0] out = torch.zeros_like(x) - # Keep the original input only when we log: the per-layer diagnostic - # is the cumulative reconstruction error of the *original* input by - # the centroid sum so far (the same quantity update_metric reports). - # ``out + x`` would equal it only when normalize_residuals is off; with - # normalization the residual is rescaled each layer, so track x0. + # Original input, kept only for the log: the per-layer diagnostic is the + # cumulative recon error of x0 by the centroid sum (what update_metric + # reports). ``out + x`` would equal it only without normalization. x0 = x.clone() if verbose else None - # Use FAISS GPU compute when a GPU build is available (data stays on - # host; FAISS streams only its subsampled training set to the device). - # An explicit ``gpu`` in faiss_kmeans_kwargs always wins. + # Use FAISS GPU compute when a faiss-gpu build is present; an explicit + # ``gpu`` in faiss_kmeans_kwargs always wins. kwargs = dict(self.faiss_kmeans_kwargs) if "gpu" not in kwargs: kwargs["gpu"] = ( @@ -207,17 +181,15 @@ def train_offline( else False ) - # Chunk size for index.search to limit peak memory. - # 500K × 512 × 4B ≈ 1 GB per chunk. + # Chunk index.search to cap peak memory (~1 GB at 500K × 512 × 4B). SEARCH_CHUNK = 500_000 for layer_idx in range(self.n_layers): if self.normalize_residuals: x = F.normalize(x, dim=-1) - # Fresh Kmeans per layer so each layer can use its own K - # (non-uniform codebooks supported). Index construction is a cheap - # O(K*D) allocation next to train(), so this is effectively free. + # Fresh Kmeans per layer so each can use its own K (non-uniform + # codebooks). kmeans = faiss.Kmeans( self.embed_dim, self.n_embed_list[layer_idx], **kwargs ) From fbd973ffd1e391b1df4579dfe9987a9e90760704 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 8 Jun 2026 07:56:52 +0000 Subject: [PATCH 06/46] [review] SID: move FAISS fit-sample sizing into the quantizer Add ResidualKMeansQuantizer.default_fit_sample_size() (max(K) * max_points_per_centroid) so the FAISS default lives in the FAISS-owning class; SidRqkmeans._init_reservoir asks the quantizer instead of reading faiss_kwargs and hardcoding 256. Behavior-identical. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 13 ++++++------- tzrec/modules/sid/residual_kmeans_quantizer.py | 10 ++++++++++ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 3859a3ef0..9f29b4eac 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -97,15 +97,14 @@ def __init__( def _init_reservoir(self) -> None: """Set up the bounded host reservoir for the end-of-loop FAISS fit. - Per-rank cap: FAISS subsamples to K*max_points_per_centroid internally, - so reservoir-sample to that target (split across ranks) rather than - buffer the whole corpus. Use the largest per-layer K so non-uniform - codebooks still feed their biggest layer enough points. + Per-rank cap: target the points the FAISS fit will subsample to + (``ResidualKMeansQuantizer.default_fit_sample_size``), split across + ranks, rather than buffer the whole corpus. """ - k = max(self._n_embed_list) - max_ppc = int(self._faiss_kwargs.get("max_points_per_centroid", 256)) target = self._model_config.train_sample_size - global_target = target if target > 0 else k * max_ppc + global_target = ( + target if target > 0 else self._quantizer.default_fit_sample_size() + ) world_size = dist.get_world_size() if dist.is_initialized() else 1 self._sample_cap = max(1, -(-global_target // world_size)) # ceil div diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index e28891f9c..8e5960e99 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -142,6 +142,16 @@ def _lookup_code(self, layer_idx: int, code_idx: torch.Tensor) -> torch.Tensor: """Look up codebook vectors via the layer's centroid table.""" return self.layers[layer_idx].centroids[code_idx] + def default_fit_sample_size(self) -> int: + """Points the FAISS fit subsamples to: max(K) * max_points_per_centroid. + + ``faiss.Kmeans`` caps each layer's training set at + ``K * max_points_per_centroid`` (default 256), so fitting on more is + wasted. Callers use this to size their training-sample reservoir. + """ + max_ppc = int(self.faiss_kmeans_kwargs.get("max_points_per_centroid", 256)) + return max(self.n_embed_list) * max_ppc + @torch.no_grad() def train_offline( self, From 893a62794b17fc5aaf22bd0b9710b6a26ef8bdf1 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 8 Jun 2026 08:22:38 +0000 Subject: [PATCH 07/46] [review] SID: log rank0 FAISS-fit failure with traceback Use logger.exception() in on_train_end's rank0 except so the underlying error's stack trace is captured (peers raise a coordinated RuntimeError pointing at the rank0 log); drop the now-unused `as e`. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 9f29b4eac..3401f7568 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -302,12 +302,14 @@ def on_train_end(self) -> bool: ) self._quantizer.train_offline(full, verbose=True) del full - except Exception as e: # noqa: BLE001 + except Exception: # noqa: BLE001 # Don't raise yet — peers would hang on the broadcast below. # Signal failure via the status flag so all ranks raise. + # logger.exception keeps the traceback so the rank0-only + # failure is diagnosable from the log. fit_ok = False - logger.error( - "[SidRqkmeans.on_train_end] rank0 FAISS fit failed: %s", e + logger.exception( + "[SidRqkmeans.on_train_end] rank0 FAISS fit failed" ) # Broadcast rank0's status (int, not bool — see NCCL note below) so # a rank0-only failure makes all ranks raise instead of deadlocking. From 3734fc21dfee8d7ebc98b9b4e6d957ed84df8e11 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 02:36:08 +0000 Subject: [PATCH 08/46] [review] SID: clarify the reservoir ceil-div comment Comment-only; pushed to re-trigger CI. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 3401f7568..25bd4f19a 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -106,7 +106,8 @@ def _init_reservoir(self) -> None: target if target > 0 else self._quantizer.default_fit_sample_size() ) world_size = dist.get_world_size() if dist.is_initialized() else 1 - self._sample_cap = max(1, -(-global_target // world_size)) # ceil div + # ceil div: round up so the per-rank caps together cover global_target. + self._sample_cap = max(1, -(-global_target // world_size)) # Allocated lazily on the first batch. _n_filled = used slots; # _n_seen = running count for the accept prob. From 795c676569188719079e72183fe33083f06d1547 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 03:40:15 +0000 Subject: [PATCH 09/46] [review] SID: fix FAISS gpu kwarg + close test gaps from PR review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Should-fix: - train_offline: faiss reads `gpu` as a GPU *count*, not a device index, so `gpu=current_device()` was 0 (single-GPU / rank0) -> falsy -> silent CPU fallback. Pass `gpu=1` so the fit actually runs on the (rank0) GPU. Test gaps: - reservoir Phase-2 replacement correctness (identifiable rows: intact, in-range, replacement actually occurs) — beyond the count/shape checks. - normalize_residuals=True end-to-end through train_offline (the F.normalize site the other tests never reached). - eval vs inference predict contract (quantized/input_embedding vs codes-only) and the init_metric/update_metric/compute_metric path. - checkpoint round-trip now asserts codes match the source model exactly (assert_close), not merely non-zero. Minor docs: - on_train_end Returns: clarify only the single-process path returns False; DDP raises on an empty gather. - train_offline docstring: the post-fit index.search streams all N in chunks. - proto train_sample_size comment: K -> max(K) for non-uniform codebooks. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 7 +- tzrec/models/sid_rqkmeans_test.py | 169 ++++++++++++++++-- .../modules/sid/residual_kmeans_quantizer.py | 17 +- tzrec/protos/models/sid_model.proto | 7 +- 4 files changed, 173 insertions(+), 27 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 25bd4f19a..384c049a9 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -273,8 +273,11 @@ def on_train_end(self) -> bool: Returns: is_ckpt_after_train (bool): ``True`` if the codebook was fitted - (centroids changed → force a final checkpoint), ``False`` if the - fit was skipped (empty reservoir). + (centroids changed → force a final checkpoint). Only the + single-process path can return ``False`` (empty reservoir, fit + skipped); the DDP path either returns ``True`` or raises (an empty + gather makes rank0's fit fail, which the status broadcast turns + into a coordinated ``RuntimeError``). """ is_ddp = ( dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1 diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index 30e204116..2cfad30a1 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -27,12 +27,9 @@ WORLD_SIZE = 2 -def _make_batch(batch_size: int, input_dim: int, device: str = "cpu") -> Batch: - """Create a minimal Batch with dense embedding features.""" - dense_feature = KeyedTensor.from_tensor_list( - keys=["item_emb"], - tensors=[torch.randn(batch_size, input_dim, device=device)], - ) +def _batch_from_rows(rows: torch.Tensor) -> Batch: + """Wrap explicit ``item_emb`` rows in a minimal Batch.""" + dense_feature = KeyedTensor.from_tensor_list(keys=["item_emb"], tensors=[rows]) return Batch( dense_features={BASE_DATA_GROUP: dense_feature}, sparse_features={}, @@ -40,7 +37,14 @@ def _make_batch(batch_size: int, input_dim: int, device: str = "cpu") -> Batch: ) -def _build_model(input_dim=32, n_layers=2, niter=5, codebook=None) -> SidRqkmeans: +def _make_batch(batch_size: int, input_dim: int, device: str = "cpu") -> Batch: + """Create a minimal Batch with random dense embedding features.""" + return _batch_from_rows(torch.randn(batch_size, input_dim, device=device)) + + +def _build_model( + input_dim=32, n_layers=2, niter=5, codebook=None, normalize_residuals=False +) -> SidRqkmeans: """Build a SidRqkmeans configured for offline FAISS fit. Module-level (not a method) so the spawned DDP workers below can build @@ -56,7 +60,7 @@ def _build_model(input_dim=32, n_layers=2, niter=5, codebook=None) -> SidRqkmean cfg = sid_model_pb2.SidRqkmeans( input_dim=input_dim, codebook=n_embed_list, - normalize_residuals=False, + normalize_residuals=normalize_residuals, faiss_kmeans_kwargs=faiss_kwargs, embedding_feature_name="item_emb", ) @@ -70,9 +74,16 @@ def _build_model(input_dim=32, n_layers=2, niter=5, codebook=None) -> SidRqkmean class SidRqkmeansOfflineTest(unittest.TestCase): """Single-process tests for SidRqkmeans (FAISS-only).""" - def _create_model(self, input_dim=32, n_layers=2, niter=5, codebook=None): + def _create_model( + self, + input_dim=32, + n_layers=2, + niter=5, + codebook=None, + normalize_residuals=False, + ): """Create a SidRqkmeans on CPU with params initialized.""" - model = _build_model(input_dim, n_layers, niter, codebook) + model = _build_model(input_dim, n_layers, niter, codebook, normalize_residuals) init_parameters(model, device=torch.device("cpu")) return model @@ -117,6 +128,51 @@ def test_reservoir_caps_memory(self) -> None: self.assertEqual(model._n_filled, 10) self.assertEqual(model._reservoir.shape, (10, input_dim)) + def test_reservoir_phase2_replacement(self) -> None: + """Phase-2 replacement keeps a valid reservoir of real, in-range rows. + + Feeds identifiable rows (each row's value == its global stream index), + then asserts every reservoir slot still holds an intact fed row, all + indices are in range, and replacement past the initial fill actually + happened — exercising the accept-prob / slot-write logic that the + count/shape-only ``test_reservoir_caps_memory`` cannot. + """ + torch.manual_seed(0) + input_dim, cap, B, n_batches = 4, 8, 4, 50 + model = self._create_model(input_dim=input_dim) + model._sample_cap = cap + model._reset_reservoir() + model.train() + + gidx = 0 + for _ in range(n_batches): + rows = ( + torch.arange(gidx, gidx + B, dtype=torch.float32) + .unsqueeze(1) + .expand(B, input_dim) + .contiguous() + ) + gidx += B + model.predict(_batch_from_rows(rows)) + + total = B * n_batches + self.assertEqual(model._n_seen, total) + self.assertEqual(model._n_filled, cap) + + res = model._reservoir + idx = res[:, 0].round().long() + # Each stored row is an intact fed row (all columns equal its index), + # never zeros/garbage. + self.assertTrue( + torch.equal(res, idx.unsqueeze(1).float().expand_as(res)), + "reservoir holds corrupted (non-fed) rows", + ) + # All indices are valid stream positions. + self.assertTrue((idx >= 0).all() and (idx < total).all()) + # Phase-2 replacement happened: at least one slot holds a row added + # after the reservoir filled (index >= cap). + self.assertTrue((idx >= cap).any(), "no Phase-2 replacement occurred") + def test_on_train_end_runs_faiss(self) -> None: """on_train_end triggers FAISS fit and clears buffer.""" try: @@ -182,6 +238,83 @@ def test_non_uniform_codebook_end_to_end(self) -> None: for i, k in enumerate(codebook): self.assertTrue((codes[:, i] >= 0).all() and (codes[:, i] < k).all()) + def test_normalize_residuals_end_to_end(self) -> None: + """train_offline with normalize_residuals=True fits + predicts. + + Exercises the ``F.normalize`` site inside ``train_offline`` (a second + normalize independent of ``_residual_pass``), which the other tests — + all built with normalize_residuals=False — never reach. + """ + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + + B, input_dim = 64, 32 + model = self._create_model(input_dim=input_dim, normalize_residuals=True) + self.assertTrue(model._quantizer.normalize_residuals) + + model.train() + for _ in range(8): + model.predict(_make_batch(B, input_dim)) + self.assertTrue(model.on_train_end()) + + for layer in model._quantizer.layers: + self.assertTrue(layer.is_initialized) + + model.eval() + codes = model.predict(_make_batch(B, input_dim))["codes"] + self.assertEqual(codes.shape, (B, 2)) + self.assertTrue((codes >= 0).all() and (codes < 16).all()) + + def test_eval_and_inference_predict_contract(self) -> None: + """Eval exposes quantized/input_embedding; inference is codes-only.""" + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + + B, input_dim = 64, 32 + model = self._create_model(input_dim=input_dim) + model.train() + for _ in range(8): + model.predict(_make_batch(B, input_dim)) + model.on_train_end() + + # Eval mode: reconstruction outputs are present for update_metric. + model.eval() + eval_preds = model.predict(_make_batch(B, input_dim)) + self.assertIn("quantized", eval_preds) + self.assertIn("input_embedding", eval_preds) + + # Inference (serving) mode: codes-only contract. + model.set_is_inference(True) + inf_preds = model.predict(_make_batch(B, input_dim)) + self.assertEqual(set(inf_preds.keys()), {"codes"}) + + def test_eval_metric_path(self) -> None: + """init_metric/update_metric report finite mse + rel_loss in eval.""" + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + + B, input_dim = 64, 32 + model = self._create_model(input_dim=input_dim) + model.train() + for _ in range(8): + model.predict(_make_batch(B, input_dim)) + model.on_train_end() + + model.init_metric() + model.eval() + preds = model.predict(_make_batch(B, input_dim)) + model.update_metric(preds, _make_batch(B, input_dim)) + metrics = model.compute_metric() + for key in ("mse", "rel_loss", "unique_sid_ratio"): + self.assertIn(key, metrics) + self.assertTrue(torch.isfinite(torch.as_tensor(metrics[key])).all()) + def test_on_train_end_noop_on_empty_buffer(self) -> None: """on_train_end on an empty buffer is a warned no-op.""" model = self._create_model() @@ -191,9 +324,9 @@ def test_on_train_end_noop_on_empty_buffer(self) -> None: def test_post_fit_checkpoint_round_trips(self) -> None: """Fit → save state_dict → load into fresh instance → predict. - After loading, ``predict`` must return real (non-zero) codes — - the centroids and the ``_is_initialized`` flag both need to come - through the state_dict. + The reloaded model must produce the *same* codes as the source on the + same batch — verifying the centroids round-trip exactly, not merely + that they came through as non-zero. """ try: import faiss # noqa: F401 @@ -210,13 +343,19 @@ def test_post_fit_checkpoint_round_trips(self) -> None: dst = self._create_model(input_dim=input_dim) dst.load_state_dict(sd) + + # Same batch through both → identical codes (exact round-trip). + batch = _make_batch(B, input_dim) + src.eval() dst.eval() - codes = dst.predict(_make_batch(B, input_dim))["codes"] + src_codes = src.predict(batch)["codes"] + dst_codes = dst.predict(batch)["codes"] self.assertGreater( - codes.abs().sum().item(), + dst_codes.abs().sum().item(), 0, "post-fit checkpoint resume produced all-zero codes", ) + torch.testing.assert_close(dst_codes, src_codes) def test_mid_fit_checkpoint_rejected_on_load(self) -> None: """Tampered state (_is_initialized=True + zero centroids) raises.""" diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 8e5960e99..83e095192 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -160,9 +160,12 @@ def train_offline( ) -> None: """Train the multi-layer codebook via offline FAISS K-Means. - The residual matrix stays a host (CPU) tensor; with a faiss-gpu build, - ``gpu=`` moves only FAISS's subsampled working set to the GPU, so - we never hold (N, D) in VRAM. faiss-cpu runs the same path on CPU. + The residual matrix stays a host (CPU) tensor. With a faiss-gpu build, + ``faiss.Kmeans`` runs the K-Means training (over its internally + subsampled set) on the GPU; the post-fit ``index.search`` assignment + still streams all N rows through in ``SEARCH_CHUNK``-sized chunks, so we + never hold the full (N, D) on the device. faiss-cpu runs the same path + on CPU. Args: inputs (Tensor): embedding matrix (N, D). Copied once to an owned @@ -182,13 +185,13 @@ def train_offline( x0 = x.clone() if verbose else None # Use FAISS GPU compute when a faiss-gpu build is present; an explicit - # ``gpu`` in faiss_kmeans_kwargs always wins. + # ``gpu`` in faiss_kmeans_kwargs always wins. NB faiss reads ``gpu`` as a + # GPU *count* (1 = one GPU = the current/rank0 device), not a device + # index — passing an index of 0 is falsy and silently falls back to CPU. kwargs = dict(self.faiss_kmeans_kwargs) if "gpu" not in kwargs: kwargs["gpu"] = ( - torch.cuda.current_device() - if faiss.get_num_gpus() > 0 and torch.cuda.is_available() - else False + 1 if (faiss.get_num_gpus() > 0 and torch.cuda.is_available()) else False ) # Chunk index.search to cap peak memory (~1 GB at 500K × 512 × 4B). diff --git a/tzrec/protos/models/sid_model.proto b/tzrec/protos/models/sid_model.proto index 6c3d1b297..fdd41a22c 100644 --- a/tzrec/protos/models/sid_model.proto +++ b/tzrec/protos/models/sid_model.proto @@ -21,9 +21,10 @@ message SidRqkmeans { optional google.protobuf.Struct faiss_kmeans_kwargs = 5; // Target number of embeddings to reservoir-sample for the FAISS fit // (global, across all ranks). Bounds host memory regardless of corpus - // size. 0 (the default) auto-derives it as K * max_points_per_centroid - // — exactly what FAISS subsamples to internally (default 256), so no - // training points are wasted. + // size. 0 (the default) auto-derives it as max(K) * max_points_per_centroid + // (the largest per-layer codebook, for non-uniform codebooks) — exactly + // what FAISS subsamples to internally (default 256), so no training points + // are wasted. optional uint32 train_sample_size = 6 [default = 0]; // Name of the item embedding feature inside the input Batch. From 2bb5abc117692678bea701c6e94cd26fa771cfe8 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 06:02:49 +0000 Subject: [PATCH 10/46] [review] SID: default FAISS fit to CPU + DDP fit-failure test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The earlier gpu=1 "fix" was itself wrong and broke the GPU unittest_ci (cpu_ci/h20 passed): faiss reads `gpu` as a COUNT and 1 == True collapses to all-GPUs, so the rank0-only fit sharded over every rank's device and the GPU faiss path (newly activated — it was a silent CPU fallback before) failed on the tiny test data. faiss's count kwarg cannot pin to a single device, so default the fit to CPU (a bounded one-shot; set gpu in faiss_kmeans_kwargs to opt in explicitly). Also: - _init_reservoir docstring: note the cap targets train_sample_size when set, else default_fit_sample_size(). - Add test_on_train_end_ddp_rank0_failure: forces rank0's fit to raise and asserts every rank raises the coordinated RuntimeError, with join(timeout) so a reintroduced deadlock fails CI instead of hanging. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 5 +- tzrec/models/sid_rqkmeans_test.py | 64 +++++++++++++++++++ .../modules/sid/residual_kmeans_quantizer.py | 15 ++--- 3 files changed, 74 insertions(+), 10 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 384c049a9..66f87cdd2 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -97,9 +97,10 @@ def __init__( def _init_reservoir(self) -> None: """Set up the bounded host reservoir for the end-of-loop FAISS fit. - Per-rank cap: target the points the FAISS fit will subsample to + Per-rank cap: target ``train_sample_size`` when set (>0), else the + points the FAISS fit subsamples to (``ResidualKMeansQuantizer.default_fit_sample_size``), split across - ranks, rather than buffer the whole corpus. + ranks — rather than buffer the whole corpus. """ target = self._model_config.train_sample_size global_target = ( diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index 2cfad30a1..e54e44f53 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -426,6 +426,41 @@ def _on_train_end_worker(rank: int, world_size: int, port: int) -> None: dist.destroy_process_group() +def _on_train_end_fail_worker(rank: int, world_size: int, port: int) -> None: + """Worker that forces rank0's FAISS fit to fail. + + Every rank must then raise the coordinated ``RuntimeError`` (driven by the + fit-status broadcast) instead of deadlocking on the centroid broadcast. A + worker returns 0 only if it caught that expected error. + """ + device = _init_dist(rank, world_size, port) + input_dim, n_layers, k = 16, 2, 16 + model = _build_model(input_dim, n_layers, codebook=[k] * n_layers).to(device) + model.train() + for _ in range(6): + model.predict(_make_batch(32, input_dim, device)) + + # Force the rank0-only fit to raise (no faiss needed: only rank0 fits, and + # we replace its fit). The status flag must turn this into an all-ranks + # raise, not a hang. + if rank == 0: + + def _boom(*args, **kwargs): + raise RuntimeError("forced rank0 fit failure") + + model._quantizer.train_offline = _boom + + try: + model.on_train_end() + except RuntimeError: + dist.destroy_process_group() + return # expected: coordinated failure reached this rank + dist.destroy_process_group() + raise AssertionError( + f"rank{rank}: on_train_end did not raise on a rank0 fit failure" + ) + + class SidRqkmeansDistTest(unittest.TestCase): """2-rank test for SidRqkmeans.on_train_end (gather -> fit -> broadcast).""" @@ -442,6 +477,35 @@ def test_on_train_end_ddp(self) -> None: if p.exitcode != 0: raise RuntimeError(f"worker-{i} failed (exitcode={p.exitcode}).") + def test_on_train_end_ddp_rank0_failure(self) -> None: + """A rank0-only fit failure raises on every rank — never deadlocks. + + Guards the status-flag-before-centroid-broadcast ordering: a regression + that reordered/dropped it would hang here. ``join(timeout=...)`` turns a + reintroduced deadlock into a CI failure instead of a hung job. + """ + port = misc_util.get_free_port() + ctx = mp.get_context("spawn") + procs = [] + for rank in range(WORLD_SIZE): + p = ctx.Process( + target=_on_train_end_fail_worker, args=(rank, WORLD_SIZE, port) + ) + p.start() + procs.append(p) + for i, p in enumerate(procs): + p.join(timeout=120) + if p.is_alive(): + p.terminate() + raise RuntimeError( + f"worker-{i} deadlocked on a rank0 fit failure (timed out)." + ) + if p.exitcode != 0: + raise RuntimeError( + f"worker-{i} did not raise the coordinated error " + f"(exitcode={p.exitcode})." + ) + if __name__ == "__main__": unittest.main() diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 83e095192..d3ebf82c2 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -184,15 +184,14 @@ def train_offline( # reports). ``out + x`` would equal it only without normalization. x0 = x.clone() if verbose else None - # Use FAISS GPU compute when a faiss-gpu build is present; an explicit - # ``gpu`` in faiss_kmeans_kwargs always wins. NB faiss reads ``gpu`` as a - # GPU *count* (1 = one GPU = the current/rank0 device), not a device - # index — passing an index of 0 is falsy and silently falls back to CPU. + # Default to a CPU fit. faiss reads ``gpu`` as a GPU *count*, not a + # device index (and ``1 == True`` collapses to all GPUs), so it cannot + # pin this rank0-only fit to a single device without sharding faiss + # memory onto the other ranks' GPUs. The fit is a bounded one-shot over + # the reservoir subsample, so CPU is cheap; set ``gpu`` explicitly in + # faiss_kmeans_kwargs (e.g. ``True`` for all GPUs) to opt into GPU. kwargs = dict(self.faiss_kmeans_kwargs) - if "gpu" not in kwargs: - kwargs["gpu"] = ( - 1 if (faiss.get_num_gpus() > 0 and torch.cuda.is_available()) else False - ) + kwargs.setdefault("gpu", False) # Chunk index.search to cap peak memory (~1 GB at 500K × 512 × 4B). SEARCH_CHUNK = 500_000 From 33acbe6caad213a603acd544fef82928956be55b Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 06:16:00 +0000 Subject: [PATCH 11/46] [review] SID: log the FAISS fit device (CPU/GPU) Announce CPU vs GPU + N/D at the start of train_offline so the CPU default isn't silent (configs that don't set faiss_kmeans_kwargs.gpu now fit on CPU). Gated by verbose (on_train_end passes verbose=True). Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/modules/sid/residual_kmeans_quantizer.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index d3ebf82c2..21af2f9af 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -192,6 +192,15 @@ def train_offline( # faiss_kmeans_kwargs (e.g. ``True`` for all GPUs) to opt into GPU. kwargs = dict(self.faiss_kmeans_kwargs) kwargs.setdefault("gpu", False) + if verbose: + logger.info( + "[ResidualKMeansQuantizer] fitting %d-layer codebook on %s " + "(N=%d, D=%d); set faiss_kmeans_kwargs.gpu to change.", + self.n_layers, + "GPU" if kwargs["gpu"] else "CPU", + N, + self.embed_dim, + ) # Chunk index.search to cap peak memory (~1 GB at 500K × 512 × 4B). SEARCH_CHUNK = 500_000 From 23c552cf06981f730aa13b1174410d50ddeb79e7 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 06:17:31 +0000 Subject: [PATCH 12/46] [chore] bump version to 1.2.18 Merge upstream/master (1.2.17, incl. #540 DlrmHSTU fix) and bump. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tzrec/version.py b/tzrec/version.py index c0c16b619..52c53fa4f 100644 --- a/tzrec/version.py +++ b/tzrec/version.py @@ -9,4 +9,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.2.17" +__version__ = "1.2.18" From 3261c2ce04dc36ef0e2d462c063f7532cd9c19d8 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 07:41:33 +0000 Subject: [PATCH 13/46] [review] SID: address 23c552c review (test timeout, N>=K assert, cap test, doc) - test_on_train_end_ddp: route both DDP tests through a shared _run_dist_workers(... timeout=120) so a success-path deadlock (e.g. a dropped barrier) fails CI instead of hanging. (#1) - train_offline: assert N >= max(n_embed_list) so a too-small corpus fails loudly instead of faiss silently fitting a degenerate codebook. (#2) - add test_sample_cap_from_train_sample_size covering the explicit train_sample_size branch + per-rank ceil-div across world_size. (#3) - update_metric docstring: note mse/rel_loss are meaningful only with normalize_residuals=False. (#4) Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 7 ++ tzrec/models/sid_rqkmeans_test.py | 96 ++++++++++++------- .../modules/sid/residual_kmeans_quantizer.py | 7 ++ 3 files changed, 74 insertions(+), 36 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 66f87cdd2..477580e4b 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -241,6 +241,13 @@ def update_metric( ) -> None: """Update metric state. + Note: ``mse``/``rel_loss`` compare ``input_embedding`` against the + centroid-sum reconstruction. They are meaningful reconstruction + metrics only with ``normalize_residuals=False`` (the default); with + normalization the centroids live on the rescaled-residual scale, so + the two quantities don't share a scale (same caveat the train_offline + per-layer log carries). + Args: predictions (dict): a dict of predicted result. batch (Batch): input batch data. diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index e54e44f53..76a1bda0e 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -43,7 +43,12 @@ def _make_batch(batch_size: int, input_dim: int, device: str = "cpu") -> Batch: def _build_model( - input_dim=32, n_layers=2, niter=5, codebook=None, normalize_residuals=False + input_dim=32, + n_layers=2, + niter=5, + codebook=None, + normalize_residuals=False, + train_sample_size=0, ) -> SidRqkmeans: """Build a SidRqkmeans configured for offline FAISS fit. @@ -63,6 +68,7 @@ def _build_model( normalize_residuals=normalize_residuals, faiss_kmeans_kwargs=faiss_kwargs, embedding_feature_name="item_emb", + train_sample_size=train_sample_size, ) return SidRqkmeans( model_config=model_pb2.ModelConfig(sid_rqkmeans=cfg), @@ -81,9 +87,17 @@ def _create_model( niter=5, codebook=None, normalize_residuals=False, + train_sample_size=0, ): """Create a SidRqkmeans on CPU with params initialized.""" - model = _build_model(input_dim, n_layers, niter, codebook, normalize_residuals) + model = _build_model( + input_dim, + n_layers, + niter, + codebook, + normalize_residuals, + train_sample_size, + ) init_parameters(model, device=torch.device("cpu")) return model @@ -96,6 +110,23 @@ def test_proto_parse(self) -> None: self.assertEqual(model._n_seen, 0) self.assertIsNone(model._reservoir) + def test_sample_cap_from_train_sample_size(self) -> None: + """Explicit train_sample_size drives the per-rank cap (ceil-div).""" + from unittest import mock + + # Single process (world_size=1): cap == train_sample_size. + model = self._create_model(train_sample_size=900) + self.assertEqual(model._sample_cap, 900) + + # Per-rank ceil-div across world_size (patch dist + recompute the cap). + for world_size, expected in [(4, 225), (7, 129), (1000, 1)]: + with ( + mock.patch.object(dist, "is_initialized", return_value=True), + mock.patch.object(dist, "get_world_size", return_value=world_size), + ): + model._init_reservoir() + self.assertEqual(model._sample_cap, expected) + def test_predict_collects_buffer(self) -> None: """In train mode, predict reservoir-samples; never fits.""" B, input_dim = 8, 32 @@ -461,50 +492,43 @@ def _boom(*args, **kwargs): ) +def _run_dist_workers(worker, world_size: int, timeout: int = 120) -> None: + """Spawn ``world_size`` procs running ``worker(rank, world_size, port)``. + + Joins with a timeout so a deadlock (e.g. a dropped barrier / reordered + broadcast) fails the test instead of hanging CI, and raises on a hung or + nonzero-exit worker. + """ + port = misc_util.get_free_port() + ctx = mp.get_context("spawn") + procs = [] + for rank in range(world_size): + p = ctx.Process(target=worker, args=(rank, world_size, port)) + p.start() + procs.append(p) + for i, p in enumerate(procs): + p.join(timeout=timeout) + if p.is_alive(): + p.terminate() + raise RuntimeError(f"worker-{i} deadlocked (timed out after {timeout}s).") + if p.exitcode != 0: + raise RuntimeError(f"worker-{i} failed (exitcode={p.exitcode}).") + + class SidRqkmeansDistTest(unittest.TestCase): """2-rank test for SidRqkmeans.on_train_end (gather -> fit -> broadcast).""" def test_on_train_end_ddp(self) -> None: - port = misc_util.get_free_port() - ctx = mp.get_context("spawn") - procs = [] - for rank in range(WORLD_SIZE): - p = ctx.Process(target=_on_train_end_worker, args=(rank, WORLD_SIZE, port)) - p.start() - procs.append(p) - for i, p in enumerate(procs): - p.join() - if p.exitcode != 0: - raise RuntimeError(f"worker-{i} failed (exitcode={p.exitcode}).") + _run_dist_workers(_on_train_end_worker, WORLD_SIZE) def test_on_train_end_ddp_rank0_failure(self) -> None: """A rank0-only fit failure raises on every rank — never deadlocks. Guards the status-flag-before-centroid-broadcast ordering: a regression - that reordered/dropped it would hang here. ``join(timeout=...)`` turns a - reintroduced deadlock into a CI failure instead of a hung job. + that reordered/dropped it would hang, which the join timeout turns into + a CI failure instead of a hung job. """ - port = misc_util.get_free_port() - ctx = mp.get_context("spawn") - procs = [] - for rank in range(WORLD_SIZE): - p = ctx.Process( - target=_on_train_end_fail_worker, args=(rank, WORLD_SIZE, port) - ) - p.start() - procs.append(p) - for i, p in enumerate(procs): - p.join(timeout=120) - if p.is_alive(): - p.terminate() - raise RuntimeError( - f"worker-{i} deadlocked on a rank0 fit failure (timed out)." - ) - if p.exitcode != 0: - raise RuntimeError( - f"worker-{i} did not raise the coordinated error " - f"(exitcode={p.exitcode})." - ) + _run_dist_workers(_on_train_end_fail_worker, WORLD_SIZE) if __name__ == "__main__": diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 21af2f9af..a2648d2b8 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -178,6 +178,13 @@ def train_offline( ) x = inputs.detach().to("cpu", torch.float32).contiguous().clone() N = x.shape[0] + # Fail loudly on a too-small corpus: faiss.Kmeans only warns (not + # errors) when N < K and returns a degenerate codebook, which the + # all-zero poison guard in KMeansLayer would not catch. + max_k = max(self.n_embed_list) + assert N >= max_k, ( + f"need >= {max_k} points to fit the codebook (largest layer K), got N={N}" + ) out = torch.zeros_like(x) # Original input, kept only for the log: the per-layer diagnostic is the # cumulative recon error of x0 by the centroid sum (what update_metric From 39017abf87849b6863d4b02b6bfd41f935a3d3ef Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 08:20:49 +0000 Subject: [PATCH 14/46] [review] checkpoint_util: force only overrides the dedupe Drop the redundant `or force` in `want` (the only caller pairs force with final=True, so final already sets want). `force` now purely bypasses the per-step dedupe, matching its docstring; behavior is identical for the on_train_end tail-save caller. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/utils/checkpoint_util.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tzrec/utils/checkpoint_util.py b/tzrec/utils/checkpoint_util.py index 5bafd825c..78555a550 100644 --- a/tzrec/utils/checkpoint_util.py +++ b/tzrec/utils/checkpoint_util.py @@ -419,16 +419,17 @@ def maybe_save( data_timestamp: this rank's consumed event-time (seconds), -1.0 if none; reconciled across workers (quorum) for the event-time trigger. final: force a save (still subject to the dedupe), e.g. at train end. - force: save even if this step was already saved (bypasses the - per-step dedupe), e.g. when end-of-train work mutated the model - state at the final step (see ``on_train_end``). + force: when a save is already requested (e.g. ``final``), bypass the + per-step dedupe so it fires even if this step was already saved + — e.g. when end-of-train work mutated the model state at the + final step (see ``on_train_end``). No effect on its own. Returns: True if a checkpoint was saved. """ data_ts = self._reconcile_event_time(data_timestamp) - want = final or force + want = final if self._save_steps > 0 and step > 0 and step % self._save_steps == 0: want = True if ( From 5afbd5ed23437f0360900f74e6f99a3326f1576a Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 08:42:07 +0000 Subject: [PATCH 15/46] [review] checkpoint maybe_save: clarify final vs force docstrings Reword the `final`/`force` param docs to remove the verbal collision (both previously described as "force a save"). `final` sets `want`; `force` only relaxes the per-step dedupe and is a no-op on its own. Docstring-only; no behavior change. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/utils/checkpoint_util.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tzrec/utils/checkpoint_util.py b/tzrec/utils/checkpoint_util.py index 78555a550..612cf023d 100644 --- a/tzrec/utils/checkpoint_util.py +++ b/tzrec/utils/checkpoint_util.py @@ -418,11 +418,15 @@ def maybe_save( epoch: current epoch; enables the epoch trigger when not None. data_timestamp: this rank's consumed event-time (seconds), -1.0 if none; reconciled across workers (quorum) for the event-time trigger. - final: force a save (still subject to the dedupe), e.g. at train end. - force: when a save is already requested (e.g. ``final``), bypass the - per-step dedupe so it fires even if this step was already saved - — e.g. when end-of-train work mutated the model state at the - final step (see ``on_train_end``). No effect on its own. + final: request a save unconditionally (still subject to the dedupe), + e.g. at train end. This sets ``want``; it does not bypass the + per-step dedupe — that is what ``force`` is for. + force: bypass the per-step dedupe so a wanted save fires even if this + step was already saved — e.g. when end-of-train work mutated the + model state at the already-saved final step (see ``on_train_end``). + Orthogonal to ``final``: ``force`` only relaxes the dedupe and has + no effect on its own (it still needs ``want``, which ``final`` or a + cadence trigger supplies). Returns: True if a checkpoint was saved. From 415b8a38dcac74d428abee3e8045498e59c51699 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 12:01:39 +0000 Subject: [PATCH 16/46] [refactor] SidRqkmeans: single-process only; raise under DDP Drop the DDP path in on_train_end (gather_object -> rank0 FAISS fit -> status/centroid broadcast -> barrier). SidRqkmeans now supports single-process training only: on_train_end raises RuntimeError when world_size > 1, and fits the codebook on the local reservoir otherwise. Simplify _init_reservoir accordingly (no per-rank cap split). Replace the multi-process DDP tests (gather/broadcast/rank0-failure) with a guard test asserting on_train_end raises under world_size>1; trim now-unused imports. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 99 +++++------------- tzrec/models/sid_rqkmeans_test.py | 162 ++++-------------------------- 2 files changed, 44 insertions(+), 217 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 477580e4b..8375ff981 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -97,18 +97,15 @@ def __init__( def _init_reservoir(self) -> None: """Set up the bounded host reservoir for the end-of-loop FAISS fit. - Per-rank cap: target ``train_sample_size`` when set (>0), else the - points the FAISS fit subsamples to - (``ResidualKMeansQuantizer.default_fit_sample_size``), split across - ranks — rather than buffer the whole corpus. + Caps at ``train_sample_size`` when set (>0), else the points the FAISS + fit subsamples to (``ResidualKMeansQuantizer.default_fit_sample_size``) + — rather than buffer the whole corpus. Single-process only (see the + world_size guard in :meth:`on_train_end`), so no per-rank split. """ target = self._model_config.train_sample_size - global_target = ( - target if target > 0 else self._quantizer.default_fit_sample_size() + self._sample_cap = max( + 1, target if target > 0 else self._quantizer.default_fit_sample_size() ) - world_size = dist.get_world_size() if dist.is_initialized() else 1 - # ceil div: round up so the per-rank caps together cover global_target. - self._sample_cap = max(1, -(-global_target // world_size)) # Allocated lazily on the first batch. _n_filled = used slots; # _n_seen = running count for the accept prob. @@ -272,79 +269,35 @@ def on_train_end(self) -> bool: """Fit the FAISS codebook once, after the train_eval loop exits. Overrides :meth:`BaseModel.on_train_end` (called unconditionally by - ``tzrec.main``). DDP: every rank gather_objects its reservoir to rank0, - which fits and broadcasts the centroids back. + ``tzrec.main``). Single-process only: the fit runs on one process over + its local reservoir, with no cross-rank gather/broadcast. - An empty reservoir only happens for a pathologically tiny corpus - (rebalance splits rows across ``num_workers * world_size``); it then - fails fast via the fit-status broadcast rather than hanging. + An empty reservoir only happens for a pathologically tiny corpus; the + fit is then skipped and ``False`` returned. Returns: is_ckpt_after_train (bool): ``True`` if the codebook was fitted - (centroids changed → force a final checkpoint). Only the - single-process path can return ``False`` (empty reservoir, fit - skipped); the DDP path either returns ``True`` or raises (an empty - gather makes rank0's fit fail, which the status broadcast turns - into a coordinated ``RuntimeError``). + (centroids changed → force a final checkpoint), ``False`` if the + fit was skipped (empty reservoir). + + Raises: + RuntimeError: if launched under distributed training + (``world_size > 1``). SidRqkmeans is single-process only. """ - is_ddp = ( - dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1 - ) + if ( + dist.is_available() + and dist.is_initialized() + and dist.get_world_size() > 1 + ): + raise RuntimeError( + "SidRqkmeans supports single-process training only " + f"(world_size=1); got world_size={dist.get_world_size()}. " + "Launch with --nproc-per-node=1." + ) local = self._reservoir_sample() self._reset_reservoir() - if is_ddp: - # Each rank ships its (capped) reservoir to rank0, which fits. - rank = dist.get_rank() - gathered: Optional[List[Optional[torch.Tensor]]] = ( - [None] * dist.get_world_size() if rank == 0 else None - ) - dist.gather_object(local, gathered, dst=0) - del local - fit_ok = True - if rank == 0: - assert gathered is not None - try: - full = torch.cat([g for g in gathered if g is not None], dim=0) - del gathered - logger.info( - "[SidRqkmeans.on_train_end] rank0 fitting FAISS " - "on %d samples (D=%d)." % (full.shape[0], full.shape[1]) - ) - self._quantizer.train_offline(full, verbose=True) - del full - except Exception: # noqa: BLE001 - # Don't raise yet — peers would hang on the broadcast below. - # Signal failure via the status flag so all ranks raise. - # logger.exception keeps the traceback so the rank0-only - # failure is diagnosable from the log. - fit_ok = False - logger.exception( - "[SidRqkmeans.on_train_end] rank0 FAISS fit failed" - ) - # Broadcast rank0's status (int, not bool — see NCCL note below) so - # a rank0-only failure makes all ranks raise instead of deadlocking. - status = torch.tensor( - [1 if fit_ok else 0], - device=self._quantizer.layers[0].centroids.device, - ) - dist.broadcast(status, src=0) - if int(status.item()) == 0: - raise RuntimeError( - "[SidRqkmeans.on_train_end] FAISS fit failed on rank0; " - "see rank0 logs for the underlying error." - ) - # Broadcast centroids; set the init flag locally (avoids - # broadcasting a bool buffer — NCCL bool support is inconsistent). - # All ranks are in lockstep, so a local mark_initialized_() agrees. - for layer in self._quantizer.layers: - dist.broadcast(layer.centroids, src=0) - layer.mark_initialized_() - dist.barrier() - return True - - # Single-process: guard an empty reservoir with a plain local check. if local.shape[0] == 0: logger.warning( "[SidRqkmeans.on_train_end] empty reservoir; skipping FAISS " diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index 76a1bda0e..00a320046 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -9,23 +9,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import unittest import torch import torch.distributed as dist -import torch.multiprocessing as mp from torchrec import KeyedTensor from tzrec.datasets.utils import BASE_DATA_GROUP, Batch from tzrec.models.sid_rqkmeans import SidRqkmeans from tzrec.protos import model_pb2 from tzrec.protos.models import sid_model_pb2 -from tzrec.utils import misc_util from tzrec.utils.state_dict_util import init_parameters -WORLD_SIZE = 2 - def _batch_from_rows(rows: torch.Tensor) -> Batch: """Wrap explicit ``item_emb`` rows in a minimal Batch.""" @@ -52,8 +47,6 @@ def _build_model( ) -> SidRqkmeans: """Build a SidRqkmeans configured for offline FAISS fit. - Module-level (not a method) so the spawned DDP workers below can build - the same model; callers move it to a device / init params as needed. SID models read the item-embedding dense feature directly from the batch and do not consume feature_groups, so none is set. """ @@ -111,21 +104,14 @@ def test_proto_parse(self) -> None: self.assertIsNone(model._reservoir) def test_sample_cap_from_train_sample_size(self) -> None: - """Explicit train_sample_size drives the per-rank cap (ceil-div).""" - from unittest import mock - - # Single process (world_size=1): cap == train_sample_size. + """train_sample_size (when set) drives the reservoir cap directly.""" + # Explicit train_sample_size: cap == train_sample_size. model = self._create_model(train_sample_size=900) self.assertEqual(model._sample_cap, 900) - # Per-rank ceil-div across world_size (patch dist + recompute the cap). - for world_size, expected in [(4, 225), (7, 129), (1000, 1)]: - with ( - mock.patch.object(dist, "is_initialized", return_value=True), - mock.patch.object(dist, "get_world_size", return_value=world_size), - ): - model._init_reservoir() - self.assertEqual(model._sample_cap, expected) + # Default (train_sample_size=0): cap == the FAISS fit's subsample size. + model = self._create_model() + self.assertEqual(model._sample_cap, model._quantizer.default_fit_sample_size()) def test_predict_collects_buffer(self) -> None: """In train mode, predict reservoir-samples; never fits.""" @@ -352,6 +338,19 @@ def test_on_train_end_noop_on_empty_buffer(self) -> None: # No fit happened, so no tail checkpoint is requested. self.assertFalse(model.on_train_end()) # should not raise + def test_on_train_end_raises_under_ddp(self) -> None: + """SidRqkmeans is single-process only: world_size>1 must raise.""" + from unittest import mock + + model = self._create_model() + with ( + mock.patch.object(dist, "is_available", return_value=True), + mock.patch.object(dist, "is_initialized", return_value=True), + mock.patch.object(dist, "get_world_size", return_value=2), + self.assertRaisesRegex(RuntimeError, "single-process"), + ): + model.on_train_end() + def test_post_fit_checkpoint_round_trips(self) -> None: """Fit → save state_dict → load into fresh instance → predict. @@ -406,130 +405,5 @@ def test_mid_fit_checkpoint_rejected_on_load(self) -> None: fresh.load_state_dict(sd) -# -------------------------------------------------------------------------- -# Distributed (multi-process) test for the DDP on_train_end path: the -# cross-rank gather_object -> FAISS fit -> broadcast sequence the in-process -# tests above cannot reach. NCCL on GPU when >=2 devices, else gloo/CPU. -# -------------------------------------------------------------------------- -def _init_dist(rank: int, world_size: int, port: int) -> torch.device: - os.environ["RANK"] = str(rank) - os.environ["WORLD_SIZE"] = str(world_size) - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = str(port) - use_cuda = torch.cuda.is_available() and torch.cuda.device_count() >= world_size - if use_cuda: - torch.cuda.set_device(rank) - dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) - return torch.device(f"cuda:{rank}") - dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) - return torch.device("cpu") - - -def _on_train_end_worker(rank: int, world_size: int, port: int) -> None: - device = _init_dist(rank, world_size, port) - input_dim, n_layers, k = 16, 2, 16 - model = _build_model(input_dim, n_layers, codebook=[k] * n_layers).to(device) - model.train() - - torch.manual_seed(100 + rank) - for _ in range(6): - model.predict(_make_batch(32, input_dim, device)) - assert model._n_seen == 6 * 32, f"rank{rank}: reservoir not filled" - - # gather_object -> rank0 FAISS fit -> broadcast centroids + fill flag. - # Every rank fitted/received the codebook, so each requests a tail ckpt. - assert model.on_train_end(), f"rank{rank}: on_train_end should request ckpt" - - for layer in model._quantizer.layers: - assert bool(layer._is_initialized.item()), f"rank{rank}: layer uninit" - assert layer.centroids.abs().sum().item() > 0.0, f"rank{rank}: zero centroids" - # Centroids were broadcast from rank0 -> must be bit-identical across ranks. - for layer in model._quantizer.layers: - cmin, cmax = layer.centroids.clone(), layer.centroids.clone() - dist.all_reduce(cmin, op=dist.ReduceOp.MIN) - dist.all_reduce(cmax, op=dist.ReduceOp.MAX) - assert torch.allclose(cmin, cmax), f"rank{rank}: centroids differ across ranks" - - model.eval() - codes = model.predict(_make_batch(8, input_dim, device))["codes"] - assert codes.shape == (8, n_layers), f"rank{rank}: bad codes shape {codes.shape}" - assert (codes >= 0).all() and (codes < k).all(), f"rank{rank}: codes out of range" - dist.destroy_process_group() - - -def _on_train_end_fail_worker(rank: int, world_size: int, port: int) -> None: - """Worker that forces rank0's FAISS fit to fail. - - Every rank must then raise the coordinated ``RuntimeError`` (driven by the - fit-status broadcast) instead of deadlocking on the centroid broadcast. A - worker returns 0 only if it caught that expected error. - """ - device = _init_dist(rank, world_size, port) - input_dim, n_layers, k = 16, 2, 16 - model = _build_model(input_dim, n_layers, codebook=[k] * n_layers).to(device) - model.train() - for _ in range(6): - model.predict(_make_batch(32, input_dim, device)) - - # Force the rank0-only fit to raise (no faiss needed: only rank0 fits, and - # we replace its fit). The status flag must turn this into an all-ranks - # raise, not a hang. - if rank == 0: - - def _boom(*args, **kwargs): - raise RuntimeError("forced rank0 fit failure") - - model._quantizer.train_offline = _boom - - try: - model.on_train_end() - except RuntimeError: - dist.destroy_process_group() - return # expected: coordinated failure reached this rank - dist.destroy_process_group() - raise AssertionError( - f"rank{rank}: on_train_end did not raise on a rank0 fit failure" - ) - - -def _run_dist_workers(worker, world_size: int, timeout: int = 120) -> None: - """Spawn ``world_size`` procs running ``worker(rank, world_size, port)``. - - Joins with a timeout so a deadlock (e.g. a dropped barrier / reordered - broadcast) fails the test instead of hanging CI, and raises on a hung or - nonzero-exit worker. - """ - port = misc_util.get_free_port() - ctx = mp.get_context("spawn") - procs = [] - for rank in range(world_size): - p = ctx.Process(target=worker, args=(rank, world_size, port)) - p.start() - procs.append(p) - for i, p in enumerate(procs): - p.join(timeout=timeout) - if p.is_alive(): - p.terminate() - raise RuntimeError(f"worker-{i} deadlocked (timed out after {timeout}s).") - if p.exitcode != 0: - raise RuntimeError(f"worker-{i} failed (exitcode={p.exitcode}).") - - -class SidRqkmeansDistTest(unittest.TestCase): - """2-rank test for SidRqkmeans.on_train_end (gather -> fit -> broadcast).""" - - def test_on_train_end_ddp(self) -> None: - _run_dist_workers(_on_train_end_worker, WORLD_SIZE) - - def test_on_train_end_ddp_rank0_failure(self) -> None: - """A rank0-only fit failure raises on every rank — never deadlocks. - - Guards the status-flag-before-centroid-broadcast ordering: a regression - that reordered/dropped it would hang, which the join timeout turns into - a CI failure instead of a hung job. - """ - _run_dist_workers(_on_train_end_fail_worker, WORLD_SIZE) - - if __name__ == "__main__": unittest.main() From b27eb7b5b55e7dab7a7d8b014feb51df7567ca73 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 12:05:27 +0000 Subject: [PATCH 17/46] [refactor] SidRqkmeans: move DDP guard to __init__ (fail fast) Raise the single-process world_size>1 guard at construction instead of in on_train_end, so an accidental multi-rank launch fails immediately rather than after a full training pass. Update the guard test to assert __init__ raises. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 36 +++++++++++++++---------------- tzrec/models/sid_rqkmeans_test.py | 7 +++--- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 8375ff981..3c87cedae 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -72,6 +72,20 @@ def __init__( ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) + # Single-process only: the FAISS fit runs on one process over its local + # reservoir, with no cross-rank gather/broadcast. Fail fast here rather + # than after a full (wasted) training pass. + if ( + dist.is_available() + and dist.is_initialized() + and dist.get_world_size() > 1 + ): + raise RuntimeError( + "SidRqkmeans supports single-process training only " + f"(world_size=1); got world_size={dist.get_world_size()}. " + "Launch with --nproc-per-node=1." + ) + cfg = self._model_config # SidRqkmeans proto message # config_to_kwargs yields Struct numbers as floats; coerce back to int. @@ -100,7 +114,7 @@ def _init_reservoir(self) -> None: Caps at ``train_sample_size`` when set (>0), else the points the FAISS fit subsamples to (``ResidualKMeansQuantizer.default_fit_sample_size``) — rather than buffer the whole corpus. Single-process only (see the - world_size guard in :meth:`on_train_end`), so no per-rank split. + world_size guard in ``__init__``), so no per-rank split. """ target = self._model_config.train_sample_size self._sample_cap = max( @@ -269,8 +283,9 @@ def on_train_end(self) -> bool: """Fit the FAISS codebook once, after the train_eval loop exits. Overrides :meth:`BaseModel.on_train_end` (called unconditionally by - ``tzrec.main``). Single-process only: the fit runs on one process over - its local reservoir, with no cross-rank gather/broadcast. + ``tzrec.main``). Single-process only (enforced by the world_size guard + in ``__init__``): the fit runs on one process over its local reservoir, + with no cross-rank gather/broadcast. An empty reservoir only happens for a pathologically tiny corpus; the fit is then skipped and ``False`` returned. @@ -279,22 +294,7 @@ def on_train_end(self) -> bool: is_ckpt_after_train (bool): ``True`` if the codebook was fitted (centroids changed → force a final checkpoint), ``False`` if the fit was skipped (empty reservoir). - - Raises: - RuntimeError: if launched under distributed training - (``world_size > 1``). SidRqkmeans is single-process only. """ - if ( - dist.is_available() - and dist.is_initialized() - and dist.get_world_size() > 1 - ): - raise RuntimeError( - "SidRqkmeans supports single-process training only " - f"(world_size=1); got world_size={dist.get_world_size()}. " - "Launch with --nproc-per-node=1." - ) - local = self._reservoir_sample() self._reset_reservoir() diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index 00a320046..ef90d6032 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -338,18 +338,17 @@ def test_on_train_end_noop_on_empty_buffer(self) -> None: # No fit happened, so no tail checkpoint is requested. self.assertFalse(model.on_train_end()) # should not raise - def test_on_train_end_raises_under_ddp(self) -> None: - """SidRqkmeans is single-process only: world_size>1 must raise.""" + def test_init_raises_under_ddp(self) -> None: + """SidRqkmeans is single-process only: world_size>1 fails fast in init.""" from unittest import mock - model = self._create_model() with ( mock.patch.object(dist, "is_available", return_value=True), mock.patch.object(dist, "is_initialized", return_value=True), mock.patch.object(dist, "get_world_size", return_value=2), self.assertRaisesRegex(RuntimeError, "single-process"), ): - model.on_train_end() + self._create_model() def test_post_fit_checkpoint_round_trips(self) -> None: """Fit → save state_dict → load into fresh instance → predict. From 6f7ae1dee07ad6ba6d94239e3a6519d2a08b4a35 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 12:08:47 +0000 Subject: [PATCH 18/46] [simplify] SidRqkmeans: drop dead max(1,...) cap clamp; fold test _build_model MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit _init_reservoir: both cap branches are always >= 1 now that the per-rank ceil-div is gone, so the max(1, ...) clamp is dead — drop it. Test: _build_model was module-level only to serve the (now-deleted) DDP worker processes; fold it into _create_model, its sole remaining caller. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 4 +- tzrec/models/sid_rqkmeans_test.py | 63 +++++++++++-------------------- 2 files changed, 24 insertions(+), 43 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 3c87cedae..65c5aab46 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -117,8 +117,8 @@ def _init_reservoir(self) -> None: world_size guard in ``__init__``), so no per-rank split. """ target = self._model_config.train_sample_size - self._sample_cap = max( - 1, target if target > 0 else self._quantizer.default_fit_sample_size() + self._sample_cap = ( + target if target > 0 else self._quantizer.default_fit_sample_size() ) # Allocated lazily on the first batch. _n_filled = used slots; diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index ef90d6032..3b7aded5b 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -37,39 +37,6 @@ def _make_batch(batch_size: int, input_dim: int, device: str = "cpu") -> Batch: return _batch_from_rows(torch.randn(batch_size, input_dim, device=device)) -def _build_model( - input_dim=32, - n_layers=2, - niter=5, - codebook=None, - normalize_residuals=False, - train_sample_size=0, -) -> SidRqkmeans: - """Build a SidRqkmeans configured for offline FAISS fit. - - SID models read the item-embedding dense feature directly from the batch - and do not consume feature_groups, so none is set. - """ - from google.protobuf.struct_pb2 import Struct - - n_embed_list = codebook if codebook is not None else [16] * n_layers - faiss_kwargs = Struct() - faiss_kwargs.update({"niter": niter, "verbose": False, "seed": 1234}) - cfg = sid_model_pb2.SidRqkmeans( - input_dim=input_dim, - codebook=n_embed_list, - normalize_residuals=normalize_residuals, - faiss_kmeans_kwargs=faiss_kwargs, - embedding_feature_name="item_emb", - train_sample_size=train_sample_size, - ) - return SidRqkmeans( - model_config=model_pb2.ModelConfig(sid_rqkmeans=cfg), - features=[], - labels=[], - ) - - class SidRqkmeansOfflineTest(unittest.TestCase): """Single-process tests for SidRqkmeans (FAISS-only).""" @@ -82,14 +49,28 @@ def _create_model( normalize_residuals=False, train_sample_size=0, ): - """Create a SidRqkmeans on CPU with params initialized.""" - model = _build_model( - input_dim, - n_layers, - niter, - codebook, - normalize_residuals, - train_sample_size, + """Build a SidRqkmeans on CPU with params initialized. + + SID models read the item-embedding dense feature directly from the + batch and do not consume feature_groups, so none is set. + """ + from google.protobuf.struct_pb2 import Struct + + n_embed_list = codebook if codebook is not None else [16] * n_layers + faiss_kwargs = Struct() + faiss_kwargs.update({"niter": niter, "verbose": False, "seed": 1234}) + cfg = sid_model_pb2.SidRqkmeans( + input_dim=input_dim, + codebook=n_embed_list, + normalize_residuals=normalize_residuals, + faiss_kmeans_kwargs=faiss_kwargs, + embedding_feature_name="item_emb", + train_sample_size=train_sample_size, + ) + model = SidRqkmeans( + model_config=model_pb2.ModelConfig(sid_rqkmeans=cfg), + features=[], + labels=[], ) init_parameters(model, device=torch.device("cpu")) return model From 5827d5b5b9fee48de07d9ce7a281d048d35b709d Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 12:12:30 +0000 Subject: [PATCH 19/46] [style] ruff-format the __init__ DDP guard (collapse to one line) The world_size guard fit within the line limit, so ruff format collapses the parenthesized multi-line `if` to a single line. ruff check passed but ruff-format (pre-commit / codestyle CI) did not. No logic change. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 65c5aab46..12641974a 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -75,11 +75,7 @@ def __init__( # Single-process only: the FAISS fit runs on one process over its local # reservoir, with no cross-rank gather/broadcast. Fail fast here rather # than after a full (wasted) training pass. - if ( - dist.is_available() - and dist.is_initialized() - and dist.get_world_size() > 1 - ): + if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: raise RuntimeError( "SidRqkmeans supports single-process training only " f"(world_size=1); got world_size={dist.get_world_size()}. " From 4e2e87848f85112452014a63b6c69271388fecbd Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 12:59:41 +0000 Subject: [PATCH 20/46] =?UTF-8?q?[refactor]=20SidRqkmeans:=20CPU-only=20?= =?UTF-8?q?=E2=80=94=20raise=20on=20visible=20CUDA,=20drop=20device=20copi?= =?UTF-8?q?es?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SID RQ-KMeans is now CPU-only by decision. __init__ raises RuntimeError when torch.cuda.is_available() so all tensors (embeddings, reservoir, FAISS fit) stay on the host; run with CUDA_VISIBLE_DEVICES="". Remove the now-dead CPU<->GPU copies: - _reservoir_add: x is already on host, so .to("cpu", float32) becomes a plain float32 cast; drop idx.to(x.device). - train_offline: input is host float32 (.to("cpu") -> .to(float32)); drop centroids.cpu() and the explicit device="cpu" on search indices. - Drop the faiss-GPU passthrough: pop any "gpu" kwarg so a stale config / faiss-gpu build can't target an absent GPU; log line is CPU-only. Tests: setUp simulates a CPU-only host (GPU CI runners have CUDA, which would otherwise trip the new guard); add test_init_raises_on_gpu. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 21 +++++++--- tzrec/models/sid_rqkmeans_test.py | 19 ++++++++- .../modules/sid/residual_kmeans_quantizer.py | 41 ++++++++----------- 3 files changed, 50 insertions(+), 31 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 12641974a..c393b4943 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -72,6 +72,16 @@ def __init__( ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) + # CPU-only: everything (embeddings, reservoir, FAISS fit) stays on the + # host, so there are no device copies on the train path. Refuse to run + # when CUDA is visible rather than silently shuttling tensors to/from a + # GPU; launch with CUDA_VISIBLE_DEVICES="" (or on a CPU-only host). + if torch.cuda.is_available(): + raise RuntimeError( + "SidRqkmeans is CPU-only, but a CUDA device is visible. " + 'Run with CUDA_VISIBLE_DEVICES="" (or on a CPU-only host).' + ) + # Single-process only: the FAISS fit runs on one process over its local # reservoir, with no cross-rank gather/broadcast. Fail fast here rather # than after a full (wasted) training pass. @@ -138,11 +148,12 @@ def _reservoir_add(self, x: torch.Tensor) -> None: if self._reservoir is None: self._reservoir = torch.empty(cap, x.shape[1], dtype=torch.float32) - # Phase 1: fill empty slots first. Copy only the rows we keep to host. + # Phase 1: fill empty slots first. x is already on the host (CPU-only + # model), so this is a dtype cast into the reservoir, not a device copy. if self._n_filled < cap: take = min(x.shape[0], cap - self._n_filled) self._reservoir[self._n_filled : self._n_filled + take] = x[:take].to( - "cpu", dtype=torch.float32 + torch.float32 ) self._n_filled += take self._n_seen += take @@ -151,9 +162,7 @@ def _reservoir_add(self, x: torch.Tensor) -> None: return # Phase 2: row j enters with prob cap/(n_seen+j+1), displacing a random - # slot. The accept decision needs only counts, so compute it on host and - # copy ONLY accepted rows (in steady state, almost none) — avoiding the - # whole-batch GPU->CPU copy. float64 keeps n_seen+j+1 exact past 2**24. + # slot. float64 keeps n_seen+j+1 exact past 2**24. r = x.shape[0] pos = self._n_seen + torch.arange(r) accept = torch.rand(r) < (cap / (pos + 1).to(torch.float64)) @@ -161,7 +170,7 @@ def _reservoir_add(self, x: torch.Tensor) -> None: if idx.numel() > 0: slots = torch.randint(0, cap, (idx.numel(),)) # Slot collisions are last-write-wins; O(B/cap) bias, negligible here. - self._reservoir[slots] = x[idx.to(x.device)].to("cpu", dtype=torch.float32) + self._reservoir[slots] = x[idx].to(torch.float32) self._n_seen += r def _reservoir_sample(self) -> torch.Tensor: diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index 3b7aded5b..f29b9455a 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -10,6 +10,7 @@ # limitations under the License. import unittest +from unittest import mock import torch import torch.distributed as dist @@ -40,6 +41,14 @@ def _make_batch(batch_size: int, input_dim: int, device: str = "cpu") -> Batch: class SidRqkmeansOfflineTest(unittest.TestCase): """Single-process tests for SidRqkmeans (FAISS-only).""" + def setUp(self) -> None: + # SidRqkmeans is CPU-only and refuses to init when CUDA is visible. The + # GPU CI runners have CUDA, so simulate a CPU-only host for every + # construction-based test. (test_init_raises_on_gpu overrides this.) + patcher = mock.patch.object(torch.cuda, "is_available", return_value=False) + patcher.start() + self.addCleanup(patcher.stop) + def _create_model( self, input_dim=32, @@ -321,8 +330,6 @@ def test_on_train_end_noop_on_empty_buffer(self) -> None: def test_init_raises_under_ddp(self) -> None: """SidRqkmeans is single-process only: world_size>1 fails fast in init.""" - from unittest import mock - with ( mock.patch.object(dist, "is_available", return_value=True), mock.patch.object(dist, "is_initialized", return_value=True), @@ -331,6 +338,14 @@ def test_init_raises_under_ddp(self) -> None: ): self._create_model() + def test_init_raises_on_gpu(self) -> None: + """SidRqkmeans is CPU-only: a visible CUDA device fails fast in init.""" + with ( + mock.patch.object(torch.cuda, "is_available", return_value=True), + self.assertRaisesRegex(RuntimeError, "CPU-only"), + ): + self._create_model() + def test_post_fit_checkpoint_round_trips(self) -> None: """Fit → save state_dict → load into fresh instance → predict. diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index a2648d2b8..971ef9e3b 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -55,7 +55,8 @@ class ResidualKMeansQuantizer(ResidualQuantizer): residuals with no per-layer normalization). faiss_kmeans_kwargs (Dict|None): extra kwargs forwarded to ``faiss.Kmeans(D, K, **kwargs)`` (e.g. {'niter': 20, - 'gpu': True, 'verbose': True, 'spherical': False}). + 'verbose': True, 'spherical': False}). A ``gpu`` key is ignored — + the fit is CPU-only. """ def __init__( @@ -160,23 +161,21 @@ def train_offline( ) -> None: """Train the multi-layer codebook via offline FAISS K-Means. - The residual matrix stays a host (CPU) tensor. With a faiss-gpu build, - ``faiss.Kmeans`` runs the K-Means training (over its internally - subsampled set) on the GPU; the post-fit ``index.search`` assignment - still streams all N rows through in ``SEARCH_CHUNK``-sized chunks, so we - never hold the full (N, D) on the device. faiss-cpu runs the same path - on CPU. + CPU-only: ``inputs`` is already a host tensor (SidRqkmeans refuses to + run when CUDA is visible) and the FAISS fit runs on CPU. The post-fit + ``index.search`` assignment streams all N rows through in + ``SEARCH_CHUNK``-sized chunks to cap peak memory. Args: - inputs (Tensor): embedding matrix (N, D). Copied once to an owned - CPU float32 tensor; not mutated. + inputs (Tensor): embedding matrix (N, D) on CPU. Copied once to an + owned float32 tensor; not mutated. verbose (bool): print per-layer reconstruction loss. Default: True. """ - # Own a contiguous CPU float32 copy to update in place as the residual. + # Own a contiguous float32 copy to update in place as the residual. assert inputs.dim() == 2 and inputs.shape[1] == self.embed_dim, ( f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" ) - x = inputs.detach().to("cpu", torch.float32).contiguous().clone() + x = inputs.detach().to(torch.float32).contiguous().clone() N = x.shape[0] # Fail loudly on a too-small corpus: faiss.Kmeans only warns (not # errors) when N < K and returns a degenerate codebook, which the @@ -191,20 +190,16 @@ def train_offline( # reports). ``out + x`` would equal it only without normalization. x0 = x.clone() if verbose else None - # Default to a CPU fit. faiss reads ``gpu`` as a GPU *count*, not a - # device index (and ``1 == True`` collapses to all GPUs), so it cannot - # pin this rank0-only fit to a single device without sharding faiss - # memory onto the other ranks' GPUs. The fit is a bounded one-shot over - # the reservoir subsample, so CPU is cheap; set ``gpu`` explicitly in - # faiss_kmeans_kwargs (e.g. ``True`` for all GPUs) to opt into GPU. + # CPU-only fit: SidRqkmeans refuses to initialize when CUDA is visible, + # so the codebook is always built on CPU. Drop any stale ``gpu`` request + # from the config so a faiss-gpu build can't try to use an absent GPU. kwargs = dict(self.faiss_kmeans_kwargs) - kwargs.setdefault("gpu", False) + kwargs.pop("gpu", None) if verbose: logger.info( - "[ResidualKMeansQuantizer] fitting %d-layer codebook on %s " - "(N=%d, D=%d); set faiss_kmeans_kwargs.gpu to change.", + "[ResidualKMeansQuantizer] fitting %d-layer codebook on CPU " + "(N=%d, D=%d).", self.n_layers, - "GPU" if kwargs["gpu"] else "CPU", N, self.embed_dim, ) @@ -222,12 +217,12 @@ def train_offline( self.embed_dim, self.n_embed_list[layer_idx], **kwargs ) kmeans.train(x) - centroids = torch.as_tensor(kmeans.centroids, dtype=torch.float32).cpu() + centroids = torch.as_tensor(kmeans.centroids, dtype=torch.float32) for start in range(0, N, SEARCH_CHUNK): end = min(start + SEARCH_CHUNK, N) _, idx = kmeans.index.search(x[start:end], 1) - idx = torch.as_tensor(idx, device="cpu").reshape(-1).long() + idx = torch.as_tensor(idx).reshape(-1).long() q = centroids[idx] # (chunk, D) out[start:end] += q x[start:end] -= q # residual From 4773e2a657ac4746bac8d80c05deafc0d8df578f Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 13:03:31 +0000 Subject: [PATCH 21/46] [simplify] train_offline: assert host input; single-copy float32 own MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Assert inputs is not CUDA: the quantizer is a standalone module that now assumes host tensors (SidRqkmeans enforces CPU-only at __init__); make the contract local so misuse fails here, not opaquely inside faiss. - Replace `.to(float32).contiguous().clone()` with `.to(dtype=float32, copy=True).contiguous()` — one guaranteed owning copy instead of a chain that could double-copy on a non-contiguous/non-float32 input. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/modules/sid/residual_kmeans_quantizer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 971ef9e3b..1bfe20267 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -171,11 +171,15 @@ def train_offline( owned float32 tensor; not mutated. verbose (bool): print per-layer reconstruction loss. Default: True. """ - # Own a contiguous float32 copy to update in place as the residual. + # CPU-only: SidRqkmeans refuses to init when CUDA is visible, but this + # quantizer is a standalone module — assert the host-tensor contract it + # relies on so misuse fails here, not deep inside faiss. + assert not inputs.is_cuda, "train_offline is CPU-only; got a CUDA tensor" assert inputs.dim() == 2 and inputs.shape[1] == self.embed_dim, ( f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" ) - x = inputs.detach().to(torch.float32).contiguous().clone() + # Own one contiguous float32 copy to update in place as the residual. + x = inputs.detach().to(dtype=torch.float32, copy=True).contiguous() N = x.shape[0] # Fail loudly on a too-small corpus: faiss.Kmeans only warns (not # errors) when N < K and returns a degenerate codebook, which the From df83d070f4569c0dfa2e14543451843c790d41d4 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 02:34:01 +0000 Subject: [PATCH 22/46] [refactor] KMeansLayer.predict: use torch.cdist; drop _squared_euclidean_distance MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per maintainer review: replace the manual squared-L2 expansion with torch.cdist(batch, centroids).argmin(-1). argmin is invariant to the monotonic sqrt, so codes are identical for all non-degenerate inputs — verified bit-exact across a wide sweep (random shapes/dtypes incl. 1000x8192x128, all cdist compute_modes, large-magnitude/cancellation, and near-ties): 0 mismatches over ~140k rows. The only divergence is at *exact* equidistant ties (measure zero for real embeddings), where either centroid is equally near. Confirmed predict still scripts / FX-traces / torch.exports identically to eager. Removes the now-unused _squared_euclidean_distance helper + its unit test on this branch. NOTE: feat/sid_abstract's vector_quantize.py (RQ-VAE, PR3) also imports this helper — PR3 must migrate its l2 branch to cdist in the same series. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/modules/sid/kmeans.py | 29 ++++++++--------------------- tzrec/modules/sid/kmeans_test.py | 9 --------- 2 files changed, 8 insertions(+), 30 deletions(-) diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans.py index d6e34acd9..7e874a0c1 100644 --- a/tzrec/modules/sid/kmeans.py +++ b/tzrec/modules/sid/kmeans.py @@ -52,25 +52,6 @@ def recon_diagnostics( return mse, rel -@torch.no_grad() -def _squared_euclidean_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """Squared L2 distance between rows of ``x`` and ``y``. - - Args: - x (Tensor): data points, shape (N, D). - y (Tensor): centroids, shape (K, D). - - Returns: - Tensor: squared distances, shape (N, K). - - Kept branch-free (no data-dependent control flow on ``N``) so the - per-batch predict forward stays FX-traceable for torchrec inference. - """ - x_sq = x.pow(2).sum(dim=1, keepdim=True) # (N, 1) - y_sq = y.pow(2).sum(dim=1, keepdim=True).t() # (1, K) - return (x_sq + y_sq - 2.0 * x @ y.t()).clamp_(min=0.0) - - class KMeansLayer(nn.Module): """Single layer of a residual K-Means stack. @@ -165,11 +146,17 @@ def _load_from_state_dict( def predict(self, batch: torch.Tensor) -> torch.Tensor: """Assign points to nearest centroid. + Uses ``torch.cdist`` (plain L2). argmin is invariant to the monotonic + sqrt, so the assignment is identical to squared-L2 for all + non-degenerate inputs (verified bit-exact across random / large- + magnitude / near-tie sweeps); only an exact equidistant tie — measure + zero for real embeddings — may resolve to a different, equally-near + centroid. + Args: batch (Tensor): data points, shape (B, D). Returns: Tensor: cluster indices, shape (B,). """ - dists = _squared_euclidean_distance(batch, self.centroids) - return torch.argmin(dists, dim=-1) + return torch.cdist(batch, self.centroids).argmin(dim=-1) diff --git a/tzrec/modules/sid/kmeans_test.py b/tzrec/modules/sid/kmeans_test.py index cb86a39d8..1b21604d3 100644 --- a/tzrec/modules/sid/kmeans_test.py +++ b/tzrec/modules/sid/kmeans_test.py @@ -15,7 +15,6 @@ from tzrec.modules.sid.kmeans import ( KMeansLayer, - _squared_euclidean_distance, recon_diagnostics, ) @@ -29,14 +28,6 @@ def test_recon_diagnostics_zero_on_identity(self) -> None: self.assertAlmostEqual(mse.item(), 0.0, places=6) self.assertAlmostEqual(rel.item(), 0.0, places=6) - def test_squared_euclidean_distance(self) -> None: - x = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) - y = torch.tensor([[0.0, 0.0], [0.0, 1.0]]) - d = _squared_euclidean_distance(x, y) - self.assertEqual(d.shape, (2, 2)) - # row0: dist to (0,0)=0, to (0,1)=1; row1: to (0,0)=1, to (0,1)=2 - torch.testing.assert_close(d, torch.tensor([[0.0, 1.0], [1.0, 2.0]])) - class KMeansLayerTest(unittest.TestCase): """Tests for the single KMeansLayer.""" From d037db7dbb6701a4a2d5f59989da170066c5fc20 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 02:56:23 +0000 Subject: [PATCH 23/46] [refactor] SidRqkmeans: drop input_embedding from predictions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per maintainer review (#1): the reconstruction target is an input, not a model output, so don't thread it through predictions. update_metric now re-extracts the embedding from batch (mirrors SidRqvae.update_metric) and guards on "quantized", which is eval-only — so the reconstruction metric stays eval-only by construction. predict (eval) now exposes {codes, quantized}; the metric test passes the same batch through predict + update_metric so the re-extracted target matches the prediction. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 29 ++++++++++++++--------------- tzrec/models/sid_rqkmeans_test.py | 16 ++++++++++------ 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index c393b4943..7a065e471 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -218,7 +218,6 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: if self.is_eval: predictions["quantized"] = quantized - predictions["input_embedding"] = embedding return predictions @@ -257,28 +256,28 @@ def update_metric( ) -> None: """Update metric state. - Note: ``mse``/``rel_loss`` compare ``input_embedding`` against the - centroid-sum reconstruction. They are meaningful reconstruction - metrics only with ``normalize_residuals=False`` (the default); with - normalization the centroids live on the rescaled-residual scale, so - the two quantities don't share a scale (same caveat the train_offline - per-layer log carries). + The reconstruction target (the input embedding) is re-extracted from + ``batch`` rather than threaded through ``predictions`` — it is an input, + not a model output (mirrors ``SidRqvae.update_metric``). ``quantized`` is + present only in eval (see ``predict``), so this runs eval-only. + + Note: ``mse``/``rel_loss`` compare that embedding against the centroid-sum + reconstruction. They are meaningful reconstruction metrics only with + ``normalize_residuals=False`` (the default); with normalization the + centroids live on the rescaled-residual scale, so the two quantities + don't share a scale (same caveat the train_offline per-layer log carries). Args: predictions (dict): a dict of predicted result. batch (Batch): input batch data. losses (dict, optional): a dict of loss. """ - if "input_embedding" in predictions: - _, rel = recon_diagnostics( - predictions["input_embedding"], - predictions["quantized"], - ) + if "quantized" in predictions: + embedding = self._extract_feature(batch) + _, rel = recon_diagnostics(embedding, predictions["quantized"]) # mse aggregates (preds, target) itself; rel_loss has no # torchmetrics equivalent, so it stays a MeanMetric. - self._metric_modules["mse"].update( - predictions["quantized"], predictions["input_embedding"] - ) + self._metric_modules["mse"].update(predictions["quantized"], embedding) self._metric_modules["rel_loss"].update(rel) self._metric_modules["unique_sid_ratio"].update(predictions["codes"]) diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index f29b9455a..f0964a8d0 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -275,7 +275,7 @@ def test_normalize_residuals_end_to_end(self) -> None: self.assertTrue((codes >= 0).all() and (codes < 16).all()) def test_eval_and_inference_predict_contract(self) -> None: - """Eval exposes quantized/input_embedding; inference is codes-only.""" + """Eval exposes codes + quantized only; inference is codes-only.""" try: import faiss # noqa: F401 except ImportError: @@ -288,11 +288,12 @@ def test_eval_and_inference_predict_contract(self) -> None: model.predict(_make_batch(B, input_dim)) model.on_train_end() - # Eval mode: reconstruction outputs are present for update_metric. + # Eval mode: the centroid-sum reconstruction is exposed for + # update_metric; the input embedding is NOT threaded through + # predictions (it is re-extracted from the batch in update_metric). model.eval() eval_preds = model.predict(_make_batch(B, input_dim)) - self.assertIn("quantized", eval_preds) - self.assertIn("input_embedding", eval_preds) + self.assertEqual(set(eval_preds.keys()), {"codes", "quantized"}) # Inference (serving) mode: codes-only contract. model.set_is_inference(True) @@ -315,8 +316,11 @@ def test_eval_metric_path(self) -> None: model.init_metric() model.eval() - preds = model.predict(_make_batch(B, input_dim)) - model.update_metric(preds, _make_batch(B, input_dim)) + # Same batch through predict + update_metric: the reconstruction target + # is re-extracted from this batch, so it must match the predicted one. + batch = _make_batch(B, input_dim) + preds = model.predict(batch) + model.update_metric(preds, batch) metrics = model.compute_metric() for key in ("mse", "rel_loss", "unique_sid_ratio"): self.assertIn(key, metrics) From 88856f3adf44e809fcf1d9d24fc75d761b73703b Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 03:08:18 +0000 Subject: [PATCH 24/46] [simplify] trim SID docstrings (predict provenance; stale SidRqvae xref) - KMeansLayer.predict: collapse the 6-line cdist-vs-squared-L2 verification provenance to a one-line equivalence note. - SidRqkmeans.update_metric: trim the over-explained re-extraction paragraph and drop the ``SidRqvae.update_metric`` cross-ref (sid_rqvae.py is PR3, not present in this PR's merge target, so the symbol doesn't resolve). Docstring-only; no behavior change. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 5 ++--- tzrec/modules/sid/kmeans.py | 9 +++------ 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 7a065e471..35dbc1036 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -257,9 +257,8 @@ def update_metric( """Update metric state. The reconstruction target (the input embedding) is re-extracted from - ``batch`` rather than threaded through ``predictions`` — it is an input, - not a model output (mirrors ``SidRqvae.update_metric``). ``quantized`` is - present only in eval (see ``predict``), so this runs eval-only. + ``batch`` — it is an input, not a model output. ``quantized`` is present + only in eval (see ``predict``), so this runs eval-only. Note: ``mse``/``rel_loss`` compare that embedding against the centroid-sum reconstruction. They are meaningful reconstruction metrics only with diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans.py index 7e874a0c1..02cfc63d6 100644 --- a/tzrec/modules/sid/kmeans.py +++ b/tzrec/modules/sid/kmeans.py @@ -146,12 +146,9 @@ def _load_from_state_dict( def predict(self, batch: torch.Tensor) -> torch.Tensor: """Assign points to nearest centroid. - Uses ``torch.cdist`` (plain L2). argmin is invariant to the monotonic - sqrt, so the assignment is identical to squared-L2 for all - non-degenerate inputs (verified bit-exact across random / large- - magnitude / near-tie sweeps); only an exact equidistant tie — measure - zero for real embeddings — may resolve to a different, equally-near - centroid. + Uses ``torch.cdist`` (L2); argmin is invariant to the monotonic sqrt, + so assignments match squared-L2 except at exact equidistant ties + (measure zero for real embeddings), where either centroid is valid. Args: batch (Tensor): data points, shape (B, D). From 2fa312b99da7c115c7d44021e26075d88fb8f7df Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 03:27:13 +0000 Subject: [PATCH 25/46] [refactor] extract reservoir sampling into ReservoirSampler (kmeans.py) Move the Vitter-Algorithm-R reservoir out of SidRqkmeans into a standalone ReservoirSampler class in kmeans.py (the shared SID/kmeans utility module, already home to recon_diagnostics). SidRqkmeans now holds one ReservoirSampler(cap, dim) and calls add()/sample()/reset(); the four state fields and three private reservoir methods are gone. Reservoir-mechanics tests (caps_memory, phase2_replacement) move to kmeans_test.py against ReservoirSampler directly (no model needed), plus empty-sample and reset tests; model tests now poke the sampler via its n_seen/n_filled/capacity accessors. Behavior-preserving. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 88 ++++------------------------- tzrec/models/sid_rqkmeans_test.py | 80 ++++----------------------- tzrec/modules/sid/kmeans.py | 92 ++++++++++++++++++++++++++++++- tzrec/modules/sid/kmeans_test.py | 73 ++++++++++++++++++++++++ 4 files changed, 188 insertions(+), 145 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 35dbc1036..3f3b6954f 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -27,7 +27,7 @@ from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature from tzrec.models.sid_model import BaseSidModel -from tzrec.modules.sid.kmeans import recon_diagnostics +from tzrec.modules.sid.kmeans import ReservoirSampler, recon_diagnostics from tzrec.modules.sid.residual_kmeans_quantizer import ( ResidualKMeansQuantizer, ) @@ -109,82 +109,18 @@ def __init__( faiss_kmeans_kwargs=self._faiss_kwargs, ) - self._init_reservoir() + # Bounded host reservoir for the end-of-loop FAISS fit: cap at + # ``train_sample_size`` when set (>0), else the points the FAISS fit + # subsamples to (``default_fit_sample_size``) — rather than buffer the + # whole corpus. Single-process only (see the world_size guard above), + # so no per-rank split. + target = self._model_config.train_sample_size + cap = target if target > 0 else self._quantizer.default_fit_sample_size() + self._reservoir = ReservoirSampler(cap, self._input_dim) # KMeans has no learnable params; a dummy keeps the optimizer/DDP happy. self._dummy_param = nn.Parameter(torch.zeros(1), requires_grad=True) - def _init_reservoir(self) -> None: - """Set up the bounded host reservoir for the end-of-loop FAISS fit. - - Caps at ``train_sample_size`` when set (>0), else the points the FAISS - fit subsamples to (``ResidualKMeansQuantizer.default_fit_sample_size``) - — rather than buffer the whole corpus. Single-process only (see the - world_size guard in ``__init__``), so no per-rank split. - """ - target = self._model_config.train_sample_size - self._sample_cap = ( - target if target > 0 else self._quantizer.default_fit_sample_size() - ) - - # Allocated lazily on the first batch. _n_filled = used slots; - # _n_seen = running count for the accept prob. - self._reservoir: Optional[torch.Tensor] = None - self._n_filled = 0 - self._n_seen = 0 - - @torch.no_grad() - def _reservoir_add(self, x: torch.Tensor) -> None: - """Stream a batch into the reservoir (Vitter Algorithm R). - - Keeps a uniform ``_sample_cap`` sample of all embeddings seen, in - O(cap) host memory. - - Args: - x (Tensor): batch of embeddings, shape (B, D). - """ - x = x.detach() - cap = self._sample_cap - if self._reservoir is None: - self._reservoir = torch.empty(cap, x.shape[1], dtype=torch.float32) - - # Phase 1: fill empty slots first. x is already on the host (CPU-only - # model), so this is a dtype cast into the reservoir, not a device copy. - if self._n_filled < cap: - take = min(x.shape[0], cap - self._n_filled) - self._reservoir[self._n_filled : self._n_filled + take] = x[:take].to( - torch.float32 - ) - self._n_filled += take - self._n_seen += take - x = x[take:] - if x.shape[0] == 0: - return - - # Phase 2: row j enters with prob cap/(n_seen+j+1), displacing a random - # slot. float64 keeps n_seen+j+1 exact past 2**24. - r = x.shape[0] - pos = self._n_seen + torch.arange(r) - accept = torch.rand(r) < (cap / (pos + 1).to(torch.float64)) - idx = accept.nonzero(as_tuple=True)[0] - if idx.numel() > 0: - slots = torch.randint(0, cap, (idx.numel(),)) - # Slot collisions are last-write-wins; O(B/cap) bias, negligible here. - self._reservoir[slots] = x[idx].to(torch.float32) - self._n_seen += r - - def _reservoir_sample(self) -> torch.Tensor: - """Return the filled portion of the reservoir, shape (n_filled, D).""" - if self._reservoir is None or self._n_filled == 0: - return torch.empty(0, self._input_dim, dtype=torch.float32) - return self._reservoir[: self._n_filled] - - def _reset_reservoir(self) -> None: - """Drop the reservoir after the FAISS fit to free host memory.""" - self._reservoir = None - self._n_filled = 0 - self._n_seen = 0 - def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Predict the model. @@ -202,7 +138,7 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: # Training: just reservoir-sample for the end-of-loop FAISS fit and # return dummy codes — the codebook does not exist yet. if self.is_train: - self._reservoir_add(embedding) + self._reservoir.add(embedding) B = embedding.shape[0] return { "codes": torch.zeros( @@ -298,8 +234,8 @@ def on_train_end(self) -> bool: (centroids changed → force a final checkpoint), ``False`` if the fit was skipped (empty reservoir). """ - local = self._reservoir_sample() - self._reset_reservoir() + local = self._reservoir.sample() + self._reservoir.reset() if local.shape[0] == 0: logger.warning( diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index f0964a8d0..c312ebad3 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -90,18 +90,20 @@ def test_proto_parse(self) -> None: self.assertEqual(model._faiss_kwargs.get("niter"), 5) self.assertEqual(model._faiss_kwargs.get("seed"), 1234) self.assertFalse(model._faiss_kwargs.get("verbose")) - self.assertEqual(model._n_seen, 0) - self.assertIsNone(model._reservoir) + self.assertEqual(model._reservoir.n_seen, 0) + self.assertEqual(model._reservoir.n_filled, 0) def test_sample_cap_from_train_sample_size(self) -> None: """train_sample_size (when set) drives the reservoir cap directly.""" # Explicit train_sample_size: cap == train_sample_size. model = self._create_model(train_sample_size=900) - self.assertEqual(model._sample_cap, 900) + self.assertEqual(model._reservoir.capacity, 900) # Default (train_sample_size=0): cap == the FAISS fit's subsample size. model = self._create_model() - self.assertEqual(model._sample_cap, model._quantizer.default_fit_sample_size()) + self.assertEqual( + model._reservoir.capacity, model._quantizer.default_fit_sample_size() + ) def test_predict_collects_buffer(self) -> None: """In train mode, predict reservoir-samples; never fits.""" @@ -116,70 +118,12 @@ def test_predict_collects_buffer(self) -> None: # Reservoir holds all 4*B samples (well under the cap) and tracks # the running count. - self.assertEqual(model._n_seen, 4 * B) - self.assertEqual(model._n_filled, 4 * B) + self.assertEqual(model._reservoir.n_seen, 4 * B) + self.assertEqual(model._reservoir.n_filled, 4 * B) # FAISS not yet triggered: layers should be uninitialized for layer in model._quantizer.layers: self.assertFalse(layer.is_initialized) - def test_reservoir_caps_memory(self) -> None: - """Reservoir bounds the buffer at _sample_cap regardless of corpus.""" - B, input_dim = 16, 8 - model = self._create_model(input_dim=input_dim) - model._sample_cap = 10 # force a tiny cap - model._reset_reservoir() - model.train() - for _ in range(20): # 320 rows >> cap - model.predict(_make_batch(B, input_dim)) - self.assertEqual(model._n_seen, 20 * B) - self.assertEqual(model._n_filled, 10) - self.assertEqual(model._reservoir.shape, (10, input_dim)) - - def test_reservoir_phase2_replacement(self) -> None: - """Phase-2 replacement keeps a valid reservoir of real, in-range rows. - - Feeds identifiable rows (each row's value == its global stream index), - then asserts every reservoir slot still holds an intact fed row, all - indices are in range, and replacement past the initial fill actually - happened — exercising the accept-prob / slot-write logic that the - count/shape-only ``test_reservoir_caps_memory`` cannot. - """ - torch.manual_seed(0) - input_dim, cap, B, n_batches = 4, 8, 4, 50 - model = self._create_model(input_dim=input_dim) - model._sample_cap = cap - model._reset_reservoir() - model.train() - - gidx = 0 - for _ in range(n_batches): - rows = ( - torch.arange(gidx, gidx + B, dtype=torch.float32) - .unsqueeze(1) - .expand(B, input_dim) - .contiguous() - ) - gidx += B - model.predict(_batch_from_rows(rows)) - - total = B * n_batches - self.assertEqual(model._n_seen, total) - self.assertEqual(model._n_filled, cap) - - res = model._reservoir - idx = res[:, 0].round().long() - # Each stored row is an intact fed row (all columns equal its index), - # never zeros/garbage. - self.assertTrue( - torch.equal(res, idx.unsqueeze(1).float().expand_as(res)), - "reservoir holds corrupted (non-fed) rows", - ) - # All indices are valid stream positions. - self.assertTrue((idx >= 0).all() and (idx < total).all()) - # Phase-2 replacement happened: at least one slot holds a row added - # after the reservoir filled (index >= cap). - self.assertTrue((idx >= cap).any(), "no Phase-2 replacement occurred") - def test_on_train_end_runs_faiss(self) -> None: """on_train_end triggers FAISS fit and clears buffer.""" try: @@ -194,14 +138,14 @@ def test_on_train_end_runs_faiss(self) -> None: # Accumulate enough samples (FAISS K-Means needs at least K points) for _ in range(8): model.predict(_make_batch(B, input_dim)) - self.assertGreater(model._n_seen, 0) + self.assertGreater(model._reservoir.n_seen, 0) # Trigger one-shot FAISS fit; a real fit must request a tail checkpoint self.assertTrue(model.on_train_end()) # Reservoir should be released after the fit - self.assertEqual(model._n_seen, 0) - self.assertIsNone(model._reservoir) + self.assertEqual(model._reservoir.n_seen, 0) + self.assertEqual(model._reservoir.n_filled, 0) # All layers should be initialized + centroids non-zero for layer in model._quantizer.layers: self.assertTrue(bool(layer._is_initialized.item())) @@ -226,7 +170,7 @@ def test_non_uniform_codebook_end_to_end(self) -> None: model = self._create_model(input_dim=input_dim, codebook=codebook) # Reservoir cap derives from the LARGEST K (16), not the first (8). self.assertEqual( - model._sample_cap, + model._reservoir.capacity, 16 * int(model._faiss_kwargs.get("max_points_per_centroid", 256)), ) diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans.py index 02cfc63d6..629392fc4 100644 --- a/tzrec/modules/sid/kmeans.py +++ b/tzrec/modules/sid/kmeans.py @@ -18,9 +18,12 @@ :class:`ResidualKMeansQuantizer`. Centroids are injected by the FAISS backend via ``load_centroids_``; the only forward path is ``predict``. +* :class:`ReservoirSampler` — bounded uniform stream sample (Vitter + Algorithm R) that :class:`~tzrec.models.sid_rqkmeans.SidRqkmeans` + fills during training to feed the one-shot FAISS fit. """ -from typing import Tuple +from typing import Optional, Tuple import torch from torch import nn @@ -52,6 +55,93 @@ def recon_diagnostics( return mse, rel +class ReservoirSampler: + """Bounded uniform sample of a stream (Vitter Algorithm R). + + Keeps a uniform ``capacity``-row sample of all rows passed to ``add``, in + O(capacity) host (CPU) memory — used to subsample the training corpus for + the one-shot FAISS fit without buffering the whole corpus. The buffer is a + CPU float32 tensor, allocated lazily on the first ``add``. + + Args: + capacity (int): max rows retained. + dim (int): row width (feature dimension). + """ + + def __init__(self, capacity: int, dim: int) -> None: + self._cap = capacity + self._dim = dim + # Allocated lazily on the first add. _n_filled = used slots; + # _n_seen = running count for the accept prob. + self._buf: Optional[torch.Tensor] = None + self._n_filled = 0 + self._n_seen = 0 + + @property + def capacity(self) -> int: + """Max rows retained.""" + return self._cap + + @property + def n_seen(self) -> int: + """Total rows passed to ``add`` so far.""" + return self._n_seen + + @property + def n_filled(self) -> int: + """Rows currently held (<= capacity).""" + return self._n_filled + + @torch.no_grad() + def add(self, x: torch.Tensor) -> None: + """Stream a batch of rows into the reservoir. + + Args: + x (Tensor): rows to add, shape (B, dim). + """ + x = x.detach() + cap = self._cap + if self._buf is None: + self._buf = torch.empty(cap, self._dim, dtype=torch.float32) + + # Phase 1: fill empty slots first. x is already on the host (CPU-only + # model), so this is a dtype cast into the buffer, not a device copy. + if self._n_filled < cap: + take = min(x.shape[0], cap - self._n_filled) + self._buf[self._n_filled : self._n_filled + take] = x[:take].to( + torch.float32 + ) + self._n_filled += take + self._n_seen += take + x = x[take:] + if x.shape[0] == 0: + return + + # Phase 2: row j enters with prob cap/(n_seen+j+1), displacing a random + # slot. float64 keeps n_seen+j+1 exact past 2**24. + r = x.shape[0] + pos = self._n_seen + torch.arange(r) + accept = torch.rand(r) < (cap / (pos + 1).to(torch.float64)) + idx = accept.nonzero(as_tuple=True)[0] + if idx.numel() > 0: + slots = torch.randint(0, cap, (idx.numel(),)) + # Slot collisions are last-write-wins; O(B/cap) bias, negligible here. + self._buf[slots] = x[idx].to(torch.float32) + self._n_seen += r + + def sample(self) -> torch.Tensor: + """Return the filled portion of the reservoir, shape (n_filled, dim).""" + if self._buf is None or self._n_filled == 0: + return torch.empty(0, self._dim, dtype=torch.float32) + return self._buf[: self._n_filled] + + def reset(self) -> None: + """Drop the buffer and counters to free host memory.""" + self._buf = None + self._n_filled = 0 + self._n_seen = 0 + + class KMeansLayer(nn.Module): """Single layer of a residual K-Means stack. diff --git a/tzrec/modules/sid/kmeans_test.py b/tzrec/modules/sid/kmeans_test.py index 1b21604d3..d6b06a7f1 100644 --- a/tzrec/modules/sid/kmeans_test.py +++ b/tzrec/modules/sid/kmeans_test.py @@ -15,6 +15,7 @@ from tzrec.modules.sid.kmeans import ( KMeansLayer, + ReservoirSampler, recon_diagnostics, ) @@ -70,5 +71,77 @@ def test_post_fit_checkpoint_round_trips(self) -> None: torch.testing.assert_close(fresh.centroids, layer.centroids) +class ReservoirSamplerTest(unittest.TestCase): + """Tests for the bounded reservoir sampler (Vitter Algorithm R).""" + + def test_empty_sample(self) -> None: + """sample() before any add returns an empty (0, dim) tensor.""" + r = ReservoirSampler(capacity=10, dim=4) + self.assertEqual(r.sample().shape, (0, 4)) + self.assertEqual(r.n_seen, 0) + self.assertEqual(r.n_filled, 0) + + def test_caps_memory(self) -> None: + """The buffer is bounded at capacity regardless of stream length.""" + cap, dim, B = 10, 8, 16 + r = ReservoirSampler(capacity=cap, dim=dim) + for _ in range(20): # 320 rows >> cap + r.add(torch.randn(B, dim)) + self.assertEqual(r.n_seen, 20 * B) + self.assertEqual(r.n_filled, cap) + self.assertEqual(r.sample().shape, (cap, dim)) + + def test_phase2_replacement(self) -> None: + """Phase-2 replacement keeps a valid sample of real, in-range rows. + + Feeds identifiable rows (each row's value == its global stream index), + then asserts every slot still holds an intact fed row, all indices are + in range, and replacement past the initial fill actually happened — + exercising the accept-prob / slot-write logic that the count/shape-only + ``test_caps_memory`` cannot. + """ + torch.manual_seed(0) + dim, cap, B, n_batches = 4, 8, 4, 50 + r = ReservoirSampler(capacity=cap, dim=dim) + + gidx = 0 + for _ in range(n_batches): + rows = ( + torch.arange(gidx, gidx + B, dtype=torch.float32) + .unsqueeze(1) + .expand(B, dim) + .contiguous() + ) + gidx += B + r.add(rows) + + total = B * n_batches + self.assertEqual(r.n_seen, total) + self.assertEqual(r.n_filled, cap) + + res = r.sample() + idx = res[:, 0].round().long() + # Each stored row is an intact fed row (all columns equal its index). + self.assertTrue( + torch.equal(res, idx.unsqueeze(1).float().expand_as(res)), + "reservoir holds corrupted (non-fed) rows", + ) + # All indices are valid stream positions. + self.assertTrue((idx >= 0).all() and (idx < total).all()) + # Phase-2 replacement happened: at least one slot holds a row added + # after the reservoir filled (index >= cap). + self.assertTrue((idx >= cap).any(), "no Phase-2 replacement occurred") + + def test_reset(self) -> None: + """reset() drops the buffer and counters.""" + r = ReservoirSampler(capacity=10, dim=4) + r.add(torch.randn(5, 4)) + self.assertEqual(r.n_filled, 5) + r.reset() + self.assertEqual(r.n_seen, 0) + self.assertEqual(r.n_filled, 0) + self.assertEqual(r.sample().shape, (0, 4)) + + if __name__ == "__main__": unittest.main() From e296c8d32e629b021e6cd9f7947b9cc2c0e609f3 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 03:30:08 +0000 Subject: [PATCH 26/46] [refactor] ReservoirSampler: log capacity + dim on construction Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/modules/sid/kmeans.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans.py index 629392fc4..7c89f7d17 100644 --- a/tzrec/modules/sid/kmeans.py +++ b/tzrec/modules/sid/kmeans.py @@ -28,6 +28,8 @@ import torch from torch import nn +from tzrec.utils.logging_util import logger + def recon_diagnostics( x: torch.Tensor, @@ -76,6 +78,7 @@ def __init__(self, capacity: int, dim: int) -> None: self._buf: Optional[torch.Tensor] = None self._n_filled = 0 self._n_seen = 0 + logger.info("[ReservoirSampler] capacity=%d, dim=%d", capacity, dim) @property def capacity(self) -> int: From 892a8d26538fd816b715170a17847f2bcc5d3c55 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 03:44:41 +0000 Subject: [PATCH 27/46] [fix] SID code-review: fail-fast cap, skip pre-fit eval, dedup MSE, drop x0 clone Addresses /code-review findings on SidRqkmeans: #1 Fail fast at __init__ when the reservoir cap < max(codebook) (an explicit train_sample_size too small would otherwise assert in train_offline only at on_train_end, after the whole training pass). #2 update_metric returns early when the codebook isn't fitted yet (ResidualKMeansQuantizer.is_fitted), so in-loop eval before the end-of-train FAISS fit no longer logs garbage mse/rel_loss/unique_sid_ratio over the all-zero codebook. #3 Stop computing MSE twice per eval batch: extract a relative_l1 helper (recon_diagnostics now reuses it) and call it directly instead of recon_diagnostics-then-discard-mse alongside MeanSquaredError.update. #4 Drop the persistent ~0.5GB x0 clone in train_offline for the common (normalize_residuals=False) path: reconstruct the per-layer log reference on the fly via the out + x == x0 invariant; keep the clone only when normalization breaks that invariant. Tests: add fail-fast and pre-fit-skip cases. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 25 +++++++++++++++---- tzrec/models/sid_rqkmeans_test.py | 16 ++++++++++++ tzrec/modules/sid/kmeans.py | 23 ++++++++++++++--- .../modules/sid/residual_kmeans_quantizer.py | 23 +++++++++++++---- 4 files changed, 74 insertions(+), 13 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 3f3b6954f..a79250bac 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -27,7 +27,7 @@ from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature from tzrec.models.sid_model import BaseSidModel -from tzrec.modules.sid.kmeans import ReservoirSampler, recon_diagnostics +from tzrec.modules.sid.kmeans import ReservoirSampler, relative_l1 from tzrec.modules.sid.residual_kmeans_quantizer import ( ResidualKMeansQuantizer, ) @@ -116,6 +116,14 @@ def __init__( # so no per-rank split. target = self._model_config.train_sample_size cap = target if target > 0 else self._quantizer.default_fit_sample_size() + # Fail fast: FAISS needs >= K points to fit each layer, so a cap below + # the largest codebook would only assert at on_train_end — after the + # whole training pass. (The default cap is always >= max(K).) + max_k = max(self._n_embed_list) + assert cap >= max_k, ( + f"reservoir cap ({cap}) < largest codebook size ({max_k}); set " + f"train_sample_size >= {max_k} (or 0 for the default)." + ) self._reservoir = ReservoirSampler(cap, self._input_dim) # KMeans has no learnable params; a dummy keeps the optimizer/DDP happy. @@ -207,13 +215,20 @@ def update_metric( batch (Batch): input batch data. losses (dict, optional): a dict of loss. """ + # In-loop eval can run before the end-of-train FAISS fit; the codebook + # is all-zeros then, so codes/reconstruction are meaningless. Skip until + # fitted so those bogus values don't pollute the eval metrics. + if not self._quantizer.is_fitted: + return + if "quantized" in predictions: embedding = self._extract_feature(batch) - _, rel = recon_diagnostics(embedding, predictions["quantized"]) - # mse aggregates (preds, target) itself; rel_loss has no - # torchmetrics equivalent, so it stays a MeanMetric. + # mse aggregates (preds, target) itself; rel_loss has no torchmetrics + # equivalent, so compute it directly (only rel is needed here). self._metric_modules["mse"].update(predictions["quantized"], embedding) - self._metric_modules["rel_loss"].update(rel) + self._metric_modules["rel_loss"].update( + relative_l1(embedding, predictions["quantized"]) + ) self._metric_modules["unique_sid_ratio"].update(predictions["codes"]) diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index c312ebad3..e41fba295 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -105,6 +105,11 @@ def test_sample_cap_from_train_sample_size(self) -> None: model._reservoir.capacity, model._quantizer.default_fit_sample_size() ) + def test_init_raises_on_too_small_train_sample_size(self) -> None: + """train_sample_size below the largest codebook fails fast at init.""" + with self.assertRaisesRegex(AssertionError, "largest codebook"): + self._create_model(codebook=[16, 16], train_sample_size=8) + def test_predict_collects_buffer(self) -> None: """In train mode, predict reservoir-samples; never fits.""" B, input_dim = 8, 32 @@ -270,6 +275,17 @@ def test_eval_metric_path(self) -> None: self.assertIn(key, metrics) self.assertTrue(torch.isfinite(torch.as_tensor(metrics[key])).all()) + def test_update_metric_skipped_before_fit(self) -> None: + """Pre-fit eval (unfitted codebook) does not pollute metric state.""" + B, input_dim = 8, 32 + model = self._create_model(input_dim=input_dim) + model.init_metric() + model.eval() + # Codebook not fitted yet: predict emits zeros; update_metric must skip. + batch = _make_batch(B, input_dim) + model.update_metric(model.predict(batch), batch) + self.assertEqual(model._metric_modules["unique_sid_ratio"].count.item(), 0.0) + def test_on_train_end_noop_on_empty_buffer(self) -> None: """on_train_end on an empty buffer is a warned no-op.""" model = self._create_model() diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans.py index 7c89f7d17..50b2263f3 100644 --- a/tzrec/modules/sid/kmeans.py +++ b/tzrec/modules/sid/kmeans.py @@ -50,11 +50,28 @@ def recon_diagnostics( mse: scalar ``((out - x) ** 2).mean()``. rel: scalar relative-L1 ``mean(|x - out| / (max(|x|, |out|) + eps))``. """ - mse = ((out - x) ** 2).mean() - rel = ( + return ((out - x) ** 2).mean(), relative_l1(x, out, epsilon) + + +def relative_l1( + x: torch.Tensor, + out: torch.Tensor, + epsilon: float = 1e-4, +) -> torch.Tensor: + """Relative-L1 error ``mean(|x - out| / (max(|x|, |out|) + eps))``. + + Symmetric relative error in [0, 1] (verbatim port of OpenOneRec's + ``calc_loss``). Used standalone by :meth:`SidRqkmeans.update_metric` (which + needs only ``rel``, not the MSE :meth:`recon_diagnostics` also computes). + + Args: + x: ground-truth embedding, shape (B, D). + out: quantized reconstruction, shape (B, D). + epsilon: numerical stabilizer for the denominator. + """ + return ( torch.abs(x - out) / (torch.maximum(torch.abs(x), torch.abs(out)) + epsilon) ).mean() - return mse, rel class ReservoirSampler: diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 1bfe20267..8a0c8a176 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -127,6 +127,15 @@ def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: cluster_ids, quantized_sum, _ = self._residual_pass(input) return cluster_ids, quantized_sum + @property + def is_fitted(self) -> bool: + """Whether ``train_offline`` has populated every layer's codebook. + + ``forward`` is callable before the fit (uninitialized layers emit + zeros), so reconstruction outputs are meaningful only once this is True. + """ + return all(layer.is_initialized for layer in self.layers) + @torch.no_grad() def get_codebook_embeddings(self, layer_idx: int) -> torch.Tensor: """Get centroid weights for a specific layer. @@ -189,10 +198,12 @@ def train_offline( f"need >= {max_k} points to fit the codebook (largest layer K), got N={N}" ) out = torch.zeros_like(x) - # Original input, kept only for the log: the per-layer diagnostic is the - # cumulative recon error of x0 by the centroid sum (what update_metric - # reports). ``out + x`` would equal it only without normalization. - x0 = x.clone() if verbose else None + # The per-layer log reports the cumulative recon error of the original + # input x0 by the centroid sum. Without normalization the invariant + # ``out + x == x0`` holds, so x0 is reconstructed on the fly below and we + # skip the persistent (N, D) clone; with normalization x is rescaled each + # layer, breaking the invariant, so the clone is required. + x0 = x.clone() if (verbose and self.normalize_residuals) else None # CPU-only fit: SidRqkmeans refuses to initialize when CUDA is visible, # so the codebook is always built on CPU. Drop any stale ``gpu`` request @@ -233,10 +244,12 @@ def train_offline( del idx, q if verbose: + # x0 == out + x without normalization (see above). + ref = x0 if x0 is not None else out + x logger.info( "[ResidualKMeansQuantizer][offline_faiss][layer %d] %s", layer_idx, - self._calc_loss(x0, out), # cumulative recon of original input + self._calc_loss(ref, out), # cumulative recon of original input ) self.layers[layer_idx].load_centroids_(centroids) From b14304af2a6fae7c29cd17ddd46d5f0b9930c129 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 03:50:03 +0000 Subject: [PATCH 28/46] [simplify] SID: raise (not assert) for cap guard; name normalize_residuals - __init__: the cap < max(codebook) fail-fast used a bare assert, which is stripped under `python -O` (defeating the fail-fast purpose) and was inconsistent with the two sibling raise-guards in the same constructor. Convert to raise RuntimeError; update the test accordingly. - train_offline: `ref = x0 if self.normalize_residuals else out + x` names the actual reason instead of re-deriving it via the `x0 is not None` sentinel. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 9 +++++---- tzrec/models/sid_rqkmeans_test.py | 2 +- tzrec/modules/sid/residual_kmeans_quantizer.py | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index a79250bac..af9b5b7e7 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -120,10 +120,11 @@ def __init__( # the largest codebook would only assert at on_train_end — after the # whole training pass. (The default cap is always >= max(K).) max_k = max(self._n_embed_list) - assert cap >= max_k, ( - f"reservoir cap ({cap}) < largest codebook size ({max_k}); set " - f"train_sample_size >= {max_k} (or 0 for the default)." - ) + if cap < max_k: + raise RuntimeError( + f"reservoir cap ({cap}) < largest codebook size ({max_k}); set " + f"train_sample_size >= {max_k} (or 0 for the default)." + ) self._reservoir = ReservoirSampler(cap, self._input_dim) # KMeans has no learnable params; a dummy keeps the optimizer/DDP happy. diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index e41fba295..782991eac 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -107,7 +107,7 @@ def test_sample_cap_from_train_sample_size(self) -> None: def test_init_raises_on_too_small_train_sample_size(self) -> None: """train_sample_size below the largest codebook fails fast at init.""" - with self.assertRaisesRegex(AssertionError, "largest codebook"): + with self.assertRaisesRegex(RuntimeError, "largest codebook"): self._create_model(codebook=[16, 16], train_sample_size=8) def test_predict_collects_buffer(self) -> None: diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 8a0c8a176..9816341e9 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -245,7 +245,7 @@ def train_offline( if verbose: # x0 == out + x without normalization (see above). - ref = x0 if x0 is not None else out + x + ref = x0 if self.normalize_residuals else out + x logger.info( "[ResidualKMeansQuantizer][offline_faiss][layer %d] %s", layer_idx, From eb39b5e294b3a71599f871b8d505777e5393addf Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 03:53:02 +0000 Subject: [PATCH 29/46] [style] SID: trim verbose comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Tighten the multi-line block comments added across the recent SID work (CPU-only/single-process guards, reservoir cap, x0 invariant, gpu-kwarg drop, reservoir Phase-1, update_metric) — keep the load-bearing "why", drop the over-explanation the error messages and code already convey. Comments only. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 27 +++++++------------ tzrec/modules/sid/kmeans.py | 4 +-- .../modules/sid/residual_kmeans_quantizer.py | 14 +++++----- 3 files changed, 18 insertions(+), 27 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index af9b5b7e7..8e181b1c8 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -72,19 +72,16 @@ def __init__( ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) - # CPU-only: everything (embeddings, reservoir, FAISS fit) stays on the - # host, so there are no device copies on the train path. Refuse to run - # when CUDA is visible rather than silently shuttling tensors to/from a - # GPU; launch with CUDA_VISIBLE_DEVICES="" (or on a CPU-only host). + # CPU-only: embeddings, reservoir, and FAISS fit all stay on the host, + # so there are no device copies. Refuse to run when CUDA is visible. if torch.cuda.is_available(): raise RuntimeError( "SidRqkmeans is CPU-only, but a CUDA device is visible. " 'Run with CUDA_VISIBLE_DEVICES="" (or on a CPU-only host).' ) - # Single-process only: the FAISS fit runs on one process over its local - # reservoir, with no cross-rank gather/broadcast. Fail fast here rather - # than after a full (wasted) training pass. + # Single-process only: the fit runs over one process's local reservoir, + # with no cross-rank gather. Fail fast before the (wasted) train pass. if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: raise RuntimeError( "SidRqkmeans supports single-process training only " @@ -109,16 +106,13 @@ def __init__( faiss_kmeans_kwargs=self._faiss_kwargs, ) - # Bounded host reservoir for the end-of-loop FAISS fit: cap at - # ``train_sample_size`` when set (>0), else the points the FAISS fit - # subsamples to (``default_fit_sample_size``) — rather than buffer the - # whole corpus. Single-process only (see the world_size guard above), - # so no per-rank split. + # Bounded host reservoir for the end-of-loop fit: cap at + # ``train_sample_size`` (when >0) else the fit's subsample size, rather + # than buffer the whole corpus. target = self._model_config.train_sample_size cap = target if target > 0 else self._quantizer.default_fit_sample_size() - # Fail fast: FAISS needs >= K points to fit each layer, so a cap below - # the largest codebook would only assert at on_train_end — after the - # whole training pass. (The default cap is always >= max(K).) + # Fail fast: a cap below the largest codebook would only fail deep in + # train_offline, after the whole training pass. max_k = max(self._n_embed_list) if cap < max_k: raise RuntimeError( @@ -224,8 +218,7 @@ def update_metric( if "quantized" in predictions: embedding = self._extract_feature(batch) - # mse aggregates (preds, target) itself; rel_loss has no torchmetrics - # equivalent, so compute it directly (only rel is needed here). + # rel_loss has no torchmetrics equivalent, so compute it directly. self._metric_modules["mse"].update(predictions["quantized"], embedding) self._metric_modules["rel_loss"].update( relative_l1(embedding, predictions["quantized"]) diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans.py index 50b2263f3..11df2b65e 100644 --- a/tzrec/modules/sid/kmeans.py +++ b/tzrec/modules/sid/kmeans.py @@ -124,8 +124,8 @@ def add(self, x: torch.Tensor) -> None: if self._buf is None: self._buf = torch.empty(cap, self._dim, dtype=torch.float32) - # Phase 1: fill empty slots first. x is already on the host (CPU-only - # model), so this is a dtype cast into the buffer, not a device copy. + # Phase 1: fill empty slots first. x is on the host, so ``.to`` is a + # dtype cast into the buffer, not a device copy. if self._n_filled < cap: take = min(x.shape[0], cap - self._n_filled) self._buf[self._n_filled : self._n_filled + take] = x[:take].to( diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 9816341e9..29ad037d1 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -198,16 +198,14 @@ def train_offline( f"need >= {max_k} points to fit the codebook (largest layer K), got N={N}" ) out = torch.zeros_like(x) - # The per-layer log reports the cumulative recon error of the original - # input x0 by the centroid sum. Without normalization the invariant - # ``out + x == x0`` holds, so x0 is reconstructed on the fly below and we - # skip the persistent (N, D) clone; with normalization x is rescaled each - # layer, breaking the invariant, so the clone is required. + # x0 (original input) feeds the per-layer recon log. Without + # normalization ``out + x == x0``, so it's rebuilt on the fly below and + # the persistent (N, D) clone is skipped; normalization rescales x and + # breaks that invariant, so clone then. x0 = x.clone() if (verbose and self.normalize_residuals) else None - # CPU-only fit: SidRqkmeans refuses to initialize when CUDA is visible, - # so the codebook is always built on CPU. Drop any stale ``gpu`` request - # from the config so a faiss-gpu build can't try to use an absent GPU. + # CPU-only fit (SidRqkmeans refuses CUDA). Drop any stale ``gpu`` kwarg + # so a faiss-gpu build can't target an absent GPU. kwargs = dict(self.faiss_kmeans_kwargs) kwargs.pop("gpu", None) if verbose: From 8bf50aa857674606ed89cff19e184e16451f0ffb Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 04:41:40 +0000 Subject: [PATCH 30/46] [refactor] SID: move init_metric/update_metric to BaseSidModel + RelativeL1 metric MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses maintainer review #3. - New custom torchmetrics metric RelativeL1 (tzrec/metrics/relative_l1.py): symmetric relative-L1 |t-p|/(max(|t|,|p|)+eps), count-weighted aggregation. A proper Metric class (like UniqueRatio), NOT torchmetrics MeanAbsolutePercentageError — MAPE's asymmetric |t-p|/|t| denominator differs from OpenOneRec's calc_loss, which this is a verbatim port of. - BaseSidModel now owns init_metric (mse + rel_loss + unique_sid_ratio) and a generic update_metric that re-extracts the target embedding and gates all eval metrics on a non-None _reconstruction() hook (so a not-yet-fitted model logs nothing). - SidRqkmeans drops its init_metric/update_metric overrides and implements _reconstruction() -> quantized (or None until the FAISS fit), inheriting the shared metric logic. Drop now-unused torchmetrics / relative_l1 imports. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/metrics/relative_l1.py | 50 ++++++++++++++++++++++++ tzrec/metrics/relative_l1_test.py | 49 +++++++++++++++++++++++ tzrec/models/sid_model.py | 65 +++++++++++++++++++++++++++---- tzrec/models/sid_rqkmeans.py | 57 +++++++-------------------- 4 files changed, 169 insertions(+), 52 deletions(-) create mode 100644 tzrec/metrics/relative_l1.py create mode 100644 tzrec/metrics/relative_l1_test.py diff --git a/tzrec/metrics/relative_l1.py b/tzrec/metrics/relative_l1.py new file mode 100644 index 000000000..72a55c28d --- /dev/null +++ b/tzrec/metrics/relative_l1.py @@ -0,0 +1,50 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torchmetrics import Metric + + +class RelativeL1(Metric): + """Mean symmetric relative-L1 error ``|t - p| / (max(|t|, |p|) + eps)``. + + A bounded reconstruction-error metric (0 = exact, → 1 = unrelated). It is a + verbatim port of OpenOneRec's residual-K-Means ``calc_loss`` and is + deliberately **not** ``torchmetrics.MeanAbsolutePercentageError``, which uses + the asymmetric ``|t - p| / |t|`` denominator. Aggregation is element-wise + (count-weighted), so the reported value is the mean over all elements seen. + """ + + higher_is_better = False + is_differentiable = True + + def __init__(self, epsilon: float = 1e-4, **kwargs) -> None: + super().__init__(**kwargs) + self.epsilon = epsilon + self.add_state("sum_rel", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("count", default=torch.tensor(0.0), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: + """Accumulate the relative-L1 error between ``preds`` and ``target``. + + Args: + preds (Tensor): reconstruction, shape (B, D). + target (Tensor): ground-truth embedding, shape (B, D). + """ + rel = torch.abs(target - preds) / ( + torch.maximum(torch.abs(target), torch.abs(preds)) + self.epsilon + ) + self.sum_rel += rel.sum() + self.count += rel.numel() + + def compute(self) -> torch.Tensor: + """Mean relative-L1 over all elements (NaN before any update).""" + return self.sum_rel / self.count diff --git a/tzrec/metrics/relative_l1_test.py b/tzrec/metrics/relative_l1_test.py new file mode 100644 index 000000000..0f89c2ccd --- /dev/null +++ b/tzrec/metrics/relative_l1_test.py @@ -0,0 +1,49 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from tzrec.metrics.relative_l1 import RelativeL1 + + +class RelativeL1Test(unittest.TestCase): + def test_zero_on_identity(self) -> None: + metric = RelativeL1() + x = torch.randn(8, 4) + metric.update(x, x.clone()) + self.assertAlmostEqual(metric.compute().item(), 0.0, places=6) + + def test_matches_formula(self) -> None: + metric = RelativeL1(epsilon=1e-4) + p = torch.tensor([[1.0, 0.0]]) + t = torch.tensor([[0.0, 2.0]]) + # |t-p|/(max(|t|,|p|)+eps): [1/(1+eps), 2/(2+eps)], mean of the two. + expected = (1.0 / (1.0 + 1e-4) + 2.0 / (2.0 + 1e-4)) / 2 + metric.update(p, t) + self.assertAlmostEqual(metric.compute().item(), expected, places=5) + + def test_count_weighted_across_updates(self) -> None: + """Aggregation is element-wise, not a mean of per-batch means.""" + metric = RelativeL1() + metric.update(torch.zeros(1, 4), torch.ones(1, 4)) # 4 elems, rel ~1 + metric.update(torch.ones(3, 4), torch.ones(3, 4)) # 12 elems, rel 0 + # Element-weighted: 4 nonzero over 16 elems -> ~0.25, NOT (1+0)/2 = 0.5. + per = 1.0 / (1.0 + 1e-4) # rel of a 0-vs-1 element (with epsilon) + self.assertAlmostEqual(metric.compute().item(), 4 * per / 16, places=6) + + def test_nan_before_update(self) -> None: + self.assertTrue(torch.isnan(RelativeL1().compute())) + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index 973fcf99f..51fd9a179 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -18,6 +18,7 @@ from tzrec.datasets.utils import BASE_DATA_GROUP, Batch from tzrec.features.feature import BaseFeature +from tzrec.metrics.relative_l1 import RelativeL1 from tzrec.metrics.unique_ratio import UniqueRatio from tzrec.models.model import BaseModel from tzrec.protos.model_pb2 import ModelConfig @@ -39,10 +40,10 @@ class BaseSidModel(BaseModel): proxy). Subclasses build their quantizer in ``__init__`` (after calling - ``super().__init__``) and implement :meth:`predict` and :meth:`loss`. - They extend :meth:`init_metric` (via ``super()``) and implement - :meth:`update_metric` to populate the registered metrics - (:meth:`update_train_metric` defaults to a no-op). + ``super().__init__``) and implement :meth:`predict`, :meth:`loss`, and + :meth:`_reconstruction` (which exposes the model's reconstruction of the + input embedding for the shared :meth:`update_metric`). + (:meth:`update_train_metric` defaults to a no-op.) Args: model_config (ModelConfig): an instance of ModelConfig. @@ -99,14 +100,62 @@ def init_loss(self) -> None: def init_metric(self) -> None: """Initialize the eval metrics shared by all SID models. - ``mse``: reconstruction error (input vs. quantized / decoded). - ``unique_sid_ratio``: mean per-batch unique-SID ratio (distinct rows / - batch size; a batch-size-sensitive diversity proxy, not global - coverage). Subclasses call ``super().init_metric()`` then add extras. + - ``mse``: reconstruction error (input vs. quantized / decoded). + - ``rel_loss``: symmetric relative-L1 reconstruction error + (:class:`~tzrec.metrics.relative_l1.RelativeL1`); meaningful only with + ``normalize_residuals=False`` (else the reconstruction and the input + live on different scales). + - ``unique_sid_ratio``: mean per-batch unique-SID ratio (distinct rows / + batch size; a batch-size-sensitive diversity proxy, not global + coverage). + + Subclasses that add extras call ``super().init_metric()`` first. """ self._metric_modules["mse"] = torchmetrics.MeanSquaredError() + self._metric_modules["rel_loss"] = RelativeL1() self._metric_modules["unique_sid_ratio"] = UniqueRatio() + def _reconstruction( + self, predictions: Dict[str, torch.Tensor] + ) -> Optional[torch.Tensor]: + """The model's reconstruction of the input embedding, or None. + + Returns the (B, D) tensor that ``mse``/``rel_loss`` compare against the + input embedding — e.g. ``predictions["quantized"]`` (RQ-KMeans) or + ``predictions["x_hat"]`` (RQ-VAE). Returns None when it is unavailable or + not yet meaningful this step (e.g. before a K-Means fit), in which case + :meth:`update_metric` skips the eval metrics entirely. + + Args: + predictions (dict): a dict of predicted result. + """ + raise NotImplementedError + + def update_metric( + self, + predictions: Dict[str, torch.Tensor], + batch: Batch, + losses: Optional[Dict[str, torch.Tensor]] = None, + ) -> None: + """Update eval metrics from a reconstruction + the re-extracted input. + + The target embedding is re-extracted from ``batch`` (it is an input, not + a model output). All three metrics are gated on a non-None + :meth:`_reconstruction` so a not-yet-fitted model does not log garbage. + + Args: + predictions (dict): a dict of predicted result. + batch (Batch): input batch data. + losses (dict, optional): a dict of loss. + """ + recon = self._reconstruction(predictions) + if recon is None: + return + embedding = self._extract_feature(batch) + self._metric_modules["mse"].update(recon, embedding) + self._metric_modules["rel_loss"].update(recon, embedding) + self._metric_modules["unique_sid_ratio"].update(predictions["codes"]) + def update_train_metric( self, predictions: Dict[str, torch.Tensor], diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 8e181b1c8..d8fd2d677 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -21,13 +21,12 @@ import torch import torch.distributed as dist -import torchmetrics from torch import nn from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature from tzrec.models.sid_model import BaseSidModel -from tzrec.modules.sid.kmeans import ReservoirSampler, relative_l1 +from tzrec.modules.sid.kmeans import ReservoirSampler from tzrec.modules.sid.residual_kmeans_quantizer import ( ResidualKMeansQuantizer, ) @@ -177,54 +176,24 @@ def loss( """ return {"dummy_loss": self._dummy_param.sum() * 0.0} - def init_metric(self) -> None: - """Register eval metrics (shared ``mse`` + ``rel_loss``). + def _reconstruction( + self, predictions: Dict[str, torch.Tensor] + ) -> Optional[torch.Tensor]: + """Centroid-sum reconstruction, or None until the codebook is fit. - Train-time metrics are intentionally absent: ``predict`` returns dummy - codes pre-fit, so the inherited no-op ``update_train_metric`` keeps the - train path empty. - """ - super().init_metric() - self._metric_modules["rel_loss"] = torchmetrics.MeanMetric() - - def update_metric( - self, - predictions: Dict[str, torch.Tensor], - batch: Batch, - losses: Optional[Dict[str, torch.Tensor]] = None, - ) -> None: - """Update metric state. - - The reconstruction target (the input embedding) is re-extracted from - ``batch`` — it is an input, not a model output. ``quantized`` is present - only in eval (see ``predict``), so this runs eval-only. - - Note: ``mse``/``rel_loss`` compare that embedding against the centroid-sum - reconstruction. They are meaningful reconstruction metrics only with - ``normalize_residuals=False`` (the default); with normalization the - centroids live on the rescaled-residual scale, so the two quantities - don't share a scale (same caveat the train_offline per-layer log carries). + ``quantized`` is present only in eval and is all-zeros before the + end-of-train FAISS fit, so gate on the fit — the shared + :meth:`BaseSidModel.update_metric` then skips the eval metrics until the + reconstruction is meaningful. (Meaningful only with + ``normalize_residuals=False``; with normalization the centroids live on + the rescaled-residual scale, so the two quantities don't share a scale.) Args: predictions (dict): a dict of predicted result. - batch (Batch): input batch data. - losses (dict, optional): a dict of loss. """ - # In-loop eval can run before the end-of-train FAISS fit; the codebook - # is all-zeros then, so codes/reconstruction are meaningless. Skip until - # fitted so those bogus values don't pollute the eval metrics. if not self._quantizer.is_fitted: - return - - if "quantized" in predictions: - embedding = self._extract_feature(batch) - # rel_loss has no torchmetrics equivalent, so compute it directly. - self._metric_modules["mse"].update(predictions["quantized"], embedding) - self._metric_modules["rel_loss"].update( - relative_l1(embedding, predictions["quantized"]) - ) - - self._metric_modules["unique_sid_ratio"].update(predictions["codes"]) + return None + return predictions.get("quantized") @torch.no_grad() def on_train_end(self) -> bool: From e8a3609dcac60ad07e79c242812dabed4ca0121e Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 04:54:39 +0000 Subject: [PATCH 31/46] [test] SID: add sid_integration_test (train -> fit -> checkpoint -> eval) Addresses maintainer review #2 (integration test in tzrec/tests/, like the match/rank integration tests). Drives a real train_eval -> eval over a tiny prepared embedding parquet and asserts on_train_end forced a final checkpoint (the codebook exists only after the fit) and a post-fit eval_result was written. Because SidRqkmeans is CPU-only + single-process, the test forces CUDA_VISIBLE_DEVICES="" and TEST_NPROC_PER_NODE=1 (the harness otherwise defaults to GPU + nproc=2). Verified passing on the DSW remote (torchrec 1.6); the local container can't run train_eval (torchrec 1.5). Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/tests/configs/sid_rqkmeans_mock.config | 54 ++++++++++ tzrec/tests/sid_integration_test.py | 105 +++++++++++++++++++ 2 files changed, 159 insertions(+) create mode 100644 tzrec/tests/configs/sid_rqkmeans_mock.config create mode 100644 tzrec/tests/sid_integration_test.py diff --git a/tzrec/tests/configs/sid_rqkmeans_mock.config b/tzrec/tests/configs/sid_rqkmeans_mock.config new file mode 100644 index 000000000..d473dd705 --- /dev/null +++ b/tzrec/tests/configs/sid_rqkmeans_mock.config @@ -0,0 +1,54 @@ +train_input_path: "" +eval_input_path: "" +model_dir: "experiments/sid_rqkmeans_mock" +train_config { + sparse_optimizer { + adagrad_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + dense_optimizer { + adam_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + num_epochs: 1 +} +eval_config { +} +data_config { + batch_size: 256 + dataset_type: ParquetDataset + fg_mode: FG_DAG + num_workers: 2 +} +feature_configs { + raw_feature { + feature_name: "item_emb" + expression: "item:embedding" + value_dim: 16 + } +} +model_config { + feature_groups { + group_name: "deep" + feature_names: "item_emb" + group_type: DEEP + } + sid_rqkmeans { + input_dim: 16 + codebook: 16 + codebook: 16 + codebook: 16 + normalize_residuals: false + embedding_feature_name: "item_emb" + faiss_kmeans_kwargs { + fields { key: "niter" value { number_value: 5 } } + fields { key: "seed" value { number_value: 42 } } + } + } +} diff --git a/tzrec/tests/sid_integration_test.py b/tzrec/tests/sid_integration_test.py new file mode 100644 index 000000000..711e69ec0 --- /dev/null +++ b/tzrec/tests/sid_integration_test.py @@ -0,0 +1,105 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import os +import shutil +import tempfile +import unittest +from unittest import mock + +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq + +from tzrec.tests import utils +from tzrec.utils import config_util + + +class SidIntegrationTest(unittest.TestCase): + def setUp(self): + self.success = False + if not os.path.exists("./tmp"): + os.makedirs("./tmp") + self.test_dir = tempfile.mkdtemp(prefix="tzrec_", dir="./tmp") + os.chmod(self.test_dir, 0o755) + # SID models are CPU-only (refuse a visible CUDA device) and + # single-process (refuse world_size > 1), so hide CUDA and pin + # nproc=1 — the GPU CI harness otherwise defaults to GPU + nproc=2. + patcher = mock.patch.dict( + os.environ, {"CUDA_VISIBLE_DEVICES": "", "TEST_NPROC_PER_NODE": "1"} + ) + patcher.start() + self.addCleanup(patcher.stop) + + def tearDown(self): + if self.success and os.path.exists(self.test_dir): + shutil.rmtree(self.test_dir) + + def _prepare_config(self, num_rows: int, dim: int) -> str: + """Write an embedding parquet + a SID config pointed at it. + + Single dense ``embedding`` column, no labels — SID reads the item + embedding straight from the batch. Returns the saved config path. + """ + data_dir = os.path.join(self.test_dir, "sid_data") + os.makedirs(data_dir, exist_ok=True) + emb = np.random.rand(num_rows, dim).astype(np.float32) + pq.write_table( + pa.table({"embedding": pa.array(list(emb))}), + os.path.join(data_dir, "part-0.parquet"), + ) + data_glob = os.path.join(data_dir, "*.parquet") + + # train_input_path set -> load_config_for_test uses it as-is (the + # FG_DAG auto-mock path is match-model-specific; SID is single-table). + config = config_util.load_pipeline_config( + "tzrec/tests/configs/sid_rqkmeans_mock.config" + ) + config.train_input_path = data_glob + config.eval_input_path = data_glob + config_path = os.path.join(self.test_dir, "sid.config") + config_util.save_message(config, config_path) + return config_path + + def test_sid_rqkmeans_train_eval(self): + """End-to-end train -> on_train_end FAISS fit -> checkpoint -> eval. + + Locks down the load-bearing path: the codebook exists only after + ``on_train_end``, which forces the final checkpoint; the post-fit eval + then reports finite reconstruction metrics. + """ + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + + config_path = self._prepare_config(num_rows=2048, dim=16) + + self.success = utils.test_train_eval(config_path, self.test_dir) + if self.success: + self.success = utils.test_eval( + os.path.join(self.test_dir, "pipeline.config"), self.test_dir + ) + self.assertTrue(self.success) + # on_train_end fitted the codebook and forced a final checkpoint. + self.assertTrue( + glob.glob(os.path.join(self.test_dir, "train", "model.ckpt-*")), + "no checkpoint persisted after on_train_end", + ) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "train", "eval_result.txt")), + "no eval_result.txt produced", + ) + + +if __name__ == "__main__": + unittest.main() From 3dfbde07693018ff0732e640045bbe5ba653c919 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 05:17:49 +0000 Subject: [PATCH 32/46] [test] checkpoint: verify force re-save overwrites the same step MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add test_force_overwrite_same_step: save step 5 (centroids=0), then re-save the SAME step with different params (centroids=7) — assert a non-force re-save dedupes (no overwrite) while a force re-save overwrites, and the reloaded checkpoint holds the later params. This is the on_train_end post-fit path: a periodic save at the final step, then a forced re-save of the fitted codebook at the same step. Verified on the DSW remote (torchrec 1.6); the local container can't import checkpoint_util (torchrec 1.5). Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/utils/checkpoint_util_test.py | 59 +++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/tzrec/utils/checkpoint_util_test.py b/tzrec/utils/checkpoint_util_test.py index 8fc6130f4..6f3b38757 100644 --- a/tzrec/utils/checkpoint_util_test.py +++ b/tzrec/utils/checkpoint_util_test.py @@ -171,6 +171,52 @@ def _remap_restore_worker(test_dir, rank, world_size, port, remap_file_path): shard_w_2_m2.gather(0) +def _force_overwrite_worker(test_dir, rank, world_size, port): + """force=True re-save at an already-saved step must overwrite it. + + Saves step 5 with centroids=0, then re-saves the SAME step with different + params (centroids=7): a non-force re-save dedupes (no overwrite), a force + re-save overwrites. Reloads and asserts the persisted step-5 checkpoint + holds the later params. (This is the on_train_end post-fit checkpoint path: + a periodic save at the final step, then a forced re-save of the fitted + codebook at the same step.) + """ + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + dist.init_process_group(backend="gloo") + + class BufModel(nn.Module): + def __init__(self, fill): + super().__init__() + self.register_buffer("centroids", torch.full((4, 3), float(fill))) + + manager = checkpoint_util.CheckpointManager(test_dir, keep_checkpoint_max=0) + + # Initial save at step 5 (pre-fit: centroids = 0). + assert manager.maybe_save(5, BufModel(0.0), final=True), "initial save" + + # Same step, different params: non-force dedupes; force overwrites. + model = BufModel(7.0) + assert not manager.maybe_save(5, model, final=True, force=False), ( + "non-force same-step save must dedupe (not overwrite)" + ) + assert manager.maybe_save(5, model, final=True, force=True), ( + "force same-step save must fire" + ) + manager.close() # drain the async prune worker + + # Reload: the persisted step-5 checkpoint must hold the LATER params (7), + # i.e. the force-save overwrote the earlier (0) one. + restored = BufModel(0.0) + checkpoint_util.restore_model(os.path.join(test_dir, "model.ckpt-5"), restored) + assert torch.allclose(restored.centroids, torch.full((4, 3), 7.0)), ( + f"overwrite failed: centroids={restored.centroids.flatten().tolist()}" + ) + dist.destroy_process_group() + + class CheckpointUtilTest(unittest.TestCase): def setUp(self): if not os.path.exists("./tmp"): @@ -327,6 +373,19 @@ def test_checkpoint_manager_discovery(self): ) self.assertEqual(manager.best_checkpoint()[1], 10) + def test_force_overwrite_same_step(self): + port = misc_util.get_free_port() + ctx = mp.get_context("spawn") + p = ctx.Process( + target=_force_overwrite_worker, args=(self.test_dir, 0, 1, port) + ) + p.start() + p.join(timeout=120) + if p.is_alive(): + p.terminate() + raise RuntimeError("force-overwrite worker timed out.") + self.assertEqual(p.exitcode, 0, "force-overwrite worker failed") + def test_dist_save_restore_model(self): port = misc_util.get_free_port() procs = [] From d67ccd1f01570c00a0390a138ab0fc1888a6abc8 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 06:22:22 +0000 Subject: [PATCH 33/46] [review] split quantizer tests by module; clarify copy=True MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit #A: residual_quantizer_test.py tested both the base ResidualQuantizer and ResidualKMeansQuantizer. Split the K-Means tests into the matching residual_kmeans_quantizer_test.py (so each module has its own test file); the base tests stay put. #B: expand the train_offline copy=True comment — the residual loop mutates x in place and the input is a view into the reservoir buffer, so it must own a fresh copy (copy=True is a single guaranteed copy vs a double-copy clone). Co-Authored-By: Claude Opus 4.8 (1M context) --- .../modules/sid/residual_kmeans_quantizer.py | 5 +- .../sid/residual_kmeans_quantizer_test.py | 114 ++++++++++++++++++ tzrec/modules/sid/residual_quantizer_test.py | 92 -------------- 3 files changed, 118 insertions(+), 93 deletions(-) create mode 100644 tzrec/modules/sid/residual_kmeans_quantizer_test.py diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 29ad037d1..0074331da 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -187,7 +187,10 @@ def train_offline( assert inputs.dim() == 2 and inputs.shape[1] == self.embed_dim, ( f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" ) - # Own one contiguous float32 copy to update in place as the residual. + # The loop below mutates x in place (the residual ``x -= q``), and the + # input is a view into the caller's float32 reservoir buffer — so own a + # fresh copy (copy=True forces one even when the dtype already matches, + # avoiding the double copy a separate ``.clone()`` would add). x = inputs.detach().to(dtype=torch.float32, copy=True).contiguous() N = x.shape[0] # Fail loudly on a too-small corpus: faiss.Kmeans only warns (not diff --git a/tzrec/modules/sid/residual_kmeans_quantizer_test.py b/tzrec/modules/sid/residual_kmeans_quantizer_test.py new file mode 100644 index 000000000..42647468e --- /dev/null +++ b/tzrec/modules/sid/residual_kmeans_quantizer_test.py @@ -0,0 +1,114 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from tzrec.modules.sid.residual_kmeans_quantizer import ( + ResidualKMeansQuantizer, +) +from tzrec.modules.sid.residual_quantizer import ( + ResidualQuantizer, +) + + +class ResidualKMeansQuantizerTest(unittest.TestCase): + def test_is_subclass(self) -> None: + rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) + self.assertIsInstance(rkq, ResidualQuantizer) + + def test_non_uniform_codebook_supported(self) -> None: + rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=3, n_embed=[8, 4, 16]) + self.assertEqual(rkq.n_embed_list, [8, 4, 16]) + self.assertEqual([layer.centroids.shape[0] for layer in rkq.layers], [8, 4, 16]) + + def test_forward_returns_zeros_before_fit(self) -> None: + rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) + self.assertFalse(all(layer.is_initialized for layer in rkq.layers)) + codes, quantized = rkq(torch.randn(5, 4)) + self.assertEqual(codes.shape, (5, 2)) + self.assertEqual(quantized.shape, (5, 4)) + + def test_forward_is_fx_traceable(self) -> None: + """Predict forward must FX-trace. + + torchrec's inference pipeline symbolically traces the model, so the + per-batch distance path must be free of data-dependent control flow. + """ + import torch.fx as fx + + torch.manual_seed(0) + rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) + for layer in rkq.layers: # populate centroids -> is_initialized=True + layer.load_centroids_(torch.randn(8, 4)) + traced = fx.symbolic_trace(rkq) + x = torch.randn(5, 4) + c_eager, q_eager = rkq(x) + c_traced, q_traced = traced(x) + torch.testing.assert_close(c_traced, c_eager) + torch.testing.assert_close(q_traced, q_eager) + + def test_train_offline_non_uniform(self) -> None: + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + torch.manual_seed(0) + n_embed = [8, 4, 16] + rkq = ResidualKMeansQuantizer( + embed_dim=4, n_layers=3, n_embed=n_embed, faiss_kmeans_kwargs={"niter": 5} + ) + rkq.train_offline(torch.randn(512, 4), verbose=False) + self.assertTrue(all(layer.is_initialized for layer in rkq.layers)) + # Each layer fit its own K centroids; codes stay in per-layer range. + codes, _ = rkq(torch.randn(7, 4)) + self.assertEqual(codes.shape, (7, 3)) + for i, k in enumerate(n_embed): + self.assertTrue((codes[:, i] >= 0).all() and (codes[:, i] < k).all()) + + def test_train_offline_then_decode(self) -> None: + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + torch.manual_seed(0) + rkq = ResidualKMeansQuantizer( + embed_dim=4, n_layers=2, n_embed=8, faiss_kmeans_kwargs={"niter": 5} + ) + rkq.train_offline(torch.randn(256, 4), verbose=False) + self.assertTrue(all(layer.is_initialized for layer in rkq.layers)) + + codes, _ = rkq(torch.randn(5, 4)) + self.assertTrue((codes >= 0).all() and (codes < 8).all()) + recon = rkq.decode_codes(codes) # inherited from the base + self.assertEqual(recon.shape, (5, 4)) + + def test_forward_get_codes_consistent(self) -> None: + """Forward ids and get_codes both route through the shared walk.""" + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + torch.manual_seed(0) + rkq = ResidualKMeansQuantizer( + embed_dim=4, n_layers=3, n_embed=8, faiss_kmeans_kwargs={"niter": 5} + ) + rkq.train_offline(torch.randn(256, 4), verbose=False) + x = torch.randn(9, 4) + fwd_ids, fwd_quant = rkq(x) + torch.testing.assert_close(rkq.get_codes(x), fwd_ids) + # forward's residual-sum equals the centroid-sum reconstruction. + torch.testing.assert_close(fwd_quant, rkq.decode_codes(fwd_ids)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/modules/sid/residual_quantizer_test.py b/tzrec/modules/sid/residual_quantizer_test.py index d23ef1cf5..c94cc545d 100644 --- a/tzrec/modules/sid/residual_quantizer_test.py +++ b/tzrec/modules/sid/residual_quantizer_test.py @@ -14,9 +14,6 @@ import torch from torch import nn -from tzrec.modules.sid.residual_kmeans_quantizer import ( - ResidualKMeansQuantizer, -) from tzrec.modules.sid.residual_quantizer import ( ResidualQuantizer, normalize_n_embed, @@ -145,94 +142,5 @@ def test_decode_codes_sum_and_dtype(self) -> None: self.assertEqual(recon16.dtype, torch.bfloat16) -class ResidualKMeansQuantizerTest(unittest.TestCase): - def test_is_subclass(self) -> None: - rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) - self.assertIsInstance(rkq, ResidualQuantizer) - - def test_non_uniform_codebook_supported(self) -> None: - rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=3, n_embed=[8, 4, 16]) - self.assertEqual(rkq.n_embed_list, [8, 4, 16]) - self.assertEqual([layer.centroids.shape[0] for layer in rkq.layers], [8, 4, 16]) - - def test_forward_returns_zeros_before_fit(self) -> None: - rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) - self.assertFalse(all(layer.is_initialized for layer in rkq.layers)) - codes, quantized = rkq(torch.randn(5, 4)) - self.assertEqual(codes.shape, (5, 2)) - self.assertEqual(quantized.shape, (5, 4)) - - def test_forward_is_fx_traceable(self) -> None: - """Predict forward must FX-trace. - - torchrec's inference pipeline symbolically traces the model, so the - per-batch distance path must be free of data-dependent control flow. - """ - import torch.fx as fx - - torch.manual_seed(0) - rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) - for layer in rkq.layers: # populate centroids -> is_initialized=True - layer.load_centroids_(torch.randn(8, 4)) - traced = fx.symbolic_trace(rkq) - x = torch.randn(5, 4) - c_eager, q_eager = rkq(x) - c_traced, q_traced = traced(x) - torch.testing.assert_close(c_traced, c_eager) - torch.testing.assert_close(q_traced, q_eager) - - def test_train_offline_non_uniform(self) -> None: - try: - import faiss # noqa: F401 - except ImportError: - self.skipTest("faiss not installed") - torch.manual_seed(0) - n_embed = [8, 4, 16] - rkq = ResidualKMeansQuantizer( - embed_dim=4, n_layers=3, n_embed=n_embed, faiss_kmeans_kwargs={"niter": 5} - ) - rkq.train_offline(torch.randn(512, 4), verbose=False) - self.assertTrue(all(layer.is_initialized for layer in rkq.layers)) - # Each layer fit its own K centroids; codes stay in per-layer range. - codes, _ = rkq(torch.randn(7, 4)) - self.assertEqual(codes.shape, (7, 3)) - for i, k in enumerate(n_embed): - self.assertTrue((codes[:, i] >= 0).all() and (codes[:, i] < k).all()) - - def test_train_offline_then_decode(self) -> None: - try: - import faiss # noqa: F401 - except ImportError: - self.skipTest("faiss not installed") - torch.manual_seed(0) - rkq = ResidualKMeansQuantizer( - embed_dim=4, n_layers=2, n_embed=8, faiss_kmeans_kwargs={"niter": 5} - ) - rkq.train_offline(torch.randn(256, 4), verbose=False) - self.assertTrue(all(layer.is_initialized for layer in rkq.layers)) - - codes, _ = rkq(torch.randn(5, 4)) - self.assertTrue((codes >= 0).all() and (codes < 8).all()) - recon = rkq.decode_codes(codes) # inherited from the base - self.assertEqual(recon.shape, (5, 4)) - - def test_forward_get_codes_consistent(self) -> None: - """Forward ids and get_codes both route through the shared walk.""" - try: - import faiss # noqa: F401 - except ImportError: - self.skipTest("faiss not installed") - torch.manual_seed(0) - rkq = ResidualKMeansQuantizer( - embed_dim=4, n_layers=3, n_embed=8, faiss_kmeans_kwargs={"niter": 5} - ) - rkq.train_offline(torch.randn(256, 4), verbose=False) - x = torch.randn(9, 4) - fwd_ids, fwd_quant = rkq(x) - torch.testing.assert_close(rkq.get_codes(x), fwd_ids) - # forward's residual-sum equals the centroid-sum reconstruction. - torch.testing.assert_close(fwd_quant, rkq.decode_codes(fwd_ids)) - - if __name__ == "__main__": unittest.main() From 6a736c582fda1a212b26d5410d39bf18469d58d2 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 06:27:25 +0000 Subject: [PATCH 34/46] [refactor] drop CheckpointManager force param; SID uses no periodic ckpts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per maintainer review: remove the `force` knob from maybe_save instead of threading it through main.py. SID models run with save_checkpoints_steps and save_checkpoints_epochs = 0, so no periodic save lands on the final step and the tail final=True save is never deduped away — `force` isn't needed. - checkpoint_util.maybe_save: drop `force`; dedupe is `step == _last_ckpt_step`. - main.py: call on_train_end() for its side effect (the fit); tail save is maybe_save(..., final=True). - BaseModel/SidRqkmeans.on_train_end now return None (the bool existed only to feed `force`). - Remove the now-obsolete checkpoint force-overwrite test; update on_train_end return assertions. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/main.py | 11 ++---- tzrec/models/model.py | 12 ++---- tzrec/models/sid_rqkmeans.py | 16 +++----- tzrec/models/sid_rqkmeans_test.py | 10 ++--- tzrec/utils/checkpoint_util.py | 14 ++----- tzrec/utils/checkpoint_util_test.py | 59 ----------------------------- 6 files changed, 22 insertions(+), 100 deletions(-) diff --git a/tzrec/main.py b/tzrec/main.py index fe4dd1079..e0b43b329 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -516,9 +516,10 @@ def run_eval(step: int, epoch: int) -> None: lr.step() # One-shot end-of-loop hook (default no-op; e.g. SidRqkmeans fits its FAISS - # codebook here). Returns True if it mutated persistable state, forcing the - # tail save below even when the last in-loop checkpoint hit the final step. - is_ckpt_after_train = _model.on_train_end() + # codebook here). SID models run with periodic checkpointing disabled + # (save_checkpoints_steps/epochs = 0), so the tail final=True save below is + # the only checkpoint and persists whatever on_train_end produced. + _model.on_train_end() _log_train( i_step, @@ -533,9 +534,6 @@ def run_eval(step: int, epoch: int) -> None: summary_writer.close() if train_config.is_profiling: prof.stop() - # ``force`` re-fires the save past maybe_save's per-step dedupe when - # on_train_end mutated persistable state (e.g. SidRqkmeans fit its codebook) - # after the last in-loop save landed on the final step. if ckpt_manager.maybe_save( i_step, model, @@ -543,7 +541,6 @@ def run_eval(step: int, epoch: int) -> None: dataloader_state, data_timestamp=data_timestamp, final=True, - force=is_ckpt_after_train, ): run_eval(i_step, i_epoch) ckpt_manager.close() diff --git a/tzrec/models/model.py b/tzrec/models/model.py index c6b2b952c..26ec63dbc 100644 --- a/tzrec/models/model.py +++ b/tzrec/models/model.py @@ -150,18 +150,14 @@ def compute_train_metric(self) -> Dict[str, torch.Tensor]: metric_results[metric_name] = metric.compute() return metric_results - def on_train_end(self) -> bool: + def on_train_end(self) -> None: """Hook fired once after the train_eval loop exits. Default no-op; override for one-shot end-of-loop work (e.g. - :class:`SidRqkmeans` fits its FAISS codebook here). - - Returns: - is_ckpt_after_train (bool): whether the hook mutated state that must - be persisted, so the loop forces a final checkpoint even if one was - already saved at the last step. Default ``False``. + :class:`SidRqkmeans` fits its FAISS codebook here). The tail + ``final=True`` checkpoint persists whatever it produced. """ - return False + return def sparse_parameters( self, diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index d8fd2d677..b2188e8c3 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -196,21 +196,18 @@ def _reconstruction( return predictions.get("quantized") @torch.no_grad() - def on_train_end(self) -> bool: + def on_train_end(self) -> None: """Fit the FAISS codebook once, after the train_eval loop exits. Overrides :meth:`BaseModel.on_train_end` (called unconditionally by ``tzrec.main``). Single-process only (enforced by the world_size guard in ``__init__``): the fit runs on one process over its local reservoir, - with no cross-rank gather/broadcast. + with no cross-rank gather/broadcast. The tail ``final=True`` checkpoint + then persists the fitted codebook (SID runs with periodic checkpointing + disabled, so that save is never deduped away). An empty reservoir only happens for a pathologically tiny corpus; the - fit is then skipped and ``False`` returned. - - Returns: - is_ckpt_after_train (bool): ``True`` if the codebook was fitted - (centroids changed → force a final checkpoint), ``False`` if the - fit was skipped (empty reservoir). + fit is then skipped. """ local = self._reservoir.sample() self._reservoir.reset() @@ -220,11 +217,10 @@ def on_train_end(self) -> bool: "[SidRqkmeans.on_train_end] empty reservoir; skipping FAISS " "fit. Did the train_eval loop run?" ) - return False + return logger.info( "[SidRqkmeans.on_train_end] fitting FAISS on %d samples (D=%d)." % (local.shape[0], local.shape[1]) ) self._quantizer.train_offline(local, verbose=True) - return True diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index 782991eac..c41cb1cf1 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -145,8 +145,8 @@ def test_on_train_end_runs_faiss(self) -> None: model.predict(_make_batch(B, input_dim)) self.assertGreater(model._reservoir.n_seen, 0) - # Trigger one-shot FAISS fit; a real fit must request a tail checkpoint - self.assertTrue(model.on_train_end()) + # Trigger one-shot FAISS fit. + model.on_train_end() # Reservoir should be released after the fit self.assertEqual(model._reservoir.n_seen, 0) @@ -213,7 +213,7 @@ def test_normalize_residuals_end_to_end(self) -> None: model.train() for _ in range(8): model.predict(_make_batch(B, input_dim)) - self.assertTrue(model.on_train_end()) + model.on_train_end() for layer in model._quantizer.layers: self.assertTrue(layer.is_initialized) @@ -289,8 +289,8 @@ def test_update_metric_skipped_before_fit(self) -> None: def test_on_train_end_noop_on_empty_buffer(self) -> None: """on_train_end on an empty buffer is a warned no-op.""" model = self._create_model() - # No fit happened, so no tail checkpoint is requested. - self.assertFalse(model.on_train_end()) # should not raise + model.on_train_end() # warns and returns without fitting; must not raise + self.assertFalse(model._quantizer.is_fitted) def test_init_raises_under_ddp(self) -> None: """SidRqkmeans is single-process only: world_size>1 fails fast in init.""" diff --git a/tzrec/utils/checkpoint_util.py b/tzrec/utils/checkpoint_util.py index 612cf023d..c601fd432 100644 --- a/tzrec/utils/checkpoint_util.py +++ b/tzrec/utils/checkpoint_util.py @@ -399,7 +399,6 @@ def maybe_save( epoch: Optional[int] = None, data_timestamp: float = -1.0, final: bool = False, - force: bool = False, ) -> bool: """Save a checkpoint if a configured trigger fires; return whether it did. @@ -418,15 +417,8 @@ def maybe_save( epoch: current epoch; enables the epoch trigger when not None. data_timestamp: this rank's consumed event-time (seconds), -1.0 if none; reconciled across workers (quorum) for the event-time trigger. - final: request a save unconditionally (still subject to the dedupe), - e.g. at train end. This sets ``want``; it does not bypass the - per-step dedupe — that is what ``force`` is for. - force: bypass the per-step dedupe so a wanted save fires even if this - step was already saved — e.g. when end-of-train work mutated the - model state at the already-saved final step (see ``on_train_end``). - Orthogonal to ``final``: ``force`` only relaxes the dedupe and has - no effect on its own (it still needs ``want``, which ``final`` or a - cadence trigger supplies). + final: request a save unconditionally (still subject to the per-step + dedupe), e.g. at train end. Returns: True if a checkpoint was saved. @@ -452,7 +444,7 @@ def maybe_save( ): want = True - if not want or (step == self._last_ckpt_step and not force): + if not want or step == self._last_ckpt_step: return False self._last_ckpt_step = step diff --git a/tzrec/utils/checkpoint_util_test.py b/tzrec/utils/checkpoint_util_test.py index 6f3b38757..8fc6130f4 100644 --- a/tzrec/utils/checkpoint_util_test.py +++ b/tzrec/utils/checkpoint_util_test.py @@ -171,52 +171,6 @@ def _remap_restore_worker(test_dir, rank, world_size, port, remap_file_path): shard_w_2_m2.gather(0) -def _force_overwrite_worker(test_dir, rank, world_size, port): - """force=True re-save at an already-saved step must overwrite it. - - Saves step 5 with centroids=0, then re-saves the SAME step with different - params (centroids=7): a non-force re-save dedupes (no overwrite), a force - re-save overwrites. Reloads and asserts the persisted step-5 checkpoint - holds the later params. (This is the on_train_end post-fit checkpoint path: - a periodic save at the final step, then a forced re-save of the fitted - codebook at the same step.) - """ - os.environ["RANK"] = str(rank) - os.environ["WORLD_SIZE"] = str(world_size) - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = str(port) - dist.init_process_group(backend="gloo") - - class BufModel(nn.Module): - def __init__(self, fill): - super().__init__() - self.register_buffer("centroids", torch.full((4, 3), float(fill))) - - manager = checkpoint_util.CheckpointManager(test_dir, keep_checkpoint_max=0) - - # Initial save at step 5 (pre-fit: centroids = 0). - assert manager.maybe_save(5, BufModel(0.0), final=True), "initial save" - - # Same step, different params: non-force dedupes; force overwrites. - model = BufModel(7.0) - assert not manager.maybe_save(5, model, final=True, force=False), ( - "non-force same-step save must dedupe (not overwrite)" - ) - assert manager.maybe_save(5, model, final=True, force=True), ( - "force same-step save must fire" - ) - manager.close() # drain the async prune worker - - # Reload: the persisted step-5 checkpoint must hold the LATER params (7), - # i.e. the force-save overwrote the earlier (0) one. - restored = BufModel(0.0) - checkpoint_util.restore_model(os.path.join(test_dir, "model.ckpt-5"), restored) - assert torch.allclose(restored.centroids, torch.full((4, 3), 7.0)), ( - f"overwrite failed: centroids={restored.centroids.flatten().tolist()}" - ) - dist.destroy_process_group() - - class CheckpointUtilTest(unittest.TestCase): def setUp(self): if not os.path.exists("./tmp"): @@ -373,19 +327,6 @@ def test_checkpoint_manager_discovery(self): ) self.assertEqual(manager.best_checkpoint()[1], 10) - def test_force_overwrite_same_step(self): - port = misc_util.get_free_port() - ctx = mp.get_context("spawn") - p = ctx.Process( - target=_force_overwrite_worker, args=(self.test_dir, 0, 1, port) - ) - p.start() - p.join(timeout=120) - if p.is_alive(): - p.terminate() - raise RuntimeError("force-overwrite worker timed out.") - self.assertEqual(p.exitcode, 0, "force-overwrite worker failed") - def test_dist_save_restore_model(self): port = misc_util.get_free_port() procs = [] From 5bc89d4e193269e2772f0ca891fbe133002ae9f1 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 06:33:12 +0000 Subject: [PATCH 35/46] [refactor] typed FaissKmeansConfig proto; drop Struct + _coerce_proto_numbers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per maintainer review: replace the loosely-typed google.protobuf.Struct faiss_kmeans_kwargs with a strictly-typed FaissKmeansConfig message (niter, nredo, seed, max/min_points_per_centroid, spherical, verbose). Struct numbers arrive as floats and _coerce_proto_numbers heuristically int-ified them — a typed message is type-safe and removes that hack. gpu is omitted (CPU-only). SidRqkmeans builds the faiss kwargs from the typed message's set fields (ListFields), so unset fields fall back to faiss's own defaults. Updated the mock config and the test builder to the typed form. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 24 ++++---------------- tzrec/models/sid_rqkmeans_test.py | 7 +++--- tzrec/protos/models/sid_model.proto | 20 ++++++++++++---- tzrec/tests/configs/sid_rqkmeans_mock.config | 4 ++-- 4 files changed, 25 insertions(+), 30 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index b2188e8c3..17b94e1ff 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -31,22 +31,9 @@ ResidualKMeansQuantizer, ) from tzrec.protos.model_pb2 import ModelConfig -from tzrec.utils import config_util from tzrec.utils.logging_util import logger -def _coerce_proto_numbers(d: Dict) -> Dict: - """Coerce whole-valued floats back to int. - - ``Struct.number_value`` is always float, but faiss.Kmeans kwargs - (``niter``, ``seed``, ...) need ``int``. - """ - return { - k: int(v) if isinstance(v, float) and v.is_integer() else v - for k, v in d.items() - } - - class SidRqkmeans(BaseSidModel): """SID generation model using residual K-Means (FAISS-only). @@ -90,12 +77,11 @@ def __init__( cfg = self._model_config # SidRqkmeans proto message - # config_to_kwargs yields Struct numbers as floats; coerce back to int. - self._faiss_kwargs = ( - _coerce_proto_numbers(config_util.config_to_kwargs(cfg.faiss_kmeans_kwargs)) - if cfg.HasField("faiss_kmeans_kwargs") - else {} - ) + # Typed faiss kwargs: only the explicitly-set fields are forwarded, so + # unset ones fall back to faiss's own defaults (no float->int coercion). + self._faiss_kwargs = { + f.name: v for f, v in cfg.faiss_kmeans_kwargs.ListFields() + } self._quantizer = ResidualKMeansQuantizer( embed_dim=self._input_dim, diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index c41cb1cf1..db7fc6143 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -63,11 +63,10 @@ def _create_model( SID models read the item-embedding dense feature directly from the batch and do not consume feature_groups, so none is set. """ - from google.protobuf.struct_pb2 import Struct - n_embed_list = codebook if codebook is not None else [16] * n_layers - faiss_kwargs = Struct() - faiss_kwargs.update({"niter": niter, "verbose": False, "seed": 1234}) + faiss_kwargs = sid_model_pb2.FaissKmeansConfig( + niter=niter, verbose=False, seed=1234 + ) cfg = sid_model_pb2.SidRqkmeans( input_dim=input_dim, codebook=n_embed_list, diff --git a/tzrec/protos/models/sid_model.proto b/tzrec/protos/models/sid_model.proto index fdd41a22c..f6f07da2f 100644 --- a/tzrec/protos/models/sid_model.proto +++ b/tzrec/protos/models/sid_model.proto @@ -1,7 +1,19 @@ syntax = "proto2"; package tzrec.protos; -import "google/protobuf/struct.proto"; +// Strictly-typed subset of faiss.Kmeans(D, K, **kwargs) knobs. Unset fields +// fall back to faiss's own defaults (so it is safe to leave partially set). +// ``gpu`` is intentionally omitted — the fit is CPU-only (SidRqkmeans refuses +// a visible CUDA device). +message FaissKmeansConfig { + optional uint32 niter = 1; + optional uint32 nredo = 2; + optional uint32 seed = 3; + optional uint32 max_points_per_centroid = 4; + optional uint32 min_points_per_centroid = 5; + optional bool spherical = 6; + optional bool verbose = 7; +} message SidRqkmeans { // Input embedding dimension (K-Means runs directly on raw embeddings, @@ -15,10 +27,8 @@ message SidRqkmeans { repeated uint32 codebook = 3; // L2-normalize residuals before each layer. optional bool normalize_residuals = 4 [default = false]; - // Extra kwargs forwarded to faiss.Kmeans(D, K, **kwargs) as a - // loosely-typed dict, e.g. {niter: 20, gpu: true, verbose: true, - // spherical: false, seed: 1234}. - optional google.protobuf.Struct faiss_kmeans_kwargs = 5; + // Strictly-typed extra kwargs forwarded to faiss.Kmeans(D, K, **kwargs). + optional FaissKmeansConfig faiss_kmeans_kwargs = 5; // Target number of embeddings to reservoir-sample for the FAISS fit // (global, across all ranks). Bounds host memory regardless of corpus // size. 0 (the default) auto-derives it as max(K) * max_points_per_centroid diff --git a/tzrec/tests/configs/sid_rqkmeans_mock.config b/tzrec/tests/configs/sid_rqkmeans_mock.config index d473dd705..0aad49cfb 100644 --- a/tzrec/tests/configs/sid_rqkmeans_mock.config +++ b/tzrec/tests/configs/sid_rqkmeans_mock.config @@ -47,8 +47,8 @@ model_config { normalize_residuals: false embedding_feature_name: "item_emb" faiss_kmeans_kwargs { - fields { key: "niter" value { number_value: 5 } } - fields { key: "seed" value { number_value: 42 } } + niter: 5 + seed: 42 } } } From feeb4af1eccd79cd3a4461e365f4ee3fc7708ec1 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 06:39:42 +0000 Subject: [PATCH 36/46] [refactor] add QuantizeLayer base; KMeansLayer -> KMeansQuantizeLayer Per maintainer review: introduce a QuantizeLayer ABC (quantize / lookup / get_codebook_embeddings) so the K-Means and (PR3) RQ-VAE vector-quantize layers share one interface and the residual quantizer drives either uniformly. - new types.py: QuantizeOutput(embeddings, ids) NamedTuple (matches the PR3 feat/sid_abstract definition for a clean merge). - kmeans.py: add QuantizeLayer(nn.Module) ABC; rename KMeansLayer -> KMeansQuantizeLayer(QuantizeLayer); replace predict() with quantize()->QuantizeOutput (incl. the uninitialized-zeros path) + lookup() + get_codebook_embeddings(). - ResidualKMeansQuantizer: _quantize_layer/_lookup_code/get_codebook_embeddings delegate to the layer's quantize/lookup/get_codebook_embeddings. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/modules/sid/kmeans.py | 77 ++++++++++++++----- tzrec/modules/sid/kmeans_test.py | 36 +++++---- .../modules/sid/residual_kmeans_quantizer.py | 30 +++----- tzrec/modules/sid/types.py | 28 +++++++ 4 files changed, 120 insertions(+), 51 deletions(-) create mode 100644 tzrec/modules/sid/types.py diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans.py index 11df2b65e..5701230b7 100644 --- a/tzrec/modules/sid/kmeans.py +++ b/tzrec/modules/sid/kmeans.py @@ -14,20 +14,23 @@ This module is the single home for torch-native K-Means code used by SID models: -* :class:`KMeansLayer` — per-layer centroid container used by - :class:`ResidualKMeansQuantizer`. Centroids are injected - by the FAISS backend via ``load_centroids_``; the only forward path - is ``predict``. +* :class:`QuantizeLayer` — the per-layer quantizer interface + (``quantize`` / ``lookup`` / ``get_codebook_embeddings``) shared with the + RQ-VAE backend's vector-quantize layer. +* :class:`KMeansQuantizeLayer` — the K-Means implementation: a centroid + container populated by the FAISS backend via ``load_centroids_``. * :class:`ReservoirSampler` — bounded uniform stream sample (Vitter Algorithm R) that :class:`~tzrec.models.sid_rqkmeans.SidRqkmeans` fills during training to feed the one-shot FAISS fit. """ +from abc import abstractmethod from typing import Optional, Tuple import torch from torch import nn +from tzrec.modules.sid.types import QuantizeOutput from tzrec.utils.logging_util import logger @@ -162,11 +165,35 @@ def reset(self) -> None: self._n_seen = 0 -class KMeansLayer(nn.Module): +class QuantizeLayer(nn.Module): + """One quantize layer: assign inputs to a codebook and look codes up. + + Shared interface for the K-Means backend (:class:`KMeansQuantizeLayer`) + and the RQ-VAE backend's vector-quantize layer, so the residual quantizer + can drive either uniformly. + """ + + @abstractmethod + def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: + """Assign ``x`` (B, D) to the codebook, returning codes + embeddings.""" + raise NotImplementedError + + @abstractmethod + def lookup(self, ids: torch.Tensor) -> torch.Tensor: + """Gather codebook embeddings for ``ids``.""" + raise NotImplementedError + + @abstractmethod + def get_codebook_embeddings(self) -> torch.Tensor: + """Return the full codebook, shape (n_clusters, D).""" + raise NotImplementedError + + +class KMeansQuantizeLayer(QuantizeLayer): """Single layer of a residual K-Means stack. Centroids are populated externally by ``load_centroids_`` (the FAISS - backend in :class:`ResidualKMeansQuantizer`); ``predict`` is the only + backend in :class:`ResidualKMeansQuantizer`); ``quantize`` is the only forward path. Args: @@ -198,11 +225,7 @@ def is_initialized(self) -> bool: return self._initialized def mark_initialized_(self) -> None: - """Flag centroids populated, syncing buffer + cached mirror. - - For callers that fill ``centroids`` in place (e.g. the DDP broadcast - in :meth:`SidRqkmeans.on_train_end`) rather than via ``load_centroids_``. - """ + """Flag centroids populated, syncing buffer + cached mirror.""" self._is_initialized.fill_(True) self._initialized = True @@ -247,23 +270,39 @@ def _load_from_state_dict( self._initialized = bool(self._is_initialized.item()) if self._initialized and self.centroids.abs().sum() == 0: error_msgs.append( - f"KMeansLayer at '{prefix}': _is_initialized=True but centroids " - "are all zero — checkpoint was likely taken mid-FAISS-fit. " - "Re-run on_train_end to produce a valid checkpoint." + f"KMeansQuantizeLayer at '{prefix}': _is_initialized=True but " + "centroids are all zero — checkpoint was likely taken " + "mid-FAISS-fit. Re-run on_train_end to produce a valid checkpoint." ) @torch.no_grad() - def predict(self, batch: torch.Tensor) -> torch.Tensor: - """Assign points to nearest centroid. + def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: + """Assign points to the nearest centroid and gather them. Uses ``torch.cdist`` (L2); argmin is invariant to the monotonic sqrt, so assignments match squared-L2 except at exact equidistant ties (measure zero for real embeddings), where either centroid is valid. + Before the FAISS fit (uninitialized) this returns all-zero codes + + embeddings so the residual walk stays a no-op and the model is callable. + ``temperature`` is unused (no soft assignment). Args: - batch (Tensor): data points, shape (B, D). + x (Tensor): data points, shape (B, D). + temperature (float): unused. Returns: - Tensor: cluster indices, shape (B,). + QuantizeOutput: ``ids`` (B,) and ``embeddings`` (B, D). """ - return torch.cdist(batch, self.centroids).argmin(dim=-1) + if not self.is_initialized: + ids = torch.zeros(x.shape[0], dtype=torch.long, device=x.device) + return QuantizeOutput(embeddings=torch.zeros_like(x), ids=ids) + ids = torch.cdist(x, self.centroids).argmin(dim=-1) + return QuantizeOutput(embeddings=self.centroids[ids], ids=ids) + + def lookup(self, ids: torch.Tensor) -> torch.Tensor: + """Gather centroids for ``ids``, shape (..., D).""" + return self.centroids[ids] + + def get_codebook_embeddings(self) -> torch.Tensor: + """Return the centroid table, shape (n_clusters, n_features).""" + return self.centroids diff --git a/tzrec/modules/sid/kmeans_test.py b/tzrec/modules/sid/kmeans_test.py index d6b06a7f1..66a8de1a9 100644 --- a/tzrec/modules/sid/kmeans_test.py +++ b/tzrec/modules/sid/kmeans_test.py @@ -14,7 +14,7 @@ import torch from tzrec.modules.sid.kmeans import ( - KMeansLayer, + KMeansQuantizeLayer, ReservoirSampler, recon_diagnostics, ) @@ -30,42 +30,52 @@ def test_recon_diagnostics_zero_on_identity(self) -> None: self.assertAlmostEqual(rel.item(), 0.0, places=6) -class KMeansLayerTest(unittest.TestCase): - """Tests for the single KMeansLayer.""" +class KMeansQuantizeLayerTest(unittest.TestCase): + """Tests for the single KMeansQuantizeLayer.""" def test_uninitialized_by_default(self) -> None: - layer = KMeansLayer(n_clusters=4, n_features=3) + layer = KMeansQuantizeLayer(n_clusters=4, n_features=3) self.assertFalse(layer.is_initialized) self.assertEqual(layer.centroids.abs().sum().item(), 0.0) - def test_load_centroids_and_predict(self) -> None: - layer = KMeansLayer(n_clusters=2, n_features=2) + def test_load_centroids_and_quantize(self) -> None: + layer = KMeansQuantizeLayer(n_clusters=2, n_features=2) centroids = torch.tensor([[0.0, 0.0], [10.0, 10.0]]) layer.load_centroids_(centroids) self.assertTrue(layer.is_initialized) batch = torch.tensor([[0.1, 0.0], [9.0, 11.0]]) - codes = layer.predict(batch) - torch.testing.assert_close(codes, torch.tensor([0, 1])) + out = layer.quantize(batch) + torch.testing.assert_close(out.ids, torch.tensor([0, 1])) + # embeddings are the gathered centroids; lookup matches. + torch.testing.assert_close(out.embeddings, centroids[out.ids]) + torch.testing.assert_close(layer.lookup(out.ids), out.embeddings) + + def test_quantize_uninitialized_returns_zeros(self) -> None: + layer = KMeansQuantizeLayer(n_clusters=4, n_features=3) + out = layer.quantize(torch.randn(5, 3)) + self.assertEqual(out.ids.shape, (5,)) + self.assertEqual(int(out.ids.abs().sum()), 0) + torch.testing.assert_close(out.embeddings, torch.zeros(5, 3)) def test_load_centroids_shape_mismatch_raises(self) -> None: - layer = KMeansLayer(n_clusters=2, n_features=2) + layer = KMeansQuantizeLayer(n_clusters=2, n_features=2) with self.assertRaises(AssertionError): layer.load_centroids_(torch.zeros(3, 2)) def test_mid_fit_checkpoint_rejected(self) -> None: - layer = KMeansLayer(n_clusters=2, n_features=2) + layer = KMeansQuantizeLayer(n_clusters=2, n_features=2) sd = layer.state_dict() # Simulate a mid-fit checkpoint: flag True but centroids still zero. sd["_is_initialized"] = torch.tensor(True) - fresh = KMeansLayer(n_clusters=2, n_features=2) + fresh = KMeansQuantizeLayer(n_clusters=2, n_features=2) with self.assertRaisesRegex(RuntimeError, "mid-FAISS-fit"): fresh.load_state_dict(sd) def test_post_fit_checkpoint_round_trips(self) -> None: - layer = KMeansLayer(n_clusters=2, n_features=2) + layer = KMeansQuantizeLayer(n_clusters=2, n_features=2) layer.load_centroids_(torch.tensor([[1.0, 2.0], [3.0, 4.0]])) - fresh = KMeansLayer(n_clusters=2, n_features=2) + fresh = KMeansQuantizeLayer(n_clusters=2, n_features=2) fresh.load_state_dict(layer.state_dict()) self.assertTrue(fresh.is_initialized) torch.testing.assert_close(fresh.centroids, layer.centroids) diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 0074331da..2b9f522c6 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -23,7 +23,7 @@ from torch import nn from torch.nn import functional as F -from tzrec.modules.sid.kmeans import KMeansLayer, recon_diagnostics +from tzrec.modules.sid.kmeans import KMeansQuantizeLayer, recon_diagnostics from tzrec.modules.sid.residual_quantizer import ResidualQuantizer from tzrec.utils.logging_util import logger @@ -35,8 +35,7 @@ class ResidualKMeansQuantizer(ResidualQuantizer): residual_0 = input for each layer i: (optionally) residual_i = L2_normalize(residual_i) - code_i = layer_i.predict(residual_i) - quantized_i = layer_i.centroids[code_i] + code_i, quantized_i = layer_i.quantize(residual_i) residual_{i+1} = residual_i - quantized_i output = sum of all quantized_i @@ -72,7 +71,7 @@ def __init__( self.layers = nn.ModuleList( [ - KMeansLayer( + KMeansQuantizeLayer( n_clusters=self.n_embed_list[i], n_features=embed_dim, ) @@ -86,29 +85,22 @@ def _quantize_layer( residual: torch.Tensor, temperature: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Nearest-centroid assignment for one layer. + """Nearest-centroid assignment for one layer (delegates to the layer). Uninitialized layers (before ``train_offline``) return zeros, so the - residual walk is a no-op and the model stays callable. ``temperature`` - is unused (no soft assignment). + residual walk is a no-op and the model stays callable. Args: layer_idx (int): quantization layer index. residual (Tensor): current residual, shape (B, D). - temperature (float): unused. + temperature (float): unused (no soft assignment). Returns: codes (Tensor): cluster indices, shape (B,). quantized (Tensor): selected centroids, shape (B, D). """ - layer = self.layers[layer_idx] - if not layer.is_initialized: - codes = torch.zeros( - residual.shape[0], dtype=torch.long, device=residual.device - ) - return codes, torch.zeros_like(residual) - codes = layer.predict(residual) - return codes, layer.centroids[codes] + out = self.layers[layer_idx].quantize(residual, temperature) + return out.ids, out.embeddings def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Assign codes per layer and sum the centroids. @@ -146,11 +138,11 @@ def get_codebook_embeddings(self, layer_idx: int) -> torch.Tensor: Returns: Tensor: centroids, shape (n_embed, embed_dim). """ - return self.layers[layer_idx].centroids + return self.layers[layer_idx].get_codebook_embeddings() def _lookup_code(self, layer_idx: int, code_idx: torch.Tensor) -> torch.Tensor: """Look up codebook vectors via the layer's centroid table.""" - return self.layers[layer_idx].centroids[code_idx] + return self.layers[layer_idx].lookup(code_idx) def default_fit_sample_size(self) -> int: """Points the FAISS fit subsamples to: max(K) * max_points_per_centroid. @@ -195,7 +187,7 @@ def train_offline( N = x.shape[0] # Fail loudly on a too-small corpus: faiss.Kmeans only warns (not # errors) when N < K and returns a degenerate codebook, which the - # all-zero poison guard in KMeansLayer would not catch. + # all-zero poison guard in KMeansQuantizeLayer would not catch. max_k = max(self.n_embed_list) assert N >= max_k, ( f"need >= {max_k} points to fit the codebook (largest layer K), got N={N}" diff --git a/tzrec/modules/sid/types.py b/tzrec/modules/sid/types.py new file mode 100644 index 000000000..2f0cf3c60 --- /dev/null +++ b/tzrec/modules/sid/types.py @@ -0,0 +1,28 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data types for SID generation: output tuples shared across quantizers.""" + +from typing import NamedTuple + +import torch + + +class QuantizeOutput(NamedTuple): + """One quantize layer's output. + + Attributes: + embeddings (Tensor): quantized embeddings, shape (B, D). + ids (Tensor): codebook indices, shape (B,). + """ + + embeddings: torch.Tensor + ids: torch.Tensor From a5d43b2e4640a0cabf0793caf65021da2b40c295 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 07:13:57 +0000 Subject: [PATCH 37/46] [refactor] unify reconstruction key to x_hat; drop _reconstruction hook MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit predict now exposes the reconstruction under predictions["x_hat"] (the same key RQ-VAE uses) instead of "quantized", and only in eval once the codebook is fit. With the key and readiness decided by the producer, BaseSidModel.update_metric is fully concrete — it gates on `"x_hat" in predictions` and needs no per-model _reconstruction hook (removed). RQ-VAE reuses update_metric as-is (it already emits x_hat); SidRqkmeans just gates the x_hat exposure on _quantizer.is_fitted. Update the predict-contract test to assert {codes, x_hat}. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_model.py | 36 ++++++++++--------------------- tzrec/models/sid_rqkmeans.py | 28 ++++++------------------ tzrec/models/sid_rqkmeans_test.py | 10 ++++----- 3 files changed, 23 insertions(+), 51 deletions(-) diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index 51fd9a179..c0a0e9e56 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -40,9 +40,9 @@ class BaseSidModel(BaseModel): proxy). Subclasses build their quantizer in ``__init__`` (after calling - ``super().__init__``) and implement :meth:`predict`, :meth:`loss`, and - :meth:`_reconstruction` (which exposes the model's reconstruction of the - input embedding for the shared :meth:`update_metric`). + ``super().__init__``) and implement :meth:`predict` and :meth:`loss`. + :meth:`predict` exposes the reconstruction under ``predictions["x_hat"]`` + (only when meaningful) so the shared :meth:`update_metric` can score it. (:meth:`update_train_metric` defaults to a no-op.) Args: @@ -115,42 +115,28 @@ def init_metric(self) -> None: self._metric_modules["rel_loss"] = RelativeL1() self._metric_modules["unique_sid_ratio"] = UniqueRatio() - def _reconstruction( - self, predictions: Dict[str, torch.Tensor] - ) -> Optional[torch.Tensor]: - """The model's reconstruction of the input embedding, or None. - - Returns the (B, D) tensor that ``mse``/``rel_loss`` compare against the - input embedding — e.g. ``predictions["quantized"]`` (RQ-KMeans) or - ``predictions["x_hat"]`` (RQ-VAE). Returns None when it is unavailable or - not yet meaningful this step (e.g. before a K-Means fit), in which case - :meth:`update_metric` skips the eval metrics entirely. - - Args: - predictions (dict): a dict of predicted result. - """ - raise NotImplementedError - def update_metric( self, predictions: Dict[str, torch.Tensor], batch: Batch, losses: Optional[Dict[str, torch.Tensor]] = None, ) -> None: - """Update eval metrics from a reconstruction + the re-extracted input. + """Update eval metrics from the reconstruction + the re-extracted input. - The target embedding is re-extracted from ``batch`` (it is an input, not - a model output). All three metrics are gated on a non-None - :meth:`_reconstruction` so a not-yet-fitted model does not log garbage. + ``predictions["x_hat"]`` is the model's reconstruction of the input + embedding (the centroid sum for RQ-KMeans, the decoder output for + RQ-VAE). Subclasses expose it only when it is meaningful, so a + not-yet-fitted model omits it and this logs nothing. The target + embedding is re-extracted from ``batch`` (it is an input, not an output). Args: predictions (dict): a dict of predicted result. batch (Batch): input batch data. losses (dict, optional): a dict of loss. """ - recon = self._reconstruction(predictions) - if recon is None: + if "x_hat" not in predictions: return + recon = predictions["x_hat"] embedding = self._extract_feature(batch) self._metric_modules["mse"].update(recon, embedding) self._metric_modules["rel_loss"].update(recon, embedding) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 17b94e1ff..88568c051 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -140,8 +140,13 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: "codes": codes, } - if self.is_eval: - predictions["quantized"] = quantized + # Expose the centroid-sum reconstruction (``x_hat``, the scoring target + # for update_metric) only in eval AND once the codebook is fit — before + # on_train_end it is all-zeros, so omitting it makes update_metric skip. + # (Meaningful only with normalize_residuals=False; with normalization the + # centroids live on the rescaled-residual scale, off the input's scale.) + if self.is_eval and self._quantizer.is_fitted: + predictions["x_hat"] = quantized return predictions @@ -162,25 +167,6 @@ def loss( """ return {"dummy_loss": self._dummy_param.sum() * 0.0} - def _reconstruction( - self, predictions: Dict[str, torch.Tensor] - ) -> Optional[torch.Tensor]: - """Centroid-sum reconstruction, or None until the codebook is fit. - - ``quantized`` is present only in eval and is all-zeros before the - end-of-train FAISS fit, so gate on the fit — the shared - :meth:`BaseSidModel.update_metric` then skips the eval metrics until the - reconstruction is meaningful. (Meaningful only with - ``normalize_residuals=False``; with normalization the centroids live on - the rescaled-residual scale, so the two quantities don't share a scale.) - - Args: - predictions (dict): a dict of predicted result. - """ - if not self._quantizer.is_fitted: - return None - return predictions.get("quantized") - @torch.no_grad() def on_train_end(self) -> None: """Fit the FAISS codebook once, after the train_eval loop exits. diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index db7fc6143..ecc96db86 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -223,7 +223,7 @@ def test_normalize_residuals_end_to_end(self) -> None: self.assertTrue((codes >= 0).all() and (codes < 16).all()) def test_eval_and_inference_predict_contract(self) -> None: - """Eval exposes codes + quantized only; inference is codes-only.""" + """Eval (post-fit) exposes codes + x_hat; inference is codes-only.""" try: import faiss # noqa: F401 except ImportError: @@ -236,12 +236,12 @@ def test_eval_and_inference_predict_contract(self) -> None: model.predict(_make_batch(B, input_dim)) model.on_train_end() - # Eval mode: the centroid-sum reconstruction is exposed for - # update_metric; the input embedding is NOT threaded through - # predictions (it is re-extracted from the batch in update_metric). + # Eval mode (fitted): the reconstruction is exposed as ``x_hat`` for + # update_metric; the input embedding is re-extracted from the batch + # there, not threaded through predictions. model.eval() eval_preds = model.predict(_make_batch(B, input_dim)) - self.assertEqual(set(eval_preds.keys()), {"codes", "quantized"}) + self.assertEqual(set(eval_preds.keys()), {"codes", "x_hat"}) # Inference (serving) mode: codes-only contract. model.set_is_inference(True) From c4c361a96253ef97e9de9fedc7f04d60e3fd6ac6 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 07:16:47 +0000 Subject: [PATCH 38/46] [style] SID: trim redundant comments Tighten a few over-long/obvious comments in predict (x_hat exposure, train branch), drop the `cfg = self._model_config` inline note, and shorten the host-tensor assert comment in train_offline. Comments only; load-bearing "why" notes kept. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 13 +++++-------- tzrec/modules/sid/residual_kmeans_quantizer.py | 5 ++--- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 88568c051..28c8a16bc 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -75,7 +75,7 @@ def __init__( "Launch with --nproc-per-node=1." ) - cfg = self._model_config # SidRqkmeans proto message + cfg = self._model_config # Typed faiss kwargs: only the explicitly-set fields are forwarded, so # unset ones fall back to faiss's own defaults (no float->int coercion). @@ -123,8 +123,7 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """ embedding = self._extract_feature(batch) - # Training: just reservoir-sample for the end-of-loop FAISS fit and - # return dummy codes — the codebook does not exist yet. + # Training: reservoir-sample only; codes are dummy until the fit. if self.is_train: self._reservoir.add(embedding) B = embedding.shape[0] @@ -140,11 +139,9 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: "codes": codes, } - # Expose the centroid-sum reconstruction (``x_hat``, the scoring target - # for update_metric) only in eval AND once the codebook is fit — before - # on_train_end it is all-zeros, so omitting it makes update_metric skip. - # (Meaningful only with normalize_residuals=False; with normalization the - # centroids live on the rescaled-residual scale, off the input's scale.) + # Expose the centroid-sum reconstruction (``x_hat``) for update_metric + # only once fitted — pre-fit it is all-zeros, so omitting it skips the + # eval metrics. (Meaningful only with normalize_residuals=False.) if self.is_eval and self._quantizer.is_fitted: predictions["x_hat"] = quantized diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 2b9f522c6..3ed9d7087 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -172,9 +172,8 @@ def train_offline( owned float32 tensor; not mutated. verbose (bool): print per-layer reconstruction loss. Default: True. """ - # CPU-only: SidRqkmeans refuses to init when CUDA is visible, but this - # quantizer is a standalone module — assert the host-tensor contract it - # relies on so misuse fails here, not deep inside faiss. + # Assert the host-tensor contract locally (this is a standalone module) + # so misuse fails here, not deep inside faiss. assert not inputs.is_cuda, "train_offline is CPU-only; got a CUDA tensor" assert inputs.dim() == 2 and inputs.shape[1] == self.embed_dim, ( f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" From db7f2beb8e726ba116b89fca8eff55c8e51c482b Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 07:21:21 +0000 Subject: [PATCH 39/46] [refactor] QuantizeLayer: make lookup concrete in the base MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit lookup(ids) is backend-independent given the codebook, so define it once in QuantizeLayer as get_codebook_embeddings()[ids] and drop KMeansQuantizeLayer's override. get_codebook_embeddings stays abstract — the codebook lives in a backend-specific attribute (centroids buffer vs nn.Embedding), so only it (and quantize) need a per-subclass implementation. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/modules/sid/kmeans.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans.py index 5701230b7..9dc15fb55 100644 --- a/tzrec/modules/sid/kmeans.py +++ b/tzrec/modules/sid/kmeans.py @@ -178,14 +178,18 @@ def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: """Assign ``x`` (B, D) to the codebook, returning codes + embeddings.""" raise NotImplementedError - @abstractmethod def lookup(self, ids: torch.Tensor) -> torch.Tensor: - """Gather codebook embeddings for ``ids``.""" - raise NotImplementedError + """Gather codebook embeddings for ``ids`` (indexes the codebook).""" + return self.get_codebook_embeddings()[ids] @abstractmethod def get_codebook_embeddings(self) -> torch.Tensor: - """Return the full codebook, shape (n_clusters, D).""" + """Return the full codebook, shape (n_clusters, D). + + The codebook lives in a backend-specific attribute (a ``centroids`` + buffer for K-Means, an ``nn.Embedding`` for RQ-VAE), so this stays + abstract; :meth:`lookup` is then concrete in terms of it. + """ raise NotImplementedError @@ -299,10 +303,6 @@ def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: ids = torch.cdist(x, self.centroids).argmin(dim=-1) return QuantizeOutput(embeddings=self.centroids[ids], ids=ids) - def lookup(self, ids: torch.Tensor) -> torch.Tensor: - """Gather centroids for ``ids``, shape (..., D).""" - return self.centroids[ids] - def get_codebook_embeddings(self) -> torch.Tensor: """Return the centroid table, shape (n_clusters, n_features).""" return self.centroids From ed12cff7cee38739dac6118d5ed51256d30d235a Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 07:24:37 +0000 Subject: [PATCH 40/46] [refactor] QuantizeLayer: own n_clusters/n_features in the base Every quantize layer has a codebook of n_clusters x n_features, so store that shape in QuantizeLayer.__init__; KMeansQuantizeLayer passes them via super() and builds its centroid buffer from them. (PR3's vector-quantize layer maps its n_embed/embed_dim onto the same base params.) Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/modules/sid/kmeans.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans.py index 9dc15fb55..4f6450388 100644 --- a/tzrec/modules/sid/kmeans.py +++ b/tzrec/modules/sid/kmeans.py @@ -170,9 +170,19 @@ class QuantizeLayer(nn.Module): Shared interface for the K-Means backend (:class:`KMeansQuantizeLayer`) and the RQ-VAE backend's vector-quantize layer, so the residual quantizer - can drive either uniformly. + can drive either uniformly. Owns the codebook shape; subclasses build the + backend-specific codebook (a buffer, an ``nn.Embedding``, …) from it. + + Args: + n_clusters (int): number of codebook entries. + n_features (int): feature dimension. """ + def __init__(self, n_clusters: int, n_features: int) -> None: + super().__init__() + self.n_clusters = n_clusters + self.n_features = n_features + @abstractmethod def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: """Assign ``x`` (B, D) to the codebook, returning codes + embeddings.""" @@ -210,10 +220,7 @@ def __init__( n_clusters: int, n_features: int, ) -> None: - super().__init__() - self.n_clusters = n_clusters - self.n_features = n_features - + super().__init__(n_clusters, n_features) self.register_buffer("centroids", torch.zeros(n_clusters, n_features)) # Persistent so a post-fit checkpoint round-trips; a mid-fit poison # (True flag + zero centroids) is caught in _load_from_state_dict. From d2697eb7a140de8c3b7a048e4b87ba3a362ed24d Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 08:00:03 +0000 Subject: [PATCH 41/46] [refactor] SID: extract QuantizeLayer ABC; rename kmeans -> kmeans_quantize - Add tzrec/modules/sid/quantize_layer.py: QuantizeLayer ABC shared by the K-Means backend and (PR3) the RQ-VAE vector-quantize layer. Owns the codebook shape (n_embed, embed_dim); concrete lookup() in terms of an abstract get_codebook_embeddings(). Adds quantize_layer_test.py. - Rename kmeans.py -> kmeans_quantize.py (parallel to vector_quantize.py) and CentroidQuantizeLayer -> KMeansQuantizeLayer; KMeansQuantizeLayer now subclasses QuantizeLayer. - ResidualKMeansQuantizer.train_offline now CONSUMES its input (may mutate in place); the copy decision is the caller's. on_train_end hands over the reservoir buffer by ownership (no copy) since nothing reads it afterward. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 5 +- .../sid/{kmeans.py => kmeans_quantize.py} | 71 +++------------- ...kmeans_test.py => kmeans_quantize_test.py} | 18 ++--- tzrec/modules/sid/quantize_layer.py | 58 +++++++++++++ tzrec/modules/sid/quantize_layer_test.py | 81 +++++++++++++++++++ .../modules/sid/residual_kmeans_quantizer.py | 22 ++--- 6 files changed, 177 insertions(+), 78 deletions(-) rename tzrec/modules/sid/{kmeans.py => kmeans_quantize.py} (79%) rename tzrec/modules/sid/{kmeans_test.py => kmeans_quantize_test.py} (91%) create mode 100644 tzrec/modules/sid/quantize_layer.py create mode 100644 tzrec/modules/sid/quantize_layer_test.py diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 28c8a16bc..07ce132f9 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -26,7 +26,7 @@ from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature from tzrec.models.sid_model import BaseSidModel -from tzrec.modules.sid.kmeans import ReservoirSampler +from tzrec.modules.sid.kmeans_quantize import ReservoirSampler from tzrec.modules.sid.residual_kmeans_quantizer import ( ResidualKMeansQuantizer, ) @@ -178,6 +178,9 @@ def on_train_end(self) -> None: An empty reservoir only happens for a pathologically tiny corpus; the fit is then skipped. """ + # train_offline consumes its input; we hand it the reservoir buffer + # directly (no copy) since nothing reads it after this — reset() drops + # the sampler's reference and ``local`` is the last user of the storage. local = self._reservoir.sample() self._reservoir.reset() diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans_quantize.py similarity index 79% rename from tzrec/modules/sid/kmeans.py rename to tzrec/modules/sid/kmeans_quantize.py index 4f6450388..872783893 100644 --- a/tzrec/modules/sid/kmeans.py +++ b/tzrec/modules/sid/kmeans_quantize.py @@ -14,22 +14,18 @@ This module is the single home for torch-native K-Means code used by SID models: -* :class:`QuantizeLayer` — the per-layer quantizer interface - (``quantize`` / ``lookup`` / ``get_codebook_embeddings``) shared with the - RQ-VAE backend's vector-quantize layer. -* :class:`KMeansQuantizeLayer` — the K-Means implementation: a centroid - container populated by the FAISS backend via ``load_centroids_``. +* :class:`KMeansQuantizeLayer` — the K-Means :class:`QuantizeLayer`: a + centroid container populated by the FAISS backend via ``load_centroids_``. * :class:`ReservoirSampler` — bounded uniform stream sample (Vitter Algorithm R) that :class:`~tzrec.models.sid_rqkmeans.SidRqkmeans` fills during training to feed the one-shot FAISS fit. """ -from abc import abstractmethod from typing import Optional, Tuple import torch -from torch import nn +from tzrec.modules.sid.quantize_layer import QuantizeLayer from tzrec.modules.sid.types import QuantizeOutput from tzrec.utils.logging_util import logger @@ -165,63 +161,22 @@ def reset(self) -> None: self._n_seen = 0 -class QuantizeLayer(nn.Module): - """One quantize layer: assign inputs to a codebook and look codes up. - - Shared interface for the K-Means backend (:class:`KMeansQuantizeLayer`) - and the RQ-VAE backend's vector-quantize layer, so the residual quantizer - can drive either uniformly. Owns the codebook shape; subclasses build the - backend-specific codebook (a buffer, an ``nn.Embedding``, …) from it. - - Args: - n_clusters (int): number of codebook entries. - n_features (int): feature dimension. - """ - - def __init__(self, n_clusters: int, n_features: int) -> None: - super().__init__() - self.n_clusters = n_clusters - self.n_features = n_features - - @abstractmethod - def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: - """Assign ``x`` (B, D) to the codebook, returning codes + embeddings.""" - raise NotImplementedError - - def lookup(self, ids: torch.Tensor) -> torch.Tensor: - """Gather codebook embeddings for ``ids`` (indexes the codebook).""" - return self.get_codebook_embeddings()[ids] - - @abstractmethod - def get_codebook_embeddings(self) -> torch.Tensor: - """Return the full codebook, shape (n_clusters, D). - - The codebook lives in a backend-specific attribute (a ``centroids`` - buffer for K-Means, an ``nn.Embedding`` for RQ-VAE), so this stays - abstract; :meth:`lookup` is then concrete in terms of it. - """ - raise NotImplementedError - - class KMeansQuantizeLayer(QuantizeLayer): - """Single layer of a residual K-Means stack. + """K-Means :class:`QuantizeLayer`: a centroid codebook + nearest assignment. Centroids are populated externally by ``load_centroids_`` (the FAISS backend in :class:`ResidualKMeansQuantizer`); ``quantize`` is the only - forward path. + forward path. (The k-means *fit* lives in the quantizer; this layer just + holds the resulting centroids.) Args: - n_clusters (int): number of clusters (codebook size). - n_features (int): feature dimension. + n_embed (int): number of centroids (codebook size). + embed_dim (int): feature dimension. """ - def __init__( - self, - n_clusters: int, - n_features: int, - ) -> None: - super().__init__(n_clusters, n_features) - self.register_buffer("centroids", torch.zeros(n_clusters, n_features)) + def __init__(self, n_embed: int, embed_dim: int) -> None: + super().__init__(n_embed, embed_dim) + self.register_buffer("centroids", torch.zeros(n_embed, embed_dim)) # Persistent so a post-fit checkpoint round-trips; a mid-fit poison # (True flag + zero centroids) is caught in _load_from_state_dict. self.register_buffer("_is_initialized", torch.tensor(False)) @@ -246,7 +201,7 @@ def load_centroids_(self, centroids: torch.Tensor) -> None: Args: centroids (Tensor): externally trained centroids, - shape (n_clusters, n_features). + shape (n_embed, embed_dim). """ assert centroids.shape == self.centroids.shape, ( f"centroids shape mismatch: expected {tuple(self.centroids.shape)}, " @@ -311,5 +266,5 @@ def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: return QuantizeOutput(embeddings=self.centroids[ids], ids=ids) def get_codebook_embeddings(self) -> torch.Tensor: - """Return the centroid table, shape (n_clusters, n_features).""" + """Return the centroid table, shape (n_embed, embed_dim).""" return self.centroids diff --git a/tzrec/modules/sid/kmeans_test.py b/tzrec/modules/sid/kmeans_quantize_test.py similarity index 91% rename from tzrec/modules/sid/kmeans_test.py rename to tzrec/modules/sid/kmeans_quantize_test.py index 66a8de1a9..9c2df2611 100644 --- a/tzrec/modules/sid/kmeans_test.py +++ b/tzrec/modules/sid/kmeans_quantize_test.py @@ -13,7 +13,7 @@ import torch -from tzrec.modules.sid.kmeans import ( +from tzrec.modules.sid.kmeans_quantize import ( KMeansQuantizeLayer, ReservoirSampler, recon_diagnostics, @@ -34,12 +34,12 @@ class KMeansQuantizeLayerTest(unittest.TestCase): """Tests for the single KMeansQuantizeLayer.""" def test_uninitialized_by_default(self) -> None: - layer = KMeansQuantizeLayer(n_clusters=4, n_features=3) + layer = KMeansQuantizeLayer(n_embed=4, embed_dim=3) self.assertFalse(layer.is_initialized) self.assertEqual(layer.centroids.abs().sum().item(), 0.0) def test_load_centroids_and_quantize(self) -> None: - layer = KMeansQuantizeLayer(n_clusters=2, n_features=2) + layer = KMeansQuantizeLayer(n_embed=2, embed_dim=2) centroids = torch.tensor([[0.0, 0.0], [10.0, 10.0]]) layer.load_centroids_(centroids) self.assertTrue(layer.is_initialized) @@ -52,30 +52,30 @@ def test_load_centroids_and_quantize(self) -> None: torch.testing.assert_close(layer.lookup(out.ids), out.embeddings) def test_quantize_uninitialized_returns_zeros(self) -> None: - layer = KMeansQuantizeLayer(n_clusters=4, n_features=3) + layer = KMeansQuantizeLayer(n_embed=4, embed_dim=3) out = layer.quantize(torch.randn(5, 3)) self.assertEqual(out.ids.shape, (5,)) self.assertEqual(int(out.ids.abs().sum()), 0) torch.testing.assert_close(out.embeddings, torch.zeros(5, 3)) def test_load_centroids_shape_mismatch_raises(self) -> None: - layer = KMeansQuantizeLayer(n_clusters=2, n_features=2) + layer = KMeansQuantizeLayer(n_embed=2, embed_dim=2) with self.assertRaises(AssertionError): layer.load_centroids_(torch.zeros(3, 2)) def test_mid_fit_checkpoint_rejected(self) -> None: - layer = KMeansQuantizeLayer(n_clusters=2, n_features=2) + layer = KMeansQuantizeLayer(n_embed=2, embed_dim=2) sd = layer.state_dict() # Simulate a mid-fit checkpoint: flag True but centroids still zero. sd["_is_initialized"] = torch.tensor(True) - fresh = KMeansQuantizeLayer(n_clusters=2, n_features=2) + fresh = KMeansQuantizeLayer(n_embed=2, embed_dim=2) with self.assertRaisesRegex(RuntimeError, "mid-FAISS-fit"): fresh.load_state_dict(sd) def test_post_fit_checkpoint_round_trips(self) -> None: - layer = KMeansQuantizeLayer(n_clusters=2, n_features=2) + layer = KMeansQuantizeLayer(n_embed=2, embed_dim=2) layer.load_centroids_(torch.tensor([[1.0, 2.0], [3.0, 4.0]])) - fresh = KMeansQuantizeLayer(n_clusters=2, n_features=2) + fresh = KMeansQuantizeLayer(n_embed=2, embed_dim=2) fresh.load_state_dict(layer.state_dict()) self.assertTrue(fresh.is_initialized) torch.testing.assert_close(fresh.centroids, layer.centroids) diff --git a/tzrec/modules/sid/quantize_layer.py b/tzrec/modules/sid/quantize_layer.py new file mode 100644 index 000000000..e7f344fda --- /dev/null +++ b/tzrec/modules/sid/quantize_layer.py @@ -0,0 +1,58 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""QuantizeLayer: the per-layer quantizer interface shared by SID backends.""" + +from abc import abstractmethod + +import torch +from torch import nn + +from tzrec.modules.sid.types import QuantizeOutput + + +class QuantizeLayer(nn.Module): + """One quantize layer: assign inputs to a codebook and look codes up. + + Shared interface for the K-Means backend + (:class:`~tzrec.modules.sid.kmeans_quantize.KMeansQuantizeLayer`) and the RQ-VAE + backend's vector-quantize layer, so the residual quantizer can drive either + uniformly. Owns the codebook shape; subclasses build the backend-specific + codebook (a buffer, an ``nn.Embedding``, …) from it. + + Args: + n_embed (int): number of codebook entries. + embed_dim (int): feature dimension. + """ + + def __init__(self, n_embed: int, embed_dim: int) -> None: + super().__init__() + self.n_embed = n_embed + self.embed_dim = embed_dim + + @abstractmethod + def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: + """Assign ``x`` (B, D) to the codebook, returning codes + embeddings.""" + raise NotImplementedError + + def lookup(self, ids: torch.Tensor) -> torch.Tensor: + """Gather codebook embeddings for ``ids`` (indexes the codebook).""" + return self.get_codebook_embeddings()[ids] + + @abstractmethod + def get_codebook_embeddings(self) -> torch.Tensor: + """Return the full codebook, shape (n_embed, embed_dim). + + The codebook lives in a backend-specific attribute (a ``centroids`` + buffer for K-Means, an ``nn.Embedding`` for RQ-VAE), so this stays + abstract; :meth:`lookup` is then concrete in terms of it. + """ + raise NotImplementedError diff --git a/tzrec/modules/sid/quantize_layer_test.py b/tzrec/modules/sid/quantize_layer_test.py new file mode 100644 index 000000000..28eb4849b --- /dev/null +++ b/tzrec/modules/sid/quantize_layer_test.py @@ -0,0 +1,81 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from tzrec.modules.sid.quantize_layer import QuantizeLayer +from tzrec.modules.sid.types import QuantizeOutput + + +class _StubQuantizeLayer(QuantizeLayer): + """Minimal concrete subclass: a fixed codebook, nearest-row assignment. + + Exercises the base class's concrete ``__init__`` / ``lookup`` without + pulling in a backend (FAISS / nn.Embedding). + """ + + def __init__(self, n_embed: int, embed_dim: int) -> None: + super().__init__(n_embed, embed_dim) + # A deterministic codebook so lookup/quantize are checkable by hand. + self._codebook = torch.arange(n_embed * embed_dim, dtype=torch.float32).reshape( + n_embed, embed_dim + ) + + def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: + dist = torch.cdist(x, self._codebook) + ids = dist.argmin(dim=-1) + return QuantizeOutput(embeddings=self.lookup(ids), ids=ids) + + def get_codebook_embeddings(self) -> torch.Tensor: + return self._codebook + + +class QuantizeLayerTest(unittest.TestCase): + """Tests for the shared QuantizeLayer base class.""" + + def test_init_stores_codebook_shape(self) -> None: + layer = _StubQuantizeLayer(n_embed=4, embed_dim=3) + self.assertEqual(layer.n_embed, 4) + self.assertEqual(layer.embed_dim, 3) + + def test_lookup_gathers_codebook_rows(self) -> None: + layer = _StubQuantizeLayer(n_embed=4, embed_dim=3) + ids = torch.tensor([0, 2, 3, 1]) + out = layer.lookup(ids) + torch.testing.assert_close(out, layer.get_codebook_embeddings()[ids]) + self.assertEqual(out.shape, (4, 3)) + + def test_quantize_assigns_exact_codebook_rows(self) -> None: + # Feeding codebook rows back in must recover their own indices. + layer = _StubQuantizeLayer(n_embed=4, embed_dim=3) + x = layer.get_codebook_embeddings().clone() + out = layer.quantize(x) + torch.testing.assert_close(out.ids, torch.arange(4)) + torch.testing.assert_close(out.embeddings, x) + + def test_abstract_methods_unoverridden_raise(self) -> None: + # The abstract methods are documented to raise if a subclass forgets + # to implement them; QuantizeLayer relies on nn.Module (no ABCMeta), + # so this guards that the bodies still fail loudly rather than no-op. + class _Incomplete(QuantizeLayer): + pass + + layer = _Incomplete(n_embed=2, embed_dim=2) + with self.assertRaises(NotImplementedError): + layer.get_codebook_embeddings() + with self.assertRaises(NotImplementedError): + layer.quantize(torch.zeros(1, 2)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 3ed9d7087..ddd6154f2 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -23,7 +23,7 @@ from torch import nn from torch.nn import functional as F -from tzrec.modules.sid.kmeans import KMeansQuantizeLayer, recon_diagnostics +from tzrec.modules.sid.kmeans_quantize import KMeansQuantizeLayer, recon_diagnostics from tzrec.modules.sid.residual_quantizer import ResidualQuantizer from tzrec.utils.logging_util import logger @@ -72,8 +72,8 @@ def __init__( self.layers = nn.ModuleList( [ KMeansQuantizeLayer( - n_clusters=self.n_embed_list[i], - n_features=embed_dim, + n_embed=self.n_embed_list[i], + embed_dim=embed_dim, ) for i in range(n_layers) ] @@ -168,8 +168,9 @@ def train_offline( ``SEARCH_CHUNK``-sized chunks to cap peak memory. Args: - inputs (Tensor): embedding matrix (N, D) on CPU. Copied once to an - owned float32 tensor; not mutated. + inputs (Tensor): embedding matrix (N, D) on CPU. CONSUMED: the + residual pass may mutate it in place, so the caller must not + rely on its contents afterward (copy first if it needs them). verbose (bool): print per-layer reconstruction loss. Default: True. """ # Assert the host-tensor contract locally (this is a standalone module) @@ -178,11 +179,12 @@ def train_offline( assert inputs.dim() == 2 and inputs.shape[1] == self.embed_dim, ( f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" ) - # The loop below mutates x in place (the residual ``x -= q``), and the - # input is a view into the caller's float32 reservoir buffer — so own a - # fresh copy (copy=True forces one even when the dtype already matches, - # avoiding the double copy a separate ``.clone()`` would add). - x = inputs.detach().to(dtype=torch.float32, copy=True).contiguous() + # train_offline CONSUMES its input: the residual loop below mutates x + # in place (``x -= q``). We only normalize dtype/layout for faiss — a + # no-op view when the input is already float32 + contiguous, so the + # mutation lands in the caller's buffer (intended; the caller copies + # first if it still needs the data). + x = inputs.detach().to(dtype=torch.float32).contiguous() N = x.shape[0] # Fail loudly on a too-small corpus: faiss.Kmeans only warns (not # errors) when N < K and returns a degenerate codebook, which the From 097e9eb0fd0f352004247be4237cbb092e30c561 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 08:08:08 +0000 Subject: [PATCH 42/46] [docs] checkpoint_util: tighten maybe_save `final` param docstring Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/utils/checkpoint_util.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tzrec/utils/checkpoint_util.py b/tzrec/utils/checkpoint_util.py index c601fd432..ede4ef4de 100644 --- a/tzrec/utils/checkpoint_util.py +++ b/tzrec/utils/checkpoint_util.py @@ -417,8 +417,7 @@ def maybe_save( epoch: current epoch; enables the epoch trigger when not None. data_timestamp: this rank's consumed event-time (seconds), -1.0 if none; reconciled across workers (quorum) for the event-time trigger. - final: request a save unconditionally (still subject to the per-step - dedupe), e.g. at train end. + final: force a save (still subject to the dedupe), e.g. at train end. Returns: True if a checkpoint was saved. From a9a889c2ddff65bd907f535db57c2c949b0e8802 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 11:44:09 +0000 Subject: [PATCH 43/46] [fix] SID: review fixes + fail-fast validation; fix integration test CPU pin - sid_integration_test: force CPU with CUDA_VISIBLE_DEVICES="-1" not "" (empty is treated inconsistently across CUDA runtimes; the GPU CI runner didn't hide devices, tripping the CPU-only guard in the train_eval child). - BaseSidModel: validate codebook entries >=1 and input_dim >=1 at construction; guard feature width in _extract_feature (a (B,1) tensor would otherwise broadcast into a degenerate rank-1 codebook). assert -> raise. - residual_kmeans_quantizer / kmeans_quantize: assert -> raise for the data-corruption guards (N>=max_k, load_centroids_ shape, CPU/shape contract) so they survive python -O. - RelativeL1: float64 sum / long count to avoid float32 rounding past 2**24. - kmeans_quantize: drop the duplicate relative_l1/recon_diagnostics helpers; RelativeL1 (tzrec/metrics) is the single home of the formula. Per-layer offline-fit log now reports MSE only. - sid_rqkmeans: TODO documenting the periodic-checkpointing-disabled contract (codebook can be dropped by save dedupe otherwise). - sid_model.proto: drop stale "(global, across all ranks)" wording. - mock config: set save_checkpoints_steps/epochs = 0 (the documented convention). Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/metrics/relative_l1.py | 15 ++++- tzrec/models/sid_model.py | 23 +++++++- tzrec/models/sid_rqkmeans.py | 13 ++++- tzrec/modules/sid/kmeans_quantize.py | 56 +++---------------- tzrec/modules/sid/kmeans_quantize_test.py | 13 +---- .../modules/sid/residual_kmeans_quantizer.py | 34 +++++------ tzrec/protos/models/sid_model.proto | 4 +- tzrec/tests/configs/sid_rqkmeans_mock.config | 2 + tzrec/tests/sid_integration_test.py | 5 +- 9 files changed, 79 insertions(+), 86 deletions(-) diff --git a/tzrec/metrics/relative_l1.py b/tzrec/metrics/relative_l1.py index 72a55c28d..5aa00f4e4 100644 --- a/tzrec/metrics/relative_l1.py +++ b/tzrec/metrics/relative_l1.py @@ -29,8 +29,17 @@ class RelativeL1(Metric): def __init__(self, epsilon: float = 1e-4, **kwargs) -> None: super().__init__(**kwargs) self.epsilon = epsilon - self.add_state("sum_rel", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("count", default=torch.tensor(0.0), dist_reduce_fx="sum") + # float64 sum / long count: element-wise aggregation crosses 2**24 at + # only ~32K rows of a 512-dim embedding, past which float32 increments + # round (mirrors the float64 care in ``ReservoirSampler.add``). + self.add_state( + "sum_rel", + default=torch.tensor(0.0, dtype=torch.float64), + dist_reduce_fx="sum", + ) + self.add_state( + "count", default=torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum" + ) def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: """Accumulate the relative-L1 error between ``preds`` and ``target``. @@ -42,7 +51,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: rel = torch.abs(target - preds) / ( torch.maximum(torch.abs(target), torch.abs(preds)) + self.epsilon ) - self.sum_rel += rel.sum() + self.sum_rel += rel.sum().double() self.count += rel.numel() def compute(self) -> torch.Tensor: diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index c0a0e9e56..579ab6702 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -70,8 +70,17 @@ def __init__( self._input_dim = cfg.input_dim self._normalize_residuals = cfg.normalize_residuals - assert cfg.codebook, "codebook must be set, e.g. [256, 256, 256]" + if not cfg.codebook: + raise ValueError("codebook must be set, e.g. [256, 256, 256]") self._n_embed_list = list(cfg.codebook) + # Fail fast: a zero codebook entry / input_dim==0 only errors opaquely + # deep inside faiss, after the whole training pass. + if any(k < 1 for k in self._n_embed_list): + raise ValueError( + f"every codebook entry must be >= 1, got {self._n_embed_list}" + ) + if self._input_dim < 1: + raise ValueError(f"input_dim must be >= 1, got {self._input_dim}") self._n_layers = len(self._n_embed_list) def _extract_feature( @@ -87,7 +96,17 @@ def _extract_feature( if feature_name is None: feature_name = self._embedding_feature_name kt = batch.dense_features[BASE_DATA_GROUP] - return kt[feature_name] + embedding = kt[feature_name] + # Guard a misconfigured feature width: a (B, 1) tensor (raw_feature + # missing value_dim, which defaults to 1) would otherwise broadcast + # silently downstream and fit a degenerate rank-1 codebook. + if embedding.dim() != 2 or embedding.shape[1] != self._input_dim: + raise ValueError( + f"feature '{feature_name}' has shape {tuple(embedding.shape)}, " + f"expected (B, {self._input_dim}); check that its value_dim " + "matches the SID input_dim." + ) + return embedding def init_loss(self) -> None: """Initialize loss modules. diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 07ce132f9..317fa4779 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -58,8 +58,10 @@ def __init__( ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) - # CPU-only: embeddings, reservoir, and FAISS fit all stay on the host, - # so there are no device copies. Refuse to run when CUDA is visible. + # CPU-only: training and inference both run on the host (embeddings, + # reservoir, FAISS fit, and post-fit assignment), so there are no device + # copies. v1 deliberately restricts the whole model to CPU; refuse to + # run when CUDA is visible. if torch.cuda.is_available(): raise RuntimeError( "SidRqkmeans is CPU-only, but a CUDA device is visible. " @@ -175,6 +177,13 @@ def on_train_end(self) -> None: then persists the fitted codebook (SID runs with periodic checkpointing disabled, so that save is never deduped away). + TODO: the "periodic checkpointing disabled" requirement is currently a + convention, not enforced. If a user sets save_checkpoints_steps/epochs + > 0 and the last in-loop save lands on the final step, the tail save is + deduped away and the fitted codebook is silently dropped. Harden the + save logic (enforce the contract / bypass the dedupe for this save) in a + future update. + An empty reservoir only happens for a pathologically tiny corpus; the fit is then skipped. """ diff --git a/tzrec/modules/sid/kmeans_quantize.py b/tzrec/modules/sid/kmeans_quantize.py index 872783893..6eb5b940a 100644 --- a/tzrec/modules/sid/kmeans_quantize.py +++ b/tzrec/modules/sid/kmeans_quantize.py @@ -21,7 +21,7 @@ fills during training to feed the one-shot FAISS fit. """ -from typing import Optional, Tuple +from typing import Optional import torch @@ -30,49 +30,6 @@ from tzrec.utils.logging_util import logger -def recon_diagnostics( - x: torch.Tensor, - out: torch.Tensor, - epsilon: float = 1e-4, -) -> Tuple[torch.Tensor, torch.Tensor]: - """MSE + relative-L1 reconstruction diagnostics. - - Shared by :meth:`SidRqkmeans.update_metric` and - :meth:`ResidualKMeansQuantizer.train_offline`'s per-layer log. - - Args: - x: ground-truth embedding, shape (B, D). - out: quantized reconstruction, shape (B, D). - epsilon: numerical stabilizer for the relative-L1 denominator. - - Returns: - mse: scalar ``((out - x) ** 2).mean()``. - rel: scalar relative-L1 ``mean(|x - out| / (max(|x|, |out|) + eps))``. - """ - return ((out - x) ** 2).mean(), relative_l1(x, out, epsilon) - - -def relative_l1( - x: torch.Tensor, - out: torch.Tensor, - epsilon: float = 1e-4, -) -> torch.Tensor: - """Relative-L1 error ``mean(|x - out| / (max(|x|, |out|) + eps))``. - - Symmetric relative error in [0, 1] (verbatim port of OpenOneRec's - ``calc_loss``). Used standalone by :meth:`SidRqkmeans.update_metric` (which - needs only ``rel``, not the MSE :meth:`recon_diagnostics` also computes). - - Args: - x: ground-truth embedding, shape (B, D). - out: quantized reconstruction, shape (B, D). - epsilon: numerical stabilizer for the denominator. - """ - return ( - torch.abs(x - out) / (torch.maximum(torch.abs(x), torch.abs(out)) + epsilon) - ).mean() - - class ReservoirSampler: """Bounded uniform sample of a stream (Vitter Algorithm R). @@ -203,10 +160,13 @@ def load_centroids_(self, centroids: torch.Tensor) -> None: centroids (Tensor): externally trained centroids, shape (n_embed, embed_dim). """ - assert centroids.shape == self.centroids.shape, ( - f"centroids shape mismatch: expected {tuple(self.centroids.shape)}, " - f"got {tuple(centroids.shape)}" - ) + # raise (not assert): under ``python -O`` a dropped assert would let a + # (1, D) tensor broadcast-replicate into all K centroid rows silently. + if centroids.shape != self.centroids.shape: + raise RuntimeError( + f"centroids shape mismatch: expected {tuple(self.centroids.shape)}, " + f"got {tuple(centroids.shape)}" + ) self.centroids.copy_( centroids.to(dtype=self.centroids.dtype, device=self.centroids.device) ) diff --git a/tzrec/modules/sid/kmeans_quantize_test.py b/tzrec/modules/sid/kmeans_quantize_test.py index 9c2df2611..2f2883562 100644 --- a/tzrec/modules/sid/kmeans_quantize_test.py +++ b/tzrec/modules/sid/kmeans_quantize_test.py @@ -16,20 +16,9 @@ from tzrec.modules.sid.kmeans_quantize import ( KMeansQuantizeLayer, ReservoirSampler, - recon_diagnostics, ) -class KmeansHelpersTest(unittest.TestCase): - """Tests for the K-Means helper functions.""" - - def test_recon_diagnostics_zero_on_identity(self) -> None: - x = torch.randn(8, 4) - mse, rel = recon_diagnostics(x, x.clone()) - self.assertAlmostEqual(mse.item(), 0.0, places=6) - self.assertAlmostEqual(rel.item(), 0.0, places=6) - - class KMeansQuantizeLayerTest(unittest.TestCase): """Tests for the single KMeansQuantizeLayer.""" @@ -60,7 +49,7 @@ def test_quantize_uninitialized_returns_zeros(self) -> None: def test_load_centroids_shape_mismatch_raises(self) -> None: layer = KMeansQuantizeLayer(n_embed=2, embed_dim=2) - with self.assertRaises(AssertionError): + with self.assertRaises(RuntimeError): layer.load_centroids_(torch.zeros(3, 2)) def test_mid_fit_checkpoint_rejected(self) -> None: diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index ddd6154f2..4056ca861 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -23,7 +23,7 @@ from torch import nn from torch.nn import functional as F -from tzrec.modules.sid.kmeans_quantize import KMeansQuantizeLayer, recon_diagnostics +from tzrec.modules.sid.kmeans_quantize import KMeansQuantizeLayer from tzrec.modules.sid.residual_quantizer import ResidualQuantizer from tzrec.utils.logging_util import logger @@ -173,12 +173,15 @@ def train_offline( rely on its contents afterward (copy first if it needs them). verbose (bool): print per-layer reconstruction loss. Default: True. """ - # Assert the host-tensor contract locally (this is a standalone module) - # so misuse fails here, not deep inside faiss. - assert not inputs.is_cuda, "train_offline is CPU-only; got a CUDA tensor" - assert inputs.dim() == 2 and inputs.shape[1] == self.embed_dim, ( - f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" - ) + # Check the host-tensor contract locally (this is a standalone module) + # so misuse fails here, not deep inside faiss. raise (not assert): these + # guard silent data corruption and must survive ``python -O``. + if inputs.is_cuda: + raise RuntimeError("train_offline is CPU-only; got a CUDA tensor") + if inputs.dim() != 2 or inputs.shape[1] != self.embed_dim: + raise RuntimeError( + f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" + ) # train_offline CONSUMES its input: the residual loop below mutates x # in place (``x -= q``). We only normalize dtype/layout for faiss — a # no-op view when the input is already float32 + contiguous, so the @@ -190,9 +193,11 @@ def train_offline( # errors) when N < K and returns a degenerate codebook, which the # all-zero poison guard in KMeansQuantizeLayer would not catch. max_k = max(self.n_embed_list) - assert N >= max_k, ( - f"need >= {max_k} points to fit the codebook (largest layer K), got N={N}" - ) + if N < max_k: + raise RuntimeError( + f"need >= {max_k} points to fit the codebook (largest layer K), " + f"got N={N}" + ) out = torch.zeros_like(x) # x0 (original input) feeds the per-layer recon log. Without # normalization ``out + x == x0``, so it's rebuilt on the fly below and @@ -254,9 +259,6 @@ def train_offline( ) @staticmethod - def _calc_loss( - x: torch.Tensor, out: torch.Tensor, epsilon: float = 1e-4 - ) -> Dict[str, float]: - """Reconstruction loss diagnostics (MSE + relative L1).""" - loss, rel_loss = recon_diagnostics(x, out, epsilon=epsilon) - return {"loss": float(loss.item()), "rel_loss": float(rel_loss.item())} + def _calc_loss(x: torch.Tensor, out: torch.Tensor) -> Dict[str, float]: + """Per-layer reconstruction MSE for the offline-fit log.""" + return {"loss": float(((out - x) ** 2).mean().item())} diff --git a/tzrec/protos/models/sid_model.proto b/tzrec/protos/models/sid_model.proto index f6f07da2f..e51462efa 100644 --- a/tzrec/protos/models/sid_model.proto +++ b/tzrec/protos/models/sid_model.proto @@ -29,8 +29,8 @@ message SidRqkmeans { optional bool normalize_residuals = 4 [default = false]; // Strictly-typed extra kwargs forwarded to faiss.Kmeans(D, K, **kwargs). optional FaissKmeansConfig faiss_kmeans_kwargs = 5; - // Target number of embeddings to reservoir-sample for the FAISS fit - // (global, across all ranks). Bounds host memory regardless of corpus + // Target number of embeddings to reservoir-sample for the FAISS fit. + // Bounds host memory regardless of corpus // size. 0 (the default) auto-derives it as max(K) * max_points_per_centroid // (the largest per-layer codebook, for non-uniform codebooks) — exactly // what FAISS subsamples to internally (default 256), so no training points diff --git a/tzrec/tests/configs/sid_rqkmeans_mock.config b/tzrec/tests/configs/sid_rqkmeans_mock.config index 0aad49cfb..0e6dec907 100644 --- a/tzrec/tests/configs/sid_rqkmeans_mock.config +++ b/tzrec/tests/configs/sid_rqkmeans_mock.config @@ -17,6 +17,8 @@ train_config { } } num_epochs: 1 + save_checkpoints_steps: 0 + save_checkpoints_epochs: 0 } eval_config { } diff --git a/tzrec/tests/sid_integration_test.py b/tzrec/tests/sid_integration_test.py index 711e69ec0..94c5216e7 100644 --- a/tzrec/tests/sid_integration_test.py +++ b/tzrec/tests/sid_integration_test.py @@ -34,8 +34,11 @@ def setUp(self): # SID models are CPU-only (refuse a visible CUDA device) and # single-process (refuse world_size > 1), so hide CUDA and pin # nproc=1 — the GPU CI harness otherwise defaults to GPU + nproc=2. + # Use "-1", not "" — an empty CUDA_VISIBLE_DEVICES is treated + # inconsistently across CUDA runtimes (the GPU CI runner does not hide + # the devices), which trips the CPU-only guard in the train_eval child. patcher = mock.patch.dict( - os.environ, {"CUDA_VISIBLE_DEVICES": "", "TEST_NPROC_PER_NODE": "1"} + os.environ, {"CUDA_VISIBLE_DEVICES": "-1", "TEST_NPROC_PER_NODE": "1"} ) patcher.start() self.addCleanup(patcher.stop) From 3b41df9b2921ae280fd0ce92375ef680bc1e61f8 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 12:22:18 +0000 Subject: [PATCH 44/46] [review] SID: doc fixes, negative tests, stronger integration assertions - CPU-only guard message recommends CUDA_VISIBLE_DEVICES="-1" (not "", which this PR found unreliable on the GPU CI runner). - Correct the train_offline comment: faiss throws (not warns) for N < K. - Add negative tests for the fail-fast guards: empty/zero codebook, input_dim<1, feature-width mismatch, and train_offline too-few-points / wrong-dim. - sid_integration_test: assert the post-fit eval reports finite mse/rel_loss/ unique_sid_ratio (rel_loss < 1.0, unique_sid_ratio > 0) so a degenerate / unfitted codebook can't keep the test green. - Trim verbose comments. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/metrics/relative_l1.py | 5 ++- tzrec/models/sid_model.py | 5 ++- tzrec/models/sid_rqkmeans.py | 23 ++++++------- tzrec/models/sid_rqkmeans_test.py | 21 ++++++++++++ .../modules/sid/residual_kmeans_quantizer.py | 19 +++++------ .../sid/residual_kmeans_quantizer_test.py | 12 +++++++ tzrec/tests/sid_integration_test.py | 32 +++++++++++++------ 7 files changed, 76 insertions(+), 41 deletions(-) diff --git a/tzrec/metrics/relative_l1.py b/tzrec/metrics/relative_l1.py index 5aa00f4e4..685307608 100644 --- a/tzrec/metrics/relative_l1.py +++ b/tzrec/metrics/relative_l1.py @@ -29,9 +29,8 @@ class RelativeL1(Metric): def __init__(self, epsilon: float = 1e-4, **kwargs) -> None: super().__init__(**kwargs) self.epsilon = epsilon - # float64 sum / long count: element-wise aggregation crosses 2**24 at - # only ~32K rows of a 512-dim embedding, past which float32 increments - # round (mirrors the float64 care in ``ReservoirSampler.add``). + # float64 sum / long count: float32 loses integer precision past 2**24 + # (~32K rows of a 512-dim embedding) under element-wise aggregation. self.add_state( "sum_rel", default=torch.tensor(0.0, dtype=torch.float64), diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index 579ab6702..d3023090c 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -97,9 +97,8 @@ def _extract_feature( feature_name = self._embedding_feature_name kt = batch.dense_features[BASE_DATA_GROUP] embedding = kt[feature_name] - # Guard a misconfigured feature width: a (B, 1) tensor (raw_feature - # missing value_dim, which defaults to 1) would otherwise broadcast - # silently downstream and fit a degenerate rank-1 codebook. + # Guard a misconfigured width: a (B, 1) tensor (raw_feature missing + # value_dim) would broadcast silently into a degenerate rank-1 codebook. if embedding.dim() != 2 or embedding.shape[1] != self._input_dim: raise ValueError( f"feature '{feature_name}' has shape {tuple(embedding.shape)}, " diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 317fa4779..59b05af41 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -58,14 +58,12 @@ def __init__( ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) - # CPU-only: training and inference both run on the host (embeddings, - # reservoir, FAISS fit, and post-fit assignment), so there are no device - # copies. v1 deliberately restricts the whole model to CPU; refuse to - # run when CUDA is visible. + # CPU-only: v1 restricts the whole model (train + inference) to the + # host. Refuse to run when CUDA is visible. if torch.cuda.is_available(): raise RuntimeError( "SidRqkmeans is CPU-only, but a CUDA device is visible. " - 'Run with CUDA_VISIBLE_DEVICES="" (or on a CPU-only host).' + 'Run with CUDA_VISIBLE_DEVICES="-1" (or on a CPU-only host).' ) # Single-process only: the fit runs over one process's local reservoir, @@ -177,19 +175,16 @@ def on_train_end(self) -> None: then persists the fitted codebook (SID runs with periodic checkpointing disabled, so that save is never deduped away). - TODO: the "periodic checkpointing disabled" requirement is currently a - convention, not enforced. If a user sets save_checkpoints_steps/epochs - > 0 and the last in-loop save lands on the final step, the tail save is - deduped away and the fitted codebook is silently dropped. Harden the - save logic (enforce the contract / bypass the dedupe for this save) in a - future update. + TODO: "periodic checkpointing disabled" is a convention, not enforced. + With save_checkpoints_steps/epochs > 0, a final-step in-loop save can + dedupe the tail save away, silently dropping the fitted codebook. Harden + this (enforce the contract / bypass the dedupe) in a future update. An empty reservoir only happens for a pathologically tiny corpus; the fit is then skipped. """ - # train_offline consumes its input; we hand it the reservoir buffer - # directly (no copy) since nothing reads it after this — reset() drops - # the sampler's reference and ``local`` is the last user of the storage. + # train_offline consumes its input; hand it the reservoir buffer + # directly (no copy) — nothing reads it after this. local = self._reservoir.sample() self._reservoir.reset() diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index ecc96db86..273d123a2 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -109,6 +109,27 @@ def test_init_raises_on_too_small_train_sample_size(self) -> None: with self.assertRaisesRegex(RuntimeError, "largest codebook"): self._create_model(codebook=[16, 16], train_sample_size=8) + def test_init_raises_on_empty_codebook(self) -> None: + """An empty codebook fails fast at construction.""" + with self.assertRaisesRegex(ValueError, "codebook must be set"): + self._create_model(codebook=[]) + + def test_init_raises_on_zero_codebook_entry(self) -> None: + """A zero codebook entry fails fast at construction.""" + with self.assertRaisesRegex(ValueError, "codebook entry must be >= 1"): + self._create_model(codebook=[16, 0]) + + def test_init_raises_on_zero_input_dim(self) -> None: + """input_dim < 1 fails fast at construction.""" + with self.assertRaisesRegex(ValueError, "input_dim must be >= 1"): + self._create_model(input_dim=0) + + def test_predict_raises_on_wrong_feature_width(self) -> None: + """A feature whose width != input_dim fails fast (missing value_dim).""" + model = self._create_model(input_dim=32) + with self.assertRaisesRegex(ValueError, "value_dim"): + model.predict(_batch_from_rows(torch.randn(8, 1))) + def test_predict_collects_buffer(self) -> None: """In train mode, predict reservoir-samples; never fits.""" B, input_dim = 8, 32 diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 4056ca861..11b06951c 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -173,25 +173,22 @@ def train_offline( rely on its contents afterward (copy first if it needs them). verbose (bool): print per-layer reconstruction loss. Default: True. """ - # Check the host-tensor contract locally (this is a standalone module) - # so misuse fails here, not deep inside faiss. raise (not assert): these - # guard silent data corruption and must survive ``python -O``. + # Host-tensor contract, checked here (not deep in faiss). raise (not + # assert): these guard data corruption and must survive ``python -O``. if inputs.is_cuda: raise RuntimeError("train_offline is CPU-only; got a CUDA tensor") if inputs.dim() != 2 or inputs.shape[1] != self.embed_dim: raise RuntimeError( f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" ) - # train_offline CONSUMES its input: the residual loop below mutates x - # in place (``x -= q``). We only normalize dtype/layout for faiss — a - # no-op view when the input is already float32 + contiguous, so the - # mutation lands in the caller's buffer (intended; the caller copies - # first if it still needs the data). + # The loop below mutates x in place (``x -= q``); the dtype/layout + # normalize is a no-op view when already float32 + contiguous, so the + # mutation lands in the caller's buffer (intended — see Args: CONSUMED). x = inputs.detach().to(dtype=torch.float32).contiguous() N = x.shape[0] - # Fail loudly on a too-small corpus: faiss.Kmeans only warns (not - # errors) when N < K and returns a degenerate codebook, which the - # all-zero poison guard in KMeansQuantizeLayer would not catch. + # Clear message before faiss's own opaque C++ throw for N < K. (The + # K <= N < K * min_points_per_centroid case, where faiss only warns and + # returns a degenerate codebook, is not guarded here.) max_k = max(self.n_embed_list) if N < max_k: raise RuntimeError( diff --git a/tzrec/modules/sid/residual_kmeans_quantizer_test.py b/tzrec/modules/sid/residual_kmeans_quantizer_test.py index 42647468e..265991143 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer_test.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer_test.py @@ -31,6 +31,18 @@ def test_non_uniform_codebook_supported(self) -> None: self.assertEqual(rkq.n_embed_list, [8, 4, 16]) self.assertEqual([layer.centroids.shape[0] for layer in rkq.layers], [8, 4, 16]) + def test_train_offline_raises_on_too_few_points(self) -> None: + """N < largest K fails fast (clear message before faiss's own throw).""" + rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=1, n_embed=8) + with self.assertRaisesRegex(RuntimeError, "largest layer K"): + rkq.train_offline(torch.randn(4, 4), verbose=False) + + def test_train_offline_raises_on_wrong_dim(self) -> None: + """An input whose width != embed_dim fails fast.""" + rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=1, n_embed=8) + with self.assertRaisesRegex(RuntimeError, "inputs must be"): + rkq.train_offline(torch.randn(16, 8), verbose=False) + def test_forward_returns_zeros_before_fit(self) -> None: rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) self.assertFalse(all(layer.is_initialized for layer in rkq.layers)) diff --git a/tzrec/tests/sid_integration_test.py b/tzrec/tests/sid_integration_test.py index 94c5216e7..0c3595414 100644 --- a/tzrec/tests/sid_integration_test.py +++ b/tzrec/tests/sid_integration_test.py @@ -10,6 +10,8 @@ # limitations under the License. import glob +import json +import math import os import shutil import tempfile @@ -31,12 +33,10 @@ def setUp(self): os.makedirs("./tmp") self.test_dir = tempfile.mkdtemp(prefix="tzrec_", dir="./tmp") os.chmod(self.test_dir, 0o755) - # SID models are CPU-only (refuse a visible CUDA device) and - # single-process (refuse world_size > 1), so hide CUDA and pin - # nproc=1 — the GPU CI harness otherwise defaults to GPU + nproc=2. - # Use "-1", not "" — an empty CUDA_VISIBLE_DEVICES is treated - # inconsistently across CUDA runtimes (the GPU CI runner does not hide - # the devices), which trips the CPU-only guard in the train_eval child. + # SID is CPU-only + single-process, so hide CUDA and pin nproc=1 (the + # GPU CI harness defaults to GPU + nproc=2). Use "-1", not "" — an empty + # CUDA_VISIBLE_DEVICES is treated inconsistently and the GPU CI runner + # doesn't hide the devices, tripping the CPU-only guard in the child. patcher = mock.patch.dict( os.environ, {"CUDA_VISIBLE_DEVICES": "-1", "TEST_NPROC_PER_NODE": "1"} ) @@ -98,10 +98,22 @@ def test_sid_rqkmeans_train_eval(self): glob.glob(os.path.join(self.test_dir, "train", "model.ckpt-*")), "no checkpoint persisted after on_train_end", ) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "train", "eval_result.txt")), - "no eval_result.txt produced", - ) + # A fitted codebook yields finite metrics; a degenerate/unfitted one + # never exposes x_hat -> metrics stay NaN. So assert finiteness, plus + # rel_loss < 1.0 (all-zero baseline ~ 1.0) and nonzero SID variety. + result_path = os.path.join(self.test_dir, "train", "eval_result.txt") + self.assertTrue(os.path.exists(result_path), "no eval_result.txt produced") + with open(result_path) as f: + lines = [ln for ln in f.read().splitlines() if ln.strip()] + self.assertTrue(lines, "eval_result.txt is empty") + metrics = json.loads(lines[-1]) + for key in ("mse", "rel_loss", "unique_sid_ratio"): + self.assertIn(key, metrics) + self.assertTrue( + math.isfinite(metrics[key]), f"{key} not finite: {metrics[key]}" + ) + self.assertLess(metrics["rel_loss"], 1.0) + self.assertGreater(metrics["unique_sid_ratio"], 0.0) if __name__ == "__main__": From 5f5af01400400312f02e1e063c6efd8dfd5f0efb Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 12:27:44 +0000 Subject: [PATCH 45/46] [review] SID: drop _extract_feature width guard (embedding width is never 1) The (B, 1) broadcast footgun isn't reachable in practice, so revert _extract_feature to the plain feature read and remove its negative test. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_model.py | 11 +---------- tzrec/models/sid_rqkmeans_test.py | 6 ------ 2 files changed, 1 insertion(+), 16 deletions(-) diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index d3023090c..8db468799 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -96,16 +96,7 @@ def _extract_feature( if feature_name is None: feature_name = self._embedding_feature_name kt = batch.dense_features[BASE_DATA_GROUP] - embedding = kt[feature_name] - # Guard a misconfigured width: a (B, 1) tensor (raw_feature missing - # value_dim) would broadcast silently into a degenerate rank-1 codebook. - if embedding.dim() != 2 or embedding.shape[1] != self._input_dim: - raise ValueError( - f"feature '{feature_name}' has shape {tuple(embedding.shape)}, " - f"expected (B, {self._input_dim}); check that its value_dim " - "matches the SID input_dim." - ) - return embedding + return kt[feature_name] def init_loss(self) -> None: """Initialize loss modules. diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index 273d123a2..0b68fefa6 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -124,12 +124,6 @@ def test_init_raises_on_zero_input_dim(self) -> None: with self.assertRaisesRegex(ValueError, "input_dim must be >= 1"): self._create_model(input_dim=0) - def test_predict_raises_on_wrong_feature_width(self) -> None: - """A feature whose width != input_dim fails fast (missing value_dim).""" - model = self._create_model(input_dim=32) - with self.assertRaisesRegex(ValueError, "value_dim"): - model.predict(_batch_from_rows(torch.randn(8, 1))) - def test_predict_collects_buffer(self) -> None: """In train mode, predict reservoir-samples; never fits.""" B, input_dim = 8, 32 From 43e84cadb2a254439092c919677d3757611de65d Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Thu, 11 Jun 2026 01:58:04 +0000 Subject: [PATCH 46/46] [fix] SID integration test: skip on CUDA, run on CPU CI The end-to-end train_eval is CPU-only (SidRqkmeans refuses a visible CUDA device). Forcing CPU on the CUDA-built GPU CI image is unreliable (the prior CUDA_VISIBLE_DEVICES="" / "-1" workarounds both still failed in the train_eval child). Skip when CUDA is available so the test runs on the CPU CI job (where it passes) and skips on the GPU runner. Keep nproc=1 for the single-process guard. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/tests/sid_integration_test.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tzrec/tests/sid_integration_test.py b/tzrec/tests/sid_integration_test.py index 0c3595414..53f24a1d3 100644 --- a/tzrec/tests/sid_integration_test.py +++ b/tzrec/tests/sid_integration_test.py @@ -21,6 +21,7 @@ import numpy as np import pyarrow as pa import pyarrow.parquet as pq +import torch from tzrec.tests import utils from tzrec.utils import config_util @@ -33,13 +34,9 @@ def setUp(self): os.makedirs("./tmp") self.test_dir = tempfile.mkdtemp(prefix="tzrec_", dir="./tmp") os.chmod(self.test_dir, 0o755) - # SID is CPU-only + single-process, so hide CUDA and pin nproc=1 (the - # GPU CI harness defaults to GPU + nproc=2). Use "-1", not "" — an empty - # CUDA_VISIBLE_DEVICES is treated inconsistently and the GPU CI runner - # doesn't hide the devices, tripping the CPU-only guard in the child. - patcher = mock.patch.dict( - os.environ, {"CUDA_VISIBLE_DEVICES": "-1", "TEST_NPROC_PER_NODE": "1"} - ) + # SidRqkmeans is single-process; pin nproc=1 (the CI harness defaults + # to 2, which would trip the world_size>1 guard). + patcher = mock.patch.dict(os.environ, {"TEST_NPROC_PER_NODE": "1"}) patcher.start() self.addCleanup(patcher.stop) @@ -73,6 +70,11 @@ def _prepare_config(self, num_rows: int, dim: int) -> str: config_util.save_message(config, config_path) return config_path + @unittest.skipIf( + torch.cuda.is_available(), + "SidRqkmeans is CPU-only; this end-to-end test runs on the CPU CI job. " + "Forcing CPU on a CUDA-built (GPU) image is unreliable.", + ) def test_sid_rqkmeans_train_eval(self): """End-to-end train -> on_train_end FAISS fit -> checkpoint -> eval.