Balancing Inversions (USACO Gold 2019 US Open)

Balancing Inversions (USACO Gold 2019 US Open)

我们首先观察,两两交换数字,这个带来的状态变化,比较难以用DP这样的常用方式来充分表达,所以多半还是需要利用01数据,以及交换操作的特性,来简化问题的状态空间。

以下比较简单的规律容易发现:只有0和1交换才有意义,00和11不需要交换。交换可以看1的左右移动。其次,大部分01交换将逆序对数量加1,10交换将逆序对数量减1,但n-1和n位置的交换特殊(“中央交换”),这也是题目所给例子提示我们的。

以下是关键一步,仔细推公式可以发现:“中央01交换”使得左右逆序对数量差增加“N - 序列中1的数量”(例如题目例子的中央01交换,使得左右差增加2),同样道理,“中央10交换”使得结果减少 “N - 序列中1的数量”。

有了这个发现,我们就知道了每次中央交换对结果改变是恒定的,那就比较容易接着想到,最后的方案是由若干个中央交换(同一方向)和若干个其它交换(也是对左右差而言的同样方向)组成的。

所以我们的解法:不失一般性,只考虑10中央交换的情况。直接模拟0个、1个...10中央交换,得到需要多少步,同时也得到交换完成后的新的左右差,再加上与最终目标间的差,就是这个数量的中央交换的情况下的步数。对所有数量的中央交换求最小值,就是答案。这里有两个细节:

  1. 我们可以证明,对于每一个可以实现的中央交换次数,之后都可以通过非中央交换实现左右逆序对差为0。至少一个可行的操作,是把1都移到左右半边的最右侧,则逆序对数都为0。
  2. 持续进行中央10交换的办法,是将左侧离中央最近的1,以及右侧离中央最近的0,通过交换移动到中央,这个模拟可以通过两个指针进行(并不需要实际交换,找到对应的数字的位置就可以,见代码)。

最后对于“01中央交换”,可以对应看成0的向右移动,以上算法可以参数化后可以复用。我们枚举所有10中央交换和01中央交换的次数,就可以得到最优解。

这个题还是比较奇特的,需要有足够的耐心一步步推导,层层解开谜团,才能找到最终的数学规律。最后O(N)解决问题。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int,int> pi;
 
int n;
int a[200005];
ll c1, c2;
int cnt, lcnt, rcnt;    // count of ones
ll ans;
 
ll count(int *a, int n) {
    int zero = 0;
    ll res = 0;
    for (int i = n-1; i >= 0; i--)
        if (a[i]) res += zero;
        else zero++;
    return res;
}
 
void solve(int t) {  // t: 1 for moving ones to the right, 0 for moving zeros
    int tot = min(t ? lcnt : n-lcnt, t ? n-rcnt : rcnt);
    int delta = t ? cnt - n : n - cnt;  // change result to per central swap
    int l = n-1, r = n;                 // ptr to next t (l) and 1-t (r)
    ll steps = 0;
    ll res = c1-c2;
    for (int i = 1; i <= tot; i++) {    // move tot 1's to the right
        while (l > 0 && a[l] != t) l--;
        while (r < 2*n && a[r] != 1-t) r++;
        steps += (n-1-l) + (r - n) + 1;
        res += delta;                   // changes to result from central swap
        if (t) res += -(n-1-l) + (r-n); // changes to result from other swaps
        else res += (n-1-l) - (r-n);
        ans = min(ans, steps + abs(res));
        l--; r++;
    }
}
 
int main() {
    scanf("%d", &n);
    for (int i = 0; i < 2*n; i++) {
        scanf("%d", &a[i]);
        cnt += a[i];
        if (i < n) lcnt += a[i]; else rcnt += a[i];
    }
    c1 = count(a, n);
    c2 = count(a+n, n);
    ans = abs(c1-c2);
    solve(0);
    solve(1);
    printf("%lld", ans);
    return 0;
}