What is the Convex Hull Trick?
What is the Convex Hull Trick? (2020-06-14)
In [1]:
# TITLE: What is the Convex Hull Trick?
# COVER: https://i.imgur.com/8Xb6NSb.png
# DATE: 2020-06-14
# TAGS: algorithms,convex hull trick
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from functools import lru_cache
%matplotlib inline
/home/stewart/.pyenv/versions/3.7.5/lib/python3.7/site-packages/pandas/compat/__init__.py:117: UserWarning: Could not import the lzma module. Your installed Python is incomplete. Attempting to use lzma compression will result in a RuntimeError.
  warnings.warn(msg)

I had a chance to poke at a hard algorithm problem recently and miserably failed. This bothered me for quite some time and someone told me that the problem I was struggling with can be easily solved with something advanced I've only heard of: the convex hull trick. So I decided to learn this. But it seems like to me that every material I can find on the internet assumes someone who's pretty already knowledgeable in this whole algorithm stuff. So here's my version of understanding—working knowledge, if you will. If you want to skip the problem part and get to the explanation of the convex hull trick, please click here.

The problem that broke me: APIO10A Commando

The problem I was looking at was in Korean on a Korean website, but this is the same problem in English: Link

Basically, the gist of the problem is that you have an array of integers ($1 \leq x_i \leq100$) like below:

In [2]:
x = (2, 2, 3, 4)
x
Out[2]:
(2, 2, 3, 4)

The number of integers can go up to 1,000,000, but the above example is just 4 integers.

Your goal here is to slice the above array so that the sum of the adjusted sums of slices of the array is the greatest. You adjust the sum by plugging $f(x) = ax^2+bx+c$ in, where $a$, $b$, $c$ are given from the input. ($-5 \leq a \leq -1$, $|b| \leq 10,000,000$, $|c| \leq 10,000,000$)

Let $a=-1$, $b=10$, $c=-20$. The following happens to get the answer:

In [3]:
def get_all_combinations(x):
    if not len(x):
        return []
    rst = [[x]] 
    for i in range(1, len(x) + 1):
        for y in get_all_combinations(x[i:]):
            rst.append([x[:i]] + y)
    return rst

def get_adjusted_sum(xs):
    S = 0
    for x in xs:
        s = sum(x)
        S += a * s ** 2 + b * s + c
    return S
    
a, b, c = -1, 10, -20
df = pd.DataFrame(
    [
        (x_, get_adjusted_sum(x_)) 
        for x_ in sorted(get_all_combinations(tuple(x)), key=lambda x: len(x))
    ],
    columns=["Slice Combination", "Sum of Adjusted Sums"]
)

def hl(x):
    if x == [(2, 2), (3,), (4,)] or x == 9:
        return 'color: red'
    return 'color: unset'
df.style.applymap(hl)
Out[3]:
Slice Combination Sum of Adjusted Sums
0 [(2, 2, 3, 4)] -31
1 [(2,), (2, 3, 4)] -15
2 [(2, 2), (3, 4)] 5
3 [(2, 2, 3), (4,)] 5
4 [(2,), (2,), (3, 4)] -7
5 [(2,), (2, 3), (4,)] 5
6 [(2, 2), (3,), (4,)] 9
7 [(2,), (2,), (3,), (4,)] -3

You can see that the combination at Row 6 has the maximum value, 9. At first glance, this seems like an easy problem. Don't let the above simple example trick you. If the array gets longer, this becomes a non-trival problem. If you'd like a challenge, I suggest that you stop here and go poke at this problem on your own. It's fun!

You probably understood the problem by now. Let's solve this in a naive way together first.

The naive way

What I always like as my go-to first approach is to write a recursive function. I get better understanding of the problem by doing so because it lets you write things in a more stateless way (i.e. no state variables that control if statements everywhere) therefore it helps you formulate a math equation as the result of the process. And if you memoize the function, you can go pretty far with it, too.

Of course, we won't be using a recursive function for the final solution, since the stack size will be the limiting factor. But let's just go with it for now.

In [4]:
adjusted = lambda x: a * x ** 2 + b * x + c

@lru_cache(maxsize=1024)
def f(i):
    if i == 0:
        return 0
    return max(f(j) + adjusted(sum(x[j:i])) for j in range(i))

n = 4
a, b, c = -1, 10, -20
x = (2, 2, 3, 4)
%time f(n)
CPU times: user 19 µs, sys: 6 µs, total: 25 µs
Wall time: 27.7 µs
Out[4]:
9

The above recursive function translates into the following equation:

$$ \begin{align*} f(0) &= 0 \\ f(i) &= \max_{j<i}{\{f(j) + ay^2 + by + c\}} \text{, where } y = \sum_{k=1}^{j}{x_k} \end{align*} $$

This is bad because it means that we have to:

  1. Recursively call the memoized function for the previous $i$'s
  2. Maximize the given equation
  3. Sum the integers for each equation inside the above. (Calculate $y$)

This makes our solution's time complexity $O(n^3)$.

Well, normally this naive approach makes me at least be able to barely pass most of the medium difficulty or even sometimes hard difficulty problems, but this time, it didn't work even for 5,000 integers. (The last test case has 1,000,000 integers...)

If you run the above solution with 1,000 integers:

In [5]:
f.cache_clear() # clear memoized values
with open('data/commando_1000.txt') as file:
    n, a, b, c, *x = [int(x) for x in file.read().split()]

%time f(n)
CPU times: user 1.81 s, sys: 0 ns, total: 1.81 s
Wall time: 1.82 s
Out[5]:
1588599

It now takes ~2s vs. ~30µs ... Obviously, we need a better way.

The clever way

So the next approach I took is to pre-calculate the partial sums of $x$ and subsituting $y$ with a term that uses the partial sums. Let's define the array of partial sums like below:

$$ \begin{align*} S_0 &= 0 \\ S_i &= S_{i-1} + x_i \end{align*} $$

If you have the above array pre-built, you can replace $y$ easily with the property demonstrated by the below code:

In [6]:
A = [1, 2, 3, 4, 5]
s = sum(A[3:5])                   # sum([4, 5])
s_prime = sum(A[:5]) - sum(A[:3]) # sum([1, 2, 3, 4, 5]) - sum([1, 2, 3])

s == s_prime
Out[6]:
True

The revised equation using the above is the following:

$$ \begin{align*} f(0) &= 0 \\ f(i) &= \max_{j<i}{\{f(j) + a(S_i - S_j)^2 + b(S_i - S_j) + c\}} \end{align*} $$

Now, the summation for each iteration is gone, which makes our solution's time complexity $O(n^2)$.

This was the part where I was able to get to without knowing the later discussed algorithm, and I felt pretty good!

In [7]:
@lru_cache(maxsize=1024)
def f(i):
    if i == 0:
        return 0
    return max(f(j) + adjusted(S[i] - S[j]) for j in range(i))

S = [0]
for x_i in x:
    S.append(S[-1] + x_i)
    
%time f(n)
CPU times: user 356 ms, sys: 985 µs, total: 357 ms
Wall time: 358 ms
Out[7]:
1588599

Now it's ~400ms! Not bad!

THE way

It's pretty good, but the above solution wasn't even close to solve it for 1,000,000 integers. Couldn't even get it to get the result no matter how long I waited... Now it's big brain time.

Big Brain Time

The convex hull trick

Before getting to how it works, let's first talk about what it solves. Simply speaking, at least the way I understand it, it's a way to find the maximum value at $x$ out of all given linear functions in a constant time if some conditions are met. If your function looks like below, you can plug this in to optimize it:

$$ f(i) = g(i) + \max_{j}{\{a_j x_i + b_j \}} \text{, where } a_j > 0 \text{ and } x \text{ is in ascending order} $$

The $a_j x_i + b_j$ part is the linear functions given.

The idea of the convex hull trick is the following:

  • The convex hull trick is basically a data structure that ...
    • stores line functions as pairs of $(a, b)$, and
    • has two operations: add line and query
  • When you query for the value at $x$, it removes unneeded lines and gets the value from the line function that yields the maximum value.
    • In the bottom figure, the orange line is the line that gets evaluated at $x = 0$
    • When $x \geq 3$, the orange line gets deleted because we don't need that anymore.
      • In order to be able to do this, the order of queries made for $x_i$ needs to be in ascending order (e.g. query(x=1), query(x=2), ...). That's why $x$ has to be in ascending order.
    • It's basically always maintaining the line functions that return the maximum value in the given domain. If you draw only the maximum value of this collection of line functions, it looks like a lower hull of a convex hull. Hence the name, the convex hull trick.
  • When you add a new line, it checks if the $x$ value of the intersection between the new line and the second last line is greater than the $x$ value of the intersection between the last line and the second last line.
    • This is to check if the last line we added can be removed in lieu of the new line we're adding.
    • We can do this because with the condition $a > 0$ and the order of queries is in ascending order (therefore, also the order of adding is in ascending order), we're only maintaining the right side of a lower hull (right side of the shape U), we can easily compare the intersection points and determine that it's not needed, instead of handling all the cases for the left side of the lower hull.

By implementing the above data structure, we can only check a small number of lines (most of the time, 1 or 2) for each $\max$ loop, instead of going through all lines. Pretty awesome, isn't it? In the below figure, we'd only keep the purple line at $x = 4$, instead of checking all of them. Constant time, baby!!

In [8]:
X = [1, 2, 3, 4, 5]
A = [1, 2, 4, 2, 9]
B = [10, 20, -10, 9, -8]

f = lambda a, x, b: a * x + b

L = list([f(a, x, b) for x in X] for a, b in zip(A, B))
for l in L:
    plt.plot(l)

If we implement the above description, we get the following code:

In [9]:
class ConvexHullTrick:
    def __init__(self, lines):
        # Initial line
        self.lines = lines

    def add(self, a, b):
        self.lines.append((a, b))
        while len(self.lines) >= 3:
            new = self.lines[-1]
            top = self.lines[-2]
            base = self.lines[-3]
            if (base[1] - new[1]) * (top[0] - base[0]) <= (base[1] - top[1]) * (new[0] - base[0]): # (1)
                self.lines.pop(-2)
            else:
                break

    def query(self, x):
        while len(self.lines) >= 2:
            a, b = self.lines[0]
            c, d = self.lines[1]
            if a * x + b < c * x + d:
                self.lines.pop(0)
            else:
                break
        if self.lines:
            a, b = self.lines[0]
            return a * x + b
        return None

Most of the code should make sense if you read both the above code and the idea of the convex hull trick I wrote above together. One thing I didn't cover in the above is the (1) part. It's basically checking the intersection of two lines, but how I arrived at it is below:

Let $l$, $m$ be lines as pairs of $(a, b)$.

$$ \begin{align*} l_ax + l_b &= m_ax + m_b \\ (l_a - m_a)x &= m_b - l_b \\ x &= \frac{m_b - l_b}{l_a - m_a} \end{align*} $$

Now to get the intersections from new and base and from top and base and check if the last line needs to be removed, the equation becomes like this:

Let $l$, $m$, $n$ be new, top, base.

$$ \begin{align*} \frac{n_b - l_b}{l_a - n_a} &\geq \frac{n_b - m_b}{m_a - n_a} \end{align*} $$

Perfect! but the only problem with the above is that in computer programs, divisions will produce float values, or worse, sometimes in languages like C, this will do integer divisions instead. Both ways will sacrifice the precision of the comparison either slightly or greatly. We can easily avoid this uncomfortable situation by multiplying both sides by $(l_a - n_a)(m_a - n_a)$.

$$ \begin{align*} (n_b - l_b)(m_a - n_a) &\leq (n_b - m_b)(l_a - n_a) \end{align*} $$

I made it as a class so someone can just copy it and use it. 🙄 But how do we use it?

How to plug in

You might be asking, our equation doesn't look like the case where we can use the above! There are no linear functions in there!

Let's take a look again:

$$ \begin{align*} f(0) &= 0 \\ f(i) &= \max_{j<i}{\{f(j) + a(S_i - S_j)^2 + b(S_i - S_j) + c\}} \end{align*} $$

If you tinker with the above equation a little bit, you can make that inside part linear functions.

First, expand all the parentheses and exponetials:

$$ \begin{align*} f(i) &= \max_{j<i}{\{f(j) + aS_i^2 - 2aS_iS_j + aS_j^2 + bS_i - bS_j + c\}} \end{align*} $$

Then, take all the terms that don't involve $j$ out of the $\max$ since it doesn't need to be there:

$$ \begin{align*} &= \max_{j<i}{\{f(j) - 2 a S_i S_j + a S_j^2 - bS_j\}} + aS_i^2 + bS_i + c \end{align*} $$

And viola! you should be able to make this look like linear functions:

$$ \begin{align*} &= \max_{j<i}{\{\color{red}{(-2aS_j)} S_i + \color{blue}{(aS_j^2 - bS_j + f(j))}\}} + aS_i^2 + bS_i + c \end{align*} $$

Now we have to make sure the special conditions can be met:

  1. $-2aS_j$ is always positive because $-5\leq a \leq-1$ is given by the problem definition.
  2. $S_i$ is in ascending order because $S_i$ is partial sums out of an array of positive integers ($1 \leq x \leq 100$).

Therefore, our solution can be optimized by the convex hull trick. (Well, by the simpler version I learned. Apparently, there are more advanced algorithms with different restrictions.)

Now, the function to optimize is defined:

$$ \begin{align*} f(i) = g(i) + \max_{j}{\{a_j x_i + b_j \}} \text{, where } & a_j = -2aS_j \\ & x_i = S_i \\ & b_j = aS_j^2 - bS_j + f(j) \\ & g(i) = aS_i^2 + bS_i + c \end{align*} $$

Let's try to solve it for 1,000 integers again first:

In [10]:
with open('data/commando_1000.txt') as file:
    n, a, b, c, *x = [int(x) for x in file.read().split()]

S = [0]
for x_i in x:
    S.append(S[-1] + x_i)
    
f   = [0]
a_j = lambda j: -2 * a * S[j]
b_j = lambda j: a * S[j] ** 2 - b * S[j] + f[j]
g   = lambda i: a * S[i] ** 2 + b * S[i] + c

def solve():
    convhull = ConvexHullTrick([(0, 0)])
    for i in range(1, n + 1):
        f.append(convhull.query(S[i]) + g(i))
        convhull.add(a_j(i), b_j(i))
    return f[-1]

%time solve()
CPU times: user 2.93 ms, sys: 0 ns, total: 2.93 ms
Wall time: 2.94 ms
Out[10]:
1588599

~400ms -> 2.94ms 💪💪💪💪💪💪

Let's run it with 1,000,000 integers now:

In [11]:
with open('data/commando_1000000.txt') as file:
    n, a, b, c, *x = [int(x) for x in file.read().split()]

f = [0]
S = [0]
for x_i in x:
    S.append(S[-1] + x_i)

%time solve()
CPU times: user 2.68 s, sys: 20.4 ms, total: 2.7 s
Wall time: 2.71 s
Out[11]:
2616046388

Magic

It ran, and it finished in 2.71s!

I hope you find this post helpful.