Sunday, December 7, 2014

Functional Programming in Mainstream Languages, Part 5: Higher-Order Functions in Scala

(Jump to Table of Contents)
Scala is a relatively new language; it belongs to the bigger-is-better school of programming languages, and includes many kinds of object-oriented features as well as a functional subset.  Scala captured a lot of interest in the academic community, and for several years it featured in many papers in the leading conferences on object-oriented programming.

Scala attempts (among other things) to combine object-oriented and functional programming in a graceful way.  In this post, I will show how higher-order functions are expressed in Scala, using the same examples as previous posts, and contrast it with the Java 8 implementation.

Scala compiles to Java bytecode, and is therefore compatible with Java code.  It extends the Java type system in ways that make it more consistent.  For example, all values in Scala are objects, so it doesn't have the artificial and annoying distinction between primitive types (int, double, ...) and classes (everything else, including Integer and Double).  Like Java, Scala is strongly typed, but contains extensive type-inference mechanisms so that programmers can avoid a lot of type specifications.

Here is the simple example of the curried addition function:

Java 8
IntFunction<IntUnaryOperator> makeAdder = x -> y -> x + y; IntUnaryOperator increment = makeAdder.apply(1); int res = increment.applyAsInt(4);
def makeAdder = (x: Int) => (y: Int) => x + y def increment = makeAdder(1) def res = increment(4)

Note that in this example type inference goes in the opposite direction in the two languages: Java infers parameter types from the declaration, whereas Scala infers the function's type from the parameter types and the expression in the body.  Scala infers the following type for the add function: Int => (Int => Int).  We could specify this type manually, but don't need to.

Clearly, Scala's notation for functional types is much nicer than Java's use of multiple interfaces; and the function application notation doesn't require an artificial method call.  In both aspects, Scala notation is similar to the usual mathematical notation.

Here is the recursive fixed-point function:

Java 8
public double fixedpoint(DoubleUnaryOperator f, double v) { double next = f.applyAsDouble(v); if (Math.abs(next - v) < epsilon) return next; else return fixedpoint(f, next); }
def fixedpoint(f: Double => Double, v: Double): Double = { val next = f(v) if (Math.abs(next - v) < epsilon) next else fixedpoint(f, next) }

These are quite similar, except for the notations for types and for function application, and the fact that Scala doesn't require explicit return statements.

Scala emphasizes recursion, and will therefore perform tail-recursion optimization whenever possible.  If you want to make sure that the function is indeed optimized in this way, you can add the annotation @tailrec to the function (or method).  This will cause the compiler to produce an error if the tail-recursion optimization can't be applied.  In this case, the compiler accepts this annotation, so that this version is just as efficient as the imperative one.  Because recursion is natural in Scala, this would be the preferred way of writing this method.  For comparison, here is the imperative version:

Java 8
public double fixedpoint(DoubleUnaryOperator f, double v) { double prev = v; double next = v; do { prev = next; next = f.applyAsDouble(next); } while (Math.abs(next - prev) >= epsilon); return next; }
def fixedpoint(f: Double => Double, v: Double) = { var next = v var prev = v do { prev = next next = f(next) } while (Math.abs(next - prev) >= epsilon) next }

This is not too bad; still, I recommend using recursion instead of assignments in all languages that fully support it (that is, guarantee the tail-recursion optimization), including Scala.

Now for the naive non-terminating version of sqrt that uses fixedpoint:
Java 8
public double sqrt(double x) { return fixedpoint(y -> x / y, 1.0); }
def sqrt(x: Double) = fixedpoint(y => x / y, 1.0)

The differences are small, and have to do with Java's verbosity more than anything else.  Gone are the access-level designator (public), the return type, the return keyword, and the braces.  The Scala version is more concise, and can easily be written in one line.  These are small savings, but they add up over a large program and reduce the noise-to-signal ratio.  Except for Haskell (forthcoming), Scala is the most concise (for this kind of code) of all the languages I survery in this series.

The terminating version of sqrt that uses average damping should now come as no surprise:

Java 8
public double sqrt(double x) { return fixedpoint(y -> (y + x / y) / 2.0, 1.0); }
def sqrt(x: Double) = fixedpoint(y => (y + x / y) / 2.0, 1.0)

Here is the general average-damp procedure and the version of sqrt that uses it:

Java 8
public DoubleUnaryOperator averageDamp(DoubleUnaryOperator f) { return x -> (x + f.applyAsDouble(x)) / 2.0; } public double sqrt(double x) { return fixedpoint(averageDamp(y -> x / y), 1.0); }
def averageDamp(f: Double => Double) = (x: Double) => (x + f(x)) / 2.0 def sqrt(x: Double) = fixedpoint(averageDamp(y => x / y), 1.0)

Again, the Scala formulation is concise and clear, while still being strongly typed.  By design, Scala is well-suited for expressing functional-programming abstractions, using type inference effectively in order to reduce the burden of specifying types.  Scala is too large for my taste, but functional programming is easy and natural in it.

#functionalprogramming #scala

No comments:

Post a Comment