Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mockfirestore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
from mockfirestore.collection import CollectionReference
from mockfirestore.query import Query
from mockfirestore._helpers import Timestamp
from mockfirestore.transaction import Transaction
from mockfirestore.transaction import BatchTransaction, Transaction
7 changes: 4 additions & 3 deletions mockfirestore/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Iterable, Sequence
from mockfirestore.collection import CollectionReference
from mockfirestore.document import DocumentReference, DocumentSnapshot
from mockfirestore.transaction import Transaction
from mockfirestore.transaction import BatchTransaction, Transaction


class MockFirestore:
Expand Down Expand Up @@ -47,6 +47,9 @@ def collection(self, path: str) -> CollectionReference:
def collections(self) -> Sequence[CollectionReference]:
return [CollectionReference(self._data, [collection_name]) for collection_name in self._data]

def batch(self) -> BatchTransaction:
return BatchTransaction(self)

def reset(self):
self._data = {}

Expand All @@ -58,5 +61,3 @@ def get_all(self, references: Iterable[DocumentReference],

def transaction(self, **kwargs) -> Transaction:
return Transaction(self, **kwargs)


9 changes: 9 additions & 0 deletions mockfirestore/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ def __init__(self, data: Store, path: List[str],
self._path = path
self.parent = parent

@property
def id(self):
"""The collection identifier.

Returns:
str: The last component of the path.
"""
return self._path[-1]

def document(self, document_id: Optional[str] = None) -> DocumentReference:
collection = get_by_path(self._data, self._path)
if document_id is None:
Expand Down
24 changes: 24 additions & 0 deletions mockfirestore/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,27 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
self.commit()


class BatchTransaction:
# array of transactions
def __init__(self, client):
self._transactions = []

def set(self, reference: DocumentReference, document_data: dict, merge=False):
# stash set for later
self._transactions.append(partial(reference.set, document_data, merge=merge))

def update(self, reference: DocumentReference, field_updates: dict, option=None):
# stash update for later
self._transactions.append(partial(reference.update, field_updates))

def delete(self, reference: DocumentReference, option=None):
# stash delete for later
self._transactions.append(reference.delete)

def commit(self):
# execute all transactions
for transaction in self._transactions:
transaction()
self._transactions.clear()
18 changes: 17 additions & 1 deletion tests/test_transaction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from unittest import TestCase
from mockfirestore import MockFirestore, Transaction
from mockfirestore import MockFirestore, BatchTransaction, Transaction


class TestTransaction(TestCase):
Expand Down Expand Up @@ -70,4 +70,20 @@ def test_transaction_delete_documentDoesNotExistAfterDelete(self):
self.assertEqual(False, doc.exists)


class TestBatchTransaction(TestCase):
def setUp(self) -> None:
self.fs = MockFirestore()
self.fs._data = {"foo": {"first": {"id": 1}, "second": {"id": 2}}}

def test_batchTransaction_set_setContentOfDocuments(self):
batch = self.fs.batch()
doc_contents = [{"id": "3"}, {"id": "4"}]
doc_refs = [
self.fs.collection("foo").document("third"),
self.fs.collection("foo").document("fourth"),
]
for doc_ref, doc_content in zip(doc_refs, doc_contents):
batch.set(doc_ref, doc_content)
batch.commit()
for doc_ref, doc_content in zip(doc_refs, doc_contents):
self.assertEqual(doc_ref.get().to_dict(), doc_content)