A.

B/C.

注意到题目描述里是从第一列的给定位置走到第 mm 列的给定位置,并且我们往右走是无法回头的。

所以这个提醒我们可以按列进行考虑。

D. 树到云端

因为通过 val[]{val}[] 计算出来的 T[]T[] 数组是已知的,而 val[]{val}[] 是未知的,这就相当于我们在解方程,于是就可以想到消元(其实就是作差)。考虑以 11 为根,考虑树边 uvu\to v 其中 dep[u]<dep[v]\texttt{dep}[u]<\texttt{dep}[v],则有

T[u]T[v]=isubtree(v)val[i]isubtree(v)val[i]=(2×isubtree(v)val[i])S \begin{aligned} T[u]-T[v]&=\sum_{i\in \text{subtree(v)}}val[i]-\sum_{i\notin\text{subtree(v)}}val[i]\\ &=\Big(2\times\sum_{i\in\text{subtree(v)}}val[i]\Big)-S \end{aligned}

这里 S=ival[i]S=\sum_i val[i].

我们再把所有边加起来:

这一步我做的时候没有想到…… -.-

(uv)Edep[u]<dep[v]T[u]T[v]=(n1)S+2(v1isubtree(v)val[i]) \begin{aligned} \sum_{(u\to v)\in E}^{\texttt{dep}[u]<\texttt{dep}[v]}T[u]-T[v] &=-(n-1)S+2\Bigg( \sum_{v\ne 1}\sum_{i\in\text{subtree(v)}}val[i] \Bigg) \end{aligned}

然后我们注意看括号里的和式,我们换一个角度考察。现在的和式是“对于每一个 vv,统计其子树内的 val[]val[] 的和”。我们把视角放回 vv 何时可以被某个 subtree(u)\text{subtree(u)} 统计到,我们发现当且仅当 uuvv 的祖先时可以,而 vv 有多少的祖先呢?一共 dep[v]\texttt{dep[\(v\)]} 个(这里令 dep[root]=0\texttt{dep[root]}=0,而且要注意到 subtree(u)\text{subtree}(u) 要求 urootu\ne\texttt{root}),于是这个和式可以进一步改写为

=(n1)S+2(v1(val[v]×dep[v])) =-(n-1)S+2\Bigg( \sum_{v\ne 1}(\texttt{val[\(v\)]}\times \texttt{dep[\(v\)]}) \Bigg)

刚好这个和式就等于 T[1]T[1],于是:

(uv)ET[u]T[v]=(n1)S+2T[1]S=2T[1](uv)ET[u]T[v]n1 \sum_{(u\to v)\in E}T[u]-T[v]=-(n-1)S+2T[1]\\ S=\frac{2T[1]-\sum_{(u\to v)\in E}T[u]-T[v]}{n-1}

所以我们也可以求出子树 vvval[]val[] 和了:

isubtree(v)val[i]=S+T[u]T[v]2 \sum_{i\in\text{subtree(v)}}val[i]=\frac{S+T[u]-T[v]}{2}

考虑节点 uu,那么其 val[u]val[u] 值就是其子树和减去所有儿子的子树和

val[u]=isubtree(u)val[i]vsonuisubtree(v)val[i] val[u]=\sum_{i\in\text{subtree(u)}}val[i]-\sum_{v\in son_u}\sum_{i\in\text{subtree(v)}}val[i]

root=1root=1 特殊考虑一下。

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#include <iostream>
#include <vector>
using i64 = long long;
using vi = std::vector<i64>;
using vvi = std::vector<vi>;
using tree = std::vector<std::vector<int>>;

int main() {
int n;
std::cin >> n;

tree G(n + 1);
for (int i = 1, u, v; i < n; i++) {
std::cin >> u >> v;
G.at(u).push_back(v);
G.at(v).push_back(u);
}

vi T(n + 1, 0);
for (int i = 1; i <= n; i++) std::cin >> T.at(i);

vi val(n + 1, 0), subtree(n + 1, 0);
i64 delta = 0;
auto dfs = [&](auto &&self, int u, int fa) -> void {
for (auto v : G.at(u)) {
if (v == fa) continue;
delta += T.at(u) - T.at(v);
self(self, v, u);
}
};
dfs(dfs, 1, 1);
i64 sumval = (2 * T.at(1) - delta) / (n - 1);

auto dfs2 = [&](auto &&self, int u, int fa) -> void {
for (auto v : G.at(u)) {
if (v == fa) continue;
self(self, v, u);
subtree.at(u) += subtree.at(v);
}
if (u != 1) val.at(u) = (T.at(fa) - T.at(u) + sumval) / 2 - subtree.at(u);
else val.at(u) = sumval - subtree.at(u);
subtree.at(u) += val.at(u);
};
dfs2(dfs2, 1, 1);
for (int i = 1; i <= n; i++) std::cout << val.at(i) << " \n"[i == n];
}

F. 信条

首先,廻文串是回文串,我们先用马拉车 O(n)O(n) 计算出回文串,然后再在回文串的基础上判断廻文串。

除去长度为 11 的廻文串,因为廻文串满足 wwwwww'ww' 的 pattern,所以长度必须为偶数,因此,在 augmented string A[]A[] 里,廻文串的中心必须是 padding char (#)

考虑第 ii 个位置 A[i]=‘#’A[i]=\texttt{`\#'},以其为中心的最大回文串长度为 p[i]p[i],如下图所示,我们希望知道它是多少个廻文串的中心。

廻文串
廻文串

我们关注右侧的黄色部分,注意到其实 jj 也是一个回文中心,我们可以得出 jj 需要满足的条件:

  1. jj 为中心的最大回文串至少需要包住 ii
  2. 考虑以 A[ij]A[i\dots j] 为左半段的回文串,其右端点不能超过 ii 的回文半径(下图浅蓝色红色部分),即 i+1ji+pi2i+1\le j\le i+\frac{p_i}{2}

我们考虑对于 ii 怎么维护满足条件的 jj。一个想法就是,如果让某个数据结构满足 T[j]=1T[j]=1 表示对于 iijj 满足条件,那么我只需要求和 T[i+1]+T[i+pi2]T[i+1]+\dots T[i+\frac{p_i}{2}] 即可。那么该怎么表示

i+1ji+pi2jp[j]iT[j] \sum_{i+1\le j\le i+\frac{p_i}{2}}^{j-p[j]\le i} T[j]

我们考虑把 jj 挂载到 pos=jp[j]pos=j-p[j] 上,这样我们顺序处理的时候,如果经过 pospos,就表明 jj 的回文中心的字符串可以到达 pospos 这个位置,那么对于 iposi\ge posii 来说,jp[j]ij-p[j]\le i 就自动满足了。然后我们挂载的时候,把 T[j]T[j] 设置为 11,这样,对于位置 ii 他就可以通过查询 i+1ji+pi2T[j]\sum_{i+1\le j\le i+\frac{p_i}{2}} T[j] 找到对 ii 来说所有满足要求的 jj

j 的选择
j 的选择

再来看第二问,这是一个典型的贪心问题。对于廻文中心在 A[i]A[i] 的字符串,我们只需要考虑最长的廻文子串即可。这样一共有 O(n)O(n) 条线段。我们按左端点排序,维护一个“已经被覆盖的区间 [l,r][l,r]”,然后顺序遍历所有线条,我们下一条线段取“左端点在 [l,r][l,r] 里、右端点最大”的线段。时间复杂度 O(nlogn)O(n\log n).

取最长的廻文子串需要找到对于 ii 来说 [i+1,i+pi2][i+1,i+\frac{p_i}{2}] 中最远的那个 11. 这一点可以用线段树完成(如果为 11,maxpos 设置为 pospos,否则为 1-1,查询区间的时候查找区间内的 maxpos)时间复杂度也为 O(nlogn)O(n\log n)

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#include <algorithm>
#include <cassert>
#include <iostream>
#include <string>
#include <utility>
#include <vector>

using pii = std::pair<int, int>;

class SegTree {
std::vector<int> tree;
std::vector<int> maxpos;
int n;
void add(int p, int l, int r, int pos, int val) {
if (l == r) {
tree[p] += val;
maxpos[p] = pos;
return;
}
int mid = l + ((r - l) >> 1);
if (pos <= mid) add(p * 2, l, mid, pos, val);
else add(p * 2 + 1, mid + 1, r, pos, val);
tree[p] = tree[p * 2] + tree[p * 2 + 1];
maxpos[p] = std::max(maxpos[p * 2], maxpos[p * 2 + 1]);
}
int sum(int p, int l, int r, int ql, int qr) const {
if (ql > r || qr < l) return 0;
if (ql <= l && r <= qr) return tree[p];
int mid = l + ((r - l) >> 1);
return sum(p * 2, l, mid, ql, qr) + sum(p * 2 + 1, mid + 1, r, ql, qr);
}
int mr(int p, int l, int r, int ql, int qr) {
if (ql > r || qr < l) return -1;
if (ql <= l && r <= qr) return maxpos[p];
int mid = l + ((r - l) >> 1);
return std::max(mr(p * 2, l, mid, ql, qr), mr(p * 2 + 1, mid + 1, r, ql, qr));
}

public:
SegTree(int n_) : tree(5 * n_, 0), n{n_}, maxpos(5 * n_, -1) {}
void add(int pos, int val) { add(1, 0, n - 1, pos, val); }
int sum(int ql, int qr) const { return ql <= qr ? sum(1, 0, n - 1, ql, qr) : 0; }
int maxr(int pos, int l, int r) {
if (l >= r) return pos;
else return mr(1, 0, n - 1, l, r);
}
int operator[](int pos) const {
assert(pos >= 0 && pos < n);
return sum(pos, pos);
}
};

int main() {
std::string s;
std::cin >> s;

std::string aug = "#";
for (char c : s) {
aug += c;
aug += '#';
}

std::vector<int> maxlen;
std::vector<std::vector<int>> lbound(aug.size() + 1);
int l = 0, r = -1;

for (int i = 0, lr = aug.size(); i < lr; i++) {
int tmp = (i > r ? 0 : std::min(maxlen[l + r - i], r - i));
while (i - tmp >= 0 && i + tmp < lr && aug[i - tmp] == aug[i + tmp]) ++tmp;
tmp--;
maxlen.push_back(tmp);
if (i + tmp > r) {
l = i - tmp;
r = i + tmp;
}
if (i % 2 == 0) lbound.at(i - tmp).push_back(i);
}

long long ans = 0;
SegTree st(aug.size() + 1);
std::vector<pii> str; // L, R

for (int i = 0; i < aug.size(); i++) {
for (auto e : lbound.at(i)) {
assert(e % 2 == 0);
st.add(e, 1);
}

if (i % 2 == 0) {
int v = st.sum(i + 1, i + maxlen[i] / 2);
ans += v;
int len = st.maxr(i, i + 1, i + maxlen[i] / 2);
if (len > i) str.push_back({(i + 1 - (len - i) * 2) / 2, +(i - 1 + (len - i) * 2) / 2});
} else str.push_back({i / 2, i / 2});
}

std::sort(str.begin(), str.end(), [](const pii &a, const pii &b) {
return a.first < b.first || (a.first == b.first && a.second > b.second);
});
int cnt = 0, R = -1;
for (int i = 0; i < str.size();) {
int tr = -1, j = i;
while (j < str.size() && str.at(j).first <= R) {
tr = std::max(tr, str[j].second);
j++;
}
tr = std::max(tr, str[j].second);
if (j < str.size()) cnt++, R = tr;
i = j + 1;
}
std::cout << ans + s.size() << ' ' << cnt << '\n';
}