Post

2024牛客寒假2C

2024牛客寒假2C

01trie.

题目链接

Tokitsukaze and Min-Max XOR

思路

不难发现,如果有两个数异或起来小于等于 $k$,那么值位于这两个数之间的数选或不选都没有影响,选定两个异或起来小于等于 $k$ 的数,它的贡献为 $2^x$,$x$ 是两个数之间数的个数。

不妨将整个数组排序,对于每一个数,找到它前面所有与它异或起来小于等于 $k$ 的数,然后分别计算贡献。考虑用 01trie 优化复杂度,假设排完序后的数组 $a$,对于 $a_{pos}$,假设在它之前有 $a_{b_1},a_{b_2},…$ 与它异或起来小于等于 $k$,那么 $a_{pos}$ 的贡献为:

\[\sum_{i}{2^{pos-b_i-1}} = 2^{pos}\sum_{i}{\frac{1}{2^{b_i+1}}}\]

只需要在 01trie 上维护 $\sum{({2^{i+1}})^{-1}}$,然后每次查询即可。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#include <bits/stdc++.h>
using namespace std;
#define pii pair<int, int>
const int N = 2e5 + 5, mod = 1e9 + 7;
int n, a[N], k;

long long ksm(int a, int b) {
    int base = a, res = 1;
    while (b) {
        if (b & 1)
            res = 1ll * res * base % mod;
        base = 1ll * base * base % mod;
        b >>= 1;
    }
    return res;
}

long long inv(int x) {
    return ksm(x, mod - 2);
}

class Trie {
private:
    struct node {
        long long nxt[2], cnt, val;
    };
    int tot = 0;
    vector<node> tr;

public:
    Trie(int sz) {
        tr.resize(sz * 32, {{0, 0}, 0, 0});
    }

    void insert(int v, int pos) {
        int x = 0;
        for (int i = 30; i >= 0; i--) {
            int nxt = v >> i & 1;
            if (!tr[x].nxt[nxt]) {
                tr[x].nxt[nxt] = ++tot;
                tr[tr[x].nxt[nxt]].val = tr[tr[x].nxt[nxt]].cnt = 0;
            }
            x = tr[x].nxt[nxt];
            tr[x].cnt++;
            tr[x].val = (tr[x].val + inv(ksm(2, pos + 1))) % mod;
        }
    }

    long long query(int v) {
        long long res = 0, x = 0;
        for (int i = 30; i >= 0; i--) {
            int bitv = v >> i & 1, bitk = k >> i & 1;
            if (bitk == 1 && tr[x].nxt[bitv]) {
                res = (res + tr[tr[x].nxt[bitv]].val) % mod;
            }
            x = tr[x].nxt[bitv ^ bitk];
            if (!x)
                return res;
        }
        res = (res + tr[x].val) % mod;
        return res;
    }
};

void solve() {
    long long res = 0;
    cin >> n >> k;
    Trie *trie = new Trie(n);
    for (int i = 1; i <= n; i++)
        cin >> a[i];
    sort(a + 1, a + n + 1);
    for (int i = 1; i <= n; i++) {
        res = (res + trie->query(a[i]) * ksm(2, i) % mod) % mod;
        trie->insert(a[i], i);
    }
    cout << (res + n) % mod << '\n';
    free(trie);
}

int main() {
    cout.tie(0), cin.tie(0)->sync_with_stdio(0);
    int T;
    cin >> T;
    while (T--) {
        solve();
    }
    return 0;
}
This post is licensed under CC BY 4.0 by the author.