How to check if a binary tree is balanced

Question 4.1 of Cracking the Coding Interview:

Implement a function to check if a binary tree is balanced. For the purposes of this question, a balanced tree is defined to be a tree such that the heights of the two subtrees of any node never differ by more than one.

To implement this, we can just translate the English definition into a given programming language. Here it is in Haskell:

module BalancedTree where
  
import qualified Data.Maybe

data Tree = Leaf | Branch Tree Tree

isBalanced :: Tree -> Bool
isBalanced Leaf = True
isBalanced (Branch l r) = 
  isBalanced l && isBalanced r && diff (height l) (height r) <= 1

height :: Tree -> Int
height Leaf = 0
height (Branch l r) = max (height l) (height r) + 1

diff n m = abs (n-m)

This naive translation gives an O(n*log(n)) algorithm, which is not too bad. But it does a linear-time amount of work at each node to calculate the heights of each subtree. But we can reduce the work at each node to constant time, by passing up a height from the recursive calls. This gives us an O(n) algorithm:

module BalancedTree where

import qualified Data.Maybe

data Tree = Leaf | Branch Tree Tree

isBalanced :: Tree -> Bool
isBalanced = Data.Maybe.isJust . isBalancedWithHeight

isBalancedWithHeight :: Tree -> Maybe Int
isBalancedWithHeight Leaf = Just 0
isBalancedWithHeight (Branch l r) = do
  lh <- isBalancedWithHeight l
  rh <- isBalancedWithHeight r
  if diff lh rh <= 1
    then Just $ max lh rh + 1
    else Nothing

diff n m = abs (n-m)

main :: IO ()
main = print $ all id [
    isBalanced Leaf,
    isBalanced (Branch Leaf Leaf),
    isBalanced (Branch (Branch Leaf Leaf) Leaf),
    not $ isBalanced (Branch Leaf (Branch Leaf (Branch Leaf Leaf)))
  ]

How did I work out that the naive algorithm is O(n*log(n))? I actually did it by noticing “this looks like a sorting algorithm”, in that it makes recursive sub-calls then does a linear amount of work, then remembering that "these sorting algorithms are O(n*log(n)). A more principled way is to write out the recurrence relation, T(n) = 2*T(n/2) + n, then use magic master theorem. I’m sure I’ll have to learn the master theorem properly in a future question from this book.

Tagged #ctci, #programming, #haskell.

Similar posts

More by Jim

Want to build a fantastic product using LLMs? I work at Granola where we're building the future IDE for knowledge work. Come and work with us! Read more or get in touch!

This page copyright James Fisher 2020. Content is not associated with my employer. Found an error? Edit this page.