1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
| void work() { read(n); all.clear(); for (int i = 1; i <= n; ++ i) read(a[i]); for (int i = 1; i <= n; ++ i) vis[i] = false; for (int i = 1; i <= n; ++ i) cnt[i] = 0; for (int i = 1; i <= n; ++ i) mx1[i] = mx2[i] = mx3[i] = 0; for (int i = 1; i <= n; ++ i) { if (vis[i]) continue; int j = i, cir = 0; while (!vis[j]) cir ++, vis[j] = true, j = a[j]; if (!cnt[cir] ++) all.push_back(cir); } for (int c : all) { int cur = c, t = 0, p; while (cur ^ 1) { p = fac[cur].front(), t = 0; while (cur % p == 0) t ++, cur /= p; for (int cs = 1; cs <= cnt[c]; ++ cs) if (t > mx1[p]) mx3[p] = mx2[p], mx2[p] = mx1[p], mx1[p] = t; else if (t > mx2[p]) mx3[p] = mx2[p], mx2[p] = t; else chkmax(mx3[p], t); } } int lcm = 1, res = 0; for (int i = 1; i < N; ++ i) if (mx1[i]) lcm = (LL) lcm * pw[i][mx1[i]] % Mod; for (int c1 : all) for (int c2 : all) { if (c1 == c2 && cnt[c1] == 1) continue; int cur = lcm; auto del = [&](int x) { int p, t; while (x ^ 1) { p = fac[x].front(), t = 0; if (!havbac[p]) havbac[p] = true, bac1[p] = mx1[p], bac2[p] = mx2[p], bac3[p] = mx3[p]; while (x % p == 0) t ++, x /= p; if (t == mx1[p]) cur = (LL) cur * inv[p][mx1[p] - mx2[p]] % Mod, mx1[p] = mx2[p], mx2[p] = mx3[p]; else if (t == mx2[p]) mx2[p] = mx3[p]; } }; auto ins = [&](int x) { int p, t; while (x ^ 1) { p = fac[x].front(), t = 0; if (!havbac[p]) havbac[p] = true, bac1[p] = mx1[p], bac2[p] = mx2[p], bac3[p] = mx3[p]; while (x % p == 0) t ++, x /= p; if (t > mx1[p]) cur = (LL) cur * pw[p][t - mx1[p]] % Mod, mx3[p] = mx2[p], mx2[p] = mx1[p], mx1[p] = t; else if (t > mx2[p]) mx3[p] = mx2[p], mx2[p] = t; else chkmax(mx3[p], t); } }; auto bac = [&](int x) { int p, t; while (x ^ 1) { p = fac[x].front(), t = 0; while (x % p == 0) t ++, x /= p; if (havbac[p]) havbac[p] = false, mx1[p] = bac1[p], mx2[p] = bac2[p], mx3[p] = bac3[p]; } }; del(c1), del(c2), ins(c1 + c2); bac(c1), bac(c2), bac(c1 + c2); if (c1 == c2) res = (res + (LL) cur * cnt[c1] % Mod * (cnt[c1] - 1) % Mod * c1 % Mod * c1) % Mod; else res = (res + (LL) cur * cnt[c1] % Mod * cnt[c2] % Mod * c1 % Mod * c2) % Mod; } printf("%d\n", res); }
|