## Brute-forcing 22 trillion parameters

In this year's Advent of Code, one puzzle stands out: Day 24.

You're given a specification for a fictional computer architecture, with only `x, y, z, w` registers, and a small set of assembly operations to manipulate them.

The puzzle input is a long program ("MONAD") for this computer. MONAD takes a 14-digit input (in base 10) and performs some calculation. The goal of the problem is to find the largest possible 14-digit input for which `z = 0` after the calculation is complete.

It's clear from the problem description that you're not meant to brute force this across the entire input space:

MONAD imposes additional, mysterious restrictions on model numbers, and legend says the last copy of the MONAD documentation was eaten by a tanuki. You'll need to figure out what MONAD does some other way.

Looking at the Solution Megathread, most people reverse-engineered the assembly code to figure out what it was actually doing, then rewrote an optimized function to solve the problem.

This works, but it's not a general solution; can we do better?

This writeup describes a path that brings the solution time from 3.6 years down to 4.2 seconds, with a solution that's completely general-purpose: it can work for any problem input, not just the ones crafted to be reverse-engineered.

### Very Dumb Brute Force

Let's start from the dumbest possible starting point: we'll build an interpreter that evaluates the MONAD code, then start running it on every 14-digit number.

(The problem is limited to numbers without zeros, so the input space is 914 instead of 1014. This means there's a mere 22 trillion options, instead of 100 trillion)

``````use std::io::BufRead;

fn run(lines: &[String], input: usize) -> bool {
let mut index = 14u32;
let mut regs = [0i64; 4];
for line in lines.iter() {
let mut words = line.split(' ');
let op = words.next().unwrap();
let ra = reg_index(words.next().unwrap());
let a = regs[ra];
let b = words.next().map(|rb| reg_value(rb, &regs)).unwrap_or(0);
match op {
"inp" => {
index -= 1;
regs[ra] = (input / 10usize.pow(index)) as i64 % 10;
}
"add" => regs[ra] = a + b,
"mul" => regs[ra] = a * b,
"div" => regs[ra] = a / b,
"mod" => regs[ra] = a % b,
"eql" => regs[ra] = (a == b) as i64,
_ => panic!("Invalid instruction {}", line),
}
}
regs[2] == 0
}

fn reg_index(s: &str) -> usize {
match s {
"x" => 0,
"y" => 1,
"z" => 2,
"w" => 3,
c => panic!("Invalid register '{}'", c),
}
}

fn reg_value(s: &str, regs: &[i64; 4]) -> i64 {
match s {
"x" | "y" | "z" | "w" => regs[reg_index(s)],
i => i.parse().unwrap(),
}
}

fn main() {
let lines = std::io::stdin()
.lock()
.lines()
.map(|line| line.unwrap())
.collect::<Vec<String>>();

for i in (11111111111111..=99999999999999).rev() {
// Skip any number with a zero in it
if (0..14).any(|p| (i / 10usize.pow(p)) % 10 == 0) {
continue;
}
if run(&lines, i) {
println!("Solved: {}", i);
break;
}
}
}
``````

This interpreter checks about 20K values per second, so to explore the full solution space would take... 3.58 years.

There's some low-hanging fruit here (e.g. the program is reparsed every time), but it's clear that minor adjustments won't suffice; we need a completely different approach.

### Very Dumb Brute Force (with code generation)

What if we remove the interpreter altogether?

We can use a build script to parse the input program at compile-time and transform it into Rust code, which is then compiled with the rest of the program.

The generated code looks something like this:

``````pub fn monad(inp: usize) -> bool {
let mut x = 0;
let mut y = 0;
let mut z = 0;
let mut w = 0;
let mut index = 14;

index -= 1;
w = (inp / 10usize.pow(index)) as i64; // inp w
x = x * 0; // mul x 0
x = x + z; // add x z
x = x % 26; // mod x 26
z = z / 1; // div z 1
x = x + 11; // add x 11
x = (x == w).into(); // eql x w
x = (x == 0).into(); // eql x 0
y = y * 0; // mul y 0
y = y + 25; // add y 25
y = y * x; // mul y x
y = y + 1; // add y 1
z = z * y; // mul z y
y = y * 0; // mul y 0
y = y + w; // add y w
y = y + 6; // add y 6
y = y * x; // mul y x
z = z + y; // add z y

// ...lots of code elided here...

index -= 1;
w = (inp / 10usize.pow(index)) as i64; // inp w
x = x * 0; // mul x 0
x = x + z; // add x z
x = x % 26; // mod x 26
z = z / 26; // div z 26
x = x + -2; // add x -2
x = (x == w).into(); // eql x w
x = (x == 0).into(); // eql x 0
y = y * 0; // mul y 0
y = y + 25; // add y 25
y = y * x; // mul y x
y = y + 1; // add y 1
z = z * y; // mul z y
y = y * 0; // mul y 0
y = y + w; // add y w
y = y + 1; // add y 1
y = y * x; // mul y x
z = z + y; // add z y
z == 0
}
``````

(The main loop stays about the same, so I won't reproduce it again)

Using this compiled function, our brute-force solver will terminate in a mere 17 days. That's a 74× speedup over the initial interpreter, but still a little slow; I'd like to be done by New Year's Eve.

### State deduplication

Many of the search problems in Advent of Code can be solved by exploring every option, then deduplicating when you reach a state that you've seen before.

Can we apply this strategy to solving MONAD?

A state can be represented by six values:

• The four registers values (all `i64`)
• The highest and lowest input (both `usize`) that produces this state (since this is what's eventually asked for as your puzzle solution)

Many instructions can reduce the number of states! For example, `mul x 0` sets the `x` register to `0`, so `{x:1, y:2, z:3, w:4}` and `{x:10, y:2, z:3, w:4}` both become `{x:0, y:2, z:3, w:4}`.

Only the input (`inp`) instruction can increase the number of states. For example, `inp x` causes `{x:1, y:2, z:3, w:4}` to expand into 9 new states: `{x:1, y:2, z:3, w:4}, {x:2, y:2, z:3, w:4}, ..., {x:9, y:2, z:3, w:4}.`

It's worth noticing that `inp` won't necessarily increase the number of states by a full 9×. If your input states are `{x:1, y:2, z:3, w:4}` and `{x:10, y:2, z:3, w:4}` , you'll only end up with the 9 new states shown above, not 18.

With all of this explanation out of the way, let's consider our new program:

``````fn main() {
let mut state: Vec<([i64; 4], (usize, usize))> = vec![([0; 4], (0, 0))];

for line in std::io::stdin()
.lock()
.lines()
{
let line = line.unwrap();
let mut words = line.split(' ');

let op = words.next().unwrap();
let ra = reg_index(words.next().unwrap());
let rb = words.next().unwrap_or("");
match op {
"inp" => {
// Each state splits into 9 new states,
// one for each possible input digit.
let mut next = Vec::with_capacity(state.len() * 9);
for (regs, (min, max)) in state.iter() {
for i in 1..=9 {
let mut regs = *regs;
regs[ra] = i;
let min = min * 10 + i as usize;
let max = max * 10 + i as usize;
next.push((regs, (min, max)));
}
}
state = next;
}
for (regs, _) in state.iter_mut() {
let a = regs[ra];
let b = reg_value(rb, regs);
regs[ra] = a + b;
}
}
"mul" => {
for (regs, _) in state.iter_mut() {
let a = regs[ra];
let b = reg_value(rb, regs);
regs[ra] = a * b;
}
}
"div" => {
for (regs, _) in state.iter_mut() {
let a = regs[ra];
let b = reg_value(rb, regs);
regs[ra] = a / b;
}
}
"mod" => {
for (regs, _) in state.iter_mut() {
let a = regs[ra];
let b = reg_value(rb, regs);
regs[ra] = a % b;
}
}
"eql" => {
for (regs, _) in state.iter_mut() {
let a = regs[ra];
let b = reg_value(rb, regs);
regs[ra] = (a == b) as i64;
}
}
_ => panic!("Invalid instruction {}", line),
}
// Deduplicate by accumulating into a HashMap, then
// pack back into a Vec for further operations.
let mut dedup = HashMap::new();
for (state, (min, max)) in state.into_iter() {
let entry = dedup.entry(state).or_insert((usize::MAX, 0));
entry.0 = entry.0.min(min);
entry.1 = entry.1.max(max);
}
state = dedup.into_iter().collect();
}
let (min, max) = state
.iter()
.filter(|(k, _)| k[2] == 0)
.map(|(_, v)| *v)
.reduce(|a, b| (a.0.min(b.0), a.1.max(b.1)))
.unwrap();
println!("Part 1: {}", max);
println!("Part 2: {}", min);
}
``````

After every instruction, we deduplicate the state list with a `HashMap`, tracking the smallest and largest values that led to that state.

This is the first solution that finishes in a reasonable time: 379 seconds.

We can plot the number of states active at every instruction:

(note the log scale on the Y axis)

Looking closely, we can see that the number of active states doesn't change all that often, even though we're deduplicating after every instruction. Furthermore, the deduplication is doing a huge amount of work:

### Less frequent deduplication

What if we just... deduplicated less?

Let's deduplicate states right before each `inp` instruction, since that's when the state count is about to increase:

``````        match op {
"inp" => {
let mut dedup = HashMap::new();
for (state, (min, max)) in state.into_iter() {
let entry = dedup.entry(state).or_insert((usize::MAX, 0));
entry.0 = entry.0.min(min);
entry.1 = entry.1.max(max);
}
state = dedup.into_iter().collect();
// ...rest of `inp` handling here
``````

This brings us down to 30.2 seconds!

Plotting the number of active states, we can see that more states are active, but because deduplication is so expensive, we're still coming out ahead!

### Smarter deduplication

Remember the discussion earlier, where we noticed that the `inp` instruction could cause states to merge as well as split?

We can use that to make deduplication even more effective: before adding a new state to the hash map, let's set the register which is about to be overwritten to 0.

This means that if we're about to overwrite `x`, then `{x:1, y:2, z:3, w:4}` and `{x:10, y:2, z:3, w:4}` will be combined in the `dedup` table.

``````        match op {
"inp" => {
let mut dedup = HashMap::new();
for (mut state, (min, max)) in state.into_iter() {
state[ra] = 0; // <-- This is the new line!
let entry = dedup.entry(state).or_insert((usize::MAX, 0));
entry.0 = entry.0.min(min);
entry.1 = entry.1.max(max);
}
state = dedup.into_iter().collect();
// ...rest of `inp` handling here
``````

This brings our running time down to 21.8 seconds. Looking at the graph, you'll see that we're also doing slightly less work:

### In-place deduplication

We can make deduplication faster by skipping the hash table entirely.

Given a sorted list of states, we can do deduplication in a single-pass by keeping two indexes, then merging states with matching registers.

For example, here's a sorted list of states:

``````{x: 0, y: 1, z: 2, w: 3} {min: 123, max: 531}
{x: 0, y: 1, z: 2, w: 3} {min: 143, max: 571}
{x: 0, y: 1, z: 2, w: 3} {min: 113, max: 481}
{x: 0, y: 3, z: 2, w: 3} {min: 322, max: 991}
{x: 0, y: 3, z: 2, w: 3} {min: 321, max: 989}
``````

(We're about to write to `x`, so it's been cleared to `0`, as discussed above)

In a single pass, we can collapse this into

``````{x: 0, y: 1, z: 2, w: 3} {min: 113, max: 571}
{x: 0, y: 3, z: 2, w: 3} {min: 321, max: 991}
``````

Sorting is a call to `sort_unstable_by_key`, using the `[x, y, z, w]` state as the key.

Here's what this looks like in our program:

``````        match op {
"inp" => {
// Clear the register that's about to be written
state.iter_mut().for_each(|k| k.0[ra] = 0);

// Sort by register state
state.sort_unstable_by_key(|k| k.0);

// Do single-pass compaction
let mut i = 0;
let mut j = 1;
while j < state.len() {
if state[i].0 == state[j].0 {
let (imin, imax) = state[i].1;
let (jmin, jmax) = state[j].1;
state[i].1 = (imin.min(jmin), imax.max(jmax));
} else {
i += 1;
state[i] = state[j];
}
j += 1;
}
assert!(i < state.len());
state.resize(i + 1, ([0; 4], (0, 0)));
``````

This brings running time down to 17.8 seconds.

## Return of the codegen

We're now running the following loop

• For each `inp` instruction, deduplicate then expand states
• For every other instruction, evaluate it on every existing state

It turns out we can fuse these two options into one:

• For each `inp` instruction, deduplicate then expand states; then evaluate every instruction until the next `inp` on every existing state.

This is a subtle change in perspective, but it means that we can use code generation again! Instead of generating a single function for all of MONAD, we'll generate functions for each block, meaning from an `inp` until the next `inp`.

For example, the first block is

``````pub fn block0(regs: [i64; 4], inp: u8) -> [i64; 4] {
let [mut x, mut y, mut z, mut w] = regs;
let _ = (x, y, z, w);
w = inp as i64; // inp w
x = x * 0; // mul x 0
x = x + z; // add x z
x = x % 26; // mod x 26
z = z / 1; // div z 1
x = x + 11; // add x 11
x = (x == w).into(); // eql x w
x = (x == 0).into(); // eql x 0
y = y * 0; // mul y 0
y = y + 25; // add y 25
y = y * x; // mul y x
y = y + 1; // add y 1
z = z * y; // mul z y
y = y * 0; // mul y 0
y = y + w; // add y w
y = y + 6; // add y 6
y = y * x; // mul y x
z = z + y; // add z y
return [x, y, z, w];
}
``````

The code generation also produces arrays of functions (for each block) and input register number (again for each block):

``````const BLOCKS: [fn(Registers, u8) -> Registers; 14] = [
block0,
block1,
block2,
// ... etc ...
];

const INPUTS: [usize; 14] = [
3,
3,
3,
// ... etc ...
];
``````

(The MONAD program always sends inputs to register `w`, but our solution doesn't rely on that behavior)

With these functions available, the main loop removes the interpreter entirely, and simply calls them one by one:

``````    for (f, r) in BLOCKS.iter().zip(INPUTS) {
// ... same deduplication logic as above...
state = (1..=9)
.flat_map(|i| {
state.iter().map(move |(regs, (min, max))| {
let min = min * 10 + i as usize;
let max = max * 10 + i as usize;
(f(*regs, i), (min, max))
})
})
.collect();
}
``````

This runs in 12.4 seconds.

### Throw some threads at it

Rayon makes it very easy to throw parallelism at your problems:

• `iter` becomes `par_iter`
• `iter_mut` becomes `par_iter_mut`
• `sort_unstable_by_key` becomes `par_sort_unstable_by_key`

(plus a few more minor tweaks)

This brings our final running time down to 4.2 seconds, another 3× speedup. My machine has 10 cores (8 performance and 2 efficiency), so it's not quite perfect scaling, but it's still an easy win.

At this point, I'll declare victory: this is fast enough.

### Conclusions

Here's a list of each optimization and the incremental speedup:

VersionRuntimeSpeedup
Very dumb brute force3.58 years--
Very dumb brute force with codegen17 days74×
State deduplication (every instruction)379 seconds3875×
State deduplication (on `inp`)30.2 seconds12.5×
Smarter deduplication21.8 seconds1.4×
In-place deduplication17.8 seconds1.2×
Return of the codegen12.4 seconds1.4×
Parallellism4.2 seconds

Overall, we see a 26,880,685× improvement from the initial brute-force solution, without loss of generality!

The final code lives on Github:

Did I miss any more low-hanging fruit? Let me know via email or Twitter!

### Appendix I: Things that didn't work

• Using a better hash table for faster deduplication. In my testing, `par_unstable_sort_by_key` and single-pass deduplication are hard to beat, even when I try to use Rayon's `fold` / `reduce` to build and merge hash maps in parallel. It's possible that someone else could make this work; one of my coworkers wrote a single-threaded version that runs in 7.7 seconds using hashbrown, which is very close!
• Tracking register values using interval arithmetic; this is actually how I solved the problem day-of, but only worked due to a bug in my program. Here's someone else's interval-based code, which presumably works, but I'm haven't dug into the code to see exactly what they're doing.
• Using `i32` instead of `i64` for register state: this speeds up evaluation (down to 3.18 seconds) and gets the same answer, but running in debug mode reveals that integers are overflowing all over the place. Since I'm trying to write a general-purpose solution, that's a hard no.

### Appendix II: Things that I didn't try

• Using an explicitly-concurrent hashmap for deduplication, e.g. DashMap.
• Doing something on the GPU (??)
• Using a SAT/SMT solver (but folks on Reddit succeeded with this strategy).
• Fine-tuning where exactly to do deduplication for peak performance; doing it at each `inp` was a good-enough solution, and tuning it farther felt like overfitting on the input.