Skip to content

Commit

Permalink
Add **kwargs for various subclasses of AbstractSyrupyExtension
Browse files Browse the repository at this point in the history
  • Loading branch information
atharva-2001 committed Sep 25, 2023
1 parent 7730070 commit 343ae08
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 18 deletions.
4 changes: 2 additions & 2 deletions src/syrupy/extensions/amber/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def delete_snapshots(
else:
Path(snapshot_location).unlink()

def _read_snapshot_collection(self, snapshot_location: str) -> "SnapshotCollection":
def _read_snapshot_collection(self, snapshot_location: str, **kwargs: Any) -> "SnapshotCollection":
return self.serializer_class.read_file(snapshot_location)

@classmethod
Expand All @@ -72,7 +72,7 @@ def _read_snapshot_data_from_location(

@classmethod
def _write_snapshot_collection(
cls, *, snapshot_collection: "SnapshotCollection"
cls, *, snapshot_collection: "SnapshotCollection", **kwargs: Any
) -> None:
cls.serializer_class.write_file(snapshot_collection, merge=True)

Expand Down
23 changes: 17 additions & 6 deletions src/syrupy/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Optional,
Set,
Tuple,
Any,
)

from syrupy.constants import (
Expand Down Expand Up @@ -67,6 +68,7 @@ def serialize(
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
**kwargs: Any,
) -> "SerializedData":
"""
Serializes a python object / data structure into a string
Expand Down Expand Up @@ -108,7 +110,7 @@ def is_snapshot_location(self, *, location: str) -> bool:
return location.endswith(self._file_extension)

def discover_snapshots(
self, *, test_location: "PyTestLocation"
self, *, test_location: "PyTestLocation", **kwargs: Any
) -> "SnapshotCollections":
"""
Returns all snapshot collections in test site
Expand Down Expand Up @@ -216,7 +218,7 @@ def delete_snapshots(

@abstractmethod
def _read_snapshot_collection(
self, *, snapshot_location: str
self, *, snapshot_location: str, **kwargs: Any
) -> "SnapshotCollection":
"""
Read the snapshot location and construct a snapshot collection object
Expand All @@ -235,15 +237,17 @@ def _read_snapshot_data_from_location(
@classmethod
@abstractmethod
def _write_snapshot_collection(
cls, *, snapshot_collection: "SnapshotCollection"
cls, *, snapshot_collection: "SnapshotCollection", **kwargs: Any
) -> None:
"""
Adds the snapshot data to the snapshots in collection location
"""
raise NotImplementedError

@classmethod
def dirname(cls, *, test_location: "PyTestLocation") -> str:
def dirname(
cls, *, test_location: "PyTestLocation", **kwargs: Any
) -> str:
test_dir = Path(test_location.filepath).parent
return str(test_dir.joinpath(SNAPSHOT_DIRNAME))

Expand All @@ -259,15 +263,21 @@ class SnapshotReporter(ABC):
_context_line_count = 1

def diff_snapshots(
self, serialized_data: "SerializedData", snapshot_data: "SerializedData"
self,
serialized_data: "SerializedData",
snapshot_data: "SerializedData",
**kwargs: Any,
) -> "SerializedData":
env = {DISABLE_COLOR_ENV_VAR: "true"}
attrs = {"_context_line_count": 0}
with env_context(**env), obj_attrs(self, attrs):
return "\n".join(self.diff_lines(serialized_data, snapshot_data))

def diff_lines(
self, serialized_data: "SerializedData", snapshot_data: "SerializedData"
self,
serialized_data: "SerializedData",
snapshot_data: "SerializedData",
**kwargs: Any,
) -> Iterator[str]:
for line in self.__diff_lines(str(snapshot_data), str(serialized_data)):
yield reset(line)
Expand Down Expand Up @@ -407,6 +417,7 @@ def matches(
*,
serialized_data: "SerializableData",
snapshot_data: "SerializableData",
**kwargs: Any,
) -> bool:
"""
Compares serialized data and snapshot data and returns
Expand Down
1 change: 1 addition & 0 deletions src/syrupy/extensions/json/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def serialize(
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
**kwargs: Any,
) -> "SerializedData":
data = self._filter(
data=data,
Expand Down
23 changes: 13 additions & 10 deletions src/syrupy/extensions/single_file.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
from enum import Enum
from gettext import gettext
from pathlib import Path
from typing import (
TYPE_CHECKING,
Optional,
Set,
Type,
Union,
)
from typing import TYPE_CHECKING, Optional, Set, Type, Union, Dict, Any
from unicodedata import category

from syrupy.constants import TEXT_ENCODING
Expand Down Expand Up @@ -49,6 +43,7 @@ def serialize(
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
**kwargs: Any,
) -> "SerializedData":
return self.get_supported_dataclass()(data)

Expand All @@ -74,12 +69,17 @@ def _get_file_basename(
return cls.get_snapshot_name(test_location=test_location, index=index)

@classmethod
def dirname(cls, *, test_location: "PyTestLocation") -> str:
def dirname(
cls, *, test_location: "PyTestLocation", **kwargs: Any
) -> str:
original_dirname = AbstractSyrupyExtension.dirname(test_location=test_location)
return str(Path(original_dirname).joinpath(test_location.basename))

def _read_snapshot_collection(
self, *, snapshot_location: str
self,
*,
snapshot_location: str,
**kwargs: Any,
) -> "SnapshotCollection":
file_ext_len = len(self._file_extension) + 1 if self._file_extension else 0
filename_wo_ext = snapshot_location[:-file_ext_len]
Expand Down Expand Up @@ -116,7 +116,10 @@ def get_write_encoding(cls) -> Optional[str]:

@classmethod
def _write_snapshot_collection(
cls, *, snapshot_collection: "SnapshotCollection"
cls,
*,
snapshot_collection: "SnapshotCollection",
**kwargs: Any,
) -> None:
filepath, data = (
snapshot_collection.location,
Expand Down

0 comments on commit 343ae08

Please sign in to comment.