Beware of Taking Float16s From the Cookie Jar
As tempting as it may be, always stick to what your hardware supports!
Part of being an industry practitioner requires keeping up-to date on current topics and research. While there’s many cool and exciting new frontiers on the horizon the hardware-software interface has been making exceptional strides. A physics paper in particular caught my eye for suggesting the use of 16-bit floating point (FP) values in stochastic models. The reasoning is that stochastic models are already random by nature, so they would likely be numerically stable against FP round-off error.
I decided to try this idea out on a real example by replicating Burger’s paper on modeling El Niño as a stochastic oscillator1. I used Julia to make sure the code was being compiled down and optimized for my hardware, then I ran the same code with 3 different Data Types. Here are the results:
Type | Memory | Runtime |
---|---|---|
Float64 | 10.44 Gib | 2.431 s |
Float32 | 5.22 Gib | 1.738 s |
Float16 | 2.61 Gib | 31.14 s |
5 out of 6 of these numbers make a lot of sense, but that 6th one is pretty surprising, isn’t it? The other 5 numbers show a strong implication that halving the size of the Data Type doubles the runtime performance and halves the memory complexity, but explaining why that 6th number isn’t a fluke will require some digging, so let’s get started.2
A basic example
Let’s simplify life as much as possible and test-compile the simplest function I can think of that deals with floats - addition:
function foo(T::DataType, x, y)
a = convert(T, x)
b = convert(T, y)
a + b
end
Note that this code is not like
interpret_cast
of C++ - it isn’t dynamically interpreting the value as a certain DataType, but rather converting the value whole-meal instead.
Our gunny pig will be a Ryzen 2700 in a desktop computer to represent the modern desktop. Compiling the code on this lovely machine using 64-bit FPs we get a truly lovely result:
julia> @code_native foo(Float64, 1, 2)
vcvtsi2sd %rdi, %xmm0, %xmm0
vcvtsi2sd %rsi, %xmm1, %xmm1
vaddsd %xmm1, %xmm0, %xmm0
retq
Okay, for people who’ve never seen assembly before this may look confusing, but
I promise that once you know what the acronyms are it will be (very) clear that
Julia and LLVM did a fantastic job optimizing our code. First off, what the hell
is vcvtsi2sd
? It’s an accronym, so let’s just expand it letter-by-letter. The
first v stands for vector and indicates that we are using the SIMD extension to
the x86_64 instruction set. This means that Julia thought it best to use the
vectorized instructions, even for non-array inputs. The next piece is cvt which
stands for convert. The thing it’s converting is coming from the rdi
register,
which is a general register currently holding the variable x. What is it
converting it to? Well, let’s expand the rest of the accronyms: the si is for
scalar integer, the 2 is self-explanitory and the sd ending is scalar double.
Ah! So it’s converting the integer to a double (i.e. 64-bit FP) just like we
told it to. It placed the result in the xmm0
register, which is part of the
XMM registers used by the SIMD extension. The next line is a copy of the first
but for the y variable, and finally we reach vaddsd
which should hopefully be
clear now that it’s an accronym for “vector add scalar double.” It saves the
result in the x register as a simple optimization, then finally returns the
result with retq
.
Clear as mud, right?
We can make the assembly even smaller by using floats as inputs to begin with:
julia> @code_native foo(Float64, 1., 2.)
vaddsd %xmm1, %xmm0, %xmm0
retq
By changing the inputs from 1
and 2
to 1.0
and 2.0
we’ve cut down on two
machine instructions and removed literally all possible overhead that we can at
this point. In other words, Julia allowed us to write a clean function in a
high-level language and gave us Type safety, SIMD optimization at zero
overhead cost. Pretty nifty if you ask me.
Okay, so 64-bit FP passes the test. What about 32-bit FPs? Well, rather than
seeing sd
or scalar double we’re going to see a new suffix: ss
or “scalar
single:”
julia> @code_native foo(Float32, 1, 2)
vcvtsi2ss %rdi, %xmm0, %xmm0
vcvtsi2ss %rsi, %xmm1, %xmm1
vaddss %xmm1, %xmm0, %xmm0
retq
and likewise if we use floats as inputs Julia removes the conversions since it doesn’t need them:
julia> @code_native foo(Float32, 1., 2.)
vaddss %xmm1, %xmm0, %xmm0
retq
This is all going quite splendidly if I do say so. Well, if we can double memory and time complexity by changing doubles to singles, then we can get another doubling by going to 16-bit Floats, right? Let’s check out what our Ryzen CPU has to say about that…
julia> @code_native foo(Float16, 1, 2)
pushq %rax
vcvtsi2ss %rdi, %xmm0, %xmm0
movabsq $139889297183744, %rdx
movabsq $139889297184256, %r8
vmovd %xmm0, %ecx
movq %rcx, %rax
andl $8388607, %ecx # imm = 0x7FFFFF
shrq $23, %rax
orl $8388608, %ecx # imm = 0x800000
movzbl (%rax,%rdx), %edx
shrxl %edx, %ecx, %edi
andl $1023, %edi # imm = 0x3FF
addw (%r8,%rax,2), %di
movl $1, %r8d
movl %edi, %eax
andl $31744, %eax # imm = 0x7C00
cmpl $31744, %eax # imm = 0x7C00
je L157
decq %rdx
cmpq $31, %rdx
ja L157
shlxl %edx, %r8d, %eax
andl %ecx, %eax
je L157
movw $1, %ax
testb $1, %dil
jne L150
movq $-1, %rax
cmpq $63, %rdx
shlxq %rdx, %rax, %rax
movl $16777215, %edx # imm = 0xFFFFFF
notl %eax
cmovbel %eax, %edx
xorl %eax, %eax
testl %ecx, %edx
setne %al
L150:
movzwl %ax, %eax
addl %edi, %eax
movl %eax, %edi
L157:
vcvtsi2ss %rsi, %xmm1, %xmm0
movabsq $.rodata, %rdx
movabsq $139889297182720, %r9
vmovd %xmm0, %ecx
movq %rcx, %rax
andl $8388607, %ecx # imm = 0x7FFFFF
shrq $23, %rax
orl $8388608, %ecx # imm = 0x800000
movzbl (%rax,%rdx), %edx
shrxl %edx, %ecx, %esi
andl $1023, %esi # imm = 0x3FF
addw (%r9,%rax,2), %si
movl %esi, %eax
andl $31744, %eax # imm = 0x7C00
cmpl $31744, %eax # imm = 0x7C00
je L307
decq %rdx
cmpq $31, %rdx
ja L307
shlxl %edx, %r8d, %eax
andl %ecx, %eax
je L307
movw $1, %ax
testb $1, %sil
jne L300
movq $-1, %rax
cmpq $63, %rdx
shlxq %rdx, %rax, %rax
movl $16777215, %edx # imm = 0xFFFFFF
notl %eax
cmovbel %eax, %edx
xorl %eax, %eax
testl %ecx, %edx
setne %al
L300:
movzwl %ax, %eax
addl %esi, %eax
movl %eax, %esi
L307:
movabsq $"+", %rax
callq *%rax
popq %rcx
retq
Ouch. What the hell just happened? We seemed to have passed the honeymoon phase and the CPU took the kids only to come back with divorce papers.
If you sift through the much more cryptic result of using FP-16 you would notice
that most of this code is actually error handling code. Why is doing error
handling, you might ask? Because it’s downcasting the result to a smaller data
type and needs to make sure that it can represent the result in the smaller data
type. You will notice a lot of comparissons (cmp
), moves (movq
), conditional
logic (xorl
) and so on, all of which are a direct result of Julia not being
able to tell ahead of time that the result can be fit in the Float16 data type
(hence it has to do it at runtime rather than compile time).
We don’t even need to be hardware experts to know that this will run like shit compared to the last two versions - and that’s exactly what I saw in my stochastic oscillator code. We still get the expected improvement in memory complexity, if that matters to you.