From 22ef165a0cc86987fa1984347f48fc46d50cf690 Mon Sep 17 00:00:00 2001 From: Mathias Ertl Date: Sat, 9 Aug 2025 10:55:29 +0200 Subject: [PATCH 1/4] implement __eq__ for HashAlgorithm instances --- src/cryptography/hazmat/primitives/hashes.py | 71 +++++++++++++++ tests/doubles.py | 7 ++ tests/hazmat/primitives/test_hashes.py | 91 +++++++++++++++++++- tests/hazmat/primitives/utils.py | 9 ++ 4 files changed, 177 insertions(+), 1 deletion(-) diff --git a/src/cryptography/hazmat/primitives/hashes.py b/src/cryptography/hazmat/primitives/hashes.py index 4b55ec33dbff..d3e6fbbc64fc 100644 --- a/src/cryptography/hazmat/primitives/hashes.py +++ b/src/cryptography/hazmat/primitives/hashes.py @@ -5,6 +5,7 @@ from __future__ import annotations import abc +import typing from cryptography.hazmat.bindings._rust import openssl as rust_openssl from cryptography.utils import Buffer @@ -36,6 +37,13 @@ class HashAlgorithm(metaclass=abc.ABCMeta): + @abc.abstractmethod + def __eq__(self, other: typing.Any) -> bool: + """ + Implement equality checking. + """ + ... + @property @abc.abstractmethod def name(self) -> str: @@ -103,66 +111,99 @@ class SHA1(HashAlgorithm): digest_size = 20 block_size = 64 + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, SHA1) + class SHA512_224(HashAlgorithm): # noqa: N801 name = "sha512-224" digest_size = 28 block_size = 128 + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, SHA512_224) + class SHA512_256(HashAlgorithm): # noqa: N801 name = "sha512-256" digest_size = 32 block_size = 128 + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, SHA512_256) + class SHA224(HashAlgorithm): name = "sha224" digest_size = 28 block_size = 64 + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, SHA224) + class SHA256(HashAlgorithm): name = "sha256" digest_size = 32 block_size = 64 + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, SHA256) + class SHA384(HashAlgorithm): name = "sha384" digest_size = 48 block_size = 128 + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, SHA384) + class SHA512(HashAlgorithm): name = "sha512" digest_size = 64 block_size = 128 + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, SHA512) + class SHA3_224(HashAlgorithm): # noqa: N801 name = "sha3-224" digest_size = 28 block_size = None + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, SHA3_224) + class SHA3_256(HashAlgorithm): # noqa: N801 name = "sha3-256" digest_size = 32 block_size = None + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, SHA3_256) + class SHA3_384(HashAlgorithm): # noqa: N801 name = "sha3-384" digest_size = 48 block_size = None + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, SHA3_384) + class SHA3_512(HashAlgorithm): # noqa: N801 name = "sha3-512" digest_size = 64 block_size = None + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, SHA3_512) + class SHAKE128(HashAlgorithm, ExtendableOutputFunction): name = "shake128" @@ -177,6 +218,12 @@ def __init__(self, digest_size: int): self._digest_size = digest_size + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, SHAKE128) + and self._digest_size == other._digest_size + ) + @property def digest_size(self) -> int: return self._digest_size @@ -195,6 +242,12 @@ def __init__(self, digest_size: int): self._digest_size = digest_size + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, SHAKE256) + and self._digest_size == other._digest_size + ) + @property def digest_size(self) -> int: return self._digest_size @@ -205,6 +258,9 @@ class MD5(HashAlgorithm): digest_size = 16 block_size = 64 + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, MD5) + class BLAKE2b(HashAlgorithm): name = "blake2b" @@ -218,6 +274,12 @@ def __init__(self, digest_size: int): self._digest_size = digest_size + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, BLAKE2b) + and self._digest_size == other._digest_size + ) + @property def digest_size(self) -> int: return self._digest_size @@ -235,6 +297,12 @@ def __init__(self, digest_size: int): self._digest_size = digest_size + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, BLAKE2s) + and self._digest_size == other._digest_size + ) + @property def digest_size(self) -> int: return self._digest_size @@ -244,3 +312,6 @@ class SM3(HashAlgorithm): name = "sm3" digest_size = 32 block_size = 64 + + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, SM3) diff --git a/tests/doubles.py b/tests/doubles.py index cf2c96a3e83c..760fc1ba7c49 100644 --- a/tests/doubles.py +++ b/tests/doubles.py @@ -2,6 +2,7 @@ # 2.0, and the BSD License. See the LICENSE file in the root of this repository # for complete details. +import typing from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import padding @@ -40,6 +41,12 @@ class DummyHashAlgorithm(hashes.HashAlgorithm): def __init__(self, digest_size: int = 32) -> None: self._digest_size = digest_size + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(self, DummyHashAlgorithm) + and self._digest_size == other._digest_size + ) + @property def digest_size(self) -> int: return self._digest_size diff --git a/tests/hazmat/primitives/test_hashes.py b/tests/hazmat/primitives/test_hashes.py index 092ba9af41d4..0076d439961b 100644 --- a/tests/hazmat/primitives/test_hashes.py +++ b/tests/hazmat/primitives/test_hashes.py @@ -12,7 +12,7 @@ from ...doubles import DummyHashAlgorithm from ...utils import raises_unsupported_algorithm -from .utils import generate_base_hash_test +from .utils import generate_base_hash_test, generate_eq_hash_test class TestHashContext: @@ -52,6 +52,7 @@ class TestSHA1: hashes.SHA1(), digest_size=20, ) + test_sha1_eq = generate_eq_hash_test(hashes.SHA1()) @pytest.mark.supported( @@ -63,6 +64,7 @@ class TestSHA224: hashes.SHA224(), digest_size=28, ) + test_sha224_eq = generate_eq_hash_test(hashes.SHA224()) @pytest.mark.supported( @@ -74,6 +76,7 @@ class TestSHA256: hashes.SHA256(), digest_size=32, ) + test_sha256_eq = generate_eq_hash_test(hashes.SHA256()) @pytest.mark.supported( @@ -85,6 +88,7 @@ class TestSHA384: hashes.SHA384(), digest_size=48, ) + test_sha384_eq = generate_eq_hash_test(hashes.SHA384()) @pytest.mark.supported( @@ -96,6 +100,79 @@ class TestSHA512: hashes.SHA512(), digest_size=64, ) + test_sha512_eq = generate_eq_hash_test(hashes.SHA512()) + + +@pytest.mark.supported( + only_if=lambda backend: backend.hash_supported(hashes.SHA512_224()), + skip_message="Does not support SHA512 224", +) +class TestSHA512224: + test_sha512_224 = generate_base_hash_test( + hashes.SHA512_224(), + digest_size=28, + ) + test_sha512_224_eq = generate_eq_hash_test(hashes.SHA512_224()) + + +@pytest.mark.supported( + only_if=lambda backend: backend.hash_supported(hashes.SHA512_256()), + skip_message="Does not support SHA512 256", +) +class TestSHA512256: + test_sha512_256 = generate_base_hash_test( + hashes.SHA512_256(), + digest_size=32, + ) + test_sha512_256_eq = generate_eq_hash_test(hashes.SHA512_256()) + + +@pytest.mark.supported( + only_if=lambda backend: backend.hash_supported(hashes.SHA3_224()), + skip_message="Does not support SHA3 224", +) +class TestSHA3224: + test_sha3_224 = generate_base_hash_test( + hashes.SHA3_224(), + digest_size=28, + ) + test_sha3_224_eq = generate_eq_hash_test(hashes.SHA3_224()) + + +@pytest.mark.supported( + only_if=lambda backend: backend.hash_supported(hashes.SHA3_256()), + skip_message="Does not support SHA3 256", +) +class TestSHA3256: + test_sha3_256 = generate_base_hash_test( + hashes.SHA3_256(), + digest_size=32, + ) + test_sha3_256_eq = generate_eq_hash_test(hashes.SHA3_256()) + + +@pytest.mark.supported( + only_if=lambda backend: backend.hash_supported(hashes.SHA3_384()), + skip_message="Does not support SHA3 384", +) +class TestSHA3384: + test_sha3_384 = generate_base_hash_test( + hashes.SHA3_384(), + digest_size=48, + ) + test_sha3_384_eq = generate_eq_hash_test(hashes.SHA3_384()) + + +@pytest.mark.supported( + only_if=lambda backend: backend.hash_supported(hashes.SHA3_512()), + skip_message="Does not support SHA3 512", +) +class TestSHA3512: + test_sha3_512 = generate_base_hash_test( + hashes.SHA3_512(), + digest_size=64, + ) + test_sha3_512_eq = generate_eq_hash_test(hashes.SHA3_512()) @pytest.mark.supported( @@ -107,6 +184,7 @@ class TestMD5: hashes.MD5(), digest_size=16, ) + test_md5_eq = generate_eq_hash_test(hashes.MD5()) @pytest.mark.supported( @@ -120,6 +198,7 @@ class TestBLAKE2b: hashes.BLAKE2b(digest_size=64), digest_size=64, ) + test_blake2b_eq = generate_eq_hash_test(hashes.BLAKE2b(digest_size=64)) def test_invalid_digest_size(self, backend): with pytest.raises(ValueError): @@ -143,6 +222,7 @@ class TestBLAKE2s: hashes.BLAKE2s(digest_size=32), digest_size=32, ) + test_blake2s_eq = generate_eq_hash_test(hashes.BLAKE2s(digest_size=32)) def test_invalid_digest_size(self, backend): with pytest.raises(ValueError): @@ -165,6 +245,14 @@ def test_buffer_protocol_hash(backend): class TestSHAKE: + @pytest.mark.parametrize("xof", [hashes.SHAKE128, hashes.SHAKE256]) + def test_eq(self, xof): + value_one = xof(digest_size=32) + value_two = xof(digest_size=32) # identical + value_three = xof(digest_size=64) + assert value_one == value_two + assert value_one != value_three + @pytest.mark.parametrize("xof", [hashes.SHAKE128, hashes.SHAKE256]) def test_invalid_digest_type(self, xof): with pytest.raises(TypeError): @@ -188,3 +276,4 @@ class TestSM3: hashes.SM3(), digest_size=32, ) + test_sm3_eq = generate_eq_hash_test(hashes.SM3()) diff --git a/tests/hazmat/primitives/utils.py b/tests/hazmat/primitives/utils.py index aad324683a81..af53155410aa 100644 --- a/tests/hazmat/primitives/utils.py +++ b/tests/hazmat/primitives/utils.py @@ -35,6 +35,7 @@ Mode, ) +from ...doubles import DummyHashAlgorithm from ...utils import load_vectors_from_file @@ -207,6 +208,14 @@ def test_base_hash(self, backend): return test_base_hash +def generate_eq_hash_test(algorithm): + def test_eq(self): + assert algorithm == algorithm + assert algorithm != DummyHashAlgorithm() + + return test_eq + + def base_hash_test(backend, algorithm, digest_size): m = hashes.Hash(algorithm, backend=backend) assert m.algorithm.digest_size == digest_size From 6fc6966a33e886fa43397c7d66285eee945299ae Mon Sep 17 00:00:00 2001 From: Mathias Ertl Date: Sat, 9 Aug 2025 10:56:44 +0200 Subject: [PATCH 2/4] add __eq__ for padding classes --- CHANGELOG.rst | 5 + .../hazmat/primitives/asymmetric/padding.py | 34 ++++++ tests/hazmat/backends/test_openssl.py | 4 + tests/hazmat/primitives/test_rsa.py | 115 ++++++++++++++++++ 4 files changed, 158 insertions(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 9894665ab412..929cc043147a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -18,6 +18,11 @@ Changelog * Removed the deprecated ``CAST5``, ``SEED``, ``IDEA``, and ``Blowfish`` classes from the cipher module. These are still available in :doc:`/hazmat/decrepit/index`. +* Make instances of + :class:`~cryptography.hazmat.primitives.hashes.HashAlgorithm` as well as + instances of classes in + :mod:`~cryptography.hazmat.primitives.asymmetric.padding` + comparable. .. _v45-0-6: diff --git a/src/cryptography/hazmat/primitives/asymmetric/padding.py b/src/cryptography/hazmat/primitives/asymmetric/padding.py index 5121a288fcc7..6448227ee298 100644 --- a/src/cryptography/hazmat/primitives/asymmetric/padding.py +++ b/src/cryptography/hazmat/primitives/asymmetric/padding.py @@ -5,6 +5,7 @@ from __future__ import annotations import abc +import typing from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives._asymmetric import ( @@ -16,6 +17,9 @@ class PKCS1v15(AsymmetricPadding): name = "EMSA-PKCS1-v1_5" + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, PKCS1v15) + class _MaxLength: "Sentinel value for `MAX_LENGTH`." @@ -56,6 +60,18 @@ def __init__( self._salt_length = salt_length + def __eq__(self, other: typing.Any) -> bool: + if isinstance(self._salt_length, int): + eq_salt_length = self._salt_length == other._salt_length + else: + eq_salt_length = self._salt_length is other._salt_length + + return ( + isinstance(other, PSS) + and eq_salt_length + and self._mgf == other._mgf + ) + @property def mgf(self) -> MGF: return self._mgf @@ -77,6 +93,14 @@ def __init__( self._algorithm = algorithm self._label = label + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, OAEP) + and self._mgf == other._mgf + and self._algorithm == other._algorithm + and self._label == other._label + ) + @property def algorithm(self) -> hashes.HashAlgorithm: return self._algorithm @@ -89,6 +113,13 @@ def mgf(self) -> MGF: class MGF(metaclass=abc.ABCMeta): _algorithm: hashes.HashAlgorithm + @abc.abstractmethod + def __eq__(self, other: typing.Any) -> bool: + """ + Implement equality checking. + """ + ... + class MGF1(MGF): def __init__(self, algorithm: hashes.HashAlgorithm): @@ -97,6 +128,9 @@ def __init__(self, algorithm: hashes.HashAlgorithm): self._algorithm = algorithm + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, MGF1) and self._algorithm == other._algorithm + def calculate_max_pss_salt_length( key: rsa.RSAPrivateKey | rsa.RSAPublicKey, diff --git a/tests/hazmat/backends/test_openssl.py b/tests/hazmat/backends/test_openssl.py index a48dc653f033..e8e89efb3f9d 100644 --- a/tests/hazmat/backends/test_openssl.py +++ b/tests/hazmat/backends/test_openssl.py @@ -4,6 +4,7 @@ import itertools +import typing import pytest @@ -32,6 +33,9 @@ class DummyMGF(padding.MGF): _salt_length = 0 _algorithm = hashes.SHA1() + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, DummyMGF) + class TestOpenSSL: def test_backend_exists(self): diff --git a/tests/hazmat/primitives/test_rsa.py b/tests/hazmat/primitives/test_rsa.py index 25edfb07592c..3011d4de7822 100644 --- a/tests/hazmat/primitives/test_rsa.py +++ b/tests/hazmat/primitives/test_rsa.py @@ -7,6 +7,7 @@ import copy import itertools import os +import typing import pytest @@ -70,6 +71,9 @@ class DummyMGF(padding.MGF): _salt_length = 0 _algorithm = hashes.SHA256() + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, DummyMGF) + def _check_fips_key_length(backend, private_key): if ( @@ -1603,6 +1607,14 @@ class TestRSAPKCS1Verification: ) +class TestPKCS1v15: + def test_eq(self): + assert padding.PKCS1v15() == padding.PKCS1v15() + assert padding.PKCS1v15() != padding.PSS( + mgf=padding.MGF1(hashes.SHA256()), salt_length=32 + ) + + class TestPSS: def test_calculate_max_pss_salt_length(self): with pytest.raises(TypeError): @@ -1644,8 +1656,68 @@ def test_mgf_property(self): assert pss.mgf == mgf assert pss.mgf == pss._mgf + @pytest.mark.parametrize("xof", [hashes.SHA256(), hashes.SHA512()]) + @pytest.mark.parametrize( + "salt_length", + [ + 1, + 32, + padding.PSS.MAX_LENGTH, + padding.PSS.AUTO, + padding.PSS.DIGEST_LENGTH, + ], + ) + def test_eq( + self, xof: hashes.HashAlgorithm, salt_length: typing.Any + ) -> None: + assert padding.PSS( + salt_length=salt_length, mgf=padding.MGF1(algorithm=xof) + ) == padding.PSS( + salt_length=salt_length, mgf=padding.MGF1(algorithm=xof) + ) + + @pytest.mark.parametrize( + "salt_length", + [ + 1, + 32, + padding.PSS.MAX_LENGTH, + padding.PSS.AUTO, + padding.PSS.DIGEST_LENGTH, + ], + ) + def test_not_eq_with_different_salt_length( + self, salt_length: typing.Any + ) -> None: + xof = hashes.SHA256() + assert padding.PSS( + salt_length=salt_length, mgf=padding.MGF1(algorithm=xof) + ) != padding.PSS(salt_length=64, mgf=padding.MGF1(algorithm=xof)) + + def test_not_eq_with_salt_length_object_identity(self) -> None: + xof = hashes.SHA256() + assert padding.PSS( + salt_length=padding.PSS.AUTO, mgf=padding.MGF1(algorithm=xof) + ) != padding.PSS( + salt_length=padding.PSS.DIGEST_LENGTH, + mgf=padding.MGF1(algorithm=xof), + ) + + def test_not_eq_with_different_mgf(self) -> None: + assert padding.PSS( + salt_length=padding.PSS.AUTO, + mgf=padding.MGF1(algorithm=hashes.SHA256()), + ) != padding.PSS( + salt_length=padding.PSS.AUTO, + mgf=padding.MGF1(algorithm=hashes.SHA512()), + ) + class TestMGF1: + def test_eq(self) -> None: + assert padding.MGF1(hashes.SHA256()) == padding.MGF1(hashes.SHA256()) + assert padding.MGF1(hashes.SHA256()) != padding.MGF1(hashes.SHA512()) + def test_invalid_hash_algorithm(self): with pytest.raises(TypeError): padding.MGF1(b"not_a_hash") # type:ignore[arg-type] @@ -1680,6 +1752,49 @@ def test_mgf_property(self): assert oaep.mgf == mgf assert oaep.mgf == oaep._mgf + @pytest.mark.parametrize("xof", [hashes.SHA256(), hashes.SHA512()]) + @pytest.mark.parametrize("label", [None, b"", b"foo"]) + def test_eq( + self, xof: hashes.HashAlgorithm, label: typing.Optional[bytes] + ) -> None: + mgf = padding.MGF1(algorithm=xof) + assert padding.OAEP( + mgf=mgf, algorithm=xof, label=label + ) == padding.OAEP(mgf=mgf, algorithm=xof, label=label) + + def test_not_eq_with_different_mgf(self) -> None: + assert padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA256()), + algorithm=hashes.SHA256(), + label=None, + ) != padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA512()), + algorithm=hashes.SHA256(), + label=None, + ) + + def test_not_eq_with_different_algorithm(self) -> None: + assert padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA512()), + algorithm=hashes.SHA512(), + label=None, + ) != padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA512()), + algorithm=hashes.SHA256(), + label=None, + ) + + def test_not_eq_with_different_label(self) -> None: + assert padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA512()), + algorithm=hashes.SHA256(), + label=None, + ) != padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA512()), + algorithm=hashes.SHA256(), + label=b"", + ) + class TestRSADecryption: @pytest.mark.supported( From f842d602760a76743d3de1457b5f12dd1df906f7 Mon Sep 17 00:00:00 2001 From: Mathias Ertl Date: Thu, 7 Aug 2025 13:04:38 +0200 Subject: [PATCH 3/4] add missing properties to padding classes --- CHANGELOG.rst | 7 ++++++ docs/hazmat/primitives/asymmetric/rsa.rst | 24 +++++++++++++++++++ .../hazmat/primitives/asymmetric/padding.py | 12 ++++++++++ tests/hazmat/primitives/test_rsa.py | 19 +++++++++++++++ 4 files changed, 62 insertions(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 929cc043147a..3cb76c937075 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -23,6 +23,13 @@ Changelog instances of classes in :mod:`~cryptography.hazmat.primitives.asymmetric.padding` comparable. +* Added `salt_length` property to + :class:`~cryptography.hazmat.primitives.asymmetric.padding.PSS`. +* Added `label` property to + :class:`~cryptography.hazmat.primitives.asymmetric.padding.OAEP`. +* Added `algorithm` property to + :class:`~cryptography.hazmat.primitives.asymmetric.padding.MGF1`. + .. _v45-0-6: diff --git a/docs/hazmat/primitives/asymmetric/rsa.rst b/docs/hazmat/primitives/asymmetric/rsa.rst index 54190ae2dd38..5e68a784d6a6 100644 --- a/docs/hazmat/primitives/asymmetric/rsa.rst +++ b/docs/hazmat/primitives/asymmetric/rsa.rst @@ -325,6 +325,14 @@ Padding The padding's mask generation function (MGF). + .. attribute:: salt_length + + :type: int + + .. versionadded:: 46.0.0 + + The length of the salt. + .. class:: OAEP(mgf, algorithm, label) .. versionadded:: 0.4 @@ -351,6 +359,14 @@ Padding The padding's hash algorithm. + .. attribute:: label + + :type: bytes | None + + .. versionadded:: 42.0.0 + + The padding's hash algorithm. + .. attribute:: mgf :type: :class:`~cryptography.hazmat.primitives.asymmetric.padding.MGF` @@ -411,6 +427,14 @@ Mask generation functions :param algorithm: An instance of :class:`~cryptography.hazmat.primitives.hashes.HashAlgorithm`. + .. attribute:: algorithm + + :type: :class:`~cryptography.hazmat.primitives.hashes.HashAlgorithm`. + + .. versionadded:: 46.0.0 + + The algorithm of this instance. + Numbers ~~~~~~~ diff --git a/src/cryptography/hazmat/primitives/asymmetric/padding.py b/src/cryptography/hazmat/primitives/asymmetric/padding.py index 6448227ee298..cea9c1ae7446 100644 --- a/src/cryptography/hazmat/primitives/asymmetric/padding.py +++ b/src/cryptography/hazmat/primitives/asymmetric/padding.py @@ -76,6 +76,10 @@ def __eq__(self, other: typing.Any) -> bool: def mgf(self) -> MGF: return self._mgf + @property + def salt_length(self) -> int | _MaxLength | _Auto | _DigestLength: + return self._salt_length + class OAEP(AsymmetricPadding): name = "EME-OAEP" @@ -105,6 +109,10 @@ def __eq__(self, other: typing.Any) -> bool: def algorithm(self) -> hashes.HashAlgorithm: return self._algorithm + @property + def label(self) -> bytes | None: + return self._label + @property def mgf(self) -> MGF: return self._mgf @@ -131,6 +139,10 @@ def __init__(self, algorithm: hashes.HashAlgorithm): def __eq__(self, other: typing.Any) -> bool: return isinstance(other, MGF1) and self._algorithm == other._algorithm + @property + def algorithm(self) -> hashes.HashAlgorithm: + return self._algorithm + def calculate_max_pss_salt_length( key: rsa.RSAPrivateKey | rsa.RSAPublicKey, diff --git a/tests/hazmat/primitives/test_rsa.py b/tests/hazmat/primitives/test_rsa.py index 3011d4de7822..cd54321ef24e 100644 --- a/tests/hazmat/primitives/test_rsa.py +++ b/tests/hazmat/primitives/test_rsa.py @@ -1656,6 +1656,13 @@ def test_mgf_property(self): assert pss.mgf == mgf assert pss.mgf == pss._mgf + def test_salt_length_property(self): + algorithm = hashes.SHA256() + mgf = padding.MGF1(algorithm) + pss = padding.PSS(mgf=mgf, salt_length=padding.PSS.MAX_LENGTH) + assert pss.salt_length == padding.PSS.MAX_LENGTH + assert pss._salt_length == padding.PSS.MAX_LENGTH + @pytest.mark.parametrize("xof", [hashes.SHA256(), hashes.SHA512()]) @pytest.mark.parametrize( "salt_length", @@ -1727,6 +1734,11 @@ def test_valid_mgf1_parameters(self): mgf = padding.MGF1(algorithm) assert mgf._algorithm == algorithm + def test_algorithm_property(self): + mgf = padding.MGF1(hashes.SHA256()) + assert mgf.algorithm == hashes.SHA256() + assert mgf._algorithm == hashes.SHA256() + class TestOAEP: def test_invalid_algorithm(self): @@ -1745,6 +1757,13 @@ def test_algorithm_property(self): assert oaep.algorithm == algorithm assert oaep.algorithm == oaep._algorithm + def test_label_property(self): + algorithm = hashes.SHA256() + mgf = padding.MGF1(algorithm) + oaep = padding.OAEP(mgf=mgf, algorithm=algorithm, label=None) + assert oaep.label is None + assert oaep._label is None + def test_mgf_property(self): algorithm = hashes.SHA256() mgf = padding.MGF1(algorithm) From 2be27b4f179acad9dd0bd0f2e0e78d6ce80e431f Mon Sep 17 00:00:00 2001 From: Mathias Ertl Date: Sat, 9 Aug 2025 11:11:06 +0200 Subject: [PATCH 4/4] update documentation to use new equality checks --- docs/hazmat/primitives/asymmetric/cloudhsm.rst | 2 +- docs/x509/reference.rst | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/hazmat/primitives/asymmetric/cloudhsm.rst b/docs/hazmat/primitives/asymmetric/cloudhsm.rst index 8934133a228a..09ae30f6fc87 100644 --- a/docs/hazmat/primitives/asymmetric/cloudhsm.rst +++ b/docs/hazmat/primitives/asymmetric/cloudhsm.rst @@ -88,7 +88,7 @@ if you only need a subset of functionality. ... Maps the cryptography padding and algorithm to the corresponding KMS signing algorithm. ... This is specific to your implementation. ... """ - ... if isinstance(padding, PKCS1v15) and isinstance(algorithm, hashes.SHA256): + ... if padding == PKCS1v15() and algorithm == hashes.SHA256(): ... return b"RSA_PKCS1_V1_5_SHA_256" ... else: ... raise NotImplementedError() diff --git a/docs/x509/reference.rst b/docs/x509/reference.rst index 74d6da68bad4..6acc63a0f4bc 100644 --- a/docs/x509/reference.rst +++ b/docs/x509/reference.rst @@ -248,7 +248,7 @@ Loading Certificate Revocation Lists >>> from cryptography import x509 >>> from cryptography.hazmat.primitives import hashes >>> crl = x509.load_pem_x509_crl(pem_crl_data) - >>> isinstance(crl.signature_hash_algorithm, hashes.SHA256) + >>> crl.signature_hash_algorithm == hashes.SHA256() True .. function:: load_der_x509_crl(data) @@ -287,7 +287,7 @@ Loading Certificate Signing Requests >>> from cryptography import x509 >>> from cryptography.hazmat.primitives import hashes >>> csr = x509.load_pem_x509_csr(pem_req_data) - >>> isinstance(csr.signature_hash_algorithm, hashes.SHA256) + >>> csr.signature_hash_algorithm == hashes.SHA256() True .. function:: load_der_x509_csr(data) @@ -477,7 +477,7 @@ X.509 Certificate Object .. doctest:: >>> from cryptography.hazmat.primitives import hashes - >>> isinstance(cert.signature_hash_algorithm, hashes.SHA256) + >>> cert.signature_hash_algorithm == hashes.SHA256() True .. attribute:: signature_algorithm_oid @@ -716,7 +716,7 @@ X.509 CRL (Certificate Revocation List) Object .. doctest:: >>> from cryptography.hazmat.primitives import hashes - >>> isinstance(crl.signature_hash_algorithm, hashes.SHA256) + >>> crl.signature_hash_algorithm == hashes.SHA256() True .. attribute:: signature_algorithm_oid @@ -1119,7 +1119,7 @@ X.509 CSR (Certificate Signing Request) Object .. doctest:: >>> from cryptography.hazmat.primitives import hashes - >>> isinstance(csr.signature_hash_algorithm, hashes.SHA256) + >>> csr.signature_hash_algorithm == hashes.SHA256() True .. attribute:: signature_algorithm_oid