Skip to content

Commit ee83034

Browse files
committed
add take_while()
1 parent 428fb29 commit ee83034

File tree

3 files changed

+58
-1
lines changed

3 files changed

+58
-1
lines changed

tests/test_usage.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,6 +1259,38 @@ def test_count_zero_or_negative(self):
12591259
assert en.take_last(-1).to_list() == []
12601260

12611261

1262+
class TestTakeWhileMethod:
1263+
def test_take_while_some(self):
1264+
lst = ['1', '3', '5', '7', '', '1', '4', '5']
1265+
en = Enumerable(lst)
1266+
q = en.take_while(lambda x: x != '')
1267+
assert q.to_list() == ['1', '3', '5', '7']
1268+
1269+
def test_take_while_all(self):
1270+
lst = ['1', '3', '5', '7', '', '1', '4', '5']
1271+
en = Enumerable(lst)
1272+
q = en.take_while(lambda x: x != '77')
1273+
assert q.to_list() == lst
1274+
1275+
def test_take_while_nothing(self):
1276+
lst = ['1', '3', '5', '7', '', '1', '4', '5']
1277+
en = Enumerable(lst)
1278+
q = en.take_while(lambda x: not isinstance(x, str))
1279+
assert q.to_list() == []
1280+
1281+
def test_take_while2_some(self):
1282+
lst = ['1', '3', '5', '7', '', '1', '4', '5']
1283+
en = Enumerable(lst)
1284+
q = en.take_while2(lambda x, i: isinstance(x, str) and i < 4)
1285+
assert q.to_list() == ['1', '3', '5', '7']
1286+
1287+
def test_take_while2_all(self):
1288+
lst = ['1', '3', '5', '7', '', '1', '4', '5']
1289+
en = Enumerable(lst)
1290+
q = en.take_while2(lambda x, i: isinstance(x, str) and i < 8)
1291+
assert q.to_list() == lst
1292+
1293+
12621294
class TestUnionMethod:
12631295
def test_union(self):
12641296
gen = (i for i in range(5))

types_linq/enumerable.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,22 @@ def inner():
778778
yield from q
779779
return Enumerable(inner)
780780

781+
def take_while(self, predicate: Callable[[TSource_co], bool]) -> Enumerable[TSource_co]:
782+
def inner():
783+
for elem in self:
784+
if not predicate(elem):
785+
break
786+
yield elem
787+
return Enumerable(inner)
788+
789+
def take_while2(self, predicate: Callable[[TSource_co, int], bool]) -> Enumerable[TSource_co]:
790+
def inner():
791+
for i, elem in enumerate(self):
792+
if not predicate(elem, i):
793+
break
794+
yield elem
795+
return Enumerable(inner)
796+
781797
def to_dict(self,
782798
key_selector: Callable[[TSource_co], TKey],
783799
*args: Callable[[TSource_co], TValue],

types_linq/enumerable.pyi

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,16 @@ class Enumerable(Sequence[TSource_co], Generic[TSource_co]):
645645
Returns a new sequence that contains the last `count` elements.
646646
'''
647647

648-
# @@@ TODO
648+
def take_while(self, predicate: Callable[[TSource_co], bool]) -> Enumerable[TSource_co]:
649+
'''
650+
Returns elements from the sequence as long as the condition is true and skips the remaining.
651+
'''
652+
653+
def take_while2(self, predicate: Callable[[TSource_co, int], bool]) -> Enumerable[TSource_co]:
654+
'''
655+
Returns elements from the sequence as long as the condition is true and skips the remaining. The
656+
element's index is used in the predicate function.
657+
'''
649658

650659
@overload
651660
def to_dict(self,

0 commit comments

Comments
 (0)