mtp_simple_push_test / modeling_kormo.py
Chang-Su's picture
Upload KORMoForCausalLMWithMTP
6d87106 verified
"""
Standalone KORMo MTP Model Wrapper for Hugging Face Hub
This file contains all necessary components embedded within it to work without external kormo dependencies.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Callable, List, Optional, Tuple, Union, Any
import math
import logging
import os
from pathlib import Path
# Transformers imports
from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, rope_config_validation
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.processing_utils import Unpack
from transformers.utils import LossKwargs, logging as transformers_logging
from transformers.utils.import_utils import is_torch_flex_attn_available, is_torch_greater_or_equal
# Flex attention imports (with safe fallback)
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import create_block_mask, BlockMask, and_masks, or_masks
else:
BlockMask = torch.Tensor
def create_block_mask(*args, **kwargs):
return None
def and_masks(*args, **kwargs):
return None
def or_masks(*args, **kwargs):
return None
logger = transformers_logging.get_logger(__name__)
def print_once(message: str) -> None:
if not getattr(print_once, "_has_printed", False):
print(message)
print_once._has_printed = True
# ==============================================================================
# Configuration Class
# ==============================================================================
class KORMoConfig(PretrainedConfig):
model_type = "kormo"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
def __init__(
self,
vocab_size=112576,
hidden_size=6144,
intermediate_size=21504,
num_hidden_layers=48,
num_attention_heads=40,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=131072,
initializer_range=0.02,
rms_norm_eps=1e-05,
use_cache=True,
pad_token_id=None,
bos_token_id=0,
eos_token_id=1,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=500000.0,
attention_bias=False,
attention_dropout=0.0,
rope_scaling=None,
mlp_bias=False,
head_dim=128,
sliding_window=None,
post_ln_layer_end_idx=8,
mtp_depth=0,
mtp_loss_lambda=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
self.sliding_window = sliding_window
self.post_ln_layer_end_idx = post_ln_layer_end_idx
self.mtp_depth = mtp_depth
self.mtp_loss_lambda = mtp_loss_lambda
self.mask_type = None
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# ==============================================================================
# Custom Mask Functions
# ==============================================================================
def generate_sliding_window(sliding_window):
def inner_mask(b, h, q_idx, kv_idx):
return kv_idx > q_idx - sliding_window
return inner_mask
def generate_doc_mask(input_ids, bos_token_id):
is_bos = (input_ids.flatten() == bos_token_id)
flat_doc_ids = torch.cumsum(is_bos, 0)
doc_ids = flat_doc_ids.view_as(input_ids)
def inner_mask(b, h, q_idx, kv_idx):
same_doc = doc_ids[b, q_idx] == doc_ids[b, kv_idx]
return same_doc
return inner_mask
def generate_bos_mask(input_ids, bos_token_id):
is_bos_table = input_ids == bos_token_id
def inner_mask(b, h, q_idx, kv_idx):
is_bos = is_bos_table[b, kv_idx]
return is_bos
return inner_mask
def create_causal_mask(
config: KORMoConfig,
input_embeds: torch.Tensor,
attention_mask: Optional[torch.Tensor],
cache_position: torch.Tensor,
past_key_values: Optional[Cache],
and_mask_function: Optional[Callable] = None,
or_mask_function: Optional[Callable] = None,
) -> Optional[Union[torch.Tensor, BlockMask]]:
"""Create causal mask for flex attention"""
if config._attn_implementation != "flex_attention":
return None
if and_mask_function is None and or_mask_function is None:
return None
# This is a simplified implementation - full implementation would need
# proper flex attention block mask creation
return None
# ==============================================================================
# Core Model Components
# ==============================================================================
class RMSNorm(nn.Module):
"""KORMoRMSNorm is equivalent to T5LayerNorm"""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return (self.weight * hidden_states).to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
ALL_LAYERNORM_LAYERS.append(RMSNorm)
class RotaryEmbedding(nn.Module):
def __init__(self, config: KORMoConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
output = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return output
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed.to(q.dtype), k_embed.to(k.dtype)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: KORMoConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(hidden_shape).transpose(1, 2)
key_states = key_states.view(hidden_shape).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class PreNormDecoderLayer(nn.Module):
def __init__(self, config: KORMoConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Attention(config=config, layer_idx=layer_idx)
self.mlp = MLP(config)
self.pre_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_mlp_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
# Self Attention
hidden_states = self.pre_attention_layernorm(hidden_states)
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# MLP layer
residual = hidden_states
hidden_states = self.pre_mlp_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class PostNormDecoderLayer(nn.Module):
def __init__(self, config: KORMoConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Attention(config=config, layer_idx=layer_idx)
self.mlp = MLP(config)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_mlp_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
# MLP layer
residual = hidden_states
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.post_mlp_layernorm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
# ==============================================================================
# MTP Components
# ==============================================================================
class MTPModule(nn.Module):
"""
k번째 MTP 모듈: 이전 단계 표현과 미래 토큰 임베딩을 결합하여
projection 및 TRM_k Transformer 블록을 통해 새로운 표현을 생성합니다.
"""
def __init__(self, config: KORMoConfig, module_index: int):
super().__init__()
d = config.hidden_size
self.rmsnorm_prev = RMSNorm(d, eps=config.rms_norm_eps)
self.rmsnorm_emb = RMSNorm(d, eps=config.rms_norm_eps)
self.proj = nn.Linear(2 * d, d, bias=False)
# 더 안정적인 초기화
with torch.no_grad():
self.proj.weight.normal_(mean=0.0, std=0.02 / (2 * d) ** 0.5)
# 각 MTP 모듈마다 고유한 layer_idx 보장 (메인 모델과 겹치지 않도록)
mtp_layer_idx = config.num_hidden_layers + module_index
if mtp_layer_idx < config.post_ln_layer_end_idx:
self.trm = PostNormDecoderLayer(config, mtp_layer_idx)
else:
self.trm = PreNormDecoderLayer(config, mtp_layer_idx)
self.rotary_emb = RotaryEmbedding(config)
def forward(self, hidden_prev: torch.Tensor, emb_future: torch.Tensor) -> torch.Tensor:
# 1) Normalize previous hidden state and future token embedding via RMSNorm
h1 = self.rmsnorm_prev(hidden_prev)
h2 = self.rmsnorm_emb(emb_future.to(hidden_prev.dtype))
# 2) Concatenate normalized vectors (size 2d) and project back to hidden dimension d
x = torch.cat([h1, h2], dim=-1)
proj_dtype = self.proj.weight.dtype
x = self.proj(x.to(proj_dtype))
# 3) Generate rotary positional embeddings for the sequence
batch_size, seq_len, _ = x.size()
position_ids = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
cos, sin = self.rotary_emb(x, position_ids)
# 4) Pass through the k-th Transformer block with rotary embeddings
output = self.trm(
x,
position_ids=position_ids,
position_embeddings=(cos, sin),
)
# Extract hidden tensor from layer output (tuple for some implementations)
hidden = output[0] if isinstance(output, tuple) else output
# prevent NaN/Inf propagation: clamp any nan/inf to zero
hidden = torch.nan_to_num(hidden, nan=0.0, posinf=0.0, neginf=0.0)
return hidden
class MTP(nn.Module):
"""
전체 MTP 헤드: D개의 순차적 MTPModule을 묶어 multi-token 예측을 수행합니다.
"""
def __init__(self, config: KORMoConfig):
super().__init__()
self.config = config
self.mtp_modules = nn.ModuleList([
MTPModule(config, k) for k in range(config.mtp_depth)
])
def forward(self, hidden_states: torch.Tensor, future_embs: list[torch.Tensor]) -> list[torch.Tensor]:
# Sequentially apply each MTPModule:
# Module k takes h^{k-1} and Emb(t_{i+k}) to produce h^k
outputs = []
h = hidden_states # h^0 from main model
for k, mtp_mod in enumerate(self.mtp_modules):
h = mtp_mod(h, future_embs[k]) # h^{k-1} -> h^k
outputs.append(h)
return outputs
class MTPLoss(nn.Module):
"""
MTP 손실 계산: 각 모듈의 cross-entropy 손실을 합/평균하여 최종 손실을 반환합니다.
"""
def __init__(self, config: KORMoConfig):
super().__init__()
pad_id = config.pad_token_id or config.eos_token_id or 0
self.ce = nn.CrossEntropyLoss(ignore_index=pad_id, reduction='none')
self.lambda_mtp = config.mtp_loss_lambda
def forward(
self,
mtp_logits: list[torch.Tensor],
target_ids: torch.LongTensor,
) -> torch.Tensor:
vocab_size = mtp_logits[0].size(-1) if mtp_logits else 0
ignore_idx = self.ce.ignore_index
target_ids = target_ids.clone()
target_ids = torch.where(target_ids < 0, ignore_idx, target_ids)
target_ids = torch.where(target_ids >= vocab_size, ignore_idx, target_ids)
losses: list[torch.Tensor] = []
total_valid_tokens = 0
for k, logits in enumerate(mtp_logits, start=1):
try:
# k번째 MTP 모듈은 k번째 미래 토큰을 예측 (target_ids[:, k-1, :])
labels_k = target_ids[:, k-1, : logits.size(1)]
# 안전성 체크: NaN/Inf 검사
if torch.isnan(labels_k).any() or torch.isinf(labels_k).any():
logger.warning(f"NaN/Inf detected in labels_k for MTP module {k}")
continue
# 유효한 토큰 마스크 계산 (안전한 방식)
labels_k_flat = labels_k.reshape(-1)
# CUDA 안전성을 위해 CPU로 이동해서 마스크 계산
mask = (labels_k_flat != self.ce.ignore_index)
# 마스크도 안전성 체크
if torch.isnan(mask.float()).any():
logger.warning(f"NaN detected in mask for MTP module {k}")
continue
num_valid = mask.float().sum()
total_valid_tokens += num_valid
if num_valid > 0:
# 유효한 토큰이 있는 경우만 손실 계산
logits_flat = logits.reshape(-1, vocab_size)
# logits도 안전성 체크
if torch.isnan(logits_flat).any() or torch.isinf(logits_flat).any():
logger.warning(f"NaN/Inf detected in logits for MTP module {k}")
continue
loss_k = self.ce(logits_flat, labels_k_flat)
# loss_k 안전성 체크
if torch.isnan(loss_k).any() or torch.isinf(loss_k).any():
logger.warning(f"NaN/Inf detected in loss_k for MTP module {k}")
continue
losses.append((loss_k * mask).sum() / num_valid)
# 유효한 토큰이 없는 경우 해당 모듈 손실을 건너뜀
except Exception as e:
logger.error(f"Error processing MTP module {k}: {e}")
continue
if losses and total_valid_tokens > 0:
loss = sum(losses) / len(losses)
loss = loss * self.lambda_mtp
# NaN/Inf 방지
if torch.isnan(loss) or torch.isinf(loss):
loss = torch.tensor(0.0, device=target_ids.device, dtype=loss.dtype)
else:
# 모든 토큰이 패딩인 경우 MTP 손실을 0으로 설정 (main loss만 사용)
loss = torch.tensor(0.0, device=target_ids.device, dtype=torch.float32)
return loss
# ==============================================================================
# Main Model Classes
# ==============================================================================
class KORMoPreTrainedModel(PreTrainedModel):
config_class = KORMoConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["PreNormDecoderLayer", "PostNormDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class KORMoModel(KORMoPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PreNormDecoderLayer`, `PostNormDecoderLayer`]
Args:
config: KORMoConfig
"""
def __init__(self, config: KORMoConfig):
super().__init__(config)
post_ln_index = config.post_ln_layer_end_idx
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[PostNormDecoderLayer(config, layer_idx) for layer_idx in range(post_ln_index)] +
[PreNormDecoderLayer(config, layer_idx) for layer_idx in range(post_ln_index, config.num_hidden_layers)]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = RotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# Simplified causal mask creation (removed complex mask logic for standalone version)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
output = BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
return output if return_dict else output.to_tuple()
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
# Simplified causal mask implementation
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and past_key_values is not None:
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of KORMo. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
if (
self.config._attn_implementation == "sdpa"
and not (using_static_cache or using_sliding_window_cache)
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
sliding_window=self.config.sliding_window,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_sliding_window_cache or using_static_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
past_key_values=past_key_values,
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config: KORMoConfig,
past_key_values: Cache,
):
if attention_mask is not None and attention_mask.dim() == 4:
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if config.sliding_window is not None:
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=device) <= (
cache_position.reshape(-1, 1) - config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.shape[-1] > target_length:
attention_mask = attention_mask[:, :target_length]
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
pass
class KORMoForCausalLM(KORMoPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config):
super().__init__(config)
self.model = KORMoModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# ==============================================================================
# MTP Wrapper Class
# ==============================================================================
class KORMoForCausalLMWithMTP(KORMoForCausalLM):
def save_pretrained(self, save_directory, **kwargs):
"""
Save the MTP model, including the base model and the MTP head.
"""
# Ensure the config reflects the MTP class
self.config.architectures = [self.__class__.__name__]
# Call the parent's save method to save the base model and config
super().save_pretrained(save_directory, **kwargs)
# Save the MTP head's state dictionary
if self.mtp_head is not None:
mtp_head_path = os.path.join(save_directory, "mtp_head.pt")
torch.save(self.mtp_head.state_dict(), mtp_head_path)
print(f"✅ MTP head saved to {mtp_head_path}")
@classmethod
def from_pretrained(cls, model_path, **kwargs):
"""
Load MTP model from checkpoint or create from base model - Hugging Face style!
"""
# Load config first
config = KORMoConfig.from_pretrained(model_path)
# Create an instance of the correct model class
model = cls(config)
# Load the base model weights
base_model_state_dict = KORMoForCausalLM.from_pretrained(model_path, **kwargs).state_dict()
model.load_state_dict(base_model_state_dict, strict=False)
# Load the MTP head's state dictionary if it exists
mtp_head_path = os.path.join(model_path, "mtp_head.pt")
if os.path.exists(mtp_head_path) and model.mtp_head is not None:
model.mtp_head.load_state_dict(torch.load(mtp_head_path, map_location=model.device))
print(f"✅ MTP head loaded from {mtp_head_path}")
return model
def __init__(self, config):
super().__init__(config)
if getattr(config, 'mtp_depth', 0) > 0:
self.mtp_head = MTP(config)
self.mtp_loss_fn = MTPLoss(config)
else:
self.mtp_head = None
self.mtp_loss_fn = None
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
position_ids: torch.LongTensor = None,
past_key_values: Any = None,
inputs_embeds: torch.FloatTensor = None,
labels: torch.LongTensor = None,
mtp_labels: torch.LongTensor = None,
use_cache: bool = None,
output_attentions: bool = None,
output_hidden_states: bool = None,
return_dict: bool = None,
**kwargs,
):
# Device placement: ensure all inputs on model device
device = self.lm_head.weight.device
if input_ids is not None:
input_ids = input_ids.to(device)
if attention_mask is not None:
attention_mask = attention_mask.to(device)
if position_ids is not None:
position_ids = position_ids.to(device)
if inputs_embeds is not None:
inputs_embeds = inputs_embeds.to(device)
if labels is not None:
labels = labels.to(device)
if mtp_labels is not None:
mtp_labels = mtp_labels.to(device)
if getattr(self.config, 'mtp_depth', 0) <= 0 or self.mtp_head is None:
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=False,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=True,
**kwargs,
)
# Extract final hidden states from main LM for MTP
hidden_states = outputs.hidden_states[-1]
batch_size, seq_len, _ = hidden_states.size()
pad_id = self.config.pad_token_id or self.config.eos_token_id or 0
future_embs = []
# Prepare future token embeddings for each MTP step using shared embedding layer
for k in range(1, self.config.mtp_depth + 1):
tid = torch.cat(
[input_ids[:, k:], input_ids.new_full((batch_size, k), pad_id)],
dim=1,
)
future_embs.append(self.get_input_embeddings()(tid))
# Forward through MTP head modules and compute logits via shared LM head
mtp_hidden = self.mtp_head(hidden_states, future_embs)
# MTP hidden states 안전성 체크
mtp_logits = []
for i, h in enumerate(mtp_hidden):
# NaN/Inf 체크 및 정제
if torch.isnan(h).any() or torch.isinf(h).any():
h = torch.nan_to_num(h, nan=0.0, posinf=1e6, neginf=-1e6)
logits = self.lm_head(h)
# logits도 안전성 체크
if torch.isnan(logits).any() or torch.isinf(logits).any():
logits = torch.nan_to_num(logits, nan=0.0, posinf=1e6, neginf=-1e6)
mtp_logits.append(logits)
mtp_loss = None
if mtp_labels is not None:
mtp_loss = self.mtp_loss_fn(mtp_logits, mtp_labels)
# MTP loss 안전성 체크
if torch.isnan(mtp_loss) or torch.isinf(mtp_loss):
mtp_loss = torch.tensor(0.0, device=hidden_states.device, dtype=mtp_loss.dtype)
base_loss = outputs.loss if outputs.loss is not None else torch.tensor(0.0, device=hidden_states.device)
if mtp_loss is not None:
loss = base_loss + mtp_loss
else:
loss = base_loss
if not return_dict:
# outputs.past_key_values may be None when use_cache=False
logits = outputs.logits
past_kv = outputs.past_key_values or ()
# For tuple format, we can't easily include MTP logits
return (loss, logits, *past_kv)
# Create custom output object that includes MTP logits
class CausalLMOutputWithMTP(CausalLMOutputWithPast):
def __init__(self, loss=None, logits=None, past_key_values=None,
hidden_states=None, attentions=None, mtp_logits=None):
super().__init__(loss=loss, logits=logits, past_key_values=past_key_values,
hidden_states=hidden_states, attentions=attentions)
self.mtp_logits = mtp_logits
# Stack MTP logits for inference
mtp_logits_tensor = None
if mtp_logits:
# Stack along depth dimension: [batch, depth, seq_len, vocab]
mtp_logits_tensor = torch.stack(mtp_logits, dim=1)
return CausalLMOutputWithMTP(
loss=loss,
logits=outputs.logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
mtp_logits=mtp_logits_tensor,
)
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 50,
temperature: float = 1.0,
top_p: float = 0.9,
top_k: int = 50,
do_sample: bool = True,
speculative_decode: bool = False,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
**kwargs
) -> torch.Tensor:
"""
Generate text using MTP model with optional speculative decoding.
"""
input_ids = input_ids.to(self.device)
if pad_token_id is None:
pad_token_id = getattr(self.config, 'pad_token_id', None)
if eos_token_id is None:
eos_token_id = getattr(self.config, 'eos_token_id', None)
# Check if we should use speculative decoding
mtp_depth = getattr(self.config, 'mtp_depth', 0)
if speculative_decode and mtp_depth > 0 and self.mtp_head is not None:
return self._generate_speculative(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=do_sample,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
)
else:
# Use standard generation
return self._generate_standard(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=do_sample,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
)
def _generate_standard(
self,
input_ids: torch.Tensor,
max_new_tokens: int,
temperature: float,
top_p: float,
top_k: int,
do_sample: bool,
pad_token_id: int,
eos_token_id: int,
) -> torch.Tensor:
"""Standard autoregressive generation."""
current_ids = input_ids.clone()
for _ in range(max_new_tokens):
with torch.no_grad():
outputs = self(current_ids, return_dict=True)
logits = outputs.logits[:, -1, :] # Last token logits
# Apply temperature
if temperature != 1.0:
logits = logits / temperature
# Sample next token
if do_sample:
next_token = self._sample_token(logits, top_p, top_k)
else:
next_token = torch.argmax(logits, dim=-1, keepdim=True)
# Append new token
current_ids = torch.cat([current_ids, next_token], dim=-1)
# Check for EOS
if eos_token_id is not None and next_token.item() == eos_token_id:
break
return current_ids
def _generate_speculative(
self,
input_ids: torch.Tensor,
max_new_tokens: int,
temperature: float,
top_p: float,
top_k: int,
do_sample: bool,
pad_token_id: int,
eos_token_id: int,
) -> torch.Tensor:
"""Speculative decoding using the MTP head as a draft model and the main model as the verifier."""
current_ids = input_ids.clone()
mtp_depth = getattr(self.config, 'mtp_depth', 0)
remaining_tokens = max_new_tokens
while remaining_tokens > 0:
with torch.no_grad():
# 1. Draft Generation: MTP head proposes k candidate tokens
draft_outputs = self(current_ids, return_dict=True)
main_logits = draft_outputs.logits[:, -1, :]
mtp_logits = getattr(draft_outputs, 'mtp_logits', None)
if mtp_logits is None:
# Fallback to standard generation if no MTP head
return self._generate_standard(
current_ids, remaining_tokens, temperature, top_p, top_k,
do_sample, pad_token_id, eos_token_id
)
# Combine main and MTP logits for drafting
# Shape of mtp_logits: [batch, depth, seq_len, vocab] -> [batch, depth, vocab]
mtp_preds = mtp_logits[:, :, -1, :]
draft_logits = torch.cat([main_logits.unsqueeze(1), mtp_preds], dim=1)
# Generate draft tokens
if do_sample:
# Apply temperature
if temperature != 1.0:
draft_logits = draft_logits / temperature
draft_probs = F.softmax(draft_logits, dim=-1)
draft_tokens_indices = torch.multinomial(draft_probs.view(-1, self.config.vocab_size), num_samples=1)
draft_tokens = draft_tokens_indices.view(current_ids.size(0), -1)
else:
draft_tokens = torch.argmax(draft_logits, dim=-1)
num_draft_tokens = draft_tokens.shape[1]
# 2. Verification: Main model verifies the draft tokens in a single forward pass
candidate_ids = torch.cat([current_ids, draft_tokens], dim=-1)
verify_outputs = self(candidate_ids, return_dict=True)
# Get verification logits for the newly generated tokens
verify_logits = verify_outputs.logits[:, current_ids.shape[-1]-1:-1, :]
# 3. Acceptance/Rejection
# Greedily check which tokens from the draft match the main model's predictions
verifier_tokens = torch.argmax(verify_logits, dim=-1)
# Find the first mismatch
matches = (draft_tokens == verifier_tokens)
mismatch_indices = torch.where(~matches, 1, 0).argmax(dim=1)
# If all tokens match, the mismatch index will be 0 where the first element is also a match.
# We need to correct for this edge case.
all_matches = matches.all(dim=1)
accepted_len = 0
for i in range(matches.size(0)): # Batch dimension
if all_matches[i]:
accepted_len = num_draft_tokens
else:
accepted_len = mismatch_indices[i].item()
accepted_tokens = draft_tokens[:, :accepted_len]
current_ids = torch.cat([current_ids, accepted_tokens], dim=-1)
remaining_tokens -= accepted_len
# If not all tokens were accepted, sample one more token from the verifier at the mismatch position
if accepted_len < num_draft_tokens:
# Sample from the verifier's distribution at the point of mismatch
final_logits = verify_logits[:, accepted_len, :]
if do_sample:
if temperature != 1.0:
final_logits = final_logits / temperature
next_token = self._sample_token(final_logits, top_p, top_k)
else:
next_token = torch.argmax(final_logits, dim=-1, keepdim=True)
current_ids = torch.cat([current_ids, next_token], dim=-1)
remaining_tokens -= 1
# If all draft tokens were accepted, we can sample one more from the last position
else:
final_logits = verify_outputs.logits[:, -1, :]
if do_sample:
if temperature != 1.0:
final_logits = final_logits / temperature
next_token = self._sample_token(final_logits, top_p, top_k)
else:
next_token = torch.argmax(final_logits, dim=-1, keepdim=True)
current_ids = torch.cat([current_ids, next_token], dim=-1)
remaining_tokens -= 1
if eos_token_id is not None and (current_ids[:, -1] == eos_token_id).any():
break
return current_ids
def _sample_token(self, logits: torch.Tensor, top_p: float, top_k: int) -> torch.Tensor:
"""Sample token using top-p and top-k filtering."""
# Top-k filtering
if top_k > 0:
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = float('-inf')
# Top-p (nucleus) filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to keep the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = float('-inf')
# Sample from the filtered distribution
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
return next_token