본문 바로가기

알고리즘 관련/알고리즘&이론

Persistent Segment Tree

먼저 시작하기 전에 ..


-오늘 공부한 내용을 제 멍청한 머리가 기억하지 못할까봐 기록하는 개인용 글입니다.  

-해당 자료구조에 대하여 완벽하게 알고 적는 글이 아니기 때문에 Persistent segment tree를 공부하는 용도로는 좋지 못한 글이 될 확률이 높습니다.

-BOJ 11932 문제를 풀기 위해 https://hongjun7.tistory.com/m/64http://blog.myungwoo.kr/100 를 많은 부분에서 참고하였습니다.




persistent segment tree가 뭔지는 알고있었지만 귀찮아서 공부를 미루다가 오늘 BOJ 11932번을 접하고 도전하게 되었다.


persistent segment tree는 가상의 N개의 세그먼트 트리를 O(NlogN)의 공간복잡도로 유지하는 기법이다.


주로 update 질의가 없는 2D 구간 query를 처리할 때 사용하며 2D 구간처리를 할 때 N개의 세그먼트트리를 유지할 경우 쿼리를 처리하기 용이해지지만 이를 직접 구현할 경우 O(N^2)의 공간복잡도를 지니므로 굉장히 비효율적인데 (세그먼트 트리를 사용해야 할 경우 쿼리 당 시간 복잡도가 O(log N)에 처리 가능한 문제일 텐데 이는 곧 N이 10만 이상인 경우를 의미한다.) persistent segment tree를 이용할 경우 이를 가능케 해준다.


간단한 예를 하나 들면 2차원 좌표 평면에서 x좌표가 서로 다른 N개의 점이 있을 때 (x1,y1)~(x2,y2)의 구간에 있는 점의 수를 구하는 문제(x1,y1,x2,y2<=100000)를 생각해보자.


모든 x좌표가 하나의 세그먼트 트리를 가지고 있다고 생각해보자 

seg[x][node]가 x좌표가 x인 세그먼트 트리에서 node가 담당하는 구간(y1~y2)의 구간에 있는 점의 수를 저장하고 있다고 생각한다면

어떠한 x좌표값이 정해져 있을 때 y1~ y2의 구간의 합은 O(logN)에 구할 수 있다.


이 때 모든 x1~x2에 대하여 y1~y2 구간에 있는 점의 수를 구하는 경우는 O(NlogN)이 걸린다.


하지만 우리는 쿼리를 O(logN)에 처리하기를 원하고 이를 위하여 partial sum을 생각해보자


seg[x]가 1~x까지에 대하여 y1~y2의 구간에 있는 점의 수를 저장하고 있다고 생각하면 아까 구하고 싶은 쿼리는


seg[x2][node(y2~y1)]-seg[x1][node(y2~y1)]으로 구할 수 있을 것이다. 즉 O(logN)에 원하는 쿼리를 처리할 수 있는것이다.


이러한 개념을 가지고 Persistent segment tree에 update를 해보자


seg[x]라는 노드에 업데이트를 할 경우 seg[x-1]에서 x좌표가 x인 1차원 평면만 업데이트 되는 것이다.


고로 seg[x]는 seg[x-1]을 가르키면서 x좌표가 x인 부분만 업데이트를 하면 된다. x좌표가 x인 점은 오직 하나 뿐이니 이 점은 O(logN)의 구간에 업데이트 되므로 N개의 세그먼트 트리를 구축하는 시간복잡도와 공간복잡도는 O(NlogN)이 된다.


해당 예시를 확인해보고 알 수 있는점이 있다.


persistent segment tree를 구성하려면 업데이트 해야 할 점이 총 N개가 존재해야한다는 것이다.


업데이트 해야 할 점이 많아 질 경우 해당 시간/공간 복잡도도 같이 증가하기 때문에 업데이트 되어야 할 점이 적당한 범위내에 존재해야 구축이 가능하다.


다시 BOJ 11932번으로 넘어가보자


2차원 좌표평면은 아니지만 해당 트리를 루트에서 부터 리프로 업데이트 한다는 느낌으로 처리해준다면


X라는 정점까지 처리 됬을 때 X의 자식을 업데이트 하는 경우는 오로지 하나의 원소만 업데이트 되기 때문에 persistent segment tree를 적용시킬 수 있다.


각 정점의 가중치의 범위를 1~10000으로 설정해주고 싶기 때문에 좌표압축 기법을 이용하여 정점의 가중치를 해싱해준다.


해당 문제에서 묻는 쿼리는 X에서 Y까지의 경로에 K번째로 작은 가중치를 출력하는 문제이다.


seg[X,val]를 루트에서 X까지의 정점들에 대하여 val이하의 값을 가지는 정점의 수라고 정의하자


query(x,y,val)이 x에서 y까지 정점들에 대하여 val이하의 값을 가지는 정점의 수라고 정의한다면


query(x,y,val)= seg[x,val]+seg[y,val]-seg[lca(x,y),val]-seg[parent[lca(x,y)],val] 이 된다.

(이 수식은 그림을 그려보면 이해할 수 있을것이다.)


고로 우리는 query(x,y,val)을 알 수 있다면 이분탐색을 통하여 답을 구해낼 수 있을 것이다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#include <cstdio>
#include <algorithm>
#include <vector>
using namespace std;
int n, m, g[100010], par[100010][22], visited[100010], h[100010], sz;
vector<vector<int>> vt;
vector<int> idx;
struct node {
    int v;
    node *l, *r;
    node(int v, node *l, node *r) :v(v), l(l), r(r) {}
    node *update(int lo, int hi, int pos) {
        if (lo <= pos&&pos <= hi) {
            if (lo == hi)return new node(v + 1, NULL, NULL);
            int mid = (lo + hi) >> 1;
            return new node(v + 1, l->update(lo, mid, pos), r->update(mid + 1, hi, pos));
        }
        return this;
    }
}*seg[100010];    //segment tree의 노드
int query(node *x, node *y, node *anc, node* ancp, int l, int r, int k) {
    if (l == r)return l;
    int cnt = x->l->+ y->l->- anc->l->- ancp->l->v;
    int mid = (l + r) >> 1;
    if (cnt >= k)return query(x->l, y->l, anc->l, ancp->l, l, mid, k);
    return query(x->r, y->r, anc->r, ancp->r, mid + 1, r, k - cnt);
}        //이분 탐색을 통하여 k번째 수를 담은 버킷 번호를 return
int getidx(int pos) {
    return lower_bound(idx.begin(), idx.end(), pos) - idx.begin() + 1;
}    //좌표압축 된 좌표를 return
void dfs(int here, int p, int dph) {
    seg[here] = seg[p]->update(1, sz, getidx(g[here]));
    visited[here] = 1;
    h[here] = dph;
    for (auto next : vt[here]) {
        if (!visited[next]) {
            par[next][0= here;
            dfs(next, here, dph + 1);
        }
    }
}    //dfs를 통해 lca 전처리와 persistent segment tree를 구축
int lca(int x, int y) {
    if (h[x] > h[y])
        swap(x, y);
    for (int i = 20; i >= 0; i--) {
        if ((<< i) <= h[y] - h[x])
            y = par[y][i];
    }
    if (x == y)return x;
    for (int i = 20; i >= 0; i--) {
        if (par[x][i] != par[y][i]) {
            x = par[x][i];
            y = par[y][i];
        }
    }
    return par[x][0];
}    //lca를 return
int main() {
    scanf("%d%d"&n, &m);
    for (int i = 1; i <= n; i++) {
        scanf("%d"&g[i]);
        idx.push_back(g[i]);
    }
    sort(idx.begin(), idx.end());
    idx.erase(unique(idx.begin(), idx.end()), idx.end());    //가중치를 좌표 압축
    sz = idx.size();    //업데이트 될 value의 범위
    vt.resize(n + 1);
    for (int i = 0; i < n - 1; i++) {
        int x, y;
        scanf("%d%d"&x, &y);
        vt[x].push_back(y);
        vt[y].push_back(x);
    }
    seg[0]= new node(0, NULL, NULL);
    seg[0]->= seg[0]->= seg[0];   //init
    dfs(100);
    for (int i = 1; i < 21; i++) {
        for (int j = 1; j <= n; j++) {
            par[j][i] = par[par[j][i - 1]][i - 1];
        }
    }    //lca init
    for (int i = 0; i < m; i++) {
        int x, y, k;
        scanf("%d%d%d"&x, &y, &k);
        int anc = lca(x, y);
        printf("%d\n", idx[query(seg[x], seg[y], seg[anc], seg[par[anc][0]], 1, sz, k) - 1]);
    }    //return query
    return 0;
}
cs



관련 문제:http://jason9319.tistory.com/263

'알고리즘 관련 > 알고리즘&이론' 카테고리의 다른 글

Parallel binary search  (0) 2017.06.14
Fenwick Tree(Binary Indexed Tree)  (3) 2017.06.14
제1종 스털링 수  (0) 2017.05.29
Minimum Path cover in DAG  (0) 2017.05.29
이항계수를 구하는 알고리즘  (9) 2017.02.21