Skip to content

BackGwa/NASER

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

51 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

NASER

A neural network model for stereo spatial restoration, covering legacy mono recordings and degraded stereo width enhancement.

Abstract

Stereo spatial information is absent in legacy mono recordings and attenuated in degraded or narrow mixes. This perceptual deficit is pronounced on modern headphone and loudspeaker systems and costly to correct manually in music production workflows. NASER is a neural network model that addresses this deficit directly, operating in the mid-side domain to estimate the missing or attenuated side component without altering the monophonic content. Rather than predicting left and right channels directly, the model conditions on the mid signal as a stable reference and reconstructs $\hat{x}_{s}$ in the time-frequency plane via a learned complex mask applied to the mid STFT. The training pipeline jointly covers two degradation regimes: complete side removal (mono-to-stereo reconstruction) and partial side attenuation (stereo width enhancement), with the attenuation factor sampled online from $\mathcal{U}(0, 0.5)$ at each training step. The architecture combines a convolutional frequency encoder, a band-split transformer with dynamic global context, a complex-mask decoder, and an auxiliary psychoacoustic head that predicts ILD, ICC, and IPD at full 1025-bin resolution via a learned transposed convolution. The model is optimized jointly for waveform fidelity ($\mathcal{L}_{\text{time}}$, $\mathcal{L}_{\text{spec}}$), complex spectral consistency ($\mathcal{L}_{\text{complex}}$), perceptual stereo width ($\mathcal{L}_{\text{width}}$), and psychoacoustic spatial accuracy ($\mathcal{L}_{\text{ps}}$).

Contents

Motivation

Stereo spatial information is absent or degraded in a wide range of practically encountered audio. Legacy recordings (mono broadcasts, early studio sessions, and tape-digitized archives) carry no spatial content by construction, and the perceptual gap between mono playback and natural stereo imaging is pronounced on modern headphone and loudspeaker systems. In music production, spatial width is a deliberate expressive property; restoring or enhancing it for narrow or collapsed mixes is a routine but labor-intensive task that currently relies on manual processing by audio engineers. In both cases, the underlying requirement is the same: given an audio signal with missing or attenuated spatial cues, reconstruct a plausible and perceptually coherent stereo field without altering the monophonic content.

Existing approaches to this problem are either rule-based signal processors that generalize poorly across content types, or deep learning models designed for source separation that do not target spatial structure as an explicit objective. NASER is designed to fill this gap as a neural network model that treats spatial restoration as a first-class learning objective, jointly handling the full range from complete spatial absence to partial stereo width degradation within a single model.

Problem Formulation

The two target use cases (legacy mono restoration and stereo width enhancement) reduce to a common mathematical structure. In both cases, the available input is a mid signal $x_m$ and a degraded or absent side signal $\tilde{x}_{s}$, and the objective is to recover the original side signal $x_s$. The model is designed for two input conditions.

  • mono reconstruction: the side component is entirely absent ($\tilde{x}_{s} = 0$), as in legacy mono recordings or intentional mono downmixes
  • degraded stereo enhancement: the side component is attenuated ($\tilde{x}_{s} = \alpha x_s$, $\alpha \in (0, 0.5)$), as in over-compressed or width-reduced mixes

Given stereo channels $x_L$ and $x_R$, the mid-side representation is defined as

$$ x_m = \frac{x_L + x_R}{2}, \qquad x_s = \frac{x_L - x_R}{2} $$

The model receives the mid component $x_m$ and an incomplete spatial cue $\tilde{x}{s}$, and estimates a restored side signal $\hat{x}{s}$:

$$ \hat{x}_s = f_\theta(x_m,, \tilde{x}_s) $$

The final stereo output is reconstructed as

$$ \hat{x}_L = x_m + \hat{x}_s, \qquad \hat{x}_R = x_m - \hat{x}_s $$

Method

Notation

The following diagram shows how the principal symbols relate to one another through the signal processing chain. Tables below provide exact definitions.

flowchart TD
    subgraph INPUT["Input"]
        LR["x_L, x_R\nStereo waveforms"]
    end

    subgraph MS_DOMAIN["Mid-Side Decomposition"]
        XM["x_m = (x_L + x_R) / 2\nMid waveform"]
        XS["x_s = (x_L - x_R) / 2\nSide waveform"]
    end

    subgraph DEGRADE["Degradation  ·  α ~ U(0, 0.5)"]
        XST["x̃_s = α · x_s\nDegraded side input"]
        MONO["x̃_s = 0\nMono condition"]
    end

    subgraph STFT_DOMAIN["STFT Domain  ·  F = 1025,  T frames"]
        XMF["X_m ∈ ℂ^{F×T}\nMid STFT"]
        XSF["X̃_s ∈ ℂ^{F×T}\nDegraded side STFT"]
    end

    subgraph MODEL["NASER  f_θ"]
        ENC["E ∈ ℝ^{B×D×F'×T}\nEncoder output  ·  F' = 513"]
        BAND["Z ∈ ℝ^{K×T×D}\nBand tokens  ·  K = 24,  D = 256"]
        CTX["c ∈ ℝ^D\nGlobal context"]
        MASK["M_θ = (M_r, M_i)\nComplex mask"]
    end

    subgraph OUTPUT["Output"]
        XSHAT["X̂_s = M_θ ★ X_m  (complex mul)\nPredicted side STFT"]
        XSOUT["x̂_s = ISTFT(X̂_s)\nPredicted side waveform"]
        LROUT["x̂_L = x_m + x̂_s\nx̂_R = x_m − x̂_s"]
    end

    LR -->|"M/S transform"| XM
    LR -->|"M/S transform"| XS
    XS --> XST
    XS --> MONO
    XM -->|STFT| XMF
    XST -->|STFT| XSF
    MONO -->|STFT| XSF
    XMF --> ENC
    XSF --> ENC
    ENC -->|BandSplit| BAND
    BAND -->|"mean-pool + proj"| CTX
    CTX -->|cross-attn| BAND
    BAND -->|MaskDecoder| MASK
    XMF --> XSHAT
    MASK --> XSHAT
    XSHAT -->|ISTFT| XSOUT
    XSOUT --> LROUT
    XM --> LROUT
Loading

Signals

Symbol Definition
$x_L,, x_R$ Original stereo left and right channel waveforms
$x_m$ Mid channel: $x_m = (x_L + x_R)/2$
$x_s$ Side channel: $x_s = (x_L - x_R)/2$
$\tilde{x}_{s}$ Degraded side input: $\tilde{x}_{s} = \alpha x_s$
$\hat{x}_{s}$ Predicted side output
$\hat{x}{L},, \hat{x}{R}$ Reconstructed stereo output
$X_m,, X_s$ Complex STFTs of mid and side signals, $X \in \mathbb{C}^{F \times T}$
$\hat{X}_{s}$ Predicted side STFT
$M_{\theta}$ Complex mask with real and imaginary components $(M_r,, M_i)$
$\alpha$ Stereo degradation attenuation factor, $\alpha \sim \mathcal{U}(0,, 0.5)$
$f,, t$ Frequency bin index and time frame index
$\epsilon$ Numerical stability floor, $\epsilon = 10^{-5}$

Architecture

Symbol Definition
$f_{\theta}$ NASER model function
$D$ Model dimension ($D = 256$)
$K$ Number of frequency bands ($K = 24$)
$F$ STFT frequency bin count ($F = 1025$)
$F'$ Encoder-compressed frequency bin count: $F' = \lceil F/2 \rceil = 513$
$T$ Number of STFT time frames
$B$ Batch size
$H$ Number of attention heads
$d_h$ Per-head dimension: $d_h = D/H$
$E$ Encoder output tensor: $E \in \mathbb{R}^{B \times D \times F' \times T}$
$\mathbf{Z}$ Band token tensor: $\mathbf{Z} \in \mathbb{R}^{K \times T \times D}$
$\mathbf{z}_{k}$ $k$-th band token: $\mathbf{z}_{k} \in \mathbb{R}^{T \times D}$
$c$ Global context token: $c \in \mathbb{R}^{D}$
$[s_k,, e_k)$ Frequency bin range of band $k$
$W_k,, W_c$ BandSplit and GlobalContext projection weight matrices
$p,, p'$ Position indices in time attention
$\theta_i$ RoPE rotation frequency: $\theta_i = 10000^{-2i/d_h}$

Training and Losses

Symbol Definition
$r_b$ Band-$b$ side-to-mid magnitude ratio
$\hat{p}_{d},, p_d$ Predicted and target spatial parameter for descriptor $d \in {\mathrm{ILD}, \mathrm{ICC}, \mathrm{IPD}}$
$\eta_0$ Initial learning rate ($2 \times 10^{-4}$)
$r_{\mathrm{min}}$ Minimum learning rate ratio ($0.01$)
$S$ Total training step count
$T_w$ Warmup step count
$N$ Equal-power crossfade length in samples
$\ell$ Chunk index during inference

Overall Pipeline

flowchart TD
    A[Raw Stereo Audio] --> B[Mid-Side Transform]
    B --> C[Mid and Side]
    C --> D[Zero or Random Width Degradation on Side]
    C --> E[Mid STFT]
    D --> F[Side STFT]
    E --> G[Frequency Encoder]
    F --> G
    G --> H[Band-Split Transformer]
    H --> I[Complex Mask Decoder]
    E --> I
    I --> J[Estimated Side STFT]
    J --> K[ISTFT]
    K --> L[Reconstructed Side]
    C --> M[Stereo Reconstruction]
    L --> M
    M --> N[Enhanced Stereo Output]
Loading

Mid-Side Conditioning

The model is conditioned on mid and side components rather than raw left-right channels. This formulation directly reflects the two target use cases: in legacy mono restoration, $x_m$ is the sole available signal and $\tilde{x}_s = 0$; in stereo width enhancement, $x_m$ is the stable monophonic reference and $\tilde{x}_s$ carries partial spatial information. In both cases, the mid channel carries the monophonic content that must be preserved, while the side channel encodes the spatial width and directional contrast that must be reconstructed.

Time-Frequency Representation

NASER operates in the STFT domain. All signals are processed at a fixed sample rate of 48 kHz using the following configuration.

Parameter Value
Sample rate 48,000 Hz
FFT size 2048
Hop length 960 samples (20 ms)
Window length 2048 samples
Frequency bins 1025

Each waveform $x(n)$ is transformed via the short-time Fourier transform with a Hann window $w$ of length $N_{\text{fft}} = 2048$:

$$ X(f, t) = \sum_{n=0}^{N_{\text{fft}}-1} x(n + t \cdot H), w(n), e^{-j2\pi fn / N_{\text{fft}}} $$

where $H = 960$ is the hop length, $f \in {0, \ldots, F - 1}$ with $F = N_{\text{fft}}/2 + 1 = 1025$, and $t$ indexes the time frame. This yields a frequency resolution of

$$ \Delta f = \frac{f_s}{N_{\text{fft}}} = \frac{48000}{2048} \approx 23.4,\text{Hz} $$

and a time resolution of

$$ \Delta t = \frac{H}{f_s} = \frac{960}{48000} = 20,\text{ms} $$

The mid and side signals are each transformed into complex STFTs, and their real and imaginary parts are stacked along the channel dimension to form a four-channel input tensor:

$$ X = \mathrm{concat}!\bigl(\Re(X_m),, \Im(X_m),, \Re(\tilde{X}_s),, \Im(\tilde{X}_s)\bigr) $$

Frequency Encoder

The front-end encoder is a two-stage convolutional stack. The first stage extracts local time-frequency patterns with two successive $3 \times 3$ convolutions. The second stage applies a strided $3 \times 1$ convolution along the frequency axis, compressing the frequency dimension from 1025 to 513 bins, followed by a second $3 \times 3$ convolution at the compressed resolution. Group normalization and GELU activations follow each layer.

Layer Type Kernel Freq. Stride In Channels Out Channels
Stage 1 – 1 Conv2d 3×3 1 4 $D/4$
Stage 1 – 2 Conv2d 3×3 1 $D/4$ $D/2$
Stage 2 – 1 Conv2d 3×1 2 $D/2$ $D$
Stage 2 – 2 Conv2d 3×3 1 $D$ $D$

For $D = 256$, the channel progression is $4 \to 64 \to 128 \to 256 \to 256$. The input consists of 4 channels: real and imaginary parts of the mid and side STFTs, concatenated along the channel dimension. Stage 2 – 1 performs the frequency compression from 1025 to 513 bins via stride-2 convolution along the frequency axis.

The encoder maps

$$ X \in \mathbb{R}^{B \times 4 \times F \times T} ;\xrightarrow{\mathrm{Enc}}; E \in \mathbb{R}^{B \times D \times F' \times T}, \qquad F' = \left\lceil \frac{F}{2} \right\rceil = 513 $$

where $F = 1025$ is the number of STFT frequency bins.

Band-Split Transformer

The compressed representation is partitioned into 24 frequency bands with perceptually motivated boundaries: finer resolution at low frequencies and coarser resolution at high frequencies. The default model configuration is as follows.

Hyperparameter Value
Model dimension $D$ 256
Number of bands 24
Number of transformer blocks 12
Time attention heads 4
Band attention heads 4
Dropout 0.1

The 24 frequency bands and their boundaries are defined as follows.

Band Range (Hz) Bandwidth (Hz) Band Range (Hz) Bandwidth (Hz)
1 0 – 100 100 13 2500 – 3000 500
2 100 – 200 100 14 3000 – 3800 800
3 200 – 300 100 15 3800 – 5000 1200
4 300 – 450 150 16 5000 – 6500 1500
5 450 – 600 150 17 6500 – 8000 1500
6 600 – 800 200 18 8000 – 10000 2000
7 800 – 1000 200 19 10000 – 12500 2500
8 1000 – 1250 250 20 12500 – 15000 2500
9 1250 – 1500 250 21 15000 – 17500 2500
10 1500 – 1750 250 22 17500 – 20000 2500
11 1750 – 2000 250 23 20000 – 22000 2000
12 2000 – 2500 500 24 22000 – 24000 2000

Band boundaries are finer below 2 kHz (11 bands spanning 0–2000 Hz at 100–250 Hz resolution) and coarser above 2 kHz (13 bands spanning 2000–24000 Hz at 500–2500 Hz resolution).

Each band $k$ with bin range $[s_k, e_k)$ is projected to the model dimension via a dedicated linear layer:

$$ \mathbf{z}_k = W_k, \mathrm{vec}!\left(E_{s_k : e_k}\right) + \mathbf{b}_k \in \mathbb{R}^{T \times D}, \qquad W_k \in \mathbb{R}^{D \times (e_k - s_k)D},\quad \mathbf{b}_k \in \mathbb{R}^{D} $$

The full split produces the band tensor $\mathbf{Z} = [\mathbf{z}_1, \ldots, \mathbf{z}_K] \in \mathbb{R}^{K \times T \times D}$, which serves as input to the transformer stack.

Each transformer block applies three forms of interaction in sequence:

  • time attention: temporal context within each band, using rotary positional embeddings (RoPE)
  • band attention: interactions across all 24 bands at each time step
  • cross attention: injection of a dynamic global context token

Rotary Positional Embeddings. Time attention applies RoPE to query and key vectors. For position $p$ and each pair of head-dimension indices $(2i,, 2i + 1)$, the rotation is

$$ \begin{pmatrix} q'_p[2i] \ q'_p[2i+1] \end{pmatrix} = \begin{pmatrix} \cos p\theta_i & -\sin p\theta_i \ \sin p\theta_i & \cos p\theta_i \end{pmatrix} \begin{pmatrix} q_p[2i] \ q_p[2i+1] \end{pmatrix}, \qquad \theta_i = 10000^{-2i/d_h} $$

with the same rotation applied to keys. The resulting inner product $\langle q'{p},, k'{q} \rangle$ depends exclusively on the relative offset $(p - q)$, eliminating the need for absolute position encodings and allowing the model to process audio segments of any length.

The first 8 blocks use local time attention with a window of 150 frames; the remaining 4 blocks use full time attention over the entire sequence. Cross-attention is enabled in every other block, specifically blocks 2, 4, 6, 8, 10, and 12 under 1-based indexing (blocks where the 0-based index satisfies $i \bmod 2 = 1$). The internal structure of a single block is as follows.

flowchart LR
    Z["Z\n K×T×D"] --> LN1[LN]
    LN1 --> TA["Time Attn\nRoPE"]
    TA --> A1(( + ))
    Z --> A1

    A1 --> LN2[LN]
    LN2 --> BA["Band Attn\nK bands"]
    BA --> A2(( + ))
    A1 --> A2

    A2 --> LN3[LN]
    LN3 --> XA["Cross Attn\n← c"]
    XA --> A3(( + ))
    A2 --> A3

    A3 --> LN4[LN]
    LN4 --> FF["FFN\n4D · GELU"]
    FF --> A4(( + ))
    A3 --> A4

    A4 --> Z2["Z'\n K×T×D"]
Loading

Cross-attention (the third sub-layer) is present only in alternating blocks; the remaining blocks proceed directly from band attention to the feed-forward network.

At each cross-attention block, a global context token $c$ is freshly computed from the current band representation:

$$ c = W_c, \mathrm{vec}!\left(\frac{1}{T}\sum_{t=1}^{T} \mathbf{Z}_{:,t,:}\right) \in \mathbb{R}^{D}, \qquad W_c \in \mathbb{R}^{D \times KD} $$

This dynamic context evolves with the representation as it passes through the transformer stack, maintaining semantic consistency between the cross-attention keys and the query features at each depth.

flowchart TD
    A[Compressed Frequency Features] --> B[Band Split: 24 bands]
    B --> C[Blocks 1–8\nLocal Time Attn · Band Attn · Cross Attn ×4]
    C --> D[Blocks 9–12\nFull Time Attn · Band Attn · Cross Attn ×2]
    D --> E[Band-Aware Latent Representation]
Loading

Complex Mask Decoding

The decoder first reconstructs the full frequency representation from the band-aware latent via BandMerge. Each band token $\mathbf{z}_k \in \mathbb{R}^{T \times D}$ is projected back to its frequency slice through a dedicated linear layer:

$$ \hat{E}_{s_k:e_k} = \mathbf{z}_k, V_k + \mathbf{c}_k, \qquad V_k \in \mathbb{R}^{D \times (e_k - s_k)D},\quad \mathbf{c}_k \in \mathbb{R}^{(e_k-s_k)D} $$

The contributions from all bands are summed to reconstruct the full-resolution feature map $\hat{E} \in \mathbb{R}^{B \times D \times F' \times T}$, which is then upsampled to $F = 1025$ bins via transposed convolution. The decoder then predicts a complex-valued mask. The estimated mask $M_{\theta}$ is applied to the mid STFT via complex multiplication:

$$ \hat{X}_s = M_\theta(X_m, \tilde{X}_s) \circledast X_m $$

where $\circledast$ denotes complex multiplication.

The restored side waveform is obtained by inverse STFT:

$$ \hat{x}_s = \mathrm{ISTFT}(\hat{X}_s) $$

The complex multiplication is performed explicitly on real and imaginary parts, where $M_{\theta} = (M_r, M_i)$:

$$ \Re(\hat{X}_s) = M_r \cdot \Re(X_m) - M_i \cdot \Im(X_m) $$ $$ \Im(\hat{X}_s) = M_r \cdot \Im(X_m) + M_i \cdot \Re(X_m) $$

Masking the mid spectrum, rather than predicting the side spectrum directly, enforces phase-consistent reconstruction and maintains structural coherence with the input mid channel.

Auxiliary Spatial Head

In addition to the main decoder, the model includes a psychoacoustic parameter head that operates on the same band-aware latent representation and predicts three spatial descriptors at each time-frequency bin across the full 1025-bin frequency range. The band-aware latent is first merged and upsampled from 513 to 1025 frequency bins via a learned transposed convolution (stride 2 along the frequency axis), then projected to three output channels by a point-wise convolution. This produces predictions at the same resolution as the STFT without any interpolation.

The three predicted spatial descriptors are:

  • ILD (inter-channel level difference), in dB:

$$ \mathrm{ILD}(f,t) = \mathrm{clip}!\left(20\log_{10}\frac{\max!\bigl(|S(f,t)|,,\epsilon\bigr)}{\max!\bigl(|M(f,t)|,,\epsilon\bigr)},;-100,;100\right) $$

  • ICC (inter-channel coherence):

$$ \mathrm{ICC}(f,t) = \frac{\bigl|\langle M(f,t),, S^*(f,t)\rangle\bigr|}{|M(f,t)|\cdot|S(f,t)|} $$

  • IPD (inter-channel phase difference), in radians:

$$ \mathrm{IPD}(f,t) = \angle!\bigl(M(f,t)\cdot S^*(f,t)\bigr) $$

where $\epsilon = 10^{-5}$ is applied as a minimum threshold before the logarithm to prevent numerical divergence in near-silent frequency bins. The ILD is additionally clamped to $[-100, 100]$ dB to suppress instability from extreme level ratios. This auxiliary branch constrains the model to reproduce the spatial structure of the target across the entire audible spectrum, not merely its side-channel energy.

Model Complexity

Parameter counts for the default configuration ($D = 256$, 12 transformer blocks, 24 bands).

Component Parameters
Frequency Encoder 0.77 M
Band-Split Projections 33.63 M
Transformer Stack 15.78 M
Mask Decoder 33.85 M
Psychoacoustic Head 33.95 M
Total (training) 118.0 M
Total (inference) 84.0 M

The per-band linear projections (BandSplit and BandMerge) are the dominant source of parameters. Each projection layer maps variable-width frequency slices to or from the model dimension $D$; the total parameter count per layer is

$$ \sum_{k=1}^{K} (e_k - s_k) \cdot D^2 = F' \cdot D^2 = 513 \times 256^2 \approx 33.6,\text{M} $$

Three such layers exist across the full model: one BandSplit in the transformer input, one BandMerge inside the Mask Decoder, and one BandMerge inside the Psychoacoustic Head. These three layers together account for $3 \times 33.6 \approx 100.8,\text{M}$ parameters, or 85.4% of the 118.0 M training total. The remaining 17.2 M are distributed across the convolutional encoder, transformer attention and feed-forward layers, and output convolutions.

Design Rationale

The architectural and training decisions in NASER are each grounded in the requirements of the two target tasks: reconstructing a fully absent side component from mono recordings, and restoring attenuated spatial width from degraded stereo mixes. This section explains the reasoning behind the principal choices and identifies what breaks under alternative designs.

Domain and Output Parameterization

Working in the mid-side domain rather than directly on left and right channels is a structural choice that constrains the output space in a useful way. Since $x_m$ is fixed and appears unchanged in the reconstruction $\hat{x}{L} = x_m + \hat{x}{s}$, $\hat{x}{R} = x_m - \hat{x}{s}$, the model cannot distort the monophonic content regardless of its prediction for $\hat{x}{s}$. Predicting $(\hat{x}{L}, \hat{x}_{R})$ directly offers no such guarantee: even a small error in the predicted left-right balance introduces audible content-level distortion, and the high correlation between the two channels makes the regression target poorly conditioned.

The choice to apply a complex mask to $X_m$ rather than predict $\hat{X}{s}$ directly follows from the same principle. In natural stereo recordings, $X_m$ and $X_s$ share a common phase structure since they originate from the same acoustic sources. Expressing $\hat{X}{s} = M_{\theta} \circledast X_m$ anchors the predicted phase to that of $X_m$ by construction, preventing the phase misalignment that produces comb filtering and tonal instability in the reconstructed output. A real-valued mask would avoid this problem as well, but it is restricted to the magnitude domain and cannot capture the inter-channel phase relationships that IPD supervision targets. The unconstrained complex mask corresponds to the complex ideal ratio mask (cIRM), which is the theoretically optimal solution under the MSE criterion in the complex domain.

Frequency Decomposition

Applying full-frequency self-attention directly over $F' = 513$ bins is computationally intractable within a 12-block stack at 15-second context lengths ($T = 750$ frames). The band-split factorization reduces per-layer complexity from $O(F'^2 T)$ to $O(KT^2 + K^2 T)$, a reduction of approximately 17.5$\times$ at these dimensions. Beyond the computational argument, the perceptually non-uniform band boundaries concentrate modeling capacity where it matters most for spatial impression: 11 bands below 2 kHz at 100–250 Hz resolution, versus coarser partitioning above. Equal-width frequency bands would allocate the same capacity to a 100 Hz slice at 200 Hz as to one at 20 kHz, which inverts the actual perceptual weighting of spatial cues.

Global Context and Positional Encoding

The dynamic global context addresses a depth-consistency problem. A context token computed once from the encoder output and reused across all cross-attention layers carries shallow features that become semantically inconsistent with the queries at deeper blocks. Recomputing $c^{(\ell)}$ from the current representation at each cross-attention layer ensures that the injected global information remains aligned with the state of the model at that depth. The cost of the recomputation — a mean-pool and a linear projection — is $O(KTD + KD^2)$ per block, which is negligible against the attention operations.

RoPE is preferred over absolute positional encodings because inference operates on audio segments of arbitrary length, including recordings longer than any training chunk. Absolute encodings are bounded to the training context; positions outside that range produce embedding values that the model has never encountered, leading to inconsistent attention patterns. RoPE encodes only relative position, so the model's behavior is length-agnostic and the same temporal pattern receives identical treatment regardless of where it appears in the audio.

Training Distribution

Fixing $\alpha$ at preprocessing time ties each stored chunk to a single attenuation level for the entire training run. After 100 epochs, the model has seen the same $(x_m, \tilde{x}s)$ pair 100 times with no variation in the degradation condition. Online sampling from $\mathcal{U}(0, 0.5)$ at each training step ensures that the same chunk is encountered under a different $\alpha$ in each epoch, covering the full continuous distribution rather than $N{\text{files}}$ discrete levels. This is especially important at deployment, where the degree of width degradation in a real recording is unknown and continuous.

Data Construction

Preprocessing Strategy

Raw stereo audio is segmented into fixed-length chunks and stored as .npz files. The default configuration is as follows.

Parameter Value
Chunk length 15.0 s
Margin (overlap) 5.0 s
Effective content per chunk 10.0 s
Validation split 10%
Preprocessing workers 4

Each chunk is structured as a symmetric overlap region: 2.5 s margin on each side around 10 s of content. For chunk index $i$, the sample range extracted from the full waveform is

$$ \left[, i \cdot C_{\text{content}} - \frac{C_{\text{margin}}}{2},;; i \cdot C_{\text{content}} + C_{\text{content}} + \frac{C_{\text{margin}}}{2} ,\right) $$

where $C_{\text{content}} = 10 \times f_s = 480{,}000$ samples and $C_{\text{margin}} = 5 \times f_s = 240{,}000$ samples. Out-of-bounds regions are zero-padded. The margin preserves local context across chunk boundaries during both training and inference. Each .npz file stores the following fields.

Field Description
mid Mid channel waveform
side Target side channel waveform

Dual-Condition Training

Width degradation is applied stochastically at training time rather than at preprocessing time. For each training step consuming a chunk under the degraded stereo condition, the attenuation factor is independently sampled as

$$ \alpha \sim \mathcal{U}(0,, 0.5) $$

and the degraded input is constructed online as $\tilde{x}_s = \alpha, x_s$. This ensures that the model encounters a different degradation level for the same chunk across training epochs, providing greater diversity than a fixed per-chunk $\alpha$.

Each stored chunk is consumed twice per training epoch:

  • mono condition: $\tilde{x}_s = 0$
  • degraded stereo condition: $\tilde{x}_s = \alpha, x_s,\quad \alpha \sim \mathcal{U}(0,, 0.5)$

This yields an effective dataset size of

$$ N_{\text{eff}} = 2 \times N_{\text{files}} $$

with a corresponding step count per epoch of $S_{\text{epoch}} = \lfloor N_{\text{eff}} / B \rfloor$, where $B$ is the batch size. The dual-condition scheme unifies mono reconstruction and stereo width enhancement within a single data format and model.

Validation Strategy

Validation is performed twice per epoch under two fixed-input conditions.

Condition Side input $\tilde{x}_s$ Purpose
Mono $0$ Measures reconstruction from complete absence of spatial cues
Degraded $0.25,x_s$ Measures enhancement at the mean attenuation level of the training distribution

Both conditions use the same target side signal $x_s$. The fixed inputs eliminate epoch-to-epoch randomness from validation, ensuring that changes in validation loss reflect genuine model improvement rather than input variation. The best checkpoint is selected based on the average of the two total losses.

Training Objective

NASER is optimized with a composite objective that combines waveform-domain accuracy, complex spectral consistency, and explicit spatial supervision:

$$ \mathcal{L} = \mathcal{L}_{\text{spec}} + 0.5,\mathcal{L}_{\text{complex}} + 0.1,\mathcal{L}_{\text{time}} + 0.5,\mathcal{L}_{\text{width}} + 0.2,\mathcal{L}_{\text{ps}} $$

$\mathcal{L}_{\text{spec}}$ — multi-scale spectral magnitude loss, averaged over FFT sizes ${2048, 1024, 512}$:

$$ \mathcal{L}_{\text{spec}} = \frac{1}{3} \sum_{n ,\in, {2048,,1024,,512}} \ell(n) $$

where the per-scale loss is

$$ \ell(n) = \bigl| |X_n| - |\hat{X}_n| \bigr|_1 + \bigl| \log|X_n| - \log|\hat{X}_n| \bigr|_1 $$

combining linear and log-domain L1 penalties to balance large- and small-magnitude accuracy. Averaging over scales ensures that no individual FFT resolution dominates the spectral objective.

$\mathcal{L}_{\text{complex}}$ — L1 distance on the complex STFT coefficients:

$$ \mathcal{L}_{\text{complex}} = \bigl| X_s - \hat{X}_s \bigr|_1 $$

$\mathcal{L}_{\text{time}}$ — waveform-domain L1 loss:

$$ \mathcal{L}_{\text{time}} = \bigl| x_s - \hat{x}_s \bigr|_1 $$

$\mathcal{L}_{\text{width}}$ — band-wise side-to-mid magnitude ratio loss, computed over 7 perceptual bands with boundaries at ${0, 300, 700, 1500, 3000, 6000, 12000, 24000}$ Hz. For each band $b$, the ratio is

$$ r_b(X) = \frac{\displaystyle\frac{1}{|b|}\sum_{(f,t)\in b} |X_s(f,t)|}{\displaystyle\frac{1}{|b|}\sum_{(f,t)\in b} |X_m(f,t)|} $$

and the loss is $\mathcal{L}_{\text{width}} = \frac{1}{7}\sum_b |r_b(\hat{X}) - r_b(X)|$. This term constrains the reconstructed stereo image to follow the target spatial width distribution across the spectrum.

$\mathcal{L}_{\text{ps}}$ — auxiliary L1 loss on the ILD, ICC, and IPD predictions from the psychoacoustic parameter head, computed against analytically derived targets over the full 1025-bin frequency range:

$$ \mathcal{L}_{\text{ps}} = \frac{1}{3} \sum_{d,\in,{\mathrm{ILD},,\mathrm{ICC},,\mathrm{IPD}}} \bigl|\hat{p}_d - p_d\bigr|_1 $$

where $\hat{p}_d(f, t)$ and $p_d(f, t)$ are the predicted and analytically computed values of descriptor $d$ at time-frequency bin $(f, t)$.

The learning rate schedule applies linear warmup for $T_{\mathrm{w}}$ steps followed by cosine annealing:

$$ \eta(t) = \begin{cases} \eta_0 \cdot \dfrac{t}{T_{\mathrm{w}}} & 0 \le t < T_{\mathrm{w}} \[10pt] \eta_0 \left( r_{\mathrm{min}} + \dfrac{1 - r_{\mathrm{min}}}{2} \left(1 + \cos!\left(\pi,\dfrac{t - T_{\mathrm{w}}}{S - T_{\mathrm{w}}}\right)\right) \right) & t \ge T_{\mathrm{w}} \end{cases} $$

where $t$ is the current optimizer step, $T_{\mathrm{w}}$ the warmup step count, $S$ the total step count, and $r_{\mathrm{min}} = 0.01$ the minimum LR ratio.

The training configuration is as follows.

Parameter Value
Optimizer AdamW
Learning rate $2 \times 10^{-4}$
LR schedule Linear warmup (3 epochs) then cosine annealing
Minimum LR ratio $r_{\mathrm{min}} = 0.01$
Batch size 1
Gradient accumulation 16 steps (effective batch = 16)
Gradient clipping L2 norm ≤ 1.0
Precision Mixed precision (AMP), initial scale $2^{10}$, growth interval 100 steps
Total epochs 100

Inference and Deployment

Chunked Inference

During inference, the trained model processes audio of arbitrary length, including full-length legacy recordings or complete music tracks, by segmenting the input into overlapping chunks using the same chunk length (15.0 s) and margin length (5.0 s) as preprocessing. Each chunk advances by a step of $\text{content_samples} - \text{crossfade_samples} = 10.0,\text{s} - 2.5,\text{s} = 7.5,\text{s}$, which differs from the preprocessing stride of 10.0 s. Adjacent predicted chunks are blended with an equal-power crossfade over the 2.5 s ($N = 120{,}000$ samples at 48 kHz) boundary region to suppress discontinuities. The blend at the $\ell$-th chunk boundary is

$$ y(n) = \sin!\left(\frac{\pi n}{2N}\right)\hat{x}_s^{(\ell+1)}(n) ;+; \cos!\left(\frac{\pi n}{2N}\right)\hat{x}_s^{(\ell)}(n), \qquad n = 0, \ldots, N-1 $$

which satisfies $\sin^2!\left(\frac{\pi n}{2N}\right) + \cos^2!\left(\frac{\pi n}{2N}\right) = 1$, preserving constant power across the transition.

flowchart TD
    A[Input Audio] --> B[Chunking with Margin]
    B --> C[Batch-wise Model Forward]
    C --> D[Predicted Side Chunks]
    D --> E[Overlap-Add]
    E --> F[Equal-Power Crossfade]
    F --> G[Full-Length Side Signal]
    G --> H[Stereo Reconstruction]
Loading

Export

The export pipeline supports the following formats.

Format Notes
ONNX Opset 17, dynamic batch, validated against onnxruntime
TorchScript Traced, validated by JIT load and forward pass

Each export also produces a metadata JSON file containing the sample rate, chunk size, FFT configuration, parameter count, and export format. The auxiliary psychoacoustic head is excluded from exported inference wrappers.

Installation

The project targets Python 3.12.

Linux & macOS:

./install.sh

Windows:

./install.ps1

The installation scripts detect uv, create a virtual environment, and synchronize dependencies. Commands can be run either by activating the virtual environment first, or directly via uv run without activation.

Usage

Preprocessing

Segments raw stereo audio into fixed-length .npz chunks and splits them into train and validation sets.

uv run preprocess --config config/preprocess.yaml
Argument Short Default Description
--config -c config/preprocess.yaml Path to preprocessing config YAML

Config fields (config/preprocess.yaml)

Field Default Description
datasets.raw datasets/raw Directory containing raw stereo audio files
datasets.train datasets/train Output directory for training chunks
datasets.valid datasets/valid Output directory for validation chunks
workers 4 Number of parallel preprocessing workers
audio_length 15.0 Total chunk length in seconds (content + margin)
margin_length 5.0 Overlap margin in seconds (2.5 s on each side)
valid_ratio 0.1 Fraction of chunks reserved for validation

Example

# Preprocess with default config
uv run preprocess --config config/preprocess.yaml

# Preprocess with a custom config
uv run preprocess --config config/preprocess_large.yaml

Training

uv run train --config config/train.yaml
Argument Short Default Description
--config -c config/train.yaml Path to training config YAML
--resume Path to checkpoint .pt; behavior depends on whether --config is also provided (see below)

Config fields (config/train.yaml)

Field Default Description
datasets.train datasets/train Directory of training .npz chunks
datasets.valid datasets/valid Directory of validation .npz chunks
models.output models/ Directory to save checkpoints
models.name naser-base Checkpoint filename prefix
device cuda Compute device (cuda or cpu)
epochs 100 Total number of training epochs
batch_size 1 Per-step batch size
gradient_accumulation_steps 16 Steps before each optimizer update (effective batch = 16)
learning_rate 0.0002 Initial learning rate $\eta_0$
lr_warmup_epochs 3 Number of warmup epochs before cosine annealing
save_interval 10 Save a numbered checkpoint every N epochs (0 = disabled)
preview_interval 1 Log audio previews to TensorBoard every N epochs (0 = disabled)

Resume vs. fine-tune

Command Behavior
uv run train --resume ckpt.pt Full resume: restores model weights, optimizer, scheduler, scaler, and epoch counter
uv run train --config cfg.yaml --resume ckpt.pt Fine-tune: loads model weights only, starts a fresh run under the new config

Examples

# Start a new training run
uv run train --config config/train.yaml

# Resume an interrupted run
uv run train --resume models/naser-base_last.pt

# Fine-tune from a pretrained checkpoint under a new config
uv run train --config config/train_finetune.yaml --resume models/naser-base_best.pt

Inference

Runs the model on an input audio file and writes the enhanced stereo output.

uv run inference --model models/naser-base_best.pt --input input.wav --output output.wav
Argument Short Required Default Description
--model -m Yes Path to model checkpoint .pt
--input -i Yes Path to input audio file (any format supported by torchaudio)
--output -o No {stem}_naser.wav Path to output .wav file
--batch-size No 1 Number of chunks processed per forward pass
--device No cuda Compute device (cuda or cpu)

Examples

# Basic usage — output saved as input_naser.wav
uv run inference --model models/naser-base_best.pt --input input.wav

# Specify output path
uv run inference --model models/naser-base_best.pt --input input.wav --output enhanced.wav

# Run on CPU with larger batch for throughput
uv run inference --model models/naser-base_best.pt --input input.wav --device cpu --batch-size 4

Export

Exports the trained model to a portable inference format. The auxiliary psychoacoustic head is excluded from all exports.

uv run export --model models/naser-base_best.pt --format onnx
uv run export --model models/naser-base_best.pt --format torchscript
Argument Short Required Default Description
--model -m Yes Path to model checkpoint .pt
--format -f Yes Export format: onnx or torchscript
--output -o No Same directory as checkpoint Output path for the exported file
--opset No 17 ONNX opset version (ignored for TorchScript)

Each export also writes a _meta.json file alongside the model file containing the following fields: sample_rate, chunk_samples, n_fft, hop_length, n_params, format, and opset (ONNX exports only).

Examples

# Export to ONNX (default opset 17)
uv run export --model models/naser-base_best.pt --format onnx

# Export to ONNX with a specific opset
uv run export --model models/naser-base_best.pt --format onnx --opset 18

# Export to TorchScript with a custom output path
uv run export --model models/naser-base_best.pt --format torchscript --output deploy/naser.torchscript

License

MIT

About

A neural network model for audio spatial enhancement with stereo reconstruction.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Contributors

Languages