Spaces:
Runtime error
Runtime error
| import asyncio | |
| from typing import Optional | |
| import PIL.Image | |
| from pptagent.llms import LLM, AsyncLLM | |
| from pptagent.presentation import Picture, Presentation | |
| from pptagent.utils import Config, get_logger, package_join, pbasename, pjoin | |
| logger = get_logger(__name__) | |
| class ImageLabler: | |
| """ | |
| A class to extract images information, including caption, size, and appearance times in a presentation. | |
| """ | |
| def __init__(self, presentation: Presentation, config: Config): | |
| """ | |
| Initialize the ImageLabler. | |
| Args: | |
| presentation (Presentation): The presentation object. | |
| config (Config): The configuration object. | |
| """ | |
| self.presentation = presentation | |
| self.slide_area = presentation.slide_width.pt * presentation.slide_height.pt | |
| self.image_stats = {} | |
| self.config = config | |
| self.collect_images() | |
| def apply_stats(self, image_stats: Optional[dict[str, dict]] = None): | |
| """ | |
| Apply image captions to the presentation. | |
| """ | |
| if image_stats is None: | |
| image_stats = self.image_stats | |
| for slide in self.presentation.slides: | |
| for shape in slide.shape_filter(Picture): | |
| if shape.caption is None: | |
| caption = image_stats[pbasename(shape.img_path)]["caption"] | |
| shape.caption = max(caption.split("\n"), key=len) | |
| async def caption_images_async(self, vision_model: AsyncLLM): | |
| """ | |
| Generate captions for images in the presentation asynchronously. | |
| Args: | |
| vision_model (AsyncLLM): The async vision model to use for captioning. | |
| Returns: | |
| dict: Dictionary containing image stats with captions. | |
| """ | |
| assert isinstance( | |
| vision_model, AsyncLLM | |
| ), "vision_model must be an AsyncLLM instance" | |
| caption_prompt = open(package_join("prompts", "caption.txt")).read() | |
| async with asyncio.TaskGroup() as tg: | |
| for image, stats in self.image_stats.items(): | |
| if "caption" not in stats: | |
| task = tg.create_task( | |
| vision_model( | |
| caption_prompt, | |
| pjoin(self.config.IMAGE_DIR, image), | |
| ) | |
| ) | |
| task.add_done_callback( | |
| lambda t, image=image: ( | |
| self.image_stats[image].update({"caption": t.result()}), | |
| logger.debug("captioned %s: %s", image, t.result()), | |
| ) | |
| ) | |
| self.apply_stats() | |
| return self.image_stats | |
| def caption_images(self, vision_model: LLM): | |
| """ | |
| Generate captions for images in the presentation. | |
| Args: | |
| vision_model (LLM): The vision model to use for captioning. | |
| Returns: | |
| dict: Dictionary containing image stats with captions. | |
| """ | |
| assert isinstance(vision_model, LLM), "vision_model must be an LLM instance" | |
| caption_prompt = open(package_join("prompts", "caption.txt")).read() | |
| for image, stats in self.image_stats.items(): | |
| if "caption" not in stats: | |
| stats["caption"] = vision_model( | |
| caption_prompt, pjoin(self.config.IMAGE_DIR, image) | |
| ) | |
| logger.debug("captioned %s: %s", image, stats["caption"]) | |
| self.apply_stats() | |
| return self.image_stats | |
| def collect_images(self): | |
| """ | |
| Collect images from the presentation and gather other information. | |
| """ | |
| for slide_index, slide in enumerate(self.presentation.slides): | |
| for shape in slide.shape_filter(Picture): | |
| image_path = pbasename(shape.img_path) | |
| if image_path == "pic_placeholder.png": | |
| continue | |
| if image_path not in self.image_stats: | |
| size = PIL.Image.open(pjoin(self.config.IMAGE_DIR, image_path)).size | |
| self.image_stats[image_path] = { | |
| "size": size, | |
| "appear_times": 0, | |
| "slide_numbers": set(), | |
| "relative_area": shape.area / self.slide_area * 100, | |
| } | |
| self.image_stats[image_path]["appear_times"] += 1 | |
| self.image_stats[image_path]["slide_numbers"].add(slide_index + 1) | |
| for image_path, stats in self.image_stats.items(): | |
| stats["slide_numbers"] = sorted(list(stats["slide_numbers"])) | |
| ranges = self._find_ranges(stats["slide_numbers"]) | |
| top_ranges = sorted(ranges, key=lambda x: x[1] - x[0], reverse=True)[:3] | |
| top_ranges_str = ", ".join( | |
| [f"{r[0]}-{r[1]}" if r[0] != r[1] else f"{r[0]}" for r in top_ranges] | |
| ) | |
| stats["top_ranges_str"] = top_ranges_str | |
| def _find_ranges(self, numbers): | |
| """ | |
| Find consecutive ranges in a list of numbers. | |
| """ | |
| ranges = [] | |
| start = numbers[0] | |
| end = numbers[0] | |
| for num in numbers[1:]: | |
| if num == end + 1: | |
| end = num | |
| else: | |
| ranges.append((start, end)) | |
| start = num | |
| end = num | |
| ranges.append((start, end)) | |
| return ranges | |