From 3272cbf48e9086581ec5f43d67d7c733ef38767f Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Thu, 18 Jun 2020 01:35:16 +0900 Subject: [PATCH 1/3] add field to argument --- fastapi_code_generator/parser.py | 97 +++++++++++++++----------------- 1 file changed, 45 insertions(+), 52 deletions(-) diff --git a/fastapi_code_generator/parser.py b/fastapi_code_generator/parser.py index 7092cfa..aaacf69 100644 --- a/fastapi_code_generator/parser.py +++ b/fastapi_code_generator/parser.py @@ -2,7 +2,7 @@ import re from contextvars import ContextVar -from functools import cached_property +from functools import cached_property, lru_cache from typing import Any, Dict, List, Optional, Union import stringcase @@ -65,11 +65,20 @@ def camelcase(self) -> str: return stringcase.camelcase(self) -class Argument(BaseModel): +class Argument(CachedPropertyModel): name: UsefulStr + type_hint: UsefulStr + default: Optional[UsefulStr] + required: bool + + def __str__(self) -> str: + return self.argument - # def __str__(self) -> UsefulStr: - # return self.name + @cached_property + def argument(self) -> str: + if not self.default and self.required: + return f'{self.name}: {self.type_hint}' + return f'{self.name}: {self.type_hint} = {self.default}' class Operation(CachedPropertyModel): @@ -93,22 +102,28 @@ def snake_case_path(self) -> str: ) @cached_property - def request(self) -> Optional[str]: - models: List[str] = [] + def request(self) -> Optional[Argument]: + arguments: List[Argument] = [] for requests in self.request_objects: for content_type, schema in requests.contents.items(): + # TODO: support other content-types if content_type == "application/json": - models.append(schema.ref_object_name) + arguments.append( + # TODO: support multiple body + Argument( + name='body', # type: ignore + type_hint=schema.ref_object_name, # type: ignore + required=requests.required, + ) + ) self.imports.append( Import( from_=model_path_var.get(), import_=schema.ref_object_name ) ) - if not models: + if not arguments: return None - if len(models) > 1: - return f'Union[{",".join(models)}]' - return models[0] + return arguments[0] @cached_property def request_objects(self) -> List[Request]: @@ -171,69 +186,47 @@ def snake_case_arguments(self) -> str: return self.get_arguments(snake_case=True) def get_arguments(self, snake_case: bool) -> str: - arguments: List[str] = [] - - if self.parameters: - for parameter in self.parameters: - arguments.append(self.get_parameter_type(parameter, snake_case)) - - if self.request: - arguments.append(f"body: {self.request}") - - return ", ".join(arguments) + return ", ".join( + argument.argument for argument in self.get_argument_list(snake_case) + ) @cached_property def argument_list(self) -> List[Argument]: + return self.get_argument_list(False) + + def get_argument_list(self, snake_case: bool) -> List[Argument]: arguments: List[Argument] = [] if self.parameters: for parameter in self.parameters: - arguments.append(Argument.parse_obj(parameter)) + arguments.append(self.get_parameter_type(parameter, snake_case)) if self.request: - arguments.append(Argument(name=UsefulStr('body'))) - + arguments.append(self.request) return arguments def get_parameter_type( self, parameter: Dict[str, Union[str, Dict[str, str]]], snake_case: bool - ) -> str: + ) -> Argument: schema: JsonSchemaObject = JsonSchemaObject.parse_obj(parameter["schema"]) format_ = schema.format or "default" type_ = json_schema_data_formats[schema.type][format_] - return self.get_data_type_hint( - name=stringcase.snakecase(parameter["name"]) - if snake_case - else parameter["name"], + name: str = parameter["name"] + + field = DataModelField( + name=stringcase.snakecase(name) if snake_case else name, data_types=[type_map[type_]], required=parameter.get("required") == "true" or parameter.get("in") == "path", - snake_case=snake_case, default=schema.typed_default, ) - - def get_data_type_hint( - self, - name: str, - data_types: List[DataType], - required: bool, - snake_case: bool, - default: Optional[str] = None, - auto_import: bool = True, - ) -> str: - field = DataModelField( - name=stringcase.snakecase(name) if snake_case else name, - data_types=data_types, - required=required, - default=default, + self.imports.extend(field.imports) + return Argument( + name=field.name, # type: ignore + type_hint=field.type_hint, # type: ignore + default=field.default, # type: ignore + required=field.required, # type: ignore ) - if auto_import: - self.imports.extend(field.imports) - - if not default and field.required: - return f"{field.name}: {field.type_hint}" - - return f'{field.name}: {field.type_hint} = {default}' @cached_property def response(self) -> str: From 46720b5db7c75de6f735affc3c42a08c5168ae07 Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Thu, 18 Jun 2020 01:37:03 +0900 Subject: [PATCH 2/3] remove unused import --- fastapi_code_generator/parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastapi_code_generator/parser.py b/fastapi_code_generator/parser.py index aaacf69..d1c07ca 100644 --- a/fastapi_code_generator/parser.py +++ b/fastapi_code_generator/parser.py @@ -2,7 +2,7 @@ import re from contextvars import ContextVar -from functools import cached_property, lru_cache +from functools import cached_property from typing import Any, Dict, List, Optional, Union import stringcase From d49a7d4134652ca0702f5e300f62abe4c11d83f8 Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Thu, 18 Jun 2020 01:46:01 +0900 Subject: [PATCH 3/3] fix types --- fastapi_code_generator/parser.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/fastapi_code_generator/parser.py b/fastapi_code_generator/parser.py index d1c07ca..08ae197 100644 --- a/fastapi_code_generator/parser.py +++ b/fastapi_code_generator/parser.py @@ -112,7 +112,7 @@ def request(self) -> Optional[Argument]: # TODO: support multiple body Argument( name='body', # type: ignore - type_hint=schema.ref_object_name, # type: ignore + type_hint=schema.ref_object_name, required=requests.required, ) ) @@ -211,7 +211,7 @@ def get_parameter_type( schema: JsonSchemaObject = JsonSchemaObject.parse_obj(parameter["schema"]) format_ = schema.format or "default" type_ = json_schema_data_formats[schema.type][format_] - name: str = parameter["name"] + name: str = parameter["name"] # type: ignore field = DataModelField( name=stringcase.snakecase(name) if snake_case else name, @@ -222,10 +222,10 @@ def get_parameter_type( ) self.imports.extend(field.imports) return Argument( - name=field.name, # type: ignore - type_hint=field.type_hint, # type: ignore - default=field.default, # type: ignore - required=field.required, # type: ignore + name=field.name, + type_hint=field.type_hint, + default=field.default, + required=field.required, ) @cached_property