Upload folder using huggingface_hub
Browse files- .gitignore +3 -0
- LICENSE +201 -0
- NOTICE +16 -0
- README.md +196 -3
- build.toml +3 -0
- build/torch-universal/flex_sae/__init__.py +18 -0
- build/torch-universal/flex_sae/_ops.py +8 -0
- build/torch-universal/flex_sae/hierarchical_kernels.py +323 -0
- build/torch-universal/flex_sae/topk_kernels.py +237 -0
- example.py +100 -0
- flake.lock +168 -0
- flake.nix +18 -0
- tests/__init__.py +0 -0
- tests/test_all_kernels.py +66 -0
- tests/test_setup.py +41 -0
- torch-ext/flex_sae/__init__.py +18 -0
- torch-ext/flex_sae/hierarchical_kernels.py +323 -0
- torch-ext/flex_sae/topk_kernels.py +237 -0
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.DS_Store
|
| 2 |
+
*__pycache__
|
| 3 |
+
*__MACOSX
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright 2025 T-Tech
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
NOTICE
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Flex SAE Kernels
|
| 2 |
+
Copyright (c) 2025 T-Tech
|
| 3 |
+
|
| 4 |
+
This project is distributed under the Apache License, Version 2.0. The following
|
| 5 |
+
third-party component is redistributed under its original terms:
|
| 6 |
+
|
| 7 |
+
- Portions of `torch-ext/flex_sae/topk_kernels.py` are adapted from the Facebook
|
| 8 |
+
Research project "memory" (https://github.com/facebookresearch/memory).
|
| 9 |
+
That source is licensed under the Creative Commons Attribution-NonCommercial
|
| 10 |
+
4.0 International License (CC BY-NC 4.0). Any use of the adapted code must
|
| 11 |
+
comply with the non-commercial requirements described at
|
| 12 |
+
https://creativecommons.org/licenses/by-nc/4.0/.
|
| 13 |
+
|
| 14 |
+
Where the Apache 2.0 license and CC BY-NC 4.0 differ, the more restrictive
|
| 15 |
+
requirements apply to the adapted code. All other files are provided under the
|
| 16 |
+
Apache License, Version 2.0.
|
README.md
CHANGED
|
@@ -1,3 +1,196 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: apache-2.0
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- kernel
|
| 5 |
+
- sae
|
| 6 |
+
---
|
| 7 |
+
# Flex SAE Kernels
|
| 8 |
+
|
| 9 |
+
[](https://arxiv.org/abs/2505.24473)
|
| 10 |
+
|
| 11 |
+
Fused Triton implementations of the TopK and HierarchicalTopK sparse autoencoder (SAE) decoder losses described in *Train One Sparse Autoencoder Across Multiple Sparsity Budgets to Preserve Interpretability and Accuracy*.
|
| 12 |
+
|
| 13 |
+
**This work has been accepted to [EMNLP 2025](https://2025.emnlp.org/).**
|
| 14 |
+
|
| 15 |
+
## What is released?
|
| 16 |
+
|
| 17 |
+
- Fast TopK kernel for SAE (slightly modified version from xformers) `torch-ext/flex_sae/topk_kernels.py`
|
| 18 |
+
- Fast HierarchicalTopK kernels (see our [paper](https://arxiv.org/abs/2505.24473)) `torch-ext/flex_sae/hierarchical_kernels.py`.
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
## Quickstart
|
| 22 |
+
|
| 23 |
+
Kernels are available via loading from hub, they have the following signature:
|
| 24 |
+
```python
|
| 25 |
+
from kernels import get_kernel
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
flex = get_kernel('t-tech/flex-sae')
|
| 29 |
+
|
| 30 |
+
top_k_kernel = flex.triton_topk_sae_loss
|
| 31 |
+
hierarchical_top_k_kernel = flex.triton_hierarchical_sae_loss
|
| 32 |
+
|
| 33 |
+
"B -- batch size, K -- top-k, F -- dictionary size, D -- model hidden dim"
|
| 34 |
+
|
| 35 |
+
loss: torch.Tensor = top_k_kernel(
|
| 36 |
+
indices: torch.Tensor, # [B, K]
|
| 37 |
+
weight: torch.Tensor, # [F, D]
|
| 38 |
+
vals: torch.Tensor, # [B, K]
|
| 39 |
+
bias: torch.Tensor, # [D]
|
| 40 |
+
target: torch.Tensor, # [B, D]
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
loss: torch.Tensor = hierarchical_top_k_kernel(
|
| 44 |
+
indices: torch.Tensor, # [B, K]
|
| 45 |
+
weight: torch.Tensor, # [F, D]
|
| 46 |
+
vals: torch.Tensor, # [B, K]
|
| 47 |
+
bias: torch.Tensor, # [D]
|
| 48 |
+
target: torch.Tensor, # [B, D]
|
| 49 |
+
)
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
## Overview
|
| 53 |
+
- `torch-ext/flex_sae/` contains the Triton kernels alongside torch reference implementations.
|
| 54 |
+
- `tests/` hosts CUDA-backed property tests that ensure numerical parity across dtypes and kernels.
|
| 55 |
+
- `build.toml`, `flake.nix` integrate the project with [Hugging Face kernel-builder](https://github.com/huggingface/kernel-builder).
|
| 56 |
+
|
| 57 |
+
The Triton kernels target CUDA GPUs and focus on reducing the latency gap between TopK and HierarchicalTopK decoders while keeping memory usage flat.
|
| 58 |
+
|
| 59 |
+
## Example
|
| 60 |
+
|
| 61 |
+
You can find example usage in [example.py](https://huggingface.co/t-tech/flex-sae/blob/main/example.py).
|
| 62 |
+
```python
|
| 63 |
+
# /// script
|
| 64 |
+
# dependencies = [
|
| 65 |
+
# "torch",
|
| 66 |
+
# "numpy",
|
| 67 |
+
# "kernels",
|
| 68 |
+
# ]
|
| 69 |
+
# ///
|
| 70 |
+
|
| 71 |
+
import torch
|
| 72 |
+
import numpy as np
|
| 73 |
+
from kernels import get_kernel
|
| 74 |
+
|
| 75 |
+
flex = get_kernel("t-tech/flex-sae") #Fast Kernels
|
| 76 |
+
|
| 77 |
+
@torch.compile(fullgraph=True)
|
| 78 |
+
def hierarchical_sae_loss(
|
| 79 |
+
indices: torch.Tensor, # [B, K]
|
| 80 |
+
weight: torch.Tensor, # [F, D]
|
| 81 |
+
vals: torch.Tensor, # [B, K]
|
| 82 |
+
bias: torch.Tensor, # [D]
|
| 83 |
+
target: torch.Tensor, # [B, D]
|
| 84 |
+
) -> torch.Tensor:
|
| 85 |
+
emb = weight[indices].to(torch.float32) # [K, D]
|
| 86 |
+
recon_cum = bias.to(torch.float32) + (emb * vals.unsqueeze(-1)).cumsum(dim=1)
|
| 87 |
+
diff = recon_cum.to(torch.float32) - target.to(torch.float32).unsqueeze(1)
|
| 88 |
+
loss = diff.pow(2).mean()
|
| 89 |
+
return loss
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
B = 2048
|
| 93 |
+
K = 256
|
| 94 |
+
F = 1024 * 128
|
| 95 |
+
D = 1024
|
| 96 |
+
warmup = 5
|
| 97 |
+
dtype = torch.float32
|
| 98 |
+
|
| 99 |
+
vals = None
|
| 100 |
+
decoder = None
|
| 101 |
+
bias = None
|
| 102 |
+
target = None
|
| 103 |
+
indices = None
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def init_parameters():
|
| 107 |
+
global vals, decoder, bias, target, indices
|
| 108 |
+
vals = torch.randn(B, K, dtype=dtype, device="cuda").abs().requires_grad_()
|
| 109 |
+
decoder = torch.randn(F, D, dtype=dtype, device="cuda", requires_grad=True)
|
| 110 |
+
bias = torch.randn(D, dtype=dtype, device="cuda", requires_grad=True)
|
| 111 |
+
target = torch.randn(B, D, dtype=dtype, device="cuda")
|
| 112 |
+
indices = torch.randint(0, F, (B, K), dtype=torch.long, device="cuda")
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
timing_kernel = []
|
| 116 |
+
timing_vanilla = []
|
| 117 |
+
torch.cuda.reset_peak_memory_stats()
|
| 118 |
+
loss_kernel_list = torch.zeros((100,))
|
| 119 |
+
loss_vanilla_list = torch.zeros((100,))
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def zero_grad():
|
| 123 |
+
vals.grad = None
|
| 124 |
+
decoder.grad = None
|
| 125 |
+
bias.grad = None
|
| 126 |
+
torch.cuda.empty_cache()
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
for i in range(100 + warmup):
|
| 130 |
+
init_parameters()
|
| 131 |
+
start_kernel = torch.cuda.Event(enable_timing=True)
|
| 132 |
+
end_kernel = torch.cuda.Event(enable_timing=True)
|
| 133 |
+
start_vanilla = torch.cuda.Event(enable_timing=True)
|
| 134 |
+
end_vanilla = torch.cuda.Event(enable_timing=True)
|
| 135 |
+
|
| 136 |
+
start_kernel.record()
|
| 137 |
+
loss_kernel = flex.triton_hierarchical_sae_loss(indices, decoder, vals, bias, target)
|
| 138 |
+
loss_kernel.backward()
|
| 139 |
+
end_kernel.record()
|
| 140 |
+
|
| 141 |
+
zero_grad()
|
| 142 |
+
start_vanilla.record()
|
| 143 |
+
loss_vanilla = hierarchical_sae_loss(indices, decoder, vals, bias, target)
|
| 144 |
+
loss_vanilla.backward()
|
| 145 |
+
end_vanilla.record()
|
| 146 |
+
if i >= warmup:
|
| 147 |
+
torch.cuda.synchronize()
|
| 148 |
+
timing_kernel.append(start_kernel.elapsed_time(end_kernel))
|
| 149 |
+
timing_vanilla.append(start_vanilla.elapsed_time(end_vanilla))
|
| 150 |
+
loss_kernel_list[i-warmup] = loss_kernel.detach()
|
| 151 |
+
loss_vanilla_list[i-warmup] = loss_vanilla.detach()
|
| 152 |
+
zero_grad()
|
| 153 |
+
|
| 154 |
+
if torch.allclose(loss_kernel, loss_vanilla):
|
| 155 |
+
print("✅ Outputs are close! Everything is good! 🎉")
|
| 156 |
+
else:
|
| 157 |
+
print("❌ Outputs mismatch... ⚠️🤔")
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
print(f"🦎 Triton Kernel Time (Ours): {np.mean(timing_kernel):.4f} ± {np.std(timing_kernel):.4f} ms")
|
| 161 |
+
print(f"🔥 Torch Compile Kernel Time: {np.mean(timing_vanilla):.4f} ± {np.std(timing_vanilla):.4f} ms")
|
| 162 |
+
print(f"🚀 Speedup: {np.mean(timing_vanilla) / np.mean(timing_kernel):.2f}x")
|
| 163 |
+
```
|
| 164 |
+
|
| 165 |
+
Run it with `uv run https://huggingface.co/t-tech/flex-sae/resolve/main/example.py`.
|
| 166 |
+
|
| 167 |
+
## Performance
|
| 168 |
+
Benchmarks were collected on a workload with dictionary size $F = 65 536$, embedding dimension $D = 2304$, and sparsity budgets $K \in \{32, 64, 128\}$. Latency is reported as time per training step (milliseconds) and memory as peak device usage (GiB).
|
| 169 |
+
|
| 170 |
+
| Decoder backend | K=32 (ms / GiB) | K=64 (ms / GiB) | K=128 (ms / GiB) |
|
| 171 |
+
| --- | --- | --- | --- |
|
| 172 |
+
| **Pure torch-compiled** | | | |
|
| 173 |
+
| TopK | 8.787 / 2.92 | 11.746 / 2.92 | 18.877 / 2.93 |
|
| 174 |
+
| HierarchicalTopK | 12.824 / 6.29 | 23.379 / 10.79 | 43.851 / 19.80 |
|
| 175 |
+
| **Triton kernels** | | | |
|
| 176 |
+
| TopK | 5.576 / 2.92 | 6.339 / 2.92 | 7.961 / 2.93 |
|
| 177 |
+
| HierarchicalTopK | **6.696 / 2.92** | **7.995 / 2.92** | **10.609 / 2.93** |
|
| 178 |
+
|
| 179 |
+
Across the evaluated sparsity budgets the fused Triton HierarchicalTopK kernel matches TopK kernels on memory use while remaining consistently faster than the reference torch implementation.
|
| 180 |
+
|
| 181 |
+
## License & Attribution
|
| 182 |
+
- All files except `torch-ext/flex_sae/topk_kernels.py` are released under the [Apache License 2.0](LICENSE).
|
| 183 |
+
- `torch-ext/flex_sae/topk_kernels.py` includes code adapted from Facebook Research's [memory](https://github.com/facebookresearch/memory) project, originally published under the Creative Commons Attribution-NonCommercial 4.0 International License. That component therefore remains available for non-commercial use only; see [NOTICE](NOTICE) for details.
|
| 184 |
+
|
| 185 |
+
## Citation
|
| 186 |
+
```bibtex
|
| 187 |
+
@misc{balagansky2025trainsparseautoencodermultiple,
|
| 188 |
+
title={Train One Sparse Autoencoder Across Multiple Sparsity Budgets to Preserve Interpretability and Accuracy},
|
| 189 |
+
author={Nikita Balagansky and Yaroslav Aksenov and Daniil Laptev and Vadim Kurochkin and Gleb Gerasimov and Nikita Koryagin and Daniil Gavrilov},
|
| 190 |
+
year={2025},
|
| 191 |
+
eprint={2505.24473},
|
| 192 |
+
archivePrefix={arXiv},
|
| 193 |
+
primaryClass={cs.LG},
|
| 194 |
+
url={https://arxiv.org/abs/2505.24473},
|
| 195 |
+
}
|
| 196 |
+
```
|
build.toml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[general]
|
| 2 |
+
name = "flex_sae"
|
| 3 |
+
universal = true
|
build/torch-universal/flex_sae/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TopK and HierarchicalTopK SAE decoder Triton kernels
|
| 2 |
+
# Copyright 2025 T-Tech
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
from .topk_kernels import triton_topk_sae_loss, topk_sae_loss
|
| 6 |
+
from .hierarchical_kernels import triton_hierarchical_sae_loss, hierarchical_sae_loss
|
| 7 |
+
|
| 8 |
+
__kernel_metadata__ = {
|
| 9 |
+
"license": "Apache-2.0 (with CC-BY-NC-4.0 component; see NOTICE)",
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"__kernel_metadata__",
|
| 14 |
+
"topk_sae_loss",
|
| 15 |
+
"triton_topk_sae_loss",
|
| 16 |
+
"hierarchical_sae_loss",
|
| 17 |
+
"triton_hierarchical_sae_loss",
|
| 18 |
+
]
|
build/torch-universal/flex_sae/_ops.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
ops = torch.ops._flex_sae_20250924130857
|
| 3 |
+
|
| 4 |
+
def add_op_namespace_prefix(op_name: str):
|
| 5 |
+
"""
|
| 6 |
+
Prefix op by namespace.
|
| 7 |
+
"""
|
| 8 |
+
return f"_flex_sae_20250924130857::{op_name}"
|
build/torch-universal/flex_sae/hierarchical_kernels.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# HierarchicalTopK SAE decoder Triton kernels
|
| 2 |
+
# Copyright 2025 T-Tech
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
from typing import Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import triton
|
| 9 |
+
import triton.language as tl
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@triton.jit
|
| 13 |
+
def hierarchical_sae_forward_kernel(
|
| 14 |
+
loss_per_batch_ptr, # [B]
|
| 15 |
+
final_recon_ptr, # [B, D]
|
| 16 |
+
indices_ptr, # [B, K]
|
| 17 |
+
weight_ptr, # [F, D]
|
| 18 |
+
bias_ptr, # [D]
|
| 19 |
+
vals_ptr, # [B, K]
|
| 20 |
+
target_ptr, # [B, D]
|
| 21 |
+
B: tl.constexpr,
|
| 22 |
+
D: tl.constexpr,
|
| 23 |
+
K: tl.constexpr,
|
| 24 |
+
BLOCK_D: tl.constexpr,
|
| 25 |
+
LOOP_NUM_STAGES: tl.constexpr,
|
| 26 |
+
BLOCK_B: tl.constexpr,
|
| 27 |
+
):
|
| 28 |
+
tl.static_assert((D % BLOCK_D) == 0)
|
| 29 |
+
tl.static_assert((B % BLOCK_B) == 0)
|
| 30 |
+
tl.static_assert((K & (K - 1)) == 0, f"{K=} must be a power of 2")
|
| 31 |
+
tl.static_assert((BLOCK_D & (BLOCK_D - 1)) == 0, f"{BLOCK_D=} must be a power of 2")
|
| 32 |
+
tl.static_assert((BLOCK_B & (BLOCK_B - 1)) == 0, f"{BLOCK_B=} must be a power of 2")
|
| 33 |
+
|
| 34 |
+
pid_b = tl.program_id(axis=0).to(tl.int64)
|
| 35 |
+
pid_d = tl.program_id(axis=1).to(tl.int64)
|
| 36 |
+
|
| 37 |
+
batch_offsets = pid_b * BLOCK_B + tl.arange(0, BLOCK_B)
|
| 38 |
+
batch_offsets = batch_offsets.to(tl.int64)
|
| 39 |
+
tl.multiple_of(batch_offsets, BLOCK_B)
|
| 40 |
+
|
| 41 |
+
offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
|
| 42 |
+
offset_d = offset_d.to(tl.int64)
|
| 43 |
+
|
| 44 |
+
tl.multiple_of(offset_d, BLOCK_D)
|
| 45 |
+
tl.max_contiguous(offset_d, BLOCK_D)
|
| 46 |
+
|
| 47 |
+
batch_d_offset = batch_offsets[:, None] * D + offset_d[None, :]
|
| 48 |
+
|
| 49 |
+
bias_tile = tl.load(bias_ptr + offset_d).to(tl.float32)
|
| 50 |
+
|
| 51 |
+
recon = tl.zeros([BLOCK_B, BLOCK_D], dtype=tl.float32)
|
| 52 |
+
recon += bias_tile[None, :]
|
| 53 |
+
|
| 54 |
+
target = tl.load(target_ptr + batch_d_offset).to(tl.float32)
|
| 55 |
+
|
| 56 |
+
loss_accum = tl.zeros([BLOCK_B, BLOCK_D], dtype=tl.float32)
|
| 57 |
+
|
| 58 |
+
row_idx_ptr = indices_ptr + batch_offsets * K
|
| 59 |
+
row_val_ptr = vals_ptr + batch_offsets * K
|
| 60 |
+
|
| 61 |
+
idx = tl.load(row_idx_ptr).to(tl.int64)
|
| 62 |
+
val = tl.load(row_val_ptr).to(tl.float32)
|
| 63 |
+
val = val[:, None]
|
| 64 |
+
weight_tile = tl.load(weight_ptr + idx[:, None] * D + offset_d[None, :]).to(tl.float32)
|
| 65 |
+
|
| 66 |
+
for t in tl.range(0, K, num_stages=LOOP_NUM_STAGES):
|
| 67 |
+
recon += weight_tile * val
|
| 68 |
+
diff = recon - target
|
| 69 |
+
loss_accum += diff * diff
|
| 70 |
+
|
| 71 |
+
if t + 1 < K:
|
| 72 |
+
idx_next = tl.load(row_idx_ptr + (t + 1)).to(tl.int64)
|
| 73 |
+
val_next = tl.load(row_val_ptr + (t + 1)).to(tl.float32)
|
| 74 |
+
weight_next = tl.load(weight_ptr + idx_next[:, None] * D + offset_d[None, :]).to(tl.float32)
|
| 75 |
+
|
| 76 |
+
idx = idx_next
|
| 77 |
+
val = val_next[:, None]
|
| 78 |
+
weight_tile = weight_next
|
| 79 |
+
|
| 80 |
+
loss_tile = tl.sum(loss_accum, axis=1)
|
| 81 |
+
tl.atomic_add(
|
| 82 |
+
loss_per_batch_ptr + batch_offsets,
|
| 83 |
+
loss_tile,
|
| 84 |
+
sem="relaxed",
|
| 85 |
+
)
|
| 86 |
+
tl.store(
|
| 87 |
+
final_recon_ptr + batch_d_offset,
|
| 88 |
+
recon,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@triton.jit
|
| 93 |
+
def hierarchical_sae_backward_kernel(
|
| 94 |
+
weight_grad_ptr, # [F, D]
|
| 95 |
+
vals_grad_ptr, # [B, K]
|
| 96 |
+
bias_grad_ptr, # [D]
|
| 97 |
+
final_recon_ptr, # [B, D]
|
| 98 |
+
indices_ptr, # [B, K]
|
| 99 |
+
weight_ptr, # [F, D]
|
| 100 |
+
vals_ptr, # [B, K]
|
| 101 |
+
target_ptr, # [B, D]
|
| 102 |
+
B: tl.constexpr,
|
| 103 |
+
D: tl.constexpr,
|
| 104 |
+
K: tl.constexpr,
|
| 105 |
+
BLOCK_D: tl.constexpr,
|
| 106 |
+
LOOP_NUM_STAGES: tl.constexpr,
|
| 107 |
+
BLOCK_B: tl.constexpr,
|
| 108 |
+
):
|
| 109 |
+
tl.static_assert((D % BLOCK_D) == 0)
|
| 110 |
+
tl.static_assert((B % BLOCK_B) == 0)
|
| 111 |
+
tl.static_assert((K & (K - 1)) == 0, f"{K=} must be a power of 2")
|
| 112 |
+
tl.static_assert((BLOCK_D & (BLOCK_D - 1)) == 0, f"{BLOCK_D=} must be a power of 2")
|
| 113 |
+
tl.static_assert((BLOCK_B & (BLOCK_B - 1)) == 0, f"{BLOCK_B=} must be a power of 2")
|
| 114 |
+
|
| 115 |
+
pid_b = tl.program_id(axis=0).to(tl.int64)
|
| 116 |
+
pid_d = tl.program_id(axis=1).to(tl.int64)
|
| 117 |
+
|
| 118 |
+
batch_offsets = pid_b * BLOCK_B + tl.arange(0, BLOCK_B)
|
| 119 |
+
batch_offsets = batch_offsets.to(tl.int64)
|
| 120 |
+
tl.multiple_of(batch_offsets, BLOCK_B)
|
| 121 |
+
|
| 122 |
+
offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
|
| 123 |
+
offset_d = offset_d.to(tl.int64)
|
| 124 |
+
|
| 125 |
+
tl.multiple_of(offset_d, BLOCK_D)
|
| 126 |
+
tl.max_contiguous(offset_d, BLOCK_D)
|
| 127 |
+
|
| 128 |
+
batch_d_offset = batch_offsets[:, None] * D + offset_d[None, :]
|
| 129 |
+
|
| 130 |
+
recon = tl.load(final_recon_ptr + batch_d_offset).to(tl.float32)
|
| 131 |
+
target = tl.load(target_ptr + batch_d_offset).to(tl.float32)
|
| 132 |
+
|
| 133 |
+
suffix = tl.zeros([BLOCK_B, BLOCK_D], dtype=tl.float32)
|
| 134 |
+
bias_accum = tl.zeros([BLOCK_B, BLOCK_D], dtype=tl.float32)
|
| 135 |
+
scale = tl.full((), 2.0 / (B * K * D), dtype=tl.float32)
|
| 136 |
+
|
| 137 |
+
row_idx_ptr = indices_ptr + batch_offsets * K
|
| 138 |
+
row_val_ptr = vals_ptr + batch_offsets * K
|
| 139 |
+
k_offsets = tl.arange(0, K)
|
| 140 |
+
val_grad_tile = tl.zeros([BLOCK_B, K], dtype=tl.float32)
|
| 141 |
+
|
| 142 |
+
step = K - 1
|
| 143 |
+
idx = tl.load(row_idx_ptr + step).to(tl.int64)
|
| 144 |
+
val = tl.load(row_val_ptr + step).to(tl.float32)
|
| 145 |
+
weight_tile = tl.load(weight_ptr + idx[:, None] * D + offset_d[None, :]).to(tl.float32)
|
| 146 |
+
|
| 147 |
+
for _ in tl.range(0, K, num_stages=LOOP_NUM_STAGES):
|
| 148 |
+
curr_step = step
|
| 149 |
+
|
| 150 |
+
diff = recon - target
|
| 151 |
+
grad_curr = diff * scale
|
| 152 |
+
suffix += grad_curr
|
| 153 |
+
bias_accum += grad_curr
|
| 154 |
+
|
| 155 |
+
val_broadcast = val[:, None]
|
| 156 |
+
contrib = suffix * val_broadcast
|
| 157 |
+
tl.atomic_add(
|
| 158 |
+
weight_grad_ptr + idx[:, None] * D + offset_d[None, :],
|
| 159 |
+
contrib,
|
| 160 |
+
sem="relaxed",
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
dot_partial = tl.sum(weight_tile * suffix, axis=1)
|
| 164 |
+
mask_curr = k_offsets[None, :] == curr_step
|
| 165 |
+
val_grad_tile = tl.where(mask_curr, dot_partial[:, None], val_grad_tile)
|
| 166 |
+
|
| 167 |
+
recon -= weight_tile * val_broadcast
|
| 168 |
+
|
| 169 |
+
if curr_step > 0:
|
| 170 |
+
step = curr_step - 1
|
| 171 |
+
idx = tl.load(row_idx_ptr + step).to(tl.int64)
|
| 172 |
+
val = tl.load(row_val_ptr + step).to(tl.float32)
|
| 173 |
+
weight_tile = tl.load(weight_ptr + idx[:, None] * D + offset_d[None, :]).to(tl.float32)
|
| 174 |
+
|
| 175 |
+
bias_grad_tile = tl.sum(bias_accum, axis=0)
|
| 176 |
+
tl.atomic_add(
|
| 177 |
+
bias_grad_ptr + offset_d,
|
| 178 |
+
bias_grad_tile,
|
| 179 |
+
sem="relaxed",
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
row_val_grad_ptr = vals_grad_ptr + batch_offsets[:, None] * K + k_offsets[None, :]
|
| 183 |
+
tl.atomic_add(
|
| 184 |
+
row_val_grad_ptr,
|
| 185 |
+
val_grad_tile,
|
| 186 |
+
sem="relaxed",
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def _hierarchical_sae_forward(
|
| 191 |
+
indices: torch.Tensor, # [B, K]
|
| 192 |
+
weight: torch.Tensor, # [F, D]
|
| 193 |
+
vals: torch.Tensor, # [B, K]
|
| 194 |
+
bias: torch.Tensor, # [D]
|
| 195 |
+
target: torch.Tensor, # [B, D]
|
| 196 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 197 |
+
B, K = indices.shape
|
| 198 |
+
F, D = weight.shape
|
| 199 |
+
|
| 200 |
+
loss_per_batch = torch.zeros((B,), dtype=torch.float32, device=weight.device)
|
| 201 |
+
final_recon = torch.empty((B, D), dtype=torch.float32, device=weight.device)
|
| 202 |
+
|
| 203 |
+
def _forward_grid(meta):
|
| 204 |
+
return (
|
| 205 |
+
B // meta["BLOCK_B"],
|
| 206 |
+
D // meta["BLOCK_D"],
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
hierarchical_sae_forward_kernel[_forward_grid](
|
| 210 |
+
loss_per_batch,
|
| 211 |
+
final_recon,
|
| 212 |
+
indices,
|
| 213 |
+
weight,
|
| 214 |
+
bias,
|
| 215 |
+
vals,
|
| 216 |
+
target,
|
| 217 |
+
B=B,
|
| 218 |
+
D=D,
|
| 219 |
+
K=K,
|
| 220 |
+
BLOCK_D=64,
|
| 221 |
+
LOOP_NUM_STAGES=4,
|
| 222 |
+
BLOCK_B=1,
|
| 223 |
+
num_warps=2,
|
| 224 |
+
num_stages=2,
|
| 225 |
+
)
|
| 226 |
+
loss = loss_per_batch.sum() / (B * K * D)
|
| 227 |
+
return loss, final_recon
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def _hierarchical_sae_backward(
|
| 231 |
+
indices: torch.Tensor, # [B, K]
|
| 232 |
+
weight: torch.Tensor, # [F, D]
|
| 233 |
+
vals: torch.Tensor, # [B, K]
|
| 234 |
+
target: torch.Tensor, # [B, D]
|
| 235 |
+
final_recon: torch.Tensor, # [B, D]
|
| 236 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 237 |
+
device = weight.device
|
| 238 |
+
B, K = indices.shape
|
| 239 |
+
F, D = weight.shape
|
| 240 |
+
|
| 241 |
+
dW = torch.zeros((F, D), dtype=torch.float32, device=device)
|
| 242 |
+
dVals = torch.zeros((B, K), dtype=torch.float32, device=device)
|
| 243 |
+
db = torch.zeros((D,), dtype=torch.float32, device=device)
|
| 244 |
+
|
| 245 |
+
def _backward_grid(meta):
|
| 246 |
+
return (
|
| 247 |
+
B // meta["BLOCK_B"],
|
| 248 |
+
D // meta["BLOCK_D"],
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
hierarchical_sae_backward_kernel[_backward_grid](
|
| 252 |
+
dW,
|
| 253 |
+
dVals,
|
| 254 |
+
db,
|
| 255 |
+
final_recon,
|
| 256 |
+
indices,
|
| 257 |
+
weight,
|
| 258 |
+
vals,
|
| 259 |
+
target,
|
| 260 |
+
B=B,
|
| 261 |
+
D=D,
|
| 262 |
+
K=K,
|
| 263 |
+
BLOCK_D=32,
|
| 264 |
+
LOOP_NUM_STAGES=16,
|
| 265 |
+
BLOCK_B=16,
|
| 266 |
+
num_warps=8,
|
| 267 |
+
num_stages=8,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
return dW, dVals, db
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class HierarchicalSAELossFunction(torch.autograd.Function):
|
| 274 |
+
@staticmethod
|
| 275 |
+
@torch.amp.custom_fwd(device_type="cuda")
|
| 276 |
+
def forward(
|
| 277 |
+
ctx,
|
| 278 |
+
indices: torch.Tensor, # [B, K]
|
| 279 |
+
weight: torch.Tensor, # [F, D]
|
| 280 |
+
vals: torch.Tensor, # [B, K]
|
| 281 |
+
bias: torch.Tensor, # [D]
|
| 282 |
+
target: torch.Tensor, # [B, D]
|
| 283 |
+
):
|
| 284 |
+
loss, final_recon = _hierarchical_sae_forward(indices, weight, vals, bias, target)
|
| 285 |
+
ctx.save_for_backward(indices, weight, vals, target, final_recon)
|
| 286 |
+
return loss
|
| 287 |
+
|
| 288 |
+
@staticmethod
|
| 289 |
+
@torch.amp.custom_bwd(device_type="cuda")
|
| 290 |
+
def backward(ctx, grad):
|
| 291 |
+
indices, weight, vals, target, final_recon = ctx.saved_tensors
|
| 292 |
+
dW, dVals, db = _hierarchical_sae_backward(indices, weight, vals, target, final_recon)
|
| 293 |
+
|
| 294 |
+
if grad is not None:
|
| 295 |
+
dW.mul_(grad)
|
| 296 |
+
dVals.mul_(grad)
|
| 297 |
+
db.mul_(grad)
|
| 298 |
+
|
| 299 |
+
return None, dW, dVals, db, None
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def triton_hierarchical_sae_loss(
|
| 303 |
+
indices: torch.Tensor, # [B, K]
|
| 304 |
+
weight: torch.Tensor, # [F, D]
|
| 305 |
+
vals: torch.Tensor, # [B, K]
|
| 306 |
+
bias: torch.Tensor, # [D]
|
| 307 |
+
target: torch.Tensor, # [B, D]
|
| 308 |
+
) -> torch.Tensor:
|
| 309 |
+
return HierarchicalSAELossFunction.apply(indices, weight, vals, bias, target)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def hierarchical_sae_loss(
|
| 313 |
+
indices: torch.Tensor, # [B, K]
|
| 314 |
+
weight: torch.Tensor, # [F, D]
|
| 315 |
+
vals: torch.Tensor, # [B, K]
|
| 316 |
+
bias: torch.Tensor, # [D]
|
| 317 |
+
target: torch.Tensor, # [B, D]
|
| 318 |
+
) -> torch.Tensor:
|
| 319 |
+
emb = weight[indices].to(torch.float32) # [K, D]
|
| 320 |
+
recon_cum = bias.to(torch.float32) + (emb * vals.unsqueeze(-1)).cumsum(dim=1)
|
| 321 |
+
diff = recon_cum.to(torch.float32) - target.to(torch.float32).unsqueeze(1)
|
| 322 |
+
loss = diff.pow(2).mean()
|
| 323 |
+
return loss
|
build/torch-universal/flex_sae/topk_kernels.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TopK SAE decoder Triton kernels
|
| 2 |
+
# Copyright 2025 T-Tech
|
| 3 |
+
# This code is adapted from Facebook Research under the
|
| 4 |
+
# Creative Commons Attribution-NonCommercial 4.0 International License.
|
| 5 |
+
# Original code can be found at: https://github.com/facebookresearch/memory
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import triton
|
| 12 |
+
import triton.language as tl
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@triton.jit
|
| 16 |
+
def embedding_bag_forward_kernel(
|
| 17 |
+
out_ptr, # [B, D]
|
| 18 |
+
indices_ptr, # [B, K]
|
| 19 |
+
weight_ptr, # [F, D]
|
| 20 |
+
vals_ptr, # [B, K]
|
| 21 |
+
D: tl.constexpr,
|
| 22 |
+
K: tl.constexpr,
|
| 23 |
+
BLOCK_D: tl.constexpr,
|
| 24 |
+
):
|
| 25 |
+
tl.static_assert((D % BLOCK_D) == 0)
|
| 26 |
+
tl.static_assert((K & (K - 1)) == 0, f"{K=} must be a power of 2")
|
| 27 |
+
tl.static_assert((BLOCK_D & (BLOCK_D - 1)) == 0, f"{BLOCK_D=} must be a power of 2")
|
| 28 |
+
|
| 29 |
+
b = tl.program_id(axis=0).to(tl.int64)
|
| 30 |
+
pid_d = tl.program_id(axis=1).to(tl.int64)
|
| 31 |
+
|
| 32 |
+
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
|
| 33 |
+
|
| 34 |
+
out_value = tl.zeros([BLOCK_D], dtype=tl.float32)
|
| 35 |
+
for i in tl.range(K):
|
| 36 |
+
my_index = tl.load(indices_ptr + b * K + i).to(tl.int64)
|
| 37 |
+
my_scaling = tl.load(vals_ptr + b * K + i)
|
| 38 |
+
w_tile = tl.load(weight_ptr + my_index * D + off_d).to(tl.float32)
|
| 39 |
+
out_value += w_tile * my_scaling
|
| 40 |
+
|
| 41 |
+
tl.store(out_ptr + b * D + off_d, out_value)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def embedding_bag_forward(
|
| 45 |
+
indices: torch.Tensor, # [B, K]
|
| 46 |
+
weight: torch.Tensor, # [F, D]
|
| 47 |
+
vals: torch.Tensor, # [B, K]
|
| 48 |
+
) -> torch.Tensor:
|
| 49 |
+
B, K = indices.shape
|
| 50 |
+
D = weight.shape[1]
|
| 51 |
+
|
| 52 |
+
trt_out = torch.empty([B, D], dtype=weight.dtype, device=weight.device)
|
| 53 |
+
|
| 54 |
+
def _forward_grid(meta):
|
| 55 |
+
return (B, D // meta["BLOCK_D"])
|
| 56 |
+
|
| 57 |
+
embedding_bag_forward_kernel[_forward_grid](
|
| 58 |
+
trt_out,
|
| 59 |
+
indices,
|
| 60 |
+
weight,
|
| 61 |
+
vals,
|
| 62 |
+
D=D,
|
| 63 |
+
K=K,
|
| 64 |
+
BLOCK_D=64,
|
| 65 |
+
num_warps=1,
|
| 66 |
+
num_stages=1,
|
| 67 |
+
)
|
| 68 |
+
return trt_out
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@triton.jit
|
| 72 |
+
def count_per_embedding_kernel(
|
| 73 |
+
count_per_emb_ptr, # [F + 1]
|
| 74 |
+
indices_ptr, # [B, K]
|
| 75 |
+
K: tl.constexpr,
|
| 76 |
+
):
|
| 77 |
+
batch_id = tl.program_id(axis=0).to(tl.int64)
|
| 78 |
+
for t in tl.range(K):
|
| 79 |
+
embedding_id = tl.load(indices_ptr + batch_id * K + t)
|
| 80 |
+
tl.atomic_add(count_per_emb_ptr + embedding_id + 1, 1, sem="relaxed")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@triton.jit
|
| 84 |
+
def map_embeddings_and_outputs_kernel(
|
| 85 |
+
reverse_mapping_ptr, # [B * K]
|
| 86 |
+
mapping_write_pos_ptr, # [F]
|
| 87 |
+
indices_ptr, # [B, K]
|
| 88 |
+
K: tl.constexpr,
|
| 89 |
+
):
|
| 90 |
+
batch_id = tl.program_id(axis=0).to(tl.int64)
|
| 91 |
+
for t in tl.range(K):
|
| 92 |
+
embedding_id = tl.load(indices_ptr + batch_id * K + t)
|
| 93 |
+
write_pos = tl.atomic_add(mapping_write_pos_ptr + embedding_id, 1, sem="relaxed")
|
| 94 |
+
tl.store(reverse_mapping_ptr + write_pos, batch_id * K + t)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@triton.jit
|
| 98 |
+
def aggregate_gradient_for_embedding_kernel(
|
| 99 |
+
weight_grad_ptr, # [F, D]
|
| 100 |
+
vals_grad_ptr, # [B, K]
|
| 101 |
+
weight_ptr, # [F, D]
|
| 102 |
+
emb_begin_pos_ptr, # [F + 1]
|
| 103 |
+
reverse_mapping_ptr, # [B * K]
|
| 104 |
+
vals_ptr, # [B, K]
|
| 105 |
+
gradient_ptr, # [B, D]
|
| 106 |
+
D: tl.constexpr,
|
| 107 |
+
K: tl.constexpr,
|
| 108 |
+
BLOCK_D: tl.constexpr,
|
| 109 |
+
):
|
| 110 |
+
tl.static_assert((D % BLOCK_D) == 0)
|
| 111 |
+
tl.static_assert((K & (K - 1)) == 0, f"{K=} must be a power of 2")
|
| 112 |
+
tl.static_assert((BLOCK_D & (BLOCK_D - 1)) == 0, f"{BLOCK_D=} must be a power of 2")
|
| 113 |
+
|
| 114 |
+
e = tl.program_id(axis=0).to(tl.int64)
|
| 115 |
+
pid_d = tl.program_id(axis=1).to(tl.int64)
|
| 116 |
+
|
| 117 |
+
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
|
| 118 |
+
|
| 119 |
+
begin = tl.load(emb_begin_pos_ptr + e)
|
| 120 |
+
end = tl.load(emb_begin_pos_ptr + e + 1)
|
| 121 |
+
|
| 122 |
+
w_row_tile = tl.load(weight_ptr + e * D + off_d).to(tl.float32)
|
| 123 |
+
w_grad_tile = tl.zeros([BLOCK_D], dtype=tl.float32)
|
| 124 |
+
|
| 125 |
+
for idx in tl.range(begin, end):
|
| 126 |
+
out_linear = tl.load(reverse_mapping_ptr + idx).to(tl.int64)
|
| 127 |
+
b = out_linear // K
|
| 128 |
+
|
| 129 |
+
psw = tl.load(vals_ptr + out_linear)
|
| 130 |
+
g_tile = tl.load(gradient_ptr + b * D + off_d).to(tl.float32)
|
| 131 |
+
|
| 132 |
+
w_grad_tile += psw * g_tile
|
| 133 |
+
|
| 134 |
+
psw_grad_partial = tl.sum(g_tile * w_row_tile)
|
| 135 |
+
tl.atomic_add(vals_grad_ptr + out_linear, psw_grad_partial, sem="relaxed")
|
| 136 |
+
|
| 137 |
+
tl.store(weight_grad_ptr + e * D + off_d, w_grad_tile)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def embedding_bag_backward(
|
| 141 |
+
indices: torch.Tensor, # [B, K]
|
| 142 |
+
weight: torch.Tensor, # [F, D]
|
| 143 |
+
vals: torch.Tensor, # [B, K]
|
| 144 |
+
gradient: torch.Tensor, # [B, D]
|
| 145 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 146 |
+
F, D = weight.shape
|
| 147 |
+
B, K = indices.shape
|
| 148 |
+
|
| 149 |
+
count_per_emb = torch.zeros((F + 1,), dtype=torch.uint32, device=indices.device)
|
| 150 |
+
count_per_embedding_kernel[(B,)](count_per_emb, indices, K=K, num_warps=1)
|
| 151 |
+
|
| 152 |
+
emb_begin_pos = count_per_emb.cumsum(0) # [F + 1]
|
| 153 |
+
|
| 154 |
+
reverse_mapping = torch.empty([B * K], dtype=torch.uint32, device=indices.device)
|
| 155 |
+
assert B * K <= 2 ** (reverse_mapping.dtype.itemsize * 8) - 1
|
| 156 |
+
|
| 157 |
+
map_embeddings_and_outputs_kernel[(B,)](
|
| 158 |
+
reverse_mapping_ptr=reverse_mapping,
|
| 159 |
+
mapping_write_pos_ptr=emb_begin_pos.clone(),
|
| 160 |
+
indices_ptr=indices,
|
| 161 |
+
K=K,
|
| 162 |
+
num_warps=1,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
weight_grad = torch.empty_like(weight, dtype=torch.float32) # [F, D]
|
| 166 |
+
vals_grad = torch.zeros_like(vals, dtype=torch.float32) # [B, K]
|
| 167 |
+
|
| 168 |
+
def _forward_grid(meta):
|
| 169 |
+
return (F, D // meta["BLOCK_D"])
|
| 170 |
+
|
| 171 |
+
aggregate_gradient_for_embedding_kernel[_forward_grid](
|
| 172 |
+
weight_grad_ptr=weight_grad,
|
| 173 |
+
vals_grad_ptr=vals_grad,
|
| 174 |
+
weight_ptr=weight,
|
| 175 |
+
emb_begin_pos_ptr=emb_begin_pos,
|
| 176 |
+
reverse_mapping_ptr=reverse_mapping,
|
| 177 |
+
vals_ptr=vals,
|
| 178 |
+
gradient_ptr=gradient,
|
| 179 |
+
D=D,
|
| 180 |
+
K=K,
|
| 181 |
+
BLOCK_D=256,
|
| 182 |
+
num_warps=1,
|
| 183 |
+
num_stages=2,
|
| 184 |
+
)
|
| 185 |
+
return weight_grad, vals_grad
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class xFormersEmbeddingBag(torch.autograd.Function):
|
| 189 |
+
@staticmethod
|
| 190 |
+
@torch.amp.custom_fwd(device_type="cuda")
|
| 191 |
+
def forward(
|
| 192 |
+
ctx,
|
| 193 |
+
indices: torch.Tensor, # [B, K]
|
| 194 |
+
weight: torch.Tensor, # [F, D]
|
| 195 |
+
vals: torch.Tensor, # [B, K]
|
| 196 |
+
) -> torch.Tensor:
|
| 197 |
+
ctx.save_for_backward(indices, weight, vals)
|
| 198 |
+
return embedding_bag_forward(indices, weight, vals) # [B, D]
|
| 199 |
+
|
| 200 |
+
@staticmethod
|
| 201 |
+
@torch.amp.custom_bwd(device_type="cuda")
|
| 202 |
+
def backward(ctx, gradient):
|
| 203 |
+
indices, weight, vals = ctx.saved_tensors
|
| 204 |
+
weight_g, vals_g = embedding_bag_backward(
|
| 205 |
+
indices,
|
| 206 |
+
weight,
|
| 207 |
+
vals,
|
| 208 |
+
gradient,
|
| 209 |
+
)
|
| 210 |
+
return None, weight_g, vals_g
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def triton_topk_sae_loss(
|
| 214 |
+
indices: torch.Tensor, # [B, K]
|
| 215 |
+
weight: torch.Tensor, # [F, D]
|
| 216 |
+
vals: torch.Tensor, # [B, K]
|
| 217 |
+
bias: torch.Tensor, # [D]
|
| 218 |
+
target: torch.Tensor, # [B, D]
|
| 219 |
+
) -> torch.Tensor:
|
| 220 |
+
recon = bias.to(torch.float32) + xFormersEmbeddingBag.apply(indices, weight, vals)
|
| 221 |
+
diff = recon.to(torch.float32) - target.to(torch.float32)
|
| 222 |
+
loss = diff.pow(2).mean()
|
| 223 |
+
return loss
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def topk_sae_loss(
|
| 227 |
+
indices: torch.Tensor, # [B, K]
|
| 228 |
+
weight: torch.Tensor, # [F, D]
|
| 229 |
+
vals: torch.Tensor, # [B, K]
|
| 230 |
+
bias: torch.Tensor, # [D]
|
| 231 |
+
target: torch.Tensor, # [B, D]
|
| 232 |
+
) -> torch.Tensor:
|
| 233 |
+
emb = weight[indices].to(torch.float32) # [K, D]
|
| 234 |
+
recon = bias.to(torch.float32) + (emb * vals.unsqueeze(-1)).sum(dim=1)
|
| 235 |
+
diff = recon.to(torch.float32) - target.to(torch.float32)
|
| 236 |
+
loss = diff.pow(2).mean()
|
| 237 |
+
return loss
|
example.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /// script
|
| 2 |
+
# dependencies = [
|
| 3 |
+
# "torch",
|
| 4 |
+
# "numpy",
|
| 5 |
+
# "kernels",
|
| 6 |
+
# ]
|
| 7 |
+
# ///
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
from kernels import get_kernel
|
| 12 |
+
|
| 13 |
+
flex = get_kernel("t-tech/flex-sae") #Fast Kernels
|
| 14 |
+
|
| 15 |
+
@torch.compile(fullgraph=True)
|
| 16 |
+
def hierarchical_sae_loss(
|
| 17 |
+
indices: torch.Tensor, # [B, K]
|
| 18 |
+
weight: torch.Tensor, # [F, D]
|
| 19 |
+
vals: torch.Tensor, # [B, K]
|
| 20 |
+
bias: torch.Tensor, # [D]
|
| 21 |
+
target: torch.Tensor, # [B, D]
|
| 22 |
+
) -> torch.Tensor:
|
| 23 |
+
emb = weight[indices].to(torch.float32) # [K, D]
|
| 24 |
+
recon_cum = bias.to(torch.float32) + (emb * vals.unsqueeze(-1)).cumsum(dim=1)
|
| 25 |
+
diff = recon_cum.to(torch.float32) - target.to(torch.float32).unsqueeze(1)
|
| 26 |
+
loss = diff.pow(2).mean()
|
| 27 |
+
return loss
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
B = 2048
|
| 31 |
+
K = 256
|
| 32 |
+
F = 1024 * 128
|
| 33 |
+
D = 1024
|
| 34 |
+
warmup = 5
|
| 35 |
+
dtype = torch.float32
|
| 36 |
+
|
| 37 |
+
vals = None
|
| 38 |
+
decoder = None
|
| 39 |
+
bias = None
|
| 40 |
+
target = None
|
| 41 |
+
indices = None
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def init_parameters():
|
| 45 |
+
global vals, decoder, bias, target, indices
|
| 46 |
+
vals = torch.randn(B, K, dtype=dtype, device="cuda").abs().requires_grad_()
|
| 47 |
+
decoder = torch.randn(F, D, dtype=dtype, device="cuda", requires_grad=True)
|
| 48 |
+
bias = torch.randn(D, dtype=dtype, device="cuda", requires_grad=True)
|
| 49 |
+
target = torch.randn(B, D, dtype=dtype, device="cuda")
|
| 50 |
+
indices = torch.randint(0, F, (B, K), dtype=torch.long, device="cuda")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
timing_kernel = []
|
| 54 |
+
timing_vanilla = []
|
| 55 |
+
torch.cuda.reset_peak_memory_stats()
|
| 56 |
+
loss_kernel_list = torch.zeros((100,))
|
| 57 |
+
loss_vanilla_list = torch.zeros((100,))
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def zero_grad():
|
| 61 |
+
vals.grad = None
|
| 62 |
+
decoder.grad = None
|
| 63 |
+
bias.grad = None
|
| 64 |
+
torch.cuda.empty_cache()
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
for i in range(100 + warmup):
|
| 68 |
+
init_parameters()
|
| 69 |
+
start_kernel = torch.cuda.Event(enable_timing=True)
|
| 70 |
+
end_kernel = torch.cuda.Event(enable_timing=True)
|
| 71 |
+
start_vanilla = torch.cuda.Event(enable_timing=True)
|
| 72 |
+
end_vanilla = torch.cuda.Event(enable_timing=True)
|
| 73 |
+
|
| 74 |
+
start_kernel.record()
|
| 75 |
+
loss_kernel = flex.triton_hierarchical_sae_loss(indices, decoder, vals, bias, target)
|
| 76 |
+
loss_kernel.backward()
|
| 77 |
+
end_kernel.record()
|
| 78 |
+
|
| 79 |
+
zero_grad()
|
| 80 |
+
start_vanilla.record()
|
| 81 |
+
loss_vanilla = hierarchical_sae_loss(indices, decoder, vals, bias, target)
|
| 82 |
+
loss_vanilla.backward()
|
| 83 |
+
end_vanilla.record()
|
| 84 |
+
if i >= warmup:
|
| 85 |
+
torch.cuda.synchronize()
|
| 86 |
+
timing_kernel.append(start_kernel.elapsed_time(end_kernel))
|
| 87 |
+
timing_vanilla.append(start_vanilla.elapsed_time(end_vanilla))
|
| 88 |
+
loss_kernel_list[i-warmup] = loss_kernel.detach()
|
| 89 |
+
loss_vanilla_list[i-warmup] = loss_vanilla.detach()
|
| 90 |
+
zero_grad()
|
| 91 |
+
|
| 92 |
+
if torch.allclose(loss_kernel, loss_vanilla):
|
| 93 |
+
print("✅ Outputs are close! Everything is good! 🎉")
|
| 94 |
+
else:
|
| 95 |
+
print("❌ Outputs mismatch... ⚠️🤔")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
print(f"🦎 Triton Kernel Time (Ours): {np.mean(timing_kernel):.4f} ± {np.std(timing_kernel):.4f} ms")
|
| 99 |
+
print(f"🔥 Torch Compile Kernel Time: {np.mean(timing_vanilla):.4f} ± {np.std(timing_vanilla):.4f} ms")
|
| 100 |
+
print(f"🚀 Speedup: {np.mean(timing_vanilla) / np.mean(timing_kernel):.2f}x")
|
flake.lock
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nodes": {
|
| 3 |
+
"flake-compat": {
|
| 4 |
+
"locked": {
|
| 5 |
+
"lastModified": 1747046372,
|
| 6 |
+
"narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
|
| 7 |
+
"owner": "edolstra",
|
| 8 |
+
"repo": "flake-compat",
|
| 9 |
+
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
|
| 10 |
+
"type": "github"
|
| 11 |
+
},
|
| 12 |
+
"original": {
|
| 13 |
+
"owner": "edolstra",
|
| 14 |
+
"repo": "flake-compat",
|
| 15 |
+
"type": "github"
|
| 16 |
+
}
|
| 17 |
+
},
|
| 18 |
+
"flake-compat_2": {
|
| 19 |
+
"locked": {
|
| 20 |
+
"lastModified": 1747046372,
|
| 21 |
+
"narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
|
| 22 |
+
"owner": "edolstra",
|
| 23 |
+
"repo": "flake-compat",
|
| 24 |
+
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
|
| 25 |
+
"type": "github"
|
| 26 |
+
},
|
| 27 |
+
"original": {
|
| 28 |
+
"owner": "edolstra",
|
| 29 |
+
"repo": "flake-compat",
|
| 30 |
+
"type": "github"
|
| 31 |
+
}
|
| 32 |
+
},
|
| 33 |
+
"flake-utils": {
|
| 34 |
+
"inputs": {
|
| 35 |
+
"systems": "systems"
|
| 36 |
+
},
|
| 37 |
+
"locked": {
|
| 38 |
+
"lastModified": 1731533236,
|
| 39 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
| 40 |
+
"owner": "numtide",
|
| 41 |
+
"repo": "flake-utils",
|
| 42 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
| 43 |
+
"type": "github"
|
| 44 |
+
},
|
| 45 |
+
"original": {
|
| 46 |
+
"owner": "numtide",
|
| 47 |
+
"repo": "flake-utils",
|
| 48 |
+
"type": "github"
|
| 49 |
+
}
|
| 50 |
+
},
|
| 51 |
+
"flake-utils_2": {
|
| 52 |
+
"inputs": {
|
| 53 |
+
"systems": "systems_2"
|
| 54 |
+
},
|
| 55 |
+
"locked": {
|
| 56 |
+
"lastModified": 1731533236,
|
| 57 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
| 58 |
+
"owner": "numtide",
|
| 59 |
+
"repo": "flake-utils",
|
| 60 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
| 61 |
+
"type": "github"
|
| 62 |
+
},
|
| 63 |
+
"original": {
|
| 64 |
+
"owner": "numtide",
|
| 65 |
+
"repo": "flake-utils",
|
| 66 |
+
"type": "github"
|
| 67 |
+
}
|
| 68 |
+
},
|
| 69 |
+
"hf-nix": {
|
| 70 |
+
"inputs": {
|
| 71 |
+
"flake-compat": "flake-compat_2",
|
| 72 |
+
"flake-utils": "flake-utils_2",
|
| 73 |
+
"nixpkgs": "nixpkgs"
|
| 74 |
+
},
|
| 75 |
+
"locked": {
|
| 76 |
+
"lastModified": 1757675377,
|
| 77 |
+
"narHash": "sha256-JQKZOI1ZYO4faJnanuoTXziSmqzXe5rEFSGliWDWqWw=",
|
| 78 |
+
"owner": "huggingface",
|
| 79 |
+
"repo": "hf-nix",
|
| 80 |
+
"rev": "faf3354403a7381958d08e826c15fe30f6986a4f",
|
| 81 |
+
"type": "github"
|
| 82 |
+
},
|
| 83 |
+
"original": {
|
| 84 |
+
"owner": "huggingface",
|
| 85 |
+
"repo": "hf-nix",
|
| 86 |
+
"type": "github"
|
| 87 |
+
}
|
| 88 |
+
},
|
| 89 |
+
"kernel-builder": {
|
| 90 |
+
"inputs": {
|
| 91 |
+
"flake-compat": "flake-compat",
|
| 92 |
+
"flake-utils": "flake-utils",
|
| 93 |
+
"hf-nix": "hf-nix",
|
| 94 |
+
"nixpkgs": [
|
| 95 |
+
"kernel-builder",
|
| 96 |
+
"hf-nix",
|
| 97 |
+
"nixpkgs"
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
"locked": {
|
| 101 |
+
"lastModified": 1758713083,
|
| 102 |
+
"narHash": "sha256-C7yob+hU6/IL7NDX0GVBxKKY3GPVNOwX9OU+LRCCVrk=",
|
| 103 |
+
"owner": "huggingface",
|
| 104 |
+
"repo": "kernel-builder",
|
| 105 |
+
"rev": "051fbc3dfe6afdbe01a6f15197b440d0333090cd",
|
| 106 |
+
"type": "github"
|
| 107 |
+
},
|
| 108 |
+
"original": {
|
| 109 |
+
"owner": "huggingface",
|
| 110 |
+
"repo": "kernel-builder",
|
| 111 |
+
"type": "github"
|
| 112 |
+
}
|
| 113 |
+
},
|
| 114 |
+
"nixpkgs": {
|
| 115 |
+
"locked": {
|
| 116 |
+
"lastModified": 1755963616,
|
| 117 |
+
"narHash": "sha256-6yD0ww/S8n+U2uPYcJZ3DRURP8Kx036GRpR2uPNZroE=",
|
| 118 |
+
"owner": "nixos",
|
| 119 |
+
"repo": "nixpkgs",
|
| 120 |
+
"rev": "73e96df7cff5783f45e21342a75a1540c4eddce4",
|
| 121 |
+
"type": "github"
|
| 122 |
+
},
|
| 123 |
+
"original": {
|
| 124 |
+
"owner": "nixos",
|
| 125 |
+
"ref": "nixos-unstable-small",
|
| 126 |
+
"repo": "nixpkgs",
|
| 127 |
+
"type": "github"
|
| 128 |
+
}
|
| 129 |
+
},
|
| 130 |
+
"root": {
|
| 131 |
+
"inputs": {
|
| 132 |
+
"kernel-builder": "kernel-builder"
|
| 133 |
+
}
|
| 134 |
+
},
|
| 135 |
+
"systems": {
|
| 136 |
+
"locked": {
|
| 137 |
+
"lastModified": 1681028828,
|
| 138 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
| 139 |
+
"owner": "nix-systems",
|
| 140 |
+
"repo": "default",
|
| 141 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
| 142 |
+
"type": "github"
|
| 143 |
+
},
|
| 144 |
+
"original": {
|
| 145 |
+
"owner": "nix-systems",
|
| 146 |
+
"repo": "default",
|
| 147 |
+
"type": "github"
|
| 148 |
+
}
|
| 149 |
+
},
|
| 150 |
+
"systems_2": {
|
| 151 |
+
"locked": {
|
| 152 |
+
"lastModified": 1681028828,
|
| 153 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
| 154 |
+
"owner": "nix-systems",
|
| 155 |
+
"repo": "default",
|
| 156 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
| 157 |
+
"type": "github"
|
| 158 |
+
},
|
| 159 |
+
"original": {
|
| 160 |
+
"owner": "nix-systems",
|
| 161 |
+
"repo": "default",
|
| 162 |
+
"type": "github"
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
},
|
| 166 |
+
"root": "root",
|
| 167 |
+
"version": 7
|
| 168 |
+
}
|
flake.nix
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
description = "Flake for TopK and HierarchicaTopK SAE Triton kernels";
|
| 3 |
+
|
| 4 |
+
inputs = {
|
| 5 |
+
kernel-builder.url = "github:huggingface/kernel-builder";
|
| 6 |
+
};
|
| 7 |
+
|
| 8 |
+
outputs =
|
| 9 |
+
{
|
| 10 |
+
self,
|
| 11 |
+
kernel-builder,
|
| 12 |
+
}:
|
| 13 |
+
kernel-builder.lib.genFlakeOutputs {
|
| 14 |
+
path = ./.;
|
| 15 |
+
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
|
| 16 |
+
doGetKernelCheck = true;
|
| 17 |
+
};
|
| 18 |
+
}
|
tests/__init__.py
ADDED
|
File without changes
|
tests/test_all_kernels.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable
|
| 2 |
+
import pytest
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
pytest.importorskip("torch.cuda")
|
| 6 |
+
from .test_setup import DTYPES, DTYPE_TO_TOLS, PARAMS, SEED
|
| 7 |
+
from flex_sae import (
|
| 8 |
+
triton_hierarchical_sae_loss,
|
| 9 |
+
hierarchical_sae_loss,
|
| 10 |
+
triton_topk_sae_loss,
|
| 11 |
+
topk_sae_loss,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@pytest.fixture(autouse=True)
|
| 16 |
+
def _set_cuda_default_device():
|
| 17 |
+
torch.set_default_device("cuda")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def run_funcs(B, K, F, D, dtype, *, kernel_foo: Callable, ref_foo: Callable):
|
| 21 |
+
if dtype is torch.bfloat16 and not torch.cuda.is_bf16_supported():
|
| 22 |
+
pytest.skip("BF16 not supported on this GPU")
|
| 23 |
+
|
| 24 |
+
torch.manual_seed(SEED)
|
| 25 |
+
|
| 26 |
+
indices = torch.randint(0, F, (B, K), dtype=torch.long, device="cuda")
|
| 27 |
+
|
| 28 |
+
vals = torch.randn(B, K, dtype=dtype, device="cuda").abs().requires_grad_()
|
| 29 |
+
decoder = torch.randn(F, D, dtype=dtype, device="cuda", requires_grad=True)
|
| 30 |
+
bias = torch.randn(D, dtype=dtype, device="cuda", requires_grad=True)
|
| 31 |
+
target = torch.randn(B, D, dtype=dtype, device="cuda")
|
| 32 |
+
|
| 33 |
+
sv_ref = vals.clone().detach().requires_grad_()
|
| 34 |
+
dec_ref = decoder.clone().detach().requires_grad_()
|
| 35 |
+
bias_ref = bias.clone().detach().requires_grad_()
|
| 36 |
+
|
| 37 |
+
loss_f = kernel_foo(indices, decoder, vals, bias, target)
|
| 38 |
+
loss_r = ref_foo(indices, dec_ref, sv_ref, bias_ref, target)
|
| 39 |
+
|
| 40 |
+
torch.testing.assert_close(loss_f, loss_r, **DTYPE_TO_TOLS[dtype])
|
| 41 |
+
|
| 42 |
+
grad_out = torch.randn((), device="cuda", dtype=torch.float32)
|
| 43 |
+
loss_f.backward(grad_out)
|
| 44 |
+
loss_r.backward(grad_out.clone())
|
| 45 |
+
|
| 46 |
+
torch.testing.assert_close(vals.grad, sv_ref.grad, **DTYPE_TO_TOLS[dtype])
|
| 47 |
+
torch.testing.assert_close(decoder.grad, dec_ref.grad, **DTYPE_TO_TOLS[dtype])
|
| 48 |
+
torch.testing.assert_close(bias.grad, bias_ref.grad, **DTYPE_TO_TOLS[dtype])
|
| 49 |
+
|
| 50 |
+
assert indices.grad is None
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@pytest.mark.parametrize("B, K, F, D", PARAMS)
|
| 54 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
| 55 |
+
def test_triton_hierarchical_sae_loss_and_grads(B, K, F, D, dtype):
|
| 56 |
+
run_funcs(B, K, F, D, dtype, kernel_foo=triton_hierarchical_sae_loss, ref_foo=hierarchical_sae_loss)
|
| 57 |
+
torch.cuda.empty_cache()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@pytest.mark.parametrize("B, K, F, D", PARAMS)
|
| 61 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
| 62 |
+
def test_topk_sae_loss_and_grads(B, K, F, D, dtype):
|
| 63 |
+
run_funcs(
|
| 64 |
+
B, K, F, D, dtype, kernel_foo=triton_topk_sae_loss, ref_foo=topk_sae_loss
|
| 65 |
+
)
|
| 66 |
+
torch.cuda.empty_cache()
|
tests/test_setup.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
SEED = 1234
|
| 4 |
+
|
| 5 |
+
PARAMS = [
|
| 6 |
+
(16, 16, 64, 512),
|
| 7 |
+
(16, 32, 96, 768),
|
| 8 |
+
(16, 64, 128, 1024),
|
| 9 |
+
(32, 16, 128, 512),
|
| 10 |
+
(32, 32, 160, 768),
|
| 11 |
+
(32, 64, 192, 1024),
|
| 12 |
+
(48, 32, 176, 1024),
|
| 13 |
+
(48, 64, 224, 1280),
|
| 14 |
+
(64, 16, 192, 768),
|
| 15 |
+
(64, 32, 224, 1024),
|
| 16 |
+
(64, 128, 256, 2048),
|
| 17 |
+
(80, 32, 240, 1280),
|
| 18 |
+
(80, 64, 256, 1536),
|
| 19 |
+
(96, 32, 256, 1536),
|
| 20 |
+
(96, 64, 288, 2048),
|
| 21 |
+
(96, 128, 320, 3072),
|
| 22 |
+
(112, 64, 320, 2048),
|
| 23 |
+
(112, 128, 352, 2560),
|
| 24 |
+
(128, 32, 256, 1024),
|
| 25 |
+
(128, 64, 320, 1536),
|
| 26 |
+
(128, 128, 384, 3072),
|
| 27 |
+
(160, 64, 320, 1536),
|
| 28 |
+
(160, 128, 384, 2560),
|
| 29 |
+
(192, 64, 384, 2048),
|
| 30 |
+
(192, 128, 448, 3072),
|
| 31 |
+
(192, 256, 512, 4096),
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
DTYPE_TO_TOLS = {
|
| 36 |
+
torch.float32: {"atol": 1e-4, "rtol": 1e-3},
|
| 37 |
+
torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
|
| 38 |
+
torch.float16: {"atol": 1e-3, "rtol": 1e-3},
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
DTYPES = list(DTYPE_TO_TOLS.keys())
|
torch-ext/flex_sae/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TopK and HierarchicalTopK SAE decoder Triton kernels
|
| 2 |
+
# Copyright 2025 T-Tech
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
from .topk_kernels import triton_topk_sae_loss, topk_sae_loss
|
| 6 |
+
from .hierarchical_kernels import triton_hierarchical_sae_loss, hierarchical_sae_loss
|
| 7 |
+
|
| 8 |
+
__kernel_metadata__ = {
|
| 9 |
+
"license": "Apache-2.0 (with CC-BY-NC-4.0 component; see NOTICE)",
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"__kernel_metadata__",
|
| 14 |
+
"topk_sae_loss",
|
| 15 |
+
"triton_topk_sae_loss",
|
| 16 |
+
"hierarchical_sae_loss",
|
| 17 |
+
"triton_hierarchical_sae_loss",
|
| 18 |
+
]
|
torch-ext/flex_sae/hierarchical_kernels.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# HierarchicalTopK SAE decoder Triton kernels
|
| 2 |
+
# Copyright 2025 T-Tech
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
from typing import Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import triton
|
| 9 |
+
import triton.language as tl
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@triton.jit
|
| 13 |
+
def hierarchical_sae_forward_kernel(
|
| 14 |
+
loss_per_batch_ptr, # [B]
|
| 15 |
+
final_recon_ptr, # [B, D]
|
| 16 |
+
indices_ptr, # [B, K]
|
| 17 |
+
weight_ptr, # [F, D]
|
| 18 |
+
bias_ptr, # [D]
|
| 19 |
+
vals_ptr, # [B, K]
|
| 20 |
+
target_ptr, # [B, D]
|
| 21 |
+
B: tl.constexpr,
|
| 22 |
+
D: tl.constexpr,
|
| 23 |
+
K: tl.constexpr,
|
| 24 |
+
BLOCK_D: tl.constexpr,
|
| 25 |
+
LOOP_NUM_STAGES: tl.constexpr,
|
| 26 |
+
BLOCK_B: tl.constexpr,
|
| 27 |
+
):
|
| 28 |
+
tl.static_assert((D % BLOCK_D) == 0)
|
| 29 |
+
tl.static_assert((B % BLOCK_B) == 0)
|
| 30 |
+
tl.static_assert((K & (K - 1)) == 0, f"{K=} must be a power of 2")
|
| 31 |
+
tl.static_assert((BLOCK_D & (BLOCK_D - 1)) == 0, f"{BLOCK_D=} must be a power of 2")
|
| 32 |
+
tl.static_assert((BLOCK_B & (BLOCK_B - 1)) == 0, f"{BLOCK_B=} must be a power of 2")
|
| 33 |
+
|
| 34 |
+
pid_b = tl.program_id(axis=0).to(tl.int64)
|
| 35 |
+
pid_d = tl.program_id(axis=1).to(tl.int64)
|
| 36 |
+
|
| 37 |
+
batch_offsets = pid_b * BLOCK_B + tl.arange(0, BLOCK_B)
|
| 38 |
+
batch_offsets = batch_offsets.to(tl.int64)
|
| 39 |
+
tl.multiple_of(batch_offsets, BLOCK_B)
|
| 40 |
+
|
| 41 |
+
offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
|
| 42 |
+
offset_d = offset_d.to(tl.int64)
|
| 43 |
+
|
| 44 |
+
tl.multiple_of(offset_d, BLOCK_D)
|
| 45 |
+
tl.max_contiguous(offset_d, BLOCK_D)
|
| 46 |
+
|
| 47 |
+
batch_d_offset = batch_offsets[:, None] * D + offset_d[None, :]
|
| 48 |
+
|
| 49 |
+
bias_tile = tl.load(bias_ptr + offset_d).to(tl.float32)
|
| 50 |
+
|
| 51 |
+
recon = tl.zeros([BLOCK_B, BLOCK_D], dtype=tl.float32)
|
| 52 |
+
recon += bias_tile[None, :]
|
| 53 |
+
|
| 54 |
+
target = tl.load(target_ptr + batch_d_offset).to(tl.float32)
|
| 55 |
+
|
| 56 |
+
loss_accum = tl.zeros([BLOCK_B, BLOCK_D], dtype=tl.float32)
|
| 57 |
+
|
| 58 |
+
row_idx_ptr = indices_ptr + batch_offsets * K
|
| 59 |
+
row_val_ptr = vals_ptr + batch_offsets * K
|
| 60 |
+
|
| 61 |
+
idx = tl.load(row_idx_ptr).to(tl.int64)
|
| 62 |
+
val = tl.load(row_val_ptr).to(tl.float32)
|
| 63 |
+
val = val[:, None]
|
| 64 |
+
weight_tile = tl.load(weight_ptr + idx[:, None] * D + offset_d[None, :]).to(tl.float32)
|
| 65 |
+
|
| 66 |
+
for t in tl.range(0, K, num_stages=LOOP_NUM_STAGES):
|
| 67 |
+
recon += weight_tile * val
|
| 68 |
+
diff = recon - target
|
| 69 |
+
loss_accum += diff * diff
|
| 70 |
+
|
| 71 |
+
if t + 1 < K:
|
| 72 |
+
idx_next = tl.load(row_idx_ptr + (t + 1)).to(tl.int64)
|
| 73 |
+
val_next = tl.load(row_val_ptr + (t + 1)).to(tl.float32)
|
| 74 |
+
weight_next = tl.load(weight_ptr + idx_next[:, None] * D + offset_d[None, :]).to(tl.float32)
|
| 75 |
+
|
| 76 |
+
idx = idx_next
|
| 77 |
+
val = val_next[:, None]
|
| 78 |
+
weight_tile = weight_next
|
| 79 |
+
|
| 80 |
+
loss_tile = tl.sum(loss_accum, axis=1)
|
| 81 |
+
tl.atomic_add(
|
| 82 |
+
loss_per_batch_ptr + batch_offsets,
|
| 83 |
+
loss_tile,
|
| 84 |
+
sem="relaxed",
|
| 85 |
+
)
|
| 86 |
+
tl.store(
|
| 87 |
+
final_recon_ptr + batch_d_offset,
|
| 88 |
+
recon,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@triton.jit
|
| 93 |
+
def hierarchical_sae_backward_kernel(
|
| 94 |
+
weight_grad_ptr, # [F, D]
|
| 95 |
+
vals_grad_ptr, # [B, K]
|
| 96 |
+
bias_grad_ptr, # [D]
|
| 97 |
+
final_recon_ptr, # [B, D]
|
| 98 |
+
indices_ptr, # [B, K]
|
| 99 |
+
weight_ptr, # [F, D]
|
| 100 |
+
vals_ptr, # [B, K]
|
| 101 |
+
target_ptr, # [B, D]
|
| 102 |
+
B: tl.constexpr,
|
| 103 |
+
D: tl.constexpr,
|
| 104 |
+
K: tl.constexpr,
|
| 105 |
+
BLOCK_D: tl.constexpr,
|
| 106 |
+
LOOP_NUM_STAGES: tl.constexpr,
|
| 107 |
+
BLOCK_B: tl.constexpr,
|
| 108 |
+
):
|
| 109 |
+
tl.static_assert((D % BLOCK_D) == 0)
|
| 110 |
+
tl.static_assert((B % BLOCK_B) == 0)
|
| 111 |
+
tl.static_assert((K & (K - 1)) == 0, f"{K=} must be a power of 2")
|
| 112 |
+
tl.static_assert((BLOCK_D & (BLOCK_D - 1)) == 0, f"{BLOCK_D=} must be a power of 2")
|
| 113 |
+
tl.static_assert((BLOCK_B & (BLOCK_B - 1)) == 0, f"{BLOCK_B=} must be a power of 2")
|
| 114 |
+
|
| 115 |
+
pid_b = tl.program_id(axis=0).to(tl.int64)
|
| 116 |
+
pid_d = tl.program_id(axis=1).to(tl.int64)
|
| 117 |
+
|
| 118 |
+
batch_offsets = pid_b * BLOCK_B + tl.arange(0, BLOCK_B)
|
| 119 |
+
batch_offsets = batch_offsets.to(tl.int64)
|
| 120 |
+
tl.multiple_of(batch_offsets, BLOCK_B)
|
| 121 |
+
|
| 122 |
+
offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
|
| 123 |
+
offset_d = offset_d.to(tl.int64)
|
| 124 |
+
|
| 125 |
+
tl.multiple_of(offset_d, BLOCK_D)
|
| 126 |
+
tl.max_contiguous(offset_d, BLOCK_D)
|
| 127 |
+
|
| 128 |
+
batch_d_offset = batch_offsets[:, None] * D + offset_d[None, :]
|
| 129 |
+
|
| 130 |
+
recon = tl.load(final_recon_ptr + batch_d_offset).to(tl.float32)
|
| 131 |
+
target = tl.load(target_ptr + batch_d_offset).to(tl.float32)
|
| 132 |
+
|
| 133 |
+
suffix = tl.zeros([BLOCK_B, BLOCK_D], dtype=tl.float32)
|
| 134 |
+
bias_accum = tl.zeros([BLOCK_B, BLOCK_D], dtype=tl.float32)
|
| 135 |
+
scale = tl.full((), 2.0 / (B * K * D), dtype=tl.float32)
|
| 136 |
+
|
| 137 |
+
row_idx_ptr = indices_ptr + batch_offsets * K
|
| 138 |
+
row_val_ptr = vals_ptr + batch_offsets * K
|
| 139 |
+
k_offsets = tl.arange(0, K)
|
| 140 |
+
val_grad_tile = tl.zeros([BLOCK_B, K], dtype=tl.float32)
|
| 141 |
+
|
| 142 |
+
step = K - 1
|
| 143 |
+
idx = tl.load(row_idx_ptr + step).to(tl.int64)
|
| 144 |
+
val = tl.load(row_val_ptr + step).to(tl.float32)
|
| 145 |
+
weight_tile = tl.load(weight_ptr + idx[:, None] * D + offset_d[None, :]).to(tl.float32)
|
| 146 |
+
|
| 147 |
+
for _ in tl.range(0, K, num_stages=LOOP_NUM_STAGES):
|
| 148 |
+
curr_step = step
|
| 149 |
+
|
| 150 |
+
diff = recon - target
|
| 151 |
+
grad_curr = diff * scale
|
| 152 |
+
suffix += grad_curr
|
| 153 |
+
bias_accum += grad_curr
|
| 154 |
+
|
| 155 |
+
val_broadcast = val[:, None]
|
| 156 |
+
contrib = suffix * val_broadcast
|
| 157 |
+
tl.atomic_add(
|
| 158 |
+
weight_grad_ptr + idx[:, None] * D + offset_d[None, :],
|
| 159 |
+
contrib,
|
| 160 |
+
sem="relaxed",
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
dot_partial = tl.sum(weight_tile * suffix, axis=1)
|
| 164 |
+
mask_curr = k_offsets[None, :] == curr_step
|
| 165 |
+
val_grad_tile = tl.where(mask_curr, dot_partial[:, None], val_grad_tile)
|
| 166 |
+
|
| 167 |
+
recon -= weight_tile * val_broadcast
|
| 168 |
+
|
| 169 |
+
if curr_step > 0:
|
| 170 |
+
step = curr_step - 1
|
| 171 |
+
idx = tl.load(row_idx_ptr + step).to(tl.int64)
|
| 172 |
+
val = tl.load(row_val_ptr + step).to(tl.float32)
|
| 173 |
+
weight_tile = tl.load(weight_ptr + idx[:, None] * D + offset_d[None, :]).to(tl.float32)
|
| 174 |
+
|
| 175 |
+
bias_grad_tile = tl.sum(bias_accum, axis=0)
|
| 176 |
+
tl.atomic_add(
|
| 177 |
+
bias_grad_ptr + offset_d,
|
| 178 |
+
bias_grad_tile,
|
| 179 |
+
sem="relaxed",
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
row_val_grad_ptr = vals_grad_ptr + batch_offsets[:, None] * K + k_offsets[None, :]
|
| 183 |
+
tl.atomic_add(
|
| 184 |
+
row_val_grad_ptr,
|
| 185 |
+
val_grad_tile,
|
| 186 |
+
sem="relaxed",
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def _hierarchical_sae_forward(
|
| 191 |
+
indices: torch.Tensor, # [B, K]
|
| 192 |
+
weight: torch.Tensor, # [F, D]
|
| 193 |
+
vals: torch.Tensor, # [B, K]
|
| 194 |
+
bias: torch.Tensor, # [D]
|
| 195 |
+
target: torch.Tensor, # [B, D]
|
| 196 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 197 |
+
B, K = indices.shape
|
| 198 |
+
F, D = weight.shape
|
| 199 |
+
|
| 200 |
+
loss_per_batch = torch.zeros((B,), dtype=torch.float32, device=weight.device)
|
| 201 |
+
final_recon = torch.empty((B, D), dtype=torch.float32, device=weight.device)
|
| 202 |
+
|
| 203 |
+
def _forward_grid(meta):
|
| 204 |
+
return (
|
| 205 |
+
B // meta["BLOCK_B"],
|
| 206 |
+
D // meta["BLOCK_D"],
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
hierarchical_sae_forward_kernel[_forward_grid](
|
| 210 |
+
loss_per_batch,
|
| 211 |
+
final_recon,
|
| 212 |
+
indices,
|
| 213 |
+
weight,
|
| 214 |
+
bias,
|
| 215 |
+
vals,
|
| 216 |
+
target,
|
| 217 |
+
B=B,
|
| 218 |
+
D=D,
|
| 219 |
+
K=K,
|
| 220 |
+
BLOCK_D=64,
|
| 221 |
+
LOOP_NUM_STAGES=4,
|
| 222 |
+
BLOCK_B=1,
|
| 223 |
+
num_warps=2,
|
| 224 |
+
num_stages=2,
|
| 225 |
+
)
|
| 226 |
+
loss = loss_per_batch.sum() / (B * K * D)
|
| 227 |
+
return loss, final_recon
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def _hierarchical_sae_backward(
|
| 231 |
+
indices: torch.Tensor, # [B, K]
|
| 232 |
+
weight: torch.Tensor, # [F, D]
|
| 233 |
+
vals: torch.Tensor, # [B, K]
|
| 234 |
+
target: torch.Tensor, # [B, D]
|
| 235 |
+
final_recon: torch.Tensor, # [B, D]
|
| 236 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 237 |
+
device = weight.device
|
| 238 |
+
B, K = indices.shape
|
| 239 |
+
F, D = weight.shape
|
| 240 |
+
|
| 241 |
+
dW = torch.zeros((F, D), dtype=torch.float32, device=device)
|
| 242 |
+
dVals = torch.zeros((B, K), dtype=torch.float32, device=device)
|
| 243 |
+
db = torch.zeros((D,), dtype=torch.float32, device=device)
|
| 244 |
+
|
| 245 |
+
def _backward_grid(meta):
|
| 246 |
+
return (
|
| 247 |
+
B // meta["BLOCK_B"],
|
| 248 |
+
D // meta["BLOCK_D"],
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
hierarchical_sae_backward_kernel[_backward_grid](
|
| 252 |
+
dW,
|
| 253 |
+
dVals,
|
| 254 |
+
db,
|
| 255 |
+
final_recon,
|
| 256 |
+
indices,
|
| 257 |
+
weight,
|
| 258 |
+
vals,
|
| 259 |
+
target,
|
| 260 |
+
B=B,
|
| 261 |
+
D=D,
|
| 262 |
+
K=K,
|
| 263 |
+
BLOCK_D=32,
|
| 264 |
+
LOOP_NUM_STAGES=16,
|
| 265 |
+
BLOCK_B=16,
|
| 266 |
+
num_warps=8,
|
| 267 |
+
num_stages=8,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
return dW, dVals, db
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class HierarchicalSAELossFunction(torch.autograd.Function):
|
| 274 |
+
@staticmethod
|
| 275 |
+
@torch.amp.custom_fwd(device_type="cuda")
|
| 276 |
+
def forward(
|
| 277 |
+
ctx,
|
| 278 |
+
indices: torch.Tensor, # [B, K]
|
| 279 |
+
weight: torch.Tensor, # [F, D]
|
| 280 |
+
vals: torch.Tensor, # [B, K]
|
| 281 |
+
bias: torch.Tensor, # [D]
|
| 282 |
+
target: torch.Tensor, # [B, D]
|
| 283 |
+
):
|
| 284 |
+
loss, final_recon = _hierarchical_sae_forward(indices, weight, vals, bias, target)
|
| 285 |
+
ctx.save_for_backward(indices, weight, vals, target, final_recon)
|
| 286 |
+
return loss
|
| 287 |
+
|
| 288 |
+
@staticmethod
|
| 289 |
+
@torch.amp.custom_bwd(device_type="cuda")
|
| 290 |
+
def backward(ctx, grad):
|
| 291 |
+
indices, weight, vals, target, final_recon = ctx.saved_tensors
|
| 292 |
+
dW, dVals, db = _hierarchical_sae_backward(indices, weight, vals, target, final_recon)
|
| 293 |
+
|
| 294 |
+
if grad is not None:
|
| 295 |
+
dW.mul_(grad)
|
| 296 |
+
dVals.mul_(grad)
|
| 297 |
+
db.mul_(grad)
|
| 298 |
+
|
| 299 |
+
return None, dW, dVals, db, None
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def triton_hierarchical_sae_loss(
|
| 303 |
+
indices: torch.Tensor, # [B, K]
|
| 304 |
+
weight: torch.Tensor, # [F, D]
|
| 305 |
+
vals: torch.Tensor, # [B, K]
|
| 306 |
+
bias: torch.Tensor, # [D]
|
| 307 |
+
target: torch.Tensor, # [B, D]
|
| 308 |
+
) -> torch.Tensor:
|
| 309 |
+
return HierarchicalSAELossFunction.apply(indices, weight, vals, bias, target)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def hierarchical_sae_loss(
|
| 313 |
+
indices: torch.Tensor, # [B, K]
|
| 314 |
+
weight: torch.Tensor, # [F, D]
|
| 315 |
+
vals: torch.Tensor, # [B, K]
|
| 316 |
+
bias: torch.Tensor, # [D]
|
| 317 |
+
target: torch.Tensor, # [B, D]
|
| 318 |
+
) -> torch.Tensor:
|
| 319 |
+
emb = weight[indices].to(torch.float32) # [K, D]
|
| 320 |
+
recon_cum = bias.to(torch.float32) + (emb * vals.unsqueeze(-1)).cumsum(dim=1)
|
| 321 |
+
diff = recon_cum.to(torch.float32) - target.to(torch.float32).unsqueeze(1)
|
| 322 |
+
loss = diff.pow(2).mean()
|
| 323 |
+
return loss
|
torch-ext/flex_sae/topk_kernels.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TopK SAE decoder Triton kernels
|
| 2 |
+
# Copyright 2025 T-Tech
|
| 3 |
+
# This code is adapted from Facebook Research under the
|
| 4 |
+
# Creative Commons Attribution-NonCommercial 4.0 International License.
|
| 5 |
+
# Original code can be found at: https://github.com/facebookresearch/memory
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import triton
|
| 12 |
+
import triton.language as tl
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@triton.jit
|
| 16 |
+
def embedding_bag_forward_kernel(
|
| 17 |
+
out_ptr, # [B, D]
|
| 18 |
+
indices_ptr, # [B, K]
|
| 19 |
+
weight_ptr, # [F, D]
|
| 20 |
+
vals_ptr, # [B, K]
|
| 21 |
+
D: tl.constexpr,
|
| 22 |
+
K: tl.constexpr,
|
| 23 |
+
BLOCK_D: tl.constexpr,
|
| 24 |
+
):
|
| 25 |
+
tl.static_assert((D % BLOCK_D) == 0)
|
| 26 |
+
tl.static_assert((K & (K - 1)) == 0, f"{K=} must be a power of 2")
|
| 27 |
+
tl.static_assert((BLOCK_D & (BLOCK_D - 1)) == 0, f"{BLOCK_D=} must be a power of 2")
|
| 28 |
+
|
| 29 |
+
b = tl.program_id(axis=0).to(tl.int64)
|
| 30 |
+
pid_d = tl.program_id(axis=1).to(tl.int64)
|
| 31 |
+
|
| 32 |
+
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
|
| 33 |
+
|
| 34 |
+
out_value = tl.zeros([BLOCK_D], dtype=tl.float32)
|
| 35 |
+
for i in tl.range(K):
|
| 36 |
+
my_index = tl.load(indices_ptr + b * K + i).to(tl.int64)
|
| 37 |
+
my_scaling = tl.load(vals_ptr + b * K + i)
|
| 38 |
+
w_tile = tl.load(weight_ptr + my_index * D + off_d).to(tl.float32)
|
| 39 |
+
out_value += w_tile * my_scaling
|
| 40 |
+
|
| 41 |
+
tl.store(out_ptr + b * D + off_d, out_value)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def embedding_bag_forward(
|
| 45 |
+
indices: torch.Tensor, # [B, K]
|
| 46 |
+
weight: torch.Tensor, # [F, D]
|
| 47 |
+
vals: torch.Tensor, # [B, K]
|
| 48 |
+
) -> torch.Tensor:
|
| 49 |
+
B, K = indices.shape
|
| 50 |
+
D = weight.shape[1]
|
| 51 |
+
|
| 52 |
+
trt_out = torch.empty([B, D], dtype=weight.dtype, device=weight.device)
|
| 53 |
+
|
| 54 |
+
def _forward_grid(meta):
|
| 55 |
+
return (B, D // meta["BLOCK_D"])
|
| 56 |
+
|
| 57 |
+
embedding_bag_forward_kernel[_forward_grid](
|
| 58 |
+
trt_out,
|
| 59 |
+
indices,
|
| 60 |
+
weight,
|
| 61 |
+
vals,
|
| 62 |
+
D=D,
|
| 63 |
+
K=K,
|
| 64 |
+
BLOCK_D=64,
|
| 65 |
+
num_warps=1,
|
| 66 |
+
num_stages=1,
|
| 67 |
+
)
|
| 68 |
+
return trt_out
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@triton.jit
|
| 72 |
+
def count_per_embedding_kernel(
|
| 73 |
+
count_per_emb_ptr, # [F + 1]
|
| 74 |
+
indices_ptr, # [B, K]
|
| 75 |
+
K: tl.constexpr,
|
| 76 |
+
):
|
| 77 |
+
batch_id = tl.program_id(axis=0).to(tl.int64)
|
| 78 |
+
for t in tl.range(K):
|
| 79 |
+
embedding_id = tl.load(indices_ptr + batch_id * K + t)
|
| 80 |
+
tl.atomic_add(count_per_emb_ptr + embedding_id + 1, 1, sem="relaxed")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@triton.jit
|
| 84 |
+
def map_embeddings_and_outputs_kernel(
|
| 85 |
+
reverse_mapping_ptr, # [B * K]
|
| 86 |
+
mapping_write_pos_ptr, # [F]
|
| 87 |
+
indices_ptr, # [B, K]
|
| 88 |
+
K: tl.constexpr,
|
| 89 |
+
):
|
| 90 |
+
batch_id = tl.program_id(axis=0).to(tl.int64)
|
| 91 |
+
for t in tl.range(K):
|
| 92 |
+
embedding_id = tl.load(indices_ptr + batch_id * K + t)
|
| 93 |
+
write_pos = tl.atomic_add(mapping_write_pos_ptr + embedding_id, 1, sem="relaxed")
|
| 94 |
+
tl.store(reverse_mapping_ptr + write_pos, batch_id * K + t)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@triton.jit
|
| 98 |
+
def aggregate_gradient_for_embedding_kernel(
|
| 99 |
+
weight_grad_ptr, # [F, D]
|
| 100 |
+
vals_grad_ptr, # [B, K]
|
| 101 |
+
weight_ptr, # [F, D]
|
| 102 |
+
emb_begin_pos_ptr, # [F + 1]
|
| 103 |
+
reverse_mapping_ptr, # [B * K]
|
| 104 |
+
vals_ptr, # [B, K]
|
| 105 |
+
gradient_ptr, # [B, D]
|
| 106 |
+
D: tl.constexpr,
|
| 107 |
+
K: tl.constexpr,
|
| 108 |
+
BLOCK_D: tl.constexpr,
|
| 109 |
+
):
|
| 110 |
+
tl.static_assert((D % BLOCK_D) == 0)
|
| 111 |
+
tl.static_assert((K & (K - 1)) == 0, f"{K=} must be a power of 2")
|
| 112 |
+
tl.static_assert((BLOCK_D & (BLOCK_D - 1)) == 0, f"{BLOCK_D=} must be a power of 2")
|
| 113 |
+
|
| 114 |
+
e = tl.program_id(axis=0).to(tl.int64)
|
| 115 |
+
pid_d = tl.program_id(axis=1).to(tl.int64)
|
| 116 |
+
|
| 117 |
+
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
|
| 118 |
+
|
| 119 |
+
begin = tl.load(emb_begin_pos_ptr + e)
|
| 120 |
+
end = tl.load(emb_begin_pos_ptr + e + 1)
|
| 121 |
+
|
| 122 |
+
w_row_tile = tl.load(weight_ptr + e * D + off_d).to(tl.float32)
|
| 123 |
+
w_grad_tile = tl.zeros([BLOCK_D], dtype=tl.float32)
|
| 124 |
+
|
| 125 |
+
for idx in tl.range(begin, end):
|
| 126 |
+
out_linear = tl.load(reverse_mapping_ptr + idx).to(tl.int64)
|
| 127 |
+
b = out_linear // K
|
| 128 |
+
|
| 129 |
+
psw = tl.load(vals_ptr + out_linear)
|
| 130 |
+
g_tile = tl.load(gradient_ptr + b * D + off_d).to(tl.float32)
|
| 131 |
+
|
| 132 |
+
w_grad_tile += psw * g_tile
|
| 133 |
+
|
| 134 |
+
psw_grad_partial = tl.sum(g_tile * w_row_tile)
|
| 135 |
+
tl.atomic_add(vals_grad_ptr + out_linear, psw_grad_partial, sem="relaxed")
|
| 136 |
+
|
| 137 |
+
tl.store(weight_grad_ptr + e * D + off_d, w_grad_tile)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def embedding_bag_backward(
|
| 141 |
+
indices: torch.Tensor, # [B, K]
|
| 142 |
+
weight: torch.Tensor, # [F, D]
|
| 143 |
+
vals: torch.Tensor, # [B, K]
|
| 144 |
+
gradient: torch.Tensor, # [B, D]
|
| 145 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 146 |
+
F, D = weight.shape
|
| 147 |
+
B, K = indices.shape
|
| 148 |
+
|
| 149 |
+
count_per_emb = torch.zeros((F + 1,), dtype=torch.uint32, device=indices.device)
|
| 150 |
+
count_per_embedding_kernel[(B,)](count_per_emb, indices, K=K, num_warps=1)
|
| 151 |
+
|
| 152 |
+
emb_begin_pos = count_per_emb.cumsum(0) # [F + 1]
|
| 153 |
+
|
| 154 |
+
reverse_mapping = torch.empty([B * K], dtype=torch.uint32, device=indices.device)
|
| 155 |
+
assert B * K <= 2 ** (reverse_mapping.dtype.itemsize * 8) - 1
|
| 156 |
+
|
| 157 |
+
map_embeddings_and_outputs_kernel[(B,)](
|
| 158 |
+
reverse_mapping_ptr=reverse_mapping,
|
| 159 |
+
mapping_write_pos_ptr=emb_begin_pos.clone(),
|
| 160 |
+
indices_ptr=indices,
|
| 161 |
+
K=K,
|
| 162 |
+
num_warps=1,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
weight_grad = torch.empty_like(weight, dtype=torch.float32) # [F, D]
|
| 166 |
+
vals_grad = torch.zeros_like(vals, dtype=torch.float32) # [B, K]
|
| 167 |
+
|
| 168 |
+
def _forward_grid(meta):
|
| 169 |
+
return (F, D // meta["BLOCK_D"])
|
| 170 |
+
|
| 171 |
+
aggregate_gradient_for_embedding_kernel[_forward_grid](
|
| 172 |
+
weight_grad_ptr=weight_grad,
|
| 173 |
+
vals_grad_ptr=vals_grad,
|
| 174 |
+
weight_ptr=weight,
|
| 175 |
+
emb_begin_pos_ptr=emb_begin_pos,
|
| 176 |
+
reverse_mapping_ptr=reverse_mapping,
|
| 177 |
+
vals_ptr=vals,
|
| 178 |
+
gradient_ptr=gradient,
|
| 179 |
+
D=D,
|
| 180 |
+
K=K,
|
| 181 |
+
BLOCK_D=256,
|
| 182 |
+
num_warps=1,
|
| 183 |
+
num_stages=2,
|
| 184 |
+
)
|
| 185 |
+
return weight_grad, vals_grad
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class xFormersEmbeddingBag(torch.autograd.Function):
|
| 189 |
+
@staticmethod
|
| 190 |
+
@torch.amp.custom_fwd(device_type="cuda")
|
| 191 |
+
def forward(
|
| 192 |
+
ctx,
|
| 193 |
+
indices: torch.Tensor, # [B, K]
|
| 194 |
+
weight: torch.Tensor, # [F, D]
|
| 195 |
+
vals: torch.Tensor, # [B, K]
|
| 196 |
+
) -> torch.Tensor:
|
| 197 |
+
ctx.save_for_backward(indices, weight, vals)
|
| 198 |
+
return embedding_bag_forward(indices, weight, vals) # [B, D]
|
| 199 |
+
|
| 200 |
+
@staticmethod
|
| 201 |
+
@torch.amp.custom_bwd(device_type="cuda")
|
| 202 |
+
def backward(ctx, gradient):
|
| 203 |
+
indices, weight, vals = ctx.saved_tensors
|
| 204 |
+
weight_g, vals_g = embedding_bag_backward(
|
| 205 |
+
indices,
|
| 206 |
+
weight,
|
| 207 |
+
vals,
|
| 208 |
+
gradient,
|
| 209 |
+
)
|
| 210 |
+
return None, weight_g, vals_g
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def triton_topk_sae_loss(
|
| 214 |
+
indices: torch.Tensor, # [B, K]
|
| 215 |
+
weight: torch.Tensor, # [F, D]
|
| 216 |
+
vals: torch.Tensor, # [B, K]
|
| 217 |
+
bias: torch.Tensor, # [D]
|
| 218 |
+
target: torch.Tensor, # [B, D]
|
| 219 |
+
) -> torch.Tensor:
|
| 220 |
+
recon = bias.to(torch.float32) + xFormersEmbeddingBag.apply(indices, weight, vals)
|
| 221 |
+
diff = recon.to(torch.float32) - target.to(torch.float32)
|
| 222 |
+
loss = diff.pow(2).mean()
|
| 223 |
+
return loss
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def topk_sae_loss(
|
| 227 |
+
indices: torch.Tensor, # [B, K]
|
| 228 |
+
weight: torch.Tensor, # [F, D]
|
| 229 |
+
vals: torch.Tensor, # [B, K]
|
| 230 |
+
bias: torch.Tensor, # [D]
|
| 231 |
+
target: torch.Tensor, # [B, D]
|
| 232 |
+
) -> torch.Tensor:
|
| 233 |
+
emb = weight[indices].to(torch.float32) # [K, D]
|
| 234 |
+
recon = bias.to(torch.float32) + (emb * vals.unsqueeze(-1)).sum(dim=1)
|
| 235 |
+
diff = recon.to(torch.float32) - target.to(torch.float32)
|
| 236 |
+
loss = diff.pow(2).mean()
|
| 237 |
+
return loss
|