[general] name = "batch_invariant" universal = false # Defines the C++ files that bind to PyTorch [torch] src = [ "torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h" ] # Defines the CUDA kernels [kernel.batch_invariant_matmul] backend = "cuda" depends = ["torch"] src = [ "csrc/batch_invariant.cu", ]