diff --git a/src/speculators/config.py b/src/speculators/config.py index 36acea4c..cf105a9f 100644 --- a/src/speculators/config.py +++ b/src/speculators/config.py @@ -253,14 +253,12 @@ def __pydantic_schema_base_type__(cls) -> type["SpeculatorModelConfig"]: schema_discriminator: ClassVar[str] = "speculators_model_type" # PretrainedConfig class attributes - model_type: ClassVar[str] = "speculator_model" # type: ignore[misc] base_config_key: ClassVar[str] = "" # type: ignore[misc] sub_configs: ClassVar[dict[str, type[PretrainedConfig]]] = {} # type: ignore[misc,assignment] is_composition: ClassVar[bool] = False # type: ignore[misc] attribute_map: ClassVar[dict[str, str]] = {} # type: ignore[misc] base_model_tp_plan: ClassVar[Optional[dict[str, Any]]] = None # type: ignore[misc] base_model_pp_plan: ClassVar[Optional[dict[str, tuple[list[str]]]]] = None # type: ignore[misc] - _auto_class: ClassVar[Optional[str]] = "" # type: ignore[misc] # Speculator model instance attributes speculators_model_type: str = Field( @@ -283,6 +281,9 @@ def __init__(self, **kwargs): # initialize the Pydantic arguments first to set all valid fields PydanticClassRegistryMixin.__init__(self, **kwargs) + # Set model_type to speculator_model if not already set + self.model_type = kwargs.setdefault("model_type", "speculator_model") + # reset kwargs handled by Pydantic so PretrainedConfig doesn't override for field in self.__class__.model_fields: kwargs[field] = getattr(self, field) @@ -308,14 +309,12 @@ def to_dict(self) -> dict[str, Any]: "auto_package", "registry_auto_discovery", "schema_discriminator", - "model_type", "base_config_key", "sub_configs", "is_composition", "attribute_map", "base_model_tp_plan", "base_model_pp_plan", - "_auto_class", ): config_dict.pop(key, None) diff --git a/src/speculators/models/__init__.py b/src/speculators/models/__init__.py index 660aec56..4731bda7 100644 --- a/src/speculators/models/__init__.py +++ b/src/speculators/models/__init__.py @@ -1,10 +1,11 @@ from .eagle import EagleSpeculator, EagleSpeculatorConfig -from .independent import IndependentSpeculatorConfig +from .independent import IndependentSpeculator, IndependentSpeculatorConfig from .mlp import MLPSpeculatorConfig __all__ = [ "EagleSpeculator", "EagleSpeculatorConfig", + "IndependentSpeculator", "IndependentSpeculatorConfig", "MLPSpeculatorConfig", ] diff --git a/src/speculators/models/independent.py b/src/speculators/models/independent.py index ca35b879..f9f7a124 100644 --- a/src/speculators/models/independent.py +++ b/src/speculators/models/independent.py @@ -1,12 +1,20 @@ -from transformers import PretrainedConfig +import os +from typing import Any, ClassVar, Literal, Optional, Union + +from transformers import PretrainedConfig, PreTrainedModel +from transformers.models.auto.configuration_auto import CONFIG_MAPPING +from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING from speculators import SpeculatorModelConfig, SpeculatorsConfig +from speculators.model import SpeculatorModel __all__ = ["IndependentSpeculatorConfig"] @SpeculatorModelConfig.register("independent") class IndependentSpeculatorConfig(SpeculatorModelConfig): + speculators_model_type: Literal["independent"] = "independent" + @classmethod def from_pretrained_config( cls, pretrained_config: PretrainedConfig, speculators_config: SpeculatorsConfig @@ -16,16 +24,332 @@ def from_pretrained_config( return cls(**pretrained_dict, speculators_config=speculators_config) - speculators_model_type: str = "independent" + @classmethod + def from_dict( + cls, config_dict: dict[str, Any], **kwargs + ) -> "IndependentSpeculatorConfig": + """ + Create a IndependentSpeculatorConfig from a dictionary, automatically + instantiating the correct subclass based on the speculators_model_type field. + + :param config_dict: Dictionary containing the configuration + :param kwargs: Additional keyword arguments that override config values + :return: A IndependentSpeculatorConfig instance + """ + dict_obj = {**config_dict, **kwargs} + + spec_model_type = dict_obj.setdefault("speculators_model_type", "independent") + if spec_model_type != "independent": + raise ValueError( + f"Wrong speculators_model_type: {spec_model_type} for" + "IndependentSpeculatorConfig." + ) + + if "model_type" not in dict_obj: + raise ValueError("Expected model_type in config_dict") + + return cls.model_validate(dict_obj) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + **kwargs, + ) -> "IndependentSpeculatorConfig": + """ + Load a IndependentSpeculatorConfig from the name/id of a model on the + HuggingFace Hub or from a local directory. + + :param pretrained_model_name_or_path: The name or path to the pretrained model. + :param cache_dir: The directory to cache the config in. + :param force_download: Whether to force download the config from the Hub. + :param local_files_only: Whether to use local files, not download from the Hub. + :param token: The token to use for authentication with the Hub. + :param revision: The revision of the config to load from the Hub. + :param kwargs: Additional keyword arguments to pass to the config. + :return: A IndependentSpeculatorConfig object with the loaded parameters. + """ + # Transformers config loading + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + **kwargs, + ) + + return cls.from_dict(config_dict, **kwargs) + + +@SpeculatorModel.register("independent") +class IndependentSpeculator(SpeculatorModel): + config_class: ClassVar[type[IndependentSpeculatorConfig]] = ( # type: ignore[misc] + IndependentSpeculatorConfig + ) + + _independent_speculator_mod_attributes = { + "_draft_model", + "_draft_model_class", + "verifier", + "verifier_attachment_mode", + } + + def __init__( + self, + config: IndependentSpeculatorConfig, + verifier: Optional[Union[str, os.PathLike, PreTrainedModel]] = None, + verifier_attachment_mode: Optional[ + Literal["detached", "full", "train_only"] + ] = None, + ): + if not isinstance(config, IndependentSpeculatorConfig): + if not isinstance(config, PretrainedConfig): + raise ValueError( + "Attempted to initialize a IndependentSpeculator with a" + f" {type(config)} class as the config class. Please use" + "a IndependentSpeculatorConfig instance or a subclass of" + "PretrainedConfig instead." + ) + if ( + hasattr(config, "speculators_model_type") + and config.speculators_model_type != "independent" + ): + raise ValueError( + "Attempted to initialize a IndependentSpeculator with a " + f"{config.speculators_model_type} config class. " + "IndependentSpeculator only supports models with " + "speculators_model_type='independent'." + ) + # Subclass of PretrainedConfig but not an IndependentSpeculatorConfig + # Convert to IndependentSpeculatorConfig + config = IndependentSpeculatorConfig.from_pretrained_config( + pretrained_config=config, speculators_config=None + ) + + self._draft_model = None + + super().__init__( + config=config, + verifier=verifier, + verifier_attachment_mode=verifier_attachment_mode, + ) + + config_class: type[PretrainedConfig] = CONFIG_MAPPING[config.model_type] + self._draft_model_class: type[PreTrainedModel] = MODEL_FOR_CAUSAL_LM_MAPPING[ # type: ignore[assignment] + config_class + ] + self._draft_model = self._draft_model_class(config) # type: ignore[operator] + + self.post_init() + + def forward(self, *args, **kwargs): + if self._draft_model is None: + raise ValueError("Draft model is not initialized") + + return self._draft_model(*args, **kwargs) + + @classmethod + def from_pretrained( + cls: type["IndependentSpeculator"], + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *model_args, + verifier: Optional[Union[str, os.PathLike, PreTrainedModel]] = None, + verifier_attachment_mode: Optional[ + Literal["detached", "full", "train_only"] + ] = None, + config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + ignore_mismatched_sizes: bool = False, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + use_safetensors: Optional[bool] = None, + weights_only: bool = True, + **kwargs, + ) -> "IndependentSpeculator": + """ + Load a pretrained speculator model from the Hugging Face Hub or local directory. + + This method automatically resolves the correct speculator model class based on + the configuration type and loads the model with the appropriate weights. If + called on the base SpeculatorModel class, it will automatically determine and + instantiate the correct subclass based on the model configuration. + + Example: + ```python + # Load with automatic class resolution + model = SpeculatorModel.from_pretrained("RedHatAI/speculator-llama-7b") + + # Load from local directory + model = SpeculatorModel.from_pretrained("./my_speculator") + + # Load with custom config + config = SpeculatorModelConfig.from_pretrained("RedHatAI/eagle-llama-7b") + model = SpeculatorModel.from_pretrained( + None, config=config, state_dict=state_dict + ) + ``` + + :param pretrained_model_name_or_path: The model identifier on Hugging Face Hub, + or path to a local directory containing the model files. Can be None if + config is provided as a path. + :param model_args: Additional positional arguments passed to the model + constructor. + :param verifier: Optional verifier model to attach the speculator to. + Can be a path to a local model directory, a Hugging Face model identifier, + or an instance of PreTrainedModel. If provided, the speculator will use this + verifier for speculative decoding. If None, the speculator will load the + verifier from the config if specified, or it must be attached later + using the `attach_verifier` method. + :param verifier_attachment_mode: Optional mode for how the verifier is + attached to the speculator. If "detached", any verifier passed in or + resolved from the config will not be ignored. + If "full", the verifier is fully integrated into the + speculator's forward pass and generation methods. + If "train_only", only the portions of the verifier needed for training + are attached, allowing for better resource utilization during training. + If None and a verifier is provided, it defaults to "full". + If a verifier is not provided and None is found in the config, + this parameter is ignored. + :param config: Optional configuration for the model. Can be a + SpeculatorModelConfig instance, a path to a config file, or None to load + from model directory. + :param cache_dir: Directory to cache downloaded files. If None, uses default + transformers cache directory. + :param ignore_mismatched_sizes: Whether to ignore size mismatches when loading + pretrained weights. Useful for loading models with different architectures. + :param force_download: Whether to force re-download of model files even if + they exist in cache. + :param local_files_only: Whether to avoid downloading files and only use local + cached files. Raises an error if files are not found locally. + :param token: Optional authentication token for accessing private models on + Hugging Face Hub. Can be a string token or True to use saved token. + :param revision: The specific model revision to load (branch name, tag, or + commit hash). Defaults to "main". + :param use_safetensors: Whether to use safetensors format for loading weights. + If None, automatically detects the available format. + :param weights_only: Whether to only load model weights without optimizer + states or other training artifacts. + :param kwargs: Additional keyword arguments passed to the model constructor + and loading process. + :return: A SpeculatorModel instance of the appropriate subclass, loaded with + the pretrained weights and configuration. + """ + if not config: + if not pretrained_model_name_or_path: + raise ValueError( + "Either `config` or `pretrained_model_name_or_path` must be " + "provided to load a SpeculatorModel." + ) + config = cls.config_class.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + ) + + if isinstance(config, PretrainedConfig) and not isinstance( + config, IndependentSpeculatorConfig + ): + # Convert PretrainedConfig to IndependentSpeculatorConfig + config = IndependentSpeculatorConfig.from_dict(config.to_dict()) + + if not isinstance(config, IndependentSpeculatorConfig): + raise ValueError( + f"Expected config to be an instance of IndependentSpeculatorConfig, " + f"got {type(config)}." + ) + + if not pretrained_model_name_or_path and not kwargs.get("state_dict"): + raise ValueError( + "Either `pretrained_model_name_or_path` or `state_dict` must be " + "provided to load a SpeculatorModel." + ) + + independent_speculator = cls( + config=config, + verifier=verifier, + verifier_attachment_mode=verifier_attachment_mode, + ) + + # Load the draft model + independent_speculator._draft_model = ( + independent_speculator._draft_model_class.from_pretrained( + pretrained_model_name_or_path, + *model_args, + config=config, + cache_dir=cache_dir, + ignore_mismatched_sizes=ignore_mismatched_sizes, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + use_safetensors=use_safetensors, + weights_only=weights_only, + **kwargs, + ) + ) + + return independent_speculator + + def save_pretrained(self, *args, **kwargs): + if self._draft_model is None: + raise ValueError("Draft model is not initialized") + self._draft_model.save_pretrained(*args, **kwargs) + + def load_state_dict(self, *args, **kwargs): + if self._draft_model is None: + raise ValueError("Draft model is not initialized") + self._draft_model.load_state_dict(*args, **kwargs) + + def state_dict(self, *args, **kwargs): + if self._draft_model is None: + raise ValueError("Draft model is not initialized") + return self._draft_model.state_dict(*args, **kwargs) + + def __getattr__(self, name: str) -> Any: + if name == "_draft_model": + return self._modules["_draft_model"] + if self._draft_model is None: + return super().__getattr__(name) + + if name in IndependentSpeculator._independent_speculator_mod_attributes: + return super().__getattr__(name) + + return getattr(self._draft_model, name) + + def __setattr__(self, name: str, val: Any) -> None: + # Allow patching over class attributes + if hasattr(type(self), name): + return super().__setattr__(name, val) + + if name in IndependentSpeculator._independent_speculator_mod_attributes: + return super().__setattr__(name, val) + + if self._draft_model is None: + return super().__setattr__(name, val) + + return setattr(self._draft_model, name, val) + + def __delattr__(self, name: str) -> None: + # This mirrors `__setattr__` + if hasattr(type(self), name): + return super().__delattr__(name) - def __init__(self, **kwargs): - super().__init__(**kwargs) + if name in IndependentSpeculator._independent_speculator_mod_attributes: + return super().__delattr__(name) - # ensure we set the model_type to the one from the original config - self._model_type = kwargs.get("model_type") + if self._draft_model is None: + return super().__delattr__(name) - def to_dict(self): - config_dict = super().to_dict() - config_dict["model_type"] = self._model_type - del config_dict["_model_type"] - return config_dict + return delattr(self._draft_model, name) diff --git a/tests/unit/models/test_independent_config.py b/tests/unit/models/test_independent_config.py new file mode 100644 index 00000000..d16c5cb6 --- /dev/null +++ b/tests/unit/models/test_independent_config.py @@ -0,0 +1,202 @@ +""" +Unit tests for the eagle model module in the Speculators library. +""" + +import pytest +from pydantic import BaseModel, ValidationError + +from speculators import ( + SpeculatorModelConfig, + SpeculatorsConfig, + VerifierConfig, +) +from speculators.models import IndependentSpeculatorConfig +from speculators.proposals import GreedyTokenProposalConfig + +# ===== Fixtures ===== + + +@pytest.fixture +def sample_verifier_config(): + return VerifierConfig( + name_or_path="test/verifier", + architectures=["LlamaForCausalLM"], + ) + + +@pytest.fixture +def sample_token_proposal_config(): + return GreedyTokenProposalConfig( + speculative_tokens=5, + verifier_accept_k=1, + accept_tolerance=0.0, + ) + + +@pytest.fixture +def sample_speculators_config(sample_token_proposal_config, sample_verifier_config): + return SpeculatorsConfig( + algorithm="independent", + proposal_methods=[sample_token_proposal_config], + default_proposal_method="greedy", + verifier=sample_verifier_config, + ) + + +@pytest.fixture +def independent_config_dict(): + return { + "speculators_model_type": "independent", + "speculators_config": { + "algorithm": "independent", + "proposal_methods": [ + { + "proposal_type": "greedy", + "speculative_tokens": 5, + "verifier_accept_k": 1, + "accept_tolerance": 0.0, + } + ], + "default_proposal_method": "greedy", + "verifier": { + "name_or_path": "test/verifier", + "architectures": ["LlamaForCausalLM"], + "hidden_size": 768, + "intermediate_size": 3072, + "vocab_size": 32000, + "max_position_embeddings": 2048, + "bos_token_id": 1, + "eos_token_id": 2, + }, + }, + } + + +# ===== EagleSpeculatorConfig Tests ===== + + +def test_indepentent_speculator_from_pretrained(): + config = IndependentSpeculatorConfig.from_pretrained( + "meta-llama/Llama-3.2-3B-Instruct" + ) + assert config.model_type == "llama" + assert config.speculators_model_type == "independent" + assert config.speculators_config is None + + +@pytest.mark.smoke +def test_independent_speculator_config_initialization(): + """Test default initialization of IndependentSpeculatorConfig.""" + config = IndependentSpeculatorConfig() + + # Verify Independent-specific defaults + assert config.speculators_model_type == "independent" + + # Verify base class defaults + assert config.model_type == "speculator_model" + assert config.speculators_config is None + + +@pytest.mark.smoke +def test_independent_speculator_config_custom_initialization(sample_speculators_config): + """Test custom initialization of IndependentSpeculatorConfig.""" + config = IndependentSpeculatorConfig(speculators_config=sample_speculators_config) + + # Verify custom values + assert config.speculators_model_type == "independent" + assert config.speculators_config == sample_speculators_config + + +@pytest.mark.smoke +def test_independent_speculator_config_base_initialization(sample_speculators_config): + # Create IndependentSpeculatorConfig with custom values + original_config = IndependentSpeculatorConfig( + speculators_config=sample_speculators_config, + ) + + # Convert to dict and validate through base class + config_dict = original_config.model_dump() + recreated_config = SpeculatorModelConfig.model_validate(config_dict) + + # Verify type and values preservation + assert isinstance(recreated_config, IndependentSpeculatorConfig) + assert recreated_config.speculators_model_type == "independent" + assert recreated_config.speculators_config == sample_speculators_config + + +@pytest.mark.regression +def test_independent_speculator_config_nested_initialization(): + class ParentModel(BaseModel): + single_config: IndependentSpeculatorConfig + config_list: list[IndependentSpeculatorConfig] + config_dict: dict[str, IndependentSpeculatorConfig] + + parent = ParentModel( + single_config=IndependentSpeculatorConfig(), + config_list=[ + IndependentSpeculatorConfig(), + IndependentSpeculatorConfig(), + ], + config_dict={ + "draft1": IndependentSpeculatorConfig(), + "draft2": IndependentSpeculatorConfig(), + }, + ) + + # Verify single config + assert isinstance(parent.single_config, IndependentSpeculatorConfig) + + # Verify config list + assert len(parent.config_list) == 2 + assert all(isinstance(c, IndependentSpeculatorConfig) for c in parent.config_list) + + # Verify config dict + assert len(parent.config_dict) == 2 + assert all( + isinstance(c, IndependentSpeculatorConfig) for c in parent.config_dict.values() + ) + + +@pytest.mark.smoke +def test_independent_speculator_config_invalid_initialization(): + # Test invalid speculators_model_type + with pytest.raises(ValidationError) as exc_info: + IndependentSpeculatorConfig(speculators_model_type="invalid") # type: ignore[arg-type] + assert "speculators_model_type" in str(exc_info.value) + + +@pytest.mark.smoke +def test_independent_speculator_config_auto_registry(): + registered_classes = SpeculatorModelConfig.registered_classes() + class_names = [cls.__name__ for cls in registered_classes] + + # Verify IndependentSpeculatorConfig is registered + assert "IndependentSpeculatorConfig" in class_names + + # Verify registry key mapping + assert SpeculatorModelConfig.registry is not None + assert "independent" in SpeculatorModelConfig.registry + assert SpeculatorModelConfig.registry["independent"] == IndependentSpeculatorConfig + + +@pytest.mark.smoke +def test_independent_speculator_config_marshalling(sample_speculators_config): + original_config = IndependentSpeculatorConfig( + speculators_config=sample_speculators_config, + ) + + # Test model_dump() + config_dict = original_config.model_dump() + assert isinstance(config_dict, dict) + assert config_dict["speculators_model_type"] == "independent" + assert config_dict["speculators_config"] == sample_speculators_config.model_dump() + + # Test model_validate() on base class + recreated_base = SpeculatorModelConfig.model_validate(config_dict) + assert isinstance(recreated_base, IndependentSpeculatorConfig) + assert recreated_base.speculators_config == sample_speculators_config + + # Test model_validate() on derived class + recreated_derived = IndependentSpeculatorConfig.model_validate(config_dict) + assert isinstance(recreated_derived, IndependentSpeculatorConfig) + assert recreated_derived.speculators_config == sample_speculators_config diff --git a/tests/unit/models/test_independent_model.py b/tests/unit/models/test_independent_model.py new file mode 100644 index 00000000..77787e83 --- /dev/null +++ b/tests/unit/models/test_independent_model.py @@ -0,0 +1,282 @@ +""" +Unit tests for the IndependentSpeculator model in the Speculators library. +""" + +from unittest.mock import patch + +import pytest +import torch +from torch import nn +from transformers import PreTrainedModel +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaForCausalLM + +from speculators import SpeculatorsConfig, VerifierConfig +from speculators.models import ( + EagleSpeculatorConfig, + IndependentSpeculator, + IndependentSpeculatorConfig, +) +from speculators.proposals import GreedyTokenProposalConfig + +# ===== Test Helper Classes ===== + + +class MockDraftModel(PreTrainedModel): + """Mock draft model for testing IndependentSpeculator.""" + + def __init__(self, config): + super().__init__(config) + self.config = config + # Use simple linear layers instead of heavy transformer layers + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def forward(self, input_ids, **kwargs): + """Simple forward pass for testing.""" + return type( + "MockOutput", + (), + { + "logits": torch.randn( + input_ids.shape[0], input_ids.shape[1], self.config.vocab_size + ) + }, + )() + + +class MockVerifierModel(PreTrainedModel): + """Mock verifier model for testing IndependentSpeculator.""" + + def __init__(self, config): + super().__init__(config) + self.config = config + # Use simple linear layers instead of heavy transformer layers + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def forward(self, input_ids, **kwargs): + """Simple forward pass for testing.""" + return type( + "MockOutput", + (), + { + "logits": torch.randn( + input_ids.shape[0], input_ids.shape[1], self.config.vocab_size + ) + }, + )() + + +# ===== Fixtures ===== + + +@pytest.fixture +def sample_llama_config(): + """Sample LlamaConfig for testing with small dimensions for speed.""" + return LlamaConfig( + vocab_size=1000, # Much smaller vocab for faster tests + hidden_size=64, # Much smaller hidden size for faster tests + intermediate_size=128, # Much smaller intermediate size + num_hidden_layers=2, # Much fewer layers + num_attention_heads=4, # Fewer attention heads + num_key_value_heads=4, # Fewer key-value heads + hidden_act="silu", + max_position_embeddings=512, # Smaller max position embeddings + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + ) + + +@pytest.fixture +def sample_verifier_config(): + """Sample VerifierConfig for testing.""" + return VerifierConfig( + name_or_path="test/verifier", + architectures=["LlamaForCausalLM"], + ) + + +@pytest.fixture +def sample_token_proposal_config(): + """Sample GreedyTokenProposalConfig for testing.""" + return GreedyTokenProposalConfig() + + +@pytest.fixture +def sample_speculators_config(sample_token_proposal_config, sample_verifier_config): + """Sample SpeculatorsConfig for testing.""" + return SpeculatorsConfig( + algorithm="independent", + proposal_methods=[sample_token_proposal_config], + default_proposal_method="greedy", + verifier=sample_verifier_config, + ) + + +@pytest.fixture +def sample_speculators_config_no_verifier(sample_token_proposal_config): + """Sample SpeculatorsConfig without verifier for testing.""" + return SpeculatorsConfig( + algorithm="independent", + proposal_methods=[sample_token_proposal_config], + default_proposal_method="greedy", + verifier=VerifierConfig( + name_or_path=None, + architectures=["LlamaForCausalLM"], + ), + ) + + +@pytest.fixture +def independent_speculator_config(sample_speculators_config, sample_llama_config): + """Sample IndependentSpeculatorConfig for testing.""" + return IndependentSpeculatorConfig.from_pretrained_config( + pretrained_config=sample_llama_config, + speculators_config=sample_speculators_config, + ) + + +@pytest.fixture +def independent_speculator_config_no_verifier( + sample_speculators_config_no_verifier, sample_llama_config +): + """Sample IndependentSpeculatorConfig without verifier for testing.""" + return IndependentSpeculatorConfig.from_pretrained_config( + pretrained_config=sample_llama_config, + speculators_config=sample_speculators_config_no_verifier, + ) + + +@pytest.fixture +def mock_draft_model(sample_llama_config): + """Mock draft model for testing.""" + return MockDraftModel(sample_llama_config) + + +@pytest.fixture +def mock_verifier_model(sample_llama_config): + """Mock verifier model for testing.""" + return MockVerifierModel(sample_llama_config) + + +# ===== IndependentSpeculator Instantiation Tests ===== + + +@pytest.mark.smoke +def test_independent_speculator_instantiation_without_verifier( + independent_speculator_config_no_verifier, mock_draft_model +): + """Test IndependentSpeculator instantiation without verifier.""" + model = IndependentSpeculator( + config=independent_speculator_config_no_verifier, + verifier=None, + verifier_attachment_mode="detached", + ) + + # Verify model was created successfully + assert isinstance(model, IndependentSpeculator) + assert model.config == independent_speculator_config_no_verifier + assert model._draft_model is not None + assert model._draft_model.config == independent_speculator_config_no_verifier + assert model.verifier is None + assert model.verifier_attachment_mode == "detached" + + +@pytest.mark.smoke +def test_independent_speculator_instantiation_with_verifier_instance( + independent_speculator_config, mock_draft_model, mock_verifier_model +): + """Test IndependentSpeculator instantiation with verifier PreTrainedModel.""" + model = IndependentSpeculator( + config=independent_speculator_config, + verifier=mock_verifier_model, + verifier_attachment_mode="full", + ) + + # Verify model was created successfully + assert isinstance(model, IndependentSpeculator) + assert model.config == independent_speculator_config + assert isinstance(model._draft_model, LlamaForCausalLM) + assert model.verifier == mock_verifier_model + assert model.verifier_attachment_mode == "full" + + +@pytest.mark.smoke +def test_independent_speculator_instantiation_train_only_mode( + independent_speculator_config, mock_draft_model, mock_verifier_model +): + """Test IndependentSpeculator instantiation with train_only attachment mode.""" + model = IndependentSpeculator( + config=independent_speculator_config, + verifier=mock_verifier_model, + verifier_attachment_mode="train_only", + ) + + # Verify model was created successfully + assert isinstance(model, IndependentSpeculator) + assert model.verifier_attachment_mode == "train_only" + + +@pytest.mark.smoke +def test_independent_speculator_instantiation_with_auto_verifier_from_config( + independent_speculator_config, mock_verifier_model +): + """Test IndependentSpeculator instantiation with verifier loaded from config.""" + with patch( + "transformers.AutoModelForCausalLM.from_pretrained", + side_effect=[mock_verifier_model], + ): + model = IndependentSpeculator( + config=independent_speculator_config, + verifier=None, # Should load from config + verifier_attachment_mode="full", + ) + + # Verify model was created successfully + assert isinstance(model, IndependentSpeculator) + assert model.config == independent_speculator_config + assert model.verifier == mock_verifier_model + assert model.verifier_attachment_mode == "full" + + +# ===== IndependentSpeculator Error Cases Tests ===== + + +@pytest.mark.sanity +def test_independent_speculator_instantiation_invalid_config(): + """Test IndependentSpeculator instantiation with invalid config.""" + with pytest.raises( + ValueError, match="Attempted to initialize a IndependentSpeculator with a" + ): + IndependentSpeculator( + config="invalid_config", # type: ignore[arg-type] + verifier=None, + verifier_attachment_mode="detached", + ) + + +@pytest.mark.sanity +def test_independent_speculator_instantiation_wrong_config_type( + sample_speculators_config, +): + """Test IndependentSpeculator instantiation with wrong config type.""" + + eagle_config = EagleSpeculatorConfig( + transformer_layer_config=LlamaConfig(), + speculators_config=sample_speculators_config, + ) + + with pytest.raises( + ValueError, match="Attempted to initialize a IndependentSpeculator with a" + ): + IndependentSpeculator( + config=eagle_config, # type: ignore[arg-type] + verifier=None, + verifier_attachment_mode="detached", + )