Spaces:
Runtime error
Runtime error
| import torch | |
| from diffusers import StableDiffusionPipeline | |
| import numpy as np | |
| import abc | |
| import time_utils | |
| import copy | |
| import os | |
| from train_funcs import TRAIN_FUNC_DICT | |
| ## get arguments for our script | |
| with_to_k = True | |
| with_augs = True | |
| train_func = "train_closed_form" | |
| ### load model | |
| LOW_RESOURCE = True | |
| NUM_DIFFUSION_STEPS = 50 | |
| GUIDANCE_SCALE = 7.5 | |
| MAX_NUM_WORDS = 77 | |
| device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') | |
| ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(device) | |
| tokenizer = ldm_stable.tokenizer | |
| ### get layers | |
| ca_layers = [] | |
| def append_ca(net_): | |
| if net_.__class__.__name__ == 'CrossAttention': | |
| ca_layers.append(net_) | |
| elif hasattr(net_, 'children'): | |
| for net__ in net_.children(): | |
| append_ca(net__) | |
| sub_nets = ldm_stable.unet.named_children() | |
| for net in sub_nets: | |
| if "down" in net[0]: | |
| append_ca(net[1]) | |
| elif "up" in net[0]: | |
| append_ca(net[1]) | |
| elif "mid" in net[0]: | |
| append_ca(net[1]) | |
| ### get projection matrices | |
| ca_clip_layers = [l for l in ca_layers if l.to_v.in_features == 768] | |
| projection_matrices = [l.to_v for l in ca_clip_layers] | |
| og_matrices = [copy.deepcopy(l.to_v) for l in ca_clip_layers] | |
| if with_to_k: | |
| projection_matrices = projection_matrices + [l.to_k for l in ca_clip_layers] | |
| og_matrices = og_matrices + [copy.deepcopy(l.to_k) for l in ca_clip_layers] | |
| def edit_model(old_text_, new_text_, lamb=0.1): | |
| #### restart LDM parameters | |
| num_ca_clip_layers = len(ca_clip_layers) | |
| for idx_, l in enumerate(ca_clip_layers): | |
| l.to_v = copy.deepcopy(og_matrices[idx_]) | |
| projection_matrices[idx_] = l.to_v | |
| if with_to_k: | |
| l.to_k = copy.deepcopy(og_matrices[num_ca_clip_layers + idx_]) | |
| projection_matrices[num_ca_clip_layers + idx_] = l.to_k | |
| try: | |
| #### set up sentences | |
| old_texts = [old_text_] | |
| new_texts = [new_text_] | |
| if with_augs: | |
| base = old_texts[0] if old_texts[0][0:1] != "A" else "a" + old_texts[0][1:] | |
| old_texts.append("A photo of " + base) | |
| old_texts.append("An image of " + base) | |
| old_texts.append("A picture of " + base) | |
| base = new_texts[0] if new_texts[0][0:1] != "A" else "a" + new_texts[0][1:] | |
| new_texts.append("A photo of " + base) | |
| new_texts.append("An image of " + base) | |
| new_texts.append("A picture of " + base) | |
| #### prepare input k* and v* | |
| old_embs, new_embs = [], [] | |
| for old_text, new_text in zip(old_texts, new_texts): | |
| text_input = ldm_stable.tokenizer( | |
| [old_text, new_text], | |
| padding="max_length", | |
| max_length=ldm_stable.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_embeddings = ldm_stable.text_encoder(text_input.input_ids.to(ldm_stable.device))[0] | |
| old_emb, new_emb = text_embeddings | |
| old_embs.append(old_emb) | |
| new_embs.append(new_emb) | |
| #### indetify corresponding destinations for each token in old_emb | |
| idxs_replaces = [] | |
| for old_text, new_text in zip(old_texts, new_texts): | |
| tokens_a = tokenizer(old_text).input_ids | |
| tokens_b = tokenizer(new_text).input_ids | |
| tokens_a = [tokenizer.encode("a ")[1] if tokenizer.decode(t) == 'an' else t for t in tokens_a] | |
| tokens_b = [tokenizer.encode("a ")[1] if tokenizer.decode(t) == 'an' else t for t in tokens_b] | |
| num_orig_tokens = len(tokens_a) | |
| num_new_tokens = len(tokens_b) | |
| idxs_replace = [] | |
| j = 0 | |
| for i in range(num_orig_tokens): | |
| curr_token = tokens_a[i] | |
| while tokens_b[j] != curr_token: | |
| j += 1 | |
| idxs_replace.append(j) | |
| j += 1 | |
| while j < 77: | |
| idxs_replace.append(j) | |
| j += 1 | |
| while len(idxs_replace) < 77: | |
| idxs_replace.append(76) | |
| idxs_replaces.append(idxs_replace) | |
| #### prepare batch: for each pair of setences, old context and new values | |
| contexts, valuess = [], [] | |
| for old_emb, new_emb, idxs_replace in zip(old_embs, new_embs, idxs_replaces): | |
| context = old_emb.detach() | |
| values = [] | |
| with torch.no_grad(): | |
| for layer in projection_matrices: | |
| values.append(layer(new_emb[idxs_replace]).detach()) | |
| contexts.append(context) | |
| valuess.append(values) | |
| #### define training function | |
| train = TRAIN_FUNC_DICT[train_func] | |
| #### train the model | |
| train(ldm_stable, projection_matrices, og_matrices, contexts, valuess, old_texts, new_texts, lamb=lamb) | |
| return f"<b>Current model status:</b> Edited \"{old_text_}\" into \"{new_text_}\"" | |
| except: | |
| return "<b>Current model status:</b> An error occured" | |
| def generate_for_text(test_text): | |
| g = torch.Generator(device='cpu') | |
| g.seed() | |
| images = time_utils.text2image_ldm_stable(ldm_stable, [test_text], latent=None, num_inference_steps=NUM_DIFFUSION_STEPS, guidance_scale=GUIDANCE_SCALE, generator=g, low_resource=LOW_RESOURCE) | |
| return time_utils.view_images(images) | |