読者です 読者をやめる 読者になる 読者になる

naoya_t@hatenablog

いわゆるチラシノウラであります

Facebook HackerCup Round2 : 問題A. Road Removal

何かの最小全域木っぽい感じがするんだけど…とか思ったけれどグラフ系は苦手意識があって飛ばしちゃって勿体ないお化けが出ている日曜の昼下がり、皆様いかがお過ごしでしょうか。

問題文はこちら

DFS書いたりBFS書いたりDP書いたりと無駄な試行錯誤の末に辿りついた答えはこんな感じ:

(0) 都市 { 0...N-1 } のうち最初のK個 { 0...K-1 } が重要な都市。{ K...N-1 } は重要でない都市。都市A,B間には道路があったりなかったり。(入力データ参照)


(1) すべての道路は
〈分類0〉重要な都市 - 重要な都市
〈分類1〉重要な都市 - 重要でない都市
〈分類2〉重要でない都市 - 重要でない都市
に分類できる


(2) 〈分類2〉の道路(重要でない都市ー重要でない都市を結ぶもの)だけしかなかったとして、union-findでグループ分けしてみる。
そして各グループをそれぞれ1つの都市としてみなし、いちばん小さい都市番号(root)union-findのrootで代表させる。〈分類2〉の道路はループしていても重要な都市に影響しないので放置。


(3) 〈分類1〉の道路で、ある重要な都市から、ある重要でない都市(都市グループ)への道路の数を数えてみる。これが2本以上あると、その重要な都市を含んだループになってまずいので1本だけ残して削除。


(4) 〈分類1〉の残り + 〈分類0〉の道路 について、改めてunion-findでグループ分けしてみる。
そしてそれぞれのグループについて最小全域木を作る。最小全域木で使われなかった道路は削除。
なぜそれが最小全域木で済むのか… この時点で残っている道路網のどこにループがあっても重要な都市に絡んでしまうので、ループさせるわけには行かないから。


(5) (3)と(4)で削除した道路の総数が答え。

実装はこんな感じ:

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <map>
#include <set>

using namespace std;

#define pb  push_back
#define all(c)  (c).begin(), (c).end()
#define tr(c, i)  for (typeof((c).begin()) i = (c).begin(); i != (c).end(); ++i)
#define found(s, e)  ((s).find(e) != (s).end())

class UnionFind {
  vector<int> data;
 public:
  explicit UnionFind(int size) : data(size, -1) { }
  bool unionSet(int x, int y) {
    x = root(x);
    y = root(y);
    if (x != y) {
      if (data[y] < data[x]) swap(x, y);
      data[x] += data[y];
      data[y] = x;
    }
    return x != y;
  }
  bool findSet(int x, int y) { return root(x) == root(y); }
  int root(int x) { return data[x] < 0 ? x : data[x] = root(data[x]); }
  int size(int x) { return -data[root(x)]; }
  bool has(int x) { return data[x] >= 0; }
};

typedef pair<int, int> i_i;

// 最小全域木で使われるarcの本数を返す。クラスカル法。
int kruscal_count(const vector<int>& nodes, const vector<i_i>& arcs) {
  int N = nodes.size(), used = 0;
  map<int, int> trans;
  for (int i = 0; i < N; ++i) trans[nodes[i]] = i;
  UnionFind uf(N);
  tr(arcs, it) {
    int u = trans[it->first], v = trans[it->second];
    if (uf.root(u) != uf.root(v)) {
      uf.unionSet(u, v);
      ++used;
    }
  }
  return used;
}

int main() {
  int T;
  cin >> T;  // 1-20
  for (int t = 1; t <= T; ++t) {
    int N, M, K;
    cin >> N >> M >> K;  // 1-10000, 50000, (10000)

    UnionFind r2_uf(N);  // 〈分類2〉をunion-findするため
    vector<i_i> r0, r1, r2;
    for (int m = 0; m < M; ++m) {  // 50000
      int a, b;
      cin >> a >> b;
      if (a > b) swap(a, b);  // a < b を保証

      i_i road(a, b);

      if (b < K) {
        // a < b < K
        r0.push_back(road);  // 〈分類0〉重要-重要
      } else if (a < K) {
        // a < K <= b
        r1.push_back(road);  // 〈分類1〉重要-重要でない
      } else {
        // K <= a < b
        r2.push_back(road);  // 〈分類2〉重要でない-重要でない
        r2_uf.unionSet(a, b);
      }
    }

    // #2〈分類2〉の各グループを代表する都市は r2_uf.root(重要でない都市) で分かる。

    // #3〈分類1〉の道路で、ある重要な都市から、ある重要でない都市を結ぶ道路は1つあれば良いので複数あれば削除。
    // #3〈分類1' 〉=〈分類1〉から要らない道路を消したもの
    set<i_i> r1_;
    tr(r1, it) {  // 50000
      i_i road(it->first, r2_uf.root(it->second));
      if (!found(r1_, road)) r1_.insert(road);
    }
    int cut = r1.size() - r1_.size();


    // #4〈分類0+1' 〉で改めてunion-findでグループ分けしてみる。
    vector<i_i> arcs;  //〈分類0+1' 〉の道路
    arcs.insert(arcs.end(), all(r0));
    arcs.insert(arcs.end(), all(r1_));
    set<int> nodes; //〈分類0+1' 〉に含まれる都市の集合
    UnionFind r01_uf(N);
    tr(arcs, it) {  // 50000
      nodes.insert(it->first);
      nodes.insert(it->second);
      r01_uf.unionSet(it->first, it->second);
    }

    map<int, vector<int> > nodes_for_root;
    set<int> roots;
    tr(nodes, it) {  // 10000
      int node = *it, root = r01_uf.root(node);
      roots.insert(root);
      nodes_for_root[root].push_back(node);
    }
    map<int, vector<i_i> > arcs_for_root;
    tr(arcs, it) {  // 50000
      i_i arc = *it;
      int root = r01_uf.root(arc.first);
      arcs_for_root[root].push_back(arc);
    }

    // #4 それぞれのグループについて最小全域木を作る。
    // #4 最小全域木で使われなかった道路は削除。
    tr(roots, it) {  // 10000
      int root = *it;
      cut += arcs_for_root[root].size()
          - kruscal_count(nodes_for_root[root], arcs_for_root[root]);
    }

    printf("Case #%d: %d¥n", t, cut);
  }
}

とりあえず問題文のサンプルデータは通ったんだけどどうなんだろう…
→ Practiceモードになったので試してみたらこれで行けました!