Friday, April 27, 2007

Hello Scala: Reservoir Sampling

I've started to climb the rather steep learning curve for Scala. The documentation for Scala is mostly fairly high level, and there aren't as many code samples as I would have liked. I'm hoping to help out with that problem here. I've selected what I think is a fairly interesting problem and written Scala code to solve it in three different ways. I had a lot of fun doing it, and I think the code really shows how flexible this language is. I actually wrote it in another dozen ways, but these are the ones that turned out best (some of the others were really horrendous). Hopefully some of you will find this helpful in your own efforts to learn Scala.

I have interviewed nearly four hundred software engineers in the last four years. As a result I have developed an unhealthy fascination with interview questions, and I have a lot of opinions on their relative merit. A friend who has been out interviewing sent me this one recently:

Write a program to select k random lines (with uniform probability) from a file in one pass. You do not know n (the total number of lines) ahead of time.

This is a tough problem if you've never heard it before (or don't happen to be quite good with probability). It took me about twenty minutes to come up with a correct algorithm (along with a half dozen incorrect ones). The correct algorithm had a righteous feeling to it, but I had to enlist the help of a friend (with a degree in math) to help me prove it. I don't think this is a particularly fair question to ask in an interview, unless you are willing to offer a lot of help. A good interview question has an involved algorithm, but an easy way to check its correctness. This is the opposite: the algorithm is fairly simple, but checking its correctness is time consuming.

It turns out that there are a number of correct algorithms to this problem, and some of them run in asymptotic time less than O(n)! The problem is one of a set of "reservoir sampling" algorithms, and there have been a number of papers written in the area. The earliest one that I could find was from Jeffrey Scott Vitter and was published in the ACM Transactions on Mathematical Software, Vol. 11, No. 1, March 1985, Pages 37-57. It lists four algorithms, which it refers to as R, X, R and Z. The average CPU time for Z is O(k(1 + log(n/k))). That probably isn't much better than the O(n) for algorithm R, since the I/O time is likely to dominate. In a situation where you are reading a stream of data from memory, or a network card, it could be a big savings, however.

This problem seemed custom designed for a first project in Scala. It includes a lot of different features: file I/O, simple mathematics, conditional logic and some iterative processing. It is also quite interesting, and likely to be very useful for things like log parsing and file sampling (as input to unit tests, for instance). I decided to implement algorithm R since it was the most straightforward.

The intuition behind algorithm R is something like the following. Any correct solution must return min(k,n) lines. In addition, all n lines in the file must have an identical probability of being selected: k / n. The first constraint can be satisfied by just choosing the first k lines with probability one. Each line after the k'th should be selected with probability k / i where i is the line number of the file. If i == n then we have selected the last line with the correct probability. If we select the i'th line, then we choose a random line from the already selected set to replace.

I started by writing the algorithm using standard imperative constructs. Using this approach with Scala produces code that looks a lot like Java:

import java.io._
import Console._
import Math._

object ReservoirSampleImperative {
  def algorithmR(readers: Iterator[BufferedReader], k: int) : Array[String] = {
      val a = new Array[String](k)

      var t = 0

      while (readers.hasNext) {
          val reader = readers.next
          var line = reader.readLine
          while (line != null) {
              if (t < k)
                  a(t) = line
              else if (random < (k.toDouble/t.toDouble))
                  a((random * k.toDouble).toInt) = line

              t = t + 1
              line = reader.readLine
          }
      }

      a
  }

  def main(args: Array[String]) : Unit = {
      val iter = args.elements
      val k = iter.next.toInt

      val readers = if (args.length > 0)
          iter.map (f => new BufferedReader(
              new FileReader(f)))
      else
          Array[BufferedReader](new BufferedReader(
              new InputStreamReader(System.in))).elements

      val a = algorithmR(readers, k)
      for (val e <- a) println(e)
  }
}
This version actually has a lot going for it. Neither the Array class (which is where it gets its arguments) nor the Java I/O modules are very functional, so the OO/imperative approach works very well with them. Writing the program in a more functional style required that the I/O be converted to a Stream:
import java.io._
import Math._
import Console._

object ReservoirSampleFunct {
  def inputStream(i: Iterator[BufferedReader]) : Stream[String] = {
      def inputStreamString(b: BufferedReader) : Stream[String] = {
          val line = b.readLine()
          if (line == null)
              inputStream(i)
          else
              Stream.cons(line, inputStreamString(b))
      }

      if (i.hasNext)
          inputStreamString(i.next)
      else
          Stream.empty
  }

  def algorithmR(k: int)(p: (Int,Array[String]), s: String) : (Int, Array[String]) = {
      val (i,a) = p

      if (i < k)
          a(i) = s
      else if (random < (k.toDouble/i.toDouble))
          a((random * k.toDouble).toInt) = s
      (i+1, a)
  }

  def main(args: Array[String]) : Unit = {
      val i = args.elements
      val k = i.next.toInt

      val readers = if (i.hasNext)
          i.map (f => new BufferedReader(new FileReader(f)))
      else
          Array[BufferedReader](new BufferedReader(
              new InputStreamReader(System.in))).elements

      val (_,a) = inputStream(readers).foldLeft (
          (0, new Array[String](k))
      ) (algorithmR(k))

      for (val e <- a) println(e)
  }
}

There is still a lot of non-functional, side-effect producing code here. I'm using Array for both the command line arguments and the list of k sampled lines. I think that is the real power of Scala, actually. Both of those things are cases where having side effects really makes the programming easier. I could have written a helper function to replace on of the k lines in a List, but I think the resulting code would have been longer and no more understandable (in fact, probably less).

A good friend, who has been at this longer than me, contributed this version of the code, which uses an implicit constructor. He notes that it probably isn't a great idea to do this in large scale production code. The proliferation of implicit constructors can lead to some truly bizarre behavior due to unintended conversions. Still, it does nicely simplify some of the code from my previous example:

import java.io._
import Console._
import Math._

object ReservoirSampleImplicit {
  implicit def readerToStream(r: BufferedReader) : Stream[String] = {
      val line = r.readLine()
      if (line == null)
          Stream.empty
      else
          Stream.cons(line, readerToStream(r))
  }

  def main(args: Array[String]) : Unit = {
      val k: Int = args(0).toInt
      val r = new BufferedReader(new FileReader(args(1)))

      def algorithmR(p: (Int,Array[String]), e: String) : (Int, Array[String]) = {
          val (i,a) = p

          if (i < k)
              a(i) = e
          else if (random < (k.toDouble/i.toDouble))
              a((random * k.toDouble).toInt) = e

          (i+1,a)
      }

      val (_,rs) = r.foldLeft (
          (0, new Array[String](k))
      ) (algorithmR)

      for (val r <- rs) println(r)
  }
}
That's it for now. I'm going to try to tackle some of Scala's Actor libraries for the next week or so. I'll post here if I come up with any interesting applications of them. If you know Scala and have a better way to solve the Reservoir Sampling problem, I would love to hear about it. Scala is a huge language with a lot of features, and I don't feel like I've even started to explore the space of possible ways to solve this problem using it.