| import torch |
| import os |
| import sys |
| import traceback |
| import gradio as gr |
| from PIL import Image |
| from transformers import AutoModel, CLIPImageProcessor |
|
|
| print("=" * 50) |
| print("INTERNVIT-6B MODEL LOADING TEST (NO FLASH-ATTN)") |
| print("=" * 50) |
|
|
| |
| print(f"Python version: {sys.version}") |
| print(f"PyTorch version: {torch.__version__}") |
| print(f"CUDA available: {torch.cuda.is_available()}") |
|
|
| if torch.cuda.is_available(): |
| print(f"CUDA version: {torch.version.cuda}") |
| print(f"GPU count: {torch.cuda.device_count()}") |
| for i in range(torch.cuda.device_count()): |
| print(f"GPU {i}: {torch.cuda.get_device_name(i)}") |
| |
| |
| print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") |
| print(f"Allocated GPU memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB") |
| print(f"Reserved GPU memory: {torch.cuda.memory_reserved() / 1e9:.2f} GB") |
| else: |
| print("CUDA is not available. This is a critical issue for model loading.") |
|
|
| |
| def load_and_test_model(): |
| try: |
| |
| import sys |
| import types |
| |
| |
| flash_attn_module = types.ModuleType("flash_attn") |
| flash_attn_module.__version__ = "0.0.0-disabled" |
| sys.modules["flash_attn"] = flash_attn_module |
| |
| print("\nNOTE: Created dummy flash_attn module to avoid dependency error") |
| print("This is just for testing basic model loading - some functionality may be disabled") |
| |
| print("\nLoading model with bfloat16 precision and low_cpu_mem_usage=True...") |
| model = AutoModel.from_pretrained( |
| "OpenGVLab/InternViT-6B-224px", |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| trust_remote_code=True) |
| |
| if torch.cuda.is_available(): |
| print("Moving model to CUDA...") |
| model = model.cuda() |
| |
| model.eval() |
| print("β Model loaded successfully!") |
| |
| |
| print("\nLoading image processor...") |
| image_processor = CLIPImageProcessor.from_pretrained("OpenGVLab/InternViT-6B-224px") |
| print("β Image processor loaded successfully!") |
| |
| |
| print("\nCreating test image...") |
| test_image = Image.new("RGB", (224, 224), color="red") |
| |
| |
| print("Processing test image...") |
| pixel_values = image_processor(images=test_image, return_tensors="pt").pixel_values |
| |
| |
| print("Converting image tensor to bfloat16 to match model dtype...") |
| pixel_values = pixel_values.to(torch.bfloat16) |
| |
| if torch.cuda.is_available(): |
| print("Moving image tensor to CUDA...") |
| pixel_values = pixel_values.cuda() |
| |
| |
| params = sum(p.numel() for p in model.parameters()) |
| print(f"Model parameters: {params:,}") |
| |
| |
| print("Running forward pass...") |
| with torch.no_grad(): |
| outputs = model(pixel_values) |
| |
| print("β Forward pass successful!") |
| print(f"Output shape: {outputs.last_hidden_state.shape}") |
| |
| return f"SUCCESS: Model loaded and test passed!\nParameters: {params:,}\nOutput shape: {outputs.last_hidden_state.shape}" |
| |
| except Exception as e: |
| print(f"\nβ ERROR: {str(e)}") |
| traceback.print_exc() |
| return f"FAILED: Error loading model or processing image\nError: {str(e)}" |
|
|
| |
| def create_interface(): |
| with gr.Blocks(title="InternViT-6B Test") as demo: |
| gr.Markdown("# InternViT-6B Model Loading Test (without Flash Attention)") |
| gr.Markdown("### This version uses a dummy flash-attn implementation to avoid compilation issues") |
| |
| with gr.Row(): |
| test_btn = gr.Button("Test Model Loading") |
| output = gr.Textbox(label="Test Results", lines=10) |
| |
| test_btn.click(fn=load_and_test_model, inputs=[], outputs=output) |
| |
| return demo |
|
|
| |
| if __name__ == "__main__": |
| |
| print("\nEnvironment variables:") |
| relevant_vars = ["CUDA_VISIBLE_DEVICES", "NVIDIA_VISIBLE_DEVICES", |
| "TRANSFORMERS_CACHE", "HF_HOME", "PYTORCH_CUDA_ALLOC_CONF"] |
| for var in relevant_vars: |
| print(f"{var}: {os.environ.get(var, 'Not set')}") |
| |
| |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" |
| |
| |
| demo = create_interface() |
| demo.launch(share=False, server_name="0.0.0.0") |