Skip to content

eladwf/adaptive-multirate-transformers

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Adaptive Frequency-Domain DSP for Language Model Training

Gradient-Based Multirate Signal Processing with Learnable Hyperparameters for Transformer Optimization

Python 3.8+ PyTorch License: MIT

Abstract

This work presents an adaptive DSP-augmented transformer architecture that applies classical multirate signal processing theory to language model hidden states, achieving significant improvements in both convergence speed and final model quality. The approach decomposes token sequences into coarse and detail frequency bands using analysis-synthesis filterbanks, applies low-frequency oscillator (LFO) routing for temporal gating, and employs channel bottlenecks for regularization. Critically, DSP hyperparameters are treated as learnable meta-parameters optimized via gradient descent during training.

On character-level language modeling benchmarks (enwik8, text8), this method achieves:

  • 19.1% improvement on enwik8 (validation loss: 1.369 vs 1.693)
  • 12.2% improvement on text8 (validation loss: 1.293 vs 1.473)
  • 65-68% faster convergence to equivalent loss targets
  • Statistically significant gains across 5 random seeds (>14σ)

Scope & Limitations

This is an exploratory study on small GPT models and character-level language modeling (enwik8, text8). The results show large relative gains in this setting, but this is not a SOTA claim, and it is unclear how well the same DSP design scales to much larger models or other domains. The goal of this repo is to provide a reproducible reference implementation of multirate + LFO DSP wrappers and adaptive DSP hyperparameters, not to define a new standard architecture.


Table of Contents


Key Results

Multi-Seed Validation Loss Comparison


enwik8

text8

Training Efficiency

Dataset Baseline Loss Adaptive Multirate Loss Improvement Statistical Significance
enwik8 1.693 ± 0.027 1.370 ± 0.008 0.323 (19.1%) 14.85σ
text8 1.473 ± 0.005 1.293 ± 0.003 0.180 (12.2%) 38.8σ

Training Efficiency: Adaptive multirate DSP achieves equivalent loss targets ~65-68% faster (measured in steps, tokens, and FLOPs).


Method Overview

The architecture augments standard transformer blocks with three DSP operations inserted at strategic points:

Input → Multirate Decomposition → Self-Attention → LFO Routing → MLP → Channel Bottleneck → Output
       [pre-attention]                           [post-attn]         [post-mlp]

Each DSP block processes hidden states $H \in \mathbb{R}^{B \times T \times D}$ where $B$ is batch size, $T$ is sequence length, and $D$ is model dimension.

1. Multirate Subband Decomposition

Motivation: Token sequences contain information at multiple temporal scales. Attention mechanisms benefit from explicit frequency separation.

Implementation: Apply analysis-synthesis filterbank to decompose hidden states into coarse (low-frequency) and detail (high-frequency) components.

Forward Pass

Given hidden states $H \in \mathbb{R}^{B \times T \times D}$:

  1. Transpose for depthwise convolution:

    $X = H^\top \in \mathbb{R}^{B \times D \times T}$

  2. Analysis: Low-pass filtering + decimation:

    $L = \text{Conv1D}_{\text{low}}(X, \text{stride}=s) \in \mathbb{R}^{B \times D \times T_c}$

    where $T_c = \lceil T/s \rceil$ and $s = \text{downsample}$ (typically 2-4).

  3. Synthesis: Upsampling + reconstruction:

$$L_\uparrow = \mathrm{Upsample}(L) \in \mathbb{R}^{B \times D \times T}, \qquad \tilde{L} = \mathrm{Conv1D}_{\text{recon}}(L_\uparrow)$$
  1. Detail path extraction:

    $R = X - \tilde{L}$ (residual/detail)

    $D = \text{Conv1D}_{\text{detail}}(R)$

  2. Mixing with adaptive strength:

    $X_{proc} = \tilde{L} + \alpha D$, where $\alpha \in [0.5, 1.0]$ is detail_strength

  3. Residual blending:

    $H_{out} = m \cdot H + (1 - m) \cdot X_{proc}^\top$, where $m \in [0.2, 0.6]$ is mix_ratio

Learnable Parameters: mix_ratio, detail_strength (optionally downsample factor).

Intuition: This mimics wavelet decomposition but with learnable filters. Low frequencies capture long-range dependencies; high frequencies preserve fine-grained details.


2. LFO-Based Temporal Routing

Motivation: Different channels may benefit from different temporal processing patterns. Inspired by analog synthesizer modulation.

Implementation: Split channels into $R$ routes, each with a time-varying gating function modulated by learned low-frequency oscillators (LFOs).

Forward Pass

  1. Reshape into routes:

    $H \in \mathbb{R}^{B \times T \times D} \rightarrow H_{group} \in \mathbb{R}^{B \times T \times R \times d}$, where $D = R \cdot d$

  2. Define LFO modulation (for route $r$, LFO $k$):

    $g_{r,k}(t) = a_{r,k} \sin(2\pi f_{r,k} t + \phi_{r,k})$

    where $f_{r,k} \in [0, f_{\max}]$, and $a_{r,k}, \phi_{r,k}$ are learnable.

  3. Compute gating signal:

    $u_r(t) = \sum_{k=1}^{K} g_{r,k}(t) + b_r$

    $\gamma_r(t) = \sigma\left(\frac{u_r(t)}{T_{gate}}\right)$, where $T_{gate} \in [0.5, 2.0]$ is gate_temperature

  4. Apply grouped temporal convolution:

$$ H_{\mathrm{routed}} = \mathrm{Conv1D}_{\mathrm{route}}(H)_{\mathrm{group}} $$

  1. Time-varying gating:

$$\tilde{H}_{\mathrm{group}}(t, r) = \gamma_r(t), H_{\mathrm{routed}}(t, r) + \bigl(1 - \gamma_r(t)\bigr), H_{\mathrm{group}}(t, r)$$

  1. Reshape and blend:

$$ H_{\mathrm{routed}}^{\mathrm{flat}} = \mathrm{Flatten}!\bigl(\tilde{H}_{\mathrm{group}}\bigr) \in \mathbb{R}^{B \times T \times D} $$

$$ H_{\mathrm{out}} = \rho \cdot H + (1 - \rho), H_{\mathrm{routed}}^{\mathrm{flat}}, \quad \rho \in [0.2, 0.5] $$

where $\rho$ is residual_mix.

Learnable Parameters: residual_mix, gate_temperature, LFO frequencies/phases/amplitudes.

Intuition: Allows the model to dynamically route information through time-varying pathways, creating temporal structure beyond fixed attention patterns.


3. Channel Bottleneck Regularization

Motivation: Prevent overfitting by compressing channel dimensionality through a learned bottleneck.

Implementation: Per-token MLP with reduced hidden dimension.

Forward Pass

  1. Layer normalization:

    $Z = \text{LayerNorm}(H) \in \mathbb{R}^{B \times T \times D}$

  2. Bottleneck projection:

    $D_b = \lfloor \beta_{ratio} \cdot D \rfloor$, where $\beta_{ratio} \in [0.15, 0.35]$ is bottleneck_ratio

    $Z_{bottleneck} = W_2 , \phi(W_1 Z + b_1) + b_2$

    where $W_1 \in \mathbb{R}^{D \times D_b}$, $W_2 \in \mathbb{R}^{D_b \times D}$, and $\phi = \text{GELU}$.

  3. Residual update:

    $U = w_{res} \cdot Z_{bottleneck}$

    $H_{out} = H + (1 - \rho_{res}) \cdot U$, where $\rho_{res} \in [0.3, 0.7]$ is residual_mix

Learnable Parameters: bottleneck_ratio, residual_mix, residual_weight.

Intuition: Forces information compression, similar to variational autoencoders, encouraging efficient representations.


Adaptive Hyperparameter Learning

Motivation

Traditional DSP hyperparameters (mixing ratios, filter strengths, bottleneck dimensions) are manually tuned and fixed during training. This work instead treats them as meta-parameters optimized to minimize training loss.

Parameterization

Each hyperparameter $p \in [p_{\min}, p_{\max}]$ is represented via an unconstrained logit $z \in \mathbb{R}$:

$$\hat{p} = \frac{p - p_{\min}}{p_{\max} - p_{\min}} \quad \Rightarrow \quad z = \log\frac{\hat{p}}{1 - \hat{p}} \quad \text{(logit transform)}$$

During forward pass:

$$p = p_{\min} + (p_{\max} - p_{\min}) \cdot \sigma(z)$$

For integer parameters (e.g., downsample), we apply rounding after sigmoid.

Meta-Objective

The method maintains exponential moving averages (EMA) of training and validation losses:

$$\hat{\mathcal{L}}_{\text{train}}(t) = \alpha \mathcal{L}_{\text{train}}(t) + (1 - \alpha) \hat{\mathcal{L}}_{\text{train}}(t-1)$$

The meta-loss $\mathcal{L}_{\text{meta}}$ drives hyperparameter updates:

Option 1: Smoothed Loss (default):

$$\mathcal{L}_{\text{meta}} = \begin{cases} \hat{\mathcal{L}}_{\text{val}}(t) & \text{if validation data available} \\\ \hat{\mathcal{L}}_{\text{train}}(t) & \text{otherwise} \end{cases}$$

Option 2: Convergence Rate:

$$r = \frac{\mathcal{L}(t) - \mathcal{L}(t - K)}{K}, \quad \mathcal{L}_{\text{meta}} = -r$$

Option 3: Composite:

$$\mathcal{L}_{\text{meta}} = 0.7 \hat{\mathcal{L}}_{\text{train}}(t) - 0.3 r$$

Update Rule

Every $N$ steps (default $N = 100$):

$$z \leftarrow z - \eta_{\text{meta}} \nabla_z \mathcal{L}_{\text{meta}}$$

where $\eta_{\text{meta}} = 10^{-4}$ (meta-learning rate). The implementation uses Adam optimizer for meta-parameters.

Configuration Example

adaptive_dsp:
  enabled: true
  meta_lr: 1e-4
  meta_update_every: 100
  meta_objective: smoothed_loss
  ema_decay: 0.9

  learnable_params:
    multirate: [mix_ratio, detail_strength]
    lfo_routing: [residual_mix, gate_temperature]
    bottleneck_channel: [bottleneck_ratio, residual_mix]

Mathematical Formulation

Overall Architecture

For transformer layer $l$, the forward pass is:

$$\begin{align} H^{(l)}_{\text{pre}} &= \text{Multirate}(H^{(l-1)}) \\\ H^{(l)}_{\text{attn}} &= \text{SelfAttention}(H^{(l)}_{\text{pre}}) \\\ H^{(l)}_{\text{post}} &= \text{LFO}(H^{(l)}_{\text{attn}}) \\\ H^{(l)}_{\text{mlp}} &= \text{MLP}(H^{(l)}_{\text{post}}) \\\ H^{(l)} &= \text{Bottleneck}(H^{(l)}_{\text{mlp}}) \end{align}$$

Learnable Meta-Parameters

Let $\theta$ denote model weights and $\lambda$ denote DSP hyperparameters:

$$\lambda = \{\text{mix\_ratio}, \text{detail\_strength}, \text{residual\_mix}, \text{gate\_temperature}, \text{bottleneck\_ratio}, \ldots\}$$

Training objective:

$$\min_{\theta, \lambda} \mathbb{E}_{(x, y) \sim \mathcal{D}} \left[ \mathcal{L}(f_{\theta, \lambda}(x), y) \right]$$

Two-level optimization:

  • Inner loop (every step): Update $\theta$ via standard SGD/Adam
  • Outer loop (every $N$ steps): Update $\lambda$ via meta-gradient descent

Installation & Reproducibility

Prerequisites

  • Python 3.8+
  • PyTorch 2.0+
  • CUDA 11.0+ (for GPU training)

Setup

git clone https://github.com/eladwf/adaptive-multirate-transformers.git
cd adaptive-multirate-transformers
pip install -r requirements.txt

Dataset Preparation

enwik8

python scripts/prepare_enwik8.py

Downloads and prepares the enwik8 dataset (100MB Wikipedia XML).

text8

bash scripts/setup_text8.sh

WikiText-2

python scripts/download_wikitext.py wikitext-2

Running Experiments

Single Run (Quick Test)

# Baseline GPT (enwik8)
python scripts/run_experiment.py configs/base_gpt_enwik8.yaml

# Adaptive Multirate DSP (enwik8)
python scripts/run_experiment.py configs/dsp_example_multirate_enwik8_adaptive.yaml

Multi-Seed Verification (Reproduce Paper Results)

python scripts/verify_experiment.py \
  --baseline configs/base_gpt_enwik8.yaml \
  --treatment configs/dsp_example_multirate_enwik8_adaptive.yaml \
  --seeds 42 123 456 789 2024 \
  --long-steps 4000 8000 \
  --results-dir results/verification/enwik8_verification

Analyze results:

python scripts/analyze_verification.py results/verification/enwik8_verification
python scripts/analyze_efficiency.py results/verification/enwik8_verification --metric train --auto-targets

Generate Plots

python generate_plots.py

Outputs figures to figures/ directory.


Experimental Results

Training Curves


enwik8 Training Dynamics

text8 Training Dynamics

Statistical Analysis

enwik8 (Character-Level Language Modeling)

Metric Baseline GPT Adaptive Multirate Improvement
Multi-seed val loss 1.693 ± 0.027 1.370 ± 0.008 0.323 (19.1%)
Significance 14.85σ
Longer training (20k steps) 1.223 1.132 0.092 (7.5%)
Training efficiency Baseline 68% faster

Notes:

  • Multi-seed results averaged over 5 random seeds (42, 123, 456, 789, 2024)
  • All improvements statistically significant (p < 0.001)
  • Training efficiency measured as steps to reach equivalent loss targets

text8 (Character-Level Language Modeling)

Metric Baseline GPT Adaptive Multirate Improvement
Multi-seed val loss 1.473 ± 0.005 1.293 ± 0.003 0.180 (12.2%)
Significance 38.8σ
Longer training (20k steps) 1.217 1.124 0.093 (7.6%)
Training efficiency Baseline 65.6% faster

Ablation Studies

Each DSP component was ablated to measure individual contributions:

Configuration enwik8 Val Loss Δ from Full System
Full System 1.370
w/o Adaptive Learning 1.428 +0.058
w/o Multirate 1.502 +0.132
w/o LFO Routing 1.411 +0.041
w/o Bottleneck 1.389 +0.019
Baseline (no DSP) 1.693 +0.323

Key Findings:

  • Multirate decomposition contributes most to performance (+0.132)
  • Adaptive learning provides +0.058 improvement over fixed hyperparameters
  • All components contribute positively; full system achieves best results

Project Structure

adaptive-multirate-transformers/
├── README.md                          # This file
├── LICENSE                            # MIT License
├── requirements.txt                   # Python dependencies
├── generate_plots.py                  # Visualization script
│
├── configs/                           # Experiment configurations
│   ├── base_gpt_enwik8.yaml
│   ├── base_gpt_text8.yaml
│   ├── dsp_example_multirate_enwik8_adaptive.yaml
│   ├── dsp_example_multirate_text8_adaptive.yaml
│   └── dsp_example_multirate_adaptive.yaml
│
├── figures/                           # Generated plots
│   ├── enwik8_multi_seed_comparison.png
│   ├── enwik8_training_curves.png
│   ├── text8_multi_seed_comparison.png
│   ├── text8_training_curves.png
│   └── training_efficiency_comparison.png
│
├── llm_dsp_lab/                       # Core library
│   ├── models/                        # Transformer models
│   ├── dsp/                           # DSP blocks
│   │   ├── multirate.py               # Subband decomposition
│   │   ├── lfo_routing.py             # LFO gating
│   │   ├── bottleneck_channel.py      # Channel bottleneck
│   │   └── ...
│   ├── train/                         # Training infrastructure
│   │   ├── train_loop.py
│   │   ├── adaptive_dsp_wrapper.py    # Learnable parameter wrapper
│   │   └── adaptive_controller.py     # Meta-learning controller
│   ├── analysis/                      # Visualization & metrics
│   └── utils/                         # Utilities
│
└── scripts/                           # Experiment scripts
    ├── run_experiment.py
    ├── verify_experiment.py
    ├── analyze_verification.py
    ├── analyze_efficiency.py
    ├── prepare_enwik8.py
    ├── setup_text8.sh
    └── download_wikitext.py

License

This project is licensed under the MIT License - see the LICENSE file for details.

Citation Request

While not legally required by the MIT License, if you use this code in your research, please consider citing this work:

@misc{adaptive-multirate-transformers-2025,
  title={Adaptive Frequency-Domain DSP for Language Model Training:
         Gradient-Based Multirate Signal Processing with Learnable Hyperparameters},
  author={Elad Yifee},
  year={2025},
  url={https://github.com/eladwf/adaptive-multirate-transformers}
}

Citations help acknowledge the effort behind this work and enable tracking of research impact. Thank you!


Contact

If you have questions, ideas, or want to collaborate, feel free to reach out:


About

DSP-inspired multirate wrappers for GPT with adaptive hyperparameters and faster training.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published