I’ve noticed a few repeated patterns related to the for
loop from code-base to
code-base. In more performant systems languages like C, C++, Rust and so on,
hand-writing these iterative algorithms is less problematic because they get
translated to pretty good codegen anyways. In Python, however, these algorithms
are too dynamic to be optimized and hence severely bog down your performance.
Let’s dissect a few case studies to see some better alternatives.
Python’s dirty (not so) secret
People often associate a Python for
loop as being analogous to a C-style for
loop, but only “slower” for some reason. In reality, apart from the keyword,
Python’s for
loop shares nothing in common with a C-style for
loop. To
understand why, we have to appreciate the fact that Python actually let’s you
write your own iterators in a minimal amount of code. Using the magic methods
__iter__
and __next__
:
class Evens:
def __iter__(self):
self.num = 0
return self
def __next__(self):
_num = self.num
self.num += 2
return _num
By defining these magic methods (and yes these are the only ones you strictly
need!) we can now use instances of this class as the suffix of a for
loop. Can
you guess what the output of this loop is?
for even in Evens():
print(even)
Of course, this program will print out even numbers and never halt, which would
be the equivalent of this for
loop in C:
for (int num = 0;; num += 2)
printf("%d\n", num);
So far so good, right? Then why did I say that Python iterators have nothing in
common with C loops? Well, let’s look back at our Evens
class and try to add a
stopping condition. I said earlier that there are no other magic methods
associated with iterators (and I wasn’t lying), so how are we going to halt our
loop? If you look at the python documentation you might be surprised to find the
following
excerpt:
exception
StopIteration
: Raised by built-in functionnext()
and an iterator’s__next__()
method to signal that there are no further items produced by the iterator.
Ah, so if we want our iterator to have a stopping condition we’re going to need to throw an exception…1 Is the reader starting to understand where Python’s additional overhead comes from?
Insert facepalm here
To be clear, when you program a for
loop using an iterator, you’re essentially
doing the equivalent of this:
the_iter = iter(old_list)
while True:
try:
x = next(the_iter)
except StopIteration:
break
else:
# do something with x ...
And here we’ve exposed the for
loop’s dirty little (not so) secret. Every time
a for
loop is run we have to halt execution and catch a StopIteration
class!
This makes for
loops really slow in Python and can really hurt
performance.2
Now that we’ve uncovered the mask, let’s look at some case studies and see if we
can either avoid the for
loop or mitigate it in some way.
For loops in action
For loops to accumulate
Say you have some collection of objects and you want to compute some accumulated value dependent on the collection of objects. A basic example code snippet might look like this:
value = initial_value
for element in collection:
value = accumulate(value, element)
This is probably one of, if not the most popular use case for for
loops. In
fact, it’s so common a pattern that Python added this functor to the
standard library.
Even further, if the accumulate function is algebraic addition, we can use the
built-in function sum
for even better performance:
value = sum(collection)
For some reason Python thought it necessary to implement this
sum
function but not to implement amultiply
function, which would always be well-defined over an algebra by definition. The inner mathematician in me has shed a tear.
For loops to map
If accumulate loops take the cake for most common pattern then container
initializers take home the silver medal. If you’re unfamiliar with map
from
functional programming, what I’m talking about looks something like this:
new_collection = []
for element in previous_collection:
x = do_something(element)
new_collection.append(x)
What’s going on here is that we’re mapping every element of
previous_collection
to some other element in new_collection
, without
overriding the old collection. Another flavor of this pattern looks like this:
new_collection = []
for element in previous_collection:
x = do_something(element)
if predicate(x):
new_collection.append(x)
This is really just a
map
paired with areduce
, but I digress.
Recently, Python also added list comprehension to the language, so you can also write this code like so:
# Make a new list
new_collection = [do_something(x) for x in previous_collection]
# or also a new set
new_collection = {do_something(x) for x in previous_collection}
# or even a dictionary!
new_collection = {i: do_something(x) for i, x in enumerate(previous_collection)}
For loops to transform
I admit, this is a bit less common in Python, mainly because most libraries provide more convenient methods for users. It’s still common enough that it’s worth mentioning though, where you modify the collection you’re iterating over by accessing it somehow. For a list type, this is programmatically easy to do:
for idx, x in enumerate(collection):
collection[idx] = do_something(x)
Funny enough, this use-case is harder to perform in Python than in C++, which allows you to grab a reference of the element which allows you to modify it without going through an access function:
for (auto &x : collection)
x = do_something(x)
Root causing the issue (and fixing it)
The main issue here (if you’re of the ilk concerned with performance and code
cleanliness) is that these uses of for
loops are often abused too much, being
used to iterate over thousands or more elements. We can often get more readable
and performant code by using a few alternatives.
The map function
Python provides a built-in function called map
, which takes as input a
function and an iterator object (or an object that can be converted to an
iterator). In return it outputs a new iterator object which, when stepped
through, will apply your function to the element before giving it back to you.
Here’s a simple example:
def kernel(x):
return x + 1
old_collection = range(N)
# no calls to kernel yet!
new_collection = map(kernel, old_collection)
# Now we made a list
new_collection = list(new_collection)
3rd-party vectorization
map
will often perform better than equivalent for
loops because it’s
implemented in C code and can make a few more optimizations. However, for
numerical programming, map
is simply not going to cut it for performance. This
is why libraries like Numba, JAX and NumPy have provided better alternatives for
dispatching math kernels to collections of data. As an example, Let’s use JAX
and NumPy as examples. JAX provides the vmap
function and likewise numpy
provides vectorize
:
import numpy
import jax
foo_numpy = numpy.vectorize(do_something)
foo_jax = jax.vmap(do_something)
They have the same type signature of the previous functions except that they now
accept ndarray
and DeviceArray
with an extra axis respectively. Even Pandas
allows you to map a function over a collection of data with both map
and
apply
:
new_series = old_series.apply(lambda x: x + 1)
Performance
To test the differences we will use a very simple example kernel which just adds 1
def kernel(x): return x + 1
Running these different variations on a list of length 10 million and using the simple kernel we get the following results
Function Call | Time |
---|---|
for loop on list | 0.646 |
map on list | 0.372 |
list comprehension on list | 0.496 |
for loop on ndarray | 0.957 |
vectorize on ndarray | 0.765 |
vmap on DeviceArray | 0.0125 |
What we see is that for larger arrays we can get up to double the performance of
a for
loop by just using map
! List comprehension still performed better but
lies somewhere in between the two. What’s really surprising is how well the
JAX
implementation of numpy’s API scaled the operation.
Footnotes
-
There’s a very relevant exception (haha, punny) to this rule and that’s the built-in sequence types, which include the
list
type, thetuple
type and therange
type. None of these things implement the iterator magic methods, but they all can be used infor
loops and other iterative language patterns as if they were iterators (but with less runtime cost). ↩ -
Falling off an array has a funny history attached to it; Bryan Cantrill from Sun Microsystems recalls a time when Twitter was spending over 400ms per request because of an abuse of the Ruby language. ↩