我们发现光是从一个序列 { a i } \set{a_i} { a i } 计算其 f ( { a } ) f(\set{a}) f ({ a }) 的值都必须要 O ( n ) O(n) O ( n ) ,更别说枚举所有可能序列了。
所以我们反着考虑:我们考虑一个 token removal sequence 可能被几个序列统计到。对于单个 token 设其位置为 p p p ,则如果 a i a_i a i 可以把 p p p 取走,那么应该满足
a i ≤ p ≤ i
a_i\le p\le i
a i ≤ p ≤ i 我们考虑有多少 [ l , r ] [l,r] [ l , r ] 满足 l ≤ p ≤ r l\le p\le r l ≤ p ≤ r ,则满足这个条件的数对表示 a r = l a_r=l a r = l 时刻把 token p p p 取走。那么考虑 token 被取走的集合 (不是顺序)1 ≤ p 1 < p 2 < p 3 < … p k ≤ n 1\le p_1\lt p_2\lt p_3\lt \dots p_{k}\le n 1 ≤ p 1 < p 2 < p 3 < … p k ≤ n ,我们从 p k p_{k} p k 往 p 1 p_1 p 1 数有多少 { [ l , r ] } i ∈ [ 1 , k ] \set{[l,r]}_{i\in [1,k]} { [ l , r ] } i ∈ [ 1 , k ] 满足 r i r_i r i 互不相同且 p i ∈ [ l i , r i ] p_i\in [l_i,r_i] p i ∈ [ l i , r i ] ,这样的 [ l , r ] [l,r] [ l , r ] set 集合表示这个 token sequence 会被 a r = l a_r=l a r = l (其余元素为 0 0 0 ) 的序列统计到。而有多少这样的 { [ l , r ] } i \set{[l,r]}_i { [ l , r ] } i 集合呢?
∏ i = k 1 p i ( n + 1 − p i − ( k − i ) )
\prod_{i=k}^1 p_i(n+1-p_i-(k-i))
i = k ∏ 1 p i ( n + 1 − p i − ( k − i )) − ( k − i ) -(k-i) − ( k − i ) 表示比 p i p_i p i 打的 token 位置里已经被取走 k − i k-i k − i 个。所以我们最后希望统计的是
∑ f = { p } ∏ i = ∣ f ∣ 1 p i ( n + 1 − p i − ( f − i ) )
\sum_{f=\set{p}}\prod_{i= |f|}^1 p_i(n+1-p_i-(f-i))
f = { p } ∑ i = ∣ f ∣ ∏ 1 p i ( n + 1 − p i − ( f − i )) 我们显然不能枚举 { p } \set{p} { p } . 所以我们转而考虑如何递推,我们从长度出发进行递推,考虑每一个位置上的 token 会被几个 token seq 计算到。
我们令 f ℓ , p f_{\ell, p} f ℓ , p 表示长度为 ℓ \ell ℓ ,其中最小值 min = p \min=p min = p 的和式结果。则就有递推式
f ℓ , p = ( ∑ p ′ > p f ℓ − 1 , p ′ ) × p × ( n + 1 − p − ( ℓ − 1 ) )
f_{\ell, p} = \Big(\sum_{p'\gt p}f_{\ell-1, p'}\Big) \times p\times (n+1-p-(\ell-1))
f ℓ , p = ( p ′ > p ∑ f ℓ − 1 , p ′ ) × p × ( n + 1 − p − ( ℓ − 1 )) 括号内的和式很容易用前缀和优化掉,所以整体的时间复杂度为 O ( n 2 ) O(n^2) O ( n 2 ) 的.
尝试从长度递推之后,一个自然而然的问题就是,多出的那个元素应该是什么?最大值?最小值?如果你尝试用最大值的话,就会发现你还需要额外保存 ∑ p < max p \sum_{p\lt \max} p ∑ p < m a x p ,这很麻烦。所以应该选择 min \min min 作为额外添加的那个元素。
AC Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 #include <bits/stdc++.h> constexpr int N = 5005 ;int n, m;int F[N][N], g[N]; void run () { std::cin >> n >> m; for (int i = 0 ; i <= n + 1 ; i++) for (int j = 0 ; j <= n + 1 ; j++) F[i][j] = 0 ; F[0 ][n + 1 ] = 1 ; for (int len = 1 ; len <= n; len++) { for (int i = n + 1 ; i >= 1 ; i--) g[i] = (g[i + 1 ] + F[len - 1 ][i]) % m; for (int i = 1 ; i <= n; i++) F[len][i] = 1ll * g[i + 1 ] * i % m * (n - i + 1 - (len - 1 )) % m; } int ans = 1 ; for (int i = 1 ; i <= n; i++) for (int j = 1 ; j <= n; j++) ans = (ans + F[i][j]) % m; std::cout << ans << '\n' ; } int main () { std::cin.tie (0 )->sync_with_stdio (0 ); int t; std::cin >> t; while (t--) run (); }