cc_audio_8 / tabs /event_tab.py
HoneyTian's picture
pdate
459dab4
raw
history blame
6.02 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import json
from functools import lru_cache, partial
from pathlib import Path
import shutil
import tempfile
import zipfile
from typing import Tuple
import gradio as gr
import torch
from project_settings import project_path
from toolbox.torch.utils.data.vocabulary import Vocabulary
@lru_cache(maxsize=100)
def load_model(model_file: Path):
with zipfile.ZipFile(model_file, "r") as f_zip:
out_root = Path(tempfile.gettempdir()) / "cc_audio_8"
if out_root.exists():
shutil.rmtree(out_root.as_posix())
out_root.mkdir(parents=True, exist_ok=True)
f_zip.extractall(path=out_root)
tgt_path = out_root / model_file.stem
jit_model_file = tgt_path / "trace_model.zip"
vocab_path = tgt_path / "vocabulary"
vocabulary = Vocabulary.from_files(vocab_path.as_posix())
with open(jit_model_file.as_posix(), "rb") as f:
model = torch.jit.load(f)
model.eval()
shutil.rmtree(tgt_path)
d = {
"model": model,
"vocabulary": vocabulary
}
return d
def when_click_event_button(audio_t,
model_name: str, target_label: str,
win_size: float, win_step: float,
max_duration: float
) -> Tuple[str, float]:
sample_rate, signal = audio_t
model_file = project_path / f"trained_models/{model_name}.zip"
d = load_model(model_file)
model = d["model"]
vocabulary = d["vocabulary"]
inputs = signal / (1 << 15)
inputs = torch.tensor(inputs, dtype=torch.float32)
inputs = torch.unsqueeze(inputs, dim=0)
# inputs shape: (1, num_samples)
win_size = int(win_size * sample_rate)
win_step = int(win_step * sample_rate)
max_duration = int(max_duration * sample_rate)
outputs = list()
with torch.no_grad():
for begin in range(0, (max_duration-win_size+1), win_step):
end = begin + win_size
sub_inputs = inputs[:, begin:end]
if sub_inputs.shape[-1] < win_size:
break
logits = model.forward(sub_inputs)
probs = torch.nn.functional.softmax(logits, dim=-1)
label_idx = torch.argmax(probs, dim=-1)
label_idx = label_idx.cpu()
probs = probs.cpu()
label_idx = label_idx.numpy()[0]
prob = probs.numpy()[0][label_idx]
prob: float = round(float(prob), 4)
label_str: str = vocabulary.get_token_from_index(label_idx, namespace="labels")
outputs.append({
"label": label_str,
"prob": prob,
})
outputs = json.dumps(outputs, ensure_ascii=False, indent=4)
return outputs
def when_model_name_change(model_name: str, event_trained_model_dir: Path):
m = load_model(
model_file=(event_trained_model_dir / f"{model_name}.zip")
)
token_to_index: dict = m["vocabulary"].get_token_to_index_vocabulary(namespace="labels")
label_choices = list(token_to_index.keys())
split_label = gr.Dropdown(choices=label_choices, value=label_choices[0], label="label")
return split_label
def get_event_tab(examples_dir: str, trained_model_dir: str):
event_examples_dir = Path(examples_dir)
event_trained_model_dir = Path(trained_model_dir)
# models
event_model_choices = list()
for filename in event_trained_model_dir.glob("*.zip"):
model_name = filename.stem
if model_name == "examples":
continue
event_model_choices.append(model_name)
model_choices = list(sorted(event_model_choices))
# model_labels_choices
m = load_model(
model_file=(event_trained_model_dir / f"{model_choices[0]}.zip")
)
token_to_index = m["vocabulary"].get_token_to_index_vocabulary(namespace="labels")
model_labels_choices = list(token_to_index.keys())
# examples zip
event_example_zip_file = event_trained_model_dir / "examples.zip"
with zipfile.ZipFile(event_example_zip_file.as_posix(), "r") as f_zip:
out_root = event_examples_dir
if out_root.exists():
shutil.rmtree(out_root.as_posix())
out_root.mkdir(parents=True, exist_ok=True)
f_zip.extractall(path=out_root)
# examples
event_examples = list()
for filename in event_examples_dir.glob("**/*/*.wav"):
label = filename.parts[-2]
event_examples.append([
filename.as_posix(),
model_choices[0],
label
])
with gr.TabItem("event"):
with gr.Row():
with gr.Column(scale=3):
event_audio = gr.Audio(label="audio")
with gr.Row():
event_model_name = gr.Dropdown(choices=model_choices, value=model_choices[0], label="model_name")
event_label = gr.Dropdown(choices=model_labels_choices, value=model_labels_choices[0], label="label")
with gr.Row():
event_win_size = gr.Number(value=2.0, minimum=0, maximum=5, step=0.05, label="win_size")
event_win_step = gr.Number(value=2.0, minimum=0, maximum=5, step=0.05, label="win_step")
event_max_duration = gr.Number(value=8, minimum=0, maximum=15, step=1, label="max_duration")
event_button = gr.Button("run", variant="primary")
with gr.Column(scale=3):
event_outputs = gr.Textbox(label="outputs")
event_model_name.change(
partial(when_model_name_change, event_trained_model_dir=event_trained_model_dir),
inputs=[event_model_name],
outputs=[event_label],
)
event_button.click(
when_click_event_button,
inputs=[event_audio, event_model_name, event_label, event_win_size, event_win_step, event_max_duration],
outputs=[event_outputs],
)
return locals()
if __name__ == "__main__":
pass