「ZJOI2020」��抽卡
「ZJOI2020」抽卡

2020 年 06 月 28 日

mm 张带编号卡牌,每次你可以随机抽取一张。抽中每张的概率均为 1m\frac 1 m。当编号连续的 kk 张牌都被抽取过时,游戏结束。

问游戏结束的期望步数。

1km2×1051 \leq k \leq m \leq 2 \times 10^5

题解

Part 1

我们可以直接对每张牌第一次被抽中的操作序列计数。

把牌的每一段编号连续区间分开考虑,每一段处理出选中连续区间长度不超过 kk 的方案数(同时容易得到超过的方案数),然后分治 + NTT 合并,这是平凡的。

这个做法的时间复杂度是 O(n2nlog2n)O(n^2+n \log^2 n) 的,瓶颈在于前半部分即处理出分成把 nnm=1...nm=1...n 段满足每一段长度都不超过 kk 的方案数,更进一步的可以表示为:

B(u)=[xn+1]11uxxk+11xB(u) = [x^{n+1}] \frac 1 {1 - u \frac {x - x^{k+1}} {1 - x}}

我们需要求出多项式 BB

Part2

注意到这是一个拓展拉格朗日反演的形式,我们需要求出 F(x)=xxk+11x\displaystyle{F(x) = \frac{x - x^{k+1}} {1 - x}} 的复合逆。 相当于我们要求 G(x)G(x) 满足 F(G(x))=xF(G(x)) = x,根据多项式牛顿迭代,有

T(G(x))=F(G(x))x=G(x)Gk+1(x)1G(x)xT(G(x))=(1(k+1)Gk(x))(1G(x))+(G(x)Gk+1(x))(1G(x))2=1(k+1)Gk(x)+kGk+1(x)12G(x)+G2(x)T(G(x)) = F(G(x)) - x = \frac {G(x) - G^{k+1}(x)}{1 - G(x)} - x \\ \begin{aligned} T'(G(x)) &= \frac {(1 - (k+1)G^k(x))(1 - G(x)) + (G(x) - G^{k+1}(x))}{(1 - G(x))^2} \\ &= \frac {1 - (k+1)G^k(x) + kG^{k+1}(x)} {1 - 2G(x) + G^2(x)} \\ \end{aligned}

由多项式牛顿迭代,我们可以倍增得到 G(x)G(x)

Part3

代入拓展拉格朗日反演的式子,令 H(x)=11ux\displaystyle{H(x) = \frac 1 {1 - ux}} 我们可以得到

S=[xn+1]H(F(x))=1n+1[xn]H(x)(xG(x))n+1S = [x^{n+1}] H(F(x)) = \frac 1 {n+1} [x^n] H'(x) \left(\frac x {G(x)}\right)^{n+1}

T(x)=1n+1(xG(x))n+1\displaystyle{T(x) = \frac 1{n+1} \left(\frac x {G(x)}\right)^{n+1}},则有

S=[xn]H(x)T(x)=[xn]T(x)u(1ux)2=[xn]T(x)ui=0(i+1)(ux)i\begin{aligned} S &= [x^n] H'(x) T(x) = [x^n] T(x) \frac {u} {(1-ux)^2} \\ &= [x^n] T(x) u \sum_{i=0}^\infty (i+1) (ux)^i \end{aligned}

即可直接得到 S(u)S(u)

问题解决,总时间复杂度 O(nlog2n)O(n\log^2 n)

代码

#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 10, mod = 998244353;
int n, m, k, a[N], b[N], rev[N << 2];
struct z {
  int x;
  inline z() : x(0) {}
  inline z(int x) : x(x) {}
  friend inline z operator*(z a, z b) { return (long long)a.x * b.x % mod; }
  friend inline z operator-(z a, z b) { return (a.x -= b.x) < 0 ? a.x + mod : a.x; }
  friend inline z operator+(z a, z b) { return (a.x += b.x) >= mod ? a.x - mod : a.x; }
} ans, len[N], fac[N], ifac[N], w[N << 2], S[N];
using vec = vector<z>;
inline z C(int n, int m) { return n < m ? 0 : fac[n] * ifac[m] * ifac[n - m]; }
inline z fpow(z a, int b) {
  z s = 1;
  for (; b; b >>= 1, a = a * a)
    if (b & 1) s = s * a;
  return s;
}
inline vec resize(vec a, int n) {
  a.resize(n);
  return a;
}
void initfac(int n) {
  fac[0] = fac[1] = ifac[0] = ifac[1] = 1;
  for (int i = 2; i <= n; i++) fac[i] = fac[i - 1] * i;
  for (int i = 2; i <= n; i++) ifac[i] = (mod - mod / i) * ifac[mod % i];
  for (int i = 2; i <= n; i++) ifac[i] = ifac[i - 1] * ifac[i];
}
int init(int n) {
  int lim = 1, k = 0;
  while (lim < n) lim <<= 1, ++k;
  for (int i = 0; i < lim; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1));
  static int len = 1;
  for (; len < lim; len <<= 1) {
    z wn = fpow(3, (mod - 1) / (len << 1));
    w[len] = 1;
    for (int i = 1; i < len; i++) w[i + len] = w[i + len - 1] * wn;
  }
  return lim;
}
void dft(vec &a, int lim) {
  a.resize(lim);
  for (int i = 0; i < lim; i++)
    if (i < rev[i]) swap(a[i], a[rev[i]]);
  for (int len = 1; len < lim; len <<= 1)
    for (int i = 0; i < lim; i += (len << 1))
      for (int j = 0; j < len; j++) {
        z x = a[i + j], y = a[i + j + len] * w[j + len];
        a[i + j] = x + y, a[i + j + len] = x - y;
      }
}
void idft(vec &a, int lim) {
  dft(a, lim), reverse(&a[1], &a[lim]);
  z inv = fpow(lim, mod - 2);
  for (int i = 0; i < lim; i++) a[i] = a[i] * inv;
}
inline vec mul(vec a, vec b, int l) {
  if (a.size() < 10 || b.size() < 10) {
    vec c(a.size() + b.size() - 1);
    for (int i = 0; i < a.size(); i++)
      for (int j = 0; j < b.size(); j++) c[i + j] = c[i + j] + a[i] * b[j];
    return c.resize(l), c;
  }
  int len = a.size() + b.size() - 1, lim = init(len);
  dft(a, lim), dft(b, lim);
  for (int i = 0; i < lim; i++) a[i] = a[i] * b[i];
  return idft(a, lim), a.resize(l), a;
}
inline vec operator*(const vec &a, const vec &b) { return mul(a, b, a.size() + b.size() - 1); }
inline vec operator+(vec a, const vec &b) {
  a.resize(max(a.size(), b.size()));
  for (int i = 0; i < b.size(); i++) a[i] = a[i] + b[i];
  return a;
}
inline vec operator-(vec a, const vec &b) {
  a.resize(max(a.size(), b.size()));
  for (int i = 0; i < b.size(); i++) a[i] = a[i] - b[i];
  return a;
}
vec inv(const vec &f, int len = -1) {
  if ((len = ~len ? len : f.size()) == 1) return {fpow(f[0], mod - 2)};
  vec a(&f[0], &f[len]), b = inv(f, (len + 1) >> 1);
  int lim = init((len << 1) - 1);
  dft(a, lim), dft(b, lim);
  for (int i = 0; i < lim; i++) a[i] = b[i] * (2 - a[i] * b[i]);
  return idft(a, lim), a.resize(len), a;
}
vec deri(vec f) {
  for (int i = 0; i <= (int)f.size() - 2; i++) f[i] = f[i + 1] * (i + 1);
  return f.back() = 0, f;
}
vec inte(vec f) {
  for (int i = (int)f.size() - 1; i >= 1; i--) f[i] = f[i - 1] * fpow(i, mod - 2);
  return f.front() = 0, f;
}
vec ln(const vec &f) { return inte(mul(inv(f), deri(f), f.size())); }
vec exp(const vec &f, int len = -1) {
  if ((len = ~len ? len : f.size()) == 1) return {1};
  vec a(&f[0], &f[len]), b = exp(f, (len + 1) >> 1);
  return b.resize(len), mul(b, a + vec{1} - ln(b), len);
}
vec fpow(vec a, int b) {
  int n = a.size();
  vec s;
  for (int c = 0; c < n; c++)
    if (a[c].x) {
      int l = n - c * b;
      if (l <= 0) return s.resize(n), s;
      for (int i = 0; i < l; i++) a[i] = a[i + c];
      a.resize(l);
      a = ln(a);
      for (int i = 0; i < l; i++) a[i] = a[i] * b;
      a = exp(a), s.resize(c * b);
      s.insert(s.end(), a.begin(), a.end());
      return s;
    }
  return a;
}
vec complex(const vec &g) { // F(G(x))
  vec s, c = {1};
  for (int i = 1; i <= k; i++) c = mul(c, g, g.size()), s = s + c;
  return s;
}
vec complex_inv(int len) { // G^{-1}(F(x))
  if (len == 1) return {0};
  vec g = resize(complex_inv((len + 1) >> 1), len), gk = fpow(g, k), gk1 = mul(gk, g, len);
  vec res = g - mul(mul(g - gk1 - vec{0, 1} * (vec{1} - g), vec{1} - g, len), inv(vec{1} - vec{k + 1} * gk + vec{k} * gk1), len);
  return res;
}
inline vec sol(int n) { //	n+1个球,分m组,每组1~k个。
  vec g = complex_inv(n + 1), res(n + 1);
  g.erase(g.begin());
  g = fpow(inv(g), n + 1) * vec{fpow(n + 1, mod - 2)};
  for (int i = 1; i <= n; i++) res[i - 1] = (i + 1) * g[n - i];
  reverse(&res[0], &res[n]), res[n] = n + 1 <= k;
  return res;
}
pair<vec, vec> solve(int l, int r) {
  if (l == r) {
    int n = b[l];
    vec F(n + 1), G = sol(n);
    for (int i = 0; i <= n; i++) {
      F[i] = fac[n] * ifac[n - i] - G[i] * fac[i] - (i ? S[i - 1] : 0) * ifac[n - i];
      S[i] = (i ? S[i - 1] : 0) + F[i] * fac[n - i];
    }
    for (int i = 0; i < n; i++) F[i] = F[i + 1] * ifac[i];
    return F.pop_back(), pair<vec, vec>{F, G};
  }
  int mid = (l + r) >> 1;
  auto L = solve(l, (l + r) >> 1), R = solve(((l + r) >> 1) + 1, r);
  return {L.first * R.second + L.second * R.first, L.second * R.second};
}
int main() {
#ifdef memset0
  freopen("1.in", "r", stdin);
#endif
  cin >> n >> k;
  initfac(n + 5);
  for (int i = 1; i <= n; i++) cin >> a[i];
  sort(a + 1, a + n + 1);
  for (int i = 1, j; i <= n; i = j + 1) {
    for (j = i; j < n && a[j + 1] == a[i] + j - i + 1; j++)
      ;
    b[++m] = j - i + 1;
  }
  auto res = solve(1, m);
  for (int i = 1; i <= n; i++) {
    len[i] = len[i - 1] + n * fpow(n - i + 1, mod - 2);
    ans = ans + res.first[i - 1] * fac[i - 1] * fac[n - i] * len[i];
  }
  cout << (ans * ifac[n]).x << endl;
}

评论

TABLE OF CONTENTS

题解
Part 1
Part2
Part3
代码