Maximum Sum BST in Binary Tree
Advertisement
Problem
Return the maximum sum of any BST subtree found in the given binary tree.
Key insight: Post-order DFS. Each node returns (is_bst, min_val, max_val, subtree_sum). Check BST conditions using min/max, accumulate sum.
Solutions
// C++
int maxBSTSum = 0;
// returns {isBST, min, max, sum}
tuple<bool,int,int,int> dfs(TreeNode* node) {
if (!node) return {true, INT_MAX, INT_MIN, 0};
auto [lb, lmin, lmax, lsum] = dfs(node->left);
auto [rb, rmin, rmax, rsum] = dfs(node->right);
if (lb && rb && lmax < node->val && node->val < rmin) {
int sum = lsum + rsum + node->val;
maxBSTSum = max(maxBSTSum, sum);
return {true, min(lmin, node->val), max(rmax, node->val), sum};
}
return {false, 0, 0, 0};
}
int maxSumBST(TreeNode* root) {
maxBSTSum = 0; dfs(root); return maxBSTSum;
}
// Java
int maxBSTSum = 0;
public int maxSumBST(TreeNode root) {
dfs(root); return maxBSTSum;
}
// returns [isBST(1/0), min, max, sum]
int[] dfs(TreeNode node) {
if (node == null) return new int[]{1, Integer.MAX_VALUE, Integer.MIN_VALUE, 0};
int[] l = dfs(node.left), r = dfs(node.right);
if (l[0] == 1 && r[0] == 1 && l[2] < node.val && node.val < r[1]) {
int sum = l[3] + r[3] + node.val;
maxBSTSum = Math.max(maxBSTSum, sum);
return new int[]{1, Math.min(l[1], node.val), Math.max(r[2], node.val), sum};
}
return new int[]{0, 0, 0, 0};
}
// JavaScript
function maxSumBST(root) {
let maxSum = 0;
function dfs(node) {
if (!node) return [true, Infinity, -Infinity, 0];
const [lb, lmin, lmax, lsum] = dfs(node.left);
const [rb, rmin, rmax, rsum] = dfs(node.right);
if (lb && rb && lmax < node.val && node.val < rmin) {
const sum = lsum + rsum + node.val;
maxSum = Math.max(maxSum, sum);
return [true, Math.min(lmin, node.val), Math.max(rmax, node.val), sum];
}
return [false, 0, 0, 0];
}
dfs(root);
return maxSum;
}
# Python
def maxSumBST(root):
max_sum = [0]
def dfs(node):
if not node:
return True, float('inf'), float('-inf'), 0
lb, lmin, lmax, lsum = dfs(node.left)
rb, rmin, rmax, rsum = dfs(node.right)
if lb and rb and lmax < node.val < rmin:
s = lsum + rsum + node.val
max_sum[0] = max(max_sum[0], s)
return True, min(lmin, node.val), max(rmax, node.val), s
return False, 0, 0, 0
dfs(root)
return max_sum[0]
Complexity
- Time: O(n)
- Space: O(h)
Key Insight
Propagate (min, max) up through subtrees to validate BST property efficiently in a single pass.
Advertisement