|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
import logging |
|
|
import os |
|
|
import shutil |
|
|
from collections import OrderedDict, abc |
|
|
from datetime import datetime |
|
|
|
|
|
LOGGER = logging.getLogger("main") |
|
|
|
|
|
|
|
|
def construct_path(output_dir: str, exp_name: str) -> dict: |
|
|
proj_root = os.path.join(output_dir, exp_name) |
|
|
exp_idx = 0 |
|
|
exp_output_dir = os.path.join(proj_root, f"exp_{exp_idx}") |
|
|
while os.path.exists(exp_output_dir): |
|
|
exp_idx += 1 |
|
|
exp_output_dir = os.path.join(proj_root, f"exp_{exp_idx}") |
|
|
|
|
|
tb_path = os.path.join(exp_output_dir, "tb") |
|
|
save_path = os.path.join(exp_output_dir, "pre") |
|
|
pth_path = os.path.join(exp_output_dir, "pth") |
|
|
|
|
|
final_full_model_path = os.path.join(pth_path, "checkpoint_final.pth") |
|
|
final_state_path = os.path.join(pth_path, "state_final.pth") |
|
|
|
|
|
log_path = os.path.join(exp_output_dir, f"log_{str(datetime.now())[:10]}.txt") |
|
|
cfg_copy_path = os.path.join(exp_output_dir, f"config.py") |
|
|
trainer_copy_path = os.path.join(exp_output_dir, f"trainer.txt") |
|
|
excel_path = os.path.join(exp_output_dir, f"results.xlsx") |
|
|
|
|
|
path_config = { |
|
|
"output_dir": output_dir, |
|
|
"pth_log": exp_output_dir, |
|
|
"tb": tb_path, |
|
|
"save": save_path, |
|
|
"pth": pth_path, |
|
|
"final_full_net": final_full_model_path, |
|
|
"final_state_net": final_state_path, |
|
|
"log": log_path, |
|
|
"cfg_copy": cfg_copy_path, |
|
|
"excel": excel_path, |
|
|
"trainer_copy": trainer_copy_path, |
|
|
} |
|
|
|
|
|
return path_config |
|
|
|
|
|
|
|
|
def construct_exp_name(model_name: str, cfg: dict): |
|
|
|
|
|
focus_item = OrderedDict( |
|
|
{ |
|
|
"train/batch_size": "bs", |
|
|
"train/lr": "lr", |
|
|
"train/num_epochs": "e", |
|
|
"train/num_iters": "i", |
|
|
"train/data/shape/h": "h", |
|
|
"train/data/shape/w": "w", |
|
|
"train/optimizer/mode": "opm", |
|
|
"train/optimizer/group_mode": "opgm", |
|
|
"train/scheduler/mode": "sc", |
|
|
"train/scheduler/warmup/num_iters": "wu", |
|
|
"train/use_amp": "amp", |
|
|
} |
|
|
) |
|
|
config = copy.deepcopy(cfg) |
|
|
|
|
|
def _format_item(_i): |
|
|
if isinstance(_i, bool): |
|
|
_i = "" if _i else "false" |
|
|
elif isinstance(_i, (int, float)): |
|
|
if _i == 0: |
|
|
_i = "false" |
|
|
elif isinstance(_i, (list, tuple)): |
|
|
_i = "" if _i else "false" |
|
|
elif isinstance(_i, str): |
|
|
if "_" in _i: |
|
|
_i = _i.replace("_", "").lower() |
|
|
elif _i is None: |
|
|
_i = "none" |
|
|
|
|
|
return _i |
|
|
|
|
|
if (epoch_based := config.train.get("epoch_based", None)) is not None and (not epoch_based): |
|
|
focus_item.pop("train/num_epochs") |
|
|
else: |
|
|
|
|
|
focus_item.pop("train/num_iters") |
|
|
|
|
|
exp_names = [model_name] |
|
|
for key, alias in focus_item.items(): |
|
|
item = get_value_recurse(keys=key.split("/"), info=config) |
|
|
formatted_item = _format_item(item) |
|
|
if formatted_item == "false": |
|
|
continue |
|
|
exp_names.append(f"{alias.upper()}{formatted_item}") |
|
|
|
|
|
info = config.get("info", None) |
|
|
if info: |
|
|
exp_names.append(f"INFO{info.lower()}") |
|
|
|
|
|
return "_".join(exp_names) |
|
|
|
|
|
|
|
|
def pre_mkdir(path_config): |
|
|
|
|
|
check_mkdir(path_config["pth_log"]) |
|
|
make_log(path_config["log"], f"=== log {datetime.now()} ===") |
|
|
|
|
|
|
|
|
check_mkdir(path_config["save"]) |
|
|
check_mkdir(path_config["pth"]) |
|
|
|
|
|
|
|
|
def check_mkdir(dir_name, delete_if_exists=False): |
|
|
if not os.path.exists(dir_name): |
|
|
os.makedirs(dir_name) |
|
|
else: |
|
|
if delete_if_exists: |
|
|
print(f"{dir_name} will be re-created!!!") |
|
|
shutil.rmtree(dir_name) |
|
|
os.makedirs(dir_name) |
|
|
|
|
|
|
|
|
def make_log(path, context): |
|
|
with open(path, "a") as log: |
|
|
log.write(f"{context}\n") |
|
|
|
|
|
|
|
|
def iterate_nested_sequence(nested_sequence): |
|
|
""" |
|
|
当前支持list/tuple/int/float/range()的多层嵌套,注意不要嵌套的太深,小心超出python默认的最大递归深度 |
|
|
|
|
|
例子 |
|
|
:: |
|
|
|
|
|
for x in iterate_nested_sequence([[1, (2, 3)], range(3, 10), 0]): |
|
|
print(x) |
|
|
|
|
|
1 |
|
|
2 |
|
|
3 |
|
|
3 |
|
|
4 |
|
|
5 |
|
|
6 |
|
|
7 |
|
|
8 |
|
|
9 |
|
|
0 |
|
|
|
|
|
:param nested_sequence: 多层嵌套的序列 |
|
|
:return: generator |
|
|
""" |
|
|
for item in nested_sequence: |
|
|
if isinstance(item, (int, float)): |
|
|
yield item |
|
|
elif isinstance(item, (list, tuple, range)): |
|
|
yield from iterate_nested_sequence(item) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
def get_value_recurse(keys: list, info: dict): |
|
|
curr_key, sub_keys = keys[0], keys[1:] |
|
|
|
|
|
if (sub_info := info.get(curr_key, "NoKey")) == "NoKey": |
|
|
raise KeyError(f"{curr_key} must be contained in {info}") |
|
|
|
|
|
if sub_keys: |
|
|
return get_value_recurse(keys=sub_keys, info=sub_info) |
|
|
else: |
|
|
return sub_info |
|
|
|
|
|
|
|
|
def mapping_to_str(mapping: abc.Mapping, *, prefix: str = " ", lvl: int = 0, max_lvl: int = 1) -> str: |
|
|
""" |
|
|
Print the structural information of the dict. |
|
|
""" |
|
|
sub_lvl = lvl + 1 |
|
|
cur_prefix = prefix * lvl |
|
|
sub_prefix = prefix * sub_lvl |
|
|
|
|
|
if lvl == max_lvl: |
|
|
sub_items = str(mapping) |
|
|
else: |
|
|
sub_items = ["{"] |
|
|
for k, v in mapping.items(): |
|
|
sub_item = sub_prefix + k + ": " |
|
|
if isinstance(v, abc.Mapping): |
|
|
sub_item += mapping_to_str(v, prefix=prefix, lvl=sub_lvl, max_lvl=max_lvl) |
|
|
else: |
|
|
sub_item += str(v) |
|
|
sub_items.append(sub_item) |
|
|
sub_items.append(cur_prefix + "}") |
|
|
sub_items = "\n".join(sub_items) |
|
|
return sub_items |
|
|
|