Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikael Vejdemo-Johansson committed May 8, 2024
1 parent 95a2c18 commit 447ea48
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 71 deletions.
29 changes: 13 additions & 16 deletions src/main/scala/org/appliedtopology/tda4j/Chain.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class Chain[CellT <: Cell[CellT]: Ordering, CoefficientT: Fractional] private[td
val cmp = entries.ord
entries.headOption match {
case None => ()
case Some(_) => {
case Some(_) =>
val head = entries.dequeue
val cell = head._1
var acc = head._2
Expand All @@ -68,42 +68,39 @@ class Chain[CellT <: Cell[CellT]: Ordering, CoefficientT: Fractional] private[td
collapseHead()
else
entries.enqueue((cell, acc))
}
}
}

def collapseAll()(using fr: Fractional[CoefficientT]): Unit = {
def collapseAll()(using fr: Fractional[CoefficientT]): Unit =
entries = mutable.PriorityQueue.from(
entries
.groupMapReduce
(_._1) // group by cell
{(x) => x._2} // extract coefficient
.groupMapReduce(_._1) // group by cell
(x => x._2) // extract coefficient
(fr.plus) // sum the coefficient parts
.filter {(c,x) => x != fr.zero}
.filter((c, x) => x != fr.zero)
.iterator
.toSeq)
}
.toSeq
)

def isZero(): Boolean = {
collapseHead()
entries.isEmpty || (entries.head._2 == summon[Fractional[CoefficientT]].zero)
}
def items: Seq[(CellT,CoefficientT)] = entries.toSeq

def items: Seq[(CellT, CoefficientT)] = entries.toSeq

/** WARNING - this is potentially an expensive operation
*/
*/
override def equals(obj: Any): Boolean = obj match {
case other : Chain[CellT,CoefficientT] => {
case other: Chain[CellT, CoefficientT] =>
collapseAll()
other.collapseAll()
entries.iterator.toList.sorted(using entries.ord) == other.entries.iterator.toList.sorted(using other.entries.ord)
}
case _ => false
}

override def toString: String =
entries.iterator.map{(c,x) => s"${x.toString}${c.toString}"}.mkString(" + ")
entries.iterator.map((c, x) => s"${x.toString}${c.toString}").mkString(" + ")
}

object Chain {
Expand Down
52 changes: 30 additions & 22 deletions src/main/scala/org/appliedtopology/tda4j/Cube.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,19 @@ case class DegenerateInterval(n: Int) extends ElementaryInterval {
}
case class FullInterval(n: Int) extends ElementaryInterval {
override def boundary[CoefficientT: Fractional]: Chain[ElementaryInterval, CoefficientT] =
ChainOps[ElementaryInterval,CoefficientT]().minus(Chain(DegenerateInterval(n+1)),Chain(DegenerateInterval(n)))
override def toString: String = s"[$n,${n+1}]"
ChainOps[ElementaryInterval, CoefficientT]().minus(Chain(DegenerateInterval(n + 1)), Chain(DegenerateInterval(n)))
override def toString: String = s"[$n,${n + 1}]"
}

import Ordering.Implicits.seqOrdering
import scala.annotation.tailrec
given Ordering[ElementaryInterval] = Ordering.by { (i) => i.n }
given Ordering[ElementaryCube] = Ordering.by { (c) => c.intervals }
given Ordering[ElementaryInterval] = Ordering.by(i => i.n)
given Ordering[ElementaryCube] = Ordering.by(c => c.intervals)

case class ElementaryCube(val intervals: List[ElementaryInterval]) extends Cell[ElementaryCube] {
override def boundary[CoefficientT: Fractional]: Chain[ElementaryCube, CoefficientT] = {
val chainOps = ChainOps[ElementaryCube, CoefficientT]()
import chainOps.{*,given}
import chainOps.{*, given}
val fr = summon[Fractional[CoefficientT]]
import math.Fractional.Implicits.infixFractionalOps

Expand All @@ -31,40 +31,48 @@ case class ElementaryCube(val intervals: List[ElementaryInterval]) extends Cell[

// Given stuff-already-processed and an I*P decomposition, figure out whether this I changes the accumulated
// sign, and produce the ∂I * P + sign pair
def process(left: List[ElementaryInterval], current: ElementaryInterval, right: List[ElementaryInterval]):
(CoefficientT, Chain[ElementaryCube,CoefficientT]) = current match {
def process(
left: List[ElementaryInterval],
current: ElementaryInterval,
right: List[ElementaryInterval]
): (CoefficientT, Chain[ElementaryCube, CoefficientT]) = current match {
case DegenerateInterval(n) => (fr.one, Chain())
case FullInterval(n) => (fr.negate(fr.one), Chain(
ElementaryCube(left ++ (DegenerateInterval(n+1) :: right)) -> fr.one,
ElementaryCube(left ++ (DegenerateInterval(n) :: right)) -> fr.negate(fr.one),
))
case FullInterval(n) =>
(
fr.negate(fr.one),
Chain(
ElementaryCube(left ++ (DegenerateInterval(n + 1) :: right)) -> fr.one,
ElementaryCube(left ++ (DegenerateInterval(n) :: right)) -> fr.negate(fr.one)
)
)
}

@tailrec
def boundaryOf(left: List[ElementaryInterval],
current: ElementaryInterval,
right: List[ElementaryInterval],
sign: CoefficientT,
acc: Chain[ElementaryCube, CoefficientT]): Chain[ElementaryCube,CoefficientT] = {
def boundaryOf(
left: List[ElementaryInterval],
current: ElementaryInterval,
right: List[ElementaryInterval],
sign: CoefficientT,
acc: Chain[ElementaryCube, CoefficientT]
): Chain[ElementaryCube, CoefficientT] = {
val (sgn, newchain) = process(left, current, right)
right match {
case Nil => {
case Nil =>
// process current, then return the resulting acc
acc + ((sign * sgn) newchain)
}
case c :: cs =>
// process current, then call with c as new current
boundaryOf(left.appended(current), c, cs, sign*sgn, acc+((sign*sgn)newchain))
boundaryOf(left.appended(current), c, cs, sign * sgn, acc + ((sign * sgn) newchain))
}
}
intervals match {
case Nil => Chain()
case (i::is) => boundaryOf(List.empty, i, is, fr.one, Chain())
case Nil => Chain()
case (i :: is) => boundaryOf(List.empty, i, is, fr.one, Chain())
}
}

def emb: Int = intervals.size
def dim: Int = intervals.count { (i) => i.isInstanceOf[FullInterval] }
def dim: Int = intervals.count(i => i.isInstanceOf[FullInterval])

infix def cubeProduct(left: ElementaryCube, right: ElementaryCube): ElementaryCube =
ElementaryCube(left.intervals ++ right.intervals)
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/org/appliedtopology/tda4j/RingModule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ import scala.annotation.targetName
trait RingModule[T, R] {
rmod =>
def zero: T
def isZero(t: T): Boolean = (t == zero)
def isZero(t: T): Boolean = t == zero
def plus(x: T, y: T): T
def minus(x: T, y: T): T = plus(x, negate(y))
def negate(x: T): T = minus(zero, x)
def scale(x: R, y: T): T

extension (t: T) {
@targetName("add")
def +(rhs: T): T = plus(t, rhs)
Expand Down
6 changes: 4 additions & 2 deletions src/main/scala/org/appliedtopology/tda4j/SimplexStream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ trait SimplexStream[VertexT: Ordering, FiltrationT: Ordering]
}

object SimplexStream {
def from[VertexT:Ordering](stream: Seq[AbstractSimplex[VertexT]], metricSpace: FiniteMetricSpace[VertexT]):
SimplexStream[VertexT,Double] = new SimplexStream[VertexT,Double] {
def from[VertexT: Ordering](
stream: Seq[AbstractSimplex[VertexT]],
metricSpace: FiniteMetricSpace[VertexT]
): SimplexStream[VertexT, Double] = new SimplexStream[VertexT, Double] {

override def filtrationValue: PartialFunction[AbstractSimplex[VertexT], Double] =
FiniteMetricSpace.MaximumDistanceFiltrationValue[VertexT](metricSpace)(using summon[Ordering[VertexT]])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ object LazyStratifiedVietorisRips {
.combinations(2)
.toVector
.filter { xys =>
val List(x,y) = xys; metricSpace.distance(x,y) <= maxFVal
val List(x, y) = xys; metricSpace.distance(x, y) <= maxFVal
}
.sortBy { xys =>
val List(x, y) = xys; metricSpace.distance(x, y)
Expand Down
8 changes: 5 additions & 3 deletions src/test/scala/org/appliedtopology/tda4j/APISpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ class APISpec extends mutable.Specification {
val lazyHomology = persistentHomology(
SimplexStream.from(
LazyVietorisRips(
metricSpace,
1.5, 4
), metricSpace
metricSpace,
1.5,
4
),
metricSpace
)
)

Expand Down
22 changes: 12 additions & 10 deletions src/test/scala/org/appliedtopology/tda4j/CubicalSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@ import org.specs2.mutable

class CubicalSpec extends mutable.Specification with ScalaCheck {
"Testing our Cubical Set implementation" >> {
prop { (lst : List[(Int,Boolean)]) =>
val cube = ElementaryCube(lst.map { (j,b) => if(b) FullInterval(j) else DegenerateInterval(j) })
cube.boundary[Double] must beAnInstanceOf[Chain[ElementaryCube,Double]]
prop { (lst: List[(Int, Boolean)]) =>
val cube = ElementaryCube(lst.map((j, b) => if (b) FullInterval(j) else DegenerateInterval(j)))
cube.boundary[Double] must beAnInstanceOf[Chain[ElementaryCube, Double]]
}

val simpleCube = ElementaryCube(List(FullInterval(0),FullInterval(0)))
simpleCube.boundary[Double] must be_==(Chain[ElementaryCube,Double](
ElementaryCube(List(DegenerateInterval(1),FullInterval(0))) -> -1.0,
ElementaryCube(List(DegenerateInterval(0),FullInterval(0))) -> 1.0,
ElementaryCube(List(FullInterval(0),DegenerateInterval(1))) -> 1.0,
ElementaryCube(List(FullInterval(0),DegenerateInterval(0))) -> -1.0
))
val simpleCube = ElementaryCube(List(FullInterval(0), FullInterval(0)))
simpleCube.boundary[Double] must be_==(
Chain[ElementaryCube, Double](
ElementaryCube(List(DegenerateInterval(1), FullInterval(0))) -> -1.0,
ElementaryCube(List(DegenerateInterval(0), FullInterval(0))) -> 1.0,
ElementaryCube(List(FullInterval(0), DegenerateInterval(1))) -> 1.0,
ElementaryCube(List(FullInterval(0), DegenerateInterval(0))) -> -1.0
)
)
}
}
22 changes: 7 additions & 15 deletions src/test/scala/org/appliedtopology/tda4j/VietorisRipsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package org.appliedtopology.tda4j

import org.scalacheck.Gen
import org.scalacheck.Prop.forAll
import org.specs2.{ScalaCheck, Specification, mutable as s2mutable}
import org.specs2.{mutable as s2mutable, ScalaCheck, Specification}
import org.specs2.specification.core.Fragment
import org.specs2.execute.Result

Expand All @@ -11,11 +11,11 @@ import scala.collection.mutable
import scala.math.{cos, sin}
import scala.reflect.ClassTag

def matrixGen[T:ClassTag](g: Gen[T], dimension: Gen[Int], size: Gen[Int]): Gen[Array[Array[T]]] =
def matrixGen[T: ClassTag](g: Gen[T], dimension: Gen[Int], size: Gen[Int]): Gen[Array[Array[T]]] =
for
dim <- dimension
sz <- size
values <- Gen.listOfN(dim*sz, g)
values <- Gen.listOfN(dim * sz, g)
yield values.toArray.grouped(dim).toArray

class VietorisRipsSpec extends s2mutable.Specification with ScalaCheck {
Expand Down Expand Up @@ -75,7 +75,7 @@ class VietorisRipsSpec extends s2mutable.Specification with ScalaCheck {
CliqueFinder.simplexOrdering(metricSpace)
val lazyStream: LazyList[AbstractSimplex[Int]] =
LazyVietorisRips[Int](metricSpace, 0.4, maxD)
val strictStream: SimplexStream[Int,Double] =
val strictStream: SimplexStream[Int, Double] =
VietorisRips[Int](metricSpace, 0.4, maxD, ZomorodianIncremental[Int]())

val lazy100 = lazyStream.take(100).toList
Expand All @@ -85,10 +85,7 @@ class VietorisRipsSpec extends s2mutable.Specification with ScalaCheck {
}

s"Lazy Stratified Vietoris-Rips streams should" >> {
forAll(matrixGen[Double](
Gen.double,
Gen.chooseNum(1,10),
Gen.chooseNum(5,15))) { (pts : Array[Array[Double]]) =>
forAll(matrixGen[Double](Gen.double, Gen.chooseNum(1, 10), Gen.chooseNum(5, 15))) { (pts: Array[Array[Double]]) =>
val metricSpace: FiniteMetricSpace[Int] = EuclideanMetricSpace(pts.toSeq.map(_.toSeq))

given Ordering[AbstractSimplex[Int]] =
Expand All @@ -98,17 +95,12 @@ class VietorisRipsSpec extends s2mutable.Specification with ScalaCheck {
FiniteMetricSpace.MaximumDistanceFiltrationValue[Int](metricSpace)

val lazyStreams: Array[LazyList[AbstractSimplex[Int]]] =
LazyStratifiedVietorisRips(
metricSpace,
0.4,
2)
LazyStratifiedVietorisRips(metricSpace, 0.4, 2)

lazyStreams.toSeq must contain((lz:LazyList[AbstractSimplex[Int]]) =>
lazyStreams.toSeq must contain((lz: LazyList[AbstractSimplex[Int]]) =>
"must be in filtration order" ==>
(lz.iterator.map(filtrationValue).toSeq must beSorted)
)
}
}
}


0 comments on commit 447ea48

Please sign in to comment.