OCR / app.py
Sensei13k's picture
Update app.py
ecb5889 verified
import os
import uuid
import shutil
import cv2
import gradio as gr
import numpy as np
import spaces
import torch
from PIL import Image
from transformers import AutoModelForImageTextToText, AutoProcessor
from transformers.image_utils import load_image
model_name = "stepfun-ai/GOT-OCR-2.0-hf"
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForImageTextToText.from_pretrained(
model_name, low_cpu_mem_usage=True, device_map=device
)
model = model.eval().to(device)
UPLOAD_FOLDER = "./uploads"
stop_str = "<|im_end|>"
if not os.path.exists(UPLOAD_FOLDER):
os.makedirs(UPLOAD_FOLDER)
@spaces.GPU()
def process_ocr(image):
if image is None:
return "⚠️ Please upload an image first"
unique_id = str(uuid.uuid4())
image_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}.png")
try:
# Handle different image formats
if isinstance(image, np.ndarray):
cv2.imwrite(image_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
elif isinstance(image, str):
shutil.copy(image, image_path)
else:
return "⚠️ Unsupported image format"
image = load_image(image_path)
# Process with OCR
inputs = processor([image], return_tensors="pt").to(device)
generate_ids = model.generate(
**inputs,
do_sample=False,
tokenizer=processor.tokenizer,
stop_strings=stop_str,
max_new_tokens=4096,
)
result = processor.decode(
generate_ids[0, inputs["input_ids"].shape[1]:],
skip_special_tokens=True,
)
return result
except Exception as e:
return f"❌ Error: {str(e)}"
finally:
if os.path.exists(image_path):
os.remove(image_path)
# Custom CSS for modern, minimal design
custom_css = """
#header {
text-align: center;
padding: 2rem 0;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border-radius: 12px;
margin-bottom: 2rem;
}
#header h1 {
margin: 0;
font-size: 2.5rem;
font-weight: 700;
letter-spacing: -0.5px;
}
#header p {
margin: 0.5rem 0 0 0;
font-size: 1.1rem;
opacity: 0.95;
}
.main-container {
max-width: 1200px;
margin: 0 auto;
}
#image_input {
border: 2px dashed #667eea !important;
border-radius: 12px !important;
transition: all 0.3s ease;
}
#image_input:hover {
border-color: #764ba2 !important;
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.15);
}
#process_btn {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
border: none !important;
font-size: 1.1rem !important;
font-weight: 600 !important;
padding: 0.75rem 2rem !important;
border-radius: 8px !important;
transition: all 0.3s ease !important;
}
#process_btn:hover {
transform: translateY(-2px);
box-shadow: 0 6px 20px rgba(102, 126, 234, 0.3) !important;
}
#output_text {
border-radius: 12px !important;
font-family: 'Monaco', 'Courier New', monospace !important;
font-size: 0.95rem !important;
line-height: 1.6 !important;
}
.input-section, .output-section {
background: white;
padding: 1.5rem;
border-radius: 12px;
box-shadow: 0 2px 8px rgba(0,0,0,0.05);
}
footer {
text-align: center;
padding: 2rem 0;
color: #666;
font-size: 0.9rem;
}
"""
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
gr.HTML("""
<div id="header">
<h1>✨ GOT-OCR 2.0</h1>
<p>Extract text from images with AI-powered OCR</p>
</div>
""")
with gr.Row(elem_classes="main-container"):
with gr.Column(scale=1, elem_classes="input-section"):
image_input = gr.Image(
type="filepath",
label="📸 Upload Image",
elem_id="image_input",
height=400
)
process_btn = gr.Button(
"🚀 Extract Text",
elem_id="process_btn",
size="lg"
)
with gr.Column(scale=1, elem_classes="output-section"):
output_text = gr.Textbox(
label="📝 Extracted Text",
elem_id="output_text",
lines=20,
placeholder="Your extracted text will appear here...",
show_copy_button=True
)
gr.HTML("""
<footer>
<p>Powered by GOT-OCR-2.0-hf | Built with Gradio</p>
</footer>
""")
# Connect the button to the processing function
process_btn.click(
fn=process_ocr,
inputs=[image_input],
outputs=[output_text]
)
if __name__ == "__main__":
demo.launch()