Do Not Taunt Happy Fun Branch Predictor

I've been writing a lot of AArch64 assembly, for reasons.

I recently came up with a "clever" idea to eliminate one jump from an inner loop, and was surprised to find that it slowed things down. Allow me to explain my terrible error, so that you don't fall victim in the future.

A toy model of the relevant code looks something like this:

float run(const float* data, size_t n) {
    float g = 0.0;
    while (n) {
        n--;
        const float f = *data++;
        foo(f, &g);
    }
    return g;
}

static void foo(float f, float* g) {
    // do some stuff, modifying g
}

(eliding headers and the forward declaration of foo for space)

A simple translation into AArch64 assembly gives something like this:

// x0: const float* data
// x1: size_t n
// Returns a single float in s0

// Prelude: store frame and link registers
stp   x29, x30, [sp, #-16]!

// Initialize g = 0.0
fmov s0, #0.0

loop:
    cmp x1, #0
    b.eq exit
    sub x1, x1, #1
    ldr s1, [x0], #4

    bl foo   // call the function
    b loop   // keep looping

foo:
    // Do some work, reading from s1 and accumulating into s0
    // ...
    ret

exit: // Function exit
    ldp   x29, x30, [sp], #16
    ret

Here, foo is kinda like a naked function: it uses the same stack frame and registers as the parent function, reads from s1, and writes to s0.

The call to foo uses the the bl instruction, which is "branch and link": it jumps to the given label, and stores the next instruction address in the link register (lr or x30).

When foo is done, the ret instruction jumps to the address in the link register, which is the instruction following the original bl.

Looking at this code, I was struck by the fact that it does two branches, one after the other. Surely, it would be more efficient to only branch once.

I had the clever idea to do so without changing foo:

stp   x29, x30, [sp, #-16]!
fmov s0, #0.0

bl loop // Set up x30 to point to the loop entrance
loop:
    cmp x1, #0
    b.eq exit
    sub x1, x1, #1
    ldr s1, [x0], #4

foo:
    // Do some work, accumulating into `s0`
    // ...
    ret

exit: // Function exit
    ldp   x29, x30, [sp], #16
    ret

This is a little subtle:

Within the body of the loop, we never change x30, so the repeated ret instructions always return to the same place.

I set up a benchmark using a very simple foo:

foo:
    fadd s0, s0, s1
    ret

With this foo, the function as a whole sums the incoming array of float values.

Benchmarking with criterion (on an M1 Max CPU), with a 1024-element array:

ProgramTime
Original 969 ns
"Optimized"3.85 µs

The "optimized" code with one jump per loop is about 4x slower than the original version with two jumps per loop!

I found this surprising, so I asked a few colleagues about it.

Between Cliff and Dan, the consensus was that mismatched bl / ret pairs were confusing the branch predictor.

The ARM documentation agrees:

Why do we need a special function return instruction? Functionally, BR LR would do the same job as RET. Using RET tells the processor that this is a function return. Most modern processors, and all Cortex-A processors, support branch prediction. Knowing that this is a function return allows processors to more accurately predict the branch.

Branch predictors guess the direction the program flow will take across branches. The guess is used to decide what to load into a pipeline with instructions waiting to be processed. If the branch predictor guesses correctly, the pipeline has the correct instructions and the processor does not have to wait for instructions to be loaded from memory.

More specifically, the branch predictor probably keeps an internal stack of function return addresses, which is pushed to whenever a bl is executed. When the branch predictor sees a ret coming down the pipeline, it assumes that you're returning to the address associated with the most recent bl (and begins prefetching / speculative execution / whatever), then pops that top address from its internal stack.

This works if you've got matched bl / ret pairs, but the prediction will fail if the same address is used by multiple ret instructions; you'll end up with (vague handwaving) useless prefetching, incorrect speculative execution, and pipeline stalls / flushes

Dan made the great suggestion of replacing ret with br x30 to test this theory. Sure enough, this fixes the performance regression:

ProgramTime
Matched bl / ret 969 ns
One bl, many ret3.85 µs
One bl, many br x30913 ns

In fact, it's slightly faster, probably because it's only doing one branch per loop instead of two!

To further test the "branch predictor" theory, I opened up Instruments and examined performance counters for the first two programs. Picking out the worst offenders, the results seem conclusive:

CounterMatched bl / retOne bl, many ret
BRANCH_RET_INDIR_MISPRED_NONSPECIFIC92928,644,975
FETCH_RESTART61,121987,765,276
MAP_DISPATCH_BUBBLE1,155,6327,350,085,139
MAP_REWIND6,412,7342,789,499,545

These measurements are captured while summing an array of 1B elements. We see that with mismatched bl / ret pairs, the return branch predictor fails about 93% of the time!

Apple doesn't fully document these counters, but I'm guessing that the other counters are downstream effects of bad branch prediction:

In conclusion, do not taunt happy fun branch predictor with asymmetric usage of bl and ret instructions.


Appendix: Going Fast

Take a second look at this program:

stp   x29, x30, [sp, #-16]!
fmov s0, #0.0

loop:
    cmp x1, #0
    b.eq exit
    sub x1, x1, #1
    ldr s1, [x0], #4

    bl foo   // call the function
    b loop   // keep looping

foo:
    fadd s0, s0, s1
    ret

exit: // Function exit
    ldp   x29, x30, [sp], #16
    ret

Upon seeing this program, it's a common reaction to ask "why is foo a subroutine at all?"

The answer is "because this is a didactic example, not code that's trying to go as fast as possible".

Still, it's a fair question. You wanna go fast? Let's go fast.

If we know the contents of foo when building this function (and it's shorter than the maximum jump distance), we can remove the bl and ret entirely:

loop:
    cmp x1, #0
    b.eq exit
    sub x1, x1, #1
    ldr s1, [x0], #4

    // foo is completely inlined here
    fadd s0, s0, s1

    b loop

exit: // Function exit
    ldp   x29, x30, [sp], #16
    ret

This is a roughly 6% speedup: from 969 ns to 911 ns.

We can get faster still by trusting the compiler:

pub fn sum_slice(f: &[f32]) -> f32 {
    f.iter().sum()
}

This brings us down to 833 ns, a significant improvement!

Looking at the assembly, it's doing some loop unrolling. However, even when compiled with -C target-cpu=native, it's not generating NEON SIMD instructions. Can we beat it?

We sure can!

stp   x29, x30, [sp, #-16]!

fmov s0, #0.0
dup v1.4s, v0.s[0]
dup v2.4s, v0.s[0]

loop:  // 1x per loop
    ands xzr, x1, #3
    b.eq simd

    sub x1, x1, #1
    ldr s3, [x0], #4

    fadd s0, s0, s3
    b loop

simd:  // 4x SIMD per loop
    ands xzr, x1, #7
    b.eq simd2

    sub x1, x1, #4
    ldp d3, d4, [x0], #16
    mov v3.d[1], v4.d[0]

    fadd v1.4s, v1.4s, v3.4s

    b simd

simd2:  // 2 x 4x SIMD per loop
    cmp x1, #0
    b.eq exit

    sub x1, x1, #8

    ldp d3, d4, [x0], #16
    mov v3.d[1], v4.d[0]
    fadd v1.4s, v1.4s, v3.4s

    ldp d5, d6, [x0], #16
    mov v5.d[1], v6.d[0]
    fadd v2.4s, v2.4s, v5.4s

    b simd2

exit: // function exit
    fadd v2.4s, v2.4s, v1.4s
    mov s1, v2.s[0]
    fadd s0, s0, s1
    mov s1, v2.s[1]
    fadd s0, s0, s1
    mov s1, v2.s[2]
    fadd s0, s0, s1
    mov s1, v2.s[3]
    fadd s0, s0, s1

    ldp   x29, x30, [sp], #16
    ret

This code includes three different loops:

At the function exit, we accumulate the values in the vector registers v1/v2 into s0, which is returned.

The type punning here is particularly cute:

ldp d3, d4, [x0], #16
mov v3.d[1], v4.d[0]
fadd v1.4s, v1.4s, v3.4s

Remember, x0 holds a float*. We pretend that it's a double* to load 128 bits (i.e. 4x float values) into d3 and d4. Then, we move the "double" in d4 to occupy the top 64 bits of the v3 vector register (of which d3 is the lower 64 bits).

Of course, each "double" is two floats, but that doesn't matter when shuffling them around. When summing with fadd, we tell the processor to treat them as four floats (the .4s suffix), and everything works out fine.

How fast are we now?

This runs in 94 ns, or about 8.8x faster than our previous best.

Here's a summary of performance:

ProgramTime
Matched bl / ret 969 ns
One bl, many ret3.85 µs
One bl, many br x30913 ns
Plain loop with b911 ns
Rewrite it in Rust833 ns
SIMD + manual loop unrolling94 ns

Could we get even faster? I'm sure it's possible; I make no claims to being the Agner Fog of AArch64 assembly.

Still, this is a reasonable point to wrap up: we've demystified the initial performance regression, and had some fun hand-writing assembly to go very fast indeed.

The SIMD code does come with one asterisk, though: because floating-point addition is not associative, and it performs the summation in a different order, it may not get the same result as straight-line code. In retrospect, this is likely why the compiler doesn't generate SIMD instructions to compute the sum!

Does this matter for your use case? Only you can know!


All of the code from this post is published to GitHub.

You can reproduce benchmarks by running cargo bench on an ARM64 machine.