Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/speculators/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,14 +253,12 @@ def __pydantic_schema_base_type__(cls) -> type["SpeculatorModelConfig"]:
schema_discriminator: ClassVar[str] = "speculators_model_type"

# PretrainedConfig class attributes
model_type: ClassVar[str] = "speculator_model" # type: ignore[misc]
base_config_key: ClassVar[str] = "" # type: ignore[misc]
sub_configs: ClassVar[dict[str, type[PretrainedConfig]]] = {} # type: ignore[misc,assignment]
is_composition: ClassVar[bool] = False # type: ignore[misc]
attribute_map: ClassVar[dict[str, str]] = {} # type: ignore[misc]
base_model_tp_plan: ClassVar[Optional[dict[str, Any]]] = None # type: ignore[misc]
base_model_pp_plan: ClassVar[Optional[dict[str, tuple[list[str]]]]] = None # type: ignore[misc]
_auto_class: ClassVar[Optional[str]] = "" # type: ignore[misc]

# Speculator model instance attributes
speculators_model_type: str = Field(
Expand All @@ -283,6 +281,9 @@ def __init__(self, **kwargs):
# initialize the Pydantic arguments first to set all valid fields
PydanticClassRegistryMixin.__init__(self, **kwargs)

# Set model_type to speculator_model if not already set
self.model_type = kwargs.setdefault("model_type", "speculator_model")

# reset kwargs handled by Pydantic so PretrainedConfig doesn't override
for field in self.__class__.model_fields:
kwargs[field] = getattr(self, field)
Expand All @@ -308,14 +309,12 @@ def to_dict(self) -> dict[str, Any]:
"auto_package",
"registry_auto_discovery",
"schema_discriminator",
"model_type",
"base_config_key",
"sub_configs",
"is_composition",
"attribute_map",
"base_model_tp_plan",
"base_model_pp_plan",
"_auto_class",
):
config_dict.pop(key, None)

Expand Down
3 changes: 2 additions & 1 deletion src/speculators/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from .eagle import EagleSpeculator, EagleSpeculatorConfig
from .independent import IndependentSpeculatorConfig
from .independent import IndependentSpeculator, IndependentSpeculatorConfig
from .mlp import MLPSpeculatorConfig

__all__ = [
"EagleSpeculator",
"EagleSpeculatorConfig",
"IndependentSpeculator",
"IndependentSpeculatorConfig",
"MLPSpeculatorConfig",
]
Loading