BZOJ3702 二叉树 <线段树合并>

Problem

二叉树

Time  Limit:  15  Sec\mathrm{Time\;Limit:\;15\;Sec}
Memory  Limit:  256  MB\mathrm{Memory\;Limit:\;256\;MB}

Description

现在有一棵二叉树,所有非叶子节点都有两个孩子。
在每个叶子节点上有一个权值(有nn个叶子节点,满足这些权值为1n1\sim n的一个排列)。可以任意交换每个非叶子节点的左右孩子。
要求进行一系列交换,使得最终所有叶子节点的权值按照中序遍历写出来,逆序对个数最少。

Input

第一行一个整数nn
下面每行一个数xx

  • 如果x=0x=0,表示这个节点非叶子节点,递归地向下读入其左孩子和右孩子的信息,
  • 如果x0x\ne0,表示这个节点是叶子节点,权值为xx

Output

一行,最少逆序对个数。

Sample Input

1
2
3
4
5
6
3
0
0
3
1
2

Sample Output

1
1

HINT

对于100%100\%的数据:2n2×1052\le n\le 2\times10^5

标签:线段树合并

Solution

树上贪心时用线段树合并优化复杂度。

考虑朴素贪心,由于子树内如何交换和上面的点如何交换互不影响,只需要对于每个点采用其逆序对数最少的交换方式即可,不用考虑其带来的影响。具体地,对每个点尝试交换/不交换两个儿子,分别求出两种情况下跨过这个点的逆序对数,选取较小的那个即可。如果每次将两棵子树中的所有叶子权值全拿出来归并排序,复杂度是O(n2logn)O(n^2\log{n})的。

用线段树合并优化求跨过某点的逆序对数。对树上每个点维护一棵值域线段树,存储其子树中叶子节点的权值信息,在尝试交换/不交换两个儿子时,合并这两个儿子的线段树,在合并过程中可以递归算两种情况的逆序对数。如此在合并后取较小的对数计入总答案即可。时间复杂度O(nlogn)O(n\log{n})

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include <bits/stdc++.h>
#define MAX_N 400000
#define mid ((s+t)>>1)
using namespace std;
typedef long long lnt;
template <class T> inline void read(T &x) {
x = 0; int c = getchar(), f = 1;
for (; !isdigit(c); c = getchar()) if (c == 45) f = -1;
for (; isdigit(c); c = getchar()) (x *= 10) += f*(c-'0');
}
int n, m, r, s[MAX_N+5][2];
int sz, rt[MAX_N+5]; lnt s1, s2, ans;
struct node {int ls, rs, c;} tr[MAX_N*20];
int merge(int u, int v) {
if (!u) return v; if (!v) return u;
s1 += 1LL*tr[tr[u].ls].c*tr[tr[v].rs].c;
s2 += 1LL*tr[tr[u].rs].c*tr[tr[v].ls].c;
tr[u].ls = merge(tr[u].ls, tr[v].ls);
tr[u].rs = merge(tr[u].rs, tr[v].rs);
tr[u].c = tr[tr[u].ls].c+tr[tr[u].rs].c;
return u;
}
void modify(int &v, int s, int t, int p) {
if (!v) v = ++sz; if (s == t) {tr[v].c++; return;}
if (p <= mid) modify(tr[v].ls, s, mid, p);
if (p >= mid+1) modify(tr[v].rs, mid+1, t, p);
tr[v].c = tr[tr[v].ls].c+tr[tr[v].rs].c;
}
void build(int &u) {
if (!u) u = ++m; int x; read(x);
if (!x) build(s[u][0]), build(s[u][1]);
else modify(rt[u], 1, n, x);
}
void DFS(int u) {
if (!u) return; DFS(s[u][0]), DFS(s[u][1]), s1 = s2 = 0;
if (s[u][0] || s[u][1])
rt[u] = merge(rt[s[u][0]], rt[s[u][1]]), ans += min(s1, s2);
}
int main() {
read(n), build(r), DFS(r);
return printf("%lld\n", ans), 0;
}