| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | import torch |
| | import triton |
| | import triton.language as tl |
| |
|
| |
|
| | @triton.jit |
| | def _fwd_kernel( |
| | Q, |
| | K, |
| | V, |
| | Out, |
| | S, |
| | stride_qz, |
| | stride_qh, |
| | stride_qm, |
| | stride_qk, |
| | stride_kz, |
| | stride_kh, |
| | stride_kn, |
| | stride_kk, |
| | stride_vz, |
| | stride_vh, |
| | stride_vn, |
| | stride_ve, |
| | stride_oz, |
| | stride_oh, |
| | stride_om, |
| | stride_oe, |
| | stride_sh, |
| | Z, |
| | H, |
| | N_CTX, |
| | BLOCK_M: tl.constexpr, |
| | BLOCK_DMODEL_QK: tl.constexpr, |
| | BLOCK_N: tl.constexpr, |
| | BLOCK_DMODEL_V: tl.constexpr, |
| | IS_CAUSAL: tl.constexpr, |
| | USE_DECAY: tl.constexpr, |
| | ): |
| | start_m = tl.program_id(0) |
| | off_hz = tl.program_id(1) |
| | off_h = off_hz % H |
| | |
| | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| | offs_n = tl.arange(0, BLOCK_N) |
| | offs_k = tl.arange(0, BLOCK_DMODEL_QK) |
| | offs_e = tl.arange(0, BLOCK_DMODEL_V) |
| | |
| | off_q = (off_hz * stride_qh + offs_m[:, None] * stride_qm + |
| | offs_k[None, :] * stride_qk) |
| | off_k = (off_hz * stride_kh + offs_n[:, None] * stride_kn + |
| | offs_k[None, :] * stride_kk) |
| | off_v = (off_hz * stride_vh + offs_n[:, None] * stride_vn + |
| | offs_e[None, :] * stride_ve) |
| | off_o = (off_hz * stride_oh + offs_m[:, None] * stride_om + |
| | offs_e[None, :] * stride_oe) |
| |
|
| | |
| | q_ptrs = Q + off_q |
| | k_ptrs = K + off_k |
| | v_ptrs = V + off_v |
| |
|
| | |
| | acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_V], dtype=tl.float32) |
| | |
| | q = tl.load(q_ptrs, mask=offs_m[:, None] < N_CTX, other=0.0) |
| | |
| | lo = 0 |
| | |
| | hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX |
| | for start_n in range(lo, hi, BLOCK_N): |
| | |
| | k = tl.load( |
| | k_ptrs + start_n * stride_kn, |
| | mask=(start_n + offs_n)[:, None] < N_CTX, |
| | other=0.0, |
| | ) |
| | v = tl.load( |
| | v_ptrs + start_n * stride_vn, |
| | mask=(start_n + offs_n)[:, None] < N_CTX, |
| | other=0.0, |
| | ) |
| | |
| | |
| | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) |
| | |
| | qk += tl.dot(q, tl.trans(k)) |
| | if IS_CAUSAL: |
| | index = offs_m[:, None] - (start_n + offs_n[None, :]) |
| | if USE_DECAY: |
| | S_block_ptr = S + off_h * stride_sh |
| | s = tl.load(S_block_ptr) |
| | s_index = s * index |
| | s_index = tl.where(s_index >= 0, -s_index, float("-inf")) |
| | qk = tl.exp(s_index) * qk |
| | else: |
| | qk = tl.where(index >= 0, qk, 0) |
| | acc += tl.dot(qk, v.to(qk.dtype)) |
| |
|
| | out_ptrs = Out + off_o |
| | tl.store(out_ptrs, acc.to(q.dtype), mask=offs_m[:, None] < N_CTX) |
| |
|
| |
|
| | @triton.jit |
| | def _bwd_kernel_kv( |
| | Q, |
| | K, |
| | V, |
| | S, |
| | DO, |
| | DQ, |
| | DK, |
| | DV, |
| | stride_qz, |
| | stride_qh, |
| | stride_qm, |
| | stride_qk, |
| | stride_kz, |
| | stride_kh, |
| | stride_kn, |
| | stride_kk, |
| | stride_vz, |
| | stride_vh, |
| | stride_vn, |
| | stride_ve, |
| | stride_oz, |
| | stride_oh, |
| | stride_om, |
| | stride_oe, |
| | stride_sh, |
| | Z, |
| | H, |
| | N_CTX, |
| | num_block, |
| | BLOCK_M: tl.constexpr, |
| | BLOCK_DMODEL_QK: tl.constexpr, |
| | BLOCK_N: tl.constexpr, |
| | BLOCK_DMODEL_V: tl.constexpr, |
| | CAUSAL: tl.constexpr, |
| | USE_DECAY: tl.constexpr, |
| | ): |
| | start_n = tl.program_id(0) |
| | off_hz = tl.program_id(1) |
| |
|
| | off_z = off_hz // H |
| | off_h = off_hz % H |
| | |
| | Q += off_z * stride_qz + off_h * stride_qh |
| | K += off_z * stride_kz + off_h * stride_kh |
| | V += off_z * stride_vz + off_h * stride_vh |
| | DO += off_z * stride_oz + off_h * stride_oh |
| | DQ += off_z * stride_qz + off_h * stride_qh |
| | DK += off_z * stride_kz + off_h * stride_kh |
| | DV += off_z * stride_vz + off_h * stride_vh |
| |
|
| | |
| | if CAUSAL: |
| | lo = start_n * BLOCK_M |
| | else: |
| | lo = 0 |
| | |
| | |
| | offs_qm = lo + tl.arange(0, BLOCK_M) |
| | offs_kvn = start_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| | |
| | offs_qkk = tl.arange(0, BLOCK_DMODEL_QK) |
| | offs_ve = tl.arange(0, BLOCK_DMODEL_V) |
| | |
| | offs_m = tl.arange(0, BLOCK_M) |
| | |
| | q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_qkk[None, :] * stride_qk) |
| | k_ptrs = K + (offs_kvn[:, None] * stride_kn + |
| | offs_qkk[None, :] * stride_kk) |
| | v_ptrs = V + (offs_kvn[:, None] * stride_vn + offs_ve[None, :] * stride_ve) |
| | do_ptrs = DO + (offs_qm[:, None] * stride_om + |
| | offs_ve[None, :] * stride_oe) |
| | dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + |
| | offs_qkk[None, :] * stride_qk) |
| | |
| | dv = tl.zeros([BLOCK_N, BLOCK_DMODEL_V], dtype=tl.float32) |
| | dk = tl.zeros([BLOCK_N, BLOCK_DMODEL_QK], dtype=tl.float32) |
| | |
| | k = tl.load(k_ptrs, mask=offs_kvn[:, None] < N_CTX, other=0.0) |
| | v = tl.load(v_ptrs, mask=offs_kvn[:, None] < N_CTX, other=0.0) |
| | |
| | for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): |
| | offs_m_curr = start_m + offs_m |
| | |
| | q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < N_CTX, other=0.0) |
| | qk = tl.dot(q, tl.trans(k)) |
| | |
| | if CAUSAL: |
| | index = offs_m_curr[:, None] - offs_kvn[None, :] |
| | if USE_DECAY: |
| | S_block_ptr = S + off_h * stride_sh |
| | s = tl.load(S_block_ptr) |
| | s_index = s * index |
| | s_index = tl.where(s_index >= 0, -s_index, float("-inf")) |
| | s = tl.exp(s_index) |
| | qk = qk * s |
| | else: |
| | qk = tl.where(index >= 0, qk, 0) |
| |
|
| | p = qk |
| | |
| | do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < N_CTX, other=0.0) |
| | dv += tl.dot(tl.trans(p.to(do.dtype)), do) |
| | dp = tl.dot(do, tl.trans(v).to(do.dtype)) |
| | if CAUSAL: |
| | if USE_DECAY: |
| | dp = dp * s |
| | else: |
| | dp = tl.where(index >= 0, dp, 0) |
| |
|
| | dk += tl.dot(tl.trans(dp.to(q.dtype)), q).to(tl.float32) |
| |
|
| | |
| | q_ptrs += BLOCK_M * stride_qm |
| | do_ptrs += BLOCK_M * stride_om |
| | |
| | dv_ptrs = DV + (offs_kvn[:, None] * stride_vn + |
| | offs_ve[None, :] * stride_ve) |
| | dk_ptrs = DK + (offs_kvn[:, None] * stride_kn + |
| | offs_qkk[None, :] * stride_kk) |
| | tl.store(dv_ptrs, dv, mask=offs_kvn[:, None] < N_CTX) |
| | tl.store(dk_ptrs, dk, mask=offs_kvn[:, None] < N_CTX) |
| |
|
| |
|
| | @triton.jit |
| | def _bwd_kernel_q( |
| | Q, |
| | K, |
| | V, |
| | S, |
| | DO, |
| | DQ, |
| | DK, |
| | DV, |
| | stride_qz, |
| | stride_qh, |
| | stride_qm, |
| | stride_qk, |
| | stride_kz, |
| | stride_kh, |
| | stride_kn, |
| | stride_kk, |
| | stride_vz, |
| | stride_vh, |
| | stride_vn, |
| | stride_ve, |
| | stride_oz, |
| | stride_oh, |
| | stride_om, |
| | stride_oe, |
| | stride_sh, |
| | Z, |
| | H, |
| | N_CTX, |
| | num_block, |
| | BLOCK_M: tl.constexpr, |
| | BLOCK_DMODEL_QK: tl.constexpr, |
| | BLOCK_N: tl.constexpr, |
| | BLOCK_DMODEL_V: tl.constexpr, |
| | CAUSAL: tl.constexpr, |
| | USE_DECAY: tl.constexpr, |
| | ): |
| | start_m = tl.program_id(0) |
| | off_hz = tl.program_id(1) |
| | off_z = off_hz // H |
| | off_h = off_hz % H |
| | |
| | K += off_z * stride_kz + off_h * stride_kh |
| | V += off_z * stride_vz + off_h * stride_vh |
| | DO += off_z * stride_oz + off_h * stride_oh |
| | DQ += off_z * stride_qz + off_h * stride_qh |
| | |
| | offs_qkk = tl.arange(0, BLOCK_DMODEL_QK) |
| | offs_ve = tl.arange(0, BLOCK_DMODEL_V) |
| | |
| | offs_m = tl.arange(0, BLOCK_M) |
| | |
| | offs_qm = start_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| | |
| | do_ptrs = DO + (offs_qm[:, None] * stride_om + |
| | offs_ve[None, :] * stride_oe) |
| | dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + |
| | offs_qkk[None, :] * stride_qk) |
| |
|
| | do = tl.load(do_ptrs, mask=offs_qm[:, None] < N_CTX, other=0.0) |
| |
|
| | dq = tl.zeros([BLOCK_M, BLOCK_DMODEL_QK], dtype=tl.float32) |
| | lo = 0 |
| | hi = (start_m + 1) * BLOCK_M if CAUSAL else N_CTX |
| |
|
| | offs_m_curr = start_m * BLOCK_M + offs_m |
| |
|
| | for start_n in range(0, num_block): |
| | offs_kvn = start_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| | k_ptrs = K + (offs_kvn[:, None] * stride_kn + |
| | offs_qkk[None, :] * stride_kk) |
| | v_ptrs = V + (offs_kvn[:, None] * stride_vn + |
| | offs_ve[None, :] * stride_ve) |
| | |
| | k = tl.load(k_ptrs, mask=offs_kvn[:, None] < N_CTX, other=0.0) |
| | v = tl.load(v_ptrs, mask=offs_kvn[:, None] < N_CTX, other=0.0) |
| | |
| | dp = tl.dot(do, tl.trans(v).to(do.dtype)) |
| | if CAUSAL: |
| | index = offs_m_curr[:, None] - offs_kvn[None, :] |
| | if USE_DECAY: |
| | S_block_ptr = S + off_h * stride_sh |
| | s = tl.load(S_block_ptr) |
| | s_index = s * index |
| | s_index = tl.where(s_index >= 0, -s_index, float("-inf")) |
| | s = tl.exp(s_index) |
| | dp = dp * s |
| | else: |
| | dp = tl.where(index >= 0, dp, 0) |
| | |
| | dq += tl.dot(dp.to(k.dtype), k) |
| |
|
| | tl.store(dq_ptrs, dq, mask=offs_qm[:, None] < N_CTX) |
| |
|
| |
|
| | class _attention(torch.autograd.Function): |
| |
|
| | @staticmethod |
| | def forward(ctx, q, k, v, causal, s): |
| | q = q.contiguous() |
| | k = k.contiguous() |
| | v = v.contiguous() |
| | s = s.contiguous() |
| | |
| | capability = torch.cuda.get_device_capability() |
| | if capability[0] < 8: |
| | raise RuntimeError( |
| | "Lightning attention currently only supported for compute capability >= 80" |
| | ) |
| | |
| | Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] |
| | |
| | o = torch.empty( |
| | (q.shape[0], q.shape[1], q.shape[2], v.shape[-1]), |
| | dtype=q.dtype, |
| | device=q.device, |
| | ) |
| |
|
| | BLOCK_M = 128 |
| | BLOCK_N = 64 |
| | num_warps = 4 if Lk <= 64 else 8 |
| | num_stages = 1 |
| |
|
| | grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) |
| | use_decay = s.shape[0] > 0 |
| | _fwd_kernel[grid]( |
| | q, |
| | k, |
| | v, |
| | o, |
| | s, |
| | q.stride(0), |
| | q.stride(1), |
| | q.stride(2), |
| | q.stride(3), |
| | k.stride(0), |
| | k.stride(1), |
| | k.stride(2), |
| | k.stride(3), |
| | v.stride(0), |
| | v.stride(1), |
| | v.stride(2), |
| | v.stride(3), |
| | o.stride(0), |
| | o.stride(1), |
| | o.stride(2), |
| | o.stride(3), |
| | s.stride(0), |
| | q.shape[0], |
| | q.shape[1], |
| | q.shape[2], |
| | BLOCK_M=BLOCK_M, |
| | BLOCK_DMODEL_QK=Lk, |
| | BLOCK_N=BLOCK_N, |
| | BLOCK_DMODEL_V=Lv, |
| | IS_CAUSAL=causal, |
| | USE_DECAY=use_decay, |
| | num_warps=num_warps, |
| | num_stages=num_stages, |
| | ) |
| |
|
| | ctx.save_for_backward(q, k, v, s) |
| | ctx.grid = grid |
| | ctx.BLOCK_M = BLOCK_M |
| | ctx.BLOCK_DMODEL_QK = Lk |
| | ctx.BLOCK_N = BLOCK_N |
| | ctx.BLOCK_DMODEL_V = Lv |
| | ctx.causal = causal |
| | ctx.use_decay = use_decay |
| | return o |
| |
|
| | @staticmethod |
| | def backward(ctx, do): |
| | q, k, v, s = ctx.saved_tensors |
| | BLOCK_M = 32 |
| | BLOCK_N = 32 |
| | num_warps = 4 |
| | num_stages = 1 |
| |
|
| | do = do.contiguous() |
| | dq = torch.zeros_like(q, dtype=torch.float32) |
| | dk = torch.empty_like(k) |
| | dv = torch.empty_like(v) |
| |
|
| | grid_kv = (triton.cdiv(k.shape[2], |
| | BLOCK_N), k.shape[0] * k.shape[1], 1) |
| | _bwd_kernel_kv[grid_kv]( |
| | q, |
| | k, |
| | v, |
| | s, |
| | do, |
| | dq, |
| | dk, |
| | dv, |
| | q.stride(0), |
| | q.stride(1), |
| | q.stride(2), |
| | q.stride(3), |
| | k.stride(0), |
| | k.stride(1), |
| | k.stride(2), |
| | k.stride(3), |
| | v.stride(0), |
| | v.stride(1), |
| | v.stride(2), |
| | v.stride(3), |
| | do.stride(0), |
| | do.stride(1), |
| | do.stride(2), |
| | do.stride(3), |
| | s.stride(0), |
| | q.shape[0], |
| | q.shape[1], |
| | q.shape[2], |
| | grid_kv[0], |
| | BLOCK_M=BLOCK_M, |
| | BLOCK_DMODEL_QK=ctx.BLOCK_DMODEL_QK, |
| | BLOCK_N=BLOCK_N, |
| | BLOCK_DMODEL_V=ctx.BLOCK_DMODEL_V, |
| | CAUSAL=ctx.causal, |
| | USE_DECAY=ctx.use_decay, |
| | num_warps=num_warps, |
| | num_stages=num_stages, |
| | ) |
| |
|
| | grid_q = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) |
| |
|
| | _bwd_kernel_q[grid_q]( |
| | q, |
| | k, |
| | v, |
| | s, |
| | do, |
| | dq, |
| | dk, |
| | dv, |
| | q.stride(0), |
| | q.stride(1), |
| | q.stride(2), |
| | q.stride(3), |
| | k.stride(0), |
| | k.stride(1), |
| | k.stride(2), |
| | k.stride(3), |
| | v.stride(0), |
| | v.stride(1), |
| | v.stride(2), |
| | v.stride(3), |
| | do.stride(0), |
| | do.stride(1), |
| | do.stride(2), |
| | do.stride(3), |
| | s.stride(0), |
| | q.shape[0], |
| | q.shape[1], |
| | q.shape[2], |
| | grid_q[0], |
| | BLOCK_M=BLOCK_M, |
| | BLOCK_DMODEL_QK=ctx.BLOCK_DMODEL_QK, |
| | BLOCK_N=BLOCK_N, |
| | BLOCK_DMODEL_V=ctx.BLOCK_DMODEL_V, |
| | CAUSAL=ctx.causal, |
| | USE_DECAY=ctx.use_decay, |
| | num_warps=num_warps, |
| | num_stages=num_stages, |
| | ) |
| |
|
| | return dq.to(q.dtype), dk, dv, None, None |
| |
|
| |
|
| | attention = _attention.apply |
| |
|
| |
|
| | def lightning_attention(q, k, v, causal, ed): |
| | d = q.shape[-1] |
| | e = v.shape[-1] |
| | |
| | if d >= 128: |
| | m = 128 |
| | else: |
| | m = 64 |
| | arr = [m * i for i in range(d // m + 1)] |
| | if arr[-1] != d: |
| | arr.append(d) |
| | n = len(arr) |
| | output = 0 |
| | for i in range(n - 1): |
| | s = arr[i] |
| | e = arr[i + 1] |
| | q1 = q[..., s:e] |
| | k1 = k[..., s:e] |
| | o = attention(q1, k1, v, causal, ed) |
| | output = output + o |
| |
|
| | return output |
| |
|