| | |
| |
|
| | from typing import Any, Callable, Dict, List, Optional, Set, Union |
| |
|
| | import torch |
| |
|
| | import vllm.model_executor.layers.fused_moe |
| | from vllm import _custom_ops as ops |
| | from vllm.logger import init_logger |
| | from vllm.model_executor.layers.fused_moe.layer import ( |
| | FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) |
| | from vllm.model_executor.layers.linear import (LinearMethodBase, |
| | set_weight_attrs) |
| | from vllm.model_executor.layers.quantization.base_config import ( |
| | QuantizationConfig, QuantizeMethodBase) |
| | from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( |
| | MPLinearLayerConfig, choose_mp_linear_kernel) |
| | from vllm.model_executor.layers.quantization.utils import replace_parameter |
| | from vllm.model_executor.layers.quantization.utils.gptq_utils import ( |
| | get_linear_quant_method) |
| | from vllm.model_executor.layers.quantization.utils.marlin_utils import ( |
| | check_marlin_supported, check_moe_marlin_supports_layer, |
| | marlin_moe_permute_scales, marlin_repeat_scales_on_all_ranks, |
| | verify_marlin_supported) |
| | from vllm.model_executor.parameter import (ChannelQuantScaleParameter, |
| | GroupQuantScaleParameter, |
| | PackedColumnParameter, |
| | PackedvLLMParameter, |
| | RowvLLMParameter) |
| | from vllm.platforms import current_platform |
| | from vllm.scalar_type import scalar_types |
| |
|
| | logger = init_logger(__name__) |
| |
|
| |
|
| | class GPTQMarlinConfig(QuantizationConfig): |
| | """Config class for GPTQ Marlin""" |
| |
|
| | |
| | TYPE_MAP = { |
| | (4, True): scalar_types.uint4b8, |
| | (8, True): scalar_types.uint8b128, |
| | } |
| |
|
| | def __init__(self, weight_bits: int, group_size: int, desc_act: bool, |
| | is_sym: bool, lm_head_quantized: bool, |
| | dynamic: Dict[str, Dict[str, Union[int, bool]]], |
| | full_config: Dict[str, Any]) -> None: |
| | super().__init__() |
| | if desc_act and group_size == -1: |
| | |
| | |
| | desc_act = False |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | self.dynamic = dynamic |
| |
|
| | self.weight_bits = weight_bits |
| | self.is_sym = is_sym |
| |
|
| | self.pack_factor = 32 // weight_bits |
| | self.group_size = group_size |
| | self.desc_act = desc_act |
| | self.lm_head_quantized = lm_head_quantized |
| | self.full_config = full_config |
| |
|
| | if (weight_bits, is_sym) not in self.TYPE_MAP: |
| | raise ValueError("Unsupported quantization config: " |
| | f"bits={weight_bits}, sym={is_sym}") |
| |
|
| | self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] |
| |
|
| | def __repr__(self) -> str: |
| | return (f"GPTQMarlinConfig(quant_type={self.quant_type}, " |
| | f"group_size={self.group_size}, " |
| | f"desc_act={self.desc_act}, " |
| | f"lm_head_quantized={self.lm_head_quantized}), " |
| | f"dynamic={self.dynamic}") |
| |
|
| | @classmethod |
| | def get_name(cls) -> str: |
| | return "gptq_marlin" |
| |
|
| | @classmethod |
| | def get_supported_act_dtypes(cls) -> List[torch.dtype]: |
| | return [torch.half, torch.bfloat16] |
| |
|
| | @classmethod |
| | def get_min_capability(cls) -> int: |
| | return 80 |
| |
|
| | @classmethod |
| | def get_config_filenames(cls) -> List[str]: |
| | return ["quantize_config.json"] |
| |
|
| | @classmethod |
| | def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": |
| | dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) |
| | dynamic = {} if dynamic is None else dynamic |
| |
|
| | weight_bits = cls.get_from_keys(config, ["bits"]) |
| | group_size = cls.get_from_keys(config, ["group_size"]) |
| | desc_act = cls.get_from_keys(config, ["desc_act"]) |
| | is_sym = cls.get_from_keys(config, ["sym"]) |
| | lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], |
| | default=False) |
| | return cls(weight_bits, group_size, desc_act, is_sym, |
| | lm_head_quantized, dynamic, config) |
| |
|
| | @classmethod |
| | def override_quantization_method(cls, hf_quant_cfg, |
| | user_quant) -> Optional[str]: |
| | can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg) |
| |
|
| | is_valid_user_quant = (user_quant is None or user_quant == "marlin" |
| | or user_quant == "gptq_marlin") |
| |
|
| | if can_convert and is_valid_user_quant: |
| | msg = ("The model is convertible to {} during runtime." |
| | " Using {} kernel.".format(cls.get_name(), cls.get_name())) |
| | logger.info(msg) |
| | return cls.get_name() |
| |
|
| | if can_convert and user_quant == "gptq": |
| | logger.info("Detected that the model can run with gptq_marlin" |
| | ", however you specified quantization=gptq explicitly," |
| | " so forcing gptq. Use quantization=gptq_marlin for" |
| | " faster inference") |
| | return None |
| |
|
| | def get_quant_method(self, layer: torch.nn.Module, |
| | prefix: str) -> Optional["QuantizeMethodBase"]: |
| | if isinstance(layer, FusedMoE): |
| | from vllm.model_executor.layers.quantization.moe_wna16 import ( |
| | MoeWNA16Config) |
| | if not check_moe_marlin_supports_layer(layer, self.group_size): |
| | logger.warning( |
| | f"Layer '{prefix}' is not supported by GPTQMoeMarlin. " |
| | "Falling back to Moe WNA16 kernels.") |
| | return MoeWNA16Config.from_config( |
| | self.full_config).get_quant_method(layer, prefix) |
| | return GPTQMarlinMoEMethod(self) |
| | return get_linear_quant_method(self, layer, prefix, |
| | GPTQMarlinLinearMethod) |
| |
|
| | @classmethod |
| | def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): |
| | quant_method = quant_config.get("quant_method", "").lower() |
| | num_bits = quant_config.get("bits") |
| | group_size = quant_config.get("group_size") |
| | sym = quant_config.get("sym") |
| | desc_act = quant_config.get("desc_act") |
| |
|
| | if not current_platform.is_cuda(): |
| | return False |
| |
|
| | if quant_method != "gptq": |
| | return False |
| |
|
| | |
| | if (num_bits is None or group_size is None or sym is None |
| | or desc_act is None): |
| | return False |
| |
|
| | if (num_bits, sym) not in cls.TYPE_MAP: |
| | return False |
| |
|
| | return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)], |
| | group_size=group_size) |
| |
|
| |
|
| | class GPTQMarlinLinearMethod(LinearMethodBase): |
| | """Linear method for GPTQ Marlin. |
| | |
| | Args: |
| | quant_config: The GPTQ Marlin quantization config. |
| | """ |
| |
|
| | _kernel_backends_being_used: Set[str] = set() |
| |
|
| | def __init__(self, quant_config: GPTQMarlinConfig) -> None: |
| | self.quant_config = quant_config |
| |
|
| | |
| | verify_marlin_supported(quant_type=self.quant_config.quant_type, |
| | group_size=self.quant_config.group_size) |
| |
|
| | def create_weights( |
| | self, |
| | layer: torch.nn.Module, |
| | input_size_per_partition: int, |
| | output_partition_sizes: List[int], |
| | input_size: int, |
| | output_size: int, |
| | params_dtype: torch.dtype, |
| | **extra_weight_attrs, |
| | ) -> None: |
| | output_size_per_partition = sum(output_partition_sizes) |
| | is_row_parallel = input_size != input_size_per_partition |
| | weight_loader = extra_weight_attrs.get("weight_loader") |
| |
|
| | mp_linear_kernel_config = MPLinearLayerConfig( |
| | full_weight_shape=(input_size, output_size), |
| | partition_weight_shape=\ |
| | (input_size_per_partition, output_size_per_partition), |
| | weight_type=self.quant_config.quant_type, |
| | act_type=params_dtype, |
| | group_size=self.quant_config.group_size, |
| | zero_points=False, |
| | has_g_idx=self.quant_config.desc_act |
| | ) |
| |
|
| | kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) |
| |
|
| | if kernel_type.__name__ not in self._kernel_backends_being_used: |
| | logger.info("Using %s for GPTQMarlinLinearMethod", |
| | kernel_type.__name__) |
| | self._kernel_backends_being_used.add(kernel_type.__name__) |
| |
|
| | |
| | if self.quant_config.group_size != -1: |
| | group_size = self.quant_config.group_size |
| | else: |
| | group_size = input_size |
| |
|
| | |
| | if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, |
| | self.quant_config.group_size, |
| | is_row_parallel): |
| | |
| | |
| | scales_and_zp_input_dim = None |
| | scales_and_zp_size = input_size // group_size |
| | else: |
| | |
| | |
| | scales_and_zp_input_dim = 0 |
| | scales_and_zp_size = input_size_per_partition // group_size |
| |
|
| | |
| | qweight = PackedvLLMParameter( |
| | data=torch.empty( |
| | input_size_per_partition // self.quant_config.pack_factor, |
| | output_size_per_partition, |
| | dtype=torch.int32, |
| | ), |
| | input_dim=0, |
| | output_dim=1, |
| | packed_dim=0, |
| | packed_factor=self.quant_config.pack_factor, |
| | weight_loader=weight_loader) |
| |
|
| | |
| | g_idx = RowvLLMParameter(data=torch.empty( |
| | input_size_per_partition, |
| | dtype=torch.int32, |
| | ), |
| | input_dim=0, |
| | weight_loader=weight_loader) |
| |
|
| | qzeros_args = { |
| | "data": |
| | torch.empty( |
| | scales_and_zp_size, |
| | output_size_per_partition // self.quant_config.pack_factor, |
| | dtype=torch.int32, |
| | ), |
| | "weight_loader": |
| | weight_loader |
| | } |
| | weight_scale_args = { |
| | "data": |
| | torch.empty( |
| | scales_and_zp_size, |
| | output_size_per_partition, |
| | dtype=params_dtype, |
| | ), |
| | "weight_loader": |
| | weight_loader |
| | } |
| |
|
| | if scales_and_zp_input_dim is None: |
| | scales = ChannelQuantScaleParameter(output_dim=1, |
| | **weight_scale_args) |
| | qzeros = PackedColumnParameter( |
| | output_dim=1, |
| | packed_dim=1, |
| | packed_factor=self.quant_config.pack_factor, |
| | **qzeros_args) |
| |
|
| | else: |
| | scales = GroupQuantScaleParameter(output_dim=1, |
| | input_dim=0, |
| | **weight_scale_args) |
| | qzeros = PackedvLLMParameter( |
| | input_dim=0, |
| | output_dim=1, |
| | packed_dim=1, |
| | packed_factor=self.quant_config.pack_factor, |
| | **qzeros_args) |
| |
|
| | layer.register_parameter("qweight", qweight) |
| | layer.register_parameter("g_idx", g_idx) |
| | layer.register_parameter("scales", scales) |
| | layer.register_parameter("qzeros", qzeros) |
| |
|
| | self.kernel = kernel_type(mp_linear_kernel_config, |
| | w_q_param_name="qweight", |
| | w_s_param_name="scales", |
| | w_zp_param_name="qzeros", |
| | w_gidx_param_name="g_idx") |
| |
|
| | def process_weights_after_loading(self, layer: torch.nn.Module) -> None: |
| | self.kernel.process_weights_after_loading(layer) |
| |
|
| | def apply( |
| | self, |
| | layer: torch.nn.Module, |
| | x: torch.Tensor, |
| | bias: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | return self.kernel.apply_weights(layer, x, bias) |
| |
|
| |
|
| | class GPTQMarlinMoEMethod(FusedMoEMethodBase): |
| | """MoE Marlin method with quantization.""" |
| |
|
| | def __init__(self, quant_config: GPTQMarlinConfig) -> None: |
| | self.quant_config = quant_config |
| |
|
| | def create_weights( |
| | self, |
| | layer: torch.nn.Module, |
| | num_experts: int, |
| | hidden_size: int, |
| | intermediate_size_per_partition: int, |
| | params_dtype: torch.dtype, |
| | **extra_weight_attrs, |
| | ): |
| | intermediate_size_full = extra_weight_attrs.pop( |
| | "intermediate_size_full") |
| |
|
| | self.is_k_full = (not self.quant_config.desc_act) or ( |
| | intermediate_size_per_partition == intermediate_size_full) |
| |
|
| | if self.quant_config.group_size != -1: |
| | scales_size13 = hidden_size // self.quant_config.group_size |
| | w2_scales_size = (intermediate_size_full |
| | if self.quant_config.desc_act else |
| | intermediate_size_per_partition) |
| | scales_size2 = (w2_scales_size // self.quant_config.group_size) |
| | strategy = FusedMoeWeightScaleSupported.GROUP.value |
| | else: |
| | scales_size13 = 1 |
| | scales_size2 = 1 |
| | strategy = FusedMoeWeightScaleSupported.CHANNEL.value |
| |
|
| | extra_weight_attrs.update({ |
| | "quant_method": strategy, |
| | "is_transposed": True |
| | }) |
| | |
| | w13_qweight = torch.nn.Parameter( |
| | torch.empty( |
| | num_experts, |
| | hidden_size // self.quant_config.pack_factor, |
| | 2 * intermediate_size_per_partition, |
| | dtype=torch.int32, |
| | ), |
| | requires_grad=False, |
| | ) |
| | layer.register_parameter("w13_qweight", w13_qweight) |
| | set_weight_attrs(w13_qweight, extra_weight_attrs) |
| | |
| | w2_qweight = torch.nn.Parameter( |
| | torch.empty( |
| | num_experts, |
| | intermediate_size_per_partition // |
| | self.quant_config.pack_factor, |
| | hidden_size, |
| | dtype=torch.int32, |
| | ), |
| | requires_grad=False, |
| | ) |
| | layer.register_parameter("w2_qweight", w2_qweight) |
| | set_weight_attrs(w2_qweight, extra_weight_attrs) |
| | |
| | w13_scales = torch.nn.Parameter( |
| | torch.empty(num_experts, |
| | scales_size13, |
| | 2 * intermediate_size_per_partition, |
| | dtype=params_dtype), |
| | requires_grad=False, |
| | ) |
| | layer.register_parameter("w13_scales", w13_scales) |
| | set_weight_attrs(w13_scales, extra_weight_attrs) |
| | |
| | w2_scales = torch.nn.Parameter( |
| | torch.empty(num_experts, |
| | scales_size2, |
| | hidden_size, |
| | dtype=params_dtype), |
| | requires_grad=False, |
| | ) |
| | layer.register_parameter("w2_scales", w2_scales) |
| | set_weight_attrs(w2_scales, extra_weight_attrs) |
| | |
| | set_weight_attrs(w2_scales, |
| | {"load_full_w2": self.quant_config.desc_act}) |
| | |
| | w13_qzeros = torch.nn.Parameter( |
| | torch.empty(num_experts, |
| | scales_size13, |
| | 2 * intermediate_size_per_partition // |
| | self.quant_config.pack_factor, |
| | dtype=params_dtype), |
| | requires_grad=False, |
| | ) |
| | layer.register_parameter("w13_qzeros", w13_qzeros) |
| | set_weight_attrs(w13_qzeros, extra_weight_attrs) |
| | |
| | w2_qzeros = torch.nn.Parameter( |
| | torch.empty(num_experts, |
| | scales_size2, |
| | hidden_size // self.quant_config.pack_factor, |
| | dtype=params_dtype), |
| | requires_grad=False, |
| | ) |
| | layer.register_parameter("w2_qzeros", w2_qzeros) |
| | set_weight_attrs(w2_qzeros, extra_weight_attrs) |
| | |
| | set_weight_attrs(w2_qzeros, |
| | {"load_full_w2": self.quant_config.desc_act}) |
| | w13_g_idx = torch.nn.Parameter( |
| | torch.empty( |
| | num_experts, |
| | hidden_size, |
| | dtype=torch.int32, |
| | ), |
| | requires_grad=False, |
| | ) |
| | layer.register_parameter("w13_g_idx", w13_g_idx) |
| | set_weight_attrs(w13_g_idx, extra_weight_attrs) |
| | w2_g_idx = torch.nn.Parameter( |
| | torch.empty( |
| | num_experts, |
| | intermediate_size_per_partition, |
| | dtype=torch.int32, |
| | ), |
| | requires_grad=False, |
| | ) |
| | layer.register_parameter("w2_g_idx", w2_g_idx) |
| | set_weight_attrs(w2_g_idx, extra_weight_attrs) |
| | w13_g_idx_sort_indices = torch.nn.Parameter( |
| | torch.empty( |
| | num_experts, |
| | hidden_size, |
| | dtype=torch.int32, |
| | ), |
| | requires_grad=False, |
| | ) |
| | layer.register_parameter("w13_g_idx_sort_indices", |
| | w13_g_idx_sort_indices) |
| | set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) |
| | w2_g_idx_sort_indices = torch.nn.Parameter( |
| | torch.empty( |
| | num_experts, |
| | intermediate_size_per_partition, |
| | dtype=torch.int32, |
| | ), |
| | requires_grad=False, |
| | ) |
| | layer.register_parameter("w2_g_idx_sort_indices", |
| | w2_g_idx_sort_indices) |
| | set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) |
| |
|
| | device = layer.w13_qweight.device |
| | sms = torch.cuda.get_device_properties(device).multi_processor_count |
| | layer.workspace = torch.zeros((sms * 4, ), |
| | dtype=torch.int, |
| | device=device, |
| | requires_grad=False) |
| |
|
| | def process_weights_after_loading(self, layer: torch.nn.Module) -> None: |
| |
|
| | |
| | if self.quant_config.desc_act: |
| | |
| | num_experts = layer.w13_g_idx.shape[0] |
| | w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx) |
| | w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx) |
| | w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) |
| | w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) |
| | for e in range(num_experts): |
| | w13_g_idx_sort_indices[e] = torch.argsort( |
| | layer.w13_g_idx[e]).to(torch.int32) |
| | w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to( |
| | torch.int32) |
| | w13_sorted_g_idx[e] = layer.w13_g_idx[e][ |
| | w13_g_idx_sort_indices[e]] |
| | w2_sorted_g_idx[e] = layer.w2_g_idx[e][ |
| | w2_g_idx_sort_indices[e]] |
| | replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx) |
| | replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx) |
| | replace_parameter(layer, "w13_g_idx_sort_indices", |
| | w13_g_idx_sort_indices) |
| | replace_parameter(layer, "w2_g_idx_sort_indices", |
| | w2_g_idx_sort_indices) |
| | else: |
| | |
| | num_experts = layer.w13_g_idx.shape[0] |
| | device = layer.w13_g_idx.device |
| | layer.w13_g_idx = torch.nn.Parameter( |
| | torch.empty((num_experts, 0), dtype=torch.int32, |
| | device=device), |
| | requires_grad=False, |
| | ) |
| | layer.w2_g_idx = torch.nn.Parameter( |
| | torch.empty((num_experts, 0), dtype=torch.int32, |
| | device=device), |
| | requires_grad=False, |
| | ) |
| | layer.w13_g_idx_sort_indices = torch.nn.Parameter( |
| | torch.empty((num_experts, 0), dtype=torch.int32, |
| | device=device), |
| | requires_grad=False, |
| | ) |
| | layer.w2_g_idx_sort_indices = torch.nn.Parameter( |
| | torch.empty((num_experts, 0), dtype=torch.int32, |
| | device=device), |
| | requires_grad=False, |
| | ) |
| | |
| | marlin_w13_qweight = ops.gptq_marlin_moe_repack( |
| | layer.w13_qweight, |
| | layer.w13_g_idx_sort_indices, |
| | layer.w13_qweight.shape[1] * self.quant_config.pack_factor, |
| | layer.w13_qweight.shape[2], |
| | self.quant_config.quant_type.size_bits, |
| | ) |
| | replace_parameter(layer, "w13_qweight", marlin_w13_qweight) |
| | marlin_w2_qweight = ops.gptq_marlin_moe_repack( |
| | layer.w2_qweight, |
| | layer.w2_g_idx_sort_indices, |
| | layer.w2_qweight.shape[1] * self.quant_config.pack_factor, |
| | layer.w2_qweight.shape[2], |
| | self.quant_config.quant_type.size_bits, |
| | ) |
| | replace_parameter(layer, "w2_qweight", marlin_w2_qweight) |
| | |
| | marlin_w13_scales = marlin_moe_permute_scales( |
| | s=layer.w13_scales, |
| | size_k=layer.intermediate_size_per_partition, |
| | size_n=layer.w13_scales.shape[2], |
| | group_size=self.quant_config.group_size, |
| | ) |
| | replace_parameter(layer, "w13_scales", marlin_w13_scales) |
| | marlin_w2_scales = marlin_moe_permute_scales( |
| | s=layer.w2_scales, |
| | size_k=layer.w2_scales.shape[1] * |
| | (self.quant_config.group_size if self.quant_config.group_size != -1 |
| | else self.quant_config.pack_factor), |
| | size_n=layer.w2_scales.shape[2], |
| | group_size=self.quant_config.group_size, |
| | ) |
| | replace_parameter(layer, "w2_scales", marlin_w2_scales) |
| |
|
| | def apply( |
| | self, |
| | layer: torch.nn.Module, |
| | x: torch.Tensor, |
| | router_logits: torch.Tensor, |
| | top_k: int, |
| | renormalize: bool, |
| | use_grouped_topk: bool = False, |
| | topk_group: Optional[int] = None, |
| | num_expert_group: Optional[int] = None, |
| | global_num_experts: int = -1, |
| | expert_map: Optional[torch.Tensor] = None, |
| | custom_routing_function: Optional[Callable] = None, |
| | scoring_func: str = "softmax", |
| | e_score_correction_bias: Optional[torch.Tensor] = None, |
| | apply_router_weight_on_input: bool = False, |
| | activation: str = "silu", |
| | ) -> torch.Tensor: |
| | assert activation == "silu", "Only SiLU activation is supported." |
| | if apply_router_weight_on_input: |
| | raise NotImplementedError( |
| | "Apply router weight on input is not supported for" |
| | "fused Marlin MoE method.") |
| |
|
| | topk_weights, topk_ids = FusedMoE.select_experts( |
| | hidden_states=x, |
| | router_logits=router_logits, |
| | use_grouped_topk=use_grouped_topk, |
| | top_k=top_k, |
| | renormalize=renormalize, |
| | topk_group=topk_group, |
| | num_expert_group=num_expert_group, |
| | custom_routing_function=custom_routing_function, |
| | scoring_func=scoring_func, |
| | e_score_correction_bias=e_score_correction_bias) |
| |
|
| | return torch.ops.vllm.fused_marlin_moe( |
| | x, |
| | layer.w13_qweight, |
| | layer.w2_qweight, |
| | layer.w13_scales, |
| | layer.w2_scales, |
| | router_logits, |
| | topk_weights, |
| | topk_ids, |
| | global_num_experts=global_num_experts, |
| | expert_map=expert_map, |
| | g_idx1=layer.w13_g_idx, |
| | g_idx2=layer.w2_g_idx, |
| | sort_indices1=layer.w13_g_idx_sort_indices, |
| | sort_indices2=layer.w2_g_idx_sort_indices, |
| | num_bits=self.quant_config.quant_type.size_bits, |
| | workspace=layer.workspace, |
| | is_k_full=self.is_k_full) |
| |
|