LOJ3730 [SNOI2022]数位

有趣的数数题,但压轴题确实码量较大。

题意:给定 $L, R$,问有多少个 $n$ 元组 $(a_1, a_2, \cdots, a_n)$ 满足 $a_i\in [L, R]$ 并且 $\sum a$ 10 进制表示从高位向低位数字不增。$L, R\leq 10 ^ {1000}$,$k\leq 50$,6s。

并不清楚为什么要给 6s,反正我代码最大点 200- ms,目前(2022-06-23)LOJ 最优解,大概是 rk2 速度的 4 倍。

以下令 $m = 10$,用于表示复杂度。

容易发现我们只有 $L$ 的限制是好做的,我们可以用插板法用一个组合数表示。

考虑容斥,计算钦定有 $k$ 个超出 $R$ 的限制的(似乎是二项式反演的弱化版),那么我们可以选择 $k$ 个超出限制,至少为 $R + 1$,这一部分需要乘上 $\binom nk$。假设最终的和为 $s$,那么可以得到方案数为:

注意到 $s$ 必须 $\geq k(R + 1) + (n - k)L$,同时需要满足题目给的数位单调不降的特点,显然使用数位 DP。下面考虑如何计算这个组合数。

容易发现这个组合数是一个关于 $s - k(R + 1) - (n - k)L$ 的不超过 $n - 1$ 次多项式,那么我们可以把这个式子拆分为不同幂次的和。维护不同幂次的答案显然可以使用一个结构体维护,然后统一转移。具体的,使用二项式定理展开,$a(i), b(j)$ 乘 $\binom{i + j}i$ 到 $c(i + j)$。

考虑数位 DP 的具体过程,注意到我们没有上界,于是我们可以强行规定一个上界,使它比所有的 $k(R + 1) + (n - k)L$ 都要大。为了简洁我们直接设定位数不超过 $|R| + 3$,容易发现 $k(R + 1) + (n - k)L$ 肯定无法达到上界。

我们可以使用一个简单的预处理 $f(i, j)$ 表示 $i$ 位,最高位是 $j$ 且满足条件的所有数的各次幂。容易发现转移可以做到 $O(|R|\times m\times n ^ 2)$。

然后计算 $\geq cur$ 的答案,首先统计位数更大的答案,注意 $f(i, j)$ 不是最终答案,我们需要减去一个 $cur$ 的贡献,相当于对每个数加 $P - sum$。然后逐位比较,假设最高的 $i$ 位已经确定与 $cur$ 相同,枚举当前位置比 $cur$ 当前位置大并且合法的位置,然后就可以用我们预处理的 $f$。注意需要减去 $cur$ 的低位,因为高位相同,无需管。

最后一个问题是怎样通过次幂计算组合数。好像有二项式反演做法,但是 $n$ 太小了,直接暴力高斯消元得到每一项系数即可。

总时间复杂度 $O(|R|mn ^ 2 + n ^ 3)$,轻松通过。具体可以看代码。可以优化到 $O(|R|mn\log n + n ^ 3)$,不过没意义罢。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
struct BigNum {
std::vector<int> a;
BigNum() {}
BigNum(int _v) { while (_v) a.push_back(_v % 10), _v /= 10; }
int len() { return a.size(); }
int& operator [](int x) { return a[x]; }
} L, R;
int n, C[N][N], mxlen, pw10[1010], a[N][N];

inline int& adj(int &x) { return x += x >> 31 & Mod; }

int qpow(int a, int k = Mod - 2)
{
int res = 1;
for (; k; k >>= 1, a = (LL) a * a % Mod)
if (k & 1) res = (LL) res * a % Mod;
return res;
}

struct Node {
std::vector<int> a;
Node() {}
Node(int _v) : a(n + 1, 1) {
for (int i = 1; i <= n; ++ i) a[i] = a[i - 1] * (LL) _v % Mod;
}
int& operator [](int x) { return a[x]; }
Node operator *(Node t) const {
Node res;
res.a.resize(n + 1);
for (int i = 0; i <= n; ++ i)
for (int j = 0; j <= n - i; ++ j)
res[i + j] = (res[i + j] + (LL) a[i] * t[j] % Mod * C[i + j][i]) % Mod;
return res;
}
Node& operator +=(Node t) {
for (int i = 0; i <= n; ++ i) adj(a[i] += t[i] - Mod);
return *this;
}
Node operator +(Node t) const { return t += *this; }
} dp[1010][10];

void Gauss()
{
for (int i = 1; i <= n; ++ i)
{
int t = -1;
for (int j = i; j <= n; ++ j)
if (a[j][i]) {
t = j;
break;
}
assert(~t);
if (t ^ i) std::swap(a[t], a[i]);
int Inv = qpow(a[i][i]);
for (int j = i; j <= n + 1; ++ j) a[i][j] = (LL) a[i][j] * Inv % Mod;
for (int j = 1; j <= n; ++ j)
if (j != i && a[j][i])
for (int k = n + 1; k >= i; -- k)
a[j][k] = (a[j][k] + (LL) (Mod - a[j][i]) * a[i][k]) % Mod;
}
}

BigNum operator +(BigNum a, BigNum b)
{
BigNum res;
int len = std::max(a.len(), b.len()), ls = 0;
a.a.resize(len), b.a.resize(len);
for (int i = 0; i < len; ++ i)
res.a.push_back((a[i] + b[i] + ls) % 10), ls = (a[i] + b[i] + ls) >= 10;
if (ls) res.a.push_back(ls);
return res;
}

BigNum operator *(BigNum a, int k)
{
if (!k) return {};
BigNum res;
LL ls = 0;
for (int x : a.a)
ls = (LL) x * k + ls, res.a.push_back(ls % 10), ls /= 10;
while (ls) res.a.push_back(ls % 10), ls /= 10;
return res;
}

std::istream& operator >>(std::istream &fin, BigNum &res)
{
std::string buf;
fin >> buf;
std::reverse(buf.begin(), buf.end());
int len = buf.length();
for (int i = 0; i < len; ++ i) res.a.push_back(buf[i] ^ 48);
return fin;
}

std::ostream& operator <<(std::ostream &fout, BigNum res)
{
for (int i = res.len() - 1; ~i; -- i) fout << char(res[i] % 10 | 48);
return fout;
}

int solve(BigNum le)
{
Node res;
res.a.resize(n + 1);
int ls = 9, sum = 0;
for (int i = le.len() - 1; ~i; -- i) sum = (sum * 10LL + le[i]) % Mod;
for (int i = le.len() + 1; i <= mxlen; ++ i)
for (int j = 1; j <= 9; ++ j) res += dp[i][j] * Node(Mod - sum);
for (int i = le.len() - 1; ~i; -- i)
{
for (int j = le[i] + 1; j <= ls; ++ j) res += dp[i + 1][j] * Node(Mod - sum);
if (le[i] > ls) break;
sum = (sum + (LL) (Mod - pw10[i]) * le[i]) % Mod, ls = le[i];
if (i == 0) res += Node(0);
}
int ret = 0;
for (int i = 0; i < n; ++ i) ret = (ret + (LL) res[i] * a[i + 1][n + 1]) % Mod;
// std::cout << ret << ' ' << le << '\n';
return ret;
}

int main()
{
// freopen("digit.in", "r", stdin);
// freopen("digit.out", "w", stdout);
std::cin >> L >> R >> n, mxlen = R.len() + 3;
for (int i = C[0][0] = 1; i < N; ++ i)
for (int j = C[i][0] = 1; j <= i; ++ j)
adj(C[i][j] = C[i - 1][j - 1] + C[i - 1][j] - Mod);
for (int i = pw10[0] = 1; i <= mxlen; ++ i) pw10[i] = pw10[i - 1] * 10LL % Mod;
for (int i = 2; i <= mxlen; ++ i)
for (int j = 0; j <= 9; ++ j) dp[i][j].a.resize(n + 1);
for (int j = 0; j <= 9; ++ j) dp[1][j] = Node(j);
for (int l = 1; l < mxlen; ++ l)
for (int i = 0; i <= 9; ++ i)
for (int j = i; j <= 9; ++ j)
dp[l + 1][j] += dp[l][i] * Node(pw10[l] * (LL) j % Mod);
for (int i = 1; i <= n; ++ i)
for (int j = 1; j <= n; ++ j) a[i][j] = qpow(i, j - 1);
for (int i = 1; i <= n; ++ i) a[i][n + 1] = C[i + n - 1][n - 1];
Gauss();

// mxlen = 2, solve(10), exit(0);
int res = 0;
for (int i = 0, op = 1; i <= n; ++ i, op = Mod - op)
res = (res + solve(L * (n - i) + (R + 1) * i) * (LL) C[n][i] % Mod * op) % Mod;
std::cout << res << std::endl;
return 0;
}