diff --git a/mockfirestore/__init__.py b/mockfirestore/__init__.py index a7f18de..4209cf6 100644 --- a/mockfirestore/__init__.py +++ b/mockfirestore/__init__.py @@ -12,4 +12,4 @@ from mockfirestore.collection import CollectionReference from mockfirestore.query import Query from mockfirestore._helpers import Timestamp -from mockfirestore.transaction import Transaction +from mockfirestore.transaction import BatchTransaction, Transaction diff --git a/mockfirestore/client.py b/mockfirestore/client.py index 75943bd..901c91b 100644 --- a/mockfirestore/client.py +++ b/mockfirestore/client.py @@ -1,7 +1,7 @@ from typing import Iterable, Sequence from mockfirestore.collection import CollectionReference from mockfirestore.document import DocumentReference, DocumentSnapshot -from mockfirestore.transaction import Transaction +from mockfirestore.transaction import BatchTransaction, Transaction class MockFirestore: @@ -47,6 +47,9 @@ def collection(self, path: str) -> CollectionReference: def collections(self) -> Sequence[CollectionReference]: return [CollectionReference(self._data, [collection_name]) for collection_name in self._data] + def batch(self) -> BatchTransaction: + return BatchTransaction(self) + def reset(self): self._data = {} @@ -58,5 +61,3 @@ def get_all(self, references: Iterable[DocumentReference], def transaction(self, **kwargs) -> Transaction: return Transaction(self, **kwargs) - - diff --git a/mockfirestore/collection.py b/mockfirestore/collection.py index 431c074..1714da2 100644 --- a/mockfirestore/collection.py +++ b/mockfirestore/collection.py @@ -14,6 +14,15 @@ def __init__(self, data: Store, path: List[str], self._path = path self.parent = parent + @property + def id(self): + """The collection identifier. + + Returns: + str: The last component of the path. + """ + return self._path[-1] + def document(self, document_id: Optional[str] = None) -> DocumentReference: collection = get_by_path(self._data, self._path) if document_id is None: diff --git a/mockfirestore/transaction.py b/mockfirestore/transaction.py index 7f06d2d..3c060fd 100644 --- a/mockfirestore/transaction.py +++ b/mockfirestore/transaction.py @@ -117,3 +117,27 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is None: self.commit() + + +class BatchTransaction: + # array of transactions + def __init__(self, client): + self._transactions = [] + + def set(self, reference: DocumentReference, document_data: dict, merge=False): + # stash set for later + self._transactions.append(partial(reference.set, document_data, merge=merge)) + + def update(self, reference: DocumentReference, field_updates: dict, option=None): + # stash update for later + self._transactions.append(partial(reference.update, field_updates)) + + def delete(self, reference: DocumentReference, option=None): + # stash delete for later + self._transactions.append(reference.delete) + + def commit(self): + # execute all transactions + for transaction in self._transactions: + transaction() + self._transactions.clear() diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 72031fa..a3f0f65 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -1,5 +1,5 @@ from unittest import TestCase -from mockfirestore import MockFirestore, Transaction +from mockfirestore import MockFirestore, BatchTransaction, Transaction class TestTransaction(TestCase): @@ -70,4 +70,20 @@ def test_transaction_delete_documentDoesNotExistAfterDelete(self): self.assertEqual(False, doc.exists) +class TestBatchTransaction(TestCase): + def setUp(self) -> None: + self.fs = MockFirestore() + self.fs._data = {"foo": {"first": {"id": 1}, "second": {"id": 2}}} + def test_batchTransaction_set_setContentOfDocuments(self): + batch = self.fs.batch() + doc_contents = [{"id": "3"}, {"id": "4"}] + doc_refs = [ + self.fs.collection("foo").document("third"), + self.fs.collection("foo").document("fourth"), + ] + for doc_ref, doc_content in zip(doc_refs, doc_contents): + batch.set(doc_ref, doc_content) + batch.commit() + for doc_ref, doc_content in zip(doc_refs, doc_contents): + self.assertEqual(doc_ref.get().to_dict(), doc_content)