Spaces:
Runtime error
Runtime error
Commit
·
95e7ff7
1
Parent(s):
db5f6ff
Update models/unet.py
Browse files- models/unet.py +3 -105
models/unet.py
CHANGED
|
@@ -610,112 +610,10 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
|
| 610 |
# config["num_class_embeds"] = 100
|
| 611 |
|
| 612 |
from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin
|
| 613 |
-
|
| 614 |
-
# {'_class_name': 'UNet3DConditionModel',
|
| 615 |
-
# '_diffusers_version': '0.2.2',
|
| 616 |
-
# 'act_fn': 'silu',
|
| 617 |
-
# 'attention_head_dim': 8,
|
| 618 |
-
# 'block_out_channels': [320, 640, 1280, 1280],
|
| 619 |
-
# 'center_input_sample': False,
|
| 620 |
-
# 'cross_attention_dim': 768,
|
| 621 |
-
# 'down_block_types':
|
| 622 |
-
# ['CrossAttnDownBlock3D',
|
| 623 |
-
# 'CrossAttnDownBlock3D',
|
| 624 |
-
# 'CrossAttnDownBlock3D',
|
| 625 |
-
# 'DownBlock3D'],
|
| 626 |
-
# 'downsample_padding': 1,
|
| 627 |
-
# 'flip_sin_to_cos': True,
|
| 628 |
-
# 'freq_shift': 0,
|
| 629 |
-
# 'in_channels': 4,
|
| 630 |
-
# 'layers_per_block': 2,
|
| 631 |
-
# 'mid_block_scale_factor': 1,
|
| 632 |
-
# 'norm_eps': 1e-05,
|
| 633 |
-
# 'norm_num_groups': 32,
|
| 634 |
-
# 'out_channels': 4,
|
| 635 |
-
# 'sample_size': 64,
|
| 636 |
-
# 'up_block_types':
|
| 637 |
-
# ['UpBlock3D',
|
| 638 |
-
# 'CrossAttnUpBlock3D',
|
| 639 |
-
# 'CrossAttnUpBlock3D',
|
| 640 |
-
# 'CrossAttnUpBlock3D']}
|
| 641 |
|
| 642 |
model = cls.from_config(config)
|
| 643 |
-
# model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
| 644 |
-
# if not os.path.isfile(model_file):
|
| 645 |
-
# raise RuntimeError(f"{model_file} does not exist")
|
| 646 |
-
# state_dict = torch.load(model_file, map_location="cpu")
|
| 647 |
-
|
| 648 |
-
# if use_concat:
|
| 649 |
-
# new_state_dict = {}
|
| 650 |
-
# conv_in_weight = state_dict["conv_in.weight"]
|
| 651 |
-
# new_conv_weight = torch.zeros((conv_in_weight.shape[0], 9, *conv_in_weight.shape[2:]), dtype=conv_in_weight.dtype)
|
| 652 |
-
|
| 653 |
-
# for i, j in zip([0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 6, 7, 8]):
|
| 654 |
-
# new_conv_weight[:, j] = conv_in_weight[:, i]
|
| 655 |
-
# new_state_dict["conv_in.weight"] = new_conv_weight
|
| 656 |
-
# new_state_dict["conv_in.bias"] = state_dict["conv_in.bias"]
|
| 657 |
-
# for k, v in model.state_dict().items():
|
| 658 |
-
# # print(k)
|
| 659 |
-
# if '_temp.' in k:
|
| 660 |
-
# new_state_dict.update({k: v})
|
| 661 |
-
# if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
|
| 662 |
-
# k = k.replace('attn_fcross', 'attn1')
|
| 663 |
-
# state_dict.update({k: state_dict[k]})
|
| 664 |
-
# if 'norm_fcross' in k:
|
| 665 |
-
# k = k.replace('norm_fcross', 'norm1')
|
| 666 |
-
# state_dict.update({k: state_dict[k]})
|
| 667 |
-
|
| 668 |
-
# if 'conv_in' in k:
|
| 669 |
-
# continue
|
| 670 |
-
# else:
|
| 671 |
-
# new_state_dict[k] = v
|
| 672 |
-
# # # tmp
|
| 673 |
-
# # if 'class_embedding' in k:
|
| 674 |
-
# # state_dict.update({k: v})
|
| 675 |
-
# # breakpoint()
|
| 676 |
-
# model.load_state_dict(new_state_dict)
|
| 677 |
-
# else:
|
| 678 |
-
# for k, v in model.state_dict().items():
|
| 679 |
-
# # print(k)
|
| 680 |
-
# if '_temp' in k:
|
| 681 |
-
# state_dict.update({k: v})
|
| 682 |
-
# if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
|
| 683 |
-
# k = k.replace('attn_fcross', 'attn1')
|
| 684 |
-
# state_dict.update({k: state_dict[k]})
|
| 685 |
-
# if 'norm_fcross' in k:
|
| 686 |
-
# k = k.replace('norm_fcross', 'norm1')
|
| 687 |
-
# state_dict.update({k: state_dict[k]})
|
| 688 |
-
|
| 689 |
-
# model.load_state_dict(state_dict)
|
| 690 |
|
| 691 |
-
return model
|
| 692 |
-
|
| 693 |
-
if __name__ == '__main__':
|
| 694 |
-
import torch
|
| 695 |
-
# from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
|
| 696 |
-
|
| 697 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 698 |
-
|
| 699 |
-
# pretrained_model_path = "/mnt/petrelfs/maxin/work/pretrained/stable-diffusion-2-1-base/" # p cluster
|
| 700 |
-
pretrained_model_path = "/mnt/petrelfs/share_data/zhanglingjun/stable-diffusion-v1-4/" # p cluster
|
| 701 |
-
unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet").to(device)
|
| 702 |
-
# unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
|
| 703 |
-
unet.enable_xformers_memory_efficient_attention()
|
| 704 |
-
unet.enable_gradient_checkpointing()
|
| 705 |
-
|
| 706 |
-
unet.train()
|
| 707 |
-
|
| 708 |
-
use_image_num = 5
|
| 709 |
-
noisy_latents = torch.randn((2, 4, 16 + use_image_num, 32, 32)).to(device)
|
| 710 |
-
bsz = noisy_latents.shape[0]
|
| 711 |
-
timesteps = torch.randint(0, 1000, (bsz,)).to(device)
|
| 712 |
-
timesteps = timesteps.long()
|
| 713 |
-
encoder_hidden_states = torch.randn((bsz, 1 + use_image_num, 77, 768)).to(device)
|
| 714 |
-
# class_labels = torch.randn((bsz, )).to(device)
|
| 715 |
-
|
| 716 |
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
class_labels=None,
|
| 720 |
-
use_image_num=use_image_num).sample
|
| 721 |
-
print(model_pred.shape)
|
|
|
|
| 610 |
# config["num_class_embeds"] = 100
|
| 611 |
|
| 612 |
from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin
|
| 613 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 614 |
|
| 615 |
model = cls.from_config(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 616 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 617 |
|
| 618 |
+
return model
|
| 619 |
+
|
|
|
|
|
|
|
|
|