Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ ignore = [
"COM812",
"ISC001",
"TC002",
"TC003", # allow imports outside of type checking blocks
"TC003", # allow imports for typing outside of type checking blocks
"S311", # allow random number generators
"PLW1514", # allow Path.open without encoding
"RET505", # allow `else` blocks
Expand Down
276 changes: 157 additions & 119 deletions src/speculators/config.py

Large diffs are not rendered by default.

447 changes: 168 additions & 279 deletions src/speculators/model.py

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion src/speculators/models/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def attach_verifier(
self,
verifier: Union[str, os.PathLike, PreTrainedModel],
mode: Optional[Literal["full", "train_only"]] = None,
add_to_config: bool = True,
) -> PreTrainedModel:
"""
Attach a verifier model to the EagleSpeculator for speculative decoding.
Expand Down Expand Up @@ -344,15 +345,19 @@ def attach_verifier(
model directory, a Hugging Face model identifier, or an instance of
PreTrainedModel. If a path or identifier is provided, the model will be
loaded automatically. If an instance is provided, it will be used directly.
:param mode: The mode for attaching the verifier. Can be "full" or "train_only".
:param mode: The mode for attaching the verifier.
Can be "full" or "train_only".
If None, defaults to "full". In "train_only" mode, only the layers
required for a forward pass are attached, and the speculator cannot
perform generation until a full verifier is attached.
:param add_to_config: Whether to update the speculator's configuration
with details from the attached verifier model.
:return: The PreTrainedModel instance for the verifier that was attached.
"""
verifier = super().attach_verifier(
verifier=verifier,
mode=mode,
add_to_config=add_to_config,
)

# Extract layers from the verifier model
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_verifier_config_from_verifier_config():
cache_dir=tmp_dir,
)

config = VerifierConfig.from_config(
config = VerifierConfig.from_pretrained(
pretrained_config, name_or_path="RedHatAI/Llama-3.1-8B-Instruct"
)
assert config.name_or_path == "RedHatAI/Llama-3.1-8B-Instruct"
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/models/test_eagle_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def forward(self, input_ids, **kwargs):
@pytest.fixture
def sample_llama_config():
return LlamaConfig(
name_or_path="test/verifier",
architectures=["LlamaForCausalLM"],
attention_bias=False,
attention_dropout=0.0,
bos_token_id=128000,
Expand Down
Loading