4.16. Fold

We consider various concurrent implementations of the classic "list fold" function from functional programming:

def fold(_, [x])  = x
def fold(f, x:xs) = f(x, fold(xs))

This is a seedless fold (sometimes called fold1) which requires that the list be nonempty and uses its first element as a seed. This implementation is short-circuiting --- it may finish early if the reduction operator f does not use its second argument --- but it is not concurrent; no two calls to f can proceed in parallel. However, if f is associative, we can overcome this restriction and implement fold concurrently. If f is also commutative, we can further increase concurrency.

4.16.1. Associative Fold

We first consider the case when the reduction operator is associative. We define afold(f,xs) where f is a binary associative function and xs is a non-empty list. The implementation iteratively reduces xs to a single value. Each step of the iteration applies the auxiliary function step, which halves the size of xs by reducing disjoint pairs of adjacent items.

def afold(_, [x]) = x
def afold(f, xs) =
  def step([]) = []
  def step([x]) = [x]
  def step(x:y:xs) = f(x,y):step(xs)
  afold(f, step(xs))

Notice that f(x,y):step(xs) is an implicit fork-join. Thus, the call f(x,y) executes in parallel with the recursive call step(xs). As a result, all calls to f execute concurrently within each iteration of afold.

4.16.2. Associative, Commutative Fold

We can make the implementation even more concurrent when the fold operator is both associative and commutative. We define cfold(f,xs), where f is a associative and commutative binary function and xs is a non-empty list. The implementation initially copies all list items into a channel in arbitrary order using the auxiliary function xfer, counting the total number of items copied. The auxiliary function combine repeatedly pulls pairs of items from the channel, reduces them, and places the result back in the channel. Each pair of items is reduced in parallel as they become available. The last item in the channel is the result of the overall fold.

def cfold(f, xs) =
  val c = Channel()
  
  def xfer([])    = 0
  def xfer(x:xs)  = c.put(x) >> stop | xfer(xs)+1

  def combine(0) = stop
  def combine(1) =  c.get()
  def combine(m) =  c.get() >x> c.get() >y> 
                    ( c.put(f(x,y)) >> stop | combine(m-1))

  xfer(xs) >n> combine(n)