路径相关树形dp——最长乘积链
问题描述
给定一棵树,树中包含n个结点,编号为1~n,以及n- 1条无向边,每条边都有一个权值。
现从树中任选一个点, 从该点出发,在不走回头路的情况下找出二条到其他点的路径,这二条路径不能有公共边,请问这二条路径长度的乘积最大可以是多少。
注:如果从该点出发只有一个方向可以走,换句话说该点入度出度为1,则乘积为0。
输入格式
第一行输入一个整数n。
接下来n- 1行,每行输入包含三个整数 a i , b i , c i a_i,b_i,c_i ai,bi,ci,表示点 a i a_i ai和 b i b_i bi之间存在一条权值为 c i c_i ci的边。
输出格式
输出一个整数,为1二条路径长度乘积的最大值。
题目分析
考虑对于一个节点i而言,它的最长链要怎么去寻找,我们还是要画图看一下,就拿之前的那张图吧。
对于节点4而言,寻找它的最长链首先有两个寻找方向,向下沿着儿子节点寻找一条最长链和向上沿着父节点寻找一条最长链,那么次长链呢?我只有一个寻找方向,就是向下沿着儿子节点寻找,并且这个儿子节点不能在最长链里面,为什么不能向上沿着父节点寻找呢?因为向上寻找必然经过父节点,但是向上找最长链时父节点已经在了,两条链不能有重复的节点。
接下来看怎么求,先看向下沿着儿子节点寻找一条最长链和次长链怎么求。通过一次dfs可以求出来,代码如下,注意在求最长链和次长链的同时,也要记录他们是沿着哪个儿子节点的链。dp1[u]表示节点u向下走的最长链的大小,p1[u]表示最长链是沿着哪个儿子节点走的,dp2[u]表示节点u向下走的次长链的大小,p2[u]表示次长链是沿着哪个儿子节点走的。假设儿子节点是e,那么节点u沿着节点e的最长链用dp1[e]+w表示,w代表u和e之间的距离。然后用dp1[e]+w和现在的dp1[u]比较,如果dp1[u]小,则更新dp1[u],同时将现在的dp1[u]的值给dp2[u]。否则看dp1[e]+w和现在的dp2[u]的大小,如果dp2[u]小,用dp1[e]+w更新dp2[u]。注意在更新dp1[u]和dp2[u]的同时也要更新对应的p1[u]和p2[u]。代码如下,
private static void dfs1(int u, int fa) {if(map.get(u)==null) return;for(node e:map.get(u)) {int v = e.x;if(v==fa) continue;dfs1(v, u);if(dp1[v]+e.w>dp1[u]) {dp2[u]=dp1[u];p2[u]=p1[u];dp1[u]=dp1[v]+e.w;p1[u]=v;}else if(dp1[v]+e.w>dp2[u]) {dp2[u]=dp1[v]+e.w;p2[u]=v;}}
}
现在看怎么求向上沿着父节点寻找一条最长链,节点到父节点的距离为w,加上父节点向下走的最长链与向上走的最长链的最大值,但是这里要注意,如果父节点的最长链是沿着该节点走的,则不能用,要更改为加上父节点向下走的次长链与向上走的最长链的最大值,这里也体现了为什么要有数组p1和p2,代码如下
private static void dfs2(int u, int fa) {if(map.get(u)==null) return;for(node e:map.get(u)) {int v = e.x;if(v==fa) continue; if(p1[u]==v) {up[v] = e.w+Math.max(up[u], dp2[u]);}else {up[v] = e.w+Math.max(up[u], dp1[u]);}dfs2(v, u);}
}
现在我们需要的东西都求出来了,题目要求次长链和最长链的乘积最大值,那么也就是向下走的最长链乘以向下走的次长链的乘积与向下走的最长链乘以向上走的最长链的乘积里取一个最大值,代码如下
for(int i = 1;i <= n;i++) {res = Math.max(Math.max((long)dp1[i]*up[i], (long)dp1[i]*dp2[i]), res);
}
题目代码
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Scanner;
public class Main{static class node{int x,w;public node(int x, int w) {super();this.x = x;this.w = w;}}static HashMap<Integer, List<node>> map = new HashMap<Integer, List<node>>();static int dp1[],dp2[],up[],p1[],p2[];
public static void main(String[] args) {Scanner scanner = new Scanner(System.in);int n = scanner.nextInt();dp1 = new int[n+1];dp2 = new int[n+1];up = new int[n+1];p1 = new int[n+1];p2 = new int[n+1];for(int i = 1;i < n;i++) {int u = scanner.nextInt();int v = scanner.nextInt();int c =scanner.nextInt();add(u,v,c);add(v,u,c);}dfs1(1,0);dfs2(1,0);long res = 0;for(int i = 1;i <= n;i++) {res = Math.max(Math.max((long)dp1[i]*up[i], (long)dp1[i]*dp2[i]), res);}System.out.println(res);
}
private static void dfs2(int u, int fa) {if(map.get(u)==null) return;for(node e:map.get(u)) {int v = e.x;if(v==fa) continue;if(p1[u]==v) {up[v] = e.w+Math.max(up[u], dp2[u]);}else {up[v] = e.w+Math.max(up[u], dp1[u]);}dfs2(v, u);}
}
private static void dfs1(int u, int fa) {if(map.get(u)==null) return;for(node e:map.get(u)) {int v = e.x;if(v==fa) continue;dfs1(v, u);if(dp1[v]+e.w>dp1[u]) {dp2[u]=dp1[u];p2[u]=p1[u];dp1[u]=dp1[v]+e.w;p1[u]=v;}else if(dp1[v]+e.w>dp2[u]) {dp2[u]=dp1[v]+e.w;p2[u]=v;}}
}
private static void add(int u, int v, int c) {if(!map.containsKey(u)) map.put(u, new ArrayList<node>());map.get(u).add(new node(v, c));
}
}