Spaces:
Running
Running
File size: 4,263 Bytes
6af2871 af2b6f4 6af2871 a92b815 af2b6f4 459dab4 a92b815 6af2871 055797a 6af2871 055797a 6af2871 a92b815 6af2871 a92b815 459dab4 af2b6f4 a92b815 6af2871 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
docker build -t cc_audio_8:v20250828_1343 .
docker stop cc_audio_8_7864 && docker rm cc_audio_8_7864
docker run -itd \
--name cc_audio_8_7864 \
--restart=always \
--network host \
-e server_port=7865 \
cc_audio_8:v20250828_1343 /bin/bash
docker run -itd \
--name cc_audio_8_7864 \
--network host \
--gpus all \
--privileged \
--ipc=host \
python:3.12 /bin/bash
nohup python3 main.py --server_port 7864 --hf_token hf_coRVvzwA****jLmZHwJobEX &
"""
import argparse
from functools import lru_cache
from pathlib import Path
import platform
import shutil
import tempfile
import zipfile
from typing import Tuple
import gradio as gr
from huggingface_hub import snapshot_download
import numpy as np
import torch
from project_settings import environment, project_path
from toolbox.torch.utils.data.vocabulary import Vocabulary
from tabs.cls_tab import get_cls_tab
from tabs.split_tab import get_split_tab
from tabs.event_tab import get_event_tab
from tabs.shell_tab import get_shell_tab
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--examples_dir",
# default=(project_path / "data").as_posix(),
default=(project_path / "data/examples").as_posix(),
type=str
)
parser.add_argument(
"--models_repo_id",
default="qgyd2021/cc_audio_8",
type=str
)
parser.add_argument(
"--trained_model_dir",
default=(project_path / "trained_models").as_posix(),
type=str
)
parser.add_argument(
"--hf_token",
default=environment.get("hf_token"),
type=str,
)
parser.add_argument(
"--server_port",
default=environment.get("server_port", 7860),
type=int
)
args = parser.parse_args()
return args
@lru_cache(maxsize=100)
def load_model(model_file: Path):
with zipfile.ZipFile(model_file, "r") as f_zip:
out_root = Path(tempfile.gettempdir()) / "cc_audio_8"
if out_root.exists():
shutil.rmtree(out_root.as_posix())
out_root.mkdir(parents=True, exist_ok=True)
f_zip.extractall(path=out_root)
tgt_path = out_root / model_file.stem
jit_model_file = tgt_path / "trace_model.zip"
vocab_path = tgt_path / "vocabulary"
vocabulary = Vocabulary.from_files(vocab_path.as_posix())
with open(jit_model_file.as_posix(), "rb") as f:
model = torch.jit.load(f)
model.eval()
shutil.rmtree(tgt_path)
d = {
"model": model,
"vocabulary": vocabulary
}
return d
def main():
args = get_args()
examples_dir = Path(args.examples_dir)
trained_model_dir = Path(args.trained_model_dir)
# download models
if not trained_model_dir.exists():
trained_model_dir.mkdir(parents=True, exist_ok=True)
_ = snapshot_download(
repo_id=args.models_repo_id,
local_dir=trained_model_dir.as_posix(),
token=args.hf_token,
)
# examples zip
if not examples_dir.exists():
example_zip_file = trained_model_dir / "examples.zip"
with zipfile.ZipFile(example_zip_file.as_posix(), "r") as f_zip:
out_root = examples_dir
if out_root.exists():
shutil.rmtree(out_root.as_posix())
out_root.mkdir(parents=True, exist_ok=True)
f_zip.extractall(path=out_root)
# ui
with gr.Blocks() as blocks:
with gr.Tabs():
_ = get_cls_tab(
examples_dir=args.examples_dir,
trained_model_dir=args.trained_model_dir,
)
_ = get_event_tab(
examples_dir=args.examples_dir,
trained_model_dir=args.trained_model_dir,
)
_ = get_split_tab(
examples_dir=args.examples_dir,
trained_model_dir=args.trained_model_dir,
)
_ = get_shell_tab()
# http://127.0.0.1:7864/
blocks.queue().launch(
share=False if platform.system() == "Windows" else False,
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
server_port=args.server_port
)
return
if __name__ == "__main__":
main()
|