diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..b9bf211235b6d4ec294a64aa837252adfa59534b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,25 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +MegaTTS3/assets/Chinese_prompt.wav filter=lfs diff=lfs merge=lfs -text +MegaTTS3/assets/English_prompt.wav filter=lfs diff=lfs merge=lfs -text +MegaTTS3/assets/fig/Hi.gif filter=lfs diff=lfs merge=lfs -text +MegaTTS3/assets/fig/table_tts.png filter=lfs diff=lfs merge=lfs -text +MegaTTS3/assets/fig/table_wavvae.png filter=lfs diff=lfs merge=lfs -text +pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/source.pdf filter=lfs diff=lfs merge=lfs -text +pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/_page_4_Figure_0.jpeg filter=lfs diff=lfs merge=lfs -text +pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/source.pdf filter=lfs diff=lfs merge=lfs -text +pptagent/runs/ppt_video/ca046385-ac3d-4240-9284-a96c57d934d3/output.mp4 filter=lfs diff=lfs merge=lfs -text +pptagent/runs/ppt_video/e88b9f32-6b97-4096-abd6-9bee103524b6/output.mp4 filter=lfs diff=lfs merge=lfs -text +pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/images/203e2300314026057b7257a3c105a8d2fad5183e.png filter=lfs diff=lfs merge=lfs -text +pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/images/35639ff12c3127b2ba9419b7c784b212753ff628.png filter=lfs diff=lfs merge=lfs -text +pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/images/5452ff4f227c6ba1d7ad666974203486e642daf6.png filter=lfs diff=lfs merge=lfs -text +pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/images/a1c98d25e5c2a3059235733edc58ea6984e75dc9.png filter=lfs diff=lfs merge=lfs -text +pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/images/fee3a1e81ae1678f114d5799e440cc2b7d740aa1.png filter=lfs diff=lfs merge=lfs -text +pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/source.pptx filter=lfs diff=lfs merge=lfs -text +pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template.pptx filter=lfs diff=lfs merge=lfs -text +pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/source.pptx filter=lfs diff=lfs merge=lfs -text +templates/previews/Template1.jpg filter=lfs diff=lfs merge=lfs -text +templates/Template1.pptx filter=lfs diff=lfs merge=lfs -text +templates/Template2.pptx filter=lfs diff=lfs merge=lfs -text +templates/Template3.pptx filter=lfs diff=lfs merge=lfs -text diff --git a/MegaTTS3/CODE_OF_CONDUCT.md b/MegaTTS3/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..d28f7f91013e84b601278006135c437ab1e0e7f5 --- /dev/null +++ b/MegaTTS3/CODE_OF_CONDUCT.md @@ -0,0 +1,128 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct +enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. diff --git a/MegaTTS3/Dockerfile b/MegaTTS3/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..6211679afef38a510d4927ae7e0b43606cad8e81 --- /dev/null +++ b/MegaTTS3/Dockerfile @@ -0,0 +1,18 @@ +FROM pytorch/pytorch:2.3.0-cuda12.1-cudnn8-runtime + +WORKDIR /app + +RUN apt-get update && apt-get install -y \ + curl \ + python3 \ + python3-pip \ + ffmpeg \ + && apt-get clean + +COPY requirements.txt /app/ + +RUN pip install --no-cache-dir -r requirements.txt + +COPY . /app/ + +CMD ["python", "-m", "tts.gradio_api"] diff --git a/MegaTTS3/LICENSE b/MegaTTS3/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..315a3e8c3992b63922a9524fe7f48a4735d46a68 --- /dev/null +++ b/MegaTTS3/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [2025] ByteDance + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/MegaTTS3/assets/Chinese_prompt.npy b/MegaTTS3/assets/Chinese_prompt.npy new file mode 100644 index 0000000000000000000000000000000000000000..b96373ba725154fce995657388f14eb3216f108f --- /dev/null +++ b/MegaTTS3/assets/Chinese_prompt.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dd32490cefd81aecd7edb36b13dea4e8746ae9365d6aac065e6408477df7e70c +size 27264 diff --git a/MegaTTS3/assets/Chinese_prompt.wav b/MegaTTS3/assets/Chinese_prompt.wav new file mode 100644 index 0000000000000000000000000000000000000000..d3cad0bdfb31866408c6899bda685a20731a587e --- /dev/null +++ b/MegaTTS3/assets/Chinese_prompt.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dbb0860cb4dd7c7003b6f0406299fc7c0febc5c6a990e1c670d29b763e84e7ed +size 384046 diff --git a/MegaTTS3/assets/English_prompt.npy b/MegaTTS3/assets/English_prompt.npy new file mode 100644 index 0000000000000000000000000000000000000000..c275020d7afce01af8b91adc7147867a09e8bdd1 --- /dev/null +++ b/MegaTTS3/assets/English_prompt.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a00a5c6bacedeee141d135e38e8d3c297a2326fcbfd9def9a657d4ac603c8b84 +size 22400 diff --git a/MegaTTS3/assets/English_prompt.wav b/MegaTTS3/assets/English_prompt.wav new file mode 100644 index 0000000000000000000000000000000000000000..cabacb48060d5e7219edc84bce1163178c02f5c8 --- /dev/null +++ b/MegaTTS3/assets/English_prompt.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c5152e43ef1b2f72c95d64f216179b52d0b68d754785bb85b69ed9111036aa43 +size 317214 diff --git a/MegaTTS3/assets/fig/Hi.gif b/MegaTTS3/assets/fig/Hi.gif new file mode 100644 index 0000000000000000000000000000000000000000..bc5b1b70e07b809fc9e04a1b4d958b1a04b16327 --- /dev/null +++ b/MegaTTS3/assets/fig/Hi.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:af4378fa510e5bc152501a0444e312eba7deb35d726fa4bd9567a7a504fb6df8 +size 291627 diff --git a/MegaTTS3/assets/fig/table_tts.png b/MegaTTS3/assets/fig/table_tts.png new file mode 100644 index 0000000000000000000000000000000000000000..26cf13c7055aaf8dcd9151d7e1eea776ad736850 --- /dev/null +++ b/MegaTTS3/assets/fig/table_tts.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7cb8a1aee90bcc4de9b90e9883e93733fbfe7d006b018becaf0f58444a45a399 +size 104317 diff --git a/MegaTTS3/assets/fig/table_wavvae.png b/MegaTTS3/assets/fig/table_wavvae.png new file mode 100644 index 0000000000000000000000000000000000000000..fe41ed5984e6177237969001da44098b6ee0e6f8 --- /dev/null +++ b/MegaTTS3/assets/fig/table_wavvae.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:472a8a339388a3f1bd8bed1bfb8c3211a0cbf63974553dabafbcd5ca21178710 +size 152975 diff --git a/MegaTTS3/readme.md b/MegaTTS3/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..1d34a6319501024ccee8dfacedceac50f3361a32 --- /dev/null +++ b/MegaTTS3/readme.md @@ -0,0 +1,194 @@ +
+

+ MegaTTS 3 +

+

+ Official PyTorch Implementation
+

+
+
+ Hugging Face + version + version + python + mit +
+
+ + +
+ +## Key features +- 🚀**Lightweight and Efficient:** The backbone of the TTS Diffusion Transformer has only 0.45B parameters. +- 🎧**Ultra High-Quality Voice Cloning:** You can try our model at [Huggingface Demo](https://huggingface.co/spaces/ByteDance/MegaTTS3)🎉. The .wav and .npy files can be found at [link1](https://drive.google.com/drive/folders/1QhcHWcy20JfqWjgqZX1YM3I6i9u4oNlr?usp=sharing). Submit a sample (.wav format, < 24s, and please do not contain space in filename) on [link2](https://drive.google.com/drive/folders/1gCWL1y_2xu9nIFhUX_OW5MbcFuB7J5Cl?usp=sharing) to receive .npy voice latents you can use locally. +- 🌍**Bilingual Support:** Supports both Chinese and English, and code-switching. +- ✍️**Controllable:** Supports accent intensity control ✅ and fine-grained pronunciation/duration adjustment (coming soon). + +[MegaTTS 3 Demo Video](https://github.com/user-attachments/assets/0174c111-f392-4376-a34b-0b5b8164aacc) + +
+ +
+ +## 🎯Roadmap + +- **[2025-03-22]** Our project has been released! + + +## Installation +``` sh +# Clone the repository +git clone https://github.com/bytedance/MegaTTS3 +cd MegaTTS3 +``` +**Requirements (for Linux)** +``` sh + +# Create a python 3.10 conda env (you could also use virtualenv) +conda create -n megatts3-env python=3.10 +conda activate megatts3-env +pip install -r requirements.txt + +# Set the root directory +export PYTHONPATH="/path/to/MegaTTS3:$PYTHONPATH" + +# [Optional] Set GPU +export CUDA_VISIBLE_DEVICES=0 + +# If you encounter bugs with pydantic in inference, you should check if the versions of pydantic and gradio are matched. +# [Note] if you encounter bugs related with httpx, please check that whether your environmental variable "no_proxy" has patterns like "::" +``` + +**Requirements (for Windows)** +``` sh +# [The Windows version is currently under testing] +# Comment below dependence in requirements.txt: +# # WeTextProcessing==1.0.4.1 + +# Create a python 3.10 conda env (you could also use virtualenv) +conda create -n megatts3-env python=3.10 +conda activate megatts3-env +pip install -r requirements.txt +conda install -y -c conda-forge pynini==2.1.5 +pip install WeTextProcessing==1.0.3 + +# [Optional] If you want GPU inference, you may need to install specific version of PyTorch for your GPU from https://pytorch.org/. +pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126 + +# [Note] if you encounter bugs related with `ffprobe` or `ffmpeg`, you can install it through `conda install -c conda-forge ffmpeg` + +# Set environment variable for root directory +set PYTHONPATH="C:\path\to\MegaTTS3;%PYTHONPATH%" # Windows +$env:PYTHONPATH="C:\path\to\MegaTTS3;%PYTHONPATH%" # Powershell on Windows +conda env config vars set PYTHONPATH="C:\path\to\MegaTTS3;%PYTHONPATH%" # For conda users + +# [Optional] Set GPU +set CUDA_VISIBLE_DEVICES=0 # Windows +$env:CUDA_VISIBLE_DEVICES=0 # Powershell on Windows + +``` + +**Requirements (for Docker)** +``` sh +# [The Docker version is currently under testing] +# ! You should download the pretrained checkpoint before running the following command +docker build . -t megatts3:latest + +# For GPU inference +docker run -it -p 7929:7929 --gpus all -e CUDA_VISIBLE_DEVICES=0 megatts3:latest +# For CPU inference +docker run -it -p 7929:7929 megatts3:latest + +# Visit http://0.0.0.0:7929/ for gradio. +``` + + +**Model Download** + +The pretrained checkpoint can be found at [Google Drive](https://drive.google.com/drive/folders/1CidiSqtHgJTBDAHQ746_on_YR0boHDYB?usp=sharing) or [Huggingface](https://huggingface.co/ByteDance/MegaTTS3). Please download them and put them to ``./checkpoints/xxx``. + +> [!IMPORTANT] +> For security issues, we do not upload the parameters of WaveVAE encoder to the above links. You can only use the pre-extracted latents from [link1](https://drive.google.com/drive/folders/1QhcHWcy20JfqWjgqZX1YM3I6i9u4oNlr?usp=sharing) for inference. If you want to synthesize speech for speaker A, you need "A.wav" and "A.npy" in the same directory. If you have any questions or suggestions for our model, please email us. +> +> This project is primarily intended for academic purposes. For academic datasets requiring evaluation, you may upload them to the voice request queue in [link2](https://drive.google.com/drive/folders/1gCWL1y_2xu9nIFhUX_OW5MbcFuB7J5Cl?usp=sharing) (within 24s for each clip). After verifying that your uploaded voices are free from safety issues, we will upload their latent files to [link1](https://drive.google.com/drive/folders/1QhcHWcy20JfqWjgqZX1YM3I6i9u4oNlr?usp=sharing) as soon as possible. +> +> In the coming days, we will also prepare and release the latent representations for some common TTS benchmarks. + +## Inference + +**Command-Line Usage (Standard)** +``` bash +# p_w (intelligibility weight), t_w (similarity weight). Typically, prompt with more noises requires higher p_w and t_w +python tts/infer_cli.py --input_wav 'assets/Chinese_prompt.wav' --input_text "另一边的桌上,一位读书人嗤之以鼻道,'佛子三藏,神子燕小鱼是什么样的人物,李家的那个李子夜如何与他们相提并论?'" --output_dir ./gen + +# As long as audio volume and pronunciation are appropriate, increasing --t_w within reasonable ranges (2.0~5.0) +# will increase the generated speech's expressiveness and similarity (especially for some emotional cases). +python tts/infer_cli.py --input_wav 'assets/English_prompt.wav' --input_text 'As his long promised tariff threat turned into reality this week, top human advisers began fielding a wave of calls from business leaders, particularly in the automotive sector, along with lawmakers who were sounding the alarm.' --output_dir ./gen --p_w 2.0 --t_w 3.0 +``` +**Command-Line Usage (for TTS with Accents)** +``` bash +# When p_w (intelligibility weight) ≈ 1.0, the generated audio closely retains the speaker’s original accent. As p_w increases, it shifts toward standard pronunciation. +# t_w (similarity weight) is typically set 0–3 points higher than p_w for optimal results. +# Useful for accented TTS or solving the accent problems in cross-lingual TTS. +python tts/infer_cli.py --input_wav 'assets/English_prompt.wav' --input_text '这是一条有口音的音频。' --output_dir ./gen --p_w 1.0 --t_w 3.0 + +python tts/infer_cli.py --input_wav 'assets/English_prompt.wav' --input_text '这条音频的发音标准一些了吗?' --output_dir ./gen --p_w 2.5 --t_w 2.5 +``` + +**Web UI Usage** +``` bash +# We also support cpu inference, but it may take about 30 seconds (for 10 inference steps). +python tts/gradio_api.py +``` + +## Submodules +> [!TIP] +> In addition to TTS, some submodules in this project may also have additional usages. +> See ``./tts/frontend_fuction.py`` and ``./tts/infer_cli.py`` for example code. + +### Aligner +**Description:** a robust speech-text aligner model trained using pseudo-labels generated by a large number of MFA expert models. + +**Usage**: 1) Prepare the finetuning dataset for our model; 2) Filter the large-scale speech dataset (if the aligner fails to align a certain speech clip, it is likely to be noisy); 3) Phoneme recognition; 4) Speech segmentation. + +### Graphme-to-Phoneme Model +**Description:** a Qwen2.5-0.5B model finetuned for robust graphme-to-phoneme conversion. + +**Usage**: Graphme-to-phoneme conversion. + +### WaveVAE +**Description:** a strong waveform VAE that can compress 24 kHz speeche into 25 Hz acoustic latent and reconstruct the original wave almost losslessly. + +**Usage:** 1) Acoustic latents can provide a more compact and discriminative training target for speech synthesis models compared to mel-spectrograms, accelerating convergence; 2) Used as acoustic latents for voice conversion; 3) High-quality vocoder. + +
+ +
+ + +## Security +If you discover a potential security issue in this project, or think you may +have discovered a security issue, we ask that you notify Bytedance Security via our [security center](https://security.bytedance.com/src) or [sec@bytedance.com](sec@bytedance.com). + +Please do **not** create a public GitHub issue. + +## License +This project is licensed under the [Apache-2.0 License](LICENSE). + +## Citation +This repo contains forced-align version of `Sparse Alignment Enhanced Latent Diffusion Transformer for Zero-Shot Speech Synthesis` and the WavVAE is mainly based on `Wavtokenizer: an efficient acoustic discrete codec tokenizer for audio language modeling`. Compared to the model described in paper, the repository includes additional models. These models not only enhance the stability and cloning capabilities of the algorithm but can also be independently utilized to serve a wider range of scenarios. +``` +@article{jiang2025sparse, + title={Sparse Alignment Enhanced Latent Diffusion Transformer for Zero-Shot Speech Synthesis}, + author={Jiang, Ziyue and Ren, Yi and Li, Ruiqi and Ji, Shengpeng and Ye, Zhenhui and Zhang, Chen and Jionghao, Bai and Yang, Xiaoda and Zuo, Jialong and Zhang, Yu and others}, + journal={arXiv preprint arXiv:2502.18924}, + year={2025} +} + +@article{ji2024wavtokenizer, + title={Wavtokenizer: an efficient acoustic discrete codec tokenizer for audio language modeling}, + author={Ji, Shengpeng and Jiang, Ziyue and Wang, Wen and Chen, Yifu and Fang, Minghui and Zuo, Jialong and Yang, Qian and Cheng, Xize and Wang, Zehan and Li, Ruiqi and others}, + journal={arXiv preprint arXiv:2408.16532}, + year={2024} +} +``` diff --git a/MegaTTS3/requirements.txt b/MegaTTS3/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..2b1af6aa63f90faf740b911128828b50b585cdab --- /dev/null +++ b/MegaTTS3/requirements.txt @@ -0,0 +1,15 @@ +numpy<2 +setproctitle==1.3.3 +attrdict==2.0.1 +librosa==0.10.2.post1 +langdetect==1.0.9 +pydub==0.25.1 +pyloudnorm==0.1.1 +modelscope==1.22.2 +transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.*;python_version<'3.10' +transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.0;python_version>='3.10' +x-transformers==1.44.4 +torchdiffeq==0.2.5 +openai-whisper==20240930 +httpx==0.28.1 +gradio==5.23.1 diff --git a/MegaTTS3/run.sh b/MegaTTS3/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..8594f017965a8eb6d7980b4f9d251ef1037b7b35 --- /dev/null +++ b/MegaTTS3/run.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +# Define the root directory +root_dir="../../result/claude-3.7-sonnet" + +# Use find to search for all .pptx files in the directory and subdirectories +find "$root_dir" -type f -name "*.pptx" | while IFS= read -r pptx; do + echo "Running python test.py on \"$pptx\"" + python test.py --pptx "$pptx" +done diff --git a/MegaTTS3/test.py b/MegaTTS3/test.py new file mode 100644 index 0000000000000000000000000000000000000000..6b7bfc4925f4c2c1ebf4ac745c1584027bbf3ec0 --- /dev/null +++ b/MegaTTS3/test.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +import os +import tempfile +import argparse +from subprocess import call +import subprocess +from pdf2image import convert_from_path +from pptx import Presentation +# from gtts import gTTS + + +__author__ = ['chaonan99'] + + +## Sometimes ffmpeg is avconv +FFMPEG_NAME = 'ffmpeg' +# FFMPEG_NAME = 'avconv' + +import os +from typing import Optional + +from tts.infer_cli import MegaTTS3DiTInfer # adjust import path as needed +from tts.utils.audio_utils.io import save_wav + + +def get_tts( + input_wav_path: str, + input_text: str, + output_path: str, + time_step: int = 32, + p_w: float = 1.6, + t_w: float = 2.5, + device: Optional[str] = None, +) -> str: + """ + Generate TTS audio from an input WAV file and text prompt. + + Args: + input_wav_path: Path to the input WAV (prompt) file. + input_text: Text to synthesize. + output_path: Path to the output audio file. + time_step: Diffusion inference steps. + p_w: Intelligibility weight. + t_w: Similarity weight. + device: Device specifier (e.g., 'cuda' or 'cpu'). If None, auto-selected. + + Returns: + The full path to the generated WAV file. + """ + # Initialize the inference model + infer = MegaTTS3DiTInfer(device=device) + + # Read prompt audio + with open(input_wav_path, 'rb') as f: + audio_bytes = f.read() + + # Locate corresponding latent file if available + latent_file = None + potential_npy = os.path.splitext(input_wav_path)[0] + '.npy' + if os.path.isfile(potential_npy): + latent_file = potential_npy + + # Preprocess: extract features and durations + resource_context = infer.preprocess(audio_bytes, latent_file=latent_file) + + # Synthesize speech + wav_bytes = infer.forward( + resource_context, + input_text, + time_step=time_step, + p_w=p_w, + t_w=t_w + ) + + # Ensure output directory exists and save + save_wav(wav_bytes, output_path) + + return output_path + + + + + +def ppt_presenter(pptx_path): + cmd = ['libreoffice', '--headless', '--convert-to', 'pdf', pptx_path, '--outdir', os.path.dirname(pptx_path)] + result = subprocess.run(cmd, capture_output=True, text=True) + + pdf_path = os.path.splitext(pptx_path)[0] + '.pdf' + output_path = os.path.splitext(pptx_path)[0] + '.mp4' + with tempfile.TemporaryDirectory() as temp_path: + images_from_path = convert_from_path(pdf_path) + prs = Presentation(pptx_path) + assert len(images_from_path) == len(prs.slides) + for i, (slide, image) in enumerate(zip(prs.slides, images_from_path)): + if slide.has_notes_slide: + notes = slide.notes_slide.notes_text_frame.text + + # tts = gTTS(text=notes, lang='en') + image_path = os.path.join(temp_path, 'frame_{}.jpg'.format(i)) + audio_path = os.path.join(temp_path, 'frame_{}.mp3'.format(i)) + + image.save(image_path) + get_tts("assets/English_prompt.wav", notes, audio_path) + # tts.save(audio_path) + + ffmpeg_call(image_path, audio_path, temp_path, i) + + video_list = [os.path.join(temp_path, 'frame_{}.ts'.format(i)) \ + for i in range(len(images_from_path))] + video_list_str = 'concat:' + '|'.join(video_list) + ffmpeg_concat(video_list_str, output_path) + + +def ffmpeg_call(image_path, audio_path, temp_path, i): + out_path_mp4 = os.path.join(temp_path, 'frame_{}.mp4'.format(i)) + out_path_ts = os.path.join(temp_path, 'frame_{}.ts'.format(i)) + call([FFMPEG_NAME, '-loop', '1', '-y', '-i', image_path, '-i', audio_path, + '-vf', 'scale=2666:1500', '-c:v', 'libx264', '-tune', 'stillimage', '-c:a', 'aac', + '-b:a', '192k', '-pix_fmt', 'yuv420p', '-shortest', out_path_mp4]) + + call([FFMPEG_NAME, '-y', '-i', out_path_mp4, '-c', 'copy', + '-bsf:v', 'h264_mp4toannexb', '-f', 'mpegts', out_path_ts]) + + +def ffmpeg_concat(video_list_str, out_path): + call([FFMPEG_NAME, '-y', '-f', 'mpegts', '-i', '{}'.format(video_list_str), + '-c', 'copy', '-bsf:a', 'aac_adtstoasc', out_path]) + + +def main(): + parser = argparse.ArgumentParser(description='PPT Presenter help.') + parser.add_argument('--pptx', default='../../ppagent_2025-06-29_152592d9-df14-48d0-b6de-99fa7fe4fdac.pptx', help='input pptx path') + args = parser.parse_args() + ppt_presenter(args.pptx) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MegaTTS3/test2.py b/MegaTTS3/test2.py new file mode 100644 index 0000000000000000000000000000000000000000..eeceb8c07beb761dd76d03f54639de31e17218f6 --- /dev/null +++ b/MegaTTS3/test2.py @@ -0,0 +1,30 @@ +import os +import subprocess + +path = "../../new_result" + +# Loop through directories in the specified path +for dirs in os.listdir(path): + new_path = os.path.join(path, dirs) + + # Skip system files like .DS_Store + if dirs == ".DS_Store": + continue + + # Loop through files in each directory + for filename in os.listdir(new_path): + file = os.path.join(new_path, filename) + + # Process only .pptx files + if file.endswith(".pptx"): + file2 = file.replace(".pptx", ".mp4") + + # Skip if .mp4 already exists + if os.path.exists(file2): + continue + + # Log the processing of the file + print(f"Processing {file}") + + # Call the external script to convert the pptx to mp4 + subprocess.call(['python', 'test.py', '--pptx', file]) diff --git a/MegaTTS3/tts/frontend_function.py b/MegaTTS3/tts/frontend_function.py new file mode 100644 index 0000000000000000000000000000000000000000..46a09b08392549a5676568b92782b7db6caf449a --- /dev/null +++ b/MegaTTS3/tts/frontend_function.py @@ -0,0 +1,181 @@ +# Copyright 2025 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F +import whisper +import librosa +from copy import deepcopy +from tts.utils.text_utils.ph_tone_convert import split_ph_timestamp, split_ph +from tts.utils.audio_utils.align import mel2token_to_dur + +''' Graphme to phoneme function ''' +def g2p(self, text_inp): + # prepare inputs + txt_token = self.g2p_tokenizer('' + text_inp + '')['input_ids'] + input_ids = torch.LongTensor([txt_token+[145+self.speech_start_idx]]).to(self.device) + + # model forward + with torch.cuda.amp.autocast(dtype=self.precision, enabled=True): + outputs = self.g2p_model.generate(input_ids, max_new_tokens=256, do_sample=True, top_k=1, eos_token_id=800+1+self.speech_start_idx) + + # process outputs + ph_tokens = outputs[:, len(txt_token):-1]-self.speech_start_idx + ph_pred, tone_pred = split_ph(ph_tokens[0]) + ph_pred, tone_pred = ph_pred[None, :].to(self.device), tone_pred[None, :].to(self.device) + return ph_pred, tone_pred + +''' Get phoneme2mel align of prompt speech ''' +def align(self, wav): + with torch.inference_mode(): + whisper_wav = librosa.resample(wav, orig_sr=self.sr, target_sr=16000) + mel = torch.FloatTensor(whisper.log_mel_spectrogram(whisper_wav).T).to(self.device)[None].transpose(1,2) + prompt_max_frame = mel.size(2) // self.fm * self.fm + mel = mel[:, :, :prompt_max_frame] + token = torch.LongTensor([[798]]).to(self.device) + audio_features = self.aligner_lm.embed_audio(mel) + for i in range(768): + with torch.cuda.amp.autocast(dtype=self.precision, enabled=True): + logits = self.aligner_lm.logits(token, audio_features, None) + token_pred = torch.argmax(F.softmax(logits[:, -1], dim=-1), 1)[None] + token = torch.cat([token, token_pred], dim=1) + if token_pred[0] == 799: + break + alignment_tokens = token + + ph_ref, tone_ref, dur_ref, _ = split_ph_timestamp(deepcopy(alignment_tokens)[0, 1:-1]) + ph_ref = torch.Tensor(ph_ref)[None].to(self.device) + tone_ref = torch.Tensor(tone_ref)[None].to(self.device) + if dur_ref.sum() < prompt_max_frame: + dur_ref[-1] += prompt_max_frame - dur_ref.sum() + elif dur_ref.sum() > prompt_max_frame: + len_diff = dur_ref.sum() - prompt_max_frame + while True: + for i in range(len(dur_ref)): + dur_ref[i] -= 1 + len_diff -= 1 + if len_diff == 0: + break + if len_diff == 0: + break + mel2ph_ref = self.length_regulator(dur_ref[None]).to(self.device) + mel2ph_ref = mel2ph_ref[:, :mel2ph_ref.size(1)//self.fm*self.fm] + return ph_ref, tone_ref, mel2ph_ref + +''' Duration Prompting ''' +def make_dur_prompt(self, mel2ph_ref, ph_ref, tone_ref): + dur_tokens_2d_ = mel2token_to_dur(mel2ph_ref, ph_ref.shape[1]).clamp( + max=self.hp_dur_model['dur_code_size'] - 1) + 1 + + ctx_dur_tokens = dur_tokens_2d_.clone().flatten(0, 1).to(self.device) + txt_tokens_flat_ = ph_ref.flatten(0, 1) + ctx_dur_tokens = ctx_dur_tokens[txt_tokens_flat_ > 0][None] + + last_dur_pos_prompt = ctx_dur_tokens.shape[1] + dur_spk_pos_ids_flat = range(0, last_dur_pos_prompt) + dur_spk_pos_ids_flat = torch.LongTensor([dur_spk_pos_ids_flat]).to(self.device) + with torch.cuda.amp.autocast(dtype=self.precision, enabled=True): + _, incremental_state_dur_prompt = self.dur_model.infer( + ph_ref, {'tone': tone_ref}, None, None, None, + ctx_vqcodes=ctx_dur_tokens, spk_pos_ids_flat=dur_spk_pos_ids_flat, return_state=True) + return incremental_state_dur_prompt, ctx_dur_tokens + +''' Duration Prediction ''' +def dur_pred(self, ctx_dur_tokens, incremental_state_dur_prompt, ph_pred, tone_pred, seg_i, dur_disturb, dur_alpha, is_first, is_final): + last_dur_token = ctx_dur_tokens[:, -1:] + last_dur_pos_prompt = ctx_dur_tokens.shape[1] + incremental_state_dur = deepcopy(incremental_state_dur_prompt) + txt_len = ph_pred.shape[1] + dur_spk_pos_ids_flat = range(last_dur_pos_prompt, last_dur_pos_prompt + txt_len) + dur_spk_pos_ids_flat = torch.LongTensor([dur_spk_pos_ids_flat]).to(self.device) + last_dur_pos_prompt = last_dur_pos_prompt + txt_len + + with torch.cuda.amp.autocast(dtype=self.precision, enabled=True): + dur_pred = self.dur_model.infer( + ph_pred, {'tone': tone_pred}, None, None, None, + incremental_state=incremental_state_dur, + first_decoder_inp=last_dur_token, + spk_pos_ids_flat=dur_spk_pos_ids_flat, + ) + + dur_pred = dur_pred - 1 + dur_pred = dur_pred.clamp(0, self.hp_dur_model['dur_code_size'] - 1) + # if is_final: + # dur_pred[:, -1] = dur_pred[:, -1].clamp(64, 128) + # else: + # dur_pred[:, -1] = dur_pred[:, -1].clamp(48, 128) + # if seg_i > 0: + # dur_pred[:, 0] = 0 + # ['。', '!', '?', 'sil'] + # for sil_token in [148, 153, 166, 145]: + # dur_pred[ph_pred==sil_token].clamp_min(32) + # # [',', ';'] + # for sil_token in [163, 165]: + # dur_pred[ph_pred==sil_token].clamp_min(16) + if not is_final: + # add 0.32ms for crossfade + dur_pred[:, -1] = dur_pred[:, -1] + 32 + else: + dur_pred[:, -1] = dur_pred[:, -1].clamp(64, 128) + + ''' DiT target speech generation ''' + dur_disturb_choice = (torch.rand_like(dur_pred.float()) > 0.5).float() + dur_disturb_r = 1 + torch.rand_like(dur_pred.float()) * dur_disturb + dur_pred = dur_pred * dur_disturb_r * dur_disturb_choice + \ + dur_pred / dur_disturb_r * (1 - dur_disturb_choice) + dur_pred = torch.round(dur_pred * dur_alpha).clamp(0, 127) + # ['。', '!', '?', 'sil'] + for sil_token in [148, 153, 166, 145]: + dur_pred[ph_pred==sil_token] = dur_pred[ph_pred==sil_token].clamp_min(64) + # [',', ';'] + for sil_token in [163, 165]: + dur_pred[ph_pred==sil_token] = dur_pred[ph_pred==sil_token].clamp_min(32) + if is_first: + dur_pred[:, 0] = 8 + + dur_sum = dur_pred.sum() + npad = self.fm - dur_sum % self.fm + if npad < self.fm: + dur_pred[:, -1] += npad + mel2ph_pred = self.length_regulator(dur_pred).to(self.device) + return mel2ph_pred + +def prepare_inputs_for_dit(self, mel2ph_ref, mel2ph_pred, ph_ref, tone_ref, ph_pred, tone_pred, vae_latent): + # Prepare duration token + mel2ph_pred = torch.cat((mel2ph_ref, mel2ph_pred+ph_ref.size(1)), dim=1) + mel2ph_pred = mel2ph_pred[:, :mel2ph_pred.size(1)//self.fm*self.fm].repeat(3, 1) + # Prepare phone and tone token + ph_pred = torch.cat((ph_ref, ph_pred), dim=1) + tone_pred = torch.cat((tone_ref, tone_pred), dim=1) + # Disable the English tone (set them to 3)""" + en_tone_idx = ~((tone_pred == 4) | ( (11 <= tone_pred) & (tone_pred <= 15)) | (tone_pred == 0)) + tone_pred[en_tone_idx] = 3 + + # Prepare cfg inputs + ph_seq = torch.cat([ph_pred, ph_pred, torch.full(ph_pred.size(), self.cfg_mask_token_phone, device=self.device)], 0) + tone_seq = torch.cat([tone_pred, tone_pred, torch.full(tone_pred.size(), self.cfg_mask_token_tone, device=self.device)], 0) + target_size = mel2ph_pred.size(1)//self.vae_stride + vae_latent_ = vae_latent.repeat(3, 1, 1) + ctx_mask = torch.ones_like(vae_latent_[:, :, 0:1]) + vae_latent_ = F.pad(vae_latent_, (0, 0, 0, target_size - vae_latent.size(1)), mode='constant', value=0) + vae_latent_[1:] = 0.0 + ctx_mask = F.pad(ctx_mask, (0, 0, 0, target_size - vae_latent.size(1)), mode='constant', value=0) + + return { + 'phone': ph_seq, + 'tone': tone_seq, + "lat_ctx": vae_latent_ * ctx_mask, + "ctx_mask": ctx_mask, + "dur": mel2ph_pred, + } diff --git a/MegaTTS3/tts/gradio_api.py b/MegaTTS3/tts/gradio_api.py new file mode 100644 index 0000000000000000000000000000000000000000..b6f6be6fa54e4cc23a0e316a6f84817344adc023 --- /dev/null +++ b/MegaTTS3/tts/gradio_api.py @@ -0,0 +1,93 @@ +# Copyright 2025 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing as mp +import torch +import os +from functools import partial +import gradio as gr +import traceback +from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav + + +def model_worker(input_queue, output_queue, device_id): + device = None + if device_id is not None: + device = torch.device(f'cuda:{device_id}') + infer_pipe = MegaTTS3DiTInfer(device=device) + + while True: + task = input_queue.get() + inp_audio_path, inp_npy_path, inp_text, infer_timestep, p_w, t_w = task + try: + convert_to_wav(inp_audio_path) + wav_path = os.path.splitext(inp_audio_path)[0] + '.wav' + cut_wav(wav_path, max_len=28) + with open(wav_path, 'rb') as file: + file_content = file.read() + resource_context = infer_pipe.preprocess(file_content, latent_file=inp_npy_path) + wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w) + output_queue.put(wav_bytes) + except Exception as e: + traceback.print_exc() + print(task, str(e)) + output_queue.put(None) + + +def main(inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w, processes, input_queue, output_queue): + print("Push task to the inp queue |", inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w) + input_queue.put((inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w)) + res = output_queue.get() + if res is not None: + return res + else: + print("") + return None + + +if __name__ == '__main__': + mp.set_start_method('spawn', force=True) + mp_manager = mp.Manager() + + devices = os.environ.get('CUDA_VISIBLE_DEVICES', '') + if devices != '': + devices = os.environ.get('CUDA_VISIBLE_DEVICES', '').split(",") + else: + devices = None + + num_workers = 1 + input_queue = mp_manager.Queue() + output_queue = mp_manager.Queue() + processes = [] + + print("Start open workers") + for i in range(num_workers): + p = mp.Process(target=model_worker, args=(input_queue, output_queue, i % len(devices) if devices is not None else None)) + p.start() + processes.append(p) + + api_interface = gr.Interface(fn= + partial(main, processes=processes, input_queue=input_queue, + output_queue=output_queue), + inputs=[gr.Audio(type="filepath", label="Upload .wav"), gr.File(type="filepath", label="Upload .npy"), "text", + gr.Number(label="infer timestep", value=32), + gr.Number(label="Intelligibility Weight", value=1.4), + gr.Number(label="Similarity Weight", value=3.0)], outputs=[gr.Audio(label="Synthesized Audio")], + title="MegaTTS3", + description="Upload a speech clip as a reference for timbre, " + + "upload the pre-extracted latent file, "+ + "input the target text, and receive the cloned voice.", concurrency_limit=1) + api_interface.launch(server_name='0.0.0.0', server_port=7929, debug=True) + for p in processes: + p.join() diff --git a/MegaTTS3/tts/infer_cli.py b/MegaTTS3/tts/infer_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..dec996bfe3c6865e34652cc081dff81edeafcba8 --- /dev/null +++ b/MegaTTS3/tts/infer_cli.py @@ -0,0 +1,279 @@ +# Copyright 2025 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import argparse +import librosa +import numpy as np +import torch + +from tn.chinese.normalizer import Normalizer as ZhNormalizer +from tn.english.normalizer import Normalizer as EnNormalizer +from langdetect import detect as classify_language +from pydub import AudioSegment +import pyloudnorm as pyln + +from tts.modules.ar_dur.commons.nar_tts_modules import LengthRegulator +from tts.frontend_function import g2p, align, make_dur_prompt, dur_pred, prepare_inputs_for_dit +from tts.utils.audio_utils.io import save_wav, to_wav_bytes, convert_to_wav_bytes, combine_audio_segments +from tts.utils.commons.ckpt_utils import load_ckpt +from tts.utils.commons.hparams import set_hparams, hparams +from tts.utils.text_utils.text_encoder import TokenTextEncoder +from tts.utils.text_utils.split_text import chunk_text_chinese, chunk_text_english, chunk_text_chinesev2 +from tts.utils.commons.hparams import hparams, set_hparams + + +if "TOKENIZERS_PARALLELISM" not in os.environ: + os.environ["TOKENIZERS_PARALLELISM"] = "false" + +def convert_to_wav(wav_path): + # Check if the file exists + if not os.path.exists(wav_path): + print(f"The file '{wav_path}' does not exist.") + return + + # Check if the file already has a .wav extension + if not wav_path.endswith(".wav"): + # Define the output path with a .wav extension + out_path = os.path.splitext(wav_path)[0] + ".wav" + + # Load the audio file using pydub and convert it to WAV + audio = AudioSegment.from_file(wav_path) + audio.export(out_path, format="wav") + + print(f"Converted '{wav_path}' to '{out_path}'") + + +def cut_wav(wav_path, max_len=28): + audio = AudioSegment.from_file(wav_path) + audio = audio[:int(max_len * 1000)] + audio.export(wav_path, format="wav") + +class MegaTTS3DiTInfer(): + def __init__( + self, + device=None, + ckpt_root='MegaTTS3/checkpoints', + dit_exp_name='diffusion_transformer', + frontend_exp_name='aligner_lm', + wavvae_exp_name='wavvae', + dur_ckpt_path='duration_lm', + g2p_exp_name='g2p', + precision=torch.float16, + **kwargs + ): + self.sr = 24000 + self.fm = 8 + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = device + self.precision = precision + + # build models + self.dit_exp_name = os.path.join(ckpt_root, dit_exp_name) + self.frontend_exp_name = os.path.join(ckpt_root, frontend_exp_name) + self.wavvae_exp_name = os.path.join(ckpt_root, wavvae_exp_name) + self.dur_exp_name = os.path.join(ckpt_root, dur_ckpt_path) + self.g2p_exp_name = os.path.join(ckpt_root, g2p_exp_name) + self.build_model(self.device) + + # init text normalizer + self.zh_normalizer = ZhNormalizer(overwrite_cache=False, remove_erhua=False, remove_interjections=False) + self.en_normalizer = EnNormalizer(overwrite_cache=False) + # loudness meter + self.loudness_meter = pyln.Meter(self.sr) + + def build_model(self, device): + set_hparams(exp_name=self.dit_exp_name, print_hparams=False) + + ''' Load Dict ''' + current_dir = os.path.dirname(os.path.abspath(__file__)) + ling_dict = json.load(open(f"{current_dir}/utils/text_utils/dict.json", encoding='utf-8-sig')) + self.ling_dict = {k: TokenTextEncoder(None, vocab_list=ling_dict[k], replace_oov='') for k in ['phone', 'tone']} + self.token_encoder = token_encoder = self.ling_dict['phone'] + ph_dict_size = len(token_encoder) + + ''' Load Duration LM ''' + from tts.modules.ar_dur.ar_dur_predictor import ARDurPredictor + hp_dur_model = self.hp_dur_model = set_hparams(f'{self.dur_exp_name}/config.yaml', global_hparams=False) + hp_dur_model['frames_multiple'] = hparams['frames_multiple'] + self.dur_model = ARDurPredictor( + hp_dur_model, hp_dur_model['dur_txt_hs'], hp_dur_model['dur_model_hidden_size'], + hp_dur_model['dur_model_layers'], ph_dict_size, + hp_dur_model['dur_code_size'], + use_rot_embed=hp_dur_model.get('use_rot_embed', False)) + self.length_regulator = LengthRegulator() + load_ckpt(self.dur_model, f'{self.dur_exp_name}', 'dur_model') + self.dur_model.eval() + self.dur_model.to(device) + + ''' Load Diffusion Transformer ''' + from tts.modules.llm_dit.dit import Diffusion + self.dit = Diffusion() + load_ckpt(self.dit, f'{self.dit_exp_name}', 'dit', strict=False) + self.dit.eval() + self.dit.to(device) + self.cfg_mask_token_phone = 302 - 1 + self.cfg_mask_token_tone = 32 - 1 + + ''' Load Frontend LM ''' + from tts.modules.aligner.whisper_small import Whisper + self.aligner_lm = Whisper() + load_ckpt(self.aligner_lm, f'{self.frontend_exp_name}', 'model') + self.aligner_lm.eval() + self.aligner_lm.to(device) + self.kv_cache = None + self.hooks = None + + ''' Load G2P LM''' + from transformers import AutoTokenizer, AutoModelForCausalLM + g2p_tokenizer = AutoTokenizer.from_pretrained(self.g2p_exp_name, padding_side="right") + g2p_tokenizer.padding_side = "right" + self.g2p_model = AutoModelForCausalLM.from_pretrained(self.g2p_exp_name).eval().to(device) + self.g2p_tokenizer = g2p_tokenizer + self.speech_start_idx = g2p_tokenizer.encode('')[0] + + ''' Wav VAE ''' + self.hp_wavvae = hp_wavvae = set_hparams(f'{self.wavvae_exp_name}/config.yaml', global_hparams=False) + from tts.modules.wavvae.decoder.wavvae_v3 import WavVAE_V3 + self.wavvae = WavVAE_V3(hparams=hp_wavvae) + if os.path.exists(f'{self.wavvae_exp_name}/model_only_last.ckpt'): + load_ckpt(self.wavvae, f'{self.wavvae_exp_name}/model_only_last.ckpt', 'model_gen', strict=True) + self.has_vae_encoder = True + else: + load_ckpt(self.wavvae, f'{self.wavvae_exp_name}/decoder.ckpt', 'model_gen', strict=False) + self.has_vae_encoder = False + self.wavvae.eval() + self.wavvae.to(device) + self.vae_stride = hp_wavvae.get('vae_stride', 4) + self.hop_size = hp_wavvae.get('hop_size', 4) + + def preprocess(self, audio_bytes, latent_file=None, topk_dur=1, **kwargs): + wav_bytes = convert_to_wav_bytes(audio_bytes) + + ''' Load wav ''' + wav, _ = librosa.core.load(wav_bytes, sr=self.sr) + # Pad wav if necessary + ws = hparams['win_size'] + if len(wav) % ws < ws - 1: + wav = np.pad(wav, (0, ws - 1 - (len(wav) % ws)), mode='constant', constant_values=0.0).astype(np.float32) + wav = np.pad(wav, (0, 12000), mode='constant', constant_values=0.0).astype(np.float32) + self.loudness_prompt = self.loudness_meter.integrated_loudness(wav.astype(float)) + + ''' obtain alignments with aligner_lm ''' + ph_ref, tone_ref, mel2ph_ref = align(self, wav) + + with torch.inference_mode(): + ''' Forward WaveVAE to obtain: prompt latent ''' + if self.has_vae_encoder: + wav = torch.FloatTensor(wav)[None].to(self.device) + vae_latent = self.wavvae.encode_latent(wav) + vae_latent = vae_latent[:, :mel2ph_ref.size(1)//4] + else: + assert latent_file is not None, "Please provide latent_file in WaveVAE decoder-only mode" + vae_latent = torch.from_numpy(np.load(latent_file)).to(self.device) + vae_latent = vae_latent[:, :mel2ph_ref.size(1)//4] + + ''' Duration Prompting ''' + self.dur_model.hparams["infer_top_k"] = topk_dur if topk_dur > 1 else None + incremental_state_dur_prompt, ctx_dur_tokens = make_dur_prompt(self, mel2ph_ref, ph_ref, tone_ref) + + return { + 'ph_ref': ph_ref, + 'tone_ref': tone_ref, + 'mel2ph_ref': mel2ph_ref, + 'vae_latent': vae_latent, + 'incremental_state_dur_prompt': incremental_state_dur_prompt, + 'ctx_dur_tokens': ctx_dur_tokens, + } + + def forward(self, resource_context, input_text, time_step, p_w, t_w, dur_disturb=0.1, dur_alpha=1.0, **kwargs): + device = self.device + + ph_ref = resource_context['ph_ref'].to(device) + tone_ref = resource_context['tone_ref'].to(device) + mel2ph_ref = resource_context['mel2ph_ref'].to(device) + vae_latent = resource_context['vae_latent'].to(device) + ctx_dur_tokens = resource_context['ctx_dur_tokens'].to(device) + incremental_state_dur_prompt = resource_context['incremental_state_dur_prompt'] + + with torch.inference_mode(): + ''' Generating ''' + wav_pred_ = [] + language_type = classify_language(input_text) + if language_type == 'en': + input_text = self.en_normalizer.normalize(input_text) + text_segs = chunk_text_english(input_text, max_chars=130) + else: + input_text = self.zh_normalizer.normalize(input_text) + text_segs = chunk_text_chinesev2(input_text, limit=60) + + for seg_i, text in enumerate(text_segs): + ''' G2P ''' + ph_pred, tone_pred = g2p(self, text) + + ''' Duration Prediction ''' + mel2ph_pred = dur_pred(self, ctx_dur_tokens, incremental_state_dur_prompt, ph_pred, tone_pred, seg_i, dur_disturb, dur_alpha, is_first=seg_i==0, is_final=seg_i==len(text_segs)-1) + + inputs = prepare_inputs_for_dit(self, mel2ph_ref, mel2ph_pred, ph_ref, tone_ref, ph_pred, tone_pred, vae_latent) + # Speech dit inference + with torch.cuda.amp.autocast(dtype=self.precision, enabled=True): + x = self.dit.inference(inputs, timesteps=time_step, seq_cfg_w=[p_w, t_w]).float() + + # WavVAE decode + x[:, :vae_latent.size(1)] = vae_latent + wav_pred = self.wavvae.decode(x)[0,0].to(torch.float32) + + ''' Post-processing ''' + # Trim prompt wav + wav_pred = wav_pred[vae_latent.size(1)*self.vae_stride*self.hop_size:].cpu().numpy() + # Norm generated wav to prompt wav's level + meter = pyln.Meter(self.sr) # create BS.1770 meter + loudness_pred = self.loudness_meter.integrated_loudness(wav_pred.astype(float)) + wav_pred = pyln.normalize.loudness(wav_pred, loudness_pred, self.loudness_prompt) + if np.abs(wav_pred).max() >= 1: + wav_pred = wav_pred / np.abs(wav_pred).max() * 0.95 + + # Apply hamming window + wav_pred_.append(wav_pred) + + wav_pred = combine_audio_segments(wav_pred_, sr=self.sr).astype(float) + return to_wav_bytes(wav_pred, self.sr) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--input_wav', type=str) + parser.add_argument('--input_text', type=str) + parser.add_argument('--output_dir', type=str) + parser.add_argument('--output_file_name', type=str) + parser.add_argument('--time_step', type=int, default=32, help='Inference steps of Diffusion Transformer') + parser.add_argument('--p_w', type=float, default=1.6, help='Intelligibility Weight') + parser.add_argument('--t_w', type=float, default=2.5, help='Similarity Weight') + args = parser.parse_args() + wav_path, input_text, out_path, time_step, p_w, t_w, file_name = args.input_wav, args.input_text, args.output_dir, args.time_step, args.p_w, args.t_w, args.output_file_name + + infer_ins = MegaTTS3DiTInfer() + + with open(wav_path, 'rb') as file: + file_content = file.read() + + print(f"| Start processing {wav_path}+{input_text}") + resource_context = infer_ins.preprocess(file_content, latent_file=wav_path.replace('.wav', '.npy')) + wav_bytes = infer_ins.forward(resource_context, input_text, time_step=time_step, p_w=p_w, t_w=t_w) + + print(f"| Saving results to {out_path}/[P]{input_text[:20]}.wav") + os.makedirs(out_path, exist_ok=True) + save_wav(wav_bytes, f'{out_path}/{file_name}.wav') \ No newline at end of file diff --git a/MegaTTS3/tts/modules/aligner/whisper_small.py b/MegaTTS3/tts/modules/aligner/whisper_small.py new file mode 100644 index 0000000000000000000000000000000000000000..87e71a6bb995b96d4a52dd8569097eede293c2f6 --- /dev/null +++ b/MegaTTS3/tts/modules/aligner/whisper_small.py @@ -0,0 +1,318 @@ +# MIT License + +# Copyright (c) 2022 OpenAI + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# Copyright (c) [2022] [OpenAI] +# Copyright (c) [2025] [Ziyue Jiang] +# SPDX-License-Identifier: MIT +# This file has been modified by Ziyue Jiang on 2025/03/19 +# Original file was released under MIT, with the full license text # available at https://github.com/openai/whisper/blob/v20240930/LICENSE. +# This modified file is released under the same license. + +from contextlib import contextmanager +from typing import Dict, Iterable, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from torch.nn.functional import scaled_dot_product_attention +SDPA_AVAILABLE = True + + +class LayerNorm(nn.LayerNorm): + def forward(self, x: Tensor) -> Tensor: + return super().forward(x.float()).type(x.dtype) + + +class Linear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + return F.linear( + x, + self.weight.to(x.dtype), + None if self.bias is None else self.bias.to(x.dtype), + ) + + +class Conv1d(nn.Conv1d): + def _conv_forward( + self, x: Tensor, weight: Tensor, bias: Optional[Tensor] + ) -> Tensor: + return super()._conv_forward( + x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) + ) + + +def sinusoids(length, channels, max_timescale=10000): + """Returns sinusoids for positional embedding""" + assert channels % 2 == 0 + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + + +@contextmanager +def disable_sdpa(): + prev_state = MultiHeadAttention.use_sdpa + try: + MultiHeadAttention.use_sdpa = False + yield + finally: + MultiHeadAttention.use_sdpa = prev_state + + +class MultiHeadAttention(nn.Module): + use_sdpa = True + + def __init__(self, n_state: int, n_head: int): + super().__init__() + self.n_head = n_head + self.query = Linear(n_state, n_state) + self.key = Linear(n_state, n_state, bias=False) + self.value = Linear(n_state, n_state) + self.out = Linear(n_state, n_state) + + def forward( + self, + x: Tensor, + xa: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + kv_cache: Optional[dict] = None, + casual: Optional[bool] = None + ): + q = self.query(x) + + if kv_cache is None or xa is None or self.key not in kv_cache: + # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; + # otherwise, perform key/value projections for self- or cross-attention as usual. + k = self.key(x if xa is None else xa) + v = self.value(x if xa is None else xa) + else: + # for cross-attention, calculate keys and values once and reuse in subsequent calls. + k = kv_cache[self.key] + v = kv_cache[self.value] + + wv = self.qkv_attention(q, k, v, mask, casual) + return self.out(wv) + + def qkv_attention( + self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, casual: Optional[bool] = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + n_batch, n_ctx, n_state = q.shape + scale = (n_state // self.n_head) ** -0.25 + q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + + a = scaled_dot_product_attention( + q, k, v, is_causal=casual and n_ctx > 1, attn_mask=mask[:, None, None, :] if mask is not None else None + ) + out = a.permute(0, 2, 1, 3).flatten(start_dim=2) + return out + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): + super().__init__() + + self.attn = MultiHeadAttention(n_state, n_head) + self.attn_ln = LayerNorm(n_state) + + self.cross_attn = ( + MultiHeadAttention(n_state, n_head) if cross_attention else None + ) + self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None + + n_mlp = n_state * 4 + self.mlp = nn.Sequential( + Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state) + ) + self.mlp_ln = LayerNorm(n_state) + + def forward( + self, + x: Tensor, + xa: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + kv_cache: Optional[dict] = None, + casual: Optional[bool] = None, + ): + x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache, casual=casual) + if self.cross_attn: + # TODO: Cross attention mask + x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, casual=False) + x = x + self.mlp(self.mlp_ln(x)) + return x + + +class AudioEncoder(nn.Module): + def __init__( + self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int + ): + super().__init__() + self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) + self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) + self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) + + self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( + [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] + ) + self.ln_post = LayerNorm(n_state) + + def forward(self, x: Tensor, attn_mask: Tensor): + """ + x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) + the mel spectrogram of the audio + """ + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + x = x.permute(0, 2, 1) + + # assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" + x = (x + self.positional_embedding[:x.size(1)]).to(x.dtype) + + for block in self.blocks: + x = block(x, mask=attn_mask, casual=False) + + x = self.ln_post(x) + return x + + +class TextDecoder(nn.Module): + def __init__( + self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int + ): + super().__init__() + + self.token_embedding = nn.Embedding(n_vocab, n_state) + self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) + + self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( + [ + ResidualAttentionBlock(n_state, n_head, cross_attention=True) + for _ in range(n_layer) + ] + ) + self.ln = LayerNorm(n_state) + + self.out_proj = nn.Linear(n_state, n_vocab) + + def forward(self, x: Tensor, attn_mask: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): + """ + x : torch.LongTensor, shape = (batch_size, <= n_ctx) + the text tokens + xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) + the encoded audio features to be attended on + """ + offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 + x = ( + self.token_embedding(x) + + self.positional_embedding[offset : offset + x.shape[-1]] + ) + x = x.to(xa.dtype) + + for block in self.blocks: + x = block(x, xa, mask=attn_mask, kv_cache=kv_cache, casual=True) + + x = self.ln(x) + # logits = ( + # x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) + # ).float() + logits = self.out_proj(x) + + return logits + + +class Whisper(nn.Module): + def __init__(self): + super().__init__() + self.n_vocab = 6800 + self.n_text_layer = 6 + self.n_text_head = 8 + self.n_text_ctx = 2048 + + self.encoder = AudioEncoder( + n_mels=80, n_ctx=3000, n_state=512, n_head=8, n_layer=6, + ) + self.decoder = TextDecoder( + n_vocab=6800, n_ctx=2048, n_state=512, n_head=8, n_layer=6, + ) + + def embed_audio(self, mel: torch.Tensor): + return self.encoder(mel, None) + + def logits(self, tokens, audio_features, kv_cache=None): + return self.decoder(tokens, None, audio_features, kv_cache=kv_cache) + + def forward( + self, mel, mel_len, token, token_len + ) -> Dict[str, torch.Tensor]: + attn_mask_enc = self.sequence_mask(mel_len//2, device=mel.device) > 0 + attn_mask_dec = self.sequence_mask(token_len, device=mel.device) > 0 + return self.decoder(token, attn_mask_dec, self.encoder(mel, attn_mask_enc)) + + @property + def device(self): + return next(self.parameters()).device + + def install_kv_cache_hooks(self, cache: Optional[dict] = None): + """ + The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value + tensors calculated for the previous positions. This method returns a dictionary that stores + all caches, and the necessary hooks for the key and value projection modules that save the + intermediate tensors to be reused during later calculations. + + Returns + ------- + cache : Dict[nn.Module, torch.Tensor] + A dictionary object mapping the key/value projection modules to its cache + hooks : List[RemovableHandle] + List of PyTorch RemovableHandle objects to stop the hooks to be called + """ + cache = {**cache} if cache is not None else {} + hooks = [] + + def save_to_cache(module, _, output): + if module not in cache or output.shape[1] > self.n_text_ctx: + # save as-is, for the first token or cross attention + cache[module] = output + else: + cache[module] = torch.cat([cache[module], output], dim=1).detach() + return cache[module] + + def install_hooks(layer: nn.Module): + if isinstance(layer, MultiHeadAttention): + hooks.append(layer.key.register_forward_hook(save_to_cache)) + hooks.append(layer.value.register_forward_hook(save_to_cache)) + + self.decoder.apply(install_hooks) + return cache, hooks + + def sequence_mask(self, seq_lens, max_len=None, device='cpu'): + b = seq_lens.shape[0] + if max_len is None: + max_len = seq_lens.max() + mask = torch.arange(max_len).unsqueeze(0).to(device) # [1, t] + mask = mask < (seq_lens.unsqueeze(1)) # [1, t] + [b, 1] = [b, t] + mask = mask.float() + return mask diff --git a/MegaTTS3/tts/modules/ar_dur/ar_dur_predictor.py b/MegaTTS3/tts/modules/ar_dur/ar_dur_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..3221fcaf750dca82ef6c72b30a4f1360e0836a40 --- /dev/null +++ b/MegaTTS3/tts/modules/ar_dur/ar_dur_predictor.py @@ -0,0 +1,362 @@ +# Copyright 2025 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from copy import deepcopy + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import Linear +from tqdm import tqdm + +from tts.modules.ar_dur.commons.layers import Embedding, LayerNorm +from tts.modules.ar_dur.commons.nar_tts_modules import PosEmb +from tts.modules.ar_dur.commons.rot_transformer import RotTransformerDecoderLayer +from tts.modules.ar_dur.commons.transformer import SinusoidalPositionalEmbedding +from tts.modules.ar_dur.commons.rel_transformer import RelTransformerEncoder + +FS_ENCODERS = { + 'rel_fft': lambda hp, dict_size: RelTransformerEncoder( + dict_size, hp['hidden_size'], hp['hidden_size'], + hp['ffn_hidden_size'], hp['num_heads'], hp['enc_layers'], + hp['enc_ffn_kernel_size'], hp['dropout'], prenet=hp['enc_prenet'], pre_ln=hp['enc_pre_ln']), +} + +def fill_with_neg_inf2(t): + """FP16-compatible function that fills a tensor with -inf.""" + return t.float().fill_(-1e8).type_as(t) + +def expand_states(h, mel2token): + h = F.pad(h, [0, 0, 1, 0]) + mel2token_ = mel2token[..., None].repeat([1, 1, h.shape[-1]]) + h = torch.gather(h, 1, mel2token_) # [B, T, H] + return h + + +class CodePredictor(nn.Module): + def __init__(self, hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size): + super().__init__() + self.hparams = deepcopy(hparams) + self.hparams['hidden_size'] = hidden_size + self.hidden_size = hidden_size + char_dict_size = hparams.get('char_dict_size', 4000) + if not hparams.get('lm_use_enc'): + self.encoder = nn.Embedding(dict_size, self.hidden_size, padding_idx=0) + if hparams.get('mega_use_char', True): + self.char_encoder = nn.Embedding(char_dict_size, + self.hidden_size, padding_idx=0) + else: + self.encoder = FS_ENCODERS[self.hparams['encoder_type']](self.hparams, dict_size) + if hparams.get('mega_use_char', True): + self.char_encoder = FS_ENCODERS[self.hparams['encoder_type']](self.hparams, char_dict_size) + if hparams['use_ph_pos_embed']: + self.ph_pos_embed = PosEmb(self.hidden_size) + + self.char_empty_embed = nn.Embedding(1, self.hidden_size) + if hparams.get('use_bert_input'): + self.bert_input_proj = nn.Linear(768, self.hidden_size) + self.ling_label_embed_layers = nn.ModuleDict() + for k, s in zip(hparams['ling_labels'], hparams['ling_label_dict_size']): + self.ling_label_embed_layers[k] = Embedding(s + 3, self.hidden_size, padding_idx=0) + + self.dec_hidden_size = dec_hidden_size + self.enc_proj = nn.Linear(self.hidden_size, dec_hidden_size) + self.code_emb = Embedding(code_size + 2, dec_hidden_size, 0) + self.use_pos_embed = hparams.get('use_pos_embed', False) + if self.use_pos_embed: + self.embed_positions = SinusoidalPositionalEmbedding(dec_hidden_size, 0, init_size=1024) + self.use_post_ln = hparams.get('use_post_ln', False) + self.layers = None + if not self.use_post_ln: + self.layer_norm = LayerNorm(dec_hidden_size) + self.code_size = code_size + self.project_out_dim = Linear(dec_hidden_size, code_size + 1, bias=True) + + def forward_ling_encoder( + self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, spk_id, spk_embed, mels_timbre): + ph_tokens = txt_tokens + hparams = self.hparams + ph_nonpadding = (ph_tokens > 0).float()[:, :, None] # [B, T_phone, 1] + x_spk = self.forward_style_embed(spk_embed, spk_id, mels_timbre) + + # enc_ph + if not hparams.get('lm_use_enc'): + x_ph = self.encoder(ph_tokens) + x_ph = x_ph + sum( + [self.ling_label_embed_layers[k](ling_feas[k]) for k in hparams['ling_labels']]) \ + if len(hparams['ling_labels']) > 0 else 0 + x_ph = x_ph + x_spk + else: + # enc_ph + ph_enc_oembed = sum( + [self.ling_label_embed_layers[k](ling_feas[k]) for k in hparams['ling_labels']]) \ + if len(hparams['ling_labels']) > 0 else 0 + ph_enc_oembed = ph_enc_oembed + self.ph_pos_embed( + torch.arange(0, ph_tokens.shape[1])[None,].to(ph_tokens.device)) + ph_enc_oembed = ph_enc_oembed + x_spk + ph_enc_oembed = ph_enc_oembed * ph_nonpadding + x_ph = self.encoder(ph_tokens, other_embeds=ph_enc_oembed) + + # enc_char + if char_tokens is not None and ph2char is not None: + char_nonpadding = (char_tokens > 0).float()[:, :, None] + x_char = self.char_encoder(char_tokens) + empty_char = (ph2char > 100000).long() + ph2char = ph2char * (1 - empty_char) + x_char_phlevel = \ + expand_states(x_char * char_nonpadding, ph2char) \ + * (1 - empty_char)[..., None] + \ + self.char_empty_embed(torch.zeros_like(ph_tokens)) * empty_char[..., None] + else: + x_char_phlevel = 0 + # x_ling + x_ling = x_ph + x_char_phlevel + x_ling = x_ling * ph_nonpadding + x_ling = self.enc_proj(x_ling) + return x_ling + + def sample_one_step(self, vq_pred): + hparams = self.hparams + if hparams.get('infer_top_k'): + top_k = hparams.get('infer_top_k') + temperature = hparams.get('infer_temperature', 1) + vq_pred = vq_pred[:, -1] / temperature + # optionally crop the logits to only the top k options + if top_k is not None: + v, _ = torch.topk(vq_pred, min(top_k, vq_pred.size(-1))) + vq_pred[vq_pred < v[:, [-1]]] = -float('Inf') + # apply softmax to convert logits to (normalized) probabilities + probs = F.softmax(vq_pred, dim=-1) + # sample from the distribution + vq_pred = torch.multinomial(probs, num_samples=1) + else: + vq_pred = torch.argmax(F.softmax(vq_pred[:, -1], dim=-1), 1) + return vq_pred + + def forward_style_embed(self, spk_embed=None, spk_id=None, mel_ref=None): + # add spk embed + style_embed = 0 + if self.hparams['use_spk_embed']: + style_embed = style_embed + self.spk_embed_proj(spk_embed)[:, None, :] + if self.hparams['use_spk_id']: + style_embed = style_embed + self.spk_id_proj(spk_id)[:, None, :] + if self.hparams['use_spk_enc']: + style_embed = style_embed + self.spk_enc(mel_ref)[:, None, :] + return style_embed + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + if ( + not hasattr(self, '_future_mask') + or self._future_mask is None + or self._future_mask.device != tensor.device + or self._future_mask.size(0) < dim + ): + self._future_mask = torch.triu(fill_with_neg_inf2(tensor.new(dim, dim)), 1) + return self._future_mask[:dim, :dim] + + +class ARDurPredictor(CodePredictor): + def __init__(self, hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size, use_rot_embed=True, + op_version=1): + super().__init__(hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size) + self.use_rot_embed = use_rot_embed + bias = hparams.get('lm_bias', True) + if self.use_rot_embed: + self.layers = nn.ModuleList([]) + self.layers.extend([ + RotTransformerDecoderLayer( + dec_hidden_size, 0.0, kernel_size=1, ffn_hidden_size=dec_hidden_size * 4, + post_ln=self.use_post_ln, op_version=op_version, bias=bias) + for _ in range(lm_num_layers) + ]) + if hparams['dur_model_type'] == 'ar_mse': + self.project_out_dim = nn.Sequential(torch.nn.Linear(dec_hidden_size, 1), nn.Softplus()) + else: + self.project_out_dim = torch.nn.Linear(dec_hidden_size, code_size + 1) + + def forward(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, + prev_code, spk_id=None, spk_embed=None, mels_timbre=None, mel2ph=None, + incremental_state=None, x_ling=None, attn_mask=None, spk_pos_ids_flat=None, + prompt_length=None, cache_size=20, streaming=False): + x = self.code_emb(prev_code) + if x_ling is None: + x_ling = self.forward_ling_encoder( + txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, spk_id, spk_embed, mels_timbre) + x_ling = x_ling.flatten(0, 1) + txt_tokens = txt_tokens.flatten(0, 1) + x_ling = x_ling[txt_tokens > 0][None] + + # run decoder + self_attn_padding_mask = None + if self.use_pos_embed: + positions = self.embed_positions( + prev_code, + incremental_state=incremental_state + ) + if incremental_state is not None: + x_ling = x_ling[:, x.shape[1] - 1:x.shape[1]] + if spk_pos_ids_flat is not None: + spk_pos_ids_flat = spk_pos_ids_flat[:, x.shape[1] - 1:x.shape[1]] + x = x[:, -1:] + if self.use_pos_embed: + positions = positions[:, -1:] + if streaming: + # Shift Pos: query pos is min(cache_size, idx) + spk_pos_ids_flat = torch.min(torch.LongTensor([prompt_length + cache_size]).to(x.device), + spk_pos_ids_flat) + + # # B x T x C -> T x B x C + if self.use_pos_embed: + x = x + positions + x_ling = x_ling[:, :self.hparams['max_tokens']].contiguous() + T = min(self.hparams.get('max_tokens_per_item', 1e9), x_ling.shape[1]) + x_ling = x_ling.reshape(-1, T, x_ling.shape[-1]) + x = x + x_ling + x = x.transpose(0, 1) + + for idx, layer in enumerate(self.layers): + if incremental_state is None: + self_attn_mask = self.buffered_future_mask(x) + if attn_mask is not None: + self_attn_mask = self_attn_mask + (1 - attn_mask.float()) * -1e8 + self_attn_mask = self_attn_mask.clamp_min(-1e8) + else: + self_attn_mask = None + + x, attn_weights = layer( + x, + incremental_state=incremental_state, + self_attn_mask=self_attn_mask, + self_attn_padding_mask=self_attn_padding_mask, + spk_pos_ids_flat=spk_pos_ids_flat + ) + + if streaming and incremental_state != {}: + for k, v in incremental_state.items(): + if 'attn_state' in k: + prev_key, prev_value = incremental_state[k]['prev_key'], incremental_state[k]['prev_value'] + cur_length = prev_key.shape[2] + if cur_length - prompt_length > cache_size: + prev_key = torch.cat((prev_key[:, :, :prompt_length], prev_key[:, :, -cache_size:]), dim=2) + prev_value = torch.cat((prev_value[:, :, :prompt_length], prev_value[:, :, -cache_size:]), + dim=2) + incremental_state[k]['prev_key'], incremental_state[k]['prev_value'] = prev_key, prev_value + + if not self.use_post_ln: + x = self.layer_norm(x) + # T x B x C -> B x T x C + x = x.transpose(0, 1) + x = self.project_out_dim(x) + return x + + def infer(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, + spk_id=None, spk_embed=None, mels_timbre=None, + incremental_state=None, ctx_vqcodes=None, spk_pos_ids_flat=None, return_state=False, + first_step_min=0, return_probs=False, first_decoder_inp=None, dur_disturb=0.0, **kwargs): + if incremental_state is None: + incremental_state = {} + x_ling = self.forward_ling_encoder( + txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, + spk_id, spk_embed, mels_timbre) + x_ling = x_ling.flatten(0, 1) + txt_tokens_ori = txt_tokens + txt_tokens_withpad = txt_tokens = txt_tokens.flatten(0, 1) + x_ling = x_ling[txt_tokens > 0][None] + txt_tokens = txt_tokens[txt_tokens > 0][None] + + decoded = torch.zeros_like(txt_tokens) + decoded = F.pad(decoded, [1, 0], value=self.code_size + 1) + if incremental_state != {}: + if first_decoder_inp is None: + assert ctx_vqcodes is not None + decoded[:, :ctx_vqcodes.shape[1]] = ctx_vqcodes + ctx_vqcodes = None + else: + decoded[:, :1] = first_decoder_inp + probs = [] + for step in range(decoded.shape[1] - 1): + vq_pred = self(txt_tokens, None, None, None, None, + decoded[:, :step + 1], None, None, None, + incremental_state=incremental_state, x_ling=x_ling, + spk_pos_ids_flat=spk_pos_ids_flat, **kwargs) + probs.append(vq_pred.cpu()) + if ctx_vqcodes is None or step >= ctx_vqcodes.shape[1]: + if self.hparams['dur_model_type'] == 'ar_mse': + d = vq_pred[:, -1, 0] + if dur_disturb > 0 and step >= 1: + if random.random() > 0.5: + d = d * (1 + random.random() * dur_disturb) + else: + d = d / (1 + random.random() * dur_disturb) + d = torch.clamp_max(d, self.code_size - 1) + vq_pred = torch.round(d).long() + else: + vq_pred = self.sample_one_step(vq_pred) + decoded[:, step + 1] = torch.clamp_min(vq_pred, 1) + if step == 0: + decoded[:, step + 1] = torch.clamp_min(vq_pred, first_step_min) + else: + decoded[:, step + 1] = ctx_vqcodes[:, step] + decoded = decoded[:, 1:] + decoded_2d = torch.zeros_like(txt_tokens_ori) + decoded_2d.flatten(0, 1)[txt_tokens_withpad > 0] = decoded + if return_state: + return decoded_2d, incremental_state + if return_probs: + return decoded_2d, torch.cat(probs, 1) + return decoded_2d + + def streaming_infer(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, + spk_id=None, spk_embed=None, mels_timbre=None, + incremental_state=None, ctx_vqcodes=None, spk_pos_ids_flat=None, return_state=False, + **kwargs): + if incremental_state is None: + incremental_state = {} + x_ling = self.forward_ling_encoder( + txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, + spk_id, spk_embed, mels_timbre) + x_ling = x_ling.flatten(0, 1) + txt_tokens_ori = txt_tokens + txt_tokens_withpad = txt_tokens = txt_tokens.flatten(0, 1) + x_ling = x_ling[txt_tokens > 0][None] + txt_tokens = txt_tokens[txt_tokens > 0][None] + + vq_decoded = torch.zeros_like(txt_tokens) + vq_decoded = F.pad(vq_decoded, [1, 0], value=self.code_size + 1) + if incremental_state != {}: + assert ctx_vqcodes is not None + vq_decoded[:, :ctx_vqcodes.shape[1]] = ctx_vqcodes + ctx_vqcodes = None + prompt_length = list(incremental_state.items())[0][1]['prev_key'].shape[2] + for step in tqdm(range(vq_decoded.shape[1] - 1), desc='AR Duration Predictor inference...'): + vq_pred = self(txt_tokens, None, None, None, None, + vq_decoded[:, :step + 1], None, None, None, + incremental_state=incremental_state, x_ling=x_ling, + spk_pos_ids_flat=spk_pos_ids_flat, prompt_length=prompt_length, streaming=True, **kwargs) + if ctx_vqcodes is None or step >= ctx_vqcodes.shape[1]: + if self.hparams['dur_model_type'] == 'ar_mse': + vq_pred = torch.round(vq_pred[:, -1, 0]).long() + else: + vq_pred = self.sample_one_step(vq_pred) + vq_decoded[:, step + 1] = vq_pred + else: + vq_decoded[:, step + 1] = ctx_vqcodes[:, step] + vq_decoded = vq_decoded[:, 1:] + vq_decoded_2d = torch.zeros_like(txt_tokens_ori) + vq_decoded_2d.flatten(0, 1)[txt_tokens_withpad > 0] = vq_decoded + if return_state: + return vq_decoded_2d, incremental_state + return vq_decoded_2d \ No newline at end of file diff --git a/MegaTTS3/tts/modules/ar_dur/commons/layers.py b/MegaTTS3/tts/modules/ar_dur/commons/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..ae8aa03ab5cfccc09812cfa06ac69af14ac75dc2 --- /dev/null +++ b/MegaTTS3/tts/modules/ar_dur/commons/layers.py @@ -0,0 +1,64 @@ +# Copyright 2025 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + + +class LayerNorm(torch.nn.LayerNorm): + """Layer normalization module. + :param int nout: output dim size + :param int dim: dimension to be normalized + """ + + def __init__(self, nout, dim=-1, eps=1e-5): + """Construct an LayerNorm object.""" + super(LayerNorm, self).__init__(nout, eps=eps) + self.dim = dim + + def forward(self, x): + """Apply layer normalization. + :param torch.Tensor x: input tensor + :return: layer normalized tensor + :rtype torch.Tensor + """ + if self.dim == -1: + return super(LayerNorm, self).forward(x) + return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1) + + +class Reshape(nn.Module): + def __init__(self, *args): + super(Reshape, self).__init__() + self.shape = args + + def forward(self, x): + return x.view(self.shape) + + +class Permute(nn.Module): + def __init__(self, *args): + super(Permute, self).__init__() + self.args = args + + def forward(self, x): + return x.permute(self.args) + + +def Embedding(num_embeddings, embedding_dim, padding_idx=None): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) + if padding_idx is not None: + nn.init.constant_(m.weight[padding_idx], 0) + return m diff --git a/MegaTTS3/tts/modules/ar_dur/commons/nar_tts_modules.py b/MegaTTS3/tts/modules/ar_dur/commons/nar_tts_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..f2aa6f061ee8c8fb3d733dfc4978816f8168b883 --- /dev/null +++ b/MegaTTS3/tts/modules/ar_dur/commons/nar_tts_modules.py @@ -0,0 +1,73 @@ +# Copyright 2025 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +from torch import nn + +import torch.nn.functional as F + + +class LengthRegulator(torch.nn.Module): + def __init__(self, pad_value=0.0): + super(LengthRegulator, self).__init__() + self.pad_value = pad_value + + def forward(self, dur, dur_padding=None, alpha=1.0): + """ + Example (no batch dim version): + 1. dur = [2,2,3] + 2. token_idx = [[1],[2],[3]], dur_cumsum = [2,4,7], dur_cumsum_prev = [0,2,4] + 3. token_mask = [[1,1,0,0,0,0,0], + [0,0,1,1,0,0,0], + [0,0,0,0,1,1,1]] + 4. token_idx * token_mask = [[1,1,0,0,0,0,0], + [0,0,2,2,0,0,0], + [0,0,0,0,3,3,3]] + 5. (token_idx * token_mask).sum(0) = [1,1,2,2,3,3,3] + + :param dur: Batch of durations of each frame (B, T_txt) + :param dur_padding: Batch of padding of each frame (B, T_txt) + :param alpha: duration rescale coefficient + :return: + mel2ph (B, T_speech) + assert alpha > 0 + """ + dur = torch.round(dur.float() * alpha).long() + if dur_padding is not None: + dur = dur * (1 - dur_padding.long()) + token_idx = torch.arange(1, dur.shape[1] + 1)[None, :, None].to(dur.device) + dur_cumsum = torch.cumsum(dur, 1) + dur_cumsum_prev = F.pad(dur_cumsum, [1, -1], mode='constant', value=0) + + pos_idx = torch.arange(dur.sum(-1).max())[None, None].to(dur.device) + token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None]) + mel2token = (token_idx * token_mask.long()).sum(1) + return mel2token + + +class PosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim) * -emb) + self.emb = emb # TODO + + def forward(self, x): + emb = x[:, :, None] * self.emb[None, None, :].to(x.device) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb diff --git a/MegaTTS3/tts/modules/ar_dur/commons/rel_transformer.py b/MegaTTS3/tts/modules/ar_dur/commons/rel_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..13a2a76559473fe6ecfc12dbd525ea1fae391b72 --- /dev/null +++ b/MegaTTS3/tts/modules/ar_dur/commons/rel_transformer.py @@ -0,0 +1,403 @@ +# Copyright 2025 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import torch +from torch import nn +from torch.nn import functional as F + +from tts.modules.ar_dur.commons.layers import Embedding + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def shift_1d(x): + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +class Encoder(nn.Module): + def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., + window_size=None, block_length=None, pre_ln=False, **kwargs): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + self.block_length = block_length + self.pre_ln = pre_ln + + self.drop = nn.Dropout(p_dropout) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.attn_layers.append( + MultiHeadAttention(hidden_channels, hidden_channels, n_heads, window_size=window_size, + p_dropout=p_dropout, block_length=block_length)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + if pre_ln: + self.last_ln = LayerNorm(hidden_channels) + + def forward(self, x, x_mask, attn_mask=1): + if isinstance(attn_mask, torch.Tensor): + attn_mask = attn_mask[:, None] + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) * attn_mask + for i in range(self.n_layers): + x = x * x_mask + x_ = x + if self.pre_ln: + x = self.norm_layers_1[i](x) + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = x_ + y + if not self.pre_ln: + x = self.norm_layers_1[i](x) + + x_ = x + if self.pre_ln: + x = self.norm_layers_2[i](x) + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = x_ + y + if not self.pre_ln: + x = self.norm_layers_2[i](x) + if self.pre_ln: + x = self.last_ln(x) + x = x * x_mask + return x + + +class MultiHeadAttention(nn.Module): + def __init__(self, channels, out_channels, n_heads, window_size=None, heads_share=True, p_dropout=0., + block_length=None, proximal_bias=False, proximal_init=False): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.window_size = window_size + self.heads_share = heads_share + self.block_length = block_length + self.proximal_bias = proximal_bias + self.p_dropout = p_dropout + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels ** -0.5 + self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) + self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.drop = nn.Dropout(p_dropout) + + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + if proximal_init: + self.conv_k.weight.data.copy_(self.conv_q.weight.data) + self.conv_k.bias.data.copy_(self.conv_q.bias.data) + nn.init.xavier_uniform_(self.conv_v.weight) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) + if self.window_size is not None: + assert t_s == t_t, "Relative attention is only available for self-attention." + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings) + rel_logits = self._relative_position_to_absolute_position(rel_logits) + scores_local = rel_logits / math.sqrt(self.k_channels) + scores = scores + scores_local + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + if self.block_length is not None: + block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) + scores = scores * block_mask + -1e4 * (1 - block_mask) + p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + """ + x: [b, h, l, m] + y: [h or 1, m, d] + ret: [b, h, l, d] + """ + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x, y): + """ + x: [b, h, l, d] + y: [h or 1, m, d] + ret: [b, h, l, m] + """ + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret + + def _get_relative_embeddings(self, relative_embeddings, length): + max_relative_position = 2 * self.window_size + 1 + # Pad first before slice to avoid using cond ops. + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad( + relative_embeddings, + convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x): + """ + x: [b, h, l, 2*l-1] + ret: [b, h, l, l] + """ + batch, heads, length, _ = x.size() + # Concat columns of pad to shift from relative to absolute indexing. + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) + + # Concat extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])) + + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:] + return x_final + + def _absolute_position_to_relative_position(self, x): + """ + x: [b, h, l, l] + ret: [b, h, l, 2*l-1] + """ + batch, heads, length, _ = x.size() + # padd along column + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])) + x_flat = x.view([batch, heads, -1]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]])) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + def _attention_bias_proximal(self, length): + """Bias for self-attention to encourage attention to close positions. + Args: + length: an integer scalar. + Returns: + a Tensor with shape [1, 1, length, length] + """ + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.activation = activation + + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.conv_2 = nn.Conv1d(filter_channels, out_channels, 1) + self.drop = nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + if self.activation == "gelu": + x = x * torch.sigmoid(1.702 * x) + else: + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + return x * x_mask + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-4): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + n_dims = len(x.shape) + mean = torch.mean(x, 1, keepdim=True) + variance = torch.mean((x - mean) ** 2, 1, keepdim=True) + + x = (x - mean) * torch.rsqrt(variance + self.eps) + + shape = [1, -1] + [1] * (n_dims - 2) + x = x * self.gamma.view(*shape) + self.beta.view(*shape) + return x + + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + assert n_layers > 1, "Number of layers should be larger than 0." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = nn.Sequential( + nn.ReLU(), + nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class RelTransformerEncoder(nn.Module): + def __init__(self, + n_vocab, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout=0.0, + window_size=4, + block_length=None, + in_channels=None, + prenet=True, + pre_ln=True, + ): + + super().__init__() + + self.n_vocab = n_vocab + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + self.block_length = block_length + self.prenet = prenet + if n_vocab > 0: + self.emb = Embedding(n_vocab, hidden_channels, padding_idx=0) + + if prenet: + if in_channels is None: + in_channels = hidden_channels + self.pre = ConvReluNorm(in_channels, in_channels, in_channels, + kernel_size=5, n_layers=3, p_dropout=0) + if in_channels is not None and in_channels != hidden_channels: + self.encoder_inp_proj = nn.Conv1d(in_channels, hidden_channels, 1) + self.encoder = Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + window_size=window_size, + block_length=block_length, + pre_ln=pre_ln, + ) + + def forward(self, x, x_mask=None, other_embeds=0, attn_mask=1): + if self.n_vocab > 0: + x_lengths = (x > 0).long().sum(-1) + x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] + else: + x_lengths = (x.abs().sum(-1) > 0).long().sum(-1) + x = x + other_embeds + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + + if self.prenet: + x = self.pre(x, x_mask) + self.prenet_out = x.transpose(1, 2) + if hasattr(self, 'encoder_inp_proj'): + x = self.encoder_inp_proj(x) * x_mask + x = self.encoder(x, x_mask, attn_mask) + return x.transpose(1, 2) diff --git a/MegaTTS3/tts/modules/ar_dur/commons/rot_transformer.py b/MegaTTS3/tts/modules/ar_dur/commons/rot_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..da9a10dc5bc7e1f1e0aa711d364da0896db420cb --- /dev/null +++ b/MegaTTS3/tts/modules/ar_dur/commons/rot_transformer.py @@ -0,0 +1,649 @@ +# Copyright 2025 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import torch +from typing import Optional, Tuple +from torch import nn +from torch.nn import Parameter, Linear +from tts.modules.ar_dur.commons.layers import LayerNorm, Embedding +from tts.modules.ar_dur.commons.transformer import TransformerFFNLayer, MultiheadAttention +from tts.modules.ar_dur.commons.seq_utils import get_incremental_state, set_incremental_state, softmax, make_positions +import torch.nn.functional as F + +DEFAULT_MAX_SOURCE_POSITIONS = 3000 +DEFAULT_MAX_TARGET_POSITIONS = 3000 + + +class SinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length. + + Padding symbols are ignored. + """ + + def __init__(self, embedding_dim, padding_idx, init_size=1024): + super().__init__() + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.weights = SinusoidalPositionalEmbedding.get_embedding( + init_size, + embedding_dim, + padding_idx, + ) + self.register_buffer('_float_tensor', torch.FloatTensor(1)) + + @staticmethod + def get_embedding(num_embeddings, embedding_dim, padding_idx=None): + """Build sinusoidal embeddings. + + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) + emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + return emb + + def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs): + """Input is expected to be of size [bsz x seqlen].""" + bsz, seq_len = input.shape[:2] + max_pos = self.padding_idx + 1 + seq_len + if self.weights is None or max_pos > self.weights.size(0): + # recompute/expand embeddings if needed + self.weights = SinusoidalPositionalEmbedding.get_embedding( + max_pos, + self.embedding_dim, + self.padding_idx, + ) + self.weights = self.weights.to(self._float_tensor) + + if incremental_state is not None: + # positions is the same for every token when decoding a single step + pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len + return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) + + positions = make_positions(input, self.padding_idx) if positions is None else positions + return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() + + def max_positions(self): + """Maximum number of supported positions.""" + return int(1e5) # an arbitrary large number + + +class RotaryEmbeddings(nn.Module): + cos: torch.Tensor + sin: torch.Tensor + theta: torch.Tensor + + def __init__( + self, + width: int, + *, + seq_len: int = 40000, + base: int = 10000, + device: Optional[torch.device] = None, + ): + """Rotary embeddings (Su et al., 2021) layer. The rotary embedding + will be precomputed for up to 'seq _len' positions. The embedding + will be recomputed when a longer sequence is found in the input. + + :param width: + Rotary embedding dimensionality, must be even. + :param seq_len: + Number of positons to initially precompute. + :param base: + The base used for Θ_i, determines the cycle length of the + embeddings. + :param device: Device on which the module is to be initialized. + """ + super().__init__() + + if width % 2: + raise ValueError(f"Width of rotary embeddings must be even, was: {width}") + + # Ignore allocations on the meta device as we don't persist our buffer, + # i.e., we don't expect the backing tensor to be replaced with pretrained weights. + if device is not None and device.type == "meta": + device = None + # Θ_i = 10000^(-2(i-1)/d) + theta = torch.pow( + base, -torch.arange(0, width, 2, dtype=torch.float, device=device) / width + ) + self.register_buffer("theta", theta, persistent=False) + + self._create_rotary_embed(width=width, length=seq_len) + + def _create_rotary_embed(self, *, width: int, length: int): + # mΘ + position = torch.arange(length, device=self.theta.device).unsqueeze(1) + m_theta = position * self.theta.unsqueeze(0) + + # We apply both sin and cos twice (see Eq 15, 34), but the ordering + # is changed for compatibility with most common implementations. + m_theta = torch.cat([m_theta, m_theta], dim=-1) + + re_cos = m_theta.cos().view([length, width]) + re_sin = m_theta.sin().view([length, width]) + + self.register_buffer("cos", re_cos, persistent=False) + self.register_buffer("sin", re_sin, persistent=False) + + def _rotate(self, input: torch.Tensor): + """Rotate the input tensor by half of its innermost width. + + input (Tensor): array to rotate. + RETURNS (Tensor): rotated array. + + Shapes: + input - (..., width) + output - (..., width) + """ + half_idx = input.shape[-1] // 2 + input_1 = -input[..., half_idx:] + input_2 = input[..., :half_idx] + return torch.cat([input_1, input_2], dim=-1) + + def forward(self, input: torch.Tensor, *, positions: Optional[torch.Tensor] = None): + """ + Apply rotary embeddings to an array. + + :param input: Array to apply the rotary embeddings to. + :param positions: positions of the inputs. If no positions are + provided, they are assumed to be [0, seq_len). + :return: Array with the rotary embeddings applied. + + Shapes: + input - (batch_size, num_heads, seq_len, width_per_head) + positions - (batch_size, seq_len) + output - (batch_size, num_heads, seq_len, width_per_head) + """ + batch_size, _, seq_len, width = input.shape + + if positions is None: + # Fastpath: positions from [0..seq_len), avoid indexing. + if self.cos.size(-2) < seq_len: + self._create_rotary_embed(width=width, length=seq_len) + rot_cos = self.cos[:seq_len, :].view(1, 1, seq_len, width) + rot_sin = self.sin[:seq_len, :].view(1, 1, seq_len, width) + else: + max_len = int(positions.max()) + 1 + if self.cos.size(-2) < max_len: + self._create_rotary_embed(width=width, length=max_len) + + # Flatten positions to index cos/sin arrays, then unflatten. + # + # Example shapes: + # + # positions_flat - (batch_size * seq_len) + # self.cos - (max_len, width) + # rot_cos - (batch_size, seq_len, width) + positions_flat = positions.view(-1) + rot_cos = self.cos[positions_flat].view(batch_size, 1, seq_len, width) + rot_sin = self.sin[positions_flat].view(batch_size, 1, seq_len, width) + + # Eq 34 with ordering changed for compatibility. + return rot_cos * input + rot_sin * self._rotate(input) + + +class RotMultiheadAttention(MultiheadAttention): + def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True, + add_bias_kv=False, add_zero_attn=False, self_attention=False, + encoder_decoder_attention=False): + super().__init__(embed_dim, num_heads, kdim=kdim, vdim=vdim, dropout=dropout, bias=bias, + add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=self_attention, + encoder_decoder_attention=encoder_decoder_attention) + self.rotary_embeds = RotaryEmbeddings(width=embed_dim // num_heads) + + def forward( + self, + query, key, value, + spk_pos_ids_flat=None, + key_padding_mask=None, + incremental_state=None, + need_weights=True, + static_kv=False, + attn_mask=None, + before_softmax=False, + need_head_weights=False, + enc_dec_attn_constraint_mask=None, + reset_attn_weight=None + ): + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if 'prev_key' in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + # self-attention + q, k, v = self.in_proj_qkv(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.in_proj_q(query) + if key is None: + assert value is None + k = v = None + else: + k = self.in_proj_k(key) + v = self.in_proj_v(key) + else: + q = self.in_proj_q(query) + k = self.in_proj_k(key) + v = self.in_proj_v(value) + q = q * self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1) + + q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + + # Apply rot embedding and store incremental_state + q = self.rotary_embeds(q[None, :], positions=spk_pos_ids_flat)[0] + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if 'prev_key' in saved_state: + prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + k = torch.cat((prev_key, k), dim=1) + if 'prev_value' in saved_state: + prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + v = torch.cat((prev_value, v), dim=1) + saved_state['prev_key'], saved_state['prev_value'] = k.view(bsz, self.num_heads, -1, self.head_dim), v.view( + bsz, self.num_heads, -1, self.head_dim) + self._set_input_buffer(incremental_state, saved_state) + if incremental_state is not None: + key_pos = torch.arange(k.shape[-2], device=q.device).unsqueeze(0) + else: + key_pos = spk_pos_ids_flat + k = self.rotary_embeds(k[None, :], positions=key_pos)[0] + + src_len = k.size(1) + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]): + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + if len(attn_mask.shape) == 2: + attn_mask = attn_mask.unsqueeze(0) + elif len(attn_mask.shape) == 3: + attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape( + bsz * self.num_heads, tgt_len, src_len) + attn_weights = attn_weights + attn_mask + + if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.masked_fill( + enc_dec_attn_constraint_mask.unsqueeze(2).bool(), + -1e8, + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + -1e8, + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v + + attn_weights_float = softmax(attn_weights, dim=-1) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) + + if reset_attn_weight is not None: + if reset_attn_weight: + self.last_attn_probs = attn_probs.detach() + else: + assert self.last_attn_probs is not None + attn_probs = self.last_attn_probs + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + + if need_weights: + attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + else: + attn_weights = None + + return attn, (attn_weights, attn_logits) + + +class RotMultiheadAttention2(MultiheadAttention): + def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True, + add_bias_kv=False, add_zero_attn=False, self_attention=False, + encoder_decoder_attention=False): + super().__init__(embed_dim, num_heads, kdim=kdim, vdim=vdim, dropout=dropout, bias=bias, + add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=self_attention, + encoder_decoder_attention=encoder_decoder_attention) + self.rotary_embeds = RotaryEmbeddings(width=embed_dim // num_heads) + + def forward( + self, + query, key, value, + spk_pos_ids_flat=None, + key_padding_mask=None, + incremental_state=None, + need_weights=True, + static_kv=False, + attn_mask=None, + before_softmax=False, + need_head_weights=False, + enc_dec_attn_constraint_mask=None, + reset_attn_weight=None + ): + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if 'prev_key' in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + # self-attention + q, k, v = self.in_proj_qkv(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.in_proj_q(query) + if key is None: + assert value is None + k = v = None + else: + k = self.in_proj_k(key) + v = self.in_proj_v(key) + else: + q = self.in_proj_q(query) + k = self.in_proj_k(key) + v = self.in_proj_v(value) + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1) + + q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + + # Apply rot embedding and store incremental_state + q = self.rotary_embeds(q[None, :], positions=spk_pos_ids_flat)[0] + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if 'prev_key' in saved_state: + prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + k = torch.cat((prev_key, k), dim=1) + if 'prev_value' in saved_state: + prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + v = torch.cat((prev_value, v), dim=1) + saved_state['prev_key'], saved_state['prev_value'] = k.view(bsz, self.num_heads, -1, self.head_dim), v.view( + bsz, self.num_heads, -1, self.head_dim) + self._set_input_buffer(incremental_state, saved_state) + key_pos = torch.arange(k.shape[-2], device=q.device).unsqueeze(0) + k = self.rotary_embeds(k[None, :], positions=key_pos)[0] + + src_len = k.size(1) + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]): + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if attn_mask is not None: + if len(attn_mask.shape) == 2: + attn_mask = attn_mask.unsqueeze(0) + elif len(attn_mask.shape) == 3: + attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape( + bsz * self.num_heads, tgt_len, src_len) + attn = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, dropout_p=0, is_causal=False) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_logits = None + attn_weights = None + return attn, (attn_weights, attn_logits) + + +class RotDecSALayer(nn.Module): + def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1, + kernel_size=9, ffn_hidden_size=1024, act='gelu', post_ln=False, bias=True): + super().__init__() + self.c = c + self.dropout = dropout + self.layer_norm1 = LayerNorm(c) + self.self_attn = RotMultiheadAttention( + c, num_heads, self_attention=True, dropout=attention_dropout, bias=False + ) + self.layer_norm2 = LayerNorm(c) + self.ffn = TransformerFFNLayer( + c, ffn_hidden_size, padding='LEFT', kernel_size=kernel_size, + dropout=relu_dropout, act=act, bias=bias) + self.post_ln = post_ln + + def forward( + self, + x, + encoder_out=None, + encoder_padding_mask=None, + incremental_state=None, + self_attn_mask=None, + self_attn_padding_mask=None, + attn_out=None, + reset_attn_weight=None, + spk_pos_ids_flat=None, + **kwargs, + ): + layer_norm_training = kwargs.get('layer_norm_training', None) + if layer_norm_training is not None: + self.layer_norm1.training = layer_norm_training + self.layer_norm2.training = layer_norm_training + residual = x + if not self.post_ln: + x = self.layer_norm1(x) + + x, (attn_weights, _) = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + incremental_state=incremental_state, + attn_mask=self_attn_mask, + spk_pos_ids_flat=spk_pos_ids_flat + ) + x = F.dropout(x, self.dropout, training=self.training) + x = residual + x + if self.post_ln: + x = self.layer_norm1(x) + + residual = x + if not self.post_ln: + x = self.layer_norm2(x) + x = self.ffn(x, incremental_state=incremental_state) + x = F.dropout(x, self.dropout, training=self.training) + x = residual + x + if self.post_ln: + x = self.layer_norm2(x) + return x, attn_weights + + def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None): + self.encoder_attn.clear_buffer(incremental_state) + self.ffn.clear_buffer(incremental_state) + + def set_buffer(self, name, tensor, incremental_state): + return set_incremental_state(self, incremental_state, name, tensor) + + +class RotDecSALayer2(RotDecSALayer): + def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1, kernel_size=9, + ffn_hidden_size=1024, act='gelu', post_ln=False): + super().__init__(c, num_heads, dropout, attention_dropout, relu_dropout, kernel_size, ffn_hidden_size, act, + post_ln) + self.self_attn = RotMultiheadAttention2( + c, num_heads, self_attention=True, dropout=attention_dropout, bias=False + ) + + +class RotTransformerDecoderLayer(nn.Module): + def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=8, ffn_hidden_size=1024, post_ln=False, + op_version=1, bias=True): + super().__init__() + self.hidden_size = hidden_size + self.dropout = dropout + self.num_heads = num_heads + if op_version == 1: + self.op = RotDecSALayer( + hidden_size, num_heads, dropout=dropout, + attention_dropout=0.0, relu_dropout=dropout, + kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size, + post_ln=post_ln, bias=bias) + else: + self.op = RotDecSALayer2( + hidden_size, num_heads, dropout=dropout, + attention_dropout=0.0, relu_dropout=dropout, + kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size, + post_ln=post_ln) + + def forward(self, x, **kwargs): + return self.op(x, **kwargs) + + def clear_buffer(self, *args): + return self.op.clear_buffer(*args) + + def set_buffer(self, *args): + return self.op.set_buffer(*args) diff --git a/MegaTTS3/tts/modules/ar_dur/commons/seq_utils.py b/MegaTTS3/tts/modules/ar_dur/commons/seq_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cdddaccab9bc56e41b28ce0265952e82e3d4092b --- /dev/null +++ b/MegaTTS3/tts/modules/ar_dur/commons/seq_utils.py @@ -0,0 +1,342 @@ +# Copyright 2025 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +import torch +import torch.nn.functional as F + + +def make_positions(tensor, padding_idx): + """Replace non-padding symbols with their position numbers. + + Position numbers begin at padding_idx+1. Padding symbols are ignored. + """ + # The series of casts and type-conversions here are carefully + # balanced to both work with ONNX export and XLA. In particular XLA + # prefers ints, cumsum defaults to output longs, and ONNX doesn't know + # how to handle the dtype kwarg in cumsum. + mask = tensor.ne(padding_idx).int() + return ( + torch.cumsum(mask, dim=1).type_as(mask) * mask + ).long() + padding_idx + + +def softmax(x, dim): + return F.softmax(x, dim=dim, dtype=torch.float32) + + +def sequence_mask(lengths, maxlen=None, dtype=torch.bool): + if maxlen is None: + maxlen = lengths.max() + mask = ~(torch.ones((len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() > lengths).t() + mask.type(dtype) + return mask + + +def weights_nonzero_speech(target): + # target : B x T x mel + # Assign weight 1.0 to all labels except for padding (id=0). + dim = target.size(-1) + return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim) + + +INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0) + + +def _get_full_incremental_state_key(module_instance, key): + module_name = module_instance.__class__.__name__ + + # assign a unique ID to each module instance, so that incremental state is + # not shared across module instances + if not hasattr(module_instance, '_instance_id'): + INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1 + module_instance._instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name] + + return '{}.{}.{}'.format(module_name, module_instance._instance_id, key) + + +def get_incremental_state(module, incremental_state, key): + """Helper for getting incremental state for an nn.Module.""" + full_key = _get_full_incremental_state_key(module, key) + if incremental_state is None or full_key not in incremental_state: + return None + return incremental_state[full_key] + + +def set_incremental_state(module, incremental_state, key, value): + """Helper for setting incremental state for an nn.Module.""" + if incremental_state is not None: + full_key = _get_full_incremental_state_key(module, key) + incremental_state[full_key] = value + + +def fill_with_neg_inf(t): + """FP16-compatible function that fills a tensor with -inf.""" + return t.float().fill_(float('-inf')).type_as(t) + + +def fill_with_neg_inf2(t): + """FP16-compatible function that fills a tensor with -inf.""" + return t.float().fill_(-1e8).type_as(t) + + +def select_attn(attn_logits, type='best'): + """ + + :param attn_logits: [n_layers, B, n_head, T_sp, T_txt] + :return: + """ + encdec_attn = torch.stack(attn_logits, 0).transpose(1, 2) + # [n_layers * n_head, B, T_sp, T_txt] + encdec_attn = (encdec_attn.reshape([-1, *encdec_attn.shape[2:]])).softmax(-1) + if type == 'best': + indices = encdec_attn.max(-1).values.sum(-1).argmax(0) + encdec_attn = encdec_attn.gather( + 0, indices[None, :, None, None].repeat(1, 1, encdec_attn.size(-2), encdec_attn.size(-1)))[0] + return encdec_attn + elif type == 'mean': + return encdec_attn.mean(0) + + +def make_pad_mask(lengths, xs=None, length_dim=-1): + """Make mask tensor containing indices of padded part. + Args: + lengths (LongTensor or List): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. + If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. + See the example. + Returns: + Tensor: Mask tensor containing indices of padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + Examples: + With only lengths. + >>> lengths = [5, 3, 2] + >>> make_non_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + With the reference tensor. + >>> xs = torch.zeros((3, 2, 4)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0], + [0, 0, 0, 0]], + [[0, 0, 0, 1], + [0, 0, 0, 1]], + [[0, 0, 1, 1], + [0, 0, 1, 1]]], dtype=torch.uint8) + >>> xs = torch.zeros((3, 2, 6)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) + With the reference tensor and dimension indicator. + >>> xs = torch.zeros((3, 6, 6)) + >>> make_pad_mask(lengths, xs, 1) + tensor([[[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) + >>> make_pad_mask(lengths, xs, 2) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) + """ + if length_dim == 0: + raise ValueError("length_dim cannot be 0: {}".format(length_dim)) + + if not isinstance(lengths, list): + lengths = lengths.tolist() + bs = int(len(lengths)) + if xs is None: + maxlen = int(max(lengths)) + else: + maxlen = xs.size(length_dim) + + seq_range = torch.arange(0, maxlen, dtype=torch.int64) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) + seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + if xs is not None: + assert xs.size(0) == bs, (xs.size(0), bs) + + if length_dim < 0: + length_dim = xs.dim() + length_dim + # ind = (:, None, ..., None, :, , None, ..., None) + ind = tuple( + slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) + ) + mask = mask[ind].expand_as(xs).to(xs.device) + return mask + + +def make_non_pad_mask(lengths, xs=None, length_dim=-1): + """Make mask tensor containing indices of non-padded part. + Args: + lengths (LongTensor or List): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. + If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. + See the example. + Returns: + ByteTensor: mask tensor containing indices of padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + Examples: + With only lengths. + >>> lengths = [5, 3, 2] + >>> make_non_pad_mask(lengths) + masks = [[1, 1, 1, 1 ,1], + [1, 1, 1, 0, 0], + [1, 1, 0, 0, 0]] + With the reference tensor. + >>> xs = torch.zeros((3, 2, 4)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1], + [1, 1, 1, 1]], + [[1, 1, 1, 0], + [1, 1, 1, 0]], + [[1, 1, 0, 0], + [1, 1, 0, 0]]], dtype=torch.uint8) + >>> xs = torch.zeros((3, 2, 6)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) + With the reference tensor and dimension indicator. + >>> xs = torch.zeros((3, 6, 6)) + >>> make_non_pad_mask(lengths, xs, 1) + tensor([[[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8) + >>> make_non_pad_mask(lengths, xs, 2) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) + """ + return ~make_pad_mask(lengths, xs, length_dim) + + +def get_mask_from_lengths(lengths): + max_len = torch.max(lengths).item() + ids = torch.arange(0, max_len).to(lengths.device) + mask = (ids < lengths.unsqueeze(1)).bool() + return mask + + +def group_hidden_by_segs(h, seg_ids, max_len): + """ + + :param h: [B, T, H] + :param seg_ids: [B, T] + :return: h_ph: [B, T_ph, H] + """ + B, T, H = h.shape + h_gby_segs = h.new_zeros([B, max_len + 1, H]).scatter_add_(1, seg_ids[:, :, None].repeat([1, 1, H]), h) + all_ones = h.new_ones(h.shape[:2]) + cnt_gby_segs = h.new_zeros([B, max_len + 1]).scatter_add_(1, seg_ids, all_ones).contiguous() + h_gby_segs = h_gby_segs[:, 1:] + cnt_gby_segs = cnt_gby_segs[:, 1:] + h_gby_segs = h_gby_segs / torch.clamp(cnt_gby_segs[:, :, None], min=1) + return h_gby_segs, cnt_gby_segs + +def expand_by_repeat_times(source_encoding, lengths): + """ + source_encoding: [T, C] + lengths, list of int, [T,], how many times each token should repeat + return: + expanded_encoding: [T_expand, C] + """ + hid_dim = source_encoding.shape[1] + out2source = [] + for i, length in enumerate(lengths): + out2source += [i for _ in range(length)] + out2source = torch.LongTensor(out2source).to(source_encoding.device) + out2source_ = out2source[:, None].repeat([1, hid_dim]) + expanded_encoding = torch.gather(source_encoding, 0, out2source_) # [B, T, H] + return expanded_encoding + + +def expand_word2ph(word_encoding, ph2word): + word_encoding = F.pad(word_encoding,[0,0,1,0]) + ph2word_ = ph2word[:, :, None].repeat([1, 1, word_encoding.shape[-1]]) + out = torch.gather(word_encoding, 1, ph2word_) # [B, T, H] + return out diff --git a/MegaTTS3/tts/modules/ar_dur/commons/transformer.py b/MegaTTS3/tts/modules/ar_dur/commons/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..0c44ed1b044ca7bcd4cf0021cb6c3718e01548df --- /dev/null +++ b/MegaTTS3/tts/modules/ar_dur/commons/transformer.py @@ -0,0 +1,767 @@ +# Copyright 2025 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import torch +from torch import nn +from torch.nn import Parameter, Linear +from tts.modules.ar_dur.commons.layers import LayerNorm, Embedding +from tts.modules.ar_dur.commons.seq_utils import get_incremental_state, set_incremental_state, softmax, make_positions +import torch.nn.functional as F + +DEFAULT_MAX_SOURCE_POSITIONS = 3000 +DEFAULT_MAX_TARGET_POSITIONS = 3000 + + +class SinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length. + + Padding symbols are ignored. + """ + + def __init__(self, embedding_dim, padding_idx, init_size=1024): + super().__init__() + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.weights = SinusoidalPositionalEmbedding.get_embedding( + init_size, + embedding_dim, + padding_idx, + ) + self.register_buffer('_float_tensor', torch.FloatTensor(1)) + + @staticmethod + def get_embedding(num_embeddings, embedding_dim, padding_idx=None): + """Build sinusoidal embeddings. + + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) + emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + return emb + + def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs): + """Input is expected to be of size [bsz x seqlen].""" + bsz, seq_len = input.shape[:2] + max_pos = self.padding_idx + 1 + seq_len + if self.weights is None or max_pos > self.weights.size(0): + # recompute/expand embeddings if needed + self.weights = SinusoidalPositionalEmbedding.get_embedding( + max_pos, + self.embedding_dim, + self.padding_idx, + ) + self.weights = self.weights.to(self._float_tensor) + + if incremental_state is not None: + # positions is the same for every token when decoding a single step + pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len + return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) + + positions = make_positions(input, self.padding_idx) if positions is None else positions + return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() + + def max_positions(self): + """Maximum number of supported positions.""" + return int(1e5) # an arbitrary large number + + +class TransformerFFNLayer(nn.Module): + def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu', bias=True): + super().__init__() + self.kernel_size = kernel_size + self.dropout = dropout + self.act = act + if padding == 'SAME': + self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, + padding=kernel_size // 2, bias=bias) + elif padding == 'LEFT': + self.ffn_1 = nn.Sequential( + nn.ConstantPad1d((kernel_size - 1, 0), 0.0), + nn.Conv1d(hidden_size, filter_size, kernel_size, bias=bias) + ) + self.ffn_2 = Linear(filter_size, hidden_size, bias=bias) + + def forward(self, x, incremental_state=None): + # x: T x B x C + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if 'prev_input' in saved_state: + prev_input = saved_state['prev_input'] + x = torch.cat((prev_input, x), dim=0) + x = x[-self.kernel_size:] + saved_state['prev_input'] = x + self._set_input_buffer(incremental_state, saved_state) + + x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1) + x = x * self.kernel_size ** -0.5 + + if incremental_state is not None: + x = x[-1:] + if self.act == 'gelu': + x = F.gelu(x) + if self.act == 'relu': + x = F.relu(x) + x = F.dropout(x, self.dropout, training=self.training) + x = self.ffn_2(x) + return x + + def _get_input_buffer(self, incremental_state): + return get_incremental_state( + self, + incremental_state, + 'f', + ) or {} + + def _set_input_buffer(self, incremental_state, buffer): + set_incremental_state( + self, + incremental_state, + 'f', + buffer, + ) + + def clear_buffer(self, incremental_state): + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if 'prev_input' in saved_state: + del saved_state['prev_input'] + self._set_input_buffer(incremental_state, saved_state) + + +class MultiheadAttention(nn.Module): + def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True, + add_bias_kv=False, add_zero_attn=False, self_attention=False, + encoder_decoder_attention=False): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \ + 'value to be of the same size' + + if self.qkv_same_dim: + self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim)) + else: + self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) + self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) + self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) + + if bias: + self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim)) + else: + self.register_parameter('in_proj_bias', None) + + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.reset_parameters() + + self.enable_torch_version = False + self.last_attn_probs = None + + def reset_parameters(self): + if self.qkv_same_dim: + nn.init.xavier_uniform_(self.in_proj_weight) + else: + nn.init.xavier_uniform_(self.k_proj_weight) + nn.init.xavier_uniform_(self.v_proj_weight) + nn.init.xavier_uniform_(self.q_proj_weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.in_proj_bias is not None: + nn.init.constant_(self.in_proj_bias, 0.) + nn.init.constant_(self.out_proj.bias, 0.) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + + def forward( + self, + query, key, value, + key_padding_mask=None, + incremental_state=None, + need_weights=True, + static_kv=False, + attn_mask=None, + before_softmax=False, + need_head_weights=False, + enc_dec_attn_constraint_mask=None, + reset_attn_weight=None + ): + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + + if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None: + if self.qkv_same_dim: + return F.multi_head_attention_forward(query, key, value, + self.embed_dim, self.num_heads, + self.in_proj_weight, + self.in_proj_bias, self.bias_k, self.bias_v, + self.add_zero_attn, self.dropout, + self.out_proj.weight, self.out_proj.bias, + self.training, key_padding_mask, need_weights, + attn_mask) + else: + return F.multi_head_attention_forward(query, key, value, + self.embed_dim, self.num_heads, + torch.empty([0]), + self.in_proj_bias, self.bias_k, self.bias_v, + self.add_zero_attn, self.dropout, + self.out_proj.weight, self.out_proj.bias, + self.training, key_padding_mask, need_weights, + attn_mask, use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight) + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if 'prev_key' in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + # self-attention + q, k, v = self.in_proj_qkv(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.in_proj_q(query) + if key is None: + assert value is None + k = v = None + else: + k = self.in_proj_k(key) + v = self.in_proj_v(key) + + else: + q = self.in_proj_q(query) + k = self.in_proj_k(key) + v = self.in_proj_v(value) + q = q * self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1) + + q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if 'prev_key' in saved_state: + prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + k = torch.cat((prev_key, k), dim=1) + if 'prev_value' in saved_state: + prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + v = torch.cat((prev_value, v), dim=1) + if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None: + prev_key_padding_mask = saved_state['prev_key_padding_mask'] + if static_kv: + key_padding_mask = prev_key_padding_mask + else: + key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1) + + saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state['prev_key_padding_mask'] = key_padding_mask + + self._set_input_buffer(incremental_state, saved_state) + + src_len = k.size(1) + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]): + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + if len(attn_mask.shape) == 2: + attn_mask = attn_mask.unsqueeze(0) + elif len(attn_mask.shape) == 3: + attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape( + bsz * self.num_heads, tgt_len, src_len) + attn_weights = attn_weights + attn_mask + + if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.masked_fill( + enc_dec_attn_constraint_mask.unsqueeze(2).bool(), + -1e8, + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + -1e8, + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v + + attn_weights_float = softmax(attn_weights, dim=-1) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) + + if reset_attn_weight is not None: + if reset_attn_weight: + self.last_attn_probs = attn_probs.detach() + else: + assert self.last_attn_probs is not None + attn_probs = self.last_attn_probs + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + + if need_weights: + attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + else: + attn_weights = None + + return attn, (attn_weights, attn_logits) + + def in_proj_qkv(self, query): + return self._in_proj(query).chunk(3, dim=-1) + + def in_proj_q(self, query): + if self.qkv_same_dim: + return self._in_proj(query, end=self.embed_dim) + else: + bias = self.in_proj_bias + if bias is not None: + bias = bias[:self.embed_dim] + return F.linear(query, self.q_proj_weight, bias) + + def in_proj_k(self, key): + if self.qkv_same_dim: + return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim) + else: + weight = self.k_proj_weight + bias = self.in_proj_bias + if bias is not None: + bias = bias[self.embed_dim:2 * self.embed_dim] + return F.linear(key, weight, bias) + + def in_proj_v(self, value): + if self.qkv_same_dim: + return self._in_proj(value, start=2 * self.embed_dim) + else: + weight = self.v_proj_weight + bias = self.in_proj_bias + if bias is not None: + bias = bias[2 * self.embed_dim:] + return F.linear(value, weight, bias) + + def _in_proj(self, input, start=0, end=None): + weight = self.in_proj_weight + bias = self.in_proj_bias + weight = weight[start:end, :] + if bias is not None: + bias = bias[start:end] + return F.linear(input, weight, bias) + + def _get_input_buffer(self, incremental_state): + return get_incremental_state( + self, + incremental_state, + 'attn_state', + ) or {} + + def _set_input_buffer(self, incremental_state, buffer): + set_incremental_state( + self, + incremental_state, + 'attn_state', + buffer, + ) + + def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz): + return attn_weights + + def clear_buffer(self, incremental_state=None): + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if 'prev_key' in saved_state: + del saved_state['prev_key'] + if 'prev_value' in saved_state: + del saved_state['prev_value'] + self._set_input_buffer(incremental_state, saved_state) + + +class EncSALayer(nn.Module): + def __init__(self, c, num_heads, dropout, attention_dropout=0.1, + relu_dropout=0.1, kernel_size=9, padding='SAME', act='gelu', + ffn_hidden_size=1024): + super().__init__() + self.c = c + self.dropout = dropout + self.num_heads = num_heads + if num_heads > 0: + self.layer_norm1 = LayerNorm(c) + self.self_attn = MultiheadAttention( + self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False) + self.layer_norm2 = LayerNorm(c) + self.ffn = TransformerFFNLayer( + c, ffn_hidden_size, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act) + + def forward(self, x, encoder_padding_mask=None, **kwargs): + layer_norm_training = kwargs.get('layer_norm_training', None) + if layer_norm_training is not None: + self.layer_norm1.training = layer_norm_training + self.layer_norm2.training = layer_norm_training + if self.num_heads > 0: + residual = x + x = self.layer_norm1(x) + x, _, = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=encoder_padding_mask + ) + x = F.dropout(x, self.dropout, training=self.training) + x = residual + x + x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None] + + residual = x + x = self.layer_norm2(x) + x = self.ffn(x) + x = F.dropout(x, self.dropout, training=self.training) + x = residual + x + x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None] + return x + + +class DecSALayer(nn.Module): + def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1, + kernel_size=9, ffn_hidden_size=1024, act='gelu', post_ln=False): + super().__init__() + self.c = c + self.dropout = dropout + self.layer_norm1 = LayerNorm(c) + self.self_attn = MultiheadAttention( + c, num_heads, self_attention=True, dropout=attention_dropout, bias=False + ) + self.layer_norm2 = LayerNorm(c) + self.encoder_attn = MultiheadAttention( + c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False, + ) + self.layer_norm3 = LayerNorm(c) + self.ffn = TransformerFFNLayer( + c, ffn_hidden_size, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act) + self.post_ln = post_ln + + def forward( + self, + x, + encoder_out=None, + encoder_padding_mask=None, + incremental_state=None, + self_attn_mask=None, + self_attn_padding_mask=None, + attn_out=None, + reset_attn_weight=None, + **kwargs, + ): + layer_norm_training = kwargs.get('layer_norm_training', None) + if layer_norm_training is not None: + self.layer_norm1.training = layer_norm_training + self.layer_norm2.training = layer_norm_training + self.layer_norm3.training = layer_norm_training + residual = x + if not self.post_ln: + x = self.layer_norm1(x) + x, _ = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + incremental_state=incremental_state, + attn_mask=self_attn_mask + ) + x = F.dropout(x, self.dropout, training=self.training) + x = residual + x + if self.post_ln: + x = self.layer_norm1(x) + + attn_logits = None + if encoder_out is not None or attn_out is not None: + residual = x + if not self.post_ln: + x = self.layer_norm2(x) + if encoder_out is not None: + x, attn = self.encoder_attn( + query=x, + key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask, + incremental_state=incremental_state, + static_kv=True, + enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state, + 'enc_dec_attn_constraint_mask'), + reset_attn_weight=reset_attn_weight + ) + attn_logits = attn[1] + elif attn_out is not None: + x = self.encoder_attn.in_proj_v(attn_out) + if encoder_out is not None or attn_out is not None: + x = F.dropout(x, self.dropout, training=self.training) + x = residual + x + if self.post_ln: + x = self.layer_norm2(x) + + residual = x + if not self.post_ln: + x = self.layer_norm3(x) + x = self.ffn(x, incremental_state=incremental_state) + x = F.dropout(x, self.dropout, training=self.training) + x = residual + x + if self.post_ln: + x = self.layer_norm3(x) + return x, attn_logits + + def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None): + self.encoder_attn.clear_buffer(incremental_state) + self.ffn.clear_buffer(incremental_state) + + def set_buffer(self, name, tensor, incremental_state): + return set_incremental_state(self, incremental_state, name, tensor) + + +class TransformerEncoderLayer(nn.Module): + def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2, ffn_hidden_size=1024): + super().__init__() + self.hidden_size = hidden_size + self.dropout = dropout + self.num_heads = num_heads + self.op = EncSALayer( + hidden_size, num_heads, dropout=dropout, + attention_dropout=0.0, relu_dropout=dropout, + kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size) + + def forward(self, x, **kwargs): + return self.op(x, **kwargs) + + +class TransformerDecoderLayer(nn.Module): + def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2, ffn_hidden_size=1024, post_ln=False): + super().__init__() + self.hidden_size = hidden_size + self.dropout = dropout + self.num_heads = num_heads + self.op = DecSALayer( + hidden_size, num_heads, dropout=dropout, + attention_dropout=0.0, relu_dropout=dropout, + kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size, + post_ln=post_ln) + + def forward(self, x, **kwargs): + return self.op(x, **kwargs) + + def clear_buffer(self, *args): + return self.op.clear_buffer(*args) + + def set_buffer(self, *args): + return self.op.set_buffer(*args) + + +class FFTBlocks(nn.Module): + def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=0.0, + num_heads=2, use_pos_embed=True, use_last_norm=True, + use_pos_embed_alpha=True, ffn_hidden_size=1024): + super().__init__() + self.num_layers = num_layers + embed_dim = self.hidden_size = hidden_size + self.dropout = dropout + self.use_pos_embed = use_pos_embed + self.use_last_norm = use_last_norm + if use_pos_embed: + self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS + self.padding_idx = 0 + self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1 + self.embed_positions = SinusoidalPositionalEmbedding( + embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS, + ) + + self.layers = nn.ModuleList([]) + self.layers.extend([ + TransformerEncoderLayer(self.hidden_size, self.dropout, + kernel_size=ffn_kernel_size, num_heads=num_heads, + ffn_hidden_size=ffn_hidden_size) + for _ in range(self.num_layers) + ]) + if self.use_last_norm: + self.layer_norm = nn.LayerNorm(embed_dim) + else: + self.layer_norm = None + + def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False): + """ + :param x: [B, T, C] + :param padding_mask: [B, T] + :return: [B, T, C] or [L, B, T, C] + """ + padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask + nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1] + if self.use_pos_embed: + positions = self.pos_embed_alpha * self.embed_positions(x[..., 0]) + x = x + positions + x = F.dropout(x, p=self.dropout, training=self.training) + # B x T x C -> T x B x C + x = x.transpose(0, 1) * nonpadding_mask_TB + hiddens = [] + for layer in self.layers: + x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB + hiddens.append(x) + if self.use_last_norm: + x = self.layer_norm(x) * nonpadding_mask_TB + if return_hiddens: + x = torch.stack(hiddens, 0) # [L, T, B, C] + x = x.transpose(1, 2) # [L, B, T, C] + else: + x = x.transpose(0, 1) # [B, T, C] + return x + + +class FastSpeechEncoder(FFTBlocks): + def __init__(self, dict_size, hidden_size=256, num_layers=4, kernel_size=9, + dropout=0.0, num_heads=2, ffn_hidden_size=1024): + super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads, + use_pos_embed=False, dropout=dropout, ffn_hidden_size=ffn_hidden_size) + self.embed_tokens = Embedding(dict_size, hidden_size, 0) + self.embed_scale = math.sqrt(hidden_size) + self.padding_idx = 0 + self.embed_positions = SinusoidalPositionalEmbedding( + hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS, + ) + + def forward(self, txt_tokens, attn_mask=None, other_embeds=0): + """ + + :param txt_tokens: [B, T] + :return: { + 'encoder_out': [B x T x C] + } + """ + encoder_padding_mask = txt_tokens.eq(self.padding_idx).data + x = self.forward_embedding(txt_tokens) + other_embeds # [B, T, H] + if self.num_layers > 0: + x = super(FastSpeechEncoder, self).forward(x, encoder_padding_mask, attn_mask=attn_mask) + return x + + def forward_embedding(self, txt_tokens): + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(txt_tokens) + if self.use_pos_embed: + positions = self.embed_positions(txt_tokens) + x = x + positions + x = F.dropout(x, p=self.dropout, training=self.training) + return x diff --git a/MegaTTS3/tts/modules/llm_dit/cfm.py b/MegaTTS3/tts/modules/llm_dit/cfm.py new file mode 100644 index 0000000000000000000000000000000000000000..bb01732a8c270ae16a6157f34e421a5dbe88dd58 --- /dev/null +++ b/MegaTTS3/tts/modules/llm_dit/cfm.py @@ -0,0 +1,309 @@ +# MIT License + +# Copyright (c) 2023 Alexander Tong + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# Copyright (c) [2023] [Alexander Tong] +# Copyright (c) [2025] [Ziyue Jiang] +# SPDX-License-Identifier: MIT +# This file has been modified by Ziyue Jiang on 2025/03/19 +# Original file was released under MIT, with the full license text # available at https://github.com/atong01/conditional-flow-matching/blob/1.0.7/LICENSE. +# This modified file is released under the same license. + +import math +import torch +from typing import Union +from torch.distributions import LogisticNormal + + +class LogitNormalTrainingTimesteps: + def __init__(self, T=1000.0, loc=0.0, scale=1.0): + assert T > 0 + self.T = T + self.dist = LogisticNormal(loc, scale) + + def sample(self, size, device): + t = self.dist.sample(size)[..., 0].to(device) + return t + + +def pad_t_like_x(t, x): + """Function to reshape the time vector t by the number of dimensions of x. + + Parameters + ---------- + x : Tensor, shape (bs, *dim) + represents the source minibatch + t : FloatTensor, shape (bs) + + Returns + ------- + t : Tensor, shape (bs, number of x dimensions) + + Example + ------- + x: Tensor (bs, C, W, H) + t: Vector (bs) + pad_t_like_x(t, x): Tensor (bs, 1, 1, 1) + """ + if isinstance(t, (float, int)): + return t + return t.reshape(-1, *([1] * (x.dim() - 1))) + + +class ConditionalFlowMatcher: + """Base class for conditional flow matching methods. This class implements the independent + conditional flow matching methods from [1] and serves as a parent class for all other flow + matching methods. + + It implements: + - Drawing data from gaussian probability path N(t * x1 + (1 - t) * x0, sigma) function + - conditional flow matching ut(x1|x0) = x1 - x0 + - score function $\nabla log p_t(x|x0, x1)$ + """ + + def __init__(self, sigma: Union[float, int] = 0.0): + r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$. + + Parameters + ---------- + sigma : Union[float, int] + """ + self.sigma = sigma + self.time_sampler = LogitNormalTrainingTimesteps() + + def compute_mu_t(self, x0, x1, t): + """ + Compute the mean of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1]. + + Parameters + ---------- + x0 : Tensor, shape (bs, *dim) + represents the source minibatch + x1 : Tensor, shape (bs, *dim) + represents the target minibatch + t : FloatTensor, shape (bs) + + Returns + ------- + mean mu_t: t * x1 + (1 - t) * x0 + + References + ---------- + [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. + """ + t = pad_t_like_x(t, x0) + return t * x1 + (1 - t) * x0 + + def compute_sigma_t(self, t): + """ + Compute the standard deviation of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1]. + + Parameters + ---------- + t : FloatTensor, shape (bs) + + Returns + ------- + standard deviation sigma + + References + ---------- + [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. + """ + del t + return self.sigma + + def sample_xt(self, x0, x1, t, epsilon): + """ + Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1]. + + Parameters + ---------- + x0 : Tensor, shape (bs, *dim) + represents the source minibatch + x1 : Tensor, shape (bs, *dim) + represents the target minibatch + t : FloatTensor, shape (bs) + epsilon : Tensor, shape (bs, *dim) + noise sample from N(0, 1) + + Returns + ------- + xt : Tensor, shape (bs, *dim) + + References + ---------- + [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. + """ + mu_t = self.compute_mu_t(x0, x1, t) + sigma_t = self.compute_sigma_t(t) + sigma_t = pad_t_like_x(sigma_t, x0) + return mu_t + sigma_t * epsilon + + def compute_conditional_flow(self, x0, x1, t, xt): + """ + Compute the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1]. + + Parameters + ---------- + x0 : Tensor, shape (bs, *dim) + represents the source minibatch + x1 : Tensor, shape (bs, *dim) + represents the target minibatch + t : FloatTensor, shape (bs) + xt : Tensor, shape (bs, *dim) + represents the samples drawn from probability path pt + + Returns + ------- + ut : conditional vector field ut(x1|x0) = x1 - x0 + + References + ---------- + [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. + """ + del t, xt + return x1 - x0 + + def sample_noise_like(self, x): + return torch.randn_like(x) + + def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False): + """ + Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma)) + and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1]. + + Parameters + ---------- + x0 : Tensor, shape (bs, *dim) + represents the source minibatch + x1 : Tensor, shape (bs, *dim) + represents the target minibatch + (optionally) t : Tensor, shape (bs) + represents the time levels + if None, drawn from uniform [0,1] + return_noise : bool + return the noise sample epsilon + + + Returns + ------- + t : FloatTensor, shape (bs) + xt : Tensor, shape (bs, *dim) + represents the samples drawn from probability path pt + ut : conditional vector field ut(x1|x0) = x1 - x0 + (optionally) eps: Tensor, shape (bs, *dim) such that xt = mu_t + sigma_t * epsilon + + References + ---------- + [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. + """ + if t is None: + # t = torch.rand(x0.shape[0]).type_as(x0) + t = self.time_sampler.sample([x0.shape[0]], x0.device).type_as(x0) + + assert len(t) == x0.shape[0], "t has to have batch size dimension" + + eps = self.sample_noise_like(x0) + xt = self.sample_xt(x0, x1, t, eps) + ut = self.compute_conditional_flow(x0, x1, t, xt) + if return_noise: + return t, xt, ut, eps + else: + return t, xt, ut + + def compute_lambda(self, t): + """Compute the lambda function, see Eq.(23) [3]. + + Parameters + ---------- + t : FloatTensor, shape (bs) + + Returns + ------- + lambda : score weighting function + + References + ---------- + [4] Simulation-free Schrodinger bridges via score and flow matching, Preprint, Tong et al. + """ + sigma_t = self.compute_sigma_t(t) + return 2 * sigma_t / (self.sigma**2 + 1e-8) + + +class VariancePreservingConditionalFlowMatcher(ConditionalFlowMatcher): + """Albergo et al. 2023 trigonometric interpolants class. This class inherits the + ConditionalFlowMatcher and override the compute_mu_t and compute_conditional_flow functions in + order to compute [3]'s trigonometric interpolants. + + [3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al. + """ + + def compute_mu_t(self, x0, x1, t): + r"""Compute the mean of the probability path (Eq.5) from [3]. + + Parameters + ---------- + x0 : Tensor, shape (bs, *dim) + represents the source minibatch + x1 : Tensor, shape (bs, *dim) + represents the target minibatch + t : FloatTensor, shape (bs) + + Returns + ------- + mean mu_t: cos(pi t/2)x0 + sin(pi t/2)x1 + + References + ---------- + [3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al. + """ + t = pad_t_like_x(t, x0) + return torch.cos(math.pi / 2 * t) * x0 + torch.sin(math.pi / 2 * t) * x1 + + def compute_conditional_flow(self, x0, x1, t, xt): + r"""Compute the conditional vector field similar to [3]. + + ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(pi*t/2) x0), + see Eq.(21) [3]. + + Parameters + ---------- + x0 : Tensor, shape (bs, *dim) + represents the source minibatch + x1 : Tensor, shape (bs, *dim) + represents the target minibatch + t : FloatTensor, shape (bs) + xt : Tensor, shape (bs, *dim) + represents the samples drawn from probability path pt + + Returns + ------- + ut : conditional vector field + ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(\pi*t/2) x0) + + References + ---------- + [3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al. + """ + del xt + t = pad_t_like_x(t, x0) + return math.pi / 2 * (torch.cos(math.pi / 2 * t) * x1 - torch.sin(math.pi / 2 * t) * x0) diff --git a/MegaTTS3/tts/modules/llm_dit/dit.py b/MegaTTS3/tts/modules/llm_dit/dit.py new file mode 100644 index 0000000000000000000000000000000000000000..a0f30e5e9deeaa661ea1904bf489f38d95d965b3 --- /dev/null +++ b/MegaTTS3/tts/modules/llm_dit/dit.py @@ -0,0 +1,180 @@ +# Copyright 2025 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + +from tts.modules.llm_dit.cfm import ConditionalFlowMatcher +from tts.modules.ar_dur.commons.layers import Embedding +from tts.modules.ar_dur.commons.nar_tts_modules import PosEmb +from tts.modules.ar_dur.commons.rel_transformer import RelTransformerEncoder +from tts.modules.ar_dur.ar_dur_predictor import expand_states +from tts.modules.llm_dit.transformer import Transformer +from tts.modules.llm_dit.time_embedding import TimestepEmbedding + + +class Diffusion(nn.Module): + def __init__(self): + super().__init__() + # Hparams + # cond dim + self.local_cond_dim = 512 + self.ctx_mask_dim = 16 + self.in_channels = 32 + self.out_channels = 32 + # LLM + self.encoder_dim = 1024 + self.encoder_n_layers = 24 + self.encoder_n_heads = 16 + self.max_seq_len = 16384 + self.multiple_of = 256 + + self.ctx_mask_proj = nn.Linear(1, self.ctx_mask_dim) + self.local_cond_project = nn.Linear( + self.out_channels + self.ctx_mask_dim, self.local_cond_dim) + + self.encoder = Transformer(self.encoder_n_layers, self.encoder_dim, self.encoder_n_heads, self.max_seq_len) + + self.x_prenet = nn.Linear(self.in_channels, self.encoder_dim) + self.prenet = nn.Linear(self.local_cond_dim, self.encoder_dim) + self.postnet = nn.Linear(self.encoder_dim, self.out_channels) + + self.flow_matcher = ConditionalFlowMatcher(sigma=0.0) + # The implementation of TimestepEmbedding is a modified version from F5-TTS (https://github.com/SWivid/F5-TTS), + # which is licensed under the MIT License. + self.f5_time_embed = TimestepEmbedding(self.encoder_dim) + + # text encoder + self.ph_encoder = RelTransformerEncoder( + 302, self.encoder_dim, self.encoder_dim, + self.encoder_dim * 2, 4, 6, + 3, 0.0, prenet=True, pre_ln=True) + self.tone_embed = Embedding(32, self.encoder_dim, padding_idx=0) + self.ph_pos_embed = PosEmb(self.encoder_dim) + self.ling_pre_net = torch.nn.Sequential(*[ + torch.nn.Conv1d(self.encoder_dim, self.encoder_dim, kernel_size=s * 2, stride=s, padding=s // 2) + for i, s in enumerate([2, 2]) + ]) + + def forward(self, inputs, sigmas=None, x_noisy=None): + ctx_mask = inputs['ctx_mask'] + ctx_feature = inputs['lat_ctx'] * ctx_mask + + """ local conditioning (prompt_latent + spk_embed) """ + ctx_mask_emb = self.ctx_mask_proj(ctx_mask) + # ctx_feature = ctx_feature * (1 - inputs["spk_cfg_mask"][:, :, None]) + local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1) + local_cond = self.local_cond_project(local_cond) + + """ diffusion target latent """ + x = inputs['lat'] + + # Here, x is x1 in CFM + x0 = torch.randn_like(x) + t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(x0, x) + + # define noisy_input and target + t = t.bfloat16() + x_noisy = (xt * (1 - ctx_mask)).bfloat16() + target = ut + + # concat condition. + x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"]) + x_ling = self.ling_pre_net(expand_states(x_ling, inputs['mel2ph']).transpose(1, 2)).transpose(1, 2) + x_noisy = self.x_prenet(x_noisy) + self.prenet(local_cond) + x_ling + encoder_out = self.encoder(x_noisy, self.f5_time_embed(t), attn_mask=inputs["text_mel_mask"], do_checkpoint=False) + pred = self.postnet(encoder_out) + + return pred, target + + def forward_ling_encoder(self, txt_tokens, tone_tokens): + ph_tokens = txt_tokens + ph_nonpadding = (ph_tokens > 0).float()[:, :, None] # [B, T_phone, 1] + + # enc_ph + ph_enc_oembed = self.tone_embed(tone_tokens) + ph_enc_oembed = ph_enc_oembed + self.ph_pos_embed( + torch.arange(0, ph_tokens.shape[1])[None,].to(ph_tokens.device)) + ph_enc_oembed = ph_enc_oembed + ph_enc_oembed = ph_enc_oembed * ph_nonpadding + x_ling = self.ph_encoder(ph_tokens, other_embeds=ph_enc_oembed) * ph_nonpadding + return x_ling + + def _forward(self, x, local_cond, x_ling, timesteps, ctx_mask, dur=None, seq_cfg_w=[1.0,1.0]): + """ When we use torchdiffeq, we need to include the CFG process inside _forward() """ + x = x * (1 - ctx_mask) + x = self.x_prenet(x) + self.prenet(local_cond) + x_ling + pred_v = self.encoder(x, self.f5_time_embed(timesteps), attn_mask=torch.ones((x.size(0), x.size(1)), device=x.device)) + pred = self.postnet(pred_v) + + """ Perform multi-cond CFG """ + cond_spk_txt, cond_txt, uncond = pred.chunk(3) + pred = uncond + seq_cfg_w[0] * (cond_txt - uncond) + seq_cfg_w[1] * (cond_spk_txt - cond_txt) + return pred + + @torch.no_grad() + def inference(self, inputs, timesteps=20, seq_cfg_w=[1.0, 1.0], **kwargs): + # txt embedding + x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"]) + x_ling = self.ling_pre_net(expand_states(x_ling, inputs['dur']).transpose(1, 2)).transpose(1, 2) + + # speaker embedding + ctx_feature = inputs['lat_ctx'] + ctx_feature[1:, :, :] = 0 # prefix spk cfg + ctx_mask_emb = self.ctx_mask_proj(inputs['ctx_mask']) + + # local conditioning. + local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1) + local_cond = self.local_cond_project(local_cond) + + ''' Euler ODE solver ''' + bsz, device, frm_len = (local_cond.size(0), local_cond.device, local_cond.size(1)) + # Sway sampling from F5-TTS (https://github.com/SWivid/F5-TTS), + # which is licensed under the MIT License. + sway_sampling_coef = -1.0 + t_schedule = torch.linspace(0, 1, timesteps + 1, device=device, dtype=x_ling.dtype) + if sway_sampling_coef is not None: + t_schedule = t_schedule + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_schedule) - 1 + t_schedule) + + # AMO sampling implementation for "AMO Sampler: Enhancing Text Rendering with Overshooting" (https://arxiv.org/pdf/2411.19415) + def amo_sampling(z_t, t, t_next, v): + # Upcast to avoid precision issues when computing prev_sample + z_t = z_t.to(torch.float32) + + # Constant definition in Algorithm 1 + s = t_next + c = 3 + + # Line 7 in Algorithm 1 + o = min(t_next + c * (t_next - t), 1) + pred_z_o = z_t + (o - t) * v + + # Line 11 in Algorithm 1 + a = s / o + b = ((1 - s) ** 2 - (a * (1 - o)) ** 2) ** 0.5 + noise_i = torch.randn(size=z_t.shape, device=z_t.device) + z_t_next = a * pred_z_o + b * noise_i + return z_t_next.to(v.dtype) + + x = torch.randn([1, frm_len, self.out_channels], device=device) + for step_index in range(timesteps): + x = x.to(torch.float32) + sigma = t_schedule[step_index].to(x_ling.dtype) + sigma_next = t_schedule[step_index + 1] + model_out = self._forward(torch.cat([x] * bsz), local_cond, x_ling, timesteps=sigma.unsqueeze(0), ctx_mask=inputs['ctx_mask'], dur=inputs['dur'], seq_cfg_w=seq_cfg_w) + x = amo_sampling(x, sigma, sigma_next, model_out) + # Cast sample back to model compatible dtype + x = x.to(model_out.dtype) + + return x diff --git a/MegaTTS3/tts/modules/llm_dit/time_embedding.py b/MegaTTS3/tts/modules/llm_dit/time_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..98d32c22f55345a1ef171072c0afb17ce11f371f --- /dev/null +++ b/MegaTTS3/tts/modules/llm_dit/time_embedding.py @@ -0,0 +1,44 @@ +# Copyright 2025 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import torch +from torch import nn + + +class SinusPositionEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x, scale=1000): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + +class TimestepEmbedding(nn.Module): + def __init__(self, dim, freq_embed_dim=256): + super().__init__() + self.time_embed = SinusPositionEmbedding(freq_embed_dim) + self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + + def forward(self, timestep): # noqa: F821 + time_hidden = self.time_embed(timestep) + time_hidden = time_hidden.to(timestep.dtype) + time = self.time_mlp(time_hidden) # b d + return time \ No newline at end of file diff --git a/MegaTTS3/tts/modules/llm_dit/transformer.py b/MegaTTS3/tts/modules/llm_dit/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..33fb1e5b3f7ab2acec035c7e935f0bd727651343 --- /dev/null +++ b/MegaTTS3/tts/modules/llm_dit/transformer.py @@ -0,0 +1,230 @@ +# Copyright 2025 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class AdaLNZero(nn.Module): + def __init__(self, dim): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 6) + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, emb=None): + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +class AdaLNZero_Out(nn.Module): + def __init__(self, dim): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 2) + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, emb): + emb = self.linear(self.silu(emb)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +class Attention(nn.Module): + def __init__(self, encoder_dim, encoder_n_heads, max_seq_len): + super().__init__() + self.encoder_n_kv_heads = encoder_n_heads + model_parallel_size = 1 + self.n_local_heads = encoder_n_heads // model_parallel_size + self.n_local_kv_heads = self.encoder_n_kv_heads // model_parallel_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = encoder_dim // encoder_n_heads + + self.wq = nn.Linear( + encoder_dim, + encoder_n_heads * self.head_dim, + ) + self.wk = nn.Linear( + encoder_dim, + self.encoder_n_kv_heads * self.head_dim, + ) + self.wv = nn.Linear( + encoder_dim, + self.encoder_n_kv_heads * self.head_dim, + ) + self.wo = nn.Linear( + encoder_n_heads * self.head_dim, + encoder_dim, + ) + + def forward( + self, + x: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ): + bsz, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + keys = xk.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) + values = xv.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) + + output = F.scaled_dot_product_attention(xq, keys, values, mask[:, None, None, :], is_causal=False) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + super().__init__() + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear( + dim, hidden_dim + ) + self.w2 = nn.Linear( + hidden_dim, dim + ) + + def forward(self, x): + return self.w2(F.silu(self.w1(x))) + + +class TransformerBlock(nn.Module): + def __init__(self, encoder_dim, encoder_n_heads, max_seq_len): + super().__init__() + self.encoder_n_heads = encoder_n_heads + self.encoder_dim = encoder_dim + self.head_dim = encoder_dim // encoder_n_heads + self.attention = Attention(encoder_dim, encoder_n_heads, max_seq_len) + self.feed_forward = FeedForward( + dim=encoder_dim, + hidden_dim=2 * encoder_dim, + multiple_of=256, + ffn_dim_multiplier=None, + ) + self.attention_norm = AdaLNZero(encoder_dim) + self.ffn_norm = nn.LayerNorm(encoder_dim, elementwise_affine=False, eps=1e-6) + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ): + """ + Perform a forward pass through the TransformerBlock. + + Args: + x (torch.Tensor): Input tensor. + start_pos (int): Starting position for attention caching. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None. + + Returns: + torch.Tensor: Output tensor after applying attention and feedforward layers. + + """ + # pre-norm & modulation for attention input + norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attention_norm(x, emb=t) + + # attention + attn_output = self.attention(norm, start_pos, freqs_cis, mask=mask) + + # process attention output for input x + h = x + gate_msa.unsqueeze(1) * attn_output + + norm = self.ffn_norm(h) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.feed_forward(norm) + out = h + gate_mlp.unsqueeze(1) * ff_output + + return out + + +class Transformer(nn.Module): + def __init__(self, encoder_n_layers, encoder_dim, encoder_n_heads, max_seq_len): + super().__init__() + # Decoder + self.layers = torch.nn.ModuleList() + for _ in range(encoder_n_layers): + self.layers.append(TransformerBlock(encoder_dim, encoder_n_heads, max_seq_len)) + + self.norm = AdaLNZero_Out(encoder_dim) + self.out_proj = nn.Linear(encoder_dim, encoder_dim) + + # Rope embedding + freqs_cis = precompute_freqs_cis( + encoder_dim // encoder_n_heads, max_seq_len + ) + self.register_buffer("freqs_cis", torch.view_as_real(freqs_cis), persistent=False) + + def forward(self, x, t, attn_mask, start_pos=0): + freqs_cis = torch.view_as_complex(self.freqs_cis.float())[start_pos: start_pos + x.size(1)] + for i, layer in enumerate(self.layers): + x = layer(x, t, start_pos, freqs_cis, attn_mask) + x = self.norm(x, t) + x = self.out_proj(x) + return x \ No newline at end of file diff --git a/MegaTTS3/tts/modules/wavvae/decoder/diag_gaussian.py b/MegaTTS3/tts/modules/wavvae/decoder/diag_gaussian.py new file mode 100644 index 0000000000000000000000000000000000000000..5ae4050f0a2f7677a3579a020d6100c13372b720 --- /dev/null +++ b/MegaTTS3/tts/modules/wavvae/decoder/diag_gaussian.py @@ -0,0 +1,67 @@ +# Copyright 2025 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import numpy as np + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters: torch.Tensor, deterministic: bool = False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean, device=self.parameters.device, dtype=self.parameters.dtype + ) + + def sample(self, generator=None) -> torch.Tensor: + # make sure sample is on the same device as the parameters and has same dtype + sample = torch.randn( + self.mean.shape, + generator=generator, + device=self.parameters.device, + dtype=self.parameters.dtype, + ) + x = self.mean + self.std * sample + return x + + def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar + else: + return 0.5 * ( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar + ) + + def nll(self, sample, dims) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self) -> torch.Tensor: + return self.mean \ No newline at end of file diff --git a/MegaTTS3/tts/modules/wavvae/decoder/hifigan_modules.py b/MegaTTS3/tts/modules/wavvae/decoder/hifigan_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..9df833a330d8a99559399e2b9e9b22a1bb5501fd --- /dev/null +++ b/MegaTTS3/tts/modules/wavvae/decoder/hifigan_modules.py @@ -0,0 +1,283 @@ +# Copyright 2025 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn as nn +import torch.nn.functional as F +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn +from torch.nn.utils import weight_norm, remove_weight_norm +from torch.nn import Conv1d +import numpy as np + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + + +class Upsample(nn.Module): + def __init__(self, mult, r): + super(Upsample, self).__init__() + self.r = r + self.upsample = nn.Sequential(nn.Upsample(mode="nearest", scale_factor=r), + nn.LeakyReLU(0.2), + nn.ReflectionPad1d(3), + nn.utils.weight_norm(nn.Conv1d(mult, mult // 2, kernel_size=7, stride=1)) + ) + r_kernel = r if r >= 5 else 5 + self.trans_upsample = nn.Sequential(nn.LeakyReLU(0.2), + nn.utils.weight_norm(nn.ConvTranspose1d(mult, mult // 2, + kernel_size=r_kernel * 2, stride=r, + padding=r_kernel - r // 2, + output_padding=r % 2) + )) + + def forward(self, x): + x = torch.sin(x) + x + out1 = self.upsample(x) + out2 = self.trans_upsample(x) + return out1 + out2 + + +class Downsample(nn.Module): + def __init__(self, mult, r): + super(Downsample, self).__init__() + self.r = r + r_kernel = r if r >= 5 else 5 + self.trans_downsample = nn.Sequential(nn.LeakyReLU(0.2), + nn.utils.weight_norm(nn.Conv1d(mult, mult * 2, + kernel_size=r_kernel * 2, stride=r, + padding=r_kernel - r // 2) + )) + + def forward(self, x): + out = self.trans_downsample(x) + return out + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(0.0, 0.02) + elif classname.find("BatchNorm2d") != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + + +def weights_zero_init(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.fill_(0.0) + m.bias.data.fill_(0.0) + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +class Audio2Mel(nn.Module): + def __init__( + self, + hop_length=300, + sampling_rate=24000, + n_mel_channels=80, + mel_fmin=0., + mel_fmax=None, + frame_size=0.05, + device='cpu' + ): + super().__init__() + ############################################## + # FFT Parameters # + ############################################## + + self.n_fft = int(np.power(2., np.ceil(np.log(sampling_rate * frame_size) / np.log(2)))) + window = torch.hann_window(int(sampling_rate * frame_size)).float() + mel_basis = librosa_mel_fn( + sampling_rate, self.n_fft, n_mel_channels, mel_fmin, mel_fmax + ) # Mel filter (by librosa) + mel_basis = torch.from_numpy(mel_basis).float() + self.register_buffer("mel_basis", mel_basis) + self.register_buffer("window", window) + + self.hop_length = hop_length + self.win_length = int(sampling_rate * frame_size) + self.sampling_rate = sampling_rate + self.n_mel_channels = n_mel_channels + + def forward(self, audio): + fft = torch.stft( + audio.squeeze(1), + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + center=True, + ) + real_part, imag_part = fft.unbind(-1) + magnitude = torch.sqrt(torch.clamp(real_part ** 2 + imag_part ** 2, min=1e-5)) + mel_output = torch.matmul(self.mel_basis, magnitude) + + log_mel_spec = 20 * torch.log10(torch.clamp(mel_output, min=1e-5)) - 20 + norm_mel = (log_mel_spec + 115.) / 115. + mel_comp = torch.clamp(norm_mel * 8. - 4., -4., 4.) + + return mel_comp + + +class ResnetBlock(nn.Module): + def __init__(self, dim, dilation=1, dim_in=None): + super().__init__() + if dim_in is None: + dim_in = dim + + self.block = nn.Sequential( + nn.LeakyReLU(0.2), + nn.ReflectionPad1d(dilation), + WNConv1d(dim_in, dim, kernel_size=3, dilation=dilation), + nn.LeakyReLU(0.2), + WNConv1d(dim, dim, kernel_size=1), + ) + self.shortcut = WNConv1d(dim_in, dim, kernel_size=1) + + def forward(self, x): + return self.shortcut(x) + self.block(x) + + +''' +参照hifigan(https://arxiv.org/pdf/2010.05646.pdf)v2结构 +多尺度主要是kernel_size不同,3组并行卷积模块,每个卷积模块内部采用不同的串行dilation size,且中间交叉正常无dilation卷积层 +''' + + +class ResBlockMRFV2(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlockMRFV2, self).__init__() + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, 0.2) + xt = c1(xt) + xt = F.leaky_relu(xt, 0.2) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlockMRFV2Inter(torch.nn.Module): + def __init__(self, channels, kernel_size=3): + super(ResBlockMRFV2Inter, self).__init__() + self.block1 = ResBlockMRFV2(channels) + self.block2 = ResBlockMRFV2(channels, 7) + self.block3 = ResBlockMRFV2(channels, 11) + + def forward(self, x): + xs = self.block1(x) + xs += self.block2(x) + xs += self.block3(x) + x = xs / 3 + return x + + +class Generator(nn.Module): + def __init__(self, input_size_, ngf, n_residual_layers, num_band, args, ratios=[5, 5, 4, 3], onnx_export=False, + device='cpu'): + super().__init__() + self.hop_length = args.frame_shift + self.args = args + self.onnx_export = onnx_export + + # ------------- Define upsample layers ---------------- + mult = int(2 ** len(ratios)) + model_up = [] + input_size = input_size_ + model_up += [ + nn.ReflectionPad1d(3), + WNConv1d(input_size, mult * ngf, kernel_size=7, padding=0), + ] + + # Upsample to raw audio scale + for i, r in enumerate(ratios): + model_up += [Upsample(mult * ngf, r)] + model_up += [ResBlockMRFV2Inter(mult * ngf // 2)] + mult //= 2 + + model_up += [ + nn.LeakyReLU(0.2), + nn.ReflectionPad1d(3), + WNConv1d(ngf, num_band, kernel_size=7, padding=0), + nn.Tanh(), + ] + if not args.use_tanh: + model_up[-1] = nn.Conv1d(num_band, num_band, 1) + model_up[-2].apply(weights_zero_init) + + self.model_up = nn.Sequential(*model_up) + + self.apply(weights_init) + + def forward(self, mel, step=None): + # mel input: (batch_size, seq_num, 80) + if self.onnx_export: + mel = mel.transpose(1, 2) + # on onnx, for engineering, mel input: (batch_size, 80, seq_num) + + # Between Down and up + x = mel + + # Upsample pipline + cnt_after_upsample = 0 + + for i, m in enumerate(self.model_up): + x = m(x) + + if type(m) == Upsample: + cnt_after_upsample += 1 + + return x \ No newline at end of file diff --git a/MegaTTS3/tts/modules/wavvae/decoder/seanet_encoder.py b/MegaTTS3/tts/modules/wavvae/decoder/seanet_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..02da485a83bbad5a42a9ecedd4f791f769b4b0cf --- /dev/null +++ b/MegaTTS3/tts/modules/wavvae/decoder/seanet_encoder.py @@ -0,0 +1,38 @@ +# Copyright 2025 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch +from torch import nn +from tts.modules.wavvae.encoder.common_modules.seanet import SEANetEncoder + +class Encoder(nn.Module): + def __init__( + self, + dowmsamples: List[int] = [6, 5, 5, 4, 2], + ): + super().__init__() + + # breakpoint() + self.frame_rate = 25 # not use + self.encoder = SEANetEncoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2, + dimension=512, channels=1, n_filters=32, ratios=dowmsamples, activation='ELU', + kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2, + true_skip=False, compress=2) + + def forward(self, audio: torch.Tensor): + audio = audio.unsqueeze(1) # audio(16,24000) + emb = self.encoder(audio) + return emb diff --git a/MegaTTS3/tts/modules/wavvae/decoder/wavvae_v3.py b/MegaTTS3/tts/modules/wavvae/decoder/wavvae_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..9a3aa955db357d2d1c5a8bca5bc6740f8bf14920 --- /dev/null +++ b/MegaTTS3/tts/modules/wavvae/decoder/wavvae_v3.py @@ -0,0 +1,60 @@ +# Copyright 2025 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import torch +from torch import nn +import torch.nn.functional as F + +from tts.modules.wavvae.decoder.seanet_encoder import Encoder +from tts.modules.wavvae.decoder.diag_gaussian import DiagonalGaussianDistribution +from tts.modules.wavvae.decoder.hifigan_modules import Generator, Upsample + + +class WavVAE_V3(nn.Module): + def __init__(self, hparams=None): + super().__init__() + self.encoder = Encoder(dowmsamples=[6, 5, 4, 4, 2]) + self.proj_to_z = nn.Linear(512, 64) + self.proj_to_decoder = nn.Linear(32, 320) + + config_path = hparams['melgan_config'] + args = argparse.Namespace() + args.__dict__.update(config_path) + self.latent_upsampler = Upsample(320, 4) + self.decoder = Generator( + input_size_=160, ngf=128, n_residual_layers=4, + num_band=1, args=args, ratios=[5,4,4,3]) + + ''' encode waveform into 25 hz latent representation ''' + def encode_latent(self, audio): + posterior = self.encode(audio) + latent = posterior.sample().permute(0, 2, 1) # (b,t,latent_channel) + return latent + + def encode(self, audio): + x = self.encoder(audio).permute(0, 2, 1) + x = self.proj_to_z(x).permute(0, 2, 1) + poseterior = DiagonalGaussianDistribution(x) + return poseterior + + def decode(self, latent): + latent = self.proj_to_decoder(latent).permute(0, 2, 1) + return self.decoder(self.latent_upsampler(latent)) + + def forward(self, audio): + posterior = self.encode(audio) + latent = posterior.sample().permute(0, 2, 1) # (b, t, latent_channel) + recon_wav = self.decode(latent) + return recon_wav, posterior \ No newline at end of file diff --git a/MegaTTS3/tts/modules/wavvae/encoder/common_modules/conv.py b/MegaTTS3/tts/modules/wavvae/encoder/common_modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..792f1fb51a183d6e0b1f8aba9319145b8874ba98 --- /dev/null +++ b/MegaTTS3/tts/modules/wavvae/encoder/common_modules/conv.py @@ -0,0 +1,154 @@ +# MIT License + +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# Copyright (c) [2023] [Meta Platforms, Inc. and affiliates.] +# Copyright (c) [2025] [Ziyue Jiang] +# SPDX-License-Identifier: MIT +# This file has been modified by Ziyue Jiang on 2025/03/19 +# Original file was released under MIT, with the full license text # available at https://github.com/facebookresearch/encodec/blob/gh-pages/LICENSE. +# This modified file is released under the same license. + +"""Convolutional layers wrappers and utilities.""" + +import math +import typing as tp +import warnings +import einops + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.utils import spectral_norm, weight_norm + + +CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm', + 'time_layer_norm', 'layer_norm', 'time_group_norm']) + + +def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module: + assert norm in CONV_NORMALIZATIONS + if norm == 'weight_norm': + return weight_norm(module) + elif norm == 'spectral_norm': + return spectral_norm(module) + else: + return module + + +def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module: + assert norm in CONV_NORMALIZATIONS + if norm == 'layer_norm': + assert isinstance(module, nn.modules.conv._ConvNd) + return ConvLayerNorm(module.out_channels, **norm_kwargs) + elif norm == 'time_group_norm': + if causal: + raise ValueError("GroupNorm doesn't support causal evaluation.") + assert isinstance(module, nn.modules.conv._ConvNd) + return nn.GroupNorm(1, module.out_channels, **norm_kwargs) + else: + return nn.Identity() + + +def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, + padding_total: int = 0) -> int: + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.): + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == 'reflect': + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +class ConvLayerNorm(nn.LayerNorm): + def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs): + super().__init__(normalized_shape, **kwargs) + + def forward(self, x): + x = einops.rearrange(x, 'b ... t -> b t ...') + x = super().forward(x) + x = einops.rearrange(x, 'b t ... -> b ... t') + return + + +class NormConv1d(nn.Module): + def __init__(self, *args, causal: bool = False, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class SConv1d(nn.Module): + def __init__(self, in_channels: int, out_channels: int, + kernel_size: int, stride: int = 1, dilation: int = 1, + groups: int = 1, bias: bool = True, causal: bool = False, + norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, + pad_mode: str = 'reflect'): + super().__init__() + # warn user on unusual setup between dilation and stride + if stride > 1 and dilation > 1: + warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1' + f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).') + self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride, + dilation=dilation, groups=groups, bias=bias, causal=causal, + norm=norm, norm_kwargs=norm_kwargs) + self.causal = causal + self.pad_mode = pad_mode + + def forward(self, x): + B, C, T = x.shape + kernel_size = self.conv.conv.kernel_size[0] + stride = self.conv.conv.stride[0] + dilation = self.conv.conv.dilation[0] + kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations + padding_total = kernel_size - stride + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + if self.causal: + # Left padding for causal + x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode) + return self.conv(x) \ No newline at end of file diff --git a/MegaTTS3/tts/modules/wavvae/encoder/common_modules/lstm.py b/MegaTTS3/tts/modules/wavvae/encoder/common_modules/lstm.py new file mode 100644 index 0000000000000000000000000000000000000000..18c4173d073ec60fca56187c7d0580ac5150c491 --- /dev/null +++ b/MegaTTS3/tts/modules/wavvae/encoder/common_modules/lstm.py @@ -0,0 +1,51 @@ +# MIT License + +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# Copyright (c) [2023] [Meta Platforms, Inc. and affiliates.] +# Copyright (c) [2025] [Ziyue Jiang] +# SPDX-License-Identifier: MIT +# This file has been modified by Ziyue Jiang on 2025/03/19 +# Original file was released under MIT, with the full license text # available at https://github.com/facebookresearch/encodec/blob/gh-pages/LICENSE. +# This modified file is released under the same license. + +"""LSTM layers module.""" +from torch import nn + + +class SLSTM(nn.Module): + """ + LSTM without worrying about the hidden state, nor the layout of the data. + Expects input as convolutional layout. + """ + def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True): + super().__init__() + self.skip = skip + self.lstm = nn.LSTM(dimension, dimension, num_layers) + + # 修改transpose顺序 + def forward(self, x): + x1 = x.permute(2, 0, 1) + y, _ = self.lstm(x1) + y = y.permute(1, 2, 0) + if self.skip: + y = y + x + return y diff --git a/MegaTTS3/tts/modules/wavvae/encoder/common_modules/seanet.py b/MegaTTS3/tts/modules/wavvae/encoder/common_modules/seanet.py new file mode 100644 index 0000000000000000000000000000000000000000..e7e48073c8562391506f9ddf0ab83393ad268112 --- /dev/null +++ b/MegaTTS3/tts/modules/wavvae/encoder/common_modules/seanet.py @@ -0,0 +1,126 @@ +# MIT License + +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# Copyright (c) [2023] [Meta Platforms, Inc. and affiliates.] +# Copyright (c) [2025] [Ziyue Jiang] +# SPDX-License-Identifier: MIT +# This file has been modified by Ziyue Jiang on 2025/03/19 +# Original file was released under MIT, with the full license text # available at https://github.com/facebookresearch/encodec/blob/gh-pages/LICENSE. +# This modified file is released under the same license. + +"""Encodec SEANet-based encoder and decoder implementation.""" + +import typing as tp + +import numpy as np +import torch.nn as nn + +from .conv import SConv1d +from .lstm import SLSTM + + +class SEANetResnetBlock(nn.Module): + def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1], + activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, + norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False, + pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True): + super().__init__() + assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations' + act = getattr(nn, activation) + hidden = dim // compress + block = [] + for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)): + in_chs = dim if i == 0 else hidden + out_chs = dim if i == len(kernel_sizes) - 1 else hidden + block += [ + act(**activation_params), + SConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation, + norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode), + ] + self.block = nn.Sequential(*block) + self.shortcut: nn.Module + if true_skip: + self.shortcut = nn.Identity() + else: + self.shortcut = SConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode) + + def forward(self, x): + return self.shortcut(x) + self.block(x) + + +class SEANetEncoder(nn.Module): + def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1, + ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, + norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, + last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, + pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2): + super().__init__() + self.channels = channels + self.dimension = dimension + self.n_filters = n_filters + self.ratios = list(reversed(ratios)) + del ratios + self.n_residual_layers = n_residual_layers + self.hop_length = np.prod(self.ratios) + + act = getattr(nn, activation) + mult = 1 + model: tp.List[nn.Module] = [ + SConv1d(channels, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode) + ] + # Downsample to raw audio scale + for i, ratio in enumerate(self.ratios): + # Add residual layers + for j in range(n_residual_layers): + model += [ + SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base ** j, 1], + norm=norm, norm_params=norm_params, + activation=activation, activation_params=activation_params, + causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)] + + # Add downsampling layers + model += [ + act(**activation_params), + SConv1d(mult * n_filters, mult * n_filters * 2, + kernel_size=ratio * 2, stride=ratio, + norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode), + ] + mult *= 2 + + if lstm: + model += [SLSTM(mult * n_filters, num_layers=lstm)] + + model += [ + act(**activation_params), + SConv1d(mult * n_filters, dimension, last_kernel_size, norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode) + ] + + self.model = nn.Sequential(*model) + + def forward(self, x): + return self.model(x) \ No newline at end of file diff --git a/MegaTTS3/tts/utils/audio_utils/align.py b/MegaTTS3/tts/utils/audio_utils/align.py new file mode 100644 index 0000000000000000000000000000000000000000..568c6776a77887699260a8f4ba09c5bb487e325e --- /dev/null +++ b/MegaTTS3/tts/utils/audio_utils/align.py @@ -0,0 +1,36 @@ +# Copyright 2025 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +def mel2token_to_dur(mel2token, T_txt=None, max_dur=None): + is_torch = isinstance(mel2token, torch.Tensor) + has_batch_dim = True + if not is_torch: + mel2token = torch.LongTensor(mel2token) + if T_txt is None: + T_txt = mel2token.max() + if len(mel2token.shape) == 1: + mel2token = mel2token[None, ...] + has_batch_dim = False + B, _ = mel2token.shape + dur = mel2token.new_zeros(B, T_txt + 1).scatter_add(1, mel2token, torch.ones_like(mel2token)) + dur = dur[:, 1:] + if max_dur is not None: + dur = dur.clamp(max=max_dur) + if not is_torch: + dur = dur.numpy() + if not has_batch_dim: + dur = dur[0] + return dur diff --git a/MegaTTS3/tts/utils/audio_utils/io.py b/MegaTTS3/tts/utils/audio_utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..0860b697fde16da781c3fa4e36a918cbecc4f6fd --- /dev/null +++ b/MegaTTS3/tts/utils/audio_utils/io.py @@ -0,0 +1,95 @@ +# Copyright 2025 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import os +import subprocess + +import numpy as np +from scipy.io import wavfile +import pyloudnorm as pyln +from pydub import AudioSegment + + +def to_wav_bytes(wav, sr, norm=False): + wav = wav.astype(float) + if norm: + meter = pyln.Meter(sr) # create BS.1770 meter + loudness = meter.integrated_loudness(wav) + wav = pyln.normalize.loudness(wav, loudness, -18.0) + if np.abs(wav).max() >= 1: + wav = wav / np.abs(wav).max() * 0.95 + wav = wav * 32767 + bytes_io = io.BytesIO() + wavfile.write(bytes_io, sr, wav.astype(np.int16)) + return bytes_io.getvalue() + + +def save_wav(wav_bytes, path): + with open(path[:-4] + '.wav', 'wb') as file: + file.write(wav_bytes) + if path[-4:] == '.mp3': + to_mp3(path[:-4]) + + +def to_mp3(out_path): + if out_path[-4:] == '.wav': + out_path = out_path[:-4] + subprocess.check_call( + f'ffmpeg -threads 1 -loglevel error -i "{out_path}.wav" -vn -b:a 192k -y -hide_banner -async 1 "{out_path}.mp3"', + shell=True, stdin=subprocess.PIPE) + subprocess.check_call(f'rm -f "{out_path}.wav"', shell=True) + + +def convert_to_wav(wav_path): + # Check if the file exists + if not os.path.exists(wav_path): + print(f"The file '{wav_path}' does not exist.") + return + + # Check if the file already has a .wav extension + if not wav_path.endswith(".wav"): + # Define the output path with a .wav extension + out_path = os.path.splitext(wav_path)[0] + ".wav" + + # Load the audio file using pydub and convert it to WAV + audio = AudioSegment.from_file(wav_path) + audio.export(out_path, format="wav") + + print(f"Converted '{wav_path}' to '{out_path}'") + + +def convert_to_wav_bytes(audio_binary): + # Load the audio binary using pydub and convert it to WAV + audio = AudioSegment.from_file(io.BytesIO(audio_binary)) + wav_bytes = io.BytesIO() + audio.export(wav_bytes, format="wav") + wav_bytes.seek(0) + return wav_bytes + + +''' Smoothly combine audio segments using crossfade transitions." ''' +def combine_audio_segments(segments, crossfade_duration=0.16, sr=24000): + window_length = int(sr * crossfade_duration) + hanning_window = np.hanning(2 * window_length) + # Combine + for i, segment in enumerate(segments): + if i == 0: + combined_audio = segment + else: + overlap = combined_audio[-window_length:] * hanning_window[window_length:] + segment[:window_length] * hanning_window[:window_length] + combined_audio = np.concatenate( + [combined_audio[:-window_length], overlap, segment[window_length:]] + ) + return combined_audio \ No newline at end of file diff --git a/MegaTTS3/tts/utils/audio_utils/plot.py b/MegaTTS3/tts/utils/audio_utils/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..c16a5d55bdc1a58ba9fbc860907ede53f01d9be8 --- /dev/null +++ b/MegaTTS3/tts/utils/audio_utils/plot.py @@ -0,0 +1,90 @@ +# Copyright 2025 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import matplotlib + +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import numpy as np +import torch + +LINE_COLORS = ['w', 'r', 'orange', 'k', 'cyan', 'm', 'b', 'lime', 'g', 'brown', 'navy'] + + +def spec_to_figure(spec, vmin=None, vmax=None, title='', f0s=None, dur_info=None, figsize=(12, 6)): + if isinstance(spec, torch.Tensor): + spec = spec.cpu().numpy() + H = spec.shape[1] // 2 + fig = plt.figure(figsize=figsize) + plt.title(title) + plt.pcolor(spec.T, vmin=vmin, vmax=vmax) + + if dur_info is not None: + assert isinstance(dur_info, dict) + txt = dur_info['txt'] + dur_gt = dur_info['dur_gt'] + if isinstance(dur_gt, torch.Tensor): + dur_gt = dur_gt.cpu().numpy() + dur_gt = np.cumsum(dur_gt).astype(int) + for i in range(len(dur_gt)): + shift = (i % 8) + 1 + plt.text(dur_gt[i], shift * 4, txt[i]) + plt.vlines(dur_gt[i], 0, H // 2, colors='b') # blue is gt + plt.xlim(0, dur_gt[-1]) + if 'dur_pred' in dur_info: + dur_pred = dur_info['dur_pred'] + if isinstance(dur_pred, torch.Tensor): + dur_pred = dur_pred.cpu().numpy() + dur_pred = np.cumsum(dur_pred).astype(int) + for i in range(len(dur_pred)): + shift = (i % 8) + 1 + plt.text(dur_pred[i], H + shift * 4, txt[i]) + plt.vlines(dur_pred[i], H, H * 1.5, colors='r') # red is pred + plt.xlim(0, max(dur_gt[-1], dur_pred[-1])) + if f0s is not None: + ax = plt.gca() + ax2 = ax.twinx() + # ax.set_xticks() + + if not isinstance(f0s, dict): + f0s = {'f0': f0s} + for i, (k, f0) in enumerate(f0s.items()): + if f0 is not None: + if isinstance(f0, torch.Tensor): + f0 = f0.cpu().numpy() + ax2.plot( + np.arange(len(f0)) + 0.5, f0, label=k, c=LINE_COLORS[i], linewidth=1, alpha=0.5) + ax2.set_ylim(0, 1000) + ax2.legend() + return fig + + +def align_to_figure(align, dur_info): + if isinstance(align, torch.Tensor): + align = align.cpu().numpy() + H = align.shape[1] + fig = plt.figure(figsize=(12, 6)) + plt.pcolor(align.T, vmin=0, vmax=1) + if dur_info is not None: + assert isinstance(dur_info, dict) + txt = dur_info['txt'] + dur_gt = dur_info['dur_gt'] + if isinstance(dur_gt, torch.Tensor): + dur_gt = dur_gt.cpu().numpy() + dur_gt = np.cumsum(dur_gt).astype(int) // 2 + for i in range(len(dur_gt)): + plt.text(dur_gt[i], i, txt[i], color='red') + plt.vlines(dur_gt[i], 0, H, colors='b') # blue is gt + # plt.xlim(0, dur_gt[-1]) + return fig diff --git a/MegaTTS3/tts/utils/commons/ckpt_utils.py b/MegaTTS3/tts/utils/commons/ckpt_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6e66e27fc18860e414a725b4e07e5d50248f26e2 --- /dev/null +++ b/MegaTTS3/tts/utils/commons/ckpt_utils.py @@ -0,0 +1,171 @@ +# Copyright 2025 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import glob +import os +import re +import subprocess +import traceback + +import torch +from torch.nn.parallel import DistributedDataParallel +import torch.distributed as dist + + +@contextlib.contextmanager +def dist_load(path): + if not dist.is_initialized() or dist.get_world_size() == 1 or os.path.realpath(path).startswith('/dev/shm'): + yield path + else: + from tts.utils.commons.hparams import hparams + from tts.utils.commons.trainer import LOCAL_RANK + tmpdir = '/dev/shm' + assert len(os.path.basename(path)) > 0 + shm_ckpt_path = f'{tmpdir}/{hparams["exp_name"]}/{os.path.basename(path)}' + if LOCAL_RANK == 0: + subprocess.check_call( + f'mkdir -p {os.path.dirname(shm_ckpt_path)}; ' + f'cp -Lr {path} {shm_ckpt_path}', shell=True) + dist.barrier() + yield shm_ckpt_path + dist.barrier() + if LOCAL_RANK == 0: + subprocess.check_call(f'rm -rf {shm_ckpt_path}', shell=True) + + +def torch_load_dist(path, map_location='cpu'): + with dist_load(path) as tmp_path: + checkpoint = torch.load(tmp_path, map_location=map_location) + return checkpoint + + +def get_last_checkpoint(work_dir, steps=None): + checkpoint = None + last_ckpt_path = None + ckpt_paths = get_all_ckpts(work_dir, steps) + if len(ckpt_paths) > 0: + last_ckpt_path = ckpt_paths[0] + checkpoint = torch_load_dist(last_ckpt_path, map_location='cpu') + return checkpoint, last_ckpt_path + + +def get_all_ckpts(work_dir, steps=None): + if steps is None or steps == 0: + ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt' + else: + ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt' + return sorted(glob.glob(ckpt_path_pattern), + key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0])) + + +def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True, + silent=False, load_opt=False, opts=None, steps=None, checkpoint=None, ckpt_path='', delete_unmatch=True): + if checkpoint is None: + if os.path.isfile(ckpt_base_dir): + base_dir = os.path.dirname(ckpt_base_dir) + ckpt_path = ckpt_base_dir + checkpoint = torch_load_dist(ckpt_base_dir, map_location='cpu') + else: + base_dir = ckpt_base_dir + if load_opt: + checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir, steps) + else: + ckpt_path = f'{ckpt_base_dir}/model_only_last.ckpt' + if os.path.exists(ckpt_path): + checkpoint = torch_load_dist(ckpt_path, map_location='cpu') + else: + checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir, steps) + if checkpoint is not None: + state_dict_all = { + k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in checkpoint["state_dict"].items()} + if not isinstance(cur_model, list): + cur_models = [cur_model] + model_names = [model_name] + else: + cur_models = cur_model + model_names = model_name + for model_name, cur_model in zip(model_names, cur_models): + if isinstance(cur_model, DistributedDataParallel): + cur_model = cur_model.module + device = next(cur_model.parameters()).device + if '.' not in model_name: + state_dict = state_dict_all[model_name] + else: + base_model_name = model_name.split('.')[0] + rest_model_name = model_name[len(base_model_name) + 1:] + state_dict = { + k[len(rest_model_name) + 1:]: v for k, v in state_dict_all[base_model_name].items() + if k.startswith(f'{rest_model_name}.')} + state_dict = {k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in state_dict.items()} + if not strict and delete_unmatch: + try: + cur_model.load_state_dict(state_dict, strict=True) + if not silent: + print(f"| loaded '{model_name}' from '{ckpt_path}' with strict=True.") + except: + cur_model_state_dict = cur_model.state_dict() + cur_model_state_dict = {k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in + cur_model_state_dict.items()} + unmatched_keys = [] + for key, param in state_dict.items(): + if key in cur_model_state_dict: + new_param = cur_model_state_dict[key] + if new_param.shape != param.shape: + unmatched_keys.append(key) + print("| Unmatched keys: ", key, "cur model: ", new_param.shape, + "ckpt model: ", param.shape) + for key in unmatched_keys: + del state_dict[key] + load_results = cur_model.load_state_dict(state_dict, strict=strict) + cur_model.to(device) + if not silent: + print(f"| loaded '{model_name}' from '{ckpt_path}'.") + missing_keys, unexpected_keys = load_results.missing_keys, load_results.unexpected_keys + print(f"| Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}") + if load_opt: + optimizer_states = checkpoint['optimizer_states'] + assert len(opts) == len(optimizer_states) + for optimizer, opt_state in zip(opts, optimizer_states): + opt_state = {k.replace('_orig_mod.', ''): v for k, v in opt_state.items()} + if optimizer is None: + return + try: + optimizer.load_state_dict(opt_state) + for i, state in enumerate(optimizer.state.values()): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(device) + except ValueError: + print(f"| WARMING: optimizer {optimizer} parameters not match !!!") + return checkpoint.get('global_step', 0) + else: + e_msg = f"| ckpt not found in {base_dir}." + if force: + assert False, e_msg + else: + print(e_msg) + + +def load_with_size_mismatch(model, state_dict, prefix=""): + current_model_dict = model.state_dict() + cm_keys = current_model_dict.keys() + mismatch_keys = {k.replace(prefix, "") for k, v in state_dict.items() if k.replace(prefix, "") in cm_keys and v.size() != current_model_dict[k.replace(prefix, "")].size()} + new_state_dict = {k.replace(prefix, ""): v for k, v in state_dict.items() if k.replace(prefix, "") in cm_keys and v.size() == current_model_dict[k.replace(prefix, "")].size()} + missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) + print(f"| mismatch keys: ", mismatch_keys) + if len(missing_keys) > 0: + print(f"| missing_keys in dit: {missing_keys}") + if len(unexpected_keys) > 0: + print(f"| unexpected_keys in dit: {unexpected_keys}") diff --git a/MegaTTS3/tts/utils/commons/hparams.py b/MegaTTS3/tts/utils/commons/hparams.py new file mode 100644 index 0000000000000000000000000000000000000000..f6d528388842e66696444ca56c7a59d40aab6d16 --- /dev/null +++ b/MegaTTS3/tts/utils/commons/hparams.py @@ -0,0 +1,215 @@ +# Copyright 2025 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os +import re + +import yaml + +global_print_hparams = True +hparams = {} + + +class Args: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + self.__setattr__(k, v) + + +def override_config(old_config: dict, new_config: dict): + if new_config.get('__replace', False): + old_config.clear() + for k, v in new_config.items(): + if isinstance(v, dict) and k in old_config: + override_config(old_config[k], new_config[k]) + else: + old_config[k] = v + + +def traverse_dict(d, func, ctx): + for k in list(d.keys()): + v = d[k] + if isinstance(v, dict): + traverse_dict(v, func, ctx) + else: + d[k] = func(v, ctx) + + +def parse_config(v, context=None): + if context is None: + context = {} + + if isinstance(v, str): + if v.startswith('^'): + return load_config(v[1:], [], set()) + + match = re.match(r"\${(.*)}", v) + if match: + expression = match.group(1) + return eval(expression, {}, context) + return v + + +def remove_meta_key(d): + for k in list(d.keys()): + v = d[k] + if isinstance(v, dict): + remove_meta_key(v) + else: + if k[:2] == '__': + del d[k] + + +def load_config(config_fn, config_chains, loaded_configs): + # deep first inheritance and avoid the second visit of one node + if not os.path.exists(config_fn): + print(f"| WARN: {config_fn} not exist.", ) + return {} + with open(config_fn) as f: + hparams_ = yaml.safe_load(f) + loaded_configs.add(config_fn) + + if 'base_config' in hparams_: + ret_hparams = {} + if not isinstance(hparams_['base_config'], list): + hparams_['base_config'] = [hparams_['base_config']] + for c in hparams_['base_config']: + if c.startswith('.'): + c = f'{os.path.dirname(config_fn)}/{c}' + c = os.path.normpath(c) + if c not in loaded_configs: + override_config(ret_hparams, load_config(c, config_chains, loaded_configs)) + override_config(ret_hparams, hparams_) + else: + ret_hparams = hparams_ + + config_chains.append(config_fn) + return ret_hparams + + +def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True): + if config == '' and exp_name == '': + parser = argparse.ArgumentParser(description='') + parser.add_argument('--config', type=str, default='', + help='location of the data corpus') + parser.add_argument('--exp_name', type=str, default='', help='exp_name') + parser.add_argument('-hp', '--hparams', type=str, default='', + help='location of the data corpus') + parser.add_argument('--infer', action='store_true', help='infer') + parser.add_argument('--validate', action='store_true', help='validate') + parser.add_argument('--reset', action='store_true', help='reset hparams') + parser.add_argument('--remove', action='store_true', help='remove old ckpt') + parser.add_argument('--debug', action='store_true', help='debug') + parser.add_argument('--start_rank', type=int, default=-1, + help='the start rank id for DDP, keep 0 when single-machine multi-GPU') + parser.add_argument('--world_size', type=int, default=-1, + help='the total number of GPU used across all machines, keep -1 for single-machine multi-GPU') + parser.add_argument('--init_method', type=str, default='tcp', help='method to init ddp, use tcp or file') + parser.add_argument('--master_addr', type=str, default='', help='') + parser.add_argument('--ddp_dir', type=str, default='', help='') + + args, unknown = parser.parse_known_args() + if print_hparams: + print("| set_hparams Unknow hparams: ", unknown) + else: + args = Args(config=config, exp_name=exp_name, hparams=hparams_str, + infer=False, validate=False, reset=False, debug=False, remove=False, + start_rank=-1, world_size=-1, init_method='tcp', ddp_dir='', master_addr='') + global hparams + assert args.config != '' or args.exp_name != '' + if args.config != '': + assert os.path.exists(args.config), f"{args.config} not exists" + + saved_hparams = {} + args_work_dir = '' + if args.exp_name != '': + args_work_dir = f'{args.exp_name}' + ckpt_config_path = f'{args_work_dir}/config.yaml' + if os.path.exists(ckpt_config_path): + with open(ckpt_config_path) as f: + saved_hparams_ = yaml.safe_load(f) + if saved_hparams_ is not None: + saved_hparams.update(saved_hparams_) + hparams_ = {} + config_chains = [] + if args.config != '': + hparams_.update(load_config(args.config, config_chains, set())) + if len(config_chains) > 1 and print_hparams: + print('| Hparams chains: ', config_chains) + if not args.reset: + hparams_.update(saved_hparams) + traverse_dict(hparams_, parse_config, hparams_) + hparams_['work_dir'] = args_work_dir + + # Support config overriding in command line. Support list type config overriding. + # Examples: --hparams="a=1,b.c=2,d=[1 1 1]" + if args.hparams != "": + for new_hparam in args.hparams.split(","): + k, v = new_hparam.split("=") + v = v.strip("\'\" ") + config_node = hparams_ + for k_ in k.split(".")[:-1]: + config_node = config_node[k_] + k = k.split(".")[-1] + if k in config_node: + if v in ['True', 'False'] or type(config_node[k]) in [bool, list, dict]: + if type(config_node[k]) == list: + v = v.replace(" ", ",").replace('^', "\"") + if '|' in v: + tp = type(config_node[k][0]) if len(config_node[k]) else str + config_node[k] = [tp(x) for x in v.split("|") if x != ''] + continue + config_node[k] = eval(v) + else: + config_node[k] = type(config_node[k])(v) + else: + config_node[k] = v + try: + config_node[k] = float(v) + except: + pass + try: + config_node[k] = int(v) + except: + pass + if v.lower() in ['false', 'true']: + config_node[k] = v.lower() == 'true' + + if args_work_dir != '' and not args.infer: + os.makedirs(hparams_['work_dir'], exist_ok=True) + + hparams_['infer'] = args.infer + hparams_['debug'] = args.debug + hparams_['validate'] = args.validate + hparams_['exp_name'] = args.exp_name + + hparams_['start_rank'] = args.start_rank # useful for multi-machine training + hparams_['world_size'] = args.world_size + hparams_['init_method'] = args.init_method + hparams_['ddp_dir'] = args.ddp_dir + hparams_['master_addr'] = args.master_addr + + remove_meta_key(hparams_) + global global_print_hparams + if global_hparams: + hparams.clear() + hparams.update(hparams_) + if print_hparams and global_print_hparams and global_hparams: + print('| Hparams: ', json.dumps(hparams_, indent=2, sort_keys=True)) + # for i, (k, v) in enumerate(sorted(hparams_.items())): + # print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "") + global_print_hparams = False + return hparams_ \ No newline at end of file diff --git a/MegaTTS3/tts/utils/text_utils/dict.json b/MegaTTS3/tts/utils/text_utils/dict.json new file mode 100644 index 0000000000000000000000000000000000000000..15738682f241f124f9026fec7db405b15b3607c9 --- /dev/null +++ b/MegaTTS3/tts/utils/text_utils/dict.json @@ -0,0 +1 @@ +{"phone": ["C0a", "C0ai", "C0air", "C0an", "C0ang", "C0angr", "C0anr", "C0ao", "C0aor", "C0ar", "C0b", "C0c", "C0ch", "C0d", "C0e", "C0ei", "C0eir", "C0en", "C0eng", "C0engr", "C0enr", "C0er", "C0f", "C0g", "C0h", "C0i", "C0ia", "C0ian", "C0iang", "C0iangr", "C0ianr", "C0iao", "C0iaor", "C0iar", "C0ie", "C0ier", "C0ii", "C0iii", "C0iiir", "C0iir", "C0in", "C0ing", "C0ingr", "C0inr", "C0io", "C0iong", "C0iongr", "C0iou", "C0iour", "C0ir", "C0j", "C0k", "C0l", "C0m", "C0n", "C0ng", "C0o", "C0ong", "C0ongr", "C0or", "C0ou", "C0our", "C0p", "C0q", "C0r", "C0s", "C0sh", "C0t", "C0u", "C0ua", "C0uai", "C0uair", "C0uan", "C0uang", "C0uangr", "C0uanr", "C0uar", "C0uei", "C0ueir", "C0uen", "C0ueng", "C0uengr", "C0uenr", "C0uo", "C0uor", "C0ur", "C0v", "C0van", "C0vanr", "C0ve", "C0ver", "C0vn", "C0vnr", "C0vr", "C0x", "C0z", "C0zh", "C0_", "E0aa", "E0ae", "E0ah", "E0ao", "E0aw", "E0ax", "E0ay", "E0b", "E0ch", "E0d", "E0dh", "E0eh", "E0ehr", "E0er", "E0ey", "E0f", "E0g", "E0hh", "E0ih", "E0iy", "E0iyr", "E0jh", "E0k", "E0l", "E0m", "E0n", "E0ng", "E0oh", "E0ow", "E0oy", "E0p", "E0r", "E0s", "E0sh", "E0t", "E0th", "E0uh", "E0uw", "E0uwr", "E0v", "E0w", "E0y", "E0z", "E0zh", "sil", "…", "、", "。", "《", "》", "【", "】", "!", """, "#", "$", "%", "'", "''", "(", ")", "*", ",", ":", ";", "?", "\", "^", "_", "`", "{", "}", "~"], "tone": ["0", "1", "10", "11", "12", "13", "15", "17", "2", "3", "4", "5", "6", "7", "8", "9"], "wordCategory": ["0", "B", "E", "M", "S"], "prosody": ["0", "1", "2", "3", "4"], "focus": ["0", "1"], "intonation": ["0", "1", "2"], "phraseAccent": ["0", "H-", "L-"], "boundaryTone": ["0", "H%", "L%"], "accentType": ["!H*", "0", "H*", "L*", "L*+H", "L+H*"]} diff --git a/MegaTTS3/tts/utils/text_utils/ph_tone_convert.py b/MegaTTS3/tts/utils/text_utils/ph_tone_convert.py new file mode 100644 index 0000000000000000000000000000000000000000..47446b752c2a54e857c3e4c1de399c606225978b --- /dev/null +++ b/MegaTTS3/tts/utils/text_utils/ph_tone_convert.py @@ -0,0 +1,94 @@ +# Copyright 2025 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F + +def map_phone_to_tokendict(item, pad_bos_eos=True): + # Merge Chinese phone and tone (Original dict ends at 173, i.e., ph_dict_size=173). 146~173 is punctuations. + phone = item['txt_token'].clone() + merged_phone = item['txt_token'].clone() + tone_tmp = item['tone'].clone() + # In tone_dict, tone_1 is 4, tone_2 is 11, tone_3 is 12, tone_4 is 13, tone_5 is 14, tone_6 is 15 + tone_tmp[tone_tmp==4] = 1 + tone_tmp[tone_tmp==11] = 2 + tone_tmp[tone_tmp==12] = 3 + tone_tmp[tone_tmp==13] = 4 + tone_tmp[tone_tmp==14] = 5 + tone_tmp[tone_tmp==15] = 6 + # Chinese phones lie in 3~100 in the phone_dict, we map them to 200~788 + ch_phone_idx = (phone >= 3) & (phone <= 100) + merged_phone[ch_phone_idx] = (merged_phone[ch_phone_idx] - 3) * 6 + 200 + tone_tmp[ch_phone_idx] + + if pad_bos_eos: + merged_phone = F.pad(merged_phone, (1, 0), mode='constant', value=798) + merged_phone = F.pad(merged_phone, (0, 1), mode='constant', value=799) + return merged_phone + +def split_ph_timestamp(ph_timestamp): + ''' Input: ph_timestamp, shape [T] ''' + + # Map the timestamp of each phone back to its original frame-level lengths + ph_timestamp[ph_timestamp >= 800] -= 800 + + ph_list = [] + tone_list = [] + dur_list = [] + cur_timestamp = 0 + for idx, item in enumerate(ph_timestamp): + if idx % 2 == 0: + # Map Chinese phones back to its original phone_dict + if (200 <= item <= 788): + ph = (item - 200 - 1) // 6 + 3 + tone = (item - 200 - 1) % 6 + 1 + if tone == 1: + tone = 4 + else: + tone = tone + 9 + # Set English tone to '3' + else: + ph = item + tone = 3 + ph_list.append(ph) + tone_list.append(tone) + else: + dur_list.append((item - cur_timestamp)) + cur_timestamp = item + assert len(ph_list) == len(dur_list), f"{len(ph_list)}, {len(dur_list)}" + ph_seq, tone_seq, dur_seq = torch.LongTensor(ph_list), torch.LongTensor(tone_list), torch.LongTensor(dur_list) + return ph_seq, tone_seq, dur_seq, ph_timestamp[-1] + +def split_ph(ph_seq): + ''' Input: ph_timestamp, shape [T] ''' + ph_list = [] + tone_list = [] + for idx, item in enumerate(ph_seq): + # Map Chinese phones back to its original phone_dict + if (200 <= item <= 788): + ph = (item - 200 - 1) // 6 + 3 + tone = (item - 200 - 1) % 6 + 1 + if tone == 1: + tone = 4 + else: + tone = tone + 9 + # Set English tone to '3' + else: + ph = item + tone = 3 + ph_list.append(ph) + tone_list.append(tone) + + assert len(ph_list) == len(tone_list) + ph_seq, tone_seq = torch.LongTensor(ph_list), torch.LongTensor(tone_list) + return ph_seq, tone_seq \ No newline at end of file diff --git a/MegaTTS3/tts/utils/text_utils/split_text.py b/MegaTTS3/tts/utils/text_utils/split_text.py new file mode 100644 index 0000000000000000000000000000000000000000..436639a77f6cd6bb7abeabee78860e16a310d291 --- /dev/null +++ b/MegaTTS3/tts/utils/text_utils/split_text.py @@ -0,0 +1,226 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +def chunk_text_chinese(text, limit=60): + # 中文字符匹配 + chinese_pattern = re.compile(r'[\u4e00-\u9fff]') + # 标点符号匹配 + punctuation = ",。!?;:,\.!?;" + + result = [] # 存储断句结果 + current_chunk = [] # 当前片段 + chinese_count = 0 # 中文字符计数 + + i = 0 + while i < len(text): + char = text[i] + current_chunk.append(char) + if chinese_pattern.match(char): + chinese_count += 1 + + if chinese_count >= limit: # 达到限制字符数 + # 从当前位置往前找最近的标点符号 + for j in range(len(current_chunk) - 1, -1, -1): + if current_chunk[j] in punctuation: + result.append(''.join(current_chunk[:j + 1])) + current_chunk = current_chunk[j + 1:] + chinese_count = sum(1 for c in current_chunk if chinese_pattern.match(c)) + break + else: + # 如果前面没有标点符号,则继续找后面的标点符号 + for k in range(i + 1, len(text)): + if text[k] in punctuation: + result.append(''.join(current_chunk)+text[i+1:k+1]) + current_chunk = [] + chinese_count = 0 + i = k + break + i+=1 + + # 添加最后剩余的部分 + if current_chunk: + result.append(''.join(current_chunk)) + + return result + +def chunk_text_english(text, max_chars=130): + """ + Splits the input text into chunks, each with a maximum number of characters. + + Args: + text (str): The text to be split. + max_chars (int): The maximum number of characters per chunk. + + Returns: + List[str]: A list of text chunks. + """ + chunks = [] + current_chunk = "" + # Split the text into sentences based on punctuation followed by whitespace + sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text) + + for sentence in sentences: + if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars: + current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence + else: + if current_chunk: + chunks.append(current_chunk.strip()) + current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence + + if current_chunk: + chunks.append(current_chunk.strip()) + + return chunks + + +def chunk_text_chinesev2(text, limit=60, look_ahead_limit=30): + """ + 将中文文本分成多个块,优先确保每个块以句号、感叹号或问号结尾, + 其次考虑逗号等其他标点符号,避免在无标点处断句 + + 参数: + text: 要分块的文本 + limit: 每个块的中文字符数限制 + look_ahead_limit: 向后查找的最大字符数限制 + + 返回: + 分块后的文本列表 + """ + # 中文字符匹配 + chinese_pattern = re.compile(r'[\u4e00-\u9fff]') + + # 分级定义标点符号(优先级从高到低) + primary_end_marks = "。.!!??" # 首选:句号、感叹号、问号 + secondary_end_marks = ",,;;:" # 次选:逗号、分号、冒号 + tertiary_end_marks = "、…—-~~" # 再次:顿号、省略号、破折号等 + + result = [] # 存储断句结果 + current_chunk = [] # 当前片段 + chinese_count = 0 # 中文字符计数 + + i = 0 + while i < len(text): + char = text[i] + current_chunk.append(char) + + if chinese_pattern.match(char): + chinese_count += 1 + + if chinese_count >= limit: # 达到字符数限制,需要寻找断句点 + found_end = False + + # 依次尝试不同优先级的断句策略 + + # 1. 向后查找首选标点 + for k in range(1, min(look_ahead_limit, len(text) - i)): + next_char = text[i + k] + if next_char in primary_end_marks: + result.append(''.join(current_chunk) + text[i+1:i+k+1]) + current_chunk = [] + chinese_count = 0 + i = i + k + found_end = True + break + + if not found_end: + # 2. 向前查找首选标点 + for j in range(len(current_chunk) - 1, -1, -1): + if current_chunk[j] in primary_end_marks: + result.append(''.join(current_chunk[:j + 1])) + current_chunk = current_chunk[j + 1:] + chinese_count = sum(1 for c in current_chunk if chinese_pattern.match(c)) + found_end = True + break + + if not found_end: + # 3. 向后查找次选标点 + for k in range(1, min(look_ahead_limit, len(text) - i)): + next_char = text[i + k] + if next_char in secondary_end_marks: + result.append(''.join(current_chunk) + text[i+1:i+k+1]) + current_chunk = [] + chinese_count = 0 + i = i + k + found_end = True + break + + if not found_end: + # 4. 向前查找次选标点 + for j in range(len(current_chunk) - 1, -1, -1): + if current_chunk[j] in secondary_end_marks: + result.append(''.join(current_chunk[:j + 1])) + current_chunk = current_chunk[j + 1:] + chinese_count = sum(1 for c in current_chunk if chinese_pattern.match(c)) + found_end = True + break + + if not found_end: + # 5. 向后查找三级标点 + for k in range(1, min(look_ahead_limit, len(text) - i)): + next_char = text[i + k] + if next_char in tertiary_end_marks: + result.append(''.join(current_chunk) + text[i+1:i+k+1]) + current_chunk = [] + chinese_count = 0 + i = i + k + found_end = True + break + + if not found_end: + # 6. 向前查找三级标点 + for j in range(len(current_chunk) - 1, -1, -1): + if current_chunk[j] in tertiary_end_marks: + result.append(''.join(current_chunk[:j + 1])) + current_chunk = current_chunk[j + 1:] + chinese_count = sum(1 for c in current_chunk if chinese_pattern.match(c)) + found_end = True + break + + if not found_end: + # 万不得已,在此处断句(这种情况很少见,因为汉语文本中通常会有标点) + result.append(''.join(current_chunk)) + current_chunk = [] + chinese_count = 0 + + i += 1 + + # 添加最后剩余的部分 + if current_chunk: + result.append(''.join(current_chunk)) + + # 英文标点替换为中文标点 + punctuation_map = { + '.': '。', + ',': ',', + '!': '!', + '?': '?', + ';': ';', + ':': ':' + } + + for i in range(len(result)): + for eng_punc, cn_punc in punctuation_map.items(): + result[i] = result[i].replace(eng_punc, cn_punc) + + return result + +if __name__ == '__main__': + print(chunk_text_chinese("哇塞!家人们,你们太好运了。我居然发现了一个宝藏零食大礼包,简直适合所有人的口味!有香辣的,让你舌尖跳舞;有盐焗的,咸香可口;还有五香的,香气四溢。就连怀孕的姐妹都吃得津津有味!整整三十包啊!什么手撕蟹柳、辣子鸡、嫩豆干、手撕素肉、鹌鹑蛋、小肉枣肠、猪肉腐、魔芋、魔芋丝等等,应有尽有。香辣土豆爽辣过瘾,各种素肉嚼劲十足,鹌鹑蛋营养美味,真的太多太多啦,...家人们,现在价格太划算了,赶紧下单。")) + print(chunk_text_english("Washington CNN When President Donald Trump declared in the House Chamber this week that executives at the nation’s top automakers were “so excited” about their prospects amid his new tariff regime, it did not entirely reflect the conversation he’d held with them earlier that day.")) + text = "欢迎收听《TED Talks Daily》,在这里,我们每天为您带来新思想,激发您的好奇心。我是您的主持人,Elise Hugh。当我们去看医生时,医生会评估我们的身体健康状况,检查我们的生命体征,可能还会关注我们的胆固醇水平,确保我们整体处于健康状态。医生可能还会通过一系列问题来检查我们的心理健康。然而,人际交往专家Casley Killam指出,我们在理解健康时忽略了一个关键指标,那就是我们的社会健康。在2024年的演讲中,她解释了为什么人际关系如此重要,以及忽视它可能带来的代价。几年前,我认识的一位女士,我们暂且称她为Maya,在短时间内经历了许多重大变化。她结婚了,和丈夫因工作搬到了一个陌生的城市,在那里她谁也不认识。她开始了一份在家办公的新工作,同时还要应对父亲新确诊的痴呆症。为了应对这些变化带来的压力,Maya加倍关注自己的身心健康。她几乎每天都锻炼,吃健康的食物,每周去看一次心理医生。这些措施确实有帮助,她的身体变得更加强壮,心理也更具韧性,但效果有限。她仍然感到困扰,经常在半夜失眠,白天感到注意力不集中,缺乏动力。Maya做了医生通常建议我们做的所有事情来保持身心健康,但似乎还缺少些什么。如果我告诉你,Maya所缺少的东西,也是全球数十亿人所缺少的,甚至可能也是你所缺少的呢?如果我告诉你,缺乏它会削弱我们为保持健康所做的其他努力,甚至可能缩短你的寿命呢?我研究这个问题已经超过十年,我发现,我们传统上对健康的理解是不完整的。通过将健康主要视为身体和心理的健康,我们忽略了我认为是我们这个时代最大的挑战和机遇——社会健康。身体健康关乎我们的身体,心理健康关乎我们的思想,而社会健康则关乎我们的人际关系。如果你以前没有听说过这个词,那是因为它还没有进入主流词汇,但它同样重要。Maya在她的新家还没有归属感。她不再亲自见到她的家人、朋友或同事,她经常一连几周只和丈夫共度时光。她的故事告诉我们,如果我们只照顾身体和心理,而不关注人际关系,我们就无法完全健康,无法真正茁壮成长。与Maya类似,全球有数亿人连续几周不与任何朋友或家人交谈。全球范围内,有四分之一的人感到孤独。20%的成年人觉得他们没有任何人可以求助。想想看,你遇到的每五个人中,可能有一个人觉得自己孤立无援。这不仅令人心碎,也是一场公共卫生危机。" + for res in chunk_text_chinesev2(text): + print(res) diff --git a/MegaTTS3/tts/utils/text_utils/text_encoder.py b/MegaTTS3/tts/utils/text_utils/text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..5931e98cb0b275ebb412a2c038362f8d72937c6e --- /dev/null +++ b/MegaTTS3/tts/utils/text_utils/text_encoder.py @@ -0,0 +1,280 @@ +# Copyright 2025 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re +import six +from six.moves import range # pylint: disable=redefined-builtin + +PAD = "" +EOS = "" +UNK = "" +SEG = "|" +PUNCS = '!,.?;:' +RESERVED_TOKENS = [PAD, EOS, UNK] +NUM_RESERVED_TOKENS = len(RESERVED_TOKENS) +PAD_ID = RESERVED_TOKENS.index(PAD) # Normally 0 +EOS_ID = RESERVED_TOKENS.index(EOS) # Normally 1 +UNK_ID = RESERVED_TOKENS.index(UNK) # Normally 2 + +if six.PY2: + RESERVED_TOKENS_BYTES = RESERVED_TOKENS +else: + RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")] + +# Regular expression for unescaping token strings. +# '\u' is converted to '_' +# '\\' is converted to '\' +# '\213;' is converted to unichr(213) +_UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);") +_ESCAPE_CHARS = set(u"\\_u;0123456789") + + +def strip_ids(ids, ids_to_strip): + """Strip ids_to_strip from the end ids.""" + ids = list(ids) + while ids and ids[-1] in ids_to_strip: + ids.pop() + return ids + + +class TextEncoder(object): + """Base class for converting from ints to/from human readable strings.""" + + def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS): + self._num_reserved_ids = num_reserved_ids + + @property + def num_reserved_ids(self): + return self._num_reserved_ids + + def encode(self, s): + """Transform a human-readable string into a sequence of int ids. + + The ids should be in the range [num_reserved_ids, vocab_size). Ids [0, + num_reserved_ids) are reserved. + + EOS is not appended. + + Args: + s: human-readable string to be converted. + + Returns: + ids: list of integers + """ + return [int(w) + self._num_reserved_ids for w in s.split()] + + def decode(self, ids, strip_extraneous=False): + """Transform a sequence of int ids into a human-readable string. + + EOS is not expected in ids. + + Args: + ids: list of integers to be converted. + strip_extraneous: bool, whether to strip off extraneous tokens + (EOS and PAD). + + Returns: + s: human-readable string. + """ + if strip_extraneous: + ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) + return " ".join(self.decode_list(ids)) + + def decode_list(self, ids): + """Transform a sequence of int ids into a their string versions. + + This method supports transforming individual input/output ids to their + string versions so that sequence to/from text conversions can be visualized + in a human readable format. + + Args: + ids: list of integers to be converted. + + Returns: + strs: list of human-readable string. + """ + decoded_ids = [] + for id_ in ids: + if 0 <= id_ < self._num_reserved_ids: + decoded_ids.append(RESERVED_TOKENS[int(id_)]) + else: + decoded_ids.append(id_ - self._num_reserved_ids) + return [str(d) for d in decoded_ids] + + @property + def vocab_size(self): + raise NotImplementedError() + + +class TokenTextEncoder(TextEncoder): + """Encoder based on a user-supplied vocabulary (file or list).""" + + def __init__(self, + vocab_filename, + reverse=False, + vocab_list=None, + replace_oov=None, + num_reserved_ids=NUM_RESERVED_TOKENS): + """Initialize from a file or list, one token per line. + + Handling of reserved tokens works as follows: + - When initializing from a list, we add reserved tokens to the vocab. + - When initializing from a file, we do not add reserved tokens to the vocab. + - When saving vocab files, we save reserved tokens to the file. + + Args: + vocab_filename: If not None, the full filename to read vocab from. If this + is not None, then vocab_list should be None. + reverse: Boolean indicating if tokens should be reversed during encoding + and decoding. + vocab_list: If not None, a list of elements of the vocabulary. If this is + not None, then vocab_filename should be None. + replace_oov: If not None, every out-of-vocabulary token seen when + encoding will be replaced by this string (which must be in vocab). + num_reserved_ids: Number of IDs to save for reserved tokens like . + """ + super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids) + self._reverse = reverse + self._replace_oov = replace_oov + if vocab_filename: + self._init_vocab_from_file(vocab_filename) + else: + assert vocab_list is not None + self._init_vocab_from_list(vocab_list) + self.pad_index = self.token_to_id[PAD] + self.eos_index = self.token_to_id[EOS] + self.unk_index = self.token_to_id[UNK] + self.seg_index = self.token_to_id[SEG] if SEG in self.token_to_id else self.eos_index + + def encode(self, s): + """Converts a space-separated string of tokens to a list of ids.""" + if isinstance(s, str): + sentence = s + tokens = sentence.strip().split() + else: + tokens = s + if self._replace_oov is not None: + tokens = [t if t in self.token_to_id else self._replace_oov + for t in tokens] + ret = [self.token_to_id[tok] for tok in tokens] + return ret[::-1] if self._reverse else ret + + def decode(self, ids, strip_eos=False, strip_padding=False): + if strip_padding and self.pad() in list(ids): + pad_pos = list(ids).index(self.pad()) + ids = ids[:pad_pos] + if strip_eos and self.eos() in list(ids): + eos_pos = list(ids).index(self.eos()) + ids = ids[:eos_pos] + return " ".join(self.decode_list(ids)) + + def decode_list(self, ids): + seq = reversed(ids) if self._reverse else ids + return [self._safe_id_to_token(i) for i in seq] + + @property + def vocab_size(self): + return len(self.id_to_token) + + def __len__(self): + return self.vocab_size + + def _safe_id_to_token(self, idx): + return self.id_to_token.get(idx, "ID_%d" % idx) + + def _init_vocab_from_file(self, filename): + """Load vocab from a file. + + Args: + filename: The file to load vocabulary from. + """ + with open(filename) as f: + tokens = [token.strip() for token in f.readlines()] + + def token_gen(): + for token in tokens: + yield token + + self._init_vocab(token_gen(), add_reserved_tokens=False) + + def _init_vocab_from_list(self, vocab_list): + """Initialize tokens from a list of tokens. + + It is ok if reserved tokens appear in the vocab list. They will be + removed. The set of tokens in vocab_list should be unique. + + Args: + vocab_list: A list of tokens. + """ + + def token_gen(): + for token in vocab_list: + if token not in RESERVED_TOKENS: + yield token + + self._init_vocab(token_gen()) + + def _init_vocab(self, token_generator, add_reserved_tokens=True): + """Initialize vocabulary with tokens from token_generator.""" + + self.id_to_token = {} + non_reserved_start_index = 0 + + if add_reserved_tokens: + self.id_to_token.update(enumerate(RESERVED_TOKENS)) + non_reserved_start_index = len(RESERVED_TOKENS) + + self.id_to_token.update( + enumerate(token_generator, start=non_reserved_start_index)) + + # _token_to_id is the reverse of _id_to_token + self.token_to_id = dict((v, k) for k, v in six.iteritems(self.id_to_token)) + + def pad(self): + return self.pad_index + + def eos(self): + return self.eos_index + + def unk(self): + return self.unk_index + + def seg(self): + return self.seg_index + + def store_to_file(self, filename): + """Write vocab file to disk. + + Vocab files have one token per line. The file ends in a newline. Reserved + tokens are written to the vocab file as well. + + Args: + filename: Full path of the file to store the vocab to. + """ + with open(filename, "w") as f: + for i in range(len(self.id_to_token)): + f.write(self.id_to_token[i] + "\n") + + def sil_phonemes(self): + return [p for p in self.id_to_token.values() if is_sil_phoneme(p)] + + +def build_token_encoder(token_list_file): + token_list = json.load(open(token_list_file)) + return TokenTextEncoder(None, vocab_list=token_list, replace_oov='') + + +def is_sil_phoneme(p): + return p == '' or not p[0].isalpha() or p == 'sil' or p == 'sp' or p == 'XX' diff --git a/README.md b/README.md index cd62fcaad3068d715538c610cb391b330674faf4..1ff4e8c4c661e45f7b1fd47267fddf8c44ac8191 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,14 @@ ---- -title: PresentAgent -emoji: 💻 -colorFrom: green -colorTo: indigo -sdk: gradio -sdk_version: 5.36.2 -python_version: 3.11 -app_file: app.py -pinned: false -license: mit ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +--- +title: PresentAgent +emoji: 💻 +colorFrom: red +colorTo: indigo +sdk: gradio +sdk_version: 5.35.0 +python_version: 3.11 +app_file: app.py +pinned: false +license: mit +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..0910b542a955c30ea790f82ce3fbc19c3ce742ce --- /dev/null +++ b/app.py @@ -0,0 +1,910 @@ +import asyncio +import functools +import hashlib +import importlib +import json +import os +import shutil +import tempfile +import sys +import traceback +import uuid +import subprocess +import threading +from contextlib import asynccontextmanager +from copy import deepcopy +from datetime import datetime +from typing import Optional, Tuple +import time +from concurrent.futures import ThreadPoolExecutor + +import gradio as gr +from pdf2image import convert_from_path +from pptx import Presentation as PptxPresentation + +sys.path.append("./") +import pptagent.induct as induct +import pptagent.pptgen as pptgen +from pptagent.document import Document +from pptagent.model_utils import ModelManager, parse_pdf +from pptagent.multimodal import ImageLabler +from pptagent.presentation import Presentation +from pptagent.utils import Config, get_logger, package_join, pjoin, ppt_to_images_async + + +async def run_blocking(func, *args, **kw): + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, functools.partial(func, *args, **kw)) + + +async def run_cmd(cmd: list[str]): + proc = await asyncio.create_subprocess_exec( + *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + ) + stdout, stderr = await proc.communicate() + if proc.returncode != 0: + raise RuntimeError(f"{' '.join(cmd)}\n{stderr.decode()}") + return stdout + + +# Constants +DEBUG = True if len(sys.argv) == 1 else False +RUNS_DIR = package_join("runs") +STAGES = ["PPT Parsing", "PDF Parsing", "PPT Analysis", "PPT Generation", "Success!"] + +# Create a temp directory for Gradio outputs +GRADIO_TEMP_DIR = os.path.join(tempfile.gettempdir(), "gradio_ppt_agent") +os.makedirs(GRADIO_TEMP_DIR, exist_ok=True) + +# Global variables +ppt_video_progress_store: dict[str, dict] = {} +progress_store: dict[str, dict] = {} +models = None # Initialize as None, will be set in main thread +logger = get_logger(__name__) +executor = ThreadPoolExecutor(max_workers=2) + +# 在文件顶部添加默认模板配置 +DEFAULT_TEMPLATES = [ + { + "name": "Template1", + "path": "templates/Template1.pptx", + "preview": "templates/previews/Template1.jpg" + }, + { + "name": "Template2", + "path": "templates/Template2.pptx", + "preview": "templates/previews/Template2.jpg" + }, + { + "name": "Template3", + "path": "templates/Template3.pptx", + "preview": "templates/previews/Template3.jpg" + }, +] + + +# 新增函数:获取默认模板列表 +def get_default_templates(): + """获取可用的默认模板""" + available_templates = [] + base_dir = os.path.dirname(__file__) + + for template in DEFAULT_TEMPLATES: + template_path = os.path.join(base_dir, template["path"]) + preview_path = os.path.join(base_dir, template["preview"]) + + if os.path.exists(template_path) and os.path.exists(preview_path): + available_templates.append({ + "name": template["name"], + "path": template_path, + "preview": preview_path + }) + + return available_templates + + +# 新增函数:模板选择回调 +def select_template(selected_template_name): + """选择默认模板""" + if selected_template_name == "Upload Custom": + return gr.update(visible=True), gr.update(visible=False), None + else: + templates = get_default_templates() + for template in templates: + if template["name"] == selected_template_name: + return gr.update(visible=False), gr.update(visible=True), template["path"] + return gr.update(visible=True), gr.update(visible=False), None + + +# 新增函数:创建模板选择界面 +def create_template_selection(): + """创建模板选择界面""" + templates = get_default_templates() + + with gr.Row(): + with gr.Column(): + gr.Markdown("### Choose Template") + + # 创建模板选择按钮 + template_choices = [template["name"] for template in templates] + ["Upload Custom"] + template_radio = gr.Radio( + choices=template_choices, + value="Upload Custom", + label="Select Template Type" + ) + + # 默认模板预览 + template_preview = gr.Gallery( + value=[[template["preview"], template["name"]] for template in templates], + label="Template Previews", + columns=2, + rows=2, + height="auto", + visible=True + ) + + # 自定义上传区域 + custom_upload = gr.File( + label="Upload Custom PPT Template", + file_types=[".pptx"], + type="filepath", + visible=True + ) + + # 显示选中的模板 + selected_template_display = gr.Textbox( + label="Selected Template", + interactive=False, + visible=False + ) + + return template_radio, template_preview, custom_upload, selected_template_display + + +def copy_to_gradio_safe_path(source_path: str, filename: str = None) -> str: + """ + Copy file to a Gradio-safe location (temp directory) + + Args: + source_path: Path to source file + filename: Optional custom filename, defaults to original filename + + Returns: + Path to copied file in temp directory + """ + if not os.path.exists(source_path): + raise FileNotFoundError(f"Source file not found: {source_path}") + + if filename is None: + filename = os.path.basename(source_path) + + # Create unique filename to avoid conflicts + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + name, ext = os.path.splitext(filename) + safe_filename = f"{name}_{timestamp}{ext}" + + safe_path = os.path.join(GRADIO_TEMP_DIR, safe_filename) + shutil.copy2(source_path, safe_path) + + return safe_path + + +# Initialize models with custom configuration +def init_models(api_key: str = None, api_base: str = None, language_model: str = None, + vision_model: str = None, text_model: str = None): + """Initialize models with custom configuration""" + global models + try: + # Set environment variables if provided + if api_key: + os.environ["OPENAI_API_KEY"] = api_key + if api_base: + os.environ["API_BASE"] = api_base + if language_model: + os.environ["LANGUAGE_MODEL"] = language_model + if vision_model: + os.environ["VISION_MODEL"] = vision_model + if text_model: + os.environ["TEXT_MODEL"] = text_model + + # Initialize models + models = ModelManager( + api_base=api_base, + api_key=api_key, + language_model_name=language_model, + vision_model_name=vision_model, + text_model_name=text_model + ) + + # Test connections in main thread + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + result = loop.run_until_complete(models.test_connections()) + assert result, "Model connection test failed" + + logger.info("Models initialized successfully") + return "✅ Models initialized successfully" + except Exception as e: + error_msg = f"❌ Model initialization failed: {e}" + logger.error(error_msg) + return error_msg + + +class GradioProgressManager: + def __init__(self, task_id: str, stages: list[str], progress_callback=None): + self.task_id = task_id + self.stages = stages + self.progress_callback = progress_callback + self.failed = False + self.current_stage = 0 + self.total_stages = len(stages) + + async def report_progress(self): + self.current_stage += 1 + progress = int((self.current_stage / self.total_stages) * 100) + status = f"Stage: {self.stages[self.current_stage - 1]}" + if self.progress_callback: + self.progress_callback(progress, status) + return progress, status + + async def fail_stage(self, error_message: str): + error_status = f"{self.stages[self.current_stage]} Error: {error_message}" + if self.progress_callback: + self.progress_callback(100, error_status) + self.failed = True + logger.error(f"{self.task_id}: {error_status}") + return error_status + + +def generate_ppt(pptx_file, pdf_file, num_pages, progress=gr.Progress()): + """Generate PPT from template and PDF""" + try: + # Make sure models are initialized + if models is None: + return None, "❌ Please initialize models first in the Configuration tab" + + # Create task ID + task_id = datetime.now().strftime("20%y-%m-%d") + "/" + str(uuid.uuid4()) + logger.info(f"PPT generation task created: {task_id}") + + # Create directories + os.makedirs(pjoin(RUNS_DIR, task_id), exist_ok=True) + + task = { + "numberOfPages": num_pages, + "pptx": "default_template", + } + + # Handle PPT template + if pptx_file is not None: + pptx_blob = open(pptx_file, "rb").read() + pptx_md5 = hashlib.md5(pptx_blob).hexdigest() + task["pptx"] = pptx_md5 + pptx_dir = pjoin(RUNS_DIR, "pptx", pptx_md5) + if not os.path.exists(pptx_dir): + os.makedirs(pptx_dir, exist_ok=True) + with open(pjoin(pptx_dir, "source.pptx"), "wb") as f: + f.write(pptx_blob) + + # Handle PDF + if pdf_file is not None: + pdf_blob = open(pdf_file, "rb").read() + pdf_md5 = hashlib.md5(pdf_blob).hexdigest() + task["pdf"] = pdf_md5 + pdf_dir = pjoin(RUNS_DIR, "pdf", pdf_md5) + if not os.path.exists(pdf_dir): + os.makedirs(pdf_dir, exist_ok=True) + with open(pjoin(pdf_dir, "source.pdf"), "wb") as f: + f.write(pdf_blob) + + else: + return None, "❌ Please provide a PDF file" + + progress_store[task_id] = task + + # Progress callback + def update_progress(prog, status): + progress(prog / 100, desc=status) + + # Run PPT generation directly in main thread event loop + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + final_ppt_path = loop.run_until_complete(ppt_gen_async(task_id, update_progress)) + + if final_ppt_path and os.path.exists(final_ppt_path): + # Copy to Gradio-safe location + safe_path = copy_to_gradio_safe_path(final_ppt_path, "generated_presentation.pptx") + return safe_path, f"✅ PPT generated successfully! Task ID: {task_id}" + else: + return None, "❌ PPT generation failed" + + except Exception as e: + logger.error(f"PPT generation error: {str(e)}") + traceback.print_exc() + return None, f"❌ Error: {str(e)}" + +def generate_ppt_with_template_selection(selected_template_name, custom_pptx_file, pdf_file, num_pages, progress=gr.Progress()): + """Generate PPT with template selection""" + try: + # Make sure models are initialized + if models is None: + return None, "❌ Please initialize models first in the Configuration tab" + + # 确定使用的模板文件 + pptx_file = None + if selected_template_name and selected_template_name != "Upload Custom": + # 使用默认模板 + templates = get_default_templates() + for template in templates: + if template["name"] == selected_template_name: + pptx_file = template["path"] + break + else: + # 使用上传的自定义模板 + pptx_file = custom_pptx_file + + # 调用原有的generate_ppt函数 + return generate_ppt(pptx_file, pdf_file, num_pages, progress) + + except Exception as e: + logger.error(f"PPT generation with template selection error: {str(e)}") + return None, f"❌ Error: {str(e)}" + + +async def ppt_gen_async(task_id: str, progress_callback=None): + """Async PPT generation function""" + try: + if DEBUG: + importlib.reload(induct) + importlib.reload(pptgen) + + task = progress_store[task_id] + pptx_md5 = task["pptx"] + pdf_md5 = task["pdf"] + generation_config = Config(pjoin(RUNS_DIR, task_id)) + pptx_config = Config(pjoin(RUNS_DIR, "pptx", pptx_md5)) + json.dump(task, open(pjoin(generation_config.RUN_DIR, "task.json"), "w")) + + progress_manager = GradioProgressManager(task_id, STAGES, progress_callback) + parsedpdf_dir = pjoin(RUNS_DIR, "pdf", pdf_md5) + ppt_image_folder = pjoin(pptx_config.RUN_DIR, "slide_images") + + if progress_callback: + progress_callback(10, "Task initialized successfully") + + # PPT parsing + presentation = Presentation.from_file( + pjoin(pptx_config.RUN_DIR, "source.pptx"), pptx_config + ) + + if not os.path.exists(ppt_image_folder) or len(os.listdir(ppt_image_folder)) != len(presentation): + await ppt_to_images_async( + pjoin(pptx_config.RUN_DIR, "source.pptx"), ppt_image_folder + ) + + # Handle error slides + for err_idx, _ in presentation.error_history: + error_file = pjoin(ppt_image_folder, f"slide_{err_idx:04d}.jpg") + if os.path.exists(error_file): + os.remove(error_file) + + # Rename slides + for i, slide in enumerate(presentation.slides, 1): + slide.slide_idx = i + old_path = pjoin(ppt_image_folder, f"slide_{slide.real_idx:04d}.jpg") + new_path = pjoin(ppt_image_folder, f"slide_{slide.slide_idx:04d}.jpg") + if os.path.exists(old_path): + os.rename(old_path, new_path) + + # Image labeling + labler = ImageLabler(presentation, pptx_config) + stats_file = pjoin(pptx_config.RUN_DIR, "image_stats.json") + if os.path.exists(stats_file): + image_stats = json.load(open(stats_file, encoding="utf-8")) + labler.apply_stats(image_stats) + else: + await labler.caption_images_async(models.vision_model) + json.dump( + labler.image_stats, + open(stats_file, "w", encoding="utf-8"), + ensure_ascii=False, + indent=4, + ) + await progress_manager.report_progress() + + # PDF parsing + source_md_path = pjoin(parsedpdf_dir, "source.md") + if not os.path.exists(source_md_path): + # Check if we have a PDF file + pdf_file_path = pjoin(RUNS_DIR, "pdf", pdf_md5, "source.pdf") + if os.path.exists(pdf_file_path): + text_content = parse_pdf( + pdf_file_path, + parsedpdf_dir, + models.marker_model, + ) + else: + raise ValueError("No PDF file found") + else: + text_content = open(source_md_path, encoding="utf-8").read() + await progress_manager.report_progress() + + # Document refine + refined_doc_path = pjoin(parsedpdf_dir, "refined_doc.json") + if not os.path.exists(refined_doc_path): + source_doc = await Document.from_markdown_async( + text_content, + models.language_model, + models.vision_model, + parsedpdf_dir, + ) + json.dump( + source_doc.to_dict(), + open(refined_doc_path, "w"), + ensure_ascii=False, + indent=4, + ) + else: + source_doc_dict = json.load(open(refined_doc_path)) + source_doc = Document.from_dict(source_doc_dict, parsedpdf_dir) + await progress_manager.report_progress() + + # Slide Induction + slide_induction_path = pjoin(pptx_config.RUN_DIR, "slide_induction.json") + if not os.path.exists(slide_induction_path): + deepcopy(presentation).save( + pjoin(pptx_config.RUN_DIR, "template.pptx"), layout_only=True + ) + await ppt_to_images_async( + pjoin(pptx_config.RUN_DIR, "template.pptx"), + pjoin(pptx_config.RUN_DIR, "template_images"), + ) + slide_inducter = induct.SlideInducterAsync( + presentation, + ppt_image_folder, + pjoin(pptx_config.RUN_DIR, "template_images"), + pptx_config, + models.image_model, + models.language_model, + models.vision_model, + ) + layout_induction = await slide_inducter.layout_induct() + slide_induction = await slide_inducter.content_induct(layout_induction) + json.dump( + slide_induction, + open(slide_induction_path, "w", encoding="utf-8"), + ensure_ascii=False, + indent=4, + ) + else: + slide_induction = json.load(open(slide_induction_path, encoding="utf-8")) + await progress_manager.report_progress() + + # PPT Generation + ppt_agent = pptgen.PPTAgentAsync( + models.text_model, + models.language_model, + models.vision_model, + error_exit=False, + retry_times=5, + ) + ppt_agent.set_reference( + config=generation_config, + slide_induction=slide_induction, + presentation=presentation, + ) + + prs, _ = await ppt_agent.generate_pres( + source_doc=source_doc, + num_slides=task["numberOfPages"], + ) + + final_path = pjoin(generation_config.RUN_DIR, "final.pptx") + prs.save(final_path) + logger.info(f"{task_id}: generation finished") + await progress_manager.report_progress() + + return final_path + + except Exception as e: + logger.error(f"PPT generation failed: {str(e)}") + traceback.print_exc() + return None + + +def ppt_to_video(ppt_file, progress=gr.Progress()): + """Convert PPT to video presentation""" + try: + task_id = str(uuid.uuid4()) + logger.info(f"PPT2Video task created: {task_id}") + + task_dir = pjoin(RUNS_DIR, "ppt_video", task_id) + os.makedirs(task_dir, exist_ok=True) + + # Copy PPT file + ppt_blob = open(ppt_file, "rb").read() + ppt_path = pjoin(task_dir, "source.pptx") + with open(ppt_path, "wb") as f: + f.write(ppt_blob) + + # Initialize progress + ppt_video_progress_store[task_id] = { + "status": "processing", + "current_step": 1, + "current_slide": 0, + "total_slides": 0, + "progress_percentage": 0, + "task_dir": task_dir, + "ppt_path": ppt_path + } + + # Progress callback + def update_progress(prog, status): + progress(prog, desc=status) + + # Run PPT to video conversion directly in main thread event loop + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + video_path = loop.run_until_complete(process_ppt_to_video_async(task_id, update_progress)) + + if video_path and os.path.exists(video_path): + # Copy to Gradio-safe location + safe_path = copy_to_gradio_safe_path(video_path, "generated_video.mp4") + return safe_path, f"✅ Video generated successfully! Task ID: {task_id}" + else: + return None, "❌ Video generation failed" + + except Exception as e: + logger.error(f"PPT to video error: {str(e)}") + return None, f"❌ Error: {str(e)}" + + +async def process_ppt_to_video_async(task_id: str, progress_callback): + """Process PPT to video asynchronously""" + try: + task_dir = ppt_video_progress_store[task_id]["task_dir"] + ppt_path = ppt_video_progress_store[task_id]["ppt_path"] + + progress_callback(0.1, "Converting PPT to PDF...") + + # Convert PPT to PDF + pdf_path = pjoin(task_dir, "source.pdf") + await run_cmd([ + "libreoffice", "--headless", "--convert-to", "pdf", + ppt_path, "--outdir", task_dir + ]) + + # Convert PDF to images + images_from_path = await run_blocking(convert_from_path, pdf_path) + prs = await run_blocking(PptxPresentation, ppt_path) + + if len(images_from_path) != len(prs.slides): + raise Exception("PPT页数与生成的图片数量不匹配") + + progress_callback(0.2, "Extracting slides...") + + # Generate video segments + video_segments = [] + with tempfile.TemporaryDirectory() as temp_path: + total_slides = len(prs.slides) + for i, (slide, image) in enumerate(zip(prs.slides, images_from_path)): + slide_progress = 0.3 + (i / total_slides) * 0.4 + progress_callback(slide_progress, f"Processing slide {i + 1}/{total_slides}") + + # Get notes + notes = "" + if slide.has_notes_slide: + notes = slide.notes_slide.notes_text_frame.text + if not notes.strip(): + notes = f"This is slide {i + 1}" + + # Save image + image_path = pjoin(temp_path, f"frame_{i}.jpg") + image.save(image_path) + + # Generate audio + audio_path = pjoin(temp_path, f"frame_{i}.wav") + await generate_tts_audio(notes, audio_path) + + # Create video segment + video_segment_path = await create_video_segment( + image_path, audio_path, temp_path, i + ) + video_segments.append(video_segment_path) + + progress_callback(0.8, "Merging video segments...") + + # Merge video segments + output_video_path = pjoin(task_dir, "output.mp4") + await merge_video_segments(video_segments, output_video_path) + + progress_callback(1.0, "Video generation completed!") + + ppt_video_progress_store[task_id]["status"] = "completed" + return output_video_path + + except Exception as e: + logger.error(f"PPT2Video processing failed {task_id}: {e}") + ppt_video_progress_store[task_id]["status"] = "failed" + return None + + +async def generate_tts_audio(text: str, output_path: str): + """Generate TTS audio""" + try: + # Try to use MegaTTS3 if available + sys.path.append(pjoin(os.path.dirname(__file__), "MegaTTS3")) + from tts.infer_cli import MegaTTS3DiTInfer + from tts.utils.audio_utils.io import save_wav + + infer = MegaTTS3DiTInfer(ckpt_root=pjoin(os.path.dirname(__file__), "MegaTTS3", "checkpoints")) + + prompt_audio_path = pjoin(os.path.dirname(__file__), "MegaTTS3", "assets", "English_prompt.wav") + + with open(prompt_audio_path, 'rb') as f: + audio_bytes = f.read() + latent_file = None + potential_npy = os.path.splitext(prompt_audio_path)[0] + '.npy' + if os.path.isfile(potential_npy): + latent_file = potential_npy + resource_context = infer.preprocess(audio_bytes, latent_file) + + wav_bytes = infer.forward( + resource_context, + text, + time_step=32, + p_w=1.6, + t_w=2.5 + ) + + save_wav(wav_bytes, output_path) + + except Exception as e: + logger.error(f"TTS failed: {str(e)}") + # Fallback: create silent audio + import numpy as np + import wave + + sample_rate = 22050 + duration = 3.0 + samples = np.zeros(int(sample_rate * duration), dtype=np.int16) + + with wave.open(output_path, 'w') as wav_file: + wav_file.setnchannels(1) + wav_file.setsampwidth(2) + wav_file.setframerate(sample_rate) + wav_file.writeframes(samples.tobytes()) + + +async def create_video_segment(image_path: str, audio_path: str, temp_path: str, index: int): + """Create video segment from image and audio""" + output_path = pjoin(temp_path, f"segment_{index}.mp4") + await run_cmd([ + "ffmpeg", "-y", "-loop", "1", "-i", image_path, "-i", audio_path, + "-vf", "scale=1920:1080", "-c:v", "libx264", "-tune", "stillimage", + "-c:a", "aac", "-b:a", "192k", "-pix_fmt", "yuv420p", "-shortest", + output_path + ]) + return output_path + + +async def merge_video_segments(video_segments: list[str], output_path: str): + """Merge video segments""" + list_file_path = output_path.replace('.mp4', '_list.txt') + with open(list_file_path, "w") as f: + for seg in video_segments: + f.write(f"file '{seg}'\n") + + await run_cmd([ + "ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", list_file_path, + "-c", "copy", output_path + ]) + os.remove(list_file_path) + + +def cleanup_temp_files(): + """Clean up old temporary files""" + try: + import glob + import time + + # Remove files older than 1 hour + cutoff_time = time.time() - 3600 + for file_path in glob.glob(os.path.join(GRADIO_TEMP_DIR, "*")): + if os.path.getctime(file_path) < cutoff_time: + try: + os.remove(file_path) + logger.info(f"Cleaned up old temp file: {file_path}") + except Exception as e: + logger.warning(f"Failed to remove temp file {file_path}: {e}") + except Exception as e: + logger.warning(f"Cleanup failed: {e}") + + +# Gradio interface +def create_gradio_interface(): + """Create Gradio interface""" + + with gr.Blocks(title="PresentAgent", theme=gr.themes.Soft()) as demo: + gr.Markdown("# PresentAgent - PowerPoint Generation and Presentation Creation") + + with gr.Tabs(): + # Model Configuration Tab + with gr.TabItem("🔧 Configuration"): + gr.Markdown("## Model Configuration") + gr.Markdown( + "Configure your API settings and model parameters before using the PPT generation features.") + + with gr.Row(): + with gr.Column(): + api_key_input = gr.Textbox( + label="API Key", + type="password", + placeholder="Enter your OpenAI API key", + value="" + ) + api_base_input = gr.Textbox( + label="API Base URL", + placeholder="https://api.openai.com/v1", + value="" + ) + + with gr.Row(): + language_model_input = gr.Textbox( + label="Language Model", + placeholder="Model for text generation", + value="gpt-4o" + ) + vision_model_input = gr.Textbox( + label="Vision Model", + placeholder="Model for image processing", + value="gpt-4o" + ) + text_model_input = gr.Textbox( + label="Text Embedding Model", + placeholder="Model for text embeddings", + value="text-embedding-3-small" + ) + + init_btn = gr.Button("Initialize Models", variant="primary", size="lg") + + with gr.Column(): + init_status = gr.Textbox( + label="Initialization Status", + interactive=False, + lines=3 + ) + + gr.Markdown(""" + ### Instructions: + 1. Enter your API key and base URL + 2. Configure model names (defaults are recommended) + 3. Click "Initialize Models" to test the connection + 4. Once initialized, you can use the PPT generation features + """) + + init_btn.click( + fn=init_models, + inputs=[api_key_input, api_base_input, language_model_input, vision_model_input, text_model_input], + outputs=[init_status] + ) + + with gr.TabItem("📊 PPT Generation"): + gr.Markdown("## Generate PowerPoint from Template and PDF") + + with gr.Row(): + with gr.Column(): + # 模板选择区域 + template_radio, template_preview, custom_upload, selected_template_display = create_template_selection() + + # PDF输入 + pdf_input = gr.File( + label="PDF Document", + file_types=[".pdf"], + type="filepath" + ) + + # 页数选择 + num_pages_input = gr.Slider( + minimum=1, + maximum=50, + value=10, + step=1, + label="Number of Slides" + ) + + generate_btn = gr.Button("Generate PPT", variant="primary", size="lg") + + with gr.Column(): + ppt_output = gr.File(label="Generated PPT") + ppt_status = gr.Textbox(label="Status", interactive=False, lines=3) + + # 绑定模板选择事件 + template_radio.change( + fn=select_template, + inputs=[template_radio], + outputs=[custom_upload, selected_template_display, gr.State()] + ) + + # 绑定生成按钮 + generate_btn.click( + fn=generate_ppt_with_template_selection, + inputs=[template_radio, custom_upload, pdf_input, num_pages_input], + outputs=[ppt_output, ppt_status] + ) + + # PPT to Video Tab + with gr.TabItem("🎬 PPT to Presentation"): + gr.Markdown("## Convert PowerPoint to Video Presentation") + + with gr.Row(): + with gr.Column(): + ppt_video_input = gr.File( + label="PowerPoint File", + file_types=[".pptx"], + type="filepath" + ) + + video_btn = gr.Button("Convert to Video", variant="primary", size="lg") + + with gr.Column(): + video_output = gr.File(label="Generated Video") + video_status = gr.Textbox(label="Status", interactive=False, lines=3) + + video_btn.click( + fn=ppt_to_video, + inputs=[ppt_video_input], + outputs=[video_output, video_status] + ) + + return demo + + +def setup_template_directories(): + """设置模板目录结构""" + base_dir = os.path.dirname(__file__) + template_dir = os.path.join(base_dir, "templates") + preview_dir = os.path.join(template_dir, "previews") + + os.makedirs(template_dir, exist_ok=True) + os.makedirs(preview_dir, exist_ok=True) + + logger.info(f"Template directories created at: {template_dir}") + logger.info("Please place your default template files in the templates directory") + logger.info("Please place corresponding preview images in the templates/previews directory") + + +# Main function +if __name__ == "__main__": + # Create runs directory + os.makedirs(RUNS_DIR, exist_ok=True) + os.makedirs(pjoin(RUNS_DIR, "feedback"), exist_ok=True) + + setup_template_directories() + + # Clean up old temp files + cleanup_temp_files() + + # Create and launch Gradio interface + demo = create_gradio_interface() + + # Launch with allowed paths + demo.queue().launch( + server_name="0.0.0.0", + server_port=7860, + share=True, + show_error=True, + allowed_paths=[RUNS_DIR] + ) \ No newline at end of file diff --git a/pptagent/__init__.py b/pptagent/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d222c7182fabb7be5468645f5d8a37e3a0ad8a0e --- /dev/null +++ b/pptagent/__init__.py @@ -0,0 +1,56 @@ +"""PPTAgent: Generating and Evaluating Presentations Beyond Text-to-Slides. + +This package provides tools to automatically generate presentations from documents, +following a two-phase approach of Analysis and Generation. + +For more information, visit: https://github.com/icip-cas/PPTAgent +""" + +__version__ = "0.1.0" +__author__ = "Hao Zheng" +__email__ = "wszh712811@gmail.com" + + +# Check the version of python and python-pptx +import sys + +if sys.version_info < (3, 11): + raise ImportError("You should use Python 3.11 or higher for this project.") + +from packaging.version import Version +from pptx import __version__ as PPTXVersion + +try: + PPTXVersion, Mark = PPTXVersion.split("+") + assert Version(PPTXVersion) >= Version("1.0.4") and Mark == "PPTAgent" +except: + raise ImportError( + "You should install the customized `python-pptx` for this project: Force1ess/python-pptx, but got %s." + % PPTXVersion + ) + +# Import main modules to make them directly accessible when importing the package +from .agent import * +from .apis import * +from .document import * +from .induct import * +from .llms import * +from .model_utils import * +from .multimodal import * +from .pptgen import * +from .presentation import * +from .utils import * + +# Define the top-level exports +__all__ = [ + "agent", + "pptgen", + "document", + "llms", + "presentation", + "utils", + "apis", + "model_utils", + "multimodal", + "induct", +] diff --git a/pptagent/agent.py b/pptagent/agent.py new file mode 100644 index 0000000000000000000000000000000000000000..abab2d4ead3d7bb183713458ad69e923f63b7af6 --- /dev/null +++ b/pptagent/agent.py @@ -0,0 +1,401 @@ +from dataclasses import asdict, dataclass +from functools import partial +from math import ceil +from typing import Optional + +import tiktoken +import yaml +from jinja2 import Environment, StrictUndefined, Template +from PIL import Image +from torch import Tensor, cosine_similarity + +from pptagent.llms import LLM, AsyncLLM +from pptagent.utils import get_json_from_response, package_join + +ENCODING = tiktoken.encoding_for_model("gpt-4o") + + +@dataclass +class Turn: + """ + A class to represent a turn in a conversation. + """ + + id: int + prompt: str + response: str + message: list + retry: int = -1 + images: list[str] = None + input_tokens: int = 0 + output_tokens: int = 0 + embedding: Tensor = None + + def to_dict(self): + return {k: v for k, v in asdict(self).items() if k != "embedding"} + + def calc_token(self): + """ + Calculate the number of tokens for the turn. + """ + if self.images is not None: + self.input_tokens += calc_image_tokens(self.images) + self.input_tokens += len(ENCODING.encode(self.prompt)) + self.output_tokens = len(ENCODING.encode(self.response)) + + def __eq__(self, other): + return self is other + + +class Agent: + """ + An agent, defined by its instruction template and model. + """ + + def __init__( + self, + name: str, + llm_mapping: dict[str, LLM | AsyncLLM], + text_model: Optional[LLM | AsyncLLM] = None, + record_cost: bool = False, + config: Optional[dict] = None, + env: Optional[Environment] = None, + ): + """ + Initialize the Agent. + + Args: + name (str): The name of the role. + env (Environment): The Jinja2 environment. + record_cost (bool): Whether to record the token cost. + llm (LLM): The language model. + config (dict): The configuration. + text_model (LLM): The text embedding model. + """ + self.name = name + self.config = config + if self.config is None: + with open(package_join("roles", f"{name}.yaml"), encoding="utf-8") as f: + self.config = yaml.safe_load(f) + assert isinstance(self.config, dict), "Agent config must be a dict" + self.llm_mapping = llm_mapping + self.llm = self.llm_mapping[self.config["use_model"]] + self.model = self.llm.model + self.record_cost = record_cost + self.text_model = text_model + self.return_json = self.config.get("return_json", False) + self.system_message = self.config["system_prompt"] + self.prompt_args = set(self.config["jinja_args"]) + self.env = env + if self.env is None: + self.env = Environment(undefined=StrictUndefined) + self.template = self.env.from_string(self.config["template"]) + self.retry_template = Template( + """The previous output is invalid, please carefully analyze the traceback and feedback information, correct errors happened before. + feedback: + {{feedback}} + traceback: + {{traceback}} + Give your corrected output in the same format without including the previous output: + """ + ) + self.input_tokens = 0 + self.output_tokens = 0 + self._history: list[Turn] = [] + run_args = self.config.get("run_args", {}) + self.llm.__call__ = partial(self.llm.__call__, **run_args) + self.system_tokens = len(ENCODING.encode(self.system_message)) + + def calc_cost(self, turns: list[Turn]): + """ + Calculate the cost of a list of turns. + """ + for turn in turns[:-1]: + self.input_tokens += turn.input_tokens + self.input_tokens += turn.output_tokens + self.input_tokens += turns[-1].input_tokens + self.output_tokens += turns[-1].output_tokens + self.input_tokens += self.system_tokens + + def get_history(self, similar: int, recent: int, prompt: str): + """ + Get the conversation history. + """ + history = self._history[-recent:] if recent > 0 else [] + if similar > 0: + assert isinstance(self.text_model, LLM), "text_model must be a LLM" + embedding = self.text_model.get_embedding(prompt) + history.sort(key=lambda x: cosine_similarity(embedding, x.embedding)) + for turn in history: + if len(history) > similar + recent: + break + if turn not in history: + history.append(turn) + history.sort(key=lambda x: x.id) + return history + + def retry(self, feedback: str, traceback: str, turn_id: int, error_idx: int): + """ + Retry a failed turn with feedback and traceback. + """ + assert error_idx > 0, "error_idx must be greater than 0" + prompt = self.retry_template.render(feedback=feedback, traceback=traceback) + history = [t for t in self._history if t.id == turn_id] + history_msg = [] + for turn in history: + history_msg.extend(turn.message) + response, message = self.llm( + prompt, + history=history_msg, + return_message=True, + ) + turn = Turn( + id=turn_id, + prompt=prompt, + response=response, + message=message, + retry=error_idx, + ) + return self.__post_process__(response, history, turn) + + def to_sync(self): + """ + Convert the agent to a synchronous agent. + """ + return Agent( + self.name, + self.llm_mapping, + self.text_model, + self.record_cost, + self.config, + self.env, + ) + + def to_async(self): + """ + Convert the agent to an asynchronous agent. + """ + return AsyncAgent( + self.name, + self.llm_mapping, + self.text_model, + self.record_cost, + self.config, + self.env, + ) + + @property + def next_turn_id(self): + if len(self._history) == 0: + return 0 + return max(t.id for t in self._history) + 1 + + @property + def history(self): + return sorted(self._history, key=lambda x: (x.id, x.retry)) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(name={self.name}, model={self.model})" + + def __call__( + self, + images: list[str] = None, + recent: int = 0, + similar: int = 0, + **jinja_args, + ): + """ + Call the agent with prompt arguments. + + Args: + images (list[str]): A list of image file paths. + recent (int): The number of recent turns to include. + similar (int): The number of similar turns to include. + **jinja_args: Additional arguments for the Jinja2 template. + + Returns: + The response from the role. + """ + if isinstance(images, str): + images = [images] + assert self.prompt_args == set( + jinja_args.keys() + ), f"Invalid arguments, expected: {self.prompt_args}, got: {jinja_args.keys()}" + prompt = self.template.render(**jinja_args) + history = self.get_history(similar, recent, prompt) + history_msg = [] + for turn in history: + history_msg.extend(turn.message) + + response, message = self.llm( + prompt, + system_message=self.system_message, + history=history_msg, + images=images, + return_message=True, + ) + turn = Turn( + id=self.next_turn_id, + prompt=prompt, + response=response, + message=message, + images=images, + ) + return turn.id, self.__post_process__(response, history, turn, similar) + + def __post_process__( + self, response: str, history: list[Turn], turn: Turn, similar: int = 0 + ) -> str | dict: + """ + Post-process the response from the agent. + """ + self._history.append(turn) + if similar > 0: + turn.embedding = self.text_model.get_embedding(turn.prompt) + if self.record_cost: + turn.calc_token() + self.calc_cost(history + [turn]) + if self.return_json: + response = get_json_from_response(response) + return response + + +class AsyncAgent(Agent): + """ + An agent, defined by its instruction template and model. + """ + + def __init__( + self, + name: str, + llm_mapping: dict[str, AsyncLLM], + text_model: Optional[AsyncLLM] = None, + record_cost: bool = False, + config: Optional[dict] = None, + env: Optional[Environment] = None, + ): + super().__init__(name, llm_mapping, text_model, record_cost, config, env) + self.llm = self.llm.to_async() + + async def retry(self, feedback: str, traceback: str, turn_id: int, error_idx: int): + """ + Retry a failed turn with feedback and traceback. + """ + assert error_idx > 0, "error_idx must be greater than 0" + prompt = self.retry_template.render(feedback=feedback, traceback=traceback) + history = [t for t in self._history if t.id == turn_id] + history_msg = [] + for turn in history: + history_msg.extend(turn.message) + response, message = await self.llm( + prompt, + history=history_msg, + return_message=True, + ) + turn = Turn( + id=turn_id, + prompt=prompt, + response=response, + message=message, + retry=error_idx, + ) + return await self.__post_process__(response, history, turn) + + async def __call__( + self, + images: list[str] = None, + recent: int = 0, + similar: int = 0, + **jinja_args, + ): + """ + Call the agent with prompt arguments. + + Args: + images (list[str]): A list of image file paths. + recent (int): The number of recent turns to include. + similar (int): The number of similar turns to include. + **jinja_args: Additional arguments for the Jinja2 template. + + Returns: + The response from the role. + """ + if isinstance(images, str): + images = [images] + assert self.prompt_args == set( + jinja_args.keys() + ), f"Invalid arguments, expected: {self.prompt_args}, got: {jinja_args.keys()}" + prompt = self.template.render(**jinja_args) + history = await self.get_history(similar, recent, prompt) + history_msg = [] + for turn in history: + history_msg.extend(turn.message) + + response, message = await self.llm( + prompt, + system_message=self.system_message, + history=history_msg, + images=images, + return_message=True, + ) + turn = Turn( + id=self.next_turn_id, + prompt=prompt, + response=response, + message=message, + images=images, + ) + return turn.id, await self.__post_process__(response, history, turn, similar) + + async def get_history(self, similar: int, recent: int, prompt: str): + """ + Get the conversation history. + """ + history = self._history[-recent:] if recent > 0 else [] + if similar > 0: + embedding = await self.text_model.get_embedding(prompt) + history.sort(key=lambda x: cosine_similarity(embedding, x.embedding)) + for turn in history: + if len(history) > similar + recent: + break + if turn not in history: + history.append(turn) + history.sort(key=lambda x: x.id) + return history + + async def __post_process__( + self, response: str, history: list[Turn], turn: Turn, similar: int = 0 + ): + """ + Post-process the response from the agent. + """ + self._history.append(turn) + if similar > 0: + turn.embedding = await self.text_model.get_embedding(turn.prompt) + if self.record_cost: + turn.calc_token() + self.calc_cost(history + [turn]) + if self.return_json: + response = get_json_from_response(response) + return response + + +def calc_image_tokens(images: list[str]): + """ + Calculate the number of tokens for a list of images. + """ + tokens = 0 + for image in images: + with open(image, "rb") as f: + width, height = Image.open(f).size + if width > 1024 or height > 1024: + if width > height: + height = int(height * 1024 / width) + width = 1024 + else: + width = int(width * 1024 / height) + height = 1024 + h = ceil(height / 512) + w = ceil(width / 512) + tokens += 85 + 170 * h * w + return tokens diff --git a/pptagent/apis.py b/pptagent/apis.py new file mode 100644 index 0000000000000000000000000000000000000000..d55f4bfaf6e4362b1fe7c1be5bad8ffdc187cb00 --- /dev/null +++ b/pptagent/apis.py @@ -0,0 +1,549 @@ +import inspect +import os +import re +import traceback +from copy import deepcopy +from dataclasses import dataclass +from enum import Enum +from functools import partial +from typing import Any, Optional, Union + +import PIL +from bs4 import BeautifulSoup +from mistune import HTMLRenderer, create_markdown +from pptx.enum.text import PP_ALIGN +from pptx.oxml import parse_xml +from pptx.shapes.base import BaseShape +from pptx.shapes.graphfrm import GraphicFrame as PPTXGraphicFrame +from pptx.text.text import _Run +from pptx.util import Pt + +from pptagent.document import Document +from pptagent.presentation import Closure, ClosureType, Picture, ShapeElement, SlidePage +from pptagent.utils import get_logger, runs_merge + +logger = get_logger(__name__) +TABLE_REGEX = re.compile(r".*table_[0-9a-fA-F]{4}\.png$") + + +class SlideRenderer(HTMLRenderer): + """ + A renderer that does not render lists. + """ + + def list(self, text: str, ordered: bool, **attrs: Any) -> str: + return text + + def list_item(self, text: str) -> str: + return text + + +markdown = create_markdown(renderer=SlideRenderer(), plugins=["strikethrough"]) + + +class SlideEditError(Exception): + """ + Exception raised when an edit operation fails. + """ + + +@dataclass +class HistoryMark: + """ + Mark the execution status of the API call, comment and a line of code. + """ + + API_CALL_ERROR = "api_call_error" + API_CALL_CORRECT = "api_call_correct" + COMMENT_CORRECT = "comment_correct" + COMMENT_ERROR = "comment_error" + CODE_RUN_ERROR = "code_run_error" + CODE_RUN_CORRECT = "code_run_correct" + + +class CodeExecutor: + """ + Execute code actions and manage API call history, and providing error feedback. + """ + + def __init__(self, retry_times: int): + """ + Initialize the CodeExecutor. + + Args: + retry_times (int): The number of times to retry failed actions. + """ + self.api_history = [] + self.command_history = [] + self.code_history = [] + self.retry_times = retry_times + self.registered_functions = API_TYPES.all_funcs() + self.function_regex = re.compile(r"^[a-z]+_[a-z_]+\(.+\)") + + @classmethod + def get_apis_docs( + cls, + funcs: list[callable], + show_doc: bool = True, + show_return: bool = True, + ignore_keys: Optional[list[str]] = None, + ) -> str: + """ + Get the documentation for a list of API functions. + + Args: + funcs (list[callable]): A list of functions to document. + show_example (bool): Whether to show examples in the documentation. + + Returns: + str: The formatted API documentation. + """ + if ignore_keys is None: + ignore_keys = {"slide", "self", "doc"} + api_doc = [] + for func in funcs: + sig = inspect.signature(func) + params = [] + for name, param in sig.parameters.items(): + if name in ignore_keys: + continue + param_str = name + if param.annotation != inspect.Parameter.empty: + param_str += f": {param.annotation.__name__}" + if param.default != inspect.Parameter.empty: + param_str += f" = {repr(param.default)}" + params.append(param_str) + signature = f"def {func.__name__}({', '.join(params)})" + if show_return and sig.return_annotation != inspect.Parameter.empty: + signature += f" -> {sig.return_annotation.__name__}" + if show_doc and inspect.getdoc(func) is not None: + doc = "\t" + inspect.getdoc(func) + else: + doc = "" + signature += f"\n{doc}" + api_doc.append(signature) + return "\n".join(api_doc) + + def execute_actions( + self, + actions: str, + edit_slide: SlidePage, + doc: Document, + found_code: bool = False, + ) -> Union[tuple[str, str], None]: + """ + Execute a series of actions on a slide. + + Args: + actions (str): The actions to execute. + edit_slide (SlidePage): The slide to edit. + found_code (bool): Whether code was found in the actions. + + Returns: + tuple: The API lines and traceback if an error occurs. + None: If no error occurs. + """ + api_calls = actions.strip().split("\n") + self.api_history.append( + [HistoryMark.API_CALL_ERROR, edit_slide.slide_idx, actions] + ) + for line_idx, line in enumerate(api_calls): + try: + if line_idx == len(api_calls) - 1 and not found_code: + raise SlideEditError( + "No code block found in the output, please output the api calls without any prefix." + ) + if line.startswith("def"): + raise SlideEditError("The function definition were not allowed.") + if line.startswith("#"): + if len(self.command_history) != 0: + self.command_history[-1][0] = HistoryMark.COMMENT_CORRECT + self.command_history.append([HistoryMark.COMMENT_ERROR, line, None]) + continue + if not self.function_regex.match(line): + continue + found_code = True + func = line.split("(")[0] + if func not in self.registered_functions: + raise SlideEditError(f"The function {func} is not defined.") + # only one of clone and del can be used in a row + if func.startswith("clone") or func.startswith("del"): + tag = func.split("_")[0] + if ( + self.command_history[-1][-1] is None + or self.command_history[-1][-1] == tag + ): + self.command_history[-1][-1] = tag + else: + raise SlideEditError( + "Invalid command: Both 'clone_paragraph' and 'del_paragraph'/'del_image' are used within a single command. " + "Each command must only perform one of these operations based on the quantity_change." + ) + self.code_history.append([HistoryMark.CODE_RUN_ERROR, line, None]) + partial_func = partial(self.registered_functions[func], edit_slide) + if func == "replace_image": + partial_func = partial(partial_func, doc) + eval(line, {}, {func: partial_func}) + self.code_history[-1][0] = HistoryMark.CODE_RUN_CORRECT + except Exception as e: + if not isinstance(e, SlideEditError): + logger.warning(f"Encountered unknown error: {e}") + + trace_msg = traceback.format_exc() + if len(self.code_history) != 0: + self.code_history[-1][-1] = trace_msg + api_lines = ( + "\n".join(api_calls[: line_idx - 1]) + + f"\n--> Error Line: {line}\n" + + "\n".join(api_calls[line_idx + 1 :]) + ) + return api_lines, trace_msg + if len(self.command_history) != 0: + self.command_history[-1][0] = HistoryMark.COMMENT_CORRECT + self.api_history[-1][0] = HistoryMark.API_CALL_CORRECT + + def __add__(self, other): + self.api_history.extend(other.api_history) + self.command_history.extend(other.command_history) + self.code_history.extend(other.code_history) + return self + + +# supporting functions +def element_index(slide: SlidePage, element_id: int) -> ShapeElement: + """ + Find the an element in a slide. + + Args: + slide (SlidePage): The slide + element_id (int): The ID of the element. + + Returns: + ShapeElement: The shape corresponding to the element ID. + + Raises: + SlideEditError: If the element is not found. + """ + for shape in slide: + if shape.shape_idx == element_id: + return shape + raise SlideEditError( + f"Cannot find element {element_id}, is it deleted or not exist?" + ) + + +@dataclass +class TextBlock: + text: str + bold: bool = False + italic: bool = False + code: bool = False + strikethrough: bool = False + href: str = None + + def build_run(self, run: _Run): + if self.bold: + run.font.bold = True + if self.italic: + run.font.italic = True + if self.code: + run.font.name = "Consolas" + if self.strikethrough: + run.font.strikethrough = True + if self.href is not None: + run.hyperlink.address = self.href + + run.text = self.text + + +MARKDOWN_STYLES = { + "strong": "bold", + "em": "italic", + "code": "code", + "del": "strikethrough", +} + + +def process_element(element, styles=None) -> list[TextBlock]: + if styles is None: + styles = {} + + result = [] + + if isinstance(element, str): + result.append(TextBlock(element, **styles)) + else: + if element.name == "a": + new_styles = styles.copy() + for child in element.children: + blocks = process_element(child, new_styles) + for block in blocks: + block.href = element.get("href") + result.extend(blocks) + elif MARKDOWN_STYLES.get(element.name): + new_styles = styles.copy() + new_styles[MARKDOWN_STYLES[element.name]] = True + for child in element.children: + result.extend(process_element(child, new_styles)) + else: + for child in element.children: + result.extend(process_element(child, styles)) + + return result + + +def replace_para(paragraph_id: int, new_text: str, shape: BaseShape): + """ + Replace the text of a paragraph in a shape. + """ + para = shape.text_frame.paragraphs[paragraph_id] + html = markdown(new_text).strip() + soup = BeautifulSoup(html, "html.parser") + blocks = process_element(soup) + + empty_run = runs_merge(para) + empty_run.text = "" + for _ in range(len(blocks) - 1): + empty_run._r.addnext(parse_xml(empty_run._r.xml)) + for block, run in zip(blocks, para.runs): + block.build_run(run) + + +def clone_para(paragraph_id: int, shape: BaseShape): + """ + Clone a paragraph in a shape. + """ + para = shape.text_frame.paragraphs[paragraph_id] + shape.text_frame.paragraphs[-1]._element.addnext(parse_xml(para._element.xml)) + + +def del_para(paragraph_id: int, shape: BaseShape): + """ + Delete a paragraph from a shape. + """ + para = shape.text_frame.paragraphs[paragraph_id] + para._element.getparent().remove(para._element) + + +def add_table(table_data: list[list[str]], table: PPTXGraphicFrame): + rows = len(table_data) + cols = len(table_data[0]) + + max_lengths = [max(len(row[j]) for row in table_data) for j in range(cols)] + total_length = sum(max_lengths) + for j in range(cols): + col_width = int((max_lengths[j] / total_length) * table.width) + table.table.columns[j].width = col_width + + for i in range(rows): + for j in range(cols): + table.table.cell(i, j).text = table_data[i][j] + + +def merge_cells(merge_area: list[tuple[int, int, int, int]], table: PPTXGraphicFrame): + if merge_area is None or len(merge_area) == 0: + return + for y1, x1, y2, x2 in merge_area: + try: + table.table.cell(x1, y1).merge(table.table.cell(x2, y2)) + for x, y in zip(range(x1, x2 + 1), range(y1, y2 + 1)): + tf = table.table.cell(x, y).text_frame + for p in tf.paragraphs: + p.alignment = PP_ALIGN.CENTER + except Exception as e: + logger.warning(f"Failed to merge cells: {e}") + + +# api functions +def del_paragraph(slide: SlidePage, div_id: int, paragraph_id: int): + """ + Delete a paragraph from a slide. + + Args: + slide (SlidePage): The slide containing the paragraph. + div_id (int): The ID of the division containing the paragraph. + paragraph_id (int): The ID of the paragraph to delete. + + Raises: + SlideEditError: If the paragraph is not found. + """ + shape = element_index(slide, div_id) + if not shape.text_frame.is_textframe: + raise SlideEditError( + f"The element {shape.shape_idx} of slide {slide.slide_idx} does not have a text frame, please check the element id and type of element." + ) + for para in shape.text_frame.paragraphs: + if para.idx == paragraph_id: + shape.text_frame.paragraphs.remove(para) + shape._closures[ClosureType.DELETE].append( + Closure(partial(del_para, para.real_idx), para.real_idx) + ) + return + else: + raise SlideEditError( + f"Cannot find the paragraph {paragraph_id} of the element {div_id}," + "may refer to a non-existed paragraph or you haven't cloned enough paragraphs beforehand." + ) + + +def del_image(slide: SlidePage, figure_id: int): + """ + Delete an image from a slide. + + Args: + slide (SlidePage): The slide containing the image. + figure_id (int): The ID of the image to delete. + """ + shape = element_index(slide, figure_id) + if not isinstance(shape, Picture): + raise SlideEditError( + f"The element {shape.shape_idx} of slide {slide.slide_idx} is not a Picture." + ) + slide.shapes.remove(shape) + + +def replace_paragraph(slide: SlidePage, div_id: int, paragraph_id: int, text: str): + """ + Replace the text of a paragraph in a slide. + + Args: + slide (SlidePage): The slide containing the paragraph. + div_id (int): The ID of the division containing the paragraph. + paragraph_id (int): The ID of the paragraph to replace. + text (str): The new text to replace with. + + Raises: + SlideEditError: If the paragraph is not found. + """ + shape = element_index(slide, div_id) + if not shape.text_frame.is_textframe: + raise SlideEditError( + f"The element {shape.shape_idx} of slide {slide.slide_idx} does not have a text frame, please check the element id and type of element." + ) + for para in shape.text_frame.paragraphs: + if para.idx == paragraph_id: + para.text = text + shape._closures[ClosureType.REPLACE].append( + Closure( + partial(replace_para, para.real_idx, text), + para.real_idx, + ) + ) + return + else: + raise SlideEditError( + f"Cannot find the paragraph {paragraph_id} of the element {div_id}," + "Please: " + "1. check if you refer to a non-existed paragraph." + "2. check if you already deleted it." + ) + + +def replace_image(slide: SlidePage, doc: Document, img_id: int, image_path: str): + """ + Replace an image in a slide. + + Args: + slide (SlidePage): The slide containing the image. + img_id (int): The ID of the image to replace. + image_path (str): The path to the new image. + + Raises: + SlideEditError: If the image path does not exist. + """ + if not os.path.exists(image_path): + raise SlideEditError( + f"The image {image_path} does not exist, consider use del_image if image_path in the given command is faked" + ) + shape = element_index(slide, img_id) + if not isinstance(shape, Picture): + raise SlideEditError( + f"The element {shape.shape_idx} of slide {slide.slide_idx} is not a Picture." + ) + + try: + if TABLE_REGEX.match(image_path): + return replace_image_with_table(shape, doc, image_path) + except Exception as e: + logger.warning( + f"Failed to replace image with table element: {e}, fallback to use image directly." + ) + + img_size = PIL.Image.open(image_path).size + r = min(shape.width / img_size[0], shape.height / img_size[1]) + new_width = img_size[0] * r + new_height = img_size[1] * r + shape.top = Pt(shape.top + (shape.height - new_height) / 2) + shape.width = Pt(new_width) + shape.height = Pt(new_height) + shape.img_path = image_path + + +def clone_paragraph(slide: SlidePage, div_id: int, paragraph_id: int): + """ + Clone a paragraph in a slide. + + Args: + slide (SlidePage): The slide containing the paragraph. + div_id (int): The ID of the division containing the paragraph. + paragraph_id (int): The ID of the paragraph to clone. + + Raises: + SlideEditError: If the paragraph is not found. + + Mention: the cloned paragraph will have a paragraph_id one greater than the current maximum in the parent element. + """ + shape = element_index(slide, div_id) + if not shape.text_frame.is_textframe: + raise SlideEditError( + f"The element {shape.shape_idx} of slide {slide.slide_idx} does not have a text frame, please check the element id and type of element." + ) + max_idx = max([para.idx for para in shape.text_frame.paragraphs]) + for para in shape.text_frame.paragraphs: + if para.idx != paragraph_id: + continue + shape.text_frame.paragraphs.append(deepcopy(para)) + shape.text_frame.paragraphs[-1].idx = max_idx + 1 + shape.text_frame.paragraphs[-1].real_idx = len(shape.text_frame.paragraphs) - 1 + shape._closures[ClosureType.CLONE].append( + Closure( + partial(clone_para, para.real_idx), + para.real_idx, + ) + ) + return + raise SlideEditError( + f"Cannot find the paragraph {paragraph_id} of the element {div_id}, may refer to a non-existed paragraph." + ) + + +def replace_image_with_table(shape: Picture, doc: Document, image_path: str): + table = doc.get_table(image_path) + shape.is_table = True + shape.grid = (len(table.cells), len(table.cells[0])) + shape._closures[ClosureType.REPLACE].append( + Closure(partial(add_table, table.cells)) + ) + shape._closures[ClosureType.MERGE].append( + Closure(partial(merge_cells, table.merge_area)) + ) + return + + +class API_TYPES(Enum): + Agent = [ + replace_image, + del_image, + clone_paragraph, + replace_paragraph, + del_paragraph, + ] + + @classmethod + def all_funcs(cls) -> dict[str, callable]: + funcs = {} + for attr in dir(cls): + if attr.startswith("__"): + continue + funcs |= {func.__name__: func for func in getattr(cls, attr).value} + return funcs diff --git a/pptagent/document/__init__.py b/pptagent/document/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2a2a4aceae9a4e7e86fa2407f677c36859483c30 --- /dev/null +++ b/pptagent/document/__init__.py @@ -0,0 +1,11 @@ +from .document import Document, OutlineItem +from .element import Media, Section, SubSection, Table + +__all__ = [ + "Document", + "OutlineItem", + "Media", + "Section", + "SubSection", + "Table", +] diff --git a/pptagent/document/__pycache__/__init__.cpython-312.pyc b/pptagent/document/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60191e2c99501570c440dcd44242fd44db64472d Binary files /dev/null and b/pptagent/document/__pycache__/__init__.cpython-312.pyc differ diff --git a/pptagent/document/__pycache__/document.cpython-312.pyc b/pptagent/document/__pycache__/document.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ae6119dae97277b34074da508ae23694d50c2f5 Binary files /dev/null and b/pptagent/document/__pycache__/document.cpython-312.pyc differ diff --git a/pptagent/document/__pycache__/element.cpython-312.pyc b/pptagent/document/__pycache__/element.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb21c502d064d162524ba87a172c88c44cd604b6 Binary files /dev/null and b/pptagent/document/__pycache__/element.cpython-312.pyc differ diff --git a/pptagent/document/document.py b/pptagent/document/document.py new file mode 100644 index 0000000000000000000000000000000000000000..441e518efaa5f282761e36803cf5e1c7d263b689 --- /dev/null +++ b/pptagent/document/document.py @@ -0,0 +1,548 @@ +import asyncio +import re +import traceback +from dataclasses import asdict, dataclass +from datetime import datetime +from typing import Any, Optional + +from jinja2 import Environment, StrictUndefined +from torch import cosine_similarity + +from pptagent.agent import Agent, AsyncAgent +from pptagent.llms import LLM, AsyncLLM +from pptagent.utils import edit_distance, get_logger, package_join, pexists + +from .element import Section, SubSection, Table, link_medias + +logger = get_logger(__name__) + +env = Environment(undefined=StrictUndefined) + +MERGE_METADATA_PROMPT = env.from_string( + open(package_join("prompts", "merge_metadata.txt")).read() +) +HEADING_EXTRACT_PROMPT = env.from_string( + open(package_join("prompts", "heading_extract.txt")).read() +) +SECTION_SUMMARY_PROMPT = env.from_string( + open(package_join("prompts", "section_summary.txt")).read() +) + +MARKDOWN_IMAGE_REGEX = re.compile(r"!\[.*\]\(.*\)") +MARKDOWN_TABLE_REGEX = re.compile(r"\|.*\|") + + +def split_markdown_by_headings( + markdown_content: str, + headings: list[str], + adjusted_headings: list[str], + min_chunk_size: int = 64, +) -> list[str]: + """ + Split markdown content using headings as separators without regex. + + Args: + markdown_content (str): The markdown content to split + headings (list[str]): List of heading strings to split by + + Returns: + list[str]: List of content sections + """ + adjusted_headings = [ + max(headings, key=lambda x: edit_distance(x, ah)) for ah in adjusted_headings + ] + sections = [] + current_section = [] + + for line in markdown_content.splitlines(): + if any(line.strip().startswith(h) for h in adjusted_headings): + if len(current_section) != 0: + sections.append("\n".join(current_section).strip()) + current_section = [line] + else: + current_section.append(line) + + if len(current_section) != 0: + sections.append("\n".join(current_section).strip()) + + # if an chunk is too small, merge it with the previous chunk + for i in reversed(range(1, len(sections))): + if len(sections[i]) < min_chunk_size: + sections[i - 1] += sections[i] + sections.pop(i) + + if len(sections[0]) < min_chunk_size: + sections[0] += sections[1] + sections.pop(1) + + return sections + + +def to_paragraphs(original_text: str, max_chunk_size: int = 256): + paragraphs = [] + medias = [] + for i, para in enumerate(original_text.split("\n\n")): + para = para.strip() + if not para: + continue + paragraph = {"markdown_content": para, "index": i} + if MARKDOWN_TABLE_REGEX.match(para): + paragraph["type"] = "table" + medias.append(paragraph) + elif MARKDOWN_IMAGE_REGEX.match(para): + paragraph["type"] = "image" + medias.append(paragraph) + else: + paragraphs.append(paragraph) + for media in medias: + pre_chunk = "" + after_chunk = "" + for chunk in reversed(paragraphs): + if chunk["index"] < media["index"]: + pre_chunk += chunk["markdown_content"] + "\n\n" + if len(pre_chunk) > max_chunk_size: + break + for chunk in paragraphs: + if chunk["index"] > media["index"]: + after_chunk += chunk["markdown_content"] + "\n\n" + if len(after_chunk) > max_chunk_size: + break + media["near_chunks"] = (pre_chunk, after_chunk) + return medias + + +@dataclass +class Document: + image_dir: str + sections: list[Section] + metadata: dict[str, str] + + def __post_init__(self): + self.metadata["presentation-date"] = datetime.now().strftime("%Y-%m-%d") + + def iter_medias(self): + for section in self.sections: + yield from section.iter_medias() + + def get_table(self, image_path: str): + for media in self.iter_medias(): + if media.path == image_path and isinstance(media, Table): + return media + raise ValueError(f"table not found: {image_path}") + + @classmethod + def from_dict( + cls, data: dict[str, Any], image_dir: str, require_caption: bool = True + ): + assert ( + "sections" in data + ), f"'sections' key is required in data dictionary but was not found. Input keys: {list(data.keys())}" + assert ( + "metadata" in data + ), f"'metadata' key is required in data dictionary but was not found. Input keys: {list(data.keys())}" + assert pexists(image_dir), f"image directory is not found: {image_dir}" + document = cls( + image_dir=image_dir, + sections=[Section.from_dict(section) for section in data["sections"]], + metadata=data["metadata"], + ) + for section in document.sections: + section.validate_medias(image_dir, require_caption) + return document + + @classmethod + def _parse_chunk( + cls, + extractor: Agent, + language_model: LLM, + vision_model: LLM, + table_model: LLM, + metadata: Optional[dict[str, Any]], + section: Optional[dict[str, Any]], + image_dir: str, + turn_id: int = None, + retry: int = 0, + medias: Optional[list[dict]] = None, + ): + if retry == 0: + medias = to_paragraphs(section) + turn_id, section = extractor(markdown_document=section) + metadata = section.pop("metadata", {}) + try: + section["subsections"] = link_medias(medias, section["subsections"]) + section = Section.from_dict(section) + for media in section.iter_medias(): + media.parse(table_model, image_dir) + if isinstance(media, Table): + media.get_caption(language_model) + else: + media.get_caption(vision_model) + section.validate_medias(image_dir, False) + except Exception as e: + if retry < 3: + logger.info("Retry section with error: %s", str(e)) + new_section = extractor.retry( + str(e), traceback.format_exc(), turn_id, retry + 1 + ) + return cls._parse_chunk( + extractor, + language_model, + vision_model, + table_model, + metadata, + new_section, + image_dir, + turn_id, + retry + 1, + medias, + ) + else: + logger.error( + "Failed to extract section, tried %d times", + retry, + exc_info=e, + ) + raise e + return metadata, section + + @classmethod + async def _parse_chunk_async( + cls, + extractor: AsyncAgent, + language_model: AsyncLLM, + vision_model: AsyncLLM, + table_model: Optional[AsyncLLM], + metadata: Optional[dict[str, Any]], + section: Optional[dict[str, Any]], + image_dir: str, + turn_id: int = None, + retry: int = 0, + medias: Optional[list[dict]] = None, + ): + if retry == 0: + medias = to_paragraphs(section) + turn_id, section = await extractor(markdown_document=section) + metadata = section.pop("metadata", {}) + try: + section["subsections"] = link_medias(medias, section["subsections"]) + section = Section.from_dict(section) + for media in section.iter_medias(): + await media.parse_async(table_model, image_dir) + if isinstance(media, Table): + await media.get_caption_async(language_model) + else: + await media.get_caption_async(vision_model) + section.validate_medias(image_dir, False) + except Exception as e: + if retry < 3: + logger.info("Retry section with error: %s", str(e)) + new_section = await extractor.retry( + str(e), traceback.format_exc(), turn_id, retry + 1 + ) + return await cls._parse_chunk_async( + extractor, + language_model, + vision_model, + table_model, + metadata, + new_section, + image_dir, + turn_id, + retry + 1, + medias, + ) + else: + logger.error( + "Failed to extract section, tried %d times", + retry, + exc_info=e, + ) + raise e + return metadata, section + + @classmethod + def from_markdown( + cls, + markdown_content: str, + language_model: LLM, + vision_model: LLM, + image_dir: str, + table_model: Optional[LLM] = None, + ): + """ + Create a Document from markdown content. + + Args: + markdown_content (str): The markdown content. + language_model (LLM): The language model. + vision_model (LLM): The vision model. + image_dir (str): The directory containing images. + + Returns: + Document: The created document. + """ + doc_extractor = Agent( + "doc_extractor", + llm_mapping={"language": language_model, "vision": vision_model}, + ) + + metadata_list = [] + sections = [] + + headings = re.findall(r"^#+\s+.*", markdown_content, re.MULTILINE) + adjusted_headings = language_model( + HEADING_EXTRACT_PROMPT.render(headings=headings), return_json=True + ) + + for chunk in split_markdown_by_headings( + markdown_content, headings, adjusted_headings + ): + metadata, section = cls._parse_chunk( + doc_extractor, + language_model, + vision_model, + table_model, + None, + chunk, + image_dir, + ) + section.summary = language_model( + SECTION_SUMMARY_PROMPT.render(section_content=chunk), + ) + metadata_list.append(metadata) + sections.append(section) + + merged_metadata = language_model( + MERGE_METADATA_PROMPT.render(metadata=metadata_list), return_json=True + ) + return Document( + image_dir=image_dir, metadata=merged_metadata, sections=sections + ) + + @classmethod + async def from_markdown_async( + cls, + markdown_content: str, + language_model: AsyncLLM, + vision_model: AsyncLLM, + image_dir: str, + table_model: Optional[AsyncLLM] = None, + ): + doc_extractor = AsyncAgent( + "doc_extractor", + llm_mapping={"language": language_model, "vision": vision_model}, + ) + + headings = re.findall(r"^#+\s+.*", markdown_content, re.MULTILINE) + adjusted_headings = await language_model( + HEADING_EXTRACT_PROMPT.render(headings=headings), return_json=True + ) + metadata = [] + sections = [] + tasks = [] + + async with asyncio.TaskGroup() as tg: + for chunk in split_markdown_by_headings( + markdown_content, headings, adjusted_headings + ): + task1 = tg.create_task( + cls._parse_chunk_async( + doc_extractor, + language_model, + vision_model, + table_model, + None, + chunk, + image_dir, + ) + ) + task2 = tg.create_task( + language_model( + SECTION_SUMMARY_PROMPT.render(section_content=chunk), + ) + ) + tasks.append((task1, task2)) + + # Process results in order + for task1, task2 in tasks: + meta, section = task1.result() + metadata.append(meta) + sections.append(section) + for section in sections: + section.summary = task2.result() + + merged_metadata = await language_model( + MERGE_METADATA_PROMPT.render(metadata=metadata), return_json=True + ) + return Document( + image_dir=image_dir, metadata=merged_metadata, sections=sections + ) + + def __contains__(self, key: str): + for section in self.sections: + if section.title == key: + return True + return False + + def __getitem__(self, key: str): + for section in self.sections: + if section.title == key: + return section + raise KeyError( + f"section not found: {key}, available sections: {[section.title for section in self.sections]}" + ) + + def to_dict(self): + return asdict(self) + + def retrieve( + self, + indexs: dict[str, list[str]], + ) -> list[SubSection]: + assert isinstance( + indexs, dict + ), "subsection_keys for index must be a dict, follow a two-level structure" + subsecs = [] + for sec_key, subsec_keys in indexs.items(): + section = self[sec_key] + for subsec_key in subsec_keys: + subsecs.append(section[subsec_key]) + return subsecs + + def find_caption(self, caption: str): + for media in self.iter_medias(): + if media.caption == caption: + return media.path + raise ValueError(f"Image caption not found: {caption}") + + def get_overview(self, include_summary: bool = False): + overview = "" + for section in self.sections: + overview += f"Section: {section.title}\n" + if include_summary: + overview += f"\tSummary: {section.summary}\n" + for subsection in section.subsections: + overview += f"\tSubsection: {subsection.title}\n" + for media in subsection.medias: + overview += f"\t\tMedia: {media.caption}\n" + overview += "\n" + return overview + + @property + def metainfo(self): + return "\n".join([f"{k}: {v}" for k, v in self.metadata.items()]) + + @property + def subsections(self): + return [subsec for section in self.sections for subsec in section.subsections] + + +@dataclass +class OutlineItem: + purpose: str + section: str + indexs: dict[str, list[str]] | str + images: list[str] + + @classmethod + def from_dict(cls, data: dict[str, Any]): + assert ( + "purpose" in data and "section" in data + ), "purpose and section of outline item are required" + return cls( + purpose=data["purpose"], + section=data["section"], + indexs=data.get("indexs", {}), + images=data.get("images", []), + ) + + def retrieve(self, slide_idx: int, document: Document): + subsections = document.retrieve(self.indexs) + header = f"Slide-{slide_idx+1}: {self.purpose}\n" + content = "" + for subsection in subsections: + content += f"Paragraph: {subsection.title}\nContent: {subsection.content}\n" + images = [ + f"Image: {document.find_caption(caption)}\nCaption: {caption}" + for caption in self.images + ] + return header, content, images + + def check_retrieve(self, document: Document, sim_bound: float): + for sec_key, subsec_keys in list(self.indexs.items()): + section = max( + document.sections, key=lambda x: edit_distance(x.title, sec_key) + ) + self.indexs[section.title] = self.indexs.pop(sec_key) + if edit_distance(section.title, sec_key) < sim_bound: + logger.warning( + f"section not found: {sec_key}, available sections: {[section.title for section in document.sections]}.", + ) + raise ValueError( + f"section not found: {sec_key}, available sections: {[section.title for section in document.sections]}." + ) + for idx in range(len(subsec_keys)): + subsection = max( + section.subsections, + key=lambda x: edit_distance(x.title, subsec_keys[idx]), + ) + self.indexs[section.title][idx] = subsection.title + if edit_distance(subsection.title, subsec_keys[idx]) < sim_bound: + raise ValueError( + f"subsection {subsec_keys[idx]} not found in section {section.title}, available subsections: {[subsection.title for subsection in section.subsections]}." + ) + + def check_images(self, document: Document, text_model: LLM, sim_bound: float): + doc_images = list(document.iter_medias()) + image_embeddings = [] + for idx, image in enumerate(self.images): + if len(doc_images) == 0: + raise ValueError("Document does not contain any images.") + similar = max(doc_images, key=lambda x: edit_distance(x.caption, image)) + if edit_distance(similar.caption, image) > sim_bound: + self.images[idx] = similar.caption + continue + if len(image_embeddings) == 0: + image_embeddings.extend( + [text_model.get_embedding(image) for image in self.images] + ) + + embedding = text_model.get_embedding(image) + similar = max( + range(len(image_embeddings)), + key=lambda x: cosine_similarity(embedding, image_embeddings[x]), + ) + if cosine_similarity(embedding, image_embeddings[similar]) > sim_bound: + self.images[idx] = doc_images[similar].caption + else: + logger.warning( + f"image not found: {image}, available images: {[image.caption for image in doc_images]}.", + ) + raise ValueError( + f"image not found: {image}, available images: \n{[image.caption for image in doc_images]}\nPlease ensure the caption is exactly matched." + ) + + async def check_images_async( + self, document: Document, text_model: AsyncLLM, sim_bound: float + ): + doc_images = list(document.iter_medias()) + image_embeddings = [] + for idx, image in enumerate(self.images): + if len(doc_images) == 0: + raise ValueError("Document does not contain any images.") + similar = max(doc_images, key=lambda x: edit_distance(x.caption, image)) + if edit_distance(similar.caption, image) > sim_bound: + self.images[idx] = similar.caption + continue + if len(image_embeddings) == 0: + image_embeddings = await asyncio.gather( + *[text_model.get_embedding(image) for image in self.images] + ) + + embedding = await text_model.get_embedding(image) + similar = max( + range(len(image_embeddings)), + key=lambda x: cosine_similarity(embedding, image_embeddings[x]), + ) + if cosine_similarity(embedding, image_embeddings[similar]) > sim_bound: + self.images[idx] = doc_images[similar].caption diff --git a/pptagent/document/element.py b/pptagent/document/element.py new file mode 100644 index 0000000000000000000000000000000000000000..e42ab36d06fe68393c5ae7507eed72227a390394 --- /dev/null +++ b/pptagent/document/element.py @@ -0,0 +1,322 @@ +import hashlib +import re +from dataclasses import dataclass +from typing import Any, Optional, abstractmethod + +from bs4 import BeautifulSoup +from jinja2 import Environment, StrictUndefined +from mistune import html as markdown +from PIL import Image + +from pptagent.llms import LLM, AsyncLLM +from pptagent.utils import ( + edit_distance, + get_logger, + markdown_table_to_image, + package_join, + pbasename, + pexists, + pjoin, +) + +env = Environment(undefined=StrictUndefined) + +IMAGE_PARSING_REGEX = re.compile(r"\((.*?)\)") +TABLE_PARSING_PROMPT = env.from_string( + open(package_join("prompts", "table_parsing.txt")).read() +) +TABLE_CAPTION_PROMPT = env.from_string( + open(package_join("prompts", "markdown_table_caption.txt")).read() +) +IMAGE_CAPTION_PROMPT = env.from_string( + open(package_join("prompts", "markdown_image_caption.txt")).read() +) + +logger = get_logger(__name__) + + +@dataclass +class Media: + markdown_content: str + near_chunks: tuple[str, str] + path: Optional[str] = None + caption: Optional[str] = None + + @classmethod + def from_dict(cls, data: dict[str, Any]): + assert ( + "markdown_content" in data and "near_chunks" in data + ), f"'markdown_content' and 'near_chunks' keys are required in data dictionary but were not found. Input keys: {list(data.keys())}" + return cls( + markdown_content=data["markdown_content"], + near_chunks=data["near_chunks"], + path=data.get("path", None), + caption=data.get("caption", None), + ) + + @property + def size(self): + assert self.path is not None, "Path is required to get size" + return Image.open(self.path).size + + @abstractmethod + def parse(self, _: Optional[LLM], image_dir: str): + """ + Parse the markdown content to extract image path and alt text. + Format expected: ![alt text](image.png) + """ + match = IMAGE_PARSING_REGEX.search(self.markdown_content) + if match is None: + raise ValueError("No image found in the markdown content") + image_path = match.group(1) + if not pexists(image_path): + image_path = pjoin(image_dir, image_path) + assert pexists(image_path), f"image file not found: {image_path}" + self.path = image_path + + async def parse_async(self, language_model: Optional[AsyncLLM], image_dir: str): + self.parse(language_model, image_dir) + + def get_caption(self, vision_model: LLM): + assert self.path is not None, "Path is required to get caption" + if self.caption is None: + self.caption = vision_model( + IMAGE_CAPTION_PROMPT.render( + markdown_caption=self.near_chunks, + ), + self.path, + ) + logger.debug(f"Caption: {self.caption}") + + async def get_caption_async(self, vision_model: AsyncLLM): + assert self.path is not None, "Path is required to get caption" + if self.caption is None: + self.caption = await vision_model( + IMAGE_CAPTION_PROMPT.render( + markdown_caption=self.near_chunks, + ), + self.path, + ) + logger.debug(f"Caption: {self.caption}") + + +@dataclass +class Table(Media): + cells: Optional[list[list[str]]] = None + merge_area: Optional[list[tuple[int, int, int, int]]] = None + + @classmethod + def from_dict(cls, data: dict[str, Any]): + assert ( + "markdown_content" in data and "near_chunks" in data + ), f"'markdown_content' and 'near_chunks' keys are required in data dictionary but were not found. Input keys: {list(data.keys())}" + return cls( + markdown_content=data["markdown_content"], + near_chunks=data["near_chunks"], + path=data.get("path", None), + caption=data.get("caption", None), + cells=data.get("cells", None), + merge_area=data.get("merge_area", None), + ) + + def parse_table(self, image_dir: str): + html = markdown(self.markdown_content) + soup = BeautifulSoup(html, "html.parser") + table = soup.find("table") + self.cells = [] + for row in table.find_all("tr"): + self.cells.append( + [cell.text for cell in row.find_all("td") + row.find_all("th")] + ) + for i in range(len(self.cells)): + row = self.cells[i] + unstacked = row[0].split("\n") + if len(unstacked) == len(row) and all( + cell.strip() == "" for cell in row[1:] + ): + self.cells[i] = unstacked + + if self.path is None: + self.path = pjoin( + image_dir, + f"table_{hashlib.md5(str(self.cells).encode()).hexdigest()[:4]}.png", + ) + markdown_table_to_image(self.markdown_content, self.path) + + def parse(self, table_model: Optional[LLM], image_dir: str): + self.parse_table(image_dir) + if table_model is None: + return + result = table_model( + TABLE_PARSING_PROMPT.render(cells=self.cells, caption=self.caption), + return_json=True, + ) + self.merge_area = result["merge_area"] + table = [row for row in result["table_data"]] + if ( + all(len(row) == len(table[0]) for row in table) + and len(table) == len(self.cells) + and len(table[0]) == len(self.cells[0]) + ): + self.cells = table + + async def parse_async(self, table_model: Optional[AsyncLLM], image_dir: str): + self.parse_table(image_dir) + if table_model is None: + return + result = await table_model( + TABLE_PARSING_PROMPT.render(cells=self.cells, caption=self.caption), + return_json=True, + ) + self.merge_area = result["merge_area"] + table = [row for row in result["table_data"]] + if ( + all(len(row) == len(table[0]) for row in table) + and len(table) == len(self.cells) + and len(table[0]) == len(self.cells[0]) + ): + self.cells = table + + def get_caption(self, language_model: LLM): + if self.caption is None: + self.caption = language_model( + TABLE_CAPTION_PROMPT.render( + markdown_content=self.markdown_content, + markdown_caption=self.near_chunks, + ) + ) + logger.debug(f"Caption: {self.caption}") + + async def get_caption_async(self, language_model: AsyncLLM): + if self.caption is None: + self.caption = await language_model( + TABLE_CAPTION_PROMPT.render( + markdown_content=self.markdown_content, + markdown_caption=self.near_chunks, + ) + ) + logger.debug(f"Caption: {self.caption}") + + +@dataclass +class SubSection: + title: str + content: str + medias: list[Media] + + @classmethod + def from_dict(cls, data: dict[str, Any]): + assert ( + "title" in data and "content" in data + ), f"'title' and 'content' keys are required in data dictionary but were not found. Input keys: {list(data.keys())}" + medias = [] + for chunk in data.get("medias", []): + if ( + chunk.get("type", None) == "table" + or chunk.get("cells", None) is not None + ): + medias.append(Table.from_dict(chunk)) + else: + medias.append(Media.from_dict(chunk)) + return cls( + title=data["title"], + content=data["content"], + medias=medias, + ) + + def iter_medias(self): + yield from self.medias + + +@dataclass +class Section: + title: str + summary: Optional[str] + subsections: list[SubSection] + markdown_content: str + + @classmethod + def from_dict(cls, data: dict[str, Any], markdown_content: str = None): + assert ( + "title" in data and "subsections" in data + ), f"'title' and 'subsections' keys are required in data dictionary but were not found. Input keys: {list(data.keys())}" + subsections = [ + SubSection.from_dict(subsection) for subsection in data["subsections"] + ] + assert len(subsections) != 0, "subsections is empty" + return cls( + title=data["title"], + subsections=subsections, + summary=data.get("summary", None), + markdown_content=data.get("markdown_content", markdown_content), + ) + + def __contains__(self, key: str): + for subsection in self.subsections: + if subsection.title == key: + return True + return False + + def __getitem__(self, key: str): + for subsection in self.subsections: + if subsection.title == key: + return subsection + sim_subsec = max(self.subsections, key=lambda x: edit_distance(x.title, key)) + if edit_distance(sim_subsec.title, key) > 0.8: + return sim_subsec + raise KeyError( + f"subsection not found: {key}, available subsections of {self.title} are: {[subsection.title for subsection in self.subsections]}" + ) + + def iter_medias(self): + for subsection in self.subsections: + yield from subsection.iter_medias() + + def validate_medias(self, image_dir: str, require_caption: bool = True): + for media in self.iter_medias(): + if not pexists(media.path): + basename = pbasename(media.path) + if pexists(pjoin(image_dir, basename)): + media.path = pjoin(image_dir, basename) + else: + raise FileNotFoundError(f"image file not found: {media.path}") + assert ( + media.caption is not None or not require_caption + ), f"caption is required for media: {media.path}" + + +def link_medias( + medias: list[dict], + rewritten_paragraphs: list[dict[str, Any]], + max_chunk_size: int = 256, +) -> dict[str, Any]: + """ + Link media elements to the most relevant paragraphs based on content proximity. + + Args: + medias: List of media dictionaries (tables, images) + original_paragraphs: List of original paragraph dictionaries + rewritten_paragraphs: List of rewritten paragraph dictionaries + max_chunk_size: Maximum size of text chunk to consider for matching + + Returns: + The rewritten paragraphs with medias linked to appropriate sections + """ + # Process each media element + assert len(rewritten_paragraphs) != 0, "rewritten_paragraphs is empty" + for media in medias: + if len(media["near_chunks"][0]) < max_chunk_size: + link_paragraph = rewritten_paragraphs[0] + else: + link_paragraph = max( + rewritten_paragraphs, + key=lambda x: edit_distance( + media["near_chunks"][0], x.get("markdown_content", "") + ), + ) + + if "medias" not in link_paragraph: + link_paragraph["medias"] = [] + link_paragraph["medias"].append(media) + + return rewritten_paragraphs diff --git a/pptagent/induct.py b/pptagent/induct.py new file mode 100644 index 0000000000000000000000000000000000000000..a7e4ab9eb8d79fa28fffc8fd4ea85c6abb51e0a4 --- /dev/null +++ b/pptagent/induct.py @@ -0,0 +1,430 @@ +import asyncio +import os +import traceback +from collections import defaultdict +from collections.abc import Coroutine +from typing import Any + +from jinja2 import Template + +from pptagent.agent import Agent +from pptagent.llms import LLM, AsyncLLM +from pptagent.model_utils import ( + get_cluster, + get_image_embedding, + images_cosine_similarity, +) +from pptagent.presentation import Picture, Presentation, SlidePage +from pptagent.utils import ( + Config, + edit_distance, + get_logger, + is_image_path, + package_join, + pjoin, +) + +logger = get_logger(__name__) + +CATEGORY_SPLIT_TEMPLATE = Template( + open(package_join("prompts", "category_split.txt")).read() +) +ASK_CATEGORY_PROMPT = open(package_join("prompts", "ask_category.txt")).read() + + +def check_schema(schema: dict | Any, slide: SlidePage): + if not isinstance(schema, dict): + raise ValueError( + f"Output schema should be a dict, but got {type(schema)}: {schema}\n", + """ { + "element_name": { + "description": "purpose of this element", # do not mention any detail, just purpose + "type": "text" or "image", + "data": ["text1", "text2"] or ["logo:...", "logo:..."] + } + }""", + ) + + similar_ele = None + max_similarity = -1 + for el_name, element in schema.items(): + if "data" not in element or len(element["data"]) == 0: + raise ValueError( + f"Empty element is not allowed, but got {el_name}: {element}. Content of each element should be in the `data` field.\n", + "If this infered to an empty or unexpected element, remove it from the schema.", + ) + if not isinstance(element["data"], list): + logger.debug("Converting single text element to list: %s", element["data"]) + element["data"] = [element["data"]] + if element["type"] == "text": + + for item in element["data"]: + for para in slide.iter_paragraphs(): + similarity = edit_distance(para.text, item) + if similarity > 0.8: + break + if similarity > max_similarity: + max_similarity = similarity + similar_ele = para.text + else: + raise ValueError( + f"Text element `{el_name}` contains text `{item}` that does not match any single paragraph

in the current slide. The most similar paragraph found was `{similar_ele}`.", + "This error typically occurs when either: 1) multiple paragraphs are incorrectly merged into a single element, or 2) a single paragraph is incorrectly split into multiple items.", + ) + + elif element["type"] == "image": + + for caption in element["data"]: + for shape in slide.shape_filter(Picture): + similarity = edit_distance(shape.caption, caption) + if similarity > 0.8: + break + if similarity > max_similarity: + max_similarity = similarity + similar_ele = shape.caption + else: + raise ValueError( + f"Image caption of {el_name}: {caption} not found in the `alt` attribute of elements of current slide, the most similar caption is {similar_ele}" + ) + + else: + raise ValueError( + f"Unknown type of {el_name}: {element['type']}, should be one of ['text', 'image']" + ) + + +class SlideInducter: + """ + Stage I: Presentation Analysis. + This stage is to analyze the presentation: cluster slides into different layouts, and extract content schema for each layout. + """ + + def __init__( + self, + prs: Presentation, + ppt_image_folder: str, + template_image_folder: str, + config: Config, + image_models: list, + language_model: LLM, + vision_model: LLM, + use_assert: bool = True, + ): + """ + Initialize the SlideInducter. + + Args: + prs (Presentation): The presentation object. + ppt_image_folder (str): The folder containing PPT images. + template_image_folder (str): The folder containing normalized slide images. + config (Config): The configuration object. + image_models (list): A list of image models. + """ + self.prs = prs + self.config = config + self.ppt_image_folder = ppt_image_folder + self.template_image_folder = template_image_folder + self.language_model = language_model + self.vision_model = vision_model + self.image_models = image_models + self.schema_extractor = Agent( + "schema_extractor", + { + "language": language_model, + }, + ) + if not use_assert: + return + + num_template_images = sum( + is_image_path(f) for f in os.listdir(template_image_folder) + ) + num_ppt_images = sum(is_image_path(f) for f in os.listdir(ppt_image_folder)) + num_slides = len(prs.slides) + + if not (num_template_images == num_ppt_images == num_slides): + raise ValueError( + f"Slide count mismatch detected:\n" + f"- Presentation slides: {num_slides}\n" + f"- Template images: {num_template_images} ({template_image_folder})\n" + f"- PPT images: {num_ppt_images} ({ppt_image_folder})\n" + f"All counts must be equal." + ) + + def layout_induct(self) -> dict: + """ + Perform layout induction for the presentation, should be called before content induction. + Return a dict representing layouts, each layout is a dict with keys: + - key: the layout name, e.g. "Title and Content:text" + - `template_id`: the id of the template slide + - `slides`: the list of slide ids + Moreover, the dict has a key `functional_keys`, which is a list of functional keys. + """ + layout_induction = defaultdict(lambda: defaultdict(list)) + content_slides_index, functional_cluster = self.category_split() + for layout_name, cluster in functional_cluster.items(): + layout_induction[layout_name]["slides"] = cluster + layout_induction[layout_name]["template_id"] = cluster[0] + + functional_keys = list(layout_induction.keys()) + function_slides_index = set() + for layout_name, cluster in layout_induction.items(): + function_slides_index.update(cluster["slides"]) + used_slides_index = function_slides_index.union(content_slides_index) + for i in range(len(self.prs.slides)): + if i + 1 not in used_slides_index: + content_slides_index.add(i + 1) + self.layout_split(content_slides_index, layout_induction) + layout_induction["functional_keys"] = functional_keys + return layout_induction + + def category_split(self): + """ + Split slides into categories based on their functional purpose. + """ + functional_cluster = self.language_model( + CATEGORY_SPLIT_TEMPLATE.render(slides=self.prs.to_text()), + return_json=True, + ) + assert isinstance(functional_cluster, dict) and all( + isinstance(k, str) and isinstance(v, list) + for k, v in functional_cluster.items() + ), "Functional cluster must be a dictionary with string keys and list values" + functional_slides = set(sum(functional_cluster.values(), [])) + content_slides_index = set(range(1, len(self.prs) + 1)) - functional_slides + + return content_slides_index, functional_cluster + + def layout_split(self, content_slides_index: set[int], layout_induction: dict): + """ + Cluster slides into different layouts. + """ + embeddings = get_image_embedding(self.template_image_folder, *self.image_models) + assert len(embeddings) == len(self.prs) + content_split = defaultdict(list) + for slide_idx in content_slides_index: + slide = self.prs.slides[slide_idx - 1] + content_type = slide.get_content_type() + layout_name = slide.slide_layout_name + content_split[(layout_name, content_type)].append(slide_idx) + + for (layout_name, content_type), slides in content_split.items(): + sub_embeddings = [ + embeddings[f"slide_{slide_idx:04d}.jpg"] for slide_idx in slides + ] + similarity = images_cosine_similarity(sub_embeddings) + for cluster in get_cluster(similarity): + slide_indexs = [slides[i] for i in cluster] + template_id = max( + slide_indexs, + key=lambda x: len(self.prs.slides[x - 1].shapes), + ) + cluster_name = ( + self.vision_model( + ASK_CATEGORY_PROMPT, + pjoin(self.ppt_image_folder, f"slide_{template_id:04d}.jpg"), + ) + + ":" + + content_type + ) + layout_induction[cluster_name]["template_id"] = template_id + layout_induction[cluster_name]["slides"] = slide_indexs + + def content_induct(self, layout_induction: dict): + """ + Perform content schema extraction for the presentation. + """ + for layout_name, cluster in layout_induction.items(): + if layout_name == "functional_keys" or "content_schema" in cluster: + continue + slide = self.prs.slides[cluster["template_id"] - 1] + turn_id, schema = self.schema_extractor(slide=slide.to_html()) + schema = self._fix_schema(schema, slide, turn_id) + layout_induction[layout_name]["content_schema"] = schema + + return layout_induction + + def _fix_schema( + self, + schema: dict, + slide: SlidePage, + turn_id: int = None, + retry: int = 0, + ) -> dict: + """ + Fix schema by checking and retrying if necessary. + """ + try: + check_schema(schema, slide) + except ValueError as e: + retry += 1 + logger.debug("Failed at schema extraction: %s", e) + if retry == 3: + logger.error("Failed to extract schema for slide-%s: %s", turn_id, e) + raise e + schema = self.schema_extractor.retry( + e, traceback.format_exc(), turn_id, retry + ) + return self._fix_schema(schema, slide, turn_id, retry) + return schema + + +class SlideInducterAsync(SlideInducter): + def __init__( + self, + prs: Presentation, + ppt_image_folder: str, + template_image_folder: str, + config: Config, + image_models: list, + language_model: AsyncLLM, + vision_model: AsyncLLM, + ): + """ + Initialize the SlideInducterAsync with async models. + + Args: + prs (Presentation): The presentation object. + ppt_image_folder (str): The folder containing PPT images. + template_image_folder (str): The folder containing normalized slide images. + config (Config): The configuration object. + image_models (list): A list of image models. + language_model (AsyncLLM): The async language model. + vision_model (AsyncLLM): The async vision model. + """ + super().__init__( + prs, + ppt_image_folder, + template_image_folder, + config, + image_models, + language_model, + vision_model, + ) + self.language_model = self.language_model.to_async() + self.vision_model = self.vision_model.to_async() + self.schema_extractor = self.schema_extractor.to_async() + + async def category_split(self): + """ + Async version: Split slides into categories based on their functional purpose. + """ + functional_cluster = await self.language_model( + CATEGORY_SPLIT_TEMPLATE.render(slides=self.prs.to_text()), + return_json=True, + ) + assert isinstance(functional_cluster, dict) and all( + isinstance(k, str) and isinstance(v, list) + for k, v in functional_cluster.items() + ), "Functional cluster must be a dictionary with string keys and list values" + functional_slides = set(sum(functional_cluster.values(), [])) + content_slides_index = set(range(1, len(self.prs) + 1)) - functional_slides + + return content_slides_index, functional_cluster + + async def layout_split( + self, content_slides_index: set[int], layout_induction: dict + ): + """ + Async version: Cluster slides into different layouts. + """ + embeddings = get_image_embedding(self.template_image_folder, *self.image_models) + assert len(embeddings) == len(self.prs) + content_split = defaultdict(list) + for slide_idx in content_slides_index: + slide = self.prs.slides[slide_idx - 1] + content_type = slide.get_content_type() + layout_name = slide.slide_layout_name + content_split[(layout_name, content_type)].append(slide_idx) + + async with asyncio.TaskGroup() as tg: + for (layout_name, content_type), slides in content_split.items(): + sub_embeddings = [ + embeddings[f"slide_{slide_idx:04d}.jpg"] for slide_idx in slides + ] + similarity = images_cosine_similarity(sub_embeddings) + for cluster in get_cluster(similarity): + slide_indexs = [slides[i] for i in cluster] + template_id = max( + slide_indexs, + key=lambda x: len(self.prs.slides[x - 1].shapes), + ) + + tg.create_task( + self.vision_model( + ASK_CATEGORY_PROMPT, + pjoin( + self.ppt_image_folder, f"slide_{template_id:04d}.jpg" + ), + ) + ).add_done_callback( + lambda f, tid=template_id, sidxs=slide_indexs, ctype=content_type: layout_induction[ + f.result() + ":" + ctype + ].update( + {"template_id": tid, "slides": sidxs} + ) + ) + + async def layout_induct(self): + """ + Async version: Perform layout induction for the presentation. + """ + layout_induction = defaultdict(lambda: defaultdict(list)) + content_slides_index, functional_cluster = await self.category_split() + for layout_name, cluster in functional_cluster.items(): + layout_induction[layout_name]["slides"] = cluster + layout_induction[layout_name]["template_id"] = cluster[0] + + functional_keys = list(layout_induction.keys()) + function_slides_index = set() + for layout_name, cluster in layout_induction.items(): + function_slides_index.update(cluster["slides"]) + used_slides_index = function_slides_index.union(content_slides_index) + for i in range(len(self.prs.slides)): + if i + 1 not in used_slides_index: + content_slides_index.add(i + 1) + await self.layout_split(content_slides_index, layout_induction) + layout_induction["functional_keys"] = functional_keys + return layout_induction + + async def content_induct(self, layout_induction: dict): + """ + Async version: Perform content schema extraction for the presentation. + """ + async with asyncio.TaskGroup() as tg: + for layout_name, cluster in layout_induction.items(): + if layout_name == "functional_keys" or "content_schema" in cluster: + continue + slide = self.prs.slides[cluster["template_id"] - 1] + coro = self.schema_extractor(slide=slide.to_html()) + + tg.create_task(self._fix_schema(coro, slide)).add_done_callback( + lambda f, key=layout_name: layout_induction[key].update( + {"content_schema": f.result()} + ) + ) + + return layout_induction + + async def _fix_schema( + self, + schema: dict | Coroutine[dict, None, None], + slide: SlidePage, + turn_id: int = None, + retry: int = 0, + ): + if retry == 0: + turn_id, schema = await schema + try: + check_schema(schema, slide) + except ValueError as e: + retry += 1 + logger.debug("Failed at schema extraction: %s", e) + if retry == 3: + logger.error("Failed to extract schema for slide-%s: %s", turn_id, e) + raise e + schema = await self.schema_extractor.retry( + e, traceback.format_exc(), turn_id, retry + ) + return await self._fix_schema(schema, slide, turn_id, retry) + return schema diff --git a/pptagent/llms.py b/pptagent/llms.py new file mode 100644 index 0000000000000000000000000000000000000000..5b8e660e29c3eb9075736b92050d502790cc0a2a --- /dev/null +++ b/pptagent/llms.py @@ -0,0 +1,421 @@ +import base64 +import re +import threading +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from oaib import Auto +from openai import AsyncOpenAI, OpenAI +from openai.types.chat import ChatCompletion + +from pptagent.utils import get_json_from_response, get_logger, tenacity_decorator + +logger = get_logger(__name__) + + +@dataclass +class LLM: + """ + A wrapper class to interact with a language model. + """ + + model: str + base_url: Optional[str] = None + api_key: Optional[str] = None + timeout: int = 360 + + def __post_init__(self): + self.client = OpenAI( + base_url=self.base_url, api_key=self.api_key, timeout=self.timeout + ) + + @tenacity_decorator + def __call__( + self, + content: str, + images: Optional[Union[str, list[str]]] = None, + system_message: Optional[str] = None, + history: Optional[list] = None, + return_json: bool = False, + return_message: bool = False, + **client_kwargs, + ) -> Union[str, dict, list, tuple]: + """ + Call the language model with a prompt and optional images. + + Args: + content (str): The prompt content. + images (str or list[str]): An image file path or list of image file paths. + system_message (str): The system message. + history (list): The conversation history. + return_json (bool): Whether to return the response as JSON. + return_message (bool): Whether to return the message. + **client_kwargs: Additional keyword arguments to pass to the client. + + Returns: + Union[str, Dict, List, Tuple]: The response from the model. + """ + if history is None: + history = [] + system, message = self.format_message(content, images, system_message) + try: + completion = self.client.chat.completions.create( + model=self.model, messages=system + history + message, **client_kwargs + ) + except Exception as e: + logger.warning("Error in LLM call: %s", e) + raise e + response = completion.choices[0].message.content + message.append({"role": "assistant", "content": response}) + return self.__post_process__(response, message, return_json, return_message) + + def __post_process__( + self, + response: str, + message: list, + return_json: bool = False, + return_message: bool = False, + ) -> Union[str, dict, tuple]: + """ + Process the response based on return options. + + Args: + response (str): The raw response from the model. + message (List): The message history. + return_json (bool): Whether to return the response as JSON. + return_message (bool): Whether to return the message. + + Returns: + Union[str, Dict, Tuple]: Processed response. + """ + response = response.strip() + if return_json: + response = get_json_from_response(response) + if return_message: + response = (response, message) + return response + + def __repr__(self) -> str: + repr_str = f"{self.__class__.__name__}(model={self.model}" + if self.base_url is not None: + repr_str += f", base_url={self.base_url}" + return repr_str + ")" + + def test_connection(self) -> bool: + """ + Test the connection to the LLM. + + Returns: + bool: True if connection is successful, False otherwise. + """ + try: + self.client.models.list() + return True + except Exception as e: + logger.warning( + "Connection test failed: %s\nLLM: %s: %s, %s", + e, + self.model, + self.base_url, + self.api_key, + ) + return False + + def format_message( + self, + content: str, + images: Optional[Union[str, list[str]]] = None, + system_message: Optional[str] = None, + ) -> tuple[list, list]: + """ + Format messages for OpenAI server call. + + Args: + content (str): The prompt content. + images (str or list[str]): An image file path or list of image file paths. + system_message (str): The system message. + + Returns: + Tuple[List, List]: Formatted system and user messages. + """ + if isinstance(images, str): + images = [images] + if system_message is None: + if content.startswith("You are"): + system_message, content = content.split("\n", 1) + else: + system_message = "You are a helpful assistant" + system = [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + } + ] + message = [{"role": "user", "content": [{"type": "text", "text": content}]}] + if images is not None: + for image in images: + try: + with open(image, "rb") as f: + message[0]["content"].append( + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64.b64encode(f.read()).decode('utf-8')}" + }, + } + ) + except Exception as e: + logger.error("Failed to load image %s: %s", image, e) + return system, message + + def gen_image(self, prompt: str, n: int = 1, **kwargs) -> str: + """ + Generate an image from a prompt. + """ + return ( + self.client.images.generate(model=self.model, prompt=prompt, n=n, **kwargs) + .data[0] + .b64_json + ) + + def get_embedding( + self, + text: str, + encoding_format: str = "float", + to_tensor: bool = True, + **kwargs, + ) -> torch.Tensor | list[float]: + """ + Get the embedding of a text. + """ + result = self.client.embeddings.create( + model=self.model, input=text, encoding_format=encoding_format, **kwargs + ) + embeddings = [embedding.embedding for embedding in result.data] + if to_tensor: + embeddings = torch.tensor(embeddings) + return embeddings + + def to_async(self) -> "AsyncLLM": + """ + Convert the LLM to an asynchronous LLM. + """ + return AsyncLLM( + model=self.model, + base_url=self.base_url, + api_key=self.api_key, + timeout=self.timeout, + ) + + +@dataclass +class AsyncLLM(LLM): + use_batch: bool = False + """ + Asynchronous wrapper class for language model interaction. + """ + + def __post_init__(self): + """ + Initialize the AsyncLLM. + + Args: + model (str): The model name. + base_url (str): The base URL for the API. + api_key (str): API key for authentication. Defaults to environment variable. + """ + self.client = AsyncOpenAI( + base_url=self.base_url, + api_key=self.api_key, + timeout=self.timeout, + ) + if threading.current_thread() == threading.main_thread(): + self.batch = Auto( + base_url=self.base_url, + api_key=self.api_key, + timeout=self.timeout, + loglevel=0, + ) + else: + logger.warning("Auto initialization skipped because it's not the main thread.") + + @tenacity_decorator + async def __call__( + self, + content: str, + images: Optional[Union[str, list[str]]] = None, + system_message: Optional[str] = None, + history: Optional[list] = None, + return_json: bool = False, + return_message: bool = False, + **client_kwargs, + ) -> Union[str, dict, tuple]: + """ + Asynchronously call the language model with a prompt and optional images. + + Args: + content (str): The prompt content. + images (str or list[str]): An image file path or list of image file paths. + system_message (str): The system message. + history (list): The conversation history. + return_json (bool): Whether to return the response as JSON. + return_message (bool): Whether to return the message. + **client_kwargs: Additional keyword arguments to pass to the client. + + Returns: + Union[str, Dict, List, Tuple]: The response from the model. + """ + if self.use_batch and threading.current_thread() is threading.main_thread(): + self.batch = Auto( + base_url=self.base_url, + api_key=self.api_key, + timeout=self.timeout, + loglevel=0, + ) + elif self.use_batch: + logger.warning( + "Warning: AsyncLLM is not running in the main thread, may cause race condition." + ) + if history is None: + history = [] + system, message = self.format_message(content, images, system_message) + try: + if self.use_batch: + await self.batch.add( + "chat.completions.create", + model=self.model, + messages=system + history + message, + **client_kwargs, + ) + completion = await self.batch.run() + if "result" not in completion or len(completion["result"]) != 1: + raise ValueError( + f"The length of completion result should be 1, but got {completion}.\nRace condition may have occurred if multiple values are returned.\nOr, there was an error in the LLM call, use the synchronous version to check." + ) + completion = ChatCompletion(**completion["result"][0]) + else: + completion = await self.client.chat.completions.create( + model=self.model, + messages=system + history + message, + **client_kwargs, + ) + + except Exception as e: + logger.warning("Error in AsyncLLM call: %s", e) + raise e + response = completion.choices[0].message.content + message.append({"role": "assistant", "content": response}) + return self.__post_process__(response, message, return_json, return_message) + + def __getstate__(self): + state = self.__dict__.copy() + state["client"] = None + state["batch"] = None + return state + + def __setstate__(self, state: dict): + self.__dict__.update(state) + self.client = AsyncOpenAI( + base_url=self.base_url, + api_key=self.api_key, + timeout=self.timeout, + ) + self.batch = Auto( + base_url=self.base_url, + api_key=self.api_key, + timeout=self.timeout, + loglevel=0, + ) + + async def test_connection(self) -> bool: + """ + Test the connection to the LLM asynchronously. + + Returns: + bool: True if connection is successful, False otherwise. + """ + try: + await self.client.models.list() + return True + except Exception as e: + logger.warning( + "Async connection test failed: %s\nLLM: %s: %s, %s", + e, + self.model, + self.base_url, + self.api_key, + ) + return False + + async def gen_image(self, prompt: str, n: int = 1, **kwargs) -> str: + """ + Generate an image from a prompt asynchronously. + + Args: + prompt (str): The text prompt to generate an image from. + n (int): Number of images to generate. + **kwargs: Additional keyword arguments for image generation. + + Returns: + str: Base64-encoded image data. + """ + response = await self.client.images.generate( + model=self.model, prompt=prompt, n=n, response_format="b64_json", **kwargs + ) + return response.data[0].b64_json + + async def get_embedding( + self, + text: str, + to_tensor: bool = True, + **kwargs, + ) -> torch.Tensor | list[float]: + """ + Get the embedding of a text asynchronously. + + Args: + text (str): The text to get embeddings for. + **kwargs: Additional keyword arguments. + + Returns: + List[float]: The embedding vector. + """ + response = await self.client.embeddings.create( + model=self.model, + input=text, + encoding_format="float", + **kwargs, + ) + embeddings = [embedding.embedding for embedding in response.data] + if to_tensor: + embeddings = torch.tensor(embeddings) + return embeddings + + def to_sync(self) -> LLM: + """ + Convert the AsyncLLM to a synchronous LLM. + """ + return LLM(model=self.model, base_url=self.base_url, api_key=self.api_key) + + +def get_model_abbr(llms: Union[LLM, list[LLM]]) -> str: + """ + Get abbreviated model names from LLM instances. + + Args: + llms: A single LLM instance or a list of LLM instances. + + Returns: + str: Abbreviated model names joined with '+'. + """ + # Convert single LLM to list for consistent handling + if isinstance(llms, LLM): + llms = [llms] + + try: + # Attempt to extract model names before version numbers + return "+".join(re.search(r"^(.*?)-\d{2}", llm.model).group(1) for llm in llms) + except Exception: + # Fallback: return full model names if pattern matching fails + return "+".join(llm.model for llm in llms) diff --git a/pptagent/model_utils.py b/pptagent/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..84c07b58dcc6d44d5f868cd99bcadd23b6a84b3f --- /dev/null +++ b/pptagent/model_utils.py @@ -0,0 +1,317 @@ +import json +import os +from copy import deepcopy +from typing import Optional + +import numpy as np +import torch +import torchvision.transforms as T +from marker.config.parser import ConfigParser +from marker.converters.pdf import PdfConverter +from marker.models import create_model_dict +from marker.output import text_from_rendered +from PIL import Image +from transformers import AutoModel, AutoProcessor + +from pptagent.llms import LLM, AsyncLLM +from pptagent.presentation import Presentation, SlidePage +from pptagent.utils import get_logger, is_image_path, pjoin + +logger = get_logger(__name__) + + +class ModelManager: + """ + A class to manage models. + """ + + def __init__( + self, + api_base: Optional[str] = None, + api_key: Optional[str] = None, + language_model_name: Optional[str] = None, + vision_model_name: Optional[str] = None, + text_model_name: Optional[str] = None, + ): + """Initialize models from environment variables after instance creation""" + if api_base is None: + api_base = os.environ.get("API_BASE", None) + if api_key is None: + api_key = os.environ.get("OPENAI_API_KEY", None) + if language_model_name is None: + language_model_name = os.environ.get("LANGUAGE_MODEL", "gpt-4.1") + if vision_model_name is None: + vision_model_name = os.environ.get("VISION_MODEL", "gpt-4.1") + if text_model_name is None: + text_model_name = os.environ.get("TEXT_MODEL", "text-embedding-3-small") + self.api_base = api_base + self.api_key = api_key + self._image_model = None + self._marker_model = None + self.device = "cuda" if torch.cuda.is_available() else "cpu" + + self.language_model = AsyncLLM(language_model_name, api_base, api_key=api_key) + self.vision_model = AsyncLLM(vision_model_name, api_base, api_key=api_key) + self.text_model = AsyncLLM(text_model_name, api_base, api_key=api_key) + + @property + def image_model(self): + if self._image_model is None: + self._image_model = get_image_model(device=self.device) + return self._image_model + + @property + def marker_model(self): + if self._marker_model is None: + self._marker_model = create_model_dict( + device=self.device, dtype=torch.float16 + ) + return self._marker_model + + async def test_connections(self) -> bool: + """Test connections for all LLM models + + Returns: + bool: True if all connections are successful, False otherwise + """ + try: + assert await self.language_model.test_connection() + assert await self.vision_model.test_connection() + assert await self.text_model.test_connection() + except: + return False + return True + + +def prs_dedup( + presentation: Presentation, + model: LLM, + threshold: float = 0.8, +) -> list[SlidePage]: + """ + Deduplicate slides in a presentation based on text similarity. + + Args: + presentation (Presentation): The presentation object containing slides. + model: The model used for generating text embeddings. + batchsize (int): The batch size for processing slides. + threshold (float): The similarity threshold for deduplication. + + Returns: + list: A list of removed duplicate slides. + """ + text_embeddings = model.get_embedding([i.to_text() for i in presentation.slides]) + pre_embedding = text_embeddings[0] + slide_idx = 1 + duplicates = [] + while slide_idx < len(presentation): + cur_embedding = text_embeddings[slide_idx] + if torch.cosine_similarity(pre_embedding, cur_embedding, -1) > threshold: + duplicates.append(slide_idx - 1) + slide_idx += 1 + pre_embedding = cur_embedding + return [presentation.slides.pop(i) for i in reversed(duplicates)] + + +def get_image_model(device: str = None): + """ + Initialize and return an image model and its feature extractor. + + Args: + device (str): The device to run the model on. + + Returns: + tuple: A tuple containing the feature extractor and the image model. + """ + model_base = "google/vit-base-patch16-224-in21k" + return ( + AutoProcessor.from_pretrained( + model_base, + torch_dtype=torch.float16, + device_map=device, + use_fast=True, + ), + AutoModel.from_pretrained( + model_base, + torch_dtype=torch.float16, + device_map=device, + ).eval(), + ) + + +def parse_pdf( + pdf_path: str, + output_path: str, + model_lst: list, +) -> str: + """ + Parse a PDF file and extract text and images. + + Args: + pdf_path (str): The path to the PDF file. + output_path (str): The directory to save the extracted content. + model_lst (list): A list of models for processing the PDF. + + Returns: + str: The full text extracted from the PDF. + """ + os.makedirs(output_path, exist_ok=True) + config_parser = ConfigParser( + { + "output_format": "markdown", + } + ) + converter = PdfConverter( + config=config_parser.generate_config_dict(), + artifact_dict=model_lst, + processor_list=config_parser.get_processors(), + renderer=config_parser.get_renderer(), + ) + rendered = converter(pdf_path) + full_text, _, images = text_from_rendered(rendered) + with open(pjoin(output_path, "source.md"), "w+", encoding="utf-8") as f: + f.write(full_text) + for filename, image in images.items(): + image_filepath = os.path.join(output_path, filename) + image.save(image_filepath, "JPEG") + with open(pjoin(output_path, "meta.json"), "w+", encoding="utf-8") as f: + f.write(json.dumps(rendered.metadata, indent=4)) + + return full_text + + +def get_image_embedding( + image_dir: str, extractor, model, batchsize: int = 16 +) -> dict[str, torch.Tensor]: + """ + Generate image embeddings for images in a directory. + + Args: + image_dir (str): The directory containing images. + extractor: The feature extractor for images. + model: The model used for generating embeddings. + batchsize (int): The batch size for processing images. + + Returns: + dict: A dictionary mapping image filenames to their embeddings. + """ + transform = T.Compose( + [ + T.Resize(int((256 / 224) * extractor.size["height"])), + T.CenterCrop(extractor.size["height"]), + T.ToTensor(), + T.Normalize(mean=extractor.image_mean, std=extractor.image_std), + ] + ) + + inputs = [] + embeddings = [] + images = [i for i in sorted(os.listdir(image_dir)) if is_image_path(i)] + for file in images: + image = Image.open(pjoin(image_dir, file)).convert("RGB") + inputs.append(transform(image)) + if len(inputs) % batchsize == 0 or file == images[-1]: + batch = {"pixel_values": torch.stack(inputs).to(model.device)} + embeddings.extend(model(**batch).last_hidden_state.detach()) + inputs.clear() + return {image: embedding.flatten() for image, embedding in zip(images, embeddings)} + + +def images_cosine_similarity(embeddings: list[torch.Tensor]) -> torch.Tensor: + """ + Calculate the cosine similarity matrix for a list of embeddings. + Args: + embeddings (list[torch.Tensor]): A list of image embeddings. + + Returns: + torch.Tensor: A NxN similarity matrix. + """ + embeddings = [embedding for embedding in embeddings] + sim_matrix = torch.zeros((len(embeddings), len(embeddings))) + for i in range(len(embeddings)): + for j in range(i + 1, len(embeddings)): + sim_matrix[i, j] = sim_matrix[j, i] = torch.cosine_similarity( + embeddings[i], embeddings[j], -1 + ) + return sim_matrix + + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + + +def average_distance( + similarity: torch.Tensor, idx: int, cluster_idx: list[int] +) -> float: + """ + Calculate the average distance between a point (idx) and a cluster (cluster_idx). + + Args: + similarity (np.ndarray): The similarity matrix. + idx (int): The index of the point. + cluster_idx (list): The indices of the cluster. + + Returns: + float: The average distance. + """ + if idx in cluster_idx: + return 0 + total_similarity = 0 + for idx_in_cluster in cluster_idx: + total_similarity += similarity[idx, idx_in_cluster] + return total_similarity / len(cluster_idx) + + +def get_cluster(similarity: np.ndarray, sim_bound: float = 0.65): + """ + Cluster points based on similarity. + + Args: + similarity (np.ndarray): The similarity matrix. + sim_bound (float): The similarity threshold for clustering. + + Returns: + list: A list of clusters. + """ + sim_copy = deepcopy(similarity) + num_points = sim_copy.shape[0] + clusters = [] + added = [False] * num_points + + while True: + max_avg_dist = sim_bound + best_cluster = None + best_point = None + + for c in clusters: + for point_idx in range(num_points): + if added[point_idx]: + continue + avg_dist = average_distance(sim_copy, point_idx, c) + if avg_dist > max_avg_dist: + max_avg_dist = avg_dist + best_cluster = c + best_point = point_idx + + if best_point is not None: + best_cluster.append(best_point) + added[best_point] = True + sim_copy[best_point, :] = 0 + sim_copy[:, best_point] = 0 + else: + if sim_copy.max() < sim_bound: + # append the remaining points invididual cluster + for i in range(num_points): + if not added[i]: + clusters.append([i]) + break + i, j = np.unravel_index(np.argmax(sim_copy), sim_copy.shape) + clusters.append([int(i), int(j)]) + added[i] = True + added[j] = True + sim_copy[i, :] = 0 + sim_copy[:, i] = 0 + sim_copy[j, :] = 0 + sim_copy[:, j] = 0 + + return clusters diff --git a/pptagent/multimodal.py b/pptagent/multimodal.py new file mode 100644 index 0000000000000000000000000000000000000000..d605bde9a9fb22a41e3f249b135d9981493e0025 --- /dev/null +++ b/pptagent/multimodal.py @@ -0,0 +1,143 @@ +import asyncio +from typing import Optional + +import PIL.Image + +from pptagent.llms import LLM, AsyncLLM +from pptagent.presentation import Picture, Presentation +from pptagent.utils import Config, get_logger, package_join, pbasename, pjoin + +logger = get_logger(__name__) + + +class ImageLabler: + """ + A class to extract images information, including caption, size, and appearance times in a presentation. + """ + + def __init__(self, presentation: Presentation, config: Config): + """ + Initialize the ImageLabler. + + Args: + presentation (Presentation): The presentation object. + config (Config): The configuration object. + """ + self.presentation = presentation + self.slide_area = presentation.slide_width.pt * presentation.slide_height.pt + self.image_stats = {} + self.config = config + self.collect_images() + + def apply_stats(self, image_stats: Optional[dict[str, dict]] = None): + """ + Apply image captions to the presentation. + """ + if image_stats is None: + image_stats = self.image_stats + + for slide in self.presentation.slides: + for shape in slide.shape_filter(Picture): + if shape.caption is None: + caption = image_stats[pbasename(shape.img_path)]["caption"] + shape.caption = max(caption.split("\n"), key=len) + + async def caption_images_async(self, vision_model: AsyncLLM): + """ + Generate captions for images in the presentation asynchronously. + + Args: + vision_model (AsyncLLM): The async vision model to use for captioning. + + Returns: + dict: Dictionary containing image stats with captions. + """ + assert isinstance( + vision_model, AsyncLLM + ), "vision_model must be an AsyncLLM instance" + caption_prompt = open(package_join("prompts", "caption.txt")).read() + + async with asyncio.TaskGroup() as tg: + for image, stats in self.image_stats.items(): + if "caption" not in stats: + task = tg.create_task( + vision_model( + caption_prompt, + pjoin(self.config.IMAGE_DIR, image), + ) + ) + task.add_done_callback( + lambda t, image=image: ( + self.image_stats[image].update({"caption": t.result()}), + logger.debug("captioned %s: %s", image, t.result()), + ) + ) + + self.apply_stats() + return self.image_stats + + def caption_images(self, vision_model: LLM): + """ + Generate captions for images in the presentation. + + Args: + vision_model (LLM): The vision model to use for captioning. + + Returns: + dict: Dictionary containing image stats with captions. + """ + assert isinstance(vision_model, LLM), "vision_model must be an LLM instance" + caption_prompt = open(package_join("prompts", "caption.txt")).read() + for image, stats in self.image_stats.items(): + if "caption" not in stats: + stats["caption"] = vision_model( + caption_prompt, pjoin(self.config.IMAGE_DIR, image) + ) + logger.debug("captioned %s: %s", image, stats["caption"]) + self.apply_stats() + return self.image_stats + + def collect_images(self): + """ + Collect images from the presentation and gather other information. + """ + for slide_index, slide in enumerate(self.presentation.slides): + for shape in slide.shape_filter(Picture): + image_path = pbasename(shape.img_path) + if image_path == "pic_placeholder.png": + continue + if image_path not in self.image_stats: + size = PIL.Image.open(pjoin(self.config.IMAGE_DIR, image_path)).size + self.image_stats[image_path] = { + "size": size, + "appear_times": 0, + "slide_numbers": set(), + "relative_area": shape.area / self.slide_area * 100, + } + self.image_stats[image_path]["appear_times"] += 1 + self.image_stats[image_path]["slide_numbers"].add(slide_index + 1) + for image_path, stats in self.image_stats.items(): + stats["slide_numbers"] = sorted(list(stats["slide_numbers"])) + ranges = self._find_ranges(stats["slide_numbers"]) + top_ranges = sorted(ranges, key=lambda x: x[1] - x[0], reverse=True)[:3] + top_ranges_str = ", ".join( + [f"{r[0]}-{r[1]}" if r[0] != r[1] else f"{r[0]}" for r in top_ranges] + ) + stats["top_ranges_str"] = top_ranges_str + + def _find_ranges(self, numbers): + """ + Find consecutive ranges in a list of numbers. + """ + ranges = [] + start = numbers[0] + end = numbers[0] + for num in numbers[1:]: + if num == end + 1: + end = num + else: + ranges.append((start, end)) + start = num + end = num + ranges.append((start, end)) + return ranges diff --git a/pptagent/pptgen.py b/pptagent/pptgen.py new file mode 100644 index 0000000000000000000000000000000000000000..5007dd13c204b09153eff1186cb5ae4a7b1f1bb7 --- /dev/null +++ b/pptagent/pptgen.py @@ -0,0 +1,963 @@ +import asyncio +import json +import traceback +from abc import ABC, abstractmethod +from copy import deepcopy +from dataclasses import dataclass +from enum import Enum +from typing import Optional + +from pptagent.agent import Agent +from pptagent.apis import API_TYPES, CodeExecutor +from pptagent.document import Document, OutlineItem +from pptagent.llms import LLM, AsyncLLM +from pptagent.presentation import Layout, Picture, Presentation, SlidePage, StyleArg +from pptagent.utils import Config, edit_distance, get_logger, tenacity_decorator + +logger = get_logger(__name__) + +style = StyleArg.all_true() +style.area = False + + +class FunctionalLayouts(Enum): + OPENING = "opening" + TOC = "table of contents" + SECTION_OUTLINE = "section outline" + ENDING = "ending" + + +FunctionalContent = { + FunctionalLayouts.OPENING.value: "This slide is a presentation opening, presenting available meta information, like title, author, date, etc.", + FunctionalLayouts.TOC.value: "This slide is the Table of Contents, outlining the presentation's sections. Please use the given Table of Contents, and remove numbering to generate the slide content.", + FunctionalLayouts.SECTION_OUTLINE.value: "This slide is a section start , briefly presenting the section title, and optionally the section summary.", + FunctionalLayouts.ENDING.value: "This slide is an *ending slide*, simply express your gratitude like 'Thank you!' or '谢谢' as the main title and *do not* include other meta information if not specified.", +} + + +@dataclass +class PPTGen(ABC): + """ + Stage II: Presentation Generation + An abstract base class for generating PowerPoint presentations. + It accepts a reference presentation as input, then generates a presentation outline and slides. + """ + + roles = [] + text_embedder: LLM | AsyncLLM + language_model: LLM | AsyncLLM + vision_model: LLM | AsyncLLM + retry_times: int = 3 + sim_bound: float = 0.5 + force_pages: bool = False + error_exit: bool = False + record_cost: bool = False + length_factor: float | None = None + _initialized: bool = False + + def __post_init__(self): + self._initialized = False + self._hire_staffs(self.record_cost, self.language_model, self.vision_model) + assert ( + self.length_factor is None or self.length_factor > 0 + ), "length_factor must be positive or None" + + def set_reference( + self, + config: Config, + slide_induction: dict, + presentation: Presentation, + hide_small_pic_ratio: Optional[float] = 0.2, + keep_in_background: bool = True, + ): + """ + Set the reference presentation and extracted presentation information. + + Args: + presentation (Presentation): The presentation object. + slide_induction (dict): The slide induction data. + + Returns: + PPTGen: The updated PPTGen object. + """ + self.config = config + self.presentation = presentation + + self.functional_layouts = slide_induction.pop("functional_keys") + self.text_layouts = [ + k + for k in slide_induction + if k.endswith("text") and k not in self.functional_layouts + ] + self.multimodal_layouts = [ + k + for k in slide_induction + if not k.endswith("text") and k not in self.functional_layouts + ] + if len(self.text_layouts) == 0: + self.text_layouts = self.multimodal_layouts + if len(self.multimodal_layouts) == 0: + self.multimodal_layouts = self.text_layouts + + self.layouts = {k: Layout.from_dict(k, v) for k, v in slide_induction.items()} + self.empty_prs = deepcopy(self.presentation) + assert ( + hide_small_pic_ratio is None or hide_small_pic_ratio > 0 + ), "hide_small_pic_ratio must be positive or None" + if hide_small_pic_ratio is not None: + self._hide_small_pics(hide_small_pic_ratio, keep_in_background) + self._initialized = True + return self + + def generate_pres( + self, + source_doc: Document, + num_slides: Optional[int] = None, + outline: Optional[list[OutlineItem]] = None, + ): + """ + Generate a PowerPoint presentation. + + Args: + source_doc (Document): The source document. + num_slides (Optional[int]): The number of slides to generate. + outline (Optional[List[OutlineItem]]): The outline of the presentation. + + Returns: + dict: A dictionary containing the presentation data and history. + + Raise: + ValueError: if failed to generate presentation outline. + """ + assert self._initialized, "PPTGen not initialized, call `set_reference` first" + self.source_doc = source_doc + succ_flag = True + if outline is None: + self.outline = self.generate_outline(num_slides, source_doc) + else: + self.outline = outline + self.simple_outline = "\n".join( + [ + f"Slide {slide_idx+1}: {item.purpose}" + for slide_idx, item in enumerate(self.outline) + ] + ) + generated_slides = [] + code_executors = [] + for slide_idx, outline_item in enumerate(self.outline): + if self.force_pages and slide_idx == num_slides: + break + try: + slide, code_executor = self.generate_slide(slide_idx, outline_item) + generated_slides.append(slide) + code_executors.append(code_executor) + except Exception as e: + logger.warning( + "Failed to generate slide, error_exit=%s, error: %s", + self.error_exit, + str(e), + ) + traceback.print_exc() + if self.error_exit: + succ_flag = False + break + + # Collect history data + history = self._collect_history( + sum(code_executors, start=CodeExecutor(self.retry_times)) + ) + + if succ_flag: + self.empty_prs.slides = generated_slides + prs = self.empty_prs + else: + prs = None + + self.empty_prs = deepcopy(self.presentation) + return prs, history + + def generate_outline( + self, + num_slides: int, + source_doc: Document, + ): + """ + Generate an outline for the presentation. + + Args: + num_slides (int): The number of slides to generate. + + Returns: + dict: The generated outline. + """ + assert self._initialized, "PPTGen not initialized, call `set_reference` first" + turn_id, outline = self.staffs["planner"]( + num_slides=num_slides, + document_overview=source_doc.get_overview(), + ) + if num_slides == 1 and isinstance(outline, dict): + outline = [outline] + outline = self._fix_outline(outline, source_doc, turn_id) + return self._add_functional_layouts(outline) + + @abstractmethod + def generate_slide( + self, slide_idx: int, outline_item: OutlineItem + ) -> tuple[SlidePage, CodeExecutor]: + """ + Generate a slide from the outline item. + """ + raise NotImplementedError("Subclass must implement this method") + + def _add_functional_layouts(self, outline: list[OutlineItem]): + """ + Add functional layouts to the outline. + """ + toc = [] + for item in outline: + if item.section not in toc and item.section != "Functional": + toc.append(item.section) + self.toc = "\n".join(toc) + + fixed_functional_slides = [ + (FunctionalLayouts.TOC.value, 0), # toc should be inserted before opening + (FunctionalLayouts.OPENING.value, 0), + (FunctionalLayouts.ENDING.value, 999999), # append to the end + ] + for title, pos in fixed_functional_slides: + layout = max( + self.functional_layouts, + key=lambda x: edit_distance(x.lower(), title), + ) + if edit_distance(layout, title) > 0.7: + outline.insert(pos, OutlineItem(title, "Functional", {}, [])) + + section_outline = max( + self.functional_layouts, + key=lambda x: edit_distance(x, FunctionalLayouts.SECTION_OUTLINE.value), + ) + if ( + not edit_distance(section_outline, FunctionalLayouts.SECTION_OUTLINE.value) + > 0.7 + ): + return outline + full_outline = [] + pre_section = None + for item in outline: + if item.section == "Functional": + full_outline.append(item) + continue + if item.section != pre_section: + new_item = OutlineItem( + FunctionalLayouts.SECTION_OUTLINE.value, + "Functional", + item.section, + [], + ) + full_outline.append(new_item) + full_outline.append(item) + pre_section = item.section + return full_outline + + def _hide_small_pics(self, area_ratio: float, keep_in_background: bool): + for layout in self.layouts.values(): + template_slide = self.presentation.slides[layout.template_id - 1] + pictures = list(template_slide.shape_filter(Picture, return_father=True)) + if len(pictures) == 0: + continue + for father, pic in pictures: + if pic.area / pic.slide_area < area_ratio: + if keep_in_background: + father.shapes.remove(pic) + else: + father.shapes.remove(pic) + father.backgrounds.append(pic) + layout.remove_item(pic.caption.strip()) + + if len(list(template_slide.shape_filter(Picture))) == 0: + logger.debug( + "All pictures in layout %s are too small, set to pure text layout", + layout.title, + ) + layout.title = layout.title.replace(":image", ":text") + + def _fix_outline( + self, outline: list[dict], source_doc: Document, turn_id: int, retry: int = 0 + ) -> list[OutlineItem]: + """ + Validate the generated outline. + + Raises: + ValueError: If the outline is invalid. + """ + try: + outline_items = [ + OutlineItem.from_dict(outline_item) for outline_item in outline + ] + for outline_item in outline_items: + outline_item.check_retrieve(source_doc, self.sim_bound) + outline_item.check_images( + source_doc, self.text_embedder, self.sim_bound + ) + return outline_items + except Exception as e: + retry += 1 + logger.info( + "Failed to generate outline, tried %d/%d times, error: %s", + retry, + self.retry_times, + str(e), + ) + logger.debug(traceback.format_exc()) + if retry < self.retry_times: + new_outline = self.staffs["planner"].retry( + str(e), traceback.format_exc(), turn_id, retry + ) + return self._fix_outline(new_outline, source_doc, turn_id, retry) + else: + raise ValueError("Failed to generate outline, tried too many times") + + def _collect_history(self, code_executor: CodeExecutor): + """ + Collect the history of code execution, API calls and agent steps. + + Returns: + dict: The collected history data. + """ + history = { + "agents": {}, + "code_history": code_executor.code_history, + "api_history": code_executor.api_history, + } + + for role_name, role in self.staffs.items(): + history["agents"][role_name] = role.history + role._history = [] + + return history + + def _hire_staffs( + self, + record_cost: bool, + language_model: LLM | AsyncLLM, + vision_model: LLM | AsyncLLM, + ) -> dict[str, Agent]: + """ + Initialize agent roles and their models + """ + llm_mapping = { + "language": language_model, + "vision": vision_model, + } + self.staffs = { + role: Agent( + role, + record_cost=record_cost, + text_model=self.text_embedder, + llm_mapping=llm_mapping, + ) + for role in ["planner"] + self.roles + } + + +@dataclass +class PPTGenAsync(PPTGen): + """ + Asynchronous base class for generating PowerPoint presentations. + Extends PPTGen with async functionality. + """ + + def __post_init__(self): + super().__post_init__() + for k in list(self.staffs.keys()): + self.staffs[k] = self.staffs[k].to_async() + + async def generate_pres( + self, + source_doc: Document, + num_slides: Optional[int] = None, + outline: Optional[list[OutlineItem]] = None, + ): + """ + Asynchronously generate a PowerPoint presentation. + """ + assert ( + self._initialized + ), "AsyncPPTAgent not initialized, call `set_reference` first" + self.source_doc = source_doc + succ_flag = True + if outline is None: + self.outline = await self.generate_outline(num_slides, source_doc) + else: + self.outline = outline + self.simple_outline = "\n".join( + [ + f"Slide {slide_idx+1}: {item.purpose}" + for slide_idx, item in enumerate(self.outline) + ] + ) + + slide_tasks = [] + for slide_idx, outline_item in enumerate(self.outline): + if self.force_pages and slide_idx == num_slides: + break + slide_tasks.append(self.generate_slide(slide_idx, outline_item)) + + slide_results = await asyncio.gather(*slide_tasks, return_exceptions=True) + + generated_slides = [] + code_executors = [] + for result in slide_results: + if isinstance(result, Exception): + if self.error_exit: + succ_flag = False + break + continue + if result is not None: + slide, code_executor = result + generated_slides.append(slide) + code_executors.append(code_executor) + + history = self._collect_history( + sum(code_executors, start=CodeExecutor(self.retry_times)) + ) + + if succ_flag: + self.empty_prs.slides = generated_slides + prs = self.empty_prs + else: + prs = None + + self.empty_prs = deepcopy(self.presentation) + return prs, history + + async def generate_outline( + self, + num_slides: int, + source_doc: Document, + ): + """ + Asynchronously generate an outline for the presentation. + """ + assert ( + self._initialized + ), "AsyncPPTAgent not initialized, call `set_reference` first" + + turn_id, outline = await self.staffs["planner"]( + num_slides=num_slides, + document_overview=source_doc.get_overview(), + ) + if num_slides == 1 and isinstance(outline, dict): + outline = [outline] + outline = await self._fix_outline(outline, source_doc, turn_id) + return self._add_functional_layouts(outline) + + @abstractmethod + async def generate_slide( + self, slide_idx: int, outline_item: OutlineItem + ) -> tuple[SlidePage, CodeExecutor]: + """ + Asynchronously generate a slide from the outline item. + """ + raise NotImplementedError("Subclass must implement this method") + + async def _fix_outline( + self, outline: list[dict], source_doc: Document, turn_id: int, retry: int = 0 + ) -> list[OutlineItem]: + """ + Asynchronously validate the generated outline. + """ + try: + outline_items = [ + OutlineItem.from_dict(outline_item) for outline_item in outline + ] + async with asyncio.TaskGroup() as tg: + for outline_item in outline_items: + outline_item.check_retrieve(source_doc, self.sim_bound) + tg.create_task( + outline_item.check_images_async( + source_doc, self.text_embedder, self.sim_bound + ) + ) + return outline_items + except Exception as e: + retry += 1 + logger.info( + "Failed to generate outline, tried %d/%d times, error: %s", + retry, + self.retry_times, + str(e), + ) + logger.debug(traceback.format_exc()) + if retry < self.retry_times: + new_outline = await self.staffs["planner"].retry( + str(e), traceback.format_exc(), turn_id, retry + ) + return await self._fix_outline(new_outline, source_doc, turn_id, retry) + else: + raise ValueError("Failed to generate outline, tried too many times") + + +class PPTAgent(PPTGen): + """ + A class to generate PowerPoint presentations with a crew of agents. + """ + + roles: list[str] = [ + "editor", + "coder", + "content_organizer", + "layout_selector", + "notes_generator", + ] + + def generate_slide( + self, slide_idx: int, outline_item: OutlineItem + ) -> tuple[SlidePage, CodeExecutor]: + """ + Generate a slide from the outline item. + """ + if outline_item.section == "Functional": + layout = self.layouts[ + max( + self.functional_layouts, + key=lambda x: edit_distance(x, outline_item.purpose), + ) + ] + slide_desc = FunctionalContent[outline_item.purpose] + if outline_item.purpose == FunctionalLayouts.SECTION_OUTLINE.value: + outline_item.purpose = f"Section Outline of {outline_item.indexs}" + outline_item.indexs = {} + slide_content = ( + "Overview of the Document:\n" + + self.source_doc.get_overview(include_summary=True) + ) + elif outline_item.purpose == FunctionalLayouts.TOC.value: + slide_content = "Table of Contents:\n" + self.toc + else: + slide_content = "This slide is a functional layout, please follow the slide description and content schema to generate the slide content." + header, _, _ = outline_item.retrieve(slide_idx, self.source_doc) + header += slide_desc + else: + layout, header, slide_content = self._select_layout(slide_idx, outline_item) + command_list, template_id = self._generate_content( + layout, slide_content, header + ) + notes = self._generate_notes(slide_content, header) + slide, code_executor = self._edit_slide(command_list, template_id, notes) + slide.slide_notes = self._generate_notes(slide_content, header) + return slide, code_executor + + @tenacity_decorator + def _select_layout( + self, slide_idx: int, outline_item: OutlineItem + ) -> tuple[Layout, str, str]: + """ + Select a layout for the slide. + """ + header, content_source, images = outline_item.retrieve( + slide_idx, self.source_doc + ) + if len(content_source) == 0: + key_points = [] + else: + _, key_points = self.staffs["content_organizer"]( + content_source=content_source + ) + slide_content = json.dumps(key_points, indent=2, ensure_ascii=False) + layouts = self.text_layouts + if len(images) > 0: + slide_content += "\nImages:\n" + "\n".join(images) + layouts = self.multimodal_layouts + + _, layout_selection = self.staffs["layout_selector"]( + outline=self.simple_outline, + slide_description=header, + slide_content=slide_content, + available_layouts=layouts, + ) + layout = max( + self.layouts.keys(), + key=lambda x: edit_distance(x, layout_selection["layout"]), + ) + if "image" in layout and len(images) == 0: + logger.debug( + f"An image layout: {layout} is selected, but no images are provided, please check the parsed document and outline item:\n {outline_item}" + ) + elif "image" not in layout and len(images) > 0: + logger.debug( + f"A pure text layout: {layout} is selected, but images are provided, please check the parsed document and outline item:\n {outline_item}\n Set images to empty list." + ) + slide_content = slide_content[: slide_content.rfind("\nImages:\n")] + return self.layouts[layout], header, slide_content + + def _generate_content( + self, + layout: Layout, + slide_content: str, + slide_description: str, + ) -> tuple[list, int]: + """ + Synergize Agents to generate a slide. + + Args: + layout (Layout): The layout data. + slide_content (str): The slide content. + slide_description (str): The description of the slide. + + Returns: + tuple[list, int]: The generated command list and template id. + """ + turn_id, editor_output = self.staffs["editor"]( + outline=self.simple_outline, + metadata=self.source_doc.metainfo, + slide_description=slide_description, + slide_content=slide_content, + schema=layout.content_schema, + ) + command_list, template_id = self._generate_commands( + editor_output, layout, turn_id + ) + return command_list, template_id + + def _generate_notes( + self, + slide_content: str, + slide_description: str, + ) -> str: + """ + Generate speaker notes for a slide. + """ + _, notes = self.staffs["notes_generator"]( + slide_content=slide_content, + slide_description=slide_description, + ) + return notes + def _edit_slide( + self, command_list: list, template_id: int, notes: str + ) -> tuple[SlidePage, CodeExecutor]: + code_executor = CodeExecutor(self.retry_times) + turn_id, edit_actions = self.staffs["coder"]( + api_docs=code_executor.get_apis_docs(API_TYPES.Agent.value), + edit_target=self.presentation.slides[template_id - 1].to_html(), + command_list="\n".join([str(i) for i in command_list]), + ) + for error_idx in range(self.retry_times): + edit_slide: SlidePage = deepcopy(self.presentation.slides[template_id - 1]) + feedback = code_executor.execute_actions( + edit_actions, edit_slide, self.source_doc + ) + if feedback is None: + break + logger.info( + "Failed to generate slide, tried %d/%d times, error: %s", + error_idx + 1, + self.retry_times, + str(feedback[1]), + ) + logger.debug(traceback.format_exc()) + if error_idx == self.retry_times: + raise Exception( + f"Failed to generate slide, tried too many times at editing\ntraceback: {feedback[1]}" + ) + edit_actions = self.staffs["coder"].retry( + feedback[0], feedback[1], turn_id, error_idx + 1 + ) + self.empty_prs.build_slide(edit_slide) + return edit_slide, code_executor + + def _generate_commands( + self, editor_output: dict, layout: Layout, turn_id: int, retry: int = 0 + ): + """ + Generate commands for editing the slide content. + """ + command_list = [] + try: + layout.validate(editor_output, self.source_doc.image_dir) + if self.length_factor is not None: + layout.validate_length( + editor_output, self.length_factor, self.language_model + ) + old_data = layout.get_old_data(editor_output) + template_id = layout.get_slide_id(editor_output) + except Exception as e: + if retry < self.retry_times: + new_output = self.staffs["editor"].retry( + e, + traceback.format_exc(), + turn_id, + retry + 1, + ) + return self._generate_commands(new_output, layout, turn_id, retry + 1) + else: + raise Exception( + f"Failed to generate commands, tried too many times at editing\ntraceback: {e}" + ) + + for el_name, old_content in old_data.items(): + if not isinstance(old_content, list): + old_content = [old_content] + + new_content = editor_output.get(el_name, {"data": []})["data"] + if not isinstance(new_content, list): + new_content = [new_content] + new_content = [i for i in new_content if i] + quantity_change = len(new_content) - len(old_content) + command_list.append( + ( + el_name, + layout[el_name].el_type, + f"quantity_change: {quantity_change}", + old_content, + new_content, + ) + ) + + assert len(command_list) > 0, "No commands generated" + return command_list, template_id + + +class PPTAgentAsync(PPTGenAsync): + """ + Asynchronous version of PPTAgent that uses AsyncAgent for concurrent processing. + """ + + roles: list[str] = [ + "editor", + "coder", + "content_organizer", + "layout_selector", + "notes_generator", + ] + + async def generate_slide( + self, slide_idx: int, outline_item: OutlineItem + ) -> tuple[SlidePage, CodeExecutor]: + """ + Asynchronously generate a slide from the outline item. + """ + if outline_item.section == "Functional": + layout = self.layouts[ + max( + self.functional_layouts, + key=lambda x: edit_distance(x.lower(), outline_item.purpose), + ) + ] + slide_desc = FunctionalContent[outline_item.purpose] + if outline_item.purpose == FunctionalLayouts.SECTION_OUTLINE.value: + outline_item.purpose = f"Section Outline of {outline_item.indexs}" + outline_item.indexs = {} + slide_content = ( + "Overview of the Document:\n" + + self.source_doc.get_overview(include_summary=True) + ) + elif outline_item.purpose == FunctionalLayouts.TOC.value: + slide_content = "Table of Contents:\n" + self.toc + else: + slide_content = "This slide is a functional layout, please follow the slide description and content schema to generate the slide content." + header, _, _ = outline_item.retrieve(slide_idx, self.source_doc) + header += slide_desc + else: + layout, header, slide_content = await self._select_layout( + slide_idx, outline_item + ) + try: + command_list, template_id = await self._generate_content( + layout, slide_content, header + ) + notes = await self._generate_notes(slide_content, header) + slide, code_executor = await self._edit_slide(command_list, template_id, notes) + slide.slide_notes = await self._generate_notes(slide_content, header) + except Exception as e: + logger.error(f"Failed to generate slide {slide_idx}, error: {e}") + traceback.print_exc() + raise e + return slide, code_executor + + @tenacity_decorator + async def _select_layout( + self, slide_idx: int, outline_item: OutlineItem + ) -> tuple[Layout, str, str]: + """ + Asynchronously select a layout for the slide. + """ + header, content_source, images = outline_item.retrieve( + slide_idx, self.source_doc + ) + if len(content_source) == 0: + key_points = [] + else: + _, key_points = await self.staffs["content_organizer"]( + content_source=content_source + ) + slide_content = json.dumps(key_points, indent=2, ensure_ascii=False) + layouts = self.text_layouts + if len(images) > 0: + slide_content += "\nImages:\n" + "\n".join(images) + layouts = self.multimodal_layouts + + _, layout_selection = await self.staffs["layout_selector"]( + outline=self.simple_outline, + slide_description=header, + slide_content=slide_content, + available_layouts=layouts, + ) + layout = max( + self.layouts.keys(), + key=lambda x: edit_distance(x, layout_selection["layout"]), + ) + if "image" in layout and len(images) == 0: + logger.debug( + f"An image layout: {layout} is selected, but no images are provided, please check the parsed document and outline item:\n {outline_item}" + ) + elif "image" not in layout and len(images) > 0: + logger.debug( + f"A pure text layout: {layout} is selected, but images are provided, please check the parsed document and outline item:\n {outline_item}\n Set images to empty list." + ) + slide_content = slide_content[: slide_content.rfind("\nImages:\n")] + return self.layouts[layout], header, slide_content + + async def _generate_content( + self, + layout: Layout, + slide_content: str, + slide_description: str, + ) -> tuple[list, int]: + """ + Synergize Agents to generate a slide. + + Args: + layout (Layout): The layout data. + slide_content (str): The slide content. + slide_description (str): The description of the slide. + + Returns: + tuple[list, int]: The generated command list and template id. + """ + turn_id, editor_output = await self.staffs["editor"]( + outline=self.simple_outline, + metadata=self.source_doc.metainfo, + slide_description=slide_description, + slide_content=slide_content, + schema=layout.content_schema, + ) + command_list, template_id = await self._generate_commands( + editor_output, layout, turn_id + ) + return command_list, template_id + + async def _generate_notes( + self, + slide_content: str, + slide_description: str, + ) -> str: + """ + Generate speaker notes for a slide. + """ + _, notes = await self.staffs["notes_generator"]( + slide_content=slide_content, + slide_description=slide_description, + ) + return notes + + async def _edit_slide( + self, command_list: list, template_id: int, notes: str + ) -> tuple[SlidePage, CodeExecutor]: + """ + Asynchronously edit the slide. + """ + code_executor = CodeExecutor(self.retry_times) + turn_id, edit_actions = await self.staffs["coder"]( + api_docs=code_executor.get_apis_docs(API_TYPES.Agent.value), + edit_target=self.presentation.slides[template_id - 1].to_html(), + command_list="\n".join([str(i) for i in command_list]), + ) + + for error_idx in range(self.retry_times): + edit_slide: SlidePage = deepcopy(self.presentation.slides[template_id - 1]) + feedback = code_executor.execute_actions( + edit_actions, edit_slide, self.source_doc + ) + if feedback is None: + break + logger.info( + "Failed to generate slide, tried %d/%d times, error: %s", + error_idx + 1, + self.retry_times, + str(feedback[1]), + ) + if error_idx == self.retry_times: + raise Exception( + f"Failed to generate slide, tried too many times at editing\ntraceback: {feedback[1]}" + ) + edit_actions = await self.staffs["coder"].retry( + feedback[0], feedback[1], turn_id, error_idx + 1 + ) + self.empty_prs.build_slide(edit_slide) + return edit_slide, code_executor + + async def _generate_commands( + self, editor_output: dict, layout: Layout, turn_id: int, retry: int = 0 + ): + """ + Asynchronously generate commands for editing the slide content. + + Args: + editor_output (dict): The editor output. + layout (Layout): The layout object containing content schema. + turn_id (int): The turn ID for retrying. + retry (int, optional): The number of retries. Defaults to 0. + + Returns: + list: A list of commands. + + Raises: + Exception: If command generation fails. + """ + command_list = [] + try: + layout.validate(editor_output, self.source_doc.image_dir) + if self.length_factor is not None: + await layout.validate_length_async( + editor_output, self.length_factor, self.language_model + ) + old_data = layout.get_old_data(editor_output) + template_id = layout.get_slide_id(editor_output) + except Exception as e: + if retry < self.retry_times: + new_output = await self.staffs["editor"].retry( + e, + traceback.format_exc(), + turn_id, + retry + 1, + ) + return await self._generate_commands( + new_output, layout, turn_id, retry + 1 + ) + else: + raise Exception( + f"Failed to generate commands, tried too many times at editing\ntraceback: {e}" + ) + + for el_name, old_content in old_data.items(): + if not isinstance(old_content, list): + old_content = [old_content] + + new_content = editor_output.get(el_name, {"data": []})["data"] + if not isinstance(new_content, list): + new_content = [new_content] + new_content = [i for i in new_content if i] + quantity_change = len(new_content) - len(old_content) + command_list.append( + ( + el_name, + layout[el_name].el_type, + f"quantity_change: {quantity_change}", + old_content, + new_content, + ) + ) + + assert len(command_list) > 0, "No commands generated" + return command_list, template_id diff --git a/pptagent/presentation/__init__.py b/pptagent/presentation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c2f9742fba1ada4051c0b5826e5adc87571f9ff4 --- /dev/null +++ b/pptagent/presentation/__init__.py @@ -0,0 +1,46 @@ +from .layout import Layout +from .presentation import Presentation, SlidePage +from .shapes import ( + SHAPECAST, + Background, + Closure, + ClosureType, + Fill, + Font, + FreeShape, + GroupShape, + Line, + Paragraph, + Picture, + Placeholder, + SemanticPicture, + ShapeElement, + StyleArg, + TextBox, + TextFrame, + UnsupportedShape, +) + +__all__ = [ + "Presentation", + "SlidePage", + "SHAPECAST", + "Background", + "Closure", + "ClosureType", + "Fill", + "Font", + "FreeShape", + "GroupShape", + "Layout", + "Line", + "Paragraph", + "Picture", + "Placeholder", + "SemanticPicture", + "ShapeElement", + "StyleArg", + "TextBox", + "TextFrame", + "UnsupportedShape", +] diff --git a/pptagent/presentation/__pycache__/__init__.cpython-312.pyc b/pptagent/presentation/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3dfecb3d67b46a83d8f1bb102cd556d4193fab9d Binary files /dev/null and b/pptagent/presentation/__pycache__/__init__.cpython-312.pyc differ diff --git a/pptagent/presentation/__pycache__/layout.cpython-312.pyc b/pptagent/presentation/__pycache__/layout.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..659730edc25271c9edabde6fdb92da7d64be86e2 Binary files /dev/null and b/pptagent/presentation/__pycache__/layout.cpython-312.pyc differ diff --git a/pptagent/presentation/__pycache__/presentation.cpython-312.pyc b/pptagent/presentation/__pycache__/presentation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d0a55d6045a3b3fc78350f504ea16ee48193f72 Binary files /dev/null and b/pptagent/presentation/__pycache__/presentation.cpython-312.pyc differ diff --git a/pptagent/presentation/__pycache__/shapes.cpython-312.pyc b/pptagent/presentation/__pycache__/shapes.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..735acd7eee42d92d8eb4b736cf7e2441fd7837d3 Binary files /dev/null and b/pptagent/presentation/__pycache__/shapes.cpython-312.pyc differ diff --git a/pptagent/presentation/layout.py b/pptagent/presentation/layout.py new file mode 100644 index 0000000000000000000000000000000000000000..f944fc2df2318c45d7740e6a12dfd957d6334f1b --- /dev/null +++ b/pptagent/presentation/layout.py @@ -0,0 +1,234 @@ +import asyncio +from dataclasses import dataclass +from typing import Literal, Optional + +from jinja2 import StrictUndefined, Template + +from pptagent.llms import LLM, AsyncLLM +from pptagent.utils import get_logger, package_join, pbasename, pexists, pjoin + +logger = get_logger(__name__) + +LENGTHY_REWRITE_PROMPT = Template( + open(package_join("prompts", "lengthy_rewrite.txt")).read(), + undefined=StrictUndefined, +) + + +@dataclass +class Element: + el_name: str + content: list[str] + description: str + el_type: Literal["text", "image"] + suggested_characters: int | None + variable_length: tuple[int, int] | None + variable_data: dict[str, list[str]] | None + + def get_schema(self): + schema = f"Element: {self.el_name}\n" + base_attrs = ["description", "el_type"] + for attr in base_attrs: + schema += f"\t{attr}: {getattr(self, attr)}\n" + if self.el_type == "text": + schema += f"\tsuggested_characters: {self.suggested_characters}\n" + if self.variable_length is not None: + schema += f"\tThe length of the element can vary between {self.variable_length[0]} and {self.variable_length[1]}\n" + schema += f"\tThe default quantity of the element is {len(self.content)}\n" + return schema + + @classmethod + def from_dict(cls, el_name: str, data: dict): + if not isinstance(data["data"], list): + data["data"] = [data["data"]] + if data["type"] == "text": + suggested_characters = max(len(i) for i in data["data"]) + elif data["type"] == "image": + suggested_characters = None + return cls( + el_name=el_name, + el_type=data["type"], + content=data["data"], + description=data["description"], + variable_length=data.get("variableLength", None), + variable_data=data.get("variableData", None), + suggested_characters=suggested_characters, + ) + + +@dataclass +class Layout: + title: str + template_id: int + slides: list[int] + elements: list[Element] + vary_mapping: dict[int, int] | None # mapping for variable elements + + @classmethod + def from_dict(cls, title: str, data: dict): + elements = [ + Element.from_dict(el_name, data["content_schema"][el_name]) + for el_name in data["content_schema"] + ] + num_vary_elements = sum((el.variable_length is not None) for el in elements) + if num_vary_elements > 1: + raise ValueError("Only one variable element is allowed") + return cls( + title=title, + template_id=data["template_id"], + slides=data["slides"], + elements=elements, + vary_mapping=data.get("vary_mapping", None), + ) + + def get_slide_id(self, data: dict): + for el in self.elements: + if el.variable_length is not None: + num_vary = len(data[el.el_name]["data"]) + if num_vary < el.variable_length[0]: + raise ValueError( + f"The length of {el.el_name}: {num_vary} is less than the minimum length: {el.variable_length[0]}" + ) + if num_vary > el.variable_length[1]: + raise ValueError( + f"The length of {el.el_name}: {num_vary} is greater than the maximum length: {el.variable_length[1]}" + ) + return self.vary_mapping[str(num_vary)] + return self.template_id + + def get_old_data(self, editor_output: Optional[dict] = None): + if editor_output is None: + return {el.el_name: el.content for el in self.elements} + old_data = {} + for el in self.elements: + if el.variable_length is not None: + key = str(len(editor_output[el.el_name]["data"])) + assert ( + key in el.variable_data + ), f"The length of element {el.el_name} varies between {el.variable_length[0]} and {el.variable_length[1]}, but got data of length {key} which is not supported" + old_data[el.el_name] = el.variable_data[key] + else: + old_data[el.el_name] = el.content + return old_data + + def validate(self, editor_output: dict, image_dir: str): + for el_name, el_data in editor_output.items(): + assert ( + "data" in el_data + ), """key `data` not found in output + please give your output as a dict like + { + "element1": { + "data": ["text1", "text2"] for text elements + or ["/path/to/image", "..."] for image elements + }, + }""" + assert ( + el_name in self + ), f"Element {el_name} is not a valid element, supported elements are {[el.el_name for el in self.elements]}" + if self[el_name].el_type == "image": + for i in range(len(el_data["data"])): + if pexists(pjoin(image_dir, el_data["data"][i])): + el_data["data"][i] = pjoin(image_dir, el_data["data"][i]) + if not pexists(el_data["data"][i]): + basename = pbasename(el_data["data"][i]) + if pexists(pjoin(image_dir, basename)): + el_data["data"][i] = pjoin(image_dir, basename) + else: + raise ValueError( + f"Image {el_data['data'][i]} not found\n" + "Please check the image path and use only existing images\n" + "Or, leave a blank list for this element" + ) + + def validate_length( + self, editor_output: dict, length_factor: float, language_model: LLM + ): + for el_name, el_data in editor_output.items(): + if self[el_name].el_type == "text": + charater_counts = [len(i) for i in el_data["data"]] + if ( + max(charater_counts) + > self[el_name].suggested_characters * length_factor + ): + el_data["data"] = language_model( + LENGTHY_REWRITE_PROMPT.render( + el_name=el_name, + content=el_data["data"], + suggested_characters=f"{self[el_name].suggested_characters} characters", + ), + return_json=True, + ) + assert isinstance( + el_data["data"], list + ), f"Generated data is lengthy, expect {self[el_name].suggested_characters} characters, but got {len(el_data['data'])} characters for element {el_name}" + + async def validate_length_async( + self, editor_output: dict, length_factor: float, language_model: AsyncLLM + ): + async with asyncio.TaskGroup() as tg: + tasks = {} + for el_name, el_data in editor_output.items(): + if self[el_name].el_type == "text": + charater_counts = [len(i) for i in el_data["data"]] + if ( + max(charater_counts) + > self[el_name].suggested_characters * length_factor + ): + task = tg.create_task( + language_model( + LENGTHY_REWRITE_PROMPT.render( + el_name=el_name, + content=el_data["data"], + suggested_characters=f"{self[el_name].suggested_characters} characters", + ), + return_json=True, + ) + ) + tasks[el_name] = task + + for el_name, task in tasks.items(): + assert isinstance( + editor_output[el_name]["data"], list + ), f"Generated data is lengthy, expect {self[el_name].suggested_characters} characters, but got {len(editor_output[el_name]['data'])} characters for element {el_name}" + new_data = await task + logger.debug( + f"Lengthy rewrite for {el_name}:\n From {editor_output[el_name]['data']}\n To {new_data}" + ) + editor_output[el_name]["data"] = new_data + + @property + def content_schema(self): + return "\n".join([el.get_schema() for el in self.elements]) + + def remove_item(self, item: str): + for el in self.elements: + if item in el.content: + el.content.remove(item) + if len(el.content) == 0: + self.elements.remove(el) + return + else: + raise ValueError(f"Item {item} not found in layout {self.title}") + + def __contains__(self, key: str | int): + if isinstance(key, int): + return key in self.slides + elif isinstance(key, str): + for el in self.elements: + if el.el_name == key: + return True + return False + raise ValueError(f"Invalid key type: {type(key)}, should be str or int") + + def __getitem__(self, key: str): + for el in self.elements: + if el.el_name == key: + return el + raise ValueError(f"Element {key} not found") + + def __iter__(self): + return iter(self.elements) + + def __len__(self): + return len(self.elements) diff --git a/pptagent/presentation/presentation.py b/pptagent/presentation/presentation.py new file mode 100644 index 0000000000000000000000000000000000000000..a1cd825101e2663e6b758e4e02710a12e98e5f0c --- /dev/null +++ b/pptagent/presentation/presentation.py @@ -0,0 +1,481 @@ +import traceback +from collections.abc import Generator +from typing import Literal, Optional + +from pptx import Presentation as load_prs +from pptx.enum.shapes import MSO_SHAPE_TYPE +from pptx.shapes.base import BaseShape +from pptx.shapes.group import GroupShape as PPTXGroupShape +from pptx.slide import Slide as PPTXSlide + +from pptagent.utils import Config, get_logger, package_join + +from .shapes import ( + Background, + GroupShape, + Paragraph, + Picture, + ShapeElement, + StyleArg, +) + +# Type variable for ShapeElement subclasses + +logger = get_logger(__name__) + + +class SlidePage: + """ + A class to represent a slide page in a presentation. + """ + + def __init__( + self, + shapes: list[ShapeElement], + backgrounds: list[Background], + slide_idx: int, + real_idx: int, + slide_notes: Optional[str], + slide_layout_name: Optional[str], + slide_title: Optional[str], + slide_width: int, + slide_height: int, + ): + """ + Initialize a SlidePage. + + Args: + shapes (List[ShapeElement]): The shapes in the slide. + backgrounds (List[Background]): The backgrounds of the slide. + slide_idx (int): The index of the slide. + real_idx (int): The real index of the slide. + slide_notes (Optional[str]): The notes of the slide. + slide_layout_name (Optional[str]): The layout name of the slide. + slide_title (Optional[str]): The title of the slide. + slide_width (int): The width of the slide. + slide_height (int): The height of the slide. + """ + self.shapes = shapes + self.backgrounds = backgrounds + self.slide_idx = slide_idx + self.real_idx = real_idx + self.slide_notes = slide_notes + self.slide_layout_name = slide_layout_name + self.slide_title = slide_title + self.slide_width = slide_width + self.slide_height = slide_height + + # Assign group labels to group shapes + groups_shapes_labels = [] + for shape in self.shape_filter(GroupShape): + for group_shape in groups_shapes_labels: + if group_shape == shape: + shape.group_label = group_shape.group_label + continue + groups_shapes_labels.append(shape) + shape.group_label = f"group_{len(groups_shapes_labels)}" + + @classmethod + def from_slide( + cls, + slide: PPTXSlide, + slide_idx: int, + real_idx: int, + slide_width: int, + slide_height: int, + config: Config, + shape_cast: dict[MSO_SHAPE_TYPE, type[ShapeElement] | None], + ) -> "SlidePage": + """ + Create a SlidePage from a PPTXSlide. + + Args: + slide (PPTXSlide): The slide object. + slide_idx (int): The index of the slide. + real_idx (int): The real index of the slide. + slide_width (int): The width of the slide. + slide_height (int): The height of the slide. + config (Config): The configuration object. + shape_cast (dict[MSO_SHAPE_TYPE, type[ShapeElement] | None]): Mapping of shape types to their corresponding ShapeElement classes. + Set the value to None for any MSO_SHAPE_TYPE to exclude that shape type from processing. + Returns: + SlidePage: The created SlidePage. + """ + backgrounds = [Background.from_slide(slide, config)] + shapes = [] + for i, shape in enumerate(slide.shapes): + if not shape.visible: + continue + if shape_cast.get(shape.shape_type, -1) is None: + continue + shapes.append( + ShapeElement.from_shape( + slide_idx, i, shape, config, slide_width * slide_height, shape_cast + ) + ) + for i, s in enumerate(shapes): + if isinstance(s, Picture) and s.area / s.slide_area > 0.95: + backgrounds.append(shapes.pop(i)) + + slide_layout_name = slide.slide_layout.name if slide.slide_layout else None + slide_title = slide.shapes.title.text if slide.shapes.title else None + slide_notes = ( + slide.notes_slide.notes_text_frame.text + if slide.has_notes_slide and slide.notes_slide.notes_text_frame + else None + ) + + return cls( + shapes, + backgrounds, + slide_idx, + real_idx, + slide_notes, + slide_layout_name, + slide_title, + slide_width, + slide_height, + ) + + def build(self, slide: PPTXSlide) -> PPTXSlide: + """ + Build the slide page in a slide. + + Args: + slide (PPTXSlide): The slide to build the slide page in. + + Returns: + PPTXSlide: The built slide. + """ + # Remove existing placeholders + for ph in slide.placeholders: + ph.element.getparent().remove(ph.element) + + # Build backgrounds, shapes and apply closures + for shape in sorted(self.backgrounds + self.shapes, key=lambda x: x.shape_idx): + build_shape = shape.build(slide) + for closure in shape.closures: + try: + closure.apply(build_shape) + except Exception as e: + raise ValueError(f"Failed to apply closures to slides: {e}") + return slide + + def iter_paragraphs(self) -> Generator[Paragraph, None, None]: + for shape in self: # this considered the group shapes + if not shape.text_frame.is_textframe: + continue + for para in shape.text_frame.paragraphs: + if para.idx != -1: + yield para + + def shape_filter( + self, + shape_type: type[ShapeElement], + from_groupshape: bool = True, + return_father: bool = False, + ) -> ( + Generator[ShapeElement, None, None] + | Generator[tuple["SlidePage", ShapeElement], None, None] + ): + """ + Filter shapes in the slide by type. + + Args: + shape_type (Type[ShapeElement]): The type of shapes to filter. + shapes (Optional[List[ShapeElement]]): The shapes to filter. + + Yields: + ShapeElement: The filtered shapes. + """ + for shape in self.shapes: + if isinstance(shape, shape_type): + if return_father: + yield (self, shape) + else: + yield shape + elif from_groupshape and isinstance(shape, GroupShape): + yield from shape.shape_filter(shape_type, return_father) + + def get_content_type(self) -> Literal["text", "image"]: + """ + Get the content type of the slide. + + Returns: + Literal["text", "image"]: The content type of the slide. + """ + if len(list(self.shape_filter(Picture))) == 0: + return "text" + return "image" + + def to_html(self, style_args: Optional[StyleArg] = None, **kwargs) -> str: + """ + Represent the slide page in HTML. + + Args: + style_args (Optional[StyleArg]): The style arguments for HTML conversion. + **kwargs: Additional arguments. + + Returns: + str: The HTML representation of the slide page. + """ + if style_args is None: + style_args = StyleArg(**kwargs) + shapes_html = [shape.to_html(style_args) for shape in self.shapes] + shapes_html = [html for html in shapes_html if html] + return "".join( + [ + "\n\n", + (f"{self.slide_title}\n" if self.slide_title else ""), + f'\n', + "\n".join(shapes_html), + "\n\n", + ] + ) + + def to_text(self, show_image: bool = False) -> str: + """ + Represent the slide page in text. + + Args: + show_image (bool): Whether to show image captions. + + Returns: + str: The text representation of the slide page. + + Raises: + ValueError: If an image caption is not found. + """ + text_content = "" + for para in self.iter_paragraphs(): + if not para.text: + continue + if para.bullet: + text_content += para.bullet + text_content += para.text + "\n" + if show_image: + for image in self.shape_filter(Picture): + text_content += "\n" + "Image: " + image.caption + return text_content + + def __iter__(self): + """ + Iterate over all shapes in the slide page. + + Yields: + ShapeElement: Each shape in the slide page. + """ + for shape in self.shapes: + if isinstance(shape, GroupShape): + yield from shape + else: + yield shape + + def __len__(self) -> int: + """ + Get the number of shapes in the slide page. + + Returns: + int: The number of shapes. + """ + return len(self.shapes) + + +class Presentation: + """ + PPTAgent's representation of a presentation. + Aiming at a more readable and editable interface. + """ + + def __init__( + self, + slides: list[SlidePage], + error_history: list[tuple[int, str]], + slide_width: float, + slide_height: float, + file_path: str, + num_pages: int, + ) -> None: + """ + Initialize the Presentation. + + Args: + slides (List[SlidePage]): The slides in the presentation. + error_history (List[Tuple[int, str]]): The error history. + slide_width (float): The width of the slides. + slide_height (float): The height of the slides. + file_path (str): The path to the presentation file. + num_pages (int): The number of pages in the presentation. + """ + self.slides = slides + self.error_history = error_history + self.slide_width = slide_width + self.slide_height = slide_height + self.num_pages = num_pages + self.source_file = file_path + self.prs = load_prs(self.source_file) + self.layout_mapping = {layout.name: layout for layout in self.prs.slide_layouts} + self.prs.core_properties.last_modified_by = "PPTAgent" + + @classmethod + def from_file( + cls, + file_path: str, + config: Config, + shape_cast: Optional[dict[MSO_SHAPE_TYPE, type[ShapeElement]]] = None, + ) -> "Presentation": + """ + Parse a Presentation from a file. + + Args: + file_path (str): The path to the presentation file. + config (Config): The configuration object. + shape_cast (dict[MSO_SHAPE_TYPE, type[ShapeElement]] | None): Optional mapping of shape types to their corresponding ShapeElement classes. + Set the value to None for any MSO_SHAPE_TYPE to exclude that shape type from processing. + Returns: + Presentation: The parsed Presentation. + """ + prs = load_prs(file_path) + slide_width = prs.slide_width + slide_height = prs.slide_height + slides = [] + error_history = [] + slide_idx = 0 + layouts = [layout.name for layout in prs.slide_layouts] + num_pages = len(prs.slides) + + if shape_cast is None: + shape_cast = {} + + for slide in prs.slides: + # Skip slides that won't be printed to PDF, as they are invisible + if slide._element.get("show", 1) == "0": + continue + + slide_idx += 1 + try: + if slide.slide_layout.name not in layouts: + raise ValueError( + f"Slide layout {slide.slide_layout.name} not found" + ) + slides.append( + SlidePage.from_slide( + slide, + slide_idx - len(error_history), + slide_idx, + slide_width.pt, + slide_height.pt, + config, + shape_cast, + ) + ) + except Exception as e: + error_history.append((slide_idx, str(e))) + logger.warning( + "Fail to parse slide %d of %s: %s", + slide_idx, + file_path, + e, + ) + logger.warning(traceback.format_exc()) + + return cls( + slides, error_history, slide_width, slide_height, file_path, num_pages + ) + + def save(self, file_path: str, layout_only: bool = False) -> None: + """ + Save the presentation to a file. + + Args: + file_path (str): The path to save the presentation to. + layout_only (bool): Whether to save only the layout. + """ + self.clear_slides() + for slide in self.slides: + if layout_only: + self.clear_images(slide.shapes) + pptx_slide = self.build_slide(slide) + if layout_only: + self.clear_text(pptx_slide.shapes) + self.prs.save(file_path) + + def build_slide(self, slide_page: SlidePage) -> PPTXSlide: + """Create a pptx slide from our in-memory SlidePage copy, including notes.""" + # create a blank slide with the same layout + pptx_slide = self.prs.slides.add_slide( + self.layout_mapping[slide_page.slide_layout_name] + ) + + # draw backgrounds & shapes + pptx_slide = slide_page.build(pptx_slide) + + # -------- NEW SECTION: copy speaker notes -------- + if slide_page.slide_notes: # we captured notes on load + notes_slide = pptx_slide.notes_slide # auto-creates if missing + tf = notes_slide.notes_text_frame + tf.clear() # optional: start fresh + tf.text = slide_page.slide_notes + # ------------------------------------------------- + + return pptx_slide + + def clear_slides(self): + """ + Delete all slides from the presentation. + """ + while len(self.prs.slides) != 0: + rId = self.prs.slides._sldIdLst[0].rId + self.prs.part.drop_rel(rId) + del self.prs.slides._sldIdLst[0] + + def clear_images(self, shapes: list[ShapeElement]): + for shape in shapes: + if isinstance(shape, GroupShape): + self.clear_images(shape.shapes) + elif isinstance(shape, Picture): + shape.img_path = package_join("resource", "pic_placeholder.png") + + def clear_text(self, shapes: list[BaseShape]): + for shape in shapes: + if isinstance(shape, PPTXGroupShape): + self.clear_text(shape.shapes) + elif shape.has_text_frame: + for para in shape.text_frame.paragraphs: + for run in para.runs: + run.text = "a" * len(run.text) + + def to_text(self, show_image: bool = False) -> str: + """ + Represent the presentation in text. + """ + return "\n----\n".join( + [ + ( + f"Slide {slide.slide_idx} of {len(self.slides)}\n" + + (f"Title:{slide.slide_title}\n" if slide.slide_title else "") + + slide.to_text(show_image) + ) + for slide in self.slides + ] + ) + + def __iter__(self): + yield from self.slides + + def __len__(self) -> int: + """ + Get the number of slides in the presentation. + """ + return len(self.slides) + + def __getstate__(self) -> object: + state = self.__dict__.copy() + state["prs"] = None + state["layout_mapping"] = None + return state + + def __setstate__(self, state: object): + self.__dict__.update(state) + self.prs = load_prs(self.source_file) + self.layout_mapping = {layout.name: layout for layout in self.prs.slide_layouts} diff --git a/pptagent/presentation/shapes.py b/pptagent/presentation/shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..f5abf64f770debd0b2c4e633830e26415548cf74 --- /dev/null +++ b/pptagent/presentation/shapes.py @@ -0,0 +1,1267 @@ +import re +from dataclasses import dataclass +from enum import Enum, auto +from types import MappingProxyType +from typing import Callable, ClassVar, Optional, Union + +from lxml import etree +from pptx.dml.fill import FillFormat +from pptx.dml.line import LineFormat +from pptx.enum.dml import MSO_FILL_TYPE +from pptx.enum.shapes import MSO_SHAPE_TYPE +from pptx.oxml import parse_xml +from pptx.parts.slide import SlidePart +from pptx.shapes.base import BaseShape +from pptx.shapes.group import GroupShape as PPTXGroupShape +from pptx.shapes.picture import Picture as PPTXPicture +from pptx.shapes.placeholder import PlaceholderPicture, SlidePlaceholder +from pptx.slide import Slide as PPTXSlide +from pptx.slide import _Background +from pptx.text.text import _Paragraph +from pptx.util import Length + +from pptagent.utils import ( + Config, + dict_to_object, + package_join, + parse_groupshape, + parsing_image, + pjoin, + runs_merge, +) + +INDENT = "\t" + + +def shape_normalize(shape: BaseShape): + """ + This function is used to filter out those malfunctioned attrs. + """ + if not shape.has_text_frame: + return + for para in shape.text_frame.paragraphs: + for run in para.runs: + run.hyperlink.address = None + + +class ClosureType(Enum): + CLONE = auto() + REPLACE = auto() + DELETE = auto() + STYLE = auto() + MERGE = auto() + + def __str__(self): + return self.name.lower() + + @classmethod + def to_default_dict(cls): + return {key: [] for key in cls} + + +@dataclass +class StyleArg: + """ + A class to represent style arguments for HTML conversion. + """ + + paragraph_id: bool = True + element_id: bool = True + font_style: bool = True + fill_style: bool = True + area: bool = False + size: bool = False + geometry: bool = False + show_name: bool = False + show_image: bool = True + show_empty: bool = False + show_content: bool = True + show_semantic_name: bool = False + + @classmethod + def all_true(cls) -> "StyleArg": + """ + Create a StyleArg instance with all options enabled. + + Returns: + StyleArg: A StyleArg instance with all options enabled. + """ + return cls( + area=True, + size=True, + geometry=True, + show_semantic_name=True, + ) + + +class Fill: + """ + A class to represent a fill. + """ + + def __init__( + self, + fill_type: MSO_FILL_TYPE, + fill_str: str, + fill_xml: str, + image_path: Optional[str] = None, + ): + self.fill_type = fill_type + self.fill_str = fill_str + self.fill_xml = fill_xml + self.image_path = image_path + + @classmethod + def from_shape(cls, fill: Optional[FillFormat], part: SlidePart, config: Config): + if fill is None or fill.type is None or fill.type == MSO_FILL_TYPE.BACKGROUND: + return cls(MSO_FILL_TYPE.BACKGROUND, "", None) + + fill_str = "Fill: " + str(fill.value) + fill_xml = fill._xPr.xml + fill_type = fill.type + image_path = None + if fill_type == MSO_FILL_TYPE.PICTURE: + image = part.get_image(fill._fill.rId) + image_path = pjoin(config.IMAGE_DIR, f"{image.sha1}.{image.ext}") + image_path = parsing_image(image, image_path) + return cls(fill_type, fill_str, fill_xml, image_path) + + # We pass an element with fill attribute instead of a fill object because `python-pptx` automatically creates a fill object when accessing this attribute, which would cause inconsistency + def build( + self, fill_ele: LineFormat | _Background | BaseShape, part: SlidePart + ) -> None: + """ + Build the fill in a shape. + Args: + shape (BaseShape): The shape to apply fill to. + fill_xml (Optional[str]): The fill XML to apply. + """ + if self.fill_type == MSO_FILL_TYPE.BACKGROUND: + return + fill = fill_ele.fill + if self.fill_type == MSO_FILL_TYPE.PICTURE: + fill.blip() + _, rId = part.get_or_add_image_part(self.image_path) + fill.rId = rId + else: + new_element = etree.fromstring(self.fill_xml) + fill._xPr.getparent().replace(fill._xPr, new_element) + + def to_html(self, style_args: StyleArg) -> str: + """ + Convert the fill to HTML. + """ + + +class Line: + """ + A class to represent a line. + """ + + def __init__(self, fill: Fill, line_width: float, line_dash_style: str): + self.fill = fill + self.line_width = line_width + self.line_dash_style = line_dash_style + + @classmethod + def from_shape(cls, line: Optional[LineFormat], part: SlidePart, config: Config): + line_fill = getattr(line, "fill", None) + if line_fill is None: + return cls(Fill(MSO_FILL_TYPE.BACKGROUND, "", None), 0, "") + fill = Fill.from_shape(line_fill, part, config) + line_width = line.width + line_dash_style = line.dash_style + return cls(fill, line_width, line_dash_style) + + def build(self, line: LineFormat, part: SlidePart) -> None: + """ + Build the line in a shape. + """ + if self.fill.fill_type == MSO_FILL_TYPE.BACKGROUND: + return + self.fill.build(line, part) + line.width = self.line_width + line.dash_style = self.line_dash_style + + +class Background(Fill): + """ + A class to represent a slide background. + """ + + shape_idx: int = -1 + + @classmethod + def from_slide(cls, slide: PPTXSlide, config: Config) -> "Background": + """ + Build the background in a slide. + + Args: + slide (PPTXSlide): The slide to build the background in. + """ + background = slide.background + return cls.from_shape(background.fill, slide.part, config) + + def build(self, slide: PPTXSlide) -> None: + """ + Build the background in a slide. + """ + super().build(slide.background, slide.part) + + def to_html(self, style_args: StyleArg) -> str: + """ + Convert the background to HTML. + + Args: + style_args (StyleArg): The style arguments for HTML conversion. + + Returns: + str: The HTML representation of the background. + """ + raise NotImplementedError("Background to HTML conversion is not implemented") + + @property + def closures(self) -> list: + """ + Get the closure for the background. + """ + return [] + + +@dataclass +class Closure: + """ + A class to represent a closure that can be applied to a shape. + """ + + closure: Callable[[BaseShape], None] + paragraph_id: int = -1 + + def apply(self, shape: BaseShape) -> None: + """ + Apply the closure to a shape. + + Args: + shape (BaseShape): The shape to apply the closure to. + """ + self.closure(shape) + + def __gt__(self, other: "Closure") -> bool: + """ + Compare closures based on paragraph_id. + + Args: + other (Closure): Another closure to compare with. + + Returns: + bool: True if this closure's paragraph_id is greater than the other's. + """ + if self.paragraph_id != other.paragraph_id: + return self.paragraph_id > other.paragraph_id + + +@dataclass +class Font: + name: str + color: str + size: Length + bold: bool + italic: bool + underline: bool + strikethrough: bool + + def update(self, other: "Font"): + """ + Merge a list of fonts into a single font. + """ + for key, value in other.__dict__.items(): + if getattr(self, key) is None: + setattr(self, key, value) + + def override(self, other: "Font"): + """ + Merge a list of fonts into a single font. + """ + for key, value in other.__dict__.items(): + if value is not None: + setattr(self, key, value) + + def unify(self, others: list["Font"], clear_others: bool = False): + """ + Merge a list of fonts into a single font. + """ + if len(others) == 0: + return + for key in list(self.__dict__.keys()): + values = [d.__dict__[key] for d in others] + if not all(value == values[0] for value in values): + continue + setattr(self, key, values[0]) + if not clear_others: + continue + for d in others: + setattr(d, key, None) + + def to_style(self) -> str: + """ + Convert a font dictionary to a CSS style string. + + Args: + font (Dict[str, Any]): The font dictionary. + + Returns: + str: The CSS style string. + """ + styles = [] + if self.size: + styles.append(f"font-size: {self.size}pt") + + if self.color: + styles.append(f"color: #{self.color}") + + if self.bold: + styles.append("font-weight: bold") + + if self.italic: + styles.append("font-style: italic") + + return "; ".join(styles) + + +class Paragraph: + """ + A class to represent a paragraph in a text frame. + """ + + def __init__(self, paragraph: _Paragraph, idx: int): + """ + Initialize a Paragraph. + + Args: + paragraph (_Paragraph): The paragraph object. + idx (int): The index of the paragraph. + """ + run = runs_merge(paragraph) + self.idx = idx + self.real_idx = idx + self.bullet = paragraph.bullet + if run is None: + self.idx = -1 + return + self.font = Font(**paragraph.font.get_attrs()) + self.font.override(Font(**run.font.get_attrs())) + self.text = re.sub(r"(_x000B_|\\x0b)", " ", paragraph.text) + + def to_html(self, style_args: StyleArg) -> str: + """ + Convert the paragraph to HTML. + + Args: + style_args (StyleArg): The style arguments for HTML conversion. + + Returns: + str: The HTML representation of the paragraph. + + Raises: + ValueError: If the paragraph is not valid. + """ + if self.idx == -1: + raise ValueError(f"paragraph {self.idx} is not valid") + tag = "li" if self.bullet else "p" + id_str = f" id='{self.idx}'" if style_args.paragraph_id else "" + font_style = self.font.to_style() + style_str = ( + f" style='{font_style}'" if style_args.font_style and font_style else "" + ) + if self.bullet: + style_str += f" bullet-type='{self.bullet}'" + return f"<{tag}{id_str}{style_str}>{self.text}" + + def __repr__(self) -> str: + """ + Get a string representation of the paragraph. + + Returns: + str: A string representation of the paragraph. + """ + return f"Paragraph-{self.idx}: {self.text}" + + +class TextFrame: + """ + A class to represent a text frame in a shape. + """ + + def __init__(self, shape: BaseShape, level: int): + """ + Initialize a TextFrame. + + Args: + shape (BaseShape): The shape containing the text frame. + level (int): The indentation level. + """ + if not shape.has_text_frame: + self.is_textframe = False + return + self.paragraphs = [ + Paragraph(paragraph, idx) + for idx, paragraph in enumerate(shape.text_frame.paragraphs) + ] + para_offset = 0 + for para in self.paragraphs: + if para.idx == -1: + para_offset += 1 + else: + para.idx = para.idx - para_offset + if len(self.paragraphs) == 0: + self.is_textframe = False + return + self.level = level + self.text = shape.text + self.is_textframe = True + self.extents = shape.text_frame._extents + self.font = Font(**shape.text_frame.font.get_attrs()) + self.font.unify([para.font for para in self.paragraphs if para.idx != -1]) + + def to_html(self, style_args: StyleArg) -> str: + """ + Convert the text frame to HTML. + + Args: + style_args (StyleArg): The style arguments for HTML conversion. + + Returns: + str: The HTML representation of the text frame. + """ + if not self.is_textframe: + return "" + repr_list = [ + para.to_html(style_args) for para in self.paragraphs if para.idx != -1 + ] + return "\n".join([INDENT * self.level + repr for repr in repr_list]) + + def __repr__(self) -> str: + """ + Get a string representation of the text frame. + + Returns: + str: A string representation of the text frame. + """ + if not self.is_textframe: + return "TextFrame: null" + return f"TextFrame: {self.paragraphs}" + + def __len__(self) -> int: + """ + Get the length of the text in the text frame. + + Returns: + int: The length of the text. + """ + if not self.is_textframe: + return 0 + return len(self.text) + + +@dataclass +class ShapeElement: + """ + Base class for shape elements in a presentation. + """ + + config: Config + slide_idx: int + shape_idx: int + style: dict + data: list + text_frame: TextFrame + level: int + slide_area: float + xml: str + fill: Fill + line: Line + shape: BaseShape + _closures: dict[ClosureType, list[Closure]] + + @classmethod + def from_shape( + cls: type["ShapeElement"], + slide_idx: int, + shape_idx: int, + shape: BaseShape, + config: Config, + slide_area: float, + shape_cast: dict[MSO_SHAPE_TYPE, type["ShapeElement"] | None], + level: int = 0, + ) -> "ShapeElement": + """ + Create a ShapeElement from a BaseShape. + + Args: + slide_idx (int): The index of the slide. + shape_idx (int): The index of the shape. + shape (BaseShape): The shape object. + config (Config): The configuration object. + slide_area (float): The area of the slide. + level (int): The indentation level. + shape_cast (dict[MSO_SHAPE_TYPE, type[ShapeElement]] | None): Optional mapping of shape types to their corresponding ShapeElement classes. + Set the value to None for any MSO_SHAPE_TYPE to exclude that shape type from processing. + Returns: + ShapeElement: The created ShapeElement. + + Raises: + ValueError: If nested group shapes are not allowed. + """ + if shape_idx > 100 and isinstance(shape, PPTXGroupShape): + raise ValueError("Nested group shapes are not allowed") + + shape_normalize(shape) + + # Create style dictionary + style = { + "shape_bounds": { + "width": shape.width, + "height": shape.height, + "left": shape.left, + "top": shape.top, + }, + "shape_type": str(shape.shape_type).split("(")[0].lower(), + "rotation": shape.rotation, + "name": shape.name, + } + + # Determine semantic name + try: + # For auto shapes (rectangle, oval, triangle, star...) + autoshape = shape.auto_shape_type + assert autoshape is not None + style["semantic_name"] = str(autoshape).split()[0].lower().strip() + except Exception: + # For other shapes (freeform, connector, table, chart...) + style["semantic_name"] = str(shape.shape_type).split("(")[0].lower().strip() + + # Create text frame + text_frame = TextFrame(shape, level + 1) + + # Create appropriate shape element based on shape type + shape_class = shape_cast.get(shape.shape_type, UnsupportedShape) + if shape_class is UnsupportedShape: + shape_class = SHAPECAST.get(shape.shape_type, UnsupportedShape) + + if shape_class == Placeholder: + shape_class = Placeholder.from_shape + + if shape_class == GroupShape: + shape_class = GroupShape.with_shape_cast(shape_cast) + + return shape_class( + config=config, + slide_idx=slide_idx, + shape_idx=shape_idx, + style=style, + data=[], + text_frame=text_frame, + level=level, + slide_area=slide_area, + xml=shape._element.xml, + fill=Fill.from_shape(getattr(shape, "fill", None), shape.part, config), + line=Line.from_shape(getattr(shape, "line", None), shape.part, config), + shape=shape, + _closures=ClosureType.to_default_dict(), + ) + + def build(self, slide: PPTXSlide) -> BaseShape: + """ + Build the shape element in a slide. + + Args: + slide (PPTXSlide): The slide to build the shape in. + + Returns: + BaseShape: The built shape. + """ + shape = slide.shapes._shape_factory( + slide.shapes._spTree.insert_element_before(parse_xml(self.xml), "p:extLst") + ) + if getattr(shape, "fill", None) is not None: + self.fill.build(shape, shape.part) + if getattr(shape, "line", None) is not None: + self.line.build(shape.line, shape.part) + return shape + + def to_html(self, style_args: StyleArg) -> str: + """ + Convert the shape element to HTML. + + Args: + style_args (StyleArg): The style arguments for HTML conversion. + + Returns: + str: The HTML representation of the shape element. + + Raises: + NotImplementedError: If not implemented in a subclass. + """ + raise NotImplementedError( + f"to_html not implemented for {self.__class__.__name__}" + ) + + @property + def text(self) -> str: + """ + Get the text of the shape element. + """ + if self.text_frame.is_textframe: + return self.text_frame.text + return "" + + def __getstate__(self) -> object: + state = self.__dict__.copy() + state["shape"] = None + return state + + def __repr__(self) -> str: + """ + Get a string representation of the shape element. + + Returns: + str: A string representation of the shape element. + """ + return f"{self.__class__.__name__}: shape {self.shape_idx} of slide {self.slide_idx}" + + @property + def closures(self) -> list[Closure]: + """ + Get the closures associated with the shape element. + + Returns: + List[Closure]: A list of closures. + """ + closures = [] + closures.extend(sorted(self._closures[ClosureType.CLONE])) + closures.extend( + self._closures[ClosureType.REPLACE] + self._closures[ClosureType.STYLE] + ) + closures.extend(sorted(self._closures[ClosureType.DELETE], reverse=True)) + closures.extend(self._closures[ClosureType.MERGE]) + return closures + + @property + def indent(self) -> str: + """ + Get the indentation string for the shape element. + + Returns: + str: The indentation string. + """ + return "\t" * self.level + + @property + def left(self) -> float: + """ + Get the left position of the shape element. + + Returns: + float: The left position in points. + """ + return self.style["shape_bounds"]["left"].pt + + @left.setter + def left(self, value: float) -> None: + """ + Set the left position of the shape element. + + Args: + value (float): The left position in points. + """ + self.style["shape_bounds"]["left"] = value + + @property + def top(self) -> float: + """ + Get the top position of the shape element. + + Returns: + float: The top position in points. + """ + return self.style["shape_bounds"]["top"].pt + + @top.setter + def top(self, value: float) -> None: + """ + Set the top position of the shape element. + + Args: + value (float): The top position in points. + """ + self.style["shape_bounds"]["top"] = value + + @property + def width(self) -> float: + """ + Get the width of the shape element. + + Returns: + float: The width in points. + """ + return self.style["shape_bounds"]["width"].pt + + @width.setter + def width(self, value: float) -> None: + """ + Set the width of the shape element. + + Args: + value (float): The width in points. + """ + self.style["shape_bounds"]["width"] = value + + @property + def height(self) -> float: + """ + Get the height of the shape element. + + Returns: + float: The height in points. + """ + return self.style["shape_bounds"]["height"].pt + + @height.setter + def height(self, value: float) -> None: + """ + Set the height of the shape element. + + Args: + value (float): The height in points. + """ + self.style["shape_bounds"]["height"] = value + + @property + def area(self) -> float: + """ + Get the area of the shape element. + + Returns: + float: The area in square points. + """ + return self.width * self.height + + @property + def semantic_name(self) -> Optional[str]: + """ + Get the semantic name of the shape element. + + Returns: + Optional[str]: The semantic name, or None if not set. + """ + return self.style.get("semantic_name", None) + + @semantic_name.setter + def semantic_name(self, value: str) -> None: + """ + Set the semantic name of the shape element. + + Args: + value (str): The semantic name. + """ + self.style["semantic_name"] = value + + def get_inline_style(self, style_args: StyleArg) -> str: + """ + Get the inline style for the shape element. + + Args: + style_args (StyleArg): The style arguments for HTML conversion. + + Returns: + str: The inline style string. + """ + id_str = f" id='{self.shape_idx}'" if style_args.element_id else "" + data_attrs = [] + styles = [] + + # Add data attributes + if style_args.area: + data_attrs.append( + f"data-relative-area={self.area*100/self.slide_area:.2f}%;" + ) + if style_args.show_name: + data_attrs.append(f"data-shapeName='{self.style['name']}'") + if style_args.show_semantic_name and self.semantic_name is not None: + data_attrs.append(f"data-semanticName='{self.semantic_name}'") + + # Add style attributes + if style_args.size: + styles.append(f"width: {self.width}pt; height: {self.height}pt;") + if style_args.geometry: + styles.append(f"left: {self.left}pt; top: {self.top}pt;") + if style_args.font_style and self.text_frame.is_textframe: + font_style = self.text_frame.font.to_style() + if font_style: + styles.append(font_style) + + # Combine attributes + if len(styles) != 0: + id_str += " style='" + " ".join(styles) + "'" + if len(data_attrs) != 0: + id_str += " " + " ".join(data_attrs) + + return id_str + + +@dataclass +class UnsupportedShape(ShapeElement): + def __post_init__(self) -> None: + """ + Initialize an UnsupportedShape. + + Raises: + ValueError: Always, as the shape is unsupported. + """ + raise ValueError(f"Unsupported shape {self.shape.shape_type}") + + +class TextBox(ShapeElement): + """ + A class to represent a text box shape. + """ + + def to_html(self, style_args: StyleArg) -> str: + """ + Convert the text box to HTML. + + Args: + style_args (StyleArg): The style arguments for HTML conversion. + + Returns: + str: The HTML representation of the text box. + """ + content = self.text_frame.to_html(style_args) + if not style_args.show_content: + content = "" + if not content and not style_args.show_empty: + return "" + return ( + f"{self.indent}\n" + + content + + f"\n{self.indent}\n" + ) + + +@dataclass +class Picture(ShapeElement): + """ + A class to represent a picture shape. + """ + + def __post_init__(self): + """ + Create a Picture from a PPTXPicture. + + Returns: + Picture: The created Picture. + + Raises: + ValueError: If the image type is unsupported. + """ + img_path = pjoin( + self.config.IMAGE_DIR, + f"{self.shape.image.sha1}.{self.shape.image.ext}", + ) + img_path = parsing_image(self.shape.image, img_path) + + # Add image style information + self.style["img_style"] = { + "crop_bottom": self.shape.crop_bottom, + "crop_top": self.shape.crop_top, + "crop_left": self.shape.crop_left, + "crop_right": self.shape.crop_right, + } + self.data.extend([img_path, self.shape.name, None]) # [img_path, name, caption] + + def build(self, slide: PPTXSlide) -> PPTXPicture: + """ + Build the picture in a slide. + + Args: + slide (PPTXSlide): The slide to build the picture in. + + Returns: + PPTXPicture: The built picture. + """ + # Add picture to slide + if self.is_table: + return slide.shapes.add_table( + self.row, self.col, **self.style["shape_bounds"] + ) + + shape = slide.shapes.add_picture( + self.img_path, + **self.style["shape_bounds"], + ) + + # Set properties + shape.name = self.style["name"] + dict_to_object(self.style["img_style"], shape.image) + + # Apply shape bounds and rotation + dict_to_object(self.style["shape_bounds"], shape) + if hasattr(shape, "rotation"): + shape.rotation = self.style["rotation"] + + return shape + + @property + def is_table(self) -> bool: + return self.style.get("is_table", False) + + @is_table.setter + def is_table(self, value: bool) -> None: + self.style["is_table"] = value + + @property + def grid(self) -> tuple[int, int]: + assert self.is_table, "The shape is not a table." + return self.row, self.col + + @grid.setter + def grid(self, value: tuple[int, int]) -> None: + assert self.is_table, "The shape is not a table." + self.row, self.col = value + + @property + def img_path(self) -> str: + """ + Get the image path. + + Returns: + str: The image path. + """ + return self.data[0] + + @img_path.setter + def img_path(self, img_path: str) -> None: + """ + Set the image path. + + Args: + img_path (str): The image path. + """ + self.data[0] = img_path + + @property + def caption(self) -> Optional[str]: + """ + Get the caption. + + Returns: + Optional[str]: The caption, or None if not set. + """ + return self.data[2] + + @caption.setter + def caption(self, caption: str) -> None: + """ + Set the caption. + + Args: + caption (str): The caption. + """ + self.data[2] = caption + + def to_html(self, style_args: StyleArg) -> str: + """ + Convert the picture to HTML. + + Args: + style_args (StyleArg): The style arguments for HTML conversion. + + Returns: + str: The HTML representation of the picture. + + Raises: + ValueError: If the caption is not found. + """ + if not style_args.show_image: + return "" + if self.caption is None: + raise ValueError( + f"Caption not found for picture {self.shape_idx} of slide {self.slide_idx}" + ) + return ( + self.indent + + f"{self.caption}" + ) + + +@dataclass +class GroupShape(ShapeElement): + """ + A class to represent a group shape. + """ + + shape_cast: ClassVar[dict[MSO_SHAPE_TYPE, type[ShapeElement]]] = {} + + @classmethod + def with_shape_cast(cls, shape_cast: dict[MSO_SHAPE_TYPE, type[ShapeElement]]): + """ + Dynamically create a subclass of GroupShape with an isolated shape_cast. + """ + new_cls = type(f"{cls.__name__}_Isolated_{id(shape_cast)}", (cls,), {}) + new_cls.shape_cast = MappingProxyType(shape_cast) + return new_cls + + def __post_init__(self) -> None: + """ + Initialize a GroupShape. + """ + # Create shape elements for each shape in the group + self.data = [ + ShapeElement.from_shape( + self.slide_idx, + (self.shape_idx + 1) * 100 + i, + sub_shape, + self.config, + self.slide_area, + self.shape_cast, + level=self.level + 1, + ) + for i, sub_shape in enumerate(self.shape.shapes) + if self.shape_cast.get(sub_shape.shape_type, -1) is not None + and sub_shape.visible + ] + + # Apply shape bounds to each shape in the group + for idx, shape_bounds in enumerate(parse_groupshape(self.shape)): + if not self.shape.shapes[idx].visible: + continue + if self.shape_cast.get(self.shape.shapes[idx].shape_type, -1) is None: + continue + self.data[idx].style["shape_bounds"] = shape_bounds + + def build(self, slide: PPTXSlide) -> PPTXSlide: + """ + Build the group shape in a slide. + + Args: + slide (PPTXSlide): The slide to build the group shape in. + + Returns: + PPTXSlide: The slide with the built group shape. + """ + for shape in self.data: + shape.build(slide) + return slide + + def shape_filter( + self, shape_type: type["ShapeElement"], return_father: bool = False + ): + """ + Iterate over all shapes in the group. + + Yields: + ShapeElement: Each shape in the group. + """ + for shape in self.data: + if isinstance(shape, shape_type): + if return_father: + yield (self, shape) + else: + yield shape + + @property + def shapes(self): + return self.data + + def __eq__(self, __value: object) -> bool: + """ + Check if two group shapes are equal. + + Args: + __value (object): The object to compare with. + + Returns: + bool: True if the group shapes are equal, False otherwise. + """ + if not isinstance(__value, GroupShape) or len(self.data) != len(__value.data): + return False + for shape1, shape2 in zip(self.data, __value.data): + if isinstance(shape1, type(shape2)): + return False + return True + + def __repr__(self) -> str: + """ + Get a string representation of the group shape. + + Returns: + str: A string representation of the group shape. + """ + return f"{self.__class__.__name__}: {self.data}" + + def __iter__(self): + return iter(self.data) + + def to_html(self, style_args: StyleArg) -> str: + """ + Convert the group shape to HTML. + + Args: + style_args (StyleArg): The style arguments for HTML conversion. + + Returns: + str: The HTML representation of the group shape. + """ + content = "\n".join([shape.to_html(style_args) for shape in self.data]) + if not style_args.show_content: + content = "" + return ( + self.indent + + f"

\n" + + content + + "\n" + + self.indent + + "
\n" + ) + + @property + def group_label(self) -> str: + """ + Get the group label. + + Returns: + str: The group label. + """ + return getattr(self, "_group_label", f"group_{self.shape_idx}") + + @group_label.setter + def group_label(self, value: str) -> None: + """ + Set the group label. + + Args: + value (str): The group label. + """ + self._group_label = value + + +class FreeShape(ShapeElement): + """ + A class to represent a free shape. + """ + + def to_html(self, style_args: StyleArg) -> str: + """ + Convert the free shape to HTML. + + Args: + style_args (StyleArg): The style arguments for HTML conversion. + + Returns: + str: The HTML representation of the free shape. + """ + content = self.text_frame.to_html(style_args) + if not content and not style_args.show_empty: + return "" + return ( + f"{self.indent}
" + + f"\n{content}" + + f"\n{self.indent}
" + ) + + +@dataclass +class SemanticPicture(Picture): + """ + A class to represent a semantic picture (table, chart, etc.). + """ + + def __post_init__(self): + shape_type = str(self.shape.shape_type).split()[0] + self.style["img_style"] = {} + self.data = [ + package_join("resource", "pic_placeholder.png"), + self.shape.name, + f"This is a picture of {shape_type}", + ] + self.semantic_name = shape_type + + +class Placeholder: + """ + A class to represent a placeholder shape. + """ + + @classmethod + def from_shape( + cls, + config: Config, + slide_idx: int, + shape_idx: int, + shape: SlidePlaceholder, + **kwargs, + ) -> Union[Picture, TextBox]: + """ + Create a Placeholder from a SlidePlaceholder. + + Returns: + Union[Picture, TextBox]: The created shape element. + + Raises: + ValueError: If the placeholder type is unsupported. + AssertionError: If the placeholder has multiple types. + """ + # Ensure placeholder has only one type + assert ( + sum( + [ + shape.has_text_frame, + shape.has_chart, + shape.has_table, + isinstance(shape, PlaceholderPicture), + ] + ) + == 1 + ), "Placeholder should have only one type" + + # Create appropriate shape based on placeholder type + if isinstance(shape, PlaceholderPicture): + return Picture( + config=config, + slide_idx=slide_idx, + shape_idx=shape_idx, + shape=shape, + **kwargs, + ) + elif shape.has_text_frame: + return TextBox( + config=config, + slide_idx=slide_idx, + shape_idx=shape_idx, + shape=shape, + **kwargs, + ) + else: + raise ValueError(f"Unsupported placeholder {shape.placeholder_type}") + + +# Define shape type mapping +SHAPECAST = { + MSO_SHAPE_TYPE.AUTO_SHAPE: FreeShape, + MSO_SHAPE_TYPE.LINE: FreeShape, + MSO_SHAPE_TYPE.PICTURE: Picture, + MSO_SHAPE_TYPE.PLACEHOLDER: Placeholder, + MSO_SHAPE_TYPE.GROUP: GroupShape, + MSO_SHAPE_TYPE.TEXT_BOX: TextBox, + MSO_SHAPE_TYPE.MEDIA: SemanticPicture, + MSO_SHAPE_TYPE.TABLE: SemanticPicture, + MSO_SHAPE_TYPE.CHART: SemanticPicture, + MSO_SHAPE_TYPE.LINKED_PICTURE: SemanticPicture, + MSO_SHAPE_TYPE.EMBEDDED_OLE_OBJECT: SemanticPicture, + MSO_SHAPE_TYPE.LINKED_OLE_OBJECT: SemanticPicture, + MSO_SHAPE_TYPE.DIAGRAM: SemanticPicture, + MSO_SHAPE_TYPE.CANVAS: SemanticPicture, + MSO_SHAPE_TYPE.INK: SemanticPicture, + MSO_SHAPE_TYPE.IGX_GRAPHIC: SemanticPicture, + MSO_SHAPE_TYPE.WEB_VIDEO: SemanticPicture, +} diff --git a/pptagent/prompts/ask_category.txt b/pptagent/prompts/ask_category.txt new file mode 100644 index 0000000000000000000000000000000000000000..ef4076e18014a6302bf390c122d2358dd0a10c3d --- /dev/null +++ b/pptagent/prompts/ask_category.txt @@ -0,0 +1,15 @@ +Analyze the content layout and media types in the provided slide images. +Your objective is to provide a concise, descriptive title that captures purely the layout pattern. +Requirements: +Focus on HOW content is structured and presented, not WHAT the content is +Describe the number, visual arrangement, and interaction between different elements (text, images, diagrams, etc.) + +Avoid: +Specific topics or subjects, and detailed content descriptions + +Example Outputs: +One Central Square Chart with a explanatory paragraph +Picture and three illustrative key points +Two Landscape Images with Descriptive Text Below Each + +Output: Provide a one-line layout pattern description, without line breaks or other formatting. diff --git a/pptagent/prompts/caption.txt b/pptagent/prompts/caption.txt new file mode 100644 index 0000000000000000000000000000000000000000..4f33c6f993d73dbac9349452444148cc07c656dc --- /dev/null +++ b/pptagent/prompts/caption.txt @@ -0,0 +1,7 @@ +Describe the main content of the image in less than 50 words, avoiding unnecessary details or comments. +Additionally, classify the image as 'Table', 'Chart', 'Diagram', 'Banner', 'Background', 'Icon', 'Logo', etc. or 'Picture' if it cannot be classified as one of the above. +Give your answer in the following format: +: +Example Output: +Chart: Bar graph showing quarterly revenue growth over five years. Color-coded bars represent different product lines. Notable spike in Q4 of the most recent year, with a dotted line indicating industry average for comparison. +Now give your answer in one sentence only, without line breaks: diff --git a/pptagent/prompts/category_split.txt b/pptagent/prompts/category_split.txt new file mode 100644 index 0000000000000000000000000000000000000000..c47ae36b3861c0ed668f392510552f588cd34fc4 --- /dev/null +++ b/pptagent/prompts/category_split.txt @@ -0,0 +1,40 @@ +You are an expert presentation analyst specializing in categorizing PowerPoint slides, focusing on structural slides (Opening, Table of Contents, Section Outline, and Ending) that guide the presentation's flow. + +Instructions: +1. Analyze the provided slides and identify the existence four categories of structural slides: Opening, Table of Contents, Section Outline, and Ending, based on their content and position. +2. Include only categories present in the slides in the output, with their corresponding slide numbers. Do not include absent categories or force matches. +3. Structural Characteristics: + - Position and Quantity: + - Opening, Table of Contents, and Ending: Typically appear as single slides at the start (first and second slides) and end of the presentation respectively. + - Section Outline: Section Outline slides **must be multiple slides**, each interleave several slides to detail the content of the section. + + - Content: + - Opening: Minimal content, often meta-information (e.g., title, presenter, or context). + - Table of Contents: Lists the title of each section, optionally use numbers or bullets to indicate the order, acting as a presentation roadmap. + - Ending: Minimal content, often meta-information (e.g., "Thank You" or contact details). + - Section Outline: + - Contains a **section title** that **strictly matches** the Table of Contents (identical wording), if present. + - Concise content including a section title (matching the table of contents, if present), with optional section number or brief introduction. + - Content is concise, optionally including a section number or one-line brief overview. + +Example Outputs: +- Complete categorization: +```json +{ + "opening": [1], + "table of contents": [2], + "section outline": [3, 7], + "ending": [12] +} +``` +- Missing Table of Contents and Section Outline: +```json +{ + "opening": [1], + "ending": [10] +} +``` + +Input: {{slides}} + +Output: diff --git a/pptagent/prompts/heading_extract.txt b/pptagent/prompts/heading_extract.txt new file mode 100644 index 0000000000000000000000000000000000000000..eb83df35b14bd5db28323af1a732d55961ba9f52 --- /dev/null +++ b/pptagent/prompts/heading_extract.txt @@ -0,0 +1,16 @@ +You are a Markdown formatting assistant. I'll provide you with a Markdown text containing headings with potentially inconsistent levels (#, ##, etc.). Your task is to identify and extract the top-level headings based on their semantic and logical structure. + +Your task: +1. Analyze the logical structure of the headings. +2. Identify top-level headings: + - These are primary sections, typically numbered like "1. Introduction", "2. Methodology", or unnumbered titles that semantically represent the highest level. + - Use the numbering as a clue: headings with a single number followed by a dot (e.g., "1.", "2.") are considered top-level. + - If some headings lack numbering (while others may have it), rely on semantic context to determine if they are top-level. + - Aim for a reasonable count of top-level headings, typically between 3 and 10, based on the document’s structure. +3. Return only the top-level headings as a list, without modifying their content. + +Here's the heading list to process: + +{{headings}} + +Output: Please provide the extracted top-level headings in JSON format as a list: diff --git a/pptagent/prompts/lengthy_rewrite.txt b/pptagent/prompts/lengthy_rewrite.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b486736efb9d6037af1eb130a226fd6771282e9 --- /dev/null +++ b/pptagent/prompts/lengthy_rewrite.txt @@ -0,0 +1,18 @@ +You are a presentation writing expert tasked with rewriting a given presentation element to be concise and align with presentation style. + + +Provide your output as a JSON array of strings, e.g., +```json +[ + "text1", + "text2" +] +``` + +Rewrite or summarize the content of the element {{ el_name }} to approximately {{ suggested_characters }}. + +The current content is: + +{{ content }} + +Output: diff --git a/pptagent/prompts/markdown_image_caption.txt b/pptagent/prompts/markdown_image_caption.txt new file mode 100644 index 0000000000000000000000000000000000000000..195a9135b69d449206dd8df536b50e54929017c2 --- /dev/null +++ b/pptagent/prompts/markdown_image_caption.txt @@ -0,0 +1,14 @@ +Describe the main content of the image in less than 50 words, avoiding unnecessary details. Additionally, consider nearby chunks from a document, which may be of unreliable quality (e.g., vague, incomplete, or misleading), and prioritize the image content over the nearby chunks when crafting the description. Classify the image as 'Table,' 'Chart,' 'Diagram,' 'Banner,' 'Background,' 'Icon,' 'Logo,' etc., or 'Picture' if it cannot be classified as one of the above. Give your answer in the following format: + +: + +Example Output: + +Chart: Bar graph showing quarterly revenue growth over five years. Color-coded bars represent different product lines. Notable spike in Q4 of the most recent year, with a dotted line indicating industry average for comparison. + +Input: + +Nearby chunks: +{{markdown_caption}} + +Output: diff --git a/pptagent/prompts/markdown_table_caption.txt b/pptagent/prompts/markdown_table_caption.txt new file mode 100644 index 0000000000000000000000000000000000000000..d68df1bc6c17577cc6b05e6c41ad467f7864e4ee --- /dev/null +++ b/pptagent/prompts/markdown_table_caption.txt @@ -0,0 +1,10 @@ +Given a Markdown table and nearby chunks from a document, write a new, clear, and concise caption that accurately describes the table's content. The nearby chunks may be of unreliable quality (e.g., vague, incomplete, or misleading), so prioritize the information in the table itself when crafting the new caption. + +Input: +- Markdown table: +{{markdown_content}} + +- Nearby chunks: +{{markdown_caption}} + +Provide only the new caption as your response, start with "Table:" and less than 50 words. diff --git a/pptagent/prompts/merge_metadata.txt b/pptagent/prompts/merge_metadata.txt new file mode 100644 index 0000000000000000000000000000000000000000..f9c9109b8b786b60dad080b71874ba53edd39b94 --- /dev/null +++ b/pptagent/prompts/merge_metadata.txt @@ -0,0 +1,7 @@ +Given a list of dictionaries containing metadata extracted from different sections of a document in the order of those sections, +your task is to merge and refine this metadata into a single, coherent dictionary. +The provided metadata may be unreliable (e.g., containing meaningless, redundant, or inconsistent entries), and you should resolve these issues during the merging process. + +{{metadata}} + +Output the merged and refined metadata as a single dictionary in JSON format. diff --git a/pptagent/prompts/notes_generator.txt b/pptagent/prompts/notes_generator.txt new file mode 100644 index 0000000000000000000000000000000000000000..86cc8f616c6e6c38fa5bbcc670e9bfbaad21a344 --- /dev/null +++ b/pptagent/prompts/notes_generator.txt @@ -0,0 +1,20 @@ +You are a helpful assistant that generates concise and informative speaker notes for a presentation slide. + +Here is the slide content: +``` +{{slide_content}} +``` + +Here is the slide description: +``` +{{slide_description}} +``` + +**Task** +Using the slide content (Slide Content) and its description (Slide Description) below, craft a single-paragraph speech script that can be read aloud as-is. + +**Requirements** +1. Provide exactly one continuous paragraph—no bullet points or line breaks—so the delivery is smooth and natural; +2. Keep the language concise, engaging, and logically structured, adding brief context or explanations where helpful to aid audience understanding; +3. Aim for a length that around 30 to 40 words; +4. Output plain text only—do not use bullet points, numbers, or any Markdown formatting symbols. \ No newline at end of file diff --git a/pptagent/prompts/ppteval_coherence.txt b/pptagent/prompts/ppteval_coherence.txt new file mode 100644 index 0000000000000000000000000000000000000000..63615e72cf53c8d4016f5946a06a353a82ceeafd --- /dev/null +++ b/pptagent/prompts/ppteval_coherence.txt @@ -0,0 +1,28 @@ +You are an unbiased presentation analysis judge responsible for evaluating the coherence of the presentation. Please carefully review the provided summary of the presentation, assessing its logical flow and contextual information. Each score level requires that all evaluation criteria meet the standards of that level. +Scoring Criteria (Five-Point Scale) + +1 Point: +The logical structure is chaotic, making it difficult for the audience to understand. + +2 Points: +The logical structure is generally reasonable, with minor issues in transitions. + +3 Points: +The presentation demonstrates a clear and logical structure, with smooth transitions between sections. However, it lacks essential background information. + +4 Points: +The presentation features a well-organized logical flow and includes basic background information (e.g., speaker, date, or institution). + +5 Points: +The narrative structure is engaging and meticulously organized with detailed and comprehensive background information(speaker/institution, date, and acknowledgments/conclusion) included. + +Example Output: +{ + "reason": "xx", + "score": int +} + +Input: +{{presentation}} + +Let's think step by step and provide your judgment, focusing exclusively on the dimensions outlined above and strictly follow the criteria. diff --git a/pptagent/prompts/ppteval_content.txt b/pptagent/prompts/ppteval_content.txt new file mode 100644 index 0000000000000000000000000000000000000000..5c22b9766b62c066d3a6ba8fc21e74afb3db06e7 --- /dev/null +++ b/pptagent/prompts/ppteval_content.txt @@ -0,0 +1,27 @@ +You are an unbiased presentation analysis judge responsible for evaluating the quality of slide content. Please carefully review the provided description of the slide, assessing its content, and provide your judgement in a JSON object containing the reason and score. Each score level requires that all evaluation criteria meet the standards of that level. + +Scoring Criteria (Five-Point Scale): + +1 Point (Poor): +The text on the slides contains significant grammatical errors or is poorly structured, making it difficult to understand. + +2 Points (Below Average): +The slides lack a clear focus, the text is awkwardly phrased, and the overall organization is weak, making it hard to engage the audience. + +3 Points (Average): +The slide content is clear and complete but lacks visual aids, resulting in insufficient overall appeal. + +4 Points (Good): +The slide content is clear and well-developed, but the images have weak relevance to the theme, limiting the effectiveness of the presentation. + +5 Points (Excellent): +The slides are well-developed with a clear focus, and the images and text effectively complement each other to convey the information successfully. + +Example Output: +{ + "reason": "xx", + "score": int +} +Input: +{{descr}} +Please evaluate the slide step by step, ensuring your judgment strictly adheres to the scoring criteria. diff --git a/pptagent/prompts/ppteval_describe_content.txt b/pptagent/prompts/ppteval_describe_content.txt new file mode 100644 index 0000000000000000000000000000000000000000..317c55d361cddff4739e2d56d961af8bda587f02 --- /dev/null +++ b/pptagent/prompts/ppteval_describe_content.txt @@ -0,0 +1,9 @@ +Please describe the input slide based on the following three dimensions: +1. Information Density +Whether the slide conveys too lengthy or too little information, resulting in a large white space without colors or images. +2. Content Quality +Check if there are any grammatical errors or unclear expressions of textual content. +3. Images and Relevance +Assess the use of visual aids such as images or icons, their presence, and how well they relate to the theme and content of the slides. + +Provide an objective and concise description, focusing solely on the specified dimensions. diff --git a/pptagent/prompts/ppteval_describe_style.txt b/pptagent/prompts/ppteval_describe_style.txt new file mode 100644 index 0000000000000000000000000000000000000000..9305ad4bb9d8ff8507853c88ae668260e9ec8cb7 --- /dev/null +++ b/pptagent/prompts/ppteval_describe_style.txt @@ -0,0 +1,9 @@ +Describe the input slide based on the following dimensions: +1. Visual Consistency +Evaluate if any stylistic choices affect readability, such as overlapping elements or low contrast. +2. Color Scheme +Identify the colors used in the slide and determine if the design is monochromatic (black and white) or colorful (including gray). +3. Use of Visual Elements +Assess the presence of supporting visual elements, such as backgrounds, textures, patterns, or geometric shapes (e.g., rectangles, circles, bold dividers). + +Provide an objective and concise description, focusing solely on the specified dimensions. diff --git a/pptagent/prompts/ppteval_extract.txt b/pptagent/prompts/ppteval_extract.txt new file mode 100644 index 0000000000000000000000000000000000000000..b4fab949c05cdcad91ff7f886872c200b203019f --- /dev/null +++ b/pptagent/prompts/ppteval_extract.txt @@ -0,0 +1,20 @@ +You are an expert presentation content extractor responsible for analyzing and summarizing key elements and metadata of presentations. Your task is to extract and provide the following information: +1. Slide Descriptions: conclude the purpose of each slide. +2. Presentation Metadata: Identify explicit background information (presented as a standalone paragraph and not embedded within other paragraphs), such as the author, speaker, date, and other directly stated details from the opening and closing slides. + +Example Output: +{ + "slide_descriptions": [ + "This slide introduces the xx, xx.", + "...", + ], + "background": { + "speaker": "speaker x", + "date": "date x" + } +} + +Input: +{{presentation}} + +Output: diff --git a/pptagent/prompts/ppteval_style.txt b/pptagent/prompts/ppteval_style.txt new file mode 100644 index 0000000000000000000000000000000000000000..3e9548362a1ac58ffdc6986076fdf323d357d511 --- /dev/null +++ b/pptagent/prompts/ppteval_style.txt @@ -0,0 +1,28 @@ +You are an unbiased presentation analysis judge responsible for evaluating the visual appeal of slides. Please carefully review the provided description of the slide, assessing their aesthetics only, and provide your judgment in a JSON object containing the reason and score. Each score level requires that all evaluation criteria meet the standards of that level. + +Scoring Criteria (Five-point scale): + +1 Point (Poor): +There is a conflict between slide styles, making the content difficult to read. + +2 Points (Fair): +The slide uses monotonous colors(black and white), ensuring readability while lacking visual appeal. + +3 Points (Average): +The slide employs a basic color scheme; however, it lacks supplementary visual elements such as icons, backgrounds, images, or geometric shapes(like rectangles), making it look plain. + +4 Points (Good): +The slide uses a harmonious color scheme and contains some visual elements(like icons, backgrounds, images, or geometric shapes); however, minor flaws may exist in the overall design. + +5 Points (Excellent): +The style of the slide is harmonious and engaging, the use of supplementary visual elements like images and geometric shapes enhances the slide’s overall visual appeal. + +Example Output: +{ + "reason": "xx", + "score": int +} + +Input: +{{descr}} +Please evaluate the slide step by step, ensuring your judgment strictly adheres to the scoring criteria. diff --git a/pptagent/prompts/section_summary.txt b/pptagent/prompts/section_summary.txt new file mode 100644 index 0000000000000000000000000000000000000000..e470cd20503ba948c4bebfd078266b98886b70eb --- /dev/null +++ b/pptagent/prompts/section_summary.txt @@ -0,0 +1,6 @@ +Please summarize the content of the document section into a concise paragraph less than 100 words. + +Input: +{{section_content}} + +Output: diff --git a/pptagent/prompts/table_parsing.txt b/pptagent/prompts/table_parsing.txt new file mode 100644 index 0000000000000000000000000000000000000000..9d86a4f2677800ec3f4ae176e162024dd41b0bd1 --- /dev/null +++ b/pptagent/prompts/table_parsing.txt @@ -0,0 +1,41 @@ +You are an AI assistant tasked with refining tabular data. The table may contain ambiguities or errors in labeling. +Your job is to interpret the given table and its caption to: +1. Identify issues in the table structure (e.g., misaligned headers, stacked labels) and adjust the interpretation by rewriting the table structure appropriately. +2. Ensure adherence to basic table conventions, such as: + - The empty top-left cell is commonly used. + - Each cell generally contains a single word or number +3. Infer which cells should be merged based on the table’s hierarchical layout, where top-level headers may span multiple rows or columns. Consider the semantic relationship between adjacent cells (above, below, or beside) to ensure merged cells reflect logical groupings. +4. Provide the final output in the following format: + - `table_data`: A refined 2D array of the table’s contents. + - `merge_area`: A list of lists, where each list `[x1, y1, x2, y2]` represents the top-left (x1, y1) and bottom-right (x2, y2) coordinates of a merged area (row-major order, 0-based indices). + +Example Input: +Caption: "Fruit Prices and Stock Levels" +[ + ["Fruit", "Attribute", ""], + ["", "Price Stock", ""], + ["Apple", "4", "8"], + ["Pear", "7", "6"] +] + +Example Output: +```json +{ + "table_data": [ + ["Fruit", "", ""], + ["", "Price", "Stock"], # split stacked header based on caption intent + ["Apple", "4", "8"], + ["Pear", "7", "6"] + ], + "merge_area": [ + [0, 1, 0, 2] # merge the second and third cells of the first row + ] +} +``` + +Tabular Data: +{{cells}} +Table Caption +{{caption}} + +Output: give your final output in json format wrapped in ```json``` diff --git a/pptagent/resource/pic_placeholder.png b/pptagent/resource/pic_placeholder.png new file mode 100644 index 0000000000000000000000000000000000000000..db8065a59734417f7441b740d3919a9970d46a56 Binary files /dev/null and b/pptagent/resource/pic_placeholder.png differ diff --git a/pptagent/roles/agent.yaml b/pptagent/roles/agent.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2263c13026af32877abd22863ca0b127b957248a --- /dev/null +++ b/pptagent/roles/agent.yaml @@ -0,0 +1,72 @@ +system_prompt: | + You are a multifunctional content processing and code-generation assistant specializing in parsing HTML structures and content frameworks. Your task is to convert slide content and editing requirements into accurate API call sequences. You must strictly follow the rules to ensure precision and consistency with the input logic. You must use English. +template: | + Task Description: + Generate an API call sequence based on the input slide code and available content to replace the existing slide content. Follow the rules below: + + Available API Functions: + {{api_docs}} + + Rules and Requirements + 1. Slide Content Generation Rules + • Use the structure defined in the schema to generate content. + • Extract key information from text and images_info to generate core elements (e.g., slide_title, main_content). Ensure the content is semantically consistent and concise. + • Supportive elements (e.g., logo) should only be generated if relevant information is provided. + + 2. HTML Structure Rules + • Determine the parent-child relationship of elements based on indentation in the HTML structure. + • Ensure all and elements are fully processed, avoiding omissions. + + 3. Quantity Adjustment Rules + • Adding Elements: + • Use clone_paragraph to add paragraphs. The new paragraph’s paragraph_id should be the current maximum paragraph_id + 1, while keeping the span_id unchanged. + • Prioritize cloning paragraphs with existing special styles (e.g., bold, color). + • Removing Elements: + • Use del_span or del_image to reduce content. Always remove elements from the end of the paragraph first. + • Maintaining Quantity: + • If the quantity remains unchanged, only replace the content without cloning or deleting elements. + + 4. Content Replacement Rules + • Text Replacement: + • Use replace_span to replace content within a paragraph. Add styles (e.g., bold, color) where needed. + • Image Replacement: + • Use replace_image to replace image paths, ensuring the images match the input. + + 5. Operation Restrictions + • Each API call must perform only one type of operation, either clone or del, but not both. + • Ensure the generated API call sequence strictly follows the input logic and avoids generating unrelated content. + + Example Output: + # Replace title content + replace_span(0, 0, 0, "New Slide Title") + + # Add a new main content paragraph + clone_paragraph(1, 0) # The new paragraph's paragraph_id is 2, based on the current max paragraph_id of 1 + replace_span(1, 2, 0, "Generated content based on the input text") + + # Delete unnecessary content from the paragraph + del_span(1, 1, 0) + + # Replace project logo + replace_image(2, "images/new_logo.png") + + Input: + - Schema: {{schema}} + - Outline: {{outline}} + - Metadata: {{metadata}} + - Reference Text: {{text}} + - Image Information: {{images_info}} + - Current Slide Content: {{edit_target}} + + Output: Output only the API call sequence. Add comments for each API call, explaining the purpose of the operation and the corresponding element. + +jinja_args: + - schema + - outline + - metadata + - text + - images_info + - edit_target + - api_docs +use_model: language +return_json: false diff --git a/pptagent/roles/coder.yaml b/pptagent/roles/coder.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3dd7365c53a723fcab790a01b39339c4d7266515 --- /dev/null +++ b/pptagent/roles/coder.yaml @@ -0,0 +1,80 @@ +system_prompt: You are a Code Generator agent specializing in slide manipulation. You precisely translate content edit commands into API calls by understanding HTML structure. You must use English. +template: | + Generate API calls based on the provided commands, ensuring compliance with the specified rules and precise execution. + You must determine the parent-child relationships of elements based on indentation and ensure that all

and elements are modified. + + Each command follows this format: (element_class, type, quantity_change: int, old_data, new_data). + + Available APIs + + {{api_docs}} + + Steps + 1. Quantity Adjustment: + - If quantity_change = 0, modify the content only. + - If quantity_change > 0, use clone_paragraph to add the specified number of paragraphs from the same element_class. The paragraph_id for newly cloned paragraphs should be the current maximum paragraph_id of the parent element plus 1. + - If quantity_change < 0, use del_paragraph or del_image to remove the specified number of tail elements. + - Each command’s API call group must exclusively use either clone_paragraph or del_paragraph/del_image based on the `quantity_change` + 2. Content Modification: + - Text Content: Use replace_paragraph to modify the content. + - Image Content: Use replace_image to replace image resources. + 3. Output Format: + - Add comments to each API call group, explaining the intent of the original command and the associated element_class. + - For cloning operations, annotate the paragraph_id of the newly created paragraphs. + + Example Input: + +

+

+ WorldFAIR: Global cooperation on FAIR data policy and practice +

+
+ +
+
    +
  • + Two-year project to advance implementation... +
  • +
  • + Funded by the European Union... +
  • +
+
+ + logo: project of xx + + [ + ("title", "text", "quantity_change: 0", ["WorldFAIR: Global cooperation on FAIR data policy and practice"], ["New Title"]), + ("project_description", "text", "quantity_change: 1", ["Two-year project to advance implementation of the FAIR principles"], ["New project description1", "New project description2"]), + ("funding_info", "text", "quantity_change: -1", ["Funded by the European Union"], []), + ("project_logo", "image", "quantity_change: 0", ["logo: project of xx"], ["new_logo.png"]) + ] + + Example Output + # ("title", "text", "quantity_change: 0", ["WorldFAIR: Global cooperation on FAIR data policy and practice"], ["New Title"]) + replace_paragraph(0, 0, "New Title") + + # ("project_description", "text", "quantity_change: 1", ["Two-year project to advance implementation of the FAIR principles"], ["New project description1", "New project description2"]) + clone_paragraph(1, 0) # New cloned paragraph_id is 2 as the current max paragraph_id is 1 + replace_paragraph(1, 0, "New project description1") + replace_paragraph(1, 2, "New project description2") + + # ("funding_info", "text", "quantity_change: -1", ["Funded by the European Union"], []) + del_paragraph(1, 1) + + # ("project_logo", "image", "quantity_change: 0", ["logo: project of xx"], ["new_logo.png"]) + replace_image(2, "new_logo.png") + + Current Slide Content: + {{edit_target}} + + Command List: + {{command_list}} + + Please output only the API call sequence, one call per line, wrapped in ```python and ```, with comments for corresponding commands. +jinja_args: + - api_docs + - edit_target + - command_list +use_model: language +return_json: false diff --git a/pptagent/roles/content_organizer.yaml b/pptagent/roles/content_organizer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0eb27eeded57b8f6e498b272e38fa067195672f7 --- /dev/null +++ b/pptagent/roles/content_organizer.yaml @@ -0,0 +1,45 @@ +system_prompt: You are an intelligent assistant tasked with extracting "key points" from the given content source. Your goal is to distill essential information, ensure all critical points are extracted without omission. +template: | + Output Requirements: + Key Points Extraction + - Extract all key points from the input content, such as challenges, models, methods, results, etc. + - Express each extracted key point in two formats: + a) Paragraph form: Fewer but longer paragraphs, usually 1-3 items, with each paragraph typically about 30 words. + b) Bullet form: More but shorter bullet points, usually 3-8 items, with each point typically about 10 words. + - If no content is provided, leave the key points an empty list. + Example Output: + + ```json + [ + { + "pointName": "Challenges", + "paragraphForm": "One of the main challenges in this domain is the ability to scale efficiently while maintaining accuracy. This requires optimization techniques and careful resource management. Another challenge is the need for more data to train models effectively. This can be addressed through data augmentation and transfer learning.", + "bulletForm": [ + "Scaling while maintaining accuracy is difficult.", + "Requires optimization techniques.", + "Careful resource management is necessary.", + "More data is needed for effective model training.", + "Data augmentation and transfer learning can address this." + ] + }, + { + "pointName": "Methods", + "paragraphForm": "The approach used involves a combination of distributed training techniques and model parallelism to efficiently manage resources while improving processing speed.", + "bulletForm": [ + "Uses distributed training techniques.", + "Employs model parallelism.", + "Aims to improve processing speed." + ] + } + ] + ``` + + Input: + {{content_source}} + + Output: give your output in JSON format, you must use English. + +jinja_args: + - content_source +use_model: language +return_json: true diff --git a/pptagent/roles/doc_extractor.yaml b/pptagent/roles/doc_extractor.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6c18a9893a7a4fd7c7448117aa1594f866293000 --- /dev/null +++ b/pptagent/roles/doc_extractor.yaml @@ -0,0 +1,44 @@ +system_prompt: | + You are a document content extractor specialist, expert in losslessly extracting content from a single section of various types of Markdown documents, and reorganizing it into a structured format. You must use English. +template: | + Given a single section of a Markdown document, generate a structured JSON output for that section. + Step-by-Step Instructions: + 1. Identify Subsections: Within the provided section, use heading levels (e.g., H2, H3) and logical relationships to identify subsections. + 2. Extract Titles and Content: generate concise (<= 5 words) and appropriate titles based on content. Ensure the content is complete and not truncated. + 3. Process Content: + - Retain all original text as provided, and ensure the most important content is retained without truncation. + 4. Extract Available Metadata: Extract available metadata (e.g., title, author, publish date, organization, etc.) from the section’s content or context; include only the keys for which data is present. + + Example Output: + { + "metadata": { + `key`: `value` // leave it empty if no metadata is present + }, + "title": "Section 1", + "subsections": [ + { + "title": "Subsection 1.1", + "content": "content of subsection 1.1" + }, + { + "title": "Subsection 1.2", + "content": "content" + }, + { + "title": "Subsection 1.3", + "content": "content" + } + ] + } + + Input: + + Markdown Document: + {{ markdown_document }} + + Output: Give your output in JSON format, use the same language as the input document, make sure all valid text is retained. + +jinja_args: + - markdown_document +use_model: language +return_json: true diff --git a/pptagent/roles/editor.yaml b/pptagent/roles/editor.yaml new file mode 100644 index 0000000000000000000000000000000000000000..af2074470ef1a8d3b33e045de23727714bef7977 --- /dev/null +++ b/pptagent/roles/editor.yaml @@ -0,0 +1,50 @@ +system_prompt: You are an expert Editor agent. Transform reference text and images into slide content, following schema rules and using only provided materials. Ensure the content is engaging and within the character limit. Always generate content in the same language as the reference text. You must use English. +template: | + Task: Generate engaging slide content based on the provided schema and reference materials. + + Requirements: + 1. Content Generation Rules: + - Follow default_quantity for elements, adjust when necessary and it is not 1, you can properly rewrite the content to fit the schema + - Ensure text content meets character limits + - Generated text should use concise and impactful presentation style + + 2. Core Elements: + - Must extract essential content from reference text (e.g., slide_title, main_content) and maintain semantic consistency + - Must include images that support the main content (e.g., diagrams for explanations, visuals directly discussed in text) + + 3. Supporting Elements (e.g., presenters, logo images): + - Generate only when relevant content exists in reference materials + + Generate content for each element and output in the following format: + { + "element1": { + "data": ["text1", "text2"] for text elements + or ["/path/to/image", "..."] for image elements + }, + } + + Input: + Presentation Outline: + {{outline}} + + Current Slide Description: + {{slide_description}} + + Metadata of Presentation: + {{metadata}} + + Slide Content Source: + {{slide_content}} + + Schema: + {{schema}} + + Output: Ensure the generated content strictly adheres to the schema specifications, follows the slide style, reader-friendly, and use the same language as the reference text. +jinja_args: + - outline + - slide_description + - metadata + - slide_content + - schema +use_model: language +return_json: true diff --git a/pptagent/roles/layout_selector.yaml b/pptagent/roles/layout_selector.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7b9b72520f80e9388423e8fd092f51151b03fe87 --- /dev/null +++ b/pptagent/roles/layout_selector.yaml @@ -0,0 +1,45 @@ +system_prompt: | + You are an intelligent assistant tasked with selecting the most suitable layout from a set of predefined options based on the provided slide information and providing detailed reasoning. You must use English. +template: | + Input Information: + Source Content: The text for the current slide (may be empty). + Image Information: Images and their captions. + Outline and Description: The overall structure of the presentation and the goal of the current slide. + Layout Options: Including name (e.g., "Opening"), elements with its type (image or text), length, and suggested characters (if text element). + + Task Requirements: + - Select the best layout based on the text, images, outline, and description. + - Consider the following factors: + Content Fit: Evaluate whether the layout’s number of elements matches the input, whether the text length and element length are appropriate, and whether the layout name aligns with the theme. + Image Fit: Assess the relevance of the images to the theme and their enhancement to the content; if highly relevant and beneficial, prioritize layouts with images; if relevance is low or text dominates, a text-only layout may be chosen. + If no images are provided, use the text-only layout. + - Output: + - Layout name. + - Detailed Reasoning: Analyze the fit between the layout and content (element count, text length, theme alignment) and the fit between images and content (relevance and enhancement), explaining why this layout was chosen and whether images are used. + + Example Output: + { + "reasoning": "The current slide is Slide 2, themed \"team introduction,\" with the goal of showcasing team members and their backgrounds. The text (50 characters) is concise and close to the middle of the Image-Text layout's character range (30-100), making it suitable for summarizing team details. The provided team photo is highly relevant to the theme, offering a visual representation of the team that significantly enhances audience understanding and engagement, aligning with the rule to prioritize image-inclusive layouts when applicable. The Image-Text layout, with 1 image slot and 1 text slot, perfectly matches the input needs. In contrast, Opening:Text (100-300 characters) is better suited for a text-heavy opening slide, while Stage Analysis (2 images, 1 text) is excessive for a single image and short text, making Image-Text the optimal choice.", + "layout": "Image-Text" + } + + Input: + Outline: {{ outline }} + + Current Slide Description: + {{ slide_description }} + + Slide Content Source: + {{ slide_content }} + + Layout Options: {{ available_layouts }} + + Output: give your anwser in json format + +jinja_args: + - outline + - slide_description + - slide_content + - available_layouts +use_model: language +return_json: true diff --git a/pptagent/roles/notes_generator.yaml b/pptagent/roles/notes_generator.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f831d16de6f0ab701686ac369b9fb820dbd1210f --- /dev/null +++ b/pptagent/roles/notes_generator.yaml @@ -0,0 +1,27 @@ +system_prompt: | + You are a helpful assistant that generates concise and informative speaker notes for a presentation slide. You must use English. + +template: | + Here is the slide content: + ``` + {{slide_content}} + ``` + + Here is the slide description: + ``` + {{slide_description}} + ``` + + **Task** + Using the slide content (Slide Content) and its description (Slide Description) below, craft a single-paragraph speech script that can be read aloud as-is. + + **Requirements** + 1. Provide exactly one continuous paragraph—no bullet points or line breaks—so the delivery is smooth and natural; + 2. Keep the language concise, engaging, and logically structured, adding brief context or explanations where helpful to aid audience understanding; + 3. Aim for a length that around 30 to 40 words; + 4. Output plain text only—do not use bullet points, numbers, or any Markdown formatting symbols. +jinja_args: + - slide_content + - slide_description +use_model: language +return_json: false \ No newline at end of file diff --git a/pptagent/roles/planner.yaml b/pptagent/roles/planner.yaml new file mode 100644 index 0000000000000000000000000000000000000000..217f0fa50ac78caf82f74cd4da4e2443d0236377 --- /dev/null +++ b/pptagent/roles/planner.yaml @@ -0,0 +1,50 @@ +system_prompt: | + You are a skilled presentation designer tasked with crafting engaging and structured presentation outlines based on provided document overviews, ensuring accurate indexing, prioritizing important sections, and aligning with specified slide requirements. You must use English. +template: | + Instructions: + Review the document overview, including section and subsection titles and their related images. Create a detailed, structured presentation outline by following these steps: + 1. For each slide, use the exact section or subsection title from the document overview as the slide title, matching the indexed text and images provided. Distribute multiple images to different slides, each slide can only present a single image directly associated with a specific subsection, and do not repeat the same image in different slides. + 2. Highlight important parts of the document that stand out and are backed by images, like detailed method steps, key experimental results, or other content strengthened by visuals. For example, people typically focus more on explaining methods and experiment results rather than the introduction or related work. + 3. Ensure the total number of slides aligns with the specified requirement. + + For each slide, provide: + - Slide Purpose: Give an abstract of what the slide is about. Do not include excessive information (e.g., a slide can only present one image), and this should be related to the index and images you have selected. + - Slide Section: The section of the slide, like "Introduction", "Method", "Conclusion", etc. It must be the same language as the document overview, and without numbering (e.g. "Introduction" instead of "1. Introduction"). + - Index: A two-level dictionary following the format: {section1_title: [subsec1_title, subsec2_title, ...]}}. Use exact section and subsection titles as they appear without any modification or formatting, e.g., "1. Introduction" instead of "Introduction", "2 Method" instead of "2 Method". + - Images: A list of images that are related to the slide, select the most relevant images from the document overview, each image should be a string exactly matching the caption of the image. + + Example Output: + [ + { + "purpose": "introduce ...", + "section": "Introduction", + "indexs": {"Section 1": ["Section 1.1", "Section 1.2"]}, + "images": [], + }, + { + "purpose": "detail ...", + "section": "Method", + "indexs": {"Section 2": ["Section 2.1", "Section 2.2"]}, + "images": ["workflow of the method..."], + }, + { + "purpose": "illustrate the ...", + "section": "Experiment", + "indexs": {"Section 3": ["Section 3.1", "Section 3.2"]}, + "images": ["experiment results..."], + }, + ..., + ] + + Input: + Required Number of Slides: {{ num_slides }} + + Document Overview: + {{ document_overview }} + + Output: the `indexs` should be an exact match of the title of the section and subsection. +jinja_args: + - num_slides + - document_overview +use_model: language +return_json: true diff --git a/pptagent/roles/schema_extractor.yaml b/pptagent/roles/schema_extractor.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9889a026e1d1d0e7427653498f45dcb70a6c2612 --- /dev/null +++ b/pptagent/roles/schema_extractor.yaml @@ -0,0 +1,63 @@ +system_prompt: You are an expert in extracting slide schemas for slides. Your task is to analyze slide HTML and create a JSON schema that captures the slide elements and their relationships. Your output should be a dictionary in JSON format, where each key-value pair strictly corresponds to the text content in a `

` element or the `alt` attribute of an `` element. You must use English. +template: | + Please analyze the slide elements and create a structured slide schema in JSON format. You should: + + 1. Understand the html representation of the slide, especially the style, layout, and the logical relationship between elements + + 2. For each element, extract the following information: + - "name": The role and feature of the element, such as "section/main/sub title", "content bullets/paragraph", "portrait/landscape/square image", "presenters", "dates", "acknowledgments", etc. + - "description": A clear description of the element's purpose, style feature(e.g. red bold bullets, small circular images), do not mention any content detail + - "type": Literal["text", "image"] + - "data": List[str] + * For text elements: The content of each paragraph, defined as text within `

` tags, must be treated as a single, distinct item. + - Preserve newlines (`\n`) within paragraphs, ensuring they remain intact and are not split into multiple items. Only `

` tags serve as separators. + - Do not combine multiple `

` tags into a single item, regardless of logical or narrative connections. Each `

` tag represents a separate item. + * For image elements: Use the `alt` attribute of the tag as the data of the image + + 3. Do not include any empty elements, only given elements should be included + + Example Input: +

+

text0

+
+
+

text1\ntext2\ntext3

+
+ caption of image + + Example Output: + { + "main title": { + "description": "main title of the slide", + "type": "text", + "data": ["text0"] + }, + "content bullets": { + "description": "content bullets of the slide", + "type": "text", + "data": ["text1\ntext2\ntext3"] + }, + "main image": { + "description": "main image of the slide", + "type": "image", + "data": ["caption of image"] + } + } + Example format: + { + "element_name": { + "description": "purpose of this element", # do not mention any detail, just purpose + "type": "text" or "image", + "data": ["text1", "text2"] or ["logo:...", "logo:..."] + } + } + Input: + {{slide}} + + Output: Please provide the slide schema in a dict of JSON format + +jinja_args: + - slide + +use_model: language +return_json: True diff --git a/pptagent/runs/2025-07-05/165c18a8-8169-480d-a388-615ab6f420b0/task.json b/pptagent/runs/2025-07-05/165c18a8-8169-480d-a388-615ab6f420b0/task.json new file mode 100644 index 0000000000000000000000000000000000000000..2b75c02249a3cb62e833b05003738f5099889347 --- /dev/null +++ b/pptagent/runs/2025-07-05/165c18a8-8169-480d-a388-615ab6f420b0/task.json @@ -0,0 +1 @@ +{"numberOfPages": 5, "pptx": "c1eb4d337b2aa71bec0b0bda89322db2", "pdf": "37fd83b93256101767cb27322fba795f"} \ No newline at end of file diff --git a/pptagent/runs/2025-07-05/170a771a-eb46-4f03-bba2-c77dae7dc110/final.pptx b/pptagent/runs/2025-07-05/170a771a-eb46-4f03-bba2-c77dae7dc110/final.pptx new file mode 100644 index 0000000000000000000000000000000000000000..9482af8a2cba66312d26a07f871d416b19d55174 Binary files /dev/null and b/pptagent/runs/2025-07-05/170a771a-eb46-4f03-bba2-c77dae7dc110/final.pptx differ diff --git a/pptagent/runs/2025-07-05/170a771a-eb46-4f03-bba2-c77dae7dc110/task.json b/pptagent/runs/2025-07-05/170a771a-eb46-4f03-bba2-c77dae7dc110/task.json new file mode 100644 index 0000000000000000000000000000000000000000..2b75c02249a3cb62e833b05003738f5099889347 --- /dev/null +++ b/pptagent/runs/2025-07-05/170a771a-eb46-4f03-bba2-c77dae7dc110/task.json @@ -0,0 +1 @@ +{"numberOfPages": 5, "pptx": "c1eb4d337b2aa71bec0b0bda89322db2", "pdf": "37fd83b93256101767cb27322fba795f"} \ No newline at end of file diff --git a/pptagent/runs/2025-07-05/5284ef26-24fd-43c2-b435-5abad5de0cf8/task.json b/pptagent/runs/2025-07-05/5284ef26-24fd-43c2-b435-5abad5de0cf8/task.json new file mode 100644 index 0000000000000000000000000000000000000000..c9520f954e0142a748a3a0b367fe1f0fbd73783d --- /dev/null +++ b/pptagent/runs/2025-07-05/5284ef26-24fd-43c2-b435-5abad5de0cf8/task.json @@ -0,0 +1 @@ +{"numberOfPages": 3, "pptx": "0210ff6b414902fa05857e734dd5bcee", "pdf": "9145dbfce1296e2b0603293042aa883e"} \ No newline at end of file diff --git a/pptagent/runs/2025-07-05/a5c87f48-ac8c-4577-9a99-e5ba2e268bfc/task.json b/pptagent/runs/2025-07-05/a5c87f48-ac8c-4577-9a99-e5ba2e268bfc/task.json new file mode 100644 index 0000000000000000000000000000000000000000..d8ab43227c50540420208cfe18312696b4098c04 --- /dev/null +++ b/pptagent/runs/2025-07-05/a5c87f48-ac8c-4577-9a99-e5ba2e268bfc/task.json @@ -0,0 +1 @@ +{"numberOfPages": 4, "pptx": "c1eb4d337b2aa71bec0b0bda89322db2", "pdf": "37fd83b93256101767cb27322fba795f"} \ No newline at end of file diff --git a/pptagent/runs/2025-07-05/b43539a3-11a0-42bd-8f7e-ca33253615a1/task.json b/pptagent/runs/2025-07-05/b43539a3-11a0-42bd-8f7e-ca33253615a1/task.json new file mode 100644 index 0000000000000000000000000000000000000000..d8ab43227c50540420208cfe18312696b4098c04 --- /dev/null +++ b/pptagent/runs/2025-07-05/b43539a3-11a0-42bd-8f7e-ca33253615a1/task.json @@ -0,0 +1 @@ +{"numberOfPages": 4, "pptx": "c1eb4d337b2aa71bec0b0bda89322db2", "pdf": "37fd83b93256101767cb27322fba795f"} \ No newline at end of file diff --git a/pptagent/runs/2025-07-05/da363260-400b-4a00-b426-0282f3069046/final.pptx b/pptagent/runs/2025-07-05/da363260-400b-4a00-b426-0282f3069046/final.pptx new file mode 100644 index 0000000000000000000000000000000000000000..d363ad7656b7bd0c87f202adf28763fb6772d8da Binary files /dev/null and b/pptagent/runs/2025-07-05/da363260-400b-4a00-b426-0282f3069046/final.pptx differ diff --git a/pptagent/runs/2025-07-05/da363260-400b-4a00-b426-0282f3069046/task.json b/pptagent/runs/2025-07-05/da363260-400b-4a00-b426-0282f3069046/task.json new file mode 100644 index 0000000000000000000000000000000000000000..2b75c02249a3cb62e833b05003738f5099889347 --- /dev/null +++ b/pptagent/runs/2025-07-05/da363260-400b-4a00-b426-0282f3069046/task.json @@ -0,0 +1 @@ +{"numberOfPages": 5, "pptx": "c1eb4d337b2aa71bec0b0bda89322db2", "pdf": "37fd83b93256101767cb27322fba795f"} \ No newline at end of file diff --git a/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Figure_27.jpeg b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Figure_27.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..c0634c7a006e30b50a6461037ce77313de502e60 Binary files /dev/null and b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Figure_27.jpeg differ diff --git a/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Figure_34.jpeg b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Figure_34.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..8148c7fa84fc604bf8ed33a29dbba0db60e943fb Binary files /dev/null and b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Figure_34.jpeg differ diff --git a/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Figure_42.jpeg b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Figure_42.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..21a880ea5935b16630192986ad78d24b1612fe1a Binary files /dev/null and b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Figure_42.jpeg differ diff --git a/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Figure_66.jpeg b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Figure_66.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..ad1698636602d6222a8cf4eb210ad1f2f551261f Binary files /dev/null and b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Figure_66.jpeg differ diff --git a/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Figure_74.jpeg b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Figure_74.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..eb3b15ab2894b1494d615ca8650e3e19d3b5b005 Binary files /dev/null and b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Figure_74.jpeg differ diff --git a/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Figure_85.jpeg b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Figure_85.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..5ca14e9f7b0eabfc3cfaadb23315df39ebfd48d2 Binary files /dev/null and b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Figure_85.jpeg differ diff --git a/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Picture_2.jpeg b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Picture_2.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..c789b1c562ca98e1352eafbac9f9db677e106b7e Binary files /dev/null and b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Picture_2.jpeg differ diff --git a/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_1_Picture_52.jpeg b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_1_Picture_52.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..548b94a4c322997566b552c745db4d1cf7d0ab66 Binary files /dev/null and b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_1_Picture_52.jpeg differ diff --git a/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/meta.json b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/meta.json new file mode 100644 index 0000000000000000000000000000000000000000..5f44434f7cdd689321edb60651a142cf159b2f5e --- /dev/null +++ b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/meta.json @@ -0,0 +1,693 @@ +{ + "table_of_contents": [ + { + "title": "Building effective agents", + "heading_level": null, + "page_id": 0, + "polygon": [ + [ + 515.60107421875, + 129.55579376220703 + ], + [ + 1456.3255615234375, + 129.55579376220703 + ], + [ + 1456.3255615234375, + 216.3658676147461 + ], + [ + 515.60107421875, + 216.3658676147461 + ] + ] + }, + { + "title": "What are agents?", + "heading_level": null, + "page_id": 0, + "polygon": [ + [ + 623.978759765625, + 960.8446655273438 + ], + [ + 956.5999755859375, + 960.8446655273438 + ], + [ + 956.5999755859375, + 997.7664184570312 + ], + [ + 623.978759765625, + 997.7664184570312 + ] + ] + }, + { + "title": "When (and when not) to use\nagents", + "heading_level": null, + "page_id": 0, + "polygon": [ + [ + 617.9833984375, + 1626.3153076171875 + ], + [ + 1136.0601806640625, + 1626.3153076171875 + ], + [ + 1136.0601806640625, + 1704.6533203125 + ], + [ + 617.9833984375, + 1704.6533203125 + ] + ] + }, + { + "title": "When and how to use frameworks", + "heading_level": null, + "page_id": 0, + "polygon": [ + [ + 620.75048828125, + 2129.1544189453125 + ], + [ + 1259.02587890625, + 2129.1544189453125 + ], + [ + 1259.02587890625, + 2171.3507080078125 + ], + [ + 620.75048828125, + 2171.3507080078125 + ] + ] + }, + { + "title": "Building blocks, workflows, and\nagents", + "heading_level": null, + "page_id": 0, + "polygon": [ + [ + 615.21630859375, + 2918.6533203125 + ], + [ + 1202.76171875, + 2918.6533203125 + ], + [ + 1202.76171875, + 2997.6947021484375 + ], + [ + 615.21630859375, + 2997.6947021484375 + ] + ] + }, + { + "title": "Building block: The augmented LLM", + "heading_level": null, + "page_id": 0, + "polygon": [ + [ + 618.444580078125, + 3215.7088623046875 + ], + [ + 1140.041015625, + 3215.7088623046875 + ], + [ + 1140.041015625, + 3250.8724365234375 + ], + [ + 618.444580078125, + 3250.8724365234375 + ] + ] + }, + { + "title": "The augmented LLM", + "heading_level": null, + "page_id": 0, + "polygon": [ + [ + 464.5, + 3851.2904663085938 + ], + [ + 609.220947265625, + 3851.2904663085938 + ], + [ + 609.220947265625, + 3882.05859375 + ], + [ + 464.5, + 3882.05859375 + ] + ] + }, + { + "title": "Workflow: Prompt chaining", + "heading_level": null, + "page_id": 0, + "polygon": [ + [ + 621.211669921875, + 4257.429748535156 + ], + [ + 997.9970703125, + 4257.429748535156 + ], + [ + 997.9970703125, + 4285.560607910156 + ], + [ + 621.211669921875, + 4285.560607910156 + ] + ] + }, + { + "title": "Workflow: Routing", + "heading_level": null, + "page_id": 0, + "polygon": [ + [ + 624.5, + 5309.69970703125 + ], + [ + 878.0260009765625, + 5309.69970703125 + ], + [ + 878.0260009765625, + 5337.83056640625 + ], + [ + 624.5, + 5337.83056640625 + ] + ] + }, + { + "title": "Workflow: Parallelization", + "heading_level": null, + "page_id": 0, + "polygon": [ + [ + 623.517578125, + 6389.80517578125 + ], + [ + 977.705078125, + 6389.80517578125 + ], + [ + 977.705078125, + 6422.6268310546875 + ], + [ + 623.517578125, + 6422.6268310546875 + ] + ] + }, + { + "title": "Sectioning:", + "heading_level": null, + "page_id": 0, + "polygon": [ + [ + 640.581298828125, + 7428.964370727539 + ], + [ + 786.314697265625, + 7428.964370727539 + ], + [ + 786.314697265625, + 7457.974319458008 + ], + [ + 640.581298828125, + 7457.974319458008 + ] + ] + }, + { + "title": "Workflow: Orchestrator-workers", + "heading_level": null, + "page_id": 0, + "polygon": [ + [ + 622.59521484375, + 7972.80517578125 + ], + [ + 1076.2110595703125, + 7972.80517578125 + ], + [ + 1076.2110595703125, + 8002.350402832031 + ], + [ + 622.59521484375, + 8002.350402832031 + ] + ] + }, + { + "title": "Workflow: Evaluator-optimizer", + "heading_level": null, + "page_id": 0, + "polygon": [ + [ + 611.52685546875, + 9054.620361328125 + ], + [ + 1053.3388671875, + 9054.620361328125 + ], + [ + 1053.3388671875, + 9086.267578125 + ], + [ + 611.52685546875, + 9086.267578125 + ] + ] + }, + { + "title": "The evaluator-optimizer workflow", + "heading_level": null, + "page_id": 0, + "polygon": [ + [ + 461.642822265625, + 9604.930297851562 + ], + [ + 712.525634765625, + 9604.930297851562 + ], + [ + 712.525634765625, + 9633.061157226562 + ], + [ + 461.642822265625, + 9633.061157226562 + ] + ] + }, + { + "title": "Examples where evaluator-optimizer is useful:", + "heading_level": null, + "page_id": 0, + "polygon": [ + [ + 623.517578125, + 9940.742431640625 + ], + [ + 1072.70849609375, + 9940.742431640625 + ], + [ + 1072.70849609375, + 9972.3896484375 + ], + [ + 623.517578125, + 9972.3896484375 + ] + ] + }, + { + "title": "Agents", + "heading_level": null, + "page_id": 0, + "polygon": [ + [ + 614.755126953125, + 10197.4365234375 + ], + [ + 722.4810180664062, + 10197.4365234375 + ], + [ + 722.4810180664062, + 10226.80517578125 + ], + [ + 614.755126953125, + 10226.80517578125 + ] + ] + }, + { + "title": "Combining and customizing these\npatterns", + "heading_level": null, + "page_id": 0, + "polygon": [ + [ + 623.517578125, + 12692.47412109375 + ], + [ + 1251.64697265625, + 12692.47412109375 + ], + [ + 1251.64697265625, + 12780.309814453125 + ], + [ + 623.517578125, + 12780.309814453125 + ] + ] + }, + { + "title": "Summary", + "heading_level": null, + "page_id": 0, + "polygon": [ + [ + 615.21630859375, + 13029.76318359375 + ], + [ + 808.6976318359375, + 13029.76318359375 + ], + [ + 808.6976318359375, + 13068.410888671875 + ], + [ + 615.21630859375, + 13068.410888671875 + ] + ] + }, + { + "title": "Acknowledgements", + "heading_level": null, + "page_id": 0, + "polygon": [ + [ + 618.90576171875, + 13679.747314453125 + ], + [ + 892.84765625, + 13679.747314453125 + ], + [ + 892.84765625, + 13707.854736328125 + ], + [ + 618.90576171875, + 13707.854736328125 + ] + ] + }, + { + "title": "Appendix 1: Agents in practice", + "heading_level": null, + "page_id": 0, + "polygon": [ + [ + 620.75048828125, + 13875.653076171875 + ], + [ + 1182.4697265625, + 13875.653076171875 + ], + [ + 1182.4697265625, + 13911.653106689453 + ], + [ + 620.75048828125, + 13911.653106689453 + ] + ] + }, + { + "title": "A. Customer support", + "heading_level": null, + "page_id": 0, + "polygon": [ + [ + 616.138671875, + 14162.843627929688 + ], + [ + 909.199951171875, + 14162.843627929688 + ], + [ + 909.199951171875, + 14194.464477539062 + ], + [ + 616.138671875, + 14194.464477539062 + ] + ] + }, + { + "title": "B. Coding agents", + "heading_level": null, + "page_id": 1, + "polygon": [ + [ + 620.75048828125, + 286.9424473044454 + ], + [ + 856.87548828125, + 286.9424473044454 + ], + [ + 856.87548828125, + 316.3837091258816 + ], + [ + 620.75048828125, + 316.3837091258816 + ] + ] + }, + { + "title": "Appendix 2: Prompt engineering\nyour tools", + "heading_level": null, + "page_id": 1, + "polygon": [ + [ + 624.5, + 824.3553310002137 + ], + [ + 1221.208984375, + 824.3553310002137 + ], + [ + 1221.208984375, + 902.653076171875 + ], + [ + 624.5, + 902.653076171875 + ] + ] + }, + { + "title": "", + "heading_level": null, + "page_id": 1, + "polygon": [ + [ + 26.791770935058594, + 2699.918776865774 + ], + [ + 81.629150390625, + 2699.918776865774 + ], + [ + 81.629150390625, + 2728.3545980826484 + ], + [ + 26.791770935058594, + 2728.3545980826484 + ] + ] + }, + { + "title": "Product", + "heading_level": null, + "page_id": 1, + "polygon": [ + [ + 295.617431640625, + 2695.9569091796875 + ], + [ + 383.241943359375, + 2695.9569091796875 + ], + [ + 383.241943359375, + 2720.6498857772513 + ], + [ + 295.617431640625, + 2720.6498857772513 + ] + ] + }, + { + "title": "API Platform", + "heading_level": null, + "page_id": 1, + "polygon": [ + [ + 317.6874694824219, + 3033.876074500511 + ], + [ + 433.971923828125, + 3033.876074500511 + ], + [ + 433.971923828125, + 3054.026860530011 + ], + [ + 317.6874694824219, + 3054.026860530011 + ] + ] + } + ], + "page_stats": [ + { + "page_id": 0, + "text_extraction_method": "pdftext", + "block_counts": [ + [ + "Span", + 696 + ], + [ + "Line", + 268 + ], + [ + "Text", + 49 + ], + [ + "ListItem", + 31 + ], + [ + "SectionHeader", + 21 + ], + [ + "ListGroup", + 11 + ], + [ + "Figure", + 6 + ], + [ + "Form", + 3 + ], + [ + "PageHeader", + 2 + ], + [ + "Picture", + 1 + ], + [ + "PageFooter", + 1 + ], + [ + "Caption", + 1 + ] + ] + }, + { + "page_id": 1, + "text_extraction_method": "pdftext", + "block_counts": [ + [ + "Span", + 304 + ], + [ + "Line", + 136 + ], + [ + "Text", + 36 + ], + [ + "ListItem", + 15 + ], + [ + "SectionHeader", + 5 + ], + [ + "ListGroup", + 4 + ], + [ + "Picture", + 1 + ] + ] + } + ], + "debug_data_path": "debug_data/source" +} \ No newline at end of file diff --git a/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/refined_doc.json b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/refined_doc.json new file mode 100644 index 0000000000000000000000000000000000000000..c000871814e4360633849e6bde9e6f8bf5430dec --- /dev/null +++ b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/refined_doc.json @@ -0,0 +1,295 @@ +{ + "image_dir": "/Users/shijingwei/Desktop/PresentAgent/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f", + "sections": [ + { + "title": "Building effective agents", + "summary": "Appendix 2 emphasizes the importance of prompt engineering when defining tools for agentic systems like Claude, which utilize external services and APIs. It recommends considering multiple action specifications, ensuring accessibility, and minimizing formatting overhead to facilitate model understanding. Key strategies include clear tool definitions, including examples, refining parameter descriptions for clarity, testing model interactions with tools, and implementing design changes to reduce errors. The section asserts that optimizing tool specifications can significantly influence the agent's performance, often requiring more attention than overall prompt design.", + "subsections": [ + { + "title": "What are agents?", + "content": "\"Agent\" can be defined in several ways. Some customers define agents as fully autonomous systems that operate independently over extended periods, using various tools to accomplish complex tasks. Others use the term to describe more prescriptive implementations that follow predefined workflows. At Anthropic, we categorize all these variations as agentic systems, but draw an important architectural distinction between workflows and agents: Workflows are systems where LLMs and tools are orchestrated through predefined code paths. Agents, on the other hand, are systems where LLMs dynamically direct their own processes and tool usage, maintaining control over how they accomplish tasks. Below, we will explore both types of agentic systems in detail. In Appendix 1 (\"Agents in Practice\"), we describe two domains where customers have found particular value in using these kinds of systems.", + "medias": [] + }, + { + "title": "When (and when not) to use agents", + "content": "When building applications with LLMs, we recommend finding the simplest solution possible, and only increasing complexity when needed. This might mean not building agentic systems at all. Agentic systems often trade latency and cost for better task performance, and you should consider when this tradeoff makes sense. When more complexity is warranted, workflows offer predictability and consistency for well-defined tasks, whereas agents are the better option when flexibility and model-driven decision-making are needed at scale. For many applications, however, optimizing single LLM calls with retrieval and in-context examples is usually enough.", + "medias": [] + } + ], + "markdown_content": null + }, + { + "title": "When and how to use frameworks", + "summary": "Appendix 2 emphasizes the importance of prompt engineering when defining tools for agentic systems like Claude, which utilize external services and APIs. It recommends considering multiple action specifications, ensuring accessibility, and minimizing formatting overhead to facilitate model understanding. Key strategies include clear tool definitions, including examples, refining parameter descriptions for clarity, testing model interactions with tools, and implementing design changes to reduce errors. The section asserts that optimizing tool specifications can significantly influence the agent's performance, often requiring more attention than overall prompt design.", + "subsections": [ + { + "title": "Frameworks Overview", + "content": "There are many frameworks that make agentic systems easier to implement, including: LangGraph from LangChain; Amazon Bedrock's AI Agent framework; Rivet, a drag and drop GUI LLM workflow builder; and Vellum, another GUI tool for building and testing complex workflows. These frameworks make it easy to get started by simplifying standard low-level tasks like calling LLMs, defining and parsing tools, and chaining calls together. However, they often create extra layers of abstraction that can obscure the underlying prompts and responses, making them harder to debug. They can also make it tempting to add complexity when a simpler setup would suffice. We suggest that developers start by using LLM APIs directly: many patterns can be implemented in a few lines of code. If you do use a framework, ensure you understand the underlying code. Incorrect assumptions about what's under the hood are a common source of customer error. See our cookbook for some sample implementations.", + "medias": [ + { + "markdown_content": "![](_page_0_Figure_27.jpeg)", + "near_chunks": [ + "The basic building block of agentic systems is an LLM enhanced with augmentations such as retrieval, tools, and memory. Our current models can actively use these capabilities—generating their own search queries, selecting appropriate tools, and determining what information to retain.\n\n", + "#### The augmented LLM\n\nThe prompt chaining workflow\n\nWe recommend focusing on two key aspects of the implementation: tailoring these capabilities to your specific use case and ensuring they provide an easy, well-documented interface for your LLM. While there are many ways to implement these augmentations, one approach is through our recently released Model Context Protocol, which allows developers to integrate with a growing ecosystem of third-party tools with a simple client implementation.\n\n" + ], + "path": "/Users/shijingwei/Desktop/PresentAgent/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Figure_27.jpeg", + "caption": "Diagram: Flowchart illustrating the interactions in an augmented LLM system. It shows input and output paths along with components like Retrieval, Tools, and Memory, detailing how the system processes queries and responses." + }, + { + "markdown_content": "![](_page_0_Figure_34.jpeg)", + "near_chunks": [ + "Prompt chaining decomposes a task into a sequence of steps, where each LLM call processes the output of the previous one. You can add programmatic checks (see \"gate\" in the diagram below) on any intermediate steps to ensure that the process is still on track.\n\n", + "When to use this workflow: This workflow is ideal for situations where the task can be easily and cleanly decomposed into fixed subtasks. The main goal is to trade off latency for higher accuracy, by making each LLM call an easier task.\n\nExamples where prompt chaining is useful:\n\n" + ], + "path": "/Users/shijingwei/Desktop/PresentAgent/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Figure_34.jpeg", + "caption": "Diagram: Flowchart illustrating a prompt chaining workflow. It shows the sequence of steps from input through several LLM calls, with a gate determining whether to pass to the next step or exit." + }, + { + "markdown_content": "![](_page_0_Figure_42.jpeg)", + "near_chunks": [ + "Routing classifies an input and directs it to a specialized followup task. This workflow allows for separation of concerns, and building more specialized prompts. Without this workflow, optimizing for one kind of input can hurt performance on other inputs.\n\n", + "The routing workflow\n\nWhen to use this workflow: Routing works well for complex tasks where there are distinct categories that are better handled separately, and where classification can be handled accurately, either by an LLM or a more traditional classification model/algorithm.\n\n" + ], + "path": "/Users/shijingwei/Desktop/PresentAgent/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Figure_42.jpeg", + "caption": "Diagram: Flowchart illustrating the routing process for LLM calls. Shows an input leading to an LLM call router, which directs to multiple LLM calls and outputs a result, highlighting the separation of tasks in complex workflows." + }, + { + "markdown_content": "| 7 | LLM Call 1 | 7 | | |\n| --- | --- | --- | --- | --- |\n| > In | LLM Call 2 | > | Aggregator | Out 1 |\n| 1 | LLM Call 3 | 기 | | |", + "near_chunks": [ + "- Sectioning: Breaking a task into independent subtasks run in parallel.\n- Voting: Running the same task multiple times to get diverse outputs.\n\nLLMs can sometimes work simultaneously on a task and have their outputs aggregated programmatically. This workflow, parallelization, manifests in two key variations:\n\n", + "The parallelization workflow\n\ndivided subtasks can be parallelized for speed, or when multiple perspectives or attempts are needed for higher confidence results. For complex tasks with multiple considerations, LLMs generally perform better when each consideration is handled by a separate LLM call, allowing focused attention on each specific aspect.\n\n" + ], + "path": "/Users/shijingwei/Desktop/PresentAgent/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/table_9f02.png", + "caption": "Table: This table outlines different LLM calls and their outputs, illustrating a parallelization workflow where tasks are divided and processed separately to enhance efficiency and output diversity.", + "cells": [ + [ + "7", + "LLM Call 1", + "7", + "", + "" + ], + [ + "> In", + "LLM Call 2", + ">", + "Aggregator", + "Out 1" + ], + [ + "1", + "LLM Call 3", + "기", + "", + "" + ] + ], + "merge_area": null + }, + { + "markdown_content": "![](_page_0_Figure_66.jpeg)", + "near_chunks": [ + "In the orchestrator-workers workflow, a central LLM dynamically breaks down tasks, delegates them to worker LLMs, and synthesizes their results.\n\n#### Workflow: Orchestrator-workers\n\n- Implementing guardrails where one model instance processes user queries while another screens them for inappropriate content or requests. This tends to perform better than having the same LLM call handle both guardrails and the core response.\n- Automating evals for evaluating LLM performance, where each LLM call evaluates a different aspect of the model's performance on a given prompt.\n- Voting:\n- Reviewing a piece of code for vulnerabilities, where several different prompts review and flag the code if they find a problem.\n- Evaluating whether a given piece of content is inappropriate, with multiple prompts evaluating different aspects or requiring different vote thresholds to balance false positives and negatives.\n\n", + "The orchestrator-workers workflow\n\nWhen to use this workflow: This workflow is well-suited for complex tasks where you can't predict the subtasks needed (in coding, for example, the number of files that need to be changed and the nature of the change in each file likely depend on the task). Whereas it's topographically similar, the key difference from parallelization is its flexibility—subtasks aren't pre-defined, but determined by the orchestrator based on the specific input.\n\n" + ], + "path": "/Users/shijingwei/Desktop/PresentAgent/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Figure_66.jpeg", + "caption": "Diagram: Workflow illustration showing the process flow of an orchestrator managing multiple LLM calls. Inputs are directed to the orchestrator, which delegates tasks to LLM Call 1, Call 2, and Call 3, before synthesizing results and producing an output." + }, + { + "markdown_content": "![](_page_0_Figure_74.jpeg)", + "near_chunks": [ + "In the evaluator-optimizer workflow, one LLM call generates a response while another provides evaluation and feedback in a loop.\n\n#### Workflow: Evaluator-optimizer\n\n- Coding products that make complex changes to multiple files each time.\n- Search tasks that involve gathering and analyzing information from multiple sources for possible relevant information.\n\n", + "#### The evaluator-optimizer workflow\n\nWhen to use this workflow: This workflow is particularly effective when we have clear evaluation criteria, and when iterative\n\nrefinement provides measurable value. The two signs of good fit are, first, that LLM responses can be demonstrably improved when a human articulates their feedback; and second, that the LLM can provide such feedback. This is analogous to the iterative writing process a human writer might go through when producing a polished document.\n\n" + ], + "path": "/Users/shijingwei/Desktop/PresentAgent/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Figure_74.jpeg", + "caption": "Diagram: Flowchart illustrating the evaluator-optimizer workflow. It shows input leading to an LLM Call Generator, which outputs a solution. An LLM Call Evaluator processes the solution, providing feedback, leading to either acceptance or rejection." + }, + { + "markdown_content": "| (\"Prompt Engineering your Tools\"). |\n| --- |", + "near_chunks": [ + "Agents can handle sophisticated tasks, but their implementation is often straightforward. They are typically just LLMs using tools based on environmental feedback in a loop. It is therefore crucial to design toolsets and their documentation clearly and thoughtfully. We expand on best practices for tool development in Appendix 2\n\n", + "When to use agents: Agents can be used for open-ended problems where it's difficult or impossible to predict the required number of steps, and where you can't hardcode a fixed path. The LLM will potentially operate for many turns, and you must have some level of trust in its decision-making. Agents' autonomy makes them ideal for scaling tasks in trusted environments.\n\n" + ], + "path": "/Users/shijingwei/Desktop/PresentAgent/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/table_ef93.png", + "caption": "Table: Best practices and considerations for prompt engineering tools in the context of developing agents that utilize large language models (LLMs) for effective and reliable task execution.", + "cells": [ + [ + "(\"Prompt Engineering your Tools\")." + ], + [ + "" + ] + ], + "merge_area": null + }, + { + "markdown_content": "![](_page_0_Figure_85.jpeg)", + "near_chunks": [ + "Agents can handle sophisticated tasks, but their implementation is often straightforward. They are typically just LLMs using tools based on environmental feedback in a loop. It is therefore crucial to design toolsets and their documentation clearly and thoughtfully. We expand on best practices for tool development in Appendix 2\n\n", + "When to use agents: Agents can be used for open-ended problems where it's difficult or impossible to predict the required number of steps, and where you can't hardcode a fixed path. The LLM will potentially operate for many turns, and you must have some level of trust in its decision-making. Agents' autonomy makes them ideal for scaling tasks in trusted environments.\n\n" + ], + "path": "/Users/shijingwei/Desktop/PresentAgent/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_0_Figure_85.jpeg", + "caption": "Diagram: Flowchart illustrating the interaction between a human, an LLM call, and the environment. It depicts a feedback loop with an action path, indicating the process of feedback and stopping mechanisms." + } + ] + }, + { + "title": "Building blocks, workflows, and agents", + "content": "In this section, we'll explore the common patterns for agentic systems we've seen in production. We'll start with our foundational building block—the augmented LLM—and progressively increase complexity, from simple compositional workflows to autonomous agents.", + "medias": [] + }, + { + "title": "Building block: The augmented LLM", + "content": "The basic building block of agentic systems is an LLM enhanced with augmentations such as retrieval, tools, and memory. Our current models can actively use these capabilities—generating their own search queries, selecting appropriate tools, and determining what information to retain.", + "medias": [] + }, + { + "title": "The prompt chaining workflow", + "content": "We recommend focusing on two key aspects of the implementation: tailoring these capabilities to your specific use case and ensuring they provide an easy, well-documented interface for your LLM. While there are many ways to implement these augmentations, one approach is through our recently released Model Context Protocol, which allows developers to integrate with a growing ecosystem of third-party tools with a simple client implementation. For the remainder of this post, we'll assume each LLM call has access to these augmented capabilities.", + "medias": [] + }, + { + "title": "Workflow: Prompt chaining", + "content": "Prompt chaining decomposes a task into a sequence of steps, where each LLM call processes the output of the previous one. You can add programmatic checks (see 'gate' in the diagram below) on any intermediate steps to ensure that the process is still on track. When to use this workflow: This workflow is ideal for situations where the task can be easily and cleanly decomposed into fixed subtasks. The main goal is to trade off latency for higher accuracy, by making each LLM call an easier task. Examples where prompt chaining is useful: Generating Marketing copy, then translating it into a different language. Writing an outline of a document, checking that the outline meets certain criteria, then writing the document based on the outline.", + "medias": [] + }, + { + "title": "Workflow: Routing", + "content": "Routing classifies an input and directs it to a specialized followup task. This workflow allows for separation of concerns, and building more specialized prompts. Without this workflow, optimizing for one kind of input can hurt performance on other inputs. When to use this workflow: Routing works well for complex tasks where there are distinct categories that are better handled separately, and where classification can be handled accurately, either by an LLM or a more traditional classification model/algorithm. Examples where routing is useful: Directing different types of customer service queries (general questions, refund requests, technical support) into different downstream processes, prompts, and tools. Routing easy/common questions to smaller models like Claude 3.5 Haiku and hard/unusual questions to more capable models like Claude 3.5 Sonnet to optimize cost and speed.", + "medias": [] + }, + { + "title": "Workflow: Parallelization", + "content": "LLMs can sometimes work simultaneously on a task and have their outputs aggregated programmatically. This workflow, parallelization, manifests in two key variations: Sectioning: Breaking a task into independent subtasks run in parallel and Voting: Running the same task multiple times to get diverse outputs. Divided subtasks can be parallelized for speed, or when multiple perspectives or attempts are needed for higher confidence results. For complex tasks with multiple considerations, LLMs generally perform better when each consideration is handled by a separate LLM call, allowing focused attention on each specific aspect. Examples where parallelization is useful: Implementing guardrails where one model instance processes user queries while another screens them for inappropriate content or requests. Automating evals for evaluating LLM performance, where each LLM call evaluates a different aspect of the model's performance on a given prompt.", + "medias": [] + }, + { + "title": "Workflow: Orchestrator-workers", + "content": "In the orchestrator-workers workflow, a central LLM dynamically breaks down tasks, delegates them to worker LLMs, and synthesizes their results. When to use this workflow: This workflow is well-suited for complex tasks where you can't predict the subtasks needed. Whereas it's topographically similar, the key difference from parallelization is its flexibility—subtasks aren't pre-defined, but determined by the orchestrator based on the specific input. Example where orchestrator-workers is useful: Coding products that make complex changes to multiple files each time. Search tasks that involve gathering and analyzing information from multiple sources for possible relevant information.", + "medias": [] + }, + { + "title": "Workflow: Evaluator-optimizer", + "content": "In the evaluator-optimizer workflow, one LLM call generates a response while another provides evaluation and feedback in a loop. When to use this workflow: This workflow is particularly effective when we have clear evaluation criteria, and when iterative refinement provides measurable value. The two signs of good fit are, first, that LLM responses can be demonstrably improved when a human articulates their feedback; and second, that the LLM can provide such feedback. This is analogous to the iterative writing process a human writer might go through when producing a polished document.", + "medias": [] + }, + { + "title": "Agents", + "content": "Agents are emerging in production as LLMs mature in key capabilities—understanding complex inputs, engaging in reasoning and planning, using tools reliably, and recovering from errors. Agents begin their work with either a command from, or interactive discussion with, the human user. Once the task is clear, agents plan and operate independently, potentially returning to the human for further information or judgement. During execution, it's crucial for the agents to gain 'ground truth' from the environment at each step (such as tool call results or code execution) to assess its progress. Agents can then pause for human feedback at checkpoints or when encountering blockers. The task often terminates upon completion, but it's also common to include stopping conditions (such as a maximum number of iterations) to maintain control. Examples where agents are useful: A coding Agent to resolve SWE-bench tasks, which involve edits to many files based on a task description; Our 'computer use' reference implementation, where Claude uses a computer to accomplish tasks.", + "medias": [] + } + ], + "markdown_content": null + }, + { + "title": "Combining and customizing these patterns", + "summary": "Appendix 2 emphasizes the importance of prompt engineering when defining tools for agentic systems like Claude, which utilize external services and APIs. It recommends considering multiple action specifications, ensuring accessibility, and minimizing formatting overhead to facilitate model understanding. Key strategies include clear tool definitions, including examples, refining parameter descriptions for clarity, testing model interactions with tools, and implementing design changes to reduce errors. The section asserts that optimizing tool specifications can significantly influence the agent's performance, often requiring more attention than overall prompt design.", + "subsections": [ + { + "title": "Summary", + "content": "Success in the LLM space isn't about building the most sophisticated system. It's about building the *right* system for your needs. Start with simple prompts, optimize them with comprehensive evaluation, and add multi-step agentic systems only when simpler solutions fall short.\n\nWhen implementing agents, we try to follow three core principles:\n\n- Maintain simplicity in your agent's design. 2. Prioritize transparency by explicitly showing the agent's planning steps.\n- Carefully craft your agent-computer interface (ACI) through thorough tool documentation and testing.\n\nFrameworks can help you get started quickly, but don't hesitate to reduce abstraction layers and build with basic components as you move to production. By following these principles, you can create agents that are not only powerful but also reliable, maintainable, and trusted by their users.", + "medias": [] + }, + { + "title": "Acknowledgements", + "content": "Written by Erik Schluntz and Barry Zhang. This work draws upon our experiences building agents at Anthropic and the valuable insights shared by our customers, for which we're deeply grateful.", + "medias": [] + }, + { + "title": "Appendix 1: Agents in practice", + "content": "Our work with customers has revealed two particularly promising applications for AI agents that demonstrate the practical value of the patterns discussed above. Both applications illustrate how agents add the most value for tasks that require both conversation and action, have clear success criteria, enable feedback loops, and integrate meaningful human oversight.", + "medias": [] + }, + { + "title": "A. Customer support", + "content": "Customer support combines familiar chatbot interfaces with enhanced capabilities through tool integration. This is a natural fit for more open-ended agents because:\n\n- Support interactions naturally follow a conversation flow while requiring access to external information and actions; Tools can be integrated to pull customer data, order history, and knowledge base articles; Actions such as issuing refunds or updating tickets can be handled programmatically; and Success can be clearly measured through user-defined resolutions.\n\nSeveral companies have demonstrated the viability of this approach through usage-based pricing models that charge only for successful resolutions, showing confidence in their agents' effectiveness.", + "medias": [] + }, + { + "title": "B. Coding agents", + "content": "The software development space has shown remarkable potential for LLM features, with capabilities evolving from code completion to autonomous problem-solving. Agents are particularly effective because:\n\n- Code solutions are verifiable through automated tests;\n- Agents can iterate on solutions using test results as feedback;\n- The problem space is well-defined and structured; and\n- Output quality can be measured objectively.\n\nIn our own implementation, agents can now solve real GitHub issues in the SWE-bench Verified benchmark based on the pull request description alone. However, whereas automated testing helps verify functionality, human review remains crucial for ensuring solutions align with broader system requirements.", + "medias": [] + } + ], + "markdown_content": null + }, + { + "title": "Appendix 2: Prompt engineering your tools", + "summary": "Appendix 2 emphasizes the importance of prompt engineering when defining tools for agentic systems like Claude, which utilize external services and APIs. It recommends considering multiple action specifications, ensuring accessibility, and minimizing formatting overhead to facilitate model understanding. Key strategies include clear tool definitions, including examples, refining parameter descriptions for clarity, testing model interactions with tools, and implementing design changes to reduce errors. The section asserts that optimizing tool specifications can significantly influence the agent's performance, often requiring more attention than overall prompt design.", + "subsections": [ + { + "title": "Introduction to Tools", + "content": "No matter which agentic system you're building, tools will likely be an important part of your agent. Tools enable Claude to interact with external services and APIs by specifying their exact structure and definition in our API. When Claude responds, it will include a tool use block in the API response if it plans to invoke a tool. Tool definitions and specifications should be given just as much prompt engineering attention as your overall prompts. In this brief appendix, we describe how to prompt engineer your tools.", + "medias": [ + { + "markdown_content": "![](_page_1_Picture_52.jpeg)", + "near_chunks": [ + "© 2025 Anthropic PBC\n\nUsage policy\n\nTerms of service commercial\n\nTerms of service consumer\n\nResponsible disclosure policy\n\nPrivacy policy\n\nPrivacy choices\n\nTerms and policies\n\nSupport center\n\nHelp and security Status Availability\n\nStartups program\n\nEvents News\n\n", + "" + ], + "path": "/Users/shijingwei/Desktop/PresentAgent/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/_page_1_Picture_52.jpeg", + "caption": "Icon: Three social media platform icons displayed in grayscale: YouTube, LinkedIn, and X. The icons are positioned side by side against a dark background." + } + ] + }, + { + "title": "Guidelines for Tool Formats", + "content": "Our suggestions for deciding on tool formats are the following:\n\n- Give the model enough tokens to \"think\" before it writes itself into a corner.\n- Keep the format close to what the model has seen naturally occurring in text on the internet.\n- Make sure there's no formatting \"overhead\" such as having to keep an accurate count of thousands of lines of code, or string escaping any code it writes.", + "medias": [] + }, + { + "title": "Human-Computer Interfaces", + "content": "One rule of thumb is to think about how much effort goes into human-computer interfaces (HCI), and plan to invest just as much effort in creating good *agent*-computer interfaces (ACI). Here are some thoughts on how to do so:\n\n- Put yourself in the model's shoes. Is it obvious how to use this tool, based on the description and parameters, or would you need to think carefully about it? If so, then it's probably also true for the model. A good tool definition often includes example usage, edge cases, input format requirements, and clear boundaries from other tools.\n- How can you change parameter names or descriptions to make things more obvious? Think of this as writing a great docstring for a junior developer on your team. This is especially important when using many similar tools.\n- Test how the model uses your tools: Run many example inputs in our workbench to see what mistakes the model makes, and iterate.\n- Poka-yoke your tools. Change the arguments so that it is harder to make mistakes.", + "medias": [] + }, + { + "title": "Tool Optimization", + "content": "While building our agent for SWE-bench, we actually spent more time optimizing our tools than the overall prompt. For example, we found that the model would make mistakes with tools using relative filepaths after the agent had moved out of the root directory. To fix this, we changed the tool to always require absolute filepaths—and we found that the model used this method flawlessly.", + "medias": [] + }, + { + "title": "Product Overview", + "content": "Claude overview Claude Code Claude team plan Claude enterprise plan Claude education plan Download Claude apps Claude.ai pricing plans Claude.ai login", + "medias": [] + }, + { + "title": "API Platform", + "content": "API overview Developer docs Claude in Amazon Bedrock Claude on Google Cloud's Vertex AI Pricing", + "medias": [] + }, + { + "title": "Console and Research", + "content": "Console login\n\nResearch\n\nResearch overview Economic Index\n\nClaude models Claude Opus 4 Claude Sonnet 4\n\nClaude Haiku 3.5\n\nCommitments Transparency Responsible scaling policy Security and compliance", + "medias": [] + }, + { + "title": "Solutions and Learn", + "content": "Solutions AI agents\n\nCoding\n\nCustomer support\n\nLearn\n\nAnthropic Academy Customer stories Engineering at Anthropic MCP Integrations", + "medias": [] + }, + { + "title": "Corporate Information", + "content": "Explore About us Become a partner Careers\n\nEvents News\n\nStartups program\n\nHelp and security Status Availability\n\nSupport center\n\nTerms and policies\n\nPrivacy choices\n\nPrivacy policy\n\nResponsible disclosure policy\n\nTerms of service consumer\n\nTerms of service commercial\n\nUsage policy", + "medias": [] + }, + { + "title": "Copyright Notice", + "content": "© 2025 Anthropic PBC\n\n![](_page_1_Picture_52.jpeg)", + "medias": [] + } + ], + "markdown_content": null + } + ], + "metadata": { + "title": "Building effective agents", + "publish_date": "Dec 19, 2024", + "authors": [ + "Erik Schluntz", + "Barry Zhang" + ], + "organization": "Anthropic PBC", + "year": "2025", + "presentation-date": "2025-07-05" + } +} \ No newline at end of file diff --git a/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/source.md b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/source.md new file mode 100644 index 0000000000000000000000000000000000000000..51219594d1d14138bf8f1a7e6bb06a959dedad6e --- /dev/null +++ b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/source.md @@ -0,0 +1,305 @@ +# Building effective agents + +![](_page_0_Picture_2.jpeg) + +Engineering at Anthropic + +Published Dec 19, 2024 We've worked with dozens of teams building LLM agents across industries. Consistently, the most successful implementations use simple, composable patterns rather than complex frameworks. + +> Over the past year, we've worked with dozens of teams building large language model (LLM) agents across industries. Consistently, the most successful implementations weren't using complex frameworks or specialized libraries. Instead, they were building with simple, composable patterns. + +In this post, we share what we've learned from working with our customers and building agents ourselves, and give practical advice for developers on building effective agents. + +#### What are agents? + +"Agent" can be defined in several ways. Some customers define agents as fully autonomous systems that operate independently over extended periods, using various tools to accomplish complex tasks. Others use the term to describe more prescriptive implementations that follow predefined workflows. At Anthropic, we categorize all these variations as agentic systems, but draw an important architectural distinction between workflows and agents: + +- Workflows are systems where LLMs and tools are orchestrated +- through predefined code paths. Agents, on the other hand, are systems where LLMs dynamically direct their own processes and tool usage, maintaining control over how they accomplish tasks. + +Below, we will explore both types of agentic systems in detail. In Appendix 1 ("Agents in Practice"), we describe two domains where customers have found particular value in using these kinds of systems. + +### When (and when not) to use agents + +When building applications with LLMs, we recommend finding the simplest solution possible, and only increasing complexity when needed. This might mean not building agentic systems at all. Agentic systems often trade latency and cost for better task performance, and you should consider when this tradeoff makes sense. + +When more complexity is warranted, workflows offer predictability and consistency for well-defined tasks, whereas agents are the better option when flexibility and model-driven decision-making are needed at scale. For many applications, however, optimizing single LLM calls with retrieval and in-context examples is usually enough. + +## When and how to use frameworks + +There are many frameworks that make agentic systems easier to implement, including: + +- LangGraph from LangChain; +- Amazon Bedrock's AI Agent framework; +- Rivet, a drag and drop GUI LLM workflow builder; and +- Vellum, another GUI tool for building and testing complex workflows. + +These frameworks make it easy to get started by simplifying standard low-level tasks like calling LLMs, defining and parsing tools, and chaining calls together. However, they often create extra layers of abstraction that can obscure the underlying prompts and responses, making them harder to debug. They can also make it tempting to add complexity when a simpler setup would suffice. + +We suggest that developers start by using LLM APIs directly: many patterns can be implemented in a few lines of code. If you do use a framework, ensure you understand the underlying code. Incorrect assumptions about what's under the hood are a common source of customer error. + +See our cookbook for some sample implementations. + +### Building blocks, workflows, and agents + +In this section, we'll explore the common patterns for agentic systems we've seen in production. We'll start with our foundational building block—the augmented LLM—and progressively increase complexity, from simple compositional workflows to autonomous agents. + +#### Building block: The augmented LLM + +The basic building block of agentic systems is an LLM enhanced with augmentations such as retrieval, tools, and memory. Our current models can actively use these capabilities—generating their own search queries, selecting appropriate tools, and determining what information to retain. + +![](_page_0_Figure_27.jpeg) + +#### The augmented LLM + +The prompt chaining workflow + +We recommend focusing on two key aspects of the implementation: tailoring these capabilities to your specific use case and ensuring they provide an easy, well-documented interface for your LLM. While there are many ways to implement these augmentations, one approach is through our recently released Model Context Protocol, which allows developers to integrate with a growing ecosystem of third-party tools with a simple client implementation. + +For the remainder of this post, we'll assume each LLM call has access to these augmented capabilities. + +#### Workflow: Prompt chaining + +Prompt chaining decomposes a task into a sequence of steps, where each LLM call processes the output of the previous one. You can add programmatic checks (see "gate" in the diagram below) on any intermediate steps to ensure that the process is still on track. + +![](_page_0_Figure_34.jpeg) + +When to use this workflow: This workflow is ideal for situations where the task can be easily and cleanly decomposed into fixed subtasks. The main goal is to trade off latency for higher accuracy, by making each LLM call an easier task. + +Examples where prompt chaining is useful: + +Generating Marketing copy, then translating it into a different + +- language. Writing an outline of a document, checking that the outline +- meets certain criteria, then writing the document based on the outline. + +#### Workflow: Routing + +Routing classifies an input and directs it to a specialized followup task. This workflow allows for separation of concerns, and building more specialized prompts. Without this workflow, optimizing for one kind of input can hurt performance on other inputs. + +![](_page_0_Figure_42.jpeg) + +The routing workflow + +When to use this workflow: Routing works well for complex tasks where there are distinct categories that are better handled separately, and where classification can be handled accurately, either by an LLM or a more traditional classification model/algorithm. + +Examples where routing is useful: + +- Directing different types of customer service queries (general questions, refund requests, technical support) into different +- downstream processes, prompts, and tools. Routing easy/common questions to smaller models like Claude 3.5 Haiku and hard/unusual questions to more capable models +- like Claude 3.5 Sonnet to optimize cost and speed. + +#### Workflow: Parallelization + +LLMs can sometimes work simultaneously on a task and have their outputs aggregated programmatically. This workflow, parallelization, manifests in two key variations: + +- Sectioning: Breaking a task into independent subtasks run in parallel. +- Voting: Running the same task multiple times to get diverse outputs. + +| 7 | LLM Call 1 | 7 | | | +| --- | --- | --- | --- | --- | +| > In | LLM Call 2 | > | Aggregator | Out 1 | +| 1 | LLM Call 3 | 기 | | | + +The parallelization workflow + +divided subtasks can be parallelized for speed, or when multiple perspectives or attempts are needed for higher confidence results. For complex tasks with multiple considerations, LLMs generally perform better when each consideration is handled by a separate LLM call, allowing focused attention on each specific aspect. + +Examples where parallelization is useful: + +#### Sectioning: + +- Implementing guardrails where one model instance processes user queries while another screens them for inappropriate content or requests. This tends to perform better than having the same LLM call handle both guardrails and the core response. +- Automating evals for evaluating LLM performance, where each LLM call evaluates a different aspect of the model's performance on a given prompt. +- Voting: +- Reviewing a piece of code for vulnerabilities, where several different prompts review and flag the code if they find a problem. +- Evaluating whether a given piece of content is inappropriate, with multiple prompts evaluating different aspects or requiring different vote thresholds to balance false positives and negatives. + +#### Workflow: Orchestrator-workers + +In the orchestrator-workers workflow, a central LLM dynamically breaks down tasks, delegates them to worker LLMs, and synthesizes their results. + +![](_page_0_Figure_66.jpeg) + +The orchestrator-workers workflow + +When to use this workflow: This workflow is well-suited for complex tasks where you can't predict the subtasks needed (in coding, for example, the number of files that need to be changed and the nature of the change in each file likely depend on the task). Whereas it's topographically similar, the key difference from parallelization is its flexibility—subtasks aren't pre-defined, but determined by the orchestrator based on the specific input. + +Example where orchestrator-workers is useful: + +- Coding products that make complex changes to multiple files each time. +- Search tasks that involve gathering and analyzing information from multiple sources for possible relevant information. + +#### Workflow: Evaluator-optimizer + +In the evaluator-optimizer workflow, one LLM call generates a response while another provides evaluation and feedback in a loop. + +![](_page_0_Figure_74.jpeg) + +#### The evaluator-optimizer workflow + +When to use this workflow: This workflow is particularly effective when we have clear evaluation criteria, and when iterative + +refinement provides measurable value. The two signs of good fit are, first, that LLM responses can be demonstrably improved when a human articulates their feedback; and second, that the LLM can provide such feedback. This is analogous to the iterative writing process a human writer might go through when producing a polished document. + +#### Examples where evaluator-optimizer is useful: + +- Literary translation where there are nuances that the translator LLM might not capture initially, but where an evaluator LLM can provide useful critiques. +- Complex search tasks that require multiple rounds of searching and analysis to gather comprehensive information, where the evaluator decides whether further searches are warranted. + +#### Agents + +Agents are emerging in production as LLMs mature in key capabilities—understanding complex inputs, engaging in reasoning and planning, using tools reliably, and recovering from errors. Agents begin their work with either a command from, or interactive discussion with, the human user. Once the task is clear, agents plan and operate independently, potentially returning to the human for further information or judgement. During execution, it's crucial for the agents to gain "ground truth" from the environment at each step (such as tool call results or code execution) to assess its progress. Agents can then pause for human feedback at checkpoints or when encountering blockers. The task often terminates upon completion, but it's also common to include stopping conditions (such as a maximum number of iterations) to maintain control. + +Agents can handle sophisticated tasks, but their implementation is often straightforward. They are typically just LLMs using tools based on environmental feedback in a loop. It is therefore crucial to design toolsets and their documentation clearly and thoughtfully. We expand on best practices for tool development in Appendix 2 + +| ("Prompt Engineering your Tools"). | +| --- | + +![](_page_0_Figure_85.jpeg) + +When to use agents: Agents can be used for open-ended problems where it's difficult or impossible to predict the required number of steps, and where you can't hardcode a fixed path. The LLM will potentially operate for many turns, and you must have some level of trust in its decision-making. Agents' autonomy makes them ideal for scaling tasks in trusted environments. + +The autonomous nature of agents means higher costs, and the potential for compounding errors. We recommend extensive testing in sandboxed environments, along with the appropriate guardrails. + +Examples where agents are useful: + +- The following examples are from our own implementations: + - A coding Agent to resolve SWE-bench tasks, which involve edits to many files based on a task description; + - Our "computer use" reference implementation, where Claude uses a computer to accomplish tasks. + +| +| | + +High-level flow of a coding agent + +Autonomous agent + +# Combining and customizing these patterns + +These building blocks aren't prescriptive. They're common patterns that developers can shape and combine to fit different use cases. The key to success, as with any LLM features, is measuring performance and iterating on implementations. To repeat: you should consider adding complexity *only* when it demonstrably improves outcomes. + +#### Summary + +Success in the LLM space isn't about building the most sophisticated system. It's about building the *right* system for your needs. Start with simple prompts, optimize them with comprehensive evaluation, and add multi-step agentic systems only when simpler solutions fall short. + +When implementing agents, we try to follow three core principles: + +- 1. Maintain simplicity in your agent's design. 2. Prioritize transparency by explicitly showing the agent's +- planning steps. +- 3. Carefully craft your agent-computer interface (ACI) through thorough tool documentation and testing. + +Frameworks can help you get started quickly, but don't hesitate to reduce abstraction layers and build with basic components as you move to production. By following these principles, you can create agents that are not only powerful but also reliable, maintainable, and trusted by their users. + +#### Acknowledgements + +Written by Erik Schluntz and Barry Zhang. This work draws upon our experiences building agents at Anthropic and the valuable insights shared by our customers, for which we're deeply grateful. + +#### Appendix 1: Agents in practice + +Our work with customers has revealed two particularly promising applications for AI agents that demonstrate the practical value of the patterns discussed above. Both applications illustrate how agents add the most value for tasks that require both conversation and action, have clear success criteria, enable feedback loops, and integrate meaningful human oversight. + +#### A. Customer support + +Customer support combines familiar chatbot interfaces with enhanced capabilities through tool integration. This is a natural fit for more open-ended agents because: + +- Support interactions naturally follow a conversation flow while +- requiring access to external information and actions; Tools can be integrated to pull customer data, order history, and +- knowledge base articles; +- knowledge base articles; +- Actions such as issuing refunds or updating tickets can be handled programmatically; and +- Success can be clearly measured through user-defined resolutions. + +Several companies have demonstrated the viability of this approach through usage-based pricing models that charge only for successful resolutions, showing confidence in their agents' effectiveness. + +#### B. Coding agents + +The software development space has shown remarkable potential for LLM features, with capabilities evolving from code completion to autonomous problem-solving. Agents are particularly effective because: + +- Code solutions are verifiable through automated tests; +- Agents can iterate on solutions using test results as feedback; +- The problem space is well-defined and structured; and +- Output quality can be measured objectively. + +In our own implementation, agents can now solve real GitHub issues in the SWE-bench Verified benchmark based on the pull request description alone. However, whereas automated testing helps verify functionality, human review remains crucial for ensuring solutions align with broader system requirements. + +### Appendix 2: Prompt engineering your tools + +No matter which agentic system you're building, tools will likely be an important part of your agent. Tools enable Claude to interact with external services and APIs by specifying their exact structure and definition in our API. When Claude responds, it will include a tool use block in the API response if it plans to invoke a tool. Tool definitions and specifications should be given just as much prompt engineering attention as your overall prompts. In this brief appendix, we describe how to prompt engineer your tools. + +There are often several ways to specify the same action. For instance, you can specify a file edit by writing a diff, or by rewriting the entire file. For structured output, you can return code inside markdown or inside JSON. In software engineering, differences like these are cosmetic and can be converted losslessly from one to the other. However, some formats are much more difficult for an LLM to write than others. Writing a diff requires knowing how many lines are changing in the chunk header before the new code is written. Writing code inside JSON (compared to markdown) requires extra escaping of newlines and quotes. + +Our suggestions for deciding on tool formats are the following: + +- Give the model enough tokens to "think" before it writes itself into a corner. +- Keep the format close to what the model has seen naturally occurring in text on the internet. +- Make sure there's no formatting "overhead" such as having to keep an accurate count of thousands of lines of code, or stringescaping any code it writes. +- + +One rule of thumb is to think about how much effort goes into human-computer interfaces (HCI), and plan to invest just as much effort in creating good *agent*-computer interfaces (ACI). Here are some thoughts on how to do so: + +- Put yourself in the model's shoes. Is it obvious how to use this tool, based on the description and parameters, or would you need to think carefully about it? If so, then it's probably also true for the model. A good tool definition often includes example usage, edge cases, input format requirements, and clear boundaries from other tools. +- How can you change parameter names or descriptions to make things more obvious? Think of this as writing a great docstring for a junior developer on your team. This is especially important when using many similar tools. +- Test how the model uses your tools: Run many example inputs in our workbench to see what mistakes the model makes, and iterate. +- Poka-yoke your tools. Change the arguments so that it is harder to make mistakes. + +While building our agent for SWE-bench, we actually spent more time optimizing our tools than the overall prompt. For example, we found that the model would make mistakes with tools using relative filepaths after the agent had moved out of the root directory. To fix this, we changed the tool to always require absolute filepaths—and we found that the model used this method flawlessly. + +#### Product + +Claude overview Claude Code Claude team plan Claude enterprise plan Claude education plan Download Claude apps Claude.ai pricing plans Claude.ai login + +#### API Platform + +API overview Developer docs Claude in Amazon Bedrock Claude on Google Cloud's Vertex AI Pricing + +Console login + +Research + +Research overview Economic Index + +Claude models Claude Opus 4 Claude Sonnet 4 + +Claude Haiku 3.5 + +Commitments Transparency Responsible scaling policy Security and compliance + +Solutions AI agents + +Coding + +Customer support + +Learn + +Anthropic Academy Customer stories Engineering at Anthropic MCP Integrations + +Explore About us Become a partner Careers + +Events News + +Startups program + +Help and security Status Availability + +Support center + +Terms and policies + +Privacy choices + +Privacy policy + +Responsible disclosure policy + +Terms of service consumer + +Terms of service commercial + +Usage policy + +© 2025 Anthropic PBC + +![](_page_1_Picture_52.jpeg) + diff --git a/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/source.pdf b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/source.pdf new file mode 100644 index 0000000000000000000000000000000000000000..57434d59b293976c698c908316a453cb5bdf18e0 --- /dev/null +++ b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/source.pdf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a0353b4c9a809722eba156d0f9678ed537d40c4c06caf6de420136970757e470 +size 1708153 diff --git a/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/table_9f02.png b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/table_9f02.png new file mode 100644 index 0000000000000000000000000000000000000000..1059dbdf43e638bf261b7c2cce9bedd7f99b8152 Binary files /dev/null and b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/table_9f02.png differ diff --git a/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/table_ef93.png b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/table_ef93.png new file mode 100644 index 0000000000000000000000000000000000000000..c22cdb0e1060d47c7c05a6d933009576109a20b3 Binary files /dev/null and b/pptagent/runs/pdf/37fd83b93256101767cb27322fba795f/table_ef93.png differ diff --git a/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/_page_0_Figure_10.jpeg b/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/_page_0_Figure_10.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..c1f66605e8b152d2f25b2cb58feec13c8252cff3 Binary files /dev/null and b/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/_page_0_Figure_10.jpeg differ diff --git a/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/_page_11_Figure_0.jpeg b/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/_page_11_Figure_0.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..dca4560ece93e2963ce63f809041d5cd2875614f Binary files /dev/null and b/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/_page_11_Figure_0.jpeg differ diff --git a/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/_page_2_Figure_0.jpeg b/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/_page_2_Figure_0.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..204af99cec932b8dbc5a66fb9d2c3d69504a568d Binary files /dev/null and b/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/_page_2_Figure_0.jpeg differ diff --git a/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/_page_4_Figure_0.jpeg b/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/_page_4_Figure_0.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..32e9a54e4384e36d5985e01f34c2390429c2331f --- /dev/null +++ b/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/_page_4_Figure_0.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b32012c8bcc8dd44500aa3c84be88660de94f20b1b3fca67a83d972b45ac6ae3 +size 164516 diff --git a/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/_page_5_Figure_2.jpeg b/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/_page_5_Figure_2.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..0c6446c3b4192308b0e7c3e3b93c08d4ae716262 Binary files /dev/null and b/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/_page_5_Figure_2.jpeg differ diff --git a/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/meta.json b/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/meta.json new file mode 100644 index 0000000000000000000000000000000000000000..22751f308cf0a41303071cf4b5e9d30d7f70d788 --- /dev/null +++ b/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/meta.json @@ -0,0 +1,1081 @@ +{ + "table_of_contents": [ + { + "title": "PresentAgent: Multimodal Agent for Presentation Video Generation", + "heading_level": null, + "page_id": 0, + "polygon": [ + [ + 73.7724609375, + 77.4471435546875 + ], + [ + 523.24609375, + 77.4471435546875 + ], + [ + 523.24609375, + 92.09521484375 + ], + [ + 73.7724609375, + 92.09521484375 + ] + ] + }, + { + "title": "Abstract", + "heading_level": null, + "page_id": 0, + "polygon": [ + [ + 157.75796508789062, + 221.8001708984375 + ], + [ + 204.875, + 221.8001708984375 + ], + [ + 204.875, + 233.75537109375 + ], + [ + 157.75796508789062, + 233.75537109375 + ] + ] + }, + { + "title": "1 Introduction", + "heading_level": null, + "page_id": 0, + "polygon": [ + [ + 70.753173828125, + 648.191162109375 + ], + [ + 154.8203125, + 648.191162109375 + ], + [ + 154.8203125, + 660.279296875 + ], + [ + 70.753173828125, + 660.279296875 + ] + ] + }, + { + "title": "2 Presentation Benchmark", + "heading_level": null, + "page_id": 1, + "polygon": [ + [ + 304.693359375, + 362.5001525878906 + ], + [ + 451.365234375, + 362.5001525878906 + ], + [ + 451.365234375, + 374.4553527832031 + ], + [ + 304.693359375, + 374.4553527832031 + ] + ] + }, + { + "title": "2.1 Doc2Present Dataset", + "heading_level": null, + "page_id": 2, + "polygon": [ + [ + 69.880126953125, + 327.6728515625 + ], + [ + 192.361328125, + 327.6728515625 + ], + [ + 192.361328125, + 339.4931945800781 + ], + [ + 69.880126953125, + 339.4931945800781 + ] + ] + }, + { + "title": "2.2 PresentEval", + "heading_level": null, + "page_id": 2, + "polygon": [ + [ + 70.389404296875, + 543.1064453125 + ], + [ + 152.6376953125, + 543.1064453125 + ], + [ + 152.6376953125, + 554.7662048339844 + ], + [ + 70.389404296875, + 554.7662048339844 + ] + ] + }, + { + "title": "3 PresentAgent", + "heading_level": null, + "page_id": 3, + "polygon": [ + [ + 70.607666015625, + 137.81512451171875 + ], + [ + 158.603515625, + 137.81512451171875 + ], + [ + 158.603515625, + 149.77032470703125 + ], + [ + 70.607666015625, + 149.77032470703125 + ] + ] + }, + { + "title": "3.1 Problem Formulation", + "heading_level": null, + "page_id": 3, + "polygon": [ + [ + 70.243896484375, + 345.55712890625 + ], + [ + 196.45156860351562, + 345.55712890625 + ], + [ + 196.45156860351562, + 356.65771484375 + ], + [ + 70.243896484375, + 356.65771484375 + ] + ] + }, + { + "title": "3.2 Slide Planning and Composition", + "heading_level": null, + "page_id": 3, + "polygon": [ + [ + 305.56640625, + 123.43804931640625 + ], + [ + 480.73126220703125, + 123.43804931640625 + ], + [ + 480.73126220703125, + 134.954345703125 + ], + [ + 305.56640625, + 134.954345703125 + ] + ] + }, + { + "title": "3.3 Narration and Audio Synthesis", + "heading_level": null, + "page_id": 3, + "polygon": [ + [ + 306.14202880859375, + 516.3828125 + ], + [ + 475.228515625, + 516.3828125 + ], + [ + 475.228515625, + 527.3302001953125 + ], + [ + 306.14202880859375, + 527.3302001953125 + ] + ] + }, + { + "title": "3.4 Video Assembly", + "heading_level": null, + "page_id": 4, + "polygon": [ + [ + 70.389404296875, + 454.712890625 + ], + [ + 169.85516357421875, + 454.712890625 + ], + [ + 169.85516357421875, + 466.224609375 + ], + [ + 70.389404296875, + 466.224609375 + ] + ] + }, + { + "title": "4 Experiments", + "heading_level": null, + "page_id": 4, + "polygon": [ + [ + 70.2802734375, + 673.435546875 + ], + [ + 154.529296875, + 673.435546875 + ], + [ + 154.529296875, + 685.9183502197266 + ], + [ + 70.2802734375, + 685.9183502197266 + ] + ] + }, + { + "title": "4.1 Main Results", + "heading_level": null, + "page_id": 4, + "polygon": [ + [ + 304.984375, + 524.1943359375 + ], + [ + 392.20391845703125, + 524.1943359375 + ], + [ + 392.20391845703125, + 535.7060546875 + ], + [ + 304.984375, + 535.7060546875 + ] + ] + }, + { + "title": "4.2 Analysis", + "heading_level": null, + "page_id": 5, + "polygon": [ + [ + 70.607666015625, + 661.1015625 + ], + [ + 135.4677734375, + 661.1015625 + ], + [ + 135.4677734375, + 673.0631561279297 + ], + [ + 70.607666015625, + 673.0631561279297 + ] + ] + }, + { + "title": "5 Conclusion", + "heading_level": null, + "page_id": 5, + "polygon": [ + [ + 303.529296875, + 534.0615234375 + ], + [ + 381.208740234375, + 534.0615234375 + ], + [ + 381.208740234375, + 546.3853454589844 + ], + [ + 303.529296875, + 546.3853454589844 + ] + ] + }, + { + "title": "References", + "heading_level": null, + "page_id": 6, + "polygon": [ + [ + 70.06201171875, + 72.56494140625 + ], + [ + 126.95556640625, + 72.56494140625 + ], + [ + 126.95556640625, + 84.71435546875 + ], + [ + 70.06201171875, + 84.71435546875 + ] + ] + }, + { + "title": "A Related Work", + "heading_level": null, + "page_id": 9, + "polygon": [ + [ + 70.243896484375, + 71.845458984375 + ], + [ + 162.61021423339844, + 71.845458984375 + ], + [ + 162.61021423339844, + 84.71435546875 + ], + [ + 70.243896484375, + 84.71435546875 + ] + ] + }, + { + "title": "A.1 Document-to-Multimodal Generation", + "heading_level": null, + "page_id": 9, + "polygon": [ + [ + 70.2802734375, + 93.73828125 + ], + [ + 272.6625061035156, + 93.73828125 + ], + [ + 272.6625061035156, + 106.1451416015625 + ], + [ + 70.2802734375, + 106.1451416015625 + ] + ] + }, + { + "title": "A.2 Vision-Language Agents", + "heading_level": null, + "page_id": 9, + "polygon": [ + [ + 70.098388671875, + 393.24853515625 + ], + [ + 213.4599609375, + 393.24853515625 + ], + [ + 213.4599609375, + 405.17138671875 + ], + [ + 70.098388671875, + 405.17138671875 + ] + ] + }, + { + "title": "B Implementation Details", + "heading_level": null, + "page_id": 9, + "polygon": [ + [ + 306.14202880859375, + 398.18212890625 + ], + [ + 447.58203125, + 398.18212890625 + ], + [ + 447.58203125, + 411.3153381347656 + ], + [ + 306.14202880859375, + 411.3153381347656 + ] + ] + }, + { + "title": "C Discussion", + "heading_level": null, + "page_id": 10, + "polygon": [ + [ + 70.35302734375, + 150.988525390625 + ], + [ + 146.59912109375, + 150.988525390625 + ], + [ + 146.59912109375, + 164.32232666015625 + ], + [ + 70.35302734375, + 164.32232666015625 + ] + ] + }, + { + "title": "D Limitations", + "heading_level": null, + "page_id": 10, + "polygon": [ + [ + 70.316650390625, + 631.5 + ], + [ + 151.4736328125, + 631.5 + ], + [ + 151.4736328125, + 645.5033721923828 + ], + [ + 70.316650390625, + 645.5033721923828 + ] + ] + }, + { + "title": "E Evaluation Benchmark", + "heading_level": null, + "page_id": 10, + "polygon": [ + [ + 305.56640625, + 149.960693359375 + ], + [ + 445.8359375, + 149.960693359375 + ], + [ + 445.8359375, + 163.23040771484375 + ], + [ + 305.56640625, + 163.23040771484375 + ] + ] + }, + { + "title": "F Doc2Present Dataset Details", + "heading_level": null, + "page_id": 10, + "polygon": [ + [ + 305.56640625, + 358.0966796875 + ], + [ + 470.86328125, + 358.0966796875 + ], + [ + 470.86328125, + 371.4473571777344 + ], + [ + 305.56640625, + 371.4473571777344 + ] + ] + }, + { + "title": "G PresentEval", + "heading_level": null, + "page_id": 10, + "polygon": [ + [ + 306.14202880859375, + 654.9345703125 + ], + [ + 390.833984375, + 654.9345703125 + ], + [ + 390.833984375, + 668.8763580322266 + ], + [ + 306.14202880859375, + 668.8763580322266 + ] + ] + }, + { + "title": "G.1 Prompts of Objective Quiz Evaluation", + "heading_level": null, + "page_id": 10, + "polygon": [ + [ + 304.693359375, + 676.724609375 + ], + [ + 513.642578125, + 676.724609375 + ], + [ + 513.642578125, + 689.9411468505859 + ], + [ + 304.693359375, + 689.9411468505859 + ] + ] + }, + { + "title": "G.2 Prompts of Subjective Scoring", + "heading_level": null, + "page_id": 11, + "polygon": [ + [ + 70.35302734375, + 677.1357421875 + ], + [ + 240.9609375, + 677.1357421875 + ], + [ + 240.9609375, + 689.4697265625 + ], + [ + 70.35302734375, + 689.4697265625 + ] + ] + }, + { + "title": "H Evaluation Setup", + "heading_level": null, + "page_id": 11, + "polygon": [ + [ + 304.984375, + 682.0693359375 + ], + [ + 416.734375, + 682.0693359375 + ], + [ + 416.734375, + 696.0478515625 + ], + [ + 304.984375, + 696.0478515625 + ] + ] + } + ], + "page_stats": [ + { + "page_id": 0, + "text_extraction_method": "pdftext", + "block_counts": [ + [ + "Span", + 258 + ], + [ + "Line", + 94 + ], + [ + "Text", + 10 + ], + [ + "SectionHeader", + 3 + ], + [ + "Figure", + 1 + ], + [ + "Caption", + 1 + ], + [ + "PageFooter", + 1 + ], + [ + "FigureGroup", + 1 + ] + ] + }, + { + "page_id": 1, + "text_extraction_method": "pdftext", + "block_counts": [ + [ + "Span", + 205 + ], + [ + "Line", + 101 + ], + [ + "Text", + 8 + ], + [ + "ListItem", + 4 + ], + [ + "SectionHeader", + 1 + ], + [ + "PageFooter", + 1 + ], + [ + "ListGroup", + 1 + ] + ] + }, + { + "page_id": 2, + "text_extraction_method": "pdftext", + "block_counts": [ + [ + "Span", + 205 + ], + [ + "Line", + 103 + ], + [ + "Text", + 6 + ], + [ + "SectionHeader", + 2 + ], + [ + "Figure", + 1 + ], + [ + "Caption", + 1 + ], + [ + "PageFooter", + 1 + ], + [ + "FigureGroup", + 1 + ] + ] + }, + { + "page_id": 3, + "text_extraction_method": "pdftext", + "block_counts": [ + [ + "Span", + 273 + ], + [ + "Line", + 97 + ], + [ + "Text", + 11 + ], + [ + "SectionHeader", + 4 + ], + [ + "Equation", + 2 + ], + [ + "TextInlineMath", + 1 + ], + [ + "PageFooter", + 1 + ] + ] + }, + { + "page_id": 4, + "text_extraction_method": "pdftext", + "block_counts": [ + [ + "Span", + 274 + ], + [ + "Line", + 121 + ], + [ + "Text", + 10 + ], + [ + "SectionHeader", + 3 + ], + [ + "Figure", + 1 + ], + [ + "Caption", + 1 + ], + [ + "PageFooter", + 1 + ], + [ + "FigureGroup", + 1 + ] + ] + }, + { + "page_id": 5, + "text_extraction_method": "pdftext", + "block_counts": [ + [ + "Span", + 211 + ], + [ + "Line", + 77 + ], + [ + "Text", + 5 + ], + [ + "SectionHeader", + 2 + ], + [ + "Table", + 1 + ], + [ + "Figure", + 1 + ], + [ + "Caption", + 1 + ], + [ + "PageFooter", + 1 + ], + [ + "FigureGroup", + 1 + ] + ] + }, + { + "page_id": 6, + "text_extraction_method": "pdftext", + "block_counts": [ + [ + "Span", + 257 + ], + [ + "Line", + 110 + ], + [ + "ListItem", + 22 + ], + [ + "ListGroup", + 2 + ], + [ + "SectionHeader", + 1 + ], + [ + "PageFooter", + 1 + ] + ] + }, + { + "page_id": 7, + "text_extraction_method": "pdftext", + "block_counts": [ + [ + "Span", + 261 + ], + [ + "Line", + 111 + ], + [ + "ListItem", + 22 + ], + [ + "ListGroup", + 2 + ], + [ + "PageFooter", + 1 + ] + ] + }, + { + "page_id": 8, + "text_extraction_method": "pdftext", + "block_counts": [ + [ + "Span", + 81 + ], + [ + "Line", + 32 + ], + [ + "ListItem", + 6 + ], + [ + "Text", + 1 + ], + [ + "PageFooter", + 1 + ], + [ + "ListGroup", + 1 + ] + ] + }, + { + "page_id": 9, + "text_extraction_method": "pdftext", + "block_counts": [ + [ + "Span", + 235 + ], + [ + "Line", + 101 + ], + [ + "Text", + 9 + ], + [ + "SectionHeader", + 4 + ], + [ + "PageFooter", + 1 + ] + ] + }, + { + "page_id": 10, + "text_extraction_method": "pdftext", + "block_counts": [ + [ + "Span", + 180 + ], + [ + "Line", + 97 + ], + [ + "Text", + 11 + ], + [ + "SectionHeader", + 6 + ], + [ + "PageFooter", + 1 + ] + ] + }, + { + "page_id": 11, + "text_extraction_method": "pdftext", + "block_counts": [ + [ + "Span", + 112 + ], + [ + "Line", + 60 + ], + [ + "Text", + 5 + ], + [ + "SectionHeader", + 2 + ], + [ + "Figure", + 1 + ], + [ + "Caption", + 1 + ], + [ + "Table", + 1 + ], + [ + "PageFooter", + 1 + ], + [ + "FigureGroup", + 1 + ] + ] + }, + { + "page_id": 12, + "text_extraction_method": "pdftext", + "block_counts": [ + [ + "Span", + 77 + ], + [ + "Line", + 36 + ], + [ + "Text", + 5 + ], + [ + "Table", + 1 + ], + [ + "PageFooter", + 1 + ] + ] + } + ], + "debug_data_path": "debug_data/source" +} \ No newline at end of file diff --git a/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/refined_doc.json b/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/refined_doc.json new file mode 100644 index 0000000000000000000000000000000000000000..0678ac20d75ea6ed034be840eb09be2e775c2aeb --- /dev/null +++ b/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/refined_doc.json @@ -0,0 +1,608 @@ +{ + "image_dir": "/Users/shijingwei/Desktop/PresentAgent/presentagent/../pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e", + "sections": [ + { + "title": "PresentAgent: Multimodal Agent for Presentation Video Generation", + "summary": "The evaluation setup consists of 30 long-form documents with human-created reference videos spanning diverse topics. Each document is processed through the authors' generation pipeline to create a two-minute presentation video. The evaluation framework, PresentEval, employs a split strategy: Qwen-VL-2.5-3B conducts objective assessment via multiple-choice comprehension questions on entire videos, while Qwen-Omni-7B performs subjective scoring on shorter segments. Evaluation dimensions include narrative coherence, visual/audio appeal, and comprehension difficulty, guided by specific prompts. This approach addresses the current limitation of multimodal models in processing longer videos while maintaining comprehensive assessment across content quality, visual quality, and comprehension accuracy.", + "subsections": [ + { + "title": "Authors and Affiliations", + "content": "Jingwei Shi1∗ Zeyu Zhang1∗† Biao Wu2∗ Yanjie Liang1∗\n\nMeng Fang3 Ling Chen2 Yang Zhao4‡\n\n1AI Geeks, Australia\n\n2Australian Artificial Intelligence Institute, Australia 3University of Liverpool, United Kingdom 4La Trobe University, Australia\n\n∗Equal contribution. † Project lead. ‡Corresponding author: y.zhao2@latrobe.edu.au.", + "medias": [] + }, + { + "title": "Abstract", + "content": "We present PresentAgent, a multimodal agent that transforms long-form documents into narrated presentation videos. While existing approaches are limited to generating static slides or text summaries, our method advances beyond these limitations by producing fully synchronized visual and spoken content that closely mimics human-style presentations. To achieve this integration, PresentAgent employs a modular pipeline that systematically segments the input document, plans and renders slide-style visual frames, generates contextual spoken narration with large language models and Text-to-Speech models, and seamlessly composes the final video with precise audiovisual alignment. Given the complexity of evaluating such multimodal outputs, we introduce PresentEval, a unified assessment framework powered by Vision-Language Models that comprehensively scores videos across three critical dimensions: content fidelity, visual clarity, and audience comprehension through prompt-based evaluation. Our experimental validation on a curated dataset of 30 document–presentation pairs demonstrates that PresentAgent approaches human-level quality across all evaluation metrics. These results highlight the significant potential of controllable multimodal agents in transforming static textual materials into dynamic, effective, and accessible presentation formats. Code will be available at https://github.com/ AIGeeksGroup/PresentAgent.", + "medias": [] + } + ], + "markdown_content": null + }, + { + "title": "Introduction", + "summary": "The evaluation setup consists of 30 long-form documents with human-created reference videos spanning diverse topics. Each document is processed through the authors' generation pipeline to create a two-minute presentation video. The evaluation framework, PresentEval, employs a split strategy: Qwen-VL-2.5-3B conducts objective assessment via multiple-choice comprehension questions on entire videos, while Qwen-Omni-7B performs subjective scoring on shorter segments. Evaluation dimensions include narrative coherence, visual/audio appeal, and comprehension difficulty, guided by specific prompts. This approach addresses the current limitation of multimodal models in processing longer videos while maintaining comprehensive assessment across content quality, visual quality, and comprehension accuracy.", + "subsections": [ + { + "title": "Presentation importance", + "content": "Presentations are a widely used and effective medium for conveying complex ideas. By combining visual elements, structured narration, and spoken explanations, they enable information to unfold progressively and be more easily understood by diverse audiences (Fu et al., 2022). Despite their proven effectiveness, creating high-quality presentation videos from long-form documents—such as business reports, technical manuals, policy briefs, or academic papers—typically requires considerable manual effort (Li et al., 2023). This process involves identifying key content, designing slide layouts, writing scripts, recording narration, and aligning all elements into a coherent multimodal output.", + "medias": [ + { + "markdown_content": "![](_page_0_Figure_10.jpeg)", + "near_chunks": [ + "Presentations are a widely used and effective medium for conveying complex ideas. By combining visual elements, structured narration, and spoken explanations, they enable information to unfold progressively and be more easily understood by diverse audiences (Fu et al., 2022). Despite their proven effectiveness, creating high-quality presentation videos from long-form documents—such as\n\n", + "Figure 1: Overview of PresentAgent. It takes documents (e.g., web pages) as input and follows a generation pipeline: (1) document processing, (2) structured slide generation, (3) synchronized caption creation, and (4) audio synthesis. The final output is a presentation video combining visual slides with aligned narration. The purple-highlighted middle results emphasize the system's key transitional outputs during generation.\n\n" + ], + "path": "/Users/shijingwei/Desktop/PresentAgent/presentagent/../pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/_page_0_Figure_10.jpeg", + "caption": "Diagram: Overview of PresentAgent system showing a four-step pipeline that transforms documents into presentation videos, with input documents on the left, middle processing results (slides, captions, audio) in purple, and final presentation video output on the right." + } + ] + }, + { + "title": "Current AI limitations", + "content": "Although recent advancements in AI have enabled progress in related areas such as documentto-slide generation (Fu et al., 2022; Zheng et al., 2025a; Pang et al., 2025; Zhang et al., 2024) and text-to-video synthesis (Yang et al., 2024c; Li et al., 2023; Xue et al., 2025; Khachatryan et al., 2023; He et al., 2023; Solanki and Khublani, 2024), a critical gap remains: these methods either produce static visual summaries or generic video clips without structured narration, limiting their effectiveness for structured communication tasks like presentations.", + "medias": [] + }, + { + "title": "Document-to-Presentation Video Generation", + "content": "To bridge this gap, we introduce the task of Document-to-Presentation Video Generation, which aims to automatically convert a structured or unstructured document into a narrated video presentation composed of synchronized slides and speech. This task presents unique challenges as it goes beyond traditional summarization (Lewis et al., 2019; Beltagy et al., 2020; Chen and Yang, 2021; Wang et al., 2024a) or text-to-speech (Tachibana et al., 2018; Ren et al., 2019; Popov et al., 2021; Ni et al., 2022) pipelines by requiring selective content abstraction, layout-aware planning (Wang et al., 2025), and precise multimodal alignment (Li et al., 2024) between visuals and narration. In contrast to prior work that focuses on either static slide and image generation (Zheng et al., 2025a; Deng et al., 2025; Xie et al., 2024) or audio summarization in isolation, our objective is to produce a fully integrated, viewer-ready video experience that closely mimics how human presenters deliver information in real-world scenarios.", + "medias": [] + }, + { + "title": "PresentAgent framework", + "content": "To tackle these challenges, we propose a modular generation framework named PresentAgent. Given an input document, the system first segments it into semantic blocks through outline planning, then generates layout-guided slide visuals for each block and rewrites the key message into oral-style narration. Subsequently, these are then synthesized into audio and combined with the slide visuals to produce a time-aligned presentation video. Importantly, our pipeline is designed to be domainadaptable and controllable, enabling broad applicability across document types and presentation styles.", + "medias": [] + }, + { + "title": "Evaluation approach", + "content": "Recognizing the need for rigorous evaluation of such complex multimodal outputs, we curate a test set of 30 human-authored document-video pairs spanning diverse domains, including education, finance, policy, and scientific communication. To comprehensively assess system performance, we further introduce a two-path evaluation strategy that combines fact-based comprehension assessment (via fixed multiple-choice quizzes) and preference-based scoring using vision-language models. This dual-pronged approach captures both objective correctness and subjective quality in video delivery.", + "medias": [] + }, + { + "title": "Results and findings", + "content": "Experiment results demonstrate that our method produces fluent, well-structured, and informative presentation videos, approaching human-level performance in both content delivery and viewer comprehension. These findings highlight the potential of combining language models, layout generation, and multimodal synthesis for creating explainable and scalable presentation systems from raw documents.", + "medias": [] + }, + { + "title": "Key contributions", + "content": "In general, our contributions are summarized as follows:\n\n- We formulate and address the novel task of document-to-presentation video generation, which aims to produce narrated, slide-structured videos from long-form documents across diverse domains.\n\n- We propose PresentAgent, a modular generation framework that integrates document parsing, layout-aware slide composition, narration planning, and audio-visual synchronization, enabling controllable and interpretable generation.\n- We introduce PresentEval, a multi-dimensional evaluation framework powered by Vision-Language Models (VLMs), which scores videos along content, visual, and comprehension dimensions via prompt-based judging.\n- We create a test set of 30 real-world document–presentation pairs and demonstrate through experiments and ablations that PresentAgent approaches human-level performance and significantly outperforms competitive variants.", + "medias": [] + } + ], + "markdown_content": null + }, + { + "title": "Presentation Benchmark", + "summary": "The evaluation setup consists of 30 long-form documents with human-created reference videos spanning diverse topics. Each document is processed through the authors' generation pipeline to create a two-minute presentation video. The evaluation framework, PresentEval, employs a split strategy: Qwen-VL-2.5-3B conducts objective assessment via multiple-choice comprehension questions on entire videos, while Qwen-Omni-7B performs subjective scoring on shorter segments. Evaluation dimensions include narrative coherence, visual/audio appeal, and comprehension difficulty, guided by specific prompts. This approach addresses the current limitation of multimodal models in processing longer videos while maintaining comprehensive assessment across content quality, visual quality, and comprehension accuracy.", + "subsections": [ + { + "title": "Benchmark Overview", + "content": "The benchmark supports evaluation not only of fluency and fidelity, but also of downstream comprehension. Following the methodology introduced in Paper2Poster (Pang et al., 2025), we construct a quiz-style evaluation protocol (§5), where vision-language models are asked to answer factual content questions using only the generated video (slides + narration), simulating an audience's understanding. Human-authored videos are used as reference standards for both score calibration and upperbound comparison. As shown in Figure 5, our benchmark encompasses four representative document types (academic papers, web pages, technical blogs, and slides) paired with human-authored videos, covering diverse real-world domains like education, research, and business reports.\n\nWe adopt a unified, model-based evaluation framework to assess the generated presentation videos. All evaluations are conducted using a vision-language model, guided by dimensionspecific prompts tailored to different assessment objectives. The framework consists of two complementary components: (1) objective quiz evaluation, which measures factual accuracy through multiplechoice question answering; and (2) subjective scoring, which rates Content Quality, Visual or Audio Quality, and Comprehension Clarity on a 1–5 scale. Together, these metrics provide a comprehensive assessment of both the quality and informativeness of the generated videos.", + "medias": [ + { + "markdown_content": "![](_page_2_Figure_0.jpeg)", + "near_chunks": [ + "We adopt a unified, model-based evaluation framework to assess the generated presentation videos. All evaluations are conducted using a vision-language model, guided by dimensionspecific prompts tailored to different assessment objectives. The framework consists of two complementary components: (1) objective quiz evaluation, which measures factual accuracy through multiplechoice question answering; and (2) subjective scoring, which rates Content Quality, Visual or Audio Quality, and Comprehension Clarity on a 1–5 scale. Together, these metrics provide a comprehensive assessment of both the quality and informativeness\n\n", + "Figure 2: Overview of our framework. Our approach addresses the full pipeline of document-to-presentation video generation and evaluation. Left: Given diverse input documents—including papers, websites, blogs, slides, and PDFs—PresentAgent generates narrated presentation videos by producing synchronized slide decks with audio. Right: To evaluate these videos, we introduce PresentEval, a two-part evaluation framework: (1) Objective Quiz Evaluation (top), which measures factual comprehension using Qwen-VL; and (2) Subjective Scoring (bottom), which uses vision-language models to rate content quality, visual design, and audio comprehension across predefined dimensions.\n\n" + ], + "path": "/Users/shijingwei/Desktop/PresentAgent/presentagent/../pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/_page_2_Figure_0.jpeg", + "caption": "Diagram: Framework overview showing PresentAgent's document-to-video pipeline. Left shows input documents (papers, websites, blogs), center displays the video creation process, and right illustrates PresentEval's two-part evaluation system with objective quiz evaluation and subjective scoring metrics." + } + ] + }, + { + "title": "Doc2Present Dataset", + "content": "To support the evaluation of document to presentation video generation, we curate the Doc2Present Benchmark, a diverse dataset of document–presentation video pairs spanning multiple domains. Unlike prior benchmarks focused on research abstracts or slide generation, our dataset includes documents such as business reports, product manuals, policy briefs, and instructional texts, each paired with a human-crafted presentation video.We collect 30 high-quality video samples from public platforms, educational repositories, and professional presentation archives, further details regarding the data sources and statistical information of the dataset can be found in the appendix F.", + "medias": [] + }, + { + "title": "PresentEval", + "content": "To assess the quality of generated presentation videos, we adopt two complementary evaluation strategies: Objective Quiz Evaluation and Subjective Scoring. For each video, we provide the visionlanguage model with the complete set of slide images and the full narration transcript as a unified input—simulating how a real viewer would experience the presentation. In Objective Quiz Evaluation, the model answers a fixed set of factual questions to determine whether the video accurately conveys the key information from the source content. In Subjective Scoring, the model evaluates the video along three dimensions: the coherence of the narration, the clarity and design of the visuals, and the overall ease of understanding. All evaluations are conducted without ground-truth references and rely entirely on the model's interpretation of the presented content.\n\nObjective Quiz Evaluation To evaluate whether a generated presentation video effectively conveys the core content of its source document, we use a fixed-question comprehension evaluation protocol. Specifically, we manually design five multiplechoice questions for each document, tailored to its content. These questions focus on key aspects such as topic recognition, structural understanding, and main argument extraction. As shown in Table 2, during evaluation, a vision-language model is given the video, including both visual frames and audio transcript, and asked to answer the five questions. Each question has four options, with one correct answer, annotated based on a human-created reference video. The final comprehension score (ranging from 0 to 5) reflects how many questions the model answered correctly, serving as a direct measure of how well the video communicates the original document.\n\nSubjective Scoring To evaluate the quality of generated presentation videos, we adopt a promptbased assessment using vision-language models. Instead of relying on human references or fixed metrics, we ask the model to evaluate each video from a viewer's perspective, using its own reasoning and preferences. The evaluation focuses on three aspects: coherence of narration, clarity and aesthetics of visuals, and overall ease of understanding. The model is shown the video and audio, and gives a score (1–5) with a brief explanation for each aspect. This enables scalable, consistent, and human-aligned evaluation without manual references. As shown in Table 3, we design different prompts for different modalities and tasks to ensure targeted and effective assessment.", + "medias": [] + } + ], + "markdown_content": null + }, + { + "title": "PresentAgent", + "summary": "The evaluation setup consists of 30 long-form documents with human-created reference videos spanning diverse topics. Each document is processed through the authors' generation pipeline to create a two-minute presentation video. The evaluation framework, PresentEval, employs a split strategy: Qwen-VL-2.5-3B conducts objective assessment via multiple-choice comprehension questions on entire videos, while Qwen-Omni-7B performs subjective scoring on shorter segments. Evaluation dimensions include narrative coherence, visual/audio appeal, and comprehension difficulty, guided by specific prompts. This approach addresses the current limitation of multimodal models in processing longer videos while maintaining comprehensive assessment across content quality, visual quality, and comprehension accuracy.", + "subsections": [ + { + "title": "Overview", + "content": "To convert a long-form document into a narrated presentation video, we design a multi-stage generation framework that mirrors how human presenters prepare slides and talk tracks. Our method proceeds in four steps: segmenting the document into semantic units, composing slides with layout-aware structures, generating oral-style narration for each slide and assembling the visual and audio components into a synchronized video. This modular design supports controllability, interpretability, and multimodal alignment, enabling both high-quality generation and fine-grained evaluation. The following sections describe each component in detail.", + "medias": [ + { + "markdown_content": "![](_page_4_Figure_0.jpeg)", + "near_chunks": [ + "For each content block corresponding to a slide, we prompt a language model to generate a concise, oral-style narration. The model is instructed to rewrite the key message of the slide into natural spoken language, avoiding dense text or technical jargon. We apply length control to ensure each narration falls within a target duration, typically between 30 and 150 seconds. Once the narration script is obtained, we synthesize the corresponding audio using a text-to-speech system. Each narration audio is paired with its slide and timestamped, forming the basis for synchronized video rendering in the next stage.\n\n", + "Figure 3: Overview of the PresentAgent framework. Our system takes diverse documents (e.g., papers, websites, PDFs) as input and follows a modular generation pipeline. It first performs outline generation (Step 1) and retrieves the most suitable template (Step 2), then generates slides and narration notes via a vision-language model (Step 3). The notes are converted into audio via TTS and composed into a presentation video (Step 4). To evaluate video quality, we design multiple prompts (Step 5) and feed them into a VLM-based scoring pipeline (Step 6) that outputs dimension-specific metrics.\n\n" + ], + "path": "/Users/shijingwei/Desktop/PresentAgent/presentagent/../pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/_page_4_Figure_0.jpeg", + "caption": "Diagram: Comprehensive flowchart of the PresentAgent framework showing a six-step process from input documents through generation pipeline (outline creation, template selection, slides/notes generation, video production) to evaluation metrics with visual icons representing each component." + } + ] + }, + { + "title": "Problem Formulation", + "content": "Our method is designed to transform a long-form document into a structured presentation video through a multi-stage generation pipeline. We provide a formal description to highlight the key difference between our approach and conventional slide-based methods.\n\nConventional approaches often focus on generating slide elements S directly from a document chunk C, as in Equation 1, where each element includes text or image content, layout attributes, and visual style:\n\n$$S=\\{e_{1},e_{2},...,e_{n}\\}=f(C)\\qquad\\quad(1)$$\n\nIn contrast, we treat the entire document D as a globally structured input and generate a presentation in three steps: (1) a sequence of semantic segments {C1, ..., CK} via outline planning, (2) a set of slides {S1, ..., SK}, each paired with a narrated audio track Tk generated by first producing a slide-specific script and then converting it to speech, and (3) a video V composed of visual and audio content aligned over time. This is defined as:\n\n$V=$ **Compose($\\{(S_{1},T_{1}),...,(S_{K},T_{K})\\})=g(D)$**\n\nRather than editing predefined templates or layouts, our system first identifies high-level structure in the document and then generates slide visuals and narration from scratch. This pipeline\n\nsupports controllability, modular evaluation, and multimodal alignment for downstream comprehension and quality assessment.", + "medias": [] + }, + { + "title": "Slide Planning and Composition", + "content": "Our slide generation module is inspired by the editing-based paradigm proposed in PPTAgent (Zheng et al., 2025b), which formulates presentation construction as a structured editing process over HTML-like layouts. While PPTAgent focuses on producing editable .pptx slides, our goal is to generate visually coherent, narrationready slide frames for downstream video synthesis. We re-implement the core idea in a self-contained pipeline tailored to multimodal synchronization.\n\nWe begin by segmenting the input document into coherent content blocks using a lightweight LLM-based parser. Each block is assigned a corresponding slide type such as bullet slide, figuredescription, or title-intro, and matched with a predefined layout schema encoded in HTML. Unlike retrieval-based template matching, our system uses semantic and structural cues to map content to layout patterns in a rule-guided manner.\n\nTo populate the slide, we define a set of editable operations such as replace_text, insert_image, and add_list, which are applied to the layout structure. These instructions are generated by prompting a language model with the content block and layout constraints. Slides are then rendered into static visual frames using python-pptx or HTML-based renderers.", + "medias": [] + }, + { + "title": "Narration and Audio Synthesis", + "content": "To transform the static slides into an engaging presentation, we generate a spoken narration for each slide and synthesize it into audio. The process involves two components: narration script generation and text-to-speech synthesis.\n\nFor each content block corresponding to a slide, we prompt a language model to generate a concise, oral-style narration. The model is instructed to rewrite the key message of the slide into natural spoken language, avoiding dense text or technical jargon. We apply length control to ensure each narration falls within a target duration, typically between 30 and 150 seconds. Once the narration script is obtained, we synthesize the corresponding audio using a text-to-speech system. Each narration audio is paired with its slide and timestamped, forming the basis for synchronized video rendering in the next stage.\n\n![](_page_4_Figure_0.jpeg)\n\nFigure 3: Overview of the PresentAgent framework. Our system takes diverse documents (e.g., papers, websites, PDFs) as input and follows a modular generation pipeline. It first performs outline generation (Step 1) and retrieves the most suitable template (Step 2), then generates slides and narration notes via a vision-language model (Step 3). The notes are converted into audio via TTS and composed into a presentation video (Step 4). To evaluate video quality, we design multiple prompts (Step 5) and feed them into a VLM-based scoring pipeline (Step 6) that outputs dimension-specific metrics.", + "medias": [] + }, + { + "title": "Video Assembly", + "content": "In the final stage, we assemble the slide images and narration audio into a coherent, time-aligned presentation video. Each slide frame is displayed for the duration of its corresponding audio segment, with optional transitions between segments. We use video processing libraries such as ffmpeg to compose the visual and audio tracks. Each slide is rendered as a static frame, and the narration is added as synchronized voiceover audio. The output is a fully rendered video file in standard formats such as .mp4, suitable for presentation, sharing, or further editing. This stage completes the transformation from a raw document into a narrated, structured presentation video.", + "medias": [] + } + ], + "markdown_content": null + }, + { + "title": "Experiments", + "summary": "The evaluation setup consists of 30 long-form documents with human-created reference videos spanning diverse topics. Each document is processed through the authors' generation pipeline to create a two-minute presentation video. The evaluation framework, PresentEval, employs a split strategy: Qwen-VL-2.5-3B conducts objective assessment via multiple-choice comprehension questions on entire videos, while Qwen-Omni-7B performs subjective scoring on shorter segments. Evaluation dimensions include narrative coherence, visual/audio appeal, and comprehension difficulty, guided by specific prompts. This approach addresses the current limitation of multimodal models in processing longer videos while maintaining comprehensive assessment across content quality, visual quality, and comprehension accuracy.", + "subsections": [ + { + "title": "Introduction to Experiments", + "content": "We conduct experiments to evaluate the effectiveness of our proposed system in generating highquality, narrated presentation videos. Given the novelty of the task, our focus is not on competing with existing baselines, but rather on assessing the performance of our full system relative to human-created presentations. Comprehension accuracy is determined based on performance in the PresentEval task. Evaluation setup can be found in appendix H.\n\n**Via Fixed Quiz**\n\n...\n\n**Question1 : XXX Question2 : XXX**\n\n**Question5 : XXX**\n\n**1. Content Quality 2. Visual Quality 3. Comprehension Accuracy**", + "medias": [ + { + "markdown_content": "| Method | Model | Quiz Accuracy | | Video Score | | | | Audio Score | | |\n| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |\n| | | | Content | Visual | Comp. | Mean | Content | Audio | Comp. | Mean |\n| Human | Human | 0.56 | 4.0 | 4.6 | 4.8 | 4.47 | 4.8 | 4.6 | 5.0 | 4.80 |\n| PresentAgent | Claude-3.7-sonnet | 0.64 | 4.0 | 4.0 | 4.0 | 4.00 | 4.2 | 4.6 | 4.8 | 4.53 |\n| PresentAgent | Qwen-VL-Max | 0.52 | 4.2 | 4.8 | 4.4 | 4.47 | 4.6 | 4.2 | 5.0 | 4.60 |\n| PresentAgent | Gemini-2.5-pro | 0.52 | 4.2 | 4.4 | 4.4 | 4.33 | 4.2 | 4.0 | 4.8 | 4.33 |\n| PresentAgent | Gemini-2.5-flash | 0.52 | 4.2 | 5.0 | 3.8 | 4.33 | 4.2 | 4.2 | 4.8 | 4.40 |\n| PresentAgent | GPT-4o-Mini | 0.64 | 4.8 | 4.6 | 4.6 | 4.67 | 4.0 | 4.4 | 4.8 | 4.40 |\n| PresentAgent | GPT-4o | 0.56 | 4.0 | 4.2 | 3.6 | 3.93 | 4.2 | 4.4 | 4.8 | 4.47 |", + "near_chunks": [ + "In terms of subjective quality, human-created presentations still lead with the highest video and audio scores overall. However, several PresentAgent variants show competitive performance.\n\nTable 1 presents evaluation results, covering both factual comprehension (Quiz Accuracy) and preference-based quality scores for video and audio outputs. In terms of quiz accuracy, most PresentAgent variants perform comparably to or better than the human reference (0.56), with Claude-3.7 sonnet (Anthropic, 2024) achieving the highest accuracy at 0.64, suggesting strong alignment between the generated content and the source document. Other models such as Qwen-VL-Max (Bai et al., 2025) and Gemini-2.5-flash (DeepMind, 2024) scored slightly lower (0.52), indicating room for improvement in factual grounding.\n\n", + "Table 1: Detailed evaluation results on the 5-document test set. Fact-based evaluation includes accuracy on five fixed quiz questions (Q1–Q5). Preference-based evaluation includes 1–5 scale scores for content fidelity, visual design, and overall clarity. Each Quality Score group has a calculated mean column.\n\n" + ], + "path": "/Users/shijingwei/Desktop/PresentAgent/presentagent/../pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/table_f5f7.png", + "caption": "Table: Comparison of presentation quality between human presenters and PresentAgent variants using different LLMs, showing quiz accuracy and subjective scores for video (content, visual, comprehensibility) and audio (content, audio quality, comprehensibility) components.", + "cells": [ + [ + "Method", + "Model", + "Quiz Accuracy", + "", + "Video Score", + "", + "", + "", + "Audio Score", + "", + "" + ], + [ + "", + "", + "", + "Content", + "Visual", + "Comp.", + "Mean", + "Content", + "Audio", + "Comp.", + "Mean" + ], + [ + "Human", + "Human", + "0.56", + "4.0", + "4.6", + "4.8", + "4.47", + "4.8", + "4.6", + "5.0", + "4.80" + ], + [ + "PresentAgent", + "Claude-3.7-sonnet", + "0.64", + "4.0", + "4.0", + "4.0", + "4.00", + "4.2", + "4.6", + "4.8", + "4.53" + ], + [ + "PresentAgent", + "Qwen-VL-Max", + "0.52", + "4.2", + "4.8", + "4.4", + "4.47", + "4.6", + "4.2", + "5.0", + "4.60" + ], + [ + "PresentAgent", + "Gemini-2.5-pro", + "0.52", + "4.2", + "4.4", + "4.4", + "4.33", + "4.2", + "4.0", + "4.8", + "4.33" + ], + [ + "PresentAgent", + "Gemini-2.5-flash", + "0.52", + "4.2", + "5.0", + "3.8", + "4.33", + "4.2", + "4.2", + "4.8", + "4.40" + ], + [ + "PresentAgent", + "GPT-4o-Mini", + "0.64", + "4.8", + "4.6", + "4.6", + "4.67", + "4.0", + "4.4", + "4.8", + "4.40" + ], + [ + "PresentAgent", + "GPT-4o", + "0.56", + "4.0", + "4.2", + "3.6", + "3.93", + "4.2", + "4.4", + "4.8", + "4.47" + ] + ], + "merge_area": null + }, + { + "markdown_content": "![](_page_5_Figure_2.jpeg)", + "near_chunks": [ + "Table 1: Detailed evaluation results on the 5-document test set. Fact-based evaluation includes accuracy on five fixed quiz questions (Q1–Q5). Preference-based evaluation includes 1–5 scale scores for content fidelity, visual design, and overall clarity. Each Quality Score group has a calculated mean column.\n\n", + "Figure 4: PresentAgent Demo. Automatically generates academic-style slides and narrated videos from research papers, streamlining the transformation from written content to engaging visual presentations.\n\nFor example, GPT-4o-Mini (Achiam et al., 2023) achieves top scores in video content and visual appeal (both at or near 4.8), while Claude-3.7 sonnet (Anthropic, 2024) delivers the most balanced audio quality (mean 4.53). Interestingly, Gemini-2.5-flash (DeepMind, 2024) scores highest in visual quality (5.0) but lower in comprehension, reflecting a trade-off between aesthetics and clarity. These results highlight the effectiveness of our modular pipeline and the usefulness of our unified PresentEval framework in capturing diverse aspects of presentation quality.\n\n" + ], + "path": "/Users/shijingwei/Desktop/PresentAgent/presentagent/../pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/_page_5_Figure_2.jpeg", + "caption": "Figure: PresentAgent demo showing three sections: Technical Blogs (explaining agent concepts), Slides with captions (demonstrating parallelization and evaluator-optimizer workflows), and Videos (featuring augmented LLM presentations with explanatory text)." + } + ] + }, + { + "title": "Main Results", + "content": "Table 1 presents evaluation results, covering both factual comprehension (Quiz Accuracy) and preference-based quality scores for video and audio outputs. In terms of quiz accuracy, most PresentAgent variants perform comparably to or better than the human reference (0.56), with Claude-3.7 sonnet (Anthropic, 2024) achieving the highest accuracy at 0.64, suggesting strong alignment between the generated content and the source document. Other models such as Qwen-VL-Max (Bai et al., 2025) and Gemini-2.5-flash (DeepMind, 2024) scored slightly lower (0.52), indicating room for improvement in factual grounding.\n\nIn terms of subjective quality, human-created presentations still lead with the highest video and audio scores overall. However, several PresentAgent variants show competitive performance.\n\n| Method | Model | Quiz Accuracy | | Video Score | | | | Audio Score | | |\n| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |\n| | | | Content | Visual | Comp. | Mean | Content | Audio | Comp. | Mean |\n| Human | Human | 0.56 | 4.0 | 4.6 | 4.8 | 4.47 | 4.8 | 4.6 | 5.0 | 4.80 |\n| PresentAgent | Claude-3.7-sonnet | 0.64 | 4.0 | 4.0 | 4.0 | 4.00 | 4.2 | 4.6 | 4.8 | 4.53 |\n| PresentAgent | Qwen-VL-Max | 0.52 | 4.2 | 4.8 | 4.4 | 4.47 | 4.6 | 4.2 | 5.0 | 4.60 |\n| PresentAgent | Gemini-2.5-pro | 0.52 | 4.2 | 4.4 | 4.4 | 4.33 | 4.2 | 4.0 | 4.8 | 4.33 |\n| PresentAgent | Gemini-2.5-flash | 0.52 | 4.2 | 5.0 | 3.8 | 4.33 | 4.2 | 4.2 | 4.8 | 4.40 |\n| PresentAgent | GPT-4o-Mini | 0.64 | 4.8 | 4.6 | 4.6 | 4.67 | 4.0 | 4.4 | 4.8 | 4.40 |\n| PresentAgent | GPT-4o | 0.56 | 4.0 | 4.2 | 3.6 | 3.93 | 4.2 | 4.4 | 4.8 | 4.47 |\n\nTable 1: Detailed evaluation results on the 5-document test set. Fact-based evaluation includes accuracy on five fixed quiz questions (Q1–Q5). Preference-based evaluation includes 1–5 scale scores for content fidelity, visual design, and overall clarity. Each Quality Score group has a calculated mean column.", + "medias": [] + }, + { + "title": "Analysis", + "content": "Figure 4 Presents a full example of a PresentAgentauto-generated presentation video, showing a technical blog turned into a narrated presentation. The system identifies structural segments (e.g., introduction, technical explanations) and generates slides with oral-style captions and synchronized speech, covering topics like \"parallelization workflow\" and \"agent system architecture\" to demonstrate its ability to keep technical accuracy while delivering content clearly and conversationally.\n\n![](_page_5_Figure_2.jpeg)\n\nFigure 4: PresentAgent Demo. Automatically generates academic-style slides and narrated videos from research papers, streamlining the transformation from written content to engaging visual presentations.", + "medias": [] + } + ], + "markdown_content": null + }, + { + "title": "Conclusion", + "summary": "The evaluation setup consists of 30 long-form documents with human-created reference videos spanning diverse topics. Each document is processed through the authors' generation pipeline to create a two-minute presentation video. The evaluation framework, PresentEval, employs a split strategy: Qwen-VL-2.5-3B conducts objective assessment via multiple-choice comprehension questions on entire videos, while Qwen-Omni-7B performs subjective scoring on shorter segments. Evaluation dimensions include narrative coherence, visual/audio appeal, and comprehension difficulty, guided by specific prompts. This approach addresses the current limitation of multimodal models in processing longer videos while maintaining comprehensive assessment across content quality, visual quality, and comprehension accuracy.", + "subsections": [ + { + "title": "PresentAgent System Summary", + "content": "In conclusion, we presented PresentAgent, a modular system for transforming long-form documents into narrated presentation videos. By addressing the challenges of slide planning, narration synthesis, and synchronized rendering, PresentAgent enables structured, controllable, and reusable multimodal outputs.", + "medias": [] + }, + { + "title": "Evaluation Approach", + "content": "To evaluate this novel task, we introduced a diverse benchmark and proposed complementary factual and preference-based metrics.", + "medias": [] + }, + { + "title": "Results and Impact", + "content": "Experimental results show that PresentAgent generates coherent, engaging, and informative presentations, approaching human quality. This work lays the groundwork for automated, explainable content generation and opens new directions for research in multimodal communication across education, business, and accessibility.", + "medias": [] + } + ], + "markdown_content": null + }, + { + "title": "References", + "summary": "The evaluation setup consists of 30 long-form documents with human-created reference videos spanning diverse topics. Each document is processed through the authors' generation pipeline to create a two-minute presentation video. The evaluation framework, PresentEval, employs a split strategy: Qwen-VL-2.5-3B conducts objective assessment via multiple-choice comprehension questions on entire videos, while Qwen-Omni-7B performs subjective scoring on shorter segments. Evaluation dimensions include narrative coherence, visual/audio appeal, and comprehension difficulty, guided by specific prompts. This approach addresses the current limitation of multimodal models in processing longer videos while maintaining comprehensive assessment across content quality, visual quality, and comprehension accuracy.", + "subsections": [ + { + "title": "Academic Citations", + "content": "- Josh Achiam, Steven Adler, Sandhini Agarwal, Lama Ahmad, Ilge Akkaya, Florencia Leoni Aleman, Diogo Almeida, Janko Altenschmidt, Sam Altman, Shyamal Anadkat, and 1 others. 2023. Gpt-4 technical report. *arXiv preprint arXiv:2303.08774*.\n- Rie Kubota Ando and Tong Zhang. 2005. A framework for learning predictive structures from multiple tasks and unlabeled data. *Journal of Machine Learning Research*, 6:1817–1853.\n- Galen Andrew and Jianfeng Gao. 2007. Scalable training of L1-regularized log-linear models. In *Proceedings of the 24th International Conference on Machine Learning*, pages 33–40.\n- Anthropic. 2024. Claude 3 technical overview. https://www.anthropic.com/news/claude-3. Accessed: 2025-06-30.\n- Shuai Bai, Keqin Chen, Xuejing Liu, Jialin Wang, Wenbin Ge, Sibo Song, Kai Dang, Peng Wang, Shijie Wang, Jun Tang, and 1 others. 2025. Qwen2. 5-vl technical report. *arXiv preprint arXiv:2502.13923*.\n- Iz Beltagy, Matthew E Peters, and Arman Cohan. 2020. Longformer: The long-document transformer. *arXiv preprint arXiv:2004.05150*.\n- Jiaao Chen and Diyi Yang. 2021. Structure-aware abstractive conversation summarization via discourse and action graphs. *arXiv preprint arXiv:2104.08400*.\n- Google DeepMind. 2024. Gemini 2.5: Pushing the frontier with advanced reasoning, multimodality, long context, and next generation agentic capabilities. https://deepmind.google/technologies/ gemini/. Accessed: 2025-06-30.\n- Chaorui Deng, Deyao Zhu, Kunchang Li, Chenhui Gou, Feng Li, Zeyu Wang, Shu Zhong, Weihao Yu, Xiaonan Nie, Ziang Song, and 1 others. 2025. Emerging properties in unified multimodal pretraining. *arXiv preprint arXiv:2505.14683*.\n- Tsu-Jui Fu, William Yang Wang, Daniel McDuff, and Yale Song. 2022. Doc2ppt: Automatic presentation slides generation from scientific documents. In *Proceedings of the AAAI Conference on Artificial Intelligence*, volume 36, pages 634–642.\n- Jiaxin Ge, Zora Zhiruo Wang, Xuhui Zhou, Yi-Hao Peng, Sanjay Subramanian, Qinyue Tan, Maarten Sap, Alane Suhr, Daniel Fried, Graham Neubig, and Trevor Darrell. 2025. Autopresent: Designing structured visuals from scratch. *arXiv preprint arXiv:2501.00912*.\n- Yingqing He, Menghan Xia, Haoxin Chen, Xiaodong Cun, Yuan Gong, Jinbo Xing, Yong Zhang, Xintao Wang, Chao Weng, Ying Shan, and 1 others. 2023. Animate-a-story: Storytelling with retrieval-augmented video generation. *arXiv preprint arXiv:2307.06940*.\n- Levon Khachatryan, Andranik Movsisyan, Vahram Tadevosyan, Roberto Henschel, Zhangyang Wang, Shant Navasardyan, and Humphrey Shi. 2023. Text2video-zero: Text-to-image diffusion models are zero-shot video generators. In *Proceedings of the IEEE/CVF International Conference on Computer Vision*, pages 15954–15964.\n- Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov, and Luke Zettlemoyer. 2019. Bart: Denoising sequence-to-sequence pre-training for natural language generation, translation, and comprehension. *arXiv preprint arXiv:1910.13461*.\n- Bo Li, Yuanhan Zhang, Dong Guo, Renrui Zhang, Feng Li, Hao Zhang, Kaichen Zhang, Peiyuan Zhang, Yanwei Li, Ziwei Liu, and 1 others. 2024. Llavaonevision: Easy visual task transfer. *arXiv preprint arXiv:2408.03326*.\n- Xin Li, Wenqing Chu, Ye Wu, Weihang Yuan, Fanglong Liu, Qi Zhang, Fu Li, Haocheng Feng, Errui Ding, and Jingdong Wang. 2023. Videogen: A reference-guided latent diffusion approach for high definition text-to-video generation. *arXiv preprint arXiv:2309.00398*.\n- Kevin Qinghong Lin, Linjie Li, Difei Gao, Qinchen Wu, Mingyi Yan, Zhengyuan Yang, Lijuan Wang, and Mike Zheng Shou. 2024a. Videogui: A benchmark for gui automation from instructional videos. *arXiv preprint arXiv:2406.10227*.\n- Kevin Qinghong Lin, Linjie Li, Difei Gao, Zhengyuan Yang, Shiwei Wu, Zechen Bai, Weixian Lei, Lijuan Wang, and Mike Zheng Shou. 2024b. Showui: One vision-language-action model for gui visual agent. *arXiv preprint arXiv:2411.17465*.\n- Pan Lu, Bowen Chen, Sheng Liu, Rahul Thapa, Joseph Boen, and James Zou. 2025. Octotools: An agentic framework with extensible tools for complex reasoning. *arXiv preprint arXiv:2502.11271*.\n- Shravan Nayak, Xiangru Jian, Kevin Qinghong Lin, Juan A. Rodriguez, Montek Kalsi, Rabiul Awal, Nicolas Chapados, M. Tamer Özsu, Aishwarya Agrawal, David Vazquez, Christopher Pal, Perouz Taslakian, Spandana Gella, and Sai Rajeswar. 2025. Ui-vision: A desktop-centric gui benchmark for visual perception and interaction. *arXiv preprint arXiv:2503.15661*.\n- Junrui Ni, Liming Wang, Heting Gao, Kaizhi Qian, Yang Zhang, Shiyu Chang, and Mark Hasegawa-Johnson. 2022. Unsupervised text-to-speech synthesis by unsupervised automatic speech recognition. *arXiv preprint arXiv:2203.15796*.\n- Wei Pang, Kevin Qinghong Lin, Xiangru Jian, Xi He, and Philip Torr. 2025. Paper2poster: Towards multimodal poster automation from scientific papers. *arXiv preprint arXiv:2505.21497*.\n- Vadim Popov, Ivan Vovk, Vladimir Gogoryan, Tasnima Sadekova, and Mikhail Kudinov. 2021. Grad-tts: A diffusion probabilistic model for text-to-speech. In *International conference on machine learning*, pages 8599–8608. PMLR.\n- Yujia Qin, Yining Ye, Junjie Fang, Haoming Wang, Shihao Liang, Shizuo Tian, Junda Zhang, Jiahao Li, Yunxin Li, Shijue Huang, and 1 others. 2025. Uitars: Pioneering automated gui interaction with native agents. *arXiv preprint arXiv:2501.12326*.\n- Mohammad Sadegh Rasooli and Joel R. Tetreault. 2015. Yara parser: A fast and accurate dependency parser. *Computing Research Repository*, arXiv:1503.06733. Version 2.\n- Yi Ren, Yangjun Ruan, Xu Tan, Tao Qin, Sheng Zhao, Zhou Zhao, and Tie-Yan Liu. 2019. Fastspeech: Fast, robust and controllable text to speech. *Advances in neural information processing systems*, 32.\n- Timo Schick, Jane Dwivedi-Yu, Roberto Dessì, and et al. 2023. Toolformer: Language models can teach themselves to use tools. *arXiv preprint arXiv:2302.04761*.\n- Shivam R Solanki and Drupad K Khublani. 2024. From script to screen: Unveiling text-to-video generation. In *Generative Artificial Intelligence: Exploring the Power and Potential of Generative AI*, pages 81–112. Springer.\n- Qiushi Sun, Kanzhi Cheng, Zichen Ding, Chuanyang Jin, Yian Wang, Fangzhi Xu, Zhenyu Wu, Chengyou Jia, Liheng Chen, Zhoumianze Liu, and 1 others. 2024. Os-genesis: Automating gui agent trajectory construction via reverse task synthesis. *arXiv preprint arXiv:2412.19723*.\n- Hideyuki Tachibana, Katsuya Uenoyama, and Shunsuke Aihara. 2018. Efficiently trainable text-to-speech system based on deep convolutional networks with guided attention. In *2018 IEEE international conference on acoustics, speech and signal processing (ICASSP)*, pages 4784–4788. IEEE.\n- Baode Wang, Biao Wu, Weizhen Li, Meng Fang, Yanjie Liang, Zuming Huang, Haozhe Wang, Jun Huang, Ling Chen, Wei Chu, and 1 others. 2025. Infinity parser: Layout aware reinforcement learning for scanned document parsing. *arXiv preprint arXiv:2506.03197*.\n- Guanghua Wang, Priyanshi Garg, and Weili Wu. 2024a. Segmented summarization and refinement: A pipeline for long-document analysis on social media. *Journal of Social Computing*, 5(2):132–144.\n- Peng Wang, Shuai Bai, Sinan Tan, Shijie Wang, Zhihao Fan, Jinze Bai, Keqin Chen, Xuejing Liu, Jialin Wang, Wenbin Ge, and 1 others. 2024b. Qwen2 vl: Enhancing vision-language model's perception of the world at any resolution. *arXiv preprint arXiv:2409.12191*.\n- Xingyao Wang, Boxuan Li, Yufan Song, Frank F Xu, Xiangru Tang, Mingchen Zhuge, Jiayi Pan, Yueqi Song, Bowen Li, Jaskirat Singh, and 1 others. 2024c. Opendevin: An open platform for ai software developers as generalist agents. *arXiv preprint arXiv:2407.16741*.\n- Yuan Wang, Di Huang, Yaqi Zhang, Wanli Ouyang, Jile Jiao, Xuetao Feng, Yan Zhou, Pengfei Wan, Shixiang Tang, and Dan Xu. 2024d. Motiongpt-2: A general-purpose motion-language model for motion generation and understanding. *arXiv preprint arXiv:2410.21747*.\n- Biao Wu, Yanda Li, Meng Fang, Zirui Song, Zhiwei Zhang, Yunchao Wei, and Ling Chen. 2024. Foundations and recent trends in multimodal mobile agents: A survey. *arXiv preprint arXiv:2411.02006*.\n- Jinheng Xie, Weijia Mao, Zechen Bai, David Junhao Zhang, Weihao Wang, Kevin Qinghong Lin, Yuchao Gu, Zhijie Chen, Zhenheng Yang, and Mike Zheng Shou. 2024. Show-o: One single transformer to unify multimodal understanding and generation. *arXiv preprint arXiv:2408.12528*.\n- Jin Xu, Zhifang Guo, Jinzheng He, Hangrui Hu, Ting He, Shuai Bai, Keqin Chen, Jialin Wang, Yang Fan, Kai Dang, and 1 others. 2025. Qwen2. 5-omni technical report. *arXiv preprint arXiv:2503.20215*.\n- Qiyao Xue, Xiangyu Yin, Boyuan Yang, and Wei Gao. 2025. Phyt2v: Llm-guided iterative self-refinement for physics-grounded text-to-video generation. In *Proceedings of the Computer Vision and Pattern Recognition Conference*, pages 18826–18836.\n- John Yang, Carlos Jimenez, Alexander Wettig, Kilian Lieret, Shunyu Yao, Karthik Narasimhan, and Ofir Press. 2024a. Swe-agent: Agent-computer interfaces enable automated software engineering. *Advances in Neural Information Processing Systems*, 37:50528– 50652.\n- Ke Yang, Jiateng Liu, John Wu, Chaoqi Yang, Yi R Fung, Sha Li, Zixuan Huang, Xu Cao, Xingyao Wang, Yiquan Wang, and 1 others. 2024b. If llm is the wizard, then code is the wand: A survey on how code empowers large language models to serve as intelligent agents. *arXiv preprint arXiv:2401.00812*.\n- Rui Yang, Lin Song, Yanwei Li, Sijie Zhao, Yixiao Ge, Xiu Li, and Ying Shan. 2023a. Gpt4tools: Teaching large language model to use tools via self-instruction. *Advances in Neural Information Processing Systems*, 36:71995–72007.\n- Zhengyuan Yang, Linjie Li, Jianfeng Wang, Kevin Lin, Ehsan Azarnasab, Faisal Ahmed, Zicheng Liu, Ce Liu, Michael Zeng, and Lijuan Wang. 2023b. Mm-react: Prompting chatgpt for multimodal reasoning and action. *arXiv preprint arXiv:2303.11381*.\n- Zhuoyi Yang, Jiayan Teng, Wendi Zheng, Ming Ding, Shiyu Huang, Jiazheng Xu, Yuanming Yang, Wenyi Hong, Xiaohan Zhang, Guanyu Feng, and 1 others. 2024c. Cogvideox: Text-to-video diffusion models with an expert transformer. *arXiv preprint arXiv:2408.06072*.\n- Shunyu Yao, Jeffrey Zhao, Dian Yu, Nan Du, Izhak Shafran, Karthik R Narasimhan, and Yuan Cao. 2023. React: Synergizing reasoning and acting in language models. In *The Eleventh International Conference on Learning Representations*.\n- Murong Yue, Wenlin Yao, Haitao Mi, Dian Yu, Ziyu Yao, and Dong Yu. 2024. Dots: Learning to reason dynamically in llms via optimal reasoning trajectories search. *arXiv preprint arXiv:2410.03864*.\n- Zeyu Zhang, Yiran Wang, Biao Wu, Shuo Chen, Zhiyuan Zhang, Shiya Huang, Wenbo Zhang, Meng Fang, Ling Chen, and Yang Zhao. 2024. Motion avatar: Generate human and animal avatars with arbitrary motion. *arXiv preprint arXiv:2405.11286*.\n- Hao Zheng, Xinyan Guan, Hao Kong, Jia Zheng, Weixiang Zhou, Hongyu Lin, Yaojie Lu, Ben He, Xianpei Han, and Le Sun. 2025a. Pptagent: Generating and evaluating presentations beyond text-to-slides. *arXiv preprint arXiv:2501.03936*.\n- Hao Zheng, Xinyan Guan, Hao Kong, Jia Zheng, Weixiang Zhou, Hongyu Lin, Yaojie Lu, Ben He, Xianpei Han, and Le Sun. 2025b. Pptagent: Generating and evaluating presentations beyond text-to-slides. *arXiv preprint arXiv:2501.03936*.\n- Zixiang Zhou, Yu Wan, and Baoyuan Wang. 2024. Avatargpt: All-in-one framework for motion understanding planning generation and beyond. In *Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition*, pages 1357–1366.", + "medias": [] + } + ], + "markdown_content": null + }, + { + "title": "Related Work", + "summary": "The evaluation setup consists of 30 long-form documents with human-created reference videos spanning diverse topics. Each document is processed through the authors' generation pipeline to create a two-minute presentation video. The evaluation framework, PresentEval, employs a split strategy: Qwen-VL-2.5-3B conducts objective assessment via multiple-choice comprehension questions on entire videos, while Qwen-Omni-7B performs subjective scoring on shorter segments. Evaluation dimensions include narrative coherence, visual/audio appeal, and comprehension difficulty, guided by specific prompts. This approach addresses the current limitation of multimodal models in processing longer videos while maintaining comprehensive assessment across content quality, visual quality, and comprehension accuracy.", + "subsections": [ + { + "title": "Document-to-Multimodal Generation", + "content": "Recent advances in large language models (LLMs) and multimodal generation have sparked growing interest in converting documents into diverse output formats, such as slides, posters, or audio summaries (Xu et al., 2025; Wang et al., 2025; Pang et al., 2025; Sun et al., 2024). Systems like PP-TAgent (Zheng et al., 2025b) and Doc2PPT (Fu et al., 2022) treat document-to-slide generation as a structured summarization problem, focusing on layout-aware slide construction. Other works, such as Paper2Poster (Pang et al., 2025) extend this idea by producing single-page visual summaries using layout planning and visual feedback. However, these systems typically generate static outputs and do not model time-dependent delivery such as narration or slide progression. Our work builds upon these foundations, but further introduces temporal planning and audio-visual synchronization, enabling the generation of fully narrated presentation videos.", + "medias": [] + }, + { + "title": "Vision-Language Agents", + "content": "Recent advances have highlighted the expanding capabilities of vision language models (VLMs) beyond traditional language understanding. Techniques such as ReAct (Yao et al., 2023; Yang et al., 2023b; Yue et al., 2024) have shown that LLMs can operate as autonomous agents, capable of stepby-step reasoning and dynamic interaction through code execution (Wang et al., 2024c; Yang et al., 2024a,b), API function calls (Schick et al., 2023; Lu et al., 2025; Yang et al., 2023a), user interface manipulation (Lin et al., 2024b; Qin et al., 2025; Nayak et al., 2025; Wu et al., 2024), and motion generation (Zhang et al., 2024; Zhou et al., 2024; Wang et al., 2024d). Despite these developments, general-purpose agents still struggle with professional tasks that demand accuracy, domainspecific knowledge, and reliable interaction (Lin et al., 2024a). A closely related area is slide automation (Ge et al., 2025; Zheng et al., 2025a), which agents translate short text prompts into executable Python code to render presentation slides. In contrast, our proposed presentation video generation task is significantly more challenging: instead of taking a short prompt as input, the system processes an entire long-form document—such as a research paper, product manual, or technical report—and produces a well-structured presentation video with oral-style narration. This task imposes higher demands on content understanding, multimodal alignment, speech generation, and video synthesis. To address these challenges, we design a generation pipeline along with an automatic evaluation framework to systematically assess the generated videos in terms of information delivery, visual quality, and overall comprehensibility.", + "medias": [] + }, + { + "title": "Implementation Details", + "content": "PresentAgent is implemented using a modular architecture that integrates LLMs, VLMs, and text-to-speech (TTS) systems. Our primary models include LLMs (GPT-4o, GPT-4o-mini, Claude-3.7-sonnet) and VLMs (Qwen-VL-Max, Gemini-2.5-Flash, Gemini-2.5- Pro). For TTS systems, we choose the MegaTTS3 model for better performance. For visual and multimodal evaluation, we use Qwen-VL-2.5-3B-Instruct as VLM. In our experimental pipeline, any input document is automatically transformed into a Power-Point deck, paired with a generated audio narration, and then composited into a synchronized video presentation.", + "medias": [] + } + ], + "markdown_content": null + }, + { + "title": "Implementation Details", + "summary": "The evaluation setup consists of 30 long-form documents with human-created reference videos spanning diverse topics. Each document is processed through the authors' generation pipeline to create a two-minute presentation video. The evaluation framework, PresentEval, employs a split strategy: Qwen-VL-2.5-3B conducts objective assessment via multiple-choice comprehension questions on entire videos, while Qwen-Omni-7B performs subjective scoring on shorter segments. Evaluation dimensions include narrative coherence, visual/audio appeal, and comprehension difficulty, guided by specific prompts. This approach addresses the current limitation of multimodal models in processing longer videos while maintaining comprehensive assessment across content quality, visual quality, and comprehension accuracy.", + "subsections": [ + { + "title": "Multimodal Architecture", + "content": "PresentAgent adopts a highly modular multimodalgeneration architecture. At the languageunderstanding and generation layer, we run six primary LLM back ends in parallel—GPT-4o, GPT-4o-mini, Qwen-VL-Max, Gemini-2.5-Flash, Gemini-2.5-Pro, and Claude-3.7-Sonnet—and select or ensemble them on-the-fly with a dynamic routing policy that weighs input length, conversational complexity, and latency budget. For visuallanguage evaluation, we introduce the lightweight VLM Qwen-VL-2.5-3B-Instruct to score slide layout, chart readability, and cross-modal consistency, feeding its self-critique back into generation. Speech synthesis is unified on MegaTTS3, which outputs 24 kHz, 16-bit high-fidelity narration and supports prosody-tag controls for fine-grained rate, pitch, and emotion adjustment.", + "medias": [] + }, + { + "title": "Experimental Pipeline", + "content": "The experimental pipeline converts any input document—PDF, Markdown, DOCX, or web snapshot through three automated stages:\n\n1. Structured parsing & re-ordering that maps content to a hierarchical topic–subtopic tree.\n\n2. Per-slide generation with the chosen LLM, producing a PowerPoint deck containing titles, bullet points, graphic placeholders, and Alt-Text, while retrieving and inserting relevant images for key nouns.\n\n3. Synchronized narration generation with MegaTTS3 in Chinese or English, followed by an FFmpeg script that assembles a 1080 p video with fade-in/out transitions and optional captions.", + "medias": [] + } + ], + "markdown_content": null + }, + { + "title": "C Discussion", + "summary": "The evaluation setup consists of 30 long-form documents with human-created reference videos spanning diverse topics. Each document is processed through the authors' generation pipeline to create a two-minute presentation video. The evaluation framework, PresentEval, employs a split strategy: Qwen-VL-2.5-3B conducts objective assessment via multiple-choice comprehension questions on entire videos, while Qwen-Omni-7B performs subjective scoring on shorter segments. Evaluation dimensions include narrative coherence, visual/audio appeal, and comprehension difficulty, guided by specific prompts. This approach addresses the current limitation of multimodal models in processing longer videos while maintaining comprehensive assessment across content quality, visual quality, and comprehension accuracy.", + "subsections": [ + { + "title": "Current Work Synthesis", + "content": "In this work, we synthesized presentation-style videos that integrate visual slides, textual narration, and spoken audio, simulating realistic multimodal communication scenarios. While our current evaluation focuses on the individual quality of each modality—such as visual clarity, textual relevance, and audio intelligibility—these dimensions are treated independently. However, in real-world applications, the effectiveness of communication often hinges on the semantic and temporal coherence across modalities.", + "medias": [] + }, + { + "title": "Future Research Direction", + "content": "Future research should thus move beyond isolated assessments and aim toward fusion-aware understanding and evaluation. This entails not only modeling the interactions and alignment among image, audio, and text modalities, but also enabling the system to reason over their combined meaning. Existing models like ImageBind offer a unified embedding space for multiple modalities, but lack the capacity for high-level inference and semantic comprehension.", + "medias": [] + }, + { + "title": "Multimodal Reasoning Integration", + "content": "A promising direction lies in bridging representation alignment with multimodal reasoning, by integrating aligned modality encoders with powerful language models. This would allow the system to jointly perceive, interpret, and respond to complex multimodal inputs—such as explaining a visual concept based on both audio narration and visual cues, or identifying inconsistencies across modalities. Developing such reasoning-capable, fusion-aware models will be critical for advancing robust, coherent multimodal understanding in real-world applications.", + "medias": [] + } + ], + "markdown_content": null + }, + { + "title": "Limitations", + "summary": "The evaluation setup consists of 30 long-form documents with human-created reference videos spanning diverse topics. Each document is processed through the authors' generation pipeline to create a two-minute presentation video. The evaluation framework, PresentEval, employs a split strategy: Qwen-VL-2.5-3B conducts objective assessment via multiple-choice comprehension questions on entire videos, while Qwen-Omni-7B performs subjective scoring on shorter segments. Evaluation dimensions include narrative coherence, visual/audio appeal, and comprehension difficulty, guided by specific prompts. This approach addresses the current limitation of multimodal models in processing longer videos while maintaining comprehensive assessment across content quality, visual quality, and comprehension accuracy.", + "subsections": [ + { + "title": "Computational Cost Constraints", + "content": "Our work faces two key constraints: (1) Due to the high computational costs of commercial LLM/VLM APIs (e.g., GPT-4o and Gemini-2.5- Pro), evaluation was limited to five academic papers, potentially underrepresenting the document diversity shown in our benchmark (Figure 5);", + "medias": [] + }, + { + "title": "Static Slide Limitations", + "content": "(2) PresentAgent currently generates static slides without dynamic animations/effects due to architectural constraints in video synthesis and trade-offs between generation speed and visual quality, as noted in ChronoMagic-Bench's temporal coherence studies. Future improvements could involve lightweight distillation models and physics-aware rendering engines.", + "medias": [] + } + ], + "markdown_content": null + }, + { + "title": "Evaluation Benchmark", + "summary": "The evaluation setup consists of 30 long-form documents with human-created reference videos spanning diverse topics. Each document is processed through the authors' generation pipeline to create a two-minute presentation video. The evaluation framework, PresentEval, employs a split strategy: Qwen-VL-2.5-3B conducts objective assessment via multiple-choice comprehension questions on entire videos, while Qwen-Omni-7B performs subjective scoring on shorter segments. Evaluation dimensions include narrative coherence, visual/audio appeal, and comprehension difficulty, guided by specific prompts. This approach addresses the current limitation of multimodal models in processing longer videos while maintaining comprehensive assessment across content quality, visual quality, and comprehension accuracy.", + "subsections": [ + { + "title": "Document Types and Content", + "content": "As Shown in Figure 5, we showcase four of the representative document types in our benchmark: academic papers, web pages, technical blogs, and presentation slides. These documents cover a broad spectrum of real-world content domains, such as educational tutorials, research briefs, product manuals, scientific articles, news commentary, and business reports. Each document is paired with a manually authored presentation video, providing a diverse and realistic testbed for evaluating documentto-video generation systems in terms of multimodal coherence, content preservation, and presentation quality.", + "medias": [] + } + ], + "markdown_content": null + }, + { + "title": "Doc2Present Dataset Details", + "summary": "The evaluation setup consists of 30 long-form documents with human-created reference videos spanning diverse topics. Each document is processed through the authors' generation pipeline to create a two-minute presentation video. The evaluation framework, PresentEval, employs a split strategy: Qwen-VL-2.5-3B conducts objective assessment via multiple-choice comprehension questions on entire videos, while Qwen-Omni-7B performs subjective scoring on shorter segments. Evaluation dimensions include narrative coherence, visual/audio appeal, and comprehension difficulty, guided by specific prompts. This approach addresses the current limitation of multimodal models in processing longer videos while maintaining comprehensive assessment across content quality, visual quality, and comprehension accuracy.", + "subsections": [ + { + "title": "Data Source", + "content": "We collect 30 high-quality video samples from public platforms, educational repositories, and professional presentation archives. Each video follows a structured narration format, combining slide-based visuals with synchronized voiceover. We manually align each video with its source document and ensure the following conditions are met: (1) the content structure of the video follows that of the document; (2) the visuals convey document information in a compact, structured form; and (3) the narration and slides are well-aligned temporally.", + "medias": [] + }, + { + "title": "Data Statistics", + "content": "The average document length is 3,000–8,000 words, while the corresponding videos range from 1 to 2 minutes and contain 5-10 slides. This setting highlights the core challenge of the task: transforming dense, domain-specific documents into effective and digestible multimodal presentations.", + "medias": [] + } + ], + "markdown_content": null + }, + { + "title": "PresentEval Evaluation Methods", + "summary": "The evaluation setup consists of 30 long-form documents with human-created reference videos spanning diverse topics. Each document is processed through the authors' generation pipeline to create a two-minute presentation video. The evaluation framework, PresentEval, employs a split strategy: Qwen-VL-2.5-3B conducts objective assessment via multiple-choice comprehension questions on entire videos, while Qwen-Omni-7B performs subjective scoring on shorter segments. Evaluation dimensions include narrative coherence, visual/audio appeal, and comprehension difficulty, guided by specific prompts. This approach addresses the current limitation of multimodal models in processing longer videos while maintaining comprehensive assessment across content quality, visual quality, and comprehension accuracy.", + "subsections": [ + { + "title": "Objective Quiz Evaluation", + "content": "## G.1 Prompts of Objective Quiz Evaluation\n\nTable 2 presents the prompting content for the evaluation method utilizing objective quiz-based assessment. Each set of questions included in this evaluation is crafted manually, with its creation firmly rooted in the actual content of the relevant documents. The formulation of these questions\n\n![](_page_11_Figure_0.jpeg)\n\nFigure 5: Document Diversity in Our Evaluation Benchmark.\n\n| Prensentation of Web Pages | What is the main feature highlighted in the iPhone's promotional webpage? |\n| --- | --- |\n| A. | A more powerful chip for faster performance |\n| B. | A brighter and more vibrant display |\n| C. | An upgraded camera system with better lenses |\n| D. | A longer-lasting and more efficient battery |\n| Prensentation of Academic Paper | What primary research gap did the authors aim to address by introducing the FineGym dataset? |\n| A. | Lack of low-resolution sports footage for compression studies |\n| B. | Need for fine-grained action understanding that goes beyond coarse categories |\n| C. | Absence of synthetic data to replace human annotations |\n| D. | Shortage of benchmarks for background context recognition |\n\nTable 2: Prompt of evaluation via Objective Quiz Evaluation. Each question set is manually created based on the actual document content, with a focus on topic recognition, structural understanding, and key argument identification. These questions evaluate how well the generated video communicates the source material.\n\nplaces a distinct emphasis on three key aspects: topic recognition, which involves the ability to accurately identify and grasp the central themes of the source material; structural understanding, referring to the comprehension of the organizational framework and logical arrangement of the document; and key argument identification, focusing on the capacity to pinpoint the core viewpoints and supporting arguments within the content. These carefully designed questions serve as a means to evaluate the extent to which the generated video successfully conveys the essential information, core ideas, and structural logic of the original source material, thereby assessing the effectiveness of the video in communicating the source content.", + "medias": [ + { + "markdown_content": "![](_page_11_Figure_0.jpeg)", + "near_chunks": [ + "Table 2 presents the prompting content for the evaluation method utilizing objective quiz-based assessment. Each set of questions included in this evaluation is crafted manually, with its creation firmly rooted in the actual content of the relevant documents. The formulation of these questions\n\n", + "Figure 5: Document Diversity in Our Evaluation Benchmark.\n\nTable 2: Prompt of evaluation via Objective Quiz Evaluation. Each question set is manually created based on the actual document content, with a focus on topic recognition, structural understanding, and key argument identification. These questions evaluate how well the generated video communicates the source material.\n\n" + ], + "path": "/Users/shijingwei/Desktop/PresentAgent/presentagent/../pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/_page_11_Figure_0.jpeg", + "caption": "Figure: Document diversity visualization showing four types of input documents: Academic Papers, Web Pages, Technical Blogs, and Slides, with representative examples of each format displayed in a grid layout." + }, + { + "markdown_content": "| Prensentation of Web Pages | What is the main feature highlighted in the iPhone's promotional webpage? |\n| --- | --- |\n| A. | A more powerful chip for faster performance |\n| B. | A brighter and more vibrant display |\n| C. | An upgraded camera system with better lenses |\n| D. | A longer-lasting and more efficient battery |\n| Prensentation of Academic Paper | What primary research gap did the authors aim to address by introducing the FineGym dataset? |\n| A. | Lack of low-resolution sports footage for compression studies |\n| B. | Need for fine-grained action understanding that goes beyond coarse categories |\n| C. | Absence of synthetic data to replace human annotations |\n| D. | Shortage of benchmarks for background context recognition |", + "near_chunks": [ + "Figure 5: Document Diversity in Our Evaluation Benchmark.\n\nTable 2 presents the prompting content for the evaluation method utilizing objective quiz-based assessment. Each set of questions included in this evaluation is crafted manually, with its creation firmly rooted in the actual content of the relevant documents. The formulation of these questions\n\n", + "Table 2: Prompt of evaluation via Objective Quiz Evaluation. Each question set is manually created based on the actual document content, with a focus on topic recognition, structural understanding, and key argument identification. These questions evaluate how well the generated video communicates the source material.\n\n" + ], + "path": "/Users/shijingwei/Desktop/PresentAgent/presentagent/../pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/table_21d2.png", + "caption": "Table: Sample quiz questions evaluating understanding of web page and academic paper content, with multiple-choice options addressing iPhone features and FineGym dataset research contributions.", + "cells": [ + [ + "Prensentation of Web Pages", + "What is the main feature highlighted in the iPhone's promotional webpage?" + ], + [ + "A.", + "A more powerful chip for faster performance" + ], + [ + "B.", + "A brighter and more vibrant display" + ], + [ + "C.", + "An upgraded camera system with better lenses" + ], + [ + "D.", + "A longer-lasting and more efficient battery" + ], + [ + "Prensentation of Academic Paper", + "What primary research gap did the authors aim to address by introducing the FineGym dataset?" + ], + [ + "A.", + "Lack of low-resolution sports footage for compression studies" + ], + [ + "B.", + "Need for fine-grained action understanding that goes beyond coarse categories" + ], + [ + "C.", + "Absence of synthetic data to replace human annotations" + ], + [ + "D.", + "Shortage of benchmarks for background context recognition" + ] + ], + "merge_area": null + } + ] + }, + { + "title": "Subjective Scoring", + "content": "### G.2 Prompts of Subjective Scoring\n\nPrompt of evaluation via subjective scoring is shown in table 3. This table showcases the prompting content employed in the subjective scoringbased evaluation approach. Each individual prompt within this set is precisely targeted at a specific evaluative dimension. These dimensions encompass\n\nnarrative coherence, which pertains to the logical flow and consistency of the storytelling; visual appeal and audio appeal, focusing on the attractiveness and engaging nature of the visual elements and audio components respectively; and comprehension difficulty, referring to the level of ease or challenge in understanding the presented content. These prompts are meticulously designed to serve as a guiding framework for vision-language models, enabling them to assess presentations from a human-centric perspective. This means that the evaluation aligns with human perceptions, preferences, and ways of understanding, ensuring that the assessment results are more in line with how humans would judge the quality of the presentations.", + "medias": [] + } + ], + "markdown_content": null + }, + { + "title": "Evaluation Setup", + "summary": "The evaluation setup consists of 30 long-form documents with human-created reference videos spanning diverse topics. Each document is processed through the authors' generation pipeline to create a two-minute presentation video. The evaluation framework, PresentEval, employs a split strategy: Qwen-VL-2.5-3B conducts objective assessment via multiple-choice comprehension questions on entire videos, while Qwen-Omni-7B performs subjective scoring on shorter segments. Evaluation dimensions include narrative coherence, visual/audio appeal, and comprehension difficulty, guided by specific prompts. This approach addresses the current limitation of multimodal models in processing longer videos while maintaining comprehensive assessment across content quality, visual quality, and comprehension accuracy.", + "subsections": [ + { + "title": "Test Set Construction", + "content": "We construct a test set consisting of 30 long-form documents, each paired with a manually created presentation video that serves as a human-level reference. These documents span a diverse range of topics, including education, product explanation, research overviews, and policy briefings. For each document, we generate a corresponding presentation video using our full generation pipeline.", + "medias": [ + { + "markdown_content": "| Video | Scoring Prompt |\n| --- | --- |\n| Narr. Coh. | \"How coherent is the narration across the video? Are the ideas logically connected and easy to follow?\" |\n| Visual Appeal | \"How would you rate the visual design of the slides in terms of layout, aesthetics, and overall quality?\" |\n| Comp. Diff. | \"How easy is it to understand the presentation as a viewer? Were there any confusing or contradictory parts?\" |\n| Audio | Scoring Prompt |\n| Narr. Coh. | \"How coherent is the narration throughout the audio? Are the ideas logically structured and easy to follow?\" |\n| Audio Appeal | \"How pleasant and engaging is the narrator's voice in terms of tone, pacing, and delivery?\" |\n| Comp. Diff. | \"How easy is it to understand the spoken content? Were there any unclear or confusing parts in the audio?\" |", + "near_chunks": [ + "We construct a test set consisting of 30 long-form documents, each paired with a manually created presentation video that serves as a human-level reference. These documents span a diverse range of topics, including education, product explanation,\n\n# H Evaluation Setup\n\n", + "Table 3: Prompt of evaluation via Subjective Scoring. Each prompt targets a specific dimension—narrative coherence, visual/audio appeal, or comprehension difficulty—and is designed to guide vision-language models in assessing presentations from a human-centric perspective. Abbreviations: Narr. Coh. = Narrative Coherence; Comp. Diff. = Comprehension Difficulty.\n\n" + ], + "path": "/Users/shijingwei/Desktop/PresentAgent/presentagent/../pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/table_efca.png", + "caption": "Table: Subjective Scoring Prompts for Video and Audio Evaluation. The table outlines specific questions used to assess narrative coherence, visual/audio appeal, and comprehension difficulty when evaluating presentation quality from both visual and audio perspectives.", + "cells": [ + [ + "Video", + "Scoring Prompt" + ], + [ + "Narr. Coh.", + "\"How coherent is the narration across the video? Are the ideas logically connected and easy to follow?\"" + ], + [ + "Visual Appeal", + "\"How would you rate the visual design of the slides in terms of layout, aesthetics, and overall quality?\"" + ], + [ + "Comp. Diff.", + "\"How easy is it to understand the presentation as a viewer? Were there any confusing or contradictory parts?\"" + ], + [ + "Audio", + "Scoring Prompt" + ], + [ + "Narr. Coh.", + "\"How coherent is the narration throughout the audio? Are the ideas logically structured and easy to follow?\"" + ], + [ + "Audio Appeal", + "\"How pleasant and engaging is the narrator's voice in terms of tone, pacing, and delivery?\"" + ], + [ + "Comp. Diff.", + "\"How easy is it to understand the spoken content? Were there any unclear or confusing parts in the audio?\"" + ] + ], + "merge_area": null + } + ] + }, + { + "title": "Evaluation Framework", + "content": "All videos, both human-created and machinegenerated, are evaluated using our unified evaluation framework, PresentEval. Each synthesized video is approximately two minutes in length. However, due to the current lack of a single multimodal model capable of jointly assessing visual and audio quality for videos longer than two minutes, we adopt a split evaluation strategy.", + "medias": [] + }, + { + "title": "Evaluation Stages", + "content": "In the Objective Quiz stage, we use Qwen-VL-2.5-3B (Wang et al., 2024b) to evaluate the accuracy of the entire video using a fixed set of multiplechoice comprehension questions. In the Subjective Scoring stage, we extract short video/audio segments and evaluate them individually to assess quality in a more focused and scalable manner, using Qwen-Omni-7B (Xu et al., 2025).\n\nBoth models are guided by dimension-specific prompts and score each video or audio sample along three axes: Content Quality, Visual Quality, and Comprehension Accuracy.", + "medias": [] + }, + { + "title": "Scoring Prompts", + "content": "| Video | Scoring Prompt |\n| --- | --- |\n| Narr. Coh. | \"How coherent is the narration across the video? Are the ideas logically connected and easy to follow?\" |\n| Visual Appeal | \"How would you rate the visual design of the slides in terms of layout, aesthetics, and overall quality?\" |\n| Comp. Diff. | \"How easy is it to understand the presentation as a viewer? Were there any confusing or contradictory parts?\" |\n| Audio | Scoring Prompt |\n| Narr. Coh. | \"How coherent is the narration throughout the audio? Are the ideas logically structured and easy to follow?\" |\n| Audio Appeal | \"How pleasant and engaging is the narrator's voice in terms of tone, pacing, and delivery?\" |\n| Comp. Diff. | \"How easy is it to understand the spoken content? Were there any unclear or confusing parts in the audio?\" |\n\nTable 3: Prompt of evaluation via Subjective Scoring. Each prompt targets a specific dimension—narrative coherence, visual/audio appeal, or comprehension difficulty—and is designed to guide vision-language models in assessing presentations from a human-centric perspective. Abbreviations: Narr. Coh. = Narrative Coherence; Comp. Diff. = Comprehension Difficulty.", + "medias": [] + } + ], + "markdown_content": null + } + ], + "metadata": { + "title": "PresentAgent: Multimodal Agent for Presentation Video Generation", + "authors": "Jingwei Shi, Zeyu Zhang, Biao Wu, Yanjie Liang, Meng Fang, Ling Chen, Yang Zhao", + "affiliations": "AI Geeks, Australia; Australian Artificial Intelligence Institute, Australia; University of Liverpool, United Kingdom; La Trobe University, Australia", + "corresponding_author": "y.zhao2@latrobe.edu.au", + "presentation-date": "2025-07-05" + } +} \ No newline at end of file diff --git a/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/source.md b/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/source.md new file mode 100644 index 0000000000000000000000000000000000000000..9ab2bcc12c1399062ec608b51d4232ef07664349 --- /dev/null +++ b/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/source.md @@ -0,0 +1,332 @@ +# PresentAgent: Multimodal Agent for Presentation Video Generation + +Jingwei Shi1∗ Zeyu Zhang1∗† Biao Wu2∗ Yanjie Liang1∗ + +Meng Fang3 Ling Chen2 Yang Zhao4‡ + +1AI Geeks, Australia + +2Australian Artificial Intelligence Institute, Australia 3University of Liverpool, United Kingdom 4La Trobe University, Australia + +∗Equal contribution. † Project lead. ‡Corresponding author: y.zhao2@latrobe.edu.au. + +#### Abstract + +We present PresentAgent, a multimodal agent that transforms long-form documents into narrated presentation videos. While existing approaches are limited to generating static slides or text summaries, our method advances beyond these limitations by producing fully synchronized visual and spoken content that closely mimics human-style presentations. To achieve this integration, PresentAgent employs a modular pipeline that systematically segments the input document, plans and renders slide-style visual frames, generates contextual spoken narration with large language models and Text-to-Speech models, and seamlessly composes the final video with precise audiovisual alignment. Given the complexity of evaluating such multimodal outputs, we introduce PresentEval, a unified assessment framework powered by Vision-Language Models that comprehensively scores videos across three critical dimensions: content fidelity, visual clarity, and audience comprehension through prompt-based evaluation. Our experimental validation on a curated dataset of 30 document–presentation pairs demonstrates that PresentAgent approaches human-level quality across all evaluation metrics. These results highlight the significant potential of controllable multimodal agents in transforming static textual materials into dynamic, effective, and accessible presentation formats. Code will be available at https://github.com/ AIGeeksGroup/PresentAgent. + +### 1 Introduction + +Presentations are a widely used and effective medium for conveying complex ideas. By combining visual elements, structured narration, and spoken explanations, they enable information to unfold progressively and be more easily understood by diverse audiences (Fu et al., 2022). Despite their proven effectiveness, creating high-quality presentation videos from long-form documents—such as + +![](_page_0_Figure_10.jpeg) + +Figure 1: Overview of PresentAgent. It takes documents (e.g., web pages) as input and follows a generation pipeline: (1) document processing, (2) structured slide generation, (3) synchronized caption creation, and (4) audio synthesis. The final output is a presentation video combining visual slides with aligned narration. The purple-highlighted middle results emphasize the system's key transitional outputs during generation. + +business reports, technical manuals, policy briefs, or academic papers—typically requires considerable manual effort (Li et al., 2023). This process involves identifying key content, designing slide layouts, writing scripts, recording narration, and aligning all elements into a coherent multimodal output. + +Although recent advancements in AI have enabled progress in related areas such as documentto-slide generation (Fu et al., 2022; Zheng et al., 2025a; Pang et al., 2025; Zhang et al., 2024) and text-to-video synthesis (Yang et al., 2024c; Li et al., 2023; Xue et al., 2025; Khachatryan et al., 2023; He et al., 2023; Solanki and Khublani, 2024), a critical gap remains: these methods either produce static visual summaries or generic video clips without structured narration, limiting their effectiveness for structured communication tasks like presentations. + +To bridge this gap, we introduce the task of Document-to-Presentation Video Generation, which aims to automatically convert a structured or unstructured document into a narrated video presentation composed of synchronized slides and speech. This task presents unique challenges as it goes beyond traditional summarization (Lewis et al., 2019; Beltagy et al., 2020; Chen and Yang, 2021; Wang + +et al., 2024a) or text-to-speech (Tachibana et al., 2018; Ren et al., 2019; Popov et al., 2021; Ni et al., 2022) pipelines by requiring selective content abstraction, layout-aware planning (Wang et al., 2025), and precise multimodal alignment (Li et al., 2024) between visuals and narration. In contrast to prior work that focuses on either static slide and image generation (Zheng et al., 2025a; Deng et al., 2025; Xie et al., 2024) or audio summarization in isolation, our objective is to produce a fully integrated, viewer-ready video experience that closely mimics how human presenters deliver information in real-world scenarios. + +To tackle these challenges, we propose a modular generation framework named PresentAgent. Given an input document, the system first segments it into semantic blocks through outline planning, then generates layout-guided slide visuals for each block and rewrites the key message into oral-style narration. Subsequently, these are then synthesized into audio and combined with the slide visuals to produce a time-aligned presentation video. Importantly, our pipeline is designed to be domainadaptable and controllable, enabling broad applicability across document types and presentation styles. + +Recognizing the need for rigorous evaluation of such complex multimodal outputs, we curate a test set of 30 human-authored document-video pairs spanning diverse domains, including education, finance, policy, and scientific communication. To comprehensively assess system performance, we further introduce a two-path evaluation strategy that combines fact-based comprehension assessment (via fixed multiple-choice quizzes) and preference-based scoring using vision-language models. This dual-pronged approach captures both objective correctness and subjective quality in video delivery. + +Experiment results demonstrate that our method produces fluent, well-structured, and informative presentation videos, approaching human-level performance in both content delivery and viewer comprehension. These findings highlight the potential of combining language models, layout generation, and multimodal synthesis for creating explainable and scalable presentation systems from raw documents. + +In general, our contributions are summarized as follows: + +- We formulate and address the novel task +of document-to-presentation video generation, which aims to produce narrated, slide-structured videos from long-form documents across diverse domains. + +- We propose PresentAgent, a modular generation framework that integrates document parsing, layout-aware slide composition, narration planning, and audio-visual synchronization, enabling controllable and interpretable generation. +- We introduce PresentEval, a multi-dimensional evaluation framework powered by Vision-Language Models (VLMs), which scores videos along content, visual, and comprehension dimensions via prompt-based judging. +- We create a test set of 30 real-world document–presentation pairs and demonstrate through experiments and ablations that PresentAgent approaches human-level performance and significantly outperforms competitive variants. + +#### 2 Presentation Benchmark + +The benchmark supports evaluation not only of fluency and fidelity, but also of downstream comprehension. Following the methodology introduced in Paper2Poster (Pang et al., 2025), we construct a quiz-style evaluation protocol (§5), where vision-language models are asked to answer factual content questions using only the generated video (slides + narration), simulating an audience's understanding. Human-authored videos are used as reference standards for both score calibration and upperbound comparison. As shown in Figure 5, our benchmark encompasses four representative document types (academic papers, web pages, technical blogs, and slides) paired with human-authored videos, covering diverse real-world domains like education, research, and business reports. + +We adopt a unified, model-based evaluation framework to assess the generated presentation videos. All evaluations are conducted using a vision-language model, guided by dimensionspecific prompts tailored to different assessment objectives. The framework consists of two complementary components: (1) objective quiz evaluation, which measures factual accuracy through multiplechoice question answering; and (2) subjective scoring, which rates Content Quality, Visual or Audio Quality, and Comprehension Clarity on a 1–5 scale. Together, these metrics provide a comprehensive assessment of both the quality and informativeness + +![](_page_2_Figure_0.jpeg) + +Figure 2: Overview of our framework. Our approach addresses the full pipeline of document-to-presentation video generation and evaluation. Left: Given diverse input documents—including papers, websites, blogs, slides, and PDFs—PresentAgent generates narrated presentation videos by producing synchronized slide decks with audio. Right: To evaluate these videos, we introduce PresentEval, a two-part evaluation framework: (1) Objective Quiz Evaluation (top), which measures factual comprehension using Qwen-VL; and (2) Subjective Scoring (bottom), which uses vision-language models to rate content quality, visual design, and audio comprehension across predefined dimensions. + +of the generated videos. + +#### 2.1 Doc2Present Dataset + +To support the evaluation of document to presentation video generation, we curate the Doc2Present Benchmark, a diverse dataset of document–presentation video pairs spanning multiple domains. Unlike prior benchmarks focused on research abstracts or slide generation, our dataset includes documents such as business reports, product manuals, policy briefs, and instructional texts, each paired with a human-crafted presentation video.We collect 30 high-quality video samples from public platforms, educational repositories, and professional presentation archives, further details regarding the data sources and statistical information of the dataset can be found in the appendix F. + +#### 2.2 PresentEval + +To assess the quality of generated presentation videos, we adopt two complementary evaluation strategies: Objective Quiz Evaluation and Subjective Scoring. For each video, we provide the visionlanguage model with the complete set of slide images and the full narration transcript as a unified input—simulating how a real viewer would experience the presentation. In Objective Quiz Evaluation, the model answers a fixed set of factual questions to determine whether the video accurately conveys the key information from the source content. In Subjective Scoring, the model evaluates the video along three dimensions: the coherence of the narration, the clarity and design of the visuals, and the overall ease of understanding. All evaluations are conducted without ground-truth references and + +rely entirely on the model's interpretation of the presented content. + +Objective Quiz Evaluation To evaluate whether a generated presentation video effectively conveys the core content of its source document, we use a fixed-question comprehension evaluation protocol. Specifically, we manually design five multiplechoice questions for each document, tailored to its content. These questions focus on key aspects such as topic recognition, structural understanding, and main argument extraction. As shown in Table 2, during evaluation, a vision-language model is given the video, including both visual frames and audio transcript, and asked to answer the five questions. Each question has four options, with one correct answer, annotated based on a human-created reference video. The final comprehension score (ranging from 0 to 5) reflects how many questions the model answered correctly, serving as a direct measure of how well the video communicates the original document. + +Subjective Scoring To evaluate the quality of generated presentation videos, we adopt a promptbased assessment using vision-language models. Instead of relying on human references or fixed metrics, we ask the model to evaluate each video from a viewer's perspective, using its own reasoning and preferences. The evaluation focuses on three aspects: coherence of narration, clarity and aesthetics of visuals, and overall ease of understanding. The model is shown the video and audio, and gives a score (1–5) with a brief explanation for each aspect. This enables scalable, consistent, + +and human-aligned evaluation without manual references. As shown in Table 3, we design different prompts for different modalities and tasks to ensure targeted and effective assessment. + +#### 3 PresentAgent + +To convert a long-form document into a narrated presentation video, we design a multi-stage generation framework that mirrors how human presenters prepare slides and talk tracks. Our method proceeds in four steps: segmenting the document into semantic units, composing slides with layout-aware structures, generating oral-style narration for each slide and assembling the visual and audio components into a synchronized video. This modular design supports controllability, interpretability, and multimodal alignment, enabling both high-quality generation and fine-grained evaluation. The following sections describe each component in detail. + +#### 3.1 Problem Formulation + +Our method is designed to transform a long-form document into a structured presentation video through a multi-stage generation pipeline. We provide a formal description to highlight the key difference between our approach and conventional slide-based methods. + +Conventional approaches often focus on generating slide elements S directly from a document chunk C, as in Equation 1, where each element includes text or image content, layout attributes, and visual style: + +$$S=\{e_{1},e_{2},...,e_{n}\}=f(C)\qquad\quad(1)$$ + +In contrast, we treat the entire document D as a globally structured input and generate a presentation in three steps: (1) a sequence of semantic segments {C1, ..., CK} via outline planning, (2) a set of slides {S1, ..., SK}, each paired with a narrated audio track Tk generated by first producing a slide-specific script and then converting it to speech, and (3) a video V composed of visual and audio content aligned over time. This is defined as: + +$V=$ **Compose($\{(S_{1},T_{1}),...,(S_{K},T_{K})\})=g(D)$** + +Rather than editing predefined templates or layouts, our system first identifies high-level structure in the document and then generates slide visuals and narration from scratch. This pipeline + +supports controllability, modular evaluation, and multimodal alignment for downstream comprehension and quality assessment. + +#### 3.2 Slide Planning and Composition + +Our slide generation module is inspired by the editing-based paradigm proposed in PPTAgent (Zheng et al., 2025b), which formulates presentation construction as a structured editing process over HTML-like layouts. While PPTAgent focuses on producing editable .pptx slides, our goal is to generate visually coherent, narrationready slide frames for downstream video synthesis. We re-implement the core idea in a self-contained pipeline tailored to multimodal synchronization. + +We begin by segmenting the input document into coherent content blocks using a lightweight LLM-based parser. Each block is assigned a corresponding slide type such as bullet slide, figuredescription, or title-intro, and matched with a predefined layout schema encoded in HTML. Unlike retrieval-based template matching, our system uses semantic and structural cues to map content to layout patterns in a rule-guided manner. + +To populate the slide, we define a set of editable operations such as replace_text, insert_image, and add_list, which are applied to the layout structure. These instructions are generated by prompting a language model with the content block and layout constraints. Slides are then rendered into static visual frames using python-pptx or HTML-based renderers. + +#### 3.3 Narration and Audio Synthesis + +To transform the static slides into an engaging presentation, we generate a spoken narration for each slide and synthesize it into audio. The process involves two components: narration script generation and text-to-speech synthesis. + +For each content block corresponding to a slide, we prompt a language model to generate a concise, oral-style narration. The model is instructed to rewrite the key message of the slide into natural spoken language, avoiding dense text or technical jargon. We apply length control to ensure each narration falls within a target duration, typically between 30 and 150 seconds. Once the narration script is obtained, we synthesize the corresponding audio using a text-to-speech system. Each narration audio is paired with its slide and timestamped, forming the basis for synchronized video rendering in the next stage. + +![](_page_4_Figure_0.jpeg) + +Figure 3: Overview of the PresentAgent framework. Our system takes diverse documents (e.g., papers, websites, PDFs) as input and follows a modular generation pipeline. It first performs outline generation (Step 1) and retrieves the most suitable template (Step 2), then generates slides and narration notes via a vision-language model (Step 3). The notes are converted into audio via TTS and composed into a presentation video (Step 4). To evaluate video quality, we design multiple prompts (Step 5) and feed them into a VLM-based scoring pipeline (Step 6) that outputs dimension-specific metrics. + +#### 3.4 Video Assembly + +In the final stage, we assemble the slide images and narration audio into a coherent, time-aligned presentation video. Each slide frame is displayed for the duration of its corresponding audio segment, with optional transitions between segments. We use video processing libraries such as ffmpeg to compose the visual and audio tracks. Each slide is rendered as a static frame, and the narration is added as synchronized voiceover audio. The output is a fully rendered video file in standard formats such as .mp4, suitable for presentation, sharing, or further editing. This stage completes the transformation from a raw document into a narrated, structured presentation video. + +### 4 Experiments + +We conduct experiments to evaluate the effectiveness of our proposed system in generating highquality, narrated presentation videos. Given the novelty of the task, our focus is not on competing with existing baselines, but rather on assessing the performance of our full system relative to human- + +created presentations. Comprehension accuracy is determined based on performance in the PresentEval task. Evaluation setup can be found in appendix H. + +**Via Fixed Quiz** + +... + +**Question1 : XXX Question2 : XXX** + +**Question5 : XXX** + +**1. Content Quality 2. Visual Quality 3. Comprehension Accuracy** + +#### 4.1 Main Results + +Table 1 presents evaluation results, covering both factual comprehension (Quiz Accuracy) and preference-based quality scores for video and audio outputs. In terms of quiz accuracy, most PresentAgent variants perform comparably to or better than the human reference (0.56), with Claude-3.7 sonnet (Anthropic, 2024) achieving the highest accuracy at 0.64, suggesting strong alignment between the generated content and the source document. Other models such as Qwen-VL-Max (Bai et al., 2025) and Gemini-2.5-flash (DeepMind, 2024) scored slightly lower (0.52), indicating room for improvement in factual grounding. + +In terms of subjective quality, human-created presentations still lead with the highest video and audio scores overall. However, several PresentAgent variants show competitive performance. + +| Method | Model | Quiz Accuracy | | Video Score | | | | Audio Score | | | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| | | | Content | Visual | Comp. | Mean | Content | Audio | Comp. | Mean | +| Human | Human | 0.56 | 4.0 | 4.6 | 4.8 | 4.47 | 4.8 | 4.6 | 5.0 | 4.80 | +| PresentAgent | Claude-3.7-sonnet | 0.64 | 4.0 | 4.0 | 4.0 | 4.00 | 4.2 | 4.6 | 4.8 | 4.53 | +| PresentAgent | Qwen-VL-Max | 0.52 | 4.2 | 4.8 | 4.4 | 4.47 | 4.6 | 4.2 | 5.0 | 4.60 | +| PresentAgent | Gemini-2.5-pro | 0.52 | 4.2 | 4.4 | 4.4 | 4.33 | 4.2 | 4.0 | 4.8 | 4.33 | +| PresentAgent | Gemini-2.5-flash | 0.52 | 4.2 | 5.0 | 3.8 | 4.33 | 4.2 | 4.2 | 4.8 | 4.40 | +| PresentAgent | GPT-4o-Mini | 0.64 | 4.8 | 4.6 | 4.6 | 4.67 | 4.0 | 4.4 | 4.8 | 4.40 | +| PresentAgent | GPT-4o | 0.56 | 4.0 | 4.2 | 3.6 | 3.93 | 4.2 | 4.4 | 4.8 | 4.47 | + +Table 1: Detailed evaluation results on the 5-document test set. Fact-based evaluation includes accuracy on five fixed quiz questions (Q1–Q5). Preference-based evaluation includes 1–5 scale scores for content fidelity, visual design, and overall clarity. Each Quality Score group has a calculated mean column. + +![](_page_5_Figure_2.jpeg) + +Figure 4: PresentAgent Demo. Automatically generates academic-style slides and narrated videos from research papers, streamlining the transformation from written content to engaging visual presentations. + +For example, GPT-4o-Mini (Achiam et al., 2023) achieves top scores in video content and visual appeal (both at or near 4.8), while Claude-3.7 sonnet (Anthropic, 2024) delivers the most balanced audio quality (mean 4.53). Interestingly, Gemini-2.5-flash (DeepMind, 2024) scores highest in visual quality (5.0) but lower in comprehension, reflecting a trade-off between aesthetics and clarity. These results highlight the effectiveness of our modular pipeline and the usefulness of our unified PresentEval framework in capturing diverse aspects of presentation quality. + +#### 4.2 Analysis + +Figure 4 Presents a full example of a PresentAgentauto-generated presentation video, showing a technical blog turned into a narrated presentation. The system identifies structural segments (e.g., introduction, technical explanations) and generates slides with oral-style captions and synchronized speech, covering topics like "parallelization workflow" and "agent system architecture" to demonstrate its ability to keep technical accuracy while delivering content clearly and conversationally. + +### 5 Conclusion + +In conclusion, we presented PresentAgent, a modular system for transforming long-form documents into narrated presentation videos. By addressing the challenges of slide planning, narration synthesis, and synchronized rendering, PresentAgent enables structured, controllable, and reusable multimodal outputs. To evaluate this novel task, we introduced a diverse benchmark and proposed complementary factual and preference-based metrics. Experimental results show that PresentAgent generates coherent, engaging, and informative presentations, approaching human quality. This work lays the groundwork for automated, explainable content generation and opens new directions for research in multimodal communication across education, business, and accessibility. + +### References + +- Josh Achiam, Steven Adler, Sandhini Agarwal, Lama Ahmad, Ilge Akkaya, Florencia Leoni Aleman, Diogo Almeida, Janko Altenschmidt, Sam Altman, Shyamal Anadkat, and 1 others. 2023. Gpt-4 technical report. *arXiv preprint arXiv:2303.08774*. +- Rie Kubota Ando and Tong Zhang. 2005. A framework for learning predictive structures from multiple tasks and unlabeled data. *Journal of Machine Learning Research*, 6:1817–1853. +- Galen Andrew and Jianfeng Gao. 2007. Scalable training of L1-regularized log-linear models. In *Proceedings of the 24th International Conference on Machine Learning*, pages 33–40. +- Anthropic. 2024. Claude 3 technical overview. https://www.anthropic.com/news/claude-3. Accessed: 2025-06-30. +- Shuai Bai, Keqin Chen, Xuejing Liu, Jialin Wang, Wenbin Ge, Sibo Song, Kai Dang, Peng Wang, Shijie Wang, Jun Tang, and 1 others. 2025. Qwen2. 5-vl technical report. *arXiv preprint arXiv:2502.13923*. +- Iz Beltagy, Matthew E Peters, and Arman Cohan. 2020. Longformer: The long-document transformer. *arXiv preprint arXiv:2004.05150*. +- Jiaao Chen and Diyi Yang. 2021. Structure-aware abstractive conversation summarization via discourse and action graphs. *arXiv preprint arXiv:2104.08400*. +- Google DeepMind. 2024. Gemini 2.5: Pushing the frontier with advanced reasoning, multimodality, long context, and next generation agentic capabilities. https://deepmind.google/technologies/ gemini/. Accessed: 2025-06-30. +- Chaorui Deng, Deyao Zhu, Kunchang Li, Chenhui Gou, Feng Li, Zeyu Wang, Shu Zhong, Weihao Yu, Xiaonan Nie, Ziang Song, and 1 others. 2025. Emerging properties in unified multimodal pretraining. *arXiv preprint arXiv:2505.14683*. +- Tsu-Jui Fu, William Yang Wang, Daniel McDuff, and Yale Song. 2022. Doc2ppt: Automatic presentation slides generation from scientific documents. In *Proceedings of the AAAI Conference on Artificial Intelligence*, volume 36, pages 634–642. +- Jiaxin Ge, Zora Zhiruo Wang, Xuhui Zhou, Yi-Hao Peng, Sanjay Subramanian, Qinyue Tan, Maarten Sap, Alane Suhr, Daniel Fried, Graham Neubig, and Trevor Darrell. 2025. Autopresent: Designing structured visuals from scratch. *arXiv preprint arXiv:2501.00912*. +- Yingqing He, Menghan Xia, Haoxin Chen, Xiaodong Cun, Yuan Gong, Jinbo Xing, Yong Zhang, Xintao Wang, Chao Weng, Ying Shan, and 1 others. 2023. Animate-a-story: Storytelling with retrieval-augmented video generation. *arXiv preprint arXiv:2307.06940*. +- Levon Khachatryan, Andranik Movsisyan, Vahram Tadevosyan, Roberto Henschel, Zhangyang Wang, Shant Navasardyan, and Humphrey Shi. 2023. Text2video-zero: Text-to-image diffusion models are zero-shot video generators. In *Proceedings of the IEEE/CVF International Conference on Computer Vision*, pages 15954–15964. +- Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov, and Luke Zettlemoyer. 2019. Bart: Denoising sequence-to-sequence pre-training for natural language generation, translation, and comprehension. *arXiv preprint arXiv:1910.13461*. +- Bo Li, Yuanhan Zhang, Dong Guo, Renrui Zhang, Feng Li, Hao Zhang, Kaichen Zhang, Peiyuan Zhang, Yanwei Li, Ziwei Liu, and 1 others. 2024. Llavaonevision: Easy visual task transfer. *arXiv preprint arXiv:2408.03326*. +- Xin Li, Wenqing Chu, Ye Wu, Weihang Yuan, Fanglong Liu, Qi Zhang, Fu Li, Haocheng Feng, Errui Ding, and Jingdong Wang. 2023. Videogen: A reference-guided latent diffusion approach for high definition text-to-video generation. *arXiv preprint arXiv:2309.00398*. +- Kevin Qinghong Lin, Linjie Li, Difei Gao, Qinchen Wu, Mingyi Yan, Zhengyuan Yang, Lijuan Wang, and Mike Zheng Shou. 2024a. Videogui: A benchmark for gui automation from instructional videos. *arXiv preprint arXiv:2406.10227*. +- Kevin Qinghong Lin, Linjie Li, Difei Gao, Zhengyuan Yang, Shiwei Wu, Zechen Bai, Weixian Lei, Lijuan Wang, and Mike Zheng Shou. 2024b. Showui: One vision-language-action model for gui visual agent. *arXiv preprint arXiv:2411.17465*. +- Pan Lu, Bowen Chen, Sheng Liu, Rahul Thapa, Joseph Boen, and James Zou. 2025. Octotools: An agentic framework with extensible tools for complex reasoning. *arXiv preprint arXiv:2502.11271*. +- Shravan Nayak, Xiangru Jian, Kevin Qinghong Lin, Juan A. Rodriguez, Montek Kalsi, Rabiul Awal, Nicolas Chapados, M. Tamer Özsu, Aishwarya Agrawal, David Vazquez, Christopher Pal, Perouz Taslakian, Spandana Gella, and Sai Rajeswar. 2025. Ui-vision: A desktop-centric gui benchmark for visual perception and interaction. *arXiv preprint arXiv:2503.15661*. +- Junrui Ni, Liming Wang, Heting Gao, Kaizhi Qian, Yang Zhang, Shiyu Chang, and Mark Hasegawa-Johnson. 2022. Unsupervised text-to-speech synthesis by unsupervised automatic speech recognition. *arXiv preprint arXiv:2203.15796*. +- Wei Pang, Kevin Qinghong Lin, Xiangru Jian, Xi He, and Philip Torr. 2025. Paper2poster: Towards multimodal poster automation from scientific papers. *arXiv preprint arXiv:2505.21497*. +- Vadim Popov, Ivan Vovk, Vladimir Gogoryan, Tasnima Sadekova, and Mikhail Kudinov. 2021. Grad-tts: A diffusion probabilistic model for text-to-speech. In *International conference on machine learning*, pages 8599–8608. PMLR. +- Yujia Qin, Yining Ye, Junjie Fang, Haoming Wang, Shihao Liang, Shizuo Tian, Junda Zhang, Jiahao Li, Yunxin Li, Shijue Huang, and 1 others. 2025. Uitars: Pioneering automated gui interaction with native agents. *arXiv preprint arXiv:2501.12326*. +- Mohammad Sadegh Rasooli and Joel R. Tetreault. 2015. Yara parser: A fast and accurate dependency parser. *Computing Research Repository*, arXiv:1503.06733. Version 2. +- Yi Ren, Yangjun Ruan, Xu Tan, Tao Qin, Sheng Zhao, Zhou Zhao, and Tie-Yan Liu. 2019. Fastspeech: Fast, robust and controllable text to speech. *Advances in neural information processing systems*, 32. +- Timo Schick, Jane Dwivedi-Yu, Roberto Dessì, and et al. 2023. Toolformer: Language models can teach themselves to use tools. *arXiv preprint arXiv:2302.04761*. +- Shivam R Solanki and Drupad K Khublani. 2024. From script to screen: Unveiling text-to-video generation. In *Generative Artificial Intelligence: Exploring the Power and Potential of Generative AI*, pages 81–112. Springer. +- Qiushi Sun, Kanzhi Cheng, Zichen Ding, Chuanyang Jin, Yian Wang, Fangzhi Xu, Zhenyu Wu, Chengyou Jia, Liheng Chen, Zhoumianze Liu, and 1 others. 2024. Os-genesis: Automating gui agent trajectory construction via reverse task synthesis. *arXiv preprint arXiv:2412.19723*. +- Hideyuki Tachibana, Katsuya Uenoyama, and Shunsuke Aihara. 2018. Efficiently trainable text-to-speech system based on deep convolutional networks with guided attention. In *2018 IEEE international conference on acoustics, speech and signal processing (ICASSP)*, pages 4784–4788. IEEE. +- Baode Wang, Biao Wu, Weizhen Li, Meng Fang, Yanjie Liang, Zuming Huang, Haozhe Wang, Jun Huang, Ling Chen, Wei Chu, and 1 others. 2025. Infinity parser: Layout aware reinforcement learning for scanned document parsing. *arXiv preprint arXiv:2506.03197*. +- Guanghua Wang, Priyanshi Garg, and Weili Wu. 2024a. Segmented summarization and refinement: A pipeline for long-document analysis on social media. *Journal of Social Computing*, 5(2):132–144. +- Peng Wang, Shuai Bai, Sinan Tan, Shijie Wang, Zhihao Fan, Jinze Bai, Keqin Chen, Xuejing Liu, Jialin Wang, Wenbin Ge, and 1 others. 2024b. Qwen2 vl: Enhancing vision-language model's perception of the world at any resolution. *arXiv preprint arXiv:2409.12191*. +- Xingyao Wang, Boxuan Li, Yufan Song, Frank F Xu, Xiangru Tang, Mingchen Zhuge, Jiayi Pan, Yueqi Song, Bowen Li, Jaskirat Singh, and 1 others. 2024c. Opendevin: An open platform for ai software developers as generalist agents. *arXiv preprint arXiv:2407.16741*. +- Yuan Wang, Di Huang, Yaqi Zhang, Wanli Ouyang, Jile Jiao, Xuetao Feng, Yan Zhou, Pengfei Wan, Shixiang Tang, and Dan Xu. 2024d. Motiongpt-2: A general-purpose motion-language model for motion generation and understanding. *arXiv preprint arXiv:2410.21747*. +- Biao Wu, Yanda Li, Meng Fang, Zirui Song, Zhiwei Zhang, Yunchao Wei, and Ling Chen. 2024. Foundations and recent trends in multimodal mobile agents: A survey. *arXiv preprint arXiv:2411.02006*. +- Jinheng Xie, Weijia Mao, Zechen Bai, David Junhao Zhang, Weihao Wang, Kevin Qinghong Lin, Yuchao Gu, Zhijie Chen, Zhenheng Yang, and Mike Zheng Shou. 2024. Show-o: One single transformer to unify multimodal understanding and generation. *arXiv preprint arXiv:2408.12528*. +- Jin Xu, Zhifang Guo, Jinzheng He, Hangrui Hu, Ting He, Shuai Bai, Keqin Chen, Jialin Wang, Yang Fan, Kai Dang, and 1 others. 2025. Qwen2. 5-omni technical report. *arXiv preprint arXiv:2503.20215*. +- Qiyao Xue, Xiangyu Yin, Boyuan Yang, and Wei Gao. 2025. Phyt2v: Llm-guided iterative self-refinement for physics-grounded text-to-video generation. In *Proceedings of the Computer Vision and Pattern Recognition Conference*, pages 18826–18836. +- John Yang, Carlos Jimenez, Alexander Wettig, Kilian Lieret, Shunyu Yao, Karthik Narasimhan, and Ofir Press. 2024a. Swe-agent: Agent-computer interfaces enable automated software engineering. *Advances in Neural Information Processing Systems*, 37:50528– 50652. +- Ke Yang, Jiateng Liu, John Wu, Chaoqi Yang, Yi R Fung, Sha Li, Zixuan Huang, Xu Cao, Xingyao Wang, Yiquan Wang, and 1 others. 2024b. If llm is the wizard, then code is the wand: A survey on how code empowers large language models to serve as intelligent agents. *arXiv preprint arXiv:2401.00812*. +- Rui Yang, Lin Song, Yanwei Li, Sijie Zhao, Yixiao Ge, Xiu Li, and Ying Shan. 2023a. Gpt4tools: Teaching large language model to use tools via self-instruction. *Advances in Neural Information Processing Systems*, 36:71995–72007. +- Zhengyuan Yang, Linjie Li, Jianfeng Wang, Kevin Lin, Ehsan Azarnasab, Faisal Ahmed, Zicheng Liu, Ce Liu, Michael Zeng, and Lijuan Wang. 2023b. Mm-react: Prompting chatgpt for multimodal reasoning and action. *arXiv preprint arXiv:2303.11381*. +- Zhuoyi Yang, Jiayan Teng, Wendi Zheng, Ming Ding, Shiyu Huang, Jiazheng Xu, Yuanming Yang, Wenyi Hong, Xiaohan Zhang, Guanyu Feng, and 1 others. 2024c. Cogvideox: Text-to-video diffusion + +models with an expert transformer. *arXiv preprint arXiv:2408.06072*. + +- Shunyu Yao, Jeffrey Zhao, Dian Yu, Nan Du, Izhak Shafran, Karthik R Narasimhan, and Yuan Cao. 2023. React: Synergizing reasoning and acting in language models. In *The Eleventh International Conference on Learning Representations*. +- Murong Yue, Wenlin Yao, Haitao Mi, Dian Yu, Ziyu Yao, and Dong Yu. 2024. Dots: Learning to reason dynamically in llms via optimal reasoning trajectories search. *arXiv preprint arXiv:2410.03864*. +- Zeyu Zhang, Yiran Wang, Biao Wu, Shuo Chen, Zhiyuan Zhang, Shiya Huang, Wenbo Zhang, Meng Fang, Ling Chen, and Yang Zhao. 2024. Motion avatar: Generate human and animal avatars with arbitrary motion. *arXiv preprint arXiv:2405.11286*. +- Hao Zheng, Xinyan Guan, Hao Kong, Jia Zheng, Weixiang Zhou, Hongyu Lin, Yaojie Lu, Ben He, Xianpei Han, and Le Sun. 2025a. Pptagent: Generating and evaluating presentations beyond text-to-slides. *arXiv preprint arXiv:2501.03936*. +- Hao Zheng, Xinyan Guan, Hao Kong, Jia Zheng, Weixiang Zhou, Hongyu Lin, Yaojie Lu, Ben He, Xianpei Han, and Le Sun. 2025b. Pptagent: Generating and evaluating presentations beyond text-to-slides. *arXiv preprint arXiv:2501.03936*. +- Zixiang Zhou, Yu Wan, and Baoyuan Wang. 2024. Avatargpt: All-in-one framework for motion understanding planning generation and beyond. In *Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition*, pages 1357–1366. + +## A Related Work + +### A.1 Document-to-Multimodal Generation + +Recent advances in large language models (LLMs) and multimodal generation have sparked growing interest in converting documents into diverse output formats, such as slides, posters, or audio summaries (Xu et al., 2025; Wang et al., 2025; Pang et al., 2025; Sun et al., 2024). Systems like PP-TAgent (Zheng et al., 2025b) and Doc2PPT (Fu et al., 2022) treat document-to-slide generation as a structured summarization problem, focusing on layout-aware slide construction. Other works, such as Paper2Poster (Pang et al., 2025) extend this idea by producing single-page visual summaries using layout planning and visual feedback. However, these systems typically generate static outputs and do not model time-dependent delivery such as narration or slide progression. Our work builds upon these foundations, but further introduces temporal planning and audio-visual synchronization, enabling the generation of fully narrated presentation videos. + +#### A.2 Vision-Language Agents + +Recent advances have highlighted the expanding capabilities of vision language models (VLMs) beyond traditional language understanding. Techniques such as ReAct (Yao et al., 2023; Yang et al., 2023b; Yue et al., 2024) have shown that LLMs can operate as autonomous agents, capable of stepby-step reasoning and dynamic interaction through code execution (Wang et al., 2024c; Yang et al., 2024a,b), API function calls (Schick et al., 2023; Lu et al., 2025; Yang et al., 2023a), user interface manipulation (Lin et al., 2024b; Qin et al., 2025; Nayak et al., 2025; Wu et al., 2024), and motion generation (Zhang et al., 2024; Zhou et al., 2024; Wang et al., 2024d). Despite these developments, general-purpose agents still struggle with professional tasks that demand accuracy, domainspecific knowledge, and reliable interaction (Lin et al., 2024a). A closely related area is slide automation (Ge et al., 2025; Zheng et al., 2025a), which agents translate short text prompts into executable Python code to render presentation slides. In contrast, our proposed presentation video generation task is significantly more challenging: instead of taking a short prompt as input, the system processes an entire long-form document—such as a research paper, product manual, or technical report—and produces a well-structured presentation + +video with oral-style narration. This task imposes higher demands on content understanding, multimodal alignment, speech generation, and video synthesis. To address these challenges, we design a generation pipeline along with an automatic evaluation framework to systematically assess the generated videos in terms of information delivery, visual quality, and overall comprehensibility. + +sectionImplementation Details PresentAgent is implemented using a modular architecture that integrates LLMs, VLMs, and text-to-speech (TTS) systems. Our primary models include LLMs (GPT-4o, GPT-4o-mini, Claude-3.7-sonnet) and VLMs (Qwen-VL-Max, Gemini-2.5-Flash, Gemini-2.5- Pro). For TTS systems, we choose the MegaTTS3 model for better performance. For visual and multimodal evaluation, we use Qwen-VL-2.5-3B-Instruct as VLM. + +In our experimental pipeline, any input document is automatically transformed into a Power-Point deck, paired with a generated audio narration, and then composited into a synchronized video presentation. + +## B Implementation Details + +PresentAgent adopts a highly modular multimodalgeneration architecture. At the languageunderstanding and generation layer, we run six primary LLM back ends in parallel—GPT-4o, GPT-4o-mini, Qwen-VL-Max, Gemini-2.5-Flash, Gemini-2.5-Pro, and Claude-3.7-Sonnet—and select or ensemble them on-the-fly with a dynamic routing policy that weighs input length, conversational complexity, and latency budget. For visuallanguage evaluation, we introduce the lightweight VLM Qwen-VL-2.5-3B-Instruct to score slide layout, chart readability, and cross-modal consistency, feeding its self-critique back into generation. Speech synthesis is unified on MegaTTS3, which outputs 24 kHz, 16-bit high-fidelity narration and supports prosody-tag controls for fine-grained rate, pitch, and emotion adjustment. + +The experimental pipeline converts any input document—PDF, Markdown, DOCX, or web snapshot through three automated stages: + +1. Structured parsing & re-ordering that maps content to a hierarchical topic–subtopic tree. + +2. Per-slide generation with the chosen LLM, producing a PowerPoint deck containing titles, bullet points, graphic placeholders, and Alt-Text, while retrieving and inserting relevant images for + +key nouns. + +3. Synchronized narration generation with MegaTTS3 in Chinese or English, followed by an FFmpeg script that assembles a 1080 p video with fade-in/out transitions and optional captions. + +## C Discussion + +In this work, we synthesized presentation-style videos that integrate visual slides, textual narration, and spoken audio, simulating realistic multimodal communication scenarios. While our current evaluation focuses on the individual quality of each modality—such as visual clarity, textual relevance, and audio intelligibility—these dimensions are treated independently. However, in real-world applications, the effectiveness of communication often hinges on the semantic and temporal coherence across modalities. + +Future research should thus move beyond isolated assessments and aim toward fusion-aware understanding and evaluation. This entails not only modeling the interactions and alignment among image, audio, and text modalities, but also enabling the system to reason over their combined meaning. Existing models like ImageBind offer a unified embedding space for multiple modalities, but lack the capacity for high-level inference and semantic comprehension. + +A promising direction lies in bridging representation alignment with multimodal reasoning, by integrating aligned modality encoders with powerful language models. This would allow the system to jointly perceive, interpret, and respond to complex multimodal inputs—such as explaining a visual concept based on both audio narration and visual cues, or identifying inconsistencies across modalities. Developing such reasoning-capable, fusion-aware models will be critical for advancing robust, coherent multimodal understanding in real-world applications. + +# D Limitations + +Our work faces two key constraints: (1) Due to the high computational costs of commercial LLM/VLM APIs (e.g., GPT-4o and Gemini-2.5- Pro), evaluation was limited to five academic papers, potentially underrepresenting the document diversity shown in our benchmark (Figure 5); (2) PresentAgent currently generates static slides without dynamic animations/effects due to architectural constraints in video synthesis and trade-offs + +between generation speed and visual quality, as noted in ChronoMagic-Bench's temporal coherence studies. Future improvements could involve lightweight distillation models and physics-aware rendering engines. + +## E Evaluation Benchmark + +As Shown in Figure 5, we showcase four of the representative document types in our benchmark: academic papers, web pages, technical blogs, and presentation slides. These documents cover a broad spectrum of real-world content domains, such as educational tutorials, research briefs, product manuals, scientific articles, news commentary, and business reports. Each document is paired with a manually authored presentation video, providing a diverse and realistic testbed for evaluating documentto-video generation systems in terms of multimodal coherence, content preservation, and presentation quality. + +## F Doc2Present Dataset Details + +Data Source. We collect 30 high-quality video samples from public platforms, educational repositories, and professional presentation archives. Each video follows a structured narration format, combining slide-based visuals with synchronized voiceover. We manually align each video with its source document and ensure the following conditions are met: (1) the content structure of the video follows that of the document; (2) the visuals convey document information in a compact, structured form; and (3) the narration and slides are well-aligned temporally. + +Data Statistics. The average document length is 3,000–8,000 words, while the corresponding videos range from 1 to 2 minutes and contain 5-10 slides. This setting highlights the core challenge of the task: transforming dense, domain-specific documents into effective and digestible multimodal presentations. + +# G PresentEval + +## G.1 Prompts of Objective Quiz Evaluation + +Table 2 presents the prompting content for the evaluation method utilizing objective quiz-based assessment. Each set of questions included in this evaluation is crafted manually, with its creation firmly rooted in the actual content of the relevant documents. The formulation of these questions + +![](_page_11_Figure_0.jpeg) + +Figure 5: Document Diversity in Our Evaluation Benchmark. + +| Prensentation of Web Pages | What is the main feature highlighted in the iPhone's promotional webpage? | +| --- | --- | +| A. | A more powerful chip for faster performance | +| B. | A brighter and more vibrant display | +| C. | An upgraded camera system with better lenses | +| D. | A longer-lasting and more efficient battery | +| Prensentation of Academic Paper | What primary research gap did the authors aim to address by introducing the FineGym dataset? | +| A. | Lack of low-resolution sports footage for compression studies | +| B. | Need for fine-grained action understanding that goes beyond coarse categories | +| C. | Absence of synthetic data to replace human annotations | +| D. | Shortage of benchmarks for background context recognition | + +Table 2: Prompt of evaluation via Objective Quiz Evaluation. Each question set is manually created based on the actual document content, with a focus on topic recognition, structural understanding, and key argument identification. These questions evaluate how well the generated video communicates the source material. + +places a distinct emphasis on three key aspects: topic recognition, which involves the ability to accurately identify and grasp the central themes of the source material; structural understanding, referring to the comprehension of the organizational framework and logical arrangement of the document; and key argument identification, focusing on the capacity to pinpoint the core viewpoints and supporting arguments within the content. These carefully designed questions serve as a means to evaluate the extent to which the generated video successfully conveys the essential information, core ideas, and structural logic of the original source material, thereby assessing the effectiveness of the video in communicating the source content. + +### G.2 Prompts of Subjective Scoring + +Prompt of evaluation via subjective scoring is shown in table 3. This table showcases the prompting content employed in the subjective scoringbased evaluation approach. Each individual prompt within this set is precisely targeted at a specific evaluative dimension. These dimensions encompass + +narrative coherence, which pertains to the logical flow and consistency of the storytelling; visual appeal and audio appeal, focusing on the attractiveness and engaging nature of the visual elements and audio components respectively; and comprehension difficulty, referring to the level of ease or challenge in understanding the presented content. These prompts are meticulously designed to serve as a guiding framework for vision-language models, enabling them to assess presentations from a human-centric perspective. This means that the evaluation aligns with human perceptions, preferences, and ways of understanding, ensuring that the assessment results are more in line with how humans would judge the quality of the presentations. + +# H Evaluation Setup + +We construct a test set consisting of 30 long-form documents, each paired with a manually created presentation video that serves as a human-level reference. These documents span a diverse range of topics, including education, product explanation, + +| Video | Scoring Prompt | +| --- | --- | +| Narr. Coh. | "How coherent is the narration across the video? Are the ideas logically connected and easy to follow?" | +| Visual Appeal | "How would you rate the visual design of the slides in terms of layout, aesthetics, and overall quality?" | +| Comp. Diff. | "How easy is it to understand the presentation as a viewer? Were there any confusing or contradictory parts?" | +| Audio | Scoring Prompt | +| Narr. Coh. | "How coherent is the narration throughout the audio? Are the ideas logically structured and easy to follow?" | +| Audio Appeal | "How pleasant and engaging is the narrator's voice in terms of tone, pacing, and delivery?" | +| Comp. Diff. | "How easy is it to understand the spoken content? Were there any unclear or confusing parts in the audio?" | + +Table 3: Prompt of evaluation via Subjective Scoring. Each prompt targets a specific dimension—narrative coherence, visual/audio appeal, or comprehension difficulty—and is designed to guide vision-language models in assessing presentations from a human-centric perspective. Abbreviations: Narr. Coh. = Narrative Coherence; Comp. Diff. = Comprehension Difficulty. + +research overviews, and policy briefings. For each document, we generate a corresponding presentation video using our full generation pipeline. + +All videos, both human-created and machinegenerated, are evaluated using our unified evaluation framework, PresentEval. Each synthesized video is approximately two minutes in length. However, due to the current lack of a single multimodal model capable of jointly assessing visual and audio quality for videos longer than two minutes, we adopt a split evaluation strategy. + +In the Objective Quiz stage, we use Qwen-VL-2.5-3B (Wang et al., 2024b) to evaluate the accuracy of the entire video using a fixed set of multiplechoice comprehension questions. In the Subjective Scoring stage, we extract short video/audio segments and evaluate them individually to assess quality in a more focused and scalable manner, using Qwen-Omni-7B (Xu et al., 2025). + +Both models are guided by dimension-specific prompts and score each video or audio sample along three axes: Content Quality, Visual Quality, and Comprehension Accuracy. + diff --git a/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/source.pdf b/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/source.pdf new file mode 100644 index 0000000000000000000000000000000000000000..2884a9ae9161d6fcd7ed6c8463d9127e17a8c786 --- /dev/null +++ b/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/source.pdf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1d30fa77088faf2e5e9a1ea32d787b218e7a5a455f07ed9c192694ed5174871d +size 1829248 diff --git a/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/table_21d2.png b/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/table_21d2.png new file mode 100644 index 0000000000000000000000000000000000000000..64f93310f0f4dbce6c6a07ce5c5680a0f0a52c69 Binary files /dev/null and b/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/table_21d2.png differ diff --git a/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/table_efca.png b/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/table_efca.png new file mode 100644 index 0000000000000000000000000000000000000000..3dadb52863c962d427a31e338201c7734c30ae0b Binary files /dev/null and b/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/table_efca.png differ diff --git a/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/table_f5f7.png b/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/table_f5f7.png new file mode 100644 index 0000000000000000000000000000000000000000..f547e4a71602bd2677bd542efa6f579a9b5aa058 Binary files /dev/null and b/pptagent/runs/pdf/9145dbfce1296e2b0603293042aa883e/table_f5f7.png differ diff --git a/pptagent/runs/ppt_video/ca046385-ac3d-4240-9284-a96c57d934d3/output.mp4 b/pptagent/runs/ppt_video/ca046385-ac3d-4240-9284-a96c57d934d3/output.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..cff93b20629d09c45ed457acd32e831a9bad3c60 --- /dev/null +++ b/pptagent/runs/ppt_video/ca046385-ac3d-4240-9284-a96c57d934d3/output.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:33ae87a36e52d48b8ace4eae461d382ea9a42125b550d223a4ffeb213e06fc36 +size 484166 diff --git a/pptagent/runs/ppt_video/ca046385-ac3d-4240-9284-a96c57d934d3/source.pdf b/pptagent/runs/ppt_video/ca046385-ac3d-4240-9284-a96c57d934d3/source.pdf new file mode 100644 index 0000000000000000000000000000000000000000..d431cc5a975e8ae84f9453bc2d604c6b8a525f32 Binary files /dev/null and b/pptagent/runs/ppt_video/ca046385-ac3d-4240-9284-a96c57d934d3/source.pdf differ diff --git a/pptagent/runs/ppt_video/ca046385-ac3d-4240-9284-a96c57d934d3/source.pptx b/pptagent/runs/ppt_video/ca046385-ac3d-4240-9284-a96c57d934d3/source.pptx new file mode 100644 index 0000000000000000000000000000000000000000..d363ad7656b7bd0c87f202adf28763fb6772d8da Binary files /dev/null and b/pptagent/runs/ppt_video/ca046385-ac3d-4240-9284-a96c57d934d3/source.pptx differ diff --git a/pptagent/runs/ppt_video/e88b9f32-6b97-4096-abd6-9bee103524b6/output.mp4 b/pptagent/runs/ppt_video/e88b9f32-6b97-4096-abd6-9bee103524b6/output.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..be481a7a96bdfa7d91bccefc0374794b032864db --- /dev/null +++ b/pptagent/runs/ppt_video/e88b9f32-6b97-4096-abd6-9bee103524b6/output.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5eee4df12ff7a893d12cdce87dee3e2df826ada26b50807090a63af7a3f31d92 +size 4414794 diff --git a/pptagent/runs/ppt_video/e88b9f32-6b97-4096-abd6-9bee103524b6/source.pdf b/pptagent/runs/ppt_video/e88b9f32-6b97-4096-abd6-9bee103524b6/source.pdf new file mode 100644 index 0000000000000000000000000000000000000000..4d96791f6a89cb0c762e387bacef328280181c91 Binary files /dev/null and b/pptagent/runs/ppt_video/e88b9f32-6b97-4096-abd6-9bee103524b6/source.pdf differ diff --git a/pptagent/runs/ppt_video/e88b9f32-6b97-4096-abd6-9bee103524b6/source.pptx b/pptagent/runs/ppt_video/e88b9f32-6b97-4096-abd6-9bee103524b6/source.pptx new file mode 100644 index 0000000000000000000000000000000000000000..d363ad7656b7bd0c87f202adf28763fb6772d8da Binary files /dev/null and b/pptagent/runs/ppt_video/e88b9f32-6b97-4096-abd6-9bee103524b6/source.pptx differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/image_stats.json b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/image_stats.json new file mode 100644 index 0000000000000000000000000000000000000000..ed2e58521de66dd3f511511ed0c69b6a11765779 --- /dev/null +++ b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/image_stats.json @@ -0,0 +1,93 @@ +{ + "a1c98d25e5c2a3059235733edc58ea6984e75dc9.png": { + "size": [ + 2873, + 1069 + ], + "appear_times": 1, + "slide_numbers": [ + 1 + ], + "relative_area": 6.404320987654321, + "top_ranges_str": "1", + "caption": "Logo: BMVC 2024 logo featuring a blue circuit-like pattern resembling a camera or technological device on the left side with \"BMVC 2024\" text in dark blue on the right." + }, + "83d1124da2030bef8f40da55db202923268685e2.png": { + "size": [ + 296, + 296 + ], + "appear_times": 1, + "slide_numbers": [ + 1 + ], + "relative_area": 2.3909465020576133, + "top_ranges_str": "1", + "caption": "Picture: A blank white image with no visible content or elements to describe." + }, + "5452ff4f227c6ba1d7ad666974203486e642daf6.png": { + "size": [ + 1672, + 703 + ], + "appear_times": 1, + "slide_numbers": [ + 6 + ], + "relative_area": 74.5679012345679, + "top_ranges_str": "6", + "caption": "Diagram: An illustration of wolf motion synthesis showing original actions (Howl, Walk, Attack, Die) and expanded combined motions through a SinMDM model, with a workflow demonstrating how text prompts about wolf behaviors are refined through AI models to generate detailed motion descriptions." + }, + "35639ff12c3127b2ba9419b7c784b212753ff628.png": { + "size": [ + 1221, + 524 + ], + "appear_times": 1, + "slide_numbers": [ + 9 + ], + "relative_area": 66.79149519890261, + "top_ranges_str": "9", + "caption": "Diagram: A comprehensive AI pipeline showing the process of generating 3D animated jaguar models from text prompts, including SDXL for image creation, TripoSR for 3D mesh conversion, and MoMASK for motion sequence generation." + }, + "203e2300314026057b7257a3c105a8d2fad5183e.png": { + "size": [ + 900, + 506 + ], + "appear_times": 1, + "slide_numbers": [ + 10 + ], + "relative_area": 22.124485596707817, + "top_ranges_str": "10", + "caption": "Diagram: A collection of sequential animation frames showing various character movements including people, wolves, horses, and jaguars in different actions like walking, running, jumping, and attacking." + }, + "4bbdd852ecafe7b9f1c65dfdbba4a04a5de91642.png": { + "size": [ + 1233, + 225 + ], + "appear_times": 1, + "slide_numbers": [ + 13 + ], + "relative_area": 24.963991769547324, + "top_ranges_str": "13", + "caption": "Table: Comparison of model performance metrics showing LLM Planner outperforming LLaMA-7B with significantly higher accuracy scores across animal, motion, and overall categories, with green numbers indicating percentage improvements." + }, + "fee3a1e81ae1678f114d5799e440cc2b7d740aa1.png": { + "size": [ + 1636, + 988 + ], + "appear_times": 1, + "slide_numbers": [ + 14 + ], + "relative_area": 39.7136488340192, + "top_ranges_str": "14", + "caption": "Diagram: Collection of 3D character models showing various animated creatures including fantasy characters (demon and dragon-themed anime figures, yellow Yacuruna), animals (bear, cats, dogs, horses) and a cobra snake, each displayed from multiple angles with identifying labels." + } +} \ No newline at end of file diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/images/203e2300314026057b7257a3c105a8d2fad5183e.png b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/images/203e2300314026057b7257a3c105a8d2fad5183e.png new file mode 100644 index 0000000000000000000000000000000000000000..b9d3f1d64409d4393873a9f52c029ceb45ef80f5 --- /dev/null +++ b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/images/203e2300314026057b7257a3c105a8d2fad5183e.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e36ae7d668c0382125d2a205bbb18d16d9c3df01e81d4a26c5320f94dbf5e793 +size 333197 diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/images/35639ff12c3127b2ba9419b7c784b212753ff628.png b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/images/35639ff12c3127b2ba9419b7c784b212753ff628.png new file mode 100644 index 0000000000000000000000000000000000000000..b3f8f6a0a6dff29ae09a12161aa51f09f99cbef6 --- /dev/null +++ b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/images/35639ff12c3127b2ba9419b7c784b212753ff628.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3b5266de03220e670712288c7f1273ac56cd458a23999edfe11151336c7d7b96 +size 269874 diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/images/4bbdd852ecafe7b9f1c65dfdbba4a04a5de91642.png b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/images/4bbdd852ecafe7b9f1c65dfdbba4a04a5de91642.png new file mode 100644 index 0000000000000000000000000000000000000000..72c3f9e8fe459f6a58b1dc0e9ca108d6a50f71ef Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/images/4bbdd852ecafe7b9f1c65dfdbba4a04a5de91642.png differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/images/5452ff4f227c6ba1d7ad666974203486e642daf6.png b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/images/5452ff4f227c6ba1d7ad666974203486e642daf6.png new file mode 100644 index 0000000000000000000000000000000000000000..8c1d3ae54c72b64e086c466e53235cf0f8b22f5d --- /dev/null +++ b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/images/5452ff4f227c6ba1d7ad666974203486e642daf6.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:950b236ee7927e02c41e1c9f4d7101e283a5c3cc6e53a700b8e0a7f9326fa288 +size 657026 diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/images/83d1124da2030bef8f40da55db202923268685e2.png b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/images/83d1124da2030bef8f40da55db202923268685e2.png new file mode 100644 index 0000000000000000000000000000000000000000..69d457b3cb268f271026f2ec3d22c76d720c59ac Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/images/83d1124da2030bef8f40da55db202923268685e2.png differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/images/a1c98d25e5c2a3059235733edc58ea6984e75dc9.png b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/images/a1c98d25e5c2a3059235733edc58ea6984e75dc9.png new file mode 100644 index 0000000000000000000000000000000000000000..7373b9902d892abf7ecdbb3ed19b9c21d23d5414 --- /dev/null +++ b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/images/a1c98d25e5c2a3059235733edc58ea6984e75dc9.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e39c960001e890ba64a2fad3c7b2d2773228ab5ad894c2c381d15900b8f57041 +size 277111 diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/images/fee3a1e81ae1678f114d5799e440cc2b7d740aa1.png b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/images/fee3a1e81ae1678f114d5799e440cc2b7d740aa1.png new file mode 100644 index 0000000000000000000000000000000000000000..773a1db61b92e367a5ee707bc9077efba7e713ed --- /dev/null +++ b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/images/fee3a1e81ae1678f114d5799e440cc2b7d740aa1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:837401a779f30e28375fe8ee0305bab849fd6e6aef33d70d11034e7a5a8cf7c4 +size 915138 diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0001.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4afa82fede49c5107d112f68a58b5b55a8cd39f7 Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0001.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0002.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5594e72fdb6f3e6b47237747790de9fc0853477d Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0002.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0003.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f01013c0bb2cb9d517840f827d64be22a078619f Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0003.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0004.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0004.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1a10ab19edf051efd52d339f94c97ecc4c29ee9a Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0004.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0005.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0005.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2c9823521b8ec5ca93c3410f5955d293e4f8758a Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0005.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0006.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0006.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d30412cc7786f3a120ea604492b86102014fd47d Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0006.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0007.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0007.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a58df1eaac9841f46cfe3eb349f2a0b0a81cf64b Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0007.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0008.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0008.jpg new file mode 100644 index 0000000000000000000000000000000000000000..62750ef341f5e7c2caccaa3a537eb0830de81a61 Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0008.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0009.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0009.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3e7792617b3b8e34737d1ce32179c66575295628 Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0009.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0010.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0010.jpg new file mode 100644 index 0000000000000000000000000000000000000000..af5d9d959e75b7b381335d3968425df993f72afd Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0010.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0011.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0011.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3a014b542668a8e07414e8e20dbdee602958af7e Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0011.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0012.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0012.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c256d90012f3c7a06930099de89dbbeec677d102 Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0012.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0013.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0013.jpg new file mode 100644 index 0000000000000000000000000000000000000000..24b11ee44f85eb8edff22b7b9b12fc1caf3f9edb Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0013.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0014.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0014.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0d20a57198eed7c0204a45fbdecd2e884d000f3e Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0014.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0015.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0015.jpg new file mode 100644 index 0000000000000000000000000000000000000000..da2054cbca0978c7b4f580645592e50dc1a332fc Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0015.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0016.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0016.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1b765f3c0e4d59ae686e25be8f204e36ee69f9b1 Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0016.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0017.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0017.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cd435d0673f1eaee304fa9a89da0be7dcd57b0c9 Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0017.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0018.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0018.jpg new file mode 100644 index 0000000000000000000000000000000000000000..453616f7ce223a545578f0dde092269851bfebd9 Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/slide_images/slide_0018.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/source.pptx b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/source.pptx new file mode 100644 index 0000000000000000000000000000000000000000..a4663336820e0f6db49b6bd7f26e3bce03ee668c --- /dev/null +++ b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/source.pptx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3795ad16408515f6f3041fe35547e1cddc3a070e348b5660160f09d63563d343 +size 2877057 diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template.pptx b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template.pptx new file mode 100644 index 0000000000000000000000000000000000000000..29005f046a563309b51a38951a33a25dfc8735f2 --- /dev/null +++ b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template.pptx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b41577e048ab4f6d384bc9629f53c45b1311f3a23e4e644d55af441edb478eb4 +size 368779 diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0001.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8993b2a0c9ab32813b06c9157bb4e0fd010b0ac1 Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0001.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0002.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7193af890545add528dd1a82c7ecd50f092c2ebc Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0002.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0003.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1db26b05439dd60c93f69584856a34cfc86fb1a4 Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0003.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0004.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0004.jpg new file mode 100644 index 0000000000000000000000000000000000000000..787b0ec2e0443ac6112120f42d5f51027826394c Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0004.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0005.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0005.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c80354cdd62da17414de89e58b99d92fa7e683a7 Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0005.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0006.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0006.jpg new file mode 100644 index 0000000000000000000000000000000000000000..43174fd628998fb699d6f3380a6e35d99fb8af2d Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0006.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0007.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0007.jpg new file mode 100644 index 0000000000000000000000000000000000000000..20b624eab11a2a6b79a0bada36e3839f4ec37c11 Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0007.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0008.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0008.jpg new file mode 100644 index 0000000000000000000000000000000000000000..743da6e90377da82b9bcd28e2c31f629671c0741 Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0008.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0009.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0009.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4ebc7d3a91ae9d35b3dc0410fc75b06e59bd623f Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0009.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0010.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0010.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f913ef0e517bbbc69b0218f340f6b4ceb97b7b4e Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0010.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0011.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0011.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ae8b35c1683650390abacd4ad97a2ed905cd46d8 Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0011.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0012.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0012.jpg new file mode 100644 index 0000000000000000000000000000000000000000..743da6e90377da82b9bcd28e2c31f629671c0741 Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0012.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0013.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0013.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e828b66e145420cb63be3b033141a4372aaa236c Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0013.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0014.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0014.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6896c8410069799bc6bda64b99fdb24823b8232a Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0014.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0015.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0015.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a17dff3fcd38b1d53252c210660f3eeadd11bcaa Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0015.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0016.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0016.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bc4fd5eb06952701d677761b861959b12f3fc18f Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0016.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0017.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0017.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ffc264bb4945818793d417be091b33a3f8af160a Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0017.jpg differ diff --git a/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0018.jpg b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0018.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2ff955d6f36bada4fa79a22b76c51d081d11c4b0 Binary files /dev/null and b/pptagent/runs/pptx/0210ff6b414902fa05857e734dd5bcee/template_images/slide_0018.jpg differ diff --git a/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/image_stats.json b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/image_stats.json new file mode 100644 index 0000000000000000000000000000000000000000..6fdc28a7f53337d1986aeae4aa5d91aa0f7a17d6 --- /dev/null +++ b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/image_stats.json @@ -0,0 +1,15 @@ +{ + "1df5f510a94dff77293458473e5407d97a31bdfe.png": { + "size": [ + 784, + 878 + ], + "appear_times": 1, + "slide_numbers": [ + 3 + ], + "relative_area": 33.9647633744856, + "top_ranges_str": "3", + "caption": "Diagram: Illustration showing an inclined plane with an angle \\( \\theta \\), a reference point \\( C \\), and various vectors labeled \\( \\mathbf{e} \\), \\( \\mathbf{e_z} \\), and \\( \\mathbf{C_{ptmext}} \\) indicating direction and measurement." + } +} \ No newline at end of file diff --git a/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/images/1df5f510a94dff77293458473e5407d97a31bdfe.png b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/images/1df5f510a94dff77293458473e5407d97a31bdfe.png new file mode 100644 index 0000000000000000000000000000000000000000..d339e4716273eeb6a74d85e1daa5d54555e46cac Binary files /dev/null and b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/images/1df5f510a94dff77293458473e5407d97a31bdfe.png differ diff --git a/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/slide_images/slide_0001.jpg b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/slide_images/slide_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6171ae480ea54e615bd8af9ce5fdc32b2c14a4d5 Binary files /dev/null and b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/slide_images/slide_0001.jpg differ diff --git a/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/slide_images/slide_0002.jpg b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/slide_images/slide_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3dcd2e296b9d78c514a722ae77a2cdea3da821d6 Binary files /dev/null and b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/slide_images/slide_0002.jpg differ diff --git a/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/slide_images/slide_0003.jpg b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/slide_images/slide_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..42cb03677628089dcafadaf00193415b1ae55ed4 Binary files /dev/null and b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/slide_images/slide_0003.jpg differ diff --git a/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/slide_images/slide_0004.jpg b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/slide_images/slide_0004.jpg new file mode 100644 index 0000000000000000000000000000000000000000..651baca97938b5867121a95571bb96a6b3e2a12e Binary files /dev/null and b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/slide_images/slide_0004.jpg differ diff --git a/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/slide_images/slide_0005.jpg b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/slide_images/slide_0005.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2bc65109c1a60ec7b8082e5e4b3523867447748a Binary files /dev/null and b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/slide_images/slide_0005.jpg differ diff --git a/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/slide_images/slide_0006.jpg b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/slide_images/slide_0006.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f3748a58373a7a8b6072484b276bee3b5c1666f0 Binary files /dev/null and b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/slide_images/slide_0006.jpg differ diff --git a/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/slide_induction.json b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/slide_induction.json new file mode 100644 index 0000000000000000000000000000000000000000..f109f08cd0b84c79d7d6bfe63ceec83284a0f032 --- /dev/null +++ b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/slide_induction.json @@ -0,0 +1,122 @@ +{ + "opening": { + "slides": [ + 1 + ], + "template_id": 1, + "content_schema": { + "presenters": { + "description": "name of the presenter", + "type": "text", + "data": [ + "潘伟洲(josephpan)" + ] + }, + "affiliation": { + "description": "presenter's affiliation or department", + "type": "text", + "data": [ + "SNG-社交平台部-空间运营中心" + ] + }, + "presentation date": { + "description": "date of the presentation", + "type": "text", + "data": [ + "2025/4/29" + ] + }, + "main title": { + "description": "main title of the presentation", + "type": "text", + "data": [ + "移动客户端通道面试陈述" + ] + } + } + }, + "table of contents": { + "slides": [ + 2 + ], + "template_id": 2, + "content_schema": { + "main title": { + "description": "main title of the slide", + "type": "text", + "data": [ + "Table of Contents" + ] + }, + "content bullets": { + "description": "content bullets of the slide", + "type": "text", + "data": [ + "个人经历", + "项目经验 ", + "技术影响力", + "专业领域优势" + ] + } + } + }, + "section outline": { + "slides": [ + 3, + 4, + 5 + ], + "template_id": 3, + "content_schema": { + "main title": { + "description": "main title of the slide", + "type": "text", + "data": [ + "个人经历" + ] + }, + "content paragraph": { + "description": "content paragraph of the slide", + "type": "text", + "data": [ + "这张图展示了一个倾斜圆盘在空间中的几何关系。圆盘的法向量为 \\vec{e},与竖直方向单位向量 \\vec{e}z 之间夹角为 \\theta,表示圆盘的倾斜角度。圆盘中心为点 C,红色箭头 \\vec{C}{ptmext} 表示作用在该点上的外力或外力矩。该图常用于描述刚体在三维空间中的姿态与受力关系。" + ] + }, + "main image": { + "description": "main image of the slide", + "type": "image", + "data": [ + "Diagram: Illustration showing an inclined plane with an angle \\( \\theta \\), a reference point \\( C \\), and various vectors labeled \\( \\mathbf{e} \\), \\( \\mathbf{e_z} \\), and \\( \\mathbf{C_{ptmext}} \\) indicating direction and measurement." + ] + } + } + }, + "ending": { + "slides": [ + 6 + ], + "template_id": 6, + "content_schema": { + "main title": { + "description": "main title of the slide", + "type": "text", + "data": [ + "本次报告到此结束" + ] + }, + "content paragraph": { + "description": "additional content or closing remarks of the slide", + "type": "text", + "data": [ + "欢迎批评指正!" + ] + } + } + }, + "functional_keys": [ + "opening", + "table of contents", + "section outline", + "ending" + ] +} \ No newline at end of file diff --git a/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/source.pptx b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/source.pptx new file mode 100644 index 0000000000000000000000000000000000000000..78d2eeb927d7f939becf52c190525b5f41520165 --- /dev/null +++ b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/source.pptx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec684256c1f6edbc172b67432eb48c4fc2b68c7111630615a5719eb6e9117b7c +size 129909 diff --git a/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/template.pptx b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/template.pptx new file mode 100644 index 0000000000000000000000000000000000000000..c7f9759e4fd18b653649b3c56cec459bfaf0d711 Binary files /dev/null and b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/template.pptx differ diff --git a/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/template_images/slide_0001.jpg b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/template_images/slide_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9ba7d58564fc4bfe93b778572cec6720e0bb9955 Binary files /dev/null and b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/template_images/slide_0001.jpg differ diff --git a/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/template_images/slide_0002.jpg b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/template_images/slide_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7b85d91751058b26613bb4f6400f0bbcf5b2dc61 Binary files /dev/null and b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/template_images/slide_0002.jpg differ diff --git a/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/template_images/slide_0003.jpg b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/template_images/slide_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b366275bb58556d2bbf997920b0b423f2faba64c Binary files /dev/null and b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/template_images/slide_0003.jpg differ diff --git a/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/template_images/slide_0004.jpg b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/template_images/slide_0004.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f44d8046f8448dd38df987799864c42c6ad4c3b0 Binary files /dev/null and b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/template_images/slide_0004.jpg differ diff --git a/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/template_images/slide_0005.jpg b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/template_images/slide_0005.jpg new file mode 100644 index 0000000000000000000000000000000000000000..94c2472fa24bf1f0d7524d43cfda09ee3798e370 Binary files /dev/null and b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/template_images/slide_0005.jpg differ diff --git a/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/template_images/slide_0006.jpg b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/template_images/slide_0006.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6e40e8d132a9d5bf4fa8eb3e7d1773f1de409da3 Binary files /dev/null and b/pptagent/runs/pptx/c1eb4d337b2aa71bec0b0bda89322db2/template_images/slide_0006.jpg differ diff --git a/pptagent/utils.py b/pptagent/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3a6f00408155d8c4ae435f0a05e1cda36f2f4195 --- /dev/null +++ b/pptagent/utils.py @@ -0,0 +1,664 @@ +import asyncio +import json +import logging +import os +import shutil +import subprocess +import tempfile +import traceback +from itertools import product +from shutil import which +from time import sleep, time +from typing import Any, Optional + +import json_repair +import Levenshtein +from html2image import Html2Image +from mistune import html as markdown +from pdf2image import convert_from_path +from PIL import Image as PILImage +from pptx.dml.color import RGBColor +from pptx.oxml import parse_xml +from pptx.parts.image import Image +from pptx.shapes.group import GroupShape +from pptx.text.text import _Paragraph, _Run +from pptx.util import Length, Pt +from tenacity import RetryCallState, retry, stop_after_attempt, wait_fixed + + +def get_logger(name="pptagent", level=None): + """ + Get a logger with the specified name and level. + + Args: + name (str): The name of the logger. + level (int): The logging level (default: logging.INFO). + + Returns: + logging.Logger: A configured logger instance. + """ + if level is None: + level = int(os.environ.get("LOG_LEVEL", logging.INFO)) + + logger = logging.getLogger(name) + logger.setLevel(level) + + # Check if the logger already has handlers to avoid duplicates + if not logger.handlers: + # Create console handler and set level + console_handler = logging.StreamHandler() + console_handler.setLevel(level) + + # Create formatter + formatter = logging.Formatter( + "%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s" + ) + console_handler.setFormatter(formatter) + + # Add handler to logger + logger.addHandler(console_handler) + + return logger + + +logger = get_logger(__name__) + +if which("soffice") is None: + logging.warning("soffice is not installed, pptx to images conversion will not work") + +# Set of supported image extensions +IMAGE_EXTENSIONS: set[str] = { + "bmp", + "jpg", + "jpeg", + "pgm", + "png", + "ppm", + "tif", + "tiff", + "webp", +} + +# Common colors and measurements +BLACK = RGBColor(0, 0, 0) +YELLOW = RGBColor(255, 255, 0) +BLUE = RGBColor(0, 0, 255) +BORDER_LEN = Pt(2) +BORDER_OFFSET = Pt(2) +LABEL_LEN = Pt(24) +FONT_LEN = Pt(20) + + +def is_image_path(file: str) -> bool: + """ + Check if a file path is an image based on its extension. + + Args: + file (str): The file path to check. + + Returns: + bool: True if the file is an image, False otherwise. + """ + return file.split(".")[-1].lower() in IMAGE_EXTENSIONS + + +def runs_merge(paragraph: _Paragraph) -> Optional[_Run]: + """ + Merge all runs in a paragraph into a single run. + + Args: + paragraph (_Paragraph): The paragraph to merge runs in. + + Returns: + Optional[_Run]: The merged run, or None if there are no runs. + """ + runs = paragraph.runs + + # Handle field codes + if len(runs) == 0: + runs = [ + _Run(r, paragraph) + for r in parse_xml(paragraph._element.xml.replace("fld", "r")).r_lst + ] + if len(runs) == 1: + return runs[0] + if len(runs) == 0: + return None + + # Find the run with the most text + run = max(runs, key=lambda x: len(x.text)) + run.text = paragraph.text + + # Remove other runs + for r in runs: + if r != run: + r._r.getparent().remove(r._r) + return run + + +def older_than(filepath: str, seconds: int = 10, wait: bool = False) -> bool: + """ + Check if a file is older than a specified number of seconds. + + Args: + filepath (str): The path to the file. + seconds (int): The number of seconds to check against. + wait (bool): Whether to wait for the file to exist. + + Returns: + bool: True if the file is older than the specified number of seconds, False otherwise. + """ + if not os.path.exists(filepath): + while wait: + logger.info("waiting for: %s", filepath) + sleep(1) + if os.path.exists(filepath): + sleep(seconds) + return True + return False + file_creation_time = os.path.getctime(filepath) + current_time = time() + return seconds < (current_time - file_creation_time) + + +def edit_distance(text1: str, text2: str) -> float: + """ + Calculate the normalized edit distance between two strings. + + Args: + text1 (str): The first string. + text2 (str): The second string. + + Returns: + float: The normalized edit distance (0.0 to 1.0, where 1.0 means identical). + """ + if not text1 and not text2: + return 1.0 + return 1 - Levenshtein.distance(text1, text2) / max(len(text1), len(text2)) + + +def tenacity_log(retry_state: RetryCallState) -> None: + """ + Log function for tenacity retries. + + Args: + retry_state (RetryCallState): The retry state. + """ + logger.warning("tenacity retry: %s", retry_state) + traceback.print_tb(retry_state.outcome.exception().__traceback__) + + +def get_json_from_response(response: str) -> dict[str, Any]: + """ + Extract JSON from a text response. + + Args: + response (str): The response text. + + Returns: + Dict[str, Any]: The extracted JSON. + + Raises: + Exception: If JSON cannot be extracted from the response. + """ + response = response.strip() + + try: + return json.loads(response) + except Exception: + pass + + # Try to extract JSON from markdown code blocks + l, r = response.rfind("```json"), response.rfind("```") + if l != -1 and r != -1: + json_obj = json_repair.loads(response[l + 7 : r].strip()) + if isinstance(json_obj, (dict, list)): + return json_obj + + # Try to find JSON by looking for matching braces + open_braces = [] + close_braces = [] + + for i, char in enumerate(response): + if char == "{" or char == "[": + open_braces.append(i) + elif char == "}" or char == "]": + close_braces.append(i) + + for i, j in product(open_braces, reversed(close_braces)): + if i > j: + continue + try: + json_obj = json_repair.loads(response[i : j + 1]) + if isinstance(json_obj, (dict, list)): + return json_obj + except Exception: + pass + + raise Exception("JSON not found in the given output", response) + + +# Create a tenacity decorator with custom settings +def tenacity_decorator(_func=None, *, wait: int = 3, stop: int = 5): + def decorator(func): + return retry(wait=wait_fixed(wait), stop=stop_after_attempt(stop))(func) + + if _func is None: + # Called with arguments + return decorator + else: + # Called without arguments + return decorator(_func) + + +TABLE_CSS = """ +table { + border-collapse: collapse; /* Merge borders */ + width: auto; /* Width adapts to content */ + font-family: SimHei, Arial, sans-serif; /* Font supporting Chinese characters */ + background: white; +} +th, td { + border: 1px solid black; /* Add borders */ + padding: 8px; /* Cell padding */ + text-align: center; /* Center text */ +} +th { + background-color: #f2f2f2; /* Header background color */ +} +""" + + +# Convert Markdown to HTML +# def markdown_table_to_image(markdown_text: str, output_path: str): +# """ +# Convert a Markdown table to a cropped image + +# Args: +# markdown_text (str): Markdown text containing a table +# output_path (str): Output image path, defaults to 'table_cropped.png' + +# Returns: +# str: The path of the generated image +# """ +# html = markdown(markdown_text) +# assert "table" in html, "Failed to find table in markdown" + +# parent_dir, basename = os.path.split(output_path) +# hti = Html2Image( +# disable_logging=True, +# output_path=parent_dir, +# custom_flags=["--no-sandbox", "--headless"], +# ) +# hti.browser.use_new_headless = None +# hti.screenshot(html_str=html, css_str=TABLE_CSS, save_as=basename) + +# img = PILImage.open(output_path).convert("RGB") +# bbox = img.getbbox() +# assert ( +# bbox is not None +# ), "Failed to capture the bbox, may be markdown table conversion failed" +# bbox = (0, 0, bbox[2] + 10, bbox[3] + 10) +# img.crop(bbox).save(output_path) +# return output_path +import pandas as pd +import matplotlib.pyplot as plt +from io import StringIO +def markdown_table_to_image(markdown_text: str, output_path: str) -> str: + """ + Convert a Markdown table to an image using pandas and matplotlib. + + Args: + markdown_text (str): Markdown text containing a table. + output_path (str): Path to save the output image. + + Returns: + str: The path of the generated image. + """ + # Read markdown into DataFrame + df = pd.read_csv( + StringIO(markdown_text), + sep=r'\|', + engine='python', + skipinitialspace=True + ) + # Remove empty columns that result from leading/trailing pipes + mask = [col.strip() != '' for col in df.columns] + df = df.loc[:, mask] + + # Create figure and axis + fig, ax = plt.subplots() + ax.axis('off') + + # Create table + table = ax.table( + cellText=df.values, + colLabels=df.columns.str.strip(), + cellLoc='center', + loc='center' + ) + table.auto_set_font_size(False) + table.set_fontsize(12) + table.scale(1, 1.5) + + # Save figure + fig.savefig(output_path, bbox_inches='tight', dpi=150) + plt.close(fig) + return output_path + + +@tenacity_decorator +def ppt_to_images(file: str, output_dir: str): + assert pexists(file), f"File {file} does not exist" + if pexists(output_dir): + logger.warning(f"ppt2images: {output_dir} already exists") + os.makedirs(output_dir, exist_ok=True) + with tempfile.TemporaryDirectory() as temp_dir: + command_list = [ + "soffice", + "--headless", + "--convert-to", + "pdf", + file, + "--outdir", + temp_dir, + ] + process = subprocess.Popen( + command_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + out, err = process.communicate() + if process.returncode != 0: + raise RuntimeError(f"soffice failed with error: {err.decode()}") + + for f in os.listdir(temp_dir): + if not f.endswith(".pdf"): + continue + temp_pdf = pjoin(temp_dir, f) + images = convert_from_path(temp_pdf, dpi=72) + for i, img in enumerate(images): + img.save(pjoin(output_dir, f"slide_{i+1:04d}.jpg")) + return + + raise RuntimeError( + f"No PDF file was created in the temporary directory: {file}\n" + f"Output: {out.decode()}\n" + f"Error: {err.decode()}" + ) + + +@tenacity_decorator +async def ppt_to_images_async(file: str, output_dir: str): + assert pexists(file), f"File {file} does not exist" + if pexists(output_dir): + logger.debug(f"ppt2images: {output_dir} already exists") + os.makedirs(output_dir, exist_ok=True) + + with tempfile.TemporaryDirectory() as temp_dir: + command_list = [ + "soffice", + "--headless", + "--convert-to", + "pdf", + file, + "--outdir", + temp_dir, + ] + + process = await asyncio.create_subprocess_exec( + *command_list, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await process.communicate() + if process.returncode != 0: + raise RuntimeError(f"soffice failed with error: {stderr.decode()}") + for f in os.listdir(temp_dir): + if not f.endswith(".pdf"): + continue + temp_pdf = pjoin(temp_dir, f) + images = convert_from_path(temp_pdf, dpi=72) + for i, img in enumerate(images): + img.save(pjoin(output_dir, f"slide_{i+1:04d}.jpg")) + return + + raise RuntimeError( + f"No PDF file was created in the temporary directory: {file}\n" + f"Output: {stdout.decode()}\n" + f"Error: {stderr.decode()}" + ) + + +def parsing_image(image: Image, image_path: str) -> str: + # Handle WMF images (PDFs) + if image.ext == "wmf": + image_path = image_path.replace(".wmf", ".jpg") + if not pexists(image_path): + wmf_to_images(image.blob, image_path) + # Check for supported image types + elif image.ext not in IMAGE_EXTENSIONS: + raise ValueError(f"Unsupported image type {image.ext}") + + # Save image if it doesn't exist + if not pexists(image_path): + with open(image_path, "wb") as f: + f.write(image.blob) + return image_path + + +@tenacity_decorator +def wmf_to_images(blob: bytes, filepath: str): + if not filepath.endswith(".jpg"): + raise ValueError("filepath must end with .jpg") + dirname = os.path.dirname(filepath) + basename = os.path.basename(filepath).removesuffix(".jpg") + with tempfile.TemporaryDirectory() as temp_dir: + with open(pjoin(temp_dir, f"{basename}.wmf"), "wb") as f: + f.write(blob) + command_list = [ + "soffice", + "--headless", + "--convert-to", + "jpg", + pjoin(temp_dir, f"{basename}.wmf"), + "--outdir", + dirname, + ] + subprocess.run(command_list, check=True, stdout=subprocess.DEVNULL) + + assert pexists(filepath), f"File {filepath} does not exist" + + +def parse_groupshape(groupshape: GroupShape) -> list[dict[str, Length]]: + """ + Parse a group shape to get the bounds of its child shapes. + + Args: + groupshape (GroupShape): The group shape to parse. + + Returns: + List[Dict[str, Length]]: The bounds of the child shapes. + + Raises: + AssertionError: If the input is not a GroupShape. + """ + assert isinstance(groupshape, GroupShape), "Input must be a GroupShape" + + # Get group bounds + group_top_left_x = groupshape.left + group_top_left_y = groupshape.top + group_width = groupshape.width + group_height = groupshape.height + + # Get shape bounds + shape_top_left_x = min([sp.left for sp in groupshape.shapes]) + shape_top_left_y = min([sp.top for sp in groupshape.shapes]) + shape_width = ( + max([sp.left + sp.width for sp in groupshape.shapes]) - shape_top_left_x + ) + shape_height = ( + max([sp.top + sp.height for sp in groupshape.shapes]) - shape_top_left_y + ) + + # Calculate bounds for each shape in the group + group_shape_xy = [] + for sp in groupshape.shapes: + group_shape_left = ( + sp.left - shape_top_left_x + ) * group_width / shape_width + group_top_left_x + group_shape_top = ( + sp.top - shape_top_left_y + ) * group_height / shape_height + group_top_left_y + group_shape_width = sp.width * group_width / shape_width + group_shape_height = sp.height * group_height / shape_height + + group_shape_xy.append( + { + "left": Length(group_shape_left), + "top": Length(group_shape_top), + "width": Length(group_shape_width), + "height": Length(group_shape_height), + } + ) + + return group_shape_xy + + +def is_primitive(obj: Any) -> bool: + """ + Check if an object is a primitive type or a collection of primitive types. + + Args: + obj (Any): The object to check. + + Returns: + bool: True if the object is a primitive type or a collection of primitive types, False otherwise. + """ + if isinstance(obj, (list, tuple, set, frozenset)): + return all(is_primitive(item) for item in obj) + + return isinstance( + obj, (int, float, complex, bool, str, bytes, bytearray, type(None)) + ) + + +DEFAULT_EXCLUDE: set[str] = {"element", "language_id", "ln", "placeholder_format"} + + +def dict_to_object( + dict_obj: dict[str, Any], obj: Any, exclude: Optional[set[str]] = None +) -> None: + """ + Apply dictionary values to an object. + + Args: + dict_obj (Dict[str, Any]): The dictionary with values to apply. + obj (Any): The object to apply values to. + exclude (Optional[Set[str]]): The keys to exclude. + """ + if exclude is None: + exclude = set() + + for key, value in dict_obj.items(): + if key not in exclude and value is not None: + setattr(obj, key, value) + + +def package_join(*paths: str) -> str: + """ + Join paths with the appropriate separator for the platform. + + Args: + *paths: The paths to join. + + Returns: + str: The joined path. + """ + _dir = pdirname(__file__) + return pjoin(_dir, *paths) + + +class Config: + """ + Configuration class for the application. + """ + + def __init__( + self, + rundir: Optional[str] = None, + session_id: Optional[str] = None, + ): + """ + Initialize the configuration. + + Args: + rundir (Optional[str]): The run directory. + session_id (Optional[str]): The session ID. + debug (bool): Whether to enable debug mode. + """ + if rundir is not None: + self.set_rundir(rundir) + elif session_id is not None: + self.set_session(session_id) + else: + raise ValueError("No session ID or run directory provided") + + def set_session(self, session_id: str) -> None: + """ + Set the session ID and update the run directory. + + Args: + session_id (str): The session ID. + """ + self.session_id = session_id + self.set_rundir(f"./runs/{session_id}") + + def set_rundir(self, rundir: str) -> None: + """ + Set the run directory and create necessary subdirectories. + + Args: + rundir (str): The run directory. + """ + self.RUN_DIR = rundir + self.IMAGE_DIR = pjoin(self.RUN_DIR, "images") + + for the_dir in [self.RUN_DIR, self.IMAGE_DIR]: + os.makedirs(the_dir, exist_ok=True) + + def set_debug(self, debug: bool) -> None: + """ + Set the debug mode. + + Args: + debug (bool): Whether to enable debug mode. + """ + self.DEBUG = debug + + def remove_rundir(self) -> None: + """ + Remove the run directory and its subdirectories. + """ + if pexists(self.RUN_DIR): + shutil.rmtree(self.RUN_DIR) + if pexists(self.IMAGE_DIR): + shutil.rmtree(self.IMAGE_DIR) + + def __repr__(self) -> str: + """ + Get a string representation of the configuration. + + Returns: + str: A string representation of the configuration. + """ + attrs = [] + for attr in dir(self): + if not attr.startswith("_") and not callable(getattr(self, attr)): + attrs.append(f"{attr}={getattr(self, attr)}") + return f"Config({', '.join(attrs)})" + + +# Path utility functions +pjoin = os.path.join +pexists = os.path.exists +pbasename = os.path.basename +pdirname = os.path.dirname diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..7a68dd7b00bf04de02107921ad790a5fca9c75d4 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,42 @@ +beautifulsoup4 +fastapi +einops +func_argparse +html2image +jinja2 +json_repair +jsonlines +lxml +mistune +marker-pdf==1.1.0 +oaib +openai +opencv-python-headless +pandas +pdf2image +peft +pillow +PyPDF2 +python-Levenshtein +python-multipart +python-pptx @ git+https://github.com/Force1ess/python-pptx@219513d7d81a61961fc541578c1857d08b43aa2a +rich +socksio +tenacity +tiktoken +timm +uvicorn +numpy<2 +setproctitle==1.3.3 +attrdict==2.0.1 +librosa==0.10.2.post1 +langdetect==1.0.9 +pydub==0.25.1 +pyloudnorm==0.1.1 +modelscope==1.22.2 +transformers==4.49.0 +x-transformers==1.44.4 +torchdiffeq==0.2.5 +openai-whisper==20240930 +httpx==0.28.1 +gradio==5.23.1 \ No newline at end of file diff --git a/templates/Template1.pptx b/templates/Template1.pptx new file mode 100644 index 0000000000000000000000000000000000000000..a4663336820e0f6db49b6bd7f26e3bce03ee668c --- /dev/null +++ b/templates/Template1.pptx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3795ad16408515f6f3041fe35547e1cddc3a070e348b5660160f09d63563d343 +size 2877057 diff --git a/templates/Template2.pptx b/templates/Template2.pptx new file mode 100644 index 0000000000000000000000000000000000000000..62938aa70593fc1dc347ceb3575d44cab8338bf9 --- /dev/null +++ b/templates/Template2.pptx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:05b93d432ee3a347e571786e74cac0373d1d45803ad87dcdf079572c25d83f0a +size 125824 diff --git a/templates/Template3.pptx b/templates/Template3.pptx new file mode 100644 index 0000000000000000000000000000000000000000..bd348177e489acd8862b243e61de8c86fda8c938 --- /dev/null +++ b/templates/Template3.pptx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:96084fe919cdbdc91a66d340b0f9594b8f31b318dab24188d846dd2e1514ce84 +size 2472949 diff --git a/templates/previews/Template1.jpg b/templates/previews/Template1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e5d2572025dcccdedf2fddedf2f9db9e3fe660b2 --- /dev/null +++ b/templates/previews/Template1.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8c081af9017bfdc9f4fed547de9e5040ef60b9279ee0464f7ab34b9a7facb726 +size 376407 diff --git a/templates/previews/Template2.jpg b/templates/previews/Template2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..db9abfd707567e3e9502beb27e087a03fbc9cbee Binary files /dev/null and b/templates/previews/Template2.jpg differ diff --git a/templates/previews/Template3.jpg b/templates/previews/Template3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..dbb603031b4cd4221435831aae75071392796338 Binary files /dev/null and b/templates/previews/Template3.jpg differ