atcoder.jp
自力で解法が分からなかったため、解説を見てざっくりとした方針を把握してから実装した。
木DPなるものは聞いたことがなかったけど以下のように実装
>|python|
import sys
sys.setrecursionlimit(10 ** 8)
n, q = map(int, input().split())
x_list = list(map(int, input().split()))
G = [[] for _ in range(n + 1)]
for _ in range(n - 1):
a, b = map(int, input().split())
G[a].append(b)
G[b].append(a)
list_dict = [[] for _ in range(n + 1)]
flag_list = [False for _ in range(n + 1)]
def dfs(s):
global list_dict,flag_list
flag_list[s] = True
if len(G[s]) == 0:
list_dict[s].append(x_list[s - 1])
return
tmp = [x_list[s - 1]]
for next_p in G[s]:
if flag_list[next_p] == True:
continue
dfs(next_p)
tmp.extend(list_dict[next_p])
tmp.sort()
tmp = tmp[-max_k:]
list_dict[s] = tmp
v_list = []
k_list = []
for _ in range(q):
v, k = map(int, input().split())
v_list.append(v)
k_list.append(k)
max_k = max(k_list)
dfs(1)
for v, k in zip(v_list, k_list):
print(list_dict[v][-k])
||<
最初に書いたときは
>|python|
tmp.extend(list_dict[next_p])
||<
ではなく
>|python|
tmp = tmp + list_dict[next_p]
||<
としていたが、これだとTLEだった。
+演算子からextendにするだけで2200msから600msに改善できたので、今後はextendを使うようにする。