-
Notifications
You must be signed in to change notification settings - Fork 43
[inactive] Track entropy and MI of routing distribution for topk MoE #188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
2a7cf1b
dd85e84
aef18e7
bef39d8
620ec76
7a93aee
eb617e8
440738a
e5f3c4b
27e2a5c
b016d95
7b9ac8c
0577b2c
9e2ec37
efd16bf
1202f5f
9855b82
9c47764
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,7 @@ | |
| from fast_llm.engine.schedule.schedule import Schedule, Step | ||
| from fast_llm.logging import log_memory_usage | ||
| from fast_llm.utils import Assert | ||
| from typing import Callable | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
@@ -94,6 +95,7 @@ def __init__( | |
| self._tied_parameters = self._multi_stage.tied_parameters | ||
| self._num_stages = len(self._stages) | ||
| self._loss_defs = {loss_def.name: loss_def for loss_def in self._multi_stage.base_model.loss_defs} | ||
| self._metric_defs = {metric_def.name: metric_def for metric_def in self._multi_stage.base_model.metric_defs} | ||
|
|
||
| def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> None: | ||
| assert not self._is_setup | ||
|
|
@@ -265,20 +267,41 @@ def run_step( | |
| log_pipeline_parallel_main_rank( | ||
| lambda: log_memory_usage(f"End of {context.phase.value} iteration {iteration}", str) | ||
| ) | ||
|
|
||
| return self._reduce_losses(context), update_successful, metrics | ||
| metrics = self._reduce_metrics(context) if return_metrics else metrics | ||
| return ( | ||
| self._reduce_losses(context), | ||
| update_successful, | ||
| metrics, | ||
| ) | ||
|
|
||
| def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: | ||
| return self._reduce_metric_or_loss(context, lambda name: self._loss_defs[name].count, "losses") | ||
|
|
||
| def _reduce_metrics(self, context: BatchContext) -> dict[str, float | int]: | ||
| return self._reduce_metric_or_loss( | ||
| context, lambda name: self._metric_defs[name].count, "metrics", self._is_reduced_metric | ||
| ) | ||
|
|
||
| def _reduce_metric_or_loss( | ||
| self, | ||
| context: BatchContext, | ||
| check_count: Callable[[str], int], | ||
| reduce_attr: str = "losses", | ||
| check_reduce: Callable[[str], bool] = lambda _: True, | ||
| ) -> dict[str, float | int]: | ||
| reduced_losses = {} | ||
| num_inputs = self._distributed_config.data_parallel * context.schedule.batch_config.num_inputs | ||
| for name, losses in context.losses.items(): | ||
| for name, losses in context.__getattribute__(reduce_attr).items(): | ||
| if not check_reduce(name): | ||
| reduced_losses[name] = losses | ||
| continue | ||
| if losses or self._distributed.pipeline_group: | ||
| if losses: | ||
| reduced_loss = torch.stack(losses).sum() / num_inputs / self._loss_defs[name].count | ||
| reduced_loss = torch.stack(losses).sum() / num_inputs / check_count(name) | ||
| if self._distributed.data_group: | ||
| all_reduce(reduced_loss, group=self._distributed.data_group) | ||
| else: | ||
| reduced_loss = torch.zeros([1], dtype=self._loss_defs[name].dtype, device=self._distributed.device) | ||
| reduced_loss = torch.zeros([1], dtype=check_count(name).dtype, device=self._distributed.device) | ||
| if self._distributed.pipeline_group: | ||
| all_reduce(reduced_loss, group=self._distributed.pipeline_group) | ||
| else: | ||
|
|
@@ -289,6 +312,19 @@ def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: | |
| for name, reduced_loss in reduced_losses.items() | ||
| } | ||
|
|
||
| def _is_reduced_metric(self, metric_name: str) -> bool: | ||
| """Check if a metric should be reduced (is defined in a TransformerReducedMetrics subclass).""" | ||
| from fast_llm.layers.transformer.config import TransformerReducedMetrics | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can't use hard-coded values here. Suggestion above would fix it, or there are a few other ways to get this dynamically.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Simplified the setup s.t. all metrics that come back from a forward pass are reduced automatically, hence no need for this function. |
||
|
|
||
| if metric_name not in self._metric_defs: | ||
| return False | ||
| if not hasattr(self, "_reduced_metrics"): | ||
| self._reduced_metrics = set() | ||
| for cls in TransformerReducedMetrics.__subclasses__(): | ||
| for attr_name in dir(cls): | ||
| self._reduced_metrics.add(attr_name) | ||
| return metric_name in self._reduced_metrics | ||
|
|
||
| def _train_step(self, context: BatchContext, step: Step) -> None: | ||
| if step.throttle_event is not None: | ||
| step.throttle_event.record() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| TransformerDimNames, | ||
| TransformerKwargs, | ||
| TransformerLossNames, | ||
| TransformerRoutingMetrics | ||
| ) | ||
| from fast_llm.layers.transformer.mlp import MLPBase | ||
| from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage | ||
|
|
@@ -26,6 +27,35 @@ | |
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def calculate_normalized_average_entropy(probs: torch.Tensor) -> torch.Tensor: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could try |
||
| """ | ||
| Calculates routing entropy for each token, then averages over all tokens. | ||
| If low, means a lot of mass is put on a single expert in all tokens, which can indicate collapse or specialization. | ||
| """ | ||
| n_experts = probs.size(-1) | ||
| entropy_values = entropy(probs) | ||
| average_entropy = entropy_values.mean() # Average over batch and tokens | ||
| return average_entropy / torch.log(torch.tensor(n_experts, dtype=probs.dtype, device=probs.device)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| def entropy(probs: torch.Tensor) -> torch.Tensor: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| probs = torch.clamp(probs, min=1e-9) # Avoid log(0) | ||
| return -torch.sum(probs * torch.log(probs), dim=-1) | ||
|
|
||
|
|
||
| def calculate_mutual_information(probs: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Calculates the difference between the entropy of the average routing and | ||
| the average routing entropy, we average across all tokens of all examples in the batch. | ||
| If low, means that routing is not informative. | ||
| """ | ||
| n_experts = probs.size(-1) | ||
| average_routing = torch.mean(probs.view(-1, n_experts), dim=0) # Average over tokens | ||
| entropy_avg_routing = entropy(average_routing) / torch.log(torch.tensor(n_experts, dtype=probs.dtype)) # H[E[X]] | ||
| entropy_routing = calculate_normalized_average_entropy(probs) # E[H[X]] | ||
|
|
||
| return entropy_avg_routing - entropy_routing | ||
|
|
||
|
|
||
| class MixtureOfExpertMLP(MLPBase): | ||
| """ | ||
| MoeLayer following implementation from | ||
|
|
@@ -103,7 +133,7 @@ def forward( | |
|
|
||
| # Routing | ||
| if self._routing_type == RoutingType.topk: | ||
| scores, top_experts = self._topk_routing(logits, kwargs.get(TransformerKwargs.grad_output), losses) | ||
| scores, top_experts = self._topk_routing(logits, kwargs.get(TransformerKwargs.grad_output), losses, metrics) | ||
| if self._num_shared_experts > 0: | ||
| scores, top_experts = self._add_shared_experts(top_experts, scores) | ||
| elif self._routing_type == RoutingType.sinkhorn: | ||
|
|
@@ -169,11 +199,27 @@ def _topk_routing( | |
| logits: torch.Tensor, | ||
| grad_scale: float | None = None, | ||
| losses: dict | None = None, | ||
| metrics: dict | None = None, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| top_logits, top_experts = torch.topk(logits, k=self._experts_per_token, dim=-1) | ||
| scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32) | ||
| if losses is not None or (self.training and grad_scale is not None): | ||
| probs = torch.softmax(logits, dim=-1, dtype=torch.float32) | ||
|
|
||
|
|
||
| # Store these metrics | ||
| if metrics is not None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given the extra computation involved, this should be enabled through a config parameter
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how much compute are we talking about for these metrics? likely this won't be noticeable.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is already controlled by
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally we'd like to calculate the bare minimum by default (and the computation isn't optimized), so I think yes. If we really want by default I guess we could have a parameter with that defaults to true but can disabled for performance, eg. for benchmarks.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. addressed |
||
| # Calculate and log entropy and mutual information | ||
| entropy = calculate_normalized_average_entropy(probs) | ||
| mutual_info = calculate_mutual_information(probs) | ||
| if TransformerRoutingMetrics.normalized_average_entropy not in metrics: | ||
| metrics[TransformerRoutingMetrics.normalized_average_entropy] = [] | ||
| if TransformerRoutingMetrics.mutual_info not in metrics: | ||
| metrics[TransformerRoutingMetrics.mutual_info] = [] | ||
|
|
||
| metrics[TransformerRoutingMetrics.normalized_average_entropy].append(entropy.detach()) | ||
| metrics[TransformerRoutingMetrics.mutual_info].append(mutual_info.detach()) | ||
|
|
||
| mask = torch.nn.functional.one_hot(top_experts, num_classes=self._num_unshared_experts).sum(dim=1) | ||
| # Auxiliary loss, corresponding to the sum of probabilities for the top experts. | ||
| # In the optimal case (uniform distribution), loss = experts_per_token / num_experts. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This loss/metric split is way more complicated than needed. How about having a single entry, and using a
is_metricflag inLossDef(or a derived class) to distinguish? Then no change is needed other than extracting metrics from the context before returning fromrun_stepThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be nice!
Maybe better to leave it for a separate pr? It would make this one larger as it would require also changing the interfaces of the models' forward functions (that expect losses and metrics) as well as making sure that metrics are only calculated when
return_metricsis True.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There isn't much change needed actually, just need to add
kwargs["return_metrics"]. I would prefer doing this here so we don't growScheduleRunnertoo much.