Skip to main content
  1. Posts/

LeetCode-2846 边权重均等查询

·2 mins·

LeetCode-2846 边权重均等查询 #

Solution 1 #

最近公共祖先的模板题. 思路来自 0x3f , 这里是他的讲解视频 模运算 最近公共祖先【力扣周赛 361】

class Solution:
    def minOperationsQueries(self, n: int, edges: List[List[int]], queries: List[List[int]]) -> List[int]:
        # 用邻接表格式存储图
        g = [[] for _ in range(n)]
        for x, y, w in edges:
            g[x].append((y, w - 1))
            g[y].append((x, w - 1))

        m = n.bit_length() # 最多跳 2^m 步
        pa = [[-1] * m for _ in range(n)] # 父节点
        cnt = [[[0] * 26 for _ in range(m)] for _ in range(n)] # cnt[x][i][w] 存储节点 x 向上跳 2^i 步, 途径的路径中有多少条边权为 w 的边
        depth = [0] * n # 记录深度

        # dfs 更新父节点, 深度以及 cnt[x][0]
        def dfs(x: int, fa: int) -> None:
            pa[x][0] = fa
            for y, w in g[x]:
                if y != fa:
                    cnt[y][0][w] = 1
                    depth[y] = depth[x] + 1
                    dfs(y, x)
        dfs(0, -1)

        # 倍增
        for i in range(m - 1):
            for x in range(n):
                p = pa[x][i]
                if p != -1:
                    pp = pa[p][i]
                    pa[x][i + 1] = pp
                    for j, (c1, c2) in enumerate(zip(cnt[x][i], cnt[p][i])):
                        cnt[x][i + 1][j] = c1 + c2

        ans = []
        for x, y in queries:
            path_len = depth[x] + depth[y]  # 由于 x, y 会变化, 所以提前计算, 最后减去 depth[lca] * 2
            cw = [0] * 26
            if depth[x] > depth[y]:
                x, y = y, x

            # 先提升 y, 使 y 和 x 在同一深度
            k = depth[y] - depth[x]
            for i in range(k.bit_length()):
                if (k >> i) & 1:  # 深度差 k 二进制从低到高第 i 位是 1, 就向上跳 2^i 步
                    p = pa[y][i]
                    for j, c in enumerate(cnt[y][i]):
                        cw[j] += c
                    y = p

            # 如果不在同一侧, 同时向上跳, 寻找 lca
            if y != x:
                for i in range(m - 1, -1, -1):
                    px, py = pa[x][i], pa[y][i]
                    if px != py: # 没跳过 lca, 所以跳
                        for j, (c1, c2) in enumerate(zip(cnt[x][i], cnt[y][i])):
                            cw[j] += c1 + c2
                        x, y = px, py  # 同时上跳 2^i 步
                # 最后跳一步就到达 lca, 注意这里没有跳
                for j, (c1, c2) in enumerate(zip(cnt[x][0], cnt[y][0])):
                    cw[j] += c1 + c2
                x = pa[x][0] # 最后跳一步

            lca = x
            path_len -= depth[lca] * 2
            ans.append(path_len - max(cw))
        return ans