Skip to content

Commit

Permalink
fix: handle positional dataclass arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
15r10nk committed Jan 17, 2025
1 parent ea2977b commit 219dd6d
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 10 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ dependencies = [
installer="uv"

[tool.hatch.envs.cov.scripts]
gh=[
github=[
"- rm htmlcov/*",
"gh run download -n html-report -D htmlcov",
"xdg-open htmlcov/index.html",
Expand Down
7 changes: 5 additions & 2 deletions src/inline_snapshot/_adapter/generic_call_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,11 @@ def arguments(cls, value):
return ([], kwargs)

def argument(self, value, pos_or_name):
assert isinstance(pos_or_name, str)
return getattr(value, pos_or_name)
if isinstance(pos_or_name, str):
return getattr(value, pos_or_name)
else:
args = [field for field in fields(value) if field.init]
return args[pos_or_name]


try:
Expand Down
59 changes: 52 additions & 7 deletions tests/adapter/test_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ class A:
c:list=field(default_factory=list)
def test_something():
assert A(a=1) == snapshot(A(a=1,b=2,c=[]))
for _ in [1,2]:
assert A(a=1) == snapshot(A(a=1,b=2,c=[]))
"""
).run_inline(
["--inline-snapshot=update"],
Expand All @@ -112,7 +113,47 @@ class A:
c:list=field(default_factory=list)
def test_something():
assert A(a=1) == snapshot(A(a=1))
for _ in [1,2]:
assert A(a=1) == snapshot(A(a=1))
"""
}
),
)


def test_dataclass_positional_arguments():
Example(
"""\
from inline_snapshot import snapshot,Is
from dataclasses import dataclass,field
@dataclass
class A:
a:int
b:int=2
c:list=field(default_factory=list)
def test_something():
for _ in [1,2]:
assert A(a=1) == snapshot(A(1,2,c=[]))
"""
).run_inline(
["--inline-snapshot=update"],
changed_files=snapshot(
{
"test_something.py": """\
from inline_snapshot import snapshot,Is
from dataclasses import dataclass,field
@dataclass
class A:
a:int
b:int=2
c:list=field(default_factory=list)
def test_something():
for _ in [1,2]:
assert A(a=1) == snapshot(A(1,2))
"""
}
),
Expand Down Expand Up @@ -400,12 +441,14 @@ def argument(cls, value, pos_or_name):
return value.l[pos_or_name]
def test_L1():
assert L(1,2) == snapshot(L(1)), "not equal"
for _ in [1,2]:
assert L(1,2) == snapshot(L(1)), "not equal"
def test_L2():
assert L(1,2) == snapshot(L(1, 2, 3)), "not equal"
for _ in [1,2]:
assert L(1,2) == snapshot(L(1, 2, 3)), "not equal"
"""
).run_pytest(
).run_pytest().run_pytest(
["--inline-snapshot=fix"],
changed_files=snapshot(
{
Expand Down Expand Up @@ -439,10 +482,12 @@ def argument(cls, value, pos_or_name):
return value.l[pos_or_name]
def test_L1():
assert L(1,2) == snapshot(L(1, 2)), "not equal"
for _ in [1,2]:
assert L(1,2) == snapshot(L(1, 2)), "not equal"
def test_L2():
assert L(1,2) == snapshot(L(1, 2)), "not equal"
for _ in [1,2]:
assert L(1,2) == snapshot(L(1, 2)), "not equal"
"""
}
),
Expand Down

0 comments on commit 219dd6d

Please sign in to comment.