-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Description
Problem statement
The BLAS level-3 gemm kernel (general matrix-matrix multiply) is not implemented for the PyTorch backend in kernel_course.pytorch_ops. The README BLAS table lists no PyTorch support for gemm, even though GEMM is one of the most important operations in linear algebra and deep learning.
Without a PyTorch gemm kernel:
- there is no unified kernel API for
$C = \alpha A B + \beta C$ on PyTorch tensors, - cross-backend comparisons for GEMM are not possible within
kernel-course, - higher-level modules must rely directly on raw PyTorch matmul/linear primitives instead of a standard kernel abstraction.
Proposed solution
Add a PyTorch implementation of gemm under kernel_course.pytorch_ops that matches the Python reference semantics and follows existing backend patterns.
Concretely:
- Introduce
kernel_course/pytorch_ops/gemm.pyimplementing$C = \alpha A B + \beta C$ for 2Dtorch.Tensorinputs. - Align function signature and behaviour with other PyTorch kernels (argument order, dtype/device handling).
- Implement the operation using idiomatic PyTorch (e.g.
torch.matmulor@operator plus scaled add) while maintaining BLAS-like semantics.
Alternatives considered
Relying entirely on raw PyTorch matmul in higher-level code would:
- weaken the uniform kernel abstraction across backends,
- complicate benchmarking GEMM across Triton and CuTe,
- reduce the pedagogical clarity of mapping mathematical GEMM to backend implementations.
Implementation details
- Add
kernel_course/pytorch_ops/gemm.pywith a publicgemmfunction. - Accept scalars
alpha,beta, matricesA,B, andC. - Validate shapes and ensure compatibility with BLAS-style semantics.
- Update
kernel_course/pytorch_ops/__init__.pyto exposegemm.
Use case
The PyTorch gemm kernel will:
- provide a standard GEMM primitive through a unified kernel interface,
- act as a benchmark target for Triton and CuTe implementations,
- support construction of higher-level modules such as attention and MLP layers.
Related work
- Existing PyTorch kernels:
pytorch_ops.copy,pytorch_ops.swap. - PyTorch matmul and linear layer implementations.
Additional context
This issue contributes to completing the gemm row for the PyTorch column in the README BLAS table.
Metadata
Metadata
Assignees
Labels
No labels