diff --git a/README.md b/README.md index ce7477c..c4fc8db 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,7 @@ Options: -i, --input FILENAME [required] -o, --output PATH [required] -t, --template-dir PATH + -m, --model-file Specify generated model file path + name, if not default to models.py --install-completion Install completion for the current shell. --show-completion Show completion for the current shell, to copy it or customize the installation. diff --git a/fastapi_code_generator/__main__.py b/fastapi_code_generator/__main__.py index b42a783..cf314df 100644 --- a/fastapi_code_generator/__main__.py +++ b/fastapi_code_generator/__main__.py @@ -16,18 +16,20 @@ BUILTIN_TEMPLATE_DIR = Path(__file__).parent / "template" -MODEL_PATH: Path = Path("models.py") - - @app.command() def main( input_file: typer.FileText = typer.Option(..., "--input", "-i"), output_dir: Path = typer.Option(..., "--output", "-o"), + model_file: str = typer.Option(None, "--model-file", "-m"), template_dir: Optional[Path] = typer.Option(None, "--template-dir", "-t"), ) -> None: input_name: str = input_file.name input_text: str = input_file.read() - return generate_code(input_name, input_text, output_dir, template_dir) + if model_file: + model_path = Path(f"{model_file}.py") + else: + model_path = Path("models.py") + return generate_code(input_name, input_text, output_dir, template_dir, model_path) def _get_most_of_reference(data_type: DataType) -> Optional[Reference]: @@ -41,7 +43,7 @@ def _get_most_of_reference(data_type: DataType) -> Optional[Reference]: def generate_code( - input_name: str, input_text: str, output_dir: Path, template_dir: Optional[Path] + input_name: str, input_text: str, output_dir: Path, template_dir: Optional[Path], model_path: Path ) -> None: if not output_dir.exists(): output_dir.mkdir(parents=True) @@ -54,7 +56,7 @@ def generate_code( if not models: return elif isinstance(models, str): - output = output_dir / MODEL_PATH + output = output_dir / model_path modules = {output: (models, input_name)} else: raise Exception('Modular references are not supported in this version') @@ -72,7 +74,7 @@ def generate_code( if reference: imports.append(data_type.all_imports) imports.append( - Import.from_full_path(f'.{MODEL_PATH.stem}.{reference.name}') + Import.from_full_path(f'.{model_path.stem}.{reference.name}') ) for from_, imports_ in parser.imports_for_fastapi.items(): imports[from_].update(imports_)