Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Multi-Modal Explainability (XAI) for Medical Imaging with Captum

Abstract

Medical AI models increasingly fuse multiple imaging modalities — CT scans, MRI sequences, and clinical records — to make high-stakes decisions. This workshop teaches you to explain such multi-modal models using Captum, PyTorch’s interpretability library. You will build three complete XAI pipelines covering Classification (Image + Tabular), Regression (Image + Image), and Segmentation (Image + Image → Mask), all on procedurally generated data with known ground-truth signal locations so explanations can be visually verified.

Keywords:medical imagingdeep learningMICCAIAI tutorialsretina classificationmultiple instance learning

Open In Colab Open in Binder View on GitHub

# [RUN ONCE] Install dependencies
# Force-reinstall numpy first to avoid ABI binary incompatibility errors,
# then install the remaining packages against the fresh numpy build.
import subprocess, sys

def pip(*args):
    subprocess.check_call([sys.executable, "-m", "pip", *args])

pip("install", "--upgrade", "numpy", "--quiet")
pip("install", "--upgrade", "captum", "torchvision", "seaborn", "scikit-learn", "scipy", "--quiet")

print("All dependencies installed. Please restart the kernel/runtime, then re-run from cell 2 onward.")
All dependencies installed. Please restart the kernel/runtime, then re-run from cell 2 onward.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings("ignore")

from captum.attr import (
    IntegratedGradients,
    DeepLift,
    GradientShap,
    Saliency,
    Occlusion,
    FeatureAblation,
    LayerGradCam,
    LayerAttribution,
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
print(f"Using device: {DEVICE}")
Using device: cpu

Section 2 — XAI Theory & Taxonomy

Overview: This section introduces Explainable AI (XAI) concepts, families of methods, and why they matter in medical imaging.

XAI Introduction Part 1XAI Introduction Part 2

Figure: Comprehensive introduction to XAI — the black box problem, interpretability vs explainability, why XAI is critical in medicine, and the three families of XAI methods we’ll explore in this tutorial.


2.0 — Interpretability vs Explainability

These two terms are often confused but are fundamentally different (see figure above):

ConceptDefinitionExample
InterpretabilityThe model is inherently understandable by designLinear regression, decision tree — you can read the coefficients
ExplainabilityA post-hoc method approximates why a black-box model predicted XIG on a CNN — explains a single prediction after the fact

Deep networks used in medical imaging are not interpretable. We use XAI methods to explain their behavior.


2.1 — Why XAI in Medicine?

  • Clinical Safety: Detect model errors before harm — verify the model is not relying on spurious correlations

  • Regulatory Compliance: EU AI Act and FDA guidance mandate explainability for high-risk AI systems

  • Trust Building: Clinicians need to understand and verify the model’s reasoning before making clinical decisions

  • Scientific Discovery: XAI can reveal new biomarkers or disease patterns not previously known


2.2 — Family 1: Gradient-Based Methods

Core intuition: “Which pixels, if perturbed slightly, would change the model output the most?”

Mathematically, given a model ff and input xx, the saliency is:

S(x)=f(x)xS(x) = \left| \frac{\partial f(x)}{\partial x} \right|

This is simply the gradient of the output with respect to the input. A large gradient at pixel (i,j)(i,j) means that pixel has a strong influence on the prediction.

Unimodal case

MethodIdeaFormula
SaliencyRaw gradient magnitudexf(x)|{\nabla_x f(x)}|
Integrated Gradients (IG)Integrate gradient from baseline xx' to input xx — satisfies completeness and sensitivity axioms01xf(x+α(xx))(xx)dα\int_0^1 \nabla_x f(x' + \alpha(x-x')) \cdot (x-x') \, d\alpha
DeepLIFTPropagate difference from a reference activation back to inputsCΔx=δfδΔxC_{\Delta x} = \frac{\delta f}{\delta \Delta x}, where Δx=xxref\Delta x = x - x_{\text{ref}}
GradCAMGradient w.r.t. last conv feature map; coarse but class-discriminativeCAMc=ReLU ⁣(kαkcAk)\text{CAM}^c = \text{ReLU}\!\left(\sum_k \alpha_k^c A^k\right), αkc=1HWi,jycAijk\alpha_k^c = \frac{1}{HW}\sum_{i,j} \frac{\partial y^c}{\partial A_{ij}^k}
GradientSHAPApproximates SHAP values using gradient of noisy samples × (input − baseline)Stochastic IG variant

Why IG over raw gradients? Raw gradients suffer from gradient saturation — near a flat region, gradients vanish even if the feature is important. IG fixes this by integrating the path from baseline to input.

Extending to multi-modal (Captum tuple-input API)

When the model takes two inputs (x1,x2)(x_1, x_2), gradients flow through each stream independently:

attr(x1)=01x1f(x1+αΔx1,x2+αΔx2)Δx1dα\text{attr}(x_1) = \int_0^1 \nabla_{x_1} f(x_1' + \alpha \Delta x_1,\, x_2' + \alpha \Delta x_2) \cdot \Delta x_1 \, d\alpha
attr(x2)=01x2f(x1+αΔx1,x2+αΔx2)Δx2dα\text{attr}(x_2) = \int_0^1 \nabla_{x_2} f(x_1' + \alpha \Delta x_1,\, x_2' + \alpha \Delta x_2) \cdot \Delta x_2 \, d\alpha

Captum handles this automatically via the tuple-input API — you get one attribution tensor per modality:

attr_x1, attr_x2 = ig.attribute(
    inputs=(x1, x2),
    baselines=(baseline1, baseline2),
    target=class_idx,
)

You can then compute modality importance as the mean absolute attribution per modality:

importance(xk)=E[attr(xk)]jE[attr(xj)]\text{importance}(x_k) = \frac{\mathbb{E}[|\text{attr}(x_k)|]}{\sum_j \mathbb{E}[|\text{attr}(x_j)|]}

2.3 — Family 2: Perturbation-Based Methods

Core intuition: “What happens to the model output when I hide or change parts of the input?”

Instead of computing gradients, these methods probe the model by systematically removing or replacing regions and measuring the output change.

Unimodal case

MethodIdea
OcclusionSlide a grey/black patch over the image; large output drop = important region
Feature AblationZero out entire features (e.g., a tabular column, or a whole modality)
LIMEFit a local linear model on perturbed samples around the input
SHAPUse game-theory Shapley values; computes average marginal contribution of each feature

Trade-off vs gradient-based: Perturbation methods are model-agnostic (no backprop needed) but are much slower — every probe requires a forward pass. Occlusion on a 224×224 image with a 8×8 patch and stride 4 requires ~(55×55) = 3025 forward passes per sample.

Extending to multi-modal

Perturbation-based methods extend naturally: you can ablate at the modality level by zeroing an entire input tensor.

# Feature Ablation — assign each modality its own group ID
mask_t1 = torch.ones_like(t1_xai)          # group 1
mask_t2 = torch.ones_like(t2_xai) * 2      # group 2
attr_t1, attr_t2 = fa.attribute(
    inputs=(t1_xai, t2_xai),
    feature_mask=(mask_t1, mask_t2),
)

This gives a single importance score per modality — directly answering “which MRI sequence matters more?”


2.4 — Family 3: Attention-Based Methods

Core intuition: “Where did the model’s Transformer attention heads look?”

Attention-based methods are only applicable when the model contains attention mechanisms (Transformers, Vision Transformers, nnFormer, etc.).

Unimodal case

MethodIdea
Attention mapsDirectly visualize multi-head attention weights for the [CLS] token
Attention RolloutPropagate attention matrices through layers (account for residual connections)
DINOSelf-supervised attention captures semantic object boundaries without labels

Extending to multi-modal

For multi-modal Transformers (e.g., MedFuse, BioViL-T), cross-attention layers connect tokens from different modalities. The cross-attention matrix Aij(cross)A^{(\text{cross})}_{ij} tells you: “how much did modality-B token jj influence modality-A token ii?”

Note: The models in this workshop use CNNs, not Transformers — so attention-based methods do not directly apply here. We cover this family for completeness. GradCAM is often used as an attention-like proxy for CNN models.


2.5 — Multi-Modal Challenge & Captum API

Single-input XAI tells you where a model looked in one image. Multi-modal XAI must also answer:

“How much did each modality (CT vs. EHR, T1 vs. T2, FLAIR vs. T1ce) contribute?”

Captum’s tuple-input API elegantly solves this by treating each modality as a separate input with independent gradient flow:

Captum API Cheatsheet — Part 1: Core Concepts

Figure 5a: Captum’s multi-modal API pattern. Each modality gets its own attribution tensor. The inputs tuple structure is mirrored in the outputs — clean and consistent!

Captum API Cheatsheet — Part 2: Method Reference

Figure 5b: Quick reference for common Captum methods in multi-modal scenarios. Shows method signatures, when to use each approach, and key parameters (baselines, targets, n_steps).

Core API pattern:

attr1, attr2 = method.attribute(
    inputs=(x1, x2),
    baselines=(b1, b2),
    target=class_idx,   # or None for regression, or 0 for single output
)

Key principles:

  • Both inputs must have requires_grad=True

  • Baselines should be meaningful references (zeros, dataset mean, blur-masked image)

  • The target parameter specifies which output neuron to explain:

    • Classification: class index (0, 1, 2, ...)

    • Regression: output neuron index (usually 0 for single-output models)

    • Segmentation: use a wrapper to reduce spatial output to scalar (e.g., mean probability)


Section 3 — Utility Functions & Model Wrappers

All plotting helpers and Captum-compatibility wrappers live in the standalone module xai_utils.py (in this folder). This keeps the notebook cells focused on the XAI logic while making the utilities reusable in scripts and tests.

# ── Import shared utilities from xai_utils.py ─────────────────────────────────
import sys, pathlib
sys.path.insert(0, str(pathlib.Path(".").resolve()))

from xai_utils import (
    normalize_attr,
    plot_img_attr,
    plot_tabular_attr,
    plot_modality_importance,
    overlay_attribution,
    plot_dual_attr,
    MultiInputWrapper,
    SegReductionWrapper,
    EHR_FEATURE_NAMES,
    IMG_SIZE,
)
print("xai_utils loaded — utility functions and model wrappers ready.")
xai_utils loaded — utility functions and model wrappers ready.

Section 4 — Realistic Synthetic Data Generation & Loading

All data is procedurally generated by the standalone script generate_dataset.py to faithfully mimic the physical and clinical properties of real medical imaging data, while embedding known ground-truth signals for XAI verification. This mirrors a real-world workflow where data preparation and model training live in separate stages.

TaskModalitiesRealism HighlightsPlanted Signal
ClassificationCT (224×224) + 12 EHR featuresHounsfield-inspired tissue layers, anatomical noise, realistic circular/spiculated noduleNodule in left lobe (label=0/benign) or right lobe (label=1/malignant + spiculation); SUVmax, age, pack-years discriminative
RegressionT1 MRI + T2 MRI (224×224)Tissue-contrast T1/T2 relaxation ratios (WM bright T1, CSF dark T1/bright T2), bandpass spatial texture, smooth tissue gradientsWM/CSF T1 signal variation tracks EDSS-like score; periventricular T2 hyperintensity scales with score
SegmentationFLAIR + T1ce (224×224)Brain parenchyma with sulci/gyri folding texture, realistic ring-enhancing lesion with necrotic core + peritumoral edemaFLAIR: edema halo bright; T1ce: ring-enhancing rim bright, necrotic core dark; mask = full tumor
import subprocess, pathlib

DATA_DIR = pathlib.Path("data")

if not DATA_DIR.exists() or len(list(DATA_DIR.glob("*.npz"))) < 9:
    print("Running generate_dataset.py  (may take ~2 min on first run) ...")
    subprocess.check_call(
        [sys.executable, "generate_dataset.py",
         "--n", "300", "--seed", "42",
         "--xai_n", "30", "--xai_seed", "999",
         "--out_dir", str(DATA_DIR)],
    )
else:
    print(f"Data directory already populated ({len(list(DATA_DIR.glob('*')))} files). Skipping generation.")

# Verify files exist
expected = [
    "cls_images.npz", "cls_tabular.npz", "cls_metadata.csv",
    "cls_xai_images.npz", "cls_xai_tabular.npz",
    "reg_t1.npz", "reg_t2.npz", "reg_scores.npz",
    "reg_xai_t1.npz", "reg_xai_t2.npz", "reg_xai_scores.npz",
    "seg_flair.npz", "seg_t1ce.npz", "seg_masks.npz",
    "seg_xai_flair.npz", "seg_xai_t1ce.npz", "seg_xai_masks.npz",
]
for f in expected:
    assert (DATA_DIR / f).exists(), f"Missing: {f}"
print(f"All {len(expected)} data files verified in {DATA_DIR}/")
Data directory already populated (17 files). Skipping generation.
All 17 data files verified in data/

4.2 — Constants & helper functions (used throughout the notebook)

from scipy.ndimage import gaussian_filter, label as ndlabel

# -- Tissue & texture helpers (used by EDA visualisation cells) ----------------

def _smooth_noise(rng, shape, scale=0.04, sigma=2.0):
    """Low-frequency perlin-like noise via blurred random field."""
    raw = rng.normal(0, scale, shape).astype(np.float32)
    return gaussian_filter(raw, sigma=sigma)


def _circular_mask(img_size, cx, cy, r):
    y, x = np.ogrid[:img_size, :img_size]
    return ((x - cx)**2 + (y - cy)**2) <= r**2


def _ellipse_mask(img_size, cx, cy, rx, ry, angle_deg=0):
    """Rotated ellipse mask."""
    y, x = np.ogrid[:img_size, :img_size]
    cos_a = np.cos(np.radians(angle_deg))
    sin_a = np.sin(np.radians(angle_deg))
    xr = cos_a * (x - cx) + sin_a * (y - cy)
    yr = -sin_a * (x - cx) + cos_a * (y - cy)
    return (xr**2 / rx**2 + yr**2 / ry**2) <= 1.0


EHR_FEATURE_NAMES = [
    "Age", "PackYears", "SUVmax",
    "NoduleDiam_mm", "SpiculationIdx", "PleuraTether",
    "FEV1_pct", "BMI", "CRP_mgL",
    "FamilyHx", "SmokingStatus", "Gender"
]

IMG_SIZE = 224

print("Helper functions & constants defined.")
Helper functions & constants defined.

Section 5 — Dataset & DataLoader Builders

Each Dataset loads pre-generated data from disk (.npz + .csv files produced by generate_dataset.py) and exposes the standard PyTorch interface.
metadata DataFrames are attached to ClsDataset and RegDataset for EDA convenience.

from torch.utils.data import Dataset, DataLoader
import pandas as pd

N_SAMPLES = 300
IMG_SIZE  = 224


class ClsDataset(Dataset):
    """CT + 12-feature EHR binary classification dataset (loaded from disk).

    self.metadata — pd.DataFrame with all EHR features + label for EDA.
    """
    def __init__(self, npz_prefix="cls", data_dir=DATA_DIR):
        img_data = np.load(data_dir / f"{npz_prefix}_images.npz")
        tab_data = np.load(data_dir / f"{npz_prefix}_tabular.npz")
        self.images  = torch.tensor(img_data["images"])           # (N, 1, H, W)
        self.tabular = torch.tensor(tab_data["tabular"])           # (N, 12)
        self.labels  = torch.tensor(tab_data["labels"])            # (N,)
        meta_path = data_dir / f"{npz_prefix}_metadata.csv"
        if meta_path.exists():
            self.metadata = pd.read_csv(meta_path)
        else:
            rows = []
            for i in range(len(self.labels)):
                row = dict(zip(EHR_FEATURE_NAMES, self.tabular[i].numpy()))
                row["Label"] = int(self.labels[i])
                row["LabelName"] = "Malignant" if row["Label"] == 1 else "Benign"
                rows.append(row)
            self.metadata = pd.DataFrame(rows)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.images[idx], self.tabular[idx], self.labels[idx]


class RegDataset(Dataset):
    """T1 + T2 MRI regression dataset (loaded from disk).

    self.scores — list of float scores for EDA.
    """
    def __init__(self, npz_prefix="reg", data_dir=DATA_DIR):
        t1_data   = np.load(data_dir / f"{npz_prefix}_t1.npz")
        t2_data   = np.load(data_dir / f"{npz_prefix}_t2.npz")
        sc_data   = np.load(data_dir / f"{npz_prefix}_scores.npz")
        self.t1     = torch.tensor(t1_data["t1"])             # (N, 1, H, W)
        self.t2     = torch.tensor(t2_data["t2"])             # (N, 1, H, W)
        self.score_tensor = torch.tensor(sc_data["scores"]).float().unsqueeze(1)  # (N, 1)
        self.scores = sc_data["scores"].tolist()

    def __len__(self):
        return len(self.scores)

    def __getitem__(self, idx):
        return self.t1[idx], self.t2[idx], self.score_tensor[idx]


class SegDataset(Dataset):
    """FLAIR + T1ce + lesion mask segmentation dataset (loaded from disk)."""
    def __init__(self, npz_prefix="seg", data_dir=DATA_DIR):
        fl_data = np.load(data_dir / f"{npz_prefix}_flair.npz")
        tc_data = np.load(data_dir / f"{npz_prefix}_t1ce.npz")
        mk_data = np.load(data_dir / f"{npz_prefix}_masks.npz")
        self.flair = torch.tensor(fl_data["flair"])           # (N, 1, H, W)
        self.t1ce  = torch.tensor(tc_data["t1ce"])            # (N, 1, H, W)
        self.masks = torch.tensor(mk_data["masks"])           # (N, 1, H, W)

    def __len__(self):
        return self.flair.shape[0]

    def __getitem__(self, idx):
        return self.flair[idx], self.t1ce[idx], self.masks[idx]


# ── Build datasets from saved files ───────────────────────────────────────────
print("Loading datasets from disk …")
cls_ds  = ClsDataset("cls")
reg_ds  = RegDataset("reg")
seg_ds  = SegDataset("seg")

cls_loader  = DataLoader(cls_ds, batch_size=16, shuffle=True,  num_workers=0)
reg_loader  = DataLoader(reg_ds, batch_size=16, shuffle=True,  num_workers=0)
seg_loader  = DataLoader(seg_ds, batch_size=8,  shuffle=True,  num_workers=0)

# XAI dataloaders (loaded from the separate xai subset files)
cls_xai_ds = ClsDataset("cls_xai")
reg_xai_ds = RegDataset("reg_xai")
seg_xai_ds = SegDataset("seg_xai")
cls_xai_loader = DataLoader(cls_xai_ds, batch_size=1, shuffle=False)
reg_xai_loader = DataLoader(reg_xai_ds, batch_size=1, shuffle=False)
seg_xai_loader = DataLoader(seg_xai_ds, batch_size=1, shuffle=False)

print(f"Datasets ready  —  cls: {len(cls_ds)} | reg: {len(reg_ds)} | seg: {len(seg_ds)} samples")
print(f"XAI subsets     —  cls: {len(cls_xai_ds)} | reg: {len(reg_xai_ds)} | seg: {len(seg_xai_ds)} samples")
print(f"Image resolution: {IMG_SIZE}x{IMG_SIZE}  |  EHR features: {len(EHR_FEATURE_NAMES)}")
Loading datasets from disk …
Datasets ready  —  cls: 1000 | reg: 1000 | seg: 1000 samples
XAI subsets     —  cls: 100 | reg: 100 | seg: 100 samples
Image resolution: 224x224  |  EHR features: 12

Section 6 — Exploratory Data Analysis (EDA)

Before modelling we inspect the three datasets across four lenses:

  1. Image montage — visual quality & modality contrast

  2. Intensity distributions — histogram per tissue/class

  3. Tabular EDA — class-stratified statistics, correlation heatmap

  4. Dataset-level statistics — label balance, lesion size distribution, score distribution


6.1 — Classification: CT + EHR

import matplotlib.patches as mpatches

# ── 6.1a  Image montage ─────────────────────────────────────────────────────
fig, axes = plt.subplots(2, 6, figsize=(18, 6))
# Gather 6 benign and 6 malignant samples from the dataset
benign_idxs  = [i for i in range(len(cls_ds)) if int(cls_ds.labels[i]) == 0][:6]
malig_idxs   = [i for i in range(len(cls_ds)) if int(cls_ds.labels[i]) == 1][:6]
for row, (idxs, lname) in enumerate([(benign_idxs, "Benign"), (malig_idxs, "Malignant")]):
    for k, idx in enumerate(idxs):
        img_t, _, _ = cls_ds[idx]
        axes[row, k].imshow(img_t.squeeze().numpy(), cmap="gray", vmin=0, vmax=1)
        axes[row, k].axis("off")
        if k == 0:
            axes[row, k].set_title(f"CT — {lname}", loc="left", fontweight="bold",
                                    fontsize=10, pad=4)

plt.suptitle("Task A · CT Image Montage  (top: Benign, bottom: Malignant)", fontsize=13, y=1.01)
plt.tight_layout()
plt.show()
<Figure size 1800x600 with 12 Axes>
# ── 6.1b  Intensity histograms per class ────────────────────────────────────
fig, axes = plt.subplots(1, 2, figsize=(13, 4))

all_pixels = {0: [], 1: []}
for idx in range(min(60, len(cls_ds))):
    img, _, lbl = cls_ds[idx]
    all_pixels[int(lbl.item())].append(img.numpy().ravel())

colors = {0: "#3498db", 1: "#e74c3c"}
labels = {0: "Benign", 1: "Malignant"}

for lbl, pix_list in all_pixels.items():
    arr = np.concatenate(pix_list)
    axes[0].hist(arr, bins=80, alpha=0.55, color=colors[lbl],
                 label=labels[lbl], density=True)

axes[0].set_xlabel("Normalised CT intensity")
axes[0].set_ylabel("Density")
axes[0].set_title("CT Pixel Intensity Distribution by Class")
axes[0].legend()

# ── 6.1c  EHR class statistics table ──────────────────────────────────────
meta = cls_ds.metadata
stats = meta.groupby("LabelName")[EHR_FEATURE_NAMES[:9]].agg(["mean", "std"])
stats.columns = [f"{f}_{s}" for f, s in stats.columns]
stats_T = stats.T
print(stats_T.to_string())

# Radar / bar comparison for top discriminative features
top_feats = ["Age", "PackYears", "SUVmax", "NoduleDiam_mm", "SpiculationIdx", "FEV1_pct"]
benign_means    = meta[meta.Label == 0][top_feats].mean()
malignant_means = meta[meta.Label == 1][top_feats].mean()

x = np.arange(len(top_feats))
w = 0.35
axes[1].bar(x - w/2, benign_means.values,    w, label="Benign",    color="#3498db", alpha=0.80)
axes[1].bar(x + w/2, malignant_means.values, w, label="Malignant", color="#e74c3c", alpha=0.80)
axes[1].set_xticks(x)
axes[1].set_xticklabels(top_feats, rotation=30, ha="right", fontsize=9)
axes[1].set_ylabel("Mean value")
axes[1].set_title("Key EHR Feature Means by Class")
axes[1].legend()

plt.tight_layout()
plt.show()
LabelName               Benign  Malignant
Age_mean             54.868460  64.990940
Age_std               7.974023   8.074204
PackYears_mean       20.361829  39.832493
PackYears_std         9.970879  12.313439
SUVmax_mean           2.496799   6.418143
SUVmax_std            0.818092   1.471868
NoduleDiam_mm_mean    7.091237  14.750829
NoduleDiam_mm_std     2.337710   4.899538
SpiculationIdx_mean   0.176652   0.722602
SpiculationIdx_std    0.128126   0.157995
PleuraTether_mean     0.090000   0.402000
PleuraTether_std      0.286468   0.490793
FEV1_pct_mean        82.224398  67.134077
FEV1_pct_std         12.241022  13.691634
BMI_mean             26.393236  24.011906
BMI_std               4.075355   3.994324
CRP_mgL_mean          2.923163   7.568775
CRP_mgL_std           2.915512   7.707013
<Figure size 1300x400 with 2 Axes>
# ── 6.1d  EHR correlation heatmap ───────────────────────────────────────────
fig, axes = plt.subplots(1, 2, figsize=(16, 5))

for ax, lbl, lname, cmap in zip(axes, [0, 1], ["Benign", "Malignant"],
                                  ["Blues", "Reds"]):
    subset = meta[meta.Label == lbl][EHR_FEATURE_NAMES]
    corr = subset.corr()
    mask = np.triu(np.ones_like(corr, dtype=bool))
    sns.heatmap(corr, ax=ax, mask=mask, cmap=cmap, vmin=-1, vmax=1,
                center=0, annot=True, fmt=".2f", linewidths=0.4,
                annot_kws={"size": 7})
    ax.set_title(f"EHR Feature Correlation — {lname}", fontsize=11)
    ax.tick_params(axis="x", rotation=45, labelsize=8)
    ax.tick_params(axis="y", labelsize=8)

plt.suptitle("Task A · EHR Intra-Class Correlations", fontsize=13)
plt.tight_layout()
plt.show()
<Figure size 1600x500 with 4 Axes>
# ── 6.1e  Per-feature violin plots ──────────────────────────────────────────
plot_feats = ["Age", "PackYears", "SUVmax", "NoduleDiam_mm",
              "SpiculationIdx", "FEV1_pct", "CRP_mgL", "BMI"]
fig, axes = plt.subplots(2, 4, figsize=(18, 7))

for ax, feat in zip(axes.ravel(), plot_feats):
    data_b = meta[meta.Label == 0][feat].values
    data_m = meta[meta.Label == 1][feat].values
    parts  = ax.violinplot([data_b, data_m], positions=[0, 1],
                            showmedians=True, showextrema=True)
    for pc, color in zip(parts["bodies"], ["#3498db", "#e74c3c"]):
        pc.set_facecolor(color)
        pc.set_alpha(0.65)
    ax.set_xticks([0, 1])
    ax.set_xticklabels(["Benign", "Malignant"], fontsize=9)
    ax.set_title(feat, fontsize=10, fontweight="bold")
    ax.set_ylabel("Value", fontsize=8)

plt.suptitle("Task A · EHR Feature Distributions (Benign vs Malignant)", fontsize=13)
plt.tight_layout()
plt.show()
<Figure size 1800x700 with 8 Axes>

6.2 — Regression: T1 + T2 MRI (EDSS Score)

# ── 6.2a  Multi-score image montage ─────────────────────────────────────────
# Pick samples spanning the score range
sorted_idxs = np.argsort(reg_ds.scores)
pick_positions = np.linspace(0, len(sorted_idxs) - 1, 6, dtype=int)
pick_idxs = [sorted_idxs[p] for p in pick_positions]

fig, axes = plt.subplots(2, len(pick_idxs), figsize=(18, 6))

for col, idx in enumerate(pick_idxs):
    t1, t2, sc = reg_ds[idx]
    score = reg_ds.scores[idx]
    axes[0, col].imshow(t1.squeeze().numpy(), cmap="gray", vmin=0, vmax=1)
    axes[0, col].set_title(f"T1\nscore={score:.2f}", fontsize=9)
    axes[0, col].axis("off")
    axes[1, col].imshow(t2.squeeze().numpy(), cmap="gray", vmin=0, vmax=1)
    axes[1, col].set_title(f"T2\nscore={score:.2f}", fontsize=9)
    axes[1, col].axis("off")

plt.suptitle("Task B · T1 (top) & T2 (bottom) across EDSS-like score levels\n"
             "Expected: T1 periventricular hypointensities grow; T2 WMH burden increases",
             fontsize=11, y=1.02)
plt.tight_layout()
plt.show()
<Figure size 1800x600 with 12 Axes>
# ── 6.2b  Score distribution ────────────────────────────────────────────────
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

scores_arr = np.array(reg_ds.scores)
axes[0].hist(scores_arr, bins=30, color="#2ecc71", edgecolor="white", linewidth=0.5)
axes[0].set_xlabel("EDSS-like score")
axes[0].set_ylabel("Count")
axes[0].set_title(f"Score Distribution\nMean={scores_arr.mean():.3f}  Std={scores_arr.std():.3f}")

# ── 6.2c  Mean T1/T2 intensity vs score (lesion area proxy) ─────────────────
n_probe = min(80, len(reg_ds))
probe_scores, mean_t1, mean_t2 = [], [], []
for i in range(n_probe):
    t1_img, t2_img, sc = reg_ds[i]
    probe_scores.append(sc.item())
    # Proxy: mean intensity in upper-central ROI (periventricular for T1)
    h, w = t1_img.shape[-2], t1_img.shape[-1]
    peri_t1 = t1_img[0, h//4: h//2, w//4: 3*w//4].mean().item()
    deep_t2 = t2_img[0, h//2: 3*h//4, w//4: 3*w//4].mean().item()
    mean_t1.append(peri_t1)
    mean_t2.append(deep_t2)

axes[1].scatter(probe_scores, mean_t1, alpha=0.55, s=18, color="#e74c3c", label="T1 periventr.")
z1 = np.polyfit(probe_scores, mean_t1, 1)
xs = np.linspace(0, 1, 50)
axes[1].plot(xs, np.polyval(z1, xs), color="#e74c3c", lw=2)
axes[1].set_xlabel("EDSS score"); axes[1].set_ylabel("ROI mean intensity")
axes[1].set_title("T1 Periventricular Intensity vs Score\n(expected: decreasing — T1 hypointensities)")
axes[1].legend()

axes[2].scatter(probe_scores, mean_t2, alpha=0.55, s=18, color="#3498db", label="T2 deep WM")
z2 = np.polyfit(probe_scores, mean_t2, 1)
axes[2].plot(xs, np.polyval(z2, xs), color="#3498db", lw=2)
axes[2].set_xlabel("EDSS score"); axes[2].set_ylabel("ROI mean intensity")
axes[2].set_title("T2 Deep WM Intensity vs Score\n(expected: increasing — WMH burden)")
axes[2].legend()

plt.tight_layout()
plt.show()
<Figure size 1500x400 with 3 Axes>

6.3 — Segmentation: FLAIR + T1ce (GBM-like Lesion)

# ── 6.3a  Multi-sample panel ─────────────────────────────────────────────────
n_show = 5
fig, axes = plt.subplots(4, n_show, figsize=(18, 12))
row_labels = ["FLAIR", "T1ce", "GT Mask", "Overlay"]

for col in range(n_show):
    flair, t1ce, mask = seg_ds[col]
    imgs = [flair.squeeze(), t1ce.squeeze(), mask.squeeze()]
    cmaps = ["gray", "gray", "Reds"]
    for row in range(3):
        axes[row, col].imshow(imgs[row].numpy(), cmap=cmaps[row], vmin=0, vmax=1)
        axes[row, col].axis("off")
        if col == 0:
            axes[row, col].set_ylabel(row_labels[row], fontsize=9, rotation=90,
                                       labelpad=4, fontweight="bold")
    # Overlay row
    axes[3, col].imshow(flair.squeeze().numpy(), cmap="gray", vmin=0, vmax=1)
    axes[3, col].imshow(mask.squeeze().numpy(), cmap="Reds", alpha=0.45, vmin=0, vmax=1)
    axes[3, col].axis("off")
    axes[3, col].set_title(f"Sample {col+1}", fontsize=9)
    if col == 0:
        axes[3, col].set_ylabel(row_labels[3], fontsize=9, rotation=90,
                                 labelpad=4, fontweight="bold")

plt.suptitle("Task C · FLAIR / T1ce / GT Mask / Overlay\n"
             "Note: ring-enhancing lesion (bright rim, dark core) + surrounding edema",
             fontsize=11, y=1.01)
plt.tight_layout()
plt.show()
<Figure size 1800x1200 with 20 Axes>
# ── 6.3b  Lesion statistics across dataset ──────────────────────────────────
n_stats = min(100, len(seg_ds))
lesion_sizes, flair_intensities, t1ce_ring_intensities = [], [], []

for i in range(n_stats):
    fl, t1, msk = seg_ds[i]
    msk_np = msk.squeeze().numpy()
    fl_np  = fl.squeeze().numpy()
    t1_np  = t1.squeeze().numpy()

    lesion_area = msk_np.sum()
    lesion_sizes.append(lesion_area)
    if lesion_area > 0:
        flair_intensities.append(fl_np[msk_np > 0.5].mean())
        # T1ce ring: pixels with high intensity inside mask
        t1ce_in_mask = t1_np[msk_np > 0.5]
        t1ce_ring_intensities.append(t1ce_in_mask[t1ce_in_mask > np.percentile(t1ce_in_mask, 60)].mean()
                                     if len(t1ce_in_mask) > 5 else np.nan)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].hist(lesion_sizes, bins=25, color="#9b59b6", edgecolor="white", linewidth=0.5)
axes[0].set_xlabel("Lesion area (pixels)")
axes[0].set_ylabel("Count")
axes[0].set_title(f"GT Lesion Size Distribution\nMean={np.mean(lesion_sizes):.0f} px  "
                  f"Std={np.std(lesion_sizes):.0f} px")

axes[1].hist(flair_intensities, bins=25, color="#e67e22", edgecolor="white", linewidth=0.5)
axes[1].axvline(np.mean(flair_intensities), color="black", lw=1.5, linestyle="--",
                label=f"mean={np.mean(flair_intensities):.3f}")
axes[1].set_xlabel("Mean FLAIR intensity inside mask")
axes[1].set_title("FLAIR Lesion Intensity\n(expected: bright ≈ 0.80)")
axes[1].legend()

valid_ring = [v for v in t1ce_ring_intensities if not np.isnan(v)]
axes[2].hist(valid_ring, bins=25, color="#1abc9c", edgecolor="white", linewidth=0.5)
axes[2].axvline(np.mean(valid_ring), color="black", lw=1.5, linestyle="--",
                label=f"mean={np.mean(valid_ring):.3f}")
axes[2].set_xlabel("T1ce ring intensity (top 40% within mask)")
axes[2].set_title("T1ce Ring-Enhancement Intensity\n(expected: bright ≈ 0.85)")
axes[2].legend()

plt.suptitle("Task C · Lesion Statistics Across 100 Samples", fontsize=13)
plt.tight_layout()
plt.show()
<Figure size 1500x400 with 3 Axes>
# ── 6.3c  Multi-compartment anatomy for one sample ──────────────────────────
# Use a specific sample from the dataset for detailed compartment analysis
flair_ex, t1ce_ex, mask_ex = seg_ds[42]

flair_np = flair_ex.squeeze().numpy()
t1ce_np  = t1ce_ex.squeeze().numpy()
mask_np  = mask_ex.squeeze().numpy()

# Build sub-compartment masks from intensity thresholds within GT mask
inside = mask_np > 0.5
if inside.sum() > 0:
    t1ce_inside = t1ce_np[inside]
    thresh_ring   = np.percentile(t1ce_inside, 60)
    thresh_necro  = np.percentile(t1ce_inside, 20)
    ring_px       = inside & (t1ce_np >= thresh_ring)
    necrosis_px   = inside & (t1ce_np <= thresh_necro)
    edema_px      = inside & ~ring_px & ~necrosis_px
else:
    ring_px = necrosis_px = edema_px = np.zeros_like(mask_np, dtype=bool)

# Color overlay: edema=yellow, ring=red, necrosis=blue
overlay = np.stack([flair_np, flair_np, flair_np], axis=-1)
overlay[edema_px]    = [0.90, 0.85, 0.10]   # yellow
overlay[ring_px]     = [0.90, 0.15, 0.15]   # red
overlay[necrosis_px] = [0.15, 0.35, 0.90]   # blue

fig, axes = plt.subplots(1, 5, figsize=(20, 4))
axes[0].imshow(flair_np,  cmap="gray");  axes[0].set_title("FLAIR");     axes[0].axis("off")
axes[1].imshow(t1ce_np,   cmap="gray");  axes[1].set_title("T1ce");      axes[1].axis("off")
axes[2].imshow(mask_np,   cmap="Reds");  axes[2].set_title("GT Mask");   axes[2].axis("off")
axes[3].imshow(overlay);                  axes[3].set_title("Compartments\nyellow=edema  red=rim  blue=necrosis")
axes[3].axis("off")

# Intensity profiles: horizontal line through lesion center
cy_les = int(np.where(mask_np.sum(axis=1) > 0)[0].mean()) if mask_np.sum() > 0 else IMG_SIZE // 2
axes[4].plot(flair_np[cy_les, :], color="#e67e22", label="FLAIR", lw=1.8)
axes[4].plot(t1ce_np[cy_les, :],  color="#3498db", label="T1ce",  lw=1.8)
axes[4].axhspan(0, 1, alpha=0.08, color="gray")
axes[4].fill_between(range(IMG_SIZE),
                      [0.0] * IMG_SIZE, [1.0] * IMG_SIZE,
                      where=mask_np[cy_les, :] > 0.5,
                      alpha=0.18, color="#e74c3c", label="Lesion region")
axes[4].set_xlabel("Pixel position (x)")
axes[4].set_ylabel("Normalised intensity")
axes[4].set_title(f"Intensity Profile — row {cy_les}")
axes[4].legend(fontsize=8)
axes[4].set_ylim(-0.05, 1.05)

plt.suptitle("Task C · Single-Sample Anatomy: Multi-compartment GBM-like Lesion",
             fontsize=12, y=1.02)
plt.tight_layout()
plt.show()
<Figure size 2000x400 with 5 Axes>

6.4 — Cross-Dataset Summary

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# ── Label balance (Classification) ──────────────────────────────────────────
label_counts = cls_ds.metadata["LabelName"].value_counts()
axes[0].bar(label_counts.index, label_counts.values,
             color=["#3498db", "#e74c3c"], alpha=0.85, edgecolor="white")
for i, (lbl, cnt) in enumerate(label_counts.items()):
    axes[0].text(i, cnt + 2, str(cnt), ha="center", fontsize=11, fontweight="bold")
axes[0].set_title("Task A · Class Balance")
axes[0].set_ylabel("Number of samples")
axes[0].set_ylim(0, label_counts.max() + 20)

# ── Score distribution (Regression) ─────────────────────────────────────────
axes[1].hist(reg_ds.scores, bins=30, color="#2ecc71", edgecolor="white", linewidth=0.5)
axes[1].set_xlabel("EDSS-like score (0=minimal, 1=severe)")
axes[1].set_ylabel("Count")
axes[1].set_title("Task B · Score Distribution\n(Beta(2,5) — skewed toward low disability)")

# ── Lesion area distribution (Segmentation) ──────────────────────────────────
n_eda = min(150, len(seg_ds))
areas = []
for i in range(n_eda):
    _, _, msk = seg_ds[i]
    areas.append(msk.squeeze().numpy().sum())
axes[2].hist(areas, bins=30, color="#9b59b6", edgecolor="white", linewidth=0.5)
axes[2].set_xlabel("Total lesion area (pixels)")
axes[2].set_ylabel("Count")
axes[2].set_title(f"Task C · GT Lesion Area Distribution\n(n={n_eda} samples)")

plt.suptitle("Cross-Dataset EDA Summary", fontsize=14, fontweight="bold")
plt.tight_layout()
plt.show()

# ── Print compact dataset card ───────────────────────────────────────────────
print("=" * 60)
print("DATASET CARD")
print("=" * 60)
print(f"Task A  CT + EHR Classification")
print(f"  Samples : {len(cls_ds)}  (50% benign / 50% malignant)")
print(f"  CT size : {IMG_SIZE}×{IMG_SIZE} px  |  EHR features: {len(EHR_FEATURE_NAMES)}")
print(f"  EHR mean (benign)   : age={cls_ds.metadata[cls_ds.metadata.Label==0]['Age'].mean():.1f}, "
      f"SUVmax={cls_ds.metadata[cls_ds.metadata.Label==0]['SUVmax'].mean():.2f}")
print(f"  EHR mean (malignant): age={cls_ds.metadata[cls_ds.metadata.Label==1]['Age'].mean():.1f}, "
      f"SUVmax={cls_ds.metadata[cls_ds.metadata.Label==1]['SUVmax'].mean():.2f}")
print()
print(f"Task B  T1+T2 MRI Regression")
print(f"  Samples : {len(reg_ds)}  |  Score: Beta(2,5)")
print(f"  Score stats: mean={np.mean(reg_ds.scores):.3f} ± {np.std(reg_ds.scores):.3f}  "
      f"[{np.min(reg_ds.scores):.3f}, {np.max(reg_ds.scores):.3f}]")
print()
print(f"Task C  FLAIR+T1ce Segmentation")
print(f"  Samples : {len(seg_ds)}")
print(f"  Lesion area: mean={np.mean(areas):.0f} ± {np.std(areas):.0f} px")
print("=" * 60)
<Figure size 1500x500 with 3 Axes>
============================================================
DATASET CARD
============================================================
Task A  CT + EHR Classification
  Samples : 1000  (50% benign / 50% malignant)
  CT size : 224×224 px  |  EHR features: 12
  EHR mean (benign)   : age=54.9, SUVmax=2.50
  EHR mean (malignant): age=65.0, SUVmax=6.42

Task B  T1+T2 MRI Regression
  Samples : 1000  |  Score: Beta(2,5)
  Score stats: mean=0.292 ± 0.157  [0.006, 0.844]

Task C  FLAIR+T1ce Segmentation
  Samples : 1000
  Lesion area: mean=676 ± 260 px
============================================================

TASK A — Multi-Modal Classification (Image + Tabular)

Clinical analogy: Fusing a CT scan with structured EHR data to predict pathology presence.

Architecture

Image Stream: Conv(1→16)→ReLU→MaxPool → Conv(16→32)→ReLU→MaxPool → AdaptiveAvgPool → Linear(512→64)→ReLU
Tabular Stream: Linear(12→32)→ReLU → Linear(32→32)→ReLU
Fusion: cat[64+32] → Linear(96→64)→ReLU → Linear(64→2)

Architecture:

Task A: FusionClassifier Architecture

Figure: Multi-modal classification architecture fusing CT images (CNN) and electronic health records (MLP). The diagram shows which XAI methods apply to each component: Integrated Gradients and Saliency for feature attribution, GradCAM for convolutional layers, and Occlusion for perturbation-based analysis at the fusion level.

class FusionClassifier(nn.Module):
    def __init__(self, n_tabular=12):
        super().__init__()
        # Image stream — store as sequential for LayerGradCAM access
        self.image_encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),  # 224→112
            nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), # 112→56
        )
        self.img_head = nn.Sequential(
            nn.AdaptiveAvgPool2d((4, 4)),
            nn.Flatten(),
            nn.Linear(32 * 4 * 4, 64), nn.ReLU(),
        )
        # Tabular stream
        self.tab_encoder = nn.Sequential(
            nn.Linear(n_tabular, 32), nn.ReLU(),
            nn.Linear(32, 32),        nn.ReLU(),
        )
        # Fusion head
        self.fusion = nn.Sequential(
            nn.Linear(64 + 32, 64), nn.ReLU(),
            nn.Linear(64, 2),
        )

    def forward(self, image, tabular):
        img_feat = self.img_head(self.image_encoder(image))
        tab_feat = self.tab_encoder(tabular)
        return self.fusion(torch.cat([img_feat, tab_feat], dim=1))


cls_model = FusionClassifier(n_tabular=len(EHR_FEATURE_NAMES)).to(DEVICE)
print("FusionClassifier:", sum(p.numel() for p in cls_model.parameters()), "parameters")
FusionClassifier: 45442 parameters
# [PRE-RUN] Training loop — Classification
from pathlib import Path

cls_model_path = Path("cls_model.pt")
if cls_model_path.exists():
    print(f"✓ Found {cls_model_path} — loading pre-trained weights (skipping training)")
    cls_model.load_state_dict(torch.load(cls_model_path, map_location=DEVICE))
else:
    print("Training Classification model (80 epochs)...")
    optimizer = torch.optim.Adam(cls_model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=80)

    cls_model.train()
    for epoch in range(80):
        total_loss, correct, total = 0.0, 0, 0
        for img, tab, label in cls_loader:
            img, tab, label = img.to(DEVICE), tab.to(DEVICE), label.to(DEVICE)
            optimizer.zero_grad()
            logits = cls_model(img, tab)
            loss = criterion(logits, label)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * len(label)
            correct += (logits.argmax(1) == label).sum().item()
            total += len(label)
        scheduler.step()
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1:3d} | loss={total_loss/total:.4f} | acc={correct/total:.1%} | lr={scheduler.get_last_lr()[0]:.5f}")

    torch.save(cls_model.state_dict(), "cls_model.pt")
    print("Saved cls_model.pt")
✓ Found cls_model.pt — loading pre-trained weights (skipping training)
# [LIVE] Integrated Gradients — Classification (Image + Tabular)
cls_model.eval()

# Grab one sample
img_xai, tab_xai, label_xai = next(iter(cls_xai_loader))
img_xai = img_xai.to(DEVICE).requires_grad_(True)
tab_xai = tab_xai.to(DEVICE).requires_grad_(True)

with torch.no_grad():
    pred_class = cls_model(img_xai, tab_xai).argmax(1).item()

ig = IntegratedGradients(cls_model)
attr_img_ig, attr_tab_ig = ig.attribute(
    inputs=(img_xai, tab_xai),
    baselines=(torch.zeros_like(img_xai), torch.zeros_like(tab_xai)),
    target=pred_class,
    n_steps=50,
    return_convergence_delta=False,
)

print(f"Predicted class: {'Malignant' if pred_class==1 else 'Benign'}")
print(f"True label:      {'Malignant' if label_xai.item()==1 else 'Benign'}")

# Spatial heatmap
plot_img_attr(img_xai, attr_img_ig, title=f"IG · CT Attribution (class={pred_class})")
plt.show()

# Tabular bar chart — PackYears, SUVmax, SpiculationIdx marked ★ should dominate
# Mark clinically most discriminative features with ★
FEAT_NAMES = [
    f"{n}" + (" ★" if n in ("PackYears", "SUVmax", "SpiculationIdx") else "")
    for n in EHR_FEATURE_NAMES
]
plot_tabular_attr(attr_tab_ig, feature_names=FEAT_NAMES)
plt.show()
Predicted class: Benign
True label:      Benign
<Figure size 1400x400 with 5 Axes>
<Figure size 800x300 with 1 Axes>
# [LIVE] DeepLIFT — Classification
dl = DeepLift(cls_model)
attr_img_dl, attr_tab_dl = dl.attribute(
    inputs=(img_xai, tab_xai),
    baselines=(torch.zeros_like(img_xai), torch.zeros_like(tab_xai)),
    target=pred_class,
)

fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].imshow(normalize_attr(attr_img_ig), cmap="hot"); axes[0].set_title("IG · Image"); axes[0].axis("off")
axes[1].imshow(normalize_attr(attr_img_dl), cmap="hot"); axes[1].set_title("DeepLIFT · Image"); axes[1].axis("off")
plt.suptitle("IG vs DeepLIFT — Spatial Attribution")
plt.tight_layout(); plt.show()

# Side-by-side tabular
fig, axes = plt.subplots(1, 2, figsize=(14, 3))
for ax, attr, name in zip(axes, [attr_tab_ig, attr_tab_dl], ["IG", "DeepLIFT"]):
    vals = attr.squeeze().detach().cpu().numpy()
    colors = ["#e74c3c" if v > 0 else "#3498db" for v in vals]
    ax.barh(FEAT_NAMES, vals, color=colors)
    ax.axvline(0, color="black", lw=0.8)
    ax.set_title(f"{name} · Tabular Attribution")
plt.tight_layout(); plt.show()
<Figure size 1000x400 with 2 Axes>
<Figure size 1400x300 with 2 Axes>
# [LIVE] GradientSHAP — Classification
gs = GradientShap(cls_model)
# Build stochastic baseline distribution (10 random samples per modality)
baseline_img_dist = torch.randn(10, *img_xai.shape[1:]).to(DEVICE)
baseline_tab_dist = torch.randn(10, *tab_xai.shape[1:]).to(DEVICE)

attr_img_gs, attr_tab_gs = gs.attribute(
    inputs=(img_xai, tab_xai),
    baselines=(baseline_img_dist, baseline_tab_dist),
    target=pred_class,
    n_samples=10,
    stdevs=0.1,
)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for ax, attr, name in zip(axes, [attr_img_ig, attr_img_dl, attr_img_gs], ["IG", "DeepLIFT", "GradSHAP"]):
    ax.imshow(normalize_attr(attr), cmap="hot")
    ax.set_title(name); ax.axis("off")
plt.suptitle("Image Attribution Comparison (Classification)")
plt.tight_layout(); plt.show()
<Figure size 1500x400 with 3 Axes>
# [LIVE] LayerGradCAM — Image Stream Only
# GradCAM targets the last conv layer of the image encoder

# Wrapper that passes only image (tabular is fixed via closure)
class ImageOnlyWrapper(nn.Module):
    def __init__(self, model, tab):
        super().__init__()
        self.model = model
        self.tab = tab
    def forward(self, image):
        return self.model(image, self.tab)

img_only_model = ImageOnlyWrapper(cls_model, tab_xai).to(DEVICE)
target_layer = cls_model.image_encoder[-1]  # last Conv2d block

gcam = LayerGradCam(img_only_model, target_layer)
cam = gcam.attribute(img_xai, target=pred_class)
cam_upsampled = LayerAttribution.interpolate(cam, img_xai.shape[2:])

fig, axes = plt.subplots(1, 4, figsize=(16, 4))
for ax, attr, name in zip(axes, [attr_img_ig, attr_img_dl, attr_img_gs, cam_upsampled],
                           ["IG", "DeepLIFT", "GradSHAP", "GradCAM"]):
    ax.imshow(img_xai.squeeze().detach().cpu().numpy(), cmap="gray")
    ax.imshow(normalize_attr(attr), cmap="hot", alpha=0.5)
    ax.set_title(name); ax.axis("off")
plt.suptitle("2×2 Attribution Comparison — Classification (Overlay)")
plt.tight_layout(); plt.show()

print("Note: GradCAM only attributes the image stream. Pair with tabular IG for full picture.")
<Figure size 1600x400 with 4 Axes>
Note: GradCAM only attributes the image stream. Pair with tabular IG for full picture.
# [LIVE] Occlusion Sensitivity — Classification
cls_wrapper = MultiInputWrapper(cls_model)

occ = Occlusion(cls_wrapper)
attr_img_occ, attr_tab_occ = occ.attribute(
    inputs=(img_xai, tab_xai),
    sliding_window_shapes=((1, 8, 8), (1,)),
    target=pred_class,
    strides=((1, 4, 4), (1,)),
    baselines=(0.0, 0.0),
)

fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].imshow(img_xai.squeeze().detach().cpu().numpy(), cmap="gray")
axes[0].imshow(normalize_attr(attr_img_occ), cmap="hot", alpha=0.5)
axes[0].set_title("Occlusion · Image"); axes[0].axis("off")

vals_occ = attr_tab_occ.squeeze().detach().cpu().numpy()
colors = ["#e74c3c" if v > 0 else "#3498db" for v in vals_occ]
axes[1].barh(FEAT_NAMES, vals_occ, color=colors)
axes[1].axvline(0, color="black", lw=0.8)
axes[1].set_title("Occlusion · Tabular")
plt.suptitle("Occlusion Sensitivity (Perturbation Baseline)")
plt.tight_layout(); plt.show()
<Figure size 1000x400 with 2 Axes>
# [LIVE] Task A — XAI Gallery: Benign vs Malignant grid (IG attribution)
# Shows 4 benign and 4 malignant examples so class-specific patterns are visible.

cls_model.eval()
ig_gallery = IntegratedGradients(cls_model)

n_each = 4
benign_xai_idx  = [i for i in range(len(cls_xai_ds)) if int(cls_xai_ds.labels[i]) == 0][:n_each]
malig_xai_idx   = [i for i in range(len(cls_xai_ds)) if int(cls_xai_ds.labels[i]) == 1][:n_each]
all_gallery_idx = benign_xai_idx + malig_xai_idx
row_labels      = ["Benign"] * n_each + ["Malignant"] * n_each
border_colors   = ["#3498db"] * n_each + ["#e74c3c"] * n_each

fig, axes = plt.subplots(len(all_gallery_idx), 3, figsize=(13, 3 * len(all_gallery_idx)))
col_titles = ["CT Input", "IG Attribution", "Overlay (IG)"]
for col, ttl in enumerate(col_titles):
    axes[0, col].set_title(ttl, fontsize=11, fontweight="bold")

for row, (idx, lname, bcolor) in enumerate(zip(all_gallery_idx, row_labels, border_colors)):
    img_g, tab_g, lbl_g = cls_xai_ds[idx]
    img_g = img_g.unsqueeze(0).to(DEVICE).requires_grad_(True)
    tab_g = tab_g.unsqueeze(0).to(DEVICE).requires_grad_(True)

    with torch.no_grad():
        pred_g = cls_model(img_g, tab_g).argmax(1).item()

    attr_img_g, _ = ig_gallery.attribute(
        inputs=(img_g, tab_g),
        baselines=(torch.zeros_like(img_g), torch.zeros_like(tab_g)),
        target=pred_g, n_steps=50, return_convergence_delta=False,
    )
    img_np  = img_g.squeeze().detach().cpu().numpy()
    attr_np = normalize_attr(attr_img_g)
    pred_lbl = "Malignant" if pred_g == 1 else "Benign"
    true_lbl = "Malignant" if lbl_g.item() == 1 else "Benign"

    axes[row, 0].imshow(img_np, cmap="gray")
    axes[row, 0].axis("off")
    axes[row, 0].set_ylabel(f"{lname}\n(true={true_lbl}, pred={pred_lbl})",
                             fontsize=8, rotation=0, labelpad=90, va="center",
                             color=bcolor, fontweight="bold")
    for spine in axes[row, 0].spines.values():
        spine.set_edgecolor(bcolor); spine.set_linewidth(2)

    im_attr = axes[row, 1].imshow(attr_np, cmap="hot", vmin=0, vmax=1)
    axes[row, 1].axis("off")
    fig.colorbar(im_attr, ax=axes[row, 1], fraction=0.046, pad=0.04).set_label("Importance", fontsize=7)

    axes[row, 2].imshow(img_np, cmap="gray")
    im_ov = axes[row, 2].imshow(attr_np, cmap="hot", alpha=0.55, vmin=0, vmax=1)
    axes[row, 2].axis("off")
    fig.colorbar(im_ov, ax=axes[row, 2], fraction=0.046, pad=0.04).set_label("Importance", fontsize=7)

fig.suptitle(
    "Task A · IG Attribution Gallery — Top 4 Benign (blue) / Bottom 4 Malignant (red)\n"
    "Expected: malignant samples show nodule-region focus; benign samples show diffuse or absent signal",
    fontsize=12, y=1.01,
)
plt.tight_layout()
plt.show()
<Figure size 1300x2400 with 40 Axes>

TASK B — Multi-Modal Regression (Image + Image)

Clinical analogy: Combining T1 and T2 MRI sequences to predict a continuous biomarker (e.g., tumor volume, EDSS score).

  • T1 top-left quadrant signal scales with score → IG should highlight top-left in T1

  • T2 bottom-right quadrant signal scales with score → IG should highlight bottom-right in T2

For regression, Captum uses target=0 to select the single scalar output neuron.

Architecture:

Task B: FusionRegressor Architecture

Figure: Dual-stream regression architecture fusing T1 and T2 MRI sequences to predict disease severity scores. Both encoders share identical CNN structure. The diagram annotates applicable XAI methods: Integrated Gradients and DeepLIFT for gradient-based attribution, and Feature Ablation for modality importance quantification.

class FusionRegressor(nn.Module):
    def __init__(self):
        super().__init__()
        def _stream():
            return nn.Sequential(
                nn.Conv2d(1, 16, 3, padding=1), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(2),
                nn.Conv2d(16, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),
                nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(),
                nn.Linear(32, 64), nn.ReLU(),
            )

        self.t1_stream = _stream()
        self.t2_stream = _stream()
        self.fusion = nn.Sequential(
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, 1),
        )

    def forward(self, t1, t2):
        return self.fusion(torch.cat([self.t1_stream(t1), self.t2_stream(t2)], dim=1))


reg_model = FusionRegressor().to(DEVICE)
print("FusionRegressor:", sum(p.numel() for p in reg_model.parameters()), "parameters")
FusionRegressor: 22337 parameters
# [PRE-RUN] Training loop — Regression
reg_model_path = Path("reg_model.pt")
if reg_model_path.exists():
    print(f"✓ Found {reg_model_path} — loading pre-trained weights (skipping training)")
    reg_model.load_state_dict(torch.load(reg_model_path, map_location=DEVICE))
else:
    print("Training Regression model (100 epochs)...")
    reg_optimizer = torch.optim.Adam(reg_model.parameters(), lr=1e-3)
    reg_criterion = nn.MSELoss()
    reg_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(reg_optimizer, T_max=100)

    reg_model.train()
    for epoch in range(100):
        total_loss, total = 0.0, 0
        for t1, t2, score in reg_loader:
            t1, t2, score = t1.to(DEVICE), t2.to(DEVICE), score.to(DEVICE)
            reg_optimizer.zero_grad()
            pred = reg_model(t1, t2)
            loss = reg_criterion(pred, score)
            loss.backward()
            reg_optimizer.step()
            total_loss += loss.item() * len(score)
            total += len(score)
        reg_scheduler.step()
        if (epoch + 1) % 20 == 0:
            rmse = (total_loss / total) ** 0.5
            print(f"Epoch {epoch+1:3d} | RMSE={rmse:.4f} | lr={reg_scheduler.get_last_lr()[0]:.5f}")

    torch.save(reg_model.state_dict(), "reg_model.pt")
    print("Saved reg_model.pt")
✓ Found reg_model.pt — loading pre-trained weights (skipping training)
# [LIVE] IG + Modality Importance — Regression
reg_model.eval()

t1_xai, t2_xai, score_xai = next(iter(reg_xai_loader))
t1_xai = t1_xai.to(DEVICE).requires_grad_(True)
t2_xai = t2_xai.to(DEVICE).requires_grad_(True)

ig_reg = IntegratedGradients(reg_model)
attr_t1, attr_t2 = ig_reg.attribute(
    inputs=(t1_xai, t2_xai),
    baselines=(torch.zeros_like(t1_xai), torch.zeros_like(t2_xai)),
    target=0,   # single scalar output → neuron index 0
    n_steps=50,
)

# Verify spatial regions
fig, axes = plt.subplots(2, 3, figsize=(14, 8))
axes[0, 0].imshow(t1_xai.squeeze().detach().cpu().numpy(), cmap="gray")
axes[0, 0].set_title("T1 input"); axes[0, 0].axis("off")
axes[0, 1].imshow(normalize_attr(attr_t1), cmap="hot")
axes[0, 1].set_title("IG · T1 (expected: top-left)"); axes[0, 1].axis("off")
axes[0, 2].imshow(t1_xai.squeeze().detach().cpu().numpy(), cmap="gray")
axes[0, 2].imshow(normalize_attr(attr_t1), cmap="hot", alpha=0.5)
axes[0, 2].set_title("T1 overlay"); axes[0, 2].axis("off")

axes[1, 0].imshow(t2_xai.squeeze().detach().cpu().numpy(), cmap="gray")
axes[1, 0].set_title("T2 input"); axes[1, 0].axis("off")
axes[1, 1].imshow(normalize_attr(attr_t2), cmap="hot")
axes[1, 1].set_title("IG · T2 (expected: bottom-right)"); axes[1, 1].axis("off")
axes[1, 2].imshow(t2_xai.squeeze().detach().cpu().numpy(), cmap="gray")
axes[1, 2].imshow(normalize_attr(attr_t2), cmap="hot", alpha=0.5)
axes[1, 2].set_title("T2 overlay"); axes[1, 2].axis("off")

plt.suptitle(f"IG Regression (score={score_xai.item():.3f})")
plt.tight_layout(); plt.show()

# Modality importance
imp_t1 = attr_t1.abs().mean().item()
imp_t2 = attr_t2.abs().mean().item()
total_imp = imp_t1 + imp_t2
print(f"T1 contribution: {imp_t1/total_imp:.1%}  |  T2 contribution: {imp_t2/total_imp:.1%}")
plot_modality_importance({"T1 MRI": imp_t1, "T2 MRI": imp_t2})
plt.show()
<Figure size 1400x800 with 6 Axes>
T1 contribution: 21.4%  |  T2 contribution: 78.6%
<Figure size 500x500 with 1 Axes>
# [LIVE] DeepLIFT & Saliency comparison — Regression
dl_reg  = DeepLift(reg_model)
sal_reg = Saliency(reg_model)

attr_t1_dl, attr_t2_dl = dl_reg.attribute(
    inputs=(t1_xai, t2_xai),
    baselines=(torch.zeros_like(t1_xai), torch.zeros_like(t2_xai)),
    target=0,
)
attr_t1_sal, attr_t2_sal = sal_reg.attribute(
    inputs=(t1_xai, t2_xai),
    target=0,
    abs=False,
)

methods = ["IG", "DeepLIFT", "Saliency"]
attrs_t1 = [attr_t1, attr_t1_dl, attr_t1_sal]
attrs_t2 = [attr_t2, attr_t2_dl, attr_t2_sal]

fig, axes = plt.subplots(2, 3, figsize=(14, 8))
for col, (name, a1, a2) in enumerate(zip(methods, attrs_t1, attrs_t2)):
    axes[0, col].imshow(normalize_attr(a1), cmap="hot")
    axes[0, col].set_title(f"{name} · T1"); axes[0, col].axis("off")
    axes[1, col].imshow(normalize_attr(a2), cmap="hot")
    axes[1, col].set_title(f"{name} · T2"); axes[1, col].axis("off")
plt.suptitle("IG / DeepLIFT / Saliency — T1 (top) vs T2 (bottom)")
plt.tight_layout(); plt.show()
<Figure size 1400x800 with 6 Axes>
# [LIVE] Feature Ablation — Modality-Level Importance
reg_wrapper = MultiInputWrapper(reg_model)
fa = FeatureAblation(reg_wrapper)

# Assign all T1 pixels to group 1, all T2 pixels to group 2
mask_t1 = torch.ones_like(t1_xai, dtype=torch.long)          # group 1
mask_t2 = torch.ones_like(t2_xai, dtype=torch.long) * 2      # group 2

fa_attr_t1, fa_attr_t2 = fa.attribute(
    inputs=(t1_xai, t2_xai),
    feature_mask=(mask_t1, mask_t2),
    baselines=(torch.zeros_like(t1_xai), torch.zeros_like(t2_xai)),
    target=0,
)

with torch.no_grad():
    full_pred  = reg_model(t1_xai, t2_xai).item()
    zero_t1    = reg_model(torch.zeros_like(t1_xai), t2_xai).item()
    zero_t2    = reg_model(t1_xai, torch.zeros_like(t2_xai)).item()

print(f"Full prediction:    {full_pred:.4f}")
print(f"Ablate T1 → zero:  {zero_t1:.4f}  (drop: {full_pred - zero_t1:+.4f})")
print(f"Ablate T2 → zero:  {zero_t2:.4f}  (drop: {full_pred - zero_t2:+.4f})")

# Modality importance via ablation
drop_t1 = abs(full_pred - zero_t1)
drop_t2 = abs(full_pred - zero_t2)
plot_modality_importance({"T1 (ablated)": drop_t1, "T2 (ablated)": drop_t2},
                         title="Modality Importance via Feature Ablation")
plt.show()
Full prediction:    0.5219
Ablate T1 → zero:  0.5134  (drop: +0.0085)
Ablate T2 → zero:  -0.1432  (drop: +0.6651)
<Figure size 500x500 with 1 Axes>
# [LIVE] Task B — XAI Gallery: IG across score spectrum
# Pick 6 samples spanning low→high EDSS-like scores and show T1/T2 IG attributions.

reg_model.eval()
ig_reg_gallery = IntegratedGradients(reg_model)

# Sort XAI subset by predicted score and pick 6 evenly spaced samples
reg_xai_scores_list = []
for i in range(len(reg_xai_ds)):
    t1_tmp, t2_tmp, sc_tmp = reg_xai_ds[i]
    reg_xai_scores_list.append((i, sc_tmp.item()))

reg_xai_scores_list.sort(key=lambda x: x[1])
pick_pos    = np.linspace(0, len(reg_xai_scores_list) - 1, 6, dtype=int)
gallery_reg = [reg_xai_scores_list[p] for p in pick_pos]

fig, axes = plt.subplots(6, 6, figsize=(22, 22))
col_titles = ["T1 Input", "T1 IG", "T1 Overlay", "T2 Input", "T2 IG", "T2 Overlay"]
for col, ttl in enumerate(col_titles):
    axes[0, col].set_title(ttl, fontsize=10, fontweight="bold")

for row, (idx, score_val) in enumerate(gallery_reg):
    t1_g, t2_g, sc_g = reg_xai_ds[idx]
    t1_g = t1_g.unsqueeze(0).to(DEVICE).requires_grad_(True)
    t2_g = t2_g.unsqueeze(0).to(DEVICE).requires_grad_(True)

    attr_t1_g, attr_t2_g = ig_reg_gallery.attribute(
        inputs=(t1_g, t2_g),
        baselines=(torch.zeros_like(t1_g), torch.zeros_like(t2_g)),
        target=0, n_steps=50,
    )

    t1_np   = t1_g.squeeze().detach().cpu().numpy()
    t2_np   = t2_g.squeeze().detach().cpu().numpy()
    attr_t1 = normalize_attr(attr_t1_g)
    attr_t2 = normalize_attr(attr_t2_g)

    # T1 columns
    axes[row, 0].imshow(t1_np, cmap="gray"); axes[row, 0].axis("off")
    im1 = axes[row, 1].imshow(attr_t1, cmap="hot", vmin=0, vmax=1); axes[row, 1].axis("off")
    fig.colorbar(im1, ax=axes[row, 1], fraction=0.046, pad=0.04).set_label("Imp.", fontsize=6)
    axes[row, 2].imshow(t1_np, cmap="gray")
    im12 = axes[row, 2].imshow(attr_t1, cmap="hot", alpha=0.55, vmin=0, vmax=1); axes[row, 2].axis("off")
    fig.colorbar(im12, ax=axes[row, 2], fraction=0.046, pad=0.04).set_label("Imp.", fontsize=6)

    # T2 columns
    axes[row, 3].imshow(t2_np, cmap="gray"); axes[row, 3].axis("off")
    im2 = axes[row, 4].imshow(attr_t2, cmap="hot", vmin=0, vmax=1); axes[row, 4].axis("off")
    fig.colorbar(im2, ax=axes[row, 4], fraction=0.046, pad=0.04).set_label("Imp.", fontsize=6)
    axes[row, 5].imshow(t2_np, cmap="gray")
    im22 = axes[row, 5].imshow(attr_t2, cmap="hot", alpha=0.55, vmin=0, vmax=1); axes[row, 5].axis("off")
    fig.colorbar(im22, ax=axes[row, 5], fraction=0.046, pad=0.04).set_label("Imp.", fontsize=6)

    axes[row, 0].set_ylabel(f"score={score_val:.3f}", fontsize=9, rotation=0, labelpad=68, va="center")

fig.suptitle(
    "Task B · IG Gallery — T1 (cols 1–3) and T2 (cols 4–6) across EDSS-like score range\n"
    "Expected: low scores → sparse attribution; high scores → T1 top-left + T2 bottom-right focus",
    fontsize=12, y=1.01,
)
plt.tight_layout()
plt.show()
<Figure size 2200x2200 with 60 Axes>

TASK C — Multi-Modal Segmentation (Image + Image → Pixel Mask)

Clinical analogy: Using FLAIR + T1ce sequences to segment brain tumor lesions (BraTS-inspired).

Key XAI challenge for segmentation: Captum needs a scalar output. We use SegReductionWrapper to collapse (B, 1, H, W)(B, 1) via spatial mean, enabling gradient flow through a single target neuron.

Expected result: IG attribution for FLAIR should concentrate on the bright elliptical lesion region; T1ce attribution should show the same lesion at lower intensity.

Architecture:

Task C: FusionUNet Architecture

Figure: Dual-encoder U-Net architecture for multi-modal segmentation, fusing FLAIR and T1ce MRI sequences. Features from both encoders merge at the bottleneck, then decode with skip connections from the FLAIR pathway. The diagram shows applicable XAI methods: GradCAM for branch-specific feature visualization, and Integrated Gradients (via SegReductionWrapper) for pixel-level attribution.

class FusionUNet(nn.Module):
    """Lightweight dual-encoder U-Net for two-modality segmentation."""

    def __init__(self):
        super().__init__()
        # FLAIR encoder
        self.flair_enc1 = nn.Sequential(nn.Conv2d(1, 16, 3, padding=1), nn.ReLU())
        self.flair_pool1 = nn.MaxPool2d(2)
        self.flair_enc2 = nn.Sequential(nn.Conv2d(16, 16, 3, padding=1), nn.ReLU())
        self.flair_pool2 = nn.MaxPool2d(2)
        self.flair_last_conv = self.flair_enc2  # reference for GradCAM

        # T1ce encoder (identical structure)
        self.t1ce_enc1 = nn.Sequential(nn.Conv2d(1, 16, 3, padding=1), nn.ReLU())
        self.t1ce_pool1 = nn.MaxPool2d(2)
        self.t1ce_enc2 = nn.Sequential(nn.Conv2d(16, 16, 3, padding=1), nn.ReLU())
        self.t1ce_pool2 = nn.MaxPool2d(2)
        self.t1ce_last_conv = self.t1ce_enc2  # reference for GradCAM

        # Bottleneck fusion
        self.bottleneck = nn.Sequential(nn.Conv2d(32, 32, 3, padding=1), nn.ReLU())

        # Decoder
        self.up1 = nn.ConvTranspose2d(32, 16, 2, stride=2)
        self.dec1 = nn.Sequential(nn.Conv2d(32, 16, 3, padding=1), nn.ReLU())  # skip from flair_enc1
        self.up2 = nn.ConvTranspose2d(16, 8, 2, stride=2)
        self.dec2 = nn.Sequential(nn.Conv2d(8, 8, 3, padding=1), nn.ReLU())
        self.out_conv = nn.Conv2d(8, 1, 1)   # logits (no sigmoid — use BCEWithLogitsLoss)

    def forward(self, flair, t1ce):
        # FLAIR branch
        f1 = self.flair_enc1(flair)
        f2 = self.flair_enc2(self.flair_pool1(f1))
        fb = self.flair_pool2(f2)

        # T1ce branch
        t1 = self.t1ce_enc1(t1ce)
        t2 = self.t1ce_enc2(self.t1ce_pool1(t1))
        tb = self.t1ce_pool2(t2)

        # Fusion at bottleneck
        x = self.bottleneck(torch.cat([fb, tb], dim=1))

        # Decoder with skip from flair encoder level 1
        x = self.up1(x)
        x = self.dec1(torch.cat([x, f2], dim=1))
        x = self.up2(x)
        x = self.dec2(x)
        return self.out_conv(x)   # (B, 1, H, W) logits

seg_model = FusionUNet()

seg_model_path = Path("seg_model.pt")
if seg_model_path.exists():
    print(f"✓ Found {seg_model_path} — loading pre-trained weights (skipping training)")
    seg_model.load_state_dict(torch.load(seg_model_path, map_location=DEVICE))
else:
    print("Training Segmentation model (80 epochs)...")
    seg_optimizer = torch.optim.Adam(seg_model.parameters(), lr=1e-3)
    seg_criterion = nn.BCEWithLogitsLoss()
    seg_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(seg_optimizer, T_max=80)

    seg_model.train()
    for epoch in range(80):
        total_loss, total = 0.0, 0
        for flair_b, t1ce_b, mask_b in seg_loader:
            flair_b, t1ce_b, mask_b = flair_b.to(DEVICE), t1ce_b.to(DEVICE), mask_b.to(DEVICE)
            seg_optimizer.zero_grad()
            logits = seg_model(flair_b, t1ce_b)
            loss = seg_criterion(logits, mask_b)
            loss.backward()
            seg_optimizer.step()
            total_loss += loss.item() * len(mask_b)
            total += len(mask_b)
        seg_scheduler.step()
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1:3d} | loss={total_loss/total:.4f} | lr={seg_scheduler.get_last_lr()[0]:.5f}")

    # Dice coefficient on training set
    seg_model.eval()
    dice_scores = []
    with torch.no_grad():
        for flair_b, t1ce_b, mask_b in seg_loader:
            flair_b, t1ce_b, mask_b = flair_b.to(DEVICE), t1ce_b.to(DEVICE), mask_b.to(DEVICE)
            pred = torch.sigmoid(seg_model(flair_b, t1ce_b)) > 0.5
            intersection = (pred * mask_b.bool()).sum().float()
            dice = 2 * intersection / (pred.sum() + mask_b.sum() + 1e-8)
            dice_scores.append(dice.item())
    print(f"Mean Dice: {np.mean(dice_scores):.4f}")

    torch.save(seg_model.state_dict(), "seg_model.pt")
    intersection = (pred * mask_b.bool()).sum().float()
    dice = 2 * intersection / (pred.sum() + mask_b.sum() + 1e-8)
    dice_scores.append(dice.item())
    print(f"Mean Dice: {np.mean(dice_scores):.4f}")

    torch.save(seg_model.state_dict(), "seg_model.pt")
    print("Saved seg_model.pt")
✓ Found seg_model.pt — loading pre-trained weights (skipping training)
# SegXAIWrapper & lesion centroid helper
class SegXAIWrapper(nn.Module):
    """Reduces (B,1,H,W) segmentation output to (B,1) scalar for Captum."""
    def __init__(self, model, reduction="mean"):
        super().__init__()
        self.model = model
        self.reduction = reduction

    def forward(self, flair, t1ce):
        out = torch.sigmoid(self.model(flair, t1ce))   # (B, 1, H, W)
        if self.reduction == "mean":
            return out.mean(dim=(2, 3))                # (B, 1)
        return out.sum(dim=(2, 3))


def get_lesion_centroid(pred_mask):
    """Return (row, col) centroid of predicted lesion, or None if empty."""
    coords = pred_mask.squeeze().nonzero(as_tuple=False).float()
    if len(coords) == 0:
        return None
    return coords.mean(0).long().tolist()


seg_xai_model = SegXAIWrapper(seg_model).to(DEVICE)
print("SegXAIWrapper and centroid helper ready.")
SegXAIWrapper and centroid helper ready.
# [LIVE] IG — Segmentation (spatial attribution per modality)
seg_model.eval()

flair_xai, t1ce_xai, mask_xai = next(iter(seg_xai_loader))
flair_xai = flair_xai.to(DEVICE).requires_grad_(True)
t1ce_xai  = t1ce_xai.to(DEVICE).requires_grad_(True)

ig_seg = IntegratedGradients(seg_xai_model)
attr_flair, attr_t1ce = ig_seg.attribute(
    inputs=(flair_xai, t1ce_xai),
    baselines=(torch.zeros_like(flair_xai), torch.zeros_like(t1ce_xai)),
    target=0,
    n_steps=50,
)

# Predicted mask
with torch.no_grad():
    pred_mask = (torch.sigmoid(seg_model(flair_xai, t1ce_xai)) > 0.5).float()

centroid = get_lesion_centroid(pred_mask)
print(f"Predicted lesion centroid: {centroid}")
print(f"FLAIR attr max region: {normalize_attr(attr_flair).argmax()}")

# Modality importance for segmentation
imp_flair = attr_flair.abs().mean().item()
imp_t1ce  = attr_t1ce.abs().mean().item()
total_seg = imp_flair + imp_t1ce
print(f"FLAIR contribution: {imp_flair/total_seg:.1%}  |  T1ce contribution: {imp_t1ce/total_seg:.1%}")
Predicted lesion centroid: [121, 81]
FLAIR attr max region: 27629
FLAIR contribution: 25.7%  |  T1ce contribution: 74.3%
# [LIVE] LayerGradCAM — Dual Encoder (FLAIR + T1ce branches)
#
# seg_model outputs (B, 1, H, W) — LayerGradCam needs a scalar target.
# We pass seg_xai_model (SegReductionWrapper → output=(B,1)) and target=0.

# FLAIR encoder GradCAM — pass t1ce as additional arg
gcam_flair = LayerGradCam(seg_xai_model, seg_model.flair_last_conv)
cam_flair = gcam_flair.attribute(flair_xai, additional_forward_args=(t1ce_xai,), target=0)
cam_flair_up = LayerAttribution.interpolate(cam_flair, flair_xai.shape[2:])

# T1ce encoder GradCAM — pass flair as additional arg
gcam_t1ce = LayerGradCam(seg_xai_model, seg_model.t1ce_last_conv)
cam_t1ce = gcam_t1ce.attribute(t1ce_xai, additional_forward_args=(flair_xai,), target=0)
cam_t1ce_up = LayerAttribution.interpolate(cam_t1ce, t1ce_xai.shape[2:])

fig, axes = plt.subplots(1, 4, figsize=(16, 4))
ims = [flair_xai, t1ce_xai, cam_flair_up, cam_t1ce_up]
titles = ["FLAIR input", "T1ce input", "GradCAM · FLAIR enc", "GradCAM · T1ce enc"]
for ax, im, title in zip(axes, ims, titles):
    base = im.squeeze().detach().cpu().abs().numpy()
    ax.imshow(base, cmap="hot"); ax.set_title(title); ax.axis("off")
plt.suptitle("LayerGradCAM — Dual Encoder Attribution")
plt.tight_layout(); plt.show()
<Figure size 1600x400 with 4 Axes>
# [LIVE] Positive/Negative Attribution Decomposition

pos_attr_flair = torch.clamp(attr_flair, min=0)
neg_attr_flair = torch.clamp(attr_flair, max=0).abs()

fig, axes = plt.subplots(1, 4, figsize=(16, 4))
axes[0].imshow(flair_xai.squeeze().detach().cpu().numpy(), cmap="gray")
axes[0].set_title("FLAIR input"); axes[0].axis("off")

axes[1].imshow(normalize_attr(pos_attr_flair), cmap="Reds")
axes[1].set_title("Positive attr (FOR lesion)"); axes[1].axis("off")

axes[2].imshow(normalize_attr(neg_attr_flair), cmap="Blues")
axes[2].set_title("Negative attr (AGAINST lesion)"); axes[2].axis("off")

# Combined: RdBu diverging
combined = attr_flair.squeeze().detach().cpu().numpy()
vmax = np.abs(combined).max()
axes[3].imshow(combined, cmap="RdBu", vmin=-vmax, vmax=vmax)
axes[3].set_title("Combined (RdBu)"); axes[3].axis("off")

plt.suptitle("Positive vs Negative Attribution Decomposition — FLAIR")
plt.tight_layout(); plt.show()
<Figure size 1600x400 with 4 Axes>
# [LIVE] Cross-Modal Consistency — Dice Overlap (XAI mask vs GT mask)

attr_norm = normalize_attr(attr_flair)
# Threshold at 50% of max
xai_mask = (attr_norm > 0.5 * attr_norm.max()).astype(np.float32)
gt_mask  = mask_xai.squeeze().numpy()

intersection = (xai_mask * gt_mask).sum()
dice_xai = 2 * intersection / (xai_mask.sum() + gt_mask.sum() + 1e-8)
print(f"Dice (IG attribution mask vs GT lesion mask): {dice_xai:.4f}")
print("Interpretation: higher Dice → attribution map clinically plausible")

fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(gt_mask, cmap="Reds");  axes[0].set_title("GT Lesion Mask"); axes[0].axis("off")
axes[1].imshow(xai_mask, cmap="Blues"); axes[1].set_title("IG Attribution Mask (threshold 0.5)"); axes[1].axis("off")
overlap = np.stack([gt_mask, xai_mask, np.zeros_like(gt_mask)], axis=-1)
axes[2].imshow(overlap); axes[2].set_title(f"Overlap (Dice={dice_xai:.3f})\nRed=GT  Green=XAI"); axes[2].axis("off")
plt.tight_layout(); plt.show()
Dice (IG attribution mask vs GT lesion mask): 0.0856
Interpretation: higher Dice → attribution map clinically plausible
<Figure size 1200x400 with 3 Axes>
# Full segmentation visualization panel
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

# Row 1 — inputs + masks
inputs_row = [
    (flair_xai.squeeze().detach().cpu().numpy(), "gray", "FLAIR"),
    (t1ce_xai.squeeze().detach().cpu().numpy(),  "gray", "T1ce"),
    (mask_xai.squeeze().numpy(),                  "Reds", "GT Mask"),
    (pred_mask.squeeze().detach().cpu().numpy(),  "Reds", "Predicted Mask"),
]
for ax, (im, cmap, title) in zip(axes[0], inputs_row):
    ax.imshow(im, cmap=cmap); ax.set_title(title); ax.axis("off")

# Row 2 — attribution maps
attr_row = [
    (flair_xai, attr_flair, "IG · FLAIR"),
    (t1ce_xai,  attr_t1ce,  "IG · T1ce"),
    (flair_xai, cam_flair_up, "GradCAM · FLAIR enc"),
    (t1ce_xai,  cam_t1ce_up,  "GradCAM · T1ce enc"),
]
for ax, (base_img, attr, title) in zip(axes[1], attr_row):
    ax.imshow(base_img.squeeze().detach().cpu().numpy(), cmap="gray")
    ax.imshow(normalize_attr(attr), cmap="hot", alpha=0.5)
    ax.set_title(title); ax.axis("off")

plt.suptitle("Segmentation XAI Panel — Row 1: Inputs | Row 2: Attributions", fontsize=13)
plt.tight_layout(); plt.show()
<Figure size 1600x800 with 8 Axes>
# [LIVE] Task C — XAI Gallery: IG across 6 segmentation cases
# Columns: FLAIR | FLAIR IG | T1ce | T1ce IG | GT Mask | Pred Mask

seg_model.eval()
ig_seg_gallery = IntegratedGradients(seg_xai_model)

n_gallery_seg = 6

fig, axes = plt.subplots(n_gallery_seg, 6, figsize=(22, 4 * n_gallery_seg))
col_titles = ["FLAIR", "FLAIR IG", "T1ce", "T1ce IG", "GT Mask", "Pred Mask"]
for col, ttl in enumerate(col_titles):
    axes[0, col].set_title(ttl, fontsize=10, fontweight="bold")

for row in range(n_gallery_seg):
    fl_g, t1_g, msk_g = seg_xai_ds[row]
    fl_g  = fl_g.unsqueeze(0).to(DEVICE).requires_grad_(True)
    t1_g  = t1_g.unsqueeze(0).to(DEVICE).requires_grad_(True)

    attr_fl_g, attr_t1_g = ig_seg_gallery.attribute(
        inputs=(fl_g, t1_g),
        baselines=(torch.zeros_like(fl_g), torch.zeros_like(t1_g)),
        target=0, n_steps=50,
    )
    with torch.no_grad():
        pred_msk_g = (torch.sigmoid(seg_model(fl_g, t1_g)) > 0.5).float()

    fl_np   = fl_g.squeeze().detach().cpu().numpy()
    t1_np   = t1_g.squeeze().detach().cpu().numpy()
    attr_fl = normalize_attr(attr_fl_g)
    attr_t1 = normalize_attr(attr_t1_g)
    msk_np  = msk_g.squeeze().numpy()
    pred_np = pred_msk_g.squeeze().detach().cpu().numpy()

    # Dice for this sample
    inter = (attr_fl > 0.5).astype(float) * msk_np
    dice_g = 2 * inter.sum() / ((attr_fl > 0.5).sum() + msk_np.sum() + 1e-8)

    axes[row, 0].imshow(fl_np, cmap="gray"); axes[row, 0].axis("off")
    axes[row, 0].set_ylabel(f"Case {row+1}\nDice={dice_g:.2f}", fontsize=8,
                             rotation=0, labelpad=65, va="center")

    im_fl = axes[row, 1].imshow(attr_fl, cmap="hot", vmin=0, vmax=1); axes[row, 1].axis("off")
    fig.colorbar(im_fl, ax=axes[row, 1], fraction=0.046, pad=0.04).set_label("Imp.", fontsize=6)

    axes[row, 2].imshow(t1_np, cmap="gray"); axes[row, 2].axis("off")

    im_t1 = axes[row, 3].imshow(attr_t1, cmap="hot", vmin=0, vmax=1); axes[row, 3].axis("off")
    fig.colorbar(im_t1, ax=axes[row, 3], fraction=0.046, pad=0.04).set_label("Imp.", fontsize=6)

    axes[row, 4].imshow(msk_np, cmap="Reds"); axes[row, 4].axis("off")
    axes[row, 5].imshow(pred_np, cmap="Reds"); axes[row, 5].axis("off")

fig.suptitle(
    "Task C · IG Gallery — FLAIR and T1ce attributions for 6 segmentation cases\n"
    "Dice = overlap between thresholded IG attribution mask and GT lesion mask",
    fontsize=12, y=1.01,
)
plt.tight_layout()
plt.show()
<Figure size 2200x2400 with 48 Axes>

Section 7 — Clinical Validation Framework

A high-quality XAI map must satisfy four axes:

AxisQuestionHow to Assess
FaithfulnessDoes the attribution reflect what the model actually uses?Insertion/deletion curves, AOPC metric
PlausibilityDoes it align with clinical domain knowledge?Expert evaluation, annotation overlap (Dice)
StabilityDo small input perturbations change the attribution?Sensitivity analysis (Pearson r > 0.95 = stable)
Cross-ModalAre importance scores distributed meaningfully across modalities?Modality importance ratio pie chart

⚠️ XAI maps are explanations of model behavior, not ground-truth clinical markers. A model can be right for the wrong reason — always pair XAI analysis with model performance metrics.

# [LIVE] Stability / Sensitivity Test — add noise and re-compute IG
from scipy.stats import pearsonr

def stability_test(model, x1, x2, ig_method, noise_std=0.01, n_steps=50):
    """Compute Pearson r between clean and noisy IG attribution maps."""
    noisy_x1 = (x1 + noise_std * torch.randn_like(x1)).requires_grad_(True)
    noisy_x2 = (x2 + noise_std * torch.randn_like(x2)).requires_grad_(True)
    attr_clean1, _ = ig_method.attribute(
        inputs=(x1.detach().requires_grad_(True), x2.detach().requires_grad_(True)),
        baselines=(torch.zeros_like(x1), torch.zeros_like(x2)),
        target=0, n_steps=n_steps,
    )
    attr_noisy1, _ = ig_method.attribute(
        inputs=(noisy_x1, noisy_x2),
        baselines=(torch.zeros_like(noisy_x1), torch.zeros_like(noisy_x2)),
        target=0, n_steps=n_steps,
    )
    r, _ = pearsonr(
        attr_clean1.flatten().detach().numpy(),
        attr_noisy1.flatten().detach().numpy(),
    )
    return r


# Segmentation stability (FLAIR modality)
r_seg = stability_test(seg_xai_model, flair_xai.cpu(), t1ce_xai.cpu(), ig_seg, noise_std=0.01)
print(f"Segmentation IG stability (Pearson r): {r_seg:.4f}  [ideal: > 0.95]")

# Regression stability (T1 modality)
r_reg = stability_test(reg_model, t1_xai.cpu(), t2_xai.cpu(), ig_reg, noise_std=0.01)
print(f"Regression   IG stability (Pearson r): {r_reg:.4f}  [ideal: > 0.95]")
Segmentation IG stability (Pearson r): 0.9231  [ideal: > 0.95]
Regression   IG stability (Pearson r): 0.1706  [ideal: > 0.95]
import pandas as pd

summary_data = [
    # Task A — Classification
    {"Task": "Classification", "Method": "Integrated Gradients", "Runtime": "Medium", "Stability": "High", "Spatial Res.": "Full", "Multi-Input": "✓"},
    {"Task": "Classification", "Method": "DeepLIFT",             "Runtime": "Fast",   "Stability": "High", "Spatial Res.": "Full", "Multi-Input": "✓"},
    {"Task": "Classification", "Method": "GradientSHAP",         "Runtime": "Slow",   "Stability": "Medium","Spatial Res.": "Full", "Multi-Input": "✓"},
    {"Task": "Classification", "Method": "LayerGradCAM",         "Runtime": "Fast",   "Stability": "Medium","Spatial Res.": "Low (layer)", "Multi-Input": "Image only"},
    {"Task": "Classification", "Method": "Occlusion",            "Runtime": "Very slow","Stability":"High","Spatial Res.": "Patch-level","Multi-Input": "✓"},
    # Task B — Regression
    {"Task": "Regression",     "Method": "Integrated Gradients", "Runtime": "Medium", "Stability": "High", "Spatial Res.": "Full", "Multi-Input": "✓"},
    {"Task": "Regression",     "Method": "DeepLIFT",             "Runtime": "Fast",   "Stability": "High", "Spatial Res.": "Full", "Multi-Input": "✓"},
    {"Task": "Regression",     "Method": "Saliency",             "Runtime": "Fast",   "Stability": "Low",  "Spatial Res.": "Full", "Multi-Input": "✓"},
    {"Task": "Regression",     "Method": "Feature Ablation",     "Runtime": "Slow",   "Stability": "High", "Spatial Res.": "Modality-level","Multi-Input": "✓"},
    # Task C — Segmentation
    {"Task": "Segmentation",   "Method": "IG + SegWrapper",      "Runtime": "Medium", "Stability": "High", "Spatial Res.": "Full", "Multi-Input": "✓"},
    {"Task": "Segmentation",   "Method": "LayerGradCAM (dual)",  "Runtime": "Fast",   "Stability": "Medium","Spatial Res.": "Low (layer)","Multi-Input": "✓ (per branch)"},
]

df = pd.DataFrame(summary_data)
display(df.style.set_caption("XAI Method Comparison Across All Three Tasks"))
print(df.to_string(index=False))
Loading...
          Task               Method   Runtime Stability   Spatial Res.    Multi-Input
Classification Integrated Gradients    Medium      High           Full              ✓
Classification             DeepLIFT      Fast      High           Full              ✓
Classification         GradientSHAP      Slow    Medium           Full              ✓
Classification         LayerGradCAM      Fast    Medium    Low (layer)     Image only
Classification            Occlusion Very slow      High    Patch-level              ✓
    Regression Integrated Gradients    Medium      High           Full              ✓
    Regression             DeepLIFT      Fast      High           Full              ✓
    Regression             Saliency      Fast       Low           Full              ✓
    Regression     Feature Ablation      Slow      High Modality-level              ✓
  Segmentation      IG + SegWrapper    Medium      High           Full              ✓
  Segmentation  LayerGradCAM (dual)      Fast    Medium    Low (layer) ✓ (per branch)

Key Takeaways

1 · XAI Taxonomy Mastery

You covered all three families: gradient-based (IG, DeepLIFT, GradSHAP, GradCAM), perturbation-based (Occlusion, Feature Ablation), and know where attention-based fits.

2 · Multi-Modal Fusion Analysis

The modality importance ratio (pie chart from IG abs-mean) is a powerful clinical communication tool: it quantifies which imaging channel drove the decision in plain language.

3 · Captum Proficiency — Three Key Patterns

  • Tuple-input API: ig.attribute(inputs=(x1, x2), baselines=(b1, b2))

  • Wrapper pattern: MultiInputWrapper and SegReductionWrapper for Captum compatibility

  • Segmentation reduction: collapse (B,1,H,W)→(B,1) before attribution

4 · Clinical Validation — Four Axes

AxisQuick Test
FaithfulnessDoes removing high-attribution regions drop confidence?
PlausibilityDoes the heatmap match the planted signal region? (Dice overlap)
StabilityPearson r > 0.95 between clean and noisy attribution maps
Cross-ModalModality importance ratio consistent with known data structure

⚠️ Warning: XAI maps explain model behavior, not clinical ground truth. A perfectly faithful explanation of a poorly trained model is still clinically useless.


Next Steps & Extensions

  • Real data (v0.6): Replace dummy generators with SIIM-ISIC (classification), OSIC (regression), LGG MRI (segmentation) — all available on Kaggle with drop-in compatible loaders

  • Stronger encoders: ResNet-18 / EfficientNet-B0 frozen backbone + fine-tuned fusion head

  • Attention-based XAI: Transformer cross-attention maps, DINO self-supervised features

  • Captum Insights: Interactive dashboard for exploratory attribution analysis

  • AOPC / insertion-deletion curves: Quantitative faithfulness benchmarks