diff --git a/mockfirestore/collection.py b/mockfirestore/collection.py index 7302ab8..b4f54ad 100644 --- a/mockfirestore/collection.py +++ b/mockfirestore/collection.py @@ -57,6 +57,10 @@ def offset(self, offset: int) -> Query: query = Query(self, offset=offset) return query + def select(self, field_paths: List[str]) -> Query: + query = Query(self, projection=field_paths) + return query + def start_at(self, document_fields_or_snapshot: Union[dict, DocumentSnapshot]) -> Query: query = Query(self, start_at=(document_fields_or_snapshot, True)) return query diff --git a/mockfirestore/query.py b/mockfirestore/query.py index 4761a92..5aaa2fb 100644 --- a/mockfirestore/query.py +++ b/mockfirestore/query.py @@ -50,6 +50,15 @@ def stream(self, transaction=None) -> Iterator[DocumentSnapshot]: if self._limit: doc_snapshots = islice(doc_snapshots, self._limit) + if self.projection: + doc_snapshots = [ + DocumentSnapshot( + x.reference, + {k: v for k, v in x.to_dict().items() if k in self.projection} + ) + for x in doc_snapshots + ] + return iter(doc_snapshots) def get(self) -> Iterator[DocumentSnapshot]: @@ -77,6 +86,10 @@ def offset(self, offset_amount: int) -> 'Query': self._offset = offset_amount return self + def select(self, field_paths: List[str]) -> 'Query': + self.projection = field_paths + return self + def start_at(self, document_fields_or_snapshot: Union[dict, DocumentSnapshot]) -> 'Query': self._start_at = (document_fields_or_snapshot, True) return self diff --git a/tests/test_collection_reference.py b/tests/test_collection_reference.py index 59397be..0bdc75d 100644 --- a/tests/test_collection_reference.py +++ b/tests/test_collection_reference.py @@ -243,6 +243,42 @@ def test_collection_orderby_offset(self): self.assertEqual({'id': 3}, docs[1].to_dict()) self.assertEqual(2, len(docs)) + def test_collection_select(self): + fs = MockFirestore() + fs._data = {'foo': { + 'first': {'id': 1, 'value': 'one'}, + 'second': {'id': 2, 'value': 'two'}, + 'third': {'id': 3, 'value': 'three'} + }} + docs = list(fs.collection('foo').select(['id']).stream()) + self.assertEqual({'id': 1}, docs[0].to_dict()) + self.assertEqual({'id': 2}, docs[1].to_dict()) + self.assertEqual({'id': 3}, docs[2].to_dict()) + + def test_collection_select_missing_fields(self): + fs = MockFirestore() + fs._data = {'foo': { + 'first': {'id': 1, 'value': 'one'}, + 'second': {'id': 2, 'value': 'two'}, + 'third': {'id': 3, 'value': 'three'} + }} + docs = list(fs.collection("foo").select(["name"]).stream()) + self.assertEqual({}, docs[0].to_dict()) + self.assertEqual({}, docs[1].to_dict()) + self.assertEqual({}, docs[2].to_dict()) + + def test_collection_select_one_missing_field(self): + fs = MockFirestore() + fs._data = {'foo': { + 'first': {'id': 1, 'value': 'one'}, + 'second': {'id': 2, 'value': 'two'}, + 'third': {'id': 3, 'value': 'three'} + }} + docs = list(fs.collection("foo").select(["id", "name"]).stream()) + self.assertEqual({"id": 1}, docs[0].to_dict()) + self.assertEqual({"id": 2}, docs[1].to_dict()) + self.assertEqual({"id": 3}, docs[2].to_dict()) + def test_collection_start_at(self): fs = MockFirestore() fs._data = {'foo': {