Skip to content

Commit 83e148b

Browse files
committed
Restructure independent model to patch draft model config
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
1 parent 7cfc8ae commit 83e148b

File tree

4 files changed

+346
-195
lines changed

4 files changed

+346
-195
lines changed

src/speculators/config.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -253,14 +253,12 @@ def __pydantic_schema_base_type__(cls) -> type["SpeculatorModelConfig"]:
253253
schema_discriminator: ClassVar[str] = "speculators_model_type"
254254

255255
# PretrainedConfig class attributes
256-
model_type: ClassVar[str] = "speculator_model" # type: ignore[misc]
257256
base_config_key: ClassVar[str] = "" # type: ignore[misc]
258257
sub_configs: ClassVar[dict[str, PretrainedConfig]] = {} # type: ignore[misc]
259258
is_composition: ClassVar[bool] = False # type: ignore[misc]
260259
attribute_map: ClassVar[dict[str, str]] = {} # type: ignore[misc]
261260
base_model_tp_plan: ClassVar[Optional[dict[str, Any]]] = None # type: ignore[misc]
262261
base_model_pp_plan: ClassVar[Optional[dict[str, tuple[list[str]]]]] = None # type: ignore[misc]
263-
_auto_class: ClassVar[Optional[str]] = "" # type: ignore[misc]
264262

265263
# Speculator model instance attributes
266264
speculators_model_type: str = Field(
@@ -271,7 +269,7 @@ def __pydantic_schema_base_type__(cls) -> type["SpeculatorModelConfig"]:
271269
default=version("speculators"),
272270
description="Version of the speculators library",
273271
)
274-
speculators_config: SpeculatorsConfig = Field( # type: ignore[assignment]
272+
speculators_config: Optional[SpeculatorsConfig] = Field( # type: ignore[assignment]
275273
default=None,
276274
description=(
277275
"The speculators config describing what the model implements and creation. "
@@ -283,6 +281,9 @@ def __init__(self, **kwargs):
283281
# initialize the Pydantic arguments first to set all valid fields
284282
PydanticClassRegistryMixin.__init__(self, **kwargs)
285283

284+
# Set model_type to speculator_model if not already set
285+
self.model_type = kwargs.setdefault("model_type", "speculator_model")
286+
286287
# reset kwargs handled by Pydantic so PretrainedConfig doesn't override
287288
for field in self.__class__.model_fields:
288289
kwargs[field] = getattr(self, field)
@@ -308,14 +309,12 @@ def to_dict(self) -> dict[str, Any]:
308309
"auto_package",
309310
"registry_auto_discovery",
310311
"schema_discriminator",
311-
"model_type",
312312
"base_config_key",
313313
"sub_configs",
314314
"is_composition",
315315
"attribute_map",
316316
"base_model_tp_plan",
317317
"base_model_pp_plan",
318-
"_auto_class",
319318
):
320319
config_dict.pop(key, None)
321320

0 commit comments

Comments
 (0)