BZOJ2219 数论之神

题意:求 $x ^ a \equiv b\pmod p$ 的解数,$a, b, p$ 给定。$a, b, p\leq 10 ^ 9 + 1$,$p$ 是奇数,$T(T\leq 1000)$ 组数据。

容易发现单纯的 $p$ 不好做,于是显然根据中国剩余定理拆成 $p ^ k(p\in P)$ 做,然后直接把各部分的答案乘起来即可。

下面直接讨论 $x ^ a\equiv b\pmod {p ^ k}$ 的解数。

Case 1:$b = 0$

容易发现只要 $x$ 在 $p$ 的次幂至少是 $\left\lceil\dfrac ka \right\rceil$ 即可。于是此时的贡献就是 $p ^ {k - \lceil\frac ka\rceil}$。

Case 2:$b\bmod p = 0\land b\not= 0$

此时假设 $b = c \times p ^ t$,那么此时 $x$ 在 $p$ 的次幂一定是 $\dfrac ta$(注意一定整除,否则无解),然后可以对两边同时除以 $p ^ t$,那么方程就变为了 $x ^ a\equiv c\pmod {p ^ {k - t}}$。注意到 $\bmod p ^ {k - t}$ 映射到 $\bmod p ^ k$ 时,由于需要乘上 $p ^ {\frac ta}$,所以剩下的 $k - (k - t) - \dfrac ta = t - \dfrac ta$ 次幂是完全定义域扩大的,所以需要乘上 $p ^ {t - \frac ta}$。剩下的是第三种情况,就可以解决这个问题了。

Case 3:$b\bmod p \not= 0$

这个可以使用离散对数(知道奇数的作用了吗?),假设 $b = g ^ t$,那么就是 $x ^ a \equiv g ^ t \pmod {p ^ k}$,容易发现如果 $t\bmod \gcd(\varphi(p ^ k), a)\not= 0$,说明不可能凑出 $t$,这样的话就返回 0。

然后考虑有解的情况有多少个。假设 $x = g ^ y$,那么就是 $ay\equiv t\pmod {\varphi(p ^ k)}$。容易发现此时当找到一个解时,$+\dfrac{\text{lcm}(\varphi(p ^ k), a)}{a}$ 也是一个解,而且是最小周期。除一下,一共就有 $\gcd(\varphi(p ^ k), a)$ 个解。

把所有的乘起来即可。时间复杂度 $O(\sqrt p\log p)$ 或 $O(\sqrt p)$,瓶颈在求离散对数的 BSGS 上。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
int BSGS(int a, int b, int Mod)
{
if (b == 1 || Mod == 1) return 0;
std::unordered_map<int, int> H;
int cur = b, K = std::sqrt(Mod) + 1;
for (int i = 0; i < K; ++ i, cur = (LL) cur * a % Mod) H[cur] = i;
int ak = cur = qpow(a, K, Mod);
for (int i = 1; i <= K; ++ i, cur = (LL) cur * ak % Mod)
if (H.count(cur)) return i * K - H[cur];
return -1;
}

int findrt(int p, int k)
{
int pk = qpow(p, k);
std::vector<int> fac;
int n = pk - 1;
for (int i = 2; i <= n / i; ++ i)
if (n % i == 0) {
fac.push_back(i);
while (n % i == 0) n /= i;
}
if (n ^ 1) fac.push_back(n);
auto check = [&](int g) {
for (int x : fac)
if (qpow(g, pk / p * (p - 1) / x, pk) == 1) return false;
return true;
};
for (int i = 2; i < pk; ++ i)
if (check(i)) return i;
return -1;
}

int solve(int a, int b, int p, int k)
{
int pk = qpow(p, k);
b %= pk;
// std::cout << a << ' ' << b << ' ' << p << ' ' << k << '\n';
if (b == 0) return qpow(p, k - (k + a - 1) / a);
if (b % p == 0) {
int t = 0;
while (b % p == 0) b /= p, ++ t;
if (t % a) return 0;
return solve(a, b, p, k - t) * qpow(p, t - t / a);
}
int g = findrt(p, k), _b = BSGS(g, b, pk);
// std::cout << pk << ' ' << g << ' ' << b << ' ' << _b << ' ' << Gcd(a, pk / p * (p - 1)) << std::endl;
if (_b % Gcd(a, pk / p * (p - 1))) return 0;
return Gcd(a, pk / p * (p - 1));
}

void work()
{
int a, b, p, res = 1;
scanf("%d %d %d", &a, &b, &p), p = 2 * p + 1;
for (int i = 2; i <= p / i; ++ i)
if (p % i == 0) {
int t = 0;
while (p % i == 0) ++ t, p /= i;
// std::cout << "Solve " << a << ' ' << b << ' ' << i << ' ' << t << '\n';
res *= solve(a, b, i, t);
}
if (p ^ 1) res *= solve(a, b, p, 1);
printf("%d\n", res);
}