From 542d6de83e3199a977c97c639e7c2787d8abd18b Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Mon, 24 Nov 2025 22:12:27 +0300 Subject: [PATCH] fix copy_state_dict imports --- stable_codec/model.py | 12 ++++++------ train.py | 16 ++++++++++------ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/stable_codec/model.py b/stable_codec/model.py index 92e93a4..14f541a 100644 --- a/stable_codec/model.py +++ b/stable_codec/model.py @@ -1,18 +1,18 @@ import json +from typing import List, Optional, Tuple, Union + import torch import torch.nn as nn import torchaudio - -from typing import Optional, List, Tuple, Union from einops import rearrange +from stable_audio_tools import get_pretrained_model +from stable_audio_tools.data.utils import VolumeNorm from stable_audio_tools.models import create_model_from_config from stable_audio_tools.models.fsq import DitheredFSQ -from stable_audio_tools.models.utils import load_ckpt_state_dict -from stable_audio_tools.training.utils import copy_state_dict -from stable_audio_tools.data.utils import VolumeNorm +from stable_audio_tools.models.utils import copy_state_dict, load_ckpt_state_dict from .residual_fsq import ResidualFSQBottleneck -from stable_audio_tools import get_pretrained_model + class StableCodec(nn.Module): def __init__(self, diff --git a/train.py b/train.py index 6384c0d..4d1ecd8 100644 --- a/train.py +++ b/train.py @@ -1,17 +1,21 @@ import copy import json import os -import pytorch_lightning as pl - from typing import Optional + +import pytorch_lightning as pl from prefigure.prefigure import get_all_args, push_wandb_config from stable_audio_tools.models import create_model_from_config -from stable_audio_tools.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model -from stable_audio_tools.training.utils import copy_state_dict +from stable_audio_tools.models.utils import ( + copy_state_dict, + load_ckpt_state_dict, + remove_weight_norm_from_model, +) -from stable_codec.training_module import create_training_wrapper_from_config -from stable_codec.training_demo import create_demo_callback_from_config from stable_codec.data.dataset import create_dataloader_from_config +from stable_codec.training_demo import create_demo_callback_from_config +from stable_codec.training_module import create_training_wrapper_from_config + class ExceptionCallback(pl.Callback): def on_exception(self, trainer, module, err):