BZOJ4565【HAOI2016】字符合并 <状压DP>

Problem

【HAOI2016】字符合并

Time  Limit:  20  Sec\mathrm{Time\;Limit:\;20\;Sec}
Memory  Limit:  256  MB\mathrm{Memory\;Limit:\;256\;MB}

Description

有一个长度为 nn0101 串,你可以每次将相邻的 kk 个字符合并,得到一个新的字符并获得一定分数。得到的新字符和分数由这 kk 个字符确定。你需要求出你能获得的最大分数。

Input

第一行两个整数nnkk。接下来一行长度为nn0101串,表示初始串。接下来2k2^k行,每行一个字符cic_i和一个整数wiw_icic_i表示长度为kk0101串连成二进制后按从小到大顺序得到的第ii种合并方案得到的新字符,wiw_i表示对应的第ii种方案对应获得的分数。1n3001\le n\le 300, 0ci10\le c_i\le 1, wi1w_i\ge 1, k8k\le 8

Output

输出一个整数表示答案

Sample Input

1
2
3
4
5
6
3 2
101
1 10
1 10
0 20
1 30

Sample Output

1
40

标签:状压DP

Solution

这显然是一道状压DPDPk8k\le 8
考虑用f[i][j][k]f[i][j][k]表示将字符序列[i,j][i, j]表示为状态kk的最大分数。
初始状态为f[i][i][s[i]]=0f[i][i][s[i]] = 0 (1in1\le i\le n)
转移则将[i,j][i, j]斩成两半[i,mid][i, mid][mid+1,j][mid+1, j]f[i][j][k<<1]=max{f[i][mid][k]+f[mid+1][j][0]}f[i][j][k<<1] = max\{f[i][mid][k]+f[mid+1][j][0]\}f[i][j][k<<11]=max{f[i][mid][k]+f[mid+1][j][1]}f[i][j][k<<1|1] = max\{f[i][mid][k]+f[mid+1][j][1]\}
特别需要注意的是,当(ji)(m1)(j-i)|(m-1)时(mm为题目中的kk),会刚好变为一个字符,因此不能像上面那样递推,应为f[i][j][0]=max{f[i][j][sta]+w[sta]}f[i][j][0] = max\{f[i][j][sta]+w[sta]\}(0sta<(1<<m)0\le sta<(1<<m)stasta可为00),f[i][j][1]=max{f[i][j][sta]+w[sta]}f[i][j][1] = max\{f[i][j][sta]+w[sta]\}(0sta<(1<<m)0\le sta<(1<<m)stasta可为00)。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#include <iostream>
#include <cstdio>
#include <cstring>
#define MAX_K 8
#define MAX_N 300
using namespace std;
typedef long long lnt;
char str[MAX_N+5];
int n, m, s[MAX_N+5], c[1<<MAX_K];
lnt w[1<<MAX_K], f[MAX_N+5][MAX_N+5][1<<MAX_K];
void upd(lnt &a, lnt b) {if (a < b) a = b;}
int main() {
scanf("%d%d%s", &n, &m, str), memset(f, -1, sizeof(f));
for (int i = 0; i < (1<<m); i++) scanf("%d%lld", c+i, w+i);
for (int i = 1; i <= n; i++) s[i] = str[i-1]-'0', f[i][i][s[i]] = 0;
for (int l = 2; l <= n; l++)
for (int s = 1, t = s+l-1; t <= n; t = ++s+l-1) {
int tar = l-1; while (tar >= m) tar -= m-1;
for (int mid = t-1; mid >= s; mid -= m-1) for (int sta = 0; sta < (1<<tar); sta++) {
if (~f[s][mid][sta] && ~f[mid+1][t][0]) upd(f[s][t][sta<<1], f[s][mid][sta]+f[mid+1][t][0]);
if (~f[s][mid][sta] && ~f[mid+1][t][1]) upd(f[s][t][sta<<1|1], f[s][mid][sta]+f[mid+1][t][1]);
}
if (tar == m-1) {
lnt g[2]; g[0] = g[1] = -1;
for (int sta = 0; sta < (1<<m); sta++)
if (~f[s][t][sta]) upd(g[c[sta]], f[s][t][sta]+w[sta]);
f[s][t][0] = g[0], f[s][t][1] = g[1];
}
}
lnt ans = 0; for (int i = 0; i < (1<<m); i++) ans = max(ans, f[1][n][i]); printf("%lld", ans); return 0;
}