Skip to content

Commit 6cc2214

Browse files
committed
added context managers in tests
1 parent cacf124 commit 6cc2214

File tree

2 files changed

+42
-48
lines changed

2 files changed

+42
-48
lines changed

tests/test_kem.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@ def test_correctness():
1717

1818

1919
def check_correctness(alg_name):
20-
kem = oqs.KeyEncapsulation(alg_name)
21-
public_key = kem.generate_keypair()
22-
ciphertext, shared_secret_server = kem.encap_secret(public_key)
23-
shared_secret_client = kem.decap_secret(ciphertext)
24-
assert shared_secret_client == shared_secret_server
25-
kem.free()
20+
with oqs.KeyEncapsulation(alg_name) as kem:
21+
public_key = kem.generate_keypair()
22+
ciphertext, shared_secret_server = kem.encap_secret(public_key)
23+
shared_secret_client = kem.decap_secret(ciphertext)
24+
assert shared_secret_client == shared_secret_server
2625

2726

2827
def test_wrong_ciphertext():
@@ -33,19 +32,18 @@ def test_wrong_ciphertext():
3332

3433

3534
def check_wrong_ciphertext(alg_name):
36-
kem = oqs.KeyEncapsulation(alg_name)
37-
public_key = kem.generate_keypair()
38-
ciphertext, shared_secret_server = kem.encap_secret(public_key)
39-
wrong_ciphertext = bytes(random.getrandbits(8) for _ in range(kem.details['length_ciphertext']))
40-
shared_secret_client = kem.decap_secret(wrong_ciphertext)
41-
assert shared_secret_client != shared_secret_server
42-
kem.free()
35+
with oqs.KeyEncapsulation(alg_name) as kem:
36+
public_key = kem.generate_keypair()
37+
ciphertext, shared_secret_server = kem.encap_secret(public_key)
38+
wrong_ciphertext = bytes(random.getrandbits(8) for _ in range(kem.details['length_ciphertext']))
39+
shared_secret_client = kem.decap_secret(wrong_ciphertext)
40+
assert shared_secret_client != shared_secret_server
4341

4442

4543
def test_not_supported():
4644
try:
47-
kem = oqs.KeyEncapsulation("bogus")
48-
raise AssertionError("oqs.MechanismNotSupportedError was not raised.")
45+
with oqs.KeyEncapsulation("bogus") as kem:
46+
raise AssertionError("oqs.MechanismNotSupportedError was not raised.")
4947
except oqs.MechanismNotSupportedError:
5048
pass
5149
except E:
@@ -58,8 +56,8 @@ def test_not_enabled():
5856
if alg_name not in oqs.get_enabled_KEM_mechanisms():
5957
# found a non-enabled but supported alg
6058
try:
61-
kem = oqs.KeyEncapsulation(alg_name)
62-
raise AssertionError("oqs.MechanismNotEnabledError was not raised.")
59+
with oqs.KeyEncapsulation(alg_name) as kem:
60+
raise AssertionError("oqs.MechanismNotEnabledError was not raised.")
6361
except oqs.MechanismNotEnabledError:
6462
pass
6563
except E:

tests/test_sig.py

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@ def test_correctness():
1717

1818

1919
def check_correctness(alg_name):
20-
message = bytes(random.getrandbits(8) for _ in range(100))
21-
sig = oqs.Signature(alg_name)
22-
public_key = sig.generate_keypair()
23-
signature = sig.sign(message)
24-
assert sig.verify(message, signature, public_key)
25-
sig.free()
20+
with oqs.Signature(alg_name) as sig:
21+
message = bytes(random.getrandbits(8) for _ in range(100))
22+
public_key = sig.generate_keypair()
23+
signature = sig.sign(message)
24+
assert sig.verify(message, signature, public_key)
2625

2726

2827
def test_wrong_message():
@@ -33,13 +32,12 @@ def test_wrong_message():
3332

3433

3534
def check_wrong_message(alg_name):
36-
message = bytes(random.getrandbits(8) for _ in range(100))
37-
sig = oqs.Signature(alg_name)
38-
public_key = sig.generate_keypair()
39-
signature = sig.sign(message)
40-
wrong_message = bytes(random.getrandbits(8) for _ in range(100))
41-
assert not (sig.verify(wrong_message, signature, public_key))
42-
sig.free()
35+
with oqs.Signature(alg_name) as sig:
36+
message = bytes(random.getrandbits(8) for _ in range(100))
37+
public_key = sig.generate_keypair()
38+
signature = sig.sign(message)
39+
wrong_message = bytes(random.getrandbits(8) for _ in range(100))
40+
assert not (sig.verify(wrong_message, signature, public_key))
4341

4442

4543
def test_wrong_signature():
@@ -50,13 +48,12 @@ def test_wrong_signature():
5048

5149

5250
def check_wrong_signature(alg_name):
53-
message = bytes(random.getrandbits(8) for _ in range(100))
54-
sig = oqs.Signature(alg_name)
55-
public_key = sig.generate_keypair()
56-
signature = sig.sign(message)
57-
wrong_signature = bytes(random.getrandbits(8) for _ in range(sig.details['length_signature']))
58-
assert not (sig.verify(message, wrong_signature, public_key))
59-
sig.free()
51+
with oqs.Signature(alg_name) as sig:
52+
message = bytes(random.getrandbits(8) for _ in range(100))
53+
public_key = sig.generate_keypair()
54+
signature = sig.sign(message)
55+
wrong_signature = bytes(random.getrandbits(8) for _ in range(sig.details['length_signature']))
56+
assert not (sig.verify(message, wrong_signature, public_key))
6057

6158

6259
def test_wrong_public_key():
@@ -67,19 +64,18 @@ def test_wrong_public_key():
6764

6865

6966
def check_wrong_public_key(alg_name):
70-
message = bytes(random.getrandbits(8) for _ in range(100))
71-
sig = oqs.Signature(alg_name)
72-
public_key = sig.generate_keypair()
73-
signature = sig.sign(message)
74-
wrong_public_key = bytes(random.getrandbits(8) for _ in range(sig.details['length_public_key']))
75-
assert not (sig.verify(message, signature, wrong_public_key))
76-
sig.free()
67+
with oqs.Signature(alg_name) as sig:
68+
message = bytes(random.getrandbits(8) for _ in range(100))
69+
public_key = sig.generate_keypair()
70+
signature = sig.sign(message)
71+
wrong_public_key = bytes(random.getrandbits(8) for _ in range(sig.details['length_public_key']))
72+
assert not (sig.verify(message, signature, wrong_public_key))
7773

7874

7975
def test_not_supported():
8076
try:
81-
sig = oqs.Signature("bogus")
82-
raise AssertionError("oqs.MechanismNotSupportedError was not raised.")
77+
with oqs.Signature("bogus") as sig:
78+
raise AssertionError("oqs.MechanismNotSupportedError was not raised.")
8379
except oqs.MechanismNotSupportedError:
8480
pass
8581
except E:
@@ -92,8 +88,8 @@ def test_not_enabled():
9288
if alg_name not in oqs.get_enabled_sig_mechanisms():
9389
# found a non-enabled but supported alg
9490
try:
95-
sig = oqs.Signature(alg_name)
96-
raise AssertionError("oqs.MechanismNotEnabledError was not raised.")
91+
with oqs.Signature(alg_name) as sig:
92+
raise AssertionError("oqs.MechanismNotEnabledError was not raised.")
9793
except oqs.MechanismNotEnabledError:
9894
pass
9995
except E:

0 commit comments

Comments
 (0)