110. Balanced Binary Tree - Explanation

Problem Link

Description

Given a binary tree, return true if it is height-balanced and false otherwise.

A height-balanced binary tree is defined as a binary tree in which the left and right subtrees of every node differ in height by no more than 1.

Example 1:

Input: root = [1,2,3,null,null,4]

Output: true

Example 2:

Input: root = [1,2,3,null,null,4,null,5]

Output: false

Example 3:

Input: root = []

Output: true

Constraints:

  • The number of nodes in the tree is in the range [0, 1000].
  • -1000 <= Node.val <= 1000


Recommended Time & Space Complexity

You should aim for a solution with O(n) time and O(n) space, where n is the number of nodes in the tree.


Hint 1

A brute force solution would involve traversing every node and checking whether the tree rooted at each node is balanced by computing the heights of its left and right subtrees. This approach would result in an O(n^2) solution. Can you think of a more efficient way? Perhaps you could avoid repeatedly computing the heights for every node by determining balance and height in a single traversal.


Hint 2

We can use the Depth First Search (DFS) algorithm to compute the heights at each node. While calculating the heights of the left and right subtrees, we also check if the tree rooted at the current node is balanced. If leftHeight - rightHeight > 1, we update a global variable, such as isBalanced = False. After traversing all the nodes, the value of isBalanced indicates whether the entire tree is balanced or not.


Company Tags

Please upgrade to NeetCode Pro to view company tags.



1. Brute Force

Intuition

A tree is balanced if every node’s left and right subtree heights differ by at most 1.

The brute-force approach directly follows the definition:

  • For every node, compute the height of its left subtree.
  • Compute the height of its right subtree.
  • Check if their difference is ≤ 1.
  • Recursively repeat this check for all nodes.

Algorithm

  1. If the current node is null, the subtree is balanced.
  2. Compute:
    • leftHeight = height(left subtree)
    • rightHeight = height(right subtree)
  3. If abs(leftHeight - rightHeight) > 1, return False.
  4. Recursively check if:
    • left subtree is balanced
    • right subtree is balanced
  5. If all checks pass, return True.

Height function:

  • If node is null → return 0
  • Otherwise → 1 + max(height(left), height(right))
# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

class Solution:
    def isBalanced(self, root: Optional[TreeNode]) -> bool:
        if not root:
            return True

        left = self.height(root.left)
        right = self.height(root.right)
        if abs(left - right) > 1:
            return False
        return self.isBalanced(root.left) and self.isBalanced(root.right)

    def height(self, root: Optional[TreeNode]) -> int:
        if not root:
            return 0

        return 1 + max(self.height(root.left), self.height(root.right))

Time & Space Complexity

  • Time complexity: O(n2)O(n ^ 2)
  • Space complexity: O(n)O(n)

Intuition

The brute-force solution wastes time by repeatedly recomputing subtree heights.
We fix this by doing one DFS that returns two things at once for every node:

  1. Is the subtree balanced? (True/False)
  2. What is its height?

This way, each subtree is processed only once.
If at any node the height difference > 1, we mark it as unbalanced and stop worrying about deeper levels.


Algorithm

  1. Write a DFS function that:
    • Returns [isBalanced, height].
  2. For each node:
    • Recursively get results from left and right children.
    • A node is balanced if:
      • Left subtree is balanced
      • Right subtree is balanced
      • Height difference ≤ 1
  3. Height of the current node = 1 + max(leftHeight, rightHeight)
  4. Run DFS on the root and return the isBalanced value.
# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

class Solution:
    def isBalanced(self, root: Optional[TreeNode]) -> bool:
        def dfs(root):
            if not root:
                return [True, 0]

            left, right = dfs(root.left), dfs(root.right)
            balanced = left[0] and right[0] and abs(left[1] - right[1]) <= 1
            return [balanced, 1 + max(left[1], right[1])]

        return dfs(root)[0]

Time & Space Complexity

  • Time complexity: O(n)O(n)
  • Space complexity: O(h)O(h)

Where nn is the number of nodes in the tree and hh is the height of the tree.


3. Iterative DFS

Intuition

The recursive DFS solution computes height and balance in one postorder traversal.
This iterative version does the same thing, but simulates recursion using a stack.

The idea:

  • We must visit each node after its children (postorder).
  • Once both children of a node are processed, we already know their heights.
  • Then we:
    1. Check if the height difference ≤ 1
    2. Save the node’s height (1 + max(left, right))

If any node is unbalanced, return False immediately.


Algorithm

  1. Use a stack to simulate postorder traversal.
  2. Use a dictionary/map (depths) to store the height of each visited node.
  3. For each node:
    • Traverse left until possible.
    • When left is done, try right.
    • When both children are done:
      • Get their heights from depths.
      • If the difference > 1 → tree is unbalanced → return False.
      • Compute current node height and store it.
  4. If the traversal completes without violations → return True.
# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

class Solution:
    def isBalanced(self, root):
        stack = []
        node = root
        last = None
        depths = {}

        while stack or node:
            if node:
                stack.append(node)
                node = node.left
            else:
                node = stack[-1]
                if not node.right or last == node.right:
                    stack.pop()
                    left = depths.get(node.left, 0)
                    right = depths.get(node.right, 0)

                    if abs(left - right) > 1:
                        return False

                    depths[node] = 1 + max(left, right)
                    last = node
                    node = None
                else:
                    node = node.right

        return True

Time & Space Complexity

  • Time complexity: O(n)O(n)
  • Space complexity: O(n)O(n)