Tree to greater sum
This post is part of the Algorithms Problem Solving series.
Problem description
This is the Binary Search Tree to Greater Sum Tree problem. The description looks like this:
Given the root of a binary search tree with
distinct values, modify it so that
every node
has a new value equal to the sum of
the values of the original tree that are greater than or equal
to node.val
.
As a reminder, a binary search tree is a tree that satisfies these constraints:
- The left subtree of a node contains only nodes with keys less than the node's key.
- The right subtree of a node contains only nodes with keys greater than the node's key.
- Both the left and right subtrees must also be binary search trees.
Examples
Input: [4,1,6,0,2,5,7,null,null,null,3,null,null,null,8]
Output: [30,36,21,36,35,26,15,null,null,null,33,null,null,null,8]
Solution
My first approach was to traverse the tree to get the sum of all node values and all the node values.
def sum_and_list(node, total, values):
left_total = 0
right_total = 0
left_values = []
right_values = []
if node.left:
[left_total, left_values] = sum_and_list(node.left, total, values)
if node.right:
[right_total, right_values] = sum_and_list(node.right, total, values)
return [
total + left_total + node.val + right_total,
values + left_values + [node.val] + right_values
]
Then I built a hash map to map a node value to the greater sum it can have. So, for the example above, it would have this illustration:
{
0: 36,
1: 36,
2: 35,
3: 33,
4: 30,
5: 26,
6: 21,
7: 15,
8: 8,
}
The hash map creation algorithm is pretty simple:
smaller_total = 0
mapper = {}
for value in values:
mapper[value] = total - smaller_total
smaller_total += value
Now I can use this hash map to modify each tree node. So, I traverse the tree again and update the node value with the value in the hash map. And then just return the root with all the tree modified.
def modify_helper(node, mapper):
if node.left:
modify_helper(node.left, mapper)
if node.right:
modify_helper(node.right, mapper)
node.val = mapper[node.val]
return node
The bst_to_gst
calls all the functions we built and
return the modified node.
def bst_to_gst(root):
[total, values] = sum_and_list(root, 0, [])
smaller_total = 0
mapper = {}
for value in values:
mapper[value] = total - smaller_total
smaller_total += value
return modify_helper(root, mapper)
We could also build a reversed in order tree traversal. So the
algorithm will start with the right most node and then go to the left
side. We increment the value
as we traverse the tree.
value = 0
def bst_to_gst(node):
if node.right:
bst_to_gst(node.right)
node.val = node.val + value
value = node.val
if node.left:
bst_to_gst(node.left)
return node
Resources
- Learning Python: From Zero to Hero
- Algorithms Problem Solving Series
- Stack Data Structure
- Queue Data Structure
- Linked List
- Tree Data Structure
Have fun, keep learning, and always keep coding!