Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TypedDict for params and data arguments; add type hint for kwargs #166

Merged
merged 9 commits into from
Sep 28, 2023
2 changes: 2 additions & 0 deletions atproto/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
DISCLAIMER = f'{"#" * _MAX_DISCLAIMER_LEN}\n{DISCLAIMER}\n{"#" * _MAX_DISCLAIMER_LEN}\n\n'

PARAMS_MODEL = 'Params'
PARAMS_DICT = 'ParamsDict'
INPUT_MODEL = 'Data'
INPUT_DICT = 'DataDict'
OUTPUT_MODEL = 'Response'


Expand Down
103 changes: 95 additions & 8 deletions atproto/codegen/models/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

from atproto.codegen import (
DISCLAIMER,
INPUT_DICT,
INPUT_MODEL,
OUTPUT_MODEL,
PARAMS_DICT,
PARAMS_MODEL,
_resolve_nsid_ref,
append_code,
Expand Down Expand Up @@ -36,6 +38,11 @@ class ModelType(Enum):
RECORD = 'Record'


class TypedDictType(Enum):
PARAMS = 'Parameters'
DATA = 'Input data'


def save_code(nsid: NSID, code: str) -> None:
path_to_file = _MODELS_OUTPUT_DIR.joinpath(*get_file_path_parts(nsid))
write_code(_MODELS_OUTPUT_DIR.joinpath(path_to_file), code)
Expand All @@ -57,6 +64,7 @@ def _get_model_imports() -> str:
'if t.TYPE_CHECKING:',
f'{_(1)}from atproto.xrpc_client import models',
f'{_(1)}from atproto.xrpc_client.models.unknown_type import UnknownType',
f'{_(1)}from atproto.xrpc_client.models.unknown_type import UnknownInputType',
f'{_(1)}from atproto.xrpc_client.models.blob_ref import BlobRef',
f'{_(1)}from atproto import CIDType',
'from atproto.xrpc_client.models import base',
Expand All @@ -78,7 +86,7 @@ def _save_code_import_if_not_exist(nsid: NSID) -> None:


def _get_model_class_def(name: str, model_type: ModelType) -> str:
lines = []
lines: t.List[str] = []

if model_type is ModelType.PARAMS:
lines.append(f'class {PARAMS_MODEL}(base.ParamsModelBase):')
Expand All @@ -96,6 +104,19 @@ def _get_model_class_def(name: str, model_type: ModelType) -> str:
return join_code(lines)


def _get_typeddict_class_def(name: str, model_type: TypedDictType) -> str:
lines: t.List[str] = []

if model_type is TypedDictType.PARAMS:
lines.append(f'class {PARAMS_DICT}(te.TypedDict):')
elif model_type is TypedDictType.DATA:
lines.append(f'class {INPUT_DICT}(te.TypedDict):')

lines.append('')

return join_code(lines)


_LEXICON_TYPE_TO_PRIMITIVE_TYPEHINT = {
models.LexString: 'str',
models.LexInteger: 'int',
Expand Down Expand Up @@ -135,19 +156,23 @@ def _get_ref_union_typehint(nsid: NSID, field_type_def, *, optional: bool) -> st
return _get_optional_typehint(annotated_union, optional=optional)


def _get_model_field_typehint(nsid: NSID, field_type_def, *, optional: bool) -> str:
def _get_model_field_typehint(nsid: NSID, field_type_def, *, optional: bool, is_input_type: bool = False) -> str:
field_type = type(field_type_def)

if field_type == models.LexUnknown:
# unknown type is a generic response with records or any not described type in the lexicon. for example, didDoc
if is_input_type:
return _get_optional_typehint("'UnknownInputType'", optional=optional)
return _get_optional_typehint("'UnknownType'", optional=optional)

type_hint = _LEXICON_TYPE_TO_PRIMITIVE_TYPEHINT.get(field_type)
if type_hint:
return _get_optional_typehint(type_hint, optional=optional)

if field_type is models.LexArray:
items_type_hint = _get_model_field_typehint(nsid, field_type_def.items, optional=False)
items_type_hint = _get_model_field_typehint(
nsid, field_type_def.items, optional=False, is_input_type=is_input_type
)
return _get_optional_typehint(f't.List[{items_type_hint}]', optional=optional)

if field_type is models.LexRef:
Expand Down Expand Up @@ -308,7 +333,9 @@ def _is_reserved_pydantic_name(name: str) -> bool:
return name in _get_pydantic_reserved_names()


def _get_model(nsid: NSID, lex_object: t.Union[models.LexObject, models.LexXrpcParameters]) -> str:
def _get_model(
nsid: NSID, lex_object: t.Union[models.LexObject, models.LexXrpcParameters], *, is_input_type: bool = False
) -> str:
required_fields = _get_req_fields_set(lex_object)

fields = []
Expand All @@ -329,7 +356,7 @@ def _get_model(nsid: NSID, lex_object: t.Union[models.LexObject, models.LexXrpcP
snake_cased_field_name += '_' # add underscore to the end
alias_name = field_name

type_hint = _get_model_field_typehint(nsid, field_type_def, optional=is_optional)
type_hint = _get_model_field_typehint(nsid, field_type_def, optional=is_optional, is_input_type=is_input_type)
value = _get_model_field_value(field_type_def, alias_name, optional=is_optional)
description = _get_field_docstring(field_name, field_type_def)

Expand All @@ -351,6 +378,48 @@ def _get_model(nsid: NSID, lex_object: t.Union[models.LexObject, models.LexXrpcP
return join_code(fields)


def _get_typeddict(
nsid: NSID, lex_object: t.Union[models.LexObject, models.LexXrpcParameters], *, is_input_type: bool = False
) -> str:
required_fields = _get_req_fields_set(lex_object)

fields: t.List[str] = []
optional_fields: t.List[str] = []

for field_name, field_type_def in lex_object.properties.items():
is_optional = field_name not in required_fields

snake_cased_field_name = convert_camel_case_to_snake_case(field_name)

if _is_reserved_pydantic_name(snake_cased_field_name):
# make aliases for fields with reserved names
snake_cased_field_name += '_' # add underscore to the end
DXsmiley marked this conversation as resolved.
Show resolved Hide resolved

type_hint = _get_model_field_typehint(nsid, field_type_def, optional=is_optional, is_input_type=is_input_type)
description = _get_field_docstring(field_name, field_type_def)

# Allow optional params to actually be ommitted from the dict entirely
type_hint_defaulting = f'te.NotRequired[{type_hint}]' if is_optional else type_hint
field_def = f'{_(1)}{snake_cased_field_name}: {type_hint_defaulting} #: {description}'

if is_optional:
optional_fields.append(field_def)
else:
fields.append(field_def)

optional_fields.sort()
fields.sort()

fields.extend(optional_fields)

if len(fields) == 0:
fields.append(f'{_(1)}pass')

fields.append('')

return join_code(fields)


def _get_model_raw_data(name: str) -> str:
lines = [f'#: {name} raw data type.', f'{name}: te.TypeAlias = bytes\n\n']
return join_code(lines)
Expand All @@ -361,7 +430,12 @@ def _generate_params_model(nsid: NSID, definition: t.Union[models.LexXrpcQuery,

if definition.parameters:
lines.append(_get_model_docstring(nsid, definition.parameters, ModelType.PARAMS))
lines.append(_get_model(nsid, definition.parameters))
lines.append(_get_model(nsid, definition.parameters, is_input_type=True))

lines.append(_get_typeddict_class_def(nsid.name, TypedDictType.PARAMS))

if definition.parameters:
lines.append(_get_typeddict(nsid, definition.parameters, is_input_type=True))

return join_code(lines)

Expand All @@ -372,7 +446,7 @@ def _generate_xrpc_body_model(nsid: NSID, body: models.LexXrpcBody, model_type:
if isinstance(body.schema, models.LexObject):
lines.append(_get_model_class_def(nsid.name, model_type))
lines.append(_get_model_docstring(nsid, body.schema, model_type))
lines.append(_get_model(nsid, body.schema))
lines.append(_get_model(nsid, body.schema, is_input_type=(model_type is ModelType.DATA)))
else:
if model_type is ModelType.DATA:
model_name = INPUT_MODEL
Expand All @@ -386,8 +460,18 @@ def _generate_xrpc_body_model(nsid: NSID, body: models.LexXrpcBody, model_type:
return join_code(lines)


def _generate_data_typedict(nsid: NSID, body: models.LexXrpcBody) -> str:
lines: t.List[str] = []
if isinstance(body.schema, models.LexObject):
lines.append(_get_typeddict_class_def(nsid.name, TypedDictType.DATA))
lines.append(_get_typeddict(nsid, body.schema, is_input_type=True))
return join_code(lines)


def _generate_data_model(nsid: NSID, input_body: models.LexXrpcBody) -> str:
return _generate_xrpc_body_model(nsid, input_body, ModelType.DATA)
return join_code(
[_generate_xrpc_body_model(nsid, input_body, ModelType.DATA), _generate_data_typedict(nsid, input_body)]
)


def _generate_response_model(nsid: NSID, output_body: models.LexXrpcBody) -> str:
Expand Down Expand Up @@ -560,6 +644,9 @@ def _generate_record_type_database(lex_db: builder.BuiltRecordModels) -> None:
unknown_record_type_pydantic_lines.append(
"UnknownType: te.TypeAlias = t.Union[UnknownRecordTypePydantic, 'dot_dict.DotDictType']"
)
unknown_record_type_pydantic_lines.append(
"UnknownInputType: te.TypeAlias = t.Union[UnknownRecordTypePydantic, 'dot_dict.DotDictType', t.Dict[str, t.Any]]"
)
unknown_type_lines = [*import_lines, *unknown_record_type_hint_lines, *unknown_record_type_pydantic_lines]

write_code(_MODELS_OUTPUT_DIR.joinpath('type_conversion.py'), join_code(type_conversion_lines))
Expand Down
20 changes: 15 additions & 5 deletions atproto/codegen/namespaces/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

from atproto.codegen import (
DISCLAIMER,
INPUT_DICT,
INPUT_MODEL,
OUTPUT_MODEL,
PARAMS_DICT,
PARAMS_MODEL,
_resolve_nsid_ref,
convert_camel_case_to_snake_case,
Expand Down Expand Up @@ -175,13 +177,17 @@ def _override_arg_line(name: str, model_name: str) -> str:


def _get_namespace_method_signature_arg(
name: str, nsid: NSID, model_name: str, *, optional: bool, alias: bool = False
name: str, nsid: NSID, model_name: t.Union[t.List[str], str], *, optional: bool, alias: bool = False
) -> str:
if alias:
return f"{name}: 'models.{get_import_path(nsid)}.{model_name}'"

default_value = ''
type_hint = f"t.Union[dict, 'models.{get_import_path(nsid)}.{model_name}']"
type_hint = (
f"t.Union[dict, 'models.{get_import_path(nsid)}.{model_name}']"
if isinstance(model_name, str)
else 't.Union[' + ', '.join(f'models.{get_import_path(nsid)}.{i}' for i in model_name) + ']'
)
if optional:
type_hint = f't.Optional[{type_hint}]'
default_value = ' = None'
Expand Down Expand Up @@ -226,7 +232,9 @@ def is_optional_arg(lex_obj) -> bool:
params = method_info.definition.parameters
is_optional = is_optional_arg(params)

arg = _get_namespace_method_signature_arg('params', method_info.nsid, PARAMS_MODEL, optional=is_optional)
arg = _get_namespace_method_signature_arg(
'params', method_info.nsid, [PARAMS_MODEL, PARAMS_DICT], optional=is_optional
)
_add_arg(arg, optional=is_optional)

if isinstance(method_info, ProcedureInfo) and method_info.definition.input:
Expand All @@ -235,7 +243,9 @@ def is_optional_arg(lex_obj) -> bool:
is_optional = is_optional_arg(schema)

if schema and isinstance(schema, LexObject):
arg = _get_namespace_method_signature_arg('data', method_info.nsid, INPUT_MODEL, optional=is_optional)
arg = _get_namespace_method_signature_arg(
'data', method_info.nsid, [INPUT_MODEL, INPUT_DICT], optional=is_optional
)
_add_arg(arg, optional=is_optional)
else:
raise ValueError(f'Bad type {type(schema)}') # probably LexRefVariant
Expand All @@ -244,7 +254,7 @@ def is_optional_arg(lex_obj) -> bool:
_add_arg(arg, optional=False)

args.extend(optional_args)
args.append('**kwargs')
args.append('**kwargs: t.Any')
return ', '.join(args)


Expand Down
6 changes: 6 additions & 0 deletions atproto/xrpc_client/models/app/bsky/actor/get_preferences.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import typing as t

import typing_extensions as te

if t.TYPE_CHECKING:
from atproto.xrpc_client import models
from atproto.xrpc_client.models import base
Expand All @@ -17,6 +19,10 @@ class Params(base.ParamsModelBase):
"""Parameters model for :obj:`app.bsky.actor.getPreferences`."""


class ParamsDict(te.TypedDict):
pass


class Response(base.ResponseModelBase):

"""Output data model for :obj:`app.bsky.actor.getPreferences`."""
Expand Down
6 changes: 6 additions & 0 deletions atproto/xrpc_client/models/app/bsky/actor/get_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import typing as t

import typing_extensions as te

if t.TYPE_CHECKING:
pass
from atproto.xrpc_client.models import base
Expand All @@ -17,3 +19,7 @@ class Params(base.ParamsModelBase):
"""Parameters model for :obj:`app.bsky.actor.getProfile`."""

actor: str #: Actor.


class ParamsDict(te.TypedDict):
actor: str #: Actor.
5 changes: 5 additions & 0 deletions atproto/xrpc_client/models/app/bsky/actor/get_profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import typing as t

import typing_extensions as te
from pydantic import Field

if t.TYPE_CHECKING:
Expand All @@ -21,6 +22,10 @@ class Params(base.ParamsModelBase):
actors: t.List[str] = Field(max_length=25) #: Actors.


class ParamsDict(te.TypedDict):
actors: t.List[str] #: Actors.


class Response(base.ResponseModelBase):

"""Output data model for :obj:`app.bsky.actor.getProfiles`."""
Expand Down
6 changes: 6 additions & 0 deletions atproto/xrpc_client/models/app/bsky/actor/get_suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import typing as t

import typing_extensions as te
from pydantic import Field

if t.TYPE_CHECKING:
Expand All @@ -22,6 +23,11 @@ class Params(base.ParamsModelBase):
limit: t.Optional[int] = Field(default=50, ge=1, le=100) #: Limit.


class ParamsDict(te.TypedDict):
cursor: te.NotRequired[t.Optional[str]] #: Cursor.
MarshalX marked this conversation as resolved.
Show resolved Hide resolved
limit: te.NotRequired[t.Optional[int]] #: Limit.


class Response(base.ResponseModelBase):

"""Output data model for :obj:`app.bsky.actor.getSuggestions`."""
Expand Down
6 changes: 6 additions & 0 deletions atproto/xrpc_client/models/app/bsky/actor/put_preferences.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import typing as t

import typing_extensions as te

if t.TYPE_CHECKING:
from atproto.xrpc_client import models
from atproto.xrpc_client.models import base
Expand All @@ -17,3 +19,7 @@ class Data(base.DataModelBase):
"""Input data model for :obj:`app.bsky.actor.putPreferences`."""

preferences: 'models.AppBskyActorDefs.Preferences' #: Preferences.


class DataDict(te.TypedDict):
preferences: 'models.AppBskyActorDefs.Preferences' #: Preferences.
7 changes: 7 additions & 0 deletions atproto/xrpc_client/models/app/bsky/actor/search_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import typing as t

import typing_extensions as te
from pydantic import Field

if t.TYPE_CHECKING:
Expand All @@ -23,6 +24,12 @@ class Params(base.ParamsModelBase):
term: t.Optional[str] = None #: Term.


class ParamsDict(te.TypedDict):
cursor: te.NotRequired[t.Optional[str]] #: Cursor.
limit: te.NotRequired[t.Optional[int]] #: Limit.
term: te.NotRequired[t.Optional[str]] #: Term.


class Response(base.ResponseModelBase):

"""Output data model for :obj:`app.bsky.actor.searchActors`."""
Expand Down
Loading