BZOJ1415【NOI2005】聪聪和可可 <概率DP+最短路>

Problem

【NOI2005】聪聪和可可

Time  Limit:  10  Sec\mathrm{Time\;Limit:\;10\;Sec}
Memory  Limit:  162  MB\mathrm{Memory\;Limit:\;162\;MB}

Description

Input

第一行为两个整数NNEE,以空格分隔,分别表示森林中的景点数和连接相邻景点的路的条数。
第二行包含两个整数CCMM,以空格分隔,分别表示初始时聪聪和可可所在的景点的编号。
接下来EE行,每行两个整数,第i+2i+2行的两个整数AiA_iBiB_i表示景点AiA_i和景点BiB_i之间有一条无向边。
输入保证任何两个景点之间不会有多于一条路直接相连,且聪聪和可可之间必有路直接或间接的相连。

Output

输出一个实数,四舍五入保留三位小数,表示平均多少个时间单位后聪聪会把可可吃掉。

Sample Input

Input #1

1
2
3
4
5
4 3
1 4
1 2
2 3
3 4

Input #2

1
2
3
4
5
6
7
8
9
10
11
9 9
9 3
1 2
2 3
3 4
4 5
3 6
4 6
4 7
7 8
8 9

Sample Output

Output #1

1
1.500

Output #2

1
2.167

HINT

对于50%50\%的数据,1N501\le N\le50
对于所有的数据,1N,E10001\le N,E\le1000

标签:概率DP 最短路

Solution

经典概率DP\mathrm{DP}

暴力BFS\mathrm{BFS}预处理数组nxt[][]nxt[][],其中nxt[x][y]nxt[x][y]表示聪聪在xx,可可在yy时,聪聪下一步会走到哪个点。
f[c][k]f[c][k]表示聪聪在cc,可可在kk时,期望下还需多少单位时间可可才会被抓住。显然状态之间是不会存在环的,这是因为两者间的距离会不断缩小。
状态转移:

  • c=kc=k,则f[c][k]=0f[c][k]=0
  • dis(c,k)2dis(c,k)\le2,则f[c][k]=1f[c][k]=1
  • 其他情况下,令p=nxt[nxt[c][k]][k]p=nxt[nxt[c][k]][k],则有

    f[c][k]=f[p][k]+kqf[p][q]deg(k)+1+1f[c][k]=\frac{f[p][k]+\sum_{k\to q}f[p][q]}{deg(k)+1}+1

在记忆化搜索时特判一下即可。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include <bits/stdc++.h>
#define MAX_N 1000
#define INF 0x3f3f3f3f
using namespace std;
typedef double dnt;
template <class T> inline void read(T &x) {
x = 0; int c = getchar(), f = 1;
for (; !isdigit(c); c = getchar()) if (c == 45) f = -1;
for (; isdigit(c); c = getchar()) (x *= 10) += f*(c-'0');
}
int n, m, sc, sk;
vector <int> G[MAX_N+5];
int dis[MAX_N+5][MAX_N+5];
int nxt[MAX_N+5][MAX_N+5];
dnt d[MAX_N+5], f[MAX_N+5][MAX_N+5];
void BFS(int s) {
queue <int> que; que.push(s), dis[s][s] = 0;
while (!que.empty()) {
int u = que.front(); que.pop();
for (int i = 0, v; i < (int)G[u].size(); i++)
if (dis[s][v = G[u][i]] == INF)
dis[s][v] = dis[s][u]+1, que.push(v);
}
}
dnt DP(int c, int k) {
dnt &ret = f[c][k];
if (ret >= 0) return ret;
if (c == k) return ret = 0;
if (dis[c][k] <= 2) return ret = 1; ret = 0;
for (int i = 0, t; i < (int)G[k].size(); i++)
t = G[k][i], ret += DP(nxt[nxt[c][k]][k], t)/(d[k]+1);
return ret += DP(nxt[nxt[c][k]][k], k)/(d[k]+1)+1;
}
int main() {
read(n), read(m), read(sc), read(sk);
for (int i = 1, u, v; i <= m; i++)
read(u), read(v),
G[u].push_back(v), d[u]++,
G[v].push_back(u), d[v]++;
memset(dis, INF, sizeof dis);
for (int i = 1; i <= n; i++) BFS(i);
for (int c = 1; c <= n; c++) for (int k = 1; k <= n; k++)
for (int i = 0, t; i < (int)G[c].size(); i++)
if (dis[t = G[c][i]][k] < dis[nxt[c][k]][k] || !nxt[c][k]) nxt[c][k] = t;
else if (dis[t][k] == dis[nxt[c][k]][k] && t < nxt[c][k]) nxt[c][k] = t;
memset(f, -1, sizeof f); return printf("%.3lf\n", DP(sc, sk)), 0;
}