【bzoj2870】最長道路tree 樹的直徑+並查集


題目描述

給定一棵N個點的樹,求樹上一條鏈使得鏈的長度乘鏈上所有點中的最小權值所得的積最大。
其中鏈長度定義為鏈上點的個數。

輸入

第一行N
第二行N個數分別表示1~N的點權v[i]
接下來N-1行每行兩個數x、y,表示一條連接x和y的邊

輸出

一個數,表示最大的痛苦程度。

樣例輸入

3
5 3 5
1 2
1 3

樣例輸出

10


題解

樹的直徑+並查集

首先肯定是把權值從大到小排序,按照順序加點,維護每個連通塊的最長鏈乘以當前點權值作為貢獻。

那么如何在加上一條邊,連接兩棵樹后快速得出新的直徑呢?

一個結論:將兩棵樹連成一棵,新樹的直徑的兩端點只有可能是原來兩棵樹兩條直徑四個端點中的某兩個。

證明不太容易表述。。。簡單畫一畫就差不多出來了。實在不行可以先推加一個點的情況,然后再推加一棵樹。

於是使用並查集維護樹的直徑長度及端點位置,使用倍增LCA求距離,就做完了。。。

注意需要開long long。

時間復雜度 $O(n\log n)$ 

#include <cstdio>
#include <algorithm>
#define N 50010
using namespace std;
typedef long long ll;
int v[N] , id[N] , head[N] , to[N << 1] , next[N << 1] , cnt , fa[N][17] , deep[N] , log[N] , f[N] , px[N] , py[N];
ll ans = 0;
bool cmp(int a , int b)
{
return v[a] > v[b];
}
inline void add(int x , int y)
{
to[++cnt] = y , next[cnt] = head[x] , head[x] = cnt;
}
void dfs(int x)
{
int i;
for(i = 1 ; (1 << i) <= deep[x] ; i ++ ) fa[x][i] = fa[fa[x][i - 1]][i - 1];
for(i = head[x] ; i ; i = next[i])
if(to[i] != fa[x][0])
fa[to[i]][0] = x , deep[to[i]] = deep[x] + 1 , dfs(to[i]);
}
inline int lca(int x , int y)
{
int i;
if(deep[x] < deep[y]) swap(x , y);
for(i = log[deep[x] - deep[y]] ; ~i ; i -- )
if(deep[x] - deep[y] >= (1 << i))
x = fa[x][i];
if(x == y) return x;
for(i = log[deep[x]] ; ~i ; i -- )
if(deep[x] >= (1 << i) && fa[x][i] != fa[y][i])
x = fa[x][i] , y = fa[y][i];
return fa[x][0];
}
inline int dis(int x , int y)
{
return deep[x] + deep[y] - (deep[lca(x , y)] << 1);
}
int find(int x)
{
return x == f[x] ? x : f[x] = find(f[x]);
}
void solve(int x)
{
int i , tx , ty , t , vm , vx , vy;
for(i = head[x] ; i ; i = next[i])
{
if(f[to[i]])
{
tx = find(x) , ty = find(to[i]) , vm = -1;
if(vm < (t = dis(px[tx] , py[tx]))) vm = t , vx = px[tx] , vy = py[tx];
if(vm < (t = dis(px[ty] , py[ty]))) vm = t , vx = px[ty] , vy = py[ty];
if(vm < (t = dis(px[tx] , px[ty]))) vm = t , vx = px[tx] , vy = px[ty];
if(vm < (t = dis(px[tx] , py[ty]))) vm = t , vx = px[tx] , vy = py[ty];
if(vm < (t = dis(py[tx] , px[ty]))) vm = t , vx = py[tx] , vy = px[ty];
if(vm < (t = dis(py[tx] , py[ty]))) vm = t , vx = py[tx] , vy = py[ty];
f[ty] = tx , px[tx] = vx , py[tx] = vy;
}
}
tx = find(x) , ans = max(ans , (ll)v[x] * (dis(px[tx] , py[tx]) + 1));
}
int main()
{
int n , i , x , y;
scanf("%d" , &n);
for(i = 1 ; i <= n ; i ++ ) scanf("%d" , &v[i]) , id[i] = i;
for(i = 2 ; i <= n ; i ++ ) scanf("%d%d" , &x , &y) , add(x , y) , add(y , x) , log[i] = log[i >> 1] + 1;
dfs(1);
sort(id + 1 , id + n + 1 , cmp);
for(i = 1 ; i <= n ; i ++ )
f[id[i]] = px[id[i]] = py[id[i]] = id[i] , solve(id[i]);
printf("%lld\n" , ans);
return 0;
}

 

 


注意!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系我们删除。



 
  © 2014-2022 ITdaan.com 联系我们: