Search Trees

In this reading (notebook version), we'll learn about how to use trees as an efficient way to search for data.

By the end, you should be comfortable with the following terms:

  • binary tree
  • search
  • range query
  • binary search tree
  • balanced tree

formatting...

In [20]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}
In [21]:
from graphviz import Graph, Digraph

Binary Tree

Below is a binary tree, cleaned up from last time.

Remember that a tree is a directed graph. It has one root without a parent. Every other node has a parent. Nodes without children are called leaves.

This tree is a binary tree because each node has at most two children.

In [54]:
class Node:
    def __init__(self, val):
        self.val = val
        self.left = None
        self.right = None
    
    def to_graphviz(self, g=None):
        if g == None:
            g = Digraph()
            
        # draw self
        g.node(repr(self.val))
    
        for label, child in [("L", self.left), ("R", self.right)]:
            if child != None:
                # draw child, recursively
                child.to_graphviz(g)
                
                # draw edge from self to child
                g.edge(repr(self.val), repr(child.val), label=label)
        return g
    
    def _repr_svg_(self):
        return self.to_graphviz()._repr_svg_()
    
root = Node("A")
root.left = Node("B")
root.right = Node("C")
root.left.left = Node("Y")
root.left.right = Node("X")
root.right.right = Node("Z")
root
Out[54]:
%3 'A' 'A' 'B' 'B' 'A'->'B' L 'C' 'C' 'A'->'C' R 'Y' 'Y' 'B'->'Y' L 'X' 'X' 'B'->'X' R 'Z' 'Z' 'C'->'Z' R

Search

What if we want to check whether a tree contains a value? We know it does if one of the following is true:

  1. the root has that value
  2. the left subtree contains that value
  3. the right subtree contains that value

Let's write a recursive function, contains, to do this search. At each step, we'll display the subtree being searched.

In [55]:
from IPython.core.display import display, HTML

def contains(node, target):
    if node == None:
        return False

    display(HTML("Is the root %s?" % target))
    display(node)

    if node.val == target:
        return True
    return contains(node.left, target) or contains(node.right, target)

contains(root, "B")
Is the root B?
%3 'A' 'A' 'B' 'B' 'A'->'B' L 'C' 'C' 'A'->'C' R 'Y' 'Y' 'B'->'Y' L 'X' 'X' 'B'->'X' R 'Z' 'Z' 'C'->'Z' R
Is the root B?
%3 'B' 'B' 'Y' 'Y' 'B'->'Y' L 'X' 'X' 'B'->'X' R
Out[55]:
True

Cool, we found the value in the second place we looked because we check left first. What if the data is deep on the right side? Worse, what if the thing we're searching for isn't even in the tree? Let's try that:

In [56]:
contains(root, "M")
Is the root M?
%3 'A' 'A' 'B' 'B' 'A'->'B' L 'C' 'C' 'A'->'C' R 'Y' 'Y' 'B'->'Y' L 'X' 'X' 'B'->'X' R 'Z' 'Z' 'C'->'Z' R
Is the root M?
%3 'B' 'B' 'Y' 'Y' 'B'->'Y' L 'X' 'X' 'B'->'X' R
Is the root M?
%3 'Y' 'Y'
Is the root M?
%3 'X' 'X'
Is the root M?
%3 'C' 'C' 'Z' 'Z' 'C'->'Z' R
Is the root M?
%3 'Z' 'Z'
Out[56]:
False

Search Tree

Ouch, that was slow. It would be great if we could determine that an entry isn't in the tree without needing to look at every entry.

One way we can guarantee this is if every value in a left subtree is less than the value of the parent and every value in the right subtree is greater than the value of the parent.

Constructing a Search Tree

Let's create a function for adding values that guarantees this.

In [57]:
# TODO: make this a method...
def add(node, val):
    if node.val == val:
        return # no duplicates
    elif val < node.val:
        if node.left != None:
            add(node.left, val)
        else:
            node.left = Node(val)
    else:
        if node.right != None:
            add(node.right, val)
        else:
            node.right = Node(val)
In [58]:
root = Node("C")
root
Out[58]:
%3 'C' 'C'
In [59]:
add(root, "A")
root
Out[59]:
%3 'C' 'C' 'A' 'A' 'C'->'A' L
In [60]:
# duplicate shouldn't be added
add(root, "A")
root
Out[60]:
%3 'C' 'C' 'A' 'A' 'C'->'A' L
In [61]:
add(root, "B")
root
Out[61]:
%3 'C' 'C' 'A' 'A' 'C'->'A' L 'B' 'B' 'A'->'B' R
In [62]:
add(root, "E")
root
Out[62]:
%3 'C' 'C' 'A' 'A' 'C'->'A' L 'E' 'E' 'C'->'E' R 'B' 'B' 'A'->'B' R
In [63]:
add(root, "D1")
root
Out[63]:
%3 'C' 'C' 'A' 'A' 'C'->'A' L 'E' 'E' 'C'->'E' R 'B' 'B' 'A'->'B' R 'D1' 'D1' 'E'->'D1' L
In [64]:
add(root, "D2")
root
Out[64]:
%3 'C' 'C' 'A' 'A' 'C'->'A' L 'E' 'E' 'C'->'E' R 'B' 'B' 'A'->'B' R 'D1' 'D1' 'E'->'D1' L 'D2' 'D2' 'D1'->'D2' R
In [65]:
add(root, "F")
root
Out[65]:
%3 'C' 'C' 'A' 'A' 'C'->'A' L 'E' 'E' 'C'->'E' R 'B' 'B' 'A'->'B' R 'D1' 'D1' 'E'->'D1' L 'F' 'F' 'E'->'F' R 'D2' 'D2' 'D1'->'D2' R

Using a Search Tree

Now that we've built a search tree, we can write a method for efficiently searching it. It's like the previous contains function, but now we only need to check one child instead of checking both each time.

In [66]:
def contains(node, target):
    if node == None:
        return False

    display("Is the root %s?" % target)
    display(node)

    if node.val == target:
        return True
    
    if target < node.val:
        display("Go Left")
        return contains(node.left, target)
    else:
        display("Go Right")
        return contains(node.right, target)

contains(root, "D2")
'Is the root D2?'
%3 'C' 'C' 'A' 'A' 'C'->'A' L 'E' 'E' 'C'->'E' R 'B' 'B' 'A'->'B' R 'D1' 'D1' 'E'->'D1' L 'F' 'F' 'E'->'F' R 'D2' 'D2' 'D1'->'D2' R
'Go Right'
'Is the root D2?'
%3 'E' 'E' 'D1' 'D1' 'E'->'D1' L 'F' 'F' 'E'->'F' R 'D2' 'D2' 'D1'->'D2' R
'Go Left'
'Is the root D2?'
%3 'D1' 'D1' 'D2' 'D2' 'D1'->'D2' R
'Go Right'
'Is the root D2?'
%3 'D2' 'D2'
Out[66]:
True

Range Query

For the previous lookups, using a Python set would probably do about as well. But what if we want get all the values in some range?

We can write a similar function. But now, we'll sometimes need to search both sides (depending on the width of the range).

Rather than returning found values, we can accumulate everything in a list.

In [67]:
def range_query(node, lower, upper, results=None):
    if results == None:
        results = []

    if node == None:
        return results
        
    display("Is the root %s between %s and %s" % (node.val, str(lower), str(upper)))
    if lower <= node.val <= upper:
        display("YES")
        results.append(node.val)
    else:
        display("NO")
        
    display(node)

    if lower < node.val:
        range_query(node.left, lower, upper, results)
    if upper > node.val:
        range_query(node.right, lower, upper, results)

    return results

range_query(root, "D1", "D9")
'Is the root C between D1 and D9'
'NO'
%3 'C' 'C' 'A' 'A' 'C'->'A' L 'E' 'E' 'C'->'E' R 'B' 'B' 'A'->'B' R 'D1' 'D1' 'E'->'D1' L 'F' 'F' 'E'->'F' R 'D2' 'D2' 'D1'->'D2' R
'Is the root E between D1 and D9'
'NO'
%3 'E' 'E' 'D1' 'D1' 'E'->'D1' L 'F' 'F' 'E'->'F' R 'D2' 'D2' 'D1'->'D2' R
'Is the root D1 between D1 and D9'
'YES'
%3 'D1' 'D1' 'D2' 'D2' 'D1'->'D2' R
'Is the root D2 between D1 and D9'
'YES'
%3 'D2' 'D2'
Out[67]:
['D1', 'D2']

Balancing

If the depth of all nodes are roughly equal, the time to check a values will be O(log N), which is pretty great! But the insertion order matters a lot. Let's consider these 8 numbers:

In [78]:
nums1 = list(range(20))
nums1
Out[78]:
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
In [79]:
from numpy import random

nums2 = nums1[:] # shallow copy
random.seed(320)
random.shuffle(nums2)
nums2
Out[79]:
[17, 9, 4, 18, 12, 7, 0, 6, 11, 15, 16, 3, 8, 1, 14, 19, 5, 10, 2, 13]
In [80]:
tree1 = Node(nums1[0])
for num in nums1[1:]:
    add(tree1, num)
    
tree2 = Node(nums2[0])
for num in nums2[1:]:
    add(tree2, num)
In [83]:
# lookup with be very slow!
tree1
Out[83]:
%3 0 0 1 1 0->1 R 2 2 1->2 R 3 3 2->3 R 4 4 3->4 R 5 5 4->5 R 6 6 5->6 R 7 7 6->7 R 8 8 7->8 R 9 9 8->9 R 10 10 9->10 R 11 11 10->11 R 12 12 11->12 R 13 13 12->13 R 14 14 13->14 R 15 15 14->15 R 16 16 15->16 R 17 17 16->17 R 18 18 17->18 R 19 19 18->19 R
In [84]:
# a bit better!
tree2
Out[84]:
%3 17 17 9 9 17->9 L 18 18 17->18 R 4 4 9->4 L 12 12 9->12 R 0 0 4->0 L 7 7 4->7 R 3 3 0->3 R 1 1 3->1 L 2 2 1->2 R 6 6 7->6 L 8 8 7->8 R 5 5 6->5 L 11 11 12->11 L 15 15 12->15 R 10 10 11->10 L 14 14 15->14 L 16 16 15->16 R 13 13 14->13 L 19 19 18->19 R

Balance

The second tree is definitely a lot more balanced than the first. If we really want to measure this, we would like to identify openings that are shallower than the deepest nodes.

In [104]:
def nearest_open(node):
    if node is None:
        return 0
    return min(nearest_open(node.left), nearest_open(node.right)) + 1

def max_depth(node):
    if node is None or (node.left is None and node.right is None):
        return 0
    return 1 + max(max_depth(node.left), max_depth(node.right))

nearest_open(tree2), max_depth(tree2)
Out[104]:
(2, 6)
In [110]:
def is_balanced(node):
    return nearest_open(node) >= max_depth(node)
In [114]:
# test is_balanced
t = Node("B")
t.left = Node("A")
t.right = Node("C")
b
%3 'B' 'B' 'A' 'A' 'B'->'A' L 'C' 'C' 'B'->'C' R
Out[114]:
True
In [115]:
t.right.right = Node("D")
display(t)
is_balanced(t)
%3 'B' 'B' 'A' 'A' 'B'->'A' L 'C' 'C' 'B'->'C' R 'D' 'D' 'C'->'D' R
Out[115]:
True
In [116]:
t.right.right.right = Node("E")
display(t)
is_balanced(t)
%3 'B' 'B' 'A' 'A' 'B'->'A' L 'C' 'C' 'B'->'C' R 'D' 'D' 'C'->'D' R 'E' 'E' 'D'->'E' R
Out[116]:
False

Conclusion

In this reading, we have seen that a binary tree is a BST (Binary Search Tree) if all the left descendents of a node have lesser values than the node, and all the right descendents have greater values. Binary search trees allow us to find values and ranges of values without checking every node.

In a perfectly balanced tree, looking for a single item is O(log N). A tree if balanced if there are no nodes that could be moved closer to the root.

Randomizing insertion order can improve balance. There are also algorithms (not covered) to rearrange trees as values are inserted, maintaining balance (perhaps within some tolerance).

In [ ]: