Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions stable_codec/model.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import json
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torchaudio

from typing import Optional, List, Tuple, Union
from einops import rearrange
from stable_audio_tools import get_pretrained_model
from stable_audio_tools.data.utils import VolumeNorm
from stable_audio_tools.models import create_model_from_config
from stable_audio_tools.models.fsq import DitheredFSQ
from stable_audio_tools.models.utils import load_ckpt_state_dict
from stable_audio_tools.training.utils import copy_state_dict
from stable_audio_tools.data.utils import VolumeNorm
from stable_audio_tools.models.utils import copy_state_dict, load_ckpt_state_dict

from .residual_fsq import ResidualFSQBottleneck
from stable_audio_tools import get_pretrained_model


class StableCodec(nn.Module):
def __init__(self,
Expand Down
16 changes: 10 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import copy
import json
import os
import pytorch_lightning as pl

from typing import Optional

import pytorch_lightning as pl
from prefigure.prefigure import get_all_args, push_wandb_config
from stable_audio_tools.models import create_model_from_config
from stable_audio_tools.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model
from stable_audio_tools.training.utils import copy_state_dict
from stable_audio_tools.models.utils import (
copy_state_dict,
load_ckpt_state_dict,
remove_weight_norm_from_model,
)

from stable_codec.training_module import create_training_wrapper_from_config
from stable_codec.training_demo import create_demo_callback_from_config
from stable_codec.data.dataset import create_dataloader_from_config
from stable_codec.training_demo import create_demo_callback_from_config
from stable_codec.training_module import create_training_wrapper_from_config


class ExceptionCallback(pl.Callback):
def on_exception(self, trainer, module, err):
Expand Down