| import argparse |
| import copy |
|
|
| import warnings |
| import tensorflow as tf |
| tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) |
| import warnings |
| warnings.filterwarnings('ignore', category=FutureWarning) |
| warnings.filterwarnings('ignore', category=DeprecationWarning) |
| import sys, getopt, os |
|
|
| import numpy as np |
| import dnnlib |
| from dnnlib import EasyDict |
| import dnnlib.tflib as tflib |
| from dnnlib.tflib import tfutil |
| from dnnlib.tflib.autosummary import autosummary |
|
|
| from training import misc |
| import pickle |
| import argparse |
|
|
| def create_model(config_id = 'config-f', gamma = None, height = 512, width = 512, cond = None, label_size = 0): |
| train = EasyDict(run_func_name='training.diagnostic.create_initial_pkl') |
| G = EasyDict(func_name='training.networks_stylegan2.G_main') |
| D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2') |
| D_loss = EasyDict(func_name='training.loss.D_logistic_r1') |
| sched = EasyDict() |
| sc = dnnlib.SubmitConfig() |
| tf_config = {'rnd.np_random_seed': 1000} |
|
|
| sched.minibatch_size_base = 192 |
| sched.minibatch_gpu_base = 3 |
| D_loss.gamma = 10 |
| desc = 'stylegan2' |
|
|
| dataset_args = EasyDict() |
|
|
| if cond: |
| desc += '-cond'; dataset_args.max_label_size = 'full' |
|
|
| desc += '-' + config_id |
|
|
| |
| if config_id != 'config-f': |
| G.fmap_base = D.fmap_base = 8 << 10 |
|
|
| |
| if config_id.startswith('config-e'): |
| D_loss.gamma = 100 |
| if 'Gorig' in config_id: G.architecture = 'orig' |
| if 'Gskip' in config_id: G.architecture = 'skip' |
| if 'Gresnet' in config_id: G.architecture = 'resnet' |
| if 'Dorig' in config_id: D.architecture = 'orig' |
| if 'Dskip' in config_id: D.architecture = 'skip' |
| if 'Dresnet' in config_id: D.architecture = 'resnet' |
|
|
| |
| if config_id in ['config-a', 'config-b', 'config-c', 'config-d']: |
| sched.lod_initial_resolution = 8 |
| sched.G_lrate_base = sched.D_lrate_base = 0.001 |
| sched.G_lrate_dict = sched.D_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003} |
| sched.minibatch_size_base = 32 |
| sched.minibatch_size_dict = {8: 256, 16: 128, 32: 64, 64: 32} |
| sched.minibatch_gpu_base = 4 |
| sched.minibatch_gpu_dict = {8: 32, 16: 16, 32: 8, 64: 4} |
| G.synthesis_func = 'G_synthesis_stylegan_revised' |
| D.func_name = 'training.networks_stylegan2.D_stylegan' |
|
|
| |
| if config_id in ['config-a', 'config-b', 'config-c']: |
| G_loss = EasyDict(func_name='training.loss.G_logistic_ns') |
|
|
| |
| if config_id in ['config-a', 'config-b']: |
| train.lazy_regularization = False |
|
|
| |
| if config_id == 'config-a': |
| G = EasyDict(func_name='training.networks_stylegan.G_style') |
| D = EasyDict(func_name='training.networks_stylegan.D_basic') |
|
|
| if gamma is not None: |
| D_loss.gamma = gamma |
|
|
| G.update(resolution_h=height) |
| G.update(resolution_w=width) |
| D.update(resolution_h=height) |
| D.update(resolution_w=width) |
|
|
| sc.submit_target = dnnlib.SubmitTarget.DIAGNOSTIC |
| sc.local.do_not_copy_source_files = True |
| kwargs = EasyDict(train) |
| |
| kwargs.update(G_args=G, D_args=D, tf_config=tf_config, config_id=config_id, |
| resolution_h=height, resolution_w=width, label_size = label_size) |
| kwargs.submit_config = copy.deepcopy(sc) |
| kwargs.submit_config.run_desc = desc |
| dnnlib.submit_diagnostic(**kwargs) |
| return f'network-initial-config-f-{height}x{width}-{label_size}.pkl' |
|
|
| def _str_to_bool(v): |
| if isinstance(v, bool): |
| return v |
| if v.lower() in ('yes', 'true', 't', 'y', '1'): |
| return True |
| elif v.lower() in ('no', 'false', 'f', 'n', '0'): |
| return False |
| else: |
| raise argparse.ArgumentTypeError('Boolean value expected.') |
|
|
| def _parse_comma_sep(s): |
| if s is None or s.lower() == 'none' or s == '': |
| return [] |
| return s.split(',') |
|
|
| def copy_weights(source_pkl, target_pkl, output_pkl): |
|
|
| tflib.init_tf() |
|
|
| with tf.Session(): |
| with tf.device('/gpu:0'): |
|
|
| sourceG, sourceD, sourceGs = pickle.load(open(source_pkl, 'rb')) |
| targetG, targetD, targetGs = pickle.load(open(target_pkl, 'rb')) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| targetG.copy_compatible_trainables_from(sourceG) |
| targetD.copy_compatible_trainables_from(sourceD) |
| targetGs.copy_compatible_trainables_from(sourceGs) |
|
|
| with open(os.path.join('./', output_pkl), 'wb') as file: |
| pickle.dump((targetG, targetD, targetGs), file, protocol=pickle.HIGHEST_PROTOCOL) |