Skip to content
75 changes: 67 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ git clone https://github.com/audiohacking/fp8-mps-metal.git

That's it! The patch will automatically load when ComfyUI starts. You'll see a message confirming it's installed.

**Note:** This custom node automatically sets `PYTORCH_ENABLE_MPS_FALLBACK=1` for better compatibility with unsupported operations.

## Quick Start for Other Users

```bash
Expand Down Expand Up @@ -296,15 +298,11 @@ git clone https://github.com/comfyanonymous/ComfyUI.git
cd ComfyUI
pip install -r requirements.txt

# Required environment variables
export PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0
export PYTORCH_ENABLE_MPS_FALLBACK=1

# For FP8 models (FLUX, SD3.5):
# Install fp8-mps-metal and add to ComfyUI's startup
pip install -e /path/to/fp8-mps-metal
# Add to ComfyUI/main.py or a custom node:
import fp8_mps_patch; fp8_mps_patch.install()
# Install fp8-mps-metal custom node (automatically sets PYTORCH_ENABLE_MPS_FALLBACK=1)
cd custom_nodes
git clone https://github.com/audiohacking/fp8-mps-metal.git
cd ..

# Run with memory optimizations
python main.py --force-fp16 --use-split-cross-attention
Expand Down Expand Up @@ -382,6 +380,67 @@ These numbers are from our validated test suite. Your results will vary by chip.
- Python 3.10+
- **No Xcode required** (runtime shader compilation)

## Troubleshooting

### For ComfyUI Users

This extension automatically selects the best available backend:
- **Native backend** (PyTorch 2.10+) - Zero-copy, fastest
- **C++ extension fallback** (PyTorch 2.4+) - Requires compilation but works with older PyTorch

When you start ComfyUI, check the console output to see which backend is active.

#### "No FP8 backend available" Error

If you see this error, you have PyTorch < 2.10 and need to build the C++ extension:

**Option 1: Build the Extension (Recommended for ComfyUI)**
```bash
# Navigate to the custom node directory
cd ComfyUI/custom_nodes/fp8-mps-metal/

# Install Xcode Command Line Tools (if not already installed)
xcode-select --install

# Build and install the C++ extension
pip install -e .

# Restart ComfyUI
```

**Option 2: Upgrade PyTorch (If you manage your own environment)**
```bash
pip install --upgrade torch torchvision
```

Note: ComfyUI may manage its own PyTorch installation, so Option 1 is usually better.

### For Other Users

#### AttributeError: module 'torch.mps' has no attribute 'compile_shader'

This error means you have PyTorch < 2.10. The library will automatically fall back to the C++ extension if available.

**Solution 1: Upgrade PyTorch (Recommended)**
```bash
pip install --upgrade torch torchvision
```

**Solution 2: Build C++ Extension**
```bash
cd /path/to/fp8-mps-metal
pip install -e .
```

The C++ extension provides similar functionality but uses metal-cpp and pybind11 instead of the native PyTorch API.

### Build Requirements for C++ Extension

- **Xcode Command Line Tools**: `xcode-select --install`
- **metal-cpp**: Auto-downloaded during build
- **PyTorch**: 2.4+ (for `torch._scaled_mm`)
- **Python**: 3.10+

## Related Resources

- [metalQwen3](https://github.com/Architect2040/metalQwen3) — Custom Metal shaders for Qwen3 transformer inference
Expand Down
29 changes: 29 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@
import sys
import os

# Enable MPS fallback to CPU for unsupported operations
# This must be set BEFORE importing torch
# Use setdefault to allow users to override if needed
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")

import torch

# Add current directory to path so we can import fp8_mps_patch
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
Expand All @@ -21,13 +28,35 @@
# Import and install the patch
try:
import fp8_mps_patch
import fp8_backend

# Install the patch automatically
if not fp8_mps_patch.is_installed():
fp8_mps_patch.install()

# Check which backend is available
backend, backend_name = fp8_backend.get_backend()

print("\n" + "=" * 70)
print("✓ FP8 MPS Metal patch installed successfully!")
print("=" * 70)

# Show the actual MPS fallback setting
mps_fallback = os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK", "not set")
fallback_status = "Enabled" if mps_fallback == "1" else f"Disabled (={mps_fallback})"
print(f"PYTORCH_ENABLE_MPS_FALLBACK: {fallback_status}")

if backend_name == "native":
print("Backend: Native (torch.mps.compile_shader)")
print(f"PyTorch version: {torch.__version__}")
elif backend_name == "cpp":
print("Backend: C++ Extension (metal-cpp)")
print("Note: Native backend (PyTorch 2.10+) would be faster")
else:
print("⚠️ WARNING: No FP8 backend available!")
print(fp8_backend.get_error_message())

print("=" * 70)
print("Float8_e4m3fn operations on MPS are now supported.")
print("This enables:")
print(" • FP8 model weight loading (FLUX/SD3.5)")
Expand Down
156 changes: 156 additions & 0 deletions fp8_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""
Backend selection for FP8 operations on MPS.

This module automatically selects between:
1. fp8_mps_native - PyTorch 2.10+ native torch.mps.compile_shader() (zero-copy, preferred)
2. fp8_metal - C++ extension fallback (requires compilation with pip install -e .)

For ComfyUI users with PyTorch < 2.10, the C++ extension will be used automatically
if available, otherwise operations will fail with a helpful error message.
"""

import torch

_backend = None
_backend_name = None
_init_error = None


def _init_backend():
"""Initialize and select the best available backend."""
global _backend, _backend_name, _init_error

if _backend is not None:
return

# Try native implementation first (PyTorch 2.10+)
try:
import fp8_mps_native
if fp8_mps_native.is_available():
_backend = fp8_mps_native
_backend_name = "native"
return
except Exception as e:
_init_error = f"Failed to load fp8_mps_native: {e}"

# Fall back to C++ extension
try:
import fp8_metal
_backend = fp8_metal
_backend_name = "cpp"
return
except ImportError as e:
pass # Expected if not compiled
except Exception as e:
_init_error = f"Failed to load fp8_metal: {e}"

# No backend available
_backend = None
_backend_name = None


def get_backend():
"""
Get the active FP8 backend.

Returns:
tuple: (backend_module, backend_name) where backend_name is 'native', 'cpp', or None
"""
_init_backend()
return _backend, _backend_name


def is_available():
"""Check if any FP8 backend is available."""
_init_backend()
return _backend is not None


def get_error_message():
"""
Get a helpful error message when no backend is available.

Returns:
str: Error message with instructions
"""
_init_backend()

if _backend is not None:
return None

msg = "No FP8 backend available for MPS.\n\n"

# Check PyTorch version
msg += f"Your PyTorch version: {torch.__version__}\n\n"

# Check if torch.mps.compile_shader exists
has_compile_shader = hasattr(torch.mps, 'compile_shader')

if not has_compile_shader:
msg += "The native backend requires PyTorch 2.10+.\n"
msg += "Solutions:\n\n"
msg += "1. UPGRADE PyTorch (Recommended):\n"
msg += " pip install --upgrade torch torchvision\n\n"
msg += "2. BUILD C++ Extension (For ComfyUI users who can't upgrade):\n"
msg += " # In ComfyUI/custom_nodes/fp8-mps-metal/\n"
msg += " pip install -e .\n"
msg += " # Then restart ComfyUI\n\n"
msg += " Note: This requires:\n"
msg += " - Xcode Command Line Tools: xcode-select --install\n"
msg += " - metal-cpp (auto-downloaded during build)\n\n"
else:
msg += "The native backend is available but failed to initialize.\n"
msg += "Try building the C++ extension fallback:\n"
msg += " cd ComfyUI/custom_nodes/fp8-mps-metal/\n"
msg += " pip install -e .\n\n"

if _init_error:
msg += f"Debug info: {_init_error}\n"

return msg


def fp8_scaled_mm(A, B, scale_a, scale_b):
"""FP8 scaled matrix multiplication - uses best available backend."""
backend, name = get_backend()
if backend is None:
raise RuntimeError(get_error_message())
return backend.fp8_scaled_mm(A, B, scale_a, scale_b)


def fp8_scaled_mm_auto(A, B, scale_a, scale_b):
"""Auto-select best FP8 matmul strategy - uses best available backend."""
backend, name = get_backend()
if backend is None:
raise RuntimeError(get_error_message())

# Only native backend has the _auto variant
if name == "native":
return backend.fp8_scaled_mm_auto(A, B, scale_a, scale_b)
else:
# C++ backend only has the regular version
return backend.fp8_scaled_mm(A, B, scale_a, scale_b)


def fp8_dequantize(input, scale):
"""FP8 to float dequantization - uses best available backend."""
backend, name = get_backend()
if backend is None:
raise RuntimeError(get_error_message())
return backend.fp8_dequantize(input, scale)


def fp8_encode(input):
"""Float to FP8 encoding - uses best available backend."""
backend, name = get_backend()
if backend is None:
raise RuntimeError(get_error_message())
return backend.fp8_encode(input)


def fp8_quantize(input):
"""Float to FP8 quantization with scaling - uses best available backend."""
backend, name = get_backend()
if backend is None:
raise RuntimeError(get_error_message())
return backend.fp8_quantize(input)
27 changes: 26 additions & 1 deletion fp8_mps_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,22 @@

_lib = None
_SHADER_SOURCE = None
_compile_shader_available = None


def is_available():
"""
Check if torch.mps.compile_shader is available.

Returns:
bool: True if the native implementation can be used (PyTorch 2.10+)
"""
global _compile_shader_available
if _compile_shader_available is not None:
return _compile_shader_available

_compile_shader_available = hasattr(torch.mps, 'compile_shader')
return _compile_shader_available


def _load_shader_source():
Expand All @@ -28,11 +44,20 @@ def _load_shader_source():


def _get_lib():
"""Get or create the compiled shader library (singleton)."""
"""
Get or create the compiled shader library (singleton).

Returns:
The compiled shader library, or None if torch.mps.compile_shader is unavailable.
"""
global _lib
if _lib is not None:
return _lib

# Check if torch.mps.compile_shader is available (requires PyTorch 2.10+)
if not is_available():
return None

source = _load_shader_source()
_lib = torch.mps.compile_shader(source)
return _lib
Expand Down
Loading