diff --git a/tests/test_google_credentials.py b/tests/test_google_credentials.py index 213168689..5a5fb1e44 100644 --- a/tests/test_google_credentials.py +++ b/tests/test_google_credentials.py @@ -1,4 +1,37 @@ -from toucan_connectors.google_credentials import GoogleCredentials +import json + +from pytest_mock import MockFixture + +from toucan_connectors.google_credentials import GoogleCredentials, get_google_oauth2_credentials + + +def test_google_credentials(mocker: MockFixture): + conf = { + "type": "service_account", + "project_id": "my_project_id", + "private_key_id": "my_private_key_id", + "private_key": "-----BEGIN PRIVATE KEY-----\naaa\nbbb\n-----END PRIVATE KEY-----\n", + "client_email": "my_client_email", + "client_id": "my_client_id", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/xxx.iam.gserviceaccount.com", # noqa: E501 + } + credentials = GoogleCredentials(**conf) + # Ensure `private_key_id` and `private_key` are masked + assert credentials.model_dump_json() == json.dumps( + { + **conf, + "private_key_id": "**********", + "private_key": "**********", + }, + separators=(",", ":"), + ) + # Ensure `Credentials` is called with the right values of secrets + mock_credentials = mocker.patch("toucan_connectors.google_credentials.Credentials") + get_google_oauth2_credentials(credentials) + mock_credentials.from_service_account_info.assert_called_once_with(conf) def test_unespace_break_lines(): @@ -15,4 +48,9 @@ def test_unespace_break_lines(): "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/xxx.iam.gserviceaccount.com", # noqa: E501 } credentials = GoogleCredentials(**conf) - assert credentials.private_key == "-----BEGIN PRIVATE KEY-----\n" "aaa\n" "bbb\n" "-----END PRIVATE KEY-----\n" + assert ( + credentials.private_key.get_secret_value() == "-----BEGIN PRIVATE KEY-----\n" + "aaa\n" + "bbb\n" + "-----END PRIVATE KEY-----\n" + ) diff --git a/toucan_connectors/google_credentials.py b/toucan_connectors/google_credentials.py index b2a1a36bd..2d6136da6 100644 --- a/toucan_connectors/google_credentials.py +++ b/toucan_connectors/google_credentials.py @@ -1,5 +1,7 @@ +from typing import Annotated + from google.oauth2.service_account import Credentials -from pydantic import BaseModel, Field, HttpUrl, field_validator +from pydantic import BaseModel, Field, HttpUrl, PlainSerializer, SecretStr, field_validator CREDENTIALS_INFO_MESSAGE = ( "This information is provided in your " @@ -22,11 +24,15 @@ class JWTCredentials(BaseModel): ) +# The lambda is ugly but pydantic does signature inspection, which does not work with built-in types +StrLikeHttpUrl = Annotated[HttpUrl, PlainSerializer(lambda x: str(x), return_type=str, when_used="always")] + + class GoogleCredentials(BaseModel): type: str = Field("service_account", title="Service account", description=CREDENTIALS_INFO_MESSAGE) project_id: str = Field(..., title="Project ID", description=CREDENTIALS_INFO_MESSAGE) - private_key_id: str = Field(..., title="Private Key ID", description=CREDENTIALS_INFO_MESSAGE) - private_key: str = Field( + private_key_id: SecretStr = Field(..., title="Private Key ID", description=CREDENTIALS_INFO_MESSAGE) + private_key: SecretStr = Field( ..., title="Private Key", description=f"A private key in the form " @@ -34,22 +40,22 @@ class GoogleCredentials(BaseModel): ) client_email: str = Field(..., title="Client email", description=CREDENTIALS_INFO_MESSAGE) client_id: str = Field(..., title="Client ID", description=CREDENTIALS_INFO_MESSAGE) - auth_uri: HttpUrl = Field( + auth_uri: StrLikeHttpUrl = Field( "https://accounts.google.com/o/oauth2/auth", title="Authentication URI", description=CREDENTIALS_INFO_MESSAGE, ) - token_uri: HttpUrl = Field( + token_uri: StrLikeHttpUrl = Field( "https://oauth2.googleapis.com/token", title="Token URI", description=f"{CREDENTIALS_INFO_MESSAGE}. You should not need to change the default value.", ) - auth_provider_x509_cert_url: HttpUrl = Field( + auth_provider_x509_cert_url: StrLikeHttpUrl = Field( "https://www.googleapis.com/oauth2/v1/certs", title="Authentication provider X509 certificate URL", description=f"{CREDENTIALS_INFO_MESSAGE}. You should not need to change the default value.", ) - client_x509_cert_url: HttpUrl = Field( + client_x509_cert_url: StrLikeHttpUrl = Field( "https://www.client_cert.test", title="Client X509 certification URL", description=CREDENTIALS_INFO_MESSAGE, @@ -57,15 +63,18 @@ class GoogleCredentials(BaseModel): @field_validator("private_key") @classmethod - def unescape_break_lines(cls, v): + def unescape_break_lines(cls, v: SecretStr) -> SecretStr: """ `private_key` is a long string like '-----BEGIN PRIVATE KEY-----\nxxx...zzz\n-----END PRIVATE KEY-----\n As the breaking line are often escaped by the client, we need to be sure it's unescaped """ - return v.replace("\\n", "\n") + return SecretStr(v.get_secret_value().replace("\\n", "\n")) def get_google_oauth2_credentials(google_credentials: GoogleCredentials) -> Credentials: - return Credentials.from_service_account_info(google_credentials.dict()) + creds = google_credentials.model_dump() + for secret_field in ("private_key_id", "private_key"): + creds[secret_field] = creds[secret_field].get_secret_value() + return Credentials.from_service_account_info(creds)