Corin1998 commited on
Commit
e2a2f35
·
verified ·
1 Parent(s): d229d04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -53
app.py CHANGED
@@ -7,10 +7,11 @@ from typing import Optional, List, Tuple, Dict, Any
7
 
8
  import gradio as gr
9
 
10
- # 安全のため、GUI不要の描画バックエンドを指定
11
  import matplotlib
12
  matplotlib.use("Agg")
13
  import matplotlib.pyplot as plt
 
14
 
15
  from pptx import Presentation
16
  from pptx.util import Inches, Pt
@@ -19,19 +20,73 @@ from pptx.enum.shapes import MSO_AUTO_SHAPE_TYPE
19
  from pptx.dml.color import RGBColor
20
  from PIL import Image
21
 
22
- # transformers は任意(未インストールでも動作させる)
23
  try:
24
  from transformers import pipeline
25
  except Exception:
26
  pipeline = None
27
 
28
- import requests # Inference API を使う場合にのみ利用
29
 
30
  APP_NAME = "Auto-PPT Generator"
31
 
32
- # =========================
33
  # utils
34
- # =========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def safe_hex_to_rgb(hex_color: str):
36
  if not hex_color:
37
  return (59, 130, 246) # default blue
@@ -45,19 +100,22 @@ def safe_hex_to_rgb(hex_color: str):
45
  return (r, g, b)
46
  return (59, 130, 246)
47
 
 
48
  def ensure_tmpdir():
49
  os.makedirs("/tmp", exist_ok=True)
50
 
51
- # =========================
52
- # LLM client (local / HF API)
53
- # =========================
 
 
54
  class LLMClient:
55
  def __init__(self, use_inference_api: bool = False):
56
  self.use_inference_api = use_inference_api
57
  self.hf_token = os.getenv("HF_TOKEN", None)
58
  self._local_pipes = {}
59
 
60
- # ---------- Inference API helpers ----------
61
  def _hf_headers(self):
62
  if not self.hf_token:
63
  raise RuntimeError("HF_TOKEN is not set for Inference API usage.")
@@ -80,11 +138,11 @@ class LLMClient:
80
  return data[0]["generated_text"]
81
  if isinstance(data, dict) and "generated_text" in data:
82
  return data["generated_text"]
83
- # summarization系モデルは list[0]['summary_text'] の場合も
84
  if isinstance(data, list) and data and "summary_text" in data[0]:
85
  return data[0]["summary_text"]
86
  return str(data)
87
 
 
88
  def _get_local_pipe(self, task: str, model: str):
89
  key = (task, model)
90
  if key in self._local_pipes:
@@ -104,8 +162,8 @@ class LLMClient:
104
  except Exception:
105
  pass
106
 
107
- # ローカル(transformers)試行
108
- if pipeline is not None:
109
  try:
110
  if "t5" in model.lower():
111
  pipe = self._get_local_pipe("text2text-generation", model)
@@ -119,7 +177,7 @@ class LLMClient:
119
  except Exception:
120
  pass
121
 
122
- # フォールバック:先頭の短文を並べるだけ
123
  sents = re.split(r"[。\.!?]\s*", text)
124
  out = []
125
  for s in sents:
@@ -136,12 +194,13 @@ class LLMClient:
136
  return self._hf_textgen(model, prompt, max_new_tokens=max_new_tokens)
137
  except Exception:
138
  return ""
139
- return "" # 本実装ではルールベースに依存
 
140
 
 
 
 
141
 
142
- # =========================
143
- # text processing
144
- # =========================
145
  LIST_BULLET = re.compile(r"^(?:[-*•・]|\d+\.|\d+\))\s+(.*)")
146
  KEYVAL_LINE = re.compile(r"^\s*([^::]+?)\s*[::]\s*([^\n]+?)\s*$")
147
  LABEL_NUM = re.compile(r"^\s*([^::]+?)\s*[::]\s*([+-]?\d+(?:\.\d+)?)\s*$")
@@ -176,7 +235,7 @@ def naive_section_split(text: str, target_chars: int = 1200) -> List[Tuple[str,
176
  sections = [("本文", text)]
177
  return sections
178
 
179
- def extract_bullets(section_text: str, max_items: int = 8) -> List[str]:
180
  bullets: List[str] = []
181
  for line in section_text.splitlines():
182
  m = LIST_BULLET.match(line.strip())
@@ -203,7 +262,7 @@ def extract_keyval_table(section_text: str) -> List[Tuple[str, str]]:
203
  pairs.append((k, v))
204
  return pairs
205
 
206
- def extract_chart_data(section_text: str, top_k: int = 10) -> List[Tuple[str, float]]:
207
  data: List[Tuple[str, float]] = []
208
  for line in section_text.splitlines():
209
  m = LABEL_NUM.match(line)
@@ -262,9 +321,11 @@ def process_text(text: str,
262
  "charts": charts,
263
  }
264
 
265
- # =========================
266
- # pptx builder
267
- # =========================
 
 
268
  def _add_logo(prs: Presentation, slide, logo_bytes: Optional[bytes]):
269
  if not logo_bytes:
270
  return
@@ -322,17 +383,22 @@ def _title_slide(prs, title_text: str, theme_rgb, logo_bytes):
322
  def _summary_slide(prs, summary: str):
323
  if not summary:
324
  return
325
- slide = prs.slides.add_slide(prs.slide_layouts[1])
326
  slide.shapes.title.text = "エグゼクティブサマリー"
327
  tf = slide.placeholders[1].text_frame
328
  tf.clear()
329
  lines = [ln.strip() for ln in summary.splitlines() if ln.strip()]
330
  if not lines:
331
- lines = [summary]
 
 
 
332
  for i, ln in enumerate(lines):
333
  p = tf.add_paragraph() if i > 0 else tf.paragraphs[0]
334
  p.text = ln
335
  p.level = 0
 
 
336
 
337
  def _section_slide(prs, title: str, bullets: List[str]):
338
  slide = prs.slides.add_slide(prs.slide_layouts[1])
@@ -341,45 +407,85 @@ def _section_slide(prs, title: str, bullets: List[str]):
341
  tf.clear()
342
  if not bullets:
343
  bullets = ["(要点なし)"]
344
- for i, b in enumerate(bullets[:12]):
 
 
345
  p = tf.add_paragraph() if i > 0 else tf.paragraphs[0]
346
  p.text = b
347
  p.level = 0
 
 
348
 
349
  def _table_slide(prs, title: str, pairs: List[tuple]):
350
- slide = prs.slides.add_slide(prs.slide_layouts[5])
351
- slide.shapes.title.text = title
352
- rows = len(pairs) + 1
353
- cols = 2
354
- left = Inches(0.5)
355
- top = Inches(1.8)
356
- width = prs.slide_width - Inches(1.0)
357
- height = prs.slide_height - Inches(2.6)
358
- table = slide.shapes.add_table(rows, cols, left, top, width, height).table
359
- table.cell(0, 0).text = "項目"
360
- table.cell(0, 1).text = "値"
361
- for r, (k, v) in enumerate(pairs, start=1):
362
- table.cell(r, 0).text = str(k)
363
- table.cell(r, 1).text = str(v)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
  def _chart_slide(prs, title: str, series: List[tuple]):
366
- slide = prs.slides.add_slide(prs.slide_layouts[5])
367
- slide.shapes.title.text = title
368
- labels = [x[0] for x in series]
369
- values = [x[1] for x in series]
370
- fig = plt.figure(figsize=(8, 4.5))
371
- plt.bar(range(len(values)), values)
372
- plt.xticks(range(len(labels)), labels, rotation=20, ha='right')
373
- plt.tight_layout()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
  buf = io.BytesIO()
375
- fig.savefig(buf, format='png', dpi=200)
376
  plt.close(fig)
377
  buf.seek(0)
 
 
 
 
378
  left = Inches(0.5)
379
  top = Inches(1.6)
380
  width = prs.slide_width - Inches(1.0)
381
- height = prs.slide_height - Inches(2.2)
382
- slide.shapes.add_picture(buf, left, top, width=width, height=height)
383
 
384
  def _add_footer(prs, theme_rgb):
385
  for idx, slide in enumerate(prs.slides, start=1):
@@ -420,9 +526,11 @@ def build_presentation(output_path: str,
420
  _add_footer(prs, theme_rgb)
421
  prs.save(output_path)
422
 
423
- # =========================
 
424
  # Gradio App
425
- # =========================
 
426
  def generate_pptx(long_text: str,
427
  title: str,
428
  theme_hex: str,
@@ -512,7 +620,7 @@ def ui():
512
  **Tips**
513
  - 日本語要約には `sonoisa/t5-base-japanese` を推奨(`text2text-generation`)。
514
  - Inference API を使う場合は、Space の Secrets に `HF_TOKEN` を設定してください。
515
- - チャートは `Label: 123` 形式の行を自動検出して棒グラフを作成します。
516
  """)
517
  return demo
518
 
 
7
 
8
  import gradio as gr
9
 
10
+ # ---- Matplotlib をGUI非依存で動作させる(必ず pyplot より先に実行)----
11
  import matplotlib
12
  matplotlib.use("Agg")
13
  import matplotlib.pyplot as plt
14
+ from matplotlib import font_manager
15
 
16
  from pptx import Presentation
17
  from pptx.util import Inches, Pt
 
20
  from pptx.dml.color import RGBColor
21
  from PIL import Image
22
 
23
+ # transformers は任意(未インストールでも動作可)
24
  try:
25
  from transformers import pipeline
26
  except Exception:
27
  pipeline = None
28
 
29
+ import requests # Inference API を使う場合のみ実使用
30
 
31
  APP_NAME = "Auto-PPT Generator"
32
 
33
+ # ======================================================
34
  # utils
35
+ # ======================================================
36
+
37
+ FALLBACK_FONT_PATH = os.getenv("JP_FONT_PATH", "./assets/fonts/IPAexGothic.ttf")
38
+
39
+ def set_jp_font():
40
+ """
41
+ 図の日本語ラベルが豆腐(□)になるのを防ぐ。
42
+ 1) 環境にある日本語フォントを探索
43
+ 2) 無ければ同梱フォント(IPAexGothic など)を追加して設定
44
+ """
45
+ candidates = [
46
+ "IPAexGothic", "Noto Sans CJK JP", "Noto Sans JP",
47
+ "Source Han Sans", "源ノ角ゴシック", "Yu Gothic", "Hiragino Sans"
48
+ ]
49
+ installed = {f.name for f in font_manager.fontManager.ttflist}
50
+ chosen = None
51
+ for name in candidates:
52
+ if any(name in fam for fam in installed):
53
+ chosen = name
54
+ break
55
+ if not chosen and os.path.exists(FALLBACK_FONT_PATH):
56
+ try:
57
+ font_manager.fontManager.addfont(FALLBACK_FONT_PATH)
58
+ chosen = font_manager.FontProperties(fname=FALLBACK_FONT_PATH).get_name()
59
+ except Exception:
60
+ chosen = None
61
+ if chosen:
62
+ plt.rcParams["font.family"] = chosen
63
+ matplotlib.rcParams["axes.unicode_minus"] = False
64
+
65
+
66
+ def wrap_label(s: str, width: int = 6, max_lines: int = 2) -> str:
67
+ """長い日本語ラベルを改行・省略して横溢れを防止"""
68
+ s = str(s)
69
+ if len(s) <= width:
70
+ return s
71
+ chunks = [s[i:i + width] for i in range(0, len(s), width)]
72
+ if len(chunks) > max_lines:
73
+ chunks = chunks[:max_lines]
74
+ chunks[-1] = chunks[-1] + "…"
75
+ return "\n".join(chunks)
76
+
77
+
78
+ def chunked(seq, n):
79
+ """seq を n 件ずつに分割して yield"""
80
+ buf = []
81
+ for x in seq:
82
+ buf.append(x)
83
+ if len(buf) == n:
84
+ yield buf
85
+ buf = []
86
+ if buf:
87
+ yield buf
88
+
89
+
90
  def safe_hex_to_rgb(hex_color: str):
91
  if not hex_color:
92
  return (59, 130, 246) # default blue
 
100
  return (r, g, b)
101
  return (59, 130, 246)
102
 
103
+
104
  def ensure_tmpdir():
105
  os.makedirs("/tmp", exist_ok=True)
106
 
107
+
108
+ # ======================================================
109
+ # LLM client (local / HF Inference API)
110
+ # ======================================================
111
+
112
  class LLMClient:
113
  def __init__(self, use_inference_api: bool = False):
114
  self.use_inference_api = use_inference_api
115
  self.hf_token = os.getenv("HF_TOKEN", None)
116
  self._local_pipes = {}
117
 
118
+ # ---------- Inference API ----------
119
  def _hf_headers(self):
120
  if not self.hf_token:
121
  raise RuntimeError("HF_TOKEN is not set for Inference API usage.")
 
138
  return data[0]["generated_text"]
139
  if isinstance(data, dict) and "generated_text" in data:
140
  return data["generated_text"]
 
141
  if isinstance(data, list) and data and "summary_text" in data[0]:
142
  return data[0]["summary_text"]
143
  return str(data)
144
 
145
+ # ---------- Local transformers ----------
146
  def _get_local_pipe(self, task: str, model: str):
147
  key = (task, model)
148
  if key in self._local_pipes:
 
162
  except Exception:
163
  pass
164
 
165
+ # ローカル(transformers
166
+ if pipeline is not None and model:
167
  try:
168
  if "t5" in model.lower():
169
  pipe = self._get_local_pipe("text2text-generation", model)
 
177
  except Exception:
178
  pass
179
 
180
+ # フォールバック:先頭の短文をつなぐ
181
  sents = re.split(r"[。\.!?]\s*", text)
182
  out = []
183
  for s in sents:
 
194
  return self._hf_textgen(model, prompt, max_new_tokens=max_new_tokens)
195
  except Exception:
196
  return ""
197
+ return "" # 今回はルールベース中心
198
+
199
 
200
+ # ======================================================
201
+ # Text processing
202
+ # ======================================================
203
 
 
 
 
204
  LIST_BULLET = re.compile(r"^(?:[-*•・]|\d+\.|\d+\))\s+(.*)")
205
  KEYVAL_LINE = re.compile(r"^\s*([^::]+?)\s*[::]\s*([^\n]+?)\s*$")
206
  LABEL_NUM = re.compile(r"^\s*([^::]+?)\s*[::]\s*([+-]?\d+(?:\.\d+)?)\s*$")
 
235
  sections = [("本文", text)]
236
  return sections
237
 
238
+ def extract_bullets(section_text: str, max_items: int = 12) -> List[str]:
239
  bullets: List[str] = []
240
  for line in section_text.splitlines():
241
  m = LIST_BULLET.match(line.strip())
 
262
  pairs.append((k, v))
263
  return pairs
264
 
265
+ def extract_chart_data(section_text: str, top_k: int = 16) -> List[Tuple[str, float]]:
266
  data: List[Tuple[str, float]] = []
267
  for line in section_text.splitlines():
268
  m = LABEL_NUM.match(line)
 
321
  "charts": charts,
322
  }
323
 
324
+
325
+ # ======================================================
326
+ # PPTX builder
327
+ # ======================================================
328
+
329
  def _add_logo(prs: Presentation, slide, logo_bytes: Optional[bytes]):
330
  if not logo_bytes:
331
  return
 
383
  def _summary_slide(prs, summary: str):
384
  if not summary:
385
  return
386
+ slide = prs.slides.add_slide(prs.slide_layouts[1]) # Title and Content
387
  slide.shapes.title.text = "エグゼクティブサマリー"
388
  tf = slide.placeholders[1].text_frame
389
  tf.clear()
390
  lines = [ln.strip() for ln in summary.splitlines() if ln.strip()]
391
  if not lines:
392
+ lines = [summary.strip()]
393
+ # 行が多い場合はフォント縮小
394
+ MAX_LINES = 12
395
+ lines = lines[:MAX_LINES]
396
  for i, ln in enumerate(lines):
397
  p = tf.add_paragraph() if i > 0 else tf.paragraphs[0]
398
  p.text = ln
399
  p.level = 0
400
+ for run in p.runs:
401
+ run.font.size = Pt(14 if len(lines) <= 8 else 12)
402
 
403
  def _section_slide(prs, title: str, bullets: List[str]):
404
  slide = prs.slides.add_slide(prs.slide_layouts[1])
 
407
  tf.clear()
408
  if not bullets:
409
  bullets = ["(要点なし)"]
410
+ MAX_ITEMS = 12
411
+ bullets = bullets[:MAX_ITEMS]
412
+ for i, b in enumerate(bullets):
413
  p = tf.add_paragraph() if i > 0 else tf.paragraphs[0]
414
  p.text = b
415
  p.level = 0
416
+ for run in p.runs:
417
+ run.font.size = Pt(18 if len(bullets) <= 8 else 14)
418
 
419
  def _table_slide(prs, title: str, pairs: List[tuple]):
420
+ MAX_ROWS_PER_SLIDE = 12 # 見出し1行 + データ最大12行/枚
421
+ if not pairs:
422
+ pairs = [("(データなし)", "-")]
423
+
424
+ for i, chunk in enumerate(chunked(pairs, MAX_ROWS_PER_SLIDE)):
425
+ slide = prs.slides.add_slide(prs.slide_layouts[5]) # Title Only
426
+ page_title = title if i == 0 else f"{title}(続き)"
427
+ slide.shapes.title.text = page_title
428
+
429
+ rows = len(chunk) + 1
430
+ cols = 2
431
+ left = Inches(0.5)
432
+ top = Inches(1.8)
433
+ width = prs.slide_width - Inches(1.0)
434
+ height = prs.slide_height - Inches(2.6)
435
+ table = slide.shapes.add_table(rows, cols, left, top, width, height).table
436
+
437
+ table.cell(0, 0).text = "項目"
438
+ table.cell(0, 1).text = "値"
439
+
440
+ for r, (k, v) in enumerate(chunk, start=1):
441
+ table.cell(r, 0).text = str(k)
442
+ table.cell(r, 1).text = str(v)
443
+
444
+ # 文字サイズと折返し
445
+ for r in range(rows):
446
+ for c in range(cols):
447
+ cell = table.cell(r, c)
448
+ tf = cell.text_frame
449
+ tf.word_wrap = True
450
+ for p in tf.paragraphs:
451
+ for run in p.runs:
452
+ run.font.size = Pt(12)
453
 
454
  def _chart_slide(prs, title: str, series: List[tuple]):
455
+ # 日本語フォント設定
456
+ set_jp_font()
457
+
458
+ # ラベル整形(改行+省略)
459
+ raw_labels = [str(x[0]) for x in series]
460
+ labels = [wrap_label(lbl, width=6, max_lines=2) for lbl in raw_labels]
461
+ values = [float(x[1]) for x in series]
462
+
463
+ # ラベル長に応じて図の高さと下余白を調整
464
+ max_label_len = max((len(l) for l in raw_labels), default=0)
465
+ base_h = 4.2
466
+ fig_h = max(4.0, min(7.0, base_h + 0.10 * max_label_len)) # 4.0〜7.0 inch
467
+ bottom_margin = min(0.35, 0.18 + 0.012 * max_label_len)
468
+
469
+ fig = plt.figure(figsize=(8, fig_h))
470
+ ax = fig.add_subplot(111)
471
+ ax.bar(range(len(values)), values)
472
+ ax.set_xticks(range(len(labels)))
473
+ ax.set_xticklabels(labels, rotation=0, ha='center')
474
+ fig.subplots_adjust(bottom=bottom_margin, left=0.10, right=0.98, top=0.90)
475
+ ax.set_title(title)
476
+
477
  buf = io.BytesIO()
478
+ fig.savefig(buf, format='png', dpi=200, bbox_inches='tight')
479
  plt.close(fig)
480
  buf.seek(0)
481
+
482
+ # 画像はアスペクト維持で幅フィット(高さは自動比率)
483
+ slide = prs.slides.add_slide(prs.slide_layouts[5]) # Title Only
484
+ slide.shapes.title.text = title
485
  left = Inches(0.5)
486
  top = Inches(1.6)
487
  width = prs.slide_width - Inches(1.0)
488
+ slide.shapes.add_picture(buf, left, top, width=width) # heightは指定しない(比率維持)
 
489
 
490
  def _add_footer(prs, theme_rgb):
491
  for idx, slide in enumerate(prs.slides, start=1):
 
526
  _add_footer(prs, theme_rgb)
527
  prs.save(output_path)
528
 
529
+
530
+ # ======================================================
531
  # Gradio App
532
+ # ======================================================
533
+
534
  def generate_pptx(long_text: str,
535
  title: str,
536
  theme_hex: str,
 
620
  **Tips**
621
  - 日本語要約には `sonoisa/t5-base-japanese` を推奨(`text2text-generation`)。
622
  - Inference API を使う場合は、Space の Secrets に `HF_TOKEN` を設定してください。
623
+ - チャートは `ラベル: 数値` 形式の行を自動検出して棒グラフを作成します。
624
  """)
625
  return demo
626