Boj 11658) 구간 합 구하기 3

19 분 소요

문제

백준 11658

방법 1

설명

2차원 세그먼트트리는 구현이 복잡해서 간단한 2차원 펜윅트리를 써서 풀었음.

시간 복잡도

O(\(N^2 \log{N} + M (\log{N})^2\))

코드

template<typename T, size_t _SIZE>
struct BIT
{
	inline void Update(int x, int y, const T& v) {
		for (int j = y; j <= _SIZE; j += j & -j)
			for (int i = x; i <= _SIZE; i += i & -i)
				nodes[j][i] += v;
	}

	inline T Query(int x, int y) {
		T ans = 0;
		for (int j = y; j > 0; j -= j & -j)
			for (int i = x; i > 0; i -= i & -i)
				ans += nodes[j][i];
		return ans;
	}

	inline T Query(int x1, int y1, int x2, int y2)
	{
		return Query(x2, y2) - Query(x1 - 1, y2) - Query(x2, y1 - 1) + Query(x1 - 1, y1 - 1);
	}

	T nodes[_SIZE + 1][_SIZE + 1];
};

BIT<int, 1024> fw;

int main()
{
	fastio;

	int n, m;

	cin >> n >> m;

	int t;
	for (int y = 1; y <= n; y++)
		for (int x = 1; x <= n; x++)
		{
			cin >> t;
			fw.Update(x, y, t);
		}

	int x1, y1, x2, y2;
	while (m--)
	{
		cin >> t >> y1 >> x1 >> y2;
		if (t == 0) fw.Update(x1, y1, y2 - fw.Query(x1, y1, x1, y1));
		else {
			cin >> x2;
			cout << fw.Query(x1, y1, x2, y2) << '\n';
		}
	}
}

틀린이유 1

template<typename T, size_t _SIZE>
struct BIT
{
	void Update(int x, int y, const T& v) {
		for (int i = x; i <= _SIZE; i += i & -i)
			nodes[y][i] += v;
	}

	T Query(int x, int y) {
		T ans = 0;
		for (int i = x; i > 0; i -= i & -i)
			ans += nodes[y][i];
		return ans;
	}

	T Query(int x1, int y1, int x2, int y2)
	{
		T ans = 0;
		for (int i = y1; i <= y2; i++)
			ans += Query(x2, i) - Query(x1-1, i);
		return ans;
	}

	T nodes[_SIZE + 1][_SIZE + 1];
};

위처럼 n 개의 펜윅트리는 TLE 가 뜸

  • 왜냐하면 탐색 및 수정에서 시간복잡도가 O(\(N \log{N}\)) 이 걸리기 때문
  • 그래서 누적합으로 푸는 \(O(N)\) 보다 느림.

틀린이유 2

template<typename T, size_t _H>
class SegmentTree
{
	template<typename F>
	struct Node {
		Node() {}
		Node(F v) : v(v) {}
		Node operator+(const Node& in) { return v + in.v; }
		friend Node operator+(const F& l, const Node& in) { return in + l; }
		F v = 0;
	};

	inline const int P(int i) { return (i >> 2) + 2; }
	inline const int C(int i) { return (i - 2) << 2; }

public:
	void Init() {
		for (int i = BASE, j = C(i); i >= 3; i--, j = C(i))
			nodes[i] = nodes[j] + nodes[j + 1] + nodes[j + 2] + nodes[j + 3];
	}

	void Update(int x, int y, T v) { _t1 = x, _t2 = y, _v = v; Update_Recursive(3, 1, INDEX_MAX, 1, INDEX_MAX); }
	Node<T> Query(int l, int r, int t, int b) { _t1 = l, _t2 = r, _t3 = t, _t4 = b; return Query_Recursive(3, 1, INDEX_MAX, 1, INDEX_MAX); }

	void Update_Recursive(int x, int l, int r, int t, int d)
	{
		if (l == r && t == d && l == _t1 && t == _t2) {
			nodes[x] = T(_v);
			return;
		}

		const int m1 = (l + r) / 2, m2 = (t + d) / 2;
		if (_t1 <= m1) {
			if (_t2 <= m2) Update_Recursive(C(x), l, m1, t, m2);
			else Update_Recursive(C(x)+1, l, m1, m2 + 1, d);
		}
		else {
			if (_t2 <= m2) Update_Recursive(C(x)+2, m1 + 1, r, t, m2);
			else Update_Recursive(C(x)+3, m1 + 1, r, m2 + 1, d);
		}

		T curV = T();
		for (int i = C(x); i <= C(x)+3; i++)
			curV += nodes[i].v;
		nodes[x].v = curV;
	}

	Node<T> Query_Recursive(int x, int l, int r, int t, int b)
	{
		if (_t2 < l || r < _t1 || _t4 < t || b < _t3)     return Node<T>();
		if (_t1 <= l && r <= _t2 && _t3 <= t && b <= _t4)
			return nodes[x];

		const int m1 = (l + r) / 2, m2 = (t + b) / 2;
		return Query_Recursive(C(x), l, m1, t, m2) + Query_Recursive(C(x)+1, l, m1, m2 + 1, b)
			   + Query_Recursive(C(x)+2, m1 + 1, r, t, m2) + Query_Recursive(C(x)+3, m1 + 1, r, m2 + 1, b);
	}

	Node<T> nodes[((1 << _H * 2)) / 3 + 3];
	int BASE = ((1 << (_H - 1) * 2) / 3) + 2;  // end of not leaf
	int INDEX_MAX = 1 << _H - 1;
	int _t1, _t2, _t3, _t4; T _v;
};

위처럼 쿼드트리를 쓰면 처음 초기화 때 O(\(N^2 \log{N}\))이 걸림

  • 문제는 이게 1024 * 1024 * 10 = 10,485,760
  • 상수 덕분에 1초 컷당해 TLE 가 나는듯함.

방법 2

설명

심플한 누적합.

계산하면 \(N \times M\) 이 1억인데 계산이 간단해서 1초 내에 풀림.

시간 복잡도

O(\(N^2 + NM\))

코드

int n;
int sums[1025][1025];

int main()
{
	int t, m;
	cin >> n >> m;
	for (int y = 1; y <= n; y++)
		for (int x = 1; x <= n; x++)
		{
			cin >> t;
			sums[y][x] = sums[y][x - 1] + t;
		}

	int x, y, t1, t2;
	while (m--)
	{
		cin >> t;
		if (t)
		{
			cin >> y >> x >> t2 >> t1;
			int ans = 0;
			for (; y <= t2; y++)
				ans += sums[y][t1] - sums[y][x - 1];
			cout << ans << '\n';
		}
		else {
			cin >> y >> x >> t1;
			int diff = t1 - sums[y][x] + sums[y][x - 1];
			sums[y][x++] += diff;
			for (; x <= n; x++)
				sums[y][x] += diff;
		}
	}
}

댓글남기기