Corin1998's picture
Update app.py
e2a2f35 verified
import os
import io
import time
import sys
import re
from typing import Optional, List, Tuple, Dict, Any
import gradio as gr
# ---- Matplotlib をGUI非依存で動作させる(必ず pyplot より先に実行)----
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib import font_manager
from pptx import Presentation
from pptx.util import Inches, Pt
from pptx.enum.text import PP_ALIGN
from pptx.enum.shapes import MSO_AUTO_SHAPE_TYPE
from pptx.dml.color import RGBColor
from PIL import Image
# transformers は任意(未インストールでも動作可)
try:
from transformers import pipeline
except Exception:
pipeline = None
import requests # Inference API を使う場合のみ実使用
APP_NAME = "Auto-PPT Generator"
# ======================================================
# utils
# ======================================================
FALLBACK_FONT_PATH = os.getenv("JP_FONT_PATH", "./assets/fonts/IPAexGothic.ttf")
def set_jp_font():
"""
図の日本語ラベルが豆腐(□)になるのを防ぐ。
1) 環境にある日本語フォントを探索
2) 無ければ同梱フォント(IPAexGothic など)を追加して設定
"""
candidates = [
"IPAexGothic", "Noto Sans CJK JP", "Noto Sans JP",
"Source Han Sans", "源ノ角ゴシック", "Yu Gothic", "Hiragino Sans"
]
installed = {f.name for f in font_manager.fontManager.ttflist}
chosen = None
for name in candidates:
if any(name in fam for fam in installed):
chosen = name
break
if not chosen and os.path.exists(FALLBACK_FONT_PATH):
try:
font_manager.fontManager.addfont(FALLBACK_FONT_PATH)
chosen = font_manager.FontProperties(fname=FALLBACK_FONT_PATH).get_name()
except Exception:
chosen = None
if chosen:
plt.rcParams["font.family"] = chosen
matplotlib.rcParams["axes.unicode_minus"] = False
def wrap_label(s: str, width: int = 6, max_lines: int = 2) -> str:
"""長い日本語ラベルを改行・省略して横溢れを防止"""
s = str(s)
if len(s) <= width:
return s
chunks = [s[i:i + width] for i in range(0, len(s), width)]
if len(chunks) > max_lines:
chunks = chunks[:max_lines]
chunks[-1] = chunks[-1] + "…"
return "\n".join(chunks)
def chunked(seq, n):
"""seq を n 件ずつに分割して yield"""
buf = []
for x in seq:
buf.append(x)
if len(buf) == n:
yield buf
buf = []
if buf:
yield buf
def safe_hex_to_rgb(hex_color: str):
if not hex_color:
return (59, 130, 246) # default blue
hx = hex_color.strip()
if not hx.startswith("#"):
hx = "#" + hx
if re.fullmatch(r"#[0-9A-Fa-f]{6}", hx):
r = int(hx[1:3], 16)
g = int(hx[3:5], 16)
b = int(hx[5:7], 16)
return (r, g, b)
return (59, 130, 246)
def ensure_tmpdir():
os.makedirs("/tmp", exist_ok=True)
# ======================================================
# LLM client (local / HF Inference API)
# ======================================================
class LLMClient:
def __init__(self, use_inference_api: bool = False):
self.use_inference_api = use_inference_api
self.hf_token = os.getenv("HF_TOKEN", None)
self._local_pipes = {}
# ---------- Inference API ----------
def _hf_headers(self):
if not self.hf_token:
raise RuntimeError("HF_TOKEN is not set for Inference API usage.")
return {"Authorization": f"Bearer {self.hf_token}"}
def _hf_textgen(self, model: str, prompt: str, max_new_tokens: int = 512, temperature: float = 0.3) -> str:
url = f"https://api-inference.huggingface.co/models/{model}"
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"return_full_text": False,
},
}
r = requests.post(url, headers=self._hf_headers(), json=payload, timeout=120)
r.raise_for_status()
data = r.json()
if isinstance(data, list) and data and "generated_text" in data[0]:
return data[0]["generated_text"]
if isinstance(data, dict) and "generated_text" in data:
return data["generated_text"]
if isinstance(data, list) and data and "summary_text" in data[0]:
return data[0]["summary_text"]
return str(data)
# ---------- Local transformers ----------
def _get_local_pipe(self, task: str, model: str):
key = (task, model)
if key in self._local_pipes:
return self._local_pipes[key]
if pipeline is None:
raise RuntimeError("transformers is not available")
pipe = pipeline(task=task, model=model)
self._local_pipes[key] = pipe
return pipe
# ---------- Public ----------
def summarize(self, text: str, model: str, max_words: int = 200) -> str:
# Inference API 優先
if self.use_inference_api and model:
try:
return self._hf_textgen(model, text[:6000], max_new_tokens=max_words * 2).strip()
except Exception:
pass
# ローカル(transformers)
if pipeline is not None and model:
try:
if "t5" in model.lower():
pipe = self._get_local_pipe("text2text-generation", model)
prompt = f"要約: {text[:6000]}"
res = pipe(prompt, max_length=max_words * 2, do_sample=False)
return res[0]["generated_text"].strip()
else:
pipe = self._get_local_pipe("summarization", model)
res = pipe(text[:6000], max_length=max_words * 2, min_length=max_words // 2, do_sample=False)
return res[0]["summary_text"].strip()
except Exception:
pass
# フォールバック:先頭の短文をつなぐ
sents = re.split(r"[。\.!?]\s*", text)
out = []
for s in sents:
s = s.strip()
if s:
out.append(s)
if len(" ".join(out)) > max_words * 6:
break
return "。".join(out)
def generate(self, prompt: str, model: Optional[str] = None, max_new_tokens: int = 512) -> str:
if self.use_inference_api and model:
try:
return self._hf_textgen(model, prompt, max_new_tokens=max_new_tokens)
except Exception:
return ""
return "" # 今回はルールベース中心
# ======================================================
# Text processing
# ======================================================
LIST_BULLET = re.compile(r"^(?:[-*•・]|\d+\.|\d+\))\s+(.*)")
KEYVAL_LINE = re.compile(r"^\s*([^::]+?)\s*[::]\s*([^\n]+?)\s*$")
LABEL_NUM = re.compile(r"^\s*([^::]+?)\s*[::]\s*([+-]?\d+(?:\.\d+)?)\s*$")
HEADER = re.compile(r"^(#+|\d+\.|\d+\))\s*(.+)$")
def naive_section_split(text: str, target_chars: int = 1200) -> List[Tuple[str, str]]:
"""Split into (title, content) using headings or by size."""
lines = text.splitlines()
sections: List[Tuple[str, str]] = []
cur_title = "セクション"
cur_buf: List[str] = []
def flush():
nonlocal cur_title, cur_buf
if cur_buf:
sections.append((cur_title, "\n".join(cur_buf).strip()))
cur_buf = []
for ln in lines:
m = HEADER.match(ln.strip())
if m:
flush()
cur_title = m.group(2).strip()
continue
cur_buf.append(ln)
if sum(len(x) for x in cur_buf) > target_chars:
flush()
cur_title = f"セクション{len(sections)+1}"
flush()
if not sections:
sections = [("本文", text)]
return sections
def extract_bullets(section_text: str, max_items: int = 12) -> List[str]:
bullets: List[str] = []
for line in section_text.splitlines():
m = LIST_BULLET.match(line.strip())
if m:
bullets.append(m.group(1).strip())
if not bullets:
sents = re.split(r"[。\.!?]\s*", section_text)
for s in sents:
s = s.strip()
if 8 <= len(s) <= 120:
bullets.append(s)
if len(bullets) >= max_items:
break
return bullets[:max_items]
def extract_keyval_table(section_text: str) -> List[Tuple[str, str]]:
pairs: List[Tuple[str, str]] = []
for line in section_text.splitlines():
m = KEYVAL_LINE.match(line)
if m:
k = m.group(1).strip()
v = m.group(2).strip()
if k and v:
pairs.append((k, v))
return pairs
def extract_chart_data(section_text: str, top_k: int = 16) -> List[Tuple[str, float]]:
data: List[Tuple[str, float]] = []
for line in section_text.splitlines():
m = LABEL_NUM.match(line)
if m:
label = m.group(1).strip()
try:
val = float(m.group(2))
except ValueError:
continue
data.append((label, val))
seen = {}
for k, v in data:
seen[k] = v
items = list(seen.items())
items.sort(key=lambda x: abs(x[1]), reverse=True)
return items[:top_k]
def process_text(text: str,
use_inference_api: bool,
summarizer_model: str,
generator_model: str,
want_summary: bool,
want_tables: bool,
want_charts: bool,
max_summary_words: int = 200) -> Dict[str, Any]:
client = LLMClient(use_inference_api=use_inference_api)
summary = None
if want_summary:
summary = client.summarize(text, model=summarizer_model, max_words=max_summary_words)
sections = naive_section_split(text)
bullets_by_section: Dict[int, List[str]] = {}
tables: List[Dict[str, Any]] = []
charts: List[Dict[str, Any]] = []
for idx, (title, body) in enumerate(sections):
bullets_by_section[idx] = extract_bullets(body)
if want_tables:
kv = extract_keyval_table(body)
if kv:
tables.append({"title": f"{title} — 表", "pairs": kv})
if want_charts:
series = extract_chart_data(body)
if series:
charts.append({"title": f"{title} — チャート", "series": series})
return {
"summary": summary,
"sections": sections,
"bullets": bullets_by_section,
"tables": tables,
"charts": charts,
}
# ======================================================
# PPTX builder
# ======================================================
def _add_logo(prs: Presentation, slide, logo_bytes: Optional[bytes]):
if not logo_bytes:
return
img = Image.open(io.BytesIO(logo_bytes)).convert("RGBA")
max_w, max_h = Inches(2.0), Inches(1.0)
w, h = img.size
ratio = min(max_w / max(w, 1), max_h / max(h, 1))
new_size = (max(1, int(w * ratio)), max(1, int(h * ratio)))
resized = img.resize(new_size)
b = io.BytesIO()
resized.save(b, format="PNG")
b.seek(0)
left = prs.slide_width - max_w - Inches(0.5)
top = Inches(0.2)
slide.shapes.add_picture(b, left, top)
def _apply_theme_bg(slide, rgb):
fill = slide.background.fill
fill.solid()
fill.fore_color.rgb = RGBColor(*rgb)
def _title_slide(prs, title_text: str, theme_rgb, logo_bytes):
slide_layout = prs.slide_layouts[0]
slide = prs.slides.add_slide(slide_layout)
title = slide.shapes.title
subtitle = slide.placeholders[1]
title.text = title_text
subtitle.text = "自動生成プレゼンテーション"
_apply_theme_bg(slide, theme_rgb)
left = Inches(0.6)
top = Inches(1.8)
width = prs.slide_width - Inches(1.2)
height = Inches(2.2)
box = slide.shapes.add_shape(MSO_AUTO_SHAPE_TYPE.ROUNDED_RECTANGLE, left, top, width, height)
box.fill.solid()
box.fill.fore_color.rgb = RGBColor(255, 255, 255)
box.line.color.rgb = RGBColor(0, 0, 0)
box.line.transparency = 0.8
title.left = left + Inches(0.3)
title.top = top + Inches(0.3)
title.width = width - Inches(0.6)
title.height = Inches(1.4)
for p in title.text_frame.paragraphs:
p.font.size = Pt(40)
p.font.bold = True
subtitle.left = left + Inches(0.3)
subtitle.top = top + Inches(1.6)
subtitle.width = width - Inches(0.6)
subtitle.height = Inches(0.8)
for p in subtitle.text_frame.paragraphs:
p.font.size = Pt(16)
p.font.bold = False
_add_logo(prs, slide, logo_bytes)
def _summary_slide(prs, summary: str):
if not summary:
return
slide = prs.slides.add_slide(prs.slide_layouts[1]) # Title and Content
slide.shapes.title.text = "エグゼクティブサマリー"
tf = slide.placeholders[1].text_frame
tf.clear()
lines = [ln.strip() for ln in summary.splitlines() if ln.strip()]
if not lines:
lines = [summary.strip()]
# 行が多い場合はフォント縮小
MAX_LINES = 12
lines = lines[:MAX_LINES]
for i, ln in enumerate(lines):
p = tf.add_paragraph() if i > 0 else tf.paragraphs[0]
p.text = ln
p.level = 0
for run in p.runs:
run.font.size = Pt(14 if len(lines) <= 8 else 12)
def _section_slide(prs, title: str, bullets: List[str]):
slide = prs.slides.add_slide(prs.slide_layouts[1])
slide.shapes.title.text = title[:90]
tf = slide.placeholders[1].text_frame
tf.clear()
if not bullets:
bullets = ["(要点なし)"]
MAX_ITEMS = 12
bullets = bullets[:MAX_ITEMS]
for i, b in enumerate(bullets):
p = tf.add_paragraph() if i > 0 else tf.paragraphs[0]
p.text = b
p.level = 0
for run in p.runs:
run.font.size = Pt(18 if len(bullets) <= 8 else 14)
def _table_slide(prs, title: str, pairs: List[tuple]):
MAX_ROWS_PER_SLIDE = 12 # 見出し1行 + データ最大12行/枚
if not pairs:
pairs = [("(データなし)", "-")]
for i, chunk in enumerate(chunked(pairs, MAX_ROWS_PER_SLIDE)):
slide = prs.slides.add_slide(prs.slide_layouts[5]) # Title Only
page_title = title if i == 0 else f"{title}(続き)"
slide.shapes.title.text = page_title
rows = len(chunk) + 1
cols = 2
left = Inches(0.5)
top = Inches(1.8)
width = prs.slide_width - Inches(1.0)
height = prs.slide_height - Inches(2.6)
table = slide.shapes.add_table(rows, cols, left, top, width, height).table
table.cell(0, 0).text = "項目"
table.cell(0, 1).text = "値"
for r, (k, v) in enumerate(chunk, start=1):
table.cell(r, 0).text = str(k)
table.cell(r, 1).text = str(v)
# 文字サイズと折返し
for r in range(rows):
for c in range(cols):
cell = table.cell(r, c)
tf = cell.text_frame
tf.word_wrap = True
for p in tf.paragraphs:
for run in p.runs:
run.font.size = Pt(12)
def _chart_slide(prs, title: str, series: List[tuple]):
# 日本語フォント設定
set_jp_font()
# ラベル整形(改行+省略)
raw_labels = [str(x[0]) for x in series]
labels = [wrap_label(lbl, width=6, max_lines=2) for lbl in raw_labels]
values = [float(x[1]) for x in series]
# ラベル長に応じて図の高さと下余白を調整
max_label_len = max((len(l) for l in raw_labels), default=0)
base_h = 4.2
fig_h = max(4.0, min(7.0, base_h + 0.10 * max_label_len)) # 4.0〜7.0 inch
bottom_margin = min(0.35, 0.18 + 0.012 * max_label_len)
fig = plt.figure(figsize=(8, fig_h))
ax = fig.add_subplot(111)
ax.bar(range(len(values)), values)
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels, rotation=0, ha='center')
fig.subplots_adjust(bottom=bottom_margin, left=0.10, right=0.98, top=0.90)
ax.set_title(title)
buf = io.BytesIO()
fig.savefig(buf, format='png', dpi=200, bbox_inches='tight')
plt.close(fig)
buf.seek(0)
# 画像はアスペクト維持で幅フィット(高さは自動比率)
slide = prs.slides.add_slide(prs.slide_layouts[5]) # Title Only
slide.shapes.title.text = title
left = Inches(0.5)
top = Inches(1.6)
width = prs.slide_width - Inches(1.0)
slide.shapes.add_picture(buf, left, top, width=width) # heightは指定しない(比率維持)
def _add_footer(prs, theme_rgb):
for idx, slide in enumerate(prs.slides, start=1):
left = Inches(0.3)
top = prs.slide_height - Inches(0.4)
width = prs.slide_width - Inches(0.6)
height = Inches(0.3)
shp = slide.shapes.add_shape(MSO_AUTO_SHAPE_TYPE.RECTANGLE, left, top, width, height)
shp.fill.solid()
shp.fill.fore_color.rgb = RGBColor(*theme_rgb)
shp.line.fill.background()
tx = slide.shapes.add_textbox(prs.slide_width - Inches(1.0), top - Inches(0.05), Inches(0.8), Inches(0.3))
tf = tx.text_frame
p = tf.paragraphs[0]
p.text = f"{idx}"
p.font.size = Pt(10)
p.alignment = PP_ALIGN.RIGHT
def build_presentation(output_path: str,
title: str,
theme_rgb: tuple,
logo_bytes: Optional[bytes],
executive_summary: Optional[str],
sections: List[Tuple[str, str]],
bullets_by_section: Dict[int, List[str]],
tables: List[Dict[str, Any]],
charts: List[Dict[str, Any]]):
prs = Presentation()
_title_slide(prs, title, theme_rgb, logo_bytes)
_summary_slide(prs, executive_summary)
for idx, (sec_title, _body) in enumerate(sections):
bullets = bullets_by_section.get(idx, [])
_section_slide(prs, sec_title, bullets)
for tbl in tables:
_table_slide(prs, tbl.get("title", "表"), tbl.get("pairs", []))
for ch in charts:
_chart_slide(prs, ch.get("title", "チャート"), ch.get("series", []))
_add_footer(prs, theme_rgb)
prs.save(output_path)
# ======================================================
# Gradio App
# ======================================================
def generate_pptx(long_text: str,
title: str,
theme_hex: str,
logo_file,
add_summary: bool,
add_tables: bool,
add_charts: bool,
use_inference_api: bool,
summarizer_model: str,
generator_model: str,
max_summary_words: int):
if not long_text or not long_text.strip():
raise gr.Error("入力テキストが空です。長文を貼り付けてください。")
theme_rgb = safe_hex_to_rgb(theme_hex or "#3B82F6")
# Read logo (optional)
logo_bytes = None
if logo_file is not None:
try:
if hasattr(logo_file, "read"):
logo_bytes = logo_file.read()
elif hasattr(logo_file, "name") and logo_file.name:
with open(logo_file.name, "rb") as f:
logo_bytes = f.read()
except Exception:
logo_bytes = None
result = process_text(
text=long_text,
use_inference_api=use_inference_api,
summarizer_model=summarizer_model,
generator_model=generator_model,
want_summary=add_summary,
want_tables=add_tables,
want_charts=add_charts,
max_summary_words=max_summary_words,
)
ensure_tmpdir()
timestamp = time.strftime('%Y%m%d-%H%M%S')
out_path = f"/tmp/auto_ppt_{timestamp}.pptx"
build_presentation(
output_path=out_path,
title=(title or "Auto-PPT"),
theme_rgb=theme_rgb,
logo_bytes=logo_bytes,
executive_summary=result.get("summary"),
sections=result.get("sections", []),
bullets_by_section=result.get("bullets", {}),
tables=result.get("tables", []),
charts=result.get("charts", []),
)
return out_path
def ui():
with gr.Blocks(title=APP_NAME) as demo:
gr.Markdown(f"# {APP_NAME}\n長文→要約→セクション分割→箇条書き/表/図→**PPTX出力** まで自動化")
with gr.Row():
with gr.Column(scale=2):
long_text = gr.Textbox(label="長文テキスト (貼り付け)", lines=20, placeholder="ここに文章を貼り付け…")
title = gr.Textbox(label="タイトル", value="自動生成スライド")
theme_hex = gr.Textbox(label="ブランドカラー HEX", value="#3465A4")
logo = gr.File(label="ロゴ (任意, PNG/JPG)")
with gr.Row():
add_summary = gr.Checkbox(value=True, label="要約スライドを追加")
add_tables = gr.Checkbox(value=True, label="表を抽出して追加")
add_charts = gr.Checkbox(value=True, label="チャートを生成して追加")
with gr.Column(scale=1):
gr.Markdown("### モデル設定")
use_inference_api = gr.Checkbox(value=False, label="Hugging Face Inference API を使用")
summarizer_model = gr.Textbox(label="要約モデル (local or API)", value="sshleifer/distilbart-cnn-12-6")
generator_model = gr.Textbox(label="生成モデル (API推奨, 任意)", value="")
max_summary_words = gr.Slider(50, 600, value=200, step=10, label="要約の最大語数(目安)")
generate = gr.Button("PPTXを生成", variant="primary")
output_file = gr.File(label="ダウンロード")
generate.click(
fn=generate_pptx,
inputs=[long_text, title, theme_hex, logo, add_summary, add_tables, add_charts,
use_inference_api, summarizer_model, generator_model, max_summary_words],
outputs=[output_file],
)
gr.Markdown("""
**Tips**
- 日本語要約には `sonoisa/t5-base-japanese` を推奨(`text2text-generation`)。
- Inference API を使う場合は、Space の Secrets に `HF_TOKEN` を設定してください。
- チャートは `ラベル: 数値` 形式の行を自動検出して棒グラフを作成します。
""")
return demo
if __name__ == "__main__":
demo = ui()
# Spaces は自動でバインドされますが、ローカル互換のため指定可能
demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))