FOFPred / transformer /transformer_omnigen2.py
jimjag-sf's picture
Add transformer code for trust_remote_code (#4)
2bd2d5e verified
raw
history blame
34.5 kB
import itertools
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import PeftAdapterMixin
from diffusers.loaders.single_file_model import FromOriginalModelMixin
from diffusers.models.attention_processor import Attention
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils import (
USE_PEFT_BACKEND,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from einops import rearrange
from ...utils.import_utils import is_triton_available
from ...utils.teacache_util import TeaCacheParams
from ..attention_processor import (
OmniGen2AttnProcessor,
OmniGen2AttnProcessorFlash2Varlen,
)
from .block_lumina2 import (
Lumina2CombinedTimestepCaptionEmbedding,
LuminaFeedForward,
LuminaLayerNormContinuous,
LuminaRMSNormZero,
)
from .repo import OmniGen2RotaryPosEmbed
if is_triton_available():
from ...ops.triton.layer_norm import RMSNorm
else:
from torch.nn import RMSNorm
from ...cache_functions import cal_type
from ...taylorseer_utils import (
derivative_approximation,
taylor_cache_init,
taylor_formula,
)
logger = logging.get_logger(__name__)
class OmniGen2TransformerBlock(nn.Module):
"""
Transformer block for OmniGen2 model.
This block implements a transformer layer with:
- Multi-head attention with flash attention
- Feed-forward network with SwiGLU activation
- RMS normalization
- Optional modulation for conditional generation
Args:
dim: Dimension of the input and output tensors
num_attention_heads: Number of attention heads
num_kv_heads: Number of key-value heads
multiple_of: Multiple of which the hidden dimension should be
ffn_dim_multiplier: Multiplier for the feed-forward network dimension
norm_eps: Epsilon value for normalization layers
modulation: Whether to use modulation for conditional generation
use_fused_rms_norm: Whether to use fused RMS normalization
use_fused_swiglu: Whether to use fused SwiGLU activation
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
num_kv_heads: int,
multiple_of: int,
ffn_dim_multiplier: float,
norm_eps: float,
modulation: bool = True,
) -> None:
"""Initialize the transformer block."""
super().__init__()
self.head_dim = dim // num_attention_heads
self.modulation = modulation
try:
processor = OmniGen2AttnProcessorFlash2Varlen()
except ImportError:
processor = OmniGen2AttnProcessor()
# Initialize attention layer
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=dim // num_attention_heads,
qk_norm="rms_norm",
heads=num_attention_heads,
kv_heads=num_kv_heads,
eps=1e-5,
bias=False,
out_bias=False,
processor=processor,
)
# Initialize feed-forward network
self.feed_forward = LuminaFeedForward(
dim=dim,
inner_dim=4 * dim,
multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier,
)
# Initialize normalization layers
if modulation:
self.norm1 = LuminaRMSNormZero(
embedding_dim=dim, norm_eps=norm_eps, norm_elementwise_affine=True
)
else:
self.norm1 = RMSNorm(dim, eps=norm_eps)
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
self.norm2 = RMSNorm(dim, eps=norm_eps)
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
self.initialize_weights()
def initialize_weights(self) -> None:
"""
Initialize the weights of the transformer block.
Uses Xavier uniform initialization for linear layers and zero initialization for biases.
"""
nn.init.xavier_uniform_(self.attn.to_q.weight)
nn.init.xavier_uniform_(self.attn.to_k.weight)
nn.init.xavier_uniform_(self.attn.to_v.weight)
nn.init.xavier_uniform_(self.attn.to_out[0].weight)
nn.init.xavier_uniform_(self.feed_forward.linear_1.weight)
nn.init.xavier_uniform_(self.feed_forward.linear_2.weight)
nn.init.xavier_uniform_(self.feed_forward.linear_3.weight)
if self.modulation:
nn.init.zeros_(self.norm1.linear.weight)
nn.init.zeros_(self.norm1.linear.bias)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
image_rotary_emb: torch.Tensor,
temb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass of the transformer block.
Args:
hidden_states: Input hidden states tensor
attention_mask: Attention mask tensor
image_rotary_emb: Rotary embeddings for image tokens
temb: Optional timestep embedding tensor
Returns:
torch.Tensor: Output hidden states after transformer block processing
"""
enable_taylorseer = getattr(self, "enable_taylorseer", False)
if enable_taylorseer:
if self.modulation:
if temb is None:
raise ValueError("temb must be provided when modulation is enabled")
if self.current["type"] == "full":
self.current["module"] = "total"
taylor_cache_init(cache_dic=self.cache_dic, current=self.current)
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(
hidden_states, temb
)
attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + gate_msa.unsqueeze(
1
).tanh() * self.norm2(attn_output)
mlp_output = self.feed_forward(
self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))
)
hidden_states = hidden_states + gate_mlp.unsqueeze(
1
).tanh() * self.ffn_norm2(mlp_output)
derivative_approximation(
cache_dic=self.cache_dic,
current=self.current,
feature=hidden_states,
)
elif self.current["type"] == "Taylor":
self.current["module"] = "total"
hidden_states = taylor_formula(
cache_dic=self.cache_dic, current=self.current
)
else:
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
else:
if self.modulation:
if temb is None:
raise ValueError("temb must be provided when modulation is enabled")
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(
hidden_states, temb
)
attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + gate_msa.unsqueeze(
1
).tanh() * self.norm2(attn_output)
mlp_output = self.feed_forward(
self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))
)
hidden_states = hidden_states + gate_mlp.unsqueeze(
1
).tanh() * self.ffn_norm2(mlp_output)
else:
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
return hidden_states
class OmniGen2Transformer3DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
):
"""
OmniGen2 Transformer 3D Model (modified to output frame sequences).
A transformer-based diffusion model for image generation with:
- Patch-based image processing
- Rotary position embeddings
- Multi-head attention
- Conditional generation support
Args:
patch_size: Size of image patches
in_channels: Number of input channels
out_channels: Number of output channels (defaults to in_channels)
hidden_size: Size of hidden layers
num_layers: Number of transformer layers
num_refiner_layers: Number of refiner layers
num_attention_heads: Number of attention heads
num_kv_heads: Number of key-value heads
multiple_of: Multiple of which the hidden dimension should be
ffn_dim_multiplier: Multiplier for feed-forward network dimension
norm_eps: Epsilon value for normalization layers
axes_dim_rope: Dimensions for rotary position embeddings
axes_lens: Lengths for rotary position embeddings
text_feat_dim: Dimension of text features
timestep_scale: Scale factor for timestep embeddings
use_fused_rms_norm: Whether to use fused RMS normalization
use_fused_swiglu: Whether to use fused SwiGLU activation
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["Omnigen2TransformerBlock"]
_skip_layerwise_casting_patterns = ["x_embedder", "norm"]
@register_to_config
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
out_channels: Optional[int] = None,
hidden_size: int = 2304,
num_layers: int = 26,
num_refiner_layers: int = 2,
num_attention_heads: int = 24,
num_kv_heads: int = 8,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
norm_eps: float = 1e-5,
axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
axes_lens: Tuple[int, int, int] = (300, 512, 512),
text_feat_dim: int = 1024,
timestep_scale: float = 1.0,
) -> None:
"""Initialize the OmniGen2 transformer model."""
super().__init__()
# Validate configuration
if (hidden_size // num_attention_heads) != sum(axes_dim_rope):
raise ValueError(
f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) "
f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})"
)
self.out_channels = out_channels or in_channels
# Initialize embeddings
self.rope_embedder = OmniGen2RotaryPosEmbed(
theta=10000,
axes_dim=axes_dim_rope,
axes_lens=axes_lens,
patch_size=patch_size,
)
self.x_embedder = nn.Linear(
in_features=patch_size * patch_size * in_channels,
out_features=hidden_size,
)
self.ref_image_patch_embedder = nn.Linear(
in_features=patch_size * patch_size * in_channels,
out_features=hidden_size,
)
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
hidden_size=hidden_size,
text_feat_dim=text_feat_dim,
norm_eps=norm_eps,
timestep_scale=timestep_scale,
)
# Initialize transformer blocks
self.noise_refiner = nn.ModuleList(
[
OmniGen2TransformerBlock(
hidden_size,
num_attention_heads,
num_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
modulation=True,
)
for _ in range(num_refiner_layers)
]
)
self.ref_image_refiner = nn.ModuleList(
[
OmniGen2TransformerBlock(
hidden_size,
num_attention_heads,
num_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
modulation=True,
)
for _ in range(num_refiner_layers)
]
)
self.context_refiner = nn.ModuleList(
[
OmniGen2TransformerBlock(
hidden_size,
num_attention_heads,
num_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
modulation=False,
)
for _ in range(num_refiner_layers)
]
)
# 3. Transformer blocks
self.layers = nn.ModuleList(
[
OmniGen2TransformerBlock(
hidden_size,
num_attention_heads,
num_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
modulation=True,
)
for _ in range(num_layers)
]
)
# 4. Output norm & projection
self.norm_out = LuminaLayerNormContinuous(
embedding_dim=hidden_size,
conditioning_embedding_dim=min(hidden_size, 1024),
elementwise_affine=False,
eps=1e-6,
bias=True,
out_dim=patch_size * patch_size * self.out_channels,
)
# Add learnable embeddings to distinguish different images
self.image_index_embedding = nn.Parameter(
torch.randn(5, hidden_size)
) # support max 5 ref images
self.gradient_checkpointing = False
self.initialize_weights()
# TeaCache settings
self.enable_teacache = False
self.teacache_rel_l1_thresh = 0.05
self.teacache_params = TeaCacheParams()
coefficients = [-5.48259225, 11.48772289, -4.47407401, 2.47730926, -0.03316487]
self.rescale_func = np.poly1d(coefficients)
def initialize_weights(self) -> None:
"""
Initialize the weights of the model.
Uses Xavier uniform initialization for linear layers.
"""
nn.init.xavier_uniform_(self.x_embedder.weight)
nn.init.constant_(self.x_embedder.bias, 0.0)
nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight)
nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0)
nn.init.zeros_(self.norm_out.linear_1.weight)
nn.init.zeros_(self.norm_out.linear_1.bias)
nn.init.zeros_(self.norm_out.linear_2.weight)
nn.init.zeros_(self.norm_out.linear_2.bias)
nn.init.normal_(self.image_index_embedding, std=0.02)
def img_patch_embed_and_refine(
self,
hidden_states,
ref_image_hidden_states,
padded_img_mask,
padded_ref_img_mask,
noise_rotary_emb,
ref_img_rotary_emb,
l_effective_ref_img_len,
l_effective_img_len,
temb,
):
batch_size = len(hidden_states)
if isinstance(l_effective_img_len[0], list):
l_effective_img_len_summed = [sum(ln) for ln in l_effective_img_len]
else:
l_effective_img_len_summed = l_effective_img_len
max_combined_img_len = max(
[
img_len + sum(ref_img_len)
for img_len, ref_img_len in zip(
l_effective_img_len_summed, l_effective_ref_img_len
)
]
)
hidden_states = self.x_embedder(hidden_states)
ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
for i in range(batch_size):
shift = 0
for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
ref_image_hidden_states[i, shift : shift + ref_img_len, :] = (
ref_image_hidden_states[i, shift : shift + ref_img_len, :]
+ self.image_index_embedding[j]
)
shift += ref_img_len
for layer in self.noise_refiner:
hidden_states = layer(
hidden_states, padded_img_mask, noise_rotary_emb, temb
)
flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len))
num_ref_images = len(flat_l_effective_ref_img_len)
max_ref_img_len = max(flat_l_effective_ref_img_len)
batch_ref_img_mask = ref_image_hidden_states.new_zeros(
num_ref_images, max_ref_img_len, dtype=torch.bool
)
batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros(
num_ref_images, max_ref_img_len, self.config.hidden_size
)
batch_ref_img_rotary_emb = hidden_states.new_zeros(
num_ref_images,
max_ref_img_len,
ref_img_rotary_emb.shape[-1],
dtype=ref_img_rotary_emb.dtype,
)
batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype)
# sequence of ref imgs to batch
idx = 0
for i in range(batch_size):
shift = 0
for ref_img_len in l_effective_ref_img_len[i]:
batch_ref_img_mask[idx, :ref_img_len] = True
batch_ref_image_hidden_states[idx, :ref_img_len] = (
ref_image_hidden_states[i, shift : shift + ref_img_len]
)
batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[
i, shift : shift + ref_img_len
]
batch_temb[idx] = temb[i]
shift += ref_img_len
idx += 1
# refine ref imgs separately
for layer in self.ref_image_refiner:
batch_ref_image_hidden_states = layer(
batch_ref_image_hidden_states,
batch_ref_img_mask,
batch_ref_img_rotary_emb,
batch_temb,
)
# batch of ref imgs to sequence
idx = 0
for i in range(batch_size):
shift = 0
for ref_img_len in l_effective_ref_img_len[i]:
ref_image_hidden_states[i, shift : shift + ref_img_len] = (
batch_ref_image_hidden_states[idx, :ref_img_len]
)
shift += ref_img_len
idx += 1
combined_img_hidden_states = hidden_states.new_zeros(
batch_size, max_combined_img_len, self.config.hidden_size
)
for i, (ref_img_len, img_len) in enumerate(
zip(l_effective_ref_img_len, l_effective_img_len_summed)
):
combined_img_hidden_states[i, : sum(ref_img_len)] = ref_image_hidden_states[
i, : sum(ref_img_len)
]
combined_img_hidden_states[
i, sum(ref_img_len) : sum(ref_img_len) + img_len
] = hidden_states[i, :img_len]
return combined_img_hidden_states
def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
batch_size = len(hidden_states)
p = self.config.patch_size
device = hidden_states[0].device
if len(hidden_states[0].shape) == 3:
img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
else:
img_sizes = [
[(img.size(1), img.size(2)) for img in imgs] for imgs in hidden_states
]
l_effective_img_len = [
[(H // p) * (W // p) for (H, W) in _img_sizes]
for _img_sizes in img_sizes
]
if ref_image_hidden_states is not None:
ref_img_sizes = [
[(img.size(1), img.size(2)) for img in imgs]
if imgs is not None
else None
for imgs in ref_image_hidden_states
]
l_effective_ref_img_len = [
[
(ref_img_size[0] // p) * (ref_img_size[1] // p)
for ref_img_size in _ref_img_sizes
]
if _ref_img_sizes is not None
else [0]
for _ref_img_sizes in ref_img_sizes
]
else:
ref_img_sizes = [None for _ in range(batch_size)]
l_effective_ref_img_len = [[0] for _ in range(batch_size)]
max_ref_img_len = max(
[sum(ref_img_len) for ref_img_len in l_effective_ref_img_len]
)
if len(hidden_states[0].shape) == 4:
max_img_len = max([sum(img_len) for img_len in l_effective_img_len])
else:
max_img_len = max(l_effective_img_len)
# ref image patch embeddings
flat_ref_img_hidden_states = []
for i in range(batch_size):
if ref_img_sizes[i] is not None:
imgs = []
for ref_img in ref_image_hidden_states[i]:
C, H, W = ref_img.size()
ref_img = rearrange(
ref_img, "c (h p1) (w p2) -> (h w) (p1 p2 c)", p1=p, p2=p
)
imgs.append(ref_img)
img = torch.cat(imgs, dim=0)
flat_ref_img_hidden_states.append(img)
else:
flat_ref_img_hidden_states.append(None)
# image patch embeddings
flat_hidden_states = []
if len(hidden_states[0].shape) == 4: # New case
for i in range(batch_size):
# Process each time step and concatenate
batch_img_patches = []
for img in hidden_states[i]:
C, H, W = img.size()
img = rearrange(
img, "c (h p1) (w p2) -> (h w) (p1 p2 c)", p1=p, p2=p
)
batch_img_patches.append(img)
# Concatenate patches for the current batch item across time
flat_hidden_states.append(torch.cat(batch_img_patches, dim=0))
else: # Default
for i in range(batch_size):
img = hidden_states[i]
C, H, W = img.size()
img = rearrange(img, "c (h p1) (w p2) -> (h w) (p1 p2 c)", p1=p, p2=p)
flat_hidden_states.append(img)
padded_ref_img_hidden_states = torch.zeros(
batch_size,
max_ref_img_len,
flat_hidden_states[0].shape[-1],
device=device,
dtype=flat_hidden_states[0].dtype,
)
padded_ref_img_mask = torch.zeros(
batch_size, max_ref_img_len, dtype=torch.bool, device=device
)
for i in range(batch_size):
if ref_img_sizes[i] is not None:
padded_ref_img_hidden_states[i, : sum(l_effective_ref_img_len[i])] = (
flat_ref_img_hidden_states[i]
)
padded_ref_img_mask[i, : sum(l_effective_ref_img_len[i])] = True
padded_hidden_states = torch.zeros(
batch_size,
max_img_len,
flat_hidden_states[0].shape[-1],
device=device,
dtype=flat_hidden_states[0].dtype,
)
padded_img_mask = torch.zeros(
batch_size, max_img_len, dtype=torch.bool, device=device
)
for i in range(batch_size):
if len(hidden_states[0].shape) == 4: # New case
padded_hidden_states[i, : sum(l_effective_img_len[i])] = (
flat_hidden_states[i]
)
padded_img_mask[i, : sum(l_effective_img_len[i])] = True
else:
padded_hidden_states[i, : l_effective_img_len[i]] = flat_hidden_states[
i
]
padded_img_mask[i, : l_effective_img_len[i]] = True
return (
padded_hidden_states,
padded_ref_img_hidden_states,
padded_img_mask,
padded_ref_img_mask,
l_effective_ref_img_len,
l_effective_img_len,
ref_img_sizes,
img_sizes,
)
def forward(
self,
hidden_states: Union[torch.Tensor, List[torch.Tensor]],
timestep: torch.Tensor,
text_hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
text_attention_mask: torch.Tensor,
ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = False,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
enable_taylorseer = getattr(self, "enable_taylorseer", False)
if enable_taylorseer:
cal_type(self.cache_dic, self.current)
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if (
attention_kwargs is not None
and attention_kwargs.get("scale", None) is not None
):
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# 1. Condition, positional & patch embedding
batch_size = len(hidden_states)
is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor)
if is_hidden_states_tensor:
assert hidden_states.ndim == 4
hidden_states = [_hidden_states for _hidden_states in hidden_states]
device = hidden_states[0].device
temb, text_hidden_states = self.time_caption_embed(
timestep, text_hidden_states, hidden_states[0].dtype
)
(
hidden_states,
ref_image_hidden_states,
img_mask,
ref_img_mask,
l_effective_ref_img_len,
l_effective_img_len,
ref_img_sizes,
img_sizes,
) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
(
context_rotary_emb,
ref_img_rotary_emb,
noise_rotary_emb,
rotary_emb,
encoder_seq_lengths,
seq_lengths,
) = self.rope_embedder(
freqs_cis,
text_attention_mask,
l_effective_ref_img_len,
l_effective_img_len,
ref_img_sizes,
img_sizes,
device,
)
# 2. Context refinement
for layer in self.context_refiner:
text_hidden_states = layer(
text_hidden_states, text_attention_mask, context_rotary_emb
)
combined_img_hidden_states = self.img_patch_embed_and_refine(
hidden_states,
ref_image_hidden_states,
img_mask,
ref_img_mask,
noise_rotary_emb,
ref_img_rotary_emb,
l_effective_ref_img_len,
l_effective_img_len,
temb,
)
# 3. Joint Transformer blocks
max_seq_len = max(seq_lengths)
attention_mask = hidden_states.new_zeros(
batch_size, max_seq_len, dtype=torch.bool
)
joint_hidden_states = hidden_states.new_zeros(
batch_size, max_seq_len, self.config.hidden_size
)
for i, (encoder_seq_len, seq_len) in enumerate(
zip(encoder_seq_lengths, seq_lengths)
):
attention_mask[i, :seq_len] = True
joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[
i, :encoder_seq_len
]
joint_hidden_states[i, encoder_seq_len:seq_len] = (
combined_img_hidden_states[i, : seq_len - encoder_seq_len]
)
hidden_states = joint_hidden_states
if self.enable_teacache:
teacache_hidden_states = hidden_states.clone()
teacache_temb = temb.clone()
modulated_inp, _, _, _ = self.layers[0].norm1(
teacache_hidden_states, teacache_temb
)
if self.teacache_params.is_first_or_last_step:
should_calc = True
self.teacache_params.accumulated_rel_l1_distance = 0
else:
self.teacache_params.accumulated_rel_l1_distance += self.rescale_func(
(
(modulated_inp - self.teacache_params.previous_modulated_inp)
.abs()
.mean()
/ self.teacache_params.previous_modulated_inp.abs().mean()
)
.cpu()
.item()
)
if (
self.teacache_params.accumulated_rel_l1_distance
< self.teacache_rel_l1_thresh
):
should_calc = False
else:
should_calc = True
self.teacache_params.accumulated_rel_l1_distance = 0
self.teacache_params.previous_modulated_inp = modulated_inp
if self.enable_teacache:
if not should_calc:
hidden_states += self.teacache_params.previous_residual
else:
ori_hidden_states = hidden_states.clone()
for layer_idx, layer in enumerate(self.layers):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
layer, hidden_states, attention_mask, rotary_emb, temb
)
else:
hidden_states = layer(
hidden_states, attention_mask, rotary_emb, temb
)
self.teacache_params.previous_residual = (
hidden_states - ori_hidden_states
)
else:
if enable_taylorseer:
self.current["stream"] = "layers_stream"
for layer_idx, layer in enumerate(self.layers):
if enable_taylorseer:
layer.current = self.current
layer.cache_dic = self.cache_dic
layer.enable_taylorseer = True
self.current["layer"] = layer_idx
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
layer, hidden_states, attention_mask, rotary_emb, temb
)
else:
hidden_states = layer(
hidden_states, attention_mask, rotary_emb, temb
)
# 4. Output norm & projection
hidden_states = self.norm_out(hidden_states, temb)
p = self.config.patch_size
output = []
for i, (img_size, img_len, seq_len) in enumerate(
zip(img_sizes, l_effective_img_len, seq_lengths)
):
if isinstance(img_len, list):
batch_output = []
cur_st = seq_len - sum(img_len)
for j in range(len(img_len)):
height, width = img_size[j]
cur_len = img_len[j]
batch_output.append(
rearrange(
hidden_states[i][cur_st : cur_st + cur_len],
"(h w) (p1 p2 c) -> c (h p1) (w p2)",
h=height // p,
w=width // p,
p1=p,
p2=p,
)
)
cur_st += cur_len
output.append(torch.stack(batch_output, dim=0))
else:
height, width = img_size
output.append(
rearrange(
hidden_states[i][seq_len - img_len : seq_len],
"(h w) (p1 p2 c) -> c (h p1) (w p2)",
h=height // p,
w=width // p,
p1=p,
p2=p,
)
)
if is_hidden_states_tensor:
output = torch.stack(output, dim=0)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if enable_taylorseer:
self.current["step"] += 1
if not return_dict:
return output
return Transformer2DModelOutput(sample=output)