Boj 16993) 연속합과 쿼리

10 분 소요

문제

백준 16993

설명

연속합을 분할정복으로 풀기

연속합을 푸는 방법은 대체로 DP 로 선형시간에 푸는 방법이 잘 알려져 있다. 하지만 변수의 크기와 쿼리의 갯수만 보아도 쿼리당 로그시간으로 풀어야 한다는 감이 온다. 즉 세그트리로 구해야 한다는 것인데, 어떻게 구할 수 있을까?

분할정복으로도 풀 수 있는데 아이디어는 다음과 같다. 어떤 수열에서 최대 연속합은 그 수열을 L, R 로 쪼갰을 때 3가지 경우로 나뉜다.

  1. 온전히 L 에 속한다.
  2. 온전히 R 에 속한다.
  3. LR 사이에 걸쳐있다.

우리는 분할정복을 하면서 경우 3을 어떻게 구할지 생각을 해야한다. 하나하나 세면 시간복잡도가 커지니, 왼쪽/오른쪽 끝에서 시작해서 구할 수 있는 최대 연속합을 미리 저장해야겠다고 여기서 생각해내야한다. 그러면 우리는 total_m(max), lm(left max), rm(right max) 의 세 값을 노드마다 저장해야한다는 것을 알 수 있다.

그러면 L+R(res)total_m 은 다음의 세가지 경우 중 하나가 된다.

  1. L.total_m
  2. R.total_m
  3. L.rm + R.lm

그럼 reslmrm 은 어떻게 구할 수 있을까? res.lmR 이 결과에 영향을 주지 않는다면 L 의 왼쪽 끝부터 구할 수 있는 최댓값인 L.lm 이다. 만약 R 이 영향을 준다면 R 의 왼쪽 끝부터 구할 수 있는 최댓값이 연속합의 오른쪽 끝이 된다. 즉 L 영역의 합과 R.lm 을 더한 값이 된다. 둘중 최댓값이 res.lm 이 되며 res.rm 은 비슷하게 구하면 된다.

이를 위해선 우리는 노드의 영역에서 전체 합을 알고 있어야 한다. 즉 우리는 acc 값도 저장해둬야한다.

요약하면 Nodelm, rm, total_m, acc , 총 4개의 변수를 저장하고, 위에서 설명한 대로 업데이트 해나가면 된다.

세그트리에 관해서

연속합을 구할 때 왼쪽/오른쪽의 순서가 매우 중요하다. Bottom-Up 방식도 구현방식에 따라 가능할지 모르지만, 순서를 고려하기엔 Top-Down 이 훨씬 안정적이었다.

시간 복잡도

O(\((\mathrm{N} + \mathrm{M})\log{\mathrm{N}}\))

코드

template<typename T, size_t _H>
class SegmentTree
{
	struct Node {
		Node() {}
		Node(int v) : acc(v), total_m(v), lm(v), rm(v) {}
		Node operator+(const Node& in)
		{
			Node res;

			res.acc = acc + in.acc;
			res.lm = max(lm, acc + in.lm);
			res.rm = max(in.rm, in.acc + rm);
			res.total_m = max(rm + in.lm, max(total_m, in.total_m));
			
			return res;
		}
		int total_m = -1e9, lm = -1e9, rm = -1e9, acc = 0;
	};

public:
	void Init() { for (int i = INDEX_MAX - 1; i > 0; i--) nodes[i] = nodes[i << 1] + nodes[i << 1 | 1]; }

	void Update(int i, const T& v) {
		for (i += INDEX_MAX - 1, nodes[i] = v; i > 1; i >>= 1)
			nodes[i >> 1] = nodes[min(i, i ^ 1)] + nodes[max(i, i ^ 1)];
	}

	Node Query(int l, int r) { _t1 = l; _t2 = r; return Query_Recursive(1, 1, INDEX_MAX); }
	Node Query_Recursive(int x, int s, int e)
	{
		if (_t2 < s || e < _t1) return Node();
		if (_t1 <= s && e <= _t2) return nodes[x];
		int m = (s + e) / 2;
		return Query_Recursive(x * 2, s, m) + Query_Recursive(x * 2 + 1, m + 1, e);
	}

	Node nodes[1 << _H];
	int INDEX_MAX = 1 << _H - 1;
	int _t1, _t2;
};

SegmentTree<int, 18> seg;

int main()
{
	fastio;
	int n;
	cin >> n;
	seg.Init();
	for (int i = 1; i <= n; i++) {
		int a;  cin >> a;
		seg.Update(i, a);
	}

	int m;
	cin >> m;
	while (m--)
	{
		int s, e; cin >> s >> e;
		cout << seg.Query(s, e).total_m << '\n';
	}
}

댓글남기기