diff --git a/src/openapi_python_generator/generate_data.py b/src/openapi_python_generator/generate_data.py index fbce719..e637012 100644 --- a/src/openapi_python_generator/generate_data.py +++ b/src/openapi_python_generator/generate_data.py @@ -7,6 +7,7 @@ import httpx import isort import orjson +import yaml from black import NothingChanged from httpx import ConnectError from httpx import ConnectTimeout @@ -45,30 +46,61 @@ def write_code(path: Path, content) -> None: def get_open_api(source: Union[str, Path]) -> OpenAPI: """ - Tries to fetch the openapi.json file from the web or load from a local file. Returns the according OpenAPI object. - :param source: - :return: + Tries to fetch the openapi specification file from the web or load from a local file. + Supports both JSON and YAML formats. Returns the according OpenAPI object. + + Args: + source: URL or file path to the OpenAPI specification + + Returns: + OpenAPI: Parsed OpenAPI specification object + + Raises: + FileNotFoundError: If the specified file cannot be found + ConnectError: If the URL cannot be accessed + ValidationError: If the specification is invalid + JSONDecodeError/YAMLError: If the file cannot be parsed """ try: + # Handle remote files if not isinstance(source, Path) and ( - source.startswith("http://") or source.startswith("https://") + source.startswith("http://") or source.startswith("https://") ): - return OpenAPI(**orjson.loads(httpx.get(source).text)) + content = httpx.get(source).text + # Try JSON first, then YAML for remote files + try: + return OpenAPI(**orjson.loads(content)) + except orjson.JSONDecodeError: + return OpenAPI(**yaml.safe_load(content)) + # Handle local files with open(source, "r") as f: file_content = f.read() - return OpenAPI(**orjson.loads(file_content)) + + # Try JSON first + try: + return OpenAPI(**orjson.loads(file_content)) + except orjson.JSONDecodeError: + # If JSON fails, try YAML + try: + return OpenAPI(**yaml.safe_load(file_content)) + except yaml.YAMLError as e: + click.echo( + f"File {source} is neither a valid JSON nor YAML file: {str(e)}" + ) + raise + except FileNotFoundError: click.echo( - f"File {source} not found. Please make sure to pass the path to the OpenAPI 3.0 specification." + f"File {source} not found. Please make sure to pass the path to the OpenAPI specification." ) raise except (ConnectError, ConnectTimeout): click.echo(f"Could not connect to {source}.") raise ConnectError(f"Could not connect to {source}.") from None - except (ValidationError, orjson.JSONDecodeError): + except ValidationError: click.echo( - f"File {source} is not a valid OpenAPI 3.0 specification, or there may be a problem with your JSON." + f"File {source} is not a valid OpenAPI 3.0 specification." ) raise diff --git a/tests/test_generate_data.py b/tests/test_generate_data.py index 693990d..4214941 100644 --- a/tests/test_generate_data.py +++ b/tests/test_generate_data.py @@ -1,7 +1,9 @@ import shutil import pytest +import yaml from httpx import ConnectError +from orjson import orjson from pydantic import ValidationError from openapi_python_generator.common import HTTPLibrary @@ -16,14 +18,31 @@ def test_get_open_api(model_data): + # Test JSON file assert get_open_api(test_data_path) == model_data + # Create YAML version of the test file + yaml_path = test_data_path.with_suffix('.yaml') + with open(test_data_path) as f: + json_content = orjson.loads(f.read()) + with open(yaml_path, 'w') as f: + yaml.dump(json_content, f) + + # Test YAML file + assert get_open_api(yaml_path) == model_data + + # Cleanup YAML file + yaml_path.unlink() + + # Test remote file failure with pytest.raises(ConnectError): assert get_open_api("http://localhost:8080/api/openapi.json") + # Test invalid OpenAPI spec with pytest.raises(ValidationError): assert get_open_api(test_data_folder / "failing_api.json") + # Test non-existent file with pytest.raises(FileNotFoundError): assert get_open_api(test_data_folder / "file_does_not_exist.json")