Spaces:
Runtime error
Runtime error
| """ | |
| Script to generate captions for images using the VLM model. | |
| This script runs in the RobustMMFMEnv conda environment. | |
| """ | |
| import argparse | |
| import sys | |
| import os | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # Add the parent directory to the path to import vlm_eval modules | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) | |
| def generate_caption(image_path, epsilon, sparsity, attack_algo, num_iters, model_name="open_flamingo", num_shots=0, targeted=False): | |
| """ | |
| Generate caption for a single image. | |
| Args: | |
| image_path: Path to the image file | |
| model_name: Name of the model to use | |
| num_shots: Number of shots for few-shot learning | |
| Returns: | |
| str: Generated caption | |
| """ | |
| try: | |
| # Import required modules | |
| from PIL import Image | |
| import torch | |
| import numpy as np | |
| import tempfile | |
| from open_flamingo.eval.models.of_eval_model_adv import EvalModelAdv | |
| from open_flamingo.eval.coco_metric import postprocess_captioning_generation | |
| from vlm_eval.attacks.apgd import APGD | |
| from vlm_eval.attacks.saif import SAIF | |
| # Model arguments | |
| model_args = { | |
| "lm_path": "togethercomputer/RedPajama-INCITE-Base-3B-v1", | |
| "lm_tokenizer_path": "togethercomputer/RedPajama-INCITE-Base-3B-v1", | |
| "vision_encoder_path": "ViT-L-14", | |
| "vision_encoder_pretrained": "openai", | |
| "checkpoint_path": "/home/kc/.cache/huggingface/hub/models--openflamingo--OpenFlamingo-4B-vitl-rpj3b/snapshots/df8d3f7e75bcf891ce2fbf5253a12f524692d9c2/checkpoint.pt", | |
| "cross_attn_every_n_layers": "2", | |
| "precision": "float16", | |
| } | |
| eval_model = EvalModelAdv(model_args, adversarial=True) | |
| eval_model.set_device(0 if torch.cuda.is_available() else -1) | |
| image = Image.open(image_path).convert("RGB") | |
| image = eval_model._prepare_images([[image]]) | |
| prompt = eval_model.get_caption_prompt() | |
| # Generate original caption | |
| orig_caption = eval_model.get_outputs( | |
| batch_images=image, | |
| batch_text=[prompt], # Note: wrapped in list | |
| min_generation_length=0, | |
| max_generation_length=20, | |
| num_beams=3, | |
| length_penalty=-2.0, | |
| ) | |
| #orig_caption = [postprocess_captioning_generation(out).replace('"', "") for out in orig_caption | |
| #] | |
| # For adversarial attack, create the adversarial text prompt | |
| targeted = False # or True if you want targeted attack | |
| target_str = "a dog" # your target if targeted=True | |
| adv_caption = orig_caption[0] if not targeted else target_str | |
| prompt_adv = eval_model.get_caption_prompt(adv_caption) | |
| # ⭐ THIS IS THE CRITICAL MISSING STEP ⭐ | |
| eval_model.set_inputs( | |
| batch_text=[prompt_adv], # Use adversarial prompt | |
| past_key_values=None, | |
| to_device=True, | |
| ) | |
| # Now run the attack | |
| if attack_algo == "APGD": | |
| attack = APGD( | |
| eval_model if not targeted else lambda x: -eval_model(x), | |
| norm="linf", | |
| eps=epsilon/255.0, | |
| mask_out=None, | |
| initial_stepsize=1.0, | |
| ) | |
| adv_image = attack.perturb( | |
| image.to(eval_model.device, dtype=eval_model.cast_dtype), | |
| iterations=num_iters, | |
| pert_init=None, | |
| verbose=False, | |
| ) | |
| elif attack_algo == "SAIF": | |
| attack = SAIF( | |
| model=eval_model, | |
| targeted=targeted, | |
| img_range=(0,1), | |
| steps=num_iters, | |
| mask_out=None, | |
| eps=epsilon/255.0, | |
| k=sparsity, | |
| ver=False | |
| ) | |
| adv_image, _ = attack( | |
| x=image.to(eval_model.device, dtype=eval_model.cast_dtype), | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported attack algorithm: {attack_algo}") | |
| adv_image = adv_image.detach().cpu() | |
| # Generate adversarial caption | |
| adv_caption_output = eval_model.get_outputs( | |
| batch_images=adv_image, | |
| batch_text=[prompt], # Use clean prompt for generation | |
| min_generation_length=0, | |
| max_generation_length=20, | |
| num_beams=3, | |
| length_penalty=-2.0, | |
| ) | |
| new_predictions = [ | |
| postprocess_captioning_generation(out).replace('"', "") for out in adv_caption_output | |
| ] | |
| # At the end, instead of: | |
| # print(orig_caption[0]) | |
| # print(new_predictions[0]) | |
| # Do this - strip the list and get just the string: | |
| #print(orig_caption) | |
| orig_img_np = image.view(1,3,224,224).squeeze(0).cpu().permute(1, 2, 0).numpy() | |
| adv_img_np = adv_image.view(1,3,224,224).squeeze(0).cpu().permute(1, 2, 0).numpy() | |
| # Calculate perturbation (difference between adversarial and original) | |
| perturbation = adv_img_np - orig_img_np | |
| # Magnify by 10x for visualization | |
| perturbation_magnified = perturbation * 10 | |
| # Normalize to [0, 255] for display | |
| orig_img_np = ((orig_img_np - orig_img_np.min()) / (orig_img_np.max() - orig_img_np.min()) * 255).astype(np.uint8) | |
| adv_img_np = ((adv_img_np - adv_img_np.min()) / (adv_img_np.max() - adv_img_np.min()) * 255).astype(np.uint8) | |
| # Normalize perturbation to [0, 255] for visualization | |
| pert_img_np = ((perturbation_magnified - perturbation_magnified.min()) / | |
| (perturbation_magnified.max() - perturbation_magnified.min()) * 255).astype(np.uint8) | |
| # ✅ Save images to temporary files | |
| with tempfile.NamedTemporaryFile(mode='wb', suffix='.png', delete=False) as f: | |
| orig_img_path = f.name | |
| Image.fromarray(orig_img_np).save(orig_img_path) | |
| with tempfile.NamedTemporaryFile(mode='wb', suffix='.png', delete=False) as f: | |
| adv_img_path = f.name | |
| Image.fromarray(adv_img_np).save(adv_img_path) | |
| with tempfile.NamedTemporaryFile(mode='wb', suffix='.png', delete=False) as f: | |
| pert_img_path = f.name | |
| Image.fromarray(pert_img_np).save(pert_img_path) | |
| results = { | |
| "original_caption": orig_caption[0], | |
| "adversarial_caption": new_predictions[0], | |
| "original_image_path": orig_img_path, # Return file paths | |
| "adversarial_image_path": adv_img_path, | |
| "perturbation_image_path": pert_img_path | |
| } | |
| return results | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"Error in caption generation: {str(e)}\n{traceback.format_exc()}" | |
| print(error_msg, file=sys.stderr, flush=True) | |
| # Return dict with error information | |
| return { | |
| "original_caption": f"Error: {str(e)}", | |
| "adversarial_caption": "", | |
| "original_image_path": None, | |
| "adversarial_image_path": None, | |
| "perturbation_image_path": None | |
| } | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Generate caption for an image") | |
| parser.add_argument("--image_path", type=str, required=True, help="Path to the image") | |
| parser.add_argument("--model", type=str, default="open_flamingo", help="Model to use") | |
| parser.add_argument("--shots", type=int, default=0, help="Number of shots") | |
| parser.add_argument("--epsilon", type=float, default=8.0, help="Epsilon for adversarial attack") | |
| parser.add_argument("--sparsity", type=int, default=0, help="Sparsity for SAIF attack") | |
| parser.add_argument("--attack_algo", type=str, default="APGD", help="Adversarial attack algorithm (APGD or SAIF)") | |
| parser.add_argument("--num_iters", type=int, default=100, help="Number of iterations for adversarial attack") | |
| args = parser.parse_args() | |
| # Generate caption | |
| caption = generate_caption(args.image_path, args.epsilon, args.sparsity, args.attack_algo, args.num_iters, args.model, args.shots) | |
| if caption: | |
| print(caption) | |
| sys.exit(0) | |
| else: | |
| print("Failed to generate caption", file=sys.stderr) | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() | |