d0tfi1e’s blog

趣味と日記

ARC068 E - Snuke Line

解説がよくわからなかったので自分用にまとめ直します。

問題

数列 1, 2, ..., mがあり、これ上に n個の区間があります。区間 [l_i, r_i]で表されます。

 m行出力しなさい。 i = 1, \ldots, m行目には、 iの倍数が何種類の区間に含まれているかを答えなさい。

制約: 1 \le n \le 3 \times 10^5, 1 \le m \le 10^5, 1 \le l_i \le r_i \le m

ポイント

 iの倍数を見ているとして、区間の長さが i以上の区間は確実に含まれます。 では区間の長さが i未満の区間については何種類含まれるでしょうか?これは実は次のように考えても大丈夫です。

  1. 長さ mの配列 cntを用意し、 j = 1, \ldots m番目を、その区間を長さ i未満の区間が何種類カバーしているか、を表すことにします。
  2. 長さ i未満の区間を幅が短い順に見ていき(そのような区間のインデックスを jとする)、 l_j \le k \le r_jとなるような kについて  cnt[k]をインクリメントします。この操作はセグメント木やBITを使うことで O(\log m)で可能です。

このとき、 cnt[i], cnt[2i], \ldotsの合計が、区間の長さが i未満の区間のうち、 iの倍数が含まれる区間の種類になります。 なぜなら、長さが i未満の区間は、複数の cnt[ij], cnt[ik]にまたがってカウントされることはありえないからです。 もちろん、上記の作業をすべての(長さが i以上の区間をも含む)区間についてに行ってしまうと重複が発生するのでうまく数えることができません。

さて、このアルゴリズムを愚直に行った場合の計算量は、 iについてのループと、調べる区間の数と、区間についての加算があり、さらに最後に iの倍数を順に見ていく作業があるので  O(nm\log m + n\log m) = O(nm\log m)です。まだダメです。

高速化

さて、先程述べたアルゴリズムには明らかに高速化できる点があります。

長さ i未満の区間を順に見ていき(そのような区間のインデックスを jとする)、 l_j \le k \le r_jとなるような kについて  cnt[k]をインクリメントします。この操作はセグメント木やBITを使うことで O(\log m)で可能です。

の部分で、 iを1から mまで昇順に調べるとすると、 i番目を調べているとき、 i - 1番目を調べたときにインクリメントした区間を 再びインクリメントしています。これは明らかに無駄であり、 i - 1番目を調べたときの続きから該当する区間をインクリメントしていくことで、 それぞれの区間についてインクリメントされる回数を1回に抑えることができます。こうすると、計算が間に合います。

実装

みやすさのため、セグメント木部分を隠蔽します。

struct range {
    int w, l, r;
    bool operator<(const range& rhs) const {
        return w < rhs.w;
    }
};

int main() {
    int n, m;
    cin >> n >> m;
    
    vector<range> vs;
    for (int i = 0; i < n; i++) {
        int l, r;
        cin >> l >> r;
        vs.push_back({r - l, l, r});
    }
    
    sort(vs.begin(), vs.end());
    SegmentTree cnt(vector<int>(m + 1));
    
    int small_range_begin = 0, small_range_end;
    for (int d = 1; d <= m; d++) {
        int cnt = 0;
        int long_range_cnt = n - int(lower_bound(vs.begin(), vs.end(), range({d, 0, 0})) - vs.begin());
        cnt += long_range_cnt;
        small_range_end = n - long_range_cnt;
        for (int i = small_range_begin; i < small_range_end; i++) {
            cnt.add(vs[i].l, vs[i].r + 1, 1);
        }
        small_range_begin = small_range_end;
        for (int i = d; i <= m; i += d) {
            cnt += cnt.getsum(i, i + 1);
        }
        cout << cnt << endl;
    }
}