classSolution: deffindMinHeightTrees(self, n: int, edges: List[List[int]]) -> List[int]: mx_down, smx_down, up, p = [0] * n, [0] * n, [0] * n, [-1] * n
g = [[] for _ inrange(n)] for a, b in edges: g[a].append(b) g[b].append(a)
defdfs1(u: int, pa: int) -> int: for e in g[u]: if e != pa: v = dfs1(e, u) if v + 1 >= mx_down[u]: smx_down[u] = mx_down[u] mx_down[u] = v + 1 p[u] = e if v + 1 > smx_down[u]: smx_down[u] = v + 1 return mx_down[u]
defdfs2(u: int, pa: int) -> None: for e in g[u]: if e != pa: if p[u] == e: up[e] = max(up[e], smx_down[u] + 1) else: up[e] = max(up[e], mx_down[u] + 1) up[e] = max(up[e], up[u] + 1) dfs2(e, u)
dfs1(0, -1) dfs2(0, -1) ret = [] v = 10 ** 10 for i inrange(n): m = max(mx_down[i], up[i]) if m < v: v = m ret.clear() ret.append(i) elif m == v: ret.append(i) return ret