Skip to content

Commit

Permalink
properly type annotate default value generation
Browse files Browse the repository at this point in the history
  • Loading branch information
Kamilcuk committed Sep 29, 2024
1 parent 972b469 commit e9e7e8c
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ log_file_format = "%(asctime)s %(levelname)s %(filename)s:%(lineno)s %(message)s
log_file_level = "DEBUG"

[tool.pyright]
include = ["src"]
include = ["src", "tests"]
pythonVersion = "3.7"
typeCheckingMode = "basic"
reportUnnecessaryComparison = "error"
Expand Down
57 changes: 47 additions & 10 deletions src/clickdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,19 @@
Tuple,
Union,
get_type_hints,
overload,
)

import click
from typing_extensions import Protocol, Type, get_args, get_origin
from typing_extensions import Protocol, Type, TypeVar, get_args, get_origin

try:
from typing import NoneType # pyright: ignore
except ImportError:
NoneType = type(None)

T = TypeVar("T")

###############################################################################

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -295,12 +298,46 @@ def _myfields(arg_class: DataclassType) -> List[Field]:
return ret


def _mkfield(func: Decorator, clickdc: _OptsArg, args, kwargs):
@overload
def _mkfield(
func: Decorator,
clickdc: _OptsArg,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> Any: ...
@overload
def _mkfield(
func: Decorator,
clickdc: _OptsArg,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
default: Callable[[], T],
) -> T: ...


@overload
def _mkfield(
func: Decorator,
clickdc: _OptsArg,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
default: T,
) -> T: ...
def _mkfield(
func: Decorator,
clickdc: _OptsArg,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
default: Any = dataclasses.MISSING,
):
clickdc = Opts(no=True) if clickdc is None else clickdc
return dataclasses.field(
default_factory=lambda: kwargs.get("default"),
metadata={TAG: FieldDesc(func, clickdc, args, kwargs)},
)
metadata = {TAG: FieldDesc(func, clickdc, args, kwargs)}
if default is dataclasses.MISSING:
return dataclasses.field(metadata=metadata)
elif callable(default):
return dataclasses.field(default_factory=default, metadata=metadata)
else:
return dataclasses.field(default=default, metadata=metadata)


###############################################################################
Expand All @@ -317,13 +354,13 @@ def command(*args, clickdc: _OptsArg = Opts(), **kwargs):


@functools.wraps(click.option)
def option(*args, clickdc: _OptsArg = Opts(), **kwargs):
return _mkfield(click.option, clickdc, args, kwargs)
def option(*args, clickdc: _OptsArg = Opts(), default: Any = None, **kwargs):
return _mkfield(click.option, clickdc, args, kwargs, default)


@functools.wraps(click.argument)
def argument(*args, clickdc: _OptsArg = Opts(), **kwargs):
return _mkfield(click.argument, clickdc, args, kwargs)
def argument(*args, clickdc: _OptsArg = Opts(), default: Any = None, **kwargs):
return _mkfield(click.argument, clickdc, args, kwargs, default)


def _assert_annotations(arg_class: DataclassType):
Expand Down

0 comments on commit e9e7e8c

Please sign in to comment.