Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: check current session's pending-write queue when recalling snapshots (e.g. diffing) #927

Merged
merged 8 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions src/syrupy/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,11 +377,7 @@ def _recall_data(
) -> Tuple[Optional["SerializableData"], bool]:
try:
return (
self.extension.read_snapshot(
test_location=self.test_location,
index=index,
session_id=str(id(self.session)),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I note that #606 added this session_id parameter, which I imagine was meant to help with circumstances like this, but it seems like it ends up ignored/unused in practice?

And, in practice, I think it may be quite hard to use, since the relevant cached data is held in memory in the session itself (i.e. the extension _read_snapshot_data_from_location method calls don't really have access to it at all).

Copy link
Collaborator

@noahnu noahnu Dec 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The session_id is used as the cache_key argument in Amber's __cacheable_read_snapshot method. Although it's not used in the function, it's used by the lru_cache decorator (which caches based on the kwargs I believe). So it essentially invalidates the cache

),
self.session.recall_snapshot(self.extension, self.test_location, index),
False,
)
except SnapshotDoesNotExist:
Expand Down
35 changes: 25 additions & 10 deletions src/syrupy/location.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from syrupy.constants import PYTEST_NODE_SEP


@dataclass
@dataclass(frozen=True)
class PyTestLocation:
item: "pytest.Item"
nodename: Optional[str] = field(init=False)
Expand All @@ -23,27 +23,42 @@ class PyTestLocation:
filepath: str = field(init=False)

def __post_init__(self) -> None:
# NB. we're in a frozen dataclass, but need to transform the values that the caller
# supplied... we do so by (ab)using object.__setattr__ to forcibly set the attributes. (See
# rejected PEP-0712 for an example of a better way to handle this.)
#
# This is safe because this all happens during initialization: `self` hasn't been hashed
# (or, e.g., stored in a dict), so the mutation won't be noticed.
if self.is_doctest:
return self.__attrs_post_init_doc__()
self.__attrs_post_init_def__()

def __attrs_post_init_def__(self) -> None:
node_path: Path = getattr(self.item, "path") # noqa: B009
self.filepath = str(node_path.absolute())
# See __post_init__ for discussion of object.__setattr__
object.__setattr__(self, "filepath", str(node_path.absolute()))
obj = getattr(self.item, "obj") # noqa: B009
self.modulename = obj.__module__
self.methodname = obj.__name__
self.nodename = getattr(self.item, "name", None)
self.testname = self.nodename or self.methodname
object.__setattr__(self, "modulename", obj.__module__)
object.__setattr__(self, "methodname", obj.__name__)
object.__setattr__(self, "nodename", getattr(self.item, "name", None))
object.__setattr__(self, "testname", self.nodename or self.methodname)

def __attrs_post_init_doc__(self) -> None:
doctest = getattr(self.item, "dtest") # noqa: B009
self.filepath = doctest.filename
# See __post_init__ for discussion of object.__setattr__
object.__setattr__(self, "filepath", doctest.filename)
test_relfile, test_node = self.nodeid.split(PYTEST_NODE_SEP)
test_relpath = Path(test_relfile)
self.modulename = ".".join([*test_relpath.parent.parts, test_relpath.stem])
self.nodename = test_node.replace(f"{self.modulename}.", "")
self.testname = self.nodename or self.methodname
object.__setattr__(
self,
"modulename",
".".join([*test_relpath.parent.parts, test_relpath.stem]),
)
object.__setattr__(self, "methodname", None)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting this attribute is newly added: it previously wasn't being set on this codepath, and attempting to hash/compare the location values (to put into the queue dictionary) was blowing up when accessing it.

It seems like it was previously not read at all for doc tests?

object.__setattr__(
self, "nodename", test_node.replace(f"{self.modulename}.", "")
)
object.__setattr__(self, "testname", self.nodename or self.methodname)

@property
def classname(self) -> Optional[str]:
Expand Down
67 changes: 55 additions & 12 deletions src/syrupy/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ class ItemStatus(Enum):
SKIPPED = "skipped"


_QueuedWriteExtensionKey = Tuple[Type["AbstractSyrupyExtension"], str]
_QueuedWriteTestLocationKey = Tuple["PyTestLocation", "SnapshotIndex"]


@dataclass
class SnapshotSession:
pytest_session: "pytest.Session"
Expand All @@ -62,10 +66,28 @@ class SnapshotSession:
default_factory=lambda: defaultdict(set)
)

_queued_snapshot_writes: Dict[
Tuple[Type["AbstractSyrupyExtension"], str],
List[Tuple["SerializedData", "PyTestLocation", "SnapshotIndex"]],
] = field(default_factory=dict)
# For performance, we buffer snapshot writes in memory before flushing them to disk. In
# particular, we want to be able to write to a file on disk only once, rather than having to
# repeatedly rewrite it.
#
# That batching leads to using two layers of dicts here: the outer layer represents the
# extension/file-location pair that will be written, and the inner layer represents the
# snapshots within that, "indexed" to allow efficient recall.
_queued_snapshot_writes: DefaultDict[
_QueuedWriteExtensionKey,
Dict[_QueuedWriteTestLocationKey, "SerializedData"],
] = field(default_factory=lambda: defaultdict(dict))

def _snapshot_write_queue_keys(
self,
extension: "AbstractSyrupyExtension",
test_location: "PyTestLocation",
index: "SnapshotIndex",
) -> Tuple[_QueuedWriteExtensionKey, _QueuedWriteTestLocationKey]:
snapshot_location = extension.get_location(
test_location=test_location, index=index
)
return (extension.__class__, snapshot_location), (test_location, index)

def queue_snapshot_write(
self,
Expand All @@ -74,13 +96,10 @@ def queue_snapshot_write(
data: "SerializedData",
index: "SnapshotIndex",
) -> None:
snapshot_location = extension.get_location(
test_location=test_location, index=index
ext_key, loc_key = self._snapshot_write_queue_keys(
extension, test_location, index
)
key = (extension.__class__, snapshot_location)
queue = self._queued_snapshot_writes.get(key, [])
queue.append((data, test_location, index))
self._queued_snapshot_writes[key] = queue
self._queued_snapshot_writes[ext_key][loc_key] = data

def flush_snapshot_write_queue(self) -> None:
for (
Expand All @@ -89,9 +108,33 @@ def flush_snapshot_write_queue(self) -> None:
), queued_write in self._queued_snapshot_writes.items():
if queued_write:
extension_class.write_snapshot(
snapshot_location=snapshot_location, snapshots=queued_write
snapshot_location=snapshot_location,
snapshots=[
(data, loc, index)
for (loc, index), data in queued_write.items()
],
)
self._queued_snapshot_writes = {}
self._queued_snapshot_writes.clear()

def recall_snapshot(
self,
extension: "AbstractSyrupyExtension",
test_location: "PyTestLocation",
index: "SnapshotIndex",
) -> Optional["SerializedData"]:
"""Find the current value of the snapshot, for this session, either a pending write or the actual snapshot."""

ext_key, loc_key = self._snapshot_write_queue_keys(
extension, test_location, index
)
data = self._queued_snapshot_writes[ext_key].get(loc_key)
if data is not None:
return data

# No matching write queued, so just read the snapshot directly:
return extension.read_snapshot(
test_location=test_location, index=index, session_id=str(id(self))
)

@property
def update_snapshots(self) -> bool:
Expand Down
56 changes: 56 additions & 0 deletions tests/integration/test_snapshot_diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest

_TEST = """
def test_foo(snapshot):
assert {**base} == snapshot(name="a")
assert {**base, **extra} == snapshot(name="b", diff="a")
"""


def _make_file(testdir, base, extra):
testdir.makepyfile(
test_file="\n\n".join([f"base = {base!r}", f"extra = {extra!r}", _TEST])
)


def _run_test(testdir, base, extra, expected_update_lines):
_make_file(testdir, base=base, extra=extra)

# Run with --snapshot-update, to generate/update snapshots:
result = testdir.runpytest(
"-v",
"--snapshot-update",
)
result.stdout.re_match_lines((expected_update_lines,))
assert result.ret == 0

# Run without --snapshot-update, to validate the snapshots are actually up-to-date
result = testdir.runpytest("-v")
result.stdout.re_match_lines((r"2 snapshots passed\.",))
assert result.ret == 0


def test_diff_lifecycle(testdir) -> pytest.Testdir:
# first: create both snapshots completely from scratch
_run_test(
testdir,
base={"A": 1},
extra={"X": 10},
expected_update_lines=r"2 snapshots generated\.",
)

# second: edit the base data, to change the data for both snapshots (only changes the serialized output for the base snapshot `a`).
_run_test(
testdir,
base={"A": 1, "B": 2},
extra={"X": 10},
expected_update_lines=r"1 snapshot passed. 1 snapshot updated\.",
)

# third: edit just the extra data (only changes the serialized output for the diff snapshot `b`)
_run_test(
testdir,
base={"A": 1, "B": 2},
extra={"X": 10, "Y": 20},
expected_update_lines=r"1 snapshot passed. 1 snapshot updated\.",
)
Loading