ABC273Ex Inv(0,1)ving Insert(1,0)n

题意:定义对序列 $A$($A$ 包含整数二元组)操作一次为如下操作:选定任意两个相邻的二元组 $(a, b)$ 和 $(c, d)$,在他们中间插入 $(a + c, b + d)$。定义一个二元组序列的价值为从 $A = \{(0, 1), (1, 0)\}$ 开始至少要操作多少次才能包含序列中的所有二元组。如果无法的话价值就是 0。现给定一个长度为 $n$ 的二元组序列 $T$,问所有连续子序列的价值和,对 998244353 取模。$n\leq 10 ^ 5$,$a, b\leq 10 ^ 9$。

首先注意到这个的操作方式和 Stern-Brocot 树 的构造方式是一样的,那么直接在这棵树上做似乎是一个不错的选择。

那么根据该树的性质,我们容易得到一个二元组有两种情况是无法得到的:

  1. $\gcd(a, b) \neq 1$
  2. $a = 0\lor b = 0$

那么我们相当于现在是划分为一段一段的分别做,每一段内都是合法的状态,然后不同段之间的显然贡献都为 0。现在就处理掉了没有贡献的区间。

现在我们考虑按照该树的构造办法,我们于是可以这么拆贡献:对于该树上的每一个节点 $[\dfrac ab, \dfrac cd]$(可能 $d$ 为 0,不太严谨,就是 $+\infty$ 的意思),我们考虑要生成 $\dfrac {a + c}{b + d}$ 需要被多少个子串所需要。

一个充要条件是如果一个子串中存在一个 $\dfrac pq$ 满足 $\dfrac ab < \dfrac pq < \dfrac cd$,那么就需要 $\dfrac {a + c}{b + d}$。那么,我们需要统计所有满足分数在 $(\dfrac ab, \dfrac cd)$ 之间的位置。这样的话就是容易容斥计算的。

这样直接做复杂度是不对的,应为单次统计至少需要 $O(len)$,$len$ 为在这个区间之间的分数个数。首先我们直接考虑分治下去,将 $[\dfrac ab, \dfrac {a + c}{b + d}]$ 和 $[\dfrac {a + c}{b + d}, \dfrac cd]$ 的答案分别算出来,然后可以启发式合并一下,将少的合并到多的上,用 std::set 维护并动态统计答案,假设每次都有分支的话,复杂度就是 $O(n\log ^ 2 n)$ 的。

然后我们再来处理假设区间内部的所有数都在 $\dfrac {a + c}{b + d}$ 的一边怎么办。一个极端的情况就是 $(10 ^ 9, 1)$,我们不得不递归 $10 ^ 9$ 层才能找到他。这样显然是不好的,于是我们可以考虑二分一个 $k$ 满足不是所有数都在 $\dfrac {a + kc}{b + kd}$(或者是 $\dfrac {ka + c}{kb + d}$,看在哪一边)的一边。注意到我们相当于在树上是一次跳了 $k$ 层,于是最后的答案要 $\times k$。

于是总复杂度就是 $O(n\log ^ 2 n + n\log a)$,可以通过。实际实现的时候二分的 $k$ 差一两个是没有问题的,效率如何没测试过(因为我的 $k$ 好像就少 1)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
struct Frac {
LL x, y;
int id;
Frac operator +(Frac t) const { return {x + t.x, y + t.y}; }
bool operator <(Frac t) const { return (s128) x * t.y < (s128) y * t.x; }
bool operator >(Frac t) const { return (s128) x * t.y > (s128) y * t.x; }
bool operator ==(Frac t) const { return x == t.x && y == t.y; }
Frac operator *(LL t) const { return {x * t, y * t}; }
} a[N];

void insert(int id, int x)
{
auto iter = s[id].insert(x).first;
int y = *std::prev(iter), z = *std::next(iter);
ans[id] = (ans[id] + (LL) (z - x) * (x - y)) % Mod;
}

int solve(int l, int r, Frac lf, Frac rf)
{
while (l <= n && a[l] == lf) l ++;
while (r && a[r] == rf) r --;
if (l > r) return 1;
int cnt = 1;
if (a[r] < lf + rf) {
int x = 2, y = 1e9;
while (x < y) {
int mid = (x + y + 1) >> 1;
if (lf * mid + rf > a[r]) x = mid;
else y = mid - 1;
}
rf = rf + lf * (x - 1), cnt = x;
} else if (a[l] > lf + rf) {
int x = 2, y = 1e9;
while (x < y) {
int mid = (x + y + 1) >> 1;
if (lf + rf * mid < a[l]) x = mid;
else y = mid - 1;
}
lf = lf + rf * (x - 1), cnt = x;
}
int x = l, y = r + 1, mL, mR;
while (x < y) {
int mid = (x + y) >> 1;
if (a[mid] < lf + rf) x = mid + 1;
else y = mid;
}
mL = x, x = l - 1, y = r;
while (x < y) {
int mid = (x + y + 1) >> 1;
if (a[mid] > lf + rf) y = mid - 1;
else x = mid;
}
mR = x;
// Equal (a + c) / (b + d) range
int lc = solve(l, mL - 1, lf, lf + rf), rc = solve(mR + 1, r, lf + rf, rf);
if (s[lc].size() > s[rc].size()) std::swap(lc, rc);
if (rc == 1) s[rc = ++ tot] = {0, n + 1};
for (int t : s[lc])
if (t >= 1 && t <= n) insert(rc, t);
for (int t = mL; t <= mR; ++ t) insert(rc, a[t].id);
res = (res + (LL) ans[rc] * cnt) % Mod;
return rc;
}

int work(std::vector<Frac> vec)
{
res = 0;
n = vec.size();
for (int i = 0; i < n; ++ i) a[i + 1] = vec[i], a[i + 1].id = i + 1;
std::sort(a + 1, a + n + 1);
s[1] = {0, n + 1};
return solve(1, n, {0, 1}, {1, 0}), res;
}