Kernels
sae
elephantmipt commited on
Commit
a262a48
·
verified ·
1 Parent(s): 25069e6

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.DS_Store
2
+ *__pycache__
3
+ *__MACOSX
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2025 T-Tech
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
NOTICE ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Flex SAE Kernels
2
+ Copyright (c) 2025 T-Tech
3
+
4
+ This project is distributed under the Apache License, Version 2.0. The following
5
+ third-party component is redistributed under its original terms:
6
+
7
+ - Portions of `torch-ext/flex_sae/topk_kernels.py` are adapted from the Facebook
8
+ Research project "memory" (https://github.com/facebookresearch/memory).
9
+ That source is licensed under the Creative Commons Attribution-NonCommercial
10
+ 4.0 International License (CC BY-NC 4.0). Any use of the adapted code must
11
+ comply with the non-commercial requirements described at
12
+ https://creativecommons.org/licenses/by-nc/4.0/.
13
+
14
+ Where the Apache 2.0 license and CC BY-NC 4.0 differ, the more restrictive
15
+ requirements apply to the adapted code. All other files are provided under the
16
+ Apache License, Version 2.0.
README.md CHANGED
@@ -1,3 +1,196 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - kernel
5
+ - sae
6
+ ---
7
+ # Flex SAE Kernels
8
+
9
+ [![ArXiv](https://img.shields.io/badge/arXiv-2505.24473-b31b1b.svg)](https://arxiv.org/abs/2505.24473)
10
+
11
+ Fused Triton implementations of the TopK and HierarchicalTopK sparse autoencoder (SAE) decoder losses described in *Train One Sparse Autoencoder Across Multiple Sparsity Budgets to Preserve Interpretability and Accuracy*.
12
+
13
+ **This work has been accepted to [EMNLP 2025](https://2025.emnlp.org/).**
14
+
15
+ ## What is released?
16
+
17
+ - Fast TopK kernel for SAE (slightly modified version from xformers) `torch-ext/flex_sae/topk_kernels.py`
18
+ - Fast HierarchicalTopK kernels (see our [paper](https://arxiv.org/abs/2505.24473)) `torch-ext/flex_sae/hierarchical_kernels.py`.
19
+
20
+
21
+ ## Quickstart
22
+
23
+ Kernels are available via loading from hub, they have the following signature:
24
+ ```python
25
+ from kernels import get_kernel
26
+
27
+
28
+ flex = get_kernel('t-tech/flex-sae')
29
+
30
+ top_k_kernel = flex.triton_topk_sae_loss
31
+ hierarchical_top_k_kernel = flex.triton_hierarchical_sae_loss
32
+
33
+ "B -- batch size, K -- top-k, F -- dictionary size, D -- model hidden dim"
34
+
35
+ loss: torch.Tensor = top_k_kernel(
36
+ indices: torch.Tensor, # [B, K]
37
+ weight: torch.Tensor, # [F, D]
38
+ vals: torch.Tensor, # [B, K]
39
+ bias: torch.Tensor, # [D]
40
+ target: torch.Tensor, # [B, D]
41
+ )
42
+
43
+ loss: torch.Tensor = hierarchical_top_k_kernel(
44
+ indices: torch.Tensor, # [B, K]
45
+ weight: torch.Tensor, # [F, D]
46
+ vals: torch.Tensor, # [B, K]
47
+ bias: torch.Tensor, # [D]
48
+ target: torch.Tensor, # [B, D]
49
+ )
50
+ ```
51
+
52
+ ## Overview
53
+ - `torch-ext/flex_sae/` contains the Triton kernels alongside torch reference implementations.
54
+ - `tests/` hosts CUDA-backed property tests that ensure numerical parity across dtypes and kernels.
55
+ - `build.toml`, `flake.nix` integrate the project with [Hugging Face kernel-builder](https://github.com/huggingface/kernel-builder).
56
+
57
+ The Triton kernels target CUDA GPUs and focus on reducing the latency gap between TopK and HierarchicalTopK decoders while keeping memory usage flat.
58
+
59
+ ## Example
60
+
61
+ You can find example usage in [example.py](https://huggingface.co/t-tech/flex-sae/blob/main/example.py).
62
+ ```python
63
+ # /// script
64
+ # dependencies = [
65
+ # "torch",
66
+ # "numpy",
67
+ # "kernels",
68
+ # ]
69
+ # ///
70
+
71
+ import torch
72
+ import numpy as np
73
+ from kernels import get_kernel
74
+
75
+ flex = get_kernel("t-tech/flex-sae") #Fast Kernels
76
+
77
+ @torch.compile(fullgraph=True)
78
+ def hierarchical_sae_loss(
79
+ indices: torch.Tensor, # [B, K]
80
+ weight: torch.Tensor, # [F, D]
81
+ vals: torch.Tensor, # [B, K]
82
+ bias: torch.Tensor, # [D]
83
+ target: torch.Tensor, # [B, D]
84
+ ) -> torch.Tensor:
85
+ emb = weight[indices].to(torch.float32) # [K, D]
86
+ recon_cum = bias.to(torch.float32) + (emb * vals.unsqueeze(-1)).cumsum(dim=1)
87
+ diff = recon_cum.to(torch.float32) - target.to(torch.float32).unsqueeze(1)
88
+ loss = diff.pow(2).mean()
89
+ return loss
90
+
91
+
92
+ B = 2048
93
+ K = 256
94
+ F = 1024 * 128
95
+ D = 1024
96
+ warmup = 5
97
+ dtype = torch.float32
98
+
99
+ vals = None
100
+ decoder = None
101
+ bias = None
102
+ target = None
103
+ indices = None
104
+
105
+
106
+ def init_parameters():
107
+ global vals, decoder, bias, target, indices
108
+ vals = torch.randn(B, K, dtype=dtype, device="cuda").abs().requires_grad_()
109
+ decoder = torch.randn(F, D, dtype=dtype, device="cuda", requires_grad=True)
110
+ bias = torch.randn(D, dtype=dtype, device="cuda", requires_grad=True)
111
+ target = torch.randn(B, D, dtype=dtype, device="cuda")
112
+ indices = torch.randint(0, F, (B, K), dtype=torch.long, device="cuda")
113
+
114
+
115
+ timing_kernel = []
116
+ timing_vanilla = []
117
+ torch.cuda.reset_peak_memory_stats()
118
+ loss_kernel_list = torch.zeros((100,))
119
+ loss_vanilla_list = torch.zeros((100,))
120
+
121
+
122
+ def zero_grad():
123
+ vals.grad = None
124
+ decoder.grad = None
125
+ bias.grad = None
126
+ torch.cuda.empty_cache()
127
+
128
+
129
+ for i in range(100 + warmup):
130
+ init_parameters()
131
+ start_kernel = torch.cuda.Event(enable_timing=True)
132
+ end_kernel = torch.cuda.Event(enable_timing=True)
133
+ start_vanilla = torch.cuda.Event(enable_timing=True)
134
+ end_vanilla = torch.cuda.Event(enable_timing=True)
135
+
136
+ start_kernel.record()
137
+ loss_kernel = flex.triton_hierarchical_sae_loss(indices, decoder, vals, bias, target)
138
+ loss_kernel.backward()
139
+ end_kernel.record()
140
+
141
+ zero_grad()
142
+ start_vanilla.record()
143
+ loss_vanilla = hierarchical_sae_loss(indices, decoder, vals, bias, target)
144
+ loss_vanilla.backward()
145
+ end_vanilla.record()
146
+ if i >= warmup:
147
+ torch.cuda.synchronize()
148
+ timing_kernel.append(start_kernel.elapsed_time(end_kernel))
149
+ timing_vanilla.append(start_vanilla.elapsed_time(end_vanilla))
150
+ loss_kernel_list[i-warmup] = loss_kernel.detach()
151
+ loss_vanilla_list[i-warmup] = loss_vanilla.detach()
152
+ zero_grad()
153
+
154
+ if torch.allclose(loss_kernel, loss_vanilla):
155
+ print("✅ Outputs are close! Everything is good! 🎉")
156
+ else:
157
+ print("❌ Outputs mismatch... ⚠️🤔")
158
+
159
+
160
+ print(f"🦎 Triton Kernel Time (Ours): {np.mean(timing_kernel):.4f} ± {np.std(timing_kernel):.4f} ms")
161
+ print(f"🔥 Torch Compile Kernel Time: {np.mean(timing_vanilla):.4f} ± {np.std(timing_vanilla):.4f} ms")
162
+ print(f"🚀 Speedup: {np.mean(timing_vanilla) / np.mean(timing_kernel):.2f}x")
163
+ ```
164
+
165
+ Run it with `uv run https://huggingface.co/t-tech/flex-sae/resolve/main/example.py`.
166
+
167
+ ## Performance
168
+ Benchmarks were collected on a workload with dictionary size $F = 65 536$, embedding dimension $D = 2304$, and sparsity budgets $K \in \{32, 64, 128\}$. Latency is reported as time per training step (milliseconds) and memory as peak device usage (GiB).
169
+
170
+ | Decoder backend | K=32 (ms / GiB) | K=64 (ms / GiB) | K=128 (ms / GiB) |
171
+ | --- | --- | --- | --- |
172
+ | **Pure torch-compiled** | | | |
173
+ | TopK | 8.787 / 2.92 | 11.746 / 2.92 | 18.877 / 2.93 |
174
+ | HierarchicalTopK | 12.824 / 6.29 | 23.379 / 10.79 | 43.851 / 19.80 |
175
+ | **Triton kernels** | | | |
176
+ | TopK | 5.576 / 2.92 | 6.339 / 2.92 | 7.961 / 2.93 |
177
+ | HierarchicalTopK | **6.696 / 2.92** | **7.995 / 2.92** | **10.609 / 2.93** |
178
+
179
+ Across the evaluated sparsity budgets the fused Triton HierarchicalTopK kernel matches TopK kernels on memory use while remaining consistently faster than the reference torch implementation.
180
+
181
+ ## License & Attribution
182
+ - All files except `torch-ext/flex_sae/topk_kernels.py` are released under the [Apache License 2.0](LICENSE).
183
+ - `torch-ext/flex_sae/topk_kernels.py` includes code adapted from Facebook Research's [memory](https://github.com/facebookresearch/memory) project, originally published under the Creative Commons Attribution-NonCommercial 4.0 International License. That component therefore remains available for non-commercial use only; see [NOTICE](NOTICE) for details.
184
+
185
+ ## Citation
186
+ ```bibtex
187
+ @misc{balagansky2025trainsparseautoencodermultiple,
188
+ title={Train One Sparse Autoencoder Across Multiple Sparsity Budgets to Preserve Interpretability and Accuracy},
189
+ author={Nikita Balagansky and Yaroslav Aksenov and Daniil Laptev and Vadim Kurochkin and Gleb Gerasimov and Nikita Koryagin and Daniil Gavrilov},
190
+ year={2025},
191
+ eprint={2505.24473},
192
+ archivePrefix={arXiv},
193
+ primaryClass={cs.LG},
194
+ url={https://arxiv.org/abs/2505.24473},
195
+ }
196
+ ```
build.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [general]
2
+ name = "flex_sae"
3
+ universal = true
build/torch-universal/flex_sae/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TopK and HierarchicalTopK SAE decoder Triton kernels
2
+ # Copyright 2025 T-Tech
3
+
4
+
5
+ from .topk_kernels import triton_topk_sae_loss, topk_sae_loss
6
+ from .hierarchical_kernels import triton_hierarchical_sae_loss, hierarchical_sae_loss
7
+
8
+ __kernel_metadata__ = {
9
+ "license": "Apache-2.0 (with CC-BY-NC-4.0 component; see NOTICE)",
10
+ }
11
+
12
+ __all__ = [
13
+ "__kernel_metadata__",
14
+ "topk_sae_loss",
15
+ "triton_topk_sae_loss",
16
+ "hierarchical_sae_loss",
17
+ "triton_hierarchical_sae_loss",
18
+ ]
build/torch-universal/flex_sae/_ops.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ ops = torch.ops._flex_sae_20250924130857
3
+
4
+ def add_op_namespace_prefix(op_name: str):
5
+ """
6
+ Prefix op by namespace.
7
+ """
8
+ return f"_flex_sae_20250924130857::{op_name}"
build/torch-universal/flex_sae/hierarchical_kernels.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HierarchicalTopK SAE decoder Triton kernels
2
+ # Copyright 2025 T-Tech
3
+
4
+
5
+ from typing import Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+
12
+ @triton.jit
13
+ def hierarchical_sae_forward_kernel(
14
+ loss_per_batch_ptr, # [B]
15
+ final_recon_ptr, # [B, D]
16
+ indices_ptr, # [B, K]
17
+ weight_ptr, # [F, D]
18
+ bias_ptr, # [D]
19
+ vals_ptr, # [B, K]
20
+ target_ptr, # [B, D]
21
+ B: tl.constexpr,
22
+ D: tl.constexpr,
23
+ K: tl.constexpr,
24
+ BLOCK_D: tl.constexpr,
25
+ LOOP_NUM_STAGES: tl.constexpr,
26
+ BLOCK_B: tl.constexpr,
27
+ ):
28
+ tl.static_assert((D % BLOCK_D) == 0)
29
+ tl.static_assert((B % BLOCK_B) == 0)
30
+ tl.static_assert((K & (K - 1)) == 0, f"{K=} must be a power of 2")
31
+ tl.static_assert((BLOCK_D & (BLOCK_D - 1)) == 0, f"{BLOCK_D=} must be a power of 2")
32
+ tl.static_assert((BLOCK_B & (BLOCK_B - 1)) == 0, f"{BLOCK_B=} must be a power of 2")
33
+
34
+ pid_b = tl.program_id(axis=0).to(tl.int64)
35
+ pid_d = tl.program_id(axis=1).to(tl.int64)
36
+
37
+ batch_offsets = pid_b * BLOCK_B + tl.arange(0, BLOCK_B)
38
+ batch_offsets = batch_offsets.to(tl.int64)
39
+ tl.multiple_of(batch_offsets, BLOCK_B)
40
+
41
+ offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
42
+ offset_d = offset_d.to(tl.int64)
43
+
44
+ tl.multiple_of(offset_d, BLOCK_D)
45
+ tl.max_contiguous(offset_d, BLOCK_D)
46
+
47
+ batch_d_offset = batch_offsets[:, None] * D + offset_d[None, :]
48
+
49
+ bias_tile = tl.load(bias_ptr + offset_d).to(tl.float32)
50
+
51
+ recon = tl.zeros([BLOCK_B, BLOCK_D], dtype=tl.float32)
52
+ recon += bias_tile[None, :]
53
+
54
+ target = tl.load(target_ptr + batch_d_offset).to(tl.float32)
55
+
56
+ loss_accum = tl.zeros([BLOCK_B, BLOCK_D], dtype=tl.float32)
57
+
58
+ row_idx_ptr = indices_ptr + batch_offsets * K
59
+ row_val_ptr = vals_ptr + batch_offsets * K
60
+
61
+ idx = tl.load(row_idx_ptr).to(tl.int64)
62
+ val = tl.load(row_val_ptr).to(tl.float32)
63
+ val = val[:, None]
64
+ weight_tile = tl.load(weight_ptr + idx[:, None] * D + offset_d[None, :]).to(tl.float32)
65
+
66
+ for t in tl.range(0, K, num_stages=LOOP_NUM_STAGES):
67
+ recon += weight_tile * val
68
+ diff = recon - target
69
+ loss_accum += diff * diff
70
+
71
+ if t + 1 < K:
72
+ idx_next = tl.load(row_idx_ptr + (t + 1)).to(tl.int64)
73
+ val_next = tl.load(row_val_ptr + (t + 1)).to(tl.float32)
74
+ weight_next = tl.load(weight_ptr + idx_next[:, None] * D + offset_d[None, :]).to(tl.float32)
75
+
76
+ idx = idx_next
77
+ val = val_next[:, None]
78
+ weight_tile = weight_next
79
+
80
+ loss_tile = tl.sum(loss_accum, axis=1)
81
+ tl.atomic_add(
82
+ loss_per_batch_ptr + batch_offsets,
83
+ loss_tile,
84
+ sem="relaxed",
85
+ )
86
+ tl.store(
87
+ final_recon_ptr + batch_d_offset,
88
+ recon,
89
+ )
90
+
91
+
92
+ @triton.jit
93
+ def hierarchical_sae_backward_kernel(
94
+ weight_grad_ptr, # [F, D]
95
+ vals_grad_ptr, # [B, K]
96
+ bias_grad_ptr, # [D]
97
+ final_recon_ptr, # [B, D]
98
+ indices_ptr, # [B, K]
99
+ weight_ptr, # [F, D]
100
+ vals_ptr, # [B, K]
101
+ target_ptr, # [B, D]
102
+ B: tl.constexpr,
103
+ D: tl.constexpr,
104
+ K: tl.constexpr,
105
+ BLOCK_D: tl.constexpr,
106
+ LOOP_NUM_STAGES: tl.constexpr,
107
+ BLOCK_B: tl.constexpr,
108
+ ):
109
+ tl.static_assert((D % BLOCK_D) == 0)
110
+ tl.static_assert((B % BLOCK_B) == 0)
111
+ tl.static_assert((K & (K - 1)) == 0, f"{K=} must be a power of 2")
112
+ tl.static_assert((BLOCK_D & (BLOCK_D - 1)) == 0, f"{BLOCK_D=} must be a power of 2")
113
+ tl.static_assert((BLOCK_B & (BLOCK_B - 1)) == 0, f"{BLOCK_B=} must be a power of 2")
114
+
115
+ pid_b = tl.program_id(axis=0).to(tl.int64)
116
+ pid_d = tl.program_id(axis=1).to(tl.int64)
117
+
118
+ batch_offsets = pid_b * BLOCK_B + tl.arange(0, BLOCK_B)
119
+ batch_offsets = batch_offsets.to(tl.int64)
120
+ tl.multiple_of(batch_offsets, BLOCK_B)
121
+
122
+ offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
123
+ offset_d = offset_d.to(tl.int64)
124
+
125
+ tl.multiple_of(offset_d, BLOCK_D)
126
+ tl.max_contiguous(offset_d, BLOCK_D)
127
+
128
+ batch_d_offset = batch_offsets[:, None] * D + offset_d[None, :]
129
+
130
+ recon = tl.load(final_recon_ptr + batch_d_offset).to(tl.float32)
131
+ target = tl.load(target_ptr + batch_d_offset).to(tl.float32)
132
+
133
+ suffix = tl.zeros([BLOCK_B, BLOCK_D], dtype=tl.float32)
134
+ bias_accum = tl.zeros([BLOCK_B, BLOCK_D], dtype=tl.float32)
135
+ scale = tl.full((), 2.0 / (B * K * D), dtype=tl.float32)
136
+
137
+ row_idx_ptr = indices_ptr + batch_offsets * K
138
+ row_val_ptr = vals_ptr + batch_offsets * K
139
+ k_offsets = tl.arange(0, K)
140
+ val_grad_tile = tl.zeros([BLOCK_B, K], dtype=tl.float32)
141
+
142
+ step = K - 1
143
+ idx = tl.load(row_idx_ptr + step).to(tl.int64)
144
+ val = tl.load(row_val_ptr + step).to(tl.float32)
145
+ weight_tile = tl.load(weight_ptr + idx[:, None] * D + offset_d[None, :]).to(tl.float32)
146
+
147
+ for _ in tl.range(0, K, num_stages=LOOP_NUM_STAGES):
148
+ curr_step = step
149
+
150
+ diff = recon - target
151
+ grad_curr = diff * scale
152
+ suffix += grad_curr
153
+ bias_accum += grad_curr
154
+
155
+ val_broadcast = val[:, None]
156
+ contrib = suffix * val_broadcast
157
+ tl.atomic_add(
158
+ weight_grad_ptr + idx[:, None] * D + offset_d[None, :],
159
+ contrib,
160
+ sem="relaxed",
161
+ )
162
+
163
+ dot_partial = tl.sum(weight_tile * suffix, axis=1)
164
+ mask_curr = k_offsets[None, :] == curr_step
165
+ val_grad_tile = tl.where(mask_curr, dot_partial[:, None], val_grad_tile)
166
+
167
+ recon -= weight_tile * val_broadcast
168
+
169
+ if curr_step > 0:
170
+ step = curr_step - 1
171
+ idx = tl.load(row_idx_ptr + step).to(tl.int64)
172
+ val = tl.load(row_val_ptr + step).to(tl.float32)
173
+ weight_tile = tl.load(weight_ptr + idx[:, None] * D + offset_d[None, :]).to(tl.float32)
174
+
175
+ bias_grad_tile = tl.sum(bias_accum, axis=0)
176
+ tl.atomic_add(
177
+ bias_grad_ptr + offset_d,
178
+ bias_grad_tile,
179
+ sem="relaxed",
180
+ )
181
+
182
+ row_val_grad_ptr = vals_grad_ptr + batch_offsets[:, None] * K + k_offsets[None, :]
183
+ tl.atomic_add(
184
+ row_val_grad_ptr,
185
+ val_grad_tile,
186
+ sem="relaxed",
187
+ )
188
+
189
+
190
+ def _hierarchical_sae_forward(
191
+ indices: torch.Tensor, # [B, K]
192
+ weight: torch.Tensor, # [F, D]
193
+ vals: torch.Tensor, # [B, K]
194
+ bias: torch.Tensor, # [D]
195
+ target: torch.Tensor, # [B, D]
196
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
197
+ B, K = indices.shape
198
+ F, D = weight.shape
199
+
200
+ loss_per_batch = torch.zeros((B,), dtype=torch.float32, device=weight.device)
201
+ final_recon = torch.empty((B, D), dtype=torch.float32, device=weight.device)
202
+
203
+ def _forward_grid(meta):
204
+ return (
205
+ B // meta["BLOCK_B"],
206
+ D // meta["BLOCK_D"],
207
+ )
208
+
209
+ hierarchical_sae_forward_kernel[_forward_grid](
210
+ loss_per_batch,
211
+ final_recon,
212
+ indices,
213
+ weight,
214
+ bias,
215
+ vals,
216
+ target,
217
+ B=B,
218
+ D=D,
219
+ K=K,
220
+ BLOCK_D=64,
221
+ LOOP_NUM_STAGES=4,
222
+ BLOCK_B=1,
223
+ num_warps=2,
224
+ num_stages=2,
225
+ )
226
+ loss = loss_per_batch.sum() / (B * K * D)
227
+ return loss, final_recon
228
+
229
+
230
+ def _hierarchical_sae_backward(
231
+ indices: torch.Tensor, # [B, K]
232
+ weight: torch.Tensor, # [F, D]
233
+ vals: torch.Tensor, # [B, K]
234
+ target: torch.Tensor, # [B, D]
235
+ final_recon: torch.Tensor, # [B, D]
236
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
237
+ device = weight.device
238
+ B, K = indices.shape
239
+ F, D = weight.shape
240
+
241
+ dW = torch.zeros((F, D), dtype=torch.float32, device=device)
242
+ dVals = torch.zeros((B, K), dtype=torch.float32, device=device)
243
+ db = torch.zeros((D,), dtype=torch.float32, device=device)
244
+
245
+ def _backward_grid(meta):
246
+ return (
247
+ B // meta["BLOCK_B"],
248
+ D // meta["BLOCK_D"],
249
+ )
250
+
251
+ hierarchical_sae_backward_kernel[_backward_grid](
252
+ dW,
253
+ dVals,
254
+ db,
255
+ final_recon,
256
+ indices,
257
+ weight,
258
+ vals,
259
+ target,
260
+ B=B,
261
+ D=D,
262
+ K=K,
263
+ BLOCK_D=32,
264
+ LOOP_NUM_STAGES=16,
265
+ BLOCK_B=16,
266
+ num_warps=8,
267
+ num_stages=8,
268
+ )
269
+
270
+ return dW, dVals, db
271
+
272
+
273
+ class HierarchicalSAELossFunction(torch.autograd.Function):
274
+ @staticmethod
275
+ @torch.amp.custom_fwd(device_type="cuda")
276
+ def forward(
277
+ ctx,
278
+ indices: torch.Tensor, # [B, K]
279
+ weight: torch.Tensor, # [F, D]
280
+ vals: torch.Tensor, # [B, K]
281
+ bias: torch.Tensor, # [D]
282
+ target: torch.Tensor, # [B, D]
283
+ ):
284
+ loss, final_recon = _hierarchical_sae_forward(indices, weight, vals, bias, target)
285
+ ctx.save_for_backward(indices, weight, vals, target, final_recon)
286
+ return loss
287
+
288
+ @staticmethod
289
+ @torch.amp.custom_bwd(device_type="cuda")
290
+ def backward(ctx, grad):
291
+ indices, weight, vals, target, final_recon = ctx.saved_tensors
292
+ dW, dVals, db = _hierarchical_sae_backward(indices, weight, vals, target, final_recon)
293
+
294
+ if grad is not None:
295
+ dW.mul_(grad)
296
+ dVals.mul_(grad)
297
+ db.mul_(grad)
298
+
299
+ return None, dW, dVals, db, None
300
+
301
+
302
+ def triton_hierarchical_sae_loss(
303
+ indices: torch.Tensor, # [B, K]
304
+ weight: torch.Tensor, # [F, D]
305
+ vals: torch.Tensor, # [B, K]
306
+ bias: torch.Tensor, # [D]
307
+ target: torch.Tensor, # [B, D]
308
+ ) -> torch.Tensor:
309
+ return HierarchicalSAELossFunction.apply(indices, weight, vals, bias, target)
310
+
311
+
312
+ def hierarchical_sae_loss(
313
+ indices: torch.Tensor, # [B, K]
314
+ weight: torch.Tensor, # [F, D]
315
+ vals: torch.Tensor, # [B, K]
316
+ bias: torch.Tensor, # [D]
317
+ target: torch.Tensor, # [B, D]
318
+ ) -> torch.Tensor:
319
+ emb = weight[indices].to(torch.float32) # [K, D]
320
+ recon_cum = bias.to(torch.float32) + (emb * vals.unsqueeze(-1)).cumsum(dim=1)
321
+ diff = recon_cum.to(torch.float32) - target.to(torch.float32).unsqueeze(1)
322
+ loss = diff.pow(2).mean()
323
+ return loss
build/torch-universal/flex_sae/topk_kernels.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TopK SAE decoder Triton kernels
2
+ # Copyright 2025 T-Tech
3
+ # This code is adapted from Facebook Research under the
4
+ # Creative Commons Attribution-NonCommercial 4.0 International License.
5
+ # Original code can be found at: https://github.com/facebookresearch/memory
6
+
7
+
8
+ from typing import Tuple
9
+
10
+ import torch
11
+ import triton
12
+ import triton.language as tl
13
+
14
+
15
+ @triton.jit
16
+ def embedding_bag_forward_kernel(
17
+ out_ptr, # [B, D]
18
+ indices_ptr, # [B, K]
19
+ weight_ptr, # [F, D]
20
+ vals_ptr, # [B, K]
21
+ D: tl.constexpr,
22
+ K: tl.constexpr,
23
+ BLOCK_D: tl.constexpr,
24
+ ):
25
+ tl.static_assert((D % BLOCK_D) == 0)
26
+ tl.static_assert((K & (K - 1)) == 0, f"{K=} must be a power of 2")
27
+ tl.static_assert((BLOCK_D & (BLOCK_D - 1)) == 0, f"{BLOCK_D=} must be a power of 2")
28
+
29
+ b = tl.program_id(axis=0).to(tl.int64)
30
+ pid_d = tl.program_id(axis=1).to(tl.int64)
31
+
32
+ off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
33
+
34
+ out_value = tl.zeros([BLOCK_D], dtype=tl.float32)
35
+ for i in tl.range(K):
36
+ my_index = tl.load(indices_ptr + b * K + i).to(tl.int64)
37
+ my_scaling = tl.load(vals_ptr + b * K + i)
38
+ w_tile = tl.load(weight_ptr + my_index * D + off_d).to(tl.float32)
39
+ out_value += w_tile * my_scaling
40
+
41
+ tl.store(out_ptr + b * D + off_d, out_value)
42
+
43
+
44
+ def embedding_bag_forward(
45
+ indices: torch.Tensor, # [B, K]
46
+ weight: torch.Tensor, # [F, D]
47
+ vals: torch.Tensor, # [B, K]
48
+ ) -> torch.Tensor:
49
+ B, K = indices.shape
50
+ D = weight.shape[1]
51
+
52
+ trt_out = torch.empty([B, D], dtype=weight.dtype, device=weight.device)
53
+
54
+ def _forward_grid(meta):
55
+ return (B, D // meta["BLOCK_D"])
56
+
57
+ embedding_bag_forward_kernel[_forward_grid](
58
+ trt_out,
59
+ indices,
60
+ weight,
61
+ vals,
62
+ D=D,
63
+ K=K,
64
+ BLOCK_D=64,
65
+ num_warps=1,
66
+ num_stages=1,
67
+ )
68
+ return trt_out
69
+
70
+
71
+ @triton.jit
72
+ def count_per_embedding_kernel(
73
+ count_per_emb_ptr, # [F + 1]
74
+ indices_ptr, # [B, K]
75
+ K: tl.constexpr,
76
+ ):
77
+ batch_id = tl.program_id(axis=0).to(tl.int64)
78
+ for t in tl.range(K):
79
+ embedding_id = tl.load(indices_ptr + batch_id * K + t)
80
+ tl.atomic_add(count_per_emb_ptr + embedding_id + 1, 1, sem="relaxed")
81
+
82
+
83
+ @triton.jit
84
+ def map_embeddings_and_outputs_kernel(
85
+ reverse_mapping_ptr, # [B * K]
86
+ mapping_write_pos_ptr, # [F]
87
+ indices_ptr, # [B, K]
88
+ K: tl.constexpr,
89
+ ):
90
+ batch_id = tl.program_id(axis=0).to(tl.int64)
91
+ for t in tl.range(K):
92
+ embedding_id = tl.load(indices_ptr + batch_id * K + t)
93
+ write_pos = tl.atomic_add(mapping_write_pos_ptr + embedding_id, 1, sem="relaxed")
94
+ tl.store(reverse_mapping_ptr + write_pos, batch_id * K + t)
95
+
96
+
97
+ @triton.jit
98
+ def aggregate_gradient_for_embedding_kernel(
99
+ weight_grad_ptr, # [F, D]
100
+ vals_grad_ptr, # [B, K]
101
+ weight_ptr, # [F, D]
102
+ emb_begin_pos_ptr, # [F + 1]
103
+ reverse_mapping_ptr, # [B * K]
104
+ vals_ptr, # [B, K]
105
+ gradient_ptr, # [B, D]
106
+ D: tl.constexpr,
107
+ K: tl.constexpr,
108
+ BLOCK_D: tl.constexpr,
109
+ ):
110
+ tl.static_assert((D % BLOCK_D) == 0)
111
+ tl.static_assert((K & (K - 1)) == 0, f"{K=} must be a power of 2")
112
+ tl.static_assert((BLOCK_D & (BLOCK_D - 1)) == 0, f"{BLOCK_D=} must be a power of 2")
113
+
114
+ e = tl.program_id(axis=0).to(tl.int64)
115
+ pid_d = tl.program_id(axis=1).to(tl.int64)
116
+
117
+ off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
118
+
119
+ begin = tl.load(emb_begin_pos_ptr + e)
120
+ end = tl.load(emb_begin_pos_ptr + e + 1)
121
+
122
+ w_row_tile = tl.load(weight_ptr + e * D + off_d).to(tl.float32)
123
+ w_grad_tile = tl.zeros([BLOCK_D], dtype=tl.float32)
124
+
125
+ for idx in tl.range(begin, end):
126
+ out_linear = tl.load(reverse_mapping_ptr + idx).to(tl.int64)
127
+ b = out_linear // K
128
+
129
+ psw = tl.load(vals_ptr + out_linear)
130
+ g_tile = tl.load(gradient_ptr + b * D + off_d).to(tl.float32)
131
+
132
+ w_grad_tile += psw * g_tile
133
+
134
+ psw_grad_partial = tl.sum(g_tile * w_row_tile)
135
+ tl.atomic_add(vals_grad_ptr + out_linear, psw_grad_partial, sem="relaxed")
136
+
137
+ tl.store(weight_grad_ptr + e * D + off_d, w_grad_tile)
138
+
139
+
140
+ def embedding_bag_backward(
141
+ indices: torch.Tensor, # [B, K]
142
+ weight: torch.Tensor, # [F, D]
143
+ vals: torch.Tensor, # [B, K]
144
+ gradient: torch.Tensor, # [B, D]
145
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
146
+ F, D = weight.shape
147
+ B, K = indices.shape
148
+
149
+ count_per_emb = torch.zeros((F + 1,), dtype=torch.uint32, device=indices.device)
150
+ count_per_embedding_kernel[(B,)](count_per_emb, indices, K=K, num_warps=1)
151
+
152
+ emb_begin_pos = count_per_emb.cumsum(0) # [F + 1]
153
+
154
+ reverse_mapping = torch.empty([B * K], dtype=torch.uint32, device=indices.device)
155
+ assert B * K <= 2 ** (reverse_mapping.dtype.itemsize * 8) - 1
156
+
157
+ map_embeddings_and_outputs_kernel[(B,)](
158
+ reverse_mapping_ptr=reverse_mapping,
159
+ mapping_write_pos_ptr=emb_begin_pos.clone(),
160
+ indices_ptr=indices,
161
+ K=K,
162
+ num_warps=1,
163
+ )
164
+
165
+ weight_grad = torch.empty_like(weight, dtype=torch.float32) # [F, D]
166
+ vals_grad = torch.zeros_like(vals, dtype=torch.float32) # [B, K]
167
+
168
+ def _forward_grid(meta):
169
+ return (F, D // meta["BLOCK_D"])
170
+
171
+ aggregate_gradient_for_embedding_kernel[_forward_grid](
172
+ weight_grad_ptr=weight_grad,
173
+ vals_grad_ptr=vals_grad,
174
+ weight_ptr=weight,
175
+ emb_begin_pos_ptr=emb_begin_pos,
176
+ reverse_mapping_ptr=reverse_mapping,
177
+ vals_ptr=vals,
178
+ gradient_ptr=gradient,
179
+ D=D,
180
+ K=K,
181
+ BLOCK_D=256,
182
+ num_warps=1,
183
+ num_stages=2,
184
+ )
185
+ return weight_grad, vals_grad
186
+
187
+
188
+ class xFormersEmbeddingBag(torch.autograd.Function):
189
+ @staticmethod
190
+ @torch.amp.custom_fwd(device_type="cuda")
191
+ def forward(
192
+ ctx,
193
+ indices: torch.Tensor, # [B, K]
194
+ weight: torch.Tensor, # [F, D]
195
+ vals: torch.Tensor, # [B, K]
196
+ ) -> torch.Tensor:
197
+ ctx.save_for_backward(indices, weight, vals)
198
+ return embedding_bag_forward(indices, weight, vals) # [B, D]
199
+
200
+ @staticmethod
201
+ @torch.amp.custom_bwd(device_type="cuda")
202
+ def backward(ctx, gradient):
203
+ indices, weight, vals = ctx.saved_tensors
204
+ weight_g, vals_g = embedding_bag_backward(
205
+ indices,
206
+ weight,
207
+ vals,
208
+ gradient,
209
+ )
210
+ return None, weight_g, vals_g
211
+
212
+
213
+ def triton_topk_sae_loss(
214
+ indices: torch.Tensor, # [B, K]
215
+ weight: torch.Tensor, # [F, D]
216
+ vals: torch.Tensor, # [B, K]
217
+ bias: torch.Tensor, # [D]
218
+ target: torch.Tensor, # [B, D]
219
+ ) -> torch.Tensor:
220
+ recon = bias.to(torch.float32) + xFormersEmbeddingBag.apply(indices, weight, vals)
221
+ diff = recon.to(torch.float32) - target.to(torch.float32)
222
+ loss = diff.pow(2).mean()
223
+ return loss
224
+
225
+
226
+ def topk_sae_loss(
227
+ indices: torch.Tensor, # [B, K]
228
+ weight: torch.Tensor, # [F, D]
229
+ vals: torch.Tensor, # [B, K]
230
+ bias: torch.Tensor, # [D]
231
+ target: torch.Tensor, # [B, D]
232
+ ) -> torch.Tensor:
233
+ emb = weight[indices].to(torch.float32) # [K, D]
234
+ recon = bias.to(torch.float32) + (emb * vals.unsqueeze(-1)).sum(dim=1)
235
+ diff = recon.to(torch.float32) - target.to(torch.float32)
236
+ loss = diff.pow(2).mean()
237
+ return loss
example.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "torch",
4
+ # "numpy",
5
+ # "kernels",
6
+ # ]
7
+ # ///
8
+
9
+ import torch
10
+ import numpy as np
11
+ from kernels import get_kernel
12
+
13
+ flex = get_kernel("t-tech/flex-sae") #Fast Kernels
14
+
15
+ @torch.compile(fullgraph=True)
16
+ def hierarchical_sae_loss(
17
+ indices: torch.Tensor, # [B, K]
18
+ weight: torch.Tensor, # [F, D]
19
+ vals: torch.Tensor, # [B, K]
20
+ bias: torch.Tensor, # [D]
21
+ target: torch.Tensor, # [B, D]
22
+ ) -> torch.Tensor:
23
+ emb = weight[indices].to(torch.float32) # [K, D]
24
+ recon_cum = bias.to(torch.float32) + (emb * vals.unsqueeze(-1)).cumsum(dim=1)
25
+ diff = recon_cum.to(torch.float32) - target.to(torch.float32).unsqueeze(1)
26
+ loss = diff.pow(2).mean()
27
+ return loss
28
+
29
+
30
+ B = 2048
31
+ K = 256
32
+ F = 1024 * 128
33
+ D = 1024
34
+ warmup = 5
35
+ dtype = torch.float32
36
+
37
+ vals = None
38
+ decoder = None
39
+ bias = None
40
+ target = None
41
+ indices = None
42
+
43
+
44
+ def init_parameters():
45
+ global vals, decoder, bias, target, indices
46
+ vals = torch.randn(B, K, dtype=dtype, device="cuda").abs().requires_grad_()
47
+ decoder = torch.randn(F, D, dtype=dtype, device="cuda", requires_grad=True)
48
+ bias = torch.randn(D, dtype=dtype, device="cuda", requires_grad=True)
49
+ target = torch.randn(B, D, dtype=dtype, device="cuda")
50
+ indices = torch.randint(0, F, (B, K), dtype=torch.long, device="cuda")
51
+
52
+
53
+ timing_kernel = []
54
+ timing_vanilla = []
55
+ torch.cuda.reset_peak_memory_stats()
56
+ loss_kernel_list = torch.zeros((100,))
57
+ loss_vanilla_list = torch.zeros((100,))
58
+
59
+
60
+ def zero_grad():
61
+ vals.grad = None
62
+ decoder.grad = None
63
+ bias.grad = None
64
+ torch.cuda.empty_cache()
65
+
66
+
67
+ for i in range(100 + warmup):
68
+ init_parameters()
69
+ start_kernel = torch.cuda.Event(enable_timing=True)
70
+ end_kernel = torch.cuda.Event(enable_timing=True)
71
+ start_vanilla = torch.cuda.Event(enable_timing=True)
72
+ end_vanilla = torch.cuda.Event(enable_timing=True)
73
+
74
+ start_kernel.record()
75
+ loss_kernel = flex.triton_hierarchical_sae_loss(indices, decoder, vals, bias, target)
76
+ loss_kernel.backward()
77
+ end_kernel.record()
78
+
79
+ zero_grad()
80
+ start_vanilla.record()
81
+ loss_vanilla = hierarchical_sae_loss(indices, decoder, vals, bias, target)
82
+ loss_vanilla.backward()
83
+ end_vanilla.record()
84
+ if i >= warmup:
85
+ torch.cuda.synchronize()
86
+ timing_kernel.append(start_kernel.elapsed_time(end_kernel))
87
+ timing_vanilla.append(start_vanilla.elapsed_time(end_vanilla))
88
+ loss_kernel_list[i-warmup] = loss_kernel.detach()
89
+ loss_vanilla_list[i-warmup] = loss_vanilla.detach()
90
+ zero_grad()
91
+
92
+ if torch.allclose(loss_kernel, loss_vanilla):
93
+ print("✅ Outputs are close! Everything is good! 🎉")
94
+ else:
95
+ print("❌ Outputs mismatch... ⚠️🤔")
96
+
97
+
98
+ print(f"🦎 Triton Kernel Time (Ours): {np.mean(timing_kernel):.4f} ± {np.std(timing_kernel):.4f} ms")
99
+ print(f"🔥 Torch Compile Kernel Time: {np.mean(timing_vanilla):.4f} ± {np.std(timing_vanilla):.4f} ms")
100
+ print(f"🚀 Speedup: {np.mean(timing_vanilla) / np.mean(timing_kernel):.2f}x")
flake.lock ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1747046372,
6
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-compat_2": {
19
+ "locked": {
20
+ "lastModified": 1747046372,
21
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
22
+ "owner": "edolstra",
23
+ "repo": "flake-compat",
24
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
25
+ "type": "github"
26
+ },
27
+ "original": {
28
+ "owner": "edolstra",
29
+ "repo": "flake-compat",
30
+ "type": "github"
31
+ }
32
+ },
33
+ "flake-utils": {
34
+ "inputs": {
35
+ "systems": "systems"
36
+ },
37
+ "locked": {
38
+ "lastModified": 1731533236,
39
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
+ "owner": "numtide",
41
+ "repo": "flake-utils",
42
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
+ "type": "github"
44
+ },
45
+ "original": {
46
+ "owner": "numtide",
47
+ "repo": "flake-utils",
48
+ "type": "github"
49
+ }
50
+ },
51
+ "flake-utils_2": {
52
+ "inputs": {
53
+ "systems": "systems_2"
54
+ },
55
+ "locked": {
56
+ "lastModified": 1731533236,
57
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
+ "owner": "numtide",
59
+ "repo": "flake-utils",
60
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
+ "type": "github"
62
+ },
63
+ "original": {
64
+ "owner": "numtide",
65
+ "repo": "flake-utils",
66
+ "type": "github"
67
+ }
68
+ },
69
+ "hf-nix": {
70
+ "inputs": {
71
+ "flake-compat": "flake-compat_2",
72
+ "flake-utils": "flake-utils_2",
73
+ "nixpkgs": "nixpkgs"
74
+ },
75
+ "locked": {
76
+ "lastModified": 1757675377,
77
+ "narHash": "sha256-JQKZOI1ZYO4faJnanuoTXziSmqzXe5rEFSGliWDWqWw=",
78
+ "owner": "huggingface",
79
+ "repo": "hf-nix",
80
+ "rev": "faf3354403a7381958d08e826c15fe30f6986a4f",
81
+ "type": "github"
82
+ },
83
+ "original": {
84
+ "owner": "huggingface",
85
+ "repo": "hf-nix",
86
+ "type": "github"
87
+ }
88
+ },
89
+ "kernel-builder": {
90
+ "inputs": {
91
+ "flake-compat": "flake-compat",
92
+ "flake-utils": "flake-utils",
93
+ "hf-nix": "hf-nix",
94
+ "nixpkgs": [
95
+ "kernel-builder",
96
+ "hf-nix",
97
+ "nixpkgs"
98
+ ]
99
+ },
100
+ "locked": {
101
+ "lastModified": 1758713083,
102
+ "narHash": "sha256-C7yob+hU6/IL7NDX0GVBxKKY3GPVNOwX9OU+LRCCVrk=",
103
+ "owner": "huggingface",
104
+ "repo": "kernel-builder",
105
+ "rev": "051fbc3dfe6afdbe01a6f15197b440d0333090cd",
106
+ "type": "github"
107
+ },
108
+ "original": {
109
+ "owner": "huggingface",
110
+ "repo": "kernel-builder",
111
+ "type": "github"
112
+ }
113
+ },
114
+ "nixpkgs": {
115
+ "locked": {
116
+ "lastModified": 1755963616,
117
+ "narHash": "sha256-6yD0ww/S8n+U2uPYcJZ3DRURP8Kx036GRpR2uPNZroE=",
118
+ "owner": "nixos",
119
+ "repo": "nixpkgs",
120
+ "rev": "73e96df7cff5783f45e21342a75a1540c4eddce4",
121
+ "type": "github"
122
+ },
123
+ "original": {
124
+ "owner": "nixos",
125
+ "ref": "nixos-unstable-small",
126
+ "repo": "nixpkgs",
127
+ "type": "github"
128
+ }
129
+ },
130
+ "root": {
131
+ "inputs": {
132
+ "kernel-builder": "kernel-builder"
133
+ }
134
+ },
135
+ "systems": {
136
+ "locked": {
137
+ "lastModified": 1681028828,
138
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
139
+ "owner": "nix-systems",
140
+ "repo": "default",
141
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
142
+ "type": "github"
143
+ },
144
+ "original": {
145
+ "owner": "nix-systems",
146
+ "repo": "default",
147
+ "type": "github"
148
+ }
149
+ },
150
+ "systems_2": {
151
+ "locked": {
152
+ "lastModified": 1681028828,
153
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
154
+ "owner": "nix-systems",
155
+ "repo": "default",
156
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
157
+ "type": "github"
158
+ },
159
+ "original": {
160
+ "owner": "nix-systems",
161
+ "repo": "default",
162
+ "type": "github"
163
+ }
164
+ }
165
+ },
166
+ "root": "root",
167
+ "version": 7
168
+ }
flake.nix ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for TopK and HierarchicaTopK SAE Triton kernels";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs {
14
+ path = ./.;
15
+ rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
+ doGetKernelCheck = true;
17
+ };
18
+ }
tests/__init__.py ADDED
File without changes
tests/test_all_kernels.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+ import pytest
3
+ import torch
4
+
5
+ pytest.importorskip("torch.cuda")
6
+ from .test_setup import DTYPES, DTYPE_TO_TOLS, PARAMS, SEED
7
+ from flex_sae import (
8
+ triton_hierarchical_sae_loss,
9
+ hierarchical_sae_loss,
10
+ triton_topk_sae_loss,
11
+ topk_sae_loss,
12
+ )
13
+
14
+
15
+ @pytest.fixture(autouse=True)
16
+ def _set_cuda_default_device():
17
+ torch.set_default_device("cuda")
18
+
19
+
20
+ def run_funcs(B, K, F, D, dtype, *, kernel_foo: Callable, ref_foo: Callable):
21
+ if dtype is torch.bfloat16 and not torch.cuda.is_bf16_supported():
22
+ pytest.skip("BF16 not supported on this GPU")
23
+
24
+ torch.manual_seed(SEED)
25
+
26
+ indices = torch.randint(0, F, (B, K), dtype=torch.long, device="cuda")
27
+
28
+ vals = torch.randn(B, K, dtype=dtype, device="cuda").abs().requires_grad_()
29
+ decoder = torch.randn(F, D, dtype=dtype, device="cuda", requires_grad=True)
30
+ bias = torch.randn(D, dtype=dtype, device="cuda", requires_grad=True)
31
+ target = torch.randn(B, D, dtype=dtype, device="cuda")
32
+
33
+ sv_ref = vals.clone().detach().requires_grad_()
34
+ dec_ref = decoder.clone().detach().requires_grad_()
35
+ bias_ref = bias.clone().detach().requires_grad_()
36
+
37
+ loss_f = kernel_foo(indices, decoder, vals, bias, target)
38
+ loss_r = ref_foo(indices, dec_ref, sv_ref, bias_ref, target)
39
+
40
+ torch.testing.assert_close(loss_f, loss_r, **DTYPE_TO_TOLS[dtype])
41
+
42
+ grad_out = torch.randn((), device="cuda", dtype=torch.float32)
43
+ loss_f.backward(grad_out)
44
+ loss_r.backward(grad_out.clone())
45
+
46
+ torch.testing.assert_close(vals.grad, sv_ref.grad, **DTYPE_TO_TOLS[dtype])
47
+ torch.testing.assert_close(decoder.grad, dec_ref.grad, **DTYPE_TO_TOLS[dtype])
48
+ torch.testing.assert_close(bias.grad, bias_ref.grad, **DTYPE_TO_TOLS[dtype])
49
+
50
+ assert indices.grad is None
51
+
52
+
53
+ @pytest.mark.parametrize("B, K, F, D", PARAMS)
54
+ @pytest.mark.parametrize("dtype", DTYPES)
55
+ def test_triton_hierarchical_sae_loss_and_grads(B, K, F, D, dtype):
56
+ run_funcs(B, K, F, D, dtype, kernel_foo=triton_hierarchical_sae_loss, ref_foo=hierarchical_sae_loss)
57
+ torch.cuda.empty_cache()
58
+
59
+
60
+ @pytest.mark.parametrize("B, K, F, D", PARAMS)
61
+ @pytest.mark.parametrize("dtype", DTYPES)
62
+ def test_topk_sae_loss_and_grads(B, K, F, D, dtype):
63
+ run_funcs(
64
+ B, K, F, D, dtype, kernel_foo=triton_topk_sae_loss, ref_foo=topk_sae_loss
65
+ )
66
+ torch.cuda.empty_cache()
tests/test_setup.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ SEED = 1234
4
+
5
+ PARAMS = [
6
+ (16, 16, 64, 512),
7
+ (16, 32, 96, 768),
8
+ (16, 64, 128, 1024),
9
+ (32, 16, 128, 512),
10
+ (32, 32, 160, 768),
11
+ (32, 64, 192, 1024),
12
+ (48, 32, 176, 1024),
13
+ (48, 64, 224, 1280),
14
+ (64, 16, 192, 768),
15
+ (64, 32, 224, 1024),
16
+ (64, 128, 256, 2048),
17
+ (80, 32, 240, 1280),
18
+ (80, 64, 256, 1536),
19
+ (96, 32, 256, 1536),
20
+ (96, 64, 288, 2048),
21
+ (96, 128, 320, 3072),
22
+ (112, 64, 320, 2048),
23
+ (112, 128, 352, 2560),
24
+ (128, 32, 256, 1024),
25
+ (128, 64, 320, 1536),
26
+ (128, 128, 384, 3072),
27
+ (160, 64, 320, 1536),
28
+ (160, 128, 384, 2560),
29
+ (192, 64, 384, 2048),
30
+ (192, 128, 448, 3072),
31
+ (192, 256, 512, 4096),
32
+ ]
33
+
34
+
35
+ DTYPE_TO_TOLS = {
36
+ torch.float32: {"atol": 1e-4, "rtol": 1e-3},
37
+ torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
38
+ torch.float16: {"atol": 1e-3, "rtol": 1e-3},
39
+ }
40
+
41
+ DTYPES = list(DTYPE_TO_TOLS.keys())
torch-ext/flex_sae/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TopK and HierarchicalTopK SAE decoder Triton kernels
2
+ # Copyright 2025 T-Tech
3
+
4
+
5
+ from .topk_kernels import triton_topk_sae_loss, topk_sae_loss
6
+ from .hierarchical_kernels import triton_hierarchical_sae_loss, hierarchical_sae_loss
7
+
8
+ __kernel_metadata__ = {
9
+ "license": "Apache-2.0 (with CC-BY-NC-4.0 component; see NOTICE)",
10
+ }
11
+
12
+ __all__ = [
13
+ "__kernel_metadata__",
14
+ "topk_sae_loss",
15
+ "triton_topk_sae_loss",
16
+ "hierarchical_sae_loss",
17
+ "triton_hierarchical_sae_loss",
18
+ ]
torch-ext/flex_sae/hierarchical_kernels.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HierarchicalTopK SAE decoder Triton kernels
2
+ # Copyright 2025 T-Tech
3
+
4
+
5
+ from typing import Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+
12
+ @triton.jit
13
+ def hierarchical_sae_forward_kernel(
14
+ loss_per_batch_ptr, # [B]
15
+ final_recon_ptr, # [B, D]
16
+ indices_ptr, # [B, K]
17
+ weight_ptr, # [F, D]
18
+ bias_ptr, # [D]
19
+ vals_ptr, # [B, K]
20
+ target_ptr, # [B, D]
21
+ B: tl.constexpr,
22
+ D: tl.constexpr,
23
+ K: tl.constexpr,
24
+ BLOCK_D: tl.constexpr,
25
+ LOOP_NUM_STAGES: tl.constexpr,
26
+ BLOCK_B: tl.constexpr,
27
+ ):
28
+ tl.static_assert((D % BLOCK_D) == 0)
29
+ tl.static_assert((B % BLOCK_B) == 0)
30
+ tl.static_assert((K & (K - 1)) == 0, f"{K=} must be a power of 2")
31
+ tl.static_assert((BLOCK_D & (BLOCK_D - 1)) == 0, f"{BLOCK_D=} must be a power of 2")
32
+ tl.static_assert((BLOCK_B & (BLOCK_B - 1)) == 0, f"{BLOCK_B=} must be a power of 2")
33
+
34
+ pid_b = tl.program_id(axis=0).to(tl.int64)
35
+ pid_d = tl.program_id(axis=1).to(tl.int64)
36
+
37
+ batch_offsets = pid_b * BLOCK_B + tl.arange(0, BLOCK_B)
38
+ batch_offsets = batch_offsets.to(tl.int64)
39
+ tl.multiple_of(batch_offsets, BLOCK_B)
40
+
41
+ offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
42
+ offset_d = offset_d.to(tl.int64)
43
+
44
+ tl.multiple_of(offset_d, BLOCK_D)
45
+ tl.max_contiguous(offset_d, BLOCK_D)
46
+
47
+ batch_d_offset = batch_offsets[:, None] * D + offset_d[None, :]
48
+
49
+ bias_tile = tl.load(bias_ptr + offset_d).to(tl.float32)
50
+
51
+ recon = tl.zeros([BLOCK_B, BLOCK_D], dtype=tl.float32)
52
+ recon += bias_tile[None, :]
53
+
54
+ target = tl.load(target_ptr + batch_d_offset).to(tl.float32)
55
+
56
+ loss_accum = tl.zeros([BLOCK_B, BLOCK_D], dtype=tl.float32)
57
+
58
+ row_idx_ptr = indices_ptr + batch_offsets * K
59
+ row_val_ptr = vals_ptr + batch_offsets * K
60
+
61
+ idx = tl.load(row_idx_ptr).to(tl.int64)
62
+ val = tl.load(row_val_ptr).to(tl.float32)
63
+ val = val[:, None]
64
+ weight_tile = tl.load(weight_ptr + idx[:, None] * D + offset_d[None, :]).to(tl.float32)
65
+
66
+ for t in tl.range(0, K, num_stages=LOOP_NUM_STAGES):
67
+ recon += weight_tile * val
68
+ diff = recon - target
69
+ loss_accum += diff * diff
70
+
71
+ if t + 1 < K:
72
+ idx_next = tl.load(row_idx_ptr + (t + 1)).to(tl.int64)
73
+ val_next = tl.load(row_val_ptr + (t + 1)).to(tl.float32)
74
+ weight_next = tl.load(weight_ptr + idx_next[:, None] * D + offset_d[None, :]).to(tl.float32)
75
+
76
+ idx = idx_next
77
+ val = val_next[:, None]
78
+ weight_tile = weight_next
79
+
80
+ loss_tile = tl.sum(loss_accum, axis=1)
81
+ tl.atomic_add(
82
+ loss_per_batch_ptr + batch_offsets,
83
+ loss_tile,
84
+ sem="relaxed",
85
+ )
86
+ tl.store(
87
+ final_recon_ptr + batch_d_offset,
88
+ recon,
89
+ )
90
+
91
+
92
+ @triton.jit
93
+ def hierarchical_sae_backward_kernel(
94
+ weight_grad_ptr, # [F, D]
95
+ vals_grad_ptr, # [B, K]
96
+ bias_grad_ptr, # [D]
97
+ final_recon_ptr, # [B, D]
98
+ indices_ptr, # [B, K]
99
+ weight_ptr, # [F, D]
100
+ vals_ptr, # [B, K]
101
+ target_ptr, # [B, D]
102
+ B: tl.constexpr,
103
+ D: tl.constexpr,
104
+ K: tl.constexpr,
105
+ BLOCK_D: tl.constexpr,
106
+ LOOP_NUM_STAGES: tl.constexpr,
107
+ BLOCK_B: tl.constexpr,
108
+ ):
109
+ tl.static_assert((D % BLOCK_D) == 0)
110
+ tl.static_assert((B % BLOCK_B) == 0)
111
+ tl.static_assert((K & (K - 1)) == 0, f"{K=} must be a power of 2")
112
+ tl.static_assert((BLOCK_D & (BLOCK_D - 1)) == 0, f"{BLOCK_D=} must be a power of 2")
113
+ tl.static_assert((BLOCK_B & (BLOCK_B - 1)) == 0, f"{BLOCK_B=} must be a power of 2")
114
+
115
+ pid_b = tl.program_id(axis=0).to(tl.int64)
116
+ pid_d = tl.program_id(axis=1).to(tl.int64)
117
+
118
+ batch_offsets = pid_b * BLOCK_B + tl.arange(0, BLOCK_B)
119
+ batch_offsets = batch_offsets.to(tl.int64)
120
+ tl.multiple_of(batch_offsets, BLOCK_B)
121
+
122
+ offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
123
+ offset_d = offset_d.to(tl.int64)
124
+
125
+ tl.multiple_of(offset_d, BLOCK_D)
126
+ tl.max_contiguous(offset_d, BLOCK_D)
127
+
128
+ batch_d_offset = batch_offsets[:, None] * D + offset_d[None, :]
129
+
130
+ recon = tl.load(final_recon_ptr + batch_d_offset).to(tl.float32)
131
+ target = tl.load(target_ptr + batch_d_offset).to(tl.float32)
132
+
133
+ suffix = tl.zeros([BLOCK_B, BLOCK_D], dtype=tl.float32)
134
+ bias_accum = tl.zeros([BLOCK_B, BLOCK_D], dtype=tl.float32)
135
+ scale = tl.full((), 2.0 / (B * K * D), dtype=tl.float32)
136
+
137
+ row_idx_ptr = indices_ptr + batch_offsets * K
138
+ row_val_ptr = vals_ptr + batch_offsets * K
139
+ k_offsets = tl.arange(0, K)
140
+ val_grad_tile = tl.zeros([BLOCK_B, K], dtype=tl.float32)
141
+
142
+ step = K - 1
143
+ idx = tl.load(row_idx_ptr + step).to(tl.int64)
144
+ val = tl.load(row_val_ptr + step).to(tl.float32)
145
+ weight_tile = tl.load(weight_ptr + idx[:, None] * D + offset_d[None, :]).to(tl.float32)
146
+
147
+ for _ in tl.range(0, K, num_stages=LOOP_NUM_STAGES):
148
+ curr_step = step
149
+
150
+ diff = recon - target
151
+ grad_curr = diff * scale
152
+ suffix += grad_curr
153
+ bias_accum += grad_curr
154
+
155
+ val_broadcast = val[:, None]
156
+ contrib = suffix * val_broadcast
157
+ tl.atomic_add(
158
+ weight_grad_ptr + idx[:, None] * D + offset_d[None, :],
159
+ contrib,
160
+ sem="relaxed",
161
+ )
162
+
163
+ dot_partial = tl.sum(weight_tile * suffix, axis=1)
164
+ mask_curr = k_offsets[None, :] == curr_step
165
+ val_grad_tile = tl.where(mask_curr, dot_partial[:, None], val_grad_tile)
166
+
167
+ recon -= weight_tile * val_broadcast
168
+
169
+ if curr_step > 0:
170
+ step = curr_step - 1
171
+ idx = tl.load(row_idx_ptr + step).to(tl.int64)
172
+ val = tl.load(row_val_ptr + step).to(tl.float32)
173
+ weight_tile = tl.load(weight_ptr + idx[:, None] * D + offset_d[None, :]).to(tl.float32)
174
+
175
+ bias_grad_tile = tl.sum(bias_accum, axis=0)
176
+ tl.atomic_add(
177
+ bias_grad_ptr + offset_d,
178
+ bias_grad_tile,
179
+ sem="relaxed",
180
+ )
181
+
182
+ row_val_grad_ptr = vals_grad_ptr + batch_offsets[:, None] * K + k_offsets[None, :]
183
+ tl.atomic_add(
184
+ row_val_grad_ptr,
185
+ val_grad_tile,
186
+ sem="relaxed",
187
+ )
188
+
189
+
190
+ def _hierarchical_sae_forward(
191
+ indices: torch.Tensor, # [B, K]
192
+ weight: torch.Tensor, # [F, D]
193
+ vals: torch.Tensor, # [B, K]
194
+ bias: torch.Tensor, # [D]
195
+ target: torch.Tensor, # [B, D]
196
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
197
+ B, K = indices.shape
198
+ F, D = weight.shape
199
+
200
+ loss_per_batch = torch.zeros((B,), dtype=torch.float32, device=weight.device)
201
+ final_recon = torch.empty((B, D), dtype=torch.float32, device=weight.device)
202
+
203
+ def _forward_grid(meta):
204
+ return (
205
+ B // meta["BLOCK_B"],
206
+ D // meta["BLOCK_D"],
207
+ )
208
+
209
+ hierarchical_sae_forward_kernel[_forward_grid](
210
+ loss_per_batch,
211
+ final_recon,
212
+ indices,
213
+ weight,
214
+ bias,
215
+ vals,
216
+ target,
217
+ B=B,
218
+ D=D,
219
+ K=K,
220
+ BLOCK_D=64,
221
+ LOOP_NUM_STAGES=4,
222
+ BLOCK_B=1,
223
+ num_warps=2,
224
+ num_stages=2,
225
+ )
226
+ loss = loss_per_batch.sum() / (B * K * D)
227
+ return loss, final_recon
228
+
229
+
230
+ def _hierarchical_sae_backward(
231
+ indices: torch.Tensor, # [B, K]
232
+ weight: torch.Tensor, # [F, D]
233
+ vals: torch.Tensor, # [B, K]
234
+ target: torch.Tensor, # [B, D]
235
+ final_recon: torch.Tensor, # [B, D]
236
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
237
+ device = weight.device
238
+ B, K = indices.shape
239
+ F, D = weight.shape
240
+
241
+ dW = torch.zeros((F, D), dtype=torch.float32, device=device)
242
+ dVals = torch.zeros((B, K), dtype=torch.float32, device=device)
243
+ db = torch.zeros((D,), dtype=torch.float32, device=device)
244
+
245
+ def _backward_grid(meta):
246
+ return (
247
+ B // meta["BLOCK_B"],
248
+ D // meta["BLOCK_D"],
249
+ )
250
+
251
+ hierarchical_sae_backward_kernel[_backward_grid](
252
+ dW,
253
+ dVals,
254
+ db,
255
+ final_recon,
256
+ indices,
257
+ weight,
258
+ vals,
259
+ target,
260
+ B=B,
261
+ D=D,
262
+ K=K,
263
+ BLOCK_D=32,
264
+ LOOP_NUM_STAGES=16,
265
+ BLOCK_B=16,
266
+ num_warps=8,
267
+ num_stages=8,
268
+ )
269
+
270
+ return dW, dVals, db
271
+
272
+
273
+ class HierarchicalSAELossFunction(torch.autograd.Function):
274
+ @staticmethod
275
+ @torch.amp.custom_fwd(device_type="cuda")
276
+ def forward(
277
+ ctx,
278
+ indices: torch.Tensor, # [B, K]
279
+ weight: torch.Tensor, # [F, D]
280
+ vals: torch.Tensor, # [B, K]
281
+ bias: torch.Tensor, # [D]
282
+ target: torch.Tensor, # [B, D]
283
+ ):
284
+ loss, final_recon = _hierarchical_sae_forward(indices, weight, vals, bias, target)
285
+ ctx.save_for_backward(indices, weight, vals, target, final_recon)
286
+ return loss
287
+
288
+ @staticmethod
289
+ @torch.amp.custom_bwd(device_type="cuda")
290
+ def backward(ctx, grad):
291
+ indices, weight, vals, target, final_recon = ctx.saved_tensors
292
+ dW, dVals, db = _hierarchical_sae_backward(indices, weight, vals, target, final_recon)
293
+
294
+ if grad is not None:
295
+ dW.mul_(grad)
296
+ dVals.mul_(grad)
297
+ db.mul_(grad)
298
+
299
+ return None, dW, dVals, db, None
300
+
301
+
302
+ def triton_hierarchical_sae_loss(
303
+ indices: torch.Tensor, # [B, K]
304
+ weight: torch.Tensor, # [F, D]
305
+ vals: torch.Tensor, # [B, K]
306
+ bias: torch.Tensor, # [D]
307
+ target: torch.Tensor, # [B, D]
308
+ ) -> torch.Tensor:
309
+ return HierarchicalSAELossFunction.apply(indices, weight, vals, bias, target)
310
+
311
+
312
+ def hierarchical_sae_loss(
313
+ indices: torch.Tensor, # [B, K]
314
+ weight: torch.Tensor, # [F, D]
315
+ vals: torch.Tensor, # [B, K]
316
+ bias: torch.Tensor, # [D]
317
+ target: torch.Tensor, # [B, D]
318
+ ) -> torch.Tensor:
319
+ emb = weight[indices].to(torch.float32) # [K, D]
320
+ recon_cum = bias.to(torch.float32) + (emb * vals.unsqueeze(-1)).cumsum(dim=1)
321
+ diff = recon_cum.to(torch.float32) - target.to(torch.float32).unsqueeze(1)
322
+ loss = diff.pow(2).mean()
323
+ return loss
torch-ext/flex_sae/topk_kernels.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TopK SAE decoder Triton kernels
2
+ # Copyright 2025 T-Tech
3
+ # This code is adapted from Facebook Research under the
4
+ # Creative Commons Attribution-NonCommercial 4.0 International License.
5
+ # Original code can be found at: https://github.com/facebookresearch/memory
6
+
7
+
8
+ from typing import Tuple
9
+
10
+ import torch
11
+ import triton
12
+ import triton.language as tl
13
+
14
+
15
+ @triton.jit
16
+ def embedding_bag_forward_kernel(
17
+ out_ptr, # [B, D]
18
+ indices_ptr, # [B, K]
19
+ weight_ptr, # [F, D]
20
+ vals_ptr, # [B, K]
21
+ D: tl.constexpr,
22
+ K: tl.constexpr,
23
+ BLOCK_D: tl.constexpr,
24
+ ):
25
+ tl.static_assert((D % BLOCK_D) == 0)
26
+ tl.static_assert((K & (K - 1)) == 0, f"{K=} must be a power of 2")
27
+ tl.static_assert((BLOCK_D & (BLOCK_D - 1)) == 0, f"{BLOCK_D=} must be a power of 2")
28
+
29
+ b = tl.program_id(axis=0).to(tl.int64)
30
+ pid_d = tl.program_id(axis=1).to(tl.int64)
31
+
32
+ off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
33
+
34
+ out_value = tl.zeros([BLOCK_D], dtype=tl.float32)
35
+ for i in tl.range(K):
36
+ my_index = tl.load(indices_ptr + b * K + i).to(tl.int64)
37
+ my_scaling = tl.load(vals_ptr + b * K + i)
38
+ w_tile = tl.load(weight_ptr + my_index * D + off_d).to(tl.float32)
39
+ out_value += w_tile * my_scaling
40
+
41
+ tl.store(out_ptr + b * D + off_d, out_value)
42
+
43
+
44
+ def embedding_bag_forward(
45
+ indices: torch.Tensor, # [B, K]
46
+ weight: torch.Tensor, # [F, D]
47
+ vals: torch.Tensor, # [B, K]
48
+ ) -> torch.Tensor:
49
+ B, K = indices.shape
50
+ D = weight.shape[1]
51
+
52
+ trt_out = torch.empty([B, D], dtype=weight.dtype, device=weight.device)
53
+
54
+ def _forward_grid(meta):
55
+ return (B, D // meta["BLOCK_D"])
56
+
57
+ embedding_bag_forward_kernel[_forward_grid](
58
+ trt_out,
59
+ indices,
60
+ weight,
61
+ vals,
62
+ D=D,
63
+ K=K,
64
+ BLOCK_D=64,
65
+ num_warps=1,
66
+ num_stages=1,
67
+ )
68
+ return trt_out
69
+
70
+
71
+ @triton.jit
72
+ def count_per_embedding_kernel(
73
+ count_per_emb_ptr, # [F + 1]
74
+ indices_ptr, # [B, K]
75
+ K: tl.constexpr,
76
+ ):
77
+ batch_id = tl.program_id(axis=0).to(tl.int64)
78
+ for t in tl.range(K):
79
+ embedding_id = tl.load(indices_ptr + batch_id * K + t)
80
+ tl.atomic_add(count_per_emb_ptr + embedding_id + 1, 1, sem="relaxed")
81
+
82
+
83
+ @triton.jit
84
+ def map_embeddings_and_outputs_kernel(
85
+ reverse_mapping_ptr, # [B * K]
86
+ mapping_write_pos_ptr, # [F]
87
+ indices_ptr, # [B, K]
88
+ K: tl.constexpr,
89
+ ):
90
+ batch_id = tl.program_id(axis=0).to(tl.int64)
91
+ for t in tl.range(K):
92
+ embedding_id = tl.load(indices_ptr + batch_id * K + t)
93
+ write_pos = tl.atomic_add(mapping_write_pos_ptr + embedding_id, 1, sem="relaxed")
94
+ tl.store(reverse_mapping_ptr + write_pos, batch_id * K + t)
95
+
96
+
97
+ @triton.jit
98
+ def aggregate_gradient_for_embedding_kernel(
99
+ weight_grad_ptr, # [F, D]
100
+ vals_grad_ptr, # [B, K]
101
+ weight_ptr, # [F, D]
102
+ emb_begin_pos_ptr, # [F + 1]
103
+ reverse_mapping_ptr, # [B * K]
104
+ vals_ptr, # [B, K]
105
+ gradient_ptr, # [B, D]
106
+ D: tl.constexpr,
107
+ K: tl.constexpr,
108
+ BLOCK_D: tl.constexpr,
109
+ ):
110
+ tl.static_assert((D % BLOCK_D) == 0)
111
+ tl.static_assert((K & (K - 1)) == 0, f"{K=} must be a power of 2")
112
+ tl.static_assert((BLOCK_D & (BLOCK_D - 1)) == 0, f"{BLOCK_D=} must be a power of 2")
113
+
114
+ e = tl.program_id(axis=0).to(tl.int64)
115
+ pid_d = tl.program_id(axis=1).to(tl.int64)
116
+
117
+ off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
118
+
119
+ begin = tl.load(emb_begin_pos_ptr + e)
120
+ end = tl.load(emb_begin_pos_ptr + e + 1)
121
+
122
+ w_row_tile = tl.load(weight_ptr + e * D + off_d).to(tl.float32)
123
+ w_grad_tile = tl.zeros([BLOCK_D], dtype=tl.float32)
124
+
125
+ for idx in tl.range(begin, end):
126
+ out_linear = tl.load(reverse_mapping_ptr + idx).to(tl.int64)
127
+ b = out_linear // K
128
+
129
+ psw = tl.load(vals_ptr + out_linear)
130
+ g_tile = tl.load(gradient_ptr + b * D + off_d).to(tl.float32)
131
+
132
+ w_grad_tile += psw * g_tile
133
+
134
+ psw_grad_partial = tl.sum(g_tile * w_row_tile)
135
+ tl.atomic_add(vals_grad_ptr + out_linear, psw_grad_partial, sem="relaxed")
136
+
137
+ tl.store(weight_grad_ptr + e * D + off_d, w_grad_tile)
138
+
139
+
140
+ def embedding_bag_backward(
141
+ indices: torch.Tensor, # [B, K]
142
+ weight: torch.Tensor, # [F, D]
143
+ vals: torch.Tensor, # [B, K]
144
+ gradient: torch.Tensor, # [B, D]
145
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
146
+ F, D = weight.shape
147
+ B, K = indices.shape
148
+
149
+ count_per_emb = torch.zeros((F + 1,), dtype=torch.uint32, device=indices.device)
150
+ count_per_embedding_kernel[(B,)](count_per_emb, indices, K=K, num_warps=1)
151
+
152
+ emb_begin_pos = count_per_emb.cumsum(0) # [F + 1]
153
+
154
+ reverse_mapping = torch.empty([B * K], dtype=torch.uint32, device=indices.device)
155
+ assert B * K <= 2 ** (reverse_mapping.dtype.itemsize * 8) - 1
156
+
157
+ map_embeddings_and_outputs_kernel[(B,)](
158
+ reverse_mapping_ptr=reverse_mapping,
159
+ mapping_write_pos_ptr=emb_begin_pos.clone(),
160
+ indices_ptr=indices,
161
+ K=K,
162
+ num_warps=1,
163
+ )
164
+
165
+ weight_grad = torch.empty_like(weight, dtype=torch.float32) # [F, D]
166
+ vals_grad = torch.zeros_like(vals, dtype=torch.float32) # [B, K]
167
+
168
+ def _forward_grid(meta):
169
+ return (F, D // meta["BLOCK_D"])
170
+
171
+ aggregate_gradient_for_embedding_kernel[_forward_grid](
172
+ weight_grad_ptr=weight_grad,
173
+ vals_grad_ptr=vals_grad,
174
+ weight_ptr=weight,
175
+ emb_begin_pos_ptr=emb_begin_pos,
176
+ reverse_mapping_ptr=reverse_mapping,
177
+ vals_ptr=vals,
178
+ gradient_ptr=gradient,
179
+ D=D,
180
+ K=K,
181
+ BLOCK_D=256,
182
+ num_warps=1,
183
+ num_stages=2,
184
+ )
185
+ return weight_grad, vals_grad
186
+
187
+
188
+ class xFormersEmbeddingBag(torch.autograd.Function):
189
+ @staticmethod
190
+ @torch.amp.custom_fwd(device_type="cuda")
191
+ def forward(
192
+ ctx,
193
+ indices: torch.Tensor, # [B, K]
194
+ weight: torch.Tensor, # [F, D]
195
+ vals: torch.Tensor, # [B, K]
196
+ ) -> torch.Tensor:
197
+ ctx.save_for_backward(indices, weight, vals)
198
+ return embedding_bag_forward(indices, weight, vals) # [B, D]
199
+
200
+ @staticmethod
201
+ @torch.amp.custom_bwd(device_type="cuda")
202
+ def backward(ctx, gradient):
203
+ indices, weight, vals = ctx.saved_tensors
204
+ weight_g, vals_g = embedding_bag_backward(
205
+ indices,
206
+ weight,
207
+ vals,
208
+ gradient,
209
+ )
210
+ return None, weight_g, vals_g
211
+
212
+
213
+ def triton_topk_sae_loss(
214
+ indices: torch.Tensor, # [B, K]
215
+ weight: torch.Tensor, # [F, D]
216
+ vals: torch.Tensor, # [B, K]
217
+ bias: torch.Tensor, # [D]
218
+ target: torch.Tensor, # [B, D]
219
+ ) -> torch.Tensor:
220
+ recon = bias.to(torch.float32) + xFormersEmbeddingBag.apply(indices, weight, vals)
221
+ diff = recon.to(torch.float32) - target.to(torch.float32)
222
+ loss = diff.pow(2).mean()
223
+ return loss
224
+
225
+
226
+ def topk_sae_loss(
227
+ indices: torch.Tensor, # [B, K]
228
+ weight: torch.Tensor, # [F, D]
229
+ vals: torch.Tensor, # [B, K]
230
+ bias: torch.Tensor, # [D]
231
+ target: torch.Tensor, # [B, D]
232
+ ) -> torch.Tensor:
233
+ emb = weight[indices].to(torch.float32) # [K, D]
234
+ recon = bias.to(torch.float32) + (emb * vals.unsqueeze(-1)).sum(dim=1)
235
+ diff = recon.to(torch.float32) - target.to(torch.float32)
236
+ loss = diff.pow(2).mean()
237
+ return loss