diff --git a/web_demo_2.6.py b/web_demo_2.6.py index 2bebc19d..d60a6f4c 100644 --- a/web_demo_2.6.py +++ b/web_demo_2.6.py @@ -15,7 +15,9 @@ import traceback import re import modelscope_studio as mgr - +from typing import Union +from transformers.dynamic_module_utils import get_imports +from unittest.mock import patch # README, How to run demo on different devices @@ -65,7 +67,11 @@ model = load_checkpoint_and_dispatch(model, model_path, dtype=torch.bfloat16, device_map=device_map) else: - model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) + if device == 'mps': + with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): + model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) + else: + model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) model = model.to(device=device) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model.eval() @@ -79,6 +85,12 @@ IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'} VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'} +def fixed_get_imports(filename: Union[str, os.PathLike]) -> list[str]: + imports = get_imports(filename) + if not torch.cuda.is_available() and "flash_attn" in imports: + imports.remove("flash_attn") + return imports + def get_file_extension(filename): return os.path.splitext(filename)[1].lower()