Posts Deep Equilibrium Models in Jax
Post
Cancel

Deep Equilibrium Models in Jax

Implicit layers and Deep Equilibrium models (DEQ) have recently been proposed as memory-efficient alternatives to super-deep networks. In this post we explore:

  • the mathematical background behind implicit layers and gradients used by auto-differentiation systems;
  • jax implementations, including Jacobian-vector and vector-Jacobian products; and
  • introduce deqx, a clean and flexible jax library including haiku implementations.

This post will not go into motivations for implicit layers or DEQ models themselves. For this, we direct the interested reader to this NeurIPS workshop recording.

Mathematical Background

\(\newcommand{\pdiff}[2]{\frac{\partial {#1}}{\partial #2}}\) \(\newcommand{\diff}[2]{\frac{d {#1}}{d #2}}\)

We begin by considering a scalar function $z(\theta)$ that is implicitly defined as the root of some function $g$, i.e.

\begin{equation} g(z(\theta), \theta) = 0. \label{eqn:implicit-def} \end{equation}

Note this is slightly different to the formulation generally used in Deep Equilibrium Models that use the fixed point of a function with an additional input $x$, $z = f(z, x; \phi)$. We use Equation \ref{eqn:implicit-def} because it leads to simpler equations, but the fixed point formulation can be achieved by using $f(z, x; \phi) = g(z, [x, \phi]) + z$.

Given a sufficiently accurate approximate solution (e.g. from Newton’s method), let us consider how to compute gradients and derivatives as they are used in auto-differentiation systems.

Forward-Mode Differentiation a.k.a. Jacobian-vector Products (jvp)

Forward-mode differentiation involves computing gradients of $z$ using gradients of $\theta$ using the chain rule, \begin{equation} \diff{z}{t} = \diff{z}{\theta}\diff{\theta}{t}. \label{eqn:jvp} \end{equation}

In the case where $z$ and $\theta$ are vectors and $g$ is a vector-valued function, the leading derivative on the right is a Jacobian matrix, so this is known as a Jacobian-vector product.

To compute this quantity, we can differentiate Equation \ref{eqn:implicit-def} with respect to $\theta$:

\begin{equation} \pdiff{g}{z}\diff{z}{\theta} + \pdiff{g}{\theta} = 0. \end{equation}

Rearranging gives us the quantity we require:

\begin{equation} \diff{z}{\theta} = -\left[\diff{g}{z}\right]^{-1}\pdiff{g}{\theta}. \label{eqn:jacobian} \end{equation}

Substituting this into Equation \ref{eqn:jvp} gives

\begin{equation} \diff{z}{t} = -\left[\diff{g}{z}\right]^{-1}\pdiff{g}{\theta}\diff{\theta}{t}. \end{equation}

Using an auto-differentiation framework, we can define this in terms of other Jacobian-vector products and an iterative linear system solver.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import typing as tp

import jax
import jax.numpy as jnp


def rootfind_jvp(
    fun: tp.Callable,
    root: jnp.ndarray,
    args: tp.Tuple,
    tangents: tp.Tuple,
    jacobian_solver: tp.Callable = jax.scipy.sparse.linalg.gmres,
) -> jnp.ndarray:
    """
    Get jvp of root w.r.t. args.

    Args:
        fun: function with found root.
        root: root of fun such that `fun(root, *args) == 0`.
        args: additional arguments to `fun`.
        tangents: tangents corresponding to `args`.
        jacobian_solver: linear system solver, i.e. if `x = jacobian_solver(A, b)` then
            `A(x) == b`.

    Returns:
        jvp of `root` given `tangents` of `args`.
    """
    if len(args) == 0:
        # if there are no other parameters there will be no gradient
        return ()

    # fun_dot is the jvp of fun w.r.t all `*args`
    _, fun_dot = jax.jvp(partial(fun, root), args, tangents)

    def Jx(v):
        # The Jacobian of fun(x, *args) w.r.t. x evaluated at (primal_out, *args)
        return jax.jvp(lambda x: fun(x, *args), (root,), (v,))[1]

    sol, _ = jacobian_solver(Jx, fun_dot)
    tangent_out = -sol
    return tangent_out

Reverse-Mode Differentiation a.k.a. vector-Jacobian Products (vjp)

While propagating gradients forward using Jacobian-vector products is handy, training deep learning models requires gradients to be propagated backwards. For a scalar loss function $L(z(\theta))$, this means computing gradients of $L$ with respect to $\theta$,

\begin{equation} \diff{L}{\theta} = \diff{L}{z}\diff{z}{\theta}. \label{eqn:vjp} \end{equation}

For the vector-case, the Jacobian is the second factor on the right, so this is referred to as a vector-Jacobian product. Subsituting the result from Equation \ref{eqn:jacobian} gives

\begin{equation} \diff{L}{\theta} = -\diff{L}{z}\left[\diff{g}{z}\right]^{-1}\pdiff{g}{\theta}. \end{equation}

Similar to the Jacobian-vector product case above, this vector-Jacobian product can be implemented using existing vector-Jacobian product implementations and an iterative linear system solver,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import typing as tp

import jax
import jax.numpy as np


def rootfind_vjp(
    fun: tp.Callable,
    root: jnp.ndarray,
    args: tp.Tuple,
    g: jnp.ndarray,
    jacobian_solver: tp.Callable = jax.scipy.sparse.linalg.gmres,
) -> tp.Tuple:
    """
    rootfind vjp implementation computed using vector inverse Jacobian vector product.

    Args:
        fun: Callable such that `jnp.all(fun(root, *args) == 0)`
        root: x such that fun(x, *args) == 0
        args: tuple, used in fun.
        g: gradient of some scalar (e.g. a neural network loss) w.r.t. root.
        solver: approximate linear system solver, `solver(A, b) == x` x s.t. A(x) = b.

    Returns:
        vjp of root w.r.t. args, same structure as args.
    """
    _, vjp_fun = jax.vjp(lambda x: fun(x, *args), root)
    vJ, _ = jacobian_solver(lambda x: vjp_fun(x)[0], -g)
    out = jax.vjp(lambda *args: fun(root, *args), *args)[1](vJ)
    return out

DEQX: Clean and Flexible Implementations

To ease development, we’ve released jax implementations relevant to deep learning, including:

  • multiple root-finding and fixed point iteration implementations;
  • Jacobian-vector products and vector-Jacobian products; and
  • haiku DEQ layers.

While tests all pass, the basic mnist example fails to perform well. Whether this is down to an implementation bug or just a poorly designed model is an open question at this stage. Think you know what I’ve done wrong? Open an issue/PR in the repository or let us know in the comments!

This post is licensed under CC BY 4.0 by the author.