Skip to content

Commit

Permalink
Merge pull request #204 from baophamtd/add-option-to-specify-model-fi…
Browse files Browse the repository at this point in the history
…le-path

Added option to specify model file instead of defaulting to models.py
  • Loading branch information
koxudaxi authored Nov 28, 2021
2 parents 8a03243 + de846a2 commit 4f93df2
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Options:
-i, --input FILENAME [required]
-o, --output PATH [required]
-t, --template-dir PATH
-m, --model-file Specify generated model file path + name, if not default to models.py
--install-completion Install completion for the current shell.
--show-completion Show completion for the current shell, to copy it
or customize the installation.
Expand Down
16 changes: 9 additions & 7 deletions fastapi_code_generator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,20 @@

BUILTIN_TEMPLATE_DIR = Path(__file__).parent / "template"

MODEL_PATH: Path = Path("models.py")


@app.command()
def main(
input_file: typer.FileText = typer.Option(..., "--input", "-i"),
output_dir: Path = typer.Option(..., "--output", "-o"),
model_file: str = typer.Option(None, "--model-file", "-m"),
template_dir: Optional[Path] = typer.Option(None, "--template-dir", "-t"),
) -> None:
input_name: str = input_file.name
input_text: str = input_file.read()
return generate_code(input_name, input_text, output_dir, template_dir)
if model_file:
model_path = Path(f"{model_file}.py")
else:
model_path = Path("models.py")
return generate_code(input_name, input_text, output_dir, template_dir, model_path)


def _get_most_of_reference(data_type: DataType) -> Optional[Reference]:
Expand All @@ -41,7 +43,7 @@ def _get_most_of_reference(data_type: DataType) -> Optional[Reference]:


def generate_code(
input_name: str, input_text: str, output_dir: Path, template_dir: Optional[Path]
input_name: str, input_text: str, output_dir: Path, template_dir: Optional[Path], model_path: Path
) -> None:
if not output_dir.exists():
output_dir.mkdir(parents=True)
Expand All @@ -54,7 +56,7 @@ def generate_code(
if not models:
return
elif isinstance(models, str):
output = output_dir / MODEL_PATH
output = output_dir / model_path
modules = {output: (models, input_name)}
else:
raise Exception('Modular references are not supported in this version')
Expand All @@ -72,7 +74,7 @@ def generate_code(
if reference:
imports.append(data_type.all_imports)
imports.append(
Import.from_full_path(f'.{MODEL_PATH.stem}.{reference.name}')
Import.from_full_path(f'.{model_path.stem}.{reference.name}')
)
for from_, imports_ in parser.imports_for_fastapi.items():
imports[from_].update(imports_)
Expand Down

0 comments on commit 4f93df2

Please sign in to comment.