naoya_t@hatenablog

いわゆるチラシノウラであります

〈復習〉AGC023 C - Painting Machines (800)

解説放送を見て完全に理解した


agwたんに説明したときのメモ

N個のマスをK回ジャストで塗りつぶせる順列の個数を求めるのは難しそうだけど、
K回「以内で」塗りつぶせる順列の個数は割と楽に求められる(あとで説明する)。
↑K回以内での個数のリストさえあれば、K回ジャストの個数のリストもそこから差分で求められるのでそれぞれのKを掛けて和を取れば答え

N個のマスをK回以内で塗りつぶせる順列ってどういうパターンになってるか。
順列の最初のK回で
・最初と最後のペイントマシン(#1,#N-1)は必ず稼働している
・その間のマシンもすくなくとも最小限は押さえている。すなわち、塗りつぶせているということは連続か1つ飛ばしでペイントマシンが稼働している
・稼働ペイントマシンが隣り合っている区間がa個、1つ飛ばしになっている区間をb個、とすると
 a + 2b = 全区間(=N-2)
 1 + a + b = K
なので
 a = 2K - N, b = N - K - 1
になる。

・a+b区間のどれをaの区間にするか、の (a+b)Ca = (a+b)!/(a!b!) 通り
・最初のK回の全permutation の K! 通り
・残り(N-1)-K回の全permutation の (N-1-K)! 通り
の積、が「N個のマスをK回以内で塗りつぶせる順列の個数」
factorialをあらかじめ100万まで計算しておけば1つ1つはO(1)で求まるので
全体としてはO(N)

そこまでわかれば実装簡単

自分で実装してみた
→AC
(あらかじめ100万、を10万しか取ってなくて1投目でRE出した><)

#include <bits/stdc++.h>
using namespace std;

#define NDEBUG
#include <cassert>

typedef long long ll;

#define pb  push_back
#define rep(var,n)  for(int var=0;var<(n);++var)
#define ALL(c)  (c).begin(),(c).end()
#define IN(x,a,b) ((a)<=(x)&&(x)<=(b))


const ll M=1000000007LL;

ll ADD(ll x, ll y) { return (x+y) % M; }
ll SUB(ll x, ll y) { return (x-y+M) % M; }
ll MUL(ll x, ll y) { return x*y % M; }
ll POW(ll x, ll e) { ll v=1; for(; e; x=MUL(x,x), e>>=1) if (e&1) v = MUL(v,x); return v; }
ll DIV(ll x, ll y) { assert(y%M!=0); return MUL(x, POW(y, M-2)); }


int N;


ll _fact[1000001];

void _prepare_fact(){
    _fact[0] = _fact[1] = 1LL;
    for(int i=2; i<=N; ++i) {
        _fact[i] = MUL(_fact[i-1], i);
    }
}

inline ll fact(int x){
    assert(IN(x,0,N));
    return _fact[x];
}

ll sub(int k) {
    // k回以内で
    // a +  b = k - 1
    // a + 2b = N - 2
    int a = k*2 - N;
    if (a < 0) return 0;
    int b = N - k - 1;
    assert(a+b == k-1);
    assert(a+b*2 == N-2);

    return MUL(
            DIV(fact(a+b),
                MUL(fact(a), fact(b))),
            MUL(fact(k), fact(N-1-k))
        );
}

ll solve() {
    ll total = 0;
    ll last = 0;
    for (int k=1; k<=N-1; ++k) {
        ll a = sub(k), x = SUB(a, last);
        total = ADD(total, MUL(x, k));
        last = a;
    }
    return total;
}

int main() {
    cin >> N;
    assert(IN(N,2,1000000));

    _prepare_fact();

    cout << solve() << endl;
    return 0;
}