Jiaqi-hkust commited on
Commit
6f55415
ยท
verified ยท
1 Parent(s): 157d909

Upload folder using huggingface_hub

Browse files
Files changed (7) hide show
  1. .gitattributes +1 -0
  2. .gitignore +30 -0
  3. README.md +12 -12
  4. app.py +305 -4
  5. assets/1.jpg +0 -0
  6. assets/2.jpg +3 -0
  7. requirements.txt +8 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/2.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ env/
8
+ venv/
9
+ ENV/
10
+ .venv
11
+
12
+ # Gradio
13
+ .gradio_temp/
14
+
15
+ # IDE
16
+ .vscode/
17
+ .idea/
18
+ *.swp
19
+ *.swo
20
+
21
+ # OS
22
+ .DS_Store
23
+ Thumbs.db
24
+
25
+ # Model files (if storing locally)
26
+ *.bin
27
+ *.safetensors
28
+ *.pt
29
+ *.pth
30
+
README.md CHANGED
@@ -1,12 +1,12 @@
1
- ---
2
- title: Robust R1
3
- emoji: ๐Ÿ“š
4
- colorFrom: gray
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 5.49.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Robust-R1
3
+ emoji: ๐Ÿค–
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.0.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
app.py CHANGED
@@ -1,7 +1,308 @@
1
  import gradio as gr
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
+ import torch
4
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
5
+ from qwen_vl_utils import process_vision_info
6
+ import html
7
 
8
+ sys_prompt = """First output the the types of degradations in image briefly in <TYPE> <TYPE_END> tags,
9
+ and then output what effects do these degradation have on the image in <INFLUENCE> <INFLUENCE_END> tags,
10
+ then based on the strength of degradation, output an APPROPRIATE length for the reasoning process in <REASONING> <REASONING_END> tags,
11
+ and then summarize the content of reasoning and the give the answer in <CONCLUSION> <CONCLUSION_END> tags,
12
+ provides the user with the answer briefly in <ANSWER> <ANSWER_END>."""
13
 
14
+ project_dir = os.path.dirname(os.path.abspath(__file__))
15
+
16
+ is_spaces = os.getenv("SPACE_ID") is not None
17
+ if not is_spaces:
18
+ temp_dir = os.path.join(project_dir, ".gradio_temp")
19
+ os.makedirs(temp_dir, exist_ok=True)
20
+ os.environ["GRADIO_TEMP_DIR"] = temp_dir
21
+
22
+ MODEL_PATH = os.getenv("MODEL_PATH", "Jiaqi-hkust/Robust-R1")
23
+
24
+ print(f"==========================================")
25
+ print(f"Initializing application...")
26
+ print(f"==========================================")
27
+
28
+ class ModelHandler:
29
+ def __init__(self, model_path):
30
+ self.model_path = model_path
31
+ self.model = None
32
+ self.processor = None
33
+ self._load_model()
34
+
35
+ def _load_model(self):
36
+ try:
37
+ print(f"โณ Loading model weights, this may take a few minutes...")
38
+
39
+ self.processor = AutoProcessor.from_pretrained(self.model_path)
40
+
41
+ if torch.cuda.is_available():
42
+ device_capability = torch.cuda.get_device_capability()
43
+ use_flash_attention = device_capability[0] >= 8
44
+ print(f"๐Ÿ”ง CUDA available, device capability: {device_capability}")
45
+ else:
46
+ use_flash_attention = False
47
+ print(f"๐Ÿ”ง Using CPU or non-CUDA device")
48
+
49
+ self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
50
+ self.model_path,
51
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
52
+ device_map="auto",
53
+ attn_implementation="flash_attention_2" if use_flash_attention else "eager",
54
+ trust_remote_code=True
55
+ )
56
+ print("โœ… Model loaded successfully!")
57
+ except Exception as e:
58
+ print(f"โŒ Model loading failed: {e}")
59
+ raise e
60
+
61
+ def predict(self, message_dict, history, temperature, max_tokens):
62
+ text = message_dict.get("text", "")
63
+ files = message_dict.get("files", [])
64
+
65
+ messages = []
66
+
67
+ if history:
68
+ print(f"Processing {len(history)} previous messages from history")
69
+ for msg in history:
70
+ role = msg.get("role", "")
71
+ content = msg.get("content", "")
72
+
73
+ if role == "user":
74
+ user_content = []
75
+
76
+ if isinstance(content, list):
77
+ for item in content:
78
+ if isinstance(item, str):
79
+ if os.path.exists(item) or any(item.lower().endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp']):
80
+ user_content.append({"type": "image", "image": item})
81
+ else:
82
+ user_content.append({"type": "text", "text": item})
83
+ elif isinstance(item, dict):
84
+ user_content.append(item)
85
+ elif isinstance(content, str):
86
+ if content:
87
+ user_content.append({"type": "text", "text": content})
88
+
89
+ if user_content:
90
+ messages.append({"role": "user", "content": user_content})
91
+
92
+ elif role == "assistant":
93
+ if isinstance(content, str) and content:
94
+ messages.append({"role": "assistant", "content": content})
95
+
96
+ current_content = []
97
+ if files:
98
+ for file_path in files:
99
+ current_content.append({"type": "image", "image": file_path})
100
+
101
+ if text:
102
+ sys_prompt_formatted = " ".join(sys_prompt.split())
103
+ full_text = f"{text}\n{sys_prompt_formatted}"
104
+ current_content.append({"type": "text", "text": full_text})
105
+
106
+ if current_content:
107
+ messages.append({"role": "user", "content": current_content})
108
+
109
+ print(f"Total messages for model: {len(messages)}")
110
+ print(f"Message roles: {[m['role'] for m in messages]}")
111
+
112
+ text_prompt = self.processor.apply_chat_template(
113
+ messages, tokenize=False, add_generation_prompt=True
114
+ )
115
+
116
+ image_inputs, video_inputs = process_vision_info(messages)
117
+
118
+ inputs = self.processor(
119
+ text=[text_prompt],
120
+ images=image_inputs,
121
+ videos=video_inputs,
122
+ padding=True,
123
+ return_tensors="pt"
124
+ )
125
+
126
+ inputs = inputs.to(self.model.device)
127
+
128
+ generation_kwargs = dict(
129
+ **inputs,
130
+ max_new_tokens=max_tokens,
131
+ temperature=temperature,
132
+ do_sample=True if temperature > 0 else False,
133
+ )
134
+
135
+ try:
136
+ print("Starting model generation...")
137
+ with torch.no_grad():
138
+ generated_ids = self.model.generate(**generation_kwargs)
139
+
140
+ input_length = inputs['input_ids'].shape[1]
141
+ generated_ids = generated_ids[0][input_length:]
142
+
143
+ print(f"Input length: {input_length}, Generated token count: {len(generated_ids)}")
144
+
145
+ generated_text = self.processor.tokenizer.decode(
146
+ generated_ids,
147
+ skip_special_tokens=True
148
+ )
149
+
150
+ print(f"Generation completed. Output length: {len(generated_text)}, Content preview: {repr(generated_text[:200])}")
151
+
152
+ if generated_text and generated_text.strip():
153
+ print(f"Yielding generated text: {generated_text[:100]}...")
154
+ yield generated_text
155
+ else:
156
+ warning_msg = "โš ๏ธ No output generated. The model may not have produced any response."
157
+ print(warning_msg)
158
+ yield warning_msg
159
+
160
+ except Exception as e:
161
+ import traceback
162
+ error_details = traceback.format_exc()
163
+ print(f"Error in model.generate: {error_details}")
164
+ yield f"โŒ Generation error: {str(e)}"
165
+ return
166
+
167
+ model_handler = None
168
+
169
+ def get_model_handler():
170
+ """Get model handler with lazy loading"""
171
+ global model_handler
172
+ if model_handler is None:
173
+ print("๐Ÿ”„ Initializing model handler...")
174
+ model_handler = ModelHandler(MODEL_PATH)
175
+ return model_handler
176
+
177
+ def create_chat_ui():
178
+ custom_css = """
179
+ .gradio-container { font-family: 'Inter', sans-serif; }
180
+ #chatbot { height: 650px !important; overflow-y: auto; }
181
+ """
182
+
183
+ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="Robust-R1") as demo:
184
+
185
+ with gr.Row():
186
+ gr.Markdown("# ๐Ÿค–Robust-R1:Degradation-Aware Reasoning for Robust Visual Understanding")
187
+
188
+ with gr.Row():
189
+ with gr.Column(scale=4):
190
+ chatbot = gr.Chatbot(
191
+ elem_id="chatbot",
192
+ label="Chat",
193
+ type="messages",
194
+ avatar_images=(None, "https://api.dicebear.com/7.x/bottts/svg?seed=Qwen"),
195
+ height=650
196
+ )
197
+
198
+ chat_input = gr.MultimodalTextbox(
199
+ interactive=True,
200
+ file_types=["image"],
201
+ placeholder="Enter your question or upload an image...",
202
+ show_label=False
203
+ )
204
+
205
+ with gr.Column(scale=1):
206
+ with gr.Group():
207
+ gr.Markdown("### โš™๏ธ Generation Config")
208
+ temperature = gr.Slider(
209
+ minimum=0.01, maximum=1.0, value=0.6, step=0.05,
210
+ label="Temperature"
211
+ )
212
+ max_tokens = gr.Slider(
213
+ minimum=128, maximum=4096, value=1024, step=128,
214
+ label="Max New Tokens"
215
+ )
216
+
217
+ clear_btn = gr.Button("๐Ÿ—‘๏ธ Clear Context", variant="stop")
218
+
219
+ gr.Markdown("---")
220
+ gr.Markdown("### ๐Ÿ“š Examples")
221
+ gr.Markdown("Click the examples below to quickly fill the input box and start a conversation")
222
+
223
+ example_images_dir = os.path.join(project_dir, "assets")
224
+
225
+ examples_config = [
226
+ ("What type of vehicles are the people riding?\n0. trucks\n1. wagons\n2. jeeps\n3. cars\n", os.path.join(example_images_dir, "1.jpg")),
227
+ ("What is the giant fish in the air?\n0. blimp\n1. balloon\n2. kite\n3. sculpture\n", os.path.join(example_images_dir, "2.jpg")),
228
+ ]
229
+
230
+ example_data = []
231
+ for text, img_path in examples_config:
232
+ if os.path.exists(img_path):
233
+ example_data.append({"text": text, "files": [img_path]})
234
+
235
+ if example_data:
236
+ gr.Examples(
237
+ examples=example_data,
238
+ inputs=chat_input,
239
+ label="",
240
+ examples_per_page=3
241
+ )
242
+ else:
243
+ gr.Markdown("*No example images available, please manually upload images for testing*")
244
+
245
+ async def respond(user_msg, history, temp, tokens):
246
+ text = user_msg.get("text", "").strip()
247
+ files = user_msg.get("files", [])
248
+ user_content = list(files)
249
+ if text: user_content.append(text)
250
+
251
+ if not files and text: user_message = {"role": "user", "content": text}
252
+ else: user_message = {"role": "user", "content": user_content}
253
+
254
+ history.append(user_message)
255
+ yield history, gr.MultimodalTextbox(value=None, interactive=False)
256
+
257
+ history.append({"role": "assistant", "content": ""})
258
+
259
+ try:
260
+ previous_history = history[:-2] if len(history) >= 2 else []
261
+
262
+ handler = get_model_handler()
263
+ generated_text = ""
264
+ for chunk in handler.predict(user_msg, previous_history, temp, tokens):
265
+ generated_text = chunk
266
+
267
+ safe_text = generated_text.replace("<", "&lt;").replace(">", "&gt;")
268
+
269
+ history[-1]["content"] = safe_text
270
+ yield history, gr.MultimodalTextbox(interactive=False)
271
+
272
+ except Exception as e:
273
+ import traceback
274
+ traceback.print_exc()
275
+ history[-1]["content"] = f"โŒ Inference error: {str(e)}"
276
+ yield history, gr.MultimodalTextbox(interactive=True)
277
+
278
+ yield history, gr.MultimodalTextbox(value=None, interactive=True)
279
+
280
+ chat_input.submit(
281
+ respond,
282
+ inputs=[chat_input, chatbot, temperature, max_tokens],
283
+ outputs=[chatbot, chat_input]
284
+ )
285
+
286
+ def clear_history(): return [], None
287
+ clear_btn.click(clear_history, outputs=[chatbot, chat_input])
288
+
289
+ return demo
290
+
291
+ if __name__ == "__main__":
292
+ demo = create_chat_ui()
293
+
294
+ if is_spaces:
295
+ print(f"๐Ÿš€ Running on Hugging Face Spaces: {os.getenv('SPACE_ID')}")
296
+ demo.launch(
297
+ show_error=True,
298
+ allowed_paths=[project_dir] if project_dir else None
299
+ )
300
+ else:
301
+ print(f"๐Ÿš€ Service is starting, please visit: http://localhost:7860")
302
+ demo.launch(
303
+ server_name="0.0.0.0",
304
+ server_port=7860,
305
+ share=False,
306
+ show_error=True,
307
+ allowed_paths=[project_dir]
308
+ )
assets/1.jpg ADDED
assets/2.jpg ADDED

Git LFS Details

  • SHA256: b5a734f22c5bf40c8e70c223ce9196d5ddb6f290650b8370fedfa685333d822c
  • Pointer size: 131 Bytes
  • Size of remote file: 209 kB
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=2.0.0
3
+ transformers>=4.37.0
4
+ qwen-vl-utils
5
+ accelerate
6
+ sentencepiece
7
+ protobuf
8
+ pillow