Dance Mooves (USACO Gold 2021 January)

Dance Mooves (USACO Gold 2021 January)

首先容易看到,每KK次操作后,完成NN头牛的一个排列,再次重复这些操作的时候,就是把这个排列再应用一次。这个排列可以用一个功能图(Functional Graph)来表示,每个点有且只有一个出度,记为next[ii]。

然后如果多次应用这个排列,可以看到每头牛的位置经过多次变换,总会回到原来位置形成一个环,实际上所有节点会形成1个或多个独立不相交的环,通过O(n)\mathcal{O}(n)遍历,可以把这些环的周期都求出来。

继续往下想,容易看到M>KCM>K \cdot C之后(C是当前节点在环的周期,CNC \leq N),环上的所有位置能到达的位置的集合就完全一样了,因为路径已经完全重复了。

那么问题主要是要解决M<KCM < K \cdot C的情况。因为N105N \leq 10^5,这个MM还有101010^{10}大,模拟肯定是不可能的,需要找到更多的规律。比如因为是求unique positions,也会考虑用线段树区间求和,但主要困难是MM太大,走不通。

这时候,我们回到题目做的操作本身。因为每次只交换22个位置,那我们可以看到,KK次操作之后,所有牛能覆盖的位置数量加起来不超过2K+N2K+N(每次操作最多增加两个位置),所有位置列表是可以O(N)\mathcal{O}(N)求出来的,那我们可以把这些列表都求出来,放到每个出发位置的集合s[i]里面存下来,有了这个集合列表问题就清楚了。

这时候可以看到一个很简单的办法:对于MM次操作,是MK\lfloor \frac{M}{K} \rfloorKK操作周期,再加余下的MmodKM \mod K个操作接起来,前者能碰到的位置,就是对应的s[i]的集合并,对于所有位置来说,可以用sliding window的办法合并这些集合,一遍遍历计数就可以算出对于所有节点的结果。

余数的部分,把s[i]中间每步到达的时间记录下,就可以处理了,具体见程序。最终复杂度是O(N+K)\mathcal{O}(N+K)

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int,int> pi;
 
int n,K;
ll m;
pi ops[200005];     // K个操作
int nxt[100005];    // nxt[i]: 经过K次操作后i位置的元素到的位置
set<pi> s[100005];  // 每个位置经过K次操作能到达的{位置, 最早到达时间}
int p[100005];      // p[i]是K次操作后在i位置的数字
int cycle[100005];  // 每个位置的周期
map<int,int> sw;    // sliding window of touched positions
int ans[100005];
 
void reachable(int pos, int pos_now, int t) {   // mark pos_now reachable from pos, at time t
    auto it = s[pos].find({pos_now,-1});
    if (it != s[pos].end() && it->first == pos_now) return;
    s[pos].insert({pos_now,t});
}
 
void add(int x, int r) {        // add s[x] to sliding window
    for (pi k: s[x])   
        if (k.second < r)
            sw[k.first]++;
}
 
void remove(int x, int r) {     // remove s[x] from sliding window
    for (pi k: s[x])
        if (k.second < r)
            if (sw[k.first] == 1)
                sw.erase(k.first);
            else
                sw[k.first]--;                        
}
 
int main() {
    scanf("%d %d %lld", &n, &K, &m);
    for (int i = 0; i < n; i++)
        p[i] = i, s[i].insert({i,0});
    for (int i = 0; i < K; i++) {
        int u,v;
        scanf("%d %d", &u, &v);
        u--,v--;
        ops[i] = {u,v};
        swap(p[u], p[v]);   
        reachable(p[u], u, i);
        reachable(p[v], v, i);
    }
    for (int i = 0; i < n; i++)
        nxt[p[i]]=i;
    for (int i = 0; i < n; i++)     // 计算周期
        if (!cycle[i]) {
            set<int> ps;
            int x = i;
            while (!ps.count(x)) 
                ps.insert(x), x = nxt[x];
            for (int j: ps)
                cycle[j] = ps.size();
        }
    ll d = m / K, r = m % K;
    fill(ans, ans+n, -1);
    for (int i = 0; i < n; i++) {
        if (ans[i] > 0) continue;
        int x = i, cyc = cycle[i];
        sw.clear();
        int y = x;
        for (int j = 0; j < min(d+1, (ll)cyc); j++) {
            if (j) y = nxt[y];
            for (pi k: s[y])
                if (j < d || k.second < r)
                    sw[k.first]++;
        }
        ans[i] = sw.size();
        while (1) {                 // sliding window: s[x]..s[y]
            int x0 = x;
            x = nxt[x];
            if (x == i) break;
            if (d < cyc) {          // 移动sliding window
                remove(y, r);       // 去掉头部余下集合
                if (d > 0) {
                    remove(x0, K);  // 去掉尾部
                    add(y, K);      // 新增头部
                }
                y = nxt[y];
                add(y, r);          // 新增头部余下集合
            }
            ans[x] = sw.size();
        }
    }
    for (int i = 0; i < n; i++)
        printf("%d\n", ans[i]);
    return 0;
}