Another day, another opportunity to learn cool new math tricks. One of the coolest parts about machine learning, at least in my humble opinion, is the great lengths the community has gone to make it as easy as possible. Libraries like Tensorflow, PyTorch, JAX and Julia’s Flux especially go above and beyond by giving us the blessing of automatic differentiation (AD).
The brilliance of these libraries can hardly be overstated; they take care of everything from providing linear algebra libraries to creating CUDA kernels for GPU offloading. They even go as far as fusing multiple operations together and having their very own intermediate representation for JIT compiling!
Most machine learning is done through some augmentation of the gradient descent algorithm, which is basically just this update rule:
applied over and over again. The issue though, is that finding the gradient of the loss function through pen and paper can get pretty difficult for more complex loss functions. What’s even more challenging is keeping track of all the intermediate gradients of a deep neural network. Are we seriously expected to spend hours re-checking our math every time we want to change the loss function or model? AD is meant to help deal with this problem, but how exactly do we implement code that can differentiate complex functions?
The derivative is a Monad
I’m assuming you have at least a working knowledge of derivatives and what they do.1 For example, you shouldn’t need to squint your eyes too hard to mentally verify some of these theorems:
If the notation is confusing you just mentally replace every instance of with instead. I explained in a previous blog post why I prefer to use the former notation, which you can check out in the footnotes.
Assuming a working knowledge then, I want to walk us through a proof2 that the gradient operator forms a Monad over the category of differentiable functions. Some of you may have heard of Monads in passing and were scared off; some others may haven’t heard of Monads at all, and that’s okay too!
While I won’t make us go through the suffering that is explaining what a Monad is, I’m simply going to state ahead of time all the ingredients we need to claim something forms a Monad. Our ingredient list is as follows:
- An EndoFunctor: We need to show that the gradient follows the definition of an EndoFunctor.
- A binary operator: We need some well-defined way of taking two derivatives and combining them together to form a third, new derivative.
- Identity: We need some function such that applying our shiny new binary operator to any derivative and this identity, we simply get the other derivative back.
That sounds like a scary word but it’s really more old-school than complex. Endo is a prefix derived from Greek meaning “inner” or “within” i.e endoscopy, endoskelaton etc… A functor can be thought of programmatically as a function acting on functions (although the details of functors are much more nuanced than this, this definition is good enough for us).
The part about the EndoFunctor
Proving that the gradient is an EndoFunctor is actually quite trivial. All we need to do is recall the definition and then compare that to the inputs and outputs of the gradient. So in order to be an EndoFunctor the input’s type has to be the same as the output’s type and the type itself has to be a function.
Well, if we see what the gradient does, it takes a function as input (eg. ) and returns a new function as it output (i.e. ). Wow, that was pretty darn straight forward! The gradient is clearly an EndoFunctor.
The part about a binary operator
Okay, if we can get through this step then everything else is smooth sailing. What we’re looking for here is some way of taking two separate derivative functions and combining them to make a new derivative function.
That sounds like a deceivingly simple task because there’s so many wacky options we can choose here. As a dumb example we can just take the addition operator as our binary operator. So for arbitrary functions and we have
but this is a pretty stupid choice because we haven’t actually gained anything of interest here. Specifically, this is just a natural consequence of the derivative being a linear operator. What would really be awesome here is if we can come up with a binary operator such that the right hand side is the derivative of the composition of with . In other words, what we really want is something like
Well, maybe we’re stumped because this notation looks weird. Let’s instead try using an intermediate value here. Let’s define the variable . So now we can instead write . Well, I can’t take the derivative of with respect to directly, but I can use the chain rule, which would give me
and if we subsitute back in we can get
that looks almost identical to what we wanted, except that is now a function of . But think about it for a second: does it really matter what letter or function is tied to? If we’re differentiating a single-variable function with respect to that singular variable, we can change the underlying variable without actually affecting the derivative! So let’s subsitute for and get that
or if we want to write it in a variable-agnostic way:
and just like that we now have a rule that lets us build up arbitrarily complex functions! For any crazy-looking function you can cook up, it can always be broken up into a composition of simpler functions ad nauseum until every function in the composed version is a function we know how to differentiate.
The part about the Identity
I promised smooth sailing and so smooth sailing is what we’re getting. All we need for this step is a function such that, for any function ,
okay, well, let’s think here: is there a differentiable function such that its derivative is always 1? Well of course! That’s just the identity function . And just like that we’ve proven the last step.
Step 3: Profit
At this point you might be wondering why in the world we went through the riggamaroll of doing this in the first place. Well, I’m glad you asked, silent reader - allow me to show you! Let’s sketch out a quick plan for how we can take the derivative of a not-so-trivial function - the sigmoid:
- Start off with the identity function we discovered earlier. The reason we need this step is because the derivative is only defined on functions, not real numbers! So you can think of this step as our initialization.
- We apply a single composition to our origional identity function. This is done so that we can construct our target function (in this case the sigmoid) using nothing but our elementary basis functions
- For every composition we do in step 2, we use our derivative rules for the elementary functions and the chain rule to evaluate the gradient of the new composed function
- Repeat steps 2 and 3 ad nauseum until we reach the target function
That seems simple enough, right? Let’s start coding then!
Attempting an implementation
I’m going to start with the simplest implementation I can think of using the simplest Language I know. We’re going to keep track of the gradient as a global variable. We then define our elementary functions such that they manipulate the gradient using the chain rule, and finally they also return the function result:
grad = 1.0
def multiply(x, a):
global grad
grad *= a
return a * x
def add(x, a):
global grad
grad *= 1.0
return a + x
def exp(x):
global grad
e = 2.71828
grad *= e ** x
return e ** x
def reciprocal(x):
global grad
grad *= - 1.0 / (x * x)
return 1.0 / x
Since these functions are composable we can simply plug them into each other and we should get the correct gradient. For testing purposes we can look up the actual derivative of the sigmoid which turns out to be
and simply test our code using , which gives us the output
with gradient
and sure enough if we define the sigmoid function and apply it to 0 we get
def sigmoid(x):
x = multiply(x, -1.0)
x = exp(x)
x = add(x, 1.0)
x = reciprocal(x)
return x
value = sigmoid(0.0)
print(value, grad)
>>> 0.5 0.25
great success! But our current implementation kinda… sucks. Apart from the fact that these are impure functions (which you can easily verify by seeing that they can’t be memoized correctly) this implementation is also limited to dealing with just one function. Let’s try to amend these issues.
Memoization is the process of saving the output of an expensive-to-call function into some container like a map. This way, for any given input you only need to do the expensive calculation once. Every subsequent call to the function with identical inputs can be substituted with the stored output.
A second attempt
How should we get rid of the global variable? The simplest tweak we can do is treat the gradient as state and pass it as an input to the function. We then need to also return the modified gradient as an output, which looks like this:
def multiply(x, a, grad):
return a * x, grad * a
def add(x, a, grad):
return a + x, grad * 1.0
def exp(x, grad):
e = 2.71828
value = e ** x
return value, grad * value
def reciprocal(x, grad):
return 1.0 / x, grad * - 1.0 / (x * x)
composition of these functions would then look like this
def sigmoid(x, grad):
x, grad = multiply(x, -1.0, grad)
x, grad = exp(x, grad)
x, grad = add(x, 1.0, grad)
x, grad = reciprocal(x, grad)
return x, grad
it’s pretty obvious to see that this is not a really great pattern to use. First
off, our functions are still not memoizable because even when is the same
value there are almost an infinite number of possible values for grad
(all of
which are totally valid!). We’ve also mucked up the function signature and made
it weird.
Making things classier
Object-oriented programming is still the most popular paradigm these days, so why not try using a class to make a better implementation? We can make a class to represent one variable, with all the base operators implemented as class methods. So an example solution would look like this:
class Number:
def __init__(self, value, grad):
self.value = value
self.grad = grad
def multiply(self, a):
self.value *= a
self.grad *= a
def add(self, a):
self.value += a
self.grad *= 1.0
def exp(self):
result = 2.71828 ** self.value
self.value = result
self.grad *= result
def reciprocal(self):
value = self.value
self.value = 1.0 / value
self.grad *= - 1.0 / (value * value)
This kind of works, but what about composition? How are we supposed to write our sigmoid function? There are multiple ways to solve this issue but in my opinion none of them really stick out as great. The most object-oriented solution would be through inheritence i.e. we would make a new sigmoid class using this number class as the base:
class Sigmoid(Number):
def sigmoid(self):
self.multiply(-1.0)
self.exp()
self.add(1.0)
self.reciprocal()
but this is just absurd! Are we supposed to make a new class every time we need to make a new composed function? We could add this to the base Number class but then shipping this Number class as library code would be an absolute nightmare. What about making a new function without needing to make it a class method? So something like this:
def sigmoid(number: Number):
number.multiply(-1.0)
number.exp()
number.add(1.0)
number.reciprocal()
This might work but we would need to make sure that all our base operations must be public methods.
Another inconvinience is that the syntax for calling a base function is different than the syntax for calling a composed function, which kind of defeats the whole purpose of composition don’t you think?
Back to basics
It feels like we’re just running around in circles. Our first implementation was actually pretty good were it not for the fact we had a global variable; the function signature was super clean, composition was a breeze and memoization only broke the gradient part of the code, not the actual value part. So rather than tapping into a global variable and composing the gradients through state, why not just start off assuming that there was no gradient object to begin with? At that point, the total gradient would just be the gradient of whatever function you have. An example of this for the previous functions would be like so:
def multiply(a):
return lambda x: (a * x, a)
mult_neg_1 = multiply(-1.0)
def add(a):
return lambda x: (a + x, 1.0)
add_1 = add(1.0)
def exp(x):
e = 2.71828
value = e ** x
return value, value
def reciprocal(x):
return 1.0 / x, - 1.0 / (x * x)
Okay… but how do we compsoe these functions? The input is no longer the same type as the output! We’ve destroyed composability :(
Okay, it’s not that dire of a situation. We clearly see that the output is almost like the input, except it’s embellished in some way. Composability is no longer as simple as just shoving one function’s output into the next function’s input, but it’s close enough that we could very easily write a helper function to pretend like that’s what we did, right?
Lett’s write a function which takes this embelished output of one of these functions after evaluation and the next function expecting only the value part of the outpt. This function will then “shove” the value part of the previous output into the next function and combine the gradients of the two in the correct way (through multiplication). Let’s call it “bind”
def bind(val, foo):
x, grad = val
new_x, new_grad = foo(x)
return new_x, grad * new_grad
So what’s the signature of this function you might ask? Well it took an embelished and a function from a non-embelished to an embelsihed , then returned an embelished . So in Haskell’s type signature format it would be
bind :: M a -> (a -> M b) -> M b
we can now use this to create a new function which takes two of these weird functions and composes them to make a new function!
def comp(foo, bar):
return lambda x: bind(bar(x), foo)
which now has the function signature
comp :: (a -> M b) -> (b -> M c) -> (a -> M c)
and now we can make our sigmoid function like so:
def sigmoid(x):
foo = comp(exp, mult_neg_1)
foo = comp(add_1, foo)
foo = comp(reciprocal, foo)
return foo(x)
or, better yet, we can even write this function as a variable instead:
sigmoid = comp(reciprocal,
comp(add_1,
comp(exp, mult_neg_1)
))
A Monoid in the Category of EndoFunctors
I’d like to reflect what we just did here, because Monoids show up constantly in programming and rarely gets taken advantage of. Did we really need the concept of the Monad to write a good AD library? Not really, no; indeed, none of the libraries I praised earlier don’t use this functional approach for performance reaasons. Numerical computing is a very unique case though and in most cases the overhead of creating function objects would be negligible compared to other pieces of the code3.
So if the main libraries don’t do it this way, how do they do it? Pytorch went
the OOP route by storing the gradient object inside their torch.Tensor
class.
Tensorflow went a more functional route while still staying Pythonic; they use a
context manager called tf.GradientTape
which then watches certain variables
inside that context, creating a computational graph. JAX and Flux take an even
bigger step towards the functional paradigm. Rather than computing the gradients
in real-time like PyTorch does, or implicitly like TensorFlow does, JAX and Flux
provide APIs to transform differentiable functions into gradient functions. The
caviat though is that they implemented AD in a more complex way that involves
intermediate representations, operator fusing, special traced arrays and some
more wild things.
To be fair to all these libraries, they are all cross-polinating with each other and converging to very similar APIs. PyTorch is adding a functional API and TensorFlow switched to eager execution by default in their 2.0 release.
Hopefully this provides some insight into how AD works on a math level and why the popular APIs look the way they do :).
Footnotes
-
I’ve also written a blog post about differentiation wink wink hint nudge. ↩
-
Okay, calling this blog post a formal proof is quite the over-reach, but let’s be real here: were you honestly looking to read a 10 page rigorous proof with all the little technicalities ironed out? Yea, that’s what I thought… ↩
-
A great example where functional programming works really well is in the Pandas library, which lets you perform almost every manipulation on a DataFrame object through lambdas, which allows you to compose all your manipulations sequentially without creating any intermediate variables. ↩