33. All Nodes Distance K in Binary Tree

We are given a binary tree (with root node root), a target node, and an integer value K.

Return a list of the values of all nodes that have a distance K from the target node. The answer can be returned in any order.

Example 1:

Input: root = [3,5,1,6,2,0,8,null,null,7,4], target = 5, K = 2

Output: [7,4,1]

Explanation: 
The nodes that are a distance 2 from the target node (with value 5)
have values 7, 4, and 1.

Solution:

Approach: If we know the parent of every node x, we know all nodes that are distance 1 from x Using 2 BFS 1. Storing Parents of each node in map 2. Finding all nodes at a distance k from target

/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
 
class Solution
{
public:
    unordered_map<int, TreeNode *> m;
    set<int> vs;
    vector<int> res;

    void findNodes(TreeNode *root, int k)
    {

        queue<TreeNode *> q;
        q.push(root);
        vs.insert(root->val);

        int dist = 0;

        while (!q.empty())
        {

            if (dist == k)
            {
                break;
            }

            int s = q.size();

            while (s--)
            {
                TreeNode *t = q.front();
                q.pop();

                if (t->left)
                {
                    if (vs.find(t->left->val) == vs.end())
                    {
                        q.push(t->left);
                        vs.insert(t->left->val);
                    }
                }

                if (t->right)
                {
                    if (vs.find(t->right->val) == vs.end())
                    {
                        q.push(t->right);
                        vs.insert(t->right->val);
                    }
                }

                if (m.find(t->val) != m.end())
                {
                    TreeNode *pr = m[t->val];
                    if (vs.find(pr->val) == vs.end())
                    {
                        q.push(pr);
                        vs.insert(pr->val);
                    }
                }
            }
            dist++;
        }

        while (!q.empty())
        {
            res.push_back(q.front()->val);
            q.pop();
        }
    }

    void findParent(TreeNode *root)
    {

        queue<TreeNode *> q;
        q.push(root);

        while (!q.empty())
        {
            TreeNode *t = q.front();
            q.pop();

            if (t->left)
            {
                m[t->left->val] = t;
                q.push(t->left);
            }

            if (t->right)
            {
                m[t->right->val] = t;
                q.push(t->right);
            }
        }

        return;
    }

    vector<int> distanceK(TreeNode *root, TreeNode *target, int K)
    {

        if(root == NULL){
            return res;
        }
        
        findParent(root);
        findNodes(target, K);

        return res;
    }
};

Time Complexity: o(n)

Last updated

Was this helpful?