티스토리 뷰

알고리즘 관련/BOJ

BOJ)12995 트리나라

JASON 자손9319 2017. 6. 8. 01:41

문제: icpc.me/12995


트리가 주어질 때 노드가 K개인 서브트리의 개수를 구하는 문제이다.


트리DP임을 인지하고 dp[pos][state] <=( pos번 정점을 루트로하고 노드가 state개인 서브트리의 개수 )로 테이블을 잡은 뒤


점화식을 세우려고 했는데 자식들에게 state를 분배하는 방법이 아무리 생각해도 처리하기가 힘들어서 솔루션을 확인했다.


솔루션의 방법은 pos번 정점의 자식들을 분할하여 생각하는 방법이였는데 


우선 DP테이블은 dp[pos][idx][state]가 된다.


pos와 state는 아까와 의미가 같지만 idx라는 개념이 추가되었다. 


idx의 개념은 pos번 정점의 자식들이 C개 있다고 할 때 idx번~C번까지의 자식들만 존재한다고 생각하는 것으로 트리를 분할시킨다.


이렇게 트리를 분할 시킬 경우 2차원 디피일 때 처리하기 힘들었던 자식에게 state를 분배하는 경우를 처리해주는게 가능해진다.


예를 들어 3번 정점의 자식이 5개이고 state가 7이라고 해보자


그러면 우리는 3번 정점이 루트인 노드의 개수가 7개인 서브트리의 개수를 구할 때 dp[3][0][7]을 호출하게 된다.


이때 자식에게 분배를 할 때 idx번 자식을 노드로하는 서브트리를 분리하는 방식으로 분배를 한다. 


즉 어떠한 서브트리를 구하는 개수를

idx번 자식을 루트로 하는 서브트리 + pos를 루트로 하고 idx+1번 자식부터 취급하는 서브트리 

로 분배를 한다는것이다. 이 때 분배할 때 state의 개수를 분배를 해야하므로 for문을 돌리게 되므로 총 O(N^4)의 시간복잡도를 가지게 된다.


다만 for문으로 분배를 할 때 주의할게 현재 pos번은 루트로서 무조건 선택이 되야하므로 idx번 자식을 루트로하는 서브트리에는 state를 0~state-1 까지만 분배를 해야한다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include <cstdio>
#include <algorithm>
#include <vector>
#include <cstring>
#define MOD 1000000007LL
using namespace std;
typedef long long ll;
ll n, k, x, y, dp[55][55][55], visited[55], r;
vector<vector<int>> vt;
vector<vector<int>> tr;        
void dfs(int here) {
    visited[here] = true;
    for (auto next : vt[here]) {
        if (visited[next])
            continue;
        tr[here].push_back(next);
        dfs(next);
    }
}
ll func(ll pos, ll idx, ll state) {
    if (!state)return 1;
    if (idx >= tr[pos].size()) return state == 0;
    ll &ret = dp[pos][idx][state];
    if (ret != -1)return ret;
    ret = 0;
    for (int i = 0; i < state; i++) {    //루트가 항상 선택되어야 하기 때문에 state==i인 경우 X
        ret += func(tr[pos][idx], 0, i)*func(pos, idx + 1, state - i);
        ret %= MOD;
    }
    return ret;
}
int main() {
    scanf("%lld%lld"&n, &k);
    vt.resize(n + 1);
    tr.resize(n + 1);
    for (int i = 0; i < n - 1; i++) {
        scanf("%d%d"&x, &y);
        vt[x].push_back(y);
        vt[y].push_back(x);
    }
    dfs(1);
    memset(dp, -1sizeof(dp));
    for (int i = 1; i <= n; i++) {
        r += func(i, 0, k);
        r %= MOD;
    }
    printf("%lld\n", r);
    return 0;
}
cs


'알고리즘 관련 > BOJ' 카테고리의 다른 글

BOJ)1799 비숍  (0) 2017.06.08
BOJ)7154 Job Postings  (0) 2017.06.08
BOJ)12995 트리나라  (1) 2017.06.08
BOJ)1405 미친 로봇  (0) 2017.06.07
BOJ)10838 트리  (0) 2017.06.07
BOJ)14590 KUBC League (Small)  (0) 2017.06.06
댓글
댓글쓰기 폼