Skip to content

Commit

Permalink
Merge pull request #167 from koxudaxi/refactor_parser
Browse files Browse the repository at this point in the history
Refactor parser
  • Loading branch information
koxudaxi authored Aug 1, 2021
2 parents d4b2cf5 + 84d3a44 commit a143563
Show file tree
Hide file tree
Showing 5 changed files with 331 additions and 442 deletions.
62 changes: 42 additions & 20 deletions fastapi_code_generator/__main__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, Optional
from typing import Dict, List, Optional

import typer
from datamodel_code_generator import PythonVersion, chdir
from datamodel_code_generator.format import CodeFormatter
from datamodel_code_generator.parser.openapi import OpenAPIParser as OpenAPIModelParser
from datamodel_code_generator.imports import Import, Imports
from datamodel_code_generator.reference import Reference
from datamodel_code_generator.types import DataType
from jinja2 import Environment, FileSystemLoader

from fastapi_code_generator.parser import MODEL_PATH, OpenAPIParser, ParsedObject
from fastapi_code_generator.parser import OpenAPIParser, Operation

app = typer.Typer()

BUILTIN_TEMPLATE_DIR = Path(__file__).parent / "template"

MODEL_PATH: Path = Path("models.py")


@app.command()
def main(
Expand All @@ -26,6 +30,16 @@ def main(
return generate_code(input_name, input_text, output_dir, template_dir)


def _get_most_of_reference(data_type: DataType) -> Optional[Reference]:
if data_type.reference:
return data_type.reference
for data_type in data_type.data_types:
reference = _get_most_of_reference(data_type)
if reference:
return reference
return None


def generate_code(
input_name: str, input_text: str, output_dir: Path, template_dir: Optional[Path]
) -> None:
Expand All @@ -34,25 +48,43 @@ def generate_code(
if not template_dir:
template_dir = BUILTIN_TEMPLATE_DIR

model_parser = OpenAPIModelParser(source=input_text,)

parser = OpenAPIParser(input_name, input_text, openapi_model_parser=model_parser)
parsed_object: ParsedObject = parser.parse()
parser = OpenAPIParser(input_text)
with chdir(output_dir):
models = parser.parse()
if not models:
return
elif isinstance(models, str):
output = output_dir / MODEL_PATH
modules = {output: (models, input_name)}
else:
raise Exception('Modular references are not supported in this version')

environment: Environment = Environment(
loader=FileSystemLoader(
template_dir if template_dir else f"{Path(__file__).parent}/template",
encoding="utf8",
),
)
imports = Imports()
imports.update(parser.imports)
for data_type in parser.data_types:
reference = _get_most_of_reference(data_type)
if reference:
imports.append(data_type.all_imports)
imports.append(
Import.from_full_path(f'.{MODEL_PATH.stem}.{reference.name}')
)
for from_, imports_ in parser.imports_for_fastapi.items():
imports[from_].update(imports_)
results: Dict[Path, str] = {}
code_formatter = CodeFormatter(PythonVersion.PY_38, Path().resolve())
sorted_operations: List[Operation] = sorted(
parser.operations.values(), key=lambda m: m.path
)
for target in template_dir.rglob("*"):
relative_path = target.relative_to(template_dir)
result = environment.get_template(str(relative_path)).render(
operations=parsed_object.operations,
imports=parsed_object.imports,
info=parsed_object.info,
operations=sorted_operations, imports=imports, info=parser.parse_info(),
)
results[relative_path] = code_formatter.format_code(result)

Expand All @@ -68,16 +100,6 @@ def generate_code(
print("", file=file)
print(code.rstrip(), file=file)

with chdir(output_dir):
results = model_parser.parse()
if not results:
return
elif isinstance(results, str):
output = output_dir / MODEL_PATH
modules = {output: (results, input_name)}
else:
raise Exception('Modular references are not supported in this version')

header = f'''\
# generated by fastapi-codegen:
# filename: {{filename}}'''
Expand Down
Loading

0 comments on commit a143563

Please sign in to comment.