PresentAgent / pptagent /multimodal.py
sjw712's picture
Upload 208 files
d961e88 verified
import asyncio
from typing import Optional
import PIL.Image
from pptagent.llms import LLM, AsyncLLM
from pptagent.presentation import Picture, Presentation
from pptagent.utils import Config, get_logger, package_join, pbasename, pjoin
logger = get_logger(__name__)
class ImageLabler:
"""
A class to extract images information, including caption, size, and appearance times in a presentation.
"""
def __init__(self, presentation: Presentation, config: Config):
"""
Initialize the ImageLabler.
Args:
presentation (Presentation): The presentation object.
config (Config): The configuration object.
"""
self.presentation = presentation
self.slide_area = presentation.slide_width.pt * presentation.slide_height.pt
self.image_stats = {}
self.config = config
self.collect_images()
def apply_stats(self, image_stats: Optional[dict[str, dict]] = None):
"""
Apply image captions to the presentation.
"""
if image_stats is None:
image_stats = self.image_stats
for slide in self.presentation.slides:
for shape in slide.shape_filter(Picture):
if shape.caption is None:
caption = image_stats[pbasename(shape.img_path)]["caption"]
shape.caption = max(caption.split("\n"), key=len)
async def caption_images_async(self, vision_model: AsyncLLM):
"""
Generate captions for images in the presentation asynchronously.
Args:
vision_model (AsyncLLM): The async vision model to use for captioning.
Returns:
dict: Dictionary containing image stats with captions.
"""
assert isinstance(
vision_model, AsyncLLM
), "vision_model must be an AsyncLLM instance"
caption_prompt = open(package_join("prompts", "caption.txt")).read()
async with asyncio.TaskGroup() as tg:
for image, stats in self.image_stats.items():
if "caption" not in stats:
task = tg.create_task(
vision_model(
caption_prompt,
pjoin(self.config.IMAGE_DIR, image),
)
)
task.add_done_callback(
lambda t, image=image: (
self.image_stats[image].update({"caption": t.result()}),
logger.debug("captioned %s: %s", image, t.result()),
)
)
self.apply_stats()
return self.image_stats
def caption_images(self, vision_model: LLM):
"""
Generate captions for images in the presentation.
Args:
vision_model (LLM): The vision model to use for captioning.
Returns:
dict: Dictionary containing image stats with captions.
"""
assert isinstance(vision_model, LLM), "vision_model must be an LLM instance"
caption_prompt = open(package_join("prompts", "caption.txt")).read()
for image, stats in self.image_stats.items():
if "caption" not in stats:
stats["caption"] = vision_model(
caption_prompt, pjoin(self.config.IMAGE_DIR, image)
)
logger.debug("captioned %s: %s", image, stats["caption"])
self.apply_stats()
return self.image_stats
def collect_images(self):
"""
Collect images from the presentation and gather other information.
"""
for slide_index, slide in enumerate(self.presentation.slides):
for shape in slide.shape_filter(Picture):
image_path = pbasename(shape.img_path)
if image_path == "pic_placeholder.png":
continue
if image_path not in self.image_stats:
size = PIL.Image.open(pjoin(self.config.IMAGE_DIR, image_path)).size
self.image_stats[image_path] = {
"size": size,
"appear_times": 0,
"slide_numbers": set(),
"relative_area": shape.area / self.slide_area * 100,
}
self.image_stats[image_path]["appear_times"] += 1
self.image_stats[image_path]["slide_numbers"].add(slide_index + 1)
for image_path, stats in self.image_stats.items():
stats["slide_numbers"] = sorted(list(stats["slide_numbers"]))
ranges = self._find_ranges(stats["slide_numbers"])
top_ranges = sorted(ranges, key=lambda x: x[1] - x[0], reverse=True)[:3]
top_ranges_str = ", ".join(
[f"{r[0]}-{r[1]}" if r[0] != r[1] else f"{r[0]}" for r in top_ranges]
)
stats["top_ranges_str"] = top_ranges_str
def _find_ranges(self, numbers):
"""
Find consecutive ranges in a list of numbers.
"""
ranges = []
start = numbers[0]
end = numbers[0]
for num in numbers[1:]:
if num == end + 1:
end = num
else:
ranges.append((start, end))
start = num
end = num
ranges.append((start, end))
return ranges