Corin1998 commited on
Commit
d229d04
·
verified ·
1 Parent(s): 5621a82

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +413 -20
app.py CHANGED
@@ -2,28 +2,427 @@ import os
2
  import io
3
  import time
4
  import sys
 
 
 
5
  import gradio as gr
6
 
7
- # --- Ensure we can import modules whether it's a package or a plain folder ---
8
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
9
- MODULES_DIR = os.path.join(BASE_DIR, "modules")
10
- if os.path.isdir(MODULES_DIR) and MODULES_DIR not in sys.path:
11
- sys.path.insert(0, MODULES_DIR)
 
 
 
 
 
 
12
 
 
13
  try:
14
- # Prefer package-style if modules/__init__.py exists
15
- from modules.text_processing import process_text
16
- from modules.pptx_builder import build_presentation
17
- from modules.utils import safe_hex_to_rgb, ensure_tmpdir
18
- except ModuleNotFoundError:
19
- # Fallback: flat imports from ./modules added to sys.path
20
- from text_processing import process_text
21
- from pptx_builder import build_presentation
22
- from utils import safe_hex_to_rgb, ensure_tmpdir
23
 
24
  APP_NAME = "Auto-PPT Generator"
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def generate_pptx(long_text: str,
28
  title: str,
29
  theme_hex: str,
@@ -52,7 +451,6 @@ def generate_pptx(long_text: str,
52
  except Exception:
53
  logo_bytes = None
54
 
55
- # Step 1–3: NLP pipeline (summary, sections, bullets, tables, chart data)
56
  result = process_text(
57
  text=long_text,
58
  use_inference_api=use_inference_api,
@@ -64,7 +462,6 @@ def generate_pptx(long_text: str,
64
  max_summary_words=max_summary_words,
65
  )
66
 
67
- # Step 4: Build PPTX
68
  ensure_tmpdir()
69
  timestamp = time.strftime('%Y%m%d-%H%M%S')
70
  out_path = f"/tmp/auto_ppt_{timestamp}.pptx"
@@ -80,11 +477,8 @@ def generate_pptx(long_text: str,
80
  tables=result.get("tables", []),
81
  charts=result.get("charts", []),
82
  )
83
-
84
- # Return file path for download
85
  return out_path
86
 
87
-
88
  def ui():
89
  with gr.Blocks(title=APP_NAME) as demo:
90
  gr.Markdown(f"# {APP_NAME}\n長文→要約→セクション分割→箇条書き/表/図→**PPTX出力** まで自動化")
@@ -122,7 +516,6 @@ def ui():
122
  """)
123
  return demo
124
 
125
-
126
  if __name__ == "__main__":
127
  demo = ui()
128
  # Spaces は自動でバインドされますが、ローカル互換のため指定可能
 
2
  import io
3
  import time
4
  import sys
5
+ import re
6
+ from typing import Optional, List, Tuple, Dict, Any
7
+
8
  import gradio as gr
9
 
10
+ # 安全のため、GUI不要の描画バックエンドを指定
11
+ import matplotlib
12
+ matplotlib.use("Agg")
13
+ import matplotlib.pyplot as plt
14
+
15
+ from pptx import Presentation
16
+ from pptx.util import Inches, Pt
17
+ from pptx.enum.text import PP_ALIGN
18
+ from pptx.enum.shapes import MSO_AUTO_SHAPE_TYPE
19
+ from pptx.dml.color import RGBColor
20
+ from PIL import Image
21
 
22
+ # transformers は任意(未インストールでも動作させる)
23
  try:
24
+ from transformers import pipeline
25
+ except Exception:
26
+ pipeline = None
27
+
28
+ import requests # Inference API を使う場合にのみ利用
 
 
 
 
29
 
30
  APP_NAME = "Auto-PPT Generator"
31
 
32
+ # =========================
33
+ # utils
34
+ # =========================
35
+ def safe_hex_to_rgb(hex_color: str):
36
+ if not hex_color:
37
+ return (59, 130, 246) # default blue
38
+ hx = hex_color.strip()
39
+ if not hx.startswith("#"):
40
+ hx = "#" + hx
41
+ if re.fullmatch(r"#[0-9A-Fa-f]{6}", hx):
42
+ r = int(hx[1:3], 16)
43
+ g = int(hx[3:5], 16)
44
+ b = int(hx[5:7], 16)
45
+ return (r, g, b)
46
+ return (59, 130, 246)
47
+
48
+ def ensure_tmpdir():
49
+ os.makedirs("/tmp", exist_ok=True)
50
+
51
+ # =========================
52
+ # LLM client (local / HF API)
53
+ # =========================
54
+ class LLMClient:
55
+ def __init__(self, use_inference_api: bool = False):
56
+ self.use_inference_api = use_inference_api
57
+ self.hf_token = os.getenv("HF_TOKEN", None)
58
+ self._local_pipes = {}
59
+
60
+ # ---------- Inference API helpers ----------
61
+ def _hf_headers(self):
62
+ if not self.hf_token:
63
+ raise RuntimeError("HF_TOKEN is not set for Inference API usage.")
64
+ return {"Authorization": f"Bearer {self.hf_token}"}
65
+
66
+ def _hf_textgen(self, model: str, prompt: str, max_new_tokens: int = 512, temperature: float = 0.3) -> str:
67
+ url = f"https://api-inference.huggingface.co/models/{model}"
68
+ payload = {
69
+ "inputs": prompt,
70
+ "parameters": {
71
+ "max_new_tokens": max_new_tokens,
72
+ "temperature": temperature,
73
+ "return_full_text": False,
74
+ },
75
+ }
76
+ r = requests.post(url, headers=self._hf_headers(), json=payload, timeout=120)
77
+ r.raise_for_status()
78
+ data = r.json()
79
+ if isinstance(data, list) and data and "generated_text" in data[0]:
80
+ return data[0]["generated_text"]
81
+ if isinstance(data, dict) and "generated_text" in data:
82
+ return data["generated_text"]
83
+ # summarization系モデルは list[0]['summary_text'] の場合も
84
+ if isinstance(data, list) and data and "summary_text" in data[0]:
85
+ return data[0]["summary_text"]
86
+ return str(data)
87
+
88
+ def _get_local_pipe(self, task: str, model: str):
89
+ key = (task, model)
90
+ if key in self._local_pipes:
91
+ return self._local_pipes[key]
92
+ if pipeline is None:
93
+ raise RuntimeError("transformers is not available")
94
+ pipe = pipeline(task=task, model=model)
95
+ self._local_pipes[key] = pipe
96
+ return pipe
97
+
98
+ # ---------- Public ----------
99
+ def summarize(self, text: str, model: str, max_words: int = 200) -> str:
100
+ # Inference API 優先
101
+ if self.use_inference_api and model:
102
+ try:
103
+ return self._hf_textgen(model, text[:6000], max_new_tokens=max_words * 2).strip()
104
+ except Exception:
105
+ pass
106
+
107
+ # ローカル(transformers)試行
108
+ if pipeline is not None:
109
+ try:
110
+ if "t5" in model.lower():
111
+ pipe = self._get_local_pipe("text2text-generation", model)
112
+ prompt = f"要約: {text[:6000]}"
113
+ res = pipe(prompt, max_length=max_words * 2, do_sample=False)
114
+ return res[0]["generated_text"].strip()
115
+ else:
116
+ pipe = self._get_local_pipe("summarization", model)
117
+ res = pipe(text[:6000], max_length=max_words * 2, min_length=max_words // 2, do_sample=False)
118
+ return res[0]["summary_text"].strip()
119
+ except Exception:
120
+ pass
121
+
122
+ # フォールバック:先頭の短文を並べるだけ
123
+ sents = re.split(r"[。\.!?]\s*", text)
124
+ out = []
125
+ for s in sents:
126
+ s = s.strip()
127
+ if s:
128
+ out.append(s)
129
+ if len(" ".join(out)) > max_words * 6:
130
+ break
131
+ return "。".join(out)
132
+
133
+ def generate(self, prompt: str, model: Optional[str] = None, max_new_tokens: int = 512) -> str:
134
+ if self.use_inference_api and model:
135
+ try:
136
+ return self._hf_textgen(model, prompt, max_new_tokens=max_new_tokens)
137
+ except Exception:
138
+ return ""
139
+ return "" # 本実装ではルールベースに依存
140
+
141
+
142
+ # =========================
143
+ # text processing
144
+ # =========================
145
+ LIST_BULLET = re.compile(r"^(?:[-*•・]|\d+\.|\d+\))\s+(.*)")
146
+ KEYVAL_LINE = re.compile(r"^\s*([^::]+?)\s*[::]\s*([^\n]+?)\s*$")
147
+ LABEL_NUM = re.compile(r"^\s*([^::]+?)\s*[::]\s*([+-]?\d+(?:\.\d+)?)\s*$")
148
+ HEADER = re.compile(r"^(#+|\d+\.|\d+\))\s*(.+)$")
149
+
150
+ def naive_section_split(text: str, target_chars: int = 1200) -> List[Tuple[str, str]]:
151
+ """Split into (title, content) using headings or by size."""
152
+ lines = text.splitlines()
153
+ sections: List[Tuple[str, str]] = []
154
+ cur_title = "セクション"
155
+ cur_buf: List[str] = []
156
+
157
+ def flush():
158
+ nonlocal cur_title, cur_buf
159
+ if cur_buf:
160
+ sections.append((cur_title, "\n".join(cur_buf).strip()))
161
+ cur_buf = []
162
+
163
+ for ln in lines:
164
+ m = HEADER.match(ln.strip())
165
+ if m:
166
+ flush()
167
+ cur_title = m.group(2).strip()
168
+ continue
169
+ cur_buf.append(ln)
170
+ if sum(len(x) for x in cur_buf) > target_chars:
171
+ flush()
172
+ cur_title = f"セクション{len(sections)+1}"
173
+ flush()
174
+
175
+ if not sections:
176
+ sections = [("本文", text)]
177
+ return sections
178
+
179
+ def extract_bullets(section_text: str, max_items: int = 8) -> List[str]:
180
+ bullets: List[str] = []
181
+ for line in section_text.splitlines():
182
+ m = LIST_BULLET.match(line.strip())
183
+ if m:
184
+ bullets.append(m.group(1).strip())
185
+ if not bullets:
186
+ sents = re.split(r"[。\.!?]\s*", section_text)
187
+ for s in sents:
188
+ s = s.strip()
189
+ if 8 <= len(s) <= 120:
190
+ bullets.append(s)
191
+ if len(bullets) >= max_items:
192
+ break
193
+ return bullets[:max_items]
194
+
195
+ def extract_keyval_table(section_text: str) -> List[Tuple[str, str]]:
196
+ pairs: List[Tuple[str, str]] = []
197
+ for line in section_text.splitlines():
198
+ m = KEYVAL_LINE.match(line)
199
+ if m:
200
+ k = m.group(1).strip()
201
+ v = m.group(2).strip()
202
+ if k and v:
203
+ pairs.append((k, v))
204
+ return pairs
205
+
206
+ def extract_chart_data(section_text: str, top_k: int = 10) -> List[Tuple[str, float]]:
207
+ data: List[Tuple[str, float]] = []
208
+ for line in section_text.splitlines():
209
+ m = LABEL_NUM.match(line)
210
+ if m:
211
+ label = m.group(1).strip()
212
+ try:
213
+ val = float(m.group(2))
214
+ except ValueError:
215
+ continue
216
+ data.append((label, val))
217
+ seen = {}
218
+ for k, v in data:
219
+ seen[k] = v
220
+ items = list(seen.items())
221
+ items.sort(key=lambda x: abs(x[1]), reverse=True)
222
+ return items[:top_k]
223
 
224
+ def process_text(text: str,
225
+ use_inference_api: bool,
226
+ summarizer_model: str,
227
+ generator_model: str,
228
+ want_summary: bool,
229
+ want_tables: bool,
230
+ want_charts: bool,
231
+ max_summary_words: int = 200) -> Dict[str, Any]:
232
+ client = LLMClient(use_inference_api=use_inference_api)
233
+
234
+ summary = None
235
+ if want_summary:
236
+ summary = client.summarize(text, model=summarizer_model, max_words=max_summary_words)
237
+
238
+ sections = naive_section_split(text)
239
+
240
+ bullets_by_section: Dict[int, List[str]] = {}
241
+ tables: List[Dict[str, Any]] = []
242
+ charts: List[Dict[str, Any]] = []
243
+
244
+ for idx, (title, body) in enumerate(sections):
245
+ bullets_by_section[idx] = extract_bullets(body)
246
+
247
+ if want_tables:
248
+ kv = extract_keyval_table(body)
249
+ if kv:
250
+ tables.append({"title": f"{title} — 表", "pairs": kv})
251
+
252
+ if want_charts:
253
+ series = extract_chart_data(body)
254
+ if series:
255
+ charts.append({"title": f"{title} — チャート", "series": series})
256
+
257
+ return {
258
+ "summary": summary,
259
+ "sections": sections,
260
+ "bullets": bullets_by_section,
261
+ "tables": tables,
262
+ "charts": charts,
263
+ }
264
+
265
+ # =========================
266
+ # pptx builder
267
+ # =========================
268
+ def _add_logo(prs: Presentation, slide, logo_bytes: Optional[bytes]):
269
+ if not logo_bytes:
270
+ return
271
+ img = Image.open(io.BytesIO(logo_bytes)).convert("RGBA")
272
+ max_w, max_h = Inches(2.0), Inches(1.0)
273
+ w, h = img.size
274
+ ratio = min(max_w / max(w, 1), max_h / max(h, 1))
275
+ new_size = (max(1, int(w * ratio)), max(1, int(h * ratio)))
276
+ resized = img.resize(new_size)
277
+ b = io.BytesIO()
278
+ resized.save(b, format="PNG")
279
+ b.seek(0)
280
+ left = prs.slide_width - max_w - Inches(0.5)
281
+ top = Inches(0.2)
282
+ slide.shapes.add_picture(b, left, top)
283
+
284
+ def _apply_theme_bg(slide, rgb):
285
+ fill = slide.background.fill
286
+ fill.solid()
287
+ fill.fore_color.rgb = RGBColor(*rgb)
288
+
289
+ def _title_slide(prs, title_text: str, theme_rgb, logo_bytes):
290
+ slide_layout = prs.slide_layouts[0]
291
+ slide = prs.slides.add_slide(slide_layout)
292
+ title = slide.shapes.title
293
+ subtitle = slide.placeholders[1]
294
+ title.text = title_text
295
+ subtitle.text = "自動生成プレゼンテーション"
296
+ _apply_theme_bg(slide, theme_rgb)
297
+ left = Inches(0.6)
298
+ top = Inches(1.8)
299
+ width = prs.slide_width - Inches(1.2)
300
+ height = Inches(2.2)
301
+ box = slide.shapes.add_shape(MSO_AUTO_SHAPE_TYPE.ROUNDED_RECTANGLE, left, top, width, height)
302
+ box.fill.solid()
303
+ box.fill.fore_color.rgb = RGBColor(255, 255, 255)
304
+ box.line.color.rgb = RGBColor(0, 0, 0)
305
+ box.line.transparency = 0.8
306
+ title.left = left + Inches(0.3)
307
+ title.top = top + Inches(0.3)
308
+ title.width = width - Inches(0.6)
309
+ title.height = Inches(1.4)
310
+ for p in title.text_frame.paragraphs:
311
+ p.font.size = Pt(40)
312
+ p.font.bold = True
313
+ subtitle.left = left + Inches(0.3)
314
+ subtitle.top = top + Inches(1.6)
315
+ subtitle.width = width - Inches(0.6)
316
+ subtitle.height = Inches(0.8)
317
+ for p in subtitle.text_frame.paragraphs:
318
+ p.font.size = Pt(16)
319
+ p.font.bold = False
320
+ _add_logo(prs, slide, logo_bytes)
321
+
322
+ def _summary_slide(prs, summary: str):
323
+ if not summary:
324
+ return
325
+ slide = prs.slides.add_slide(prs.slide_layouts[1])
326
+ slide.shapes.title.text = "エグゼクティブサマリー"
327
+ tf = slide.placeholders[1].text_frame
328
+ tf.clear()
329
+ lines = [ln.strip() for ln in summary.splitlines() if ln.strip()]
330
+ if not lines:
331
+ lines = [summary]
332
+ for i, ln in enumerate(lines):
333
+ p = tf.add_paragraph() if i > 0 else tf.paragraphs[0]
334
+ p.text = ln
335
+ p.level = 0
336
+
337
+ def _section_slide(prs, title: str, bullets: List[str]):
338
+ slide = prs.slides.add_slide(prs.slide_layouts[1])
339
+ slide.shapes.title.text = title[:90]
340
+ tf = slide.placeholders[1].text_frame
341
+ tf.clear()
342
+ if not bullets:
343
+ bullets = ["(要点なし)"]
344
+ for i, b in enumerate(bullets[:12]):
345
+ p = tf.add_paragraph() if i > 0 else tf.paragraphs[0]
346
+ p.text = b
347
+ p.level = 0
348
+
349
+ def _table_slide(prs, title: str, pairs: List[tuple]):
350
+ slide = prs.slides.add_slide(prs.slide_layouts[5])
351
+ slide.shapes.title.text = title
352
+ rows = len(pairs) + 1
353
+ cols = 2
354
+ left = Inches(0.5)
355
+ top = Inches(1.8)
356
+ width = prs.slide_width - Inches(1.0)
357
+ height = prs.slide_height - Inches(2.6)
358
+ table = slide.shapes.add_table(rows, cols, left, top, width, height).table
359
+ table.cell(0, 0).text = "項目"
360
+ table.cell(0, 1).text = "値"
361
+ for r, (k, v) in enumerate(pairs, start=1):
362
+ table.cell(r, 0).text = str(k)
363
+ table.cell(r, 1).text = str(v)
364
+
365
+ def _chart_slide(prs, title: str, series: List[tuple]):
366
+ slide = prs.slides.add_slide(prs.slide_layouts[5])
367
+ slide.shapes.title.text = title
368
+ labels = [x[0] for x in series]
369
+ values = [x[1] for x in series]
370
+ fig = plt.figure(figsize=(8, 4.5))
371
+ plt.bar(range(len(values)), values)
372
+ plt.xticks(range(len(labels)), labels, rotation=20, ha='right')
373
+ plt.tight_layout()
374
+ buf = io.BytesIO()
375
+ fig.savefig(buf, format='png', dpi=200)
376
+ plt.close(fig)
377
+ buf.seek(0)
378
+ left = Inches(0.5)
379
+ top = Inches(1.6)
380
+ width = prs.slide_width - Inches(1.0)
381
+ height = prs.slide_height - Inches(2.2)
382
+ slide.shapes.add_picture(buf, left, top, width=width, height=height)
383
+
384
+ def _add_footer(prs, theme_rgb):
385
+ for idx, slide in enumerate(prs.slides, start=1):
386
+ left = Inches(0.3)
387
+ top = prs.slide_height - Inches(0.4)
388
+ width = prs.slide_width - Inches(0.6)
389
+ height = Inches(0.3)
390
+ shp = slide.shapes.add_shape(MSO_AUTO_SHAPE_TYPE.RECTANGLE, left, top, width, height)
391
+ shp.fill.solid()
392
+ shp.fill.fore_color.rgb = RGBColor(*theme_rgb)
393
+ shp.line.fill.background()
394
+ tx = slide.shapes.add_textbox(prs.slide_width - Inches(1.0), top - Inches(0.05), Inches(0.8), Inches(0.3))
395
+ tf = tx.text_frame
396
+ p = tf.paragraphs[0]
397
+ p.text = f"{idx}"
398
+ p.font.size = Pt(10)
399
+ p.alignment = PP_ALIGN.RIGHT
400
+
401
+ def build_presentation(output_path: str,
402
+ title: str,
403
+ theme_rgb: tuple,
404
+ logo_bytes: Optional[bytes],
405
+ executive_summary: Optional[str],
406
+ sections: List[Tuple[str, str]],
407
+ bullets_by_section: Dict[int, List[str]],
408
+ tables: List[Dict[str, Any]],
409
+ charts: List[Dict[str, Any]]):
410
+ prs = Presentation()
411
+ _title_slide(prs, title, theme_rgb, logo_bytes)
412
+ _summary_slide(prs, executive_summary)
413
+ for idx, (sec_title, _body) in enumerate(sections):
414
+ bullets = bullets_by_section.get(idx, [])
415
+ _section_slide(prs, sec_title, bullets)
416
+ for tbl in tables:
417
+ _table_slide(prs, tbl.get("title", "表"), tbl.get("pairs", []))
418
+ for ch in charts:
419
+ _chart_slide(prs, ch.get("title", "チャート"), ch.get("series", []))
420
+ _add_footer(prs, theme_rgb)
421
+ prs.save(output_path)
422
+
423
+ # =========================
424
+ # Gradio App
425
+ # =========================
426
  def generate_pptx(long_text: str,
427
  title: str,
428
  theme_hex: str,
 
451
  except Exception:
452
  logo_bytes = None
453
 
 
454
  result = process_text(
455
  text=long_text,
456
  use_inference_api=use_inference_api,
 
462
  max_summary_words=max_summary_words,
463
  )
464
 
 
465
  ensure_tmpdir()
466
  timestamp = time.strftime('%Y%m%d-%H%M%S')
467
  out_path = f"/tmp/auto_ppt_{timestamp}.pptx"
 
477
  tables=result.get("tables", []),
478
  charts=result.get("charts", []),
479
  )
 
 
480
  return out_path
481
 
 
482
  def ui():
483
  with gr.Blocks(title=APP_NAME) as demo:
484
  gr.Markdown(f"# {APP_NAME}\n長文→要約→セクション分割→箇条書き/表/図→**PPTX出力** まで自動化")
 
516
  """)
517
  return demo
518
 
 
519
  if __name__ == "__main__":
520
  demo = ui()
521
  # Spaces は自動でバインドされますが、ローカル互換のため指定可能