From 60f27dc3f52ec9352cf8866e07a817ae1dddfa35 Mon Sep 17 00:00:00 2001 From: annahope Date: Wed, 16 Mar 2022 18:01:14 -0500 Subject: [PATCH 01/25] add AsyncDocument --- mockfirestore/async_document.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 mockfirestore/async_document.py diff --git a/mockfirestore/async_document.py b/mockfirestore/async_document.py new file mode 100644 index 0000000..21186c5 --- /dev/null +++ b/mockfirestore/async_document.py @@ -0,0 +1,16 @@ +from typing import Dict, Any +from mockfirestore.document import DocumentReference, DocumentSnapshot + + +class AsyncDocumentReference(DocumentReference): + async def get(self) -> DocumentSnapshot: + return super().get() + + async def delete(self): + super().delete() + + async def set(self, data: Dict[str, Any], merge=False): + super().set(data, merge=merge) + + async def update(self, data: Dict[str, Any]): + super().update(data) From f40c4d00af66782f29038a3060c387cbb2e93ede Mon Sep 17 00:00:00 2001 From: annahope Date: Wed, 16 Mar 2022 18:09:29 -0500 Subject: [PATCH 02/25] Fix Query.get since it's not actually deprecated --- mockfirestore/query.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mockfirestore/query.py b/mockfirestore/query.py index 7a4618d..f7a946d 100644 --- a/mockfirestore/query.py +++ b/mockfirestore/query.py @@ -52,10 +52,8 @@ def stream(self, transaction=None) -> Iterator[DocumentSnapshot]: return iter(doc_snapshots) - def get(self) -> Iterator[DocumentSnapshot]: - warnings.warn('Query.get is deprecated, please use Query.stream', - category=DeprecationWarning) - return self.stream() + def get(self, transaction=None) -> List[DocumentSnapshot]: + return list(self.stream()) def _add_field_filter(self, field: str, op: str, value: Any): compare = self._compare_func(op) From 3c8c623ef109c2e24454c55a6b87f0255b28fb1a Mon Sep 17 00:00:00 2001 From: annahope Date: Wed, 16 Mar 2022 18:11:32 -0500 Subject: [PATCH 03/25] Add AsyncQuery --- mockfirestore/async_query.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 mockfirestore/async_query.py diff --git a/mockfirestore/async_query.py b/mockfirestore/async_query.py new file mode 100644 index 0000000..4e03100 --- /dev/null +++ b/mockfirestore/async_query.py @@ -0,0 +1,16 @@ +from typing import AsyncIterator, List +from mockfirestore.document import DocumentSnapshot +from mockfirestore.query import Query + + +class AsyncQuery(Query): + async def stream(self, transaction=None) -> AsyncIterator[DocumentSnapshot]: + doc_snapshots = super().stream() + for doc_snapshot in doc_snapshots: + yield doc_snapshot + + async def get(self, transaction=None) -> List[DocumentSnapshot]: + return super().get() + + + From 41949bf8eb125879149c220759e01138f2a57f7a Mon Sep 17 00:00:00 2001 From: annahope Date: Wed, 16 Mar 2022 18:19:42 -0500 Subject: [PATCH 04/25] Fix Collection.get since it's not deprecated --- mockfirestore/collection.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mockfirestore/collection.py b/mockfirestore/collection.py index 431c074..1dfb84f 100644 --- a/mockfirestore/collection.py +++ b/mockfirestore/collection.py @@ -23,10 +23,8 @@ def document(self, document_id: Optional[str] = None) -> DocumentReference: set_by_path(self._data, new_path, {}) return DocumentReference(self._data, new_path, parent=self) - def get(self) -> Iterable[DocumentSnapshot]: - warnings.warn('Collection.get is deprecated, please use Collection.stream', - category=DeprecationWarning) - return self.stream() + def get(self) -> List[DocumentSnapshot]: + return list(self.stream()) def add(self, document_data: Dict, document_id: str = None) \ -> Tuple[Timestamp, DocumentReference]: From 76c1bda2479d132e04f46a0222d63dc794c70180 Mon Sep 17 00:00:00 2001 From: annahope Date: Wed, 16 Mar 2022 19:28:15 -0500 Subject: [PATCH 05/25] Add AsyncCollection --- mockfirestore/async_collection.py | 36 +++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 mockfirestore/async_collection.py diff --git a/mockfirestore/async_collection.py b/mockfirestore/async_collection.py new file mode 100644 index 0000000..145afd7 --- /dev/null +++ b/mockfirestore/async_collection.py @@ -0,0 +1,36 @@ +from typing import Optional, List, Tuple, Dict, AsyncIterator +from mockfirestore.async_document import AsyncDocumentReference +from mockfirestore.collection import CollectionReference +from mockfirestore.document import DocumentSnapshot, DocumentReference +from mockfirestore._helpers import Timestamp, get_by_path + + +class AsyncCollectionReference(CollectionReference): + def document(self, document_id: Optional[str] = None) -> AsyncDocumentReference: + doc_ref = super().document(document_id) + return AsyncDocumentReference( + doc_ref._data, doc_ref._path, parent=doc_ref.parent + ) + + async def get(self) -> List[DocumentSnapshot]: + return super().get() + + async def add( + self, document_data: Dict, document_id: str = None + ) -> Tuple[Timestamp, AsyncDocumentReference]: + timestamp, doc_ref = super().add(document_data, document_id=document_id) + async_doc_ref = AsyncDocumentReference( + doc_ref._data, doc_ref._path, parent=doc_ref.parent + ) + return timestamp, async_doc_ref + + async def list_documents( + self, page_size: Optional[int] = None + ) -> AsyncIterator[DocumentReference]: + docs = super().list_documents() + for doc in docs: + yield doc + + async def stream(self, transaction=None) -> AsyncIterator[DocumentSnapshot]: + for doc_snapshot in super().stream(): + yield doc_snapshot From 69cf7f82b48e4bd2e72fa9730697381b421955e1 Mon Sep 17 00:00:00 2001 From: annahope Date: Wed, 16 Mar 2022 20:02:49 -0500 Subject: [PATCH 06/25] Add AsyncTransaction --- mockfirestore/async_transaction.py | 38 ++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 mockfirestore/async_transaction.py diff --git a/mockfirestore/async_transaction.py b/mockfirestore/async_transaction.py new file mode 100644 index 0000000..3ceef02 --- /dev/null +++ b/mockfirestore/async_transaction.py @@ -0,0 +1,38 @@ +from typing import AsyncIterable, Iterable + +from mockfirestore.async_document import AsyncDocumentReference +from mockfirestore.document import DocumentSnapshot +from mockfirestore.transaction import Transaction, WriteResult + + +class AsyncTransaction(Transaction): + async def _begin(self, retry_id=None): + return super()._begin() + + async def _rollback(self): + super()._rollback() + + async def _commit(self) -> Iterable[WriteResult]: + return super()._commit() + + async def get(self, ref_or_query) -> AsyncIterable[DocumentSnapshot]: + doc_snapshots = super().get(ref_or_query) + for doc_snapshot in doc_snapshots: + yield doc_snapshot + + async def get_all( + self, references: Iterable[AsyncDocumentReference] + ) -> AsyncIterable[DocumentSnapshot]: + doc_snapshots = super().get_all(references) + for doc_snapshot in doc_snapshots: + yield doc_snapshot + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + await self.commit() + + + From 2d341c801e5a60b40b17f03ef7161eb75fdac52e Mon Sep 17 00:00:00 2001 From: annahope Date: Wed, 16 Mar 2022 20:06:45 -0500 Subject: [PATCH 07/25] Change type to isinstance to work with subclasses --- mockfirestore/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mockfirestore/client.py b/mockfirestore/client.py index 75943bd..87c67a2 100644 --- a/mockfirestore/client.py +++ b/mockfirestore/client.py @@ -13,7 +13,7 @@ def _ensure_path(self, path): current_position = self for el in path[:-1]: - if type(current_position) in (MockFirestore, DocumentReference): + if isinstance(current_position, (MockFirestore, DocumentReference)): current_position = current_position.collection(el) else: current_position = current_position.document(el) From 26ccfa0ee481beaab2bc34bb5af09c0a2323c200 Mon Sep 17 00:00:00 2001 From: annahope Date: Wed, 16 Mar 2022 20:20:15 -0500 Subject: [PATCH 08/25] Add AsyncClient --- mockfirestore/__init__.py | 6 +++++ mockfirestore/async_client.py | 41 +++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 mockfirestore/async_client.py diff --git a/mockfirestore/__init__.py b/mockfirestore/__init__.py index a7f18de..fcefaba 100644 --- a/mockfirestore/__init__.py +++ b/mockfirestore/__init__.py @@ -13,3 +13,9 @@ from mockfirestore.query import Query from mockfirestore._helpers import Timestamp from mockfirestore.transaction import Transaction + +from mockfirestore.async_client import AsyncMockFirestore +from mockfirestore.async_document import AsyncDocumentReference +from mockfirestore.async_collection import AsyncCollectionReference +from mockfirestore.async_query import AsyncQuery +from mockfirestore.async_transaction import AsyncTransaction diff --git a/mockfirestore/async_client.py b/mockfirestore/async_client.py new file mode 100644 index 0000000..65a2141 --- /dev/null +++ b/mockfirestore/async_client.py @@ -0,0 +1,41 @@ +from typing import AsyncIterable, Iterable + +from mockfirestore.async_document import AsyncDocumentReference +from mockfirestore.async_collection import AsyncCollectionReference +from mockfirestore.async_transaction import AsyncTransaction +from mockfirestore.client import MockFirestore +from mockfirestore.document import DocumentSnapshot + + +class AsyncMockFirestore(MockFirestore): + def document(self, path: str) -> AsyncDocumentReference: + doc = super().document(path) + assert isinstance(doc, AsyncDocumentReference) + return doc + + def collection(self, path: str) -> AsyncCollectionReference: + path = path.split("/") + + if len(path) % 2 != 1: + raise Exception("Cannot create collection at path {}".format(path)) + + name = path[-1] + if len(path) > 1: + current_position = self._ensure_path(path) + return current_position.collection(name) + else: + if name not in self._data: + self._data[name] = {} + return AsyncCollectionReference(self._data, [name]) + + async def get_all( + self, + references: Iterable[AsyncDocumentReference], + field_paths=None, + transaction=None, + ) -> AsyncIterable[DocumentSnapshot]: + for doc_ref in set(references): + yield doc_ref.get() + + def transaction(self, **kwargs) -> AsyncTransaction: + return AsyncTransaction(self, **kwargs) From e571eefdd6a015cd6803e30fd381c3f52e27ed4b Mon Sep 17 00:00:00 2001 From: annahope Date: Wed, 16 Mar 2022 20:45:10 -0500 Subject: [PATCH 09/25] Implement collection on AsyncDocumentReference --- mockfirestore/async_document.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mockfirestore/async_document.py b/mockfirestore/async_document.py index 21186c5..978817f 100644 --- a/mockfirestore/async_document.py +++ b/mockfirestore/async_document.py @@ -14,3 +14,8 @@ async def set(self, data: Dict[str, Any], merge=False): async def update(self, data: Dict[str, Any]): super().update(data) + + def collection(self, name) -> 'AsyncCollectionReference': + from mockfirestore.async_collection import AsyncCollectionReference + coll_ref = super().collection(name) + return AsyncCollectionReference(coll_ref._data, coll_ref._path, self) From 1d3fc91d7353f0f49f6cc12da38a2c552a5fb32e Mon Sep 17 00:00:00 2001 From: annahope Date: Thu, 17 Mar 2022 12:29:57 -0500 Subject: [PATCH 10/25] Add tests for AsyncDocumentReference --- mockfirestore/async_document.py | 10 +- requirements-dev-minimal.txt | 3 +- tests/test_async_collection_reference.py | 7 + tests/test_async_document_reference.py | 337 +++++++++++++++++++++++ 4 files changed, 355 insertions(+), 2 deletions(-) create mode 100644 tests/test_async_collection_reference.py create mode 100644 tests/test_async_document_reference.py diff --git a/mockfirestore/async_document.py b/mockfirestore/async_document.py index 978817f..f197044 100644 --- a/mockfirestore/async_document.py +++ b/mockfirestore/async_document.py @@ -1,4 +1,6 @@ +from copy import deepcopy from typing import Dict, Any +from mockfirestore import NotFound from mockfirestore.document import DocumentReference, DocumentSnapshot @@ -10,7 +12,13 @@ async def delete(self): super().delete() async def set(self, data: Dict[str, Any], merge=False): - super().set(data, merge=merge) + if merge: + try: + await self.update(deepcopy(data)) + except NotFound: + await self.set(data) + else: + super().set(data, merge=merge) async def update(self, data: Dict[str, Any]): super().update(data) diff --git a/requirements-dev-minimal.txt b/requirements-dev-minimal.txt index 38604d8..3a3bf3f 100644 --- a/requirements-dev-minimal.txt +++ b/requirements-dev-minimal.txt @@ -1 +1,2 @@ -google-cloud-firestore \ No newline at end of file +google-cloud-firestore +aiounittest \ No newline at end of file diff --git a/tests/test_async_collection_reference.py b/tests/test_async_collection_reference.py new file mode 100644 index 0000000..68afcdc --- /dev/null +++ b/tests/test_async_collection_reference.py @@ -0,0 +1,7 @@ +import aiounittest + + +class TestAsyncCollectionReference(aiounittest.AsyncTestCase): + def test_something(self): + self.assertEqual(True, False) # add assertion here + diff --git a/tests/test_async_document_reference.py b/tests/test_async_document_reference.py new file mode 100644 index 0000000..aec8ef4 --- /dev/null +++ b/tests/test_async_document_reference.py @@ -0,0 +1,337 @@ +import aiounittest + +from google.cloud import firestore +from mockfirestore import AsyncMockFirestore, NotFound + + +class TestAsyncDocumentReference(aiounittest.AsyncTestCase): + async def test_get_document_by_path(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}}} + doc = await fs.document("foo/first").get() + self.assertEqual({"id": 1}, doc.to_dict()) + self.assertEqual("first", doc.id) + + async def test_set_document_by_path(self): + fs = AsyncMockFirestore() + fs._data = {} + doc_content = {"id": "bar"} + await fs.document("foo/doc1/bar/doc2").set(doc_content) + doc = await fs.document("foo/doc1/bar/doc2").get() + doc = doc.to_dict() + self.assertEqual(doc_content, doc) + + async def test_document_get_returnsDocument(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}}} + doc = await fs.collection("foo").document("first").get() + self.assertEqual({"id": 1}, doc.to_dict()) + self.assertEqual("first", doc.id) + + async def test_document_get_documentIdEqualsKey(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}}} + doc_ref = fs.collection("foo").document("first") + self.assertEqual("first", doc_ref.id) + + async def test_document_get_newDocumentReturnsDefaultId(self): + fs = AsyncMockFirestore() + doc_ref = fs.collection("foo").document() + doc = await doc_ref.get() + self.assertNotEqual(None, doc_ref.id) + self.assertFalse(doc.exists) + + async def test_document_get_documentDoesNotExist(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {}} + doc = await fs.collection("foo").document("bar").get() + self.assertEqual({}, doc.to_dict()) + + async def test_get_nestedDocument(self): + fs = AsyncMockFirestore() + fs._data = { + "top_collection": { + "top_document": { + "id": 1, + "nested_collection": {"nested_document": {"id": 1.1}}, + } + } + } + doc = ( + await fs.collection("top_collection") + .document("top_document") + .collection("nested_collection") + .document("nested_document") + .get() + ) + + self.assertEqual({"id": 1.1}, doc.to_dict()) + + async def test_get_nestedDocument_documentDoesNotExist(self): + fs = AsyncMockFirestore() + fs._data = { + "top_collection": {"top_document": {"id": 1, "nested_collection": {}}} + } + doc = ( + await fs.collection("top_collection") + .document("top_document") + .collection("nested_collection") + .document("nested_document") + .get() + ) + + self.assertEqual({}, doc.to_dict()) + + async def test_document_set_setsContentOfDocument(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {}} + doc_content = {"id": "bar"} + await fs.collection("foo").document("bar").set(doc_content) + doc = await fs.collection("foo").document("bar").get() + self.assertEqual(doc_content, doc.to_dict()) + + async def test_document_set_mergeNewValue(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}}} + await fs.collection("foo").document("first").set({"updated": True}, merge=True) + doc = await fs.collection("foo").document("first").get() + self.assertEqual({"id": 1, "updated": True}, doc.to_dict()) + + async def test_document_set_mergeNewValueForNonExistentDoc(self): + fs = AsyncMockFirestore() + await fs.collection("foo").document("first").set({"updated": True}, merge=True) + doc = await fs.collection("foo").document("first").get() + self.assertEqual({"updated": True}, doc.to_dict()) + + async def test_document_set_overwriteValue(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}}} + await fs.collection("foo").document("first").set({"new_id": 1}, merge=False) + doc = await fs.collection("foo").document("first").get() + self.assertEqual({"new_id": 1}, doc.to_dict()) + + async def test_document_set_isolation(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {}} + doc_content = {"id": "bar"} + await fs.collection("foo").document("bar").set(doc_content) + doc_content["id"] = "new value" + doc = await fs.collection("foo").document("bar").get() + self.assertEqual({"id": "bar"}, doc.to_dict()) + + async def test_document_update_addNewValue(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}}} + await fs.collection("foo").document("first").update({"updated": True}) + doc = await fs.collection("foo").document("first").get() + self.assertEqual({"id": 1, "updated": True}, doc.to_dict()) + + async def test_document_update_changeExistingValue(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}}} + await fs.collection("foo").document("first").update({"id": 2}) + doc = await fs.collection("foo").document("first").get() + self.assertEqual({"id": 2}, doc.to_dict()) + + async def test_document_update_documentDoesNotExist(self): + fs = AsyncMockFirestore() + with self.assertRaises(NotFound): + await fs.collection("foo").document("nonexistent").update({"id": 2}) + docsnap = await fs.collection("foo").document("nonexistent").get() + self.assertFalse(docsnap.exists) + + async def test_document_update_isolation(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"nested": {"id": 1}}}} + update_doc = {"nested": {"id": 2}} + await fs.collection("foo").document("first").update(update_doc) + update_doc["nested"]["id"] = 3 + doc = await fs.collection("foo").document("first").get() + self.assertEqual({"nested": {"id": 2}}, doc.to_dict()) + + async def test_document_update_transformerIncrementBasic(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"count": 1}}} + await fs.collection("foo").document("first").update( + {"count": firestore.Increment(2)} + ) + + doc = await fs.collection("foo").document("first").get() + self.assertEqual(doc.to_dict(), {"count": 3}) + + async def test_document_update_transformerIncrementNested(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": { + "nested": {"count": 1}, + "other": {"likes": 0}, + } + } + } + await fs.collection("foo").document("first").update( + { + "nested": {"count": firestore.Increment(-1)}, + "other": {"likes": firestore.Increment(1), "smoked": "salmon"}, + } + ) + + doc = await fs.collection("foo").document("first").get() + self.assertEqual( + doc.to_dict(), + {"nested": {"count": 0}, "other": {"likes": 1, "smoked": "salmon"}}, + ) + + async def test_document_update_transformerIncrementNonExistent(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"spicy": "tuna"}}} + await fs.collection("foo").document("first").update( + {"count": firestore.Increment(1)} + ) + + doc = await fs.collection("foo").document("first").get() + self.assertEqual(doc.to_dict(), {"count": 1, "spicy": "tuna"}) + + async def test_document_delete_documentDoesNotExistAfterDelete(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}}} + await fs.collection("foo").document("first").delete() + doc = await fs.collection("foo").document("first").get() + self.assertEqual(False, doc.exists) + + async def test_document_parent(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}}} + coll = fs.collection("foo") + document = coll.document("first") + self.assertIs(document.parent, coll) + + async def test_document_update_transformerArrayUnionBasic(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"arr": [1, 2]}}} + await fs.collection("foo").document("first").update( + {"arr": firestore.ArrayUnion([3, 4])} + ) + doc = await fs.collection("foo").document("first").get() + self.assertEqual(doc.to_dict()["arr"], [1, 2, 3, 4]) + + async def test_document_update_transformerArrayUnionNested(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": { + "nested": {"arr": [1]}, + "other": {"labels": ["a"]}, + } + } + } + await fs.collection("foo").document("first").update( + { + "nested": {"arr": firestore.ArrayUnion([2])}, + "other": {"labels": firestore.ArrayUnion(["b"]), "smoked": "salmon"}, + } + ) + + doc = await fs.collection("foo").document("first").get() + self.assertEqual( + doc.to_dict(), + { + "nested": {"arr": [1, 2]}, + "other": {"labels": ["a", "b"], "smoked": "salmon"}, + }, + ) + + async def test_document_update_transformerArrayUnionNonExistent(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"spicy": "tuna"}}} + await fs.collection("foo").document("first").update( + {"arr": firestore.ArrayUnion([1])} + ) + + doc = await fs.collection("foo").document("first").get() + self.assertEqual(doc.to_dict(), {"arr": [1], "spicy": "tuna"}) + + async def test_document_update_nestedFieldDotNotation(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"nested": {"value": 1, "unchanged": "foo"}}}} + + await fs.collection("foo").document("first").update({"nested.value": 2}) + + doc = await fs.collection("foo").document("first").get() + self.assertEqual(doc.to_dict(), {"nested": {"value": 2, "unchanged": "foo"}}) + + async def test_document_update_nestedFieldDotNotationNestedFieldCreation(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": {"first": {"other": None}} + } # non-existent nested field is created + + await fs.collection("foo").document("first").update({"nested.value": 2}) + + doc = await fs.collection("foo").document("first").get() + self.assertEqual(doc.to_dict(), {"nested": {"value": 2}, "other": None}) + + async def test_document_update_nestedFieldDotNotationMultipleNested(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"other": None}}} + + await fs.collection("foo").document("first").update( + {"nested.subnested.value": 42} + ) + + doc = await fs.collection("foo").document("first").get() + self.assertEqual( + doc.to_dict(), {"nested": {"subnested": {"value": 42}}, "other": None} + ) + + async def test_document_update_nestedFieldDotNotationMultipleNestedWithTransformer( + self, + ): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"other": None}}} + + await fs.collection("foo").document("first").update( + {"nested.subnested.value": firestore.ArrayUnion([1, 3])} + ) + + doc = await fs.collection("foo").document("first").get() + self.assertEqual( + doc.to_dict(), {"nested": {"subnested": {"value": [1, 3]}}, "other": None} + ) + + async def test_document_update_transformerSentinel(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"spicy": "tuna"}}} + await fs.collection("foo").document("first").update( + {"spicy": firestore.DELETE_FIELD} + ) + + doc = await fs.collection("foo").document("first").get() + self.assertEqual(doc.to_dict(), {}) + + async def test_document_update_transformerArrayRemoveBasic(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"arr": [1, 2, 3, 4]}}} + await fs.collection("foo").document("first").update( + {"arr": firestore.ArrayRemove([3, 4])} + ) + doc = await fs.collection("foo").document("first").get() + self.assertEqual(doc.to_dict()["arr"], [1, 2]) + + async def test_document_update_transformerArrayRemoveNonExistentField(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"arr": [1, 2, 3, 4]}}} + await fs.collection("foo").document("first").update( + {"arr": firestore.ArrayRemove([5])} + ) + doc = await fs.collection("foo").document("first").get() + self.assertEqual(doc.to_dict()["arr"], [1, 2, 3, 4]) + + async def test_document_update_transformerArrayRemoveNonExistentArray(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"arr": [1, 2, 3, 4]}}} + await fs.collection("foo").document("first").update( + {"non_existent_array": firestore.ArrayRemove([1, 2])} + ) + doc = await fs.collection("foo").document("first").get() + self.assertEqual(doc.to_dict()["arr"], [1, 2, 3, 4]) From d24a6402c9740cc91f0b3da9473271f8b51baad8 Mon Sep 17 00:00:00 2001 From: annahope Date: Thu, 17 Mar 2022 12:33:03 -0500 Subject: [PATCH 11/25] Update my name and GitHub link in the list of contributors --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 40a4ba9..43cf5af 100644 --- a/README.md +++ b/README.md @@ -113,7 +113,7 @@ transaction.commit() * [Matt Dowds](https://github.com/mdowds) * [Chris Tippett](https://github.com/christippett) -* [Anton Melnikov](https://github.com/notnami) +* [Anna Melnikov](https://github.com/anna-hope) * [Ben Riggleman](https://github.com/briggleman) * [Steve Atwell](https://github.com/satwell) * [ahti123](https://github.com/ahti123) From 32eadfd23e1557a9bcd97f42ea3b7e3e3699c003 Mon Sep 17 00:00:00 2001 From: annahope Date: Thu, 17 Mar 2022 13:05:22 -0500 Subject: [PATCH 12/25] Add helper coroutine to convert async iterable to list --- mockfirestore/_helpers.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/mockfirestore/_helpers.py b/mockfirestore/_helpers.py index f80fc22..0fb6925 100644 --- a/mockfirestore/_helpers.py +++ b/mockfirestore/_helpers.py @@ -3,16 +3,18 @@ import string from datetime import datetime as dt from functools import reduce -from typing import (Dict, Any, Tuple, TypeVar, Sequence, Iterator) +from typing import Dict, Any, Tuple, TypeVar, Sequence, Iterator, AsyncIterable, List -T = TypeVar('T') +T = TypeVar("T") KeyValuePair = Tuple[str, Dict[str, Any]] Document = Dict[str, Any] Collection = Dict[str, Document] Store = Dict[str, Collection] -def get_by_path(data: Dict[str, T], path: Sequence[str], create_nested: bool = False) -> T: +def get_by_path( + data: Dict[str, T], path: Sequence[str], create_nested: bool = False +) -> T: """Access a nested object in root by item sequence.""" def get_or_create(a, b): @@ -26,7 +28,9 @@ def get_or_create(a, b): return reduce(operator.getitem, path, data) -def set_by_path(data: Dict[str, T], path: Sequence[str], value: T, create_nested: bool = True): +def set_by_path( + data: Dict[str, T], path: Sequence[str], value: T, create_nested: bool = True +): """Set a value in a nested object in root by item sequence.""" get_by_path(data, path[:-1], create_nested=True)[path[-1]] = value @@ -37,7 +41,9 @@ def delete_by_path(data: Dict[str, T], path: Sequence[str]): def generate_random_string(): - return ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(20)) + return "".join( + random.choice(string.ascii_letters + string.digits) for _ in range(20) + ) class Timestamp: @@ -55,14 +61,16 @@ def from_now(cls): @property def seconds(self): - return str(self._timestamp).split('.')[0] + return str(self._timestamp).split(".")[0] @property def nanos(self): - return str(self._timestamp).split('.')[1] + return str(self._timestamp).split(".")[1] -def get_document_iterator(document: Dict[str, Any], prefix: str = '') -> Iterator[Tuple[str, Any]]: +def get_document_iterator( + document: Dict[str, Any], prefix: str = "" +) -> Iterator[Tuple[str, Any]]: """ :returns: (dot-delimited path, value,) """ @@ -74,4 +82,8 @@ def get_document_iterator(document: Dict[str, Any], prefix: str = '') -> Iterato if not prefix: yield key, value else: - yield '{}.{}'.format(prefix, key), value + yield "{}.{}".format(prefix, key), value + + +async def consume_async_iterable(iterable: AsyncIterable[T]) -> List[T]: + return [item async for item in iterable] From 1b9d25f3e3ad4e1bd6d713b1ae94c789572e5ffe Mon Sep 17 00:00:00 2001 From: annahope Date: Thu, 17 Mar 2022 13:11:11 -0500 Subject: [PATCH 13/25] Fix where method in AsyncCollection --- mockfirestore/async_collection.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/mockfirestore/async_collection.py b/mockfirestore/async_collection.py index 145afd7..d6cffb9 100644 --- a/mockfirestore/async_collection.py +++ b/mockfirestore/async_collection.py @@ -1,5 +1,6 @@ -from typing import Optional, List, Tuple, Dict, AsyncIterator +from typing import Optional, List, Tuple, Dict, AsyncIterator, Any from mockfirestore.async_document import AsyncDocumentReference +from mockfirestore.async_query import AsyncQuery from mockfirestore.collection import CollectionReference from mockfirestore.document import DocumentSnapshot, DocumentReference from mockfirestore._helpers import Timestamp, get_by_path @@ -32,5 +33,10 @@ async def list_documents( yield doc async def stream(self, transaction=None) -> AsyncIterator[DocumentSnapshot]: - for doc_snapshot in super().stream(): + for key in sorted(get_by_path(self._data, self._path)): + doc_snapshot = await self.document(key).get() yield doc_snapshot + + def where(self, field: str, op: str, value: Any) -> AsyncQuery: + query = AsyncQuery(self, field_filters=[(field, op, value)]) + return query From e62b3bebe8e57d89408fc6d0a6a6cf465231c9e7 Mon Sep 17 00:00:00 2001 From: annahope Date: Thu, 17 Mar 2022 13:15:59 -0500 Subject: [PATCH 14/25] Move pagination code to separate method to allow reuse in AsyncQuery --- mockfirestore/async_query.py | 7 ++++++- mockfirestore/query.py | 21 ++++++++++++--------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/mockfirestore/async_query.py b/mockfirestore/async_query.py index 4e03100..966d66f 100644 --- a/mockfirestore/async_query.py +++ b/mockfirestore/async_query.py @@ -5,7 +5,12 @@ class AsyncQuery(Query): async def stream(self, transaction=None) -> AsyncIterator[DocumentSnapshot]: - doc_snapshots = super().stream() + doc_snapshots = self.parent.stream() + for field, compare, value in self._field_filters: + doc_snapshots = [doc_snapshot async for doc_snapshot in doc_snapshots + if compare(doc_snapshot._get_by_field_path(field), value)] + + doc_snapshots = super()._process_pagination(doc_snapshots) for doc_snapshot in doc_snapshots: yield doc_snapshot diff --git a/mockfirestore/query.py b/mockfirestore/query.py index f7a946d..7f9174c 100644 --- a/mockfirestore/query.py +++ b/mockfirestore/query.py @@ -1,6 +1,5 @@ -import warnings from itertools import islice, tee -from typing import Iterator, Any, Optional, List, Callable, Union +from typing import Iterator, Any, Optional, List, Callable, Union, AsyncIterable from mockfirestore.document import DocumentSnapshot from mockfirestore._helpers import T @@ -24,13 +23,7 @@ def __init__(self, parent: 'CollectionReference', projection=None, for field_filter in field_filters: self._add_field_filter(*field_filter) - def stream(self, transaction=None) -> Iterator[DocumentSnapshot]: - doc_snapshots = self.parent.stream() - - for field, compare, value in self._field_filters: - doc_snapshots = [doc_snapshot for doc_snapshot in doc_snapshots - if compare(doc_snapshot._get_by_field_path(field), value)] - + def _process_pagination(self, doc_snapshots: Iterator[DocumentSnapshot]): if self.orders: for key, direction in self.orders: doc_snapshots = sorted(doc_snapshots, @@ -52,6 +45,16 @@ def stream(self, transaction=None) -> Iterator[DocumentSnapshot]: return iter(doc_snapshots) + def stream(self, transaction=None) -> Iterator[DocumentSnapshot]: + doc_snapshots = self.parent.stream() + + for field, compare, value in self._field_filters: + doc_snapshots = [doc_snapshot for doc_snapshot in doc_snapshots + if compare(doc_snapshot._get_by_field_path(field), value)] + + return self._process_pagination(doc_snapshots) + + def get(self, transaction=None) -> List[DocumentSnapshot]: return list(self.stream()) From e98c80f8f489147df26e42546fe27c1bf821191a Mon Sep 17 00:00:00 2001 From: annahope Date: Thu, 17 Mar 2022 13:18:23 -0500 Subject: [PATCH 15/25] Implement pagination methods for AsyncCollection --- mockfirestore/async_collection.py | 30 ++++++++++- mockfirestore/query.py | 90 ++++++++++++++++++++----------- 2 files changed, 89 insertions(+), 31 deletions(-) diff --git a/mockfirestore/async_collection.py b/mockfirestore/async_collection.py index d6cffb9..2bd8bea 100644 --- a/mockfirestore/async_collection.py +++ b/mockfirestore/async_collection.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Tuple, Dict, AsyncIterator, Any +from typing import Optional, List, Tuple, Dict, AsyncIterator, Any, Union from mockfirestore.async_document import AsyncDocumentReference from mockfirestore.async_query import AsyncQuery from mockfirestore.collection import CollectionReference @@ -40,3 +40,31 @@ async def stream(self, transaction=None) -> AsyncIterator[DocumentSnapshot]: def where(self, field: str, op: str, value: Any) -> AsyncQuery: query = AsyncQuery(self, field_filters=[(field, op, value)]) return query + + def order_by(self, key: str, direction: Optional[str] = None) -> AsyncQuery: + query = AsyncQuery(self, orders=[(key, direction)]) + return query + + def limit(self, limit_amount: int) -> AsyncQuery: + query = AsyncQuery(self, limit=limit_amount) + return query + + def offset(self, offset: int) -> AsyncQuery: + query = AsyncQuery(self, offset=offset) + return query + + def start_at(self, document_fields_or_snapshot: Union[dict, DocumentSnapshot]) -> AsyncQuery: + query = AsyncQuery(self, start_at=(document_fields_or_snapshot, True)) + return query + + def start_after(self, document_fields_or_snapshot: Union[dict, DocumentSnapshot]) -> AsyncQuery: + query = AsyncQuery(self, start_at=(document_fields_or_snapshot, False)) + return query + + def end_at(self, document_fields_or_snapshot: Union[dict, DocumentSnapshot]) -> AsyncQuery: + query = AsyncQuery(self, end_at=(document_fields_or_snapshot, True)) + return query + + def end_before(self, document_fields_or_snapshot: Union[dict, DocumentSnapshot]) -> AsyncQuery: + query = AsyncQuery(self, end_at=(document_fields_or_snapshot, False)) + return query diff --git a/mockfirestore/query.py b/mockfirestore/query.py index 7f9174c..c686318 100644 --- a/mockfirestore/query.py +++ b/mockfirestore/query.py @@ -6,9 +6,18 @@ class Query: - def __init__(self, parent: 'CollectionReference', projection=None, - field_filters=(), orders=(), limit=None, offset=None, - start_at=None, end_at=None, all_descendants=False) -> None: + def __init__( + self, + parent: "CollectionReference", + projection=None, + field_filters=(), + orders=(), + limit=None, + offset=None, + start_at=None, + end_at=None, + all_descendants=False, + ) -> None: self.parent = parent self.projection = projection self._field_filters = [] @@ -26,16 +35,22 @@ def __init__(self, parent: 'CollectionReference', projection=None, def _process_pagination(self, doc_snapshots: Iterator[DocumentSnapshot]): if self.orders: for key, direction in self.orders: - doc_snapshots = sorted(doc_snapshots, - key=lambda doc: doc.to_dict()[key], - reverse=direction == 'DESCENDING') + doc_snapshots = sorted( + doc_snapshots, + key=lambda doc: doc.to_dict()[key], + reverse=direction == "DESCENDING", + ) if self._start_at: document_fields_or_snapshot, before = self._start_at - doc_snapshots = self._apply_cursor(document_fields_or_snapshot, doc_snapshots, before, True) + doc_snapshots = self._apply_cursor( + document_fields_or_snapshot, doc_snapshots, before, True + ) if self._end_at: document_fields_or_snapshot, before = self._end_at - doc_snapshots = self._apply_cursor(document_fields_or_snapshot, doc_snapshots, before, False) + doc_snapshots = self._apply_cursor( + document_fields_or_snapshot, doc_snapshots, before, False + ) if self._offset: doc_snapshots = islice(doc_snapshots, self._offset, None) @@ -49,12 +64,14 @@ def stream(self, transaction=None) -> Iterator[DocumentSnapshot]: doc_snapshots = self.parent.stream() for field, compare, value in self._field_filters: - doc_snapshots = [doc_snapshot for doc_snapshot in doc_snapshots - if compare(doc_snapshot._get_by_field_path(field), value)] + doc_snapshots = [ + doc_snapshot + for doc_snapshot in doc_snapshots + if compare(doc_snapshot._get_by_field_path(field), value) + ] return self._process_pagination(doc_snapshots) - def get(self, transaction=None) -> List[DocumentSnapshot]: return list(self.stream()) @@ -62,40 +79,53 @@ def _add_field_filter(self, field: str, op: str, value: Any): compare = self._compare_func(op) self._field_filters.append((field, compare, value)) - def where(self, field: str, op: str, value: Any) -> 'Query': + def where(self, field: str, op: str, value: Any) -> "Query": self._add_field_filter(field, op, value) return self - def order_by(self, key: str, direction: Optional[str] = 'ASCENDING') -> 'Query': + def order_by(self, key: str, direction: Optional[str] = "ASCENDING") -> "Query": self.orders.append((key, direction)) return self - def limit(self, limit_amount: int) -> 'Query': + def limit(self, limit_amount: int) -> "Query": self._limit = limit_amount return self - def offset(self, offset_amount: int) -> 'Query': + def offset(self, offset_amount: int) -> "Query": self._offset = offset_amount return self - def start_at(self, document_fields_or_snapshot: Union[dict, DocumentSnapshot]) -> 'Query': + def start_at( + self, document_fields_or_snapshot: Union[dict, DocumentSnapshot] + ) -> "Query": self._start_at = (document_fields_or_snapshot, True) return self - def start_after(self, document_fields_or_snapshot: Union[dict, DocumentSnapshot]) -> 'Query': + def start_after( + self, document_fields_or_snapshot: Union[dict, DocumentSnapshot] + ) -> "Query": self._start_at = (document_fields_or_snapshot, False) return self - def end_at(self, document_fields_or_snapshot: Union[dict, DocumentSnapshot]) -> 'Query': + def end_at( + self, document_fields_or_snapshot: Union[dict, DocumentSnapshot] + ) -> "Query": self._end_at = (document_fields_or_snapshot, True) return self - def end_before(self, document_fields_or_snapshot: Union[dict, DocumentSnapshot]) -> 'Query': + def end_before( + self, document_fields_or_snapshot: Union[dict, DocumentSnapshot] + ) -> "Query": self._end_at = (document_fields_or_snapshot, False) return self - def _apply_cursor(self, document_fields_or_snapshot: Union[dict, DocumentSnapshot], doc_snapshot: Iterator[DocumentSnapshot], - before: bool, start: bool) -> Iterator[DocumentSnapshot]: + def _apply_cursor( + self, + document_fields_or_snapshot: Union[dict, DocumentSnapshot], + doc_snapshot: Iterator[DocumentSnapshot], + before: bool, + start: bool, + ) -> Iterator[DocumentSnapshot]: docs, doc_snapshot = tee(doc_snapshot) for idx, doc in enumerate(doc_snapshot): index = None @@ -120,21 +150,21 @@ def _apply_cursor(self, document_fields_or_snapshot: Union[dict, DocumentSnapsho return islice(docs, 0, index, None) def _compare_func(self, op: str) -> Callable[[T, T], bool]: - if op == '==': + if op == "==": return lambda x, y: x == y - elif op == '!=': + elif op == "!=": return lambda x, y: x != y - elif op == '<': + elif op == "<": return lambda x, y: x < y - elif op == '<=': + elif op == "<=": return lambda x, y: x <= y - elif op == '>': + elif op == ">": return lambda x, y: x > y - elif op == '>=': + elif op == ">=": return lambda x, y: x >= y - elif op == 'in': + elif op == "in": return lambda x, y: x in y - elif op == 'array_contains': + elif op == "array_contains": return lambda x, y: y in x - elif op == 'array_contains_any': + elif op == "array_contains_any": return lambda x, y: any([val in y for val in x]) From 5152cbcca5d599e09e40afee15118c877500f3b2 Mon Sep 17 00:00:00 2001 From: annahope Date: Thu, 17 Mar 2022 13:59:54 -0500 Subject: [PATCH 16/25] Move processing field filters to separate method to allow reuse in AsyncQuery --- mockfirestore/async_query.py | 11 +++-------- mockfirestore/query.py | 12 ++++++++---- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/mockfirestore/async_query.py b/mockfirestore/async_query.py index 966d66f..0104dbc 100644 --- a/mockfirestore/async_query.py +++ b/mockfirestore/async_query.py @@ -1,21 +1,16 @@ from typing import AsyncIterator, List from mockfirestore.document import DocumentSnapshot from mockfirestore.query import Query +from mockfirestore._helpers import consume_async_iterable class AsyncQuery(Query): async def stream(self, transaction=None) -> AsyncIterator[DocumentSnapshot]: - doc_snapshots = self.parent.stream() - for field, compare, value in self._field_filters: - doc_snapshots = [doc_snapshot async for doc_snapshot in doc_snapshots - if compare(doc_snapshot._get_by_field_path(field), value)] - + doc_snapshots = await consume_async_iterable(self.parent.stream()) + doc_snapshots = super()._process_field_filters(doc_snapshots) doc_snapshots = super()._process_pagination(doc_snapshots) for doc_snapshot in doc_snapshots: yield doc_snapshot async def get(self, transaction=None) -> List[DocumentSnapshot]: return super().get() - - - diff --git a/mockfirestore/query.py b/mockfirestore/query.py index c686318..9abbb55 100644 --- a/mockfirestore/query.py +++ b/mockfirestore/query.py @@ -1,5 +1,5 @@ from itertools import islice, tee -from typing import Iterator, Any, Optional, List, Callable, Union, AsyncIterable +from typing import Iterator, Any, Optional, List, Callable, Union, Iterable from mockfirestore.document import DocumentSnapshot from mockfirestore._helpers import T @@ -60,16 +60,20 @@ def _process_pagination(self, doc_snapshots: Iterator[DocumentSnapshot]): return iter(doc_snapshots) - def stream(self, transaction=None) -> Iterator[DocumentSnapshot]: - doc_snapshots = self.parent.stream() - + def _process_field_filters( + self, doc_snapshots: Iterator[DocumentSnapshot] + ) -> Iterable[DocumentSnapshot]: for field, compare, value in self._field_filters: doc_snapshots = [ doc_snapshot for doc_snapshot in doc_snapshots if compare(doc_snapshot._get_by_field_path(field), value) ] + return doc_snapshots + def stream(self, transaction=None) -> Iterator[DocumentSnapshot]: + doc_snapshots = self.parent.stream() + doc_snapshots = self._process_field_filters(doc_snapshots) return self._process_pagination(doc_snapshots) def get(self, transaction=None) -> List[DocumentSnapshot]: From 83963024c30b5f23a30a9fadaaad40d4f3e71d23 Mon Sep 17 00:00:00 2001 From: annahope Date: Thu, 17 Mar 2022 14:00:33 -0500 Subject: [PATCH 17/25] Add tests for AsyncCollectionReference --- tests/test_async_collection_reference.py | 521 ++++++++++++++++++++++- 1 file changed, 519 insertions(+), 2 deletions(-) diff --git a/tests/test_async_collection_reference.py b/tests/test_async_collection_reference.py index 68afcdc..eb1145f 100644 --- a/tests/test_async_collection_reference.py +++ b/tests/test_async_collection_reference.py @@ -1,7 +1,524 @@ import aiounittest +from mockfirestore import ( + AsyncMockFirestore, + DocumentReference, + DocumentSnapshot, + AlreadyExists, +) +from mockfirestore._helpers import consume_async_iterable class TestAsyncCollectionReference(aiounittest.AsyncTestCase): - def test_something(self): - self.assertEqual(True, False) # add assertion here + async def test_collection_get_returnsDocuments(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}, "second": {"id": 2}}} + docs = await consume_async_iterable(fs.collection("foo").stream()) + self.assertEqual({"id": 1}, docs[0].to_dict()) + self.assertEqual({"id": 2}, docs[1].to_dict()) + + async def test_collection_get_collectionDoesNotExist(self): + fs = AsyncMockFirestore() + docs = await consume_async_iterable(fs.collection("foo").stream()) + self.assertEqual([], docs) + + async def test_collection_get_nestedCollection(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1, "bar": {"first_nested": {"id": 1.1}}}}} + docs = await consume_async_iterable( + fs.collection("foo").document("first").collection("bar").stream() + ) + self.assertEqual({"id": 1.1}, docs[0].to_dict()) + + async def test_collection_get_nestedCollection_by_path(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1, "bar": {"first_nested": {"id": 1.1}}}}} + docs = await consume_async_iterable(fs.collection("foo/first/bar").stream()) + self.assertEqual({"id": 1.1}, docs[0].to_dict()) + + async def test_collection_get_nestedCollection_collectionDoesNotExist(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}}} + docs = await consume_async_iterable( + fs.collection("foo").document("first").collection("bar").stream() + ) + self.assertEqual([], docs) + + async def test_collection_get_nestedCollection_by_path_collectionDoesNotExist(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}}} + docs = await consume_async_iterable(fs.collection("foo/first/bar").stream()) + self.assertEqual([], docs) + + async def test_collection_get_ordersByAscendingDocumentId_byDefault(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"beta": {"id": 1}, "alpha": {"id": 2}}} + docs = await consume_async_iterable(fs.collection("foo").stream()) + self.assertEqual({"id": 2}, docs[0].to_dict()) + + async def test_collection_whereEquals(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"valid": True}, "second": {"gumby": False}}} + + docs = await consume_async_iterable( + fs.collection("foo").where("valid", "==", True).stream() + ) + self.assertEqual({"valid": True}, docs[0].to_dict()) + + async def test_collection_whereNotEquals(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"count": 1}, "second": {"count": 5}}} + + docs = await consume_async_iterable( + fs.collection("foo").where("count", "!=", 1).stream() + ) + self.assertEqual({"count": 5}, docs[0].to_dict()) + + async def test_collection_whereLessThan(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"count": 1}, "second": {"count": 5}}} + + docs = await consume_async_iterable( + fs.collection("foo").where("count", "<", 5).stream() + ) + self.assertEqual({"count": 1}, docs[0].to_dict()) + + async def test_collection_whereLessThanOrEqual(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"count": 1}, "second": {"count": 5}}} + + docs = await consume_async_iterable( + fs.collection("foo").where("count", "<=", 5).stream() + ) + self.assertEqual({"count": 1}, docs[0].to_dict()) + self.assertEqual({"count": 5}, docs[1].to_dict()) + + async def test_collection_whereGreaterThan(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"count": 1}, "second": {"count": 5}}} + + docs = await consume_async_iterable( + fs.collection("foo").where("count", ">", 1).stream() + ) + self.assertEqual({"count": 5}, docs[0].to_dict()) + + async def test_collection_whereGreaterThanOrEqual(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"count": 1}, "second": {"count": 5}}} + + docs = await consume_async_iterable( + fs.collection("foo").where("count", ">=", 1).stream() + ) + self.assertEqual({"count": 1}, docs[0].to_dict()) + self.assertEqual({"count": 5}, docs[1].to_dict()) + + async def test_collection_whereMissingField(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"count": 1}, "second": {"count": 5}}} + + docs = await consume_async_iterable( + fs.collection("foo").where("no_field", "==", 1).stream() + ) + self.assertEqual(len(docs), 0) + + async def test_collection_whereNestedField(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": {"first": {"nested": {"a": 1}}, "second": {"nested": {"a": 2}}} + } + + docs = await consume_async_iterable( + fs.collection("foo").where("nested.a", "==", 1).stream() + ) + self.assertEqual(len(docs), 1) + self.assertEqual({"nested": {"a": 1}}, docs[0].to_dict()) + + async def test_collection_whereIn(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": {"field": "a1"}, + "second": {"field": "a2"}, + "third": {"field": "a3"}, + "fourth": {"field": "a4"}, + } + } + + docs = await consume_async_iterable( + fs.collection("foo").where("field", "in", ["a1", "a3"]).stream() + ) + self.assertEqual(len(docs), 2) + self.assertEqual({"field": "a1"}, docs[0].to_dict()) + self.assertEqual({"field": "a3"}, docs[1].to_dict()) + + async def test_collection_whereArrayContains(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": {"field": ["val4"]}, + "second": {"field": ["val3", "val2"]}, + "third": {"field": ["val3", "val2", "val1"]}, + } + } + + docs = await consume_async_iterable( + fs.collection("foo").where("field", "array_contains", "val1").stream() + ) + self.assertEqual(len(docs), 1) + self.assertEqual(docs[0].to_dict(), {"field": ["val3", "val2", "val1"]}) + + async def test_collection_whereArrayContainsAny(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": {"field": ["val4"]}, + "second": {"field": ["val3", "val2"]}, + "third": {"field": ["val3", "val2", "val1"]}, + } + } + + contains_any_docs = await consume_async_iterable( + fs.collection("foo") + .where("field", "array_contains_any", ["val1", "val4"]) + .stream() + ) + self.assertEqual(len(contains_any_docs), 2) + self.assertEqual({"field": ["val4"]}, contains_any_docs[0].to_dict()) + self.assertEqual( + {"field": ["val3", "val2", "val1"]}, contains_any_docs[1].to_dict() + ) + + async def test_collection_orderBy(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"order": 2}, "second": {"order": 1}}} + + docs = await consume_async_iterable( + fs.collection("foo").order_by("order").stream() + ) + self.assertEqual({"order": 1}, docs[0].to_dict()) + self.assertEqual({"order": 2}, docs[1].to_dict()) + + async def test_collection_orderBy_descending(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": {"order": 2}, + "second": {"order": 3}, + "third": {"order": 1}, + } + } + + docs = await consume_async_iterable( + fs.collection("foo").order_by("order", direction="DESCENDING").stream() + ) + self.assertEqual({"order": 3}, docs[0].to_dict()) + self.assertEqual({"order": 2}, docs[1].to_dict()) + self.assertEqual({"order": 1}, docs[2].to_dict()) + + async def test_collection_limit(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}, "second": {"id": 2}}} + docs = await consume_async_iterable(fs.collection("foo").limit(1).stream()) + self.assertEqual({"id": 1}, docs[0].to_dict()) + self.assertEqual(1, len(docs)) + + async def test_collection_offset(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": {"first": {"id": 1}, "second": {"id": 2}, "third": {"id": 3}} + } + docs = await consume_async_iterable(fs.collection("foo").offset(1).stream()) + + self.assertEqual({"id": 2}, docs[0].to_dict()) + self.assertEqual({"id": 3}, docs[1].to_dict()) + self.assertEqual(2, len(docs)) + + async def test_collection_orderby_offset(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": {"first": {"id": 1}, "second": {"id": 2}, "third": {"id": 3}} + } + docs = await consume_async_iterable( + fs.collection("foo").order_by("id").offset(1).stream() + ) + + self.assertEqual({"id": 2}, docs[0].to_dict()) + self.assertEqual({"id": 3}, docs[1].to_dict()) + self.assertEqual(2, len(docs)) + + async def test_collection_start_at(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": {"first": {"id": 1}, "second": {"id": 2}, "third": {"id": 3}} + } + docs = await consume_async_iterable( + fs.collection("foo").start_at({"id": 2}).stream() + ) + self.assertEqual({"id": 2}, docs[0].to_dict()) + self.assertEqual(2, len(docs)) + + async def test_collection_start_at_order_by(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": {"first": {"id": 1}, "second": {"id": 2}, "third": {"id": 3}} + } + docs = await consume_async_iterable( + fs.collection("foo").order_by("id").start_at({"id": 2}).stream() + ) + self.assertEqual({"id": 2}, docs[0].to_dict()) + self.assertEqual(2, len(docs)) + + async def test_collection_start_at_doc_snapshot(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": {"id": 1}, + "second": {"id": 2}, + "third": {"id": 3}, + "fourth": {"id": 4}, + "fifth": {"id": 5}, + } + } + + doc = await fs.collection("foo").document("second").get() + + docs = await consume_async_iterable( + fs.collection("foo").order_by("id").start_at(doc).stream() + ) + self.assertEqual(4, len(docs)) + self.assertEqual({"id": 2}, docs[0].to_dict()) + self.assertEqual({"id": 3}, docs[1].to_dict()) + self.assertEqual({"id": 4}, docs[2].to_dict()) + self.assertEqual({"id": 5}, docs[3].to_dict()) + + async def test_collection_start_after(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": {"first": {"id": 1}, "second": {"id": 2}, "third": {"id": 3}} + } + docs = await consume_async_iterable( + fs.collection("foo").start_after({"id": 1}).stream() + ) + self.assertEqual({"id": 2}, docs[0].to_dict()) + self.assertEqual(2, len(docs)) + + async def test_collection_start_after_similar_objects(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": {"id": 1, "value": 1}, + "second": {"id": 2, "value": 2}, + "third": {"id": 3, "value": 2}, + "fourth": {"id": 4, "value": 3}, + } + } + docs = await consume_async_iterable( + fs.collection("foo") + .order_by("id") + .start_after({"id": 3, "value": 2}) + .stream() + ) + self.assertEqual({"id": 4, "value": 3}, docs[0].to_dict()) + self.assertEqual(1, len(docs)) + + async def test_collection_start_after_order_by(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": {"first": {"id": 1}, "second": {"id": 2}, "third": {"id": 3}} + } + docs = await consume_async_iterable( + fs.collection("foo").order_by("id").start_after({"id": 2}).stream() + ) + self.assertEqual({"id": 3}, docs[0].to_dict()) + self.assertEqual(1, len(docs)) + + async def test_collection_start_after_doc_snapshot(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "second": {"id": 2}, + "third": {"id": 3}, + "fourth": {"id": 4}, + "fifth": {"id": 5}, + } + } + + doc = await fs.collection("foo").document("second").get() + + docs = await consume_async_iterable( + fs.collection("foo").order_by("id").start_after(doc).stream() + ) + self.assertEqual(3, len(docs)) + self.assertEqual({"id": 3}, docs[0].to_dict()) + self.assertEqual({"id": 4}, docs[1].to_dict()) + self.assertEqual({"id": 5}, docs[2].to_dict()) + + async def test_collection_end_before(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": {"first": {"id": 1}, "second": {"id": 2}, "third": {"id": 3}} + } + docs = await consume_async_iterable( + fs.collection("foo").end_before({"id": 2}).stream() + ) + self.assertEqual({"id": 1}, docs[0].to_dict()) + self.assertEqual(1, len(docs)) + + async def test_collection_end_before_order_by(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": {"first": {"id": 1}, "second": {"id": 2}, "third": {"id": 3}} + } + docs = await consume_async_iterable( + fs.collection("foo").order_by("id").end_before({"id": 2}).stream() + ) + self.assertEqual({"id": 1}, docs[0].to_dict()) + self.assertEqual(1, len(docs)) + + async def test_collection_end_before_doc_snapshot(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": {"id": 1}, + "second": {"id": 2}, + "third": {"id": 3}, + "fourth": {"id": 4}, + "fifth": {"id": 5}, + } + } + + doc = await fs.collection("foo").document("fourth").get() + + docs = await consume_async_iterable( + fs.collection("foo").order_by("id").end_before(doc).stream() + ) + self.assertEqual(3, len(docs)) + + self.assertEqual({"id": 1}, docs[0].to_dict()) + self.assertEqual({"id": 2}, docs[1].to_dict()) + self.assertEqual({"id": 3}, docs[2].to_dict()) + + async def test_collection_end_at(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": {"first": {"id": 1}, "second": {"id": 2}, "third": {"id": 3}} + } + docs = await consume_async_iterable( + fs.collection("foo").end_at({"id": 2}).stream() + ) + self.assertEqual({"id": 2}, docs[1].to_dict()) + self.assertEqual(2, len(docs)) + + async def test_collection_end_at_order_by(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": {"first": {"id": 1}, "second": {"id": 2}, "third": {"id": 3}} + } + docs = await consume_async_iterable( + fs.collection("foo").order_by("id").end_at({"id": 2}).stream() + ) + self.assertEqual({"id": 2}, docs[1].to_dict()) + self.assertEqual(2, len(docs)) + + async def test_collection_end_at_doc_snapshot(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": {"id": 1}, + "second": {"id": 2}, + "third": {"id": 3}, + "fourth": {"id": 4}, + "fifth": {"id": 5}, + } + } + + doc = await fs.collection("foo").document("fourth").get() + + docs = await consume_async_iterable( + fs.collection("foo").order_by("id").end_at(doc).stream() + ) + self.assertEqual(4, len(docs)) + + self.assertEqual({"id": 1}, docs[0].to_dict()) + self.assertEqual({"id": 2}, docs[1].to_dict()) + self.assertEqual({"id": 3}, docs[2].to_dict()) + self.assertEqual({"id": 4}, docs[3].to_dict()) + + async def test_collection_limitAndOrderBy(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": {"order": 2}, + "second": {"order": 1}, + "third": {"order": 3}, + } + } + docs = await consume_async_iterable( + fs.collection("foo").order_by("order").limit(2).stream() + ) + self.assertEqual({"order": 1}, docs[0].to_dict()) + self.assertEqual({"order": 2}, docs[1].to_dict()) + + async def test_collection_listDocuments(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": {"order": 2}, + "second": {"order": 1}, + "third": {"order": 3}, + } + } + doc_refs = await consume_async_iterable(fs.collection("foo").list_documents()) + self.assertEqual(3, len(doc_refs)) + for doc_ref in doc_refs: + self.assertIsInstance(doc_ref, DocumentReference) + + async def test_collection_stream(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": {"order": 2}, + "second": {"order": 1}, + "third": {"order": 3}, + } + } + doc_snapshots = await consume_async_iterable(fs.collection("foo").stream()) + self.assertEqual(3, len(doc_snapshots)) + for doc_snapshot in doc_snapshots: + self.assertIsInstance(doc_snapshot, DocumentSnapshot) + + async def test_collection_parent(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": {"order": 2}, + "second": {"order": 1}, + "third": {"order": 3}, + } + } + doc_snapshots = await consume_async_iterable(fs.collection("foo").stream()) + for doc_snapshot in doc_snapshots: + doc_reference = doc_snapshot.reference + subcollection = doc_reference.collection("order") + self.assertIs(subcollection.parent, doc_reference) + + async def test_collection_addDocument(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {}} + doc_id = "bar" + doc_content = {"id": doc_id, "xy": "z"} + timestamp, doc_ref = await fs.collection("foo").add(doc_content) + doc_snapshot = await doc_ref.get() + self.assertEqual(doc_content, doc_snapshot.to_dict()) + + doc = await fs.collection("foo").document(doc_id).get() + self.assertEqual(doc_content, doc.to_dict()) + + with self.assertRaises(AlreadyExists): + await fs.collection("foo").add(doc_content) + + async def test_collection_useDocumentIdKwarg(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}}} + doc = await fs.collection("foo").document(document_id="first").get() + self.assertEqual({"id": 1}, doc.to_dict()) From 5cd0d49bc8e0ef75fef9386a37ca99d0a26640a1 Mon Sep 17 00:00:00 2001 From: annahope Date: Thu, 17 Mar 2022 14:08:30 -0500 Subject: [PATCH 18/25] Add tests for AsyncMockFirestore --- mockfirestore/async_client.py | 6 +++++- tests/test_async_mock_client.py | 25 +++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) create mode 100644 tests/test_async_mock_client.py diff --git a/mockfirestore/async_client.py b/mockfirestore/async_client.py index 65a2141..2eb1ecd 100644 --- a/mockfirestore/async_client.py +++ b/mockfirestore/async_client.py @@ -28,6 +28,10 @@ def collection(self, path: str) -> AsyncCollectionReference: self._data[name] = {} return AsyncCollectionReference(self._data, [name]) + async def collections(self) -> AsyncIterable[AsyncCollectionReference]: + for collection_name in self._data: + yield AsyncCollectionReference(self._data, [collection_name]) + async def get_all( self, references: Iterable[AsyncDocumentReference], @@ -35,7 +39,7 @@ async def get_all( transaction=None, ) -> AsyncIterable[DocumentSnapshot]: for doc_ref in set(references): - yield doc_ref.get() + yield await doc_ref.get() def transaction(self, **kwargs) -> AsyncTransaction: return AsyncTransaction(self, **kwargs) diff --git a/tests/test_async_mock_client.py b/tests/test_async_mock_client.py new file mode 100644 index 0000000..6131e31 --- /dev/null +++ b/tests/test_async_mock_client.py @@ -0,0 +1,25 @@ +import aiounittest + +from mockfirestore import AsyncMockFirestore +from mockfirestore._helpers import consume_async_iterable + + +class TestAsyncMockFirestore(aiounittest.AsyncTestCase): + async def test_client_get_all(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}, "second": {"id": 2}}} + doc = fs.collection("foo").document("first") + results = await consume_async_iterable(fs.get_all([doc])) + returned_doc_snapshot = results[0].to_dict() + expected_doc_snapshot = (await doc.get()).to_dict() + self.assertEqual(returned_doc_snapshot, expected_doc_snapshot) + + async def test_client_collections(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}, "second": {"id": 2}}, "bar": {}} + collections = await consume_async_iterable(fs.collections()) + expected_collections = fs._data + + self.assertEqual(len(collections), len(expected_collections)) + for collection in collections: + self.assertTrue(collection._path[0] in expected_collections) From ae8d02b0b7d26b4f4cc8bf0f3a72a75b9d835efc Mon Sep 17 00:00:00 2001 From: annahope Date: Thu, 17 Mar 2022 14:45:14 -0500 Subject: [PATCH 19/25] Fix _commit for AsyncTransaction --- mockfirestore/async_transaction.py | 20 +++++++++++++------- mockfirestore/transaction.py | 20 ++++++++------------ 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/mockfirestore/async_transaction.py b/mockfirestore/async_transaction.py index 3ceef02..6a3b6d2 100644 --- a/mockfirestore/async_transaction.py +++ b/mockfirestore/async_transaction.py @@ -2,7 +2,7 @@ from mockfirestore.async_document import AsyncDocumentReference from mockfirestore.document import DocumentSnapshot -from mockfirestore.transaction import Transaction, WriteResult +from mockfirestore.transaction import Transaction, WriteResult, _CANT_COMMIT class AsyncTransaction(Transaction): @@ -13,18 +13,27 @@ async def _rollback(self): super()._rollback() async def _commit(self) -> Iterable[WriteResult]: - return super()._commit() + if not self.in_progress: + raise ValueError(_CANT_COMMIT) + + results = [] + for write_op in self._write_ops: + await write_op() + results.append(WriteResult()) + self.write_results = results + self._clean_up() + return results async def get(self, ref_or_query) -> AsyncIterable[DocumentSnapshot]: doc_snapshots = super().get(ref_or_query) - for doc_snapshot in doc_snapshots: + async for doc_snapshot in doc_snapshots: yield doc_snapshot async def get_all( self, references: Iterable[AsyncDocumentReference] ) -> AsyncIterable[DocumentSnapshot]: doc_snapshots = super().get_all(references) - for doc_snapshot in doc_snapshots: + async for doc_snapshot in doc_snapshots: yield doc_snapshot async def __aenter__(self): @@ -33,6 +42,3 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): if exc_type is None: await self.commit() - - - diff --git a/mockfirestore/transaction.py b/mockfirestore/transaction.py index 7f06d2d..e611dc2 100644 --- a/mockfirestore/transaction.py +++ b/mockfirestore/transaction.py @@ -1,5 +1,4 @@ from functools import partial -import random from typing import Iterable, Callable from mockfirestore._helpers import generate_random_string, Timestamp from mockfirestore.document import DocumentReference, DocumentSnapshot @@ -22,8 +21,8 @@ class Transaction: This mostly follows the model from https://googleapis.dev/python/firestore/latest/transaction.html """ - def __init__(self, client, - max_attempts=MAX_ATTEMPTS, read_only=False): + + def __init__(self, client, max_attempts=MAX_ATTEMPTS, read_only=False): self._client = client self._max_attempts = max_attempts self._read_only = read_only @@ -65,8 +64,9 @@ def _commit(self) -> Iterable[WriteResult]: self._clean_up() return results - def get_all(self, - references: Iterable[DocumentReference]) -> Iterable[DocumentSnapshot]: + def get_all( + self, references: Iterable[DocumentReference] + ) -> Iterable[DocumentSnapshot]: return self._client.get_all(references) def get(self, ref_or_query) -> Iterable[DocumentSnapshot]: @@ -84,9 +84,7 @@ def get(self, ref_or_query) -> Iterable[DocumentSnapshot]: def _add_write_op(self, write_op: Callable): if self._read_only: - raise ValueError( - "Cannot perform write operation in read-only transaction." - ) + raise ValueError("Cannot perform write operation in read-only transaction.") self._write_ops.append(write_op) def create(self, reference: DocumentReference, document_data): @@ -94,13 +92,11 @@ def create(self, reference: DocumentReference, document_data): # it's already in the MockFirestore ... - def set(self, reference: DocumentReference, document_data: dict, - merge=False): + def set(self, reference: DocumentReference, document_data: dict, merge=False): write_op = partial(reference.set, document_data, merge=merge) self._add_write_op(write_op) - def update(self, reference: DocumentReference, - field_updates: dict, option=None): + def update(self, reference: DocumentReference, field_updates: dict, option=None): write_op = partial(reference.update, field_updates) self._add_write_op(write_op) From 04b531d2ab110944d0c909910aebe90c2fccc59c Mon Sep 17 00:00:00 2001 From: annahope Date: Thu, 17 Mar 2022 14:55:00 -0500 Subject: [PATCH 20/25] Add tests for AsyncTransaction --- tests/test_async_transaction.py | 71 +++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 tests/test_async_transaction.py diff --git a/tests/test_async_transaction.py b/tests/test_async_transaction.py new file mode 100644 index 0000000..abd41cf --- /dev/null +++ b/tests/test_async_transaction.py @@ -0,0 +1,71 @@ +import aiounittest + +from mockfirestore import AsyncMockFirestore, AsyncTransaction +from mockfirestore._helpers import consume_async_iterable + + +class TestAsyncTransaction(aiounittest.AsyncTestCase): + def setUp(self) -> None: + self.fs = AsyncMockFirestore() + self.fs._data = {"foo": {"first": {"id": 1}, "second": {"id": 2}}} + + async def test_transaction_getAll(self): + async with AsyncTransaction(self.fs) as transaction: + await transaction._begin() + docs = [ + self.fs.collection("foo").document(doc_name) + for doc_name in self.fs._data["foo"] + ] + results = await consume_async_iterable(transaction.get_all(docs)) + returned_docs_snapshots = [result.to_dict() for result in results] + expected_doc_snapshots = [(await doc.get()).to_dict() for doc in docs] + for expected_snapshot in expected_doc_snapshots: + self.assertIn(expected_snapshot, returned_docs_snapshots) + + async def test_transaction_getDocument(self): + async with AsyncTransaction(self.fs) as transaction: + await transaction._begin() + doc = self.fs.collection("foo").document("first") + returned_doc = await anext(transaction.get(doc)) + self.assertEqual((await doc.get()).to_dict(), returned_doc.to_dict()) + + async def test_transaction_getQuery(self): + async with AsyncTransaction(self.fs) as transaction: + await transaction._begin() + query = self.fs.collection("foo").order_by("id") + returned_docs = [doc.to_dict() async for doc in transaction.get(query)] + query = self.fs.collection("foo").order_by("id") + expected_docs = [doc.to_dict() async for doc in query.stream()] + self.assertEqual(returned_docs, expected_docs) + + async def test_transaction_set_setsContentOfDocument(self): + doc_content = {"id": "3"} + doc_ref = self.fs.collection("foo").document("third") + async with AsyncTransaction(self.fs) as transaction: + await transaction._begin() + transaction.set(doc_ref, doc_content) + self.assertEqual((await doc_ref.get()).to_dict(), doc_content) + + async def test_transaction_set_mergeNewValue(self): + doc = self.fs.collection("foo").document("first") + async with AsyncTransaction(self.fs) as transaction: + await transaction._begin() + transaction.set(doc, {"updated": True}, merge=True) + updated_doc = {"id": 1, "updated": True} + self.assertEqual((await doc.get()).to_dict(), updated_doc) + + async def test_transaction_update_changeExistingValue(self): + doc = self.fs.collection("foo").document("first") + async with AsyncTransaction(self.fs) as transaction: + await transaction._begin() + transaction.update(doc, {"updated": False}) + updated_doc = {"id": 1, "updated": False} + self.assertEqual((await doc.get()).to_dict(), updated_doc) + + async def test_transaction_delete_documentDoesNotExistAfterDelete(self): + doc = self.fs.collection("foo").document("first") + async with AsyncTransaction(self.fs) as transaction: + await transaction._begin() + transaction.delete(doc) + doc = await self.fs.collection("foo").document("first").get() + self.assertEqual(False, doc.exists) From 95c7bf9f9497b0b8b17fe74a592139d95d1cc32f Mon Sep 17 00:00:00 2001 From: annahope Date: Thu, 17 Mar 2022 16:49:54 -0500 Subject: [PATCH 21/25] Add async examples to the README.md --- README.md | 95 +++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 93 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 43cf5af..496e83c 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Python Mock Firestore -An in-memory implementation of the [Python client library](https://github.com/googleapis/python-firestore) for Google Cloud Firestore, intended for use in tests to replace the real thing. This project is in early stages and is only a partial implementation of the real client library. +An in-memory implementation of the [Python client library](https://github.com/googleapis/python-firestore) for Google Cloud Firestore, intended for use in tests to replace the real thing. This project is only a partial implementation of the real client library. To install: @@ -10,6 +10,8 @@ Python 3.6+ is required for it to work. ## Usage +### Sync + ```python db = firestore.Client() mock_db = MockFirestore() @@ -19,14 +21,28 @@ db.collection('users').get() mock_db.collection('users').get() ``` +### Async + +```python +db = firestore.AsyncClient() +mock_db = AsyncMockFirestore() + +await db.collection('users').get() +await mock_db.collection('users').get() +``` + To reset the store to an empty state, use the `reset()` method: ```python mock_db = MockFirestore() mock_db.reset() ``` +or the equivalent method of `AsyncMockFirestore` + ## Supported operations +### Sync + ```python mock_db = MockFirestore() @@ -57,7 +73,7 @@ mock_db.collection('users').document('alovelace').update({'favourite.color': 're mock_db.collection('users').document('alovelace').update({'associates': ['Charles Babbage', 'Michael Faraday']}) mock_db.collection('users').document('alovelace').collection('friends') mock_db.collection('users').document('alovelace').delete() -mock_db.collection('users').document(document_id: 'alovelace').delete() +mock_db.collection('users').document('alovelace').delete() mock_db.collection('users').add({'first': 'Ada', 'last': 'Lovelace'}, 'alovelace') mock_db.get_all([mock_db.collection('users').document('alovelace')]) mock_db.document('users/alovelace') @@ -104,6 +120,81 @@ transaction.delete(mock_db.collection('users').document('alovelace')) transaction.commit() ``` +### Async +*(Where usage of those differs from the above)* + +*Note: all iterator methods like `stream` or `list_documents` in AsyncMockFirestore and its associated async classes +return asynchronous iterators, so when iterating over them, +`async for` syntax must be used.* + +```python +mock_db = AsyncMockFirestore() + +# Collections +await mock_db.collection('users').get() + +# async iterators +[doc_ref async for doc_ref in mock_db.collection('users').list_documents()] +[doc_snapshot async for doc_snapshot in mock_db.collection('users').stream()] + +# Documents +await mock_db.collection('users').document('alovelace').get() +doc_snapshot = await mock_db.collection('users').document('alovelace').get() +doc_snapshot.exists +doc_snapshot.to_dict() +await mock_db.collection('users').document('alovelace').set({ + 'first': 'Ada', + 'last': 'Lovelace' +}) +await mock_db.collection('users').document('alovelace').set({'first': 'Augusta Ada'}, merge=True) +await mock_db.collection('users').document('alovelace').update({'born': 1815}) +await mock_db.collection('users').document('alovelace').update({'favourite.color': 'red'}) +await mock_db.collection('users').document('alovelace').update({'associates': ['Charles Babbage', 'Michael Faraday']}) +await mock_db.collection('users').document('alovelace').delete() +await mock_db.collection('users').document('alovelace').delete() +await mock_db.collection('users').add({'first': 'Ada', 'last': 'Lovelace'}, 'alovelace') +await mock_db.get_all([mock_db.collection('users').document('alovelace')]) +await mock_db.document('users/alovelace').update({'born': 1815}) + +# Querying +await mock_db.collection('users').order_by('born').get() +await mock_db.collection('users').order_by('born', direction='DESCENDING').get() +await mock_db.collection('users').limit(5).get() +await mock_db.collection('users').where('born', '==', 1815).get() +await mock_db.collection('users').where('born', '!=', 1815).get() +await mock_db.collection('users').where('born', '<', 1815).get() +await mock_db.collection('users').where('born', '>', 1815).get() +await mock_db.collection('users').where('born', '<=', 1815).get() +await mock_db.collection('users').where('born', '>=', 1815).get() + +# async iterators +mock_db.collection('users').where('born', 'in', [1815, 1900]).stream() +mock_db.collection('users').where('born', 'in', [1815, 1900]).stream() +mock_db.collection('users').where('associates', 'array_contains', 'Charles Babbage').stream() +mock_db.collection('users').where('associates', 'array_contains_any', ['Charles Babbage', 'Michael Faraday']).stream() + +# Transforms +await mock_db.collection('users').document('alovelace').update({'likes': firestore.Increment(1)}) +await mock_db.collection('users').document('alovelace').update({'associates': firestore.ArrayUnion(['Andrew Cross', 'Charles Wheatstone'])}) +await mock_db.collection('users').document('alovelace').update({firestore.DELETE_FIELD: "born"}) +await mock_db.collection('users').document('alovelace').update({'associates': firestore.ArrayRemove(['Andrew Cross'])}) + + +# Transactions +transaction = mock_db.transaction() +transaction.id +transaction.in_progress +await transaction.get(mock_db.collection('users').where('born', '==', 1815)) +await transaction.get(mock_db.collection('users').document('alovelace')) +await transaction.get_all([mock_db.collection('users').document('alovelace')]) + +transaction.set(mock_db.collection('users').document('alovelace'), {'born': 1815}) +transaction.update(mock_db.collection('users').document('alovelace'), {'born': 1815}) +transaction.delete(mock_db.collection('users').document('alovelace')) +await transaction.commit() +``` + + ## Running the tests * Create and activate a virtualenv with a Python version of at least 3.6 * Install dependencies with `pip install -r requirements-dev-minimal.txt` From 22acce40cc580e4172ace2233a0094511aef09f8 Mon Sep 17 00:00:00 2001 From: annahope Date: Thu, 17 Mar 2022 17:00:36 -0500 Subject: [PATCH 22/25] Don't use anext since it's only in Python 3.10 --- tests/test_async_transaction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_async_transaction.py b/tests/test_async_transaction.py index abd41cf..d365a06 100644 --- a/tests/test_async_transaction.py +++ b/tests/test_async_transaction.py @@ -26,7 +26,7 @@ async def test_transaction_getDocument(self): async with AsyncTransaction(self.fs) as transaction: await transaction._begin() doc = self.fs.collection("foo").document("first") - returned_doc = await anext(transaction.get(doc)) + returned_doc = [doc async for doc in transaction.get(doc)][0] self.assertEqual((await doc.get()).to_dict(), returned_doc.to_dict()) async def test_transaction_getQuery(self): From 260aaa2ecc281446e125418d65a93c98043814c9 Mon Sep 17 00:00:00 2001 From: Ben van der Harg Date: Thu, 2 Jun 2022 14:40:49 +0200 Subject: [PATCH 23/25] Fix error: 'async_generator' object is not iterable --- mockfirestore/async_query.py | 6 +++--- tests/test_async_query.py | 13 +++++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) create mode 100644 tests/test_async_query.py diff --git a/mockfirestore/async_query.py b/mockfirestore/async_query.py index 0104dbc..b073e5d 100644 --- a/mockfirestore/async_query.py +++ b/mockfirestore/async_query.py @@ -1,11 +1,11 @@ -from typing import AsyncIterator, List +from typing import List, AsyncGenerator from mockfirestore.document import DocumentSnapshot from mockfirestore.query import Query from mockfirestore._helpers import consume_async_iterable class AsyncQuery(Query): - async def stream(self, transaction=None) -> AsyncIterator[DocumentSnapshot]: + async def stream(self, transaction=None) -> AsyncGenerator[DocumentSnapshot]: doc_snapshots = await consume_async_iterable(self.parent.stream()) doc_snapshots = super()._process_field_filters(doc_snapshots) doc_snapshots = super()._process_pagination(doc_snapshots) @@ -13,4 +13,4 @@ async def stream(self, transaction=None) -> AsyncIterator[DocumentSnapshot]: yield doc_snapshot async def get(self, transaction=None) -> List[DocumentSnapshot]: - return super().get() + return [result async for result in self.stream()] diff --git a/tests/test_async_query.py b/tests/test_async_query.py new file mode 100644 index 0000000..4931ba2 --- /dev/null +++ b/tests/test_async_query.py @@ -0,0 +1,13 @@ +import aiounittest + +from mockfirestore import AsyncMockFirestore + + +class TestAsyncMockFirestore(aiounittest.AsyncTestCase): + async def test_query_get(self): + fs = AsyncMockFirestore() + doc_in_fs = {"id": 1} + fs._data = {"foo": {"first": doc_in_fs}} + docs = await fs.collection("foo").where("id", "==", 1).get() + self.assertEqual(len(docs), 1) + self.assertEqual(docs[0].to_dict()["id"], 1) From a29ed6a0420d1c3e68a859fe72b11b2fbf713a45 Mon Sep 17 00:00:00 2001 From: Ben van der Harg Date: Thu, 2 Jun 2022 15:02:55 +0200 Subject: [PATCH 24/25] Replace async list comprehension with existing consume_async_iterable --- mockfirestore/async_query.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mockfirestore/async_query.py b/mockfirestore/async_query.py index b073e5d..2a96fba 100644 --- a/mockfirestore/async_query.py +++ b/mockfirestore/async_query.py @@ -1,11 +1,11 @@ -from typing import List, AsyncGenerator +from typing import List, AsyncGenerator, AsyncIterator from mockfirestore.document import DocumentSnapshot from mockfirestore.query import Query from mockfirestore._helpers import consume_async_iterable class AsyncQuery(Query): - async def stream(self, transaction=None) -> AsyncGenerator[DocumentSnapshot]: + async def stream(self, transaction=None) -> AsyncIterator[DocumentSnapshot]: doc_snapshots = await consume_async_iterable(self.parent.stream()) doc_snapshots = super()._process_field_filters(doc_snapshots) doc_snapshots = super()._process_pagination(doc_snapshots) @@ -13,4 +13,4 @@ async def stream(self, transaction=None) -> AsyncGenerator[DocumentSnapshot]: yield doc_snapshot async def get(self, transaction=None) -> List[DocumentSnapshot]: - return [result async for result in self.stream()] + return await consume_async_iterable(self.stream()) From 3cdf10b87744f5c54b44f5e3c452bf904ed3f90c Mon Sep 17 00:00:00 2001 From: Ben van der Harg Date: Thu, 2 Jun 2022 15:11:11 +0200 Subject: [PATCH 25/25] Remove unused import --- mockfirestore/async_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mockfirestore/async_query.py b/mockfirestore/async_query.py index 2a96fba..173cc4c 100644 --- a/mockfirestore/async_query.py +++ b/mockfirestore/async_query.py @@ -1,4 +1,4 @@ -from typing import List, AsyncGenerator, AsyncIterator +from typing import List, AsyncIterator from mockfirestore.document import DocumentSnapshot from mockfirestore.query import Query from mockfirestore._helpers import consume_async_iterable