[백준/BOJ] 백준 17831번 : 대기업 승범이네

2021. 6. 29. 01:27알고리즘 문제풀이

https://www.acmicpc.net/problem/17831

 

17831번: 대기업 승범이네

첫 번째 줄에 판매원들의 수 N(2 ≤ N ≤ 200,000)이 주어진다. 판매원들은 1번, 2번, …, N번으로 번호가 매겨지며, 승범이는 항상 1번이다. 두 번째 줄에 2번 판매원부터 N번 판매원의 사수가 순서대

www.acmicpc.net

트리DP를 이용하여 문제를 해결했다. 해당 노드의 위치가 멘토일 때와 멘토가 아닐 때를 고려하여 문제를 해결하였는데, here가 멘토일 때는 there(자식 노드)중 하나는 멘티로 선택되어야 하는데 이때 here의 자식 노드(there)들 중 가장 효과적인 자식 노드를 뽑아 멘티로 고른다.

 

코드

#include <iostream>
#include <algorithm>
#include <vector>
#include <utility>
using namespace std;

int n;
vector<int> tree[200001];
vector<int> energy(200001);
vector<vector<int>> cache(200001, vector<int>(2, -1));
int result;

int Solve(int here, int mentor)
{
	int& ret = cache[here][mentor];

	if (ret != -1)
		return ret;

	ret = 0;

	//here이 멘토일때
	if (mentor == 1)
	{
		if (tree[here].size() == 0)
			return ret;

		int mentor_energy = energy[here];

		int mentee = -1;
		int min_diff = 987654321;

		//there중 하나를 멘티로 선택해야 한다
		for (int i = 0; i < tree[here].size(); i++)
		{
			int there = tree[here][i];
			int this_diff = max(Solve(there, 0), Solve(there, 1)) - (mentor_energy * energy[there] + Solve(there, 0)); //there를 멘티로 고르지 않을때와 골랐을때 차이를 구한다(there를 멘티로 골랐을때 손해가 얼만큼인지 구함)
			ret += max(Solve(there, 0), Solve(there, 1)); //일단 there를 멘티로 고르지 않은 경우일때로 고려한다

			//기존것보다 손해가 더 작다면
			if (this_diff < min_diff)
			{
				mentee = there;
				min_diff = this_diff;
			}
		}

		//mentee에 저장되있는 번호를 멘티로 하는것으로 한다
		ret -= max(Solve(mentee, 0), Solve(mentee, 1)); //기존에 멘티로 고르지 않는 경우로 더한것을 빼고
		ret += (mentor_energy * energy[mentee] + Solve(mentee, 0)); //멘티로 고른 경우를 더한다
	}

	else
	{
		for (int i = 0; i < tree[here].size(); i++)
		{
			int there = tree[here][i];

			ret += max(Solve(there, 0), Solve(there, 1));
		}
	}

	return ret;
}

int main()
{
	cin.tie(NULL);
	ios_base::sync_with_stdio(false);

	cin >> n;

	for (int i = 2; i <= n; i++)
	{
		int input;
		cin >> input;

		tree[input].push_back(i);
	}

	for (int i = 1; i <= n; i++)
	{
		int input;
		cin >> input;

		energy[i] = input;
	}

	result = max(Solve(1, 0), Solve(1, 1));

	cout << result << "\n";

	return 0;
}