diff --git a/src/speculators/__main__.py b/src/speculators/__main__.py index 95ea22cd..bdc74f8b 100644 --- a/src/speculators/__main__.py +++ b/src/speculators/__main__.py @@ -3,24 +3,22 @@ This module provides a command-line interface for creating and managing speculative decoding models. The CLI is built using Typer and provides commands for model -conversion, version information, and other utilities. - -The CLI can be accessed through the `speculators` command after installation, or by -running this module directly with `python -m speculators`. - -Commands: - convert: Convert models from external repos/formats to supported Speculators models - version: Display the current version of the Speculators library - -Usage: - $ speculators --help - $ speculators --version - $ speculators convert [OPTIONS] +conversion, version information, and other utilities. It serves as the primary +entry point for users to interact with the Speculators library from the command line. + +Example: +:: + speculators --help + speculators convert "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" \ + --algorithm eagle \ + --verifier "meta-llama/Llama-3.1-8B-Instruct" """ +from __future__ import annotations + import json from importlib.metadata import version as pkg_version -from typing import Annotated, Any, Optional +from typing import Annotated, Any, Literal, cast import click import typer # type: ignore[import-not-found] @@ -37,16 +35,11 @@ ) -def version_callback(value: bool): +def version_callback(value: bool) -> None: """ - Callback function to print the version of the Speculators package and exit. - - This function is used as a callback for the --version option in the main CLI. - When the version option is specified, it prints the version information and - exits the application. + Print the Speculators package version and exit. - :param value: Boolean indicating whether the version option was specified. - If True, prints version and exits. + :param value: Whether the version option was specified """ if value: typer.echo(f"speculators version: {pkg_version('speculators')}") @@ -65,12 +58,8 @@ def speculators( """ Main entry point for the Speculators CLI application. - This function serves as the root command callback and handles global options - such as version display. It is automatically called by Typer when the CLI - is invoked. - - :param ctx: The Typer context object containing runtime information. - :param version: Boolean option to display version information and exit. + :param ctx: Typer context object containing runtime information + :param version: Option to display version information and exit """ @@ -79,8 +68,12 @@ def convert( model: Annotated[ str, typer.Argument(help="Model checkpoint or Hugging Face model ID to convert") ], + output_path: Annotated[ + str, typer.Option(help="Directory path where converted model will be saved") + ] = "converted", + config: str | None = None, verifier: Annotated[ - str, + str | None, typer.Option( "--verifier", help=( @@ -88,22 +81,9 @@ def convert( "to attach as the verification/base model for speculative decoding" ), ), - ], - algorithm: Annotated[ - str, - typer.Option( - help=( - "The source repo/algorithm to convert from into the matching algorithm " - "in Speculators" - ), - click_type=click.Choice(["eagle", "eagle3"]), - ), - ], - output_path: Annotated[ - str, typer.Option(help="Directory path where converted model will be saved") - ] = "converted", + ] = None, validate_device: Annotated[ - Optional[str], + str | None, typer.Option( help=( "Device to validate the model on (e.g. 'cuda:0') " @@ -111,8 +91,18 @@ def convert( ), ), ] = None, + algorithm: Annotated[ + str, + typer.Option( + help=( + "The source repo/algorithm to convert from into the matching algorithm " + "in Speculators" + ), + click_type=click.Choice(["auto", "eagle", "eagle2", "hass"]), + ), + ] = "auto", algorithm_kwargs: Annotated[ - Optional[dict[str, Any]], + dict[str, Any] | None, typer.Option( parser=json.loads, help=( @@ -122,52 +112,59 @@ def convert( ), ), ] = None, -): + cache_dir: str | None = None, + force_download: bool = False, + local_files_only: bool = False, + token: str | None = None, + revision: str | None = None, +) -> None: """ - Convert models from external research repositories or formats - into the standardized Speculators format for use within the Speculators - framework, Hugging Face model hub compatability, and deployment with vLLM. - Supported algorithms, repositories, and examples given below. - - \b - algorithm=="eagle": - Eagle v1, v2: https://github.com/SafeAILab/EAGLE - HASS: https://github.com/HArmonizedSS/HASS - :: - # general - speculators convert "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" \\ - --algorithm eagle \\ + Convert models from external research repositories into Speculators format. + + Converts models from research implementations (EAGLE, HASS) into standardized + Speculators format for use with Hugging Face, vLLM, and the Speculators framework. + + [EAGLE v1, v2](https://github.com/SafeAILab/EAGLE), + and [HASS](https://github.com/HArmonizedSS/HASS) Example: + :: + speculators convert "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" \ --verifier "meta-llama/Llama-3.1-8B-Instruct" + # with layernorms and fusion bias enabled - speculators convert "./eagle/checkpoint" \\ - --algorithm eagle \\ - --algorithm-kwargs '{"layernorms": true, "fusion_bias": true}' \\ + speculators convert "./eagle/checkpoint" \ + --algorithm-kwargs '{"layernorms": true, "fusion_bias": true}' \ --verifier "meta-llama/Llama-3.1-8B-Instruct" - \b - algorithm=="eagle3": - Eagle v3: https://github.com/SafeAILab/EAGLE - :: - # general - speculators convert "./eagle/checkpoint" \\ - --algorithm eagle3 + # eagle3 with normalization before the residual + --algorithm-kwargs '{"norm_before_residual": true}' \ --verifier "meta-llama/Llama-3.1-8B-Instruct" - # with normalization before the residual - speculators convert "./eagle/checkpoint" \\ - --algorithm eagle3 - --algorithm-kwargs '{"norm_before_residual": true}' - --verifier "meta-llama/Llama-3.1-8B-Instruct" - """ - if not algorithm_kwargs: - algorithm_kwargs = {} + :param model: Model checkpoint path or Hugging Face model ID to convert + :param output_path: Directory path where converted model will be saved + :param config: Optional config path, model ID, or config instance + :param verifier: Optional verifier model for speculative decoding + :param validate_device: Optional device for post-conversion validation + :param algorithm: Source algorithm to convert from (auto, eagle, eagle2, hass) + :param algorithm_kwargs: Additional conversion algorithm keyword arguments + :param cache_dir: Optional directory for caching downloaded model files + :param force_download: Force re-downloading files even if cached + :param local_files_only: Use only local files without downloading from hub + :param token: Optional Hugging Face authentication token for private models + :param revision: Optional Git revision for downloading from Hugging Face hub + """ convert_model( model=model, - verifier=verifier, output_path=output_path, + config=config, + verifier=verifier, validate_device=validate_device, - algorithm=algorithm, # type: ignore[arg-type] - **algorithm_kwargs, + algorithm=cast('Literal["auto", "eagle", "eagle2", "hass"]', algorithm), + algorithm_kwargs=algorithm_kwargs or {}, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, ) diff --git a/src/speculators/convert/__init__.py b/src/speculators/convert/__init__.py index 3e222918..57e35740 100644 --- a/src/speculators/convert/__init__.py +++ b/src/speculators/convert/__init__.py @@ -9,6 +9,8 @@ - HASS: https://github.com/HArmonizedSS/HASS """ +from .converters import SpeculatorConverter +from .eagle import Eagle3Converter, EagleConverter from .entrypoints import convert_model -__all__ = ["convert_model"] +__all__ = ["Eagle3Converter", "EagleConverter", "SpeculatorConverter", "convert_model"] diff --git a/src/speculators/convert/converters/__init__.py b/src/speculators/convert/converters/__init__.py new file mode 100644 index 00000000..bb1cad78 --- /dev/null +++ b/src/speculators/convert/converters/__init__.py @@ -0,0 +1,14 @@ +""" +Registry-based converter architecture for transforming external checkpoints. + +This module provides the converter framework for standardizing external research model +checkpoints into the Speculators format. The converter system uses a registry pattern +to automatically detect and instantiate appropriate converters based on algorithm type +and model characteristics, supporting extensible conversion workflows with validation. +""" + +from __future__ import annotations + +from .base import SpeculatorConverter + +__all__ = ["SpeculatorConverter"] diff --git a/src/speculators/convert/converters/base.py b/src/speculators/convert/converters/base.py new file mode 100644 index 00000000..2b60c9bd --- /dev/null +++ b/src/speculators/convert/converters/base.py @@ -0,0 +1,202 @@ +""" +Abstract base converter for transforming external checkpoints to Speculators format. + +This module provides the registry-based converter architecture for standardizing +external research model checkpoints into the Speculators format. The converter +system supports automatic algorithm detection, extensible conversion workflows, +and validation for various speculative decoding implementations. +""" + +from __future__ import annotations + +import os +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Generic, TypeVar + +from torch import Tensor, device, nn +from transformers import PretrainedConfig, PreTrainedModel + +from speculators.config import SpeculatorModelConfig +from speculators.model import SpeculatorModel +from speculators.utils import RegistryMixin + +__all__ = ["ConfigT", "ModelT", "SpeculatorConverter"] + + +ConfigT = TypeVar("ConfigT", bound=SpeculatorModelConfig) +"""Generic type variable for speculator model configs""" +ModelT = TypeVar("ModelT", bound=SpeculatorModel) +"""Generic type variable for speculator models""" + + +class SpeculatorConverter(ABC, RegistryMixin, Generic[ConfigT, ModelT]): + """ + Abstract base converter for transforming external checkpoints to Speculators format. + + Provides a registry-based system for algorithm-specific converters with automatic + detection capabilities. The converter handles the complete transformation pipeline + including configuration translation, state dict conversion, model instantiation, + and validation for various speculative decoding implementations. + + Example: + :: + # Resolve converter automatically + converter_cls = SpeculatorConverter.resolve_converter( + algorithm="auto", model="path/to/model", config="path/to/config" + ) + + # Create and execute conversion + converter = converter_cls(model, config, verifier=None) + model = converter(output_path="converted_model", validate_device="cuda") + """ + + @classmethod + def resolve_converter( + cls, + algorithm: str, + model: str | Path | PreTrainedModel | nn.Module, + config: str | Path | PretrainedConfig | dict, + verifier: str | os.PathLike | PreTrainedModel | None = None, + **kwargs, + ) -> type[SpeculatorConverter]: + """ + Resolve the appropriate converter class for the specified algorithm. + + Supports automatic algorithm detection when algorithm="auto" by testing each + registered converter's `is_supported` method against the provided inputs. + + :param algorithm: Conversion algorithm name or "auto" for automatic detection + :param model: Model to convert (path, HF model ID, or PreTrainedModel instance) + :param config: Model configuration (path, HF model ID, or PretrainedConfig) + :param verifier: Optional verifier model for speculative decoding attachment + :param kwargs: Additional arguments passed to `is_supported` for auto detection + :return: Converter class for the specified or detected algorithm + :raises ValueError: If algorithm is not registered or no supported converter + found + """ + if cls.registry is None: + raise ValueError( + "No converters registered. Please ensure that the SpeculatorConverter " + "subclass has registered converters using the @register decorator." + ) + + algorithm = algorithm.lower() + + if algorithm != "auto": + if algorithm not in cls.registry: + raise ValueError( + f"Algorithm '{algorithm}' is not registered. " + f"Available algorithms: {', '.join(cls.registry.keys())}" + ) + return cls.registry[algorithm] # type: ignore[return-value] + + for _, converter in cls.registry.items(): + if converter.is_supported(model, config, verifier, **kwargs): + return converter # type: ignore[return-value] + + raise ValueError( + f"No supported converter found for model {model} with config {config}. " + f"Available algorithms: {', '.join(cls.registry.keys())}" + ) + + @classmethod + @abstractmethod + def is_supported( + cls, + model: str | Path | PreTrainedModel | nn.Module, + config: str | Path | PretrainedConfig | dict, + verifier: str | os.PathLike | PreTrainedModel | None = None, + **kwargs, + ) -> bool: + """ + Check if this converter supports the given model and configuration. + + :param model: Model to check (path, HF model ID, or PreTrainedModel instance) + :param config: Model configuration (path, HF model ID, or PretrainedConfig) + :param verifier: Optional verifier model for compatibility validation + :param kwargs: Additional arguments for algorithm-specific checks + :return: True if the converter supports the model and config + """ + ... + + def __init__( + self, + model: str | Path | PreTrainedModel | nn.Module, + config: str | Path | PretrainedConfig | dict, + verifier: str | os.PathLike | PreTrainedModel | None, + ): + """ + Initialize the converter with model, configuration, and optional verifier. + + :param model: Model to convert (path, HF model ID, or PreTrainedModel instance) + :param config: Model configuration (path, HF model ID, or PretrainedConfig) + :param verifier: Optional verifier model for speculative decoding attachment + :raises ValueError: If model or config is None or empty + """ + + if model is None or config is None or model == "" or config == "": + raise ValueError( + f"Model and config paths must be provided, got {model}, {config}" + ) + + self.model = model + self.config = config + self.verifier = verifier + + def __call__( + self, + output_path: str | os.PathLike | None = None, + validate_device: str | device | int | None = None, + ) -> ModelT: + """ + Execute the complete conversion pipeline to Speculators format. + + Converts configuration and state dict, instantiates the model, optionally + saves to disk, and validates on the specified device. + + :param output_path: Optional directory path to save the converted model + :param validate_device: Optional device for post-conversion validation + :return: Converted Speculators model instance + """ + config, state_dict = self.convert_config_state_dict() + model: ModelT = SpeculatorModel.from_pretrained( # type: ignore[assignment] + pretrained_model_name_or_path=None, + config=config, + state_dict=state_dict, + verifier=self.verifier, + verifier_attachment_mode="full", + ) + if output_path: + self.save(model, output_path) + if validate_device: + self.validate(model, validate_device) + return model + + def save(self, model: ModelT, output_path: str | os.PathLike): + """ + Save the converted model to the specified directory. + + :param model: Converted Speculators model to save + :param output_path: Directory path where the model will be saved + """ + model.save_pretrained(output_path) # type: ignore[attr-defined] + + @abstractmethod + def convert_config_state_dict(self) -> tuple[ConfigT, dict[str, Tensor]]: + """ + Convert model configuration and state dict to Speculators format. + + :return: Tuple of (converted configuration, converted state dict) + """ + ... + + @abstractmethod + def validate(self, model: ModelT, device: str | device | int): + """ + Validate the converted model on the specified device. + + :param model: Converted Speculators model to validate + :param device: Device for validation (string, torch.device, or device index) + """ + ... diff --git a/src/speculators/convert/eagle/eagle3_converter.py b/src/speculators/convert/eagle/eagle3_converter.py index b700898a..4e0d50db 100644 --- a/src/speculators/convert/eagle/eagle3_converter.py +++ b/src/speculators/convert/eagle/eagle3_converter.py @@ -2,14 +2,16 @@ Eagle-3 checkpoint converter with loguru logging. """ +import os from pathlib import Path -from typing import Optional, Union +from typing import Optional, Union, cast import torch from loguru import logger from transformers import AutoModelForCausalLM, LlamaConfig, PretrainedConfig from speculators.config import SpeculatorsConfig, VerifierConfig +from speculators.convert.converters import SpeculatorConverter from speculators.models.eagle3 import Eagle3Speculator, Eagle3SpeculatorConfig from speculators.proposals.greedy import GreedyTokenProposalConfig from speculators.utils import ( @@ -20,6 +22,45 @@ __all__ = ["Eagle3Converter"] +@SpeculatorConverter.register(["eagle3"]) +class Eagle3SpeculatorConverter(SpeculatorConverter): + """ + Intermediate patch for Eagle converter to maintain backward compatibility. + """ + + @classmethod + def is_supported(cls, **_kwargs) -> bool: # type: ignore[override] + return False # Disable auto-detection until eagle3 is refactored + + def __init__( + self, + norm_before_residual: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.norm_before_residual = norm_before_residual + + def __call__( + self, + output_path: Union[str, os.PathLike, None] = None, + validate_device: Optional[Union[str, torch.device, int]] = None, + ): + converter = Eagle3Converter() + converter.convert( + input_path=cast("Union[str, Path]", self.model), + output_path=str(output_path) or "./converted_eagle_speculator", + base_model=self.verifier if isinstance(self.verifier, str) else "", + validate=bool(validate_device), + norm_before_residual=self.norm_before_residual, + ) + + def convert_config_state_dict(self): + pass # No-op til eagle3 is refactored + + def validate(self, **_kwargs): # type: ignore[override] + pass # No-op til eagle3 is refactored + + class Eagle3Converter: """ Converter for Eagle3 checkpoints to speculators format. diff --git a/src/speculators/convert/eagle/eagle_converter.py b/src/speculators/convert/eagle/eagle_converter.py index c3ae21ae..4de49936 100644 --- a/src/speculators/convert/eagle/eagle_converter.py +++ b/src/speculators/convert/eagle/eagle_converter.py @@ -2,14 +2,16 @@ Eagle checkpoint converter with loguru logging. """ +import os from pathlib import Path -from typing import Optional, Union +from typing import Optional, Union, cast import torch from loguru import logger from transformers import LlamaConfig, PretrainedConfig from speculators.config import SpeculatorsConfig, VerifierConfig +from speculators.convert.converters import SpeculatorConverter from speculators.models.eagle import EagleSpeculator, EagleSpeculatorConfig from speculators.proposals.greedy import GreedyTokenProposalConfig from speculators.utils import ( @@ -17,7 +19,37 @@ load_model_checkpoint_state_dict, ) -__all__ = ["EagleConverter"] +__all__ = ["EagleConverter", "EagleSpeculatorConverter"] + + +@SpeculatorConverter.register(["eagle", "hass", "eagle2"]) +class EagleSpeculatorConverter(SpeculatorConverter): + """ + Intermediate patch for Eagle converter to maintain backward compatibility. + """ + + @classmethod + def is_supported(cls, **_kwargs) -> bool: # type: ignore[override] + return False # Disable auto-detection until eagle is refactored + + def __call__( + self, + output_path: Union[str, os.PathLike, None] = None, + validate_device: Optional[Union[str, torch.device, int]] = None, + ): + converter = EagleConverter() + converter.convert( + input_path=cast("Union[str, Path]", self.model), + output_path=str(output_path) or "./converted_eagle_speculator", + base_model=self.verifier if isinstance(self.verifier, str) else "", + validate=bool(validate_device), + ) + + def convert_config_state_dict(self): + pass # No-op til eagle is refactored + + def validate(self, **_kwargs): # type: ignore[override] + pass # No-op til eagle is refactored def detect_fusion_bias_and_layernorms( diff --git a/src/speculators/convert/entrypoints.py b/src/speculators/convert/entrypoints.py index 7181ab5d..4f9e5b55 100644 --- a/src/speculators/convert/entrypoints.py +++ b/src/speculators/convert/entrypoints.py @@ -1,99 +1,142 @@ """ -Provides the entry points for converting non-speculators model checkpoints to -Speculators model format with the `convert_model` function. - -It supports the following algorithms and conversion from their associated -research repositories: -- EAGLE -- EAGLE2 -- EAGLE3 -- HASS - -Functions: - convert_model: Converts a model checkpoint to the Speculators format. +Entry points for converting non-Speculators model checkpoints to Speculators format. + +Provides the primary conversion interface through the `convert_model` function, which +supports various input formats including local checkpoints, Hugging Face model IDs, +and PyTorch module instances. Converts models from research implementations (EAGLE, +EAGLE2, HASS) into standardized Speculators format with optional verifier attachment +and validation capabilities. """ -from typing import Literal, Optional +from __future__ import annotations + +import os +from pathlib import Path +from typing import Literal + +import torch +from loguru import logger +from torch import nn +from transformers import PretrainedConfig, PreTrainedModel -from speculators.convert.eagle.eagle3_converter import Eagle3Converter -from speculators.convert.eagle.eagle_converter import EagleConverter +from speculators.convert.converters import SpeculatorConverter +from speculators.model import SpeculatorModel +from speculators.utils import ( + check_download_model_checkpoint, + check_download_model_config, +) __all__ = ["convert_model"] def convert_model( - model: str, - verifier: str, - algorithm: Literal["eagle", "eagle3"], - output_path: str = "converted", - validate_device: Optional[str] = None, + model: str | os.PathLike | PreTrainedModel | nn.Module, + output_path: str | os.PathLike | None = None, + config: str | os.PathLike | PreTrainedModel | PretrainedConfig | dict | None = None, + verifier: str | os.PathLike | PreTrainedModel | None = None, + validate_device: str | torch.device | int | None = None, + algorithm: Literal["auto", "eagle", "eagle2", "hass"] = "auto", + algorithm_kwargs: dict | None = None, + cache_dir: str | Path | None = None, + force_download: bool = False, + local_files_only: bool = False, + token: str | bool | None = None, + revision: str | None = None, **kwargs, -): +) -> SpeculatorModel: """ - Convert a non speculator's model checkpoint to a speculator's model checkpoint - for use within the Speculators library, Hugging Face Hub, or vLLM. - - algorithm=="eagle": - Eagle v1, v2: https://github.com/SafeAILab/EAGLE - HASS: https://github.com/HArmonizedSS/HASS - :: - # general - convert_model( - model="yuhuili/EAGLE-LLaMA3.1-Instruct-8B", - verifier="meta-llama/Llama-3.1-8B-Instruct", - algorithm="eagle", - ) - # with layernorms and fusion bias enabled - convert_model( - model="./eagle/checkpoint", - verifier="meta-llama/Llama-3.1-8B-Instruct", - algorithm="eagle", - layernorms=True, - fusion_bias=True, - ) + Convert a non-Speculators model checkpoint to Speculators format. - algorithm=="eagle3": - Eagle v3: https://github.com/SafeAILab/EAGLE - :: - # general - convert_model( - model="./eagle/checkpoint", - verifier="meta-llama/Llama-3.1-8B-Instruct", - algorithm="eagle3", - ) - # with normalization before the residual - convert_model( - model="./eagle/checkpoint", - verifier="meta-llama/Llama-3.1-8B-Instruct", - algorithm="eagle3", - norm_before_residual=True, + Supports model instances, local Hugging Face checkpoints, and Hugging Face hub + model IDs. Optional verifier attachment and validation capabilities are provided + for enhanced model functionality. + + Example: + :: + from speculators.convert import convert_model + + speculator_model = convert_model( + model="./my_checkpoint", + output_path="./converted_speculator_model", + algorithm="eagle", + verifier="./my_verifier_checkpoint", ) - :param model: Path to the input model checkpoint or Hugging Face model ID. - :param verifier: Verifier model checkpoint or Hugging Face model ID - to attach as the verification/base model for speculative decoding - :param algorithm: The conversion algorithm to use, either "eagle" or "eagle3". - :param output_path: Directory path where the converted model will be saved. - :param kwargs: Additional keyword arguments for the conversion algorithm. - Options for Eagle: {"layernorms": true, "fusion_bias": true}. - Options for Eagle3: {"norm_before_residual": true}. + :param model: Path to checkpoint directory, Hugging Face model ID, or + PreTrainedModel instance to convert + :param output_path: Optional path to save the converted model + :param config: Optional config path, model ID, or config instance. If not + provided, inferred from model checkpoint + :param verifier: Optional verifier checkpoint path, model ID, or instance to + attach to the converted model + :param validate_device: Optional device for post-conversion validation + :param algorithm: Conversion algorithm - "auto", "eagle", "eagle2", or "hass" + :param algorithm_kwargs: Optional keyword arguments for the conversion algorithm + :param cache_dir: Optional directory for caching downloaded model files + :param force_download: Force re-downloading files even if cached + :param local_files_only: Use only local files without downloading from hub + :param token: Optional Hugging Face authentication token for private models + :param revision: Optional Git revision for downloading from Hugging Face hub + :param kwargs: Additional keyword arguments for model and config download + :return: The converted speculator model instance + :raises ValueError: When config is required but not provided for nn.Module input """ + logger.info(f"Converting model {model} to the Speculators format...") - if algorithm == "eagle": - EagleConverter().convert( - model, - output_path, - verifier, - validate=validate_device is not None, - **kwargs, - ) - elif algorithm == "eagle3": - Eagle3Converter().convert( - model, - output_path, - verifier, - validate=validate_device is not None, - **kwargs, - ) - else: - raise ValueError(f"Unsupported algorithm: {algorithm}") + model = check_download_model_checkpoint( + model, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + **kwargs, + ) + logger.info(f"Resolved the model checkpoint: {model}") + + if not config: + # Use model as config if not provided + if isinstance(model, nn.Module): + raise ValueError( + "A model config must be provided when converting " + "a PyTorch nn.Module instance." + ) + config = model + + config = check_download_model_config( + config, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + **kwargs, + ) + logger.info(f"Resolved the model config: {config}") + + if not algorithm_kwargs: + algorithm_kwargs = {} + + converter_class = SpeculatorConverter.resolve_converter( + algorithm, + model=model, + config=config, + verifier=verifier, + **algorithm_kwargs, + ) + logger.info(f"Beginning conversion with Converter: {converter_class}") + + converter = converter_class( + model=model, + config=config, + verifier=verifier, + **algorithm_kwargs, + ) + + converted = converter( + output_path=output_path, + validate_device=validate_device, + ) + logger.info(f"Conversion complete: {converted}") + + return converted diff --git a/tests/unit/convert/converters/__init__.py b/tests/unit/convert/converters/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/convert/converters/test_base.py b/tests/unit/convert/converters/test_base.py new file mode 100644 index 00000000..9867d83e --- /dev/null +++ b/tests/unit/convert/converters/test_base.py @@ -0,0 +1,397 @@ +""" +Unit tests for the base converter module in the Speculators library. +""" + +from __future__ import annotations + +import os +import tempfile +from abc import ABC +from pathlib import Path +from typing import Generic, TypeVar +from unittest.mock import MagicMock, patch + +import pytest +import torch +from torch import Tensor, device, nn +from transformers import PretrainedConfig, PreTrainedModel + +from speculators import SpeculatorModel, SpeculatorModelConfig +from speculators.convert import SpeculatorConverter +from speculators.convert.converters.base import ConfigT, ModelT +from speculators.utils import RegistryMixin + +__all__ = ["ConfigT", "ModelT", "SpeculatorConverter"] + + +@pytest.fixture +def mock_model(): + """Mock model for testing.""" + model = MagicMock(spec=PreTrainedModel) + model.config = MagicMock(spec=PretrainedConfig) + return model + + +@pytest.fixture +def mock_config(): + """Mock configuration for testing.""" + config = MagicMock(spec=PretrainedConfig) + config.to_dict.return_value = {"model_type": "test_model"} + return config + + +@pytest.fixture +def mock_verifier(): + """Mock verifier for testing.""" + verifier = MagicMock(spec=PreTrainedModel) + verifier.config = MagicMock(spec=PretrainedConfig) + return verifier + + +@pytest.fixture +def temp_directory(): + """Temporary directory for testing file operations.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield temp_dir + + +def test_config_type(): + """Test that ConfigT is configured correctly as a TypeVar.""" + assert isinstance(ConfigT, type(TypeVar("test"))) + assert ConfigT.__name__ == "ConfigT" + assert ConfigT.__bound__ is SpeculatorModelConfig + assert ConfigT.__constraints__ == () + + +def test_model_type(): + """Test that ModelT is configured correctly as a TypeVar.""" + assert isinstance(ModelT, type(TypeVar("test"))) + assert ModelT.__name__ == "ModelT" + assert ModelT.__bound__ is SpeculatorModel + assert ModelT.__constraints__ == () + + +class MockSpeculatorConverterImpl(SpeculatorConverter): + """Test implementation of SpeculatorConverter for unit testing.""" + + @classmethod + def is_supported( + cls, + model: str | Path | PreTrainedModel | nn.Module, + config: str | Path | PretrainedConfig | dict, + verifier: str | os.PathLike | PreTrainedModel | None = None, + **kwargs, + ) -> bool: + """Test implementation that always returns True.""" + return True + + def convert_config_state_dict( + self, + ) -> tuple[SpeculatorModelConfig, dict[str, Tensor]]: + """Test implementation that returns mock config and state dict.""" + mock_config = MagicMock(spec=SpeculatorModelConfig) + mock_state_dict = {"test_param": torch.tensor([1.0, 2.0, 3.0])} + + return mock_config, mock_state_dict + + def validate(self, model: SpeculatorModel, device: str | device | int): + """Test implementation that does nothing.""" + + +class MockSpeculatorConverterUnsupported(SpeculatorConverter): + """Test implementation that is never supported.""" + + @classmethod + def is_supported( + cls, + model: str | Path | PreTrainedModel | nn.Module, + config: str | Path | PretrainedConfig | dict, + verifier: str | os.PathLike | PreTrainedModel | None = None, + **kwargs, + ) -> bool: + """Test implementation that always returns False.""" + return False + + def convert_config_state_dict( + self, + ) -> tuple[SpeculatorModelConfig, dict[str, Tensor]]: + """Test implementation that returns mock config and state dict.""" + mock_config = MagicMock(spec=SpeculatorModelConfig) + mock_state_dict = {"test_param": torch.tensor([1.0, 2.0, 3.0])} + return mock_config, mock_state_dict + + def validate(self, model: SpeculatorModel, device: str | device | int): + """Test implementation that does nothing.""" + + +class TestSpeculatorConverter: + """Test class for SpeculatorConverter functionality.""" + + def setup_method(self): + """Set up test fixtures before each test method.""" + # Store the original registry and clear it for this test + self._original_registry = SpeculatorConverter.registry # type: ignore[misc] + SpeculatorConverter.registry = None # type: ignore[misc] + + def teardown_method(self): + """Clean up after each test method.""" + # Restore the original registry + SpeculatorConverter.registry = self._original_registry # type: ignore[misc] + + @pytest.mark.smoke + def test_class_signatures(self): + """Test SpeculatorConverter inheritance and type relationships.""" + assert issubclass(SpeculatorConverter, ABC) + assert issubclass(SpeculatorConverter, Generic) # type: ignore[arg-type] + assert issubclass(SpeculatorConverter, RegistryMixin) + + # Test class methods + assert hasattr(SpeculatorConverter, "resolve_converter") + assert callable(SpeculatorConverter.resolve_converter) + assert hasattr(SpeculatorConverter, "is_supported") + + # Test instance methods + assert hasattr(SpeculatorConverter, "__init__") + assert callable(SpeculatorConverter) + assert hasattr(SpeculatorConverter, "save") + assert hasattr(SpeculatorConverter, "convert_config_state_dict") + assert hasattr(SpeculatorConverter, "validate") + assert callable(SpeculatorConverter) + + # Test abstract methods can be called on concrete implementations + mock_converter = MockSpeculatorConverterImpl( + model=MagicMock(), config=MagicMock(), verifier=None + ) + + # Test is_supported method signature + assert ( + MockSpeculatorConverterImpl.is_supported( + model="test", config={}, verifier=None + ) + is True + ) + + # Test convert_config_state_dict method signature + config, state_dict = mock_converter.convert_config_state_dict() + assert config is not None + assert isinstance(state_dict, dict) + + # Test validate method signature + mock_model = MagicMock(spec=SpeculatorModel) + mock_converter.validate(mock_model, "cpu") + + @pytest.mark.smoke + @pytest.mark.parametrize( + "verifier", + [None, "mock_verifier"], + ) + def test_initialization(self, mock_model, mock_config, mock_verifier, verifier): + """Test SpeculatorConverter initialization.""" + actual_verifier = mock_verifier if verifier == "mock_verifier" else verifier + instance = MockSpeculatorConverterImpl(mock_model, mock_config, actual_verifier) + assert isinstance(instance, SpeculatorConverter) + assert instance.model is mock_model + assert instance.config is mock_config + assert instance.verifier is actual_verifier + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("model", "config"), + [ + (None, "valid_config"), + ("valid_model", None), + ("", "valid_config"), + ("valid_model", ""), + ], + ) + def test_invalid_initialization_values( + self, mock_model, mock_config, model, config + ): + """Test SpeculatorConverter with invalid field values.""" + actual_model = mock_model if model == "valid_model" else model + actual_config = mock_config if config == "valid_config" else config + + with pytest.raises(ValueError) as exc_info: + MockSpeculatorConverterImpl(actual_model, actual_config, None) + + assert "Model and config paths must be provided" in str(exc_info.value) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self, mock_config): + """Test SpeculatorConverter initialization without required model.""" + with pytest.raises(TypeError): + MockSpeculatorConverterImpl(config=mock_config, verifier=None) # type: ignore[call-arg] + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("output_path", "validate_device", "should_save", "should_validate"), + [ + (None, None, False, False), + ("output", None, True, False), + (None, "cuda", False, True), + ("output", "cuda", True, True), + ], + ) + def test_call_invocation( + self, + mock_model, + mock_config, + temp_directory, + output_path, + validate_device, + should_save, + should_validate, + ): + """Test SpeculatorConverter call with various parameter combinations.""" + converter = MockSpeculatorConverterImpl(mock_model, mock_config, None) + + if output_path: + output_path = Path(temp_directory) / output_path + + with patch.object(SpeculatorModel, "from_pretrained") as mock_from_pretrained: + mock_speculator = MagicMock(spec=SpeculatorModel) + mock_speculator.save_pretrained = MagicMock() + mock_from_pretrained.return_value = mock_speculator + + with patch.object(converter, "validate") as mock_validate: + result = converter( + output_path=output_path, validate_device=validate_device + ) + + assert result is mock_speculator + mock_from_pretrained.assert_called_once() + + if should_save: + mock_speculator.save_pretrained.assert_called_once_with(output_path) + else: + mock_speculator.save_pretrained.assert_not_called() + + if should_validate: + mock_validate.assert_called_once_with( + mock_speculator, validate_device + ) + else: + mock_validate.assert_not_called() + + @pytest.mark.sanity + def test_call_with_none_output_and_device(self, mock_model, mock_config): + """Test calling converter with None values for optional parameters.""" + converter = MockSpeculatorConverterImpl(mock_model, mock_config, None) + + with patch.object(SpeculatorModel, "from_pretrained") as mock_from_pretrained: + mock_speculator = MagicMock(spec=SpeculatorModel) + mock_from_pretrained.return_value = mock_speculator + + result = converter(output_path=None, validate_device=None) + assert result is mock_speculator + + @pytest.mark.smoke + @pytest.mark.parametrize("path_type", ["Path", "str"]) + def test_save(self, mock_model, mock_config, temp_directory, path_type): + """Test SpeculatorConverter save method with different path types.""" + converter = MockSpeculatorConverterImpl(mock_model, mock_config, None) + mock_speculator = MagicMock(spec=SpeculatorModel) + mock_speculator.save_pretrained = MagicMock() + output_path: Path | str + + if path_type == "Path": + output_path = Path(temp_directory) / "output" + else: + output_path = str(Path(temp_directory) / "output") + + converter.save(mock_speculator, output_path) + mock_speculator.save_pretrained.assert_called_once_with(output_path) + + @pytest.mark.smoke + def test_registration(self): + SpeculatorConverter.register(["test1", "test1_alt"])( + MockSpeculatorConverterImpl + ) + SpeculatorConverter.register("test2")(MockSpeculatorConverterUnsupported) + + assert SpeculatorConverter.registry is not None # type: ignore[misc] + assert "test1" in SpeculatorConverter.registry # type: ignore[misc] + assert "test1_alt" in SpeculatorConverter.registry # type: ignore[misc] + assert SpeculatorConverter.registry["test1"] is MockSpeculatorConverterImpl # type: ignore[misc] + assert SpeculatorConverter.registry["test1_alt"] is MockSpeculatorConverterImpl # type: ignore[misc] + assert "test2" in SpeculatorConverter.registry # type: ignore[misc] + assert ( + SpeculatorConverter.registry["test2"] is MockSpeculatorConverterUnsupported # type: ignore[misc] + ) + + registered = SpeculatorConverter.registered_objects() + + assert isinstance(registered, tuple) + assert len(registered) == 3 + assert MockSpeculatorConverterImpl in registered + assert MockSpeculatorConverterUnsupported in registered + + @pytest.mark.smoke + @pytest.mark.parametrize( + "algorithm", + ["test_algo", "test_algo_2", "auto"], + ) + def test_resolve(self, mock_model, mock_config, mock_verifier, algorithm): + """Test resolve_converter with specific algorithms and auto detection.""" + # Register test converters + SpeculatorConverter.register("test_algo")(MockSpeculatorConverterImpl) + SpeculatorConverter.register("test_algo_2")(MockSpeculatorConverterUnsupported) + + expected_cls = ( + MockSpeculatorConverterImpl + if algorithm in ("test_algo", "auto") + else MockSpeculatorConverterUnsupported + ) + + # Test both minimal and full argument scenarios + test_scenarios = [ + {"model": mock_model, "config": mock_config}, + { + "model": mock_model, + "config": mock_config, + "verifier": mock_verifier, + "custom_arg": "test_value", + }, + ] + + for kwargs in test_scenarios: + if algorithm == "auto": + with patch.object( + MockSpeculatorConverterImpl, "is_supported", return_value=True + ) as mock_is_supported: + converter_cls = SpeculatorConverter.resolve_converter( + algorithm=algorithm, **kwargs + ) + assert converter_cls is expected_cls + assert mock_is_supported.call_count >= 1 + else: + converter_cls = SpeculatorConverter.resolve_converter( + algorithm=algorithm, **kwargs + ) + assert converter_cls is expected_cls + + @pytest.mark.sanity + def test_resolve_failures(self, mock_model, mock_config): + """Test resolve_converter failure scenarios.""" + # Test with no registry + with pytest.raises(ValueError) as exc_info: + SpeculatorConverter.resolve_converter("test", mock_model, mock_config) + assert "No converters registered" in str(exc_info.value) + + # Register test converters for remaining tests + SpeculatorConverter.register("test_algo")(MockSpeculatorConverterUnsupported) + + # Test unknown algorithm + with pytest.raises(ValueError) as exc_info: + SpeculatorConverter.resolve_converter("unknown", mock_model, mock_config) + assert "Algorithm 'unknown' is not registered" in str(exc_info.value) + assert "Available algorithms: test_algo" in str(exc_info.value) + + # Test auto with no supported converters + with patch.object( + MockSpeculatorConverterUnsupported, "is_supported", return_value=False + ): + with pytest.raises(ValueError) as exc_info: + SpeculatorConverter.resolve_converter("auto", mock_model, mock_config) + assert "No supported converter found" in str(exc_info.value) + assert "Available algorithms: test_algo" in str(exc_info.value) diff --git a/tests/unit/convert/test_entrypoints.py b/tests/unit/convert/test_entrypoints.py new file mode 100644 index 00000000..b6e6c1c1 --- /dev/null +++ b/tests/unit/convert/test_entrypoints.py @@ -0,0 +1,299 @@ +""" +Unit tests for speculators.convert.entrypoints module. + +Tests the convert_model function which provides the primary conversion interface +for transforming non-Speculators model checkpoints to Speculators format. +""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import Mock, patch + +import pytest +import torch +from torch import nn + +from speculators.convert.entrypoints import convert_model +from speculators.model import SpeculatorModel + + +class TestConvertModel: + """Test suite for convert_model function.""" + + @pytest.fixture + def mock_converter_class(self): + """Fixture providing a mock SpeculatorConverter class.""" + converter_class = Mock() + converter_instance = Mock() + mock_speculator_model = Mock(spec=SpeculatorModel) + converter_instance.return_value = mock_speculator_model + converter_class.return_value = converter_instance + return converter_class + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("model_input", "config_input", "extra_params", "expected_calls"), + [ + # String model path with basic parameters + ( + "/path/to/model", + None, + {"algorithm": "eagle", "output_path": "/output"}, + {"config_is_model": True}, + ), + # HuggingFace model ID with explicit config + ( + "huggingface/model-id", + "/config/path", + {"algorithm": "eagle2", "verifier": "/verifier"}, + {"config_is_model": False}, + ), + # Model with all parameters + ( + "/model/path", + "/config/path", + { + "algorithm": "hass", + "algorithm_kwargs": {"param1": "value1"}, + "output_path": "/output", + "verifier": "/verifier", + "validate_device": "cuda:0", + "cache_dir": "/cache", + "force_download": True, + "token": "test_token", + }, + {"config_is_model": False}, + ), + # Auto algorithm detection + ( + "/model/path", + None, + {"algorithm": "auto"}, + {"config_is_model": True}, + ), + # PathLib paths and torch device + ( + Path("/model/path"), + None, + {"algorithm": "eagle", "validate_device": torch.device("cuda:0")}, + {"config_is_model": True}, + ), + # None algorithm_kwargs + ( + "/model/path", + None, + {"algorithm": "eagle", "algorithm_kwargs": None}, + {"config_is_model": True}, + ), + ], + ids=[ + "basic_string_model", + "hf_model_with_config", + "all_parameters", + "auto_algorithm", + "pathlib_torch_device", + "none_algorithm_kwargs", + ], + ) + @patch("speculators.convert.entrypoints.check_download_model_checkpoint") + @patch("speculators.convert.entrypoints.check_download_model_config") + @patch("speculators.convert.entrypoints.SpeculatorConverter.resolve_converter") + @patch("speculators.convert.entrypoints.logger") + def test_invocation_variations( + self, + mock_logger, + mock_resolve_converter, + mock_check_config, + mock_check_checkpoint, + mock_converter_class, + model_input, + config_input, + extra_params, + expected_calls, + ): + """Test convert_model with various parameter combinations.""" + # Setup + resolved_model = ( + str(model_input) if isinstance(model_input, Path) else model_input + ) + resolved_config = config_input if config_input else resolved_model + + mock_check_checkpoint.return_value = resolved_model + mock_check_config.return_value = resolved_config + mock_resolve_converter.return_value = mock_converter_class + mock_speculator_model = Mock(spec=SpeculatorModel) + mock_converter_class.return_value.return_value = mock_speculator_model + + # Execute + result = convert_model(model=model_input, config=config_input, **extra_params) + + # Verify core functionality + assert result == mock_speculator_model + mock_check_checkpoint.assert_called_once() + + if expected_calls["config_is_model"]: + mock_check_config.assert_called_once_with( + resolved_model, + cache_dir=extra_params.get("cache_dir"), + force_download=extra_params.get("force_download", False), + local_files_only=extra_params.get("local_files_only", False), + token=extra_params.get("token"), + revision=extra_params.get("revision"), + ) + else: + mock_check_config.assert_called_once() + + mock_resolve_converter.assert_called_once() + mock_converter_class.assert_called_once() + mock_converter_class.return_value.assert_called_once() + assert mock_logger.info.call_count >= 3 + + @pytest.mark.smoke + @patch("speculators.convert.entrypoints.check_download_model_checkpoint") + @patch("speculators.convert.entrypoints.check_download_model_config") + @patch("speculators.convert.entrypoints.SpeculatorConverter.resolve_converter") + def test_invocation_pretrained_model_instance( + self, + mock_resolve_converter, + mock_check_config, + mock_check_checkpoint, + mock_converter_class, + ): + """Test convert_model with PreTrainedModel instance and config inference.""" + # Setup + mock_model_instance = Mock() + + with patch("speculators.convert.entrypoints.isinstance") as mock_isinstance: + mock_isinstance.side_effect = lambda obj, cls: cls is not nn.Module + + mock_check_checkpoint.return_value = mock_model_instance + mock_check_config.return_value = mock_model_instance + mock_resolve_converter.return_value = mock_converter_class + mock_speculator_model = Mock(spec=SpeculatorModel) + mock_converter_class.return_value.return_value = mock_speculator_model + + # Execute + result = convert_model(model=mock_model_instance, algorithm="eagle") + + # Verify + assert result == mock_speculator_model + mock_check_checkpoint.assert_called_once_with( + mock_model_instance, + cache_dir=None, + force_download=False, + local_files_only=False, + token=None, + revision=None, + ) + mock_check_config.assert_called_once_with( + mock_model_instance, + cache_dir=None, + force_download=False, + local_files_only=False, + token=None, + revision=None, + ) + + @pytest.mark.sanity + @pytest.mark.parametrize( + ( + "error_scenario", + "model_input", + "config_input", + "algorithm", + "expected_error", + ), + [ + # nn.Module without config + ( + "nn_module_no_config", + Mock(spec=nn.Module), + None, + "eagle", + "A model config must be provided", + ), + # Invalid algorithm + ( + "invalid_algorithm", + "/model/path", + None, + "invalid_algorithm", + "Algorithm .* is not registered", + ), + # Empty algorithm + ( + "empty_algorithm", + "/model/path", + None, + "", + "Algorithm .* is not registered", + ), + ], + ids=["nn_module_no_config", "invalid_algorithm", "empty_algorithm"], + ) + @patch("speculators.convert.entrypoints.check_download_model_checkpoint") + @patch("speculators.convert.entrypoints.check_download_model_config") + @patch("speculators.convert.entrypoints.SpeculatorConverter.resolve_converter") + def test_invalid_invocations( + self, + mock_resolve_converter, + mock_check_config, + mock_check_checkpoint, + mock_converter_class, + error_scenario, + model_input, + config_input, + algorithm, + expected_error, + ): + """Test convert_model error conditions.""" + # Setup based on scenario + if error_scenario == "nn_module_no_config": + mock_check_checkpoint.return_value = model_input + else: + mock_check_checkpoint.return_value = model_input + mock_check_config.return_value = model_input + mock_resolve_converter.side_effect = ValueError( + f"Algorithm '{algorithm}' is not registered" + ) + + # Execute & Verify + with pytest.raises(ValueError, match=expected_error): + convert_model(model=model_input, config=config_input, algorithm=algorithm) + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("model_input", "config_input"), + [ + (None, "/config/path"), + ("", "/config/path"), + ("/model/path", None), + ("/model/path", ""), + ], + ids=["none_model", "empty_model", "none_config", "empty_config"], + ) + @patch("speculators.convert.entrypoints.check_download_model_checkpoint") + @patch("speculators.convert.entrypoints.check_download_model_config") + @patch("speculators.convert.entrypoints.SpeculatorConverter.resolve_converter") + def test_invalid_empty_paths( + self, + mock_resolve_converter, + mock_check_config, + mock_check_checkpoint, + mock_converter_class, + model_input, + config_input, + ): + """Test convert_model with empty/None model or config paths.""" + # Setup + mock_check_checkpoint.return_value = model_input + mock_check_config.return_value = config_input + mock_resolve_converter.return_value = mock_converter_class + mock_converter_class.side_effect = ValueError( + "Model and config paths must be provided" + ) + + # Execute & Verify + with pytest.raises(ValueError, match="Model and config paths must be provided"): + convert_model(model=model_input, config=config_input, algorithm="eagle")