Tree to greater sum

Mapples Photo by Jérôme Prax

This post is part of the Algorithms Problem Solving series.

Problem description

This is the Binary Search Tree to Greater Sum Tree problem. The description looks like this:

Given the root of a binary search tree with distinct values, modify it so that every node has a new value equal to the sum of the values of the original tree that are greater than or equal to node.val.

As a reminder, a binary search tree is a tree that satisfies these constraints:


Input: [4,1,6,0,2,5,7,null,null,null,3,null,null,null,8]
Output: [30,36,21,36,35,26,15,null,null,null,33,null,null,null,8]


My first approach was to traverse the tree to get the sum of all node values and all the node values.

def sum_and_list(node, total, values):
    left_total = 0
    right_total = 0
    left_values = []
    right_values = []

    if node.left:
        [left_total, left_values] = sum_and_list(node.left, total, values)

    if node.right:
        [right_total, right_values] = sum_and_list(node.right, total, values)

    return [
        total + left_total + node.val + right_total,
        values + left_values + [node.val] + right_values

Then I built a hash map to map a node value to the greater sum it can have. So, for the example above, it would have this illustration:

  0: 36,
  1: 36,
  2: 35,
  3: 33,
  4: 30,
  5: 26,
  6: 21,
  7: 15,
  8: 8,

The hash map creation algorithm is pretty simple:

smaller_total = 0
mapper = {}

for value in values:
    mapper[value] = total - smaller_total
    smaller_total += value

Now I can use this hash map to modify each tree node. So, I traverse the tree again and update the node value with the value in the hash map. And then just return the root with all the tree modified.

def modify_helper(node, mapper):
    if node.left:
        modify_helper(node.left, mapper)

    if node.right:
        modify_helper(node.right, mapper)

    node.val = mapper[node.val]
    return node

The bst_to_gst calls all the functions we built and return the modified node.

def bst_to_gst(root):
    [total, values] = sum_and_list(root, 0, [])

    smaller_total = 0
    mapper = {}

    for value in values:
        mapper[value] = total - smaller_total
        smaller_total += value

    return modify_helper(root, mapper)

We could also build a reversed in order tree traversal. So the algorithm will start with the right most node and then go to the left side. We increment the value as we traverse the tree.

value = 0

def bst_to_gst(node):
    if node.right:

    node.val = node.val + value
    value = node.val

    if node.left:

    return node


Have fun, keep learning, and always keep coding!

My Twitter and Github

Patreon Become a Patron Coffee icon Buy me a coffee