diff --git a/collect/src/main/kotlin/com/certora/collect/AbstractKeySet.kt b/collect/src/main/kotlin/com/certora/collect/AbstractKeySet.kt new file mode 100644 index 0000000..bc56418 --- /dev/null +++ b/collect/src/main/kotlin/com/certora/collect/AbstractKeySet.kt @@ -0,0 +1,59 @@ +package com.certora.collect + +/** + Presents the keys of a [TreapMap] as a [TreapSet]. + + The idea here is that a `TreapMap` is stored with the same Treap structure as a `TreapSet`, so we can very + quickly create the corresponding `TreapSet` when needed, in O(n) time (as opposed to the naive O(n*log(n)) + method). + + We lazily initialize the set, so that we don't create it until we need it. For many operations, we can avoid + creating the set entirely, and just use the map directly. However, many operations, e.g. [addAll]/[union] and + [retainAll/intersect], are much more efficient when we have a [TreapSet], so we create it when needed. + */ +internal abstract class AbstractKeySet<@Treapable K, S : TreapSet> : TreapSet { + /** + The map whose keys we are presenting as a set. We prefer to use the map directly when possible, so we don't + need to create the set. + */ + abstract val map: AbstractTreapMap + /** + The set of keys. This is a lazy property so that we don't create the set until we need it. + */ + abstract val keys: Lazy + + @Suppress("Treapability") + override fun hashCode() = keys.value.hashCode() + override fun equals(other: Any?) = keys.value.equals(other) + override fun toString() = keys.value.toString() + + override val size get() = map.size + override fun isEmpty() = map.isEmpty() + override fun clear() = treapSetOf() + + override operator fun contains(element: K) = map.containsKey(element) + override operator fun iterator() = map.entrySequence().map { it.key }.iterator() + + override fun add(element: K) = keys.value.add(element) + override fun addAll(elements: Collection) = keys.value.addAll(elements) + override fun remove(element: K) = keys.value.remove(element) + override fun removeAll(elements: Collection) = keys.value.removeAll(elements) + override fun removeAll(predicate: (K) -> Boolean) = keys.value.removeAll(predicate) + override fun retainAll(elements: Collection) = keys.value.retainAll(elements) + + override fun single() = map.single().key + override fun singleOrNull() = map.singleOrNull()?.key + override fun arbitraryOrNull() = map.arbitraryOrNull()?.key + + override fun containsAny(elements: Iterable) = keys.value.containsAny(elements) + override fun containsAny(predicate: (K) -> Boolean) = (this as Iterable).any(predicate) + override fun containsAll(elements: Collection) = keys.value.containsAll(elements) + override fun findEqual(element: K) = keys.value.findEqual(element) + + override fun forEachElement(action: (K) -> Unit) = map.forEachEntry { action(it.key) } + + override fun mapReduce(map: (K) -> R, reduce: (R, R) -> R) = + this.map.mapReduce({ k, _ -> map(k) }, reduce) + override fun parallelMapReduce(map: (K) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int) = + this.map.parallelMapReduce({ k, _ -> map(k) }, reduce, parallelThresholdLog2) +} diff --git a/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt index d0a28c7..f3994d4 100644 --- a/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt @@ -102,6 +102,7 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT return when { otherMap == null -> false otherMap === this -> true + otherMap.isEmpty() -> false // NB AbstractTreapMap always contains at least one entry else -> otherMap.useAsTreap( { otherTreap -> this.self.deepEquals(otherTreap) }, { other.size == this.size && other.entries.all { this.containsEntry(it) }} @@ -112,6 +113,9 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT override val size: Int get() = computeSize() override fun isEmpty(): Boolean = false + // NB AbstractTreapMap always contains at least one entry + override fun single() = singleOrNull() ?: throw IllegalArgumentException("Map contains more than one entry") + override fun containsKey(key: K) = key.toTreapKey()?.let { self.find(it) }?.shallowContainsKey(key) ?: false @@ -140,14 +144,6 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT override fun iterator() = entrySequence().iterator() } - override val keys: ImmutableSet - get() = object: AbstractSet(), ImmutableSet { - override val size get() = this@AbstractTreapMap.size - override fun isEmpty() = this@AbstractTreapMap.isEmpty() - override operator fun contains(element: K) = containsKey(element) - override operator fun iterator() = entrySequence().map { it.key }.iterator() - } - override val values: ImmutableCollection get() = object: AbstractCollection(), ImmutableCollection { override val size get() = this@AbstractTreapMap.size diff --git a/collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt b/collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt index 4900c4e..856214e 100644 --- a/collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt @@ -56,8 +56,6 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet> */ abstract fun shallowForEach(action: (element: E) -> Unit): Unit - abstract fun shallowGetSingleElement(): E? - abstract infix fun shallowUnion(that: S): S abstract infix fun shallowIntersect(that: S): S? abstract infix fun shallowDifference(that: S): S? @@ -85,6 +83,7 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet> other == null -> false this === other -> true other !is Set<*> -> false + other.isEmpty() -> false // NB AbstractTreapSet always contains at least one element else -> (other as Set).useAsTreap( { otherTreap -> this.self.deepEquals(otherTreap) }, { this.size == other.size && this.containsAll(other) } @@ -136,26 +135,12 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet> override fun findEqual(element: E): E? = element.toTreapKey()?.let { self.find(it) }?.shallowFindEqual(element) - @Suppress("UNCHECKED_CAST") - override fun single(): E = getSingleElement() ?: when { - isEmpty() -> throw NoSuchElementException("Set is empty") - size > 1 -> throw IllegalArgumentException("Set has more than one element") - else -> null as E // The single element must have been null! - } - - override fun singleOrNull(): E? = getSingleElement() - override fun forEachElement(action: (element: E) -> Unit): Unit { left?.forEachElement(action) shallowForEach(action) right?.forEachElement(action) } - internal fun getSingleElement(): E? = when { - left === null && right === null -> shallowGetSingleElement() - else -> null - } - override fun mapReduce(map: (E) -> R, reduce: (R, R) -> R): R = notForking(self) { mapReduceImpl(map, reduce) } @@ -186,7 +171,7 @@ internal infix fun <@Treapable E, S : AbstractTreapSet> S?.treapUnion(that this == null -> that that == null -> this this === that -> this - that.getSingleElement() != null -> add(that) + that.singleOrNull() != null -> add(that) else -> { // remember, a.comparePriorityTo(b)==0 <=> a.compareKeyTo(b)==0 val c = this.comparePriorityTo(that) diff --git a/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt index c33a07a..e092758 100644 --- a/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt @@ -20,6 +20,8 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap = this override fun remove(key: K, value: V): TreapMap = this + override fun single(): Map.Entry = throw NoSuchElementException("Empty map.") + override fun singleOrNull(): Map.Entry? = null override fun arbitraryOrNull(): Map.Entry? = null override fun forEachEntry(action: (Map.Entry) -> Unit): Unit {} @@ -73,7 +75,7 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap> get() = persistentSetOf>() - override val keys: ImmutableSet get() = persistentSetOf() + override val keys: TreapSet get() = treapSetOf() override val values: ImmutableCollection get() = persistentSetOf() @Suppress("Treapability", "UNCHECKED_CAST") diff --git a/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt index 59cc204..957f7db 100644 --- a/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt @@ -32,6 +32,7 @@ internal class HashTreapMap<@Treapable K, V>( this as? HashTreapMap ?: (this as? PersistentMap.Builder)?.build() as? HashTreapMap + override fun singleOrNull() = MapEntry(key, value).takeIf { next == null && left == null && right == null } override fun arbitraryOrNull(): Map.Entry? = MapEntry(key, value) override fun getShallowMerger( @@ -358,6 +359,17 @@ internal class HashTreapMap<@Treapable K, V>( forEachPair { (k, v) -> action(MapEntry(k, v)) } right?.forEachEntry(action) } + + private fun treapSetFromKeys(): HashTreapSet = + HashTreapSet(treapKey, next?.toKeyList(), left?.treapSetFromKeys(), right?.treapSetFromKeys()) + + inner class KeySet : AbstractKeySet>() { + override val map get() = this@HashTreapMap + override val keys = lazy { treapSetFromKeys() } + override fun hashCode() = super.hashCode() // avoids treapability warning + } + + override val keys get() = KeySet() } internal interface KeyValuePairList { @@ -367,6 +379,8 @@ internal interface KeyValuePairList { operator fun component1() = key operator fun component2() = value + fun toKeyList(): ElementList.More = ElementList.More(key, next?.toKeyList()) + class More( override val key: K, override val value: V, diff --git a/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt b/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt index 987e503..e0ac2a7 100644 --- a/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt @@ -25,6 +25,7 @@ internal class HashTreapSet<@Treapable E>( override fun Iterable.toTreapSetOrNull(): HashTreapSet? = (this as? HashTreapSet) ?: (this as? TreapSet.Builder)?.build() as? HashTreapSet + ?: (this as? HashTreapMap.KeySet)?.keys?.value private inline fun ElementList?.forEachNodeElement(action: (E) -> Unit) { var current = this @@ -228,8 +229,13 @@ internal class HashTreapSet<@Treapable E>( } }.iterator() - override fun shallowGetSingleElement(): E? = element.takeIf { next == null } - + override fun singleOrNull(): E? = element.takeIf { next == null && left == null && right == null } + override fun single(): E { + if (next != null || left != null || right != null) { + throw IllegalArgumentException("Set contains more than one element") + } + return element + } override fun arbitraryOrNull(): E? = element override fun shallowMapReduce(map: (E) -> R, reduce: (R, R) -> R): R { diff --git a/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt index 37459eb..c2ad6cc 100644 --- a/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt @@ -32,6 +32,7 @@ internal class SortedTreapMap<@Treapable K, V>( this as? SortedTreapMap ?: (this as? PersistentMap.Builder)?.build() as? SortedTreapMap + override fun singleOrNull(): Map.Entry? = MapEntry(key, value).takeIf { left == null && right == null } override fun arbitraryOrNull(): Map.Entry? = MapEntry(key, value) override fun getShallowUnionMerger( @@ -170,4 +171,15 @@ internal class SortedTreapMap<@Treapable K, V>( action(this.asEntry()) right?.forEachEntry(action) } + + private fun treapSetFromKeys(): SortedTreapSet = + SortedTreapSet(treapKey, left?.treapSetFromKeys(), right?.treapSetFromKeys()) + + inner class KeySet : AbstractKeySet>() { + override val map get() = this@SortedTreapMap + override val keys = lazy { treapSetFromKeys() } + override fun hashCode() = super.hashCode() // avoids treapability warning + } + + override val keys get() = KeySet() } diff --git a/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt b/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt index 5704f0a..401031f 100644 --- a/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt @@ -27,6 +27,7 @@ internal class SortedTreapSet<@Treapable E>( override fun Iterable.toTreapSetOrNull(): SortedTreapSet? = (this as? SortedTreapSet) ?: (this as? PersistentSet.Builder)?.build() as? SortedTreapSet + ?: (this as? SortedTreapMap.KeySet)?.keys?.value override val self get() = this override fun iterator(): Iterator = this.asTreapSequence().map { it.treapKey }.iterator() @@ -49,7 +50,13 @@ internal class SortedTreapSet<@Treapable E>( override fun shallowRemove(element: E): SortedTreapSet? = null override fun shallowRemoveAll(predicate: (E) -> Boolean): SortedTreapSet? = this.takeIf { !predicate(treapKey) } override fun shallowComputeHashCode(): Int = treapKey.hashCode() - override fun shallowGetSingleElement(): E = treapKey + override fun singleOrNull(): E? = treapKey.takeIf { left == null && right == null } + override fun single(): E { + if (left != null || right != null) { + throw IllegalArgumentException("Set contains more than one element") + } + return treapKey + } override fun arbitraryOrNull(): E? = treapKey override fun shallowForEach(action: (element: E) -> Unit): Unit { action(treapKey) } override fun shallowMapReduce(map: (E) -> R, reduce: (R, R) -> R): R = map(treapKey) diff --git a/collect/src/main/kotlin/com/certora/collect/TreapMap.kt b/collect/src/main/kotlin/com/certora/collect/TreapMap.kt index ce67672..3d8c641 100644 --- a/collect/src/main/kotlin/com/certora/collect/TreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/TreapMap.kt @@ -12,6 +12,7 @@ public sealed interface TreapMap : PersistentMap { override fun remove(key: K, value: @UnsafeVariance V): TreapMap override fun putAll(m: Map): TreapMap override fun clear(): TreapMap + override val keys: TreapSet /** A [PersistentMap.Builder] that produces a [TreapMap]. @@ -23,6 +24,16 @@ public sealed interface TreapMap : PersistentMap { @Suppress("Treapability") override fun builder(): Builder = TreapMapBuilder(this) + /** + If this map contains exactly one entry, returns that entry. Otherwise, throws. + */ + public fun single(): Map.Entry + + /** + If this map contains exactly one entry, returns that entry. Otherwise, returns null + */ + public fun singleOrNull(): Map.Entry? + /** Returns an arbitrary entry from the map, or null if the map is empty. */ diff --git a/collect/src/main/kotlin/com/certora/collect/TreapSet.kt b/collect/src/main/kotlin/com/certora/collect/TreapSet.kt index 7d8d643..ba94236 100644 --- a/collect/src/main/kotlin/com/certora/collect/TreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/TreapSet.kt @@ -40,7 +40,7 @@ public sealed interface TreapSet : PersistentSet { public fun containsAny(predicate: (T) -> Boolean): Boolean /** - If this set contains exactly one element, returns that element. Otherwise, throws [NoSuchElementException]. + If this set contains exactly one element, returns that element. Otherwise, throws. */ public fun single(): T