Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions configs/acoustic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ timesteps: 1000
max_beta: 0.02
enc_ffn_kernel_size: 3
use_rope: true
rope_interleaved: false
use_stretch_embed: true
use_variance_scaling: true
rel_pos: true
Expand Down
1 change: 1 addition & 0 deletions configs/templates/config_acoustic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ augmentation_args:
diffusion_type: reflow
enc_ffn_kernel_size: 3
use_rope: true
rope_interleaved: false
use_stretch_embed: true
use_variance_scaling: true
use_shallow_diffusion: true
Expand Down
1 change: 1 addition & 0 deletions configs/templates/config_variance.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ tension_logit_max: 10.0

enc_ffn_kernel_size: 3
use_rope: true
rope_interleaved: false
use_stretch_embed: false
use_variance_scaling: true
hidden_size: 384
Expand Down
1 change: 1 addition & 0 deletions configs/variance.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ predict_tension: false

enc_ffn_kernel_size: 3
use_rope: true
rope_interleaved: false
use_stretch_embed: false
use_variance_scaling: true
rel_pos: true
Expand Down
17 changes: 16 additions & 1 deletion deployment/exporters/acoustic_exporter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from pathlib import Path
from typing import List, Union, Tuple, Dict
import warnings

import onnx
import onnxsim
Expand Down Expand Up @@ -78,6 +79,7 @@ def __init__(
self.export_spk = [(name, {name: 1.0}) for name in self.spk_map.keys()]
if self.freeze_spk is not None:
self.model.fs2.register_buffer('frozen_spk_embed', self._perform_spk_mix(self.freeze_spk[1]))
self.rope_interleaved = hparams.get('rope_interleaved', None)

def build_model(self) -> DiffSingerAcousticONNX:
model = DiffSingerAcousticONNX(
Expand All @@ -88,8 +90,21 @@ def build_model(self) -> DiffSingerAcousticONNX:
for p in self.phoneme_dictionary.cross_lingual_phonemes
})
).eval().to(self.device)
if self.rope_interleaved is None:
warnings.warn(
"After RoPE is refactored, the checkpoint no longer contains relevant parameters. "
"(https://github.com/openvpi/DiffSinger/pull/276)"
"In order to export ONNX with behavior compatible with past checkpoints, "
"it will be set to 'strict=False', which will no longer check the validity of the checkpoint. "
"Please understand what you are doing.",
UserWarning,
stacklevel=2
)
strict=False
else:
strict=True
load_ckpt(model, hparams['work_dir'], ckpt_steps=self.ckpt_steps,
prefix_in_ckpt='model', strict=True, device=self.device)
prefix_in_ckpt='model', strict=strict, device=self.device)
return model

def export(self, path: Path):
Expand Down
15 changes: 15 additions & 0 deletions deployment/exporters/variance_exporter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from pathlib import Path
from typing import Union, List, Tuple, Dict
import warnings

import onnx
import onnxsim
Expand Down Expand Up @@ -81,6 +82,7 @@ def __init__(
self.export_spk = [(name, {name: 1.0}) for name in self.spk_map.keys()]
if self.freeze_spk is not None:
self.model.register_buffer('frozen_spk_embed', self._perform_spk_mix(self.freeze_spk[1]))
self.rope_interleaved = hparams.get('rope_interleaved', None)

def build_model(self) -> DiffSingerVarianceONNX:
model = DiffSingerVarianceONNX(
Expand All @@ -90,6 +92,19 @@ def build_model(self) -> DiffSingerVarianceONNX:
for p in self.phoneme_dictionary.cross_lingual_phonemes
})
).eval().to(self.device)
if self.rope_interleaved is None:
warnings.warn(
"After RoPE is refactored, the checkpoint no longer contains relevant parameters. "
"(https://github.com/openvpi/DiffSinger/pull/276)"
"In order to export ONNX with behavior compatible with past checkpoints, "
"it will be set to 'strict=False', which will no longer check the validity of the checkpoint. "
"Please understand what you are doing.",
UserWarning,
stacklevel=2
)
strict=False
else:
strict=True
load_ckpt(model, hparams['work_dir'], ckpt_steps=self.ckpt_steps,
prefix_in_ckpt='model', strict=True, device=self.device)
model.build_smooth_op(self.device)
Expand Down
Loading