To do this computation, if we start from the right so that we start with a matrix times a vector to obtain a vector (of size ) and we need to make another matrix times a vector, resulting in operations. If we start from the left with the matrix-matrix multiplication, we get operations. Hence we see that as soon as , starting for the right is much more efficient. Note however that doing the computation from the right to the left requires keeping in memory the values of , and .
Backpropagation is an efficient algorithm computing the gradient "from the right to the left", i.e. backward. In particular, we will need to compute quantities of the form: with which can be rewritten which is a Vector Jacobian Product (VJP), correponding to the interpretation where the Jacobian is the linear map: , composed with the linear map so that .
example: let where and . We clearly have
Note that here, we are slightly abusing notations and considering the partial function . To see this, we can write so that
Then recall from definitions that
Now we clearly have
Note that multiplying on the left is actually convenient when using broadcasting, i.e. we can take a batch of input vectors of shape without modifying the math above.
In PyTorch, torch.autograd
provides classes and functions implementing automatic differentiation of arbitrary scalar-valued functions. To create a custom autograd.Function, subclass this class and implement the forward()
and backward()
static methods. Here is an example:
class Exp(Function):
@staticmethod
def forward(ctx, i):
result = i.exp()
ctx.save_for_backward(result)
return result
@staticmethod
def backward(ctx, grad_output):
result, = ctx.saved_tensors
return grad_output * result
# Use it by calling the apply method:
output = Exp.apply(input)
You can have a look at Module 2b to learn more about this approach as well as MLP from scratch.
Here we will implement in numpy
a different approach mimicking the functional approach of JAX see The Autodiff Cookbook.
Each function will take 2 arguments: one being the input x
and the other being the parameters w
. For each function, we build 2 vjp functions taking as argument a gradient , and corresponding to and so that these functions return and respectively. To summarize, for , , and, ,
Then backpropagation is simply done by first computing the gradient of the loss and then composing the vjp functions in the right order.
intro to JAX: autodiff the functional way autodiff_functional_empty.ipynb and its solution autodiff_functional_sol.ipynb
Linear regression in JAX linear_regression_jax.ipynb