The explainability hook that powers Refrakt's powerful visualization and explainability component. refrakt_xai provides a unified interface for state-of-the-art Explainable AI (XAI) methods, enabling researchers and practitioners to understand and interpret their machine learning models.
- Unified XAI Interface: Consistent API across all explanation methods
- State-of-the-Art Methods: Implementation of leading XAI techniques
- PyTorch Integration: Seamless integration with PyTorch models
- Extensible Architecture: Easy to add new explanation methods
- Type Safety: Full type annotations and mypy compliance
- Comprehensive Testing: 80%+ test coverage with 68 test cases
Since refrakt_xai is part of the Refrakt ecosystem, you can install it in several ways:
# Clone the repository
git clone https://github.com/refrakt-hub/refrakt_xai.git
cd refrakt_xai# Option A: Using uv (recommended)
uv venv
source .venv/bin/activate # On Windows: .venv\Scripts\activate
# Option B: Using venv
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
# Option C: Using conda
conda create -n refrakt_xai python=3.10
conda activate refrakt_xai# Option A (with uv)
uv pip install -r pyproject.toml
# Option B (with pip)
pip install -r requirements.txtRuntime Dependencies:
torch- PyTorch deep learning frameworkcaptum- Model interpretability library
Development Dependencies:
pytest- Testing frameworkpytest-cov- Coverage reportingcoverage- Coverage measurementisort- Import sortingblack- Code formattingpylint- Code lintingruff- Fast Python linterradon- Code complexity analysislizard- Code complexity analysismypy- Type checkingpre-commit- Git hooks
refrakt_xai/
βββ methods/ # XAI method implementations
β βββ saliency.py # Gradient-based saliency maps
β βββ integrated_gradients.py # Integrated gradients
β βββ layer_gradcam.py # Layer-wise GradCAM
β βββ occlusion.py # Occlusion sensitivity
β βββ deeplift.py # DeepLift attribution
β βββ tcav.py # Testing with Concept Activation Vectors
β βββ reconstruction_attribution.py # Reconstruction-based attribution
βββ utils/ # Utility functions
β βββ model_utils.py # Model validation and processing
β βββ layer_detection.py # Automatic layer detection
β βββ layer_resolvers.py # Layer path resolution
β βββ concept_utils.py # Concept-based utilities
βββ tests/ # Comprehensive test suite
β βββ methods/ # Method-specific tests
β βββ utils/ # Utility function tests
βββ base.py # Base XAI class interface
βββ registry.py # Method registration system
βββ __init__.py # Package initialization
| Method | Description | Use Case |
|---|---|---|
| SaliencyXAI | Gradient-based attribution maps | General model interpretation |
| IntegratedGradientsXAI | Path-integrated gradients | Robust attribution analysis |
| LayerGradCAMXAI | Layer-wise GradCAM | CNN visualization |
| OcclusionXAI | Occlusion sensitivity | Feature importance analysis |
| DeepLiftXAI | DeepLift attribution | Deep network interpretation |
| TCAVXAI | Concept activation vectors | Concept-based explanations |
| ReconstructionAttributionXAI | Reconstruction-based attribution | Autoencoder interpretation |
import torch
import torch.nn as nn
from refrakt_xai import SaliencyXAI, IntegratedGradientsXAI, LayerGradCAMXAI
# Define a simple model
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 16, 3, padding=1)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(16, 10)
def forward(self, x):
x = self.conv(x)
x = self.pool(x)
x = x.view(x.size(0), -1)
return self.fc(x)
# Initialize model and input
model = SimpleCNN()
input_tensor = torch.randn(1, 3, 32, 32, requires_grad=True)
# Create XAI explanations
saliency = SaliencyXAI(model)
saliency_attributions = saliency.explain(input_tensor, target=0)
ig = IntegratedGradientsXAI(model)
ig_attributions = ig.explain(input_tensor, target=0)
gradcam = LayerGradCAMXAI(model, layer="conv")
gradcam_attributions = gradcam.explain(input_tensor, target=0)from refrakt_xai import OcclusionXAI, DeepLiftXAI
# Occlusion analysis
occlusion = OcclusionXAI(model, window_size=8)
occlusion_attributions = occlusion.explain(input_tensor, target=0)
# DeepLift attribution
deeplift = DeepLiftXAI(model)
deeplift_attributions = deeplift.explain(input_tensor, target=0)
# Auto-detection of layers
auto_gradcam = LayerGradCAMXAI(model, layer="auto")
auto_attributions = auto_gradcam.explain(input_tensor, target=0)# Process multiple inputs
batch_input = torch.randn(4, 3, 32, 32, requires_grad=True)
batch_targets = [0, 1, 2, 3]
# Batch processing with individual targets
batch_attributions = saliency.explain(batch_input, target=batch_targets)
# Single target for entire batch
single_target_attributions = saliency.explain(batch_input, target=0)# Works with any PyTorch model
import torchvision.models as models
resnet = models.resnet18(pretrained=True)
resnet.eval()
# Layer-specific analysis
layer_gradcam = LayerGradCAMXAI(resnet, layer="layer4.1.conv2")
attributions = layer_gradcam.explain(input_tensor, target=0)We welcome contributions! Please see CONTRIBUTORS.md for detailed guidelines on:
- Setting up the development environment
- Code style and conventions
- Testing requirements
- Pull request process
- Adding new XAI methods
refrakt_xai is designed as a core component of the Refrakt ecosystem, providing:
- Natural Language Interface: XAI methods can be invoked through Refrakt's NL orchestrator
- Visualization Pipeline: Attributions are automatically integrated with Refrakt's visualization system
- Workflow Integration: Seamless integration with Refrakt's ML/DL workflow orchestration
- Scalability: Methods are optimized for large-scale model analysis
This project is licensed under the same license as the main Refrakt project. See LICENSE for details.
- Built on top of Captum for robust XAI implementations
- Inspired by the XAI research community
- Part of the Refrakt ecosystem for scalable ML/DL workflows
Part of the Refrakt ecosystem - Natural-language orchestrator for scalable ML/DL workflows. [COMING SOON]