Skip to content

Commit

Permalink
Merge pull request #1163 from typelevel/topic/lru-cache
Browse files Browse the repository at this point in the history
Replace SemispaceCache with LRU Cache
  • Loading branch information
mpilquist authored Jan 16, 2025
2 parents b314c69 + 68ed3ae commit 1884e71
Show file tree
Hide file tree
Showing 9 changed files with 310 additions and 242 deletions.
24 changes: 12 additions & 12 deletions modules/core/shared/src/main/scala/Session.scala
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,9 @@ object Session {
ssl: SSL = SSL.None,
parameters: Map[String, String] = Session.DefaultConnectionParameters,
socketOptions: List[SocketOption] = Session.DefaultSocketOptions,
commandCache: Int = 1024,
queryCache: Int = 1024,
parseCache: Int = 1024,
commandCache: Int = 2048,
queryCache: Int = 2048,
parseCache: Int = 2048,
readTimeout: Duration = Duration.Inf,
redactionStrategy: RedactionStrategy = RedactionStrategy.OptIn,
): Resource[F, Resource[F, Session[F]]] = {
Expand Down Expand Up @@ -470,9 +470,9 @@ object Session {
ssl: SSL = SSL.None,
parameters: Map[String, String] = Session.DefaultConnectionParameters,
socketOptions: List[SocketOption] = Session.DefaultSocketOptions,
commandCache: Int = 1024,
queryCache: Int = 1024,
parseCache: Int = 1024,
commandCache: Int = 2048,
queryCache: Int = 2048,
parseCache: Int = 2048,
readTimeout: Duration = Duration.Inf,
redactionStrategy: RedactionStrategy = RedactionStrategy.OptIn,
): Resource[F, Tracer[F] => Resource[F, Session[F]]] = {
Expand Down Expand Up @@ -508,9 +508,9 @@ object Session {
strategy: Typer.Strategy = Typer.Strategy.BuiltinsOnly,
ssl: SSL = SSL.None,
parameters: Map[String, String] = Session.DefaultConnectionParameters,
commandCache: Int = 1024,
queryCache: Int = 1024,
parseCache: Int = 1024,
commandCache: Int = 2048,
queryCache: Int = 2048,
parseCache: Int = 2048,
readTimeout: Duration = Duration.Inf,
redactionStrategy: RedactionStrategy = RedactionStrategy.OptIn,
): Resource[F, Session[F]] =
Expand All @@ -532,9 +532,9 @@ object Session {
strategy: Typer.Strategy = Typer.Strategy.BuiltinsOnly,
ssl: SSL = SSL.None,
parameters: Map[String, String] = Session.DefaultConnectionParameters,
commandCache: Int = 1024,
queryCache: Int = 1024,
parseCache: Int = 1024,
commandCache: Int = 2048,
queryCache: Int = 2048,
parseCache: Int = 2048,
readTimeout: Duration = Duration.Inf,
redactionStrategy: RedactionStrategy = RedactionStrategy.OptIn,
): Tracer[F] => Resource[F, Session[F]] =
Expand Down
96 changes: 96 additions & 0 deletions modules/core/shared/src/main/scala/data/Cache.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// Copyright (c) 2018-2024 by Rob Norris and Contributors
// This software is licensed under the MIT License (MIT).
// For more information see LICENSE or https://opensource.org/licenses/MIT

package skunk.data

/**
* Immutable, least recently used cache.
*
* Entries are stored in the `entries` hash map. A numeric stamp is assigned to
* each entry and stored in the `usages` field, which provides a bidirectional
* mapping between stamp and key, sorted by stamp. The `entries` and `usages`
* fields always have the same size.
*
* Upon put and get of an entry, a new stamp is assigned and `usages`
* is updated. Stamps are assigned in ascending order and each stamp is used only once.
* Hence, the head of `usages` contains the least recently used key.
*/
sealed abstract case class Cache[K, V] private (
max: Int,
entries: Map[K, V]
)(usages: SortedBiMap[Long, K],
stamp: Long
) {
assert(entries.size == usages.size)

def size: Int = entries.size

def contains(k: K): Boolean = entries.contains(k)

/**
* Gets the value associated with the specified key.
*
* Accessing an entry makes it the most recently used entry, hence a new cache
* is returned with the target entry updated to reflect the recent access.
*/
def get(k: K): Option[(Cache[K, V], V)] =
entries.get(k) match {
case Some(v) =>
val newUsages = usages + (stamp -> k)
val newCache = Cache(max, entries, newUsages, stamp + 1)
Some(newCache -> v)
case None =>
None
}

/**
* Returns a new cache with the specified entry added along with the
* entry that was evicted, if any.
*
* The evicted value is defined under two circumstances:
* - the cache already contains a different value for the specified key,
* in which case the old pairing is returned
* - the cache has reeached its max size, in which case the least recently
* used value is evicted
*
* Note: if the cache contains (k, v), calling `put(k, v)` does NOT result
* in an eviction, but calling `put(k, v2)` where `v != v2` does.
*/
def put(k: K, v: V): (Cache[K, V], Option[(K, V)]) =
if (max <= 0) {
// max is 0 so immediately evict the new entry
(this, Some((k, v)))
} else if (entries.size >= max && !contains(k)) {
// at max size already and we need to add a new key, hence we must evict
// the least recently used entry
val (lruStamp, lruKey) = usages.head
val newEntries = entries - lruKey + (k -> v)
val newUsages = usages - lruStamp + (stamp -> k)
val newCache = Cache(max, newEntries, newUsages, stamp + 1)
(newCache, Some(lruKey -> entries(lruKey)))
} else {
// not growing past max size at this point, so only need to evict if
// the new entry is replacing an existing entry with different value
val newEntries = entries + (k -> v)
val newUsages = usages + (stamp -> k)
val newCache = Cache(max, newEntries, newUsages, stamp + 1)
val evicted = entries.get(k).filter(_ != v).map(k -> _)
(newCache, evicted)
}

def values: Iterable[V] = entries.values

override def toString: String =
usages.entries.iterator.map { case (_, k) => s"$k -> ${entries(k)}" }.mkString("Cache(", ", ", ")")
}

object Cache {
private def apply[K, V](max: Int, entries: Map[K, V], usages: SortedBiMap[Long, K], stamp: Long): Cache[K, V] =
new Cache(max, entries)(usages, stamp) {}

def empty[K, V](max: Int): Cache[K, V] =
apply(max max 0, Map.empty, SortedBiMap.empty, 0L)
}


83 changes: 0 additions & 83 deletions modules/core/shared/src/main/scala/data/SemispaceCache.scala

This file was deleted.

48 changes: 48 additions & 0 deletions modules/core/shared/src/main/scala/data/SortedBiMap.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) 2018-2024 by Rob Norris and Contributors
// This software is licensed under the MIT License (MIT).
// For more information see LICENSE or https://opensource.org/licenses/MIT

package skunk.data

import scala.collection.immutable.SortedMap
import scala.math.Ordering

/** Immutable bi-directional map that is sorted by key. */
sealed abstract case class SortedBiMap[K: Ordering, V](entries: SortedMap[K, V], inverse: Map[V, K]) {
assert(entries.size == inverse.size)

def size: Int = entries.size

def head: (K, V) = entries.head

def get(k: K): Option[V] = entries.get(k)

def put(k: K, v: V): SortedBiMap[K, V] =
// nb: couple important properties here:
// - SortedBiMap(k0 -> v, v -> k0).put(k1, v) == SortedBiMap(k1 -> v, v -> k1)
// - SortedBiMap(k -> v0, v0 -> k).put(k, v1) == SortedBiMap(k -> v1, v1 -> k)
SortedBiMap(
inverse.get(v).fold(entries)(entries - _) + (k -> v),
entries.get(k).fold(inverse)(inverse - _) + (v -> k))

def +(kv: (K, V)): SortedBiMap[K, V] = put(kv._1, kv._2)

def -(k: K): SortedBiMap[K, V] =
get(k) match {
case Some(v) => SortedBiMap(entries - k, inverse - v)
case None => this
}

def inverseGet(v: V): Option[K] = inverse.get(v)

override def toString: String =
entries.iterator.map { case (k, v) => s"$k <-> $v" }.mkString("SortedBiMap(", ", ", ")")
}

object SortedBiMap {
private def apply[K: Ordering, V](entries: SortedMap[K, V], inverse: Map[V, K]): SortedBiMap[K, V] =
new SortedBiMap[K, V](entries, inverse) {}

def empty[K: Ordering, V]: SortedBiMap[K, V] = apply(SortedMap.empty, Map.empty)
}

35 changes: 23 additions & 12 deletions modules/core/shared/src/main/scala/util/StatementCache.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import cats.{ Functor, ~> }
import cats.syntax.all._
import skunk.Statement
import cats.effect.kernel.Ref
import skunk.data.SemispaceCache
import skunk.data.Cache

/** An LRU (by access) cache, keyed by statement `CacheKey`. */
sealed trait StatementCache[F[_], V] { outer =>
Expand All @@ -35,31 +35,42 @@ sealed trait StatementCache[F[_], V] { outer =>
object StatementCache {

def empty[F[_]: Functor: Ref.Make, V](max: Int, trackEviction: Boolean): F[StatementCache[F, V]] =
Ref[F].of(SemispaceCache.empty[Statement.CacheKey, V](max, trackEviction)).map { ref =>
// State is the cache and a set of evicted values; the evicted set only grows when trackEviction is true
Ref[F].of((Cache.empty[Statement.CacheKey, V](max), Set.empty[V])).map { ref =>
new StatementCache[F, V] {

def get(k: Statement[_]): F[Option[V]] =
ref.modify { c =>
c.lookup(k.cacheKey) match {
case Some((cʹ, v)) => (cʹ, Some(v))
case None => (c, None)
ref.modify { case (c, evicted) =>
c.get(k.cacheKey) match {
case Some((cʹ, v)) => (cʹ -> evicted, Some(v))
case None => (c -> evicted, None)
}
}

def put(k: Statement[_], v: V): F[Unit] =
ref.update(_.insert(k.cacheKey, v))
ref.update { case (c, evicted) =>
val (c2, e) = c.put(k.cacheKey, v)
// Remove the value we just inserted from the evicted set and add the newly evicted value, if any
val evicted2 = e.filter(_ => trackEviction).fold(evicted - v) { case (_, ev) => evicted - v + ev }
(c2, evicted2)
}

def containsKey(k: Statement[_]): F[Boolean] =
ref.get.map(_.containsKey(k.cacheKey))
ref.get.map(_._1.contains(k.cacheKey))

def clear: F[Unit] =
ref.update(_.evictAll)
ref.update { case (c, evicted) =>
val evicted2 = if (trackEviction) evicted ++ c.values else evicted
(Cache.empty[Statement.CacheKey, V](max), evicted2)
}

def values: F[List[V]] =
ref.get.map(_.values)
ref.get.map(_._1.values.toList)

def clearEvicted: F[List[V]] =
ref.modify(_.clearEvicted)
def clearEvicted: F[List[V]] =
ref.modify { case (c, evicted) =>
(c, Set.empty[V]) -> evicted.toList
}
}
}
}
26 changes: 14 additions & 12 deletions modules/tests/shared/src/test/scala/PrepareCacheTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import skunk.implicits._
import skunk.codec.numeric.int8
import skunk.codec.text
import skunk.codec.boolean
import cats.syntax.all.*
import cats.syntax.all._

class PrepareCacheTest extends SkunkTest {

Expand All @@ -17,16 +17,8 @@ class PrepareCacheTest extends SkunkTest {
private val pgStatementsCountByStatement = sql"select count(*) from pg_prepared_statements where statement = ${text.text}".query(int8)
private val pgStatementsCount = sql"select count(*) from pg_prepared_statements".query(int8)
private val pgStatements = sql"select statement from pg_prepared_statements order by prepare_time".query(text.text)

pooledTest("concurrent prepare cache should close evicted prepared statements at end of session", max = 1, parseCacheSize = 2) { p =>
List.fill(4)(
p.use { s =>
s.execute(pgStatementsByName)("foo").void >> s.execute(pgStatementsByStatement)("bar").void >> s.execute(pgStatementsCountByStatement)("baz").void
}
).sequence
}

pooledTest("prepare cache should close evicted prepared statements at end of session", max = 1, parseCacheSize = 1) { p =>

pooledTest("prepare cache should close evicted prepared statements at end of session", max = 1, parseCacheSize = 2) { p =>
p.use { s =>
s.execute(pgStatementsByName)("foo").void >>
s.execute(pgStatementsByStatement)("bar").void >>
Expand All @@ -49,7 +41,7 @@ class PrepareCacheTest extends SkunkTest {
}
}

pooledTest("prepared statements via prepare shouldn't get evicted until they go out of scope", max = 1, parseCacheSize = 1) { p =>
pooledTest("prepared statements via prepare shouldn't get evicted until they go out of scope", max = 1, parseCacheSize = 2) { p =>
p.use { s =>
// creates entry in cache
s.prepare(pgStatementsByName)
Expand Down Expand Up @@ -97,4 +89,14 @@ class PrepareCacheTest extends SkunkTest {
}
}
}

pooledTest("concurrent prepare cache should close evicted prepared statements at end of session", max = 1, parseCacheSize = 4) { p =>
List.fill(8)(
p.use { s =>
s.execute(pgStatementsByName)("foo").void >>
s.execute(pgStatementsByStatement)("bar").void >>
s.execute(pgStatementsCountByStatement)("baz").void
}
).sequence
}
}
Loading

0 comments on commit 1884e71

Please sign in to comment.