Skip to content
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions fast_llm/engine/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]:
@abc.abstractmethod
def loss_defs(self) -> list[LossDef]:
pass

@property
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loss/metric split is way more complicated than needed. How about having a single entry, and using a is_metric flag in LossDef (or a derived class) to distinguish? Then no change is needed other than extracting metrics from the context before returning from run_step

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be nice!

Maybe better to leave it for a separate pr? It would make this one larger as it would require also changing the interfaces of the models' forward functions (that expect losses and metrics) as well as making sure that metrics are only calculated when return_metrics is True.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There isn't much change needed actually, just need to add kwargs["return_metrics"]. I would prefer doing this here so we don't grow ScheduleRunner too much.

def metric_defs(self) -> list[LossDef]:
return []
46 changes: 41 additions & 5 deletions fast_llm/engine/schedule/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from fast_llm.engine.schedule.schedule import Schedule, Step
from fast_llm.logging import log_memory_usage
from fast_llm.utils import Assert
from typing import Callable

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -94,6 +95,7 @@ def __init__(
self._tied_parameters = self._multi_stage.tied_parameters
self._num_stages = len(self._stages)
self._loss_defs = {loss_def.name: loss_def for loss_def in self._multi_stage.base_model.loss_defs}
self._metric_defs = {metric_def.name: metric_def for metric_def in self._multi_stage.base_model.metric_defs}

def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> None:
assert not self._is_setup
Expand Down Expand Up @@ -265,20 +267,41 @@ def run_step(
log_pipeline_parallel_main_rank(
lambda: log_memory_usage(f"End of {context.phase.value} iteration {iteration}", str)
)

return self._reduce_losses(context), update_successful, metrics
metrics = self._reduce_metrics(context) if return_metrics else metrics
return (
self._reduce_losses(context),
update_successful,
metrics,
)

def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]:
return self._reduce_metric_or_loss(context, lambda name: self._loss_defs[name].count, "losses")

def _reduce_metrics(self, context: BatchContext) -> dict[str, float | int]:
return self._reduce_metric_or_loss(
context, lambda name: self._metric_defs[name].count, "metrics", self._is_reduced_metric
)

def _reduce_metric_or_loss(
self,
context: BatchContext,
check_count: Callable[[str], int],
reduce_attr: str = "losses",
check_reduce: Callable[[str], bool] = lambda _: True,
) -> dict[str, float | int]:
reduced_losses = {}
num_inputs = self._distributed_config.data_parallel * context.schedule.batch_config.num_inputs
for name, losses in context.losses.items():
for name, losses in context.__getattribute__(reduce_attr).items():
if not check_reduce(name):
reduced_losses[name] = losses
continue
if losses or self._distributed.pipeline_group:
if losses:
reduced_loss = torch.stack(losses).sum() / num_inputs / self._loss_defs[name].count
reduced_loss = torch.stack(losses).sum() / num_inputs / check_count(name)
if self._distributed.data_group:
all_reduce(reduced_loss, group=self._distributed.data_group)
else:
reduced_loss = torch.zeros([1], dtype=self._loss_defs[name].dtype, device=self._distributed.device)
reduced_loss = torch.zeros([1], dtype=check_count(name).dtype, device=self._distributed.device)
if self._distributed.pipeline_group:
all_reduce(reduced_loss, group=self._distributed.pipeline_group)
else:
Expand All @@ -289,6 +312,19 @@ def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]:
for name, reduced_loss in reduced_losses.items()
}

def _is_reduced_metric(self, metric_name: str) -> bool:
"""Check if a metric should be reduced (is defined in a TransformerReducedMetrics subclass)."""
from fast_llm.layers.transformer.config import TransformerReducedMetrics
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't use hard-coded values here. Suggestion above would fix it, or there are a few other ways to get this dynamically.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified the setup s.t. all metrics that come back from a forward pass are reduced automatically, hence no need for this function.


if metric_name not in self._metric_defs:
return False
if not hasattr(self, "_reduced_metrics"):
self._reduced_metrics = set()
for cls in TransformerReducedMetrics.__subclasses__():
for attr_name in dir(cls):
self._reduced_metrics.add(attr_name)
return metric_name in self._reduced_metrics

def _train_step(self, context: BatchContext, step: Step) -> None:
if step.throttle_event is not None:
step.throttle_event.record()
Expand Down
9 changes: 9 additions & 0 deletions fast_llm/layers/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ class TransformerLossNames:
load_balancing_loss = "load_balancing_loss"
router_z_loss = "router_z_loss"

class TransformerReducedMetrics:
"""
Metrics that are reduced in the same way as loss before logging.
"""
pass

class TransformerRoutingMetrics(TransformerReducedMetrics):
normalized_average_entropy = "normalized_average_entropy"
mutual_info = "mutual_info"

class RotaryEmbeddingType(str, enum.Enum):
none = "none"
Expand Down
48 changes: 47 additions & 1 deletion fast_llm/layers/transformer/mixture_of_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
TransformerDimNames,
TransformerKwargs,
TransformerLossNames,
TransformerRoutingMetrics
)
from fast_llm.layers.transformer.mlp import MLPBase
from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage
Expand All @@ -26,6 +27,35 @@
logger = logging.getLogger(__name__)


def calculate_normalized_average_entropy(probs: torch.Tensor) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could try @torch.compile on these for a free performance boost.

"""
Calculates routing entropy for each token, then averages over all tokens.
If low, means a lot of mass is put on a single expert in all tokens, which can indicate collapse or specialization.
"""
n_experts = probs.size(-1)
entropy_values = entropy(probs)
average_entropy = entropy_values.mean() # Average over batch and tokens
return average_entropy / torch.log(torch.tensor(n_experts, dtype=probs.dtype, device=probs.device))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

average_entropy/math.log(n_experts) (same elsewhere)


def entropy(probs: torch.Tensor) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

calculate_entropy

probs = torch.clamp(probs, min=1e-9) # Avoid log(0)
return -torch.sum(probs * torch.log(probs), dim=-1)


def calculate_mutual_information(probs: torch.Tensor) -> torch.Tensor:
"""
Calculates the difference between the entropy of the average routing and
the average routing entropy, we average across all tokens of all examples in the batch.
If low, means that routing is not informative.
"""
n_experts = probs.size(-1)
average_routing = torch.mean(probs.view(-1, n_experts), dim=0) # Average over tokens
entropy_avg_routing = entropy(average_routing) / torch.log(torch.tensor(n_experts, dtype=probs.dtype)) # H[E[X]]
entropy_routing = calculate_normalized_average_entropy(probs) # E[H[X]]

return entropy_avg_routing - entropy_routing


class MixtureOfExpertMLP(MLPBase):
"""
MoeLayer following implementation from
Expand Down Expand Up @@ -103,7 +133,7 @@ def forward(

# Routing
if self._routing_type == RoutingType.topk:
scores, top_experts = self._topk_routing(logits, kwargs.get(TransformerKwargs.grad_output), losses)
scores, top_experts = self._topk_routing(logits, kwargs.get(TransformerKwargs.grad_output), losses, metrics)
if self._num_shared_experts > 0:
scores, top_experts = self._add_shared_experts(top_experts, scores)
elif self._routing_type == RoutingType.sinkhorn:
Expand Down Expand Up @@ -169,11 +199,27 @@ def _topk_routing(
logits: torch.Tensor,
grad_scale: float | None = None,
losses: dict | None = None,
metrics: dict | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
top_logits, top_experts = torch.topk(logits, k=self._experts_per_token, dim=-1)
scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32)
if losses is not None or (self.training and grad_scale is not None):
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)


# Store these metrics
if metrics is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the extra computation involved, this should be enabled through a config parameter

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how much compute are we talking about for these metrics? likely this won't be noticeable.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is already controlled by training.logs.interval parameter afaiu. Do you think we need a seperate parameter for MoE stats logging?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we'd like to calculate the bare minimum by default (and the computation isn't optimized), so I think yes.

If we really want by default I guess we could have a parameter with that defaults to true but can disabled for performance, eg. for benchmarks.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed

# Calculate and log entropy and mutual information
entropy = calculate_normalized_average_entropy(probs)
mutual_info = calculate_mutual_information(probs)
if TransformerRoutingMetrics.normalized_average_entropy not in metrics:
metrics[TransformerRoutingMetrics.normalized_average_entropy] = []
if TransformerRoutingMetrics.mutual_info not in metrics:
metrics[TransformerRoutingMetrics.mutual_info] = []

metrics[TransformerRoutingMetrics.normalized_average_entropy].append(entropy.detach())
metrics[TransformerRoutingMetrics.mutual_info].append(mutual_info.detach())

mask = torch.nn.functional.one_hot(top_experts, num_classes=self._num_unshared_experts).sum(dim=1)
# Auxiliary loss, corresponding to the sum of probabilities for the top experts.
# In the optimal case (uniform distribution), loss = experts_per_token / num_experts.
Expand Down
25 changes: 25 additions & 0 deletions fast_llm/models/gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
TransformerDimNames,
TransformerKwargs,
TransformerLossNames,
TransformerRoutingMetrics,
)
from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, RotaryEmbeddingPreprocessor
from fast_llm.layers.transformer.transformer import TransformerLayer
Expand Down Expand Up @@ -308,10 +309,34 @@ def loss_defs(self) -> list[LossDef]:
count=self._config.transformer.num_layers,
)
)

if self._config.logit_z_loss:
LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=1)
return loss_defs

@property
def metric_defs(self) -> list[LossDef]:
metric_defs = []
if (
self._config.transformer.num_experts > 1
and self._config.transformer.expert_routing_type == RoutingType.topk
):
metric_defs.append(
LossDef(
name=TransformerRoutingMetrics.normalized_average_entropy,
formatted_name="Normalized Entropy",
count=self._config.transformer.num_layers,
)
)
metric_defs.append(
LossDef(
name=TransformerRoutingMetrics.mutual_info,
formatted_name="Mutual Information",
count=self._config.transformer.num_layers,
)
)
return metric_defs


class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]):
config_class: typing.ClassVar[type[GPTModelConfig]] = GPTModelConfig
Expand Down
Loading