Module 2c - Automatic differentiation: VJP and intro to JAX

Table of Contents

Autodiff and Backpropagation

Jacobian

Let f:RnRm\mathbf{f}:\mathbb{R}^n\to \mathbb{R}^m, we define its Jacobian as:

fx=Jf(x)=(f1x1f1xnfmx1fmxn)=(fx1,,fxn)=(f1(x)Tfm(x)T)\begin{aligned} \frac{\partial \mathbf{f}}{\partial \mathbf{x}} = J_{\mathbf{f}}(\mathbf{x}) &= \left( \begin{array}{ccc} \frac{\partial f_1}{\partial x_1}&\dots& \frac{\partial f_1}{\partial x_n}\\ \vdots&&\vdots\\ \frac{\partial f_m}{\partial x_1}&\dots& \frac{\partial f_m}{\partial x_n} \end{array}\right)\\ &=\left( \frac{\partial \mathbf{f}}{\partial x_1},\dots, \frac{\partial \mathbf{f}}{\partial x_n}\right)\\ &=\left( \begin{array}{c} \nabla f_1(\mathbf{x})^T\\ \vdots\\ \nabla f_m(x)^T \end{array}\right) \end{aligned}

Hence the Jacobian Jf(x)Rm×nJ_{\mathbf{f}}(\mathbf{x})\in \mathbb{R}^{m\times n} is a linear map from Rn\mathbb{R}^n to Rm\mathbb{R}^m such that for x,vRn\mathbf{x},\mathbf{v} \in \mathbb{R}^n and hRh\in \mathbb{R}:

f(x+hv)=f(x)+hJf(x)v+o(h).\begin{aligned} \mathbf{f}(\mathbf{x}+h\mathbf{v}) = \mathbf{f}(\mathbf{x}) + h J_{\mathbf{f}}(\mathbf{x})\mathbf{v} +o(h). \end{aligned}

The term Jf(x)vRmJ_{\mathbf{f}}(\mathbf{x})\mathbf{v}\in \mathbb{R}^m is a Jacobian Vector Product (JVP), corresponding to the interpretation where the Jacobian is the linear map: Jf(x):RnRmJ_{\mathbf{f}}(\mathbf{x}):\mathbb{R}^n \to \mathbb{R}^m, where Jf(x)(v)=Jf(x)vJ_{\mathbf{f}}(\mathbf{x})(\mathbf{v})=J_{\mathbf{f}}(\mathbf{x})\mathbf{v}.

Chain composition

In machine learning, we are computing gradient of the loss function with respect to the parameters. In particular, if the parameters are high-dimensional, the loss is a real number. Hence, consider a real-valued function f:Rng1Rmg2RdhR\mathbf{f}:\mathbb{R}^n\stackrel{\mathbf{g}_1}{\to}\mathbb{R}^m \stackrel{\mathbf{g}_2}{\to}\mathbb{R}^d\stackrel{h}{\to}\mathbb{R}, so that f(x)=h(g2(g1(x)))R\mathbf{f}(\mathbf{x}) = h(\mathbf{g}_2(\mathbf{g}_1(\mathbf{x})))\in \mathbb{R}. We have

f(x)n×1=Jg1(x)Tn×mJg2(g1(x))Tm×dh(g2(g1(x)))d×1.\begin{aligned} \underbrace{\nabla\mathbf{f}(\mathbf{x})}_{n\times 1}=\underbrace{J_{\mathbf{g}_1}(\mathbf{x})^T}_{n\times m}\underbrace{J_{\mathbf{g}_2}(\mathbf{g}_1(\mathbf{x}))^T}_{m\times d}\underbrace{\nabla h(\mathbf{g}_2(\mathbf{g}_1(\mathbf{x})))}_{d\times 1}. \end{aligned}

To do this computation, if we start from the right so that we start with a matrix times a vector to obtain a vector (of size mm) and we need to make another matrix times a vector, resulting in O(nm+md)O(nm+md) operations. If we start from the left with the matrix-matrix multiplication, we get O(nmd+nd)O(nmd+nd) operations. Hence we see that as soon as mdm\approx d, starting for the right is much more efficient. Note however that doing the computation from the right to the left requires keeping in memory the values of g1(x)Rm\mathbf{g}_1(\mathbf{x})\in\mathbb{R}^m, and xRn\mathbf{x}\in \mathbb{R}^n.

Backpropagation is an efficient algorithm computing the gradient "from the right to the left", i.e. backward. In particular, we will need to compute quantities of the form: Jf(x)TuRnJ_{\mathbf{f}}(\mathbf{x})^T\mathbf{u} \in \mathbb{R}^n with uRm\mathbf{u} \in\mathbb{R}^m which can be rewritten uTJf(x)\mathbf{u}^T J_{\mathbf{f}}(\mathbf{x}) which is a Vector Jacobian Product (VJP), correponding to the interpretation where the Jacobian is the linear map: Jf(x):RnRmJ_{\mathbf{f}}(\mathbf{x}):\mathbb{R}^n \to \mathbb{R}^m, composed with the linear map u:RmR\mathbf{u}:\mathbb{R}^m\to \mathbb{R} so that uTJf(x)=uJf(x)\mathbf{u}^TJ_{\mathbf{f}}(\mathbf{x}) = \mathbf{u} \circ J_{\mathbf{f}}(\mathbf{x}).

example: let f(x,W)=xWRb\mathbf{f}(\mathbf{x}, W) = \mathbf{x} W\in \mathbb{R}^b where WRa×bW\in \mathbb{R}^{a\times b} and xRa\mathbf{x}\in \mathbb{R}^a. We clearly have

Jf(x)=WT. J_{\mathbf{f}}(\mathbf{x}) = W^T.

Note that here, we are slightly abusing notations and considering the partial function xf(x,W)\mathbf{x}\mapsto \mathbf{f}(\mathbf{x}, W). To see this, we can write fj=ixiWijf_j = \sum_{i}x_iW_{ij} so that

fxi=(Wi1Wib)T \frac{\partial \mathbf{f}}{\partial x_i}= \left( W_{i1}\dots W_{ib}\right)^T

Then recall from definitions that

Jf(x)=(fx1,,fxn)=WT. J_{\mathbf{f}}(\mathbf{x}) = \left( \frac{\partial \mathbf{f}}{\partial x_1},\dots, \frac{\partial \mathbf{f}}{\partial x_n}\right)=W^T.

Now we clearly have

Jf(W)=x since, f(x,W+ΔW)=f(x,W)+xΔW. J_{\mathbf{f}}(W) = \mathbf{x} \text{ since, } \mathbf{f}(\mathbf{x}, W+\Delta W) = \mathbf{f}(\mathbf{x}, W) + \mathbf{x} \Delta W.

Note that multiplying x\mathbf{x} on the left is actually convenient when using broadcasting, i.e. we can take a batch of input vectors of shape bs×a\text{bs}\times a without modifying the math above.

Implementation

In PyTorch, torch.autograd provides classes and functions implementing automatic differentiation of arbitrary scalar-valued functions. To create a custom autograd.Function, subclass this class and implement the forward() and backward() static methods. Here is an example:

class Exp(Function):
    @staticmethod
    def forward(ctx, i):
        result = i.exp()
        ctx.save_for_backward(result)
        return result
    @staticmethod
    def backward(ctx, grad_output):
        result, = ctx.saved_tensors
        return grad_output * result
# Use it by calling the apply method:
output = Exp.apply(input)

You can have a look at Module 2b to learn more about this approach as well as MLP from scratch.

Backprop the functional way

Here we will implement in numpy a different approach mimicking the functional approach of JAX see The Autodiff Cookbook.

Each function will take 2 arguments: one being the input x and the other being the parameters w. For each function, we build 2 vjp functions taking as argument a gradient u\mathbf{u}, and corresponding to Jf(x)J_{\mathbf{f}}(\mathbf{x}) and Jf(w)J_{\mathbf{f}}(\mathbf{w}) so that these functions return Jf(x)TuJ_{\mathbf{f}}(\mathbf{x})^T \mathbf{u} and Jf(w)TuJ_{\mathbf{f}}(\mathbf{w})^T \mathbf{u} respectively. To summarize, for xRn\mathbf{x} \in \mathbb{R}^n, wRd\mathbf{w} \in \mathbb{R}^d, and, f(x,w)Rm\mathbf{f}(\mathbf{x},\mathbf{w}) \in \mathbb{R}^m,

vjpx(u)=Jf(x)Tu, with Jf(x)Rm×n,uRmvjpw(u)=Jf(w)Tu, with Jf(w)Rm×d,uRm\begin{aligned} {\bf vjp}_\mathbf{x}(\mathbf{u}) &= J_{\mathbf{f}}(\mathbf{x})^T \mathbf{u}, \text{ with } J_{\mathbf{f}}(\mathbf{x})\in\mathbb{R}^{m\times n}, \mathbf{u}\in \mathbb{R}^m\\ {\bf vjp}_\mathbf{w}(\mathbf{u}) &= J_{\mathbf{f}}(\mathbf{w})^T \mathbf{u}, \text{ with } J_{\mathbf{f}}(\mathbf{w})\in\mathbb{R}^{m\times d}, \mathbf{u}\in \mathbb{R}^m \end{aligned}

Then backpropagation is simply done by first computing the gradient of the loss and then composing the vjp functions in the right order.

Practice