Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/speculators/convert/eagle/eagle3_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,7 @@ def _build_eagle3_speculator_config(
return Eagle3SpeculatorConfig(
transformer_layer_config=transformer_config,
speculators_config=speculators_config,
draft_vocab_size=eagle_config.get("draft_vocab_size", 32000),
norm_before_residual=norm_before_residual,
target_hidden_size=eagle_config.get("target_hidden_size"),
)

def _create_transformer_config_from_eagle(
Expand Down
111 changes: 63 additions & 48 deletions src/speculators/models/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
from typing import Any, ClassVar, Literal, Optional, Union

import torch
from pydantic import Field, field_serializer, field_validator, model_validator
from pydantic import (
BaseModel,
Field,
field_serializer,
field_validator,
model_validator,
)
from torch import nn
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
Expand All @@ -32,11 +38,66 @@
__all__ = [
"EagleSpeculator",
"EagleSpeculatorConfig",
"TransformerLayerConfigMixin",
]


class TransformerLayerConfigMixin(BaseModel):
transformer_layer_config: PretrainedConfig = Field(
default_factory=LlamaConfig,
description=(
"Configuration object for the transformer layer architecture. "
"Must be a PretrainedConfig instance that matches the requirements "
"of the transformer_layer_architecture. Contains parameters such as "
"hidden_size, num_attention_heads, intermediate_size, vocab_size, "
"and other architecture-specific settings. "
"Additionally, it contains all the necessary information to check and "
"validate compatibility between the speculator and verifier models, "
"such as the vocab_size used for the speculator and the hidden_size "
"used for the speculator's transformer layer, which must match "
"the verifier's hidden_size according to the algorithm's design."
),
)

@field_serializer("transformer_layer_config")
def serialize_transformer_layer_config(self, value: PretrainedConfig) -> dict:
"""
Serialize the transformer_layer_config to a dictionary for JSON storage.

Converts the PretrainedConfig object to its dictionary representation
using to_diff_dict() to only include non-default values.

:param value: The PretrainedConfig instance to serialize
:return: Dictionary representation of the transformer layer configuration
"""
return value.to_diff_dict()

@field_validator("transformer_layer_config", mode="before")
@classmethod
def validate_transformer_layer_config(cls, value: Any) -> PretrainedConfig:
"""
Validate and convert transformer_layer_config to a PretrainedConfig instance.

Accepts either a dictionary that can be converted to a PretrainedConfig
or an existing PretrainedConfig instance.

:param value: The value to validate (dict or PretrainedConfig)
:return: A validated PretrainedConfig instance
:raises ValueError: If the value cannot be converted to a PretrainedConfig
"""
if isinstance(value, dict):
return AutoConfig.for_model(**value)
if isinstance(value, PretrainedConfig):
return value

raise ValueError(
"transformer_layer_config must be a PretrainedConfig instance or a "
"dictionary that can be converted to a PretrainedConfig."
)


@SpeculatorModelConfig.register("eagle")
class EagleSpeculatorConfig(SpeculatorModelConfig):
class EagleSpeculatorConfig(SpeculatorModelConfig, TransformerLayerConfigMixin):
"""
A SpeculatorModelConfig implementation to be used with the EagleSpeculator
for EAGLE and HASS variants for spec decoding:
Expand Down Expand Up @@ -91,16 +152,6 @@ class EagleSpeculatorConfig(SpeculatorModelConfig):
"transformer decoder layer class (e.g., 'LlamaDecoderLayer')."
),
)
transformer_layer_config: PretrainedConfig = Field(
default_factory=LlamaConfig,
description=(
"Configuration object for the transformer layer architecture. "
"Must be a PretrainedConfig instance that matches the requirements "
"of the transformer_layer_architecture. Contains parameters such as "
"hidden_size, num_attention_heads, intermediate_size, vocab_size, "
"and other architecture-specific settings."
),
)
layernorms: bool = Field(
default=False,
description=(
Expand Down Expand Up @@ -140,42 +191,6 @@ def check_add_architectures(self) -> Self:

return self

@field_serializer("transformer_layer_config")
def serialize_transformer_layer_config(self, value: PretrainedConfig) -> dict:
"""
Serialize the transformer_layer_config to a dictionary for JSON storage.

Converts the PretrainedConfig object to its dictionary representation
using to_diff_dict() to only include non-default values.

:param value: The PretrainedConfig instance to serialize
:return: Dictionary representation of the transformer layer configuration
"""
return value.to_diff_dict()

@field_validator("transformer_layer_config", mode="before")
@classmethod
def validate_transformer_layer_config(cls, value: Any) -> PretrainedConfig:
"""
Validate and convert transformer_layer_config to a PretrainedConfig instance.

Accepts either a dictionary that can be converted to a PretrainedConfig
or an existing PretrainedConfig instance.

:param value: The value to validate (dict or PretrainedConfig)
:return: A validated PretrainedConfig instance
:raises ValueError: If the value cannot be converted to a PretrainedConfig
"""
if isinstance(value, dict):
return AutoConfig.for_model(**value)
if isinstance(value, PretrainedConfig):
return value

raise ValueError(
"transformer_layer_config must be a PretrainedConfig instance or a "
"dictionary that can be converted to a PretrainedConfig."
)


@SpeculatorModel.register("eagle")
class EagleSpeculator(SpeculatorModel):
Expand Down
Loading
Loading