Technical Approach: Methods and Implementation
Detailed technical methods for measuring, analyzing, and mitigating FSF and EFG in medical vision-language models
← Research Thrusts | Site Index →
Core Notation and Definitions
Mathematical Framework
Let denote a chest radiograph, a base question, and the set of semantically equivalent paraphrases.
Model Output: where is the answer space
Ground Truth:
Attention Map:
FSF Formalization
Where based on empirical calibration against semantically distinct questions.
EFG Formalization
For faithfulness metric {Deletion AUC, Insertion AUC, Sufficiency, Necessity}
Model Architectures and Configuration
MedGemma-4b-it
class MedGemmaConfig:
# Vision encoder
vision_encoder = "CLIP-ViT-L/14" # 307M parameters
vision_dim = 1024
patch_size = 14
image_size = 224
# Language model
hidden_dim = 2560
n_layers = 18
n_heads = 32
vocab_size = 32000 # SentencePiece
# Cross-attention
cross_attention_layers = [8, 12, 16]
projection_dim = 2048
projection_layers = 2 # MLP depth
# Training details
checkpoint = "google/medgemma-4b-it"
medical_pretrain = True # MIMIC-CXR + general medicalLLaVA-Rad
class LLaVARadConfig:
# Vision encoder (same as MedGemma)
vision_encoder = "CLIP-ViT-L/14"
vision_dim = 1024
# Language model (Vicuna-7B)
hidden_dim = 4096
n_layers = 32
n_heads = 32
vocab_size = 32000 # LLaMA tokenizer
# Cross-attention
cross_attention_layers = [10, 15, 20, 25]
projection_dim = 4096
projection_layers = 1 # Single linear
# Training details
checkpoint = "microsoft/llava-rad"
radiology_finetune = TrueAttention Extraction Pipeline
Multi-Layer Aggregation
def extract_attention_maps(model, image, question):
"""
Extract and aggregate attention across layers.
"""
# Forward pass with attention hooks
hooks = []
for layer_idx in model.cross_attention_layers:
hook = model.layers[layer_idx].cross_attention.register_forward_hook(
lambda m, i, o: attention_maps.append(o.attention_weights)
)
hooks.append(hook)
# Get model output
output = model(image, question)
# Aggregate attention maps
aggregated = torch.zeros(224, 224)
for idx, attn in enumerate(attention_maps):
# Average across heads
attn_avg = attn.mean(dim=1) # [batch, seq_len, n_patches]
# Reshape to spatial
attn_spatial = attn_avg.view(1, 16, 16) # 224/14 = 16
# Upsample to original resolution
attn_upsampled = F.interpolate(
attn_spatial.unsqueeze(0),
size=(224, 224),
mode='bilinear'
)
# Weight by layer index (later layers more task-relevant)
weight = (idx + 1) / len(attention_maps)
aggregated += weight * attn_upsampled.squeeze()
# Normalize
aggregated = aggregated / aggregated.sum()
# Clean up hooks
for hook in hooks:
hook.remove()
return aggregated, outputDeterministic Inference
def setup_deterministic_inference():
"""
Ensure reproducible results across runs.
"""
# Set all random seeds
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Deterministic operations
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Disable dropout
def disable_dropout(module):
if isinstance(module, nn.Dropout):
module.p = 0
model.apply(disable_dropout)
# Use greedy decoding
generation_config = {
'temperature': 0.0,
'do_sample': False,
'top_p': 1.0,
'top_k': 1
}
return generation_configParaphrase Generation and Validation
Linguistic Variation Framework
class ParaphraseGenerator:
def __init__(self):
self.strategies = {
'lexical': self.lexical_substitution,
'syntactic': self.syntactic_transformation,
'pragmatic': self.pragmatic_modulation,
'negation': self.negation_inversion,
'scope': self.scope_modification
}
# Medical synonym dictionary
self.medical_synonyms = {
'cardiomegaly': ['enlarged heart', 'cardiac enlargement'],
'pleural effusion': ['fluid in pleural space', 'pleural fluid'],
'pneumothorax': ['collapsed lung', 'air in pleural space'],
'consolidation': ['airspace opacity', 'lung opacity'],
'infiltrate': ['pulmonary infiltrate', 'lung infiltrate']
}
def lexical_substitution(self, question):
"""Replace medical terms with synonyms."""
for term, synonyms in self.medical_synonyms.items():
if term in question.lower():
synonym = random.choice(synonyms)
return question.lower().replace(term, synonym)
return question
def syntactic_transformation(self, question):
"""Transform syntactic structure."""
# Parse tree manipulation
doc = nlp(question)
# ... transformation logic ...
return transformed
def negation_inversion(self, question):
"""Add or remove negation."""
if any(neg in question.lower() for neg in ['no', 'not', 'without']):
# Remove negation
return self.remove_negation(question)
else:
# Add negation forms
patterns = [
f"No {finding} is seen, correct?",
f"There isn't any {finding}, right?",
f"Absence of {finding}?"
]
return random.choice(patterns)Clinical Validation Protocol
class ClinicalValidator:
def __init__(self, validator_credentials):
self.validator = validator_credentials
self.rejection_reasons = []
def validate_semantic_equivalence(self, original, paraphrase):
"""
Ensure paraphrases maintain clinical meaning.
"""
checks = {
'anatomical_consistency': self.check_anatomy,
'finding_consistency': self.check_findings,
'clinical_naturalism': self.check_naturalism,
'unambiguous': self.check_ambiguity
}
for check_name, check_func in checks.items():
if not check_func(original, paraphrase):
self.rejection_reasons.append(check_name)
return False
return True
def check_anatomy(self, original, paraphrase):
"""Ensure same anatomical regions referenced."""
original_anatomy = extract_anatomy_terms(original)
paraphrase_anatomy = extract_anatomy_terms(paraphrase)
return original_anatomy == paraphrase_anatomyCausal Analysis Methods
Layer-wise Representation Analysis
def analyze_representation_similarity(model, image, paraphrases):
"""
Track how representations diverge through layers.
"""
similarities = defaultdict(list)
# Hook to capture hidden states
hidden_states = {}
def hook_fn(name):
def hook(module, input, output):
hidden_states[name] = output.detach()
return hook
# Register hooks for all layers
for idx, layer in enumerate(model.layers):
layer.register_forward_hook(hook_fn(f'layer_{idx}'))
# Process each paraphrase
representations = []
for p in paraphrases:
_ = model(image, p)
representations.append({k: v.clone() for k, v in hidden_states.items()})
# Compute pairwise similarities
for i in range(len(paraphrases)):
for j in range(i+1, len(paraphrases)):
for layer_name in representations[0].keys():
rep_i = representations[i][layer_name]
rep_j = representations[j][layer_name]
# Cosine similarity
sim = F.cosine_similarity(
rep_i.flatten(),
rep_j.flatten(),
dim=0
).item()
similarities[layer_name].append(sim)
return similaritiesCross-Attention Intervention
class CrossAttentionIntervention:
def __init__(self, model):
self.model = model
self.saved_attention = {}
def save_attention_state(self, image, question, name='reference'):
"""Save attention patterns from a forward pass."""
def save_hook(module, input, output):
self.saved_attention[name] = {
'queries': output.queries.detach(),
'keys': output.keys.detach(),
'values': output.values.detach(),
'weights': output.attention_weights.detach()
}
# Attach to cross-attention layers
hooks = []
for layer_idx in self.model.cross_attention_layers:
hook = self.model.layers[layer_idx].cross_attention.register_forward_hook(save_hook)
hooks.append(hook)
# Forward pass
_ = self.model(image, question)
# Clean up
for hook in hooks:
hook.remove()
def interchange_experiment(self, image, q1, q2):
"""Test effect of swapping attention components."""
# Save attention from both questions
self.save_attention_state(image, q1, 'q1')
self.save_attention_state(image, q2, 'q2')
results = {}
# Test different swapping strategies
swap_configs = [
('swap_kv', {'keys': 'q2', 'values': 'q2', 'queries': 'q1'}),
('swap_q', {'keys': 'q1', 'values': 'q1', 'queries': 'q2'}),
('swap_all', {'keys': 'q2', 'values': 'q2', 'queries': 'q2'})
]
for config_name, config in swap_configs:
# Override forward pass
output = self.forward_with_intervention(image, q1, config)
results[config_name] = output
return resultsMitigation Strategies
LoRA Configuration for Targeted Adaptation
from peft import LoraConfig, get_peft_model
def create_targeted_lora_model(base_model, causal_analysis_results):
"""
Create LoRA model targeting causally-identified layers.
"""
# Identify target modules from causal analysis
target_layers = causal_analysis_results['high_impact_layers'] # e.g., [12, 13, 14, 15, 16]
# Configure LoRA for specific attention projections
target_modules = []
for layer_idx in target_layers:
target_modules.extend([
f"layers.{layer_idx}.self_attn.q_proj",
f"layers.{layer_idx}.self_attn.v_proj",
f"layers.{layer_idx}.cross_attn.q_proj",
f"layers.{layer_idx}.cross_attn.v_proj"
])
lora_config = LoraConfig(
r=16, # Rank
lora_alpha=32, # Scaling factor
target_modules=target_modules,
lora_dropout=0.05,
bias="none",
task_type="VL_QA"
)
# Create PEFT model
model = get_peft_model(base_model, lora_config)
# Verify parameter efficiency
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable_params:,} ({100 * trainable_params / total_params:.2f}%)")
return modelMulti-Objective Training
class ParaphraseRobustTrainer:
def __init__(self, model, config):
self.model = model
self.config = config
def compute_consistency_loss(self, image, paraphrases, model_outputs):
"""
KL divergence between paraphrase outputs.
"""
consistency_loss = 0
n_pairs = 0
for i in range(len(paraphrases)):
for j in range(i+1, len(paraphrases)):
# Get logits for both paraphrases
logits_i = model_outputs[i].logits
logits_j = model_outputs[j].logits
# Convert to probabilities
probs_i = F.softmax(logits_i, dim=-1)
probs_j = F.softmax(logits_j, dim=-1)
# KL divergence (symmetrized)
kl_ij = F.kl_div(probs_i.log(), probs_j, reduction='batchmean')
kl_ji = F.kl_div(probs_j.log(), probs_i, reduction='batchmean')
consistency_loss += (kl_ij + kl_ji) / 2
n_pairs += 1
return consistency_loss / n_pairs
def compute_attention_stability_loss(self, attention_maps):
"""
Encourage stable attention across paraphrases.
"""
stability_loss = 0
n_pairs = 0
for i in range(len(attention_maps)):
for j in range(i+1, len(attention_maps)):
# Compute SSIM loss (1 - SSIM for minimization)
ssim_loss = 1 - structural_similarity(
attention_maps[i].cpu().numpy(),
attention_maps[j].cpu().numpy()
)
stability_loss += ssim_loss
n_pairs += 1
return stability_loss / n_pairs
def training_step(self, batch):
"""
Single training iteration with multi-objective loss.
"""
images, questions, paraphrases, labels = batch
# Forward pass for all paraphrases
outputs = []
attention_maps = []
for para_set in paraphrases:
para_outputs = []
para_attention = []
for p in para_set:
out, attn = self.model(images, p, return_attention=True)
para_outputs.append(out)
para_attention.append(attn)
outputs.append(para_outputs)
attention_maps.append(para_attention)
# Task loss (standard VQA)
task_loss = sum(
F.cross_entropy(out[0].logits, labels)
for out in outputs
) / len(outputs)
# Consistency loss
consistency_loss = np.mean([
self.compute_consistency_loss(img, para, out)
for img, para, out in zip(images, paraphrases, outputs)
])
# Attention stability loss
attention_loss = np.mean([
self.compute_attention_stability_loss(attn)
for attn in attention_maps
])
# Combined loss
total_loss = (
self.config.lambda_task * task_loss +
self.config.lambda_consistency * consistency_loss +
self.config.lambda_attention * attention_loss
)
return {
'loss': total_loss,
'task_loss': task_loss.item(),
'consistency_loss': consistency_loss,
'attention_loss': attention_loss
}Deployment Framework
Uncertainty Quantification
class UncertaintyEstimator:
def __init__(self, model, n_samples=10):
self.model = model
self.n_samples = n_samples
def paraphrase_uncertainty(self, image, paraphrases):
"""
Uncertainty from paraphrase ensemble variance.
"""
predictions = []
for p in paraphrases:
output = self.model(image, p)
pred_probs = F.softmax(output.logits, dim=-1)
predictions.append(pred_probs)
# Stack predictions
pred_stack = torch.stack(predictions)
# Compute variance
pred_mean = pred_stack.mean(dim=0)
pred_var = pred_stack.var(dim=0)
# Entropy of mean prediction
entropy = -torch.sum(pred_mean * torch.log(pred_mean + 1e-8), dim=-1)
return {
'mean_prediction': pred_mean,
'variance': pred_var,
'entropy': entropy,
'disagreement': (pred_stack.argmax(dim=-1) != pred_mean.argmax(dim=-1)).float().mean()
}
def mc_dropout_uncertainty(self, image, question):
"""
Monte Carlo dropout uncertainty.
"""
# Enable dropout
self.model.train()
predictions = []
for _ in range(self.n_samples):
with torch.no_grad():
output = self.model(image, question)
predictions.append(F.softmax(output.logits, dim=-1))
# Back to eval mode
self.model.eval()
# Compute statistics
pred_stack = torch.stack(predictions)
return {
'mean': pred_stack.mean(dim=0),
'std': pred_stack.std(dim=0),
'entropy': self.predictive_entropy(pred_stack)
}Selective Triage System
class SelectiveTriageSystem:
def __init__(self, model, thresholds):
self.model = model
self.uncertainty_estimator = UncertaintyEstimator(model)
self.thresholds = thresholds
# Critical findings that always need review
self.critical_findings = [
'pneumothorax', 'tension pneumothorax',
'large pleural effusion', 'mediastinal shift',
'free air', 'aortic dissection'
]
def triage_decision(self, image, question, paraphrases=None):
"""
Make triage decision with safety guarantees.
"""
# Generate paraphrases if not provided
if paraphrases is None:
paraphrases = self.generate_paraphrases(question, n=5)
# Get uncertainty estimates
para_uncertainty = self.uncertainty_estimator.paraphrase_uncertainty(
image, paraphrases
)
# Get primary prediction
primary_output = self.model(image, question)
primary_pred = primary_output.logits.argmax(dim=-1)
# Decision logic
decision_path = []
# Check paraphrase agreement
if para_uncertainty['disagreement'] > self.thresholds['max_disagreement']:
decision_path.append('high_paraphrase_disagreement')
return self.create_decision('defer', 'paraphrase_uncertainty', decision_path)
# Check confidence
confidence = para_uncertainty['mean_prediction'].max().item()
if confidence < self.thresholds['min_confidence']:
decision_path.append('low_confidence')
return self.create_decision('defer', 'low_confidence', decision_path)
# Check for critical findings
predicted_text = self.decode_prediction(primary_pred)
if any(critical in predicted_text.lower() for critical in self.critical_findings):
decision_path.append('critical_finding_detected')
return self.create_decision('urgent_review', 'critical_finding', decision_path)
# Check if normal with high confidence
if self.is_normal_finding(predicted_text) and confidence > self.thresholds['normal_confidence']:
decision_path.append('high_confidence_normal')
return self.create_decision('auto_clear', 'normal', decision_path)
# Default to radiologist review
decision_path.append('standard_review')
return self.create_decision('review', 'standard', decision_path)
def create_decision(self, action, reason, path):
"""
Create structured triage decision.
"""
return {
'action': action,
'reason': reason,
'decision_path': path,
'timestamp': datetime.now().isoformat(),
'model_version': self.model.config.version,
'thresholds_used': self.thresholds
}Evaluation Metrics
Comprehensive FSF and EFG Computation
class RobustnessEvaluator:
def __init__(self, model, dataset, tau=0.85):
self.model = model
self.dataset = dataset
self.tau = tau
def compute_fsf_index(self, results):
"""
Compute FSF index from evaluation results.
"""
flips_with_stable_attention = 0
total_flips = 0
for item in results:
paraphrase_answers = item['paraphrase_answers']
attention_maps = item['attention_maps']
# Check all pairs
for i in range(len(paraphrase_answers)):
for j in range(i+1, len(paraphrase_answers)):
# Check if answers differ
if paraphrase_answers[i] != paraphrase_answers[j]:
total_flips += 1
# Check attention similarity
ssim = structural_similarity(
attention_maps[i],
attention_maps[j]
)
if ssim > self.tau:
flips_with_stable_attention += 1
fsf_index = flips_with_stable_attention / total_flips if total_flips > 0 else 0
return fsf_index, flips_with_stable_attention, total_flips
def compute_efg(self, results):
"""
Compute error-faithfulness gap.
"""
correct_faithfulness = []
incorrect_faithfulness = []
for item in results:
prediction = item['prediction']
ground_truth = item['ground_truth']
faithfulness_scores = item['faithfulness_scores']
if prediction == ground_truth:
correct_faithfulness.append(faithfulness_scores)
else:
incorrect_faithfulness.append(faithfulness_scores)
# Compute means
efg_results = {}
for metric in ['deletion_auc', 'insertion_auc', 'sufficiency', 'necessity']:
correct_mean = np.mean([s[metric] for s in correct_faithfulness])
incorrect_mean = np.mean([s[metric] for s in incorrect_faithfulness])
efg_results[metric] = {
'correct_mean': correct_mean,
'incorrect_mean': incorrect_mean,
'gap': incorrect_mean - correct_mean,
'cohens_d': self.compute_cohens_d(
[s[metric] for s in correct_faithfulness],
[s[metric] for s in incorrect_faithfulness]
)
}
return efg_resultsComputational Optimization
Efficient Batch Processing
class EfficientBatchProcessor:
def __init__(self, model, batch_size=32):
self.model = model
self.batch_size = batch_size
# Compile model for faster inference (PyTorch 2.0+)
self.model = torch.compile(self.model, mode="max-autotune")
def process_dataset(self, dataset, use_amp=True):
"""
Process dataset with optimizations.
"""
dataloader = DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=4,
pin_memory=True,
prefetch_factor=2
)
results = []
# Mixed precision for faster processing
with torch.cuda.amp.autocast(enabled=use_amp):
for batch in tqdm(dataloader):
# Move to GPU
batch = {k: v.cuda(non_blocking=True) for k, v in batch.items()}
# Process with gradient checkpointing for memory efficiency
with torch.no_grad():
outputs = self.model(**batch)
results.append({
'outputs': outputs.cpu(),
'batch_indices': batch['indices']
})
return resultsThis technical approach provides a comprehensive implementation framework for measuring, analyzing, and mitigating FSF and EFG phenomena in medical VLMs, enabling safe clinical deployment through rigorous engineering and evaluation.