Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion swift/llm/model/model/internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch
from transformers.dynamic_module_utils import get_class_from_dynamic_module

from transformers import GenerationMixin
from swift.llm import TemplateType
from ..constant import LLMModelType, MLLMModelType, RMModelType
from ..model_arch import ModelArch
Expand Down Expand Up @@ -139,6 +139,16 @@ def get_model_tokenizer_internvl(model_dir: str,
use_submodel_func(model, 'language_model')
patch_output_clone(model.language_model.get_input_embeddings())

if model is not None:
# fix missing generate method for InternVL-2.5 models when using transformers >= 4.50
llm_part = getattr(model, 'language_model', model)
if not hasattr(llm_part, 'generate'):
print("Detected missing 'generate' method (transformers >= 4.50). Injecting GenerationMixin...")

cls = llm_part.__class__
if GenerationMixin not in cls.__bases__:
cls.__bases__ = cls.__bases__ + (GenerationMixin,)

Comment on lines +142 to +151
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block has some indentation issues and uses print for logging. For consistency with the rest of the codebase and to allow users to control log levels, it's better to use a proper logger.

I've corrected the indentation and replaced print with logger.info.

To make this work, you'll also need to add the following at the top of the file:

from swift.utils import get_logger

logger = get_logger()
Suggested change
if model is not None:
# fix missing generate method for InternVL-2.5 models when using transformers >= 4.50
llm_part = getattr(model, 'language_model', model)
if not hasattr(llm_part, 'generate'):
print("Detected missing 'generate' method (transformers >= 4.50). Injecting GenerationMixin...")
cls = llm_part.__class__
if GenerationMixin not in cls.__bases__:
cls.__bases__ = cls.__bases__ + (GenerationMixin,)
if model is not None:
# fix missing generate method for InternVL-2.5 models when using transformers >= 4.50
llm_part = getattr(model, 'language_model', model)
if not hasattr(llm_part, 'generate'):
logger.info("Detected missing 'generate' method (transformers >= 4.50). Injecting GenerationMixin...")
cls = llm_part.__class__
if GenerationMixin not in cls.__bases__:
cls.__bases__ = cls.__bases__ + (GenerationMixin,)

return model, tokenizer


Expand Down