Skip to content

Port LLaVA to new API #7817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 24, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 65 additions & 3 deletions invokeai/backend/model_manager/config.py
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@
"""

# pyright: reportIncompatibleVariableOverride=false
import json
import logging
import time
from abc import ABC, abstractmethod
@@ -232,6 +233,23 @@ def component_paths(self):
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
return {f for f in self.path.rglob("*") if f.suffix in extensions}

def repo_variant(self):
if self.format_type == ModelFormat.Checkpoint:
return None

weight_files = list(self.path.glob("**/*.safetensors"))
weight_files.extend(list(self.path.glob("**/*.bin")))
for x in weight_files:
if ".fp16" in x.suffixes:
return ModelRepoVariant.FP16
if "openvino_model" in x.name:
return ModelRepoVariant.OpenVINO
if "flax_model" in x.name:
return ModelRepoVariant.Flax
if x.suffix == ".onnx":
return ModelRepoVariant.ONNX
return ModelRepoVariant.Default

@staticmethod
def load_state_dict(path: Path):
with SilenceWarnings():
@@ -359,21 +377,43 @@ def matches(cls, mod: ModelOnDisk) -> bool:
This doesn't need to be a perfect test - the aim is to eliminate unlikely matches quickly before parsing."""
pass

@staticmethod
def cast_overrides(overrides: dict[str, Any]):
"""Casts user overrides from str to Enum"""
if "type" in overrides:
overrides["type"] = ModelType(overrides["type"])

if "format" in overrides:
overrides["format"] = ModelFormat(overrides["format"])

if "base" in overrides:
overrides["base"] = BaseModelType(overrides["base"])

if "source_type" in overrides:
overrides["source_type"] = ModelSourceType(overrides["source_type"])

@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
"""Creates an instance of this config or raises InvalidModelConfigException."""
if not cls.matches(mod):
raise InvalidModelConfigException(f"Path {mod.path} does not match {cls.__name__} format")

fields = cls.parse(mod)
cls.cast_overrides(overrides)
fields.update(overrides)

type = fields.get("type") or cls.model_fields["type"].default
base = fields.get("base") or cls.model_fields["base"].default

fields["path"] = mod.path.as_posix()
fields["source"] = fields.get("source") or fields["path"]
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
fields["name"] = mod.name
fields["name"] = name = fields.get("name") or mod.name
fields["hash"] = fields.get("hash") or mod.hash()
fields["key"] = fields.get("key") or uuid_string()
fields["description"] = fields.get("description") or f"{base.value} {type.value} model {name}"
fields["repo_variant"] = fields.get("repo_variant") or mod.repo_variant()

fields.update(overrides)
return cls(**fields)


@@ -625,12 +665,34 @@ class FluxReduxConfig(LegacyProbeMixin, ModelConfigBase):
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint


class LlavaOnevisionConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase):
class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
"""Model config for Llava Onevision models."""

type: Literal[ModelType.LlavaOnevision] = ModelType.LlavaOnevision
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers

@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
if mod.format_type == ModelFormat.Checkpoint:
return False

config_path = mod.path / "config.json"
try:
with open(config_path, "r") as file:
config = json.load(file)
except FileNotFoundError:
return False

architectures = config.get("architectures")
return architectures and architectures[0] == "LlavaOnevisionForConditionalGeneration"

@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
return {
"base": BaseModelType.Any,
"variant": ModelVariantType.Normal,
}


def get_model_discriminator_value(v: Any) -> str:
"""
8 changes: 5 additions & 3 deletions tests/test_model_probe.py
Original file line number Diff line number Diff line change
@@ -148,22 +148,24 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading):
configs_with_tests = set()
model_paths = ModelSearch().search(datadir / "stripped_models")
fake_hash = "abcdefgh" # skip hashing to make test quicker
fake_key = "123" # fixed uuid for comparison

for path in model_paths:
legacy_config = new_config = None

try:
legacy_config = ModelProbe.probe(path, {"hash": fake_hash})
legacy_config = ModelProbe.probe(path, {"hash": fake_hash, "key": fake_key})
except InvalidModelConfigException:
pass

try:
new_config = ModelConfigBase.classify(path, hash=fake_hash)
new_config = ModelConfigBase.classify(path, hash=fake_hash, key=fake_key)
except InvalidModelConfigException:
pass

if legacy_config and new_config:
assert legacy_config == new_config
assert type(legacy_config) is type(new_config)
assert legacy_config.model_dump_json() == new_config.model_dump_json()

elif legacy_config:
assert type(legacy_config) in ModelConfigBase._USING_LEGACY_PROBE
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown