Spaces:
Runtime error
Runtime error
Update models/unet.py
Browse files- models/unet.py +45 -45
models/unet.py
CHANGED
|
@@ -640,53 +640,53 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
|
| 640 |
# 'CrossAttnUpBlock3D']}
|
| 641 |
|
| 642 |
model = cls.from_config(config)
|
| 643 |
-
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
| 644 |
-
if not os.path.isfile(model_file):
|
| 645 |
-
|
| 646 |
-
state_dict = torch.load(model_file, map_location="cpu")
|
| 647 |
-
|
| 648 |
-
if use_concat:
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
else:
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
|
| 691 |
return model
|
| 692 |
|
|
|
|
| 640 |
# 'CrossAttnUpBlock3D']}
|
| 641 |
|
| 642 |
model = cls.from_config(config)
|
| 643 |
+
# model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
| 644 |
+
# if not os.path.isfile(model_file):
|
| 645 |
+
# raise RuntimeError(f"{model_file} does not exist")
|
| 646 |
+
# state_dict = torch.load(model_file, map_location="cpu")
|
| 647 |
+
|
| 648 |
+
# if use_concat:
|
| 649 |
+
# new_state_dict = {}
|
| 650 |
+
# conv_in_weight = state_dict["conv_in.weight"]
|
| 651 |
+
# new_conv_weight = torch.zeros((conv_in_weight.shape[0], 9, *conv_in_weight.shape[2:]), dtype=conv_in_weight.dtype)
|
| 652 |
|
| 653 |
+
# for i, j in zip([0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 6, 7, 8]):
|
| 654 |
+
# new_conv_weight[:, j] = conv_in_weight[:, i]
|
| 655 |
+
# new_state_dict["conv_in.weight"] = new_conv_weight
|
| 656 |
+
# new_state_dict["conv_in.bias"] = state_dict["conv_in.bias"]
|
| 657 |
+
# for k, v in model.state_dict().items():
|
| 658 |
+
# # print(k)
|
| 659 |
+
# if '_temp.' in k:
|
| 660 |
+
# new_state_dict.update({k: v})
|
| 661 |
+
# if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
|
| 662 |
+
# k = k.replace('attn_fcross', 'attn1')
|
| 663 |
+
# state_dict.update({k: state_dict[k]})
|
| 664 |
+
# if 'norm_fcross' in k:
|
| 665 |
+
# k = k.replace('norm_fcross', 'norm1')
|
| 666 |
+
# state_dict.update({k: state_dict[k]})
|
| 667 |
|
| 668 |
+
# if 'conv_in' in k:
|
| 669 |
+
# continue
|
| 670 |
+
# else:
|
| 671 |
+
# new_state_dict[k] = v
|
| 672 |
+
# # # tmp
|
| 673 |
+
# # if 'class_embedding' in k:
|
| 674 |
+
# # state_dict.update({k: v})
|
| 675 |
+
# # breakpoint()
|
| 676 |
+
# model.load_state_dict(new_state_dict)
|
| 677 |
+
# else:
|
| 678 |
+
# for k, v in model.state_dict().items():
|
| 679 |
+
# # print(k)
|
| 680 |
+
# if '_temp' in k:
|
| 681 |
+
# state_dict.update({k: v})
|
| 682 |
+
# if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
|
| 683 |
+
# k = k.replace('attn_fcross', 'attn1')
|
| 684 |
+
# state_dict.update({k: state_dict[k]})
|
| 685 |
+
# if 'norm_fcross' in k:
|
| 686 |
+
# k = k.replace('norm_fcross', 'norm1')
|
| 687 |
+
# state_dict.update({k: state_dict[k]})
|
| 688 |
+
|
| 689 |
+
# model.load_state_dict(state_dict)
|
| 690 |
|
| 691 |
return model
|
| 692 |
|