from dataclasses import dataclass import torch import torch.nn as nn from torch.nn import functional as F @dataclass class ReLUState: zeros: torch.Tensor class StatefulReLU(nn.Module): can_torch_compile = True has_backward = True hidden_size: int @staticmethod def create_state(device: torch.device, layer: nn.Module) -> ReLUState: zeros = torch.zeros(layer.hidden_size, device=device) return ReLUState(zeros=zeros) def forward_with_state(self, state: ReLUState, input: torch.Tensor) -> torch.Tensor: return torch.maximum(input, state.zeros) __all__ = ["StatefulReLU"]