| [general] | |
| name = "batch_invariant" | |
| universal = false | |
| # Defines the C++ files that bind to PyTorch | |
| [torch] | |
| src = [ | |
| "torch-ext/torch_binding.cpp", | |
| "torch-ext/torch_binding.h" | |
| ] | |
| # Defines the CUDA kernels | |
| [kernel.batch_invariant_matmul] | |
| backend = "cuda" | |
| depends = ["torch"] | |
| src = [ | |
| "csrc/batch_invariant.cu", | |
| ] |