Skip to content

Commit 3028574

Browse files
committed
refactor to match PR with rust ports
1 parent d2aefa9 commit 3028574

File tree

5 files changed

+43
-40
lines changed

5 files changed

+43
-40
lines changed

Diff for: src/cryptography/hazmat/bindings/_rust/__init__.pyi

-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ import typing
66

77
from cryptography.hazmat.primitives import padding
88

9-
def check_ansix923_padding(data: bytes) -> bool: ...
10-
119
class PKCS7PaddingContext(padding.PaddingContext):
1210
def __init__(self, block_size: int) -> None: ...
1311
def update(self, data: bytes) -> bytes: ...

Diff for: src/rust/src/lib.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ mod _rust {
106106
use crate::oid::ObjectIdentifier;
107107
#[pymodule_export]
108108
use crate::padding::{
109-
check_ansix923_padding, ANSIX923PaddingContext, ANSIX923UnpaddingContext,
110-
PKCS7PaddingContext, PKCS7UnpaddingContext,
109+
ANSIX923PaddingContext, ANSIX923UnpaddingContext, PKCS7PaddingContext,
110+
PKCS7UnpaddingContext,
111111
};
112112
#[pymodule_export]
113113
use crate::pkcs12::pkcs12;

Diff for: src/rust/src/padding.rs

-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ pub(crate) fn check_pkcs7_padding(data: &[u8]) -> bool {
4242
(mismatch & 1) == 0
4343
}
4444

45-
#[pyo3::pyfunction]
4645
pub(crate) fn check_ansix923_padding(data: &[u8]) -> bool {
4746
let mut mismatch = 0;
4847
let pad_size = *data.last().unwrap();

Diff for: tests/hazmat/primitives/test_padding.py

+13-22
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# for complete details.
44

55

6-
import contextlib
76
import sys
87
import threading
98

@@ -12,7 +11,13 @@
1211
from cryptography.exceptions import AlreadyFinalized
1312
from cryptography.hazmat.primitives import padding
1413

15-
from .utils import IS_FREETHREADED_BUILD, run_threaded
14+
from .utils import IS_FREETHREADED_BUILD, SwitchIntervalContext, run_threaded
15+
16+
SHOULD_LOCK_BYTESTRING = (
17+
IS_FREETHREADED_BUILD
18+
or sys.version_info < (3, 10)
19+
or sys.implementation.name == "pypy"
20+
)
1621

1722

1823
class TestPKCS7:
@@ -248,20 +253,6 @@ def test_bytearray(self):
248253
assert final == unpadded + unpadded
249254

250255

251-
class SwitchIntervalContext(contextlib.ContextDecorator):
252-
def __init__(self, interval):
253-
self.interval = interval
254-
255-
def __enter__(self):
256-
self.orig_interval = sys.getswitchinterval()
257-
sys.setswitchinterval(self.interval)
258-
return self
259-
260-
def __exit__(self, *exc):
261-
sys.setswitchinterval(self.orig_interval)
262-
return False
263-
264-
265256
@SwitchIntervalContext(0.0000001)
266257
@pytest.mark.parametrize(
267258
"algorithm",
@@ -292,8 +283,8 @@ def pad_in_chunks(chunk_size):
292283
while index < len(data):
293284
try:
294285
new_content = padder.update(data[index : index + chunk_size])
295-
if IS_FREETHREADED_BUILD or sys.version_info < (3, 10):
296-
# appending to a bytestring is racey on 3.13t and < 3.10
286+
if SHOULD_LOCK_BYTESTRING:
287+
# appending to a bytestring is racey on some Pythons
297288
lock.acquire()
298289
calculated_pad += new_content
299290
lock.release()
@@ -314,7 +305,7 @@ def prepare_args(threadnum):
314305
assert chunk_size % len(chunk) == 0
315306
return (chunk_size,)
316307

317-
run_threaded(num_threads, chunk, pad_in_chunks, prepare_args)
308+
run_threaded(num_threads, pad_in_chunks, prepare_args)
318309

319310
calculated_pad += padder.finalize()
320311
assert expected_pad == calculated_pad
@@ -349,8 +340,8 @@ def unpad_in_chunks():
349340
while index < num_repeats:
350341
try:
351342
new_content = padder.update(block)
352-
if IS_FREETHREADED_BUILD or sys.version_info < (3, 10):
353-
# appending to a bytestring is racey on 3.13t and < 3.10
343+
if SHOULD_LOCK_BYTESTRING:
344+
# appending to a bytestring is racey on some Pythons
354345
lock.acquire()
355346
calculated_unpadded_message += new_content
356347
lock.release()
@@ -365,7 +356,7 @@ def unpad_in_chunks():
365356
continue
366357
index += 1
367358

368-
run_threaded(num_threads, chunk, unpad_in_chunks, lambda x: tuple())
359+
run_threaded(num_threads, unpad_in_chunks, lambda x: tuple())
369360

370361
calculated_unpadded_message += padder.finalize()
371362
assert expected_unpadded_message == calculated_unpadded_message

Diff for: tests/hazmat/primitives/utils.py

+28-13
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44

55

66
import binascii
7+
import concurrent.futures
8+
import contextlib
79
import os
10+
import sys
811
import sysconfig
912
import threading
1013
import typing
@@ -225,18 +228,6 @@ def base_hash_test(backend, algorithm, digest_size):
225228
assert copy.finalize() == m.finalize()
226229

227230

228-
def run_threaded(num_threads, chunk, func, prepare_args):
229-
threads = []
230-
for threadnum in range(num_threads):
231-
thread = threading.Thread(target=func, args=prepare_args(threadnum))
232-
threads.append(thread)
233-
234-
for thread in threads:
235-
thread.start()
236-
for thread in threads:
237-
thread.join()
238-
239-
240231
def multithreaded_hash_test(backend, algorithm):
241232
# adapted from test_threaded_hashing in CPython's hashlib tests
242233
# Updating the same hash object from several threads at once
@@ -276,7 +267,7 @@ def prepare_args(threadnum):
276267
assert chunk_size % len(chunk) == 0
277268
return (chunk_size,)
278269

279-
run_threaded(num_threads, chunk, hash_in_chunks, prepare_args)
270+
run_threaded(num_threads, hash_in_chunks, prepare_args)
280271

281272
calculated_hash = hasher.finalize()
282273

@@ -642,3 +633,27 @@ def skip_fips_traditional_openssl(backend, fmt):
642633
pytest.skip(
643634
"Traditional OpenSSL key format is not supported in FIPS mode."
644635
)
636+
637+
638+
class SwitchIntervalContext(contextlib.ContextDecorator):
639+
"""Context manager to track thread switch interval state"""
640+
641+
def __init__(self, interval):
642+
self.interval = interval
643+
644+
def __enter__(self):
645+
self.orig_interval = sys.getswitchinterval()
646+
sys.setswitchinterval(self.interval)
647+
return self
648+
649+
def __exit__(self, *exc):
650+
sys.setswitchinterval(self.orig_interval)
651+
return False
652+
653+
654+
def run_threaded(num_threads, func, prepare_args):
655+
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as e:
656+
futures = [
657+
e.submit(func, *prepare_args(t)) for t in range(num_threads)
658+
]
659+
[f.result() for f in futures]

0 commit comments

Comments
 (0)