Check Transmission

Check Transmission

穷举所有可能的“0字符串”的长度,根据s/ts/t的长度就可以算出对应的“1字符串”的长度,这样0/1字符串就确定了,然后判断是否合法就行了。

判断合法性的办法,是用字符串哈希判断每个0/1字符串对应的子串是否是正确的值。最终复杂度是O(nlogn)\mathcal{O}(n\log n)

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int,int> pi;
 
int n, m;
const int MAXM = 1000005;
char s[MAXM], t[MAXM];  // s is binary, t is ascii
int first[2], cnt[2];
int h[MAXM], p[MAXM];   // string hash
const int A = 911382323, M = 972663749;
 
int main() {
    scanf("%s %s", s, t);
    int n = strlen(s), m = strlen(t);
    first[0] = first[1] = -1;
    for (int i = 0; i < n; i++) {
        int v = s[i]-'0';
        if (first[v] == -1) first[v] = i;
        cnt[v]++;
    }
 
    p[0] = 1; h[0] = t[0]; 
    for (int i = 1; i < m; i++) {
        p[i] = (ll)p[i-1] * A % M;
        h[i] = ((ll)h[i-1] * A + t[i]) % M;
    }
    auto hsh = [](int a, int b) -> int {
        int res;
        if (a == 0) res = h[b];
        else res = (h[b] - (ll)h[a-1] * p[b-a+1] + (ll)h[a-1] * M) % M;
        // printf("h[%d,%d]=%d\n", a, b, res);
        return res;
    };
 
    int ans = 0;
    for (int i = 1; i <= m / cnt[0]; i++) { // lenght of 0 string
        if ((m - i*cnt[0]) % cnt[1] != 0) continue;     // length of 1 string needs to be whole number
        int j = (m - i*cnt[0]) / cnt[1];    // length of 1 string
        if (j == 0) continue;               // 1 string has to be non-empty
 
        int l[2], r[2], pos = 0;
        bool ok = true;
        for (int k = 0; k < n; k++) {
            int v = s[k]-'0';
            int len = v ? j : i;
            if (k == first[0] || k == first[1])
                l[v] = pos, r[v] = pos + len - 1;       // mark left and right positions
            else {
                if (hsh(l[v],r[v]) != hsh(pos, pos+len-1)) {  // check validity
                    ok = false;
                    break;
                }
            }
            pos += len;
        }
        // 0 and 1 strings must be different
        if (ok && hsh(l[0],r[0]) != hsh(l[1],r[1])) ans++;
    }
    printf("%d", ans);
 
    return 0;
}