Skip to content

Commit

Permalink
Add TypedDict for params and data arguments; add type hint for kwargs (
Browse files Browse the repository at this point in the history
  • Loading branch information
DXsmiley authored Sep 28, 2023
1 parent fd70145 commit dab0fd0
Show file tree
Hide file tree
Showing 100 changed files with 1,523 additions and 219 deletions.
2 changes: 2 additions & 0 deletions atproto/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
DISCLAIMER = f'{"#" * _MAX_DISCLAIMER_LEN}\n{DISCLAIMER}\n{"#" * _MAX_DISCLAIMER_LEN}\n\n'

PARAMS_MODEL = 'Params'
PARAMS_DICT = 'ParamsDict'
INPUT_MODEL = 'Data'
INPUT_DICT = 'DataDict'
OUTPUT_MODEL = 'Response'


Expand Down
99 changes: 91 additions & 8 deletions atproto/codegen/models/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

from atproto.codegen import (
DISCLAIMER,
INPUT_DICT,
INPUT_MODEL,
OUTPUT_MODEL,
PARAMS_DICT,
PARAMS_MODEL,
_resolve_nsid_ref,
append_code,
Expand Down Expand Up @@ -36,6 +38,11 @@ class ModelType(Enum):
RECORD = 'Record'


class TypedDictType(Enum):
PARAMS = 'Parameters'
DATA = 'Input data'


def save_code(nsid: NSID, code: str) -> None:
path_to_file = _MODELS_OUTPUT_DIR.joinpath(*get_file_path_parts(nsid))
write_code(_MODELS_OUTPUT_DIR.joinpath(path_to_file), code)
Expand All @@ -57,6 +64,7 @@ def _get_model_imports() -> str:
'if t.TYPE_CHECKING:',
f'{_(1)}from atproto.xrpc_client import models',
f'{_(1)}from atproto.xrpc_client.models.unknown_type import UnknownType',
f'{_(1)}from atproto.xrpc_client.models.unknown_type import UnknownInputType',
f'{_(1)}from atproto.xrpc_client.models.blob_ref import BlobRef',
f'{_(1)}from atproto import CIDType',
'from atproto.xrpc_client.models import base',
Expand All @@ -78,7 +86,7 @@ def _save_code_import_if_not_exist(nsid: NSID) -> None:


def _get_model_class_def(name: str, model_type: ModelType) -> str:
lines = []
lines: t.List[str] = []

if model_type is ModelType.PARAMS:
lines.append(f'class {PARAMS_MODEL}(base.ParamsModelBase):')
Expand All @@ -96,6 +104,19 @@ def _get_model_class_def(name: str, model_type: ModelType) -> str:
return join_code(lines)


def _get_typeddict_class_def(name: str, model_type: TypedDictType) -> str:
lines: t.List[str] = []

if model_type is TypedDictType.PARAMS:
lines.append(f'class {PARAMS_DICT}(te.TypedDict):')
elif model_type is TypedDictType.DATA:
lines.append(f'class {INPUT_DICT}(te.TypedDict):')

lines.append('')

return join_code(lines)


_LEXICON_TYPE_TO_PRIMITIVE_TYPEHINT = {
models.LexString: 'str',
models.LexInteger: 'int',
Expand Down Expand Up @@ -135,19 +156,23 @@ def _get_ref_union_typehint(nsid: NSID, field_type_def, *, optional: bool) -> st
return _get_optional_typehint(annotated_union, optional=optional)


def _get_model_field_typehint(nsid: NSID, field_type_def, *, optional: bool) -> str:
def _get_model_field_typehint(nsid: NSID, field_type_def, *, optional: bool, is_input_type: bool = False) -> str:
field_type = type(field_type_def)

if field_type == models.LexUnknown:
# unknown type is a generic response with records or any not described type in the lexicon. for example, didDoc
if is_input_type:
return _get_optional_typehint("'UnknownInputType'", optional=optional)
return _get_optional_typehint("'UnknownType'", optional=optional)

type_hint = _LEXICON_TYPE_TO_PRIMITIVE_TYPEHINT.get(field_type)
if type_hint:
return _get_optional_typehint(type_hint, optional=optional)

if field_type is models.LexArray:
items_type_hint = _get_model_field_typehint(nsid, field_type_def.items, optional=False)
items_type_hint = _get_model_field_typehint(
nsid, field_type_def.items, optional=False, is_input_type=is_input_type
)
return _get_optional_typehint(f't.List[{items_type_hint}]', optional=optional)

if field_type is models.LexRef:
Expand Down Expand Up @@ -308,7 +333,9 @@ def _is_reserved_pydantic_name(name: str) -> bool:
return name in _get_pydantic_reserved_names()


def _get_model(nsid: NSID, lex_object: t.Union[models.LexObject, models.LexXrpcParameters]) -> str:
def _get_model(
nsid: NSID, lex_object: t.Union[models.LexObject, models.LexXrpcParameters], *, is_input_type: bool = False
) -> str:
required_fields = _get_req_fields_set(lex_object)

fields = []
Expand All @@ -329,7 +356,7 @@ def _get_model(nsid: NSID, lex_object: t.Union[models.LexObject, models.LexXrpcP
snake_cased_field_name += '_' # add underscore to the end
alias_name = field_name

type_hint = _get_model_field_typehint(nsid, field_type_def, optional=is_optional)
type_hint = _get_model_field_typehint(nsid, field_type_def, optional=is_optional, is_input_type=is_input_type)
value = _get_model_field_value(field_type_def, alias_name, optional=is_optional)
description = _get_field_docstring(field_name, field_type_def)

Expand All @@ -351,6 +378,44 @@ def _get_model(nsid: NSID, lex_object: t.Union[models.LexObject, models.LexXrpcP
return join_code(fields)


def _get_typeddict(
nsid: NSID, lex_object: t.Union[models.LexObject, models.LexXrpcParameters], *, is_input_type: bool = False
) -> str:
required_fields = _get_req_fields_set(lex_object)

fields: t.List[str] = []
optional_fields: t.List[str] = []

for field_name, field_type_def in lex_object.properties.items():
is_optional = field_name not in required_fields

snake_cased_field_name = convert_camel_case_to_snake_case(field_name)

type_hint = _get_model_field_typehint(nsid, field_type_def, optional=is_optional, is_input_type=is_input_type)
description = _get_field_docstring(field_name, field_type_def)

# Allow optional params to actually be ommitted from the dict entirely
type_hint_defaulting = f'te.NotRequired[{type_hint}]' if is_optional else type_hint
field_def = f'{_(1)}{snake_cased_field_name}: {type_hint_defaulting} #: {description}'

if is_optional:
optional_fields.append(field_def)
else:
fields.append(field_def)

optional_fields.sort()
fields.sort()

fields.extend(optional_fields)

if len(fields) == 0:
fields.append(f'{_(1)}pass')

fields.append('')

return join_code(fields)


def _get_model_raw_data(name: str) -> str:
lines = [f'#: {name} raw data type.', f'{name}: te.TypeAlias = bytes\n\n']
return join_code(lines)
Expand All @@ -361,7 +426,12 @@ def _generate_params_model(nsid: NSID, definition: t.Union[models.LexXrpcQuery,

if definition.parameters:
lines.append(_get_model_docstring(nsid, definition.parameters, ModelType.PARAMS))
lines.append(_get_model(nsid, definition.parameters))
lines.append(_get_model(nsid, definition.parameters, is_input_type=True))

lines.append(_get_typeddict_class_def(nsid.name, TypedDictType.PARAMS))

if definition.parameters:
lines.append(_get_typeddict(nsid, definition.parameters, is_input_type=True))

return join_code(lines)

Expand All @@ -372,7 +442,7 @@ def _generate_xrpc_body_model(nsid: NSID, body: models.LexXrpcBody, model_type:
if isinstance(body.schema, models.LexObject):
lines.append(_get_model_class_def(nsid.name, model_type))
lines.append(_get_model_docstring(nsid, body.schema, model_type))
lines.append(_get_model(nsid, body.schema))
lines.append(_get_model(nsid, body.schema, is_input_type=(model_type is ModelType.DATA)))
else:
if model_type is ModelType.DATA:
model_name = INPUT_MODEL
Expand All @@ -386,8 +456,18 @@ def _generate_xrpc_body_model(nsid: NSID, body: models.LexXrpcBody, model_type:
return join_code(lines)


def _generate_data_typedict(nsid: NSID, body: models.LexXrpcBody) -> str:
lines: t.List[str] = []
if isinstance(body.schema, models.LexObject):
lines.append(_get_typeddict_class_def(nsid.name, TypedDictType.DATA))
lines.append(_get_typeddict(nsid, body.schema, is_input_type=True))
return join_code(lines)


def _generate_data_model(nsid: NSID, input_body: models.LexXrpcBody) -> str:
return _generate_xrpc_body_model(nsid, input_body, ModelType.DATA)
return join_code(
[_generate_xrpc_body_model(nsid, input_body, ModelType.DATA), _generate_data_typedict(nsid, input_body)]
)


def _generate_response_model(nsid: NSID, output_body: models.LexXrpcBody) -> str:
Expand Down Expand Up @@ -560,6 +640,9 @@ def _generate_record_type_database(lex_db: builder.BuiltRecordModels) -> None:
unknown_record_type_pydantic_lines.append(
"UnknownType: te.TypeAlias = t.Union[UnknownRecordTypePydantic, 'dot_dict.DotDictType']"
)
unknown_record_type_pydantic_lines.append(
'UnknownInputType: te.TypeAlias = t.Union[UnknownType, t.Dict[str, t.Any]]'
)
unknown_type_lines = [*import_lines, *unknown_record_type_hint_lines, *unknown_record_type_pydantic_lines]

write_code(_MODELS_OUTPUT_DIR.joinpath('type_conversion.py'), join_code(type_conversion_lines))
Expand Down
20 changes: 15 additions & 5 deletions atproto/codegen/namespaces/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

from atproto.codegen import (
DISCLAIMER,
INPUT_DICT,
INPUT_MODEL,
OUTPUT_MODEL,
PARAMS_DICT,
PARAMS_MODEL,
_resolve_nsid_ref,
convert_camel_case_to_snake_case,
Expand Down Expand Up @@ -175,13 +177,17 @@ def _override_arg_line(name: str, model_name: str) -> str:


def _get_namespace_method_signature_arg(
name: str, nsid: NSID, model_name: str, *, optional: bool, alias: bool = False
name: str, nsid: NSID, model_name: t.Union[t.List[str], str], *, optional: bool, alias: bool = False
) -> str:
if alias:
return f"{name}: 'models.{get_import_path(nsid)}.{model_name}'"

default_value = ''
type_hint = f"t.Union[dict, 'models.{get_import_path(nsid)}.{model_name}']"
type_hint = (
f"t.Union[dict, 'models.{get_import_path(nsid)}.{model_name}']"
if isinstance(model_name, str)
else 't.Union[' + ', '.join(f'models.{get_import_path(nsid)}.{i}' for i in model_name) + ']'
)
if optional:
type_hint = f't.Optional[{type_hint}]'
default_value = ' = None'
Expand Down Expand Up @@ -226,7 +232,9 @@ def is_optional_arg(lex_obj) -> bool:
params = method_info.definition.parameters
is_optional = is_optional_arg(params)

arg = _get_namespace_method_signature_arg('params', method_info.nsid, PARAMS_MODEL, optional=is_optional)
arg = _get_namespace_method_signature_arg(
'params', method_info.nsid, [PARAMS_MODEL, PARAMS_DICT], optional=is_optional
)
_add_arg(arg, optional=is_optional)

if isinstance(method_info, ProcedureInfo) and method_info.definition.input:
Expand All @@ -235,7 +243,9 @@ def is_optional_arg(lex_obj) -> bool:
is_optional = is_optional_arg(schema)

if schema and isinstance(schema, LexObject):
arg = _get_namespace_method_signature_arg('data', method_info.nsid, INPUT_MODEL, optional=is_optional)
arg = _get_namespace_method_signature_arg(
'data', method_info.nsid, [INPUT_MODEL, INPUT_DICT], optional=is_optional
)
_add_arg(arg, optional=is_optional)
else:
raise ValueError(f'Bad type {type(schema)}') # probably LexRefVariant
Expand All @@ -244,7 +254,7 @@ def is_optional_arg(lex_obj) -> bool:
_add_arg(arg, optional=False)

args.extend(optional_args)
args.append('**kwargs')
args.append('**kwargs: t.Any')
return ', '.join(args)


Expand Down
6 changes: 6 additions & 0 deletions atproto/xrpc_client/models/app/bsky/actor/get_preferences.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import typing as t

import typing_extensions as te

if t.TYPE_CHECKING:
from atproto.xrpc_client import models
from atproto.xrpc_client.models import base
Expand All @@ -17,6 +19,10 @@ class Params(base.ParamsModelBase):
"""Parameters model for :obj:`app.bsky.actor.getPreferences`."""


class ParamsDict(te.TypedDict):
pass


class Response(base.ResponseModelBase):

"""Output data model for :obj:`app.bsky.actor.getPreferences`."""
Expand Down
6 changes: 6 additions & 0 deletions atproto/xrpc_client/models/app/bsky/actor/get_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import typing as t

import typing_extensions as te

if t.TYPE_CHECKING:
pass
from atproto.xrpc_client.models import base
Expand All @@ -17,3 +19,7 @@ class Params(base.ParamsModelBase):
"""Parameters model for :obj:`app.bsky.actor.getProfile`."""

actor: str #: Actor.


class ParamsDict(te.TypedDict):
actor: str #: Actor.
5 changes: 5 additions & 0 deletions atproto/xrpc_client/models/app/bsky/actor/get_profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import typing as t

import typing_extensions as te
from pydantic import Field

if t.TYPE_CHECKING:
Expand All @@ -21,6 +22,10 @@ class Params(base.ParamsModelBase):
actors: t.List[str] = Field(max_length=25) #: Actors.


class ParamsDict(te.TypedDict):
actors: t.List[str] #: Actors.


class Response(base.ResponseModelBase):

"""Output data model for :obj:`app.bsky.actor.getProfiles`."""
Expand Down
6 changes: 6 additions & 0 deletions atproto/xrpc_client/models/app/bsky/actor/get_suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import typing as t

import typing_extensions as te
from pydantic import Field

if t.TYPE_CHECKING:
Expand All @@ -22,6 +23,11 @@ class Params(base.ParamsModelBase):
limit: t.Optional[int] = Field(default=50, ge=1, le=100) #: Limit.


class ParamsDict(te.TypedDict):
cursor: te.NotRequired[t.Optional[str]] #: Cursor.
limit: te.NotRequired[t.Optional[int]] #: Limit.


class Response(base.ResponseModelBase):

"""Output data model for :obj:`app.bsky.actor.getSuggestions`."""
Expand Down
6 changes: 6 additions & 0 deletions atproto/xrpc_client/models/app/bsky/actor/put_preferences.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import typing as t

import typing_extensions as te

if t.TYPE_CHECKING:
from atproto.xrpc_client import models
from atproto.xrpc_client.models import base
Expand All @@ -17,3 +19,7 @@ class Data(base.DataModelBase):
"""Input data model for :obj:`app.bsky.actor.putPreferences`."""

preferences: 'models.AppBskyActorDefs.Preferences' #: Preferences.


class DataDict(te.TypedDict):
preferences: 'models.AppBskyActorDefs.Preferences' #: Preferences.
7 changes: 7 additions & 0 deletions atproto/xrpc_client/models/app/bsky/actor/search_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import typing as t

import typing_extensions as te
from pydantic import Field

if t.TYPE_CHECKING:
Expand All @@ -23,6 +24,12 @@ class Params(base.ParamsModelBase):
term: t.Optional[str] = None #: Term.


class ParamsDict(te.TypedDict):
cursor: te.NotRequired[t.Optional[str]] #: Cursor.
limit: te.NotRequired[t.Optional[int]] #: Limit.
term: te.NotRequired[t.Optional[str]] #: Term.


class Response(base.ResponseModelBase):

"""Output data model for :obj:`app.bsky.actor.searchActors`."""
Expand Down
Loading

0 comments on commit dab0fd0

Please sign in to comment.