Skip to content

Ekfac#13

Closed
LouisYRYJ wants to merge 107 commits intomainfrom
ekfac
Closed

Ekfac#13
LouisYRYJ wants to merge 107 commits intomainfrom
ekfac

Conversation

@LouisYRYJ
Copy link
Contributor

@LouisYRYJ LouisYRYJ commented Jun 22, 2025

This introduces EKFAC computations, everything relevant is in /hessians

  • ignore /approx_unrolling for this PR
  • gradient_covariance files in hessians are for debugging, ignore these too
  • I mostly followed the build_index pipeline
  • Currently, we should not merge this, because it uses the current version main and not the ultrafast one
    ->As I refactored some parts of the general pipeline, we should probably merge ultrafast first and then I can make the changes to this PR.

@CLAassistant
Copy link

CLAassistant commented Jun 22, 2025

CLA assistant check
All committers have signed the CLA.

smarter and others added 18 commits October 18, 2025 16:34
This class was removed in 8232b77 but the
notebook code was not adapted.
This cell contains code from tests/ekfac_tests/test_covariance.py but fails with
file not found errors and isn't actually needed in the rest of the notebook.
This makes it easier to run the script on a small GPU for testing.
This is adapted from the main branch of bergson.
Use jupytext percent format which is directly interpretable by vscode (and can
also be converted into ipynb using `jupytext --to notebook`).

We want compute_ekfac_ground_truth to be:
- Importable from other files. So it should be split in functions and importing
  it shouldn't have side-effects.
- Usable a script we run from. So it should have a main that parses input
  arguments and run everything.
- Usable as a notebook. So it should be split into cells where each cell can be
  executed individually and produce some output.

To gain back usability as a notebook without compromising the other usecases, we
split the logic that used to be in `main()` in multiple statements guarded by
`if __name__ == "__main__"` at the end of the cell that defines the relevant
function (since each of these guarded statement defines some variable, we
actually need `or TYPE_CHECKING` to ensure they are visible to the typechecker).
All covariance-related cells are next to each other, same for all
eigendecomposition-related code.
ekfac: fix compute_ekfac_ground_truth, add minimal CI
* pyproject.toml: Bump minimal transformers version

The keyword parameter `dtype` for AutoModelForCausalLM.from_pretrained does not
exist in version 4.54.1 which was present in uv.lock (this parameter used to be
called `torch_dtype` which is now a deprecated alias).

* compute_ekfac_ground_truth: Add model-name and world-size params

Also rework default handling to avoid specifying default values in multiple places.

* compute_ekfac_ground_truth: adjust token_batch_size based on model

* Add assertions to EKFAC tests (new criterion for test_eigenvalue_correction)

This way when we start using pytest, test failures will be properly reported.

test_eigenvalue_correction had no explicit criterion for success so I made one
up.

* Convert EKFAC tests to pytest

This includes using fixtures for ground truth generation and test configuration,
so that we can just do:

uv run pytest -sv tests/ekfac_tests

and ground truth will be auto-generated.

* Add pre-commit as a dev dependency and run it

Ran "uv pre-commit run --all-files" which reads from .pre-commit-config.yaml

Unfortunately pre-commit does not respect tool settings in pyproject.toml, so
right now there's conflicting informations in pyproject.toml and
.pre-commit-config.yaml and so different settings and tool versions used
depending on how we run tools.

* Run EKFAC tests on CPU, to enable them in the CI

test_eigenvalue_corrections had to be disabled due to precision errors:

  h.6.attn.attention.out_proj: max_rel_diff=2.285%
  h.6.mlp.c_proj: max_rel_diff=3.599%
  h.7.attn.attention.out_proj: max_rel_diff=4.041%
  h.7.mlp.c_proj: max_rel_diff=2.204%

* Fix pyright inclusions/exclusions to make the CI green

It seems the working-directory parameter in the CI config is ignored if
pyproject.toml configures pyright, so tweak that instead.

* Emit an error if we're overwriting ground truth with different params

Overwriting is allowed using the --overwrite flag.

* Make gradient computation invariant to batch size

Use loss.sum().backward() to avoid scaling the gradients by 1/B (and the
covariance matrix by 1/B^2).

Without this change, G2/G1 is empirically ~0.2 with the default set of
parameters.

* Make sure all ekfac_tests are run with deterministic seeds

* Add KFAC FIM accuracy test with toy model

This compares the KFAC approximation against the exact FIM computed on a toy
model. We intentionally restrict test conditions to avoid exercising issues with
padding and last token gradient which are fixed in the next commit.

* Fix covariance estimation errors from invalid positions

When batching sequences of different lengths, we pad shorter sequences. These
padding positions aren't real data and shouldn't contribute to the FIM.
Similarly, the last position of each sequence has no next token to predict.

Invalid positions affected both covariances differently. The activation
covariance A was contaminated with out-of-distribution activations for padding.
The gradient covariance G was underestimated: gradients are zero for invalid
positions, but total_processed included them in the denominator. When
sample=True, there was a third issue: sampled labels didn't preserve -100 for
padding, so G was corrupted with non-zero gradients.

The fix computes valid_masks in pad_and_tensor() and uses it to filter
activations and restrict loss computation to valid positions.

* Pass `target_modules` to CovarianceCollector

CovarianceCollector was called without the target_modules parameter, causing it
to hook into all MLP layers instead of just the specified target modules.
LambdaCollector and the ground truth collectors already had this parameter set
correctly.
@LouisYRYJ
Copy link
Contributor Author

Closing this, as this has been now split in these 3 PRs.

#63
#81
#123

@LouisYRYJ LouisYRYJ closed this Jan 13, 2026
smarter added a commit to smarter/bergson that referenced this pull request Jan 16, 2026
Restore the calls to dist.barrier that existed in
EleutherAI#13, the process would hang.

For testing, we add _allocate_batches_world to compute the batches for the
ground truth.
smarter added a commit to smarter/bergson that referenced this pull request Jan 16, 2026
Restore the calls to dist.barrier that existed in
EleutherAI#13, the process would hang.

For testing, we add _allocate_batches_world to compute the batches for the
ground truth.
smarter added a commit to smarter/bergson that referenced this pull request Jan 16, 2026
Restore the calls to dist.barrier that existed in
EleutherAI#13, without them the process would
hang when running with world_size > 1.

For testing, we add _allocate_batches_world to compute the batches for the
ground truth.
smarter added a commit to smarter/bergson that referenced this pull request Jan 16, 2026
Restore the calls to dist.barrier that existed in
EleutherAI#13, without them the process would
hang when running with world_size > 1.

For testing, we add _allocate_batches_world to compute the batches for the
ground truth.
smarter added a commit to smarter/bergson that referenced this pull request Jan 16, 2026
Restore the calls to dist.barrier that existed in
EleutherAI#13, without them the process would
hang when running with world_size > 1.

For testing, we add _allocate_batches_world to compute the batches for the
ground truth.
smarter added a commit to smarter/bergson that referenced this pull request Jan 16, 2026
Restore the calls to dist.barrier that existed in
EleutherAI#13, without them the process would
hang when running with world_size > 1.

For testing, we add _allocate_batches_world to compute the batches for the
ground truth. The tests don't pass due to numerical errors, this is handled in
the next commit by changing our comparison logic.
smarter added a commit to smarter/bergson that referenced this pull request Jan 17, 2026
Restore the calls to dist.barrier that existed in
EleutherAI#13, without them the process would
hang when running with world_size > 1.

For testing, we add _allocate_batches_world to compute the batches for the
ground truth. The tests don't pass due to numerical errors, this is handled in
the next commit by changing our comparison logic.
smarter added a commit to smarter/bergson that referenced this pull request Jan 17, 2026
Restore the calls to dist.barrier that existed in
EleutherAI#13, without them the process would
hang when running with world_size > 1.

For testing, we add _allocate_batches_world to compute the batches for the
ground truth. The tests don't pass due to numerical errors, this is handled in
the next commit by changing our comparison logic.
LouisYRYJ added a commit that referenced this pull request Jan 26, 2026
* ekfac implementation done (untested)

* remove unnecessary squeeze

* add tkfac

* fix claude issues

* shampoo

* minor fix

* Add EKFAC tests and fix a couple of bugs (#125)

* Fix mask bug and add batch size invariance test wih toy model

The backward_hook was using g.reshape(-1, O) which includes padding
positions in the covariance computation. This causes incorrect results
when batches have different sequence lengths.

Before this commit, the added test failed with:
> FAILED tests/ekfac_tests/test_batch_size_invariance.py::test_trace_batch_invariant[seq_lengths1-20] - AssertionError: Scalars are not close!
>
> Expected 1.231401894309304 but got 0.8983965093439276.
> Absolute difference: 0.33300538496537635 (up to 1e-4 allowed)
> Relative difference: 0.27042786478102654 (up to 0.01 allowed)

* Fix use_dataset_labels condition and add FIM accuracy test

The condition `if not hessian_cfg.use_dataset_labels:` was inverted,
causing the empirical Fisher (with dataset labels) to use sampled
labels and vice versa.

Add test_fim_accuracy.py which verifies that KFAC approximates the
Fisher Information Matrix within tolerance for both empirical FIM
(dataset labels) and true FIM (sampled labels).

* Add ground truth ekfac tests

This is still missing FSDP support and test_apply_ekfac.py from
#68

Co-Authored-By: LouisYRYJ <louis.yousif@yahoo.de>

* ekfac_tests/test_batch_size_invariance.py: Fix error thresholds when running on CPU

* Cleanup EKFAC tests

- Replace set_all_seeds by existing setup_reproducibility
- Reuse approximate_hessians instead of doing something
  equivalent manually.

* Add --token_batch_size option to EKFAC tests

* Add --n_samples option to EKFAC tests

Allow configuring the number of samples from pile-10k dataset via
pytest command line option instead of hardcoding 100. The dataset
directory is now named dynamically (e.g., pile_100_examples).

* hessians: Fix distributed support and test it

Restore the calls to dist.barrier that existed in
#13, without them the process would
hang when running with world_size > 1.

For testing, we add _allocate_batches_world to compute the batches for the
ground truth. The tests don't pass due to numerical errors, this is handled in
the next commit by changing our comparison logic.

* ekfac_tests: Use appropriate metrics for each comparison

- Eigenvectors: Check |cosine_similarity| ≈ 1 per column, which naturally
  handles sign ambiguity (eigenvectors are only defined up to sign)
- Covariances: Check relative Frobenius norm since values should match exactly
- Eigenvalue corrections: Align signs based on eigenvector orientation, then
  check relative error (λ[i,j] transforms as sign_G[i] * sign_A[j])
  - Also reenable CPU tests which pass after this change.

* ekfac_tests: Relax thresholds for distributed runs

With world_size > 1, floating-point reduction order differs between ground
truth (single process) and distributed run, causing larger numerical
differences in some layers.

For eigenvectors, use average |cos_sim| instead of minimum - this tolerates
occasional outlier eigenvectors while maintaining a stricter threshold
(1e-3 vs 0.1 that would be needed for min).

For eigenvalue corrections, use atol=0.2 when world_size > 1.

* adjust test + normalize shampoo and tkfac

* minor fixes, correct tensor handling in shampoo and tkfac, introduce apply_hessian (WIP)

---------

Co-authored-by: Guillaume Martres <smarter@ubuntu.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants