Skip to content

Commit

Permalink
refactor!: improved run error handling and improved API
Browse files Browse the repository at this point in the history
  • Loading branch information
noirbizarre committed Dec 11, 2023
1 parent f3ae5cf commit 91e12bb
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 41 deletions.
4 changes: 2 additions & 2 deletions src/pytest_copier/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ class CopierTaskError(CopierError):
"""Triggered by post-generation tasks"""


class ProjectRunError(CopierError):
"""Triggered by a command executed in th project"""
class RunError(CopierError):
"""Triggered by a failed command"""
99 changes: 60 additions & 39 deletions src/pytest_copier/plugin.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

import shlex
import subprocess

from dataclasses import dataclass
from functools import cached_property
from io import StringIO
from pathlib import Path
from shutil import copy, copytree
from typing import TYPE_CHECKING, Any, Mapping, cast
from typing import TYPE_CHECKING, Any, cast

import pytest
import yaml
Expand All @@ -18,7 +18,7 @@
from plumbum import local
from pytest_dir_equal import DEFAULT_IGNORES, DiffRepr, assert_dir_equal

from .errors import CopierTaskError, ProjectRunError
from .errors import CopierTaskError, RunError

if TYPE_CHECKING:
from pytest_gitconfig import GitConfig
Expand All @@ -45,22 +45,24 @@ def expected_lines(self) -> list[str]:
return self._as_lines(self.expected)


def run(cmd: str, *args, **kwargs) -> None:
args = [cmd, *args] if args else shlex.split(cmd) # type: ignore
def run(cmd: str, *args, **kwargs) -> str:
args = [cmd, *args] if args else cmd # type: ignore
try:
subprocess.run(args, check=True, capture_output=True, **kwargs)
return subprocess.check_output(
args, text=True, stderr=subprocess.STDOUT, shell=True, **kwargs
)
except subprocess.CalledProcessError as e:
out = StringIO()
tw = TerminalWriter(out)
output = e.stdout.decode("utf-8").replace("\n", "\n>>> ")
error = e.stderr.decode("utf-8").replace("\n", "\n>>> ")
tw.hasmarkup = True
tw.line(f"{str(e)}\n")
if output:
tw.line(f"Standard output:\n>>> {output}\n")
if error:
tw.line(f"Standard error:\n>>> {error}")
raise RuntimeError(out.getvalue()) from e
tw.line(f"❌ {str(e)}\n")
if e.output:
prefix = tw.markup("│ ", red=True)
tw.line("╭╼ Combined output:")
for line in e.output.splitlines():
tw.line(f"{prefix} {line}")
tw.line("╰╼")
raise RunError(out.getvalue()) from e


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -119,16 +121,16 @@ def copier_template(
@dataclass
class CopierFixture:
template: Path
dst: Path
defaults: dict[str, Any]
monkeypatch: pytest.MonkeyPatch

def copy(self, **data) -> CopierProject:
def copy(self, dst: Path, **data) -> CopierProject:
"""Copy a template given some answers"""
__tracebackhide__ = True
try:
run_copy(
str(self.template),
self.dst,
dst,
overwrite=True,
cleanup_on_error=False,
unsafe=True,
Expand All @@ -140,15 +142,14 @@ def copy(self, **data) -> CopierProject:
# we can produce a more streamlined error report
# we explicitly raise form None to cut the inner stacktrace too
raise CopierTaskError(f"❌ {e}") from None
return CopierProject(self.dst)
return CopierProject(dst, self)

def update(self, **data) -> CopierProject:
def update(self, project: Path, **data) -> CopierProject:
"""Update a template given some answers"""
__tracebackhide__ = True
try:
run_update(
str(self.template),
self.dst,
project,
overwrite=True,
cleanup_on_error=False,
unsafe=True,
Expand All @@ -159,9 +160,9 @@ def update(self, **data) -> CopierProject:
# we catch those error which are triggered by tasks
# we can produce a more streamlined error report
raise CopierTaskError(f"❌ {e}") from None
return CopierProject(self.dst)
return CopierProject(project, self)

def context(self, **answers) -> Mapping[str, Any]:
def context(self, **answers) -> dict[str, Any]:
"""Get the context rendered given some answers"""
__tracebackhide__ = True
worker = self.worker(**answers)
Expand All @@ -171,38 +172,54 @@ def context(self, **answers) -> Mapping[str, Any]:
ctx = env.context_class(env, data, "", {}, env.globals)
return ctx.get_all()

def worker(self, **answers) -> Worker:
def worker(self, dst: Path = Path(), **answers) -> Worker:
"""Get a worker with prefilled answers"""
return Worker(
src_path=str(self.template),
dst_path=self.dst,
dst_path=dst,
unsafe=True,
defaults=True,
data={**self.defaults, **answers},
)

def delenv(self, var: str):
"""Shortcut to monkeypatch.delenv both builtin os.environ and plumbum.local.env in Copier"""
self.monkeypatch.delenv(var, raising=False)
self.monkeypatch.delitem(local.env, var, raising=False)

def setenv(self, var: str, value: str):
"""Shortcut to monkeypatch.setenv builtin os.environ and plumbum.local.env in Copier"""
self.monkeypatch.setenv(var, value)
self.monkeypatch.setitem(local.env, var, value)


@dataclass
class CopierProject:
path: Path
copier: CopierFixture

def update(
self,
):
pass
def update(self, **data) -> CopierProject:
return self.copier.update(self.path, **data)

@cached_property
def answers(self) -> dict[str, Any]:
return self.load_answers(self.path)

@cached_property
def context(self) -> dict[str, Any]:
return self.copier.context(**self.answers)

def assert_answers(self, expected: Path):
__tracebackhide__ = True
expected_answers = self.load_answers(expected)
answers = self.load_answers(self.path)
if answers != expected_answers:
if self.answers != expected_answers:
out = StringIO()
tw = TerminalWriter(out)
tw.hasmarkup = True
tw.line("❌ Answers are different")
AnsersDiffRepr("Answers", answers, expected_answers).toterminal(tw)
AnsersDiffRepr("Answers", self.answers, expected_answers).toterminal(tw)
raise AssertionError(out.getvalue())
assert answers == expected_answers
assert self.answers == expected_answers

def load_answers(self, root: Path) -> dict[str, Any]:
file = root / ANSWERS_FILE
Expand All @@ -217,15 +234,15 @@ def assert_equal(self, expected: Path, ignore: list[str] | None = None):
ignore = DEFAULT_IGNORES + [ANSWERS_FILE] + (ignore or [])
assert_dir_equal(self.path, expected, ignore=ignore)

def run(self, command: str, **kwargs):
def run(self, command: str, **kwargs) -> str:
"""Run a command in the rendered project"""
__tracebackhide__ = True
try:
run(*shlex.split(command), cwd=self.path, **kwargs)
except subprocess.CalledProcessError as e:
return run(command, cwd=self.path, **kwargs)
except RunError as e:
# produce a more streamlined error report
# we explicitly raise form None to cut the inner stacktrace too
raise ProjectRunError(f"❌ {e}") from None
raise RuntimeError(str(e)) from None

def __truediv__(self, key):
"""Provide pathlib-like support"""
Expand All @@ -241,5 +258,9 @@ def copier_defaults() -> dict[str, Any]:


@pytest.fixture
def copier(tmp_path: Path, copier_template: Path, copier_defaults: dict[str, Any]) -> CopierFixture:
return CopierFixture(template=copier_template, dst=tmp_path / "dst", defaults=copier_defaults)
def copier(
copier_template: Path, copier_defaults: dict[str, Any], monkeypatch: pytest.MonkeyPatch
) -> CopierFixture:
return CopierFixture(
template=copier_template, defaults=copier_defaults, monkeypatch=monkeypatch
)

0 comments on commit 91e12bb

Please sign in to comment.