-
Notifications
You must be signed in to change notification settings - Fork 2
/
segmentTree.cpp
79 lines (67 loc) · 2.21 KB
/
segmentTree.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include <iostream>
#include <vector>
using namespace std;
struct SegmentTreeNode {
int start, end, sum;
SegmentTreeNode *left, *right;
SegmentTreeNode(int start, int end, int sum) {
this->start = start;
this->end = end;
this->sum = sum;
this->left = this->right = nullptr;
}
};
class NumArray {
SegmentTreeNode* root;
SegmentTreeNode* build(int start, int end, const vector<int>& A) {
if (start > end) return nullptr;
else if (start == end) {
SegmentTreeNode *tree = new SegmentTreeNode(start, end, A[start]);
return tree;
} else {
SegmentTreeNode *left = build(start, (start+end)/2, A);
SegmentTreeNode *right = build((start+end)/2+1, end, A);
SegmentTreeNode *tree = new SegmentTreeNode(start, end, left->sum + right->sum);
tree -> left = left;
tree -> right = right;
return tree;
}
}
void modify(SegmentTreeNode *root, int i, int val) {
if (root->start == root->end) {
root->sum = val;
return;
}
if (i <= root->left->end) modify(root->left, i, val);
else modify(root->right, i, val);
root->sum = root->left->sum + root->right->sum;
}
int query(SegmentTreeNode *root, int i, int j) {
if (root->start == root->end) return root->sum;
if (root->start == i && root->end == j) return root->sum;
if(j <= root->left->end) {
return query(root->left, i, j);
} else if (i >= root->right->start) {
return query(root->right, i, j);
} else {
return query(root->left, i, root->left->end) + query(root->right, root->right->start, j);
}
}
public:
NumArray(vector<int> &nums) {
root = build(0, nums.size()-1, nums);
}
void update(int i, int val) {
modify(root, i, val);
}
int sumRange(int i, int j) {
return query(root, i, j);
}
};
int main() {
vector<int> nums{1, 2, 3};
NumArray numArray(nums);
cout << numArray.sumRange(0, 1) << endl;
numArray.update(1, 10);
cout << numArray.sumRange(1, 2) << endl;
}