Consider two compatible tensors, at least one of which has rank greater than 2. For most implementations of matmul (such as the ones found in NumPy, PyTorch, and TensorFlow), multiplying them will result in an operation called batched matrix multiplication.

In this setting, all dimensions except the last two are treated as “batch dimensions.” This is really hard to think about, so let’s consider a 2x2x2 situation.

Suppose we have two 2x2x2 tensors A and B, where

and

What happens when we do torch.matmul(A, B)? All of the terms end up getting multiplied, so what do we end up with? To understand the answer, it’s helpful to think of the tensors as vectors of matrices:

Then we can think of the “batch dimensions” as dimensions along which we are performing a completely separate multiplication. Let be matrix multiplication, possibly with batching. Then

At this point, it’s the usual matrix multiplication, eg:

This line of reasoning can be generalized to arbitrary tensor rank, potentially also involving broadcast.