线段树,但是抽象
构建和查询
线段树,按字面要有线段(区间),也就是
下面简记为,为。
关注上的合并操作,是一个偏函数:
每个线段上附加了信息,也就是
平凡的函数和平凡的信息没什么意思,所以来附加一点结构吧。
我们希望可以组合,这样区间的信息能够从子区间的信息合并出来。也就是最好是一个半群。
但实现上只需要相邻区间的合并,所以只需要给一个更弱的结构,暂且称之为偏半群, 即要求一个满足结合律但不必为全函数的乘法。注意也是一个偏半群。
要使查询可以分解为查询和再合并,也就是要。 简而言之,是一个同态。
暂时不考虑修改,线段树是一个偏半群同态的计算程序。
这一性质使能线段树的对数时间复杂度查询(通过二分+合并,只要乘法是O(1))。
当然你可以给任意偏半群添加一个元素使原来未定义乘法的,使,那么 就是个半群(相当于套个
Option或者说加个null)。虽然更熟悉,但totality性质用不上,还和的结构不同了,反而加大了思维难度。 线段树是个同态这个表述真是干净又清爽,轻轻松松reason代码正确性(
实现
要存储这样一个,一方面必须存储所有的,因为没法拆成真子区间的合并了。 另一方面,,所以存储是充分的。
而线段树就是在这些必要信息之外存储一些冗余的部分合并了的信息,来换取更好的复杂度。
于是建树需要提供信息:
- 所有的
- 乘法
目前的接口和实现:
import scala.reflect.ClassTag
trait SegmentTree[A] {
def query(l: Int, r: Int): A
}
object SegmentTree {
def apply[A: ClassTag](n: Int, init: Int => A, mul: (A, A) => A): SegmentTree[A] = {
val tree = Array.ofDim[A](4 * n)
def build(l: Int, r: Int, idx: Int): Unit = if (l == r) {
tree(idx) = init(l)
} else {
val m = l + (r - l) / 2
build(l, m, idx * 2)
build(m + 1, r, idx * 2 + 1)
tree(idx) = mul(tree(idx * 2), tree(idx * 2 + 1))
}
build(0, n - 1, 1)
new SegmentTree[A] {
override def query(l: Int, r: Int): A = {
def loop(l0: Int, r0: Int, idx: Int): A = if (l <= l0 && r0 <= r) {
tree(idx)
} else {
val m0 = l0 + (r0 - l0) / 2
val left = if (l <= m0) Some(loop(l0, m0, idx * 2)) else None
val right = if (m0 < r) Some(loop(m0 + 1, r0, idx * 2 + 1)) else None
List(left, right).flatten.reduce(mul)
}
loop(0, n - 1, 1)
}
}
}
}
简单的用例(半群):
val a = Array(1, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0)
val n = a.length
val tSum = SegmentTree(n, a, _ + _)
val theAnswer = tSum.query(2, 10) // 42
val tMax = SegmentTree(n, a, _ max _)
val nine = tMax.query(2, 10) // 9
val tConcat = SegmentTree(n, a.map(_.toString), _ + _)
val koishi = tConcat.query(3, 5) // "514"复杂一点的:
val s = "aaaabbbbbccdeeefgghhhhhh"
val t = SegmentTree(
s.length,
s.map(ch => (1, ch, 1, ch, 1, true)),
{ case ((len1, lch1, llen1, rch1, rlen1, mono1), (len2, lch2, llen2, rch2, rlen2, mono2)) =>
// 最长的连续相同字符的子串的长度, 区间最左边字符, 区间最左边连续相同字符个数, 区间最右边字符, 区间最右边连续相同字符个数, 区间是否是连续相同字符
if (rch1 == lch2) {
val len = len1 max len2 max (rlen1 + llen2)
val llen = if (mono1) llen1 + llen2 else llen1
val rlen = if (mono2) rlen2 + rlen1 else rlen2
(len, lch1, llen, rch2, rlen, mono1 && mono2)
} else (len1 max len2, lch1, llen1, rch2, rlen2, false)
},
)
// 最长的连续相同字符的子串的长度
t.query(10, 19)._1 // 3更新(修改)
一个更新是一个自同态。如果更新整个,那么线段树就从变成了。
线段树支持在一个子区间上部分应用这个更新。具体来讲,比如在上应用,那么生成元变成了。
支持所有的更新是不可能的,故只考虑的某个子集,用索引记为。 代码中用来表示和存储一个修改操作,放在lazy标记里。
须是的子幺半群,这是lazy propagation的实现所需: 两个修改需要能合并成一个,而且lazy标记要能表示是否已修改(无修改是,即幺元)。少一点性质都做不到lazy。
实际进行计算的是,可选取是幺半群,为幺半群同构。
实现
要拥有修改操作,需要以下额外信息:
- 的幺半群实例
- ,或者说一个嵌入,或者说
import scala.reflect.ClassTag
trait SegmentTree[A, I] {
def query(l: Int, r: Int): A
def update(l: Int, r: Int, i: I): Unit
}
object SegmentTree {
def apply[A: ClassTag, I: ClassTag](
n: Int,
init: Int => A,
mul: (A, A) => A,
iOne: I,
iMul: (I, I) => I,
update: (I, A) => A,
): SegmentTree[A, I] = {
val tree = Array.ofDim[A](4 * n)
val lazi = Array.fill(4 * n)(iOne)
def build(l: Int, r: Int, idx: Int): Unit = if (l == r) {
tree(idx) = init(l)
} else {
val m = l + (r - l) / 2
build(l, m, idx * 2)
build(m + 1, r, idx * 2 + 1)
tree(idx) = mul(tree(idx * 2), tree(idx * 2 + 1))
}
build(0, n - 1, 1)
def pushDown(l0: Int, r0: Int, idx: Int): Unit = {
val m = l0 + (r0 - l0) / 2
val lIdx = idx * 2
val rIdx = lIdx + 1
tree(lIdx) = update(lazi(idx), tree(lIdx))
tree(rIdx) = update(lazi(idx), tree(rIdx))
lazi(lIdx) = iMul(lazi(lIdx), lazi(idx))
lazi(rIdx) = iMul(lazi(rIdx), lazi(idx))
lazi(idx) = iOne
}
val updateWith = update
new SegmentTree[A, I] {
override def query(l: Int, r: Int): A = {
def loop(l0: Int, r0: Int, idx: Int): A = if (l <= l0 && r0 <= r) {
tree(idx)
} else {
if (lazi(idx) != iOne) {
pushDown(l0, r0, idx)
}
val m0 = l0 + (r0 - l0) / 2
val left = if (l <= m0) Some(loop(l0, m0, idx * 2)) else None
val right = if (m0 < r) Some(loop(m0 + 1, r0, idx * 2 + 1)) else None
List(left, right).flatten.reduce(mul)
}
loop(0, n - 1, 1)
}
override def update(l: Int, r: Int, i: I): Unit = {
def loop(l0: Int, r0: Int, idx: Int): Unit = if (l <= l0 && r0 <= r) {
tree(idx) = updateWith(i, tree(idx))
lazi(idx) = iMul(lazi(idx), i)
} else {
val m = l0 + (r0 - l0) / 2
if (lazi(idx) != iOne && l0 != r0) {
pushDown(l0, r0, idx)
}
if (l <= m) loop(l0, m, idx * 2)
if (m < r) loop(m + 1, r0, idx * 2 + 1)
tree(idx) = mul(tree(idx * 2), tree(idx * 2 + 1))
}
loop(0, n - 1, 1)
}
}
}
}用例(查询区间和,逐点更新加某一数):
val t = SegmentTree(
10,
Array(1, 1, 4, 5, 1, 4, 1, 9, 1, 9).map(a => (a, 1)),
{ case ((s1, len1), (s2, len2)) => (s1 + s2, len1 + len2) },
0,
_ + _,
{ case (delta, (s, len)) => (s + delta * len, len) },
)
println(t.query(0, 3)._1) // 11
t.update(2, 5, 10)
println(t.query(0, 3)._1) // 31
println(t.query(5, 8)._1) // 25用例(扫描线):
val t = SegmentTree(
6,
Array(1, 1, 4, 5, 1, 4).map(len => (len, 0, len)),
{ case ((len1, minThickness1, minThicknessLen1), (len2, minThickness2, minThicknessLen2)) =>
if (minThickness1 < minThickness2) {
(len1 + len2, minThickness1, minThicknessLen1)
} else if (minThickness1 > minThickness2) {
(len1 + len2, minThickness2, minThicknessLen2)
} else {
(len1 + len2, minThickness1, minThicknessLen1 + minThicknessLen2)
}
},
0,
_ + _,
{ case (delta, (len, minThickness, minThicknessLen)) =>
(len, minThickness + delta, minThicknessLen)
},
)
def queryCoverage(): Int = {
val (len, minThickness, minThicknessLen) = t.query(0, 5)
if (minThickness > 0) len else len - minThicknessLen
}
println(queryCoverage()) // 0
t.update(0, 2, 1) // add (0, 6)
println(queryCoverage()) // 6
t.update(2, 3, 1) // add (2, 11)
println(queryCoverage()) // 11
t.update(0, 2, -1) // remove (0, 6)
println(queryCoverage()) // 9
t.update(2, 3, -1) // remove (2, 11)
println(queryCoverage()) // 0优化
- 幺元建树:如果是幺半群,且初始化全部为幺元,直接用幺元填充
tree即可。 - 冗余的区间端点信息:更新过程中是知道每个被更新的区间的左右端点的,可传给
update: (I, A) => A(改成update: (Int, Int, I, A) => A), 从而使中不再需要左右端点或者区间长度之类的字段。 - AoS to SoA:特化然后改写。