From 8727e91739f2d04f2417705e098331362b8ac5c8 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Thu, 24 Oct 2024 10:18:05 -0400 Subject: [PATCH] feat(firestore): Add Firestore Multi Database Support (#818) * Added multi db support for firestore and firestore_async * Added unit and integration tests * fix docs strings --- firebase_admin/firestore.py | 88 ++++++++++++++++----------- firebase_admin/firestore_async.py | 94 ++++++++++++++++------------- integration/test_firestore.py | 55 +++++++++++++++++ integration/test_firestore_async.py | 69 +++++++++++++++++++-- tests/test_firestore.py | 86 ++++++++++++++++++++++++++ tests/test_firestore_async.py | 86 ++++++++++++++++++++++++++ 6 files changed, 396 insertions(+), 82 deletions(-) diff --git a/firebase_admin/firestore.py b/firebase_admin/firestore.py index 224ba3aeb..52ea90671 100644 --- a/firebase_admin/firestore.py +++ b/firebase_admin/firestore.py @@ -18,59 +18,75 @@ Firebase apps. This requires the ``google-cloud-firestore`` Python module. """ +from __future__ import annotations +from typing import Optional, Dict +from firebase_admin import App +from firebase_admin import _utils + try: - from google.cloud import firestore # pylint: disable=import-error,no-name-in-module + from google.cloud import firestore + from google.cloud.firestore_v1.base_client import DEFAULT_DATABASE existing = globals().keys() for key, value in firestore.__dict__.items(): if not key.startswith('_') and key not in existing: globals()[key] = value -except ImportError: +except ImportError as error: raise ImportError('Failed to import the Cloud Firestore library for Python. Make sure ' - 'to install the "google-cloud-firestore" module.') - -from firebase_admin import _utils + 'to install the "google-cloud-firestore" module.') from error _FIRESTORE_ATTRIBUTE = '_firestore' -def client(app=None) -> firestore.Client: +def client(app: Optional[App] = None, database_id: Optional[str] = None) -> firestore.Client: """Returns a client that can be used to interact with Google Cloud Firestore. Args: - app: An App instance (optional). + app: An App instance (optional). + database_id: The database ID of the Google Cloud Firestore database to be used. + Defaults to the default Firestore database ID if not specified or an empty string + (optional). Returns: - google.cloud.firestore.Firestore: A `Firestore Client`_. + google.cloud.firestore.Firestore: A `Firestore Client`_. Raises: - ValueError: If a project ID is not specified either via options, credentials or - environment variables, or if the specified project ID is not a valid string. + ValueError: If the specified database ID is not a valid string, or if a project ID is not + specified either via options, credentials or environment variables, or if the specified + project ID is not a valid string. - .. _Firestore Client: https://googlecloudplatform.github.io/google-cloud-python/latest\ - /firestore/client.html + .. _Firestore Client: https://cloud.google.com/python/docs/reference/firestore/latest/\ + google.cloud.firestore_v1.client.Client """ - fs_client = _utils.get_app_service(app, _FIRESTORE_ATTRIBUTE, _FirestoreClient.from_app) - return fs_client.get() - - -class _FirestoreClient: - """Holds a Google Cloud Firestore client instance.""" - - def __init__(self, credentials, project): - self._client = firestore.Client(credentials=credentials, project=project) - - def get(self): - return self._client - - @classmethod - def from_app(cls, app): - """Creates a new _FirestoreClient for the specified app.""" - credentials = app.credential.get_credential() - project = app.project_id - if not project: - raise ValueError( - 'Project ID is required to access Firestore. Either set the projectId option, ' - 'or use service account credentials. Alternatively, set the GOOGLE_CLOUD_PROJECT ' - 'environment variable.') - return _FirestoreClient(credentials, project) + # Validate database_id + if database_id is not None and not isinstance(database_id, str): + raise ValueError(f'database_id "{database_id}" must be a string or None.') + fs_service = _utils.get_app_service(app, _FIRESTORE_ATTRIBUTE, _FirestoreService) + return fs_service.get_client(database_id) + + +class _FirestoreService: + """Service that maintains a collection of firestore clients.""" + + def __init__(self, app: App) -> None: + self._app: App = app + self._clients: Dict[str, firestore.Client] = {} + + def get_client(self, database_id: Optional[str]) -> firestore.Client: + """Creates a client based on the database_id. These clients are cached.""" + database_id = database_id or DEFAULT_DATABASE + if database_id not in self._clients: + # Create a new client and cache it in _clients + credentials = self._app.credential.get_credential() + project = self._app.project_id + if not project: + raise ValueError( + 'Project ID is required to access Firestore. Either set the projectId option, ' + 'or use service account credentials. Alternatively, set the ' + 'GOOGLE_CLOUD_PROJECT environment variable.') + + fs_client = firestore.Client( + credentials=credentials, project=project, database=database_id) + self._clients[database_id] = fs_client + + return self._clients[database_id] diff --git a/firebase_admin/firestore_async.py b/firebase_admin/firestore_async.py index a63d5a761..4a197e9df 100644 --- a/firebase_admin/firestore_async.py +++ b/firebase_admin/firestore_async.py @@ -18,65 +18,75 @@ associated with Firebase apps. This requires the ``google-cloud-firestore`` Python module. """ -from typing import Type - -from firebase_admin import ( - App, - _utils, -) -from firebase_admin.credentials import Base +from __future__ import annotations +from typing import Optional, Dict +from firebase_admin import App +from firebase_admin import _utils try: - from google.cloud import firestore # type: ignore # pylint: disable=import-error,no-name-in-module + from google.cloud import firestore + from google.cloud.firestore_v1.base_client import DEFAULT_DATABASE existing = globals().keys() for key, value in firestore.__dict__.items(): if not key.startswith('_') and key not in existing: globals()[key] = value -except ImportError: +except ImportError as error: raise ImportError('Failed to import the Cloud Firestore library for Python. Make sure ' - 'to install the "google-cloud-firestore" module.') + 'to install the "google-cloud-firestore" module.') from error + _FIRESTORE_ASYNC_ATTRIBUTE: str = '_firestore_async' -def client(app: App = None) -> firestore.AsyncClient: +def client(app: Optional[App] = None, database_id: Optional[str] = None) -> firestore.AsyncClient: """Returns an async client that can be used to interact with Google Cloud Firestore. Args: - app: An App instance (optional). + app: An App instance (optional). + database_id: The database ID of the Google Cloud Firestore database to be used. + Defaults to the default Firestore database ID if not specified or an empty string + (optional). Returns: - google.cloud.firestore.Firestore_Async: A `Firestore Async Client`_. + google.cloud.firestore.Firestore_Async: A `Firestore Async Client`_. Raises: - ValueError: If a project ID is not specified either via options, credentials or - environment variables, or if the specified project ID is not a valid string. + ValueError: If the specified database ID is not a valid string, or if a project ID is not + specified either via options, credentials or environment variables, or if the specified + project ID is not a valid string. - .. _Firestore Async Client: https://googleapis.dev/python/firestore/latest/client.html + .. _Firestore Async Client: https://cloud.google.com/python/docs/reference/firestore/latest/\ + google.cloud.firestore_v1.async_client.AsyncClient """ - fs_client = _utils.get_app_service( - app, _FIRESTORE_ASYNC_ATTRIBUTE, _FirestoreAsyncClient.from_app) - return fs_client.get() - - -class _FirestoreAsyncClient: - """Holds a Google Cloud Firestore Async Client instance.""" - - def __init__(self, credentials: Type[Base], project: str) -> None: - self._client = firestore.AsyncClient(credentials=credentials, project=project) - - def get(self) -> firestore.AsyncClient: - return self._client - - @classmethod - def from_app(cls, app: App) -> "_FirestoreAsyncClient": - # Replace remove future reference quotes by importing annotations in Python 3.7+ b/238779406 - """Creates a new _FirestoreAsyncClient for the specified app.""" - credentials = app.credential.get_credential() - project = app.project_id - if not project: - raise ValueError( - 'Project ID is required to access Firestore. Either set the projectId option, ' - 'or use service account credentials. Alternatively, set the GOOGLE_CLOUD_PROJECT ' - 'environment variable.') - return _FirestoreAsyncClient(credentials, project) + # Validate database_id + if database_id is not None and not isinstance(database_id, str): + raise ValueError(f'database_id "{database_id}" must be a string or None.') + + fs_service = _utils.get_app_service(app, _FIRESTORE_ASYNC_ATTRIBUTE, _FirestoreAsyncService) + return fs_service.get_client(database_id) + +class _FirestoreAsyncService: + """Service that maintains a collection of firestore async clients.""" + + def __init__(self, app: App) -> None: + self._app: App = app + self._clients: Dict[str, firestore.AsyncClient] = {} + + def get_client(self, database_id: Optional[str]) -> firestore.AsyncClient: + """Creates an async client based on the database_id. These clients are cached.""" + database_id = database_id or DEFAULT_DATABASE + if database_id not in self._clients: + # Create a new client and cache it in _clients + credentials = self._app.credential.get_credential() + project = self._app.project_id + if not project: + raise ValueError( + 'Project ID is required to access Firestore. Either set the projectId option, ' + 'or use service account credentials. Alternatively, set the ' + 'GOOGLE_CLOUD_PROJECT environment variable.') + + fs_client = firestore.AsyncClient( + credentials=credentials, project=project, database=database_id) + self._clients[database_id] = fs_client + + return self._clients[database_id] diff --git a/integration/test_firestore.py b/integration/test_firestore.py index 2bc3d1931..fd39d9b8a 100644 --- a/integration/test_firestore.py +++ b/integration/test_firestore.py @@ -17,6 +17,20 @@ from firebase_admin import firestore +_CITY = { + 'name': u'Mountain View', + 'country': u'USA', + 'population': 77846, + 'capital': False + } + +_MOVIE = { + 'Name': u'Interstellar', + 'Year': 2014, + 'Runtime': u'2h 49m', + 'Academy Award Winner': True + } + def test_firestore(): client = firestore.client() @@ -35,6 +49,47 @@ def test_firestore(): doc.delete() assert doc.get().exists is False +def test_firestore_explicit_database_id(): + client = firestore.client(database_id='testing-database') + expected = _CITY + doc = client.collection('cities').document() + doc.set(expected) + + data = doc.get() + assert data.to_dict() == expected + + doc.delete() + data = doc.get() + assert data.exists is False + +def test_firestore_multi_db(): + city_client = firestore.client() + movie_client = firestore.client(database_id='testing-database') + + expected_city = _CITY + expected_movie = _MOVIE + + city_doc = city_client.collection('cities').document() + movie_doc = movie_client.collection('movies').document() + + city_doc.set(expected_city) + movie_doc.set(expected_movie) + + city_data = city_doc.get() + movie_data = movie_doc.get() + + assert city_data.to_dict() == expected_city + assert movie_data.to_dict() == expected_movie + + city_doc.delete() + movie_doc.delete() + + city_data = city_doc.get() + movie_data = movie_doc.get() + + assert city_data.exists is False + assert movie_data.exists is False + def test_server_timestamp(): client = firestore.client() expected = { diff --git a/integration/test_firestore_async.py b/integration/test_firestore_async.py index 2a5b93217..8b73dda0f 100644 --- a/integration/test_firestore_async.py +++ b/integration/test_firestore_async.py @@ -13,20 +13,31 @@ # limitations under the License. """Integration tests for firebase_admin.firestore_async module.""" +import asyncio import datetime import pytest from firebase_admin import firestore_async -@pytest.mark.asyncio -async def test_firestore_async(): - client = firestore_async.client() - expected = { +_CITY = { 'name': u'Mountain View', 'country': u'USA', 'population': 77846, 'capital': False } + +_MOVIE = { + 'Name': u'Interstellar', + 'Year': 2014, + 'Runtime': u'2h 49m', + 'Academy Award Winner': True + } + + +@pytest.mark.asyncio +async def test_firestore_async(): + client = firestore_async.client() + expected = _CITY doc = client.collection('cities').document() await doc.set(expected) @@ -37,6 +48,56 @@ async def test_firestore_async(): data = await doc.get() assert data.exists is False +@pytest.mark.asyncio +async def test_firestore_async_explicit_database_id(): + client = firestore_async.client(database_id='testing-database') + expected = _CITY + doc = client.collection('cities').document() + await doc.set(expected) + + data = await doc.get() + assert data.to_dict() == expected + + await doc.delete() + data = await doc.get() + assert data.exists is False + +@pytest.mark.asyncio +async def test_firestore_async_multi_db(): + city_client = firestore_async.client() + movie_client = firestore_async.client(database_id='testing-database') + + expected_city = _CITY + expected_movie = _MOVIE + + city_doc = city_client.collection('cities').document() + movie_doc = movie_client.collection('movies').document() + + await asyncio.gather( + city_doc.set(expected_city), + movie_doc.set(expected_movie) + ) + + data = await asyncio.gather( + city_doc.get(), + movie_doc.get() + ) + + assert data[0].to_dict() == expected_city + assert data[1].to_dict() == expected_movie + + await asyncio.gather( + city_doc.delete(), + movie_doc.delete() + ) + + data = await asyncio.gather( + city_doc.get(), + movie_doc.get() + ) + assert data[0].exists is False + assert data[1].exists is False + @pytest.mark.asyncio async def test_server_timestamp(): client = firestore_async.client() diff --git a/tests/test_firestore.py b/tests/test_firestore.py index 768eb637e..47debd54b 100644 --- a/tests/test_firestore.py +++ b/tests/test_firestore.py @@ -50,6 +50,7 @@ def test_project_id(self): client = firestore.client() assert client is not None assert client.project == 'explicit-project-id' + assert client._database == '(default)' def test_project_id_with_explicit_app(self): cred = credentials.Certificate(testutils.resource_filename('service_account.json')) @@ -57,6 +58,7 @@ def test_project_id_with_explicit_app(self): client = firestore.client(app=app) assert client is not None assert client.project == 'explicit-project-id' + assert client._database == '(default)' def test_service_account(self): cred = credentials.Certificate(testutils.resource_filename('service_account.json')) @@ -64,6 +66,7 @@ def test_service_account(self): client = firestore.client() assert client is not None assert client.project == 'mock-project-id' + assert client._database == '(default)' def test_service_account_with_explicit_app(self): cred = credentials.Certificate(testutils.resource_filename('service_account.json')) @@ -71,6 +74,89 @@ def test_service_account_with_explicit_app(self): client = firestore.client(app=app) assert client is not None assert client.project == 'mock-project-id' + assert client._database == '(default)' + + @pytest.mark.parametrize('database_id', [123, False, True, {}, []]) + def test_invalid_database_id(self, database_id): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + firebase_admin.initialize_app(cred) + with pytest.raises(ValueError) as excinfo: + firestore.client(database_id=database_id) + assert str(excinfo.value) == f'database_id "{database_id}" must be a string or None.' + + def test_database_id(self): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + firebase_admin.initialize_app(cred) + database_id = 'mock-database-id' + client = firestore.client(database_id=database_id) + assert client is not None + assert client.project == 'mock-project-id' + assert client._database == 'mock-database-id' + + @pytest.mark.parametrize('database_id', ['', '(default)', None]) + def test_database_id_with_default_id(self, database_id): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + firebase_admin.initialize_app(cred) + client = firestore.client(database_id=database_id) + assert client is not None + assert client.project == 'mock-project-id' + assert client._database == '(default)' + + def test_database_id_with_explicit_app(self): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + app = firebase_admin.initialize_app(cred) + database_id = 'mock-database-id' + client = firestore.client(app, database_id) + assert client is not None + assert client.project == 'mock-project-id' + assert client._database == 'mock-database-id' + + def test_database_id_with_multi_db(self): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + firebase_admin.initialize_app(cred) + database_id_1 = 'mock-database-id-1' + database_id_2 = 'mock-database-id-2' + client_1 = firestore.client(database_id=database_id_1) + client_2 = firestore.client(database_id=database_id_2) + assert (client_1 is not None) and (client_2 is not None) + assert client_1 is not client_2 + assert client_1.project == 'mock-project-id' + assert client_2.project == 'mock-project-id' + assert client_1._database == 'mock-database-id-1' + assert client_2._database == 'mock-database-id-2' + + def test_database_id_with_multi_db_uses_cache(self): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + firebase_admin.initialize_app(cred) + database_id = 'mock-database-id' + client_1 = firestore.client(database_id=database_id) + client_2 = firestore.client(database_id=database_id) + assert (client_1 is not None) and (client_2 is not None) + assert client_1 is client_2 + assert client_1.project == 'mock-project-id' + assert client_2.project == 'mock-project-id' + assert client_1._database == 'mock-database-id' + assert client_2._database == 'mock-database-id' + + def test_database_id_with_multi_db_uses_cache_default(self): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + firebase_admin.initialize_app(cred) + database_id_1 = '' + database_id_2 = '(default)' + client_1 = firestore.client(database_id=database_id_1) + client_2 = firestore.client(database_id=database_id_2) + client_3 = firestore.client() + assert (client_1 is not None) and (client_2 is not None) and (client_3 is not None) + assert client_1 is client_2 + assert client_1 is client_3 + assert client_2 is client_3 + assert client_1.project == 'mock-project-id' + assert client_2.project == 'mock-project-id' + assert client_3.project == 'mock-project-id' + assert client_1._database == '(default)' + assert client_2._database == '(default)' + assert client_3._database == '(default)' + def test_geo_point(self): geo_point = firestore.GeoPoint(10, 20) # pylint: disable=no-member diff --git a/tests/test_firestore_async.py b/tests/test_firestore_async.py index 0fb17c813..3d17cbfc5 100644 --- a/tests/test_firestore_async.py +++ b/tests/test_firestore_async.py @@ -50,6 +50,7 @@ def test_project_id(self): client = firestore_async.client() assert client is not None assert client.project == 'explicit-project-id' + assert client._database == '(default)' def test_project_id_with_explicit_app(self): cred = credentials.Certificate(testutils.resource_filename('service_account.json')) @@ -57,6 +58,7 @@ def test_project_id_with_explicit_app(self): client = firestore_async.client(app=app) assert client is not None assert client.project == 'explicit-project-id' + assert client._database == '(default)' def test_service_account(self): cred = credentials.Certificate(testutils.resource_filename('service_account.json')) @@ -64,6 +66,7 @@ def test_service_account(self): client = firestore_async.client() assert client is not None assert client.project == 'mock-project-id' + assert client._database == '(default)' def test_service_account_with_explicit_app(self): cred = credentials.Certificate(testutils.resource_filename('service_account.json')) @@ -71,6 +74,89 @@ def test_service_account_with_explicit_app(self): client = firestore_async.client(app=app) assert client is not None assert client.project == 'mock-project-id' + assert client._database == '(default)' + + @pytest.mark.parametrize('database_id', [123, False, True, {}, []]) + def test_invalid_database_id(self, database_id): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + firebase_admin.initialize_app(cred) + with pytest.raises(ValueError) as excinfo: + firestore_async.client(database_id=database_id) + assert str(excinfo.value) == f'database_id "{database_id}" must be a string or None.' + + def test_database_id(self): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + firebase_admin.initialize_app(cred) + database_id = 'mock-database-id' + client = firestore_async.client(database_id=database_id) + assert client is not None + assert client.project == 'mock-project-id' + assert client._database == 'mock-database-id' + + @pytest.mark.parametrize('database_id', ['', '(default)', None]) + def test_database_id_with_default_id(self, database_id): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + firebase_admin.initialize_app(cred) + client = firestore_async.client(database_id=database_id) + assert client is not None + assert client.project == 'mock-project-id' + assert client._database == '(default)' + + def test_database_id_with_explicit_app(self): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + app = firebase_admin.initialize_app(cred) + database_id = 'mock-database-id' + client = firestore_async.client(app, database_id) + assert client is not None + assert client.project == 'mock-project-id' + assert client._database == 'mock-database-id' + + def test_database_id_with_multi_db(self): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + firebase_admin.initialize_app(cred) + database_id_1 = 'mock-database-id-1' + database_id_2 = 'mock-database-id-2' + client_1 = firestore_async.client(database_id=database_id_1) + client_2 = firestore_async.client(database_id=database_id_2) + assert (client_1 is not None) and (client_2 is not None) + assert client_1 is not client_2 + assert client_1.project == 'mock-project-id' + assert client_2.project == 'mock-project-id' + assert client_1._database == 'mock-database-id-1' + assert client_2._database == 'mock-database-id-2' + + def test_database_id_with_multi_db_uses_cache(self): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + firebase_admin.initialize_app(cred) + database_id = 'mock-database-id' + client_1 = firestore_async.client(database_id=database_id) + client_2 = firestore_async.client(database_id=database_id) + assert (client_1 is not None) and (client_2 is not None) + assert client_1 is client_2 + assert client_1.project == 'mock-project-id' + assert client_2.project == 'mock-project-id' + assert client_1._database == 'mock-database-id' + assert client_2._database == 'mock-database-id' + + def test_database_id_with_multi_db_uses_cache_default(self): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + firebase_admin.initialize_app(cred) + database_id_1 = '' + database_id_2 = '(default)' + client_1 = firestore_async.client(database_id=database_id_1) + client_2 = firestore_async.client(database_id=database_id_2) + client_3 = firestore_async.client() + assert (client_1 is not None) and (client_2 is not None) and (client_3 is not None) + assert client_1 is client_2 + assert client_1 is client_3 + assert client_2 is client_3 + assert client_1.project == 'mock-project-id' + assert client_2.project == 'mock-project-id' + assert client_3.project == 'mock-project-id' + assert client_1._database == '(default)' + assert client_2._database == '(default)' + assert client_3._database == '(default)' + def test_geo_point(self): geo_point = firestore_async.GeoPoint(10, 20) # pylint: disable=no-member