diff --git a/README.md b/README.md index 2c6f5d2..a1a8c5c 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,8 @@ git clone https://github.com/audiohacking/fp8-mps-metal.git That's it! The patch will automatically load when ComfyUI starts. You'll see a message confirming it's installed. +**Note:** This custom node automatically sets `PYTORCH_ENABLE_MPS_FALLBACK=1` for better compatibility with unsupported operations. + ## Quick Start for Other Users ```bash @@ -296,15 +298,11 @@ git clone https://github.com/comfyanonymous/ComfyUI.git cd ComfyUI pip install -r requirements.txt -# Required environment variables -export PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 -export PYTORCH_ENABLE_MPS_FALLBACK=1 - # For FP8 models (FLUX, SD3.5): -# Install fp8-mps-metal and add to ComfyUI's startup -pip install -e /path/to/fp8-mps-metal -# Add to ComfyUI/main.py or a custom node: -import fp8_mps_patch; fp8_mps_patch.install() +# Install fp8-mps-metal custom node (automatically sets PYTORCH_ENABLE_MPS_FALLBACK=1) +cd custom_nodes +git clone https://github.com/audiohacking/fp8-mps-metal.git +cd .. # Run with memory optimizations python main.py --force-fp16 --use-split-cross-attention @@ -382,6 +380,67 @@ These numbers are from our validated test suite. Your results will vary by chip. - Python 3.10+ - **No Xcode required** (runtime shader compilation) +## Troubleshooting + +### For ComfyUI Users + +This extension automatically selects the best available backend: +- **Native backend** (PyTorch 2.10+) - Zero-copy, fastest +- **C++ extension fallback** (PyTorch 2.4+) - Requires compilation but works with older PyTorch + +When you start ComfyUI, check the console output to see which backend is active. + +#### "No FP8 backend available" Error + +If you see this error, you have PyTorch < 2.10 and need to build the C++ extension: + +**Option 1: Build the Extension (Recommended for ComfyUI)** +```bash +# Navigate to the custom node directory +cd ComfyUI/custom_nodes/fp8-mps-metal/ + +# Install Xcode Command Line Tools (if not already installed) +xcode-select --install + +# Build and install the C++ extension +pip install -e . + +# Restart ComfyUI +``` + +**Option 2: Upgrade PyTorch (If you manage your own environment)** +```bash +pip install --upgrade torch torchvision +``` + +Note: ComfyUI may manage its own PyTorch installation, so Option 1 is usually better. + +### For Other Users + +#### AttributeError: module 'torch.mps' has no attribute 'compile_shader' + +This error means you have PyTorch < 2.10. The library will automatically fall back to the C++ extension if available. + +**Solution 1: Upgrade PyTorch (Recommended)** +```bash +pip install --upgrade torch torchvision +``` + +**Solution 2: Build C++ Extension** +```bash +cd /path/to/fp8-mps-metal +pip install -e . +``` + +The C++ extension provides similar functionality but uses metal-cpp and pybind11 instead of the native PyTorch API. + +### Build Requirements for C++ Extension + +- **Xcode Command Line Tools**: `xcode-select --install` +- **metal-cpp**: Auto-downloaded during build +- **PyTorch**: 2.4+ (for `torch._scaled_mm`) +- **Python**: 3.10+ + ## Related Resources - [metalQwen3](https://github.com/Architect2040/metalQwen3) — Custom Metal shaders for Qwen3 transformer inference diff --git a/__init__.py b/__init__.py index bd3ff72..7366d8e 100644 --- a/__init__.py +++ b/__init__.py @@ -13,6 +13,13 @@ import sys import os +# Enable MPS fallback to CPU for unsupported operations +# This must be set BEFORE importing torch +# Use setdefault to allow users to override if needed +os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1") + +import torch + # Add current directory to path so we can import fp8_mps_patch current_dir = os.path.dirname(os.path.abspath(__file__)) if current_dir not in sys.path: @@ -21,13 +28,35 @@ # Import and install the patch try: import fp8_mps_patch + import fp8_backend # Install the patch automatically if not fp8_mps_patch.is_installed(): fp8_mps_patch.install() + + # Check which backend is available + backend, backend_name = fp8_backend.get_backend() + print("\n" + "=" * 70) print("✓ FP8 MPS Metal patch installed successfully!") print("=" * 70) + + # Show the actual MPS fallback setting + mps_fallback = os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK", "not set") + fallback_status = "Enabled" if mps_fallback == "1" else f"Disabled (={mps_fallback})" + print(f"PYTORCH_ENABLE_MPS_FALLBACK: {fallback_status}") + + if backend_name == "native": + print("Backend: Native (torch.mps.compile_shader)") + print(f"PyTorch version: {torch.__version__}") + elif backend_name == "cpp": + print("Backend: C++ Extension (metal-cpp)") + print("Note: Native backend (PyTorch 2.10+) would be faster") + else: + print("⚠️ WARNING: No FP8 backend available!") + print(fp8_backend.get_error_message()) + + print("=" * 70) print("Float8_e4m3fn operations on MPS are now supported.") print("This enables:") print(" • FP8 model weight loading (FLUX/SD3.5)") diff --git a/fp8_backend.py b/fp8_backend.py new file mode 100644 index 0000000..ea21a6e --- /dev/null +++ b/fp8_backend.py @@ -0,0 +1,156 @@ +""" +Backend selection for FP8 operations on MPS. + +This module automatically selects between: +1. fp8_mps_native - PyTorch 2.10+ native torch.mps.compile_shader() (zero-copy, preferred) +2. fp8_metal - C++ extension fallback (requires compilation with pip install -e .) + +For ComfyUI users with PyTorch < 2.10, the C++ extension will be used automatically +if available, otherwise operations will fail with a helpful error message. +""" + +import torch + +_backend = None +_backend_name = None +_init_error = None + + +def _init_backend(): + """Initialize and select the best available backend.""" + global _backend, _backend_name, _init_error + + if _backend is not None: + return + + # Try native implementation first (PyTorch 2.10+) + try: + import fp8_mps_native + if fp8_mps_native.is_available(): + _backend = fp8_mps_native + _backend_name = "native" + return + except Exception as e: + _init_error = f"Failed to load fp8_mps_native: {e}" + + # Fall back to C++ extension + try: + import fp8_metal + _backend = fp8_metal + _backend_name = "cpp" + return + except ImportError as e: + pass # Expected if not compiled + except Exception as e: + _init_error = f"Failed to load fp8_metal: {e}" + + # No backend available + _backend = None + _backend_name = None + + +def get_backend(): + """ + Get the active FP8 backend. + + Returns: + tuple: (backend_module, backend_name) where backend_name is 'native', 'cpp', or None + """ + _init_backend() + return _backend, _backend_name + + +def is_available(): + """Check if any FP8 backend is available.""" + _init_backend() + return _backend is not None + + +def get_error_message(): + """ + Get a helpful error message when no backend is available. + + Returns: + str: Error message with instructions + """ + _init_backend() + + if _backend is not None: + return None + + msg = "No FP8 backend available for MPS.\n\n" + + # Check PyTorch version + msg += f"Your PyTorch version: {torch.__version__}\n\n" + + # Check if torch.mps.compile_shader exists + has_compile_shader = hasattr(torch.mps, 'compile_shader') + + if not has_compile_shader: + msg += "The native backend requires PyTorch 2.10+.\n" + msg += "Solutions:\n\n" + msg += "1. UPGRADE PyTorch (Recommended):\n" + msg += " pip install --upgrade torch torchvision\n\n" + msg += "2. BUILD C++ Extension (For ComfyUI users who can't upgrade):\n" + msg += " # In ComfyUI/custom_nodes/fp8-mps-metal/\n" + msg += " pip install -e .\n" + msg += " # Then restart ComfyUI\n\n" + msg += " Note: This requires:\n" + msg += " - Xcode Command Line Tools: xcode-select --install\n" + msg += " - metal-cpp (auto-downloaded during build)\n\n" + else: + msg += "The native backend is available but failed to initialize.\n" + msg += "Try building the C++ extension fallback:\n" + msg += " cd ComfyUI/custom_nodes/fp8-mps-metal/\n" + msg += " pip install -e .\n\n" + + if _init_error: + msg += f"Debug info: {_init_error}\n" + + return msg + + +def fp8_scaled_mm(A, B, scale_a, scale_b): + """FP8 scaled matrix multiplication - uses best available backend.""" + backend, name = get_backend() + if backend is None: + raise RuntimeError(get_error_message()) + return backend.fp8_scaled_mm(A, B, scale_a, scale_b) + + +def fp8_scaled_mm_auto(A, B, scale_a, scale_b): + """Auto-select best FP8 matmul strategy - uses best available backend.""" + backend, name = get_backend() + if backend is None: + raise RuntimeError(get_error_message()) + + # Only native backend has the _auto variant + if name == "native": + return backend.fp8_scaled_mm_auto(A, B, scale_a, scale_b) + else: + # C++ backend only has the regular version + return backend.fp8_scaled_mm(A, B, scale_a, scale_b) + + +def fp8_dequantize(input, scale): + """FP8 to float dequantization - uses best available backend.""" + backend, name = get_backend() + if backend is None: + raise RuntimeError(get_error_message()) + return backend.fp8_dequantize(input, scale) + + +def fp8_encode(input): + """Float to FP8 encoding - uses best available backend.""" + backend, name = get_backend() + if backend is None: + raise RuntimeError(get_error_message()) + return backend.fp8_encode(input) + + +def fp8_quantize(input): + """Float to FP8 quantization with scaling - uses best available backend.""" + backend, name = get_backend() + if backend is None: + raise RuntimeError(get_error_message()) + return backend.fp8_quantize(input) diff --git a/fp8_mps_native.py b/fp8_mps_native.py index ad6af6a..0332064 100644 --- a/fp8_mps_native.py +++ b/fp8_mps_native.py @@ -13,6 +13,22 @@ _lib = None _SHADER_SOURCE = None +_compile_shader_available = None + + +def is_available(): + """ + Check if torch.mps.compile_shader is available. + + Returns: + bool: True if the native implementation can be used (PyTorch 2.10+) + """ + global _compile_shader_available + if _compile_shader_available is not None: + return _compile_shader_available + + _compile_shader_available = hasattr(torch.mps, 'compile_shader') + return _compile_shader_available def _load_shader_source(): @@ -28,11 +44,20 @@ def _load_shader_source(): def _get_lib(): - """Get or create the compiled shader library (singleton).""" + """ + Get or create the compiled shader library (singleton). + + Returns: + The compiled shader library, or None if torch.mps.compile_shader is unavailable. + """ global _lib if _lib is not None: return _lib + # Check if torch.mps.compile_shader is available (requires PyTorch 2.10+) + if not is_available(): + return None + source = _load_shader_source() _lib = torch.mps.compile_shader(source) return _lib diff --git a/fp8_mps_patch.py b/fp8_mps_patch.py index f50c9f6..96a6c4b 100644 --- a/fp8_mps_patch.py +++ b/fp8_mps_patch.py @@ -49,7 +49,7 @@ def _metal_scaled_mm(input, other, *, out_dtype=None, scale_a=None, scale_b=None use_fast_accum=use_fast_accum, ) - import fp8_mps_native + import fp8_backend # Handle FP8 dtype tensors by viewing as uint8 if input.dtype != torch.uint8: @@ -67,7 +67,7 @@ def _metal_scaled_mm(input, other, *, out_dtype=None, scale_a=None, scale_b=None if scale_b is None: scale_b = torch.tensor([1.0], device=input.device) - result = fp8_mps_native.fp8_scaled_mm_auto(input, B, scale_a, scale_b) + result = fp8_backend.fp8_scaled_mm_auto(input, B, scale_a, scale_b) # Apply bias if provided if bias is not None: @@ -154,7 +154,7 @@ def _metal_tensor_to(self, *args, **kwargs): # Scenario 2: Float/other tensor -> FP8 on MPS (quantization) # This handles on-the-fly quantization if target_device_is_mps and target_is_fp8 and not source_is_fp8: - import fp8_mps_native + import fp8_backend # First move to MPS if not already there (using original method with non-FP8 dtype) if self.device.type != "mps": @@ -166,7 +166,7 @@ def _metal_tensor_to(self, *args, **kwargs): # Use fp8_encode to convert to FP8 without scaling # This preserves value semantics (no automatic scaling) - quantized_u8 = fp8_mps_native.fp8_encode(tensor_mps) + quantized_u8 = fp8_backend.fp8_encode(tensor_mps) # View the uint8 as the requested FP8 dtype result = quantized_u8.view(target_fp8_dtype) @@ -185,14 +185,14 @@ def _metal_tensor_to(self, *args, **kwargs): else: # FP8 to non-FP8 conversion (e.g., FP8 to float32/float16) # MPS doesn't support this natively, so we need to dequantize - import fp8_mps_native + import fp8_backend # View as uint8 for dequantization self_u8 = self.view(torch.uint8) # Dequantize using scale=1.0 (no scaling, value-preserving) scale = torch.tensor([1.0], device="mps") - dequantized = fp8_mps_native.fp8_dequantize(self_u8, scale) + dequantized = fp8_backend.fp8_dequantize(self_u8, scale) # Convert from float16 (dequantize output) to target dtype if needed if dtype != torch.float16: @@ -244,7 +244,7 @@ def _metal_tensor_copy(self, src, non_blocking=False): # Scenario 2: Non-FP8 source → FP8 destination on MPS # This handles dtype conversion during copy, which MPS doesn't support natively if not source_is_fp8 and dest_is_fp8 and dest_is_mps: - import fp8_mps_native + import fp8_backend # First, move source to MPS if needed (without dtype change) if src.device.type != "mps": @@ -255,7 +255,7 @@ def _metal_tensor_copy(self, src, non_blocking=False): # Encode to FP8 using our Metal kernel (without automatic scaling) # This preserves value semantics - values are clamped to [-448, 448] # but not scaled to use the full FP8 range - quantized_u8 = fp8_mps_native.fp8_encode(src_mps) + quantized_u8 = fp8_backend.fp8_encode(src_mps) # View destination as uint8 for byte-level copy self_u8 = self.view(torch.uint8)