# description

来源:洛谷 P4242

给你一棵树,每个节点都有颜色,有两种操作:

  1. 指定其中 kk 个节点,对于每个节点,计算它 到每个给定的节点的路径上颜色的段数 的和。
  2. 修改一条路径上所有节点的颜色。

# solution

首先看描述就可以识算法了。

看到 k2×105\sum k \le 2\times10^5 显然是建虚树出来。

看到修改路径上的颜色,显然可以用树剖维护。

建虚树时,刚好可以利用树剖得到两个点之间的所有颜色。

然后问题是建虚树出来之后考虑怎么做。

首先看到路径上这三个字,可以考虑点分治。

# 点分治做法

描述约定:

  1. 对于题目中给出的 kk 个点,我们叫他关键点。
  2. 对于虚树上的点,叫它虚树点。(显然关键点都是虚树点)

然后点分治步骤:

  1. 找到树重心,将其作为根节点
  2. 依次遍历根节点的每棵子树,计算子树中所有关键点到根节点的路径中颜色的段数,并 sum += ColorCount, nodeCount += 1
  3. 依次遍历根节点的每棵子树
    1. 删除这棵子树的贡献
    2. 计算每个关键点的答案(我们将经过根节点的答案拆分成两段,一段从根根节点到其它关键点,一段从当前点到根节点,答案显然应该是:
      [子树外节点数×当前关键点到根节点路径颜色的段数+sum][\text{子树外节点数} \times \text{当前关键点到根节点路径颜色的段数} + \text{sum}]
    3. 加回子树贡献
  4. 依次递归每棵子树

时间复杂度 O(nlogn)O(n\log n)

除了点分治以外,还可以考虑虚树题目的常用解法:树形 dp。

# 树形 dp 做法

首先考虑有根树的做法。

以节点 uu 为根,考虑虚树上这样一条边 (u,v)(u, v)

对于子树 vv 中的每个点,到 uu 的路径都必须经过 vv,所以子树 vv 中的每个点到 uu 的贡献都会多上 Count(u,v)1\operatorname{Count}(u, v) - 1,其中 Count(u,v)\operatorname{Count}(u, v) 表示 uuvv 的路径上(包含 u,vu,v)颜色段数。

fuf_u 为以 uu 为根的子树的答案,cuc_u 为以 uu 为根的子树中关键点的个数,得到下面这个式子:

fu=[u 是关键点]+(u,v)Efv+cv×(Count(u,v)1)f_u=[\text{u 是关键点}] + \sum_{(u,v)\in E}f_v + c_v\times(\operatorname{Count}(u,v)-1)

对于之后的换根,同样按照这个式子再减去子树贡献就行(记 pp 为节点父亲):

fu=fpcu×(Count(p,u)1)+(kcu)×(Count(p,u)1)f_u = f_p - c_u\times(\operatorname{Count}(p,u)-1)+(k-c_u)\times(\operatorname{Count}(p,u)-1)

想展开就展开吧,总之我懒得展开 不过不展开也有个好处,就是逻辑更加清晰。

dp 的复杂度是 O(n)O(n) 的,但是树剖和虚树的复杂度是 O(nlogn)O(n\log n) 的,总复杂度 O(nlogn)O(n\log n)

这个做法贴个 Code:

/*
1. 修改一条路径上所有节点的颜色
2. 计算所有节点到其他节点的颜色总数
*/
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
#define rep(i, s, e) for(int i=s;i<=e;++i)
#define pb push_back
const int mxn = 1e5+10;
int n, q, c[mxn], dfn[mxn], ed[mxn], map[mxn];
vector <int> G[mxn];
struct node {
	int lC, rC, sum;
	inline node operator + (const node& rhs) const { 
		if(!rhs.sum) return *this; if(!sum) return rhs;
		return (node) {lC, rhs.rC, sum + rhs.sum - (rC == rhs.lC)};
	}
	inline node reverse() { swap(lC, rC); return *this; }
};
const node E0 = {0, 0, 0};
#define lc (o<<1)
#define rc (o<<1|1)
#define mid ((L+R)>>1)
struct Segtr {
	node s[mxn<<2]; int tag[mxn<<2];
	inline void cover(int o, int tg) { s[o] = {tg, tg, 1}; tag[o] = tg; }
	inline void psdn(int o) { if(tag[o]) { cover(lc, tag[o]), cover(rc, tag[o]); tag[o] = 0; } }
	inline void psup(int o) { s[o] = s[lc] + s[rc]; }
	void build(int o=1, int L=1, int R=n) {
		if(L == R) return (void)(s[o] = {c[map[L]], c[map[L]], 1});
		build(lc, L, mid); build(rc, mid+1, R); psup(o);
	}
	node query(int cl, int cr, int o=1, int L=1, int R=n) {
		if(cl <= L && R <= cr) return s[o]; psdn(o);
		node res = (cl <= mid ? query(cl, cr, lc, L, mid) : E0) + (cr > mid ? query(cl, cr, rc, mid+1, R) : E0);
		return res;
	}
	void set(int cl, int cr, int p, int o=1, int L=1, int R=n) {
		if(cl <= L && R <= cr) return cover(o, p); psdn(o);
		if(cl <= mid) set(cl, cr, p, lc, L, mid);
		if(cr > mid) set(cl, cr, p, rc, mid+1, R); psup(o);
	}
};
#undef lc
#undef rc
#undef mid
struct tr_cut {
	int sz[mxn], son[mxn], fa[mxn], dep[mxn], dfc, top[mxn];
	Segtr T;
	void dfs1(int u, int fat) {
		fa[u] = fat; dep[u] = dep[fat] + 1; sz[u] = 1;
		for(auto v: G[u]) if(v != fat) {
			dfs1(v, u); sz[u] += sz[v]; if(sz[v] > sz[son[u]]) son[u] = v;
		}
	}
	void dfs2(int u) {
		dfn[u] = ++dfc; map[dfc] = u;
		if(son[u]) { top[son[u]] = top[u]; dfs2(son[u]); }
		for(auto v: G[u]) if(v != fa[u] && v != son[u]) { top[v] = v; dfs2(v); }
		ed[u] = dfc;
	}
	inline int lca(int u, int v) {
		while(top[u] != top[v]) dep[top[u]] > dep[top[v]] ? u = fa[top[u]] : v = fa[top[v]];
		return dep[u] > dep[v] ? v : u;
	}
	inline void modify(int u, int v, int p) {
		while(top[u] != top[v]) {
			if(dep[top[u]] < dep[top[v]]) swap(u, v);
			T.set(dfn[top[u]], dfn[u], p);
			u = fa[top[u]];
		}
		if(dep[u] > dep[v]) swap(u, v);
		T.set(dfn[u], dfn[v], p);
	}
	inline node query(int u, int v) {
		// 从浅到深
		if(dep[u] > dep[v]) swap(u, v);
		node res = E0;
		while(top[v] != top[u]) {
			res = T.query(dfn[top[v]], dfn[v]) + res;
			v = fa[top[v]];
		}
		if(u != v) res = T.query(dfn[u]+1, dfn[v]) + res;
		return res;
	}
	inline void init() { dfs1(1, 0); top[1] = 1; dfs2(1); T.build(); }
} cut;
inline bool cmp(int x, int y) { return dfn[x] < dfn[y]; }
struct fake_tree {
	int pt[mxn<<1], m, M, tot, key[mxn], stk[mxn], top, sz[mxn], f[mxn], fa[mxn], ask[mxn];
	struct edge { int v; node w; };
	vector <edge> G[mxn]; node s[mxn];
	inline void init(int GETIN) {
		m = M = GETIN;
		rep(i, 1, m) { int x; scanf("%d", &x); pt[i] = ask[i] = x; key[x] = 2; }
		if(!key[1]) key[1] = 1, pt[++M] = 1;
		sort(pt + 1, pt + M + 1, cmp);
		tot = M;
		rep(i, 2, M) {
			int C = cut.lca(pt[i], pt[i-1]);
			if(!key[C]) key[C] = 1, pt[++tot] = C;
		}
		sort(pt + 1, pt + tot + 1, cmp);
		stk[top = 1] = 1;
		rep(i, 1, tot) c[pt[i]] = cut.T.query(dfn[pt[i]], dfn[pt[i]]).lC;
		#define T (stk[top]) 
		rep(i, 2, tot) {
			int cur = pt[i];
			while(dfn[cur] > ed[T]) --top; node w = cut.query(T, cut.fa[cur]);
			G[T].pb({cur, w}); stk[++top] = pt[i];
		}
		#undef T
	}
	void dfs1(int u, int fat) {
		s[u] = {c[u], c[u], 1}; fa[u] = fat; f[u] = sz[u] = (key[u] == 2);
		for(auto e: G[u]) {
			dfs1(e.v, u); sz[u] += sz[e.v];
			f[u] += f[e.v] + sz[e.v] * ( (s[u] + e.w + s[e.v]).sum - 1);
		}
	}
	void dfs2(int u, node from) {
		if(fa[u]) f[u] = f[fa[u]] - sz[u] * (from.sum - 1) + (m - sz[u]) * (from.sum - 1);
		for(auto e: G[u]) dfs2(e.v, s[u] + e.w + s[e.v]);
	}
	inline void solve() {
		dfs1(1, 0); dfs2(1, E0);
		rep(i, 1, m) printf("%d ", f[ask[i]]); puts("");
		rep(i, 1, tot) {
			int cur = pt[i]; G[cur].clear(); s[cur] = E0;
			key[cur] = f[cur] = fa[cur] = sz[cur] = ask[i] = pt[i] = 0;
		}
	}
} sol;
int main() {
	scanf("%d%d", &n, &q);
	rep(i, 1, n) scanf("%d", c+i);
	rep(i, 2, n) {
		int u, v; scanf("%d%d", &u, &v);
		G[u].pb(v); G[v].pb(u);
	}
	cut.init();
	while(q--) {
		int op; scanf("%d", &op);
		if(op&1) {
			int u, v, w; scanf("%d%d%d", &u, &v, &w);
			cut.modify(u, v, w);
		} else {
			int m; scanf("%d", &m); sol.init(m); sol.solve();
		}
	}
	return 0;
}