https://www.acmicpc.net/problem/13511

 

13511번: 트리와 쿼리 2

N개의 정점으로 이루어진 트리(무방향 사이클이 없는 연결 그래프)가 있다. 정점은 1번부터 N번까지 번호가 매겨져 있고, 간선은 1번부터 N-1번까지 번호가 매겨져 있다. 아래의 두 쿼리를 수행하는 프로그램을 작성하시오. 1 u v: u에서 v로 가는 경로의 비용을 출력한다. 2 u v k: u에서 v로 가는 경로에 존재하는 정점 중에서 k번째 정점을 출력한다. k는 u에서 v로 가는 경로에 포함된 정점의 수보다 작거나 같다.

www.acmicpc.net


DFS와 점화식을 이용하여 parser_table( dp table )를 할 수 없다면 다른 쉬운 LCA 문제를 먼저 푸는게 좋을 것 같다.

 

u에서 v로 가는데 k번째 노드번호를 출력하는 문제이다.

u에서 v로 가는 경로를 생각해보면

u -> (u와 v의 lca) -> v

의 경로가 될 것이다.

u가 1번째 노드라고 한다면

k번째 노드가 lca 보다 왼쪽에 있는지 오른쪽에 있는지 판단한 뒤에

u 또는 v 부터 diff 만큼 올라간 노드가 답이 된다 !


#include <iostream>
#include <vector>
using namespace std;

#define MAX_N 100001
#define SWAP(a,b) {int t = b;  b = a; a = t;}
typedef long long ll;

vector<pair<int,int>> vec[MAX_N];
ll dis[MAX_N];
int depth[MAX_N];
int par[MAX_N][20];
bool visit[MAX_N];

void dfs(int now, int prev, ll nowDis, int nowDepth) {
     visit[now] = true;
     par[now][0] = prev;
     dis[now] = nowDis;
     depth[now] = nowDepth;
     for (int i = 0; i < vec[now].size(); i++) {
          if (visit[ vec[now][i].first] == false) {
               dfs(vec[now][i].first, now, nowDis + vec[now][i].second, nowDepth + 1);
          }
     }

}

int lca(int x, int y) {
     if (depth[x] > depth[y]) SWAP(x, y);

     for (int i = 19; i >= 0; i--) {
          int diff = depth[y] - depth[x];
          if (diff >= (1 << i)) {
               y = par[y][i];
          }
     }
     if (x == y) return x;
     for (int i = 19; i >= 0; i--) {
          if (par[x][i] != par[y][i]) {
               x = par[x][i];
               y = par[y][i];
          }
     }
     return par[x][0];
}

int getK(int x, int y, int k) {
     int lcaNode = lca(x, y);
     int x_to_lca = depth[x] - depth[lcaNode];
     int y_to_lca = depth[y] - depth[lcaNode];
     if (k <= x_to_lca + 1) {
          int diff = k - 1;
          for (int i = 19; i >= 0; i--) {
               if (diff >= (1 << i)) {
                    diff -= (1 << i);
                    x = par[x][i];
               }
          }
          return x;
     }
     else {
          int diff = (y_to_lca)-(k - (x_to_lca + 1));
          for (int i = 19; i >= 0; i--) {
               if (diff >= (1 << i)) {
                    diff -= (1 << i);
                    y = par[y][i];
               }
          }
          return y;
     }

}

int main() {
     ios::sync_with_stdio(false);
     cin.tie(0); cout.tie(0);
     int N; cin >> N;
     for (int i = 0; i < N-1; i++) {
          int a, b, w;   cin >> a >> b >> w;
          vec[a].push_back({ b,w });
          vec[b].push_back({ a,w });
     }
     
     dfs(1, 1, 0, 0);
     for (int i = 1; i <= 19; i++) {
          for (int j = 1; j <= N; j++) {
               par[j][i] = par[par[j][i - 1]][i - 1];
          }
     }

     int M; cin >> M;
     for (int i = 0; i < M; i++) {
          int tag; cin >> tag;
          if (tag == 1) {
               int x, y; cin >> x >> y;
               cout << dis[x] + dis[y] - 2 * dis[lca(x, y)] << "\n";
          }
          else {
               int x, y, k; cin >> x >> y >> k;
               cout << getK(x, y, k) << "\n";
          }
     }
}

+ Recent posts