Spaces:
Running
Running
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| import argparse | |
| import json | |
| from functools import lru_cache, partial | |
| from pathlib import Path | |
| import shutil | |
| import tempfile | |
| import time | |
| import zipfile | |
| from typing import Tuple | |
| import gradio as gr | |
| import torch | |
| from project_settings import project_path | |
| from toolbox.torch.utils.data.vocabulary import Vocabulary | |
| def load_model(model_file: Path): | |
| with zipfile.ZipFile(model_file, "r") as f_zip: | |
| out_root = Path(tempfile.gettempdir()) / "cc_audio_8" | |
| if out_root.exists(): | |
| shutil.rmtree(out_root.as_posix()) | |
| out_root.mkdir(parents=True, exist_ok=True) | |
| f_zip.extractall(path=out_root) | |
| tgt_path = out_root / model_file.stem | |
| jit_model_file = tgt_path / "trace_model.zip" | |
| vocab_path = tgt_path / "vocabulary" | |
| vocabulary = Vocabulary.from_files(vocab_path.as_posix()) | |
| with open(jit_model_file.as_posix(), "rb") as f: | |
| model = torch.jit.load(f) | |
| model.eval() | |
| shutil.rmtree(tgt_path) | |
| d = { | |
| "model": model, | |
| "vocabulary": vocabulary | |
| } | |
| return d | |
| def when_click_cls_button(audio_t, | |
| model_name: str, | |
| ground_true: str) -> Tuple[str, float]: | |
| sample_rate, signal = audio_t | |
| model_file = project_path / f"trained_models/{model_name}.zip" | |
| d = load_model(model_file) | |
| model = d["model"] | |
| vocabulary = d["vocabulary"] | |
| inputs = signal / (1 << 15) | |
| inputs = torch.tensor(inputs, dtype=torch.float32) | |
| inputs = torch.unsqueeze(inputs, dim=0) | |
| time_begin = time.time() | |
| with torch.no_grad(): | |
| logits = model.forward(inputs) | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| label_idx = torch.argmax(probs, dim=-1) | |
| time_cost = time.time() - time_begin | |
| label_idx = label_idx.cpu() | |
| probs = probs.cpu() | |
| label_idx = label_idx.numpy()[0] | |
| prob = probs.numpy()[0][label_idx] | |
| label_str = vocabulary.get_token_from_index(label_idx, namespace="labels") | |
| result = { | |
| "label": label_str, | |
| "prob": round(float(prob), 4), | |
| "time_cost": round(time_cost, 4), | |
| } | |
| result = json.dumps(result, ensure_ascii=False, indent=4) | |
| return result | |
| def when_model_name_change(model_name: str, cls_trained_model_dir: Path): | |
| m = load_model( | |
| model_file=(cls_trained_model_dir / f"{model_name}.zip") | |
| ) | |
| token_to_index: dict = m["vocabulary"].get_token_to_index_vocabulary(namespace="labels") | |
| label_choices = list(token_to_index.keys()) | |
| split_label = gr.Dropdown(choices=label_choices, value=label_choices[0], label="label") | |
| return split_label | |
| def get_cls_tab(examples_dir: str, trained_model_dir: str): | |
| cls_examples_dir = Path(examples_dir) | |
| cls_trained_model_dir = Path(trained_model_dir) | |
| # models | |
| cls_model_choices = list() | |
| for filename in cls_trained_model_dir.glob("*.zip"): | |
| model_name = filename.stem | |
| if model_name == "examples": | |
| continue | |
| cls_model_choices.append(model_name) | |
| model_choices = list(sorted(cls_model_choices)) | |
| # model_labels_choices | |
| m = load_model( | |
| model_file=(cls_trained_model_dir / f"{model_choices[0]}.zip") | |
| ) | |
| token_to_index = m["vocabulary"].get_token_to_index_vocabulary(namespace="labels") | |
| model_labels_choices = list(token_to_index.keys()) | |
| # examples zip | |
| cls_example_zip_file = cls_trained_model_dir / "examples.zip" | |
| with zipfile.ZipFile(cls_example_zip_file.as_posix(), "r") as f_zip: | |
| out_root = cls_examples_dir | |
| if out_root.exists(): | |
| shutil.rmtree(out_root.as_posix()) | |
| out_root.mkdir(parents=True, exist_ok=True) | |
| f_zip.extractall(path=out_root) | |
| # examples | |
| cls_examples = list() | |
| for filename in cls_examples_dir.glob("**/*/*.wav"): | |
| label = filename.parts[-2] | |
| cls_examples.append([ | |
| filename.as_posix(), | |
| model_choices[0], | |
| label | |
| ]) | |
| with gr.TabItem("cls"): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| cls_audio = gr.Audio(label="audio") | |
| with gr.Row(): | |
| cls_model_name = gr.Dropdown(choices=model_choices, value=model_choices[0], label="model_name") | |
| cls_label = gr.Dropdown(choices=model_labels_choices, value=model_labels_choices[0], label="label") | |
| with gr.Row(): | |
| cls_ground_true = gr.Textbox(label="ground_true") | |
| cls_button = gr.Button("run", variant="primary") | |
| with gr.Column(scale=3): | |
| cls_outputs = gr.Textbox(label="outputs") | |
| gr.Examples( | |
| cls_examples, | |
| inputs=[cls_audio, cls_model_name, cls_ground_true], | |
| outputs=[cls_outputs], | |
| fn=when_click_cls_button, | |
| examples_per_page=5, | |
| ) | |
| cls_model_name.change( | |
| partial(when_model_name_change, cls_trained_model_dir=cls_trained_model_dir), | |
| inputs=[cls_model_name], | |
| outputs=[cls_label], | |
| ) | |
| cls_button.click( | |
| when_click_cls_button, | |
| inputs=[cls_audio, cls_model_name, cls_ground_true], | |
| outputs=[cls_outputs], | |
| ) | |
| return locals() | |
| if __name__ == "__main__": | |
| pass | |