diff --git a/src/seamless_communication/cli/m4t/evaluate/evaluate.py b/src/seamless_communication/cli/m4t/evaluate/evaluate.py index f7bf8177..a020717b 100644 --- a/src/seamless_communication/cli/m4t/evaluate/evaluate.py +++ b/src/seamless_communication/cli/m4t/evaluate/evaluate.py @@ -363,7 +363,7 @@ def run_eval( def load_checkpoint(model: UnitYModel, path: str, device = torch.device("cpu")) -> None: - saved_model = torch.load(path, map_location=device)["model"] + saved_model = torch.load(path, map_location=device, weights_only=True)["model"] saved_model = { k.replace("model.", ""): v for k, v in saved_model.items() } def _select_keys(state_dict: Dict[str, Any], prefix: str) -> Dict[str, Any]: