线段树,但是抽象

10 min

构建和查询

线段树,按字面要有线段(区间),也就是

S[0,n1]={[l,r]Z:l,rZ,0lrn1}S_{[0, n-1]}=\left\{[l,r]\cap\mathbb Z: l,r\in\mathbb Z, 0\le l\le r\le n-1\right\}

下面简记S[0,n1]S_{[0, n-1]}SS[l,r]Z[l,r]\cap\mathbb Z[l,r][l,r]

关注SS上的合并操作,是一个偏函数:

:S×SS,[a,b][b+1,c]=[a,c]\circ: S\times S\rightharpoonup S, [a, b]\circ[b+1, c] = [a, c]

每个线段上附加了AA信息,也就是

f:SAf: S \to A

平凡的函数和平凡的信息没什么意思,所以来附加一点结构吧。

我们希望AA可以组合,这样区间的信息能够从子区间的信息合并出来。也就是AA最好是一个半群。

但实现上只需要相邻区间的合并,所以只需要给AA一个更弱的结构,暂且称之为偏半群, 即要求一个满足结合律但不必为全函数的乘法。注意(S,)(S, \circ)也是一个偏半群。

要使查询[l,r][l,r]可以分解为查询[l,m][l,m][m+1,r][m+1,r]再合并,也就是要f([l,m])f([m+1,r])=f([l,r])=f([l,m][m+1,r])f([l, m])f([m+1, r]) = f([l, r]) = f([l, m]\circ[m+1, r])。 简而言之,ff是一个同态。

暂时不考虑修改,线段树一个偏半群同态f:SnAf: S_n \to A的计算程序。

这一性质使能线段树的对数时间复杂度查询(通过二分+合并,只要乘法是O(1))。

当然你可以给任意偏半群AA添加一个元素0A0\notin A使原来未定义乘法的ab=0 (a,bA)ab=0\ (a,b\in A),使0a=00a=0,那么A{0}A\cup\{0\} 就是个半群(相当于套个Option或者说加个null)。虽然更熟悉,但totality性质用不上,还和SS的结构不同了,反而加大了思维难度。 线段树是个同态这个表述真是干净又清爽,轻轻松松reason代码正确性(

实现

要存储这样一个ff,一方面必须存储所有的f([i,i]) (1in)f([i, i])\ (1 \le i \le n),因为[i,i][i,i]没法拆成真子区间的合并了。 另一方面,f([l,r])=i=lrf([i,i])f([l, r])=\sum_{i=l}^r f([i, i]),所以存储f([i,i])f([i, i])是充分的。

而线段树就是在这些必要信息之外存储一些冗余的部分合并了的信息,来换取更好的复杂度。

于是建树需要提供信息:

  • 所有的f([i,i]) (0in1)f([i, i])\ (0 \le i \le n-1)
  • 乘法:A×AA\cdot: A\times A\rightharpoonup A

目前的接口和实现:

import scala.reflect.ClassTag

trait SegmentTree[A] {
  def query(l: Int, r: Int): A
}

object SegmentTree {
  def apply[A: ClassTag](n: Int, init: Int => A, mul: (A, A) => A): SegmentTree[A] = {
    val tree = Array.ofDim[A](4 * n)
    def build(l: Int, r: Int, idx: Int): Unit = if (l == r) {
      tree(idx) = init(l)
    } else {
      val m = l + (r - l) / 2
      build(l, m, idx * 2)
      build(m + 1, r, idx * 2 + 1)
      tree(idx) = mul(tree(idx * 2), tree(idx * 2 + 1))
    }
    build(0, n - 1, 1)

    new SegmentTree[A] {
      override def query(l: Int, r: Int): A = {
        def loop(l0: Int, r0: Int, idx: Int): A = if (l <= l0 && r0 <= r) {
          tree(idx)
        } else {
          val m0 = l0 + (r0 - l0) / 2
          val left = if (l <= m0) Some(loop(l0, m0, idx * 2)) else None
          val right = if (m0 < r) Some(loop(m0 + 1, r0, idx * 2 + 1)) else None
          List(left, right).flatten.reduce(mul)
        }
        loop(0, n - 1, 1)
      }
    }
  }
}

简单的用例(半群):

val a = Array(1, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0)
val n = a.length
val tSum = SegmentTree(n, a, _ + _)
val theAnswer = tSum.query(2, 10) // 42
val tMax = SegmentTree(n, a, _ max _)
val nine = tMax.query(2, 10) // 9
val tConcat = SegmentTree(n, a.map(_.toString), _ + _)
val koishi = tConcat.query(3, 5) // "514"

复杂一点的:

val s = "aaaabbbbbccdeeefgghhhhhh"
val t = SegmentTree(
  s.length,
  s.map(ch => (1, ch, 1, ch, 1, true)),
  { case ((len1, lch1, llen1, rch1, rlen1, mono1), (len2, lch2, llen2, rch2, rlen2, mono2)) =>
    // 最长的连续相同字符的子串的长度, 区间最左边字符, 区间最左边连续相同字符个数, 区间最右边字符, 区间最右边连续相同字符个数, 区间是否是连续相同字符
    if (rch1 == lch2) {
      val len = len1 max len2 max (rlen1 + llen2)
      val llen = if (mono1) llen1 + llen2 else llen1
      val rlen = if (mono2) rlen2 + rlen1 else rlen2
      (len, lch1, llen, rch2, rlen, mono1 && mono2)
    } else (len1 max len2, lch1, llen1, rch2, rlen2, false)
  },
)
// 最长的连续相同字符的子串的长度
t.query(10, 19)._1 // 3

更新(修改)

一个更新是一个自同态u:AAu:A\rightharpoonup A。如果更新整个[0,n1][0,n-1],那么线段树就从ff变成了ufu\circ f

线段树支持在一个子区间上部分应用这个更新。具体来讲,比如在[l,r][l,r]上应用,那么生成元f([0,0]),f([1,1]),,f([n1,n1])f([0, 0]),f([1, 1]),\dots,f([n-1, n-1])变成了f([0,0]),,f([l1,l1]),uf([l,l]),,uf([r,r]),f([r+1,r+1]),f([n1,n1])f([0, 0]),\dots,f([l-1,l-1]),u\circ f([l, l]),\dots,u\circ f([r,r]),f([r+1,r+1])\dots,f([n-1, n-1])

支持所有的AAA\rightharpoonup A更新是不可能的,故只考虑Hom(A,A)\mathrm{Hom}(A, A)的某个子集,用索引II记为U={ui:iI}U=\{u_i:i\in I\}。 代码中用II来表示和存储一个修改操作,放在lazy标记里。

UU须是Hom(A,A)\mathrm{Hom}(A, A)的子幺半群,这是lazy propagation的实现所需: 两个修改需要能合并成一个,而且lazy标记要能表示是否已修改(无修改是id:AA\mathrm{id}:A\to A,即幺元)。少一点性质都做不到lazy。

实际进行计算的是II,可选取II是幺半群,u:IUu_{\bullet}:I\to U为幺半群同构。

实现

要拥有修改操作,需要以下额外信息:

  • II的幺半群实例
  • u:IUu_{\bullet}: I\to U,或者说一个嵌入IHom(A,A)I\to \mathrm{Hom}(A,A),或者说I×AAI\times A\rightharpoonup A
import scala.reflect.ClassTag

trait SegmentTree[A, I] {
  def query(l: Int, r: Int): A
  def update(l: Int, r: Int, i: I): Unit
}

object SegmentTree {
  def apply[A: ClassTag, I: ClassTag](
      n: Int,
      init: Int => A,
      mul: (A, A) => A,
      iOne: I,
      iMul: (I, I) => I,
      update: (I, A) => A,
  ): SegmentTree[A, I] = {
    val tree = Array.ofDim[A](4 * n)
    val lazi = Array.fill(4 * n)(iOne)
    def build(l: Int, r: Int, idx: Int): Unit = if (l == r) {
      tree(idx) = init(l)
    } else {
      val m = l + (r - l) / 2
      build(l, m, idx * 2)
      build(m + 1, r, idx * 2 + 1)
      tree(idx) = mul(tree(idx * 2), tree(idx * 2 + 1))
    }
    build(0, n - 1, 1)

    def pushDown(l0: Int, r0: Int, idx: Int): Unit = {
      val m = l0 + (r0 - l0) / 2
      val lIdx = idx * 2
      val rIdx = lIdx + 1
      tree(lIdx) = update(lazi(idx), tree(lIdx))
      tree(rIdx) = update(lazi(idx), tree(rIdx))
      lazi(lIdx) = iMul(lazi(lIdx), lazi(idx))
      lazi(rIdx) = iMul(lazi(rIdx), lazi(idx))
      lazi(idx) = iOne
    }

    val updateWith = update

    new SegmentTree[A, I] {
      override def query(l: Int, r: Int): A = {
        def loop(l0: Int, r0: Int, idx: Int): A = if (l <= l0 && r0 <= r) {
          tree(idx)
        } else {
          if (lazi(idx) != iOne) {
            pushDown(l0, r0, idx)
          }
          val m0 = l0 + (r0 - l0) / 2
          val left = if (l <= m0) Some(loop(l0, m0, idx * 2)) else None
          val right = if (m0 < r) Some(loop(m0 + 1, r0, idx * 2 + 1)) else None
          List(left, right).flatten.reduce(mul)
        }
        loop(0, n - 1, 1)
      }
      override def update(l: Int, r: Int, i: I): Unit = {
        def loop(l0: Int, r0: Int, idx: Int): Unit = if (l <= l0 && r0 <= r) {
          tree(idx) = updateWith(i, tree(idx))
          lazi(idx) = iMul(lazi(idx), i)
        } else {
          val m = l0 + (r0 - l0) / 2
          if (lazi(idx) != iOne && l0 != r0) {
            pushDown(l0, r0, idx)
          }
          if (l <= m) loop(l0, m, idx * 2)
          if (m < r) loop(m + 1, r0, idx * 2 + 1)
          tree(idx) = mul(tree(idx * 2), tree(idx * 2 + 1))
        }
        loop(0, n - 1, 1)
      }
    }
  }
}

用例(查询区间和,逐点更新加某一数):

val t = SegmentTree(
  10,
  Array(1, 1, 4, 5, 1, 4, 1, 9, 1, 9).map(a => (a, 1)),
  { case ((s1, len1), (s2, len2)) => (s1 + s2, len1 + len2) },
  0,
  _ + _,
  { case (delta, (s, len)) => (s + delta * len, len) },
)
println(t.query(0, 3)._1) // 11
t.update(2, 5, 10)
println(t.query(0, 3)._1) // 31
println(t.query(5, 8)._1) // 25

用例(扫描线):

val t = SegmentTree(
  6,
  Array(1, 1, 4, 5, 1, 4).map(len => (len, 0, len)),
  { case ((len1, minThickness1, minThicknessLen1), (len2, minThickness2, minThicknessLen2)) =>
    if (minThickness1 < minThickness2) {
      (len1 + len2, minThickness1, minThicknessLen1)
    } else if (minThickness1 > minThickness2) {
      (len1 + len2, minThickness2, minThicknessLen2)
    } else {
      (len1 + len2, minThickness1, minThicknessLen1 + minThicknessLen2)
    }
  },
  0,
  _ + _,
  { case (delta, (len, minThickness, minThicknessLen)) =>
    (len, minThickness + delta, minThicknessLen)
  },
)
def queryCoverage(): Int = {
  val (len, minThickness, minThicknessLen) = t.query(0, 5)
  if (minThickness > 0) len else len - minThicknessLen
}
println(queryCoverage()) // 0
t.update(0, 2, 1) // add (0, 6)
println(queryCoverage()) // 6
t.update(2, 3, 1) // add (2, 11)
println(queryCoverage()) // 11
t.update(0, 2, -1) // remove (0, 6)
println(queryCoverage()) // 9
t.update(2, 3, -1) // remove (2, 11)
println(queryCoverage()) // 0

优化

  • 幺元建树:如果AA是幺半群,且初始化全部为幺元,直接用幺元填充tree即可。
  • 冗余的区间端点信息:更新过程中是知道每个被更新的区间的左右端点的,可传给update: (I, A) => A(改成update: (Int, Int, I, A) => A), 从而使AA中不再需要左右端点或者区间长度之类的字段。
  • AoS to SoA:特化然后改写。