Maximum Sum BST in Binary Tree

Sanjeev SharmaSanjeev Sharma
2 min read

Advertisement

Problem

Return the maximum sum of any BST subtree found in the given binary tree.

Key insight: Post-order DFS. Each node returns (is_bst, min_val, max_val, subtree_sum). Check BST conditions using min/max, accumulate sum.

Solutions

// C++
int maxBSTSum = 0;
// returns {isBST, min, max, sum}
tuple<bool,int,int,int> dfs(TreeNode* node) {
    if (!node) return {true, INT_MAX, INT_MIN, 0};
    auto [lb, lmin, lmax, lsum] = dfs(node->left);
    auto [rb, rmin, rmax, rsum] = dfs(node->right);
    if (lb && rb && lmax < node->val && node->val < rmin) {
        int sum = lsum + rsum + node->val;
        maxBSTSum = max(maxBSTSum, sum);
        return {true, min(lmin, node->val), max(rmax, node->val), sum};
    }
    return {false, 0, 0, 0};
}
int maxSumBST(TreeNode* root) {
    maxBSTSum = 0; dfs(root); return maxBSTSum;
}
// Java
int maxBSTSum = 0;
public int maxSumBST(TreeNode root) {
    dfs(root); return maxBSTSum;
}
// returns [isBST(1/0), min, max, sum]
int[] dfs(TreeNode node) {
    if (node == null) return new int[]{1, Integer.MAX_VALUE, Integer.MIN_VALUE, 0};
    int[] l = dfs(node.left), r = dfs(node.right);
    if (l[0] == 1 && r[0] == 1 && l[2] < node.val && node.val < r[1]) {
        int sum = l[3] + r[3] + node.val;
        maxBSTSum = Math.max(maxBSTSum, sum);
        return new int[]{1, Math.min(l[1], node.val), Math.max(r[2], node.val), sum};
    }
    return new int[]{0, 0, 0, 0};
}
// JavaScript
function maxSumBST(root) {
    let maxSum = 0;
    function dfs(node) {
        if (!node) return [true, Infinity, -Infinity, 0];
        const [lb, lmin, lmax, lsum] = dfs(node.left);
        const [rb, rmin, rmax, rsum] = dfs(node.right);
        if (lb && rb && lmax < node.val && node.val < rmin) {
            const sum = lsum + rsum + node.val;
            maxSum = Math.max(maxSum, sum);
            return [true, Math.min(lmin, node.val), Math.max(rmax, node.val), sum];
        }
        return [false, 0, 0, 0];
    }
    dfs(root);
    return maxSum;
}
# Python
def maxSumBST(root):
    max_sum = [0]

    def dfs(node):
        if not node:
            return True, float('inf'), float('-inf'), 0
        lb, lmin, lmax, lsum = dfs(node.left)
        rb, rmin, rmax, rsum = dfs(node.right)
        if lb and rb and lmax < node.val < rmin:
            s = lsum + rsum + node.val
            max_sum[0] = max(max_sum[0], s)
            return True, min(lmin, node.val), max(rmax, node.val), s
        return False, 0, 0, 0

    dfs(root)
    return max_sum[0]

Complexity

  • Time: O(n)
  • Space: O(h)

Key Insight

Propagate (min, max) up through subtrees to validate BST property efficiently in a single pass.

Advertisement

Sanjeev Sharma

Written by

Sanjeev Sharma

Full Stack Engineer · E-mopro