This question seems very straight forward on the surface but an efficient solution is tricky to implement in Java due to pass-by-value semantics. Read on to see why!

Question:

Given a BST of integer elements, return the nth element as quickly as possible. Counting starts at 1 (not 0).

Input (input.txt):

4
3
-1
15
2
7
16

Output (stdout):

7

Solution:

A very naive solution is to traverse through the entire tree in-order, adding each element to an ArrayList as you go, and returning the nth element of the resulting ArrayList. This solution is simple to implement but uses extra memory and has a wasteful run-time if n is small compared to the size of the tree (although the run-time can be improved via early termination).

A better idea is to perform the recursive in-order traversal on the tree nodes while keeping a global counter that we can modify across the recursion stack. (Note that we need to define a custom wrapper object for this counter to simulate pass-by-reference of primitive values.) As far as the search logic goes, we can terminate the search and return the result as soon as the counter is fully decremented from n to 0. One final trick: we need to use the Integer primitive wrapper so we can return a null value when the nth element is not found in a substree.

    public static void main (String [] args) throws IOException {
        Scanner in = new Scanner (new FileReader("src/input.txt"));
        int n = in.nextInt();
        BSTNode root = new BSTNode(in.nextInt());
        while(in.hasNext()) {
           root.insert(new BSTNode(in.nextInt()));
        }
        System.out.println(findNthElement(root,new Counter(n)));
    }

    static Integer findNthElement(BSTNode node, Counter counter) {
        //If we reached the end of subtree return null to indicate "not found in subtree"
        if(node==null )  {
            return null;
        }
        //Traverse left subtree
        Integer result = findNthElement(node.left, counter);
        //Return result if found in left subtree
        if (result!=null) {
            return result;
        }
        //In-order visit
        counter.value=counter.value-1;
        //Return node value if found
        if(counter.value == 0) { 
          return node.value;  
        }
        //Traverse right subtree
        result = findNthElement(node.right, counter);
        return result;
    }

    static class Counter {
        int value;
        public Counter(int value) {
            this.value = value;
        }
    }
    
    static class BSTNode {
        BSTNode left;
        BSTNode right;
        int value;

        public BSTNode(int value) {
            this.value = value;
        }

        //Standard BST insertion algorithm
        public void insert(BSTNode node) {
            if(node.value<value) {
                if (left==null) {
                    left = node;
                } else {
                    left.insert(node);
                }
            }
            if(node.value>value) {
                if (right == null) {
                    right = node;
                } else {
                    right.insert(node);
                }
            }
        }
    }