BZOJ5418【NOI2018】屠龙勇士 < 扩展CRT >

Problem

标签:扩展CRT

Solution

set\mathrm{set}维护现有的所有剑,支持每次找巨龙血量的前驱,每次处理完一只龙后加入一个新的剑。问题转化为每次解方程x×ATKy×pi=aix\times ATK-y\times p_i=a_i,即裸扩展CRT\mathrm{CRT},用扩展欧几里得合并线性同余方程组即可。

\triangle:由于在计算过程中会超过long  long\mathrm{long\;long}的存储范围,需要在乘法时将其转换为double\mathrm{double}再转回来,这是因为double\mathrm{double}的数据范围比long  long\mathrm{long\;long}大。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include <bits/stdc++.h>
#define MAX_N 200000
#define INF 9000000000000000000LL
using namespace std;
typedef long long lnt;
typedef long double dnt;
template <class T> inline void read(T &x) {
x = 0; int c = getchar(), f = 1;
for (; !isdigit(c); c = getchar()) if (c == 45) f = -1;
for (; isdigit(c); c = getchar()) (x *= 10) += f*(c-'0');
}
int n, m; multiset <lnt> s;
lnt b[MAX_N+5], c[MAX_N+5]; lnt mx;
lnt a[MAX_N+5], p[MAX_N+5], d[MAX_N+5];
lnt GCD(lnt a, lnt b) {return b ? GCD(b, a%b) : a;}
lnt mul(lnt a, lnt b, lnt p) {
return ((a*b-(lnt)((dnt)a/p*b+1e-7)*p)+p)%p;
}
void ExGCD(lnt a, lnt b, lnt &x, lnt &y) {
if (!b) {x = 1, y = 0; return;}
ExGCD(b, a%b, y, x), y -= a/b*x;
}
bool calc(int id, lnt d, lnt a, lnt p) {
lnt gcd = GCD(d, p); if (a%gcd) return false;
lnt k = a/gcd, x, y; ExGCD(d, p, x, y);
y = p/gcd, x = mul(x, k, y);
b[id] = (x+y)%y, c[id] = y;
return true;
}
lnt getd(lnt a) {
set <lnt> :: iterator IT; lnt ret = 0;
IT = s.upper_bound(a);
if (IT == s.begin()) ret = (*IT), s.erase(IT);
else IT--, ret = (*IT), s.erase(IT);
return ret;
}
lnt sol(lnt a, lnt b) {lnt x, y; ExGCD(a, b, x, y); return (x%b+b)%b;}
bool merge(lnt a1, lnt m1, lnt a2, lnt m2, lnt &a3, lnt &m3) {
lnt c = a2-a1, d = GCD(m1, m2); if (c%d) return false;
c = (c%m2+m2)%m2, m1 /= d, m2 /= d, c /= d;
m3 = m1*m2*d, c = mul(c, sol(m1, m2), m2);
c = mul(c, m1, m3), c = mul(c, d, m3);
a3 = ((c+a1)%m3+m3)%m3; return true;
}
lnt CRT(lnt mx) {
lnt a1 = b[1], m1 = c[1];
for (int i = 2; i <= n; i++) {
lnt a2 = b[i], m2 = c[i], a3, m3;
if (!merge(a1, m1, a2, m2, a3, m3)) return -1;
a1 = a3, m1 = m3;
}
a1 = (a1%m1+m1)%m1;
if (a1 < mx) a1 += (mx-a1)/m1*m1;
if (a1 < mx) a1 += m1;
return a1;
}
int main() {
int T; read(T);
while (T--) {
s.clear();
read(n), read(m), mx = 0; bool f = false;
for (int i = 1; i <= n; i++) read(a[i]);
for (int i = 1; i <= n; i++) read(p[i]);
for (int i = 1; i <= n; i++) read(d[i+m]);
for (int i = 1; i <= m; i++) read(d[i]);
for (int i = 1; i <= m; i++) s.insert(d[i]);
for (int i = 1; i <= n; i++) {
lnt dmg = getd(a[i]);
mx = max(mx, (a[i]-1)/dmg+1);
if (!calc(i, dmg, a[i], p[i]))
{f = true; break;}
s.insert(d[i+m]);
}
if (f) {puts("-1"); continue;}
printf("%lld\n", CRT(mx));
}
return 0;
}