1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
| #include <bits/stdc++.h>
#define ll long long #define sum(a, b, mod) (((a) + (b)) % mod)
const ll MaxN = 5e5 + 10;
struct node { ll maxp, val; ll l, r, pos; bool operator<(node x) const { return val < x.val; } };
ll n, k, l, r; std::priority_queue<node> q; ll a[MaxN], lg[MaxN], sum[MaxN], max[MaxN][21], maxp[MaxN][21];
void query(ll l, ll r, ll &val, ll &pos) { ll len = lg[r - l + 1]; val = std::max(max[l][len], max[r - (1 << len) + 1][len]); pos = (max[l][len] > max[r - (1 << len) + 1][len]) ? maxp[l][len] : maxp[r - (1 << len) + 1][len]; }
void prework() { lg[0] = -1; for (ll i = 1; i <= n; i++) maxp[i][0] = i, max[i][0] = sum[i], lg[i] = lg[i >> 1] + 1; for (ll j = 1; j <= 20; j++) for (ll i = 1; i <= n - (1 << j) + 1; i++) max[i][j] = std::max(max[i][j - 1], max[i + (1 << (j - 1))][j - 1]); for (ll j = 1; j <= 20; j++) for (ll i = 1; i <= n - (1 << j) + 1; i++) maxp[i][j] = ((max[i][j - 1] > max[i + (1 << (j - 1))][j - 1]) ? maxp[i][j - 1] : maxp[i + (1 << (j - 1))][j - 1]); }
inline ll read() { ll x = 0, f = 1; char ch = getchar(); while (ch > '9' || ch < '0') { if (ch == '-') f = 0; ch = getchar(); } while (ch <= '9' && ch >= '0') x = (x << 1) + (x << 3) + (ch ^ 48), ch = getchar(); return f ? x : (-x); }
int main() { n = read(), k = read(); l = read(), r = read(); for (ll i = 1; i <= n; i++) a[i] = read(), sum[i] = sum[i - 1] + a[i]; prework(); for (ll i = 1; i <= n; i++) { ll pos, val; if (i + l - 1 > n) break; query(i + l - 1, std::min(i + r - 1, n), val, pos); val -= sum[i - 1], pos -= i - 1; q.push((node){pos, val, l, std::min(r, n - i + 1), i}); } ll ans = 0; for(ll i = 1; i <= k; i++) { node x = q.top(); q.pop(), ans += x.val; if(x.maxp > x.l) { ll pos, val; query(x.pos + x.l - 1, x.pos + x.maxp - 2, val, pos); val -= sum[x.pos - 1], pos -= x.pos - 1; q.push((node){pos, val, x.l, x.maxp - 1, x.pos}); } if(x.maxp < x.r) { ll pos, val; query(x.pos + x.maxp, x.pos + x.r - 1, val, pos); val -= sum[x.pos - 1], pos -= x.pos - 1; q.push((node){pos, val, x.maxp + 1, x.r, x.pos}); } } printf("%lld\n", ans); return 0; }
|