Tuesday, February 14, 2012

Continuations in C++ with fork

[Update, Jan 2015: I've translated this code into Rust.]

While reading "Continuations in C" I came across an intriguing idea:

It is possible to simulate call/cc, or something like it, on Unix systems with system calls like fork() that literally duplicate the running process.

The author sets this idea aside, and instead discusses some code that uses setjmp/longjmp and stack copying. And there are several other continuation-like constructs available for C, such as POSIX getcontext. But the idea of implementing call/cc with fork stuck with me, if only for its amusement value. I'd seen fork used for computing with probability distributions, but I couldn't find an implementation of call/cc itself. So I decided to give it a shot, using my favorite esolang, C++.

Continuations are a famously mind-bending idea, and this article doesn't totally explain what they are or what they're good for. If you aren't familiar with continuations, you might catch on from the examples, or you might want to consult another source first (1, 2, 3, 4, 5, 6).

Small examples

I'll get to the implementation later, but right now let's see what these fork-based continuations can do. The interface looks like this.

template <typename T>
class cont {
public:
void operator()(const T &x);
};

template <typename T>
T call_cc( std::function< T (cont<T>) > f );

std::function is a wrapper that can hold function-like values, such as function objects or C-style function pointers. So call_cc<T> will accept any function-like value that takes an argument of type cont<T> and returns a value of type T. This wrapper is the first of several C++11 features we'll use.

call_cc stands for "call with current continuation", and that's exactly what it does. call_cc(f) will call f, and return whatever f returns. The interesting part is that it passes to f an instance of our cont class, which represents all the stuff that's going to happen in the program after f returns. That cont object overloads operator() and so can be called like a function. If it's called with some argument x, the program behaves as though f had returned x.

The types reflect this usage. The type parameter T in cont<T> is the return type of the function passed to call_cc. It's also the type of values accepted by cont<T>::operator().

Here's a small example.

int f(cont<int> k) {
std::cout << "f called" << std::endl;
k(1);
std::cout << "k returns" << std::endl;
return 0;
}

int main() {
std::cout << "f returns " << call_cc<int>(f) << std::endl;
}

When we run this code we get:

f called
f returns 1

We don't see the "k returns" message. Instead, calling k(1) bails out of f early, and forces it to return 1. This would happen even if we passed k to some deeply nested function call, and invoked it there.

This nonlocal return is kind of like throwing an exception, and is not that surprising. More exciting things happen if a continuation outlives the function call it came from.

boost::optional< cont<int> > global_k;

int g(cont<int> k) {
std::cout << "g called" << std::endl;
global_k = k;
return 0;
}

int main() {
std::cout << "g returns " << call_cc<int>(g) << std::endl;

if (global_k)
(*global_k)(1);
}

When we run this, we get:

g called
g returns 0
g returns 1

g is called once, and returns twice! When called, g saves the current continuation in a global variable. After g returns, main calls that continuation, and g returns again with a different value.

What value should global_k have before g is called? There's no such thing as a "default" or "uninitialized" cont<T>. We solve this problem by wrapping it with boost::optional. We use the resulting object much like a pointer, checking for "null" and then dereferencing. The difference is that boost::optional manages storage for the underlying value, if any.

Why isn't this code an infinite loop? Because invoking a cont<T> also resets global state to the values it had when the continuation was captured. The second time g returns, global_k has been reset to the "null" optional value. This is unlike Scheme's call/cc and most other continuation systems. It turns out to be a serious limitation, though it's sometimes convenient. The reason for this behavior is that invoking a continuation is implemented as a transfer of control to another process. More on that later.

Backtracking

We can use continuations to implement backtracking, as found in logic programming languages. Here is a suitable interface.

bool guess();
void fail();

We will use guess as though it has a magical ability to predict the future. We assume it will only return true if doing so results in a program that never calls fail. Here is the implementation.

boost::optional< cont<bool> > checkpoint;

bool guess() {
return call_cc<bool>( [](cont<bool> k) {
checkpoint = k;
return true;
} );
}

void fail() {
if (checkpoint) {
(*checkpoint)(false);
} else {
std::cerr << "Nothing to be done." << std::endl;
exit(1);
}
}

guess invokes call_cc on a lambda expression, which saves the current continuation and returns true. A subsequent call to fail will invoke this continuation, retrying execution in a world where guess had returned false instead. In Scheme et al, we would store a whole stack of continuations. But invoking our cont<bool> resets global state, including the checkpoint variable itself, so we only need to explicitly track the most recent continuation.

Now we can implement the integer factoring example from "Continuations in C".

int integer(int m, int n) {
for (int i=m; i<=n; i++) {
if (guess())
return i;
}
fail();
}

void factor(int n) {
const int i = integer(2, 100);
const int j = integer(2, 100);

if (i*j != n)
fail();

std::cout << i << " * " << j << " = " << n << std::endl;
}

factor(n) will guess two integers, and fail if their product is not n. Calling factor(391) will produce the output

17 * 23 = 391

after a moment's delay. In fact, you might see this after your shell prompt has returned, because the output is produced by a thousand-generation descendant of the process your shell created.

Solving a maze

For a more substantial use of backtracking, let's solve a maze.

const int maze_size = 15;
char maze[] =
"X-------------+\n"
" | |\n"
"|--+ | | | |\n"
"| | | | --+ |\n"
"| | | |\n"
"|-+---+--+- | |\n"
"| | | |\n"
"| | | ---+-+- |\n"
"| | | |\n"
"| +-+-+--| |\n"
"| | | |--- |\n"
"| | |\n"
"|--- -+-------|\n"
"| \n"
"+------------- \n";

void solve_maze() {
int x=0, y=0;

while ((x != maze_size-1)
|| (y != maze_size-1)) {

if (guess()) x++;
else if (guess()) x--;
else if (guess()) y++;
else y--;

if ( (x < 0) || (x >= maze_size) ||
(y < 0) || (y >= maze_size) )
fail();

const int i = y*(maze_size+1) + x;
if (maze[i] != ' ')
fail();
maze[i] = 'X';
}

for (char c : maze) {
if (c == 'X')
std::cout << "\e[1;32mX\e[0m";
else
std::cout << c;
}
}

Whether code or prose, the algorithm is pretty simple. Start at the upper-left corner. As long as we haven't reached the lower-right corner, guess a direction to move. Fail if we go off the edge, run into a wall, or find ourselves on a square we already visited.

Once we've reached the goal, we iterate over the char array and print it out with some rad ANSI color codes.

Once again, we're making good use of the fact that our continuations reset global state. That's why we see 'X' marks not on the failed detours, but only on a successful path through the maze. Here's what it looks like.


X-------------+
XXXXXXXX|     |
|--+  |X|   | |
|  |  |X| --+ |
|     |XXXXX| |
|-+---+--+-X| |
| |XXX   | XXX|
| |X|X---+-+-X|
|XXX|XXXXXX|XX|
|X+-+-+--|XXX |
|X|   |  |--- |
|XXXX |       |
|---X-+-------|
|   XXXXXXXXXXX
+-------------X

Excess backtracking

We can run both examples in a single program.

int main() {
factor(391);
solve_maze();
}

If we change the maze to be unsolvable, we'll get:

17 * 23 = 391
23 * 17 = 391
Nothing to be done.

Factoring 391 a different way won't change the maze layout, but the program doesn't know that. We can add a cut primitive to eliminate unwanted backtracking.

void cut() {
checkpoint = boost::none;
}

int main() {
factor(391);
cut();
solve_maze();
}

The implementation

For such a crazy idea, the code to implement call_cc with fork is actually pretty reasonable. Here's the core of it.

template <typename T>
// static
T cont<T>::call_cc(call_cc_arg f) {
int fd[2];
pipe(fd);
int read_fd = fd[0];
int write_fd = fd[1];

if (fork()) {
// parent
close(read_fd);
return f( cont<T>(write_fd) );
} else {
// child
close(write_fd);
char buf[sizeof(T)];
if (read(read_fd, buf, sizeof(T)) < ssize_t(sizeof(T)))
exit(0);
close(read_fd);
return *reinterpret_cast<T*>(buf);
}
}

template <typename T>
void cont<T>::impl::invoke(const T &x) {
write(m_pipe, &x, sizeof(T));
exit(0);
}

To capture a continuation, we fork the process. The resulting processes share a pipe which was created before the fork. The parent process will call f immediately, passing a cont<T> object that holds onto the write end of this pipe. If that continuation is invoked with some argument x, the parent process will send x down the pipe and then exit. The child process wakes up from its read call, and returns x from call_cc.

There are a few more implementation details.

  • If the parent process exits, it will close the write end of the pipe, and the child's read will return 0, i.e. end-of-file. This prevents a buildup of unused continuation processes. But what if the parent deletes the last copy of some cont<T>, yet keeps running? We'd like to kill the corresponding child process immediately.

    This sounds like a use for a reference-counted smart pointer, but we want to hide this detail from the user. So we split off a private implementation class, cont<T>::impl, with a destructor that calls close. The user-facing class cont<T> holds a std::shared_ptr to a cont<T>::impl. And cont<T>::operator() simply calls cont<T>::impl::invoke through this pointer.

  • It would be nice to tell the compiler that cont<T>::operator() won't return, to avoid warnings like "control reaches end of non-void function". GCC provides the noreturn attribute for this purpose.

  • We want the cont<T> constructor to be private, so we had to make call_cc a static member function of that class. But the examples above use a free function call_cc<T>. It's easiest to implement the latter as a 1-line function that calls the former. The alternative is to make it a friend function of cont<T>, which requires some forward declarations and other noise.

There are a number of limitations too.

  • As noted, the forked child process doesn't see changes to the parent's global state. This precludes some interesting uses of continuations, like implementing coroutines. In fact, I had trouble coming up with any application other than backtracking. You could work around this limitation with shared memory, but it seemed like too much hassle.

  • Each captured continuation can only be invoked once. This is easiest to observe if the code using continuations also invokes fork directly. It could possibly be fixed with additional forking inside call_cc.

  • Calling a continuation sends the argument through a pipe using a naive byte-for-byte copy. So the argument needs to be Plain Old Data, and had better not contain pointers to anything not shared by the two processes. This means we can't send continuations through other continuations, sad to say.

  • I left out the error handling you would expect in serious code, because this is anything but.

  • Likewise, I'm assuming that a single write and read will suffice to send the value. Robust code will need to loop until completion, handle EINTR, etc. Or use some higher-level IPC mechanism.

  • At some size, stack-allocating the receive buffer will become a problem.

  • It's slow. Well, actually, I'm impressed with the speed of fork on Linux. My machine solves both backtracking problems in about a second, forking about 2000 processes along the way. You can speed it up more with static linking. But it's still far more overhead than the alternatives.

As usual, you can get the code from GitHub.

Thursday, February 2, 2012

Generating random functions

How can we pick a random Haskell function? Specifically, we want to write an IO action

randomFunction :: IO (Integer -> Bool)

with this behavior:

  • It produces a function of type Integer -> Bool.

  • It always produces a total function — a function which never throws an exception or enters an infinite loop.

  • It is equally likely to produce any such function.

This is tricky, because there are infinitely many such functions (more on that later).

In another language we might produce something which looks like a function, but actually flips a coin on each new integer input. It would use mutable state to remember previous results, so that future calls will be consistent. But the Haskell type we gave for randomFunction forbids this approach. randomFunction uses IO effects to pick a random function, but the function it picks has access to neither coin flips nor mutable state.

Alternatively, we could build a lazy infinite data structure containing all the Bool answers we need. randomFunction could generate an infinite list of random Bools, and produce a function f which indexes into that list. But this indexing will be inefficient in space and time. If the user calls (f 10000000), we'll have to run 10,000,000 steps of the pseudo-random number generator, and build 10,000,000 list elements, before we can return a single Bool result.

We can improve this considerably by using a different infinite data structure. Though our solution is pure functional code, we do end up relying on mutation — the implicit mutation by which lazy thunks become evaluated data.

The data structure

import System.Random
import Data.List ( genericIndex )

Our data structure is an infinite binary tree:

data Tree = Node Bool Tree Tree

We can interpret such a tree as a function from non-negative Integers to Bools. If the Integer argument is zero, the root node holds our Bool answer. Otherwise, we shift off the least-significant bit of the argument, and look at the left or right subtree depending on that bit.

get :: Tree -> (Integer -> Bool)
get (Node b _ _) 0 = b
get (Node _ x y) n =
case divMod n 2 of
(m, 0) -> get x m
(m, _) -> get y m

Now we need to build a suitable tree, starting from a random number generator state. The standard System.Random module is not going to win any speed contests, but it does have one extremely nice property: it supports an operation

split :: StdGen -> (StdGen, StdGen)

The two generator states returned by split will (ideally) produce two independent streams of random values. We use split at each node of the infinite tree.

build :: StdGen -> Tree
build g0 =
let (b, g1) = random g0
(g2, g3) = split g1
in Node b (build g2) (build g3)

This is a recursive function with no base case. Conceptually, it produces an infinite tree. Operationally, it produces a single Node constructor, whose fields are lazily-deferred computations. As get explores this notional infinite tree, new Nodes are created and randomness generated on demand.

get traverses one level per bit of its input integer. So looking up the integer n involves traversing and possibly creating O(log n) nodes. This suggests good space and time efficiency, though only testing will say for sure.

Now we have all the pieces to solve the original puzzle. We build two trees, one to handle positive numbers and another for negative numbers.

randomFunction :: IO (Integer -> Bool)
randomFunction = do
neg <- build `fmap` newStdGen
pos <- build `fmap` newStdGen
let f n | n < 0 = get neg (-n)
| otherwise = get pos n
return f

Testing

Here's some code which helps us visualize one of these functions in the vicinity of zero:

test :: (Integer -> Bool) -> IO ()
test f = putStrLn $ map (char . f) [-40..40] where
char False = ' '
char True = '-'

Now we can test randomFunction in GHCi:

λ> randomFunction >>= test
---- -   ---   -    - -   - --   - - -  -- --- -- --          - -- - - --  --- --
λ> randomFunction >>= test
-   ---- - - - -  - - -- -   -     ---  --- -- - --  -  --    - -  - - -  --   - 
λ> randomFunction >>= test
- ---  - - -  --  ---         -  --  -  -    -  -  - ---- - -  ---   -     -    -

Each result from randomFunction is indeed a function: it always gives the same output for a given input. This much should be clear from the fact that we haven't used any unsafe shenanigans. But we can also demonstrate it empirically:

λ> f <- randomFunction
λ> test f
-   -----  - -   -- - -   --- --  - -   - -   - -   -- - -   ---- - - - -  - --- 
λ> test f
-   -----  - -   -- - -   --- --  - -   - -   - -   -- - -   ---- - - - -  - --- 

Let's also test the speed on some very large arguments:

λ> :set +s
λ> f 10000000
True
(0.03 secs, 12648232 bytes)
λ> f (2^65536)
True
(1.10 secs, 569231584 bytes)
λ> f (2^65536)
True
(0.26 secs, 426068040 bytes)

The second call with 2^65536 is faster because the tree nodes already exist in memory. We can expect our tests to be faster yet if we compile with ghc -O rather than using GHCi's bytecode interpreter.

How many functions?

Assume we have infinite memory, so that Integers really can be unboundedly large. And let's ignore negative numbers, for simplicity. How many total functions of type Integer -> Bool are there?

Suppose we made an infinite list xs of all such functions. Now consider this definition:

diag :: [Integer -> Bool] -> (Integer -> Bool)
diag xs n = not $ genericIndex xs n n

For an argument n, diag xs looks at what the nth function of xs would return, and returns the opposite. This means the function diag xs differs from every function in our supposedly comprehensive list of functions. This contradiction shows that there are uncountably many total functions of type Integer -> Bool. It's closely related to Cantor's diagonal argument that the real numbers are uncountable.

But wait, there are only countably many Haskell programs! In fact, you can encode each one as a number. There may be uncountably many functions, but there are only a countable number of computable functions. So the proof breaks down if you restrict it to a real programming language like Haskell.

In that context, the existence of xs implies that there is some algorithm to enumerate the computable total functions. This is the assumption we ultimately contradict. The set of computable total functions is not recursively enumerable, even though it is countable. Intuitively, to produce a single element of this set, we would have to verify that the function halts on every input, which is impossible in the general case.

Now let's revisit randomFunction. Any function it produces is computable: the algorithm is a combination of the pseudo-random number procedure and our tree traversal. In this sense, randomFunction provides extremely poor randomness; it only selects values from a particular measure zero subset of its result type! But if you read the type constructor (->) as "computable function", as one should in a programming language, then randomFunction is closer to doing what it says it does.

Edit: See also Luke Palmer's recent article on this subject.

See also

The libraries data-memocombinators and MemoTrie use similar structures, not for building random functions but for memoizing existing ones.

You can download this post as a Literate Haskell file and play with the code.