GestureLSM / models /config.py
Tharun156's picture
Upload 149 files
f7400bf verified
raw
history blame
2.23 kB
import os
import importlib
from typing import Type, TypeVar
from argparse import ArgumentParser
from omegaconf import OmegaConf, DictConfig
def get_module_config(cfg_model: DictConfig, paths: list[str], cfg_root: str) -> DictConfig:
files = [os.path.join(cfg_root, 'modules', p+'.yaml') for p in paths]
for file in files:
assert os.path.exists(file), f'{file} is not exists.'
with open(file, 'r') as f:
cfg_model.merge_with(OmegaConf.load(f))
return cfg_model
def get_obj_from_str(string: str, reload: bool = False) -> Type:
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def instantiate_from_config(config: DictConfig) -> TypeVar:
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def parse_args() -> DictConfig:
parser = ArgumentParser()
parser.add_argument("--cfg", type=str, required=True, help="The main config file")
parser.add_argument('--example', type=str, required=False, help="The input texts and lengths with txt format")
parser.add_argument('--example_hint', type=str, required=False, help="The input hint ids and lengths with txt format")
parser.add_argument('--no-plot', action="store_true", required=False, help="Whether to plot the skeleton-based motion")
parser.add_argument('--replication', type=int, default=1, help="The number of replications of sampling")
parser.add_argument('--vis', type=str, default="tb", choices=['tb', 'swanlab'], help="The visualization backends: tensorboard or swanlab")
parser.add_argument('--optimize', action='store_true', help="Enable optimization for motion control")
args = parser.parse_args()
cfg = OmegaConf.load(args.cfg)
cfg_root = os.path.dirname(args.cfg)
cfg_model = get_module_config(cfg.model, cfg.model.target, cfg_root)
cfg = OmegaConf.merge(cfg, cfg_model)
cfg.example = args.example
cfg.example_hint = args.example_hint
cfg.no_plot = args.no_plot
cfg.replication = args.replication
cfg.vis = args.vis
cfg.optimize = args.optimize
return cfg