Balancing Inversions (USACO Gold 2019 US Open)
我们首先观察,两两交换数字,这个带来的状态变化,比较难以用DP这样的常用方式来充分表达,所以多半还是需要利用01数据,以及交换操作的特性,来简化问题的状态空间。
以下比较简单的规律容易发现:只有0和1交换才有意义,00和11不需要交换。交换可以看1的左右移动。其次,大部分01交换将逆序对数量加1,10交换将逆序对数量减1,但n-1和n位置的交换特殊(“中央交换”),这也是题目所给例子提示我们的。
以下是关键一步,仔细推公式可以发现:“中央01交换”使得左右逆序对数量差增加“N - 序列中1的数量”(例如题目例子的中央01交换,使得左右差增加2),同样道理,“中央10交换”使得结果减少 “N - 序列中1的数量”。
有了这个发现,我们就知道了每次中央交换对结果改变是恒定的,那就比较容易接着想到,最后的方案是由若干个中央交换(同一方向)和若干个其它交换(也是对左右差而言的同样方向)组成的。
所以我们的解法:不失一般性,只考虑10中央交换的情况。直接模拟0个、1个...10中央交换,得到需要多少步,同时也得到交换完成后的新的左右差,再加上与最终目标间的差,就是这个数量的中央交换的情况下的步数。对所有数量的中央交换求最小值,就是答案。这里有两个细节:
- 我们可以证明,对于每一个可以实现的中央交换次数,之后都可以通过非中央交换实现左右逆序对差为0。至少一个可行的操作,是把1都移到左右半边的最右侧,则逆序对数都为0。
- 持续进行中央10交换的办法,是将左侧离中央最近的1,以及右侧离中央最近的0,通过交换移动到中央,这个模拟可以通过两个指针进行(并不需要实际交换,找到对应的数字的位置就可以,见代码)。
最后对于“01中央交换”,可以对应看成0的向右移动,以上算法可以参数化后可以复用。我们枚举所有10中央交换和01中央交换的次数,就可以得到最优解。
这个题还是比较奇特的,需要有足够的耐心一步步推导,层层解开谜团,才能找到最终的数学规律。最后O(N)解决问题。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int,int> pi;
int n;
int a[200005];
ll c1, c2;
int cnt, lcnt, rcnt; // count of ones
ll ans;
ll count(int *a, int n) {
int zero = 0;
ll res = 0;
for (int i = n-1; i >= 0; i--)
if (a[i]) res += zero;
else zero++;
return res;
}
void solve(int t) { // t: 1 for moving ones to the right, 0 for moving zeros
int tot = min(t ? lcnt : n-lcnt, t ? n-rcnt : rcnt);
int delta = t ? cnt - n : n - cnt; // change result to per central swap
int l = n-1, r = n; // ptr to next t (l) and 1-t (r)
ll steps = 0;
ll res = c1-c2;
for (int i = 1; i <= tot; i++) { // move tot 1's to the right
while (l > 0 && a[l] != t) l--;
while (r < 2*n && a[r] != 1-t) r++;
steps += (n-1-l) + (r - n) + 1;
res += delta; // changes to result from central swap
if (t) res += -(n-1-l) + (r-n); // changes to result from other swaps
else res += (n-1-l) - (r-n);
ans = min(ans, steps + abs(res));
l--; r++;
}
}
int main() {
scanf("%d", &n);
for (int i = 0; i < 2*n; i++) {
scanf("%d", &a[i]);
cnt += a[i];
if (i < n) lcnt += a[i]; else rcnt += a[i];
}
c1 = count(a, n);
c2 = count(a+n, n);
ans = abs(c1-c2);
solve(0);
solve(1);
printf("%lld", ans);
return 0;
}