main

square/leakcanary

Last updated at: 29/12/2023 09:39

TreemapLayout.kt

TLDR

The provided file, TreemapLayout.kt, is a Kotlin file that contains the implementation of a treemap layout algorithm based on the D3.js treemap implementation.

Classes

TreemapLayout<T>

This class is the main class that implements the treemap layout algorithm. It takes a generic type T. The class has the following properties:

  • paddingInner: A function that returns the inner padding for each node layout. Default value is 0.
  • paddingLeft: A function that returns the left padding for each node layout. Default value is 0.
  • paddingTop: A function that returns the top padding for each node layout. Default value is 0.
  • paddingRight: A function that returns the right padding for each node layout. Default value is 0.
  • paddingBottom: A function that returns the bottom padding for each node layout. Default value is 0.

The class also contains the following methods:

  • layout(root: NodeValue<T>, size: Size): NodeLayout<T>
    • This method calculates the position and size of each node in the treemap layout based on the provided root node and size.
    • It returns the root node layout.

The class also contains a data class, NodeValue<T>, which represents a node in the treemap layout. It has the following properties:

  • value: The value associated with the node.
  • content: The content associated with the node.
  • children: A list of child nodes.

NodeLayout<T>

This interface represents a node in the treemap layout. It has the following properties:

  • value: The value associated with the node.
  • content: The content associated with the node.
  • depth: The depth of the node in the treemap layout.
  • children: A list of child nodes.
  • topLeft: The top left offset of the node.
  • size: The size of the node.

InternalNodeLayout<T>

This class implements the NodeLayout<T> interface and represents an internal node in the treemap layout. It adds additional properties:

  • x0: The x-coordinate of the top left corner of the node.
  • y0: The y-coordinate of the top left corner of the node.
  • x1: The x-coordinate of the bottom right corner of the node.
  • y1: The y-coordinate of the bottom right corner of the node.

Row

This class is used internally by the treemap layout algorithm and represents a row of nodes in the layout. It has the following properties:

  • value: The value associated with the row.
  • children: A list of child nodes.

Methods

depthFirstTraversal(callback: (N) -> Unit)

This method is an extension function for the NodeLayout<T> interface. It performs a depth-first traversal of the node and its descendants in a pre-order manner. The callback function is invoked for each node in the traversal, passing in the current node.

package org.leakcanary.screens

import androidx.compose.ui.geometry.Offset
import androidx.compose.ui.geometry.Size
import kotlin.math.max
import kotlin.math.sqrt
import org.leakcanary.screens.TreemapLayout.NodeLayout

/**
 * Based on https://github.com/d3/d3-hierarchy Treemap implementation.
 *
 */
class TreemapLayout<T>(
  private val paddingInner: (NodeLayout<T>) -> Float = { 0f },
  private val paddingLeft: (NodeLayout<T>) -> Float = { 0f },
  private val paddingTop: (NodeLayout<T>) -> Float = { 0f },
  private val paddingRight: (NodeLayout<T>) -> Float = { 0f },
  private val paddingBottom: (NodeLayout<T>) -> Float = { 0f }
) {

  data class NodeValue<T>(
    // TODO Float
    val value: Int,
    val content: T,
    val children: List<NodeValue<T>>
  )

  interface NodeLayout<T> {
    val value: Int
    val content: T
    val depth: Int
    val children: List<NodeLayout<T>>
    val topLeft: Offset
    val size: Size
  }

  fun layout(
    root: NodeValue<T>,
    size: Size
  ): NodeLayout<T> {
    val rootLayout = root.mapNode()
    rootLayout.x0 = 0f
    rootLayout.y0 = 0f
    rootLayout.x1 = size.width
    rootLayout.y1 = size.height
    val paddingStack = mutableListOf(0f)
    rootLayout.depthFirstTraversal { node ->
      positionNode(node, paddingStack)
    }
    return rootLayout
  }

  private fun <T> NodeValue<T>.mapNode(depth: Int = 0): InternalNodeLayout<T> {
    return InternalNodeLayout(
      value = value,
      content = content,
      depth = depth,
      children = children.map { it.mapNode(depth + 1) }
    )
  }

  private fun positionNode(
    node: InternalNodeLayout<T>,
    paddingStack: MutableList<Float>
  ) {
    var p = paddingStack[node.depth]
    var x0 = node.x0 + p
    var y0 = node.y0 + p
    var x1 = node.x1 - p
    var y1 = node.y1 - p
    if (x1 < x0) {
      x1 = (x0 + x1) / 2
      x0 = x1
    }
    if (y1 < y0) {
      y1 = (y0 + y1) / 2
      y0 = y1
    }
    if (node.children.isNotEmpty()) {
      // TODO Debug with examples to check that padding is right.
      val halfPaddingInner = paddingInner(node) / 2
      val childDepth = node.depth + 1
      if (childDepth < paddingStack.size) {
        paddingStack[childDepth] = halfPaddingInner
      } else {
        paddingStack += halfPaddingInner
      }
      p = halfPaddingInner
      x0 += paddingLeft(node) - p
      y0 += paddingTop(node) - p
      x1 -= paddingRight(node) - p
      y1 -= paddingBottom(node) - p
      if (x1 < x0) {
        x1 = (x0 + x1) / 2
        x0 = x1
      }
      if (y1 < y0) {
        y1 = (y0 + y1) / 2
        y0 = y1
      }
      squarifyRatio(phi, node, x0, y0, x1, y1)
    }
  }

  private data class InternalNodeLayout<T>(
    override val value: Int,
    override val content: T,
    override val depth: Int,
    override val children: List<InternalNodeLayout<T>>
  ) : NodeLayout<T> {

    var x0 = 0f
    var y0 = 0f
    var x1 = 0f
    var y1 = 0f

    override val topLeft: Offset
      get() = Offset(x0, y0)
    override val size: Size
      get() = Size(x1 - x0, y1 - y0)
  }

  private class Row(
    val value: Int,
    val children: List<InternalNodeLayout<*>>
  )

  private fun squarifyRatio(
    ratio: Float,
    parent: InternalNodeLayout<*>,
    x0Start: Float,
    y0Start: Float,
    x1Start: Float,
    y1Start: Float
  ) {
    // TODO Check out resquarity and try that?
    val nodes = parent.children

    var value = parent.value

    var x0 = x0Start
    var y0 = y0Start
    var x1 = x1Start
    var y1 = y1Start

    var i0 = 0
    var i1 = 0
    val n = nodes.size
    while (i0 < n) {
      val dx = x1 - x0
      val dy = y1 - y0

      // Find the next non-empty node.
      var sumValue: Int
      do {
        sumValue = nodes[i1].value
        i1++
      } while (sumValue == 0 && i1 < n)
      var minValue = sumValue
      var maxValue = sumValue
      val alpha = max(dy / dx, dx / dy) / (value * ratio)
      var beta = sumValue * sumValue * alpha
      var minRatio = max(maxValue / beta, beta / minValue)

      // Keep adding nodes while the aspect ratio maintains or improves.
      while (i1 < n) {
        val nodeValue = nodes[i1].value
        sumValue += nodeValue
        if (nodeValue < minValue) minValue = nodeValue
        if (nodeValue > maxValue) maxValue = nodeValue
        beta = sumValue * sumValue * alpha
        val newRatio = max(maxValue / beta, beta / minValue)
        if (newRatio > minRatio) {
          sumValue -= nodeValue
          break
        }
        minRatio = newRatio
        i1++
      }

      // Position and record the row orientation.
      val row = Row(
        value = sumValue,
        children = nodes.slice(i0 until i1)
      )

      if (dx < dy) {
        val initialY0 = y0
        val lastY = if (value > 0) {
          y0 += dy * sumValue / value
          y0
        } else {
          y1
        }
        treemapDice(row, x0, initialY0, x1, lastY)
      } else {
        val initialX0 = x0
        val lastX = if (value > 0) {
          x0 += dx * sumValue / value
          x0
        } else {
          x1
        }
        treemapSlice(row, initialX0, y0, lastX, y1)
      }
      value -= sumValue
      i0 = i1
    }
  }

  private fun treemapSlice(
    parent: Row,
    x0Start: Float,
    y0Start: Float,
    x1Start: Float,
    y1Start: Float
  ) {
    val nodes = parent.children

    val k = if (parent.value > 0) {
      (y1Start - y0Start) / parent.value
    } else {
      0f
    }

    var y0 = y0Start

    var i = -1
    val n = nodes.size
    while (++i < n) {
      val node = nodes[i]
      node.x0 = x0Start
      node.x1 = x1Start
      node.y0 = y0
      y0 += node.value.toFloat() * k
      node.y1 = y0
    }
  }

  private fun treemapDice(
    parent: Row,
    x0Start: Float,
    y0Start: Float,
    x1Start: Float,
    y1Start: Float
  ) {
    val nodes = parent.children

    val n = nodes.size
    val k = if (parent.value > 0) {
      (x1Start - x0Start) / parent.value
    } else {
      0f
    }

    var i = -1
    var x0 = x0Start
    while (++i < n) {
      val node = nodes[i]
      node.y0 = y0Start
      node.y1 = y1Start
      node.x0 = x0
      x0 += node.value.toFloat() * k
      node.x1 = x0
    }
  }

  companion object {
    // Golden ratio
    val phi = (1 + sqrt(5f)) / 2
  }
}

/**
 * Invokes [callback] for node and each descendant in pre-order traversal, such that a given node
 * is only visited after all of its ancestors have already been visited. [callback] is passed the
 * current descendant, the zero-based traversal index, and this node.
 */
inline fun <T, N : NodeLayout<T>> N.depthFirstTraversal(callback: (N) -> Unit) {
  var node = this
  val nodes = ArrayDeque<N>()
  nodes += node
  while (nodes.isNotEmpty()) {
    node = nodes.removeLast()
    callback(node)
    val children = node.children
    if (children.isNotEmpty()) {
      for (child in children.reversed()) {
        @Suppress("UNCHECKED_CAST")
        nodes.addLast(child as N)
      }
    }
  }
}