Bovine Genetics (USACO Gold 2020 December)

Bovine Genetics (USACO Gold 2020 December)

这是一道难度大的多维DP题。读题之后,从有规律的分隔,方案数大到long long装不下等迹象,可以看出很可能是一道DP题。同时N=105N=10^5,所以需要O(N)\mathcal{O}(N)或者O(NlogN)\mathcal{O}(N \log N),不能出现N2N^2

首先我们比较容易看出来,一种隔断字符串的方式,再加?字符填入什么字母,就唯一确定了一个原始字符串。所以我们要统计的,就是有多少个隔断和填入?字符的方案数。如果字符串的分段为...st...|s|t,即sstt是最后两个段,那这里可以观察出来两个限制条件:

基于这些规则,我们可以设计如下的DP状态:将当前处理到的字符位置,以及最后一段的起始位置作为状态,也就是dp[i][j]\text{dp}[i][j]是到第ii个字符,最后一段从jj个字符开始的方案总数,这样就可以判断在哪些位置可以插入隔断,这样下去明显是一个N2N^2的算法,实现起来也不复杂,能过一半的测试点。这个就不写了,可以参考官方题解。

我们主要讨论正解,上面的状态设计主要的缺陷,是最后一段可能从任何位置开始,甚至从整个串一开头开始,所以变成了N2N^2,但如果我们观察题目规则,我们会发现后续的状态转换中,我们并不关心最后一段开始的位置,而只关心段的开头和结尾的字母。所以,我们尝试列举重要的几个状态(还比较多):

  1. 首先我们要跟踪目前处理到的位置ii
  2. 然后我们当然关心当前位置ii是什么字母,以计算?对应的不同方案。
  3. 我们还关心最后一段以什么字母开始,因为这决定了下一段要以什么字母结束。
  4. 最后我们还关心最后第二段以什么字母开始,因为它决定当前段(最后一段)可以以什么字母结束。

可以证明有这四个状态就可以了。也就是:

然后状态转移是:

dp[i][j][k=a[i]][k]=dp[i1][l][j][l],(i1,i之间发生隔断)dp[i][j][k][l]=mldp[i1][j][k][m],la[i1](没有隔断)\text{dp}[i][j][k=a[i]][k] = \text{dp}[i-1][l][j][l], (i-1, i\text{之间发生隔断}) \\ \text{dp}[i][j][k][l] = \sum_{m \neq l} \text{dp}[i-1][j][k][m], l \neq a[i-1] (\text{没有隔断})

两个公式的值是加起来的。这样就可以解决问题了。初值处理还有一点细节,见下方代码。jkljkl的范围都是141-4,所以整体复杂度是可以认为是一个大常数的O(N)\mathcal{O}(N),可以通过。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int,int> pi;
 
int n;
int id[256];        // 0 for ?, 1-4 for ACGT
char s[100005];
int dp[100005][5][5][5];
const int M = 1e9 + 7;
 
int main() {
    scanf("%s", s);
    n = strlen(s);
    id['A'] = 1; id['C'] = 2; id['G'] = 3; id['T'] = 4;
 
    int c = id[s[0]];
    for (int i = c ? c : 1; i <= (c ? c : 4); i++)
        dp[0][0][i][i] = 1;         // j == 0是个特殊值,表示第一个段
    for (int i = 1; i < n; i++) {
        int c0 = id[s[i-1]];
        int c = id[s[i]];
        if (c != c0 || c==0) {             // 不隔断
            for (int j = 0; j <= 4; j++)
                for (int k = 1; k <= 4; k++)
                    for (int l = c ? c : 1; l <= (c ? c : 4); l++)
                        for (int m = 1; m <= 4; m++) {
                            if (m == l) continue;       // 末尾两字母不能相同
                            dp[i][j][k][l] += dp[i-1][j][k][m];
                            dp[i][j][k][l] %= M;
                        }
        }
        for (int j = 1; j <= 4; j++)        // 新增隔断
            for (int k = c ? c : 1; k <= (c ? c : 4); k++)
                for (int l = 1; l <= 4; l++) {
                    dp[i][j][k][k] = (dp[i][j][k][k] + dp[i-1][l][j][l]) % M;
                    dp[i][j][k][k] = (dp[i][j][k][k] + dp[i-1][0][j][l]) % M;       // 处理第一段
                }
    }
    int ans = 0;
    int last = id[s[n-1]];
    for (int j = 1; j <= 4; j++)
        for (int k = last ? last : 1; k <= (last ? last : 4); k++) {
            ans = (ans + dp[n-1][k][j][k]) % M;
            ans = (ans + dp[n-1][0][j][k]) % M;     // 只有一段的情况
        }
    printf("%d", ans);
    return 0;
}