1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
| int det(int n) { int x = 0, res = 1; for (int i = 1; i <= n; ++ i) { int t = -1; for (int j = i; j <= n; ++ j) if (a[j][i]) { t = j; break; } if (!~t) return 0; if (i != t) std::swap(a[i], a[t]), x ^= 1; LL Inv = qpow(a[i][i]); res = (LL) res * a[i][i] % Mod; for (int j = i; j <= n; ++ j) a[i][j] = (LL) a[i][j] * Inv % Mod; for (int j = i + 1; j <= n; ++ j) if (a[j][i]) for (int k = n; k >= i; -- k) a[j][k] = (a[j][k] + (LL) (Mod - a[i][k]) * a[j][i]) % Mod; } return x ? adj(res = -res) : res; }
void dfs(int id, int ed, int cur, int ct, std::vector<int> *v) { if (id > ed) return v[ct].push_back(cur); dfs(id + 1, ed, cur, ct, v); if (cur + abl[id] <= lim) dfs(id + 1, ed, cur + abl[id], ct + 1, v); }
int work() { for (int i = C[0][0] = 1; i < N; ++ i) for (int j = C[i][0] = 1; j <= i; ++ j) adj(C[i][j] = C[i - 1][j - 1] + C[i - 1][j] - Mod); std::vector<int> vl[N]{}, vr[N]{}; int ok[N]{}, ans[N]{}; cnt = 0; for (int i = 1; i <= n; ++ i) if (v[i] >= 0) abl[++ cnt] = v[i]; int lcnt = cnt / 2, rcnt = cnt - cnt / 2; dfs(1, cnt / 2, 0, 0, vl), dfs(cnt / 2 + 1, cnt, 0, 0, vr); for (int i = 1; i <= lcnt; ++ i) std::sort(vl[i].begin(), vl[i].end()); for (int i = 1; i <= rcnt; ++ i) std::sort(vr[i].begin(), vr[i].end()); for (int i = 0; i <= lcnt; ++ i) for (int j = 0; j <= rcnt; ++ j) { int r = vr[j].size() - 1; for (int &x : vl[i]) { while (r >= 0 && vr[j][r] + x > lim) r --; if (r < 0) break; adj(ok[i + j] += r + 1 - Mod); } } int res = 0; for (int ncnt = 0; ncnt <= cnt; ++ ncnt) { for (int i = 1; i <= n; ++ i) for (int j = 1; j <= n; ++ j) a[i][j] = i != j; for (int i = 1; i <= cnt - ncnt; ++ i) for (int j = 1; j <= cnt - ncnt; ++ j) a[i][j] = 0; for (int i = 1; i <= cnt - ncnt; ++ i) for (int j = 1; j <= cnt; ++ j) a[i][j] = a[j][i] = 0; for (int i = 1; i <= n; ++ i) { int t = 0; for (int j = 1; j <= n; ++ j) t += a[i][j]; for (int j = 1; j <= n; ++ j) if (a[i][j]) a[i][j] = Mod - a[i][j]; a[i][i] = t; }
ans[ncnt] = det(n - 1); for (int ex = 0; ex < ncnt; ++ ex) ans[ncnt] = (ans[ncnt] + (LL) (Mod - ans[ex]) * C[ncnt][ex]) % Mod; res = (res + (LL) ans[ncnt] * ok[ncnt]) % Mod; } return res; }
class SweetFruits { public: signed countTrees(std::vector<signed> val, signed _lim) { n = val.size(), lim = _lim; for (int i = 1; i <= n; ++ i) v[i] = val[i - 1]; return work(); } };
|