Add A-FINE image quality metric#4894
Conversation
Implements the Adaptive Fidelity-Naturalness Evaluator (A-FINE) full- reference image quality metric from https://arxiv.org/abs/2503.11221, ported from the PyIQA reference at https://github.com/chaofengc/IQA-PyTorch/blob/main/pyiqa/archs/afine_arch.py. A-FINE feeds both the distorted and reference images through a fine-tuned CLIP ViT-B/32 visual encoder, runs the per-block patch features through two small heads (naturalness and fidelity), then fuses them through learnable logistic calibrators and a softplus- gated adapter into a single quality score in (0, 100). The implementation lives under `crates/burn-train/src/metric/vision/afine/` and inlines the CLIP backbone rather than building it on top of a shared one. Components: - `quick_gelu.rs` — `x * sigmoid(1.702 * x)`. CLIP was trained with this, not the erf-based GELU; substituting silently shifts every MLP activation. Coefficient pinned at 1.702 as a labeled constant. - `clip_attention.rs` — self-attention with a fused `qkv_proj: Linear(d_model, 3 * d_model)` and a `chunk(3, -1)` at forward. Matches PyTorch's `nn.MultiheadAttention.in_proj_weight` layout one-to-one so the checkpoint maps without pre-splitting. - `clip_vit.rs` — `ClipVisualEncoder`: patch embed, learnable class token, positional embedding, ln_pre, twelve pre-norm transformer blocks, ln_post. Stays in NLD layout end-to-end (PyIQA permutes to LND inside the stack); attention is invariant so the math is identical. `forward_with_features` returns the per-block patch feature maps the heads consume. - `heads.rs` — `AfineQHead` (naturalness): per-level mean+variance descriptor over thirteen feature levels (raw RGB plus twelve ReLU'd CLIP layers), shared `proj_feat: Linear(1536, 128)` on the CLIP levels, then a small MLP to a scalar with the erf-based `burn_nn::Gelu` (distinct from the CLIP trunk's QuickGELU). `AfineDHead` (fidelity): SSIM-like luminance and contrast terms weighted by globally-softplus-normalized `alpha`/`beta` parameters and reduced. Epsilon `1e-10`, not the larger `1e-6` DISTS uses. - `calibrators.rs` — two logistic calibrators (`NrCalibrator`, `FrCalibratorWithLimit`), the `AfineAdapter` with learnable `k`, and the fixed `scale_finalscore` mapping into `(0, 100)`. PyIQA's reference uses an `if exp_pow >= 10` branch for numerical stability that only works on 0-D scalar tensors; rewritten as a single `(yita1 - yita2) * sigmoid((x - yita3) / (|yita4| + eps)) + yita2` expression so it batches correctly. - `weights.rs` — downloads `afine.pth` (~600 MB) from `chaofengc/IQA-PyTorch-Weights` on Hugging Face, caches it under `~/.cache/burn-dataset/afine/`, and loads all six shards (`finetuned_clip`, `natural`, `fidelity`, `natural_scale`, `fidelity_scale`, `adapter`) into the matching submodules. The five 0-D scalar checkpoint values (`yita3`, `yita4`, `k`) are read manually via `PytorchReader` because they're stored with `shape=()` and `Param<Tensor<B, 1>>` expects shape `(1,)`, so `PytorchStore`'s shape check silently drops them. A pre-rename of the text encoder's `positional_embedding` key avoids a HashMap iteration-order collision with the visual encoder's after remap. - `metric.rs` — public API surface. `AfineConfig` with `image_size` (default 256, must be divisible by 32) and `normalize_input` (default true). `Afine<B>::forward(distorted, reference)` returns `Tensor<B, 1>` with per-sample scores. `ModuleDisplay` shows the backbone name, image size, and normalization toggle. Tests (always-on): forward shape, batch processing, finite-on- constant-inputs, asymmetry, image_size validation, and display format. Two `#[ignore = "downloads pre-trained weights"]` tests: `test_afine_pretrained` (loads real weights, asserts finite output that meaningfully differs from the random-init metric, guards against silent partial loads) and `test_afine_pretrained_parity` (matches a captured PyIQA scalar to within ~5e-5 absolute on a deterministic linspace input pair, catches QuickGELU coefficient errors, fused-QKV transpose direction bugs, and channel-order mistakes that the property tests miss). Adds a row to `burn-book/src/building-blocks/metric.md`'s vision metric table. Issue: tracel-ai#4312 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Codecov Report❌ Patch coverage is ❌ Your patch check has failed because the patch coverage (76.17%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #4894 +/- ##
==========================================
- Coverage 65.49% 65.34% -0.15%
==========================================
Files 1165 1177 +12
Lines 172277 175662 +3385
==========================================
+ Hits 112830 114790 +1960
- Misses 59447 60872 +1425 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Just letting you know I am aware of this PR 🙏 but haven't gotten around to it yet; bigger PRs require a bit more context switching 😅 It's still on my stack, will review sometime this week! Thanks for the contribution |
laggui
left a comment
There was a problem hiding this comment.
Pretty clean implementation! Just one minor comment but otherwise LGTM
| // ln_post is applied only to the class-token output. | ||
| let cls = x.slice([0..batch, 0..1, 0..embed]).reshape([batch, embed]); | ||
| let cls = self.ln_post.forward(cls); | ||
|
|
||
| (cls, features) | ||
| } |
There was a problem hiding this comment.
For the metric implementation the cls output is discarded, so we could skip this computation.
We can preserve it for completeness (and you already have a test that covers this), so forward_with_features could return an output struct with optional cls output e.g.
pub struct ClipOutput<B: Backend> {
pub features: Vec<Tensor<B, 3>>,
pub cls: Option<Tensor<B, 2>>,
}The metric path discarded the class-token output of
forward_with_features, paying for a slice and a LayerNorm on every
forward. Return a ClipOutput { features, cls: Option<...> } struct
now, with a return_cls flag on forward_with_features that lets the
caller request cls only when they need it. The metric passes false;
the standalone forward() and the test pass true.
Per laggui's review on tracel-ai#4894.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
hey, I made the changes you requested, took some time so sorry for that; does this look right now? |
laggui
left a comment
There was a problem hiding this comment.
LGTM!
Your response time was totally fine btw 😄
|
Exciting! Looking forward for further contributions :) |
Pull Request Template
Checklist
cargo run-checkscommand has been executed.Related Issues/PRs
Ticks one box on the meta-issue #4312. Original scoping comment + your answers are in that thread.
Changes
A-FINE is a full-reference image quality metric that runs both the distorted and reference images through CLIP ViT-B/32, then through a naturalness head, a fidelity head, two logistic calibrators, and an adapter that fuses them into one (0, 100) score. Paper: https://arxiv.org/abs/2503.11221. PyIQA reference: https://github.com/chaofengc/IQA-PyTorch/blob/main/pyiqa/archs/afine_arch.py.
Lives at
crates/burn-train/src/metric/vision/afine/. CLIP ViT is inlined under the metric, matching what LPIPS, DISTS, and FID did. ~1830 LOC across 8 Rust files plus the burn-book row.A few things worth flagging:
x * sigmoid(1.702 * x)), not erf-based GELU. Substituting silently shifts every MLP activation, so QuickGELU lives as its own private module here, with the coefficient pinned at 1.702 as a labeled constant.qkv_proj: Linear(d_model, 3 * d_model)+chunk(3, -1)at forward) so the CLIP checkpoint maps to it one-to-one. burn-nn's MultiHeadAttention uses separate Q/K/V Linears, which would force a pre-split at load.burn_nn::Gelu, NOT QuickGELU. PyIQA's reference is explicit on this distinction.c1 = c2 = 1e-10in the fidelity head's SSIM-like ratios. DISTS uses1e-6but A-FINE's CLIP feature magnitudes are unit-scale-ish, so the smaller eps is what the trained checkpoint expects.Two non-obvious things in the pretrained loader the parity test caught:
The five 0-D scalar params (NR/FR
yita3,yita4, adapterk) are stored withshape=()in the checkpoint but burn'sParam<Tensor<B, 1>>expects(1,). PytorchStore silently drops these, leaving them at random init. Loaded manually viaPytorchReader::with_top_level_key(...).get(name).The CLIP checkpoint has both
positional_embedding(text encoder, shape[77, 512]) andvisual.positional_embedding(shape[50, 768]). After my remap renames the visual one topositional_embedding, both keys collide and HashMap iteration order picks which wins. Pre-rename of the text encoder's key first dodges the collision.Question on weight hosting
The URL points at the PyIQA author's personal HF mirror (
chaofengc/IQA-PyTorch-Weights/resolve/main/afine.pth). I tried pretty hard to find an org-hosted alternative and couldn't:afine.pth)I noticed LPIPS, DISTS, and FID currently pull weights from personal-account GitHub raw URLs, so personal hosting isn't unprecedented in burn. But I remember you mentioned org hosting was preferred when we scoped this on the issue, so figured I'd flag rather than assume. Happy to go with whatever you prefer: keep this URL, have tracel-ai mirror
afine.pth, or another option.Testing
Always-on property tests: forward shape, batch processing, finite output on constant inputs, asymmetry from the adapter's exponent term,
image_sizevalidation, display format.Two
#[ignore = "downloads pre-trained weights"]tests, matching the LPIPS/DISTS/FID convention:test_afine_pretrainedloads real weights and checks the score is finite and meaningfully differs from a random-init metric on the same input. Catches silent partial loads where a regex remap is wrong.test_afine_pretrained_paritymatches a captured PyIQA reference scalar on a deterministicarange/(N-1)vsarange/Ninput pair to within ~5e-5 absolute. Catches QuickGELU-vs-GELU swaps, fused-QKV transpose-direction bugs, channel-order mistakes, andc1=c2=1e-10epsilon drift that the property tests miss.cargo run-checksgreen: fmt, clippy, audit, typos, doc, full burn-train with--features vision(123 passing, 7 ignored).