Skip to content

Commit af7fcc7

Browse files
committed
fix mt
1 parent 294e85d commit af7fcc7

File tree

3 files changed

+45
-11
lines changed

3 files changed

+45
-11
lines changed

examples/mt.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@ def timeit(task_name):
1515
print(f"{task_name} took {end - start:.2f} seconds")
1616

1717

18-
def mt19937(bs):
18+
def mt19937(bs, samples=None):
1919
print("bs:", bs)
2020
rand = random.Random(3142)
2121
st = tuple(rand.getstate()[1][:-1])
2222

2323
effective_bs = ((bs - 1) & bs) or bs
24-
out = [rand.getrandbits(bs) for _ in range(624 * 32 // effective_bs)]
24+
samples = 624 * 32 // effective_bs if samples is None else samples
25+
out = [rand.getrandbits(bs) for _ in range(samples)]
2526

2627
lin = LinearSystem([32] * 624)
2728
mt = lin.gens()
@@ -39,10 +40,14 @@ def mt19937(bs):
3940
pyrand = rng.to_python_random()
4041
assert all(rng.getrandbits(bs) == o for o in out)
4142
assert all(pyrand.getrandbits(bs) == o for o in out)
43+
for _ in range(100):
44+
assert rng.getrandbits(bs) == rand.getrandbits(bs)
4245

4346

4447
if __name__ == "__main__":
4548
mt19937(32)
4649
mt19937(17)
4750
mt19937(9)
4851
mt19937(1)
52+
mt19937(1337, 19968 // 1337 + 10)
53+
mt19937(137, 19968 // 137 + 60)

gf2bv/__init__.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ def __rshift__(self, n: int):
5252
def __lshift__(self, n: int):
5353
return BitVec((0,) * n + self._bits[:-n])
5454

55+
def lshift_ext(self, n: int):
56+
return BitVec((0,) * n + self._bits)
57+
5558
def __and__(self, mask: int):
5659
bs = to_bits(len(self._bits), mask)
5760
if all(bs):
@@ -61,7 +64,24 @@ def __and__(self, mask: int):
6164

6265
__rand__ = __and__
6366

64-
def __or__(self, mask: int):
67+
def __or__(self, mask: BitVec | int):
68+
if isinstance(mask, BitVec):
69+
if len(self._bits) > len(mask._bits):
70+
self, mask = mask, self
71+
ar = [0] * len(mask._bits)
72+
for i in range(len(self._bits)):
73+
if self._bits[i] not in (0, 1) and mask._bits[i] not in (0, 1):
74+
raise ValueError(
75+
"Cannot compute logical or using bitvecs with non-zero bits"
76+
)
77+
if self._bits[i] == 1 or mask._bits[i] == 1:
78+
ar[i] = 1
79+
elif self._bits[i] == 0:
80+
ar[i] = mask._bits[i]
81+
elif mask._bits[i] == 0:
82+
ar[i] = self._bits[i]
83+
ar[len(self._bits) :] = mask._bits[len(self._bits) :]
84+
return BitVec(tuple(ar))
6585
bs = to_bits(len(self._bits), mask)
6686
if all(bs):
6787
# if all bits are set, it becomes all ones

gf2bv/crypto/mt.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -52,20 +52,29 @@ def __call__(self):
5252
self.mti += 1
5353
return self.temper(y)
5454

55+
def _getrandbits_word(self, k):
56+
r = self()
57+
if isinstance(r, BitVec):
58+
return r[self.w - k :]
59+
return r >> (self.w - k)
60+
5561
def getrandbits(self, k=None):
5662
"""Uses the CPython's implementation of random.getrandbits()"""
5763
if k is None:
5864
k = self.w
59-
if k == 0:
60-
return 0
65+
if k < 0:
66+
raise ValueError("number of bits cannot be negative")
6167
if k <= self.w:
62-
return self.__call__() >> (self.w - k)
68+
return self._getrandbits_word(k)
69+
words = (k - 1) // self.w + 1
6370
x = 0
64-
for i in range(0, k, self.w):
65-
r = self.__call__()
66-
if i + self.w > k:
67-
r = r >> (self.w - (k - i))
68-
x |= r << i
71+
for i in range(words):
72+
r = self._getrandbits_word(min(k, self.w))
73+
if isinstance(r, BitVec):
74+
x |= r.lshift_ext(self.w * i)
75+
else:
76+
x |= r << (self.w * i)
77+
k -= self.w
6978
return x
7079

7180

0 commit comments

Comments
 (0)