Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mockfirestore/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def offset(self, offset: int) -> Query:
query = Query(self, offset=offset)
return query

def select(self, field_paths: List[str]) -> Query:
query = Query(self, projection=field_paths)
return query

def start_at(self, document_fields_or_snapshot: Union[dict, DocumentSnapshot]) -> Query:
query = Query(self, start_at=(document_fields_or_snapshot, True))
return query
Expand Down
13 changes: 13 additions & 0 deletions mockfirestore/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ def stream(self, transaction=None) -> Iterator[DocumentSnapshot]:
if self._limit:
doc_snapshots = islice(doc_snapshots, self._limit)

if self.projection:
doc_snapshots = [
DocumentSnapshot(
x.reference,
{k: v for k, v in x.to_dict().items() if k in self.projection}
)
for x in doc_snapshots
]

return iter(doc_snapshots)

def get(self) -> Iterator[DocumentSnapshot]:
Expand Down Expand Up @@ -77,6 +86,10 @@ def offset(self, offset_amount: int) -> 'Query':
self._offset = offset_amount
return self

def select(self, field_paths: List[str]) -> 'Query':
self.projection = field_paths
return self

def start_at(self, document_fields_or_snapshot: Union[dict, DocumentSnapshot]) -> 'Query':
self._start_at = (document_fields_or_snapshot, True)
return self
Expand Down
36 changes: 36 additions & 0 deletions tests/test_collection_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,42 @@ def test_collection_orderby_offset(self):
self.assertEqual({'id': 3}, docs[1].to_dict())
self.assertEqual(2, len(docs))

def test_collection_select(self):
fs = MockFirestore()
fs._data = {'foo': {
'first': {'id': 1, 'value': 'one'},
'second': {'id': 2, 'value': 'two'},
'third': {'id': 3, 'value': 'three'}
}}
docs = list(fs.collection('foo').select(['id']).stream())
self.assertEqual({'id': 1}, docs[0].to_dict())
self.assertEqual({'id': 2}, docs[1].to_dict())
self.assertEqual({'id': 3}, docs[2].to_dict())

def test_collection_select_missing_fields(self):
fs = MockFirestore()
fs._data = {'foo': {
'first': {'id': 1, 'value': 'one'},
'second': {'id': 2, 'value': 'two'},
'third': {'id': 3, 'value': 'three'}
}}
docs = list(fs.collection("foo").select(["name"]).stream())
self.assertEqual({}, docs[0].to_dict())
self.assertEqual({}, docs[1].to_dict())
self.assertEqual({}, docs[2].to_dict())

def test_collection_select_one_missing_field(self):
fs = MockFirestore()
fs._data = {'foo': {
'first': {'id': 1, 'value': 'one'},
'second': {'id': 2, 'value': 'two'},
'third': {'id': 3, 'value': 'three'}
}}
docs = list(fs.collection("foo").select(["id", "name"]).stream())
self.assertEqual({"id": 1}, docs[0].to_dict())
self.assertEqual({"id": 2}, docs[1].to_dict())
self.assertEqual({"id": 3}, docs[2].to_dict())

def test_collection_start_at(self):
fs = MockFirestore()
fs._data = {'foo': {
Expand Down