Brute-forcing 22 trillion parameters

In this year's Advent of Code, one puzzle stands out: Day 24.

You're given a specification for a fictional computer architecture, with only x, y, z, w registers, and a small set of assembly operations to manipulate them.

The puzzle input is a long program ("MONAD") for this computer. MONAD takes a 14-digit input (in base 10) and performs some calculation. The goal of the problem is to find the largest possible 14-digit input for which z = 0 after the calculation is complete.

It's clear from the problem description that you're not meant to brute force this across the entire input space:

MONAD imposes additional, mysterious restrictions on model numbers, and legend says the last copy of the MONAD documentation was eaten by a tanuki. You'll need to figure out what MONAD does some other way.

Looking at the Solution Megathread, most people reverse-engineered the assembly code to figure out what it was actually doing, then rewrote an optimized function to solve the problem.

This works, but it's not a general solution; can we do better?

This writeup describes a path that brings the solution time from 3.6 years down to 4.2 seconds, with a solution that's completely general-purpose: it can work for any problem input, not just the ones crafted to be reverse-engineered.

Very Dumb Brute Force

Let's start from the dumbest possible starting point: we'll build an interpreter that evaluates the MONAD code, then start running it on every 14-digit number.

(The problem is limited to numbers without zeros, so the input space is 914 instead of 1014. This means there's a mere 22 trillion options, instead of 100 trillion)

use std::io::BufRead;

fn run(lines: &[String], input: usize) -> bool {
    let mut index = 14u32;
    let mut regs = [0i64; 4];
    for line in lines.iter() {
        let mut words = line.split(' ');
        let op = words.next().unwrap();
        let ra = reg_index(words.next().unwrap());
        let a = regs[ra];
        let b = words.next().map(|rb| reg_value(rb, &regs)).unwrap_or(0);
        match op {
            "inp" => {
                index -= 1;
                regs[ra] = (input / 10usize.pow(index)) as i64 % 10;
            }
            "add" => regs[ra] = a + b,
            "mul" => regs[ra] = a * b,
            "div" => regs[ra] = a / b,
            "mod" => regs[ra] = a % b,
            "eql" => regs[ra] = (a == b) as i64,
            _ => panic!("Invalid instruction {}", line),
        }
    }
    regs[2] == 0
}

fn reg_index(s: &str) -> usize {
    match s {
        "x" => 0,
        "y" => 1,
        "z" => 2,
        "w" => 3,
        c => panic!("Invalid register '{}'", c),
    }
}

fn reg_value(s: &str, regs: &[i64; 4]) -> i64 {
    match s {
        "x" | "y" | "z" | "w" => regs[reg_index(s)],
        i => i.parse().unwrap(),
    }
}

fn main() {
    let lines = std::io::stdin()
        .lock()
        .lines()
        .map(|line| line.unwrap())
        .collect::<Vec<String>>();

    for i in (11111111111111..=99999999999999).rev() {
        // Skip any number with a zero in it
        if (0..14).any(|p| (i / 10usize.pow(p)) % 10 == 0) {
            continue;
        }
        if run(&lines, i) {
            println!("Solved: {}", i);
            break;
        }
    }
}

This interpreter checks about 20K values per second, so to explore the full solution space would take... 3.58 years.

There's some low-hanging fruit here (e.g. the program is reparsed every time), but it's clear that minor adjustments won't suffice; we need a completely different approach.

Very Dumb Brute Force (with code generation)

What if we remove the interpreter altogether?

We can use a build script to parse the input program at compile-time and transform it into Rust code, which is then compiled with the rest of the program.

The generated code looks something like this:

pub fn monad(inp: usize) -> bool {
    let mut x = 0;
    let mut y = 0;
    let mut z = 0;
    let mut w = 0;
    let mut index = 14;

    index -= 1;
    w = (inp / 10usize.pow(index)) as i64; // inp w
    x = x * 0; // mul x 0
    x = x + z; // add x z
    x = x % 26; // mod x 26
    z = z / 1; // div z 1
    x = x + 11; // add x 11
    x = (x == w).into(); // eql x w
    x = (x == 0).into(); // eql x 0
    y = y * 0; // mul y 0
    y = y + 25; // add y 25
    y = y * x; // mul y x
    y = y + 1; // add y 1
    z = z * y; // mul z y
    y = y * 0; // mul y 0
    y = y + w; // add y w
    y = y + 6; // add y 6
    y = y * x; // mul y x
    z = z + y; // add z y

    // ...lots of code elided here...

    index -= 1;
    w = (inp / 10usize.pow(index)) as i64; // inp w
    x = x * 0; // mul x 0
    x = x + z; // add x z
    x = x % 26; // mod x 26
    z = z / 26; // div z 26
    x = x + -2; // add x -2
    x = (x == w).into(); // eql x w
    x = (x == 0).into(); // eql x 0
    y = y * 0; // mul y 0
    y = y + 25; // add y 25
    y = y * x; // mul y x
    y = y + 1; // add y 1
    z = z * y; // mul z y
    y = y * 0; // mul y 0
    y = y + w; // add y w
    y = y + 1; // add y 1
    y = y * x; // mul y x
    z = z + y; // add z y
    z == 0
}

(The main loop stays about the same, so I won't reproduce it again)

Using this compiled function, our brute-force solver will terminate in a mere 17 days. That's a 74× speedup over the initial interpreter, but still a little slow; I'd like to be done by New Year's Eve.

State deduplication

Many of the search problems in Advent of Code can be solved by exploring every option, then deduplicating when you reach a state that you've seen before.

Can we apply this strategy to solving MONAD?

A state can be represented by six values:

Many instructions can reduce the number of states! For example, mul x 0 sets the x register to 0, so {x:1, y:2, z:3, w:4} and {x:10, y:2, z:3, w:4} both become {x:0, y:2, z:3, w:4}.

Only the input (inp) instruction can increase the number of states. For example, inp x causes {x:1, y:2, z:3, w:4} to expand into 9 new states: {x:1, y:2, z:3, w:4}, {x:2, y:2, z:3, w:4}, ..., {x:9, y:2, z:3, w:4}.

It's worth noticing that inp won't necessarily increase the number of states by a full 9×. If your input states are {x:1, y:2, z:3, w:4} and {x:10, y:2, z:3, w:4} , you'll only end up with the 9 new states shown above, not 18.

With all of this explanation out of the way, let's consider our new program:

fn main() {
    let mut state: Vec<([i64; 4], (usize, usize))> = vec![([0; 4], (0, 0))];

    for line in std::io::stdin()
        .lock()
        .lines()
     {
        let line = line.unwrap();
        let mut words = line.split(' ');

        let op = words.next().unwrap();
        let ra = reg_index(words.next().unwrap());
        let rb = words.next().unwrap_or("");
        match op {
            "inp" => {
                // Each state splits into 9 new states,
                // one for each possible input digit.
                let mut next = Vec::with_capacity(state.len() * 9);
                for (regs, (min, max)) in state.iter() {
                    for i in 1..=9 {
                        let mut regs = *regs;
                        regs[ra] = i;
                        let min = min * 10 + i as usize;
                        let max = max * 10 + i as usize;
                        next.push((regs, (min, max)));
                    }
                }
                state = next;
            }
            "add" => {
                for (regs, _) in state.iter_mut() {
                    let a = regs[ra];
                    let b = reg_value(rb, regs);
                    regs[ra] = a + b;
                }
            }
            "mul" => {
                for (regs, _) in state.iter_mut() {
                    let a = regs[ra];
                    let b = reg_value(rb, regs);
                    regs[ra] = a * b;
                }
            }
            "div" => {
                for (regs, _) in state.iter_mut() {
                    let a = regs[ra];
                    let b = reg_value(rb, regs);
                    regs[ra] = a / b;
                }
            }
            "mod" => {
                for (regs, _) in state.iter_mut() {
                    let a = regs[ra];
                    let b = reg_value(rb, regs);
                    regs[ra] = a % b;
                }
            }
            "eql" => {
                for (regs, _) in state.iter_mut() {
                    let a = regs[ra];
                    let b = reg_value(rb, regs);
                    regs[ra] = (a == b) as i64;
                }
            }
            _ => panic!("Invalid instruction {}", line),
        }
        // Deduplicate by accumulating into a HashMap, then
        // pack back into a Vec for further operations.
        let mut dedup = HashMap::new();
        for (state, (min, max)) in state.into_iter() {
            let entry = dedup.entry(state).or_insert((usize::MAX, 0));
            entry.0 = entry.0.min(min);
            entry.1 = entry.1.max(max);
        }
        state = dedup.into_iter().collect();
    }
    let (min, max) = state
        .iter()
        .filter(|(k, _)| k[2] == 0)
        .map(|(_, v)| *v)
        .reduce(|a, b| (a.0.min(b.0), a.1.max(b.1)))
        .unwrap();
    println!("Part 1: {}", max);
    println!("Part 2: {}", min);
}

After every instruction, we deduplicate the state list with a HashMap, tracking the smallest and largest values that led to that state.

This is the first solution that finishes in a reasonable time: 379 seconds.

We can plot the number of states active at every instruction:

states with every-turn deduplication

(note the log scale on the Y axis)

Looking closely, we can see that the number of active states doesn't change all that often, even though we're deduplicating after every instruction. Furthermore, the deduplication is doing a huge amount of work:

Too much hashing

Less frequent deduplication

What if we just... deduplicated less?

Let's deduplicate states right before each inp instruction, since that's when the state count is about to increase:

        match op {
            "inp" => {
                let mut dedup = HashMap::new();
                for (state, (min, max)) in state.into_iter() {
                    let entry = dedup.entry(state).or_insert((usize::MAX, 0));
                    entry.0 = entry.0.min(min);
                    entry.1 = entry.1.max(max);
                }
                state = dedup.into_iter().collect();
                // ...rest of `inp` handling here

This brings us down to 30.2 seconds!

Plotting the number of active states, we can see that more states are active, but because deduplication is so expensive, we're still coming out ahead!

states with sparse deduplication

Smarter deduplication

Remember the discussion earlier, where we noticed that the inp instruction could cause states to merge as well as split?

We can use that to make deduplication even more effective: before adding a new state to the hash map, let's set the register which is about to be overwritten to 0.

This means that if we're about to overwrite x, then {x:1, y:2, z:3, w:4} and {x:10, y:2, z:3, w:4} will be combined in the dedup table.

        match op {
            "inp" => {
                let mut dedup = HashMap::new();
                for (mut state, (min, max)) in state.into_iter() {
                    state[ra] = 0; // <-- This is the new line!
                    let entry = dedup.entry(state).or_insert((usize::MAX, 0));
                    entry.0 = entry.0.min(min);
                    entry.1 = entry.1.max(max);
                }
                state = dedup.into_iter().collect();
                // ...rest of `inp` handling here

This brings our running time down to 21.8 seconds. Looking at the graph, you'll see that we're also doing slightly less work:

States with smart deduplication

In-place deduplication

We can make deduplication faster by skipping the hash table entirely.

Given a sorted list of states, we can do deduplication in a single-pass by keeping two indexes, then merging states with matching registers.

For example, here's a sorted list of states:

{x: 0, y: 1, z: 2, w: 3} {min: 123, max: 531}
{x: 0, y: 1, z: 2, w: 3} {min: 143, max: 571}
{x: 0, y: 1, z: 2, w: 3} {min: 113, max: 481}
{x: 0, y: 3, z: 2, w: 3} {min: 322, max: 991}
{x: 0, y: 3, z: 2, w: 3} {min: 321, max: 989}

(We're about to write to x, so it's been cleared to 0, as discussed above)

In a single pass, we can collapse this into

{x: 0, y: 1, z: 2, w: 3} {min: 113, max: 571}
{x: 0, y: 3, z: 2, w: 3} {min: 321, max: 991}

Sorting is a call to sort_unstable_by_key, using the [x, y, z, w] state as the key.

Here's what this looks like in our program:

        match op {
            "inp" => {
                // Clear the register that's about to be written
                state.iter_mut().for_each(|k| k.0[ra] = 0);

                // Sort by register state
                state.sort_unstable_by_key(|k| k.0);

                // Do single-pass compaction
                let mut i = 0;
                let mut j = 1;
                while j < state.len() {
                    if state[i].0 == state[j].0 {
                        let (imin, imax) = state[i].1;
                        let (jmin, jmax) = state[j].1;
                        state[i].1 = (imin.min(jmin), imax.max(jmax));
                    } else {
                        i += 1;
                        state[i] = state[j];
                    }
                    j += 1;
                }
                assert!(i < state.len());
                state.resize(i + 1, ([0; 4], (0, 0)));

This brings running time down to 17.8 seconds.

Return of the codegen

We're now running the following loop

It turns out we can fuse these two options into one:

This is a subtle change in perspective, but it means that we can use code generation again! Instead of generating a single function for all of MONAD, we'll generate functions for each block, meaning from an inp until the next inp.

For example, the first block is

pub fn block0(regs: [i64; 4], inp: u8) -> [i64; 4] {
    let [mut x, mut y, mut z, mut w] = regs;
    let _ = (x, y, z, w);
    w = inp as i64; // inp w
    x = x * 0; // mul x 0
    x = x + z; // add x z
    x = x % 26; // mod x 26
    z = z / 1; // div z 1
    x = x + 11; // add x 11
    x = (x == w).into(); // eql x w
    x = (x == 0).into(); // eql x 0
    y = y * 0; // mul y 0
    y = y + 25; // add y 25
    y = y * x; // mul y x
    y = y + 1; // add y 1
    z = z * y; // mul z y
    y = y * 0; // mul y 0
    y = y + w; // add y w
    y = y + 6; // add y 6
    y = y * x; // mul y x
    z = z + y; // add z y
    return [x, y, z, w];
}

The code generation also produces arrays of functions (for each block) and input register number (again for each block):

const BLOCKS: [fn(Registers, u8) -> Registers; 14] = [
    block0,
    block1,
    block2,
    // ... etc ...
];

const INPUTS: [usize; 14] = [
    3,
    3,
    3,
    // ... etc ...
];

(The MONAD program always sends inputs to register w, but our solution doesn't rely on that behavior)

With these functions available, the main loop removes the interpreter entirely, and simply calls them one by one:

    for (f, r) in BLOCKS.iter().zip(INPUTS) {
        // ... same deduplication logic as above...
        state = (1..=9)
            .flat_map(|i| {
                state.iter().map(move |(regs, (min, max))| {
                    let min = min * 10 + i as usize;
                    let max = max * 10 + i as usize;
                    (f(*regs, i), (min, max))
                })
            })
            .collect();
    }

This runs in 12.4 seconds.

Throw some threads at it

Rayon makes it very easy to throw parallelism at your problems:

(plus a few more minor tweaks)

This brings our final running time down to 4.2 seconds, another 3× speedup. My machine has 10 cores (8 performance and 2 efficiency), so it's not quite perfect scaling, but it's still an easy win.

At this point, I'll declare victory: this is fast enough.

Conclusions

Here's a list of each optimization and the incremental speedup:

VersionRuntimeSpeedup
Very dumb brute force3.58 years--
Very dumb brute force with codegen17 days74×
State deduplication (every instruction)379 seconds3875×
State deduplication (on inp)30.2 seconds12.5×
Smarter deduplication21.8 seconds1.4×
In-place deduplication17.8 seconds1.2×
Return of the codegen12.4 seconds1.4×
Parallellism4.2 seconds

Overall, we see a 26,880,685× improvement from the initial brute-force solution, without loss of generality!

The final code lives on Github:

Did I miss any more low-hanging fruit? Let me know via email or Twitter!

Appendix I: Things that didn't work

Appendix II: Things that I didn't try

Appendix III: Further reading

This is not the first time I've overengineered the heck out of an Advent of Code problem, then written a long blog post about it. If you enjoyed this writeup, you may also enjoy

(2020 was an easier year, so nothing required dramatic over-engineering)