MLLMGuard: Comprehensive Safety Framework for Medical Multimodal LLMs
A multi-layered defense system for medical Vision-Language Models, providing robust protection against adversarial attacks, prompt injections, and hallucinations while maintaining clinical utility
Overview
Note: This page reflects a broader safety wrapper developed earlier. In the current dissertation, the core contribution is a paraphrase‑robustness baselining framework (MedPhr‑Rad) with selective conformal triage. MLLMGuard remains a supportive, optional deployment layer.
MLLMGuard (Medical Large Language Model Guard) is a comprehensive safety framework designed specifically for deploying Vision-Language Models in clinical settings. It addresses unique challenges in medical AI including prompt injection attacks, hallucination prevention, appropriate abstention, and evidence-based response generation.
Core Components
1. Input Protection Layer
Prompt Injection Defense
class PromptInjectionDetector:
def __init__(self):
self.malicious_patterns = [
r"ignore previous instructions",
r"system prompt override",
r"reveal your instructions",
r"act as a different assistant"
]
self.medical_context_validator = MedicalContextValidator()
def detect(self, prompt):
# Pattern matching
for pattern in self.malicious_patterns:
if re.search(pattern, prompt.lower()):
return True, "Potential prompt injection detected"
# Semantic analysis
if not self.medical_context_validator.is_medical(prompt):
return True, "Non-medical query in medical context"
return False, None
Image Manipulation Detection
class AdversarialImageDetector:
def __init__(self, model):
self.detector = RobustFeatureExtractor(model)
self.baseline_stats = self.load_baseline_statistics()
def detect_adversarial(self, image):
features = self.detector.extract(image)
# Statistical anomaly detection
anomaly_score = self.compute_anomaly_score(features)
# Frequency domain analysis
freq_anomaly = self.frequency_analysis(image)
# Patch detection
patch_score = self.detect_adversarial_patches(image)
return {
"is_adversarial": anomaly_score > 0.8,
"confidence": anomaly_score,
"attack_type": self.classify_attack(features)
}
2. Context Scrubbing
Medical Context Isolation
class MedicalContextScrubber:
def process(self, image, text):
# Remove non-medical overlays
cleaned_image = self.remove_text_overlays(image)
cleaned_image = self.remove_non_anatomical_regions(cleaned_image)
# Extract only medical content
medical_text = self.extract_medical_content(text)
# Validate medical relevance
if not self.is_medical_context(cleaned_image, medical_text):
raise NonMedicalContentError()
return cleaned_image, medical_text
def remove_text_overlays(self, image):
# OCR to detect text regions
text_regions = self.ocr_detector.detect(image)
# Inpaint text regions
for region in text_regions:
if not self.is_medical_annotation(region):
image = self.inpaint_region(image, region)
return image
3. Guarding Strategies
Instruction Prefixing
MEDICAL_SAFETY_PREFIX = """
[MEDICAL AI SAFETY INSTRUCTIONS]
1. You are a medical AI assistant providing decision support
2. Never provide definitive diagnoses - only observations and considerations
3. Always recommend professional medical consultation
4. Refuse non-medical queries politely
5. Cite evidence when making observations
6. Express uncertainty appropriately
7. Never recommend specific medications or dosages
8. Escalate emergency symptoms immediately
[MEDICAL CONTEXT]
"""
def create_safe_prompt(user_input, image_context):
return f"{MEDICAL_SAFETY_PREFIX}\n" \
f"Image: {image_context}\n" \
f"Query: {user_input}\n" \
f"Response (following safety guidelines):"
Answer-Only with Abstention
class SafeMedicalResponder:
def __init__(self, model, confidence_threshold=0.7):
self.model = model
self.confidence_threshold = confidence_threshold
self.abstention_phrases = [
"I cannot provide guidance on this query.",
"This requires professional medical evaluation.",
"I don't have sufficient confidence to assist with this."
]
def generate_response(self, image, query):
# Generate with confidence scores
response, confidence = self.model.generate_with_confidence(
image, query
)
# Check abstention conditions
if self.should_abstain(query, response, confidence):
return self.select_abstention_response(query)
# Format safe response
return self.format_medical_response(response, confidence)
def should_abstain(self, query, response, confidence):
# Low confidence
if confidence < self.confidence_threshold:
return True
# High-risk queries
if self.is_high_risk_query(query):
return True
# Potentially harmful response
if self.contains_harmful_content(response):
return True
return False
Reflection Prompts
class ReflectiveResponder:
def generate_with_reflection(self, image, query):
# Initial response
initial = self.model.generate(image, query)
# Reflection prompt
reflection_prompt = f"""
Review your response for:
1. Medical accuracy and appropriateness
2. Proper uncertainty expression
3. Clear recommendation for professional consultation
4. Absence of definitive diagnoses
5. Evidence-based observations
Original response: {initial}
Provide an improved, safer response:
"""
# Refined response
refined = self.model.generate(reflection_prompt)
return refined
4. Evidence Citation
Evidence-Grounded Generation
class EvidenceBasedResponder:
def __init__(self, knowledge_base):
self.kb = knowledge_base
self.citation_formatter = CitationFormatter()
def generate_with_evidence(self, image, query):
# Extract visual features
findings = self.extract_findings(image)
# Retrieve relevant evidence
evidence = self.kb.retrieve(findings, query)
# Generate response with citations
response = self.model.generate(
image=image,
query=query,
evidence=evidence,
instruction="Cite evidence using [1], [2] format"
)
# Format with proper citations
return self.citation_formatter.format(response, evidence)
5. Evaluation Integration
Paraphrase‑Dispersion Risk Scoring (from prior VSF‑Med‑VQA)
class VSFMedVQAScorer:
def __init__(self):
self.risk_dimensions = {
"hallucination_risk": 0.3,
"misdiagnosis_risk": 0.4,
"harmful_advice_risk": 0.3
}
def compute_risk_score(self, response, ground_truth):
scores = {}
# Hallucination detection
scores["hallucination"] = self.detect_hallucination(
response, ground_truth
)
# Clinical accuracy
scores["accuracy"] = self.assess_clinical_accuracy(
response, ground_truth
)
# Harm potential
scores["harm"] = self.assess_harm_potential(response)
# Weighted risk score
risk_score = sum(
self.risk_dimensions[dim] * scores[dim]
for dim in scores
)
return risk_score, scores
Implementation Architecture
class MLLMGuard:
def __init__(self, model, config):
self.model = model
self.config = config
# Initialize components
self.injection_detector = PromptInjectionDetector()
self.image_detector = AdversarialImageDetector(model)
self.context_scrubber = MedicalContextScrubber()
self.safe_responder = SafeMedicalResponder(model)
self.evidence_responder = EvidenceBasedResponder(config.kb)
self.risk_scorer = VSFMedVQAScorer() # prior naming retained for code continuity
def process_request(self, image, text):
try:
# Input validation
if self.injection_detector.detect(text)[0]:
return self.safe_rejection("Invalid query detected")
if self.image_detector.detect_adversarial(image)["is_adversarial"]:
return self.safe_rejection("Image manipulation detected")
# Context scrubbing
clean_image, clean_text = self.context_scrubber.process(
image, text
)
# Generate safe response
response = self.safe_responder.generate_response(
clean_image, clean_text
)
# Add evidence if needed
if self.config.require_evidence:
response = self.evidence_responder.enhance_with_evidence(
response, clean_image, clean_text
)
# Risk assessment
risk_score, risk_breakdown = self.risk_scorer.compute_risk_score(
response, None # No ground truth in production
)
if risk_score > self.config.risk_threshold:
return self.escalate_to_human(response, risk_breakdown)
return response
except Exception as e:
self.log_error(e)
return self.safe_error_response()
Deployment Guidelines
Configuration
mllmguard_config:
# Input protection
prompt_injection_detection: true
adversarial_image_detection: true
max_prompt_length: 500
# Context scrubbing
remove_text_overlays: true
medical_context_validation: strict
# Response generation
confidence_threshold: 0.75
require_evidence: true
abstention_on_uncertainty: true
# Risk management
risk_threshold: 0.3
escalation_enabled: true
audit_logging: comprehensive
Integration with Clinical Systems
# EHR Integration
mllm_guard.set_patient_context(patient_id, ehr_data)
# PACS Integration
mllm_guard.set_imaging_context(study_id, pacs_metadata)
# Clinical Decision Support
response = mllm_guard.process_clinical_query(
image=dicom_image,
query=physician_query,
urgency=urgency_level
)
Performance Metrics
Safety Metrics
- Prompt injection block rate: >99.5%
- Adversarial detection accuracy: >95%
- Hallucination rate: <2%
- Harmful response rate: <0.1%
Utility Metrics
- Clinical relevance: >90%
- Response usefulness: >85%
- Abstention rate: <10%
- Evidence citation rate: >80%
Related Resources
- Understanding VLM Attacks
- Safety Evaluation Methods
- PhD Research Plan — RQ3: Safety mechanisms