[백준/BOJ] 백준 22959번 : 신촌 수열과 쿼리

2022. 8. 17. 23:16알고리즘 문제풀이

https://www.acmicpc.net/problem/22959

 

22959번: 신촌 수열과 쿼리

첫째 줄에 수열의 크기 $N$이 주어진다. ($1 \le N \le 200\,000$) 둘째 줄에는 수열의 원소 $a_1,$ $a_2,$ $\cdots,$ $a_N$ 이 주어진다. ($1 \le a_i \le 10^9$) 셋째 줄에는 쿼리의 개수 $M$이 주어진다. ($1 \le M \le 200

www.acmicpc.net

구간의 최솟값을 저장하는 세그먼트 트리와, 구간의 합을 저장하는 세그먼트 트리를 이용하여 문제를 해결했다. 1번 쿼리의 경우에는 세그먼트 트리의 업데이트를 수행했고, 2번 쿼리의 경우에는 l과 r을 찾아서 해당 구간의 합을 세그먼트 트리를 이용해 구하면 되는데, 이때 l의 경우 l ~ i의 범위의 최솟값이 j이상인 최소 l을 찾아야 하는데, 이때 이분 탐색을 통해 mid ~ i의 범위 구간의 최솟값이 j이상인지로 판별하는 방법으로 찾았고, r의 경우 i ~ r의 범위의 최솟값이 j이상인 최대 r을 찾아야 하는데, 이때 이분 탐색을 통해 i ~ mid의 범위 구간의 최솟값이 j이상인지 판별하는 방법으로 찾았다.

 

코드

#include <iostream>
#include <vector>
#include <algorithm>
#include <limits>
using namespace std;

int n;
vector<int> a;
vector<int> min_sgmtt(800005, 0);
vector<long long> sum_sgmtt(800005, 0);
int m;

int MakeMinSgmtt(int here, int range_left, int range_right)
{
	if (range_left == range_right)
		return min_sgmtt[here] = a[range_left];

	int range_mid = (range_left + range_right) / 2;
	int left_child = here * 2 + 1;
	int right_child = here * 2 + 2;

	return min_sgmtt[here] = min(MakeMinSgmtt(left_child, range_left, range_mid), MakeMinSgmtt(right_child, range_mid + 1, range_right));
}

long long MakeSumSgmtt(int here, int range_left, int range_right)
{
	if (range_left == range_right)
		return sum_sgmtt[here] = a[range_left];

	int range_mid = (range_left + range_right) / 2;
	int left_child = here * 2 + 1;
	int right_child = here * 2 + 2;

	return sum_sgmtt[here] = MakeSumSgmtt(left_child, range_left, range_mid) + MakeSumSgmtt(right_child, range_mid + 1, range_right);
}

int UpdateMinSgmtt(int here, int range_left, int range_right, int index)
{
	if (range_left == range_right && range_right == index)
		return min_sgmtt[here] = a[index];

	if (index < range_left || index > range_right)
		return min_sgmtt[here];

	int range_mid = (range_left + range_right) / 2;
	int left_child = here * 2 + 1;
	int right_child = here * 2 + 2;

	return min_sgmtt[here] = min(UpdateMinSgmtt(left_child, range_left, range_mid, index), UpdateMinSgmtt(right_child, range_mid + 1, range_right, index));
}

long long UpdateSumSgmtt(int here, int range_left, int range_right, int index)
{
	if (range_left == range_right && range_right == index)
		return sum_sgmtt[here] = a[index];

	if (index < range_left || index > range_right)
		return sum_sgmtt[here];

	int range_mid = (range_left + range_right) / 2;
	int left_child = here * 2 + 1;
	int right_child = here * 2 + 2;

	return sum_sgmtt[here] = UpdateSumSgmtt(left_child, range_left, range_mid, index) + UpdateSumSgmtt(right_child, range_mid + 1, range_right, index);
}

int QueryMinSgmtt(int here, int range_left, int range_right, int find_left, int find_right)
{
	if (find_left <= range_left && range_right <= find_right)
		return min_sgmtt[here];

	if (find_right < range_left || find_left > range_right)
		return numeric_limits<int>::max();

	int range_mid = (range_left + range_right) / 2;
	int left_child = here * 2 + 1;
	int right_child = here * 2 + 2;

	return min(QueryMinSgmtt(left_child, range_left, range_mid, find_left, find_right), QueryMinSgmtt(right_child, range_mid + 1, range_right, find_left, find_right));
}

long long QuerySumSgmtt(int here, int range_left, int range_right, int find_left, int find_right)
{
	if (find_left <= range_left && range_right <= find_right)
		return sum_sgmtt[here];

	if (find_right < range_left || find_left > range_right)
		return 0;

	int range_mid = (range_left + range_right) / 2;
	int left_child = here * 2 + 1;
	int right_child = here * 2 + 2;

	return QuerySumSgmtt(left_child, range_left, range_mid, find_left, find_right) + QuerySumSgmtt(right_child, range_mid + 1, range_right, find_left, find_right);
}

int main()
{
	cin.tie(NULL);
	ios_base::sync_with_stdio(false);

	cin >> n;

	for (int i = 0; i < n; i++) {
		int input;
		cin >> input;

		a.push_back(input);
	}
	MakeMinSgmtt(0, 0, n - 1); //구간의 최솟값을 저장하는 세그먼트 트리
	MakeSumSgmtt(0, 0, n - 1); //구간의 합을 저장하는 세그먼트 트리

	cin >> m;

	for (int k = 0; k < m; k++) {
		int order, i, j;

		cin >> order >> i >> j;

		//0번인덱스부터 시작하는것으로 수정
		i--;

		if (order == 1)
		{
			a[i] = j;
			UpdateMinSgmtt(0, 0, n - 1, i);
			UpdateSumSgmtt(0, 0, n - 1, i);
		}

		else
		{
			//l과 r을 이분탐색을 통해 찾는다
			int left;
			int right;
			int mid;
			int l, r;

			//l찾기
			//l~i의 범위의 최소값이 j이상인 최소 l을 찾는다
			left = 0;
			right = i;
			while (left <= right) {

				mid = (left + right) / 2;

				//조건을 만족하는 l을 찾았을때
				if (QueryMinSgmtt(0, 0, n - 1, mid, i) >= j)
				{
					l = mid;
					right = mid - 1;
				}

				else
				{
					left = mid + 1;
				}
			}

			//r찾기
			//i~r의 범위의 최소값이 j이상인 최대 r을 찾는다
			left = i;
			right = n - 1;
			while (left <= right) {

				mid = (left + right) / 2;

				//조건을 만족하는 r을 찾았을때
				if (QueryMinSgmtt(0, 0, n - 1, i, mid) >= j)
				{
					r = mid;
					left = mid + 1;
				}

				else
				{
					right = mid - 1;
				}
			}
			cout << QuerySumSgmtt(0, 0, n - 1, l, r) << "\n";
		}
	}
}