Skip to content

Add A-FINE image quality metric#4894

Merged
laggui merged 2 commits into
tracel-ai:mainfrom
Capataina:afine-metric
May 11, 2026
Merged

Add A-FINE image quality metric#4894
laggui merged 2 commits into
tracel-ai:mainfrom
Capataina:afine-metric

Conversation

@Capataina
Copy link
Copy Markdown
Contributor

Pull Request Template

Checklist

  • Confirmed that cargo run-checks command has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

Ticks one box on the meta-issue #4312. Original scoping comment + your answers are in that thread.

Changes

A-FINE is a full-reference image quality metric that runs both the distorted and reference images through CLIP ViT-B/32, then through a naturalness head, a fidelity head, two logistic calibrators, and an adapter that fuses them into one (0, 100) score. Paper: https://arxiv.org/abs/2503.11221. PyIQA reference: https://github.com/chaofengc/IQA-PyTorch/blob/main/pyiqa/archs/afine_arch.py.

Lives at crates/burn-train/src/metric/vision/afine/. CLIP ViT is inlined under the metric, matching what LPIPS, DISTS, and FID did. ~1830 LOC across 8 Rust files plus the burn-book row.

A few things worth flagging:

  • The CLIP backbone uses QuickGELU (x * sigmoid(1.702 * x)), not erf-based GELU. Substituting silently shifts every MLP activation, so QuickGELU lives as its own private module here, with the coefficient pinned at 1.702 as a labeled constant.
  • The attention block is a custom fused-QKV one (qkv_proj: Linear(d_model, 3 * d_model) + chunk(3, -1) at forward) so the CLIP checkpoint maps to it one-to-one. burn-nn's MultiHeadAttention uses separate Q/K/V Linears, which would force a pre-split at load.
  • The naturalness head's MLP uses the erf-based burn_nn::Gelu, NOT QuickGELU. PyIQA's reference is explicit on this distinction.
  • c1 = c2 = 1e-10 in the fidelity head's SSIM-like ratios. DISTS uses 1e-6 but A-FINE's CLIP feature magnitudes are unit-scale-ish, so the smaller eps is what the trained checkpoint expects.

Two non-obvious things in the pretrained loader the parity test caught:

The five 0-D scalar params (NR/FR yita3, yita4, adapter k) are stored with shape=() in the checkpoint but burn's Param<Tensor<B, 1>> expects (1,). PytorchStore silently drops these, leaving them at random init. Loaded manually via PytorchReader::with_top_level_key(...).get(name).

The CLIP checkpoint has both positional_embedding (text encoder, shape [77, 512]) and visual.positional_embedding (shape [50, 768]). After my remap renames the visual one to positional_embedding, both keys collide and HashMap iteration order picks which wins. Pre-rename of the text encoder's key first dodges the collision.

Question on weight hosting

The URL points at the PyIQA author's personal HF mirror (chaofengc/IQA-PyTorch-Weights/resolve/main/afine.pth). I tried pretty hard to find an org-hosted alternative and couldn't:

  • Not on any HF org I checked (PyIQA-related, q-future, the authors' institutions)
  • Not in IQA-PyTorch's GitHub releases (50+ other metric weights are there, just not afine.pth)
  • Not on Zenodo, figshare, ModelScope, OSF, or Kaggle
  • Authors' canonical release is Google Drive only, which reqwest can't follow cleanly

I noticed LPIPS, DISTS, and FID currently pull weights from personal-account GitHub raw URLs, so personal hosting isn't unprecedented in burn. But I remember you mentioned org hosting was preferred when we scoped this on the issue, so figured I'd flag rather than assume. Happy to go with whatever you prefer: keep this URL, have tracel-ai mirror afine.pth, or another option.

Testing

Always-on property tests: forward shape, batch processing, finite output on constant inputs, asymmetry from the adapter's exponent term, image_size validation, display format.

Two #[ignore = "downloads pre-trained weights"] tests, matching the LPIPS/DISTS/FID convention:

  • test_afine_pretrained loads real weights and checks the score is finite and meaningfully differs from a random-init metric on the same input. Catches silent partial loads where a regex remap is wrong.
  • test_afine_pretrained_parity matches a captured PyIQA reference scalar on a deterministic arange/(N-1) vs arange/N input pair to within ~5e-5 absolute. Catches QuickGELU-vs-GELU swaps, fused-QKV transpose-direction bugs, channel-order mistakes, and c1=c2=1e-10 epsilon drift that the property tests miss.

cargo run-checks green: fmt, clippy, audit, typos, doc, full burn-train with --features vision (123 passing, 7 ignored).

Implements the Adaptive Fidelity-Naturalness Evaluator (A-FINE) full-
reference image quality metric from
https://arxiv.org/abs/2503.11221, ported from the PyIQA reference at
https://github.com/chaofengc/IQA-PyTorch/blob/main/pyiqa/archs/afine_arch.py.
A-FINE feeds both the distorted and reference images through a
fine-tuned CLIP ViT-B/32 visual encoder, runs the per-block patch
features through two small heads (naturalness and fidelity), then
fuses them through learnable logistic calibrators and a softplus-
gated adapter into a single quality score in (0, 100).

The implementation lives under `crates/burn-train/src/metric/vision/afine/`
and inlines the CLIP backbone rather than building it on top of a
shared one. Components:

- `quick_gelu.rs` — `x * sigmoid(1.702 * x)`. CLIP was trained with
  this, not the erf-based GELU; substituting silently shifts every
  MLP activation. Coefficient pinned at 1.702 as a labeled constant.
- `clip_attention.rs` — self-attention with a fused
  `qkv_proj: Linear(d_model, 3 * d_model)` and a `chunk(3, -1)` at
  forward. Matches PyTorch's `nn.MultiheadAttention.in_proj_weight`
  layout one-to-one so the checkpoint maps without pre-splitting.
- `clip_vit.rs` — `ClipVisualEncoder`: patch embed, learnable class
  token, positional embedding, ln_pre, twelve pre-norm transformer
  blocks, ln_post. Stays in NLD layout end-to-end (PyIQA permutes
  to LND inside the stack); attention is invariant so the math is
  identical. `forward_with_features` returns the per-block patch
  feature maps the heads consume.
- `heads.rs` — `AfineQHead` (naturalness): per-level mean+variance
  descriptor over thirteen feature levels (raw RGB plus twelve
  ReLU'd CLIP layers), shared `proj_feat: Linear(1536, 128)` on the
  CLIP levels, then a small MLP to a scalar with the erf-based
  `burn_nn::Gelu` (distinct from the CLIP trunk's QuickGELU).
  `AfineDHead` (fidelity): SSIM-like luminance and contrast terms
  weighted by globally-softplus-normalized `alpha`/`beta` parameters
  and reduced. Epsilon `1e-10`, not the larger `1e-6` DISTS uses.
- `calibrators.rs` — two logistic calibrators (`NrCalibrator`,
  `FrCalibratorWithLimit`), the `AfineAdapter` with learnable `k`,
  and the fixed `scale_finalscore` mapping into `(0, 100)`. PyIQA's
  reference uses an `if exp_pow >= 10` branch for numerical
  stability that only works on 0-D scalar tensors; rewritten as a
  single `(yita1 - yita2) * sigmoid((x - yita3) / (|yita4| + eps)) + yita2`
  expression so it batches correctly.
- `weights.rs` — downloads `afine.pth` (~600 MB) from
  `chaofengc/IQA-PyTorch-Weights` on Hugging Face, caches it under
  `~/.cache/burn-dataset/afine/`, and loads all six shards
  (`finetuned_clip`, `natural`, `fidelity`, `natural_scale`,
  `fidelity_scale`, `adapter`) into the matching submodules. The
  five 0-D scalar checkpoint values (`yita3`, `yita4`, `k`) are
  read manually via `PytorchReader` because they're stored with
  `shape=()` and `Param<Tensor<B, 1>>` expects shape `(1,)`, so
  `PytorchStore`'s shape check silently drops them. A pre-rename
  of the text encoder's `positional_embedding` key avoids a HashMap
  iteration-order collision with the visual encoder's after remap.
- `metric.rs` — public API surface. `AfineConfig` with `image_size`
  (default 256, must be divisible by 32) and `normalize_input`
  (default true). `Afine<B>::forward(distorted, reference)` returns
  `Tensor<B, 1>` with per-sample scores. `ModuleDisplay` shows the
  backbone name, image size, and normalization toggle.

Tests (always-on): forward shape, batch processing, finite-on-
constant-inputs, asymmetry, image_size validation, and display
format. Two `#[ignore = "downloads pre-trained weights"]` tests:
`test_afine_pretrained` (loads real weights, asserts finite output
that meaningfully differs from the random-init metric, guards
against silent partial loads) and `test_afine_pretrained_parity`
(matches a captured PyIQA scalar to within ~5e-5 absolute on a
deterministic linspace input pair, catches QuickGELU coefficient
errors, fused-QKV transpose direction bugs, and channel-order
mistakes that the property tests miss).

Adds a row to `burn-book/src/building-blocks/metric.md`'s vision
metric table.

Issue: tracel-ai#4312

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Capataina Capataina marked this pull request as ready for review April 28, 2026 22:55
@laggui laggui self-requested a review April 30, 2026 18:19
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 30, 2026

Codecov Report

❌ Patch coverage is 76.17866% with 192 lines in your changes missing coverage. Please review.
✅ Project coverage is 65.34%. Comparing base (cb26f81) to head (d2825c4).
⚠️ Report is 26 commits behind head on main.

Files with missing lines Patch % Lines
...ates/burn-train/src/metric/vision/afine/weights.rs 0.00% 139 Missing ⚠️
...rates/burn-train/src/metric/vision/afine/metric.rs 71.82% 51 Missing ⚠️
.../burn-train/src/metric/vision/afine/calibrators.rs 98.95% 1 Missing ⚠️
crates/burn-train/src/metric/vision/afine/heads.rs 99.45% 1 Missing ⚠️

❌ Your patch check has failed because the patch coverage (76.17%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage.
❌ Your project check has failed because the head coverage (65.34%) is below the target coverage (80.00%). You can increase the head coverage or adjust the target coverage.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #4894      +/-   ##
==========================================
- Coverage   65.49%   65.34%   -0.15%     
==========================================
  Files        1165     1177      +12     
  Lines      172277   175662    +3385     
==========================================
+ Hits       112830   114790    +1960     
- Misses      59447    60872    +1425     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@laggui
Copy link
Copy Markdown
Member

laggui commented May 5, 2026

Just letting you know I am aware of this PR 🙏 but haven't gotten around to it yet; bigger PRs require a bit more context switching 😅

It's still on my stack, will review sometime this week!

Thanks for the contribution

Copy link
Copy Markdown
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty clean implementation! Just one minor comment but otherwise LGTM

Comment on lines +218 to +223
// ln_post is applied only to the class-token output.
let cls = x.slice([0..batch, 0..1, 0..embed]).reshape([batch, embed]);
let cls = self.ln_post.forward(cls);

(cls, features)
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the metric implementation the cls output is discarded, so we could skip this computation.

We can preserve it for completeness (and you already have a test that covers this), so forward_with_features could return an output struct with optional cls output e.g.

pub struct ClipOutput<B: Backend> {
    pub features: Vec<Tensor<B, 3>>,
    pub cls: Option<Tensor<B, 2>>,
}

The metric path discarded the class-token output of
forward_with_features, paying for a slice and a LayerNorm on every
forward. Return a ClipOutput { features, cls: Option<...> } struct
now, with a return_cls flag on forward_with_features that lets the
caller request cls only when they need it. The metric passes false;
the standalone forward() and the test pass true.

Per laggui's review on tracel-ai#4894.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Capataina
Copy link
Copy Markdown
Contributor Author

hey, I made the changes you requested, took some time so sorry for that; does this look right now?

Copy link
Copy Markdown
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Your response time was totally fine btw 😄

@Capataina
Copy link
Copy Markdown
Contributor Author

Exciting! Looking forward for further contributions :)

@laggui laggui merged commit 1997b32 into tracel-ai:main May 11, 2026
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants