naoya_t@hatenablog

いわゆるチラシノウラであります

LeetCode: 97. Interleaving String


LeetCodeさんからこんな挑戦状が届いた
f:id:n4_t:20181217164859p:plain:w480
ので挑発に乗ってみる
文字列s_3が、2つの文字列s_1,s_2をinterleaveして出来たものか否かを判定する問題。
(例えば s_1:aaa, s_2:bbb の場合、s_3が aaabbb, ababab, abbaab, babaab ならtrueだが、aaa, bbb, aaaaaa, aabbbb, aabbcc などだと falseになる)
O(N^2)解はすぐに思いつくんだけどHardってなってるしO(N\log N)解がある?

と悩んで
まあとりあえずO(N^2)解を投げて様子を見よう
→AC
おい

自分の解法はこんな感じ:

// いろいろ略
class Solution {
public:
    bool isInterleave(string s1, string s2, string s3) {
        int L1 = s1.size(), L2 = s2.size(), L3 = s3.size();
        if (L1+L2 != L3) return false;

        vi cnt1(26,0), cnt2(26,0), cnt3(26,0);
        rep(i,L1) ++cnt1[s1[i]-'a'];
        rep(i,L2) ++cnt2[s2[i]-'a'];
        rep(i,L3) ++cnt3[s3[i]-'a'];
        rep(i,26) if (cnt1[i]+cnt2[i] != cnt3[i]) return false;

        vvi dp(2);
        dp[0].pb(0);

        rep(i,L3) {
            int i0=i%2, i1=1-i0;
            dp[i1].clear();
            for(int k:dp[i0]) {
                //k: ここまででs1からk文字,s2からi-k文字取れてた。
                if (i-k<L2 && s3[i]==s2[i-k]) {
                    if (dp[i1].empty() || dp[i1].back() < k)
                        dp[i1].pb(k);
                }
                if (k<L1 && s3[i]==s1[k]) {
                    dp[i1].pb(k+1);
                }
            }
        }

        return !dp[L3%2].empty();
    }
};

|s_1|vectorやbitsetを用意してそれを更新していく方法、ではなく
s_3を左から1文字ずつ見ていって、その時点でs_1から何文字消費していることがありうるかのリストを持ち回る方式。

解説を見たら、Brute force → メモ化再帰 → 二次元DP → 一次元DP という風に順を追って説明されていたけれど

  • Time complexity: O(m\cdot n)
  • Space complexity: O(n)

までだ

bitsetを使った解(これの方が若干遅い)

#define MAXSIZE 40001

class Solution {
public:
    bool isInterleave(string s1, string s2, string s3) {
        int L1 = s1.size(), L2 = s2.size(), L3 = s3.size();
        if (L1+L2 != L3) return false;

        vi cnt1(26,0), cnt2(26,0), cnt3(26,0);
        rep(i,L1) ++cnt1[s1[i]-'a'];
        rep(i,L2) ++cnt2[s2[i]-'a'];
        rep(i,L3) ++cnt3[s3[i]-'a'];
        rep(i,26) if (cnt1[i]+cnt2[i] != cnt3[i]) return false;

        vector<bitset<MAXSIZE>> dp(2);
        dp[0].set(0, true);

        rep(i,L3) {
            int i0=i%2, i1=1-i0;
            dp[i1].reset();
            for(int k=0; k<=L1; ++k) {
                if (!dp[i0].test(k)) continue;
                //k: ここまででs1からk文字,s2からi-k文字取れてた。
                if (i-k<L2 && s3[i]==s2[i-k]) {
                    dp[i1].set(k, true);
                }
                if (k<L1 && s3[i]==s1[k]) {
                    dp[i1].set(k+1, true);
                }
            }
        }

        return dp[L3%2].test(L1);
    }
};