-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathSortedSet.py
131 lines (111 loc) · 4.12 KB
/
SortedSet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# https://github.com/tatyam-prime/SortedSet/blob/main/SortedSet.py
import math
from bisect import bisect_left, bisect_right
from typing import Generic, Iterable, Iterator, List, TypeVar, Union
T = TypeVar('T')
class SortedSet(Generic[T]):
BUCKET_RATIO = 50
REBUILD_RATIO = 170
def _build(self, a=None) -> None:
"Evenly divide `a` into buckets."
if a is None: a = list(self)
size = self.size = len(a)
bucket_size = int(math.ceil(math.sqrt(size / self.BUCKET_RATIO)))
self.a = [a[size * i // bucket_size : size * (i + 1) // bucket_size] for i in range(bucket_size)]
def __init__(self, a: Iterable[T] = []) -> None:
"Make a new SortedSet from iterable. / O(N) if sorted and unique / O(N log N)"
a = list(a)
if not all(a[i] < a[i + 1] for i in range(len(a) - 1)):
a = sorted(set(a))
self._build(a)
def __iter__(self) -> Iterator[T]:
for i in self.a:
for j in i: yield j
def __reversed__(self) -> Iterator[T]:
for i in reversed(self.a):
for j in reversed(i): yield j
def __len__(self) -> int:
return self.size
def __repr__(self) -> str:
return "SortedSet" + str(self.a)
def __str__(self) -> str:
s = str(list(self))
return "{" + s[1 : len(s) - 1] + "}"
def _find_bucket(self, x: T) -> List[T]:
"Find the bucket which should contain x. self must not be empty."
for a in self.a:
if x <= a[-1]: return a
return a
def __contains__(self, x: T) -> bool:
if self.size == 0: return False
a = self._find_bucket(x)
i = bisect_left(a, x)
return i != len(a) and a[i] == x
def add(self, x: T) -> bool:
"Add an element and return True if added. / O(√N)"
if self.size == 0:
self.a = [[x]]
self.size = 1
return True
a = self._find_bucket(x)
i = bisect_left(a, x)
if i != len(a) and a[i] == x: return False
a.insert(i, x)
self.size += 1
if len(a) > len(self.a) * self.REBUILD_RATIO:
self._build()
return True
def discard(self, x: T) -> bool:
"Remove an element and return True if removed. / O(√N)"
if self.size == 0: return False
a = self._find_bucket(x)
i = bisect_left(a, x)
if i == len(a) or a[i] != x: return False
a.pop(i)
self.size -= 1
if len(a) == 0: self._build()
return True
def lt(self, x: T) -> Union[T, None]:
"Find the largest element < x, or None if it doesn't exist."
for a in reversed(self.a):
if a[0] < x:
return a[bisect_left(a, x) - 1]
def le(self, x: T) -> Union[T, None]:
"Find the largest element <= x, or None if it doesn't exist."
for a in reversed(self.a):
if a[0] <= x:
return a[bisect_right(a, x) - 1]
def gt(self, x: T) -> Union[T, None]:
"Find the smallest element > x, or None if it doesn't exist."
for a in self.a:
if a[-1] > x:
return a[bisect_right(a, x)]
def ge(self, x: T) -> Union[T, None]:
"Find the smallest element >= x, or None if it doesn't exist."
for a in self.a:
if a[-1] >= x:
return a[bisect_left(a, x)]
def __getitem__(self, x: int) -> T:
"Return the x-th element, or IndexError if it doesn't exist."
if x < 0: x += self.size
if x < 0: raise IndexError
for a in self.a:
if x < len(a): return a[x]
x -= len(a)
raise IndexError
def index(self, x: T) -> int:
"Count the number of elements < x."
ans = 0
for a in self.a:
if a[-1] >= x:
return ans + bisect_left(a, x)
ans += len(a)
return ans
def index_right(self, x: T) -> int:
"Count the number of elements <= x."
ans = 0
for a in self.a:
if a[-1] > x:
return ans + bisect_right(a, x)
ans += len(a)
return ans