Skip to content

Commit

Permalink
Expanded on the context-driven paradigm. Dramatically slimmed down im…
Browse files Browse the repository at this point in the history
…plementation of `Chain` without any test regressions.
  • Loading branch information
Mikael Vejdemo-Johansson committed Feb 23, 2024
1 parent d57f975 commit 9ede311
Show file tree
Hide file tree
Showing 11 changed files with 222 additions and 84 deletions.
98 changes: 28 additions & 70 deletions src/main/scala/org/appliedtopology/tda4j/Chain.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,77 +21,10 @@ import scala.annotation.targetName
* @param chainMap
* Internal storage of the sorted map of the elements
*/
class Chain[CellT <: Cell[CellT]: Ordering, CoefficientT]
class Chain[CellT <: Cell[CellT]: Ordering, CoefficientT : Fractional]
/** chainMap is an immutable variable and constructor that uses Scala's SortedMap to make a key-value pairing of an CellT as the key and a
* CoefficientT type as the value. Here, we'll use the Using keyword to check for any relevant types for CoefficientT.
*/ (val chainMap: SortedMap[CellT, CoefficientT])(using
fr: Fractional[CoefficientT]
) {

/** Negate takes in a Chain instance and outputs a new Chain instance which
* takes in a chainMap (remember chainMap is using SortedMap 'under the
* hood', thus already has a key/value pair). transform() is used on this
* function's instance of chainMap to transform the key/value pairings in
* chainMap to the negative versions of the values in the value pairs.
* @return
* new Chain of negated values in key/value pairs
*/
def negate: Chain[CellT, CoefficientT] = new Chain(
chainMap.transform((k, v) => fr.negate(v))
)

/** Unary uses negate() for unary negation
*
* @return
* negation
*/
def unary_- : Chain[CellT, CoefficientT] = negate

/** scalarMultiply returns a new chain containing chainMap and its key/value
* pairing. In the chainMap, transform() is used on the value of the
* key/value pairing. On the value, it is transformed by multiplying each key
* by the CoefficientT, using the times() method of the Fractional trait,
* extending the Numeric library.
* @param c:
* method instance of CoefficientT
* @return
* new Chain with values in key/value pairing multiplied by c
*/

def scalarMultiply(c: CoefficientT): Chain[CellT, CoefficientT] =
new Chain(chainMap.transform((k, v) => fr.times(v, c)))

def *: : CoefficientT => Chain[CellT, CoefficientT] = scalarMultiply

/** add() adds the method instance of the keys of a chainMap to the classes
* chainMap keys. It then adds the result of this to a map which maps that
* key to the values of the class and the instance method's chainMap.
* @param that:
* Chain object composed of a Cell/Coefficient pair
* @return
* Chain object composed of a Cell/Coefficient pair
*/
def add(that: Chain[CellT, CoefficientT]): Chain[CellT, CoefficientT] =
Chain(
(chainMap.keySet | that.chainMap.keySet)
.map(k =>
(
k,
fr.plus(
chainMap.getOrElse(k, fr.zero),
that.chainMap.getOrElse(k, fr.zero)
)
)
)
.toSeq*
)

def + : Chain[CellT, CoefficientT] => Chain[CellT, CoefficientT] = add

def subtract(that: Chain[CellT, CoefficientT]): Chain[CellT, CoefficientT] =
this + (-that)

def - : Chain[CellT, CoefficientT] => Chain[CellT, CoefficientT] = subtract
*/ (val chainMap: SortedMap[CellT, CoefficientT]) {

override def toString: String =
chainMap.map((k, v) => s"${v.toString} *: ${k.toString}").mkString(" + ")
Expand Down Expand Up @@ -142,12 +75,37 @@ object Chain {
// Original apply innards
// new Chain[CellT, CoefficientT](SortedMap.from(items))

def apply[CellT <: Cell[CellT]: Ordering, CoefficientT](cell: CellT)(using
def apply[CellT <: Cell[CellT] : Ordering, CoefficientT](cell: CellT)(using
fr: Fractional[CoefficientT]
): Chain[CellT, CoefficientT] =
new Chain[CellT, CoefficientT](SortedMap.from(List(cell -> fr.one)))
}


class ChainOps[CellT <: Cell[CellT] : Ordering, CoefficientT : Fractional] extends RingModule[Chain[CellT, CoefficientT], CoefficientT] {
import Numeric.Implicits._

val zero: org.appliedtopology.tda4j.Chain[CellT, CoefficientT] = Chain()

def plus(x: Chain[CellT, CoefficientT], y: Chain[CellT, CoefficientT]): Chain[CellT, CoefficientT] =
Chain(
(for k <- (x.chainMap.keySet | y.chainMap.keySet)
yield {
val fr = summon[Fractional[CoefficientT]]
val xv : CoefficientT = x.chainMap.getOrElse(k, fr.zero)
val yv : CoefficientT = y.chainMap.getOrElse(k, fr.zero)
k -> fr.plus(xv,yv)
}).toSeq : _*
)

def scale(x: CoefficientT, y: Chain[CellT, CoefficientT]): Chain[CellT, CoefficientT] =
new Chain(y.chainMap.transform((k,v) => x*v))

override def negate(x: Chain[CellT, CoefficientT]): Chain[CellT, CoefficientT] =
new Chain(x.chainMap.transform((k, v) => -v))
}


/** Lightweight trait to define what it means to be a topological "Cell".
*
* Using F-bounded types to ensure reflective typing. See
Expand Down
36 changes: 36 additions & 0 deletions src/main/scala/org/appliedtopology/tda4j/RingModule.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package org.appliedtopology.tda4j

import scala.annotation.targetName

/** Specifies what it means for the type `T` to be a module (or vector space)
* over the [Numeric] (ie ring-like) type `R`.
*
* @tparam T Type of the module elements.
* @tparam R Type of the ring coefficients
*/
trait RingModule[T, R] {
val zero: T
def plus(x: T, y: T): T
def minus(x: T, y: T): T = plus(x, negate(y))
def negate(x: T): T = minus(zero, x)
def scale(x: R, y: T): T

extension (t: T) {
@targetName("add")
def +(rhs: T): T = plus(t, rhs)
@targetName("subtract")
def -(rhs: T): T = minus(t, rhs)
@targetName("scalarMultiplyRight")
def <*(rhs: R): T = scale(rhs, t)
def unary_- = negate(t)
}

extension (r: R) {
@targetName("scalarMultiplyLeft")
def *>(t: T): T = scale(r, t)
}
}

object RingModule:
def apply[T,R](using rm: RingModule[T,R]) = rm

35 changes: 30 additions & 5 deletions src/main/scala/org/appliedtopology/tda4j/Simplex.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,30 @@ import scala.math.Ordering.IntOrdering
import scala.math.Ordering.Double.IeeeOrdering
import math.Ordering.Implicits.sortedSetOrdering

trait SimplexContext[VertexT : Ordering] {
type Simplex = AbstractSimplex[VertexT]

given Ordering[Simplex] = sortedSetOrdering[AbstractSimplex, VertexT]

object Simplex {
def apply(vertices: VertexT*): Simplex = AbstractSimplex[VertexT](vertices : _*)
def empty: Simplex = AbstractSimplex.empty
def from(iterableOnce: IterableOnce[VertexT]): Simplex =
AbstractSimplex.from(iterableOnce)
def newBuilder: mutable.Builder[VertexT, Simplex] = AbstractSimplex.newBuilder
}

object s {
def apply(vertices: VertexT*): Simplex = Simplex(vertices: _*)
}

extension (s: Simplex) {
def className = "Simplex"
}
}



/** Type alias creating `Simplex` as the type representing
* `AbstractSimplex[Int]`
*
Expand All @@ -25,11 +49,11 @@ import math.Ordering.Implicits.sortedSetOrdering
* and override everything to get the type to print out at `Simplex`
* everywhere instead of as the more verbose `AbstractSimplex`
*/
type Simplex = AbstractSimplex[Int]

object Simplex {
def apply(vertices: Int*) = new Simplex(TreeSet[Int](vertices: _*))
}
//type Simplex = AbstractSimplex[Int]
//
//object Simplex {
// def apply(vertices: Int*) = new Simplex(TreeSet[Int](vertices: _*))
//}

/** Class representing an abstract simplex. Abstract simplices are sets (of
* totally ordered vertices) and thus are represented by a type that implements
Expand Down Expand Up @@ -139,3 +163,4 @@ object AbstractSimplex extends SortedIterableFactory[AbstractSimplex] {
}

}

8 changes: 4 additions & 4 deletions src/main/scala/org/appliedtopology/tda4j/VietorisRips.scala
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
package org.appliedtopology.tda4j

import org.appliedtopology.tda4j.FiniteMetricSpace.MaximumDistanceFiltrationValue

import scala.collection.immutable.{LazyList, SortedSet}
import scala.math.Ordering.Implicits.*
import Simplex.*
import org.appliedtopology.tda4j.FiniteMetricSpace.MaximumDistanceFiltrationValue
import scalax.collection.{edge, mutable as gmutable, Graph}
import scalax.collection.{Graph, edge, mutable as gmutable}
import scalax.collection.edge.Implicits.*
import scalax.collection.edge.WUnDiEdge

import scala.annotation.tailrec
import scala.collection.mutable
import scala.util.Sorting
import scala.util.control.*
import scala.util.chaining._
import scala.util.chaining.*

/** Convenience definition to allow us to choose a specific implementation.
*
Expand Down
9 changes: 8 additions & 1 deletion src/main/scala/org/appliedtopology/tda4j/package.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
package org.appliedtopology

import math.Ordering.Implicits.sortedSetOrdering

/** Package for the Scala library TDA4j
*/
package object tda4j {}
package object tda4j {
class TDAContext[VertexT : Ordering, CoefficientT : Fractional]
extends ChainOps[AbstractSimplex[VertexT], CoefficientT], SimplexContext[VertexT] {
given Conversion[Simplex, Chain[Simplex,CoefficientT]] = Chain.apply
}
}
19 changes: 19 additions & 0 deletions src/test/scala/org/appliedtopology/tda4j/APISpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package org.appliedtopology.tda4j

import org.specs2.mutable

class APISpec extends mutable.Specification {
"""Test case class for developing the non-Scala facing API functionality
|and the non-expert API functionality""".stripMargin

given ctx : TDAContext[Char, Double]()
import ctx.{*,given}

"we should be able to create and compute with chains" >> {
1.0 *> s(1,2) - s(2,3) should beEqualTo(
Chain(
Simplex(1,2) -> 1.0,
Simplex(2,3) -> -1.0)
)
}
}
39 changes: 37 additions & 2 deletions src/test/scala/org/appliedtopology/tda4j/ChainSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,19 @@ class ChainSpec extends mutable.Specification {
"""This is the specification for testing the Chain implementation.
|""".stripMargin.txt

given Conversion[AbstractSimplex[Int], Chain[AbstractSimplex[Int], Double]] =
given sc: SimplexContext[Int]()
import sc.*

given Conversion[Simplex, Chain[Simplex, Double]] =
Chain.apply

given Fractional[Double] = math.Numeric.DoubleIsFractional

given Ordering[Int] = math.Ordering.Int

given rm : RingModule[Chain[Simplex, Double], Double] = ChainOps()
import rm.*

"The `Chain` type should" >> {
val z1 = Chain(Simplex(1, 2, 3))
val z2 =
Expand Down Expand Up @@ -61,7 +67,7 @@ class ChainSpec extends mutable.Specification {
def e1 = {
val chain = z1
val expectedResult = Chain(Simplex(1, 2, 3))
val result = 2 *: chain
val result = 2 *> chain

result must beEqualTo(expectedResult)
}
Expand Down Expand Up @@ -116,4 +122,33 @@ class ChainSpec extends mutable.Specification {

}

"The `Chain` type should be comfortable to write expressions with" >> {
val z1 = Chain(s(1,2,3))
val z2 : Chain[Simplex,Double] = s(1,2) - s(1,3) + s(2,3)
val z3 = 1.0 *> s(1, 2, 5)
val z4 : Chain[Simplex,Double] = Simplex(1, 4, 8) <* 1.0
val z5 : Chain[Simplex,Double] = -s(1,2) + s(1,4) - s(2,3)
val z6 : Chain[Simplex,Double] = s(1,2) + s(2,3) - s(1,3)
val z7 : Chain[Simplex,Double] = s(1,2) - s(1,3) + s(2,3) + 0.0 *> s(3,4)

z1 must haveClass[Chain[Simplex, Double]]
z2 must beEqualTo(z6)
z2 must beEqualTo(z7)

val expectedResult1 = Chain(
Simplex(1, 2, 3) -> 1.0,
Simplex(1, 2) -> 1.0,
Simplex(1, 3) -> -1.0,
Simplex(2, 3) -> 1.0
)
val expectedResult2 = Chain(
Simplex(1, 2) -> 0.0,
Simplex(1, 3) -> 0.0,
Simplex(1, 4) -> 0.0,
Simplex(2, 3) -> 0.0
)

z1 + z2 must beEqualTo(s(1,2,3) + s(1,2) - s(1,3) + s(2,3))
z2 - z6 must beEqualTo(summon[RingModule[Chain[Simplex,Double], Double]].zero)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ class FiniteFieldSpec extends mutable.Specification {

val rand = Random
val x = Fp(rand.between(-100, 100))
var y = Fp(rand.between(-100, 100))
var y = Fp(rand.between(1, 100))
val z = Fp(rand.between(-100, 100))
if y == Fp(0) then y = Fp(rand.between(1, 100))
//if y == Fp(0) then y = Fp(rand.between(1, 100))
"all operations stay within -p/2, p/2" >> {
// noinspection ScalaRedundantConversion
eg((x * y).toInt must beBetween(-8, 8))
Expand Down
37 changes: 37 additions & 0 deletions src/test/scala/org/appliedtopology/tda4j/RingModuleSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package org.appliedtopology.tda4j
import org.specs2.mutable

class RingModuleSpec extends mutable.Specification {
"""This is a test module for developing the RingModule[M,R] interface
|and make sure that it does what we need it to do.
|""".stripMargin

object rm extends RingModule[(Int,Int), Int] {
val zero: (Int, Int) = (0, 0)

def plus(ls: (Int, Int), rs: (Int, Int)): (Int, Int) = (ls._1 + rs._1, ls._2 + rs._2)

override def negate(x: (Int, Int)): (Int, Int) = (-x._1, -x._2)

def scale(x: Int, y: (Int, Int)): (Int, Int) = (x * y._1, x * y._2)
}

given RingModule[(Int,Int), Int] = rm

"zero should exist" >> {
val v : (Int,Int) = rm.zero
v must be_==(rm.zero)
}

"addition should work" >> {
((2,3) + (4,5)) should be_== (6,8)
}

"scalar multiplication should work" >> {
(2,3) <* 4 should be_== (8,12)
}

"scalar left-multiplication should work" >> {
(4 *> (2,3)) should be_== (8,12)
}
}
Loading

0 comments on commit 9ede311

Please sign in to comment.