Robust_MMFM / gradio /run_caption.py
KC123hello's picture
Upload Files
fc0ff8f verified
"""
Script to generate captions for images using the VLM model.
This script runs in the RobustMMFMEnv conda environment.
"""
import argparse
import sys
import os
import warnings
warnings.filterwarnings('ignore')
# Add the parent directory to the path to import vlm_eval modules
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
def generate_caption(image_path, epsilon, sparsity, attack_algo, num_iters, model_name="open_flamingo", num_shots=0, targeted=False):
"""
Generate caption for a single image.
Args:
image_path: Path to the image file
model_name: Name of the model to use
num_shots: Number of shots for few-shot learning
Returns:
str: Generated caption
"""
try:
# Import required modules
from PIL import Image
import torch
import numpy as np
import tempfile
from open_flamingo.eval.models.of_eval_model_adv import EvalModelAdv
from open_flamingo.eval.coco_metric import postprocess_captioning_generation
from vlm_eval.attacks.apgd import APGD
from vlm_eval.attacks.saif import SAIF
# Model arguments
model_args = {
"lm_path": "togethercomputer/RedPajama-INCITE-Base-3B-v1",
"lm_tokenizer_path": "togethercomputer/RedPajama-INCITE-Base-3B-v1",
"vision_encoder_path": "ViT-L-14",
"vision_encoder_pretrained": "openai",
"checkpoint_path": "/home/kc/.cache/huggingface/hub/models--openflamingo--OpenFlamingo-4B-vitl-rpj3b/snapshots/df8d3f7e75bcf891ce2fbf5253a12f524692d9c2/checkpoint.pt",
"cross_attn_every_n_layers": "2",
"precision": "float16",
}
eval_model = EvalModelAdv(model_args, adversarial=True)
eval_model.set_device(0 if torch.cuda.is_available() else -1)
image = Image.open(image_path).convert("RGB")
image = eval_model._prepare_images([[image]])
prompt = eval_model.get_caption_prompt()
# Generate original caption
orig_caption = eval_model.get_outputs(
batch_images=image,
batch_text=[prompt], # Note: wrapped in list
min_generation_length=0,
max_generation_length=20,
num_beams=3,
length_penalty=-2.0,
)
#orig_caption = [postprocess_captioning_generation(out).replace('"', "") for out in orig_caption
#]
# For adversarial attack, create the adversarial text prompt
targeted = False # or True if you want targeted attack
target_str = "a dog" # your target if targeted=True
adv_caption = orig_caption[0] if not targeted else target_str
prompt_adv = eval_model.get_caption_prompt(adv_caption)
# ⭐ THIS IS THE CRITICAL MISSING STEP ⭐
eval_model.set_inputs(
batch_text=[prompt_adv], # Use adversarial prompt
past_key_values=None,
to_device=True,
)
# Now run the attack
if attack_algo == "APGD":
attack = APGD(
eval_model if not targeted else lambda x: -eval_model(x),
norm="linf",
eps=epsilon/255.0,
mask_out=None,
initial_stepsize=1.0,
)
adv_image = attack.perturb(
image.to(eval_model.device, dtype=eval_model.cast_dtype),
iterations=num_iters,
pert_init=None,
verbose=False,
)
elif attack_algo == "SAIF":
attack = SAIF(
model=eval_model,
targeted=targeted,
img_range=(0,1),
steps=num_iters,
mask_out=None,
eps=epsilon/255.0,
k=sparsity,
ver=False
)
adv_image, _ = attack(
x=image.to(eval_model.device, dtype=eval_model.cast_dtype),
)
else:
raise ValueError(f"Unsupported attack algorithm: {attack_algo}")
adv_image = adv_image.detach().cpu()
# Generate adversarial caption
adv_caption_output = eval_model.get_outputs(
batch_images=adv_image,
batch_text=[prompt], # Use clean prompt for generation
min_generation_length=0,
max_generation_length=20,
num_beams=3,
length_penalty=-2.0,
)
new_predictions = [
postprocess_captioning_generation(out).replace('"', "") for out in adv_caption_output
]
# At the end, instead of:
# print(orig_caption[0])
# print(new_predictions[0])
# Do this - strip the list and get just the string:
#print(orig_caption)
orig_img_np = image.view(1,3,224,224).squeeze(0).cpu().permute(1, 2, 0).numpy()
adv_img_np = adv_image.view(1,3,224,224).squeeze(0).cpu().permute(1, 2, 0).numpy()
# Calculate perturbation (difference between adversarial and original)
perturbation = adv_img_np - orig_img_np
# Magnify by 10x for visualization
perturbation_magnified = perturbation * 10
# Normalize to [0, 255] for display
orig_img_np = ((orig_img_np - orig_img_np.min()) / (orig_img_np.max() - orig_img_np.min()) * 255).astype(np.uint8)
adv_img_np = ((adv_img_np - adv_img_np.min()) / (adv_img_np.max() - adv_img_np.min()) * 255).astype(np.uint8)
# Normalize perturbation to [0, 255] for visualization
pert_img_np = ((perturbation_magnified - perturbation_magnified.min()) /
(perturbation_magnified.max() - perturbation_magnified.min()) * 255).astype(np.uint8)
# ✅ Save images to temporary files
with tempfile.NamedTemporaryFile(mode='wb', suffix='.png', delete=False) as f:
orig_img_path = f.name
Image.fromarray(orig_img_np).save(orig_img_path)
with tempfile.NamedTemporaryFile(mode='wb', suffix='.png', delete=False) as f:
adv_img_path = f.name
Image.fromarray(adv_img_np).save(adv_img_path)
with tempfile.NamedTemporaryFile(mode='wb', suffix='.png', delete=False) as f:
pert_img_path = f.name
Image.fromarray(pert_img_np).save(pert_img_path)
results = {
"original_caption": orig_caption[0],
"adversarial_caption": new_predictions[0],
"original_image_path": orig_img_path, # Return file paths
"adversarial_image_path": adv_img_path,
"perturbation_image_path": pert_img_path
}
return results
except Exception as e:
import traceback
error_msg = f"Error in caption generation: {str(e)}\n{traceback.format_exc()}"
print(error_msg, file=sys.stderr, flush=True)
# Return dict with error information
return {
"original_caption": f"Error: {str(e)}",
"adversarial_caption": "",
"original_image_path": None,
"adversarial_image_path": None,
"perturbation_image_path": None
}
def main():
parser = argparse.ArgumentParser(description="Generate caption for an image")
parser.add_argument("--image_path", type=str, required=True, help="Path to the image")
parser.add_argument("--model", type=str, default="open_flamingo", help="Model to use")
parser.add_argument("--shots", type=int, default=0, help="Number of shots")
parser.add_argument("--epsilon", type=float, default=8.0, help="Epsilon for adversarial attack")
parser.add_argument("--sparsity", type=int, default=0, help="Sparsity for SAIF attack")
parser.add_argument("--attack_algo", type=str, default="APGD", help="Adversarial attack algorithm (APGD or SAIF)")
parser.add_argument("--num_iters", type=int, default=100, help="Number of iterations for adversarial attack")
args = parser.parse_args()
# Generate caption
caption = generate_caption(args.image_path, args.epsilon, args.sparsity, args.attack_algo, args.num_iters, args.model, args.shots)
if caption:
print(caption)
sys.exit(0)
else:
print("Failed to generate caption", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()