Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions burn-book/src/building-blocks/metric.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ throughout the training process. We currently offer a restricted range of metric

| Vision Metric | Description |
| ------------- | ---------------------------------------------------------------------------------------------------- |
| A-FINE | Computes the Adaptive Fidelity-Naturalness Evaluator (A-FINE) full-reference perceptual quality metric built on CLIP ViT-B/32 features |
| Dice | Computes the Dice-Sorenson coefficient (DSC) for evaluating overlap between binary masks |
| DISTS | Computes the Deep Image Structure and Texture Similarity (DISTS) metric for image quality assessment |
| FID | Computes the Frechet Inception Distance (FID) for evaluating generative model quality |
Expand Down
250 changes: 250 additions & 0 deletions crates/burn-train/src/metric/vision/afine/calibrators.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
//! Calibrators, adapter, and final-score mapping for A-FINE.
//!
//! Four small modules sit between the heads and the final per-sample
//! quality score:
//!
//! - [`NrCalibrator`] — logistic mapping of the naturalness head's raw
//! output into `(-2, 2)`. Two learnable scalars.
//! - [`FrCalibratorWithLimit`] — logistic mapping of the fidelity head's
//! raw output into `(-2, 2)`, with `yita3` clamped to `[0.05, 0.95]`
//! and `yita4` to `[0.01, 0.70]` on every forward.
//! - [`AfineAdapter`] — `D = exp(softplus(k) * (N_ref - N_dis)) * N_dis + F`.
//! Single learnable scalar `k`.
//! - [`scale_finalscore`] — fixed logistic into `(0, 100)` with the
//! paper-reported constants.
//!
//! All three calibrators implement the same logistic shape:
//! `out = (yita1 - yita2) * sigmoid((x - yita3) / (|yita4| + eps)) + yita2`.
//! This is the algebraic equivalent of PyIQA's two-branch
//! `if exp_pow >= 10` formulation, rewritten as a single expression so
//! it batches correctly. PyIQA's branch only works on 0-D scalar
//! tensors.

use burn_core as burn;

use burn::config::Config;
use burn::module::{Module, Param};
use burn::tensor::Tensor;
use burn::tensor::activation::{sigmoid, softplus};
use burn::tensor::backend::Backend;

const NR_YITA1: f64 = 2.0;
const NR_YITA2: f64 = -2.0;
const NR_YITA3_INIT: f32 = 4.9592;
const NR_YITA4_INIT: f32 = 21.5968;

const FR_YITA1: f64 = 2.0;
const FR_YITA2: f64 = -2.0;
const FR_YITA3_INIT: f32 = 0.5;
const FR_YITA4_INIT: f32 = 0.15;
const FR_YITA3_MIN: f32 = 0.05;
const FR_YITA3_MAX: f32 = 0.95;
const FR_YITA4_MIN: f32 = 0.01;
const FR_YITA4_MAX: f32 = 0.70;

const ADAPTER_K_INIT: f32 = 5.0;

const SCALE_YITA1: f64 = 100.0;
const SCALE_YITA2: f64 = 0.0;
const SCALE_YITA3: f64 = -1.971_0;
const SCALE_YITA4: f64 = -2.373_4;

/// Numerical-stability epsilon in the logistic denominator. Matches
/// PyIQA exactly; do not change without coordinating a parity-test
/// re-capture.
const EPS: f64 = 1e-10;

/// Apply `(yita1 - yita2) * sigmoid((x - yita3) / (|yita4| + eps)) + yita2`
/// element-wise, broadcasting the 1-D scalar parameters over the input.
fn logistic_calibrate<B: Backend>(
x: Tensor<B, 2>,
yita3: Tensor<B, 1>,
yita4_abs: Tensor<B, 1>,
yita1: f64,
yita2: f64,
) -> Tensor<B, 2> {
let yita3 = yita3.reshape([1, 1]);
let denom = yita4_abs.reshape([1, 1]).add_scalar(EPS);
let inner = (x - yita3) / denom;
sigmoid(inner).mul_scalar(yita1 - yita2).add_scalar(yita2)
}

/// Configuration for [`NrCalibrator`].
#[derive(Config, Debug)]
pub(crate) struct NrCalibratorConfig {}

impl NrCalibratorConfig {
pub(crate) fn init<B: Backend>(&self, device: &B::Device) -> NrCalibrator<B> {
NrCalibrator {
yita3: Param::from_tensor(Tensor::from_floats([NR_YITA3_INIT], device)),
yita4: Param::from_tensor(Tensor::from_floats([NR_YITA4_INIT], device)),
}
}
}

/// Naturalness logistic calibrator. Maps `[B, 1]` into `(-2, 2)`.
#[derive(Module, Debug)]
pub(crate) struct NrCalibrator<B: Backend> {
pub(crate) yita3: Param<Tensor<B, 1>>,
pub(crate) yita4: Param<Tensor<B, 1>>,
}

impl<B: Backend> NrCalibrator<B> {
pub(crate) fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
logistic_calibrate(
x,
self.yita3.val(),
self.yita4.val().abs(),
NR_YITA1,
NR_YITA2,
)
}
}

/// Configuration for [`FrCalibratorWithLimit`].
#[derive(Config, Debug)]
pub(crate) struct FrCalibratorWithLimitConfig {}

impl FrCalibratorWithLimitConfig {
pub(crate) fn init<B: Backend>(&self, device: &B::Device) -> FrCalibratorWithLimit<B> {
FrCalibratorWithLimit {
yita3: Param::from_tensor(Tensor::from_floats([FR_YITA3_INIT], device)),
yita4: Param::from_tensor(Tensor::from_floats([FR_YITA4_INIT], device)),
}
}
}

/// Fidelity logistic calibrator with on-forward clamping of `yita3` and
/// `yita4`. PyIQA clamps the values used in the formula on every call;
/// the stored parameter is unchanged.
#[derive(Module, Debug)]
pub(crate) struct FrCalibratorWithLimit<B: Backend> {
pub(crate) yita3: Param<Tensor<B, 1>>,
pub(crate) yita4: Param<Tensor<B, 1>>,
}

impl<B: Backend> FrCalibratorWithLimit<B> {
pub(crate) fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
// Match PyIQA semantics exactly: clamp first, then abs. The
// clamp range is positive so the abs is a no-op for in-range
// values, but for an out-of-range checkpoint or a parameter
// that drifts negative during training the order matters.
let yita3 = self.yita3.val().clamp(FR_YITA3_MIN, FR_YITA3_MAX);
let yita4 = self.yita4.val().clamp(FR_YITA4_MIN, FR_YITA4_MAX);
logistic_calibrate(x, yita3, yita4.abs(), FR_YITA1, FR_YITA2)
}
}

/// Configuration for [`AfineAdapter`].
#[derive(Config, Debug)]
pub(crate) struct AfineAdapterConfig {}

impl AfineAdapterConfig {
pub(crate) fn init<B: Backend>(&self, device: &B::Device) -> AfineAdapter<B> {
AfineAdapter {
k: Param::from_tensor(Tensor::from_floats([ADAPTER_K_INIT], device)),
}
}
}

/// Fuses the calibrated naturalness and fidelity scores into a single
/// raw `D` value.
///
/// `D = exp(softplus(k) * (N_ref - N_dis)) * N_dis + F`. The `softplus`
/// wrapper enforces `k > 0` without constraining the stored parameter.
#[derive(Module, Debug)]
pub(crate) struct AfineAdapter<B: Backend> {
pub(crate) k: Param<Tensor<B, 1>>,
}

impl<B: Backend> AfineAdapter<B> {
pub(crate) fn forward(
&self,
x_nr: Tensor<B, 2>,
ref_nr: Tensor<B, 2>,
xref_fr: Tensor<B, 2>,
) -> Tensor<B, 2> {
let k_pos = softplus(self.k.val(), 1.0).reshape([1, 1]);
let weight = (k_pos * (ref_nr - x_nr.clone())).exp();
weight * x_nr + xref_fr
}
}

/// Map a raw adapter score into `(0, 100)` via a fixed 4-parameter
/// logistic. Constants are the paper-reported defaults.
pub(crate) fn scale_finalscore<B: Backend>(score: Tensor<B, 2>) -> Tensor<B, 2> {
let denom = SCALE_YITA4.abs() + EPS;
let inner = score.sub_scalar(SCALE_YITA3).div_scalar(denom);
sigmoid(inner)
.mul_scalar(SCALE_YITA1 - SCALE_YITA2)
.add_scalar(SCALE_YITA2)
}

#[cfg(test)]
mod tests {
use super::*;
use burn_flex::Flex;

type TestBackend = Flex;

#[test]
fn nr_calibrator_maps_to_bounded_range() {
let device = Default::default();
let calibrator = NrCalibratorConfig::new().init::<TestBackend>(&device);

let extremes = Tensor::<TestBackend, 2>::from_floats([[-1000.0], [0.0], [1000.0]], &device);
let out = calibrator.forward(extremes);
let values = out.into_data().to_vec::<f32>().unwrap();

for v in &values {
assert!(*v >= -2.0 && *v <= 2.0, "out-of-range value: {v}");
}
// Monotonic increasing.
assert!(values[0] < values[1]);
assert!(values[1] < values[2]);
}

#[test]
fn fr_calibrator_clamp_does_not_panic() {
let device = Default::default();
let calibrator = FrCalibratorWithLimitConfig::new().init::<TestBackend>(&device);

let input = Tensor::<TestBackend, 2>::from_floats([[0.5], [1.5], [-0.5]], &device);
let out = calibrator.forward(input);

assert_eq!(out.dims(), [3, 1]);
}

#[test]
fn adapter_forward_propagates_shape() {
let device = Default::default();
let adapter = AfineAdapterConfig::new().init::<TestBackend>(&device);

let nr_dis = Tensor::<TestBackend, 2>::from_floats([[0.5], [-0.3]], &device);
let nr_ref = Tensor::<TestBackend, 2>::from_floats([[0.7], [-0.1]], &device);
let fr = Tensor::<TestBackend, 2>::from_floats([[0.2], [0.4]], &device);

let out = adapter.forward(nr_dis, nr_ref, fr);
assert_eq!(out.dims(), [2, 1]);
}

#[test]
fn scale_finalscore_maps_to_0_100_range() {
let device = Default::default();

let scores =
Tensor::<TestBackend, 2>::from_floats([[-1000.0], [-1.971], [1000.0]], &device);
let out = scale_finalscore(scores);
let values = out.into_data().to_vec::<f32>().unwrap();

assert!(values[0] >= 0.0 && values[0] <= 100.0);
assert!(values[2] >= 0.0 && values[2] <= 100.0);
// At yita3 = -1.971 the sigmoid argument is 0, so the output is
// 100 * 0.5 = 50.
assert!(
(values[1] - 50.0).abs() < 0.5,
"midpoint should be ~50, got {}",
values[1]
);
}
}
137 changes: 137 additions & 0 deletions crates/burn-train/src/metric/vision/afine/clip_attention.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
//! Self-attention block matching PyTorch's `nn.MultiheadAttention` wire
//! format.
//!
//! PyTorch stores Q/K/V as a single fused `in_proj_weight` of shape
//! `[3 * d_model, d_model]` and `in_proj_bias` of shape `[3 * d_model]`.
//! Burn's [`burn_nn::attention::MultiHeadAttention`] uses three separate
//! Linear layers, so the CLIP checkpoint cannot map to it without
//! pre-splitting the weights at load time.
//!
//! This module keeps the fused layout: a single `qkv_proj` Linear of
//! shape `(d_model -> 3 * d_model)` and a `chunk(3, -1)` at forward,
//! giving a one-to-one mapping with the checkpoint.
//!
//! No attention mask is supported. CLIP's image encoder runs
//! unconditional self-attention; the text encoder (which uses a causal
//! mask) is not ported.

use burn_core as burn;

use burn::config::Config;
use burn::module::Module;
use burn::tensor::Tensor;
use burn::tensor::activation::softmax;
use burn::tensor::backend::Backend;
use burn_nn::{Linear, LinearConfig};

/// Configuration for [`ClipQkvAttention`].
#[derive(Config, Debug)]
pub(crate) struct ClipQkvAttentionConfig {
/// Embedding dimension. Must be divisible by `n_heads`.
pub d_model: usize,
/// Number of attention heads.
pub n_heads: usize,
}

impl ClipQkvAttentionConfig {
/// Initialize a [`ClipQkvAttention`] block.
pub(crate) fn init<B: Backend>(&self, device: &B::Device) -> ClipQkvAttention<B> {
assert_eq!(
self.d_model % self.n_heads,
0,
"d_model ({}) must be divisible by n_heads ({})",
self.d_model,
self.n_heads
);
let head_dim = self.d_model / self.n_heads;
ClipQkvAttention {
qkv_proj: LinearConfig::new(self.d_model, 3 * self.d_model)
.with_bias(true)
.init(device),
out_proj: LinearConfig::new(self.d_model, self.d_model)
.with_bias(true)
.init(device),
d_model: self.d_model,
n_heads: self.n_heads,
head_dim,
}
}
}

/// Self-attention with a fused QKV projection, matching CLIP's checkpoint
/// layout one-to-one.
#[derive(Module, Debug)]
pub(crate) struct ClipQkvAttention<B: Backend> {
/// Fused projection `d_model -> 3 * d_model` for Q, K, V.
pub(crate) qkv_proj: Linear<B>,
/// Output projection `d_model -> d_model`.
pub(crate) out_proj: Linear<B>,
/// Embedding dimension.
pub(crate) d_model: usize,
/// Number of attention heads.
pub(crate) n_heads: usize,
/// Per-head dimension (`d_model / n_heads`).
pub(crate) head_dim: usize,
}

impl<B: Backend> ClipQkvAttention<B> {
/// Apply self-attention. Input and output shape: `[batch, seq, d_model]`.
pub(crate) fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let [batch, seq, _] = x.dims();

let qkv = self.qkv_proj.forward(x);
let mut chunks = qkv.chunk(3, 2);
let value = chunks.remove(2);
let key = chunks.remove(1);
let query = chunks.remove(0);

let to_heads = |t: Tensor<B, 3>| {
t.reshape([batch, seq, self.n_heads, self.head_dim])
.swap_dims(1, 2)
};
let query = to_heads(query);
let key = to_heads(key);
let value = to_heads(value);

let scale = (self.head_dim as f32).sqrt();
let scores = query.matmul(key.transpose()).div_scalar(scale);
let weights = softmax(scores, 3);

let context = weights
.matmul(value)
.swap_dims(1, 2)
.reshape([batch, seq, self.d_model]);
self.out_proj.forward(context)
}
}

#[cfg(test)]
mod tests {
use super::*;
use burn::tensor::Distribution;
use burn_flex::Flex;

type TestBackend = Flex;

#[test]
fn clip_qkv_attention_preserves_shape() {
let device = Default::default();
let attn = ClipQkvAttentionConfig::new(768, 12).init::<TestBackend>(&device);

let input = Tensor::<TestBackend, 3>::random([1, 50, 768], Distribution::Default, &device);
let output = attn.forward(input);

assert_eq!(output.dims(), [1, 50, 768]);
}

#[test]
fn clip_qkv_attention_handles_batch() {
let device = Default::default();
let attn = ClipQkvAttentionConfig::new(64, 4).init::<TestBackend>(&device);

let input = Tensor::<TestBackend, 3>::random([3, 16, 64], Distribution::Default, &device);
let output = attn.forward(input);

assert_eq!(output.dims(), [3, 16, 64]);
}
}
Loading
Loading