When somebody offers to compile your Python code, exactly what kind of mischief are you getting yourself into? What diabolical schemes does a just-in-time compiler enact to transmute sluggish Python code into something speedier?
Toward the end of my last post, I mentioned that I'm working on a library called Parakeet which accelerates numerical Python. In this post, I'm going to illuminate mysterious inner workings of a just-in-time compiler by following a function through its various stages of existence within Parakeet.
Caveat: Parakeet isn't finished, and it's awkward to write so extensively about something I don't yet want anyone using. Nonetheless, Parakeet is the compiler I know best and its relatively simple design will hopefully allow me to convey a general sketch of how JITs work.
The function we're going radically rearrange today is count_thresh
, which sums up the number of elements in an array less than a given threshold.
This little ditty of looping and summation is simple enough that I hope its relationship to the code we later generate will stay evident. It's not, however, an entirely contrived computation. If you were to throw in an array of class labels and dash more complexity, you would soon have the core of a decision tree learning algorithm.
Compared with the wild menagerie of run-time compilation techniques that have developed over the past decade, Parakeet is a relatively modest function-specializing compiler. If you want Parakeet to compile a particular function, then wrap that function with the@jit
decorator.
Like this:
The job of @jit
is to intercept calls into the wrapped function and then to initiate the following chain of events:
- Translate the function into an untyped representation, from which we'll later derive multiple type specializations.
- Specialize the untyped function for any argument types which get passed in.
- Optimize the merciless heck out of the typed code, and translate abstractions such as tuples and n-dimensional arrays into simple heap-allocated structures with low-level access code.
- Translate the optimized and lowered code into LLVM, and let someone else worry lower-level optimizations and how to generate architecture-specific native code.
I can just stick that decorator atop any Python function and it will magically run faster? Great! I'll paste @jit all over my code and my Python performance problems will be solved!
Easy with those decorators! Parakeet is not a general-purpose compiler for all of Python. Parakeet only supports a handful of Python's data types: numbers, tuples, slices, and NumPy arrays.
To manipulate these values, Parakeet lets you use any of the usual math and logic operators, along with some, but not all, of the built-in functions. Functions such as range
are compiled to deviate from their usual behavior — in Python their result would be a list but in Parakeet such functions create NumPy arrays.
If your performance bottleneck doesn't fit neatly into Parakeet's restrictive universe then you might benefit from a faster Python implementation, or alternatively you could outsource some of your functionality to native code via Cython.
Let's continue, following count_thresh
on its inexorable march toward efficiency.
From Python into Parakeet
When trying to extract an executable representation of a Python function, we face a choice between using its syntax tree:
...or the lower-level bytecode which the Python interpreter actually executes:
Neither is ideal for program analysis and transformation, but since the bytecode is littered with distracting stack manipulation and discards some of the higher-level language constructs, let's start with a syntax tree and quickly slip into something a little more domain specific.
Untyped Representation
When a function is handed over from Python into Parakeet, it is translated into a form that is mostly similar to an ordinary Python syntax tree. There, are however, a few key differences:
- For-loops must traverse numeric ranges, so
for x in xs
gets translated intofor i in range(len(xs)): x = xs[i]
- There is a suspicious looking
phi
expression at the top of the loop. What is that thing? Does it have anything to do with the name of this site?
Looking even closer at the code above, you'll notice the variable n
has been split apart into three distinct names: n
, n2
and n_loop
. What can account for such triplicative witchery?
Calm your agitation, dear reader. You're looking at a variant of Static Single Assignment form. I'll write about SSA in more detail later, but for now the most important things to know about it are:
- Every distinct assignment to a variable in the original program becomes the creation of distinct variable. Reminiscent of functional programming, no?
- At a point in the program where control flow could have come from multiple places (such as the top of a loop), we explicitly denote the possible sources of a variable's value using a φ-node.
- In exchange for all these variable name gymnastics we get a tremendous simplification in the onerous task of writing program analyses. It may not be immediately obvious why, but this post is already too long, so trust me for now.
Another difference from Python is that Parakeet's representation treats many array operations as first-class constructs. For example, in ordinary Python len
is a library function, whereas in Parakeet it's actually part of the language syntax and thus can be analyzed with higher-level knowledge of its behavior. This is particular useful for inferring the shapes of intermediate array values.
Type-specialized Representation
When you call an untyped Parakeet function, it gets cloned for each distinct set of input types. The types of the other (non-input) variables are then inferred and the body of the function is rewritten to insert casts wherever necessary.
In the case of count_thresh
, observe that the function has been specialized for two inputs of type array1(float64)
and float64
and that its return type is known to be int64
. Furthermore, the boolean intermediate value produced by checking if an element is less than the threshold is cast to int64
before getting added to n2
If you use a variable in a way that defeats type inference (for example, by treating it sometimes as an array and other times as a scalar), then Parakeet gives up on your code and raises an error.
Optimize mercilessly!
Type specialization already gives us a big performance boost by enabling the use of an unboxed representation for numbers. Adding two floats stored in registers is orders of magnitude faster than calling __add__ on PyFloatObjects.
However, if all Parakeet did was specialize your code it would still be significantly slower than programming in a lower-level language. The compiler needs to exert more effort to contort and transform array-oriented Python code into the lean mean loops you would expect to get from a C compiler. Parakeet attacks sluggish code with the usual battery of standard optimizations, such as constant propagation,
common sub-expression elimination, and loop invariant code motion. Furthermore, to mitigate the abstraction cost of array expressions such as 0.5*vec1 + 0.5*vec2
Parakeet fuses array operators, which then exposes further opportunities for optimization.
In this case, however, the computation is simple enough that only a few optimizations can meaningfully change it. I turned off loop unrolling for this post since it significantly expands the size of the produced code.
In addition to rewriting code for performance gain, Parakeet also "lowers" higher-level constructs such as tuples and arrays into more primitive concepts. Notice that the above code does not directly index into n-dimensional arrays, but rather explicitly computes offsets and indexes directly into an array's data pointer. Lowering complex language constructs simplifies the next stage of program transformation: the escape from Parakeet into LLVM.
LLVM
LLVM is a delightfully well-engineered compiler toolkit which
that comes with its a powerful arsenal of optimizations and generates native code for a variety of platforms. To get LLVM to finish the job of compiling count_thresh
, we need to translate into assembly language. Once the Parakeet representation has been typed, optimized, and stripped clean of abstractions, the translation to LLVM turns out to be surprisingly easy. Sure, there's some plumbing work to map between Parakeet's types and LLVM's type system, but that's probably the most straightforward part of this whole pipeline.
Generated Assembly
Once we pass the torch to LLVM, Parakeet's job is mostly done. LLVM chisels the code we've given it with its bevy of optimizations passes. Once every last inefficiency has been ferreted out and exterminated, LLVM uses a platform-specific back-end to translate from its assembly language into native instructions. And thus, at last, we arrive at native code:
Reading x86-64 assembly is tedious, so I won't expect you to make sense of this code dump. But do notice that we end up with the same number of machine instructions as we originally had Python bytecodes. It's safe to suspect that the performance might have somewhat improved.
How much faster is it?
In addition to benchmarking against the Python interpreter (an unfair comparison with a predictable outcome), let's also see Parakeet stacks up against an equivalent function implemented using NumPy primitives:
I gave the the NumPy, Python, and Parakeet versions of count_thresh
1 million randomly generated inputs and averaged the time they took to complete over 5 runs.
Python | NumPy | Parakeet |
---|---|---|
3.7205 | 0.0036 | 0.0025 |
Execution time in seconds |
Not bad — Parakeet is about about 1500 times faster than vanilla Python and even manages to edge out NumPy by a safe margin. Still, that NumPy code is tantalizingly more compact than the explicit loop we've been working with throughout this post.
Can Parakeet compile something that looks more like the NumPy version of count_thresh
? In fact yes, you can (and are encouraged to) write code in a high-level array-oriented style with Parakeet. However, an explanation of how such code gets compiled (and parallelized) will have to wait until I discuss Parakeet's data parallel operators in the next post.