[inactive] Track entropy and MI of routing distribution for topk MoE#188
[inactive] Track entropy and MI of routing distribution for topk MoE#188
Conversation
|
Yes @tscholak, addressed. Using metrics dict instead. |
jlamypoirier
left a comment
There was a problem hiding this comment.
Looks good, got some comments on the structure.
| def loss_defs(self) -> list[LossDef]: | ||
| pass | ||
|
|
||
| @property |
There was a problem hiding this comment.
This loss/metric split is way more complicated than needed. How about having a single entry, and using a is_metric flag in LossDef (or a derived class) to distinguish? Then no change is needed other than extracting metrics from the context before returning from run_step
There was a problem hiding this comment.
This would be nice!
Maybe better to leave it for a separate pr? It would make this one larger as it would require also changing the interfaces of the models' forward functions (that expect losses and metrics) as well as making sure that metrics are only calculated when return_metrics is True.
There was a problem hiding this comment.
There isn't much change needed actually, just need to add kwargs["return_metrics"]. I would prefer doing this here so we don't grow ScheduleRunner too much.
|
|
||
| def _is_reduced_metric(self, metric_name: str) -> bool: | ||
| """Check if a metric should be reduced (is defined in a TransformerReducedMetrics subclass).""" | ||
| from fast_llm.layers.transformer.config import TransformerReducedMetrics |
There was a problem hiding this comment.
We can't use hard-coded values here. Suggestion above would fix it, or there are a few other ways to get this dynamically.
There was a problem hiding this comment.
Simplified the setup s.t. all metrics that come back from a forward pass are reduced automatically, hence no need for this function.
|
|
||
|
|
||
| # Store these metrics | ||
| if metrics is not None: |
There was a problem hiding this comment.
Given the extra computation involved, this should be enabled through a config parameter
There was a problem hiding this comment.
how much compute are we talking about for these metrics? likely this won't be noticeable.
There was a problem hiding this comment.
this is already controlled by training.logs.interval parameter afaiu. Do you think we need a seperate parameter for MoE stats logging?
There was a problem hiding this comment.
Ideally we'd like to calculate the bare minimum by default (and the computation isn't optimized), so I think yes.
If we really want by default I guess we could have a parameter with that defaults to true but can disabled for performance, eg. for benchmarks.
| assert 0.0 < mutual_info < 1.0, f"Expected value between 0 and 1, got {mutual_info}" | ||
|
|
||
|
|
||
| def test_edge_cases(): |
|
|
||
|
|
||
| @pytest.fixture | ||
| def setup_runner(): |
There was a problem hiding this comment.
These don't belong here. How about test_runner.py?
There was a problem hiding this comment.
why don't they belong here? this is fixture is only useful for the tests in this suite.
There was a problem hiding this comment.
Maybe we can move it to common.py in the future, it maybe be reused by other tests (e.g. ssm test)
There was a problem hiding this comment.
I'm talking about the associated tests, not the fixture. They test a feature of ScheduleRunner and have nothing to do with MoE other than the loss names which aren't relevant to the tests.
There was a problem hiding this comment.
I see, makes sense, will move it to a new test file.
|
|
||
|
|
||
|
|
||
| if __name__ == "__main__": |
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def calculate_normalized_average_entropy(probs: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Could try @torch.compile on these for a free performance boost.
| average_entropy = entropy_values.mean() # Average over batch and tokens | ||
| return average_entropy / torch.log(torch.tensor(n_experts, dtype=probs.dtype, device=probs.device)) | ||
|
|
||
| def entropy(probs: torch.Tensor) -> torch.Tensor: |
|
@oleksost Are you still working on this? |
|
@jlamypoirier yes, will address your comments today. Sorry, it was deprioritised in favour of mamba. |
|
@jlamypoirier I think I addressed all the comments. |
| n_experts = probs.size(-1) | ||
| entropy_values = calculate_entropy(probs) | ||
| average_entropy = entropy_values.mean() # Average over batch and tokens | ||
| return average_entropy / torch.log(torch.tensor(n_experts, dtype=probs.dtype, device=probs.device)) |
There was a problem hiding this comment.
average_entropy/math.log(n_experts) (same elsewhere)
| def loss_defs(self) -> list[LossDef]: | ||
| pass | ||
|
|
||
| @property |
There was a problem hiding this comment.
There isn't much change needed actually, just need to add kwargs["return_metrics"]. I would prefer doing this here so we don't grow ScheduleRunner too much.
✨ Description
To better detect potential routing collapse and have a better understanding about the routing distribution, we can track the average entropy and mutual information of routing probabilities.
Collapse routing would have low entropy and low mutual information. A healthy and specialised router would have low entropy and high mutual information, meaning that routing is specialised and considerably different across tokens.
More specifically:
Mutual info. measures the difference between:
🔍 Type of change
Select all that apply:
📝 Changes
mixture_of_experts.py, they are calculated only for the topk routing type.✅ Checklist
General
Testing
Performance Impact
📊 Performance Impact Details
I am not 100% sure there is no performance impact, we are calculating the stats at each forward pass through the router.
🗒️ Additional Notes
Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.