Skip to content

[FEATURE REQUEST] gemm PyTorch kernel implementation #31

@LoserCheems

Description

@LoserCheems

Problem statement

The BLAS level-3 gemm kernel (general matrix-matrix multiply) is not implemented for the PyTorch backend in kernel_course.pytorch_ops. The README BLAS table lists no PyTorch support for gemm, even though GEMM is one of the most important operations in linear algebra and deep learning.

Without a PyTorch gemm kernel:

  • there is no unified kernel API for $C = \alpha A B + \beta C$ on PyTorch tensors,
  • cross-backend comparisons for GEMM are not possible within kernel-course,
  • higher-level modules must rely directly on raw PyTorch matmul/linear primitives instead of a standard kernel abstraction.

Proposed solution

Add a PyTorch implementation of gemm under kernel_course.pytorch_ops that matches the Python reference semantics and follows existing backend patterns.

Concretely:

  • Introduce kernel_course/pytorch_ops/gemm.py implementing $C = \alpha A B + \beta C$ for 2D torch.Tensor inputs.
  • Align function signature and behaviour with other PyTorch kernels (argument order, dtype/device handling).
  • Implement the operation using idiomatic PyTorch (e.g. torch.matmul or @ operator plus scaled add) while maintaining BLAS-like semantics.

Alternatives considered

Relying entirely on raw PyTorch matmul in higher-level code would:

  • weaken the uniform kernel abstraction across backends,
  • complicate benchmarking GEMM across Triton and CuTe,
  • reduce the pedagogical clarity of mapping mathematical GEMM to backend implementations.

Implementation details

  • Add kernel_course/pytorch_ops/gemm.py with a public gemm function.
  • Accept scalars alpha, beta, matrices A, B, and C.
  • Validate shapes and ensure compatibility with BLAS-style semantics.
  • Update kernel_course/pytorch_ops/__init__.py to expose gemm.

Use case

The PyTorch gemm kernel will:

  • provide a standard GEMM primitive through a unified kernel interface,
  • act as a benchmark target for Triton and CuTe implementations,
  • support construction of higher-level modules such as attention and MLP layers.

Related work

  • Existing PyTorch kernels: pytorch_ops.copy, pytorch_ops.swap.
  • PyTorch matmul and linear layer implementations.

Additional context

This issue contributes to completing the gemm row for the PyTorch column in the README BLAS table.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions