NOIP1126 疫情控制

原题链接:https://www.luogu.org/problem/P1084

收录于 NOI 系列题解集

思路

在有限的时间内,尽量使军队向上跳,检查是否为合法解,不难发现答案具有单调性,我们可以用二分答案求出所需的最少时间。

观察数据范围,暴跳的话一定会TLE,因此考虑通过树上倍增进行预处理,吐槽一下:倍增太容易写错了,有一处不小心把fadist写反了 debug 了好久的说(菜的真实)

使军队跳到最高非根节点后,若不是根节点的子节点,则用mark标记一下;否则判断一下该军队是否可以走到根节点,如果不行同样标记一下,如果行,则使其走到根节点,并将(剩余时间, 节点编号)存入到一个二元组中。(这样做是因为军队可以通过根节点转移子树)

然后更新根节点的子节点的标记,若该子节点不可以向下到达叶节点,则标记一下,表明不需要从其他子树调来军队。

接下来,检查从所有去往根节点的军队的原节点是否有标记,若没有,检查其剩余时间是否小于等于到达原节点所需的时间,如果是,那不如不转移子树,使它回到原本的子树上,并标记一下。

最后是一个贪心操作,从大到小排序可转移军队的剩余时间,同时从大到小排序到达各个无标记节点所需的时间,若可以转移到每个子树,则该解是合法的。

细节非常多,不愧是紫题。。

代码

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>

const int MAXN = 50005, MAXM = 50005;
int N, M, m[MAXM], fa[MAXN][20], lb[MAXN], top[MAXN];
long long dist[MAXN][20];

struct Edge {
  int v, w, next;
} edge[MAXN << 1];
int head[MAXN], cnt;

void add(int u, int v, int w) {
  edge[cnt].v = v;
  edge[cnt].w = w;
  edge[cnt].next = head[u];
  head[u] = cnt;
  ++cnt;
}

void dfs(int u, int last) {
  for (int i = head[u]; ~i; i = edge[i].next) {
    int v = edge[i].v;
    if (v == last) continue;
    top[v] = top[u] + 1;
    fa[v][0] = u;
    dist[v][0] = edge[i].w;
    for (int j = 1; j <= lb[top[v]]; ++j) {
      fa[v][j] = fa[fa[v][j - 1]][j - 1];
      dist[v][j] = dist[v][j - 1] + dist[fa[v][j - 1]][j - 1];
    }
    dfs(v, u);
  }
}

bool mark[MAXN];
std::pair<long long, int> node[MAXN];
long long rem[MAXN], fix[MAXN];
int nsz, rsz, fsz;

bool cmp(long long a, long long b) {
  return a > b;
}

bool remark(int u, int last) {
  if (mark[u]) return true;
  bool ok = false;
  for (int i = head[u]; ~i; i = edge[i].next) {
    int v = edge[i].v;
    if (v == last) continue;
    if (!remark(v, u)) return false;
    ok = true;
  }
  return ok;
}

bool check(long long limit) {
  nsz = rsz = fsz = 0;
  memset(mark, false, sizeof(mark));

  for (int i = 0; i < M; ++i) {
    int x = m[i];
    long long tot = limit;
    for (int j = lb[top[x]]; j >= 0; --j) {
      if (fa[x][0] == 1) break;
      if (fa[x][j] != 1 && dist[x][j] <= tot) {
        tot -= dist[x][j];
        x = fa[x][j];
      }
    }
    if (fa[x][0] == 1 && tot > dist[x][0]) {
      node[nsz++] = std::make_pair(tot - dist[x][0], x);
    } else {
      mark[x] = true;
    }
  }

  for (int i = head[1]; ~i; i = edge[i].next) {
    int v = edge[i].v;
    mark[v] = remark(v, 1);
  }

  for (int i = 0; i < nsz; ++i) {
    int x = node[i].second;
    if (!mark[x] && node[i].first <= dist[x][0]) {
      mark[x] = true;
    } else {
      rem[rsz++] = node[i].first;
    }
  }

  for (int i = head[1]; ~i; i = edge[i].next) {
    int v = edge[i].v;
    if (!mark[v]) fix[fsz++] = edge[i].w;
  }

  if (fsz == 0) return true;
  if (rsz < fsz) return false;
  std::sort(rem, rem + rsz, cmp);
  std::sort(fix, fix + fsz, cmp);

  for (int i = 0; i < fsz; ++i) {
    if (rem[i] < fix[i]) return false;
  }
  
  return true;
}

long long L = 0, R = 0, ans = -1;

int main() {
  memset(head, 0xFF, sizeof(head));
  scanf("%d", &N);

  for (int i = 2; i < N; ++i) lb[i] = lb[i >> 1] + 1;

  for (int i = 0; i < N - 1; ++i) {
    int u, v, w;
    scanf("%d%d%d", &u, &v, &w);
    add(u, v, w);
    add(v, u, w);
    R += w;
  }
  scanf("%d", &M);
  for (int i = 0; i < M; ++i) {
    scanf("%d", &m[i]);
  }

  top[1] = 0;
  dfs(1, 0);

  while (L <= R) {
    long long mid = (L + R) >> 1;
    if (check(mid)) {
      R = mid - 1;
      ans = mid;
    }
    else {
      L = mid + 1;
    }
  }

  printf("%lld\n", ans);
  return 0;
}
点赞

Leave a Reply