diff --git a/fastapi_code_generator/__main__.py b/fastapi_code_generator/__main__.py index 6631be1..3e0732b 100644 --- a/fastapi_code_generator/__main__.py +++ b/fastapi_code_generator/__main__.py @@ -3,7 +3,7 @@ from typing import Dict, List, Optional import typer -from datamodel_code_generator import PythonVersion, chdir +from datamodel_code_generator import LiteralType, PythonVersion, chdir from datamodel_code_generator.format import CodeFormatter from datamodel_code_generator.imports import Import, Imports from datamodel_code_generator.reference import Reference @@ -25,6 +25,9 @@ def main( output_dir: Path = typer.Option(..., "--output", "-o"), model_file: str = typer.Option(None, "--model-file", "-m"), template_dir: Optional[Path] = typer.Option(None, "--template-dir", "-t"), + enum_field_as_literal: Optional[LiteralType] = typer.Option( + None, "--enum-field-as-literal" + ), ) -> None: input_name: str = input_file.name input_text: str = input_file.read() @@ -32,6 +35,15 @@ def main( model_path = Path(model_file).with_suffix('.py') else: model_path = MODEL_PATH + if enum_field_as_literal: + return generate_code( + input_name, + input_text, + output_dir, + template_dir, + model_path, + enum_field_as_literal, + ) return generate_code(input_name, input_text, output_dir, template_dir, model_path) @@ -51,6 +63,7 @@ def generate_code( output_dir: Path, template_dir: Optional[Path], model_path: Optional[Path] = None, + enum_field_as_literal: Optional[str] = None, ) -> None: if not model_path: model_path = MODEL_PATH @@ -58,8 +71,10 @@ def generate_code( output_dir.mkdir(parents=True) if not template_dir: template_dir = BUILTIN_TEMPLATE_DIR - - parser = OpenAPIParser(input_text) + if enum_field_as_literal: + parser = OpenAPIParser(input_text, enum_field_as_literal=enum_field_as_literal) + else: + parser = OpenAPIParser(input_text) with chdir(output_dir): models = parser.parse() if not models: