Recursion + Generators in Python

So here I am, solving a few problems on Trees. And I stumbled upon this problem: Two Sum IV - Input is a BST.

One approach can be to travel the nodes and hash the data in a set() and alongside check if the target-x (where x is the current node) is there in the set or not. This is good, but can't be used to find the closest-pair-sum if target is not present. In a sorted array, we could've used two pointers approach where one moves from left to right and other from right to left based on the sum of values at the positions of those pointers. [Check this out for full solution]

But we don't have an array here, we have a Binary Search Tree (BST).

  • We know that inorder traversal of a BST produces an increasing ordered sorted output.
  • ... which should also mean that reverse inorder traversal produces decreasing ordered sorted output.

If we could do both traversals at the same time and also have the control over which one to move next, we would effectively be implementing two pointer approach.

To achieve this, recursive inorder traversal algorithm can be implemented. But the issue is that control over the traversal is not possible using recursion as we'd need to start, stop and resume the traversal mid way through function calls. Iterative inorder will work for sure, but it is a bit tedious and less intuitive than its recursive counterpart.

That line in italics seems to scream for usage of generators. Well, let's try to do it.

def inorder(root):
    if root.left:
        inorder(root.left)
    yield root
    if root.right:
        inorder(root.right)

def rev_inorder(root):
    if root.right:
        rev_inorder(root.right)
    yield root
    if root.left:
        rev_inorder(root.left)

class Solution:
    def findTarget(self, root: TreeNode, k: int) -> bool:
        l, r = inorder(root), rev_inorder(root)
        lp = next(l)
        rp = next(r)
        while lp is not rp:
            sm = lp.val + rp.val
            if sm == k:
                return True
            elif sm < k:
                lp = next(l)
            else:
                rp = next(r)
        return False

Here l and r are generator objects and lp and rp are left and right pointers. We're trying to give control back to the caller using yield and next value can be obtained (or generated) whenever we want.

But This ... doesn't work! Why?

If you add some log statements in your recursive inorder(), you'll notice that recursion isn't happening. inorder() doesn't seem to call itself recursively.

This is because the recursive inorder() invocation is similar to the one used in findTarget(), which means calling inorder() inside inorder() is just creating generator objects and not really moving inside recursive call.

So, it isn't possible to implement it recursively? Not quite.

yield_from to the rescue!

def inorder(root):
    if root.left:
        yield from inorder(root.left)
    yield root
    if root.right:
        yield from inorder(root.right)

def rev_inorder(root):
    if root.right:
        yield from rev_inorder(root.right)
    yield root
    if root.left:
        yield from rev_inorder(root.left)

Now, This works!

You can have a look at PEP 380 for further reading :)