문제
풀이
비트 연산의 특징을 알아야 풀 수 있는 문제였다.
- 임의의 노드에서 시작한 결과를 dp에 저장한다.
- 각 비트를 0인 그룹과, 1인 그룹으로 나눈다.(temp에 1일 경우만 저장)
- 모든 노드에 대해 헤더와 XOR연산으로 비용을 구한다.
- 최소값을 출력한다.
3번에서 dp에 저장된 값과 다른 비트가 헤더로 들어가면 0인 그룹과 1인 그룹을 서로 바꿔주기만 하면 된다. 따라서 각 노드마다 O(1)만에 비용을 구할 수 있다.
비트 연산자에 아직 익숙하지가 않아서 dp를 설정하는 것부터 어려움이 있었다. xor연산의 특징을 이해하고, 아이디어만 떠올리면 쉽게 풀이 가능한 문제였다.
풀이 코드
import sys
rl = sys.stdin.readline
n, x = map(int, rl().split())
g = [[] for _ in range(n)]
for _ in range(n-1):
u, v, w = map(int, rl().split())
g[u-1].append([v-1,w])
g[v-1].append([u-1,w])
dp = [0 for _ in range(n)]
visited = [0 for _ in range(n)]
dp[0], visited[0] = x, 1
q = [[0, x]]
while q:
node, nowx = q.pop()
for nextNode, w in g[node]:
if visited[nextNode]: continue
visited[nextNode] = 1
dp[nextNode] = nowx^w
q.append([nextNode, nowx^w])
temp = [[]for _ in range(20)]
for i,a in enumerate(dp):
for j in range(20):
if a&(1<<j): temp[j].append(i)
ans = sys.maxsize
for i in dp:
out = 0
i^=x
for j in range(20):
if i&(1<<j): out += n-len(temp[j])
else: out += len(temp[j])
ans = min(out, ans)
print(ans)