본문 바로가기

알고리즘 관련/알고리즘&이론

advanced disjoint set

오늘은 disjoint set에 대해서 글을 적어보겠습니다.


직역하면 상호배타적 집합이며, union-find 라는 자료구조로 표현됩니다.  [disjoint set, DSU(disjoint union),union find] 는 전부 union find 자료구조를 칭하는 말입니다. 


union find 는 union연산과 find연산을 매우 빠르게 해주는 자료구조로서, 상호배타 집합에서는 어떠한 집합에서도 공통 된 원소가 존재하지 않으며, 동시에 모든 집합의 합집합은 전체집합이 됩니다.


이해를 돕기 위하여 특수한 상황을 가정해보겠습니다. 


a라는 나라와 b라는 나라가 동맹이고, b라는 나라와 c라는 나라가 동맹을 맺게된다면 a와 c나라도 자동으로 동맹이 되는 경우를 생각해봅시다.


해당 경우로 각각의 나라의 동맹이 체결된다면, 동맹으로 이루어진 국가들의 관계가 하나의 집합을 이룰것이며, 어떠한 나라도 두개 이상의 집합에 속하지 않게됩니다.(동맹이 이루어지면 집합이 하나로 합쳐지기 때문)


이러한 특수한 상황이 문제에서 주어진다면 dsu로 모델링하여 문제를 해결할 수 있게됩니다.


다시 본론으로 돌아가서 위에서 말했듯 union find는 union 연산과 find 연산을 지원하는 포레스트 모양의 자료구조입니다.


먼저 union 연산에 대한 개념을 알아봅시다.

자 이런식으로 초기 상태에는 각각의 정점이 하나의 컴포넌트를 이루며 모두 분리되어 있는 상태입니다. 여기서 union 연산을 이용하여 두 집합을 하나로 합치게 됩니다.


만약에 위의 상태에서 1번 정점이 속한 집합과 3번 정점이 속한 집합을 합치게 된다면

이런식으로 1번 정점과 3번 정점이 하나의 집합을 이루게 되며 그래프 개념에서는 하나의 컴포넌트를 이루는 모양이 됩니다.


여기서 만약 1번 정점과 4번 정점이 union 하게 된다면

이렇게 1,3,4번 정점이 하나의 집합을 이루게 됩니다.


이제 find 연산에 대한 개념을 알아봅시다.


사실 위의 그림들은 disjoint set에 대한 이해를 돕기 위해 일자로 된 모양을 그렸지만 실제로는 하나의 집합은 트리이며 전체 그래프는 포레스트의 모양을 띕니다.


union 연산이 이루어지는 과정은 union을 수행하려는 두 집합의 대표값(트리의 루트) a와 b가 있을 때, a의 자식으로 b를 연결시켜주어 하나의 트리로 만드는 연산입니다. 


이 union 연산을 위하여 대표값(트리의 루트)를 구하는 연산을 빠르게 수행하는 연산이 find 연산입니다.


이를 구현하는 방법은 단순히 자기의 부모를 저장하는 방법으로 구현할 수 있습니다.


처음에는 자신의 부모를 자신으로(루프)를 만들고, 이후 union연산을 하면서 부모 관계를 재설정 해주는 것입니다.


우선 초기값과 find함수를 소스코드로 구현하면

1
2
3
4
5
6
7
8
9
10
11
12
13
#include <cstdio>
#include <algorithm>
using namespace std;
const int N=1001;
int par[N];
void init(){
    for(int i=1;i<N;i++)
        par[i]=i;
}
int find(int here){
    if(here==par[here])return here;
    return find(par[here]);
}
cs


이렇게 구현할 수 있습니다. init 함수에서 자신의 부모를 자신으로 설정해준 뒤


find함수를 실행하면 만약 자신의 부모가 자신일 경우(루트를 의미) 자신을 반환 그게 아니라면 계속 부모를 따라올라가는 작업을 수행하여 루트를 찾아줍니다.


하지만 이런 경우를 생각해봅시다.

이렇게 트리의 모양이 리스트 형태로 이루어져 있다면 N번 정점에서 find 함수를 호출한다면 N번 정점에서 find 함수를 호출 할 때 마다 O(N)의 시간에 루트를 찾아낼 것입니다. 이는 너무 비효율 적이기 때문에 소스를 살짝 수정하여

1
2
3
4
int find(int here){
    if(here==par[here])return here;
    return par[here]=find(par[here]);
}

cs

이렇게 작성을 해준다면 한번 find 연산을 수행하면 해당 트리는 depth가 최대 2인 트리가 됩니다. 따라서 연산 재호출 시 시간을 줄여줄 수 있습니다.


그림으로 보자면 N 정점에서 find 연산을 수행한다면 

이렇게 트리의 깊이가 최대 2인 모양이 됩니다.


자 이제 union 연산은 소스로 어떻게 구현되는지 한번 봅시다.


1
2
3
4
5
void merge(int x,int y){
    x=find(x),y=find(y);
    if(x==y)return;
    par[x]=y;
}
cs

union 연산도 사실 매우 간단한 소스코드로 이루어져 있습니다. x와 y를 find함수를 이용하여 각각 루트로 설정해준 뒤, 두 정점이 서로 다른 집합에 속한다면 하나의 정점을 하나의 자식으로 설정해주는 간단한 소스로 이루어져 있습니다. 사실 상 시간복잡도는 find함수에 지배받습니다.


find 함수의 시간복잡도는 아크만 함수에 의해 정의되는 함수로 사실 상 상수시간에 근접하다고 생각하시면 될 것 같습니다.




여기까지가 일반 dsu 문제를 해결하는데 필요한 정보이며, 이러한 정보는 다른 블로그에 더 자세하게 설명되어 있을거라고 생각합니다. 굳이 dsu를 포스팅한 이유는 사실 뒤에 다루어질 내용을 위해서 입니다.



이제 dsu의 조금 특별한 상태에 대해서 이야기 해보겠습니다.


상호배타적 집합에서 union 연산만 고려하는것이 아니라, 서로 대립되는 관계가 주어진다고 가정해봅시다.


예를들어 1,2,3,4의 독립적인 집합이 존재한다고 가정을 하고,


1번과 3번 집합이 대립, 2번과 4번 집합이 대립된다고 할 때


1번 정점과 2번 정점이 union 된다면, 3번 정점과 4번 정점도 union되게 됩니다.

이러한 특수한 경우의 dsu를 구현해보겠습니다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#include <cstdio>
#include <algorithm>
using namespace std;
const int N=1001;
int par[N],enemy[N];
void init(){
    for(int i=1;i<N;i++){
        par[i]=i;
        enemy[i]=-1;
    }
}
int find(int here){
    if(here==par[here]||here==-1)return here;
    return par[here]=find(par[here]);
}
void make_friend(int x,int y){
    x=find(x),y=find(y);
    if(x==y)return;
    int ex=find(enemy[x]),ey=find(enemy[y]);
    par[x]=y;   //y가 root가 됨
    if(ex>ey)swap(ex,ey);   //ex<ey가 되도록 강제
    if(ex!=-1)
        par[ex]=ey; //ey가 root가 됨
    if(ey!=-1)
        enemy[ey]=y;    //대립 관계
}
void make_enemy(int x,int y){
    x=find(x),y=find(y);
    int ex=find(enemy[x]),ey=find(enemy[y]);
    if(x==ey||y==ex)return;
    if(ex!=-1)par[ex]=y;    //y를 root로
    if(ey!=-1)par[ey]=x;    //x를 root로
    enemy[x]=y;
    enemy[y]=x; //대립 관계
}
cs


이런식으로 dsu의 대립 관계를 정의 해줄 수 있습니다.


이런 개념으로 풀 수 있는 문제로 boj 15674 가로수 문제가 있습니다.


해당 문제에서 같은 나무를 심는 쿼리는 friend로 다른 나무를 심는 쿼리는 enemy로 처리해준다면 효율적으로 문제를 해결할 수 있습니다.

 

해당 문제의 소스코드를 첨부하는 것으로 글을 마치겠습니다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include <cstdio>
#include <algorithm>
using namespace std;
typedef long long ll;
ll n,d,q,a[200020],b[200020],par[200020],enemy[200020],x,y,z,aa[200020],bb[200020],res;
ll find(ll h){
    if(h==-1||par[h]==h)return h;
    return par[h]=find(par[h]);
}
ll mincost(ll x){
    x=find(x);
    ll ex=find(enemy[x]);
    ll l=a[x],r=b[x];
    if(ex!=-1){
        l+=b[ex],r+=a[ex];
    }
    return min(l,r);
}
void merge(ll x,ll y){
    par[x]=y;
    a[y]+=a[x],b[y]+=b[x];
}
void make_friend(ll x,ll y){
    x=find(x),y=find(y);
    if(x==y)return;
    ll ex=find(enemy[x]),ey=find(enemy[y]);
    res-=mincost(x)+mincost(y);
    merge(x,y);
    if(ex>ey)swap(ex,ey);
    if(ex!=-1)merge(ex,ey);
    if(ey!=-1)enemy[ey]=y;
    enemy[y]=ey;
    res+=mincost(x);
}
void make_enemy(ll x,ll y){
    x=find(x),y=find(y);
    ll ex=find(enemy[x]),ey=find(enemy[y]);
    if(x==ey||y==ex)return;
    res-=mincost(x)+mincost(y);
    if(ex!=-1)merge(ex,y);
    if(ey!=-1)merge(ey,x);
    enemy[x]=y;
    enemy[y]=x;
    res+=mincost(x);
}
int main(){
    scanf("%lld%lld",&n,&d);
    for(int i=1;i<=n;i++){
        scanf("%lld%lld",&a[i],&b[i]);
        par[i]=i,enemy[i]=-1,aa[i]=a[i],bb[i]=b[i];
        res+=mincost(i);
    }
    while(d--){
        scanf("%lld%lld%lld",&x,&y,&z);
        if(x)
            make_enemy(y,z);
        else
            make_friend(y,z);
    }
    printf("%lld\n",res);
    scanf("%lld",&q);
    while(q--){
        scanf("%lld%lld%lld",&x,&y,&z);
        if(x==0)
            make_friend(y,z);
        else if(x==1)
            make_enemy(y,z);
        else if(x==2){
            res-=mincost(y);
            ll curr=find(y);
            a[curr]+=-aa[y]+z;
            aa[y]=z;
            res+=mincost(y);
        }
        else{
            res-=mincost(y);
            ll curr=find(y);
            b[curr]+=-bb[y]+z;
            bb[y]=z;
            res+=mincost(y);
        }
        printf("%lld\n",res);
    }
    return 0;
}
cs