Tail Recursion
Table of Contents
There is a special form of recursion called tail recursion that can be compiled to a loop in the machine code. This can save a significant amount of memory if the code needs to make a lot of recursive calls, and in a functional language like Haskell this optimization can make the difference between the code working or running out of memory.
To understand how this works, let’s review how a function keeps track of its information.
1. Function calls and frames
When a program calls a function, the computer allocates memory for the particular instance of that call. This memory is called a stack frame or sometimes an activation record.
For example, suppose we have the following Haskell code:
foo x =
let a = 10
b = 20
in a + b + bar x
bar x =
let a = 30
b = 40
in a + b + baz x
baz x =
let a = 50
b = 60
in a + b + x
If we call foo 10
, the program will allocate a frame for foo
. When foo
calls bar
, the program allocates a frame for bar
, and when bar
called
baz
a frame gets allocated for baz
too.
The important thing to note is that while baz
is running, the frames for foo
and
bar
are still in memory, since they are waiting for the results of baz
.
When a function is recursive, we can end up with a lot of frames. Take this factorial code for example:
fact 0 = 1
fact n = n * fact (n-1)
Computing the factorial of 100 will require 100 frames.
2. Tail calls
Now consider this code.
foo x =
let a = 10
b = 20
in bar (a + b + x)
bar x =
let a = 30
b = 40
in baz (a + b + x)
baz x =
let a = 50
b = 60
in a + b + x
When foo
calls bar
, note that whatever bar
returns will be the value for foo
as well. Similarly,
when bar
calls baz
, whatever baz
return will be the return value for bar
. The function foo
will not
need to keep track of its variables after calling bar
since it will not be using them again, and likewise when
bar
calls baz
.
So, instead of keeping those stack frames around, the compiler will allow bar
to use the same memory that
foo
was using, and in turn allow baz
to use that same memory as well once bar
calls baz
. The function
calls get eliminated and replaced by straight-line code.
The question is: how do we know when it’s safe to do this? The answer is that we can eliminate the function call if the call is in tail position.
3. Tail Code
There are certain kinds of expressions that have subexpressions that will become the value of the whole expression if evaluated. These subexpressions are said to be in tail position. Consider this code:
if a > b then c else d
let e = 10 + 20
in f
foo 0 = 1
foo n = foo (n-1)
The variables c
, d
, and f
are in tail position. The expressions a>b
, 10+20
, and n-1
are not in
tail position. If a>b
is true, and c
has value 99, then the value of the whole if
expression will be 99.
Note that not every expression has subexpressions that are in tail position.
Here is a non-exhaustive list of tail subexpressions:
- The then and else branches of an
if
expression - The return values in a
case
expression - The body of a
let
expression - The return value of a function
If the expression that is in tail position happens to be a function call, then we call it a tail call.
The call foo (n-1)
is a tail call.
4. Accumulators
In direct-style recursion, the function makes all the calls and traverses the data structure. Once it hits the end of the data structure (or base case), it returns to the caller and performs the computations “on the way out”.
Consider this function that sums up the elements of a list.
sum [] = 0
sum (x:xs) = x + sum xs
If we apply sum
to the list [2,3,4,5,6]
, sum
will traverse to the end of the list, and then do all the
adding as things get returned. We can trace the execution this way, where each line is the next call or
return:
1: sum [2,3,4,5,6]
2: = 2 + sum [3,4,5,6]
3: = 2 + 3 + sum [4,5,6]
4: = 2 + 3 + 4 + sum [5,6]
5: = 2 + 3 + 4 + 5 + sum [6]
6: = 2 + 3 + 4 + 5 + 6 + sum []
7: = 2 + 3 + 4 + 5 + 6 + 0
8: = 2 + 3 + 4 + 5 + 6
9: = 2 + 3 + 4 + 11
10: = 2 + 3 + 15
11: = 2 + 18
12: = 20
If the function is in tail form, then we need to do the computation as we make the tail calls. This usually requires an extra parameter to keep track of the computation as we go. For example:
tsum [] a = a
tsum (x:xs) a = tsum xs (x+a)
If we call tsum [2,3,4,5,6] 0
it will look like this:
1: tsum [2,3,4,5,6] 0
2: = tsum [3,4,5,6] 2
3: = tsum [4,5,6] 5
4: = tsum [5,6] 9
5: = tsum [6] 14
6: = tsum [] 20
7: = 20
Notice that the computation happens “on the way in”, and when we reach the base case the computation is complete.
5. How to write a tail recursive function
There are four steps:
- Add an accumulator parameter to your function.
- The base case returns the accumulator. It is okay if you do some computation on it.
- The recursive case calls with an updated value for the accumulator. Once the code makes the recurive call it cannot do anything else once it returns or else it is not it tail position.
- When calling the tail function initially, pass in the result that should get returned by the base case.
(E.g., for
tsum
we called it with a 0. If was a product we would have passed in1
)
We often like to create a helper function that does the recursion, and have the main function be responsible
for passing in the initial value. Here is what tsum
would look like:
tsum xx = aux xx 0
where aux [] a = a
aux (x:xs) a = aux xs (x+a)
6. Problems
Here is some code written in direct style. Try converting them to tail recusion. The answers are below.
prod [] = 1
prod (x:xs) = x * prod xs
fact 0 = 1
fact n = n * fact (n-1)
filter f [] = []
filter f (x:xs) = if f x then x : rest
else rest
where rest = filter f xs
solution for prod
prod xx = aux xx 1
where aux [] a = a
aux (x:xs) a = prod xs (x*a)
solution for fact
fact n = aux n 1
where aux 0 a = a
aux m a = aux (m-1) (m*a)
solution for filter
filter f xx = aux xx []
where aux [] a = reverse a -- We need to reverse the list to preserve the original order
aux (x:xs) a = if f x
then aux xs (x:a)
else aux xs a
-- or --
filter f xx = aux (reverse xx) [] -- We can reverse the input first if we like.
where aux [] a = a
aux (x:xs) a = if f x
then aux xs (x:a)
else aux xs a
6.1. Thanks for Suggestions and Finding Typos
- Ruben Serrano
- Enguang Fan