-
Notifications
You must be signed in to change notification settings - Fork 28
[feat] Track entropy and MI of routing distribution for topk MoE #188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
idea is good, thanks @oleksost.
bit weird that all these metrics are appearing as losses. that name should be reserved for things for which gradients are computed. just call this dict metrics?
Yes @tscholak, addressed. Using metrics dict instead. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, got some comments on the structure.
@@ -135,3 +135,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: | |||
@abc.abstractmethod | |||
def loss_defs(self) -> list[LossDef]: | |||
pass | |||
|
|||
@property |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This loss/metric split is way more complicated than needed. How about having a single entry, and using a is_metric
flag in LossDef
(or a derived class) to distinguish? Then no change is needed other than extracting metrics from the context before returning from run_step
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be nice!
Maybe better to leave it for a separate pr? It would make this one larger as it would require also changing the interfaces of the models' forward functions (that expect losses and metrics) as well as making sure that metrics are only calculated when return_metrics
is True.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There isn't much change needed actually, just need to add kwargs["return_metrics"]
. I would prefer doing this here so we don't grow ScheduleRunner
too much.
fast_llm/engine/schedule/runner.py
Outdated
@@ -289,6 +312,19 @@ def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: | |||
for name, reduced_loss in reduced_losses.items() | |||
} | |||
|
|||
def _is_reduced_metric(self, metric_name: str) -> bool: | |||
"""Check if a metric should be reduced (is defined in a TransformerReducedMetrics subclass).""" | |||
from fast_llm.layers.transformer.config import TransformerReducedMetrics |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't use hard-coded values here. Suggestion above would fix it, or there are a few other ways to get this dynamically.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplified the setup s.t. all metrics that come back from a forward pass are reduced automatically, hence no need for this function.
|
||
|
||
# Store these metrics | ||
if metrics is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the extra computation involved, this should be enabled through a config parameter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how much compute are we talking about for these metrics? likely this won't be noticeable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is already controlled by training.logs.interval
parameter afaiu. Do you think we need a seperate parameter for MoE stats logging?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally we'd like to calculate the bare minimum by default (and the computation isn't optimized), so I think yes.
If we really want by default I guess we could have a parameter with that defaults to true but can disabled for performance, eg. for benchmarks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed
tests/test_moe_metrics.py
Outdated
assert 0.0 < mutual_info < 1.0, f"Expected value between 0 and 1, got {mutual_info}" | ||
|
||
|
||
def test_edge_cases(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More explicit name?
tests/test_moe_metrics.py
Outdated
|
||
|
||
@pytest.fixture | ||
def setup_runner(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These don't belong here. How about test_runner.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why don't they belong here? this is fixture is only useful for the tests in this suite.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can move it to common.py in the future, it maybe be reused by other tests (e.g. ssm test)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm talking about the associated tests, not the fixture. They test a feature of ScheduleRunner
and have nothing to do with MoE other than the loss names which aren't relevant to the tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, makes sense, will move it to a new test file.
tests/test_moe_metrics.py
Outdated
|
||
|
||
|
||
if __name__ == "__main__": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed
@@ -26,6 +27,35 @@ | |||
logger = logging.getLogger(__name__) | |||
|
|||
|
|||
def calculate_normalized_average_entropy(probs: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could try @torch.compile
on these for a free performance boost.
average_entropy = entropy_values.mean() # Average over batch and tokens | ||
return average_entropy / torch.log(torch.tensor(n_experts, dtype=probs.dtype, device=probs.device)) | ||
|
||
def entropy(probs: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
calculate_entropy
@oleksost Are you still working on this? |
@jlamypoirier yes, will address your comments today. Sorry, it was deprioritised in favour of mamba. |
@jlamypoirier I think I addressed all the comments. |
n_experts = probs.size(-1) | ||
entropy_values = calculate_entropy(probs) | ||
average_entropy = entropy_values.mean() # Average over batch and tokens | ||
return average_entropy / torch.log(torch.tensor(n_experts, dtype=probs.dtype, device=probs.device)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
average_entropy/math.log(n_experts)
(same elsewhere)
@@ -135,3 +135,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: | |||
@abc.abstractmethod | |||
def loss_defs(self) -> list[LossDef]: | |||
pass | |||
|
|||
@property |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There isn't much change needed actually, just need to add kwargs["return_metrics"]
. I would prefer doing this here so we don't grow ScheduleRunner
too much.
✨ Description
To better detect potential routing collapse and have a better understanding about the routing distribution, we can track the average entropy and mutual information of routing probabilities.
Collapse routing would have low entropy and low mutual information. A healthy and specialised router would have low entropy and high mutual information, meaning that routing is specialised and considerably different across tokens.
More specifically:
Mutual info. measures the difference between:
🔍 Type of change
Select all that apply:
📝 Changes
mixture_of_experts.py
, they are calculated only for the topk routing type.✅ Checklist
General
Testing
Performance Impact
📊 Performance Impact Details
I am not 100% sure there is no performance impact, we are calculating the stats at each forward pass through the router.
🗒️ Additional Notes
Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.