Skip to content

Commit

Permalink
fix: check current session's pending-write queue when recalling snaps…
Browse files Browse the repository at this point in the history
…hots (e.g. diffing) (#927)

* fix: check current session's pending-write queue when recalling snapshots (e.g. diffing)

* Make PyTestLocation hashable

* Explicitly set methodname to None for doctests

----------------------------------------------------------------------------------- benchmark: 3 tests -----------------------------------------------------------------------------------
Name (time in ms)          Min                 Max                Mean             StdDev              Median                IQR            Outliers     OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_1000x_reads      666.9710 (1.0)      748.6652 (1.0)      705.2418 (1.0)      37.2862 (1.0)      703.0552 (1.0)      70.1912 (1.07)          2;0  1.4180 (1.0)           5           1
test_standard         669.7840 (1.00)     843.3747 (1.13)     733.8905 (1.04)     68.2257 (1.83)     705.8282 (1.00)     85.6269 (1.30)          1;0  1.3626 (0.96)          5           1
test_1000x_writes     793.8229 (1.19)     937.1953 (1.25)     850.9716 (1.21)     54.4067 (1.46)     847.3260 (1.21)     65.9041 (1.0)           2;0  1.1751 (0.83)          5           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

* Queue writes with a dict for O(1) look-ups

Name (time in ms)          Min                   Max                Mean              StdDev              Median                 IQR            Outliers     OPS            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_1000x_reads      625.5781 (1.0)        887.4346 (1.0)      694.6221 (1.0)      109.0048 (1.0)      658.3128 (1.0)       87.7517 (1.0)           1;1  1.4396 (1.0)           5           1
test_1000x_writes     637.3099 (1.02)     1,021.0924 (1.15)     812.9789 (1.17)     150.2342 (1.38)     757.7635 (1.15)     215.9572 (2.46)          2;0  1.2300 (0.85)          5           1
test_standard         694.1814 (1.11)     1,037.9224 (1.17)     845.1463 (1.22)     136.2068 (1.25)     785.6973 (1.19)     194.9636 (2.22)          2;0  1.1832 (0.82)          5           1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

* Use type aliases

* return both keys from _snapshot_write_queue_key

* Use a defaultdict

* Update comments
  • Loading branch information
huonw authored Jan 13, 2025
1 parent ef8189c commit 0f6bb55
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 27 deletions.
6 changes: 1 addition & 5 deletions src/syrupy/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,11 +377,7 @@ def _recall_data(
) -> Tuple[Optional["SerializableData"], bool]:
try:
return (
self.extension.read_snapshot(
test_location=self.test_location,
index=index,
session_id=str(id(self.session)),
),
self.session.recall_snapshot(self.extension, self.test_location, index),
False,
)
except SnapshotDoesNotExist:
Expand Down
35 changes: 25 additions & 10 deletions src/syrupy/location.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from syrupy.constants import PYTEST_NODE_SEP


@dataclass
@dataclass(frozen=True)
class PyTestLocation:
item: "pytest.Item"
nodename: Optional[str] = field(init=False)
Expand All @@ -23,27 +23,42 @@ class PyTestLocation:
filepath: str = field(init=False)

def __post_init__(self) -> None:
# NB. we're in a frozen dataclass, but need to transform the values that the caller
# supplied... we do so by (ab)using object.__setattr__ to forcibly set the attributes. (See
# rejected PEP-0712 for an example of a better way to handle this.)
#
# This is safe because this all happens during initialization: `self` hasn't been hashed
# (or, e.g., stored in a dict), so the mutation won't be noticed.
if self.is_doctest:
return self.__attrs_post_init_doc__()
self.__attrs_post_init_def__()

def __attrs_post_init_def__(self) -> None:
node_path: Path = getattr(self.item, "path") # noqa: B009
self.filepath = str(node_path.absolute())
# See __post_init__ for discussion of object.__setattr__
object.__setattr__(self, "filepath", str(node_path.absolute()))
obj = getattr(self.item, "obj") # noqa: B009
self.modulename = obj.__module__
self.methodname = obj.__name__
self.nodename = getattr(self.item, "name", None)
self.testname = self.nodename or self.methodname
object.__setattr__(self, "modulename", obj.__module__)
object.__setattr__(self, "methodname", obj.__name__)
object.__setattr__(self, "nodename", getattr(self.item, "name", None))
object.__setattr__(self, "testname", self.nodename or self.methodname)

def __attrs_post_init_doc__(self) -> None:
doctest = getattr(self.item, "dtest") # noqa: B009
self.filepath = doctest.filename
# See __post_init__ for discussion of object.__setattr__
object.__setattr__(self, "filepath", doctest.filename)
test_relfile, test_node = self.nodeid.split(PYTEST_NODE_SEP)
test_relpath = Path(test_relfile)
self.modulename = ".".join([*test_relpath.parent.parts, test_relpath.stem])
self.nodename = test_node.replace(f"{self.modulename}.", "")
self.testname = self.nodename or self.methodname
object.__setattr__(
self,
"modulename",
".".join([*test_relpath.parent.parts, test_relpath.stem]),
)
object.__setattr__(self, "methodname", None)
object.__setattr__(
self, "nodename", test_node.replace(f"{self.modulename}.", "")
)
object.__setattr__(self, "testname", self.nodename or self.methodname)

@property
def classname(self) -> Optional[str]:
Expand Down
67 changes: 55 additions & 12 deletions src/syrupy/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ class ItemStatus(Enum):
SKIPPED = "skipped"


_QueuedWriteExtensionKey = Tuple[Type["AbstractSyrupyExtension"], str]
_QueuedWriteTestLocationKey = Tuple["PyTestLocation", "SnapshotIndex"]


@dataclass
class SnapshotSession:
pytest_session: "pytest.Session"
Expand All @@ -62,10 +66,28 @@ class SnapshotSession:
default_factory=lambda: defaultdict(set)
)

_queued_snapshot_writes: Dict[
Tuple[Type["AbstractSyrupyExtension"], str],
List[Tuple["SerializedData", "PyTestLocation", "SnapshotIndex"]],
] = field(default_factory=dict)
# For performance, we buffer snapshot writes in memory before flushing them to disk. In
# particular, we want to be able to write to a file on disk only once, rather than having to
# repeatedly rewrite it.
#
# That batching leads to using two layers of dicts here: the outer layer represents the
# extension/file-location pair that will be written, and the inner layer represents the
# snapshots within that, "indexed" to allow efficient recall.
_queued_snapshot_writes: DefaultDict[
_QueuedWriteExtensionKey,
Dict[_QueuedWriteTestLocationKey, "SerializedData"],
] = field(default_factory=lambda: defaultdict(dict))

def _snapshot_write_queue_keys(
self,
extension: "AbstractSyrupyExtension",
test_location: "PyTestLocation",
index: "SnapshotIndex",
) -> Tuple[_QueuedWriteExtensionKey, _QueuedWriteTestLocationKey]:
snapshot_location = extension.get_location(
test_location=test_location, index=index
)
return (extension.__class__, snapshot_location), (test_location, index)

def queue_snapshot_write(
self,
Expand All @@ -74,13 +96,10 @@ def queue_snapshot_write(
data: "SerializedData",
index: "SnapshotIndex",
) -> None:
snapshot_location = extension.get_location(
test_location=test_location, index=index
ext_key, loc_key = self._snapshot_write_queue_keys(
extension, test_location, index
)
key = (extension.__class__, snapshot_location)
queue = self._queued_snapshot_writes.get(key, [])
queue.append((data, test_location, index))
self._queued_snapshot_writes[key] = queue
self._queued_snapshot_writes[ext_key][loc_key] = data

def flush_snapshot_write_queue(self) -> None:
for (
Expand All @@ -89,9 +108,33 @@ def flush_snapshot_write_queue(self) -> None:
), queued_write in self._queued_snapshot_writes.items():
if queued_write:
extension_class.write_snapshot(
snapshot_location=snapshot_location, snapshots=queued_write
snapshot_location=snapshot_location,
snapshots=[
(data, loc, index)
for (loc, index), data in queued_write.items()
],
)
self._queued_snapshot_writes = {}
self._queued_snapshot_writes.clear()

def recall_snapshot(
self,
extension: "AbstractSyrupyExtension",
test_location: "PyTestLocation",
index: "SnapshotIndex",
) -> Optional["SerializedData"]:
"""Find the current value of the snapshot, for this session, either a pending write or the actual snapshot."""

ext_key, loc_key = self._snapshot_write_queue_keys(
extension, test_location, index
)
data = self._queued_snapshot_writes[ext_key].get(loc_key)
if data is not None:
return data

# No matching write queued, so just read the snapshot directly:
return extension.read_snapshot(
test_location=test_location, index=index, session_id=str(id(self))
)

@property
def update_snapshots(self) -> bool:
Expand Down
56 changes: 56 additions & 0 deletions tests/integration/test_snapshot_diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest

_TEST = """
def test_foo(snapshot):
assert {**base} == snapshot(name="a")
assert {**base, **extra} == snapshot(name="b", diff="a")
"""


def _make_file(testdir, base, extra):
testdir.makepyfile(
test_file="\n\n".join([f"base = {base!r}", f"extra = {extra!r}", _TEST])
)


def _run_test(testdir, base, extra, expected_update_lines):
_make_file(testdir, base=base, extra=extra)

# Run with --snapshot-update, to generate/update snapshots:
result = testdir.runpytest(
"-v",
"--snapshot-update",
)
result.stdout.re_match_lines((expected_update_lines,))
assert result.ret == 0

# Run without --snapshot-update, to validate the snapshots are actually up-to-date
result = testdir.runpytest("-v")
result.stdout.re_match_lines((r"2 snapshots passed\.",))
assert result.ret == 0


def test_diff_lifecycle(testdir) -> pytest.Testdir:
# first: create both snapshots completely from scratch
_run_test(
testdir,
base={"A": 1},
extra={"X": 10},
expected_update_lines=r"2 snapshots generated\.",
)

# second: edit the base data, to change the data for both snapshots (only changes the serialized output for the base snapshot `a`).
_run_test(
testdir,
base={"A": 1, "B": 2},
extra={"X": 10},
expected_update_lines=r"1 snapshot passed. 1 snapshot updated\.",
)

# third: edit just the extra data (only changes the serialized output for the diff snapshot `b`)
_run_test(
testdir,
base={"A": 1, "B": 2},
extra={"X": 10, "Y": 20},
expected_update_lines=r"1 snapshot passed. 1 snapshot updated\.",
)

0 comments on commit 0f6bb55

Please sign in to comment.