Recursion and tail recursion recalled

Let us start by recalling the basics of recursion and tail recursion. For a more in-depth introduction see the material you studied during “Programming 1”.

We will also introduce some new terminology (call-stacks and the expansion of function calls) so you may want to read through this part even if you have done all the programming exercises on recursion in the “Programming 1” course.

Example: Palindromes

Let us first recall the palindrome example from Programming 1, Chapter 12.2.

A palindrome is a word or sentence that reads the same when read backwards or forwards. For example, the following are palindromes:

  • A man, a plan, a canal: Panama

  • testset

  • ufo tofu

  • Nurses run

That is, a string is a palindrome if, after omitting spaces and punctuation symbols (“,”, “!” and so on), the string remains the same when written backwards. For strings without spaces and punctuation we thus get the following recursive definition of palindromes:

  1. An empty string is a palindrome.

  2. A string with only one character is a palindrome.

  3. If a string (i) starts and ends with the same character, and (ii) the substring between these is a palindrome, then the string is a palindrome.

  4. There are no other palindromes.

From this recursive definition it is easy to derive a recursive function for checking whether a string is a palindrome:

def isPalindrome(s: String): Boolean =
  // Remove spaces and punctuation, transform to lower case
  val sPlain = s.filter(_.isLetterOrDigit).map(_.toLower)
  // The cases
  if sPlain.length <= 1 then true // empty or one character?
  else if sPlain.head != sPlain.last then false // start and end differ?
  else isPalindrome(sPlain.substring(1, sPlain.length - 1)) // substring palindrome?
end isPalindrome

Observe the second line val sPlain = s.filter(_.isLetterOrDigit).map(_.toLower) in the function: it first removes the special characters and spaces with filter(_.isLetterOrDigit) and then converts all the characters to lower case with map(_.toLower). Without this line, the function would return false on all the examples above with the exception of “testset”.

There is one obvious inefficiency in this function: the second line is executed in every recursive call even though it would suffice to execute it only once in the first call. How can we capture this “first call” concept? Perhaps the most elegant way is to use an inner function actualRecursion that performs the actual recursive task; the main function thus simply first preforms the string preprocessing and then calls the inner function.

def isPalindrome(s: String): Boolean =
  // Helper inner function that encapsulates the actual recursive task.
  // Assumes that the string contains only lower case characters
  // and no spaces or punctuation.
  def actualRecursion(t: String): Boolean =
    if t.length <= 1 then true
    else if t.head != t.last then false
    else actualRecursion(t.substring(1, t.length - 1))
  end actualRecursion
  // Remove spaces and punctuation, transform to lower case
  val sPlain = s.filter(_.isLetterOrDigit).map(_.toLower)
  // Run the actual recursion
  actualRecursion(sPlain)
end isPalindrome

We have defined the inner function actualRecursion here inside the main function isPalindrome because it is very closely related to the main function: we have no use for this function elsewhere.

Call-stacks and expansion

Let us now introduce two tools that are helpful in understanding and analysing what happens in recursive function calls: “call-stacks” and “expansion”. We illustrate a call-stack as a sequence of lines consisting of

  1. function calls with concrete parameters (these lines will end in the : character), and

  2. selected commands executed in the function. (For convenience we display only selected commands with a focus on what is needed to compute the return value of each function call.)

In displaying a call stack we adopt the convention that time flows top-down and indentation indicates in which function call each command is executed; we increase indentation every time we make or enter a new function call.

For example, the call-stack for the isPalindrome with argument “ufo tofu” looks like this:

isPalindrome("ufo tofu"):
 val sPlain = "ufotofu" // We could skip this line in the presentation
 return actualRecursion("ufotofu")
  actualRecursion("ufotofu"):
   // tests (t.length <= 1) and (t.head != t.last) both fail
   return actualRecursion("fotof")
    actualRecursion("fotof"):
     return actualRecursion("oto"):
      actualRecursion("oto"):
       return actualRecursion("t"):
        actualRecursion("t"):
         return true // Test (t.length <= 1) succeeds

For a graphical animation illustrating the evaluation steps and the call-stack, it is useful to visit Chapter 12.2 of the Programming 1 course.

If the function is free of side effects, we can also illustrate its computation by simply “expanding” the code. Here is the same example again:

isPalindrome("ufo tofu") =
actualRecursion("ufo tofu".filter(_.isLetterOrDigit).map(_.toLower)) =
actualRecursion("ufotofu".map(_.toLower)) =
actualRecursion("ufotofu") =
actualRecursion("fotof") =
actualRecursion("oto") =
actualRecursion("t") =
true

That is, we simply expand, one-by-one, the function calls that actually take place and do this in the same order as is done when the Scala code is actually executed. Again for conciseness we omit obvious steps such as

actualRecursion("ufo tofu".filter(_.isLetterOrDigit).map(_.toLower)) =
actualRecursion("ufotofu".map(_.toLower))

above.

A note for the curious: for this very simple task of detecting palindromes we could have used the reverse and equals methods already defined for strings:

def isPalindrome(s: String): Boolean =
  // Remove spaces and special characters
  val sPlain = s.filter(_.isLetterOrDigit).map(_.toLower)
  // Compare whether the string is the same when reversed
  sPlain == sPlain.reverse
end isPalindrome

However, our intent in this round is to practice recursion, so we will pursue recursive solutions.

Tail recursion

One technical limitation of recursion is that there is only a limited amount of space reserved for the call-stack at run time. That is, the number of nested recursive calls is limited to some fixed amount.

For, let us implement the factorial function. For a positive integer \(n=1,2,\ldots\) the factorial \(n!\) is defined by

\[\begin{split}n! = \begin{cases} 1 & \text{if }n=1;\\ n\cdot (n-1)! & \text{if }n\geq 2\,. \end{cases}\end{split}\]

Or what is the same, using recursion and Scala:

def fact(n : Int): BigInt =
  require(n >= 1, "n should be a positive integer")
  if n == 1 then BigInt(1)
  else n * fact(n-1)
end fact

Now we get:

scala> fact(10)
res1: BigInt = 3628800

This is as expected, however, with a larger value of \(n\) it appears we are in trouble:

scala> fact(100000)
java.lang.StackOverflowError
       at .fact(<console>:10)
       at .fact(<console>:10)
       ...

If we consider the call-stack of the execution with a small value of \(n\), we observe the following:

fact(4):
 val temp1 = fact(3) // the right part of the expression "n * fact(n-1)"
  fact(3):
   val temp2 = fact(2)
    fact(2):
     val temp3 = fact(1)
      fact(1):
       return BigInt(1)
     return 2*temp3  // 2*1 = 2
   return 3*temp2  // 3*2 = 6
 return 4*temp1  // 4*6 = 24

Observe that on every recursive call with \(n \geq 2\), the function must first compute the value fact(n-1) before it can multiply the value with \(n\) and return the result to its caller. Thus, the depth of the call stack grows linearly as a function of \(n\).

Because the function fact does not have side effects, we could also compute its value by expanding it:

fact(4) =
 (4 * fact(3)) =
 (4 * (3 * fact(2))) =
 (4 * (3 * (2 * fact(1)))) =
 (4 * (3 * (2 * 1))) =
 (4 * (3 * 2)) =
 (4 * 6) =
 24

In this simple case, when the expression is written like this, a human would probably simplify the expression dynamically during its construction. But the Scala compiler cannot do or simulate this, and in general the operations that are performed on the result of a recursive call can be arbitrarily complex.

However, there is a special type of recursion that the Scala compiler can recognise and automatically avoid increasing the size of the call stack on recursive calls. This type of recursion is called tail recursion. We say that a function call in a method or function body is a tail call if the call is the last operation (or tailing operation) when the body is executed.

For example, let us consider our factorial function

def fact(n : Int) : BigInt =
  require(n >= 1, "the argument should be a positive integer")
  if n == 1 then BigInt(1)
  else n * fact(n-1)
end fact

The call fact(n-1) is in fact not a tail call because its return value gets multiplied by \(n\) before the function body returns. Indeed, the tailing operations in the body of fact are

  1. the call BigInt(1), and

  2. the multiplication operation * in the expression n * fact(n-1).

Conceptually, tail calls can be exploited by immediately “recycling” the call frame (the stack frame). That is, the space in the call stack used for storing the values of the local variables and other information specific to the function call for the next call. This stops the call-stack from growing. Some (but not all) Java Virtual Machines support such automatic tail-call optimisation.

We say that a function is tail-recursive if all of its calls to itself (that is, direct calls to itself in the function body or indirect calls via other functions) are tail calls. In the following, we will only consider tail recursion in which the function calls to itself are direct tail calls (that is, we do not consider indirect forms of tail recursion, such as alternating calls of the form def f(x) = ...; g(y) and def g(y) = ...; f(x)). The reason for restricting to direct tail recursion is that the Scala compiler can, at compile time, automatically optimise direct tail recursion into a loop-based iteration, thus simulating tail-call optimisation at compile-time.

Let us now make a tail-recursive version of the factorial function. Again it will be useful to introduce an auxiliary inner function:

def fact(n : Int) : BigInt =
 require(n >= 1, "n should be a positive integer")
 def iterate(i : Int, result : BigInt) : BigInt =
   if i > n then result
   else iterate(i+1, result*i)
 end iterate
 iterate(2, BigInt(1))
end fact

Recall the notion of closure from Round 6: Collections and functions and observe that we use closure in the function above: the function iterate can access n as it is defined in its enclosing context.

After a few seconds of waiting, we now get:

scala> fact(100000)
res: BigInt = 2824229407960347874293421578024535518477494926091224850578918086542977950901063017872551...

To understand how the function iterate works, let us compare it with an iterative implementation of factorial:

def fact(n : Int) : BigInt =
 require(n >= 1, "n should be a positive integer")
 var result = BigInt(1)
 for i <- 2 to n do result = result * i
 result
end fact

In a sense, the inner function iterate “implements” the for-loop that we witness in the iterative version. Indeed, let us study how the tail-recursive implementation expands:

fact(4) =
iterate(2, 1) =
iterate(3, 2) =
iterate(4, 6) =
iterate(5, 24) =
24

Here we see that the first parameter of iterate evolves exactly like the variable i in the iterative version, and the second parameter evolves exactly like the variable result in the iterative version.

We can witness the same as a call-stack when we perform the tail-call optimisation:

fact(4):
 return iterate(2, 1)
  iterate(2, 1):
   return iterate(3, 1*2)
  iterate(3, 2):
   return iterate(4, 2*3)
  iterate(4, 6):
   return iterate(5, 6*4)
  iterate(5, 24):
   return 24

The Scala language includes an annotation @tailrec that declares that the compiler must optimise tail recursion into iteration. For example, we could annotate our tail-recursive factorial program as follows:

import scala.annotation.tailrec

def fact(n : Int) : BigInt =
 require(n >= 1, "n should be a positive integer")
 @tailrec def iterate(i : Int, result : BigInt) : BigInt =
   if i > n then result
   else iterate(i+1, result*i)
 end iterate
 iterate(2, BigInt(1))
end fact

There are two requirements for the annotated method/function:

  1. tail-recursive calls to the function must be direct calls (that is, indirect calls such as def f(x) = {...; g(y)} and def g(y) = {...; f(x)} are not allowed),

  2. it must not be possible to override the method/function, meaning that it must be an inner function or declared as a final method (see e.g. this blog post for further discussion).

Warning

If we have annotated our function with @tailrec, then the compiler will issue an error if it cannot optimise tail recursion because a condition above is violated. Thus, such annotations are a useful defensive programming strategy to make sure that what we as programmers intend to be tail-recursive and automatically optimizable actually is optimised by the compiler as we intended.