Chapter 1: Introduction
Introduction: Explicit layers in deep learning
At the heart of modern deep learning methods is the notion of a layer. Deep learning models are traditionally built by stacking many of these layers together, to create an archictecture designed to solve some particular task. Convolutional networks, for example, consist of convolutional layers, typically followed elementwise nonlinearitis like the ReLU, with additional operations like normalization or dropout, and possibly connected together in multiple different ways, to form things like residual layers. Likewise, architectures like Transformer networks consist of combinations of so-called self-attention layers and fully-connected layers, again stacked together in a manner than results in the final form of the model.
A common defining characteristic, which at this point is so standard that it often goes unnoticed by practioners, is that the vast majority of these layers in modern deep learning are defined explicitly. That is, they are specified by an exact sequence of operations that it takes to do from the input to the output a layer. Let’s take a scaled self-attention layer as an example [Vaswani et al.]. This layer is a mapping from three matrices $K,Q,V \in \mathbb{R}^{T \times n}$ to an output $Z \in \mathbb{R}^{T \times n}$, and is defined by the operation
\[Z = \mathrm{SelfAttention}(K,Q,V) \equiv \mathrm{softmax}\left (\frac{K Q^T}{\sqrt{n}} \right) V\](this is a simplified version of self-attention, just for illustration, with no masks or multi-head structure). We could write this layer as a simple Python function (again, just for illustration, not how you would actually want to write e.g., the softmax operation, let alone the fact that you’d likely want to use an automatic differentiation library rather than plain numpy to write these functions).
import numpy as np
def self_attention(K,Q,V):
A = np.exp(K @ Q.T) / np.sqrt(K.shape[1])
return (A / np.sum(A,1)) @ V
K, Q, V = np.random.randn(3, 5, 4)
print(self_attention(K, Q, V))
[[-0.51533073 -0.00353644 1.32042545 0.44677479] [-1.42569894 0.04472804 1.28433061 -0.5801782 ] [-0.3144326 0.08306728 0.42332847 -0.12556929] [-0.08648921 0.01096141 0.90249335 0.54552451] [-1.61778227 -0.00595687 1.35908765 -0.71728098]]
Naturally, things start to get a bit more complex as we add more functionality to the layers themselves, implement them within automatic differentiation libraries, etc, but this explicit form of the typical layer remains through it all: layers are built largely like typical computer programs, where we directly write the code to generate the output of the layer as a function of its input. This may be so ingrained, in fact, that it’s hard to imagine that there is an entirely different way that layers can be defined, namely via implicit layers.
Implicit layers
The crux of an implicit layer, as we will use the term throughout this document, is that instead of specifying how to compute the layer’s output from the input, we specify the conditions that we want the layer’s output to satisfy. That is, if we were to write explicit layers (with input $x \in \mathcal{X}$ and output $z \in \mathcal{Z}$) as an application of some explicit function $f : \mathcal{X} \rightarrow \mathcal{Z}$
\[z = f(x)\]then an implicit layer would instead be defined via a function $g : \mathcal{X} \times \mathcal{Z} \rightarrow \mathbb{R}^n$, which is a joint function of both $x$ and $z$, and where the output of the layer $z$ is requied to satisfy some constraint e.g., finding a root of the equation,
\[\mbox{Find $z$ such that } g(x,z) = 0.\]The notation here may indicate that $g(x,z)$ is a simple alegebraic equation, but in practice this same formalism can capture algebraic equations and fixed points, leading to recurent backprop models or deep equilibrium models; differential equations,leading to Neural ODEs; or the optimality conditions of optimization problems, leadings to differentiable optimization approaches.
Before we move on to a concrete example, we should highlight the fact that initially moving to this implicit formulation may seem like a trivial point. After all, in order to actually implement a layer like this, we would need to specify some way of actually computing a root of the equation $g$. But as we will see shortly, there are numerous practical advantages to considering the implicit form of a layer.
Most fundamentally, implicit form layers separate the solution procedure of the layer from the definition of the layer itself. This level of modularity has proven extremely useful in a number of domains. Differential equation solvers, for example, which attempt to find a numerical solution to an ordinary differntial equation, can implement all sorts of adaptive step sizes, corections for so-called “stiff” equations, etc, all in service of attempting to find a low-error solution to the differentiable equation. Or as another example, optimization solvers often involve very complex heuristics for solving certain types of problems, but they are all aimed at finding the minimum-objective solution to an optimization task. Indeed, because we rarely find exact solutions to e.g., algebraic or differential equations, the different solutions methods can be evaluated objectively against each other based upon how well they satisfy the conditions that the layer is attempting to satisfy.
This separatation of the layer’s objective and its solution method is desirable enough in and of itself, but a second advantage to implicit layers emerges specifically in the context of deep learning and automatic differentiation. The traditional approach to automatic differentiation (AD) in machine learning is to implement all layers within an automatic differentiation framework (such as PyTorch, Tensorflow, or JAX), which immediately lets us include these layers in deep models that require gradients for fitting the models to data. Yet implementing solution procedures, especially those involving iterative updates, like standard differential equation or optimization solvers, directly within AD library would mean that we need the store the computation graph for the complete solution procedure, along with the value of temporary iterates created during this solution. This requires storing a great deal of information in memory, which can often be a bottleneck during training of large deep learning models. Fortunately, as we will illustrate below, and highlight several times within this tutorial, implicit layers have the notably advantage that we can use the implicit function theorem to directly compute gradients at the solution point of these equations, without having to store any intermediate variables along the way. This vastly improves the memory consumption and often the numerical accuracy of these methods, providing another notable benefit for implicit models in the setting of deep learning in particular.
Applications and illustrations
Since the remainder of this chapter will focus on an extremely simple demonstration intended as a pedogogical illustration of the methods (rather than an illustration of state of the art performance), we want to briefly highlight the wide array of applications that have been addressed using implicit layers. The following are just a small sampling (of instances that the authors happen to be most familiar with), but they hopefully give an illustration of the breadth of approaches addressed by current research in implicit layers. We’ll dive into more detail about a few of these in the text, but for the most part, you will want to look at the current research in the field to stay abreast of all the application areas being addressed by these approaches.
Implicit layers have been used to:
- Solve arbitrary structured convex problems (using the
cvxpy
library) in a differentiable manner. - Solve smoothed relaxtions of combinatorial optimization problems, such as graph cuts, satisfiability, and many others.
- Integrate differential equations as layers in deep networks (with numerous applications in and of itself, such as integrating continuous time observations, or approximating continuous version of traditional residual networks).
- Create architectures for efficient representation of smooth densities, for use in generative modeling an beyond.
- Achieve performance on par with state-of-the-art Transformer models (at the same parameter count), for language modeling and on par with state-of-the-art computer vision architectures on tasks such as classification and semantic segmentation.
Outline of this work
With the above brief introduction as context, we outline the remainder of this work and how the chapters fit together.
- In the remainder of Chapter 1, we will give an brief introduction to your first implicit layer, defined via a fixed point iteration. This is essentially a version of recurrent backpropagation that was one of the first forms of implicit layers, tracing back to the late 80s, and is also the approach that underlies deep equilibrium (DEQ) models.
- In Chapter 2 we will discuss the mathematical background behind implicit models, including the implicit function theorem and its implementation in automatic differentiation tools.
- In Chapter 3 we will present Neural ODEs, an instatiation of implicit layers that has received substantial attention in recent years. We will present the basic mathematical framework as well as highlight a number of applications of the model.
- In Chapter 4 we will present Deep Equilibrium Models in greater detail, focusing on extending the basic idea presented in Chapter 1 to modern deep learning frameworks, and highlighting some ongoing directions and applications of the models.
- In Chapter 5 we will present differentiable optimization, which embeds the solution to optimization problems as layers. We will specifically show how this relates to an captures and generalizes many existing layers in deep learing.
Your first implicit model: fixed point iteration
Before diving into the mathematical details and many different forms of implicit models, let’s start with a particularly simple example: a network layer defined by a fixed point iteration. As mentioned above, this type of layer dates back to some of the original formulations of recurrent backpropagation, and is also the basis for the Deep Equilibrium Models we will discuss soon.
A fixed point iteration layer
Athough we will shortly adopt a view of this layer as the root of a particular equation, to introduce the layer, suppose we have an inputs and outputs $x,z \in \mathbb{R}^n$ of the same dimension, and consider the following approach to computing the output $z$ as a function of $x$
\[\begin{split} & z := 0 \\ & \mbox{Repeat until convergence:} \\ & \quad z := \tanh(Wz + x) \end{split}\]for some network parameters $W \in \mathbb{R}^{n \times n}$. This is an instance of a fixed point iteration: under certain conditions the procedure will converge to some fixed output $z^\star$, which of course has the property that
\[z^\star = \tanh(W z^\star + x).\]We’ll delay for now, any discussion of why this might be a particularly nice form for a layer to take, but briefly say that this type of layer can e.g., be interpreted as a simple recurrent network where $z$ is the hidden layer, and where we repeatedly apply the network to the same input $x$. A layer like this can also, e.g., reap some of the benefits of a “deep” neural network (due to the fact that it involves repeatedly application of a nonlinearity), while only having the parameters $W$ of a “single” layer. But we will discuss these advantages layer, and for now focus on simply using a layer such as this.
Note that, of course, this is can be written as an implicit layer in the form above, i.e., that $z^\star$ is the solution to the root-finding equation
\[\mbox{Find $z$ such that } g(x,z) = 0, \quad \mbox{where } g(x,z) \equiv z - \tanh(W z + x).\]Note that this iteration need not actually converge: although the $\tanh$ activation will enforce that the values of $z$ never leave the range $[-1,+1]$, depending on the value of $W$ it could be that e.g., the values cycle endlessly and never reach a fixed point. On the other hand, if e.g., $W=0$, then the iteration reaches a “fixed point” $z^\star = \tanh(x)$ after a single iteration. All that we will say here is that for “typical” values of $W$ (read as: the default values of a linear layer used by most deep learning library, plus the values that they reach over optimization), this iteration will indeed converge, and we’ll cover the issues of existence and uniqueness of fixed points later.
Implementing a fixed point iteration layer
Implicit layers certainly take a bit more effort to implement, compared to traditional layers, within an automatic differentiation library like PyTorch or JAX. But the actual core of the implementation is still quite straightforward, and still made much easier via these tools.
To start, let’s consider the simplest possible implementation of a layer like this, which simply repeats the fixed point iteration to converge, all via the normal autograd functionality of the library (i.e., we are just “unrolling” the fixed point computation). Becaues this is happening via the normal autograd mechanisms, each intermediate iterate has to be stored in memory, and the backward pass must proceed similarly over the same iterations but in the reverse order. For now, we’ll make a single layer that implements exactly the composition of the $\tanh$ and linear layer above (plus with other hacks for simplicitly, like storing the most recent iteration count and error), but in later chapters we’ll make this much more modular so that we can find similar fixed points of generic layers implemented using the same library.
import torch
import torch.nn as nn
class TanhFixedPointLayer(nn.Module):
def __init__(self, out_features, tol = 1e-4, max_iter=50):
super().__init__()
self.linear = nn.Linear(out_features, out_features, bias=False)
self.tol = tol
self.max_iter = max_iter
def forward(self, x):
# initialize output z to be zero
z = torch.zeros_like(x)
self.iterations = 0
# iterate until convergence
while self.iterations < self.max_iter:
z_next = torch.tanh(self.linear(z) + x)
self.err = torch.norm(z - z_next)
z = z_next
self.iterations += 1
if self.err < self.tol:
break
return z
We can run this layer on random output, to see that it does in fact reach a fixed point.
layer = TanhFixedPointLayer(50)
X = torch.randn(10,50)
Z = layer(X)
print(f"Terminated after {layer.iterations} iterations with error {layer.err}")
Terminated after 15 iterations with error 5.143991802469827e-05
Although it may ultimately not be that much more informative that simply running the layer on random data, it’s a bit more interesting if we can see the layer used within a real model. So with this in mind, we’ll also present below a simple model trained on the MNIST dataset, using a single fixed point layer (with an additional linear input layer before the fixed point layer, and linear layer after the fixed point layer). The model isn’t intended to break any records here, but it provides a slightly more useful example upon which to base further experiments than just running the layer in isolation.
# import the MNIST dataset and data loaders
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
mnist_train = datasets.MNIST(".", train=True, download=True, transform=transforms.ToTensor())
mnist_test = datasets.MNIST(".", train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(mnist_train, batch_size = 100, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size = 100, shuffle=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# construct the simple model with fixed point layer
import torch.optim as optim
torch.manual_seed(0)
model = nn.Sequential(nn.Flatten(),
nn.Linear(784, 100),
TanhFixedPointLayer(100, max_iter=200),
nn.Linear(100, 10)
).to(device)
opt = optim.SGD(model.parameters(), lr=1e-1)
# a generic function for running a single epoch (training or evaluation)
from tqdm.notebook import tqdm
def epoch(loader, model, opt=None, monitor=None):
total_loss, total_err, total_monitor = 0.,0.,0.
model.eval() if opt is None else model.train()
for X,y in tqdm(loader, leave=False):
X,y = X.to(device), y.to(device)
yp = model(X)
loss = nn.CrossEntropyLoss()(yp,y)
if opt:
opt.zero_grad()
loss.backward()
if sum(torch.sum(torch.isnan(p.grad)) for p in model.parameters()) == 0:
opt.step()
total_err += (yp.max(dim=1)[1] != y).sum().item()
total_loss += loss.item() * X.shape[0]
if monitor is not None:
total_monitor += monitor(model)
return total_err / len(loader.dataset), total_loss / len(loader.dataset), total_monitor / len(loader)
Let’s finally train the model for 10 epochs. In addition to the train/test error/loss, we’ll also print the average number of fixed point iteration that out layer required to converge to a fixed point.
for i in range(10):
if i == 5:
opt.param_groups[0]["lr"] = 1e-2
train_err, train_loss, train_fpiter = epoch(train_loader, model, opt, lambda x : x[2].iterations)
test_err, test_loss, test_fpiter = epoch(test_loader, model, monitor = lambda x : x[2].iterations)
print(f"Train Error: {train_err:.4f}, Loss: {train_loss:.4f}, FP Iters: {train_fpiter:.2f} | " +
f"Test Error: {test_err:.4f}, Loss: {test_loss:.4f}, FP Iters: {test_fpiter:.2f}")
Again, not exactly breaking any records here (a single hidden layer network achieves the same performance, with much faster execution/training), but it’s nice that the network at least trains with this layer. There are a few things to notice, however. The first one is that we end up running the fixed point iteration for a fairly large number of layers in order to get convergence to within $10^{-4}$ of a fixed point. And if you were the look at the individual iterations required for each minibatch, you’d see that several of them don’t even reach this tolerance at all, but exit after 200 steps at a lower tolerance level (it’s possible that the fixed point iteration can even go unstable at some points during training, and this typically degrades the model significantly if it does happen without any error handling in place). This seems to be a pretty substantial downside. We are effectively running a 50-80 “layer” network in practice, and not seeing much advantage over a standard MLP (not that, because we re-add the input to our layer at each iteration $z := \tanh(Wz + x)$, this is not the same as a traditional MLP of this depth, which would suffer from vashing/exploding gradients).
To really see the potential upside of these layers, then, we need to introduce a few more ideas.
Alternative root finding techqniues
Recall that one benefit of implicit layers is that they provide a separation between what is computed by the layer, and how the layer computes this. In the above example, our goal of the fixed point iteration was to find some $z$ such that
\[z = \tanh(W z + x).\]One way to do this is to simply iterate this equation, but it is by no means the only way. Alternatively, we can employ a much faster root-finding method, such as Newton’s method, to try to find this solution more efficiently.
Newton’s method is a generic root-solving technique. For some function $g : \mathbb{R}^n \rightarrow \mathbb{R}^n$, if we wish to find a root $g(z) = 0$, then Newton’s method repeats the update
\[z := z - \left ( \frac{\partial g}{\partial z} \right ) ^{-1} g(z)\]where $\frac{\partial g}{\partial z}$ denotes the Jacobian of $f$ with respect to $z$ (a “guarded” update is often required in practice that makes smaller steps to ensure sufficient decrease of the residual $\|g(z)\|$, but we won’t consider this here). Although we could resort to automatic differentiation to compute the Jacobians (and we’ll need to do this in later chapters, when we have a more generic layer we’re using within fixed point iteration), for the case of our $\tanh$ plus lineaer layer, it’s easy to compute the Jacobian in closed form. Specifically, we are attempting to find the root of the equation $g(x,z) = 0$ (returning to the notation from the previous sections of this chapter, where we make the dependence on layer input $x$ explicit), where
\[g(x,z) = z - \tanh(Wz + x).\]Then our Jacobian is given by
\[\frac{\partial g}{\partial z} = I - \mathrm{diag}(\tanh'(Wz + x)) W\]where $\tanh’$ denotes the derivative of the $\tanh$ function, given by
\[\tanh'(x) = \mathrm{sech}^2(x).\]Let’s see what an implementation of Newton’s method looks like in code. The implementation is slightly more involved than the simple fixed point iteration, owing to the need to compute the Newton step, but this ends up being only a few additional lines of code.
class TanhNewtonLayer(nn.Module):
def __init__(self, out_features, tol = 1e-4, max_iter=50):
super().__init__()
self.linear = nn.Linear(out_features, out_features, bias=False)
self.tol = tol
self.max_iter = max_iter
def forward(self, x):
# initialize output z to be zero
z = torch.tanh(x)
self.iterations = 0
# iterate until convergence
while self.iterations < self.max_iter:
z_linear = self.linear(z) + x
g = z - torch.tanh(z_linear)
self.err = torch.norm(g)
if self.err < self.tol:
break
# newton step
J = torch.eye(z.shape[1])[None,:,:] - (1 / torch.cosh(z_linear)**2)[:,:,None]*self.linear.weight[None,:,:]
z = z - torch.solve(g[:,:,None], J)[0][:,:,0]
self.iterations += 1
g = z - torch.tanh(self.linear(z) + x)
z[torch.norm(g,dim=1) > self.tol,:] = 0
return z
layer = TanhNewtonLayer(50)
X = torch.randn(10,50)
Z = layer(X)
print(f"Terminated after {layer.iterations} iterations with error {layer.err}")
Terminated after 3 iterations with error 1.1832605650852202e-06
This is able to nicely coverge more quickly than the fixed point iteration, but with the (major) caveat that we now have to solve a linear system at each iteration. And again, since we’re implementing this whole procedure using automatic differentiation, means we also need to backprop through the solves in the backward pass. But we can simply plug this into the same training process as before.
torch.manual_seed(0)
model = nn.Sequential(nn.Flatten(),
nn.Linear(784, 100),
TanhNewtonLayer(100, max_iter=40),
nn.Linear(100, 10)
).to(device)
opt = optim.SGD(model.parameters(), lr=1e-1)
for i in range(8):
if i == 5:
opt.param_groups[0]["lr"] = 1e-2
train_err, train_loss, train_fpiter = epoch(train_loader, model, opt, lambda x : x[2].iterations)
test_err, test_loss, test_fpiter = epoch(test_loader, model, monitor = lambda x : x[2].iterations)
print(f"Train Error: {train_err:.4f}, Loss: {train_loss:.4f}, Newton Iters: {train_fpiter:.2f} | " +
f"Test Error: {test_err:.4f}, Loss: {test_loss:.4f}, Newton Iters: {test_fpiter:.2f}")
Again, the method works reasonably well. However, there are some notable issues with the approach as implemented. The first, as you will immediately notice if you run the code, is that the approach is noticably slower than the the simpler fixed point iteration method above. Even though the number of iterations needed is much smaller that for fixed point iteration, each individual iteration is also much slower, as it involves forming and inverting a separate (in this case, $100 \times 100$) Jacobian matrix for each sample in the minibatch. And for larger hidden unit sizes (especially, e.g. for convolutional networks), it will quickly become intractable to invert or even store these matrices. And indeed, in pratice, an exact Newton method is rarely used, and instead we can employ quasi-Newton methods to improve convergence over the standard fixed point iteration, while also improving wall clock time.
The second issue with this approach is a bit more subtle, but actually an even larger problem. Because we implemented Newton’s method directly within an automatic differentiation toolkit, there are a few large downsides to this method as presented. First, as with the fixed point iteration, the automatic differentiation tool will need to save the intermediate iterates of the hidden units; but here, this means that we also need to store in memory the intermediate iterates of the Jacobian terms as well, which drastically increases memory consumption even for the case where we could store and invert the full Jacobian. Furthermore, backpropagation through repeated inverses can be a numerically unstable routine: if an inverse is close to singular then even if the forward pass converges properly, the backward pass can still generate numerical errors in the graidents. Indeed, you will notice that we included a “NaN check” in our epoch()
. If we did not do this, then for Newton’s method the approach will immediately fail: if you check, you will see that around 5% of the updates actually have NaN values in the gradient, due to conditioning of the Jacobians, which is also what causes the method to actually converge slower than the fixed point iteration version.
This paints a pretty dim picture for “efficient” methods for solving implicit models. Fortunately, however, there is a much better way to implement these layers, courtesy of the implicit function theorem.
Differentiation in implicit layers
Thus far, we implemented our solver for our implicit layer in exactly the same manner as we would implement any other layer, and let the automatic differentiation library take care of the backward pass. However, there is a much nicer way to differentiate with respect to the fixed point of a hidden layer. To see how to do this, let’s consider the generic form of our implicit layer, namely given $x$, finding some $z$ such that
\[g(x,z) = 0.\]Let’s denote $z^\star(x)$ as the value that solves this fixed point, written this way to emphasize that the output of the implicit layer is still of course an (implicit) function of the input.
Now let’s consider how we can compute the Jacobian of this output with respect to the input
\[\frac{\partial z^\star(x)}{\partial x}.\]Unlike the traditional functions you’re used to, where we are given an explicit form for computing the output from the input, it may not be obvious how to determine such a Jacobian. But in fact it’s very straightforward to compute the term using implicit differentiation, a technique that goes back several centuries in Calculus. In particular, to derive an expression for this Jacobian, we start with the fixed point condition, which we know holds for $z^\star(x)$, and differentiate both sides with respect to $x$,
\[\frac{\partial g(x,z^\star(x))}{\partial x} = 0.\]Now we just use the chain rule to expand this partial derivative: since $g$ is a function of two variables, there will be a term involving the derivative with respect to each
\[\frac{\partial g(x,z^\star)}{\partial x} + \frac{\partial g(x,z^\star)}{\partial z^\star}\frac{\partial z^\star(x)}{\partial x} = 0\]where the notation $z^\star$ (not indicated as a funciton of $x$), just indicates means we are treating $z^\star$ as a fixed value here (i.e., the Jacobian $\frac{\partial g(x,z^\star)}{\partial x}$ would just be the Jacobian of $g$ with respect to $x$, evaluated at the point $(x,z^\star)$). Thus, this term, along with the $\frac{\partial g(x,z^\star)}{\partial z^\star}$ term, can be computed themselves using ordinary automatic differentiation libraries. Finally, we then just rewrite this equation to give us the expression we are after in terms of the expressions we know
\[\frac{\partial z^\star(x)}{\partial x} = - \left ( \frac{\partial g(x,z^\star)}{\partial z^\star} \right )^{-1} \frac{\partial g(x,z^\star)}{\partial x}.\]Technically, in order to ensure that we can actaully apply this theorem, we require that certain conditions must be satisfied, so that the implicit function $z^\star(x)$ is guaranteed to exist: these conditions are reflected in what is known as the implicit function theorem, which will be discussed in the following chapter. Additionally, just as with Newton vs. quasi-Newton methods, in practice it is often not possible to compute this inverse directly, but instead an iterative process is needed. We’ll cover the mathematical details and formalisms a lot more in the next chapter, but for the purposes of most of what we actually need to derive, this “informal” derivation is virtually all you’ll need. Finally, although we wrote the above formula for the Jacobian with respect to $x$, when $g$ is also a function of some parameters $\theta$ (e.g., the weights and biases), precisely the same derivation holds to find the Jacobian with respect to these parameters.
Moving back from the detailed derivation of this formula though, the implicit function theorem leads to a very practical consequence. Namely, the formula gives a form for the necessary Jacobian without needing to backpropagate through the method used to obtain the fixed point. In other words, it doesn’t matter at all how we compute the zero of the function (whether via fixed point iteration, Newton’s method, or quasi-Newton methods). All that matters is finding the fixed point (using whatever technique you want), at which point we can directly compute the necessary Jacobians using this analytical form (or more precisely, compute the backward pass, which often will not require explicit computation of the Jacobian in practice). No intermediate terms of the iterative method used to compute the fixed point need to be stored in memory (making the methods much more memory efficient), and there is no need to unroll the forward computations within an automatic differentiation layer.
Implementing implicit differentiation
Let’s see how an implementation of implicit differentiation will work in practice. First, let’s consider our tanh plus linear layer again, where the $g(x,z)$ function is given by
\[g(x,z) = z - \tanh(Wz + x)\]In this case the Jacobian $\frac{\partial g}{\partial z^\star}$, needed for implicit differentation, is given by
\[\frac{\partial g}{\partial z^\star} = I - \mathrm{diag}(\tanh'(Wz^\star+x)) W.\]You may notice that this is the exact same Jacobian that we formed when using Newton’s method to solve for the fixed point. This is not an accident: indeed, precisely the same Jacobian term is needed in finding a root within Newton’s method as for computing the backward pass via implicit differentiation. This results in a very nice property: forthe situation where we find a solution to our root via a method like Newton’s method (or any approach that computes and inverts the Jacobian), then computing the backward pass via Newton’s method is effectively “free” (at least, relative to the complexity of solving the fixed point to begin with): we can simply reuse the Jacobian (and its inverse) that we did in the forward pass. Of course, since in practice we often use quasi-Newton or first order methods for finding fixed points of implicit layers, this is not quite as big an advantage as it may seem. But, nonetheless, in the case where we do compute even an approximation to the Jacobian during the forward pass, it can be beneficial to leverage this computation in the backward pass as well.
Before we move on to our implmentation, we should emphasize how the actual implicit differentiation process works within backpropation (i.e., reverse mode autodifferentiation). In backpropagation, we don’t actually need to compute full Jacobians of the intermediate layers in a network. Rather, the goal of backprop is to compute the gradient with respect to some scalar loss. If we write this out in terms of our gradient above, it would look something like
\[\frac{\partial \ell}{\partial x} = \frac{\partial \ell}{\partial z^\star} \frac{\partial z^\star}{\partial x} = - \frac{\partial \ell}{\partial z^\star} \left (\frac{\partial g}{\partial z^\star} \right )^{-1} \frac{\partial g}{\partial x}\]where we applied the implicit differentiation formula above in the last equality. In backpropagation, this term is computed left-to-right, meaning instead of actually needing to compute the full Jacobian $\frac{\partial z^\star}{\partial x}$, we just need to the compute the vector Jacobian product shown above. As a matter of convention, most automatic differentiation frameworks frame this in terms of the operations on the gradient (the transpose of the Jacobian for a scalar-valued function)
\[\nabla_{z^\star} \ell = \left ( \frac{\partial \ell}{\partial z^\star} \right )^T\]so we need to multiply by the transpose of the Jacobian
\[\nabla_x \ell = \left (\frac{\partial g}{\partial x} \right )^T \left (\frac{\partial g}{\partial z^\star} \right )^{-T} \nabla_{z^\star} \ell.\]Again, however, we emphasize that we don’t actually need to store and compute the actual inverse $\left (\frac{\partial g}{\partial z^\star} \right )^{-T}$, just be able to solve the (linear) equation that arises in this formula.
Finally, let’s discuss how we can implement a formula like this within an automatic differentiation toolkit. The details will of course vary from framework to framework, but since we’re ultimately talking about implementing a new type of function here (i.e., in the forward pass computing a fixed point outside of any automatic differentiation, and then computing a “custom” backward pass) you may be tempted to use a feature like the autograd.Function
interface (if you were implementing this in PyTorch, for instance), which lets you specify a forward and backward pass entirely outside the normal automatic differentiation pass of the library. But this would actually be a bit cumbersome in practice: after all, one of the benefits of automatic differentiation is that we could potentially implement the function $g$ (whether it used convolutions, self-attention, or any other feature) inside the same automatic differentiation library, and we would to automatically include all these gradients without writing a new function for each particular function $g$ we want to implement. Fortunately, there is a fairly straightforward though subtle way to deal with this issue. We’ll return to several examples of
efficient implicit differentiation in later sections, which each have their own implementation quirks, but for a simple example like this, a common paradigm that works is the following three steps:
- Outside the automatic differentiation tape, solve for the root of the implicit layer $g(x,z^\star) = 0$.
- “Reengage” the automatic differentiation by running the following assignment within the automatic differentation tape:
This has the effect of “reinserting” the partial derivatives $-\frac{\partial g}{\partial x}$ to the autograd tape (and is a no-op in terms of the value of $z$, since $g(x,z^\star) = 0$ ).
- Add a “backward hook” to the backward pass that multiplies by $(\frac{\partial g}{\partial z^\star})^{-T}$. This will fix the backward pass so that it correctly implements the gradient according to the implicit function theorem.
For the tanh + linear layer from before, this results in an implementation like that below. Notice that the layer is essentially identical to the version we implemented previously, except that Newton’s method runs within a torch.no_grad():
block, and we add the backward pass hook via the register_hook
function. For the second step above, given the $g$ function as highlighted before, the assignment is simply
i.e., we run a single fixed point iteration, within the automatic differentiation tape, after finding the fixed point with Newton’s method.
class TanhNewtonImplicitLayer(nn.Module):
def __init__(self, out_features, tol = 1e-4, max_iter=50):
super().__init__()
self.linear = nn.Linear(out_features, out_features, bias=False)
self.tol = tol
self.max_iter = max_iter
def forward(self, x):
# Run Newton's method outside of the autograd framework
with torch.no_grad():
z = torch.tanh(x)
self.iterations = 0
while self.iterations < self.max_iter:
z_linear = self.linear(z) + x
g = z - torch.tanh(z_linear)
self.err = torch.norm(g)
if self.err < self.tol:
break
# newton step
J = torch.eye(z.shape[1])[None,:,:] - (1 / torch.cosh(z_linear)**2)[:,:,None]*self.linear.weight[None,:,:]
z = z - torch.solve(g[:,:,None], J)[0][:,:,0]
self.iterations += 1
# reengage autograd and add the gradient hook
z = torch.tanh(self.linear(z) + x)
z.register_hook(lambda grad : torch.solve(grad[:,:,None], J.transpose(1,2))[0][:,:,0])
return z
Note that this is a fairly non-standard implementation: we are implementing an element of the forward pass outside the normal automatic differentiation tape, and then adding a backwards hood to “fix” the gradient. We can verify the correctness of this layer using the building gradcheck
command. Note that this implmentation won’t work for double backprop (i.e., gradgradcheck
will not work), but this can be addressed by a slightly more involved approach, and often isn’t needed in practice, so we ignore it for now.
from torch.autograd import gradcheck
layer = TanhNewtonImplicitLayer(5, tol=1e-10).double()
gradcheck(layer, torch.randn(3, 5, requires_grad=True, dtype=torch.double), check_undefined_grad=False)
True
Finally, again for demonstration, we’ll train our MNIST network with this new variant of the implicit layer. As hoped, the method is indeed quite a bit faster and substantially more stable than the previous implementation of Newton’s method. And while we again highlight that using Newton’s method exactly isn’t typically a reasonable approach for settings like these, somemthing very similar will in fact be quite useful when we discuss differentiable optimimzation in later chapters.
torch.manual_seed(0)
model = nn.Sequential(nn.Flatten(),
nn.Linear(784, 100),
TanhNewtonImplicitLayer(100, max_iter=40),
nn.Linear(100, 10)
).to(device)
opt = optim.SGD(model.parameters(), lr=1e-1)
for i in range(10):
if i == 5:
opt.param_groups[0]["lr"] = 1e-2
train_err, train_loss, train_fpiter = epoch(train_loader, model, opt, lambda x : x[2].iterations)
test_err, test_loss, test_fpiter = epoch(test_loader, model, monitor = lambda x : x[2].iterations)
print(f"Train Error: {train_err:.4f}, Loss: {train_loss:.4f}, Newton Iters: {train_fpiter:.2f} | " +
f"Test Error: {test_err:.4f}, Loss: {test_loss:.4f}, Newton Iters: {test_fpiter:.2f}")
Final chapter remarks
Before delving into the much more realistic and varied world of practical implicit models, we want to highlight our accomplishments so far. Using very little additional code from a “traditional” deep model (and definitely not that much more than a traditional recurrent model), we’re able to code a layer that 1) solves a non-linear root-finding problem via Newton’s method, equivalent to finding the fixed-point of an infinite-depth network, and 2) integrates easily into automatic differentiation tools. The relative ease of these approaches, once your push past a bit of the mathematical notation of implicit differentiation, is indeed one of the more compelling factors of using implicit layers within deep learning as a whole.
Throughout the rest of this tutorial, we will provide you with the tools and background you need to apply implicit layers to a wide variety of problems and settings, with code examples throughout. Our hope is that this will enable readers to quickly integrate and make progress in this new directions.