From 56feaa59d3b4756247c31b439c4ce0d1993e305c Mon Sep 17 00:00:00 2001 From: PascalDR Date: Thu, 23 Nov 2023 11:24:38 +0100 Subject: [PATCH 01/88] feat: added policy apply on metadata --- pyeudiw/trust/__init__.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/pyeudiw/trust/__init__.py b/pyeudiw/trust/__init__.py index 0835f386..178b438f 100644 --- a/pyeudiw/trust/__init__.py +++ b/pyeudiw/trust/__init__.py @@ -13,6 +13,9 @@ MissingTrustType ) +import pyeudiw.metadata.policy as pcl + +#from pyeudiw.metadata.policy import TrustChainPolicy, combine class TrustEvaluationHelper: def __init__(self, storage: DBEngine, httpc_params, trust_anchor: str = None, **kwargs): @@ -158,20 +161,34 @@ def x509(self) -> bool: self.is_valid = self._handle_x509_pem() return self.is_valid - def get_final_metadata(self, metadata_type: str) -> dict: - # TODO - apply metadata policy and get the final metadata - # for now the final_metadata is the EC metadata -> TODO final_metadata + def get_final_metadata(self, metadata_type: str, policies: list[dict]) -> dict: + policy_acc = {"metadata": {}, "metadata_policy": {}} + + for policy in policies: + policy_acc = pcl.combine(policy, policy_acc) + self.final_metadata = unpad_jwt_payload(self.trust_chain[0]) + try: # TODO: there are some cases where the jwks are taken from a uri ... - return self.final_metadata['metadata'][metadata_type] + selected_metadata = { + "metadata": self.final_metadata['metadata'], + "metadata_policy": {} + } + + self.final_metadata = pcl.TrustChainPolicy().apply_policy( + selected_metadata, + policy_acc + ) + return self.final_metadata["metadata"][metadata_type] except KeyError: raise ProtocolMetadataNotFound( f"{metadata_type} not found in the final metadata:" - f" {self.final_metadata}" + f" {self.final_metadata['metadata']}" ) - def get_trusted_jwks(self, metadata_type: str) -> list: + def get_trusted_jwks(self, metadata_type: str, policies: list[dict] = []) -> list: return self.get_final_metadata( - metadata_type=metadata_type + metadata_type=metadata_type, + policies=policies ).get('jwks', {}).get('keys', []) From 1ea7bac372e00a5b56459649be1347d0e0370c38 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Thu, 23 Nov 2023 16:52:04 +0100 Subject: [PATCH 02/88] test: added intial tests for TrustEvaluationHelper --- .../tests/trust/test_TrustEvaluationHelper.py | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 pyeudiw/tests/trust/test_TrustEvaluationHelper.py diff --git a/pyeudiw/tests/trust/test_TrustEvaluationHelper.py b/pyeudiw/tests/trust/test_TrustEvaluationHelper.py new file mode 100644 index 00000000..69dcb973 --- /dev/null +++ b/pyeudiw/tests/trust/test_TrustEvaluationHelper.py @@ -0,0 +1,60 @@ +import pytest +from datetime import datetime +from pyeudiw.tests.settings import CONFIG +from pyeudiw.trust import TrustEvaluationHelper +from pyeudiw.storage.db_engine import DBEngine, TrustType +from pyeudiw.tests.federation.base import trust_chain_issuer +from pyeudiw.tests.x509.test_x509 import gen_chain, chain_to_pem + + +class TestTrustEvaluationHelper: + @pytest.fixture(autouse=True) + def create_engine_instance(self): + self.engine = DBEngine(CONFIG['storage']) + + def test_evaluation_method_federation(self): + teh = TrustEvaluationHelper(self.engine, {}, "", **{"trust_chain": trust_chain_issuer}) + + assert teh.federation == teh._get_evaluation_method() + + def test_chain_validity_federation(self): + teh = TrustEvaluationHelper(self.engine, {}, "", **{"trust_chain": trust_chain_issuer}) + + assert teh.evaluation_method() == True + + def test_evaluation_method_x509(self): + teh = TrustEvaluationHelper(self.engine, {}, "", **{"trust_chain": gen_chain()}) + + assert teh.x509 == teh._get_evaluation_method() + + def test_chain_validity_x509(self): + date = datetime.now() + + x509_chain = gen_chain() + + self.engine.add_trust_anchor( + "leaf.example.org", chain_to_pem(x509_chain), date, TrustType.X509) + + teh = TrustEvaluationHelper(self.engine, {}, "", **{"trust_chain": x509_chain}) + + assert teh.evaluation_method() == True + + def test_chain_invalid_x509(self): + date = datetime.now() + x509_chain = gen_chain() + x509_chain[1] = x509_chain[0] + + self.engine.add_trust_anchor( + "leaf.example.org", chain_to_pem(x509_chain), date, TrustType.X509) + + teh = TrustEvaluationHelper(self.engine, {}, "", **{"trust_chain": x509_chain}) + + assert teh.evaluation_method() == False + + def test_get_trusted_jwk(self): + teh = TrustEvaluationHelper(self.engine, {}, "", **{"trust_chain": trust_chain_issuer}) + + trusted_jwks = teh.get_trusted_jwks("openid_credential_issuer") + + assert trusted_jwks + assert len(trusted_jwks) == 1 \ No newline at end of file From 47fd884ca8087fce28dd56181cfa71472bab95ac Mon Sep 17 00:00:00 2001 From: PascalDR Date: Thu, 23 Nov 2023 16:53:11 +0100 Subject: [PATCH 03/88] fix: fixed validation issues --- pyeudiw/trust/__init__.py | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/pyeudiw/trust/__init__.py b/pyeudiw/trust/__init__.py index 178b438f..04674edb 100644 --- a/pyeudiw/trust/__init__.py +++ b/pyeudiw/trust/__init__.py @@ -30,7 +30,7 @@ def __init__(self, storage: DBEngine, httpc_params, trust_anchor: str = None, ** for k, v in kwargs.items(): setattr(self, k, v) - def evaluation_method(self) -> bool: + def _get_evaluation_method(self): # The trust chain can be either federation or x509 # If the trust_chain is empty, and we don't have a trust anchor if not self.trust_chain and not self.trust_anchor: @@ -38,15 +38,24 @@ def evaluation_method(self) -> bool: "Static trust chain is not available" ) - if is_jwt_format(self.trust_chain[0]): - return self.federation() - elif is_der_format(self.trust_chain[0]): - return self.x509() + try: + if is_jwt_format(self.trust_chain[0]): + return self.federation + except TypeError: + pass + + if is_der_format(self.trust_chain[0]): + return self.x509 raise InvalidTrustType( "Invalid Trust Type: trust type not supported" ) + + def evaluation_method(self) -> bool: + ev_method = self._get_evaluation_method() + return ev_method() + def _handle_federation_chain(self): _first_statement = unpad_jwt_payload(self.trust_chain[-1]) trust_anchor_eid = self.trust_anchor or _first_statement.get( @@ -114,7 +123,8 @@ def _handle_federation_chain(self): return _is_valid def _handle_x509_pem(self): - trust_anchor_eid = self.trust_anchor or get_issuer_from_x5c(self.x5c) + trust_anchor_eid = self.trust_anchor or get_issuer_from_x5c(self.trust_chain) + _is_valid = False if not trust_anchor_eid: raise UnknownTrustAnchor( @@ -130,9 +140,12 @@ def _handle_x509_pem(self): "a recognizable Trust Anchor." ) - pem = trust_anchor['x509']['pem'] - - _is_valid = verify_x509_anchor(pem) + try: + pem = trust_anchor['x509']['pem'] + _is_valid = verify_x509_anchor(pem) + except KeyError: + raise MissingTrustType( + f"Trust Anchor: '{trust_anchor_eid}' has no x509 trusst entity") if not self.is_trusted and trust_anchor['federation'].get("chain", None) != None: self._handle_federation_chain() @@ -180,6 +193,7 @@ def get_final_metadata(self, metadata_type: str, policies: list[dict]) -> dict: selected_metadata, policy_acc ) + return self.final_metadata["metadata"][metadata_type] except KeyError: raise ProtocolMetadataNotFound( From bd36786fd59a9d7d199cac32a5896661d1243858 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Thu, 23 Nov 2023 16:53:53 +0100 Subject: [PATCH 04/88] feat: implemented method add_trust_attestation_metadata --- pyeudiw/storage/base_storage.py | 3 +++ pyeudiw/storage/db_engine.py | 3 +++ pyeudiw/storage/mongo_storage.py | 19 ++++++++++++++++--- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/pyeudiw/storage/base_storage.py b/pyeudiw/storage/base_storage.py index 801ab6d5..3678231e 100644 --- a/pyeudiw/storage/base_storage.py +++ b/pyeudiw/storage/base_storage.py @@ -57,6 +57,9 @@ def has_trust_anchor(self, entity_id: str): def add_trust_attestation(self, entity_id: str, attestation: list[str], exp: datetime, trust_type: TrustType) -> str: raise NotImplementedError() + + def add_trust_attestation_metadata(self, entity_id: str, metadata: dict) -> str: + raise NotImplementedError() def add_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datetime, trust_type: TrustType): raise NotImplementedError() diff --git a/pyeudiw/storage/db_engine.py b/pyeudiw/storage/db_engine.py index 4d999713..252255a4 100644 --- a/pyeudiw/storage/db_engine.py +++ b/pyeudiw/storage/db_engine.py @@ -157,6 +157,9 @@ def has_trust_anchor(self, entity_id: str): def add_trust_attestation(self, entity_id: str, attestation: list[str], exp: datetime, trust_type: TrustType = TrustType.FEDERATION) -> str: return self.write("add_trust_attestation", entity_id, attestation, exp, trust_type) + + def add_trust_attestation_metadata(self, entity_id: str, metadata: dict) -> str: + return self.write("add_trust_attestation_metadata", entity_id, metadata) def add_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datetime, trust_type: TrustType = TrustType.FEDERATION): return self.write("add_trust_anchor", entity_id, entity_configuration, exp, trust_type) diff --git a/pyeudiw/storage/mongo_storage.py b/pyeudiw/storage/mongo_storage.py index 2bf1f53f..8ef31397 100644 --- a/pyeudiw/storage/mongo_storage.py +++ b/pyeudiw/storage/mongo_storage.py @@ -261,7 +261,8 @@ def add_trust_attestation(self, entity_id: str, attestation: list[str], exp: dat entity = { "entity_id": entity_id, "federation": {}, - "x509": {} + "x509": {}, + "metadata": {} } updated_entity = self._update_attestation_metadata(entity, attestation, exp, trust_type) @@ -270,6 +271,18 @@ def add_trust_attestation(self, entity_id: str, attestation: list[str], exp: dat "trust_attestations", entity_id, updated_entity, exp ) + def add_trust_attestation_metadata(self, entity_id: str, metadata: dict): + entity = self._get_trust_attestation("trust_attestations", entity_id) + + if entity is None: + raise ValueError( + f'Document with entity_id {entity_id} not found.' + ) + + entity["metadata"] = metadata + + return self._update_trust_attestation("trust_attestations", entity_id, entity) + def add_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datetime, trust_type: TrustType): if self.has_trust_anchor(entity_id): self.update_trust_anchor(entity_id, entity_configuration, exp, trust_type) @@ -283,7 +296,7 @@ def add_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datet updated_entity = self._update_anchor_metadata(entity, entity_configuration, exp, trust_type) self._add_entry("trust_anchors", entity_id, updated_entity, exp) - def _update_trust_attestation(self, collection: str, entity_id: str, entity: dict, exp: datetime) -> str: + def _update_trust_attestation(self, collection: str, entity_id: str, entity: dict) -> str: if not self._has_trust_attestation(collection, entity_id): raise ChainNotExist(f"Chain with entity id {entity_id} not exist") @@ -297,7 +310,7 @@ def update_trust_attestation(self, entity_id: str, attestation: list[str], exp: old_entity = self._get_trust_attestation("trust_attestations", entity_id) or {} upd_entity = self._update_attestation_metadata(old_entity, attestation, exp, trust_type) - return self._update_trust_attestation("trust_attestations", entity_id, upd_entity, exp) + return self._update_trust_attestation("trust_attestations", entity_id, upd_entity) def update_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datetime, trust_type: TrustType) -> str: old_entity = self._get_trust_attestation("trust_attestations", entity_id) or {} From aa0133f452489c29580baa98570b356357910581 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Thu, 23 Nov 2023 16:54:40 +0100 Subject: [PATCH 05/88] test: added test for add_trust_attestation_metadata --- pyeudiw/tests/storage/test_db_engine.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/pyeudiw/tests/storage/test_db_engine.py b/pyeudiw/tests/storage/test_db_engine.py index 7be985eb..ac1a3614 100755 --- a/pyeudiw/tests/storage/test_db_engine.py +++ b/pyeudiw/tests/storage/test_db_engine.py @@ -131,6 +131,25 @@ def test_update_unexistent_trusted_attestation(self): except StorageWriteError as e: return + def test_update_trusted_attestation_metadata(self): + replica_count = self.engine.add_trust_attestation_metadata( + self.federation_entity_id, {"metadata": "test"}) + + assert replica_count > 0 + + ta = self.engine.get_trust_attestation(self.federation_entity_id) + + assert ta.get("metadata", None) != None + assert ta["metadata"] == {"metadata": "test"} + + def test_update_unexistent_trusted_attestation_metadata(self): + try: + self.engine.add_trust_attestation_metadata( + "test", {"metadata": "test"}) + assert False + except StorageWriteError as e: + return + @pytest.fixture(autouse=True) def test_insert_trusted_anchor_federation(self): self.federation_entity_anchor_id = str(uuid.uuid4()) From 1d3fa1b74b7251c4f4702dabb7aedcc984824572 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Mon, 27 Nov 2023 16:04:48 +0100 Subject: [PATCH 06/88] fix: added metadata association by metadata_type field --- pyeudiw/storage/base_storage.py | 2 +- pyeudiw/storage/db_engine.py | 4 ++-- pyeudiw/storage/mongo_storage.py | 4 ++-- pyeudiw/tests/storage/test_db_engine.py | 6 +++--- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pyeudiw/storage/base_storage.py b/pyeudiw/storage/base_storage.py index 3678231e..f97eaf96 100644 --- a/pyeudiw/storage/base_storage.py +++ b/pyeudiw/storage/base_storage.py @@ -58,7 +58,7 @@ def has_trust_anchor(self, entity_id: str): def add_trust_attestation(self, entity_id: str, attestation: list[str], exp: datetime, trust_type: TrustType) -> str: raise NotImplementedError() - def add_trust_attestation_metadata(self, entity_id: str, metadata: dict) -> str: + def add_trust_attestation_metadata(self, entity_id: str, metadata_type: str, metadata: dict) -> str: raise NotImplementedError() def add_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datetime, trust_type: TrustType): diff --git a/pyeudiw/storage/db_engine.py b/pyeudiw/storage/db_engine.py index 252255a4..b1dbc9cb 100644 --- a/pyeudiw/storage/db_engine.py +++ b/pyeudiw/storage/db_engine.py @@ -158,8 +158,8 @@ def has_trust_anchor(self, entity_id: str): def add_trust_attestation(self, entity_id: str, attestation: list[str], exp: datetime, trust_type: TrustType = TrustType.FEDERATION) -> str: return self.write("add_trust_attestation", entity_id, attestation, exp, trust_type) - def add_trust_attestation_metadata(self, entity_id: str, metadata: dict) -> str: - return self.write("add_trust_attestation_metadata", entity_id, metadata) + def add_trust_attestation_metadata(self, entity_id: str, metadat_type: str, metadata: dict) -> str: + return self.write("add_trust_attestation_metadata", entity_id, metadat_type, metadata) def add_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datetime, trust_type: TrustType = TrustType.FEDERATION): return self.write("add_trust_anchor", entity_id, entity_configuration, exp, trust_type) diff --git a/pyeudiw/storage/mongo_storage.py b/pyeudiw/storage/mongo_storage.py index 8ef31397..541c5471 100644 --- a/pyeudiw/storage/mongo_storage.py +++ b/pyeudiw/storage/mongo_storage.py @@ -271,7 +271,7 @@ def add_trust_attestation(self, entity_id: str, attestation: list[str], exp: dat "trust_attestations", entity_id, updated_entity, exp ) - def add_trust_attestation_metadata(self, entity_id: str, metadata: dict): + def add_trust_attestation_metadata(self, entity_id: str, metadata_type: str, metadata: dict): entity = self._get_trust_attestation("trust_attestations", entity_id) if entity is None: @@ -279,7 +279,7 @@ def add_trust_attestation_metadata(self, entity_id: str, metadata: dict): f'Document with entity_id {entity_id} not found.' ) - entity["metadata"] = metadata + entity["metadata"][metadata_type] = metadata return self._update_trust_attestation("trust_attestations", entity_id, entity) diff --git a/pyeudiw/tests/storage/test_db_engine.py b/pyeudiw/tests/storage/test_db_engine.py index ac1a3614..bc77f542 100755 --- a/pyeudiw/tests/storage/test_db_engine.py +++ b/pyeudiw/tests/storage/test_db_engine.py @@ -133,19 +133,19 @@ def test_update_unexistent_trusted_attestation(self): def test_update_trusted_attestation_metadata(self): replica_count = self.engine.add_trust_attestation_metadata( - self.federation_entity_id, {"metadata": "test"}) + self.federation_entity_id, "test_metadata", {"metadata": "test"}) assert replica_count > 0 ta = self.engine.get_trust_attestation(self.federation_entity_id) assert ta.get("metadata", None) != None - assert ta["metadata"] == {"metadata": "test"} + assert ta["metadata"]["test_metadata"] == {"metadata": "test"} def test_update_unexistent_trusted_attestation_metadata(self): try: self.engine.add_trust_attestation_metadata( - "test", {"metadata": "test"}) + "test", "test_metadata", {"metadata": "test"}) assert False except StorageWriteError as e: return From d90974ce11036794799a476e4307cfd2796149e8 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Mon, 27 Nov 2023 16:09:41 +0100 Subject: [PATCH 07/88] fix: minor fix to test for add_trust_attestation_metadata's data type --- pyeudiw/tests/storage/test_db_engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyeudiw/tests/storage/test_db_engine.py b/pyeudiw/tests/storage/test_db_engine.py index bc77f542..b4e439af 100755 --- a/pyeudiw/tests/storage/test_db_engine.py +++ b/pyeudiw/tests/storage/test_db_engine.py @@ -133,19 +133,19 @@ def test_update_unexistent_trusted_attestation(self): def test_update_trusted_attestation_metadata(self): replica_count = self.engine.add_trust_attestation_metadata( - self.federation_entity_id, "test_metadata", {"metadata": "test"}) + self.federation_entity_id, "test_metadata", {"metadata": {"data_type": "test"}}) assert replica_count > 0 ta = self.engine.get_trust_attestation(self.federation_entity_id) assert ta.get("metadata", None) != None - assert ta["metadata"]["test_metadata"] == {"metadata": "test"} + assert ta["metadata"]["test_metadata"] == {"metadata": {"data_type": "test"}} def test_update_unexistent_trusted_attestation_metadata(self): try: self.engine.add_trust_attestation_metadata( - "test", "test_metadata", {"metadata": "test"}) + "test", "test_metadata", {"metadata": {"data_type": "test"}}) assert False except StorageWriteError as e: return From 850d432f9239edfa89a5ab2c2f9d0ea4547a50fc Mon Sep 17 00:00:00 2001 From: PascalDR Date: Mon, 27 Nov 2023 16:11:28 +0100 Subject: [PATCH 08/88] chore: renamed test file --- ...t_TrustEvaluationHelper.py => test_trust_evaluation_helper.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename pyeudiw/tests/trust/{test_TrustEvaluationHelper.py => test_trust_evaluation_helper.py} (100%) diff --git a/pyeudiw/tests/trust/test_TrustEvaluationHelper.py b/pyeudiw/tests/trust/test_trust_evaluation_helper.py similarity index 100% rename from pyeudiw/tests/trust/test_TrustEvaluationHelper.py rename to pyeudiw/tests/trust/test_trust_evaluation_helper.py From c062b354e94337a19c652e5fdc50240b3a8422f1 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Mon, 27 Nov 2023 16:13:16 +0100 Subject: [PATCH 09/88] chore: Removed comment --- pyeudiw/trust/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pyeudiw/trust/__init__.py b/pyeudiw/trust/__init__.py index 82cbb0ff..b5bad52b 100644 --- a/pyeudiw/trust/__init__.py +++ b/pyeudiw/trust/__init__.py @@ -18,8 +18,6 @@ import pyeudiw.metadata.policy as pcl -#from pyeudiw.metadata.policy import TrustChainPolicy, combine - class TrustEvaluationHelper: def __init__(self, storage: DBEngine, httpc_params, trust_anchor: str = None, **kwargs): self.exp: int = 0 From bf3843cf5708b8826861ae968add41388d6fe0e9 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Mon, 27 Nov 2023 16:21:14 +0100 Subject: [PATCH 10/88] fix: fixed x509 verification exception handling --- pyeudiw/trust/__init__.py | 19 ++++++++++++++----- pyeudiw/trust/exceptions.py | 3 +++ 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/pyeudiw/trust/__init__.py b/pyeudiw/trust/__init__.py index b5bad52b..343b6fff 100644 --- a/pyeudiw/trust/__init__.py +++ b/pyeudiw/trust/__init__.py @@ -13,7 +13,8 @@ MissingProtocolSpecificJwks, UnknownTrustAnchor, InvalidTrustType, - MissingTrustType + MissingTrustType, + InvalidAnchor ) import pyeudiw.metadata.policy as pcl @@ -141,12 +142,20 @@ def _handle_x509_pem(self): "a recognizable Trust Anchor." ) + pem = trust_anchor['x509'].get('pem') + + if pem == None: + raise MissingTrustType( + f"Trust Anchor: '{trust_anchor_eid}' has no x509 trusst entity" + ) + try: - pem = trust_anchor['x509']['pem'] _is_valid = verify_x509_anchor(pem) - except KeyError: - raise MissingTrustType( - f"Trust Anchor: '{trust_anchor_eid}' has no x509 trusst entity") + except Exception as e: + raise InvalidAnchor( + f"Anchor verification raised the following exception: {e}" + ) + if not self.is_trusted and trust_anchor['federation'].get("chain", None) != None: self._handle_federation_chain() diff --git a/pyeudiw/trust/exceptions.py b/pyeudiw/trust/exceptions.py index dde9b3ad..89a8c943 100644 --- a/pyeudiw/trust/exceptions.py +++ b/pyeudiw/trust/exceptions.py @@ -13,4 +13,7 @@ class MissingTrustType(Exception): pass class InvalidTrustType(Exception): + pass + +class InvalidAnchor(Exception): pass \ No newline at end of file From 5a74ea0a4091d74e13d3975b140767d8b65ab94f Mon Sep 17 00:00:00 2001 From: PascalDR Date: Mon, 27 Nov 2023 16:21:41 +0100 Subject: [PATCH 11/88] chore: fix typo --- pyeudiw/trust/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyeudiw/trust/__init__.py b/pyeudiw/trust/__init__.py index 343b6fff..479a6117 100644 --- a/pyeudiw/trust/__init__.py +++ b/pyeudiw/trust/__init__.py @@ -146,7 +146,7 @@ def _handle_x509_pem(self): if pem == None: raise MissingTrustType( - f"Trust Anchor: '{trust_anchor_eid}' has no x509 trusst entity" + f"Trust Anchor: '{trust_anchor_eid}' has no x509 trust entity" ) try: From daeb34368653daf11ccafadd3cfd753ceec4d019 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 29 Nov 2023 15:36:56 +0100 Subject: [PATCH 12/88] fix: merged federation and metadata policy implementation --- pyeudiw/federation/exceptions.py | 5 +- pyeudiw/federation/policy.py | 411 ++++++++++++--------- pyeudiw/federation/trust_chain_builder.py | 4 +- pyeudiw/metadata/__init__.py | 0 pyeudiw/metadata/exceptions.py | 6 - pyeudiw/metadata/policy.py | 413 ---------------------- 6 files changed, 250 insertions(+), 589 deletions(-) delete mode 100644 pyeudiw/metadata/__init__.py delete mode 100644 pyeudiw/metadata/exceptions.py delete mode 100755 pyeudiw/metadata/policy.py diff --git a/pyeudiw/federation/exceptions.py b/pyeudiw/federation/exceptions.py index 87f6dba7..4735e77f 100644 --- a/pyeudiw/federation/exceptions.py +++ b/pyeudiw/federation/exceptions.py @@ -77,4 +77,7 @@ class InvalidEntityHeader(Exception): pass class InvalidEntityStatementPayload(Exception): - pass \ No newline at end of file + pass + +class PolicyError(Exception): + pass diff --git a/pyeudiw/federation/policy.py b/pyeudiw/federation/policy.py index 4b2a2031..cdfc57fa 100644 --- a/pyeudiw/federation/policy.py +++ b/pyeudiw/federation/policy.py @@ -1,29 +1,27 @@ +import logging +from typing import Optional +from .exceptions import PolicyError + __author__ = "Roland Hedberg" __license__ = "Apache 2.0" __version__ = "" -import logging - logger = logging.getLogger(__name__) -class PolicyError(Exception): - pass - - -def combine_subset_of(s1, s2): # pragma: no cover +def combine_subset_of(s1, s2): return list(set(s1).intersection(set(s2))) -def combine_superset_of(s1, s2): # pragma: no cover +def combine_superset_of(s1, s2): return list(set(s1).intersection(set(s2))) -def combine_one_of(s1, s2): # pragma: no cover +def combine_one_of(s1, s2): return list(set(s1).intersection(set(s2))) -def combine_add(s1, s2): # pragma: no cover +def combine_add(s1, s2): if isinstance(s1, list): set1 = set(s1) else: @@ -35,15 +33,7 @@ def combine_add(s1, s2): # pragma: no cover return list(set1.union(set2)) -POLICY_FUNCTIONS = { - "subset_of", - "superset_of", - "one_of", - "add", - "value", - "default", - "essential", -} +POLICY_FUNCTIONS = {"subset_of", "superset_of", "one_of", "add", "value", "default", "essential"} OP2FUNC = { "subset_of": combine_subset_of, @@ -53,7 +43,7 @@ def combine_add(s1, s2): # pragma: no cover } -def do_sub_one_super_add(superior, child, policy): # pragma: no cover +def do_sub_one_super_add(superior, child, policy): if policy in superior and policy in child: comb = OP2FUNC[policy](superior[policy], child[policy]) if comb: @@ -66,7 +56,7 @@ def do_sub_one_super_add(superior, child, policy): # pragma: no cover return child[policy] -def do_value(superior, child, policy): # pragma: no cover +def do_value(superior, child, policy): if policy in superior and policy in child: if superior[policy] == child[policy]: return superior[policy] @@ -78,11 +68,11 @@ def do_value(superior, child, policy): # pragma: no cover return child[policy] -def do_default(superior, child, policy): # pragma: no cover +def do_default(superior, child, policy): # A child's default can not override a superiors if policy in superior and policy in child: - if superior["default"] == child["default"]: - return superior["default"] + if superior['default'] == child['default']: + return superior['default'] else: raise PolicyError("Not allowed to change default") elif policy in superior: @@ -91,12 +81,12 @@ def do_default(superior, child, policy): # pragma: no cover return child[policy] -def do_essential(superior, child, policy): # pragma: no cover - # essential: an child can make it True if a superior has states False +def do_essential(superior, child, policy): + # essential: a child can make it True if a superior has states False # but not the other way around if policy in superior and policy in child: - if not superior[policy] and child["essential"]: + if not superior[policy] and child['essential']: return True else: return superior[policy] @@ -113,21 +103,23 @@ def do_essential(superior, child, policy): # pragma: no cover "add": do_sub_one_super_add, "value": do_value, "default": do_default, - "essential": do_essential, + "essential": do_essential } -def combine_claim_policy(superior, child): # pragma: no cover +def combine_claim_policy(superior, child): """ Combine policy rules. Applying the child policy can only make the combined policy more restrictive. + :param superior: Superior policy :param child: Intermediates policy """ - # weed out everything I don't recognize + # weed out everything I don't recognize superior_set = set(superior).intersection(POLICY_FUNCTIONS) child_set = set(child).intersection(POLICY_FUNCTIONS) + if "value" in superior_set: # An exact value can not be restricted. if child_set: if "essential" in child_set: @@ -135,87 +127,105 @@ def combine_claim_policy(superior, child): # pragma: no cover return {"value": superior["value"], "essential": child["essential"]} else: raise PolicyError( - "value can only be combined with essential, not {}".format( - child_set - ) - ) + f"value can only be combined with essential, not {child_set}") elif "value" in child_set: if child["value"] != superior["value"]: # Not OK - raise PolicyError( - "Child can not set another value then superior") + raise PolicyError("Child can not set another value then superior") else: return superior else: raise PolicyError( - "Not allowed combination of policies: {} + {}".format( - superior, child - ) - ) + f"Not allowed combination of policies: {superior} + {child}") return superior else: if "essential" in superior_set and "essential" in child_set: # can only go from False to True - if ( - superior["essential"] != child["essential"] - and child["essential"] is False - ): + if superior["essential"] != child["essential"] and child["essential"] is False: raise PolicyError("Essential can not go from True to False") comb_policy = superior_set.union(child_set) if "one_of" in comb_policy: if "subset_of" in comb_policy or "superset_of" in comb_policy: - raise PolicyError( - "one_of can not be combined with subset_of/superset_of" - ) + raise PolicyError("one_of can not be combined with subset_of/superset_of") rule = {} for policy in comb_policy: rule[policy] = DO_POLICY[policy](superior, child, policy) - if comb_policy == {"superset_of", "subset_of"}: + if comb_policy == {'superset_of', 'subset_of'}: # make sure the subset_of is a superset of superset_of. - if set(rule["superset_of"]).difference(set(rule["subset_of"])): - raise PolicyError("superset_of not a super set of subset_of") - elif comb_policy == {"superset_of", "subset_of", "default"}: + if set(rule['superset_of']).difference(set(rule['subset_of'])): + raise PolicyError('superset_of not a super set of subset_of') + elif comb_policy == {'superset_of', 'subset_of', 'default'}: # make sure the subset_of is a superset of superset_of. - if set(rule["superset_of"]).difference(set(rule["subset_of"])): - raise PolicyError("superset_of not a super set of subset_of") - if set(rule["default"]).difference(set(rule["subset_of"])): - raise PolicyError("default not a sub set of subset_of") - if set(rule["superset_of"]).difference(set(rule["default"])): - raise PolicyError("default not a super set of subset_of") - elif comb_policy == {"subset_of", "default"}: - if set(rule["default"]).difference(set(rule["subset_of"])): - raise PolicyError("default not a sub set of subset_of") - elif comb_policy == {"superset_of", "default"}: - if set(rule["superset_of"]).difference(set(rule["default"])): - raise PolicyError("default not a super set of subset_of") - elif comb_policy == {"one_of", "default"}: - if isinstance(rule["default"], list): - if set(rule["default"]).difference(set(rule["one_of"])): - raise PolicyError("default not a super set of one_of") + if set(rule['superset_of']).difference(set(rule['subset_of'])): + raise PolicyError('superset_of not a super set of subset_of') + if set(rule['default']).difference(set(rule['subset_of'])): + raise PolicyError('default not a sub set of subset_of') + if set(rule['superset_of']).difference(set(rule['default'])): + raise PolicyError('default not a super set of subset_of') + elif comb_policy == {'subset_of', 'default'}: + if set(rule['default']).difference(set(rule['subset_of'])): + raise PolicyError('default not a sub set of subset_of') + elif comb_policy == {'superset_of', 'default'}: + if set(rule['superset_of']).difference(set(rule['default'])): + raise PolicyError('default not a super set of subset_of') + elif comb_policy == {'one_of', 'default'}: + if isinstance(rule['default'], list): + if set(rule['default']).difference(set(rule['one_of'])): + raise PolicyError('default not a super set of one_of') else: - if {rule["default"]}.difference(set(rule["one_of"])): - raise PolicyError("default not a super set of one_of") + if {rule['default']}.difference(set(rule['one_of'])): + raise PolicyError('default not a super set of one_of') return rule -def combine_policy(superior, child): - res = {} - sup_set = set(superior.keys()) - chi_set = set(child.keys()) +def combine(superior: dict, sub: dict) -> dict: + """ + + :param rule: Dictionary with two keys metadata_policy and metadata + :param sub: Dictionary with two keys metadata_policy and metadata + :return: + """ + sup_metadata = superior.get('metadata', {}) + sub_metadata = sub.get('metadata', {}) + sup_m_set = set(sup_metadata.keys()) + if sub_metadata: + chi_m_set = set(sub_metadata.keys()) + _overlap = chi_m_set.intersection(sup_m_set) + if _overlap: + for key in _overlap: + if sup_metadata[key] != sub_metadata[key]: + raise PolicyError( + 'A subordinate is not allowed to set a value different then the superiors') - for claim in set(sup_set).intersection(chi_set): - res[claim] = combine_claim_policy(superior[claim], child[claim]) + _metadata = sup_metadata.copy() + _metadata.update(sub_metadata) + superior['metadata'] = _metadata - for claim in sup_set.difference(chi_set): - res[claim] = superior[claim] + # Now for metadata_policies + _sup_policy = superior.get('metadata_policy', {}) + _sub_policy = sub.get('metadata_policy', {}) + if _sub_policy: + sup_set = set(_sup_policy.keys()) + chi_set = set(sub['metadata_policy'].keys()) - for claim in chi_set.difference(sup_set): - res[claim] = child[claim] + # A metadata_policy claim can not change a metadata claim + for claim in chi_set.intersection(sup_m_set): + combine_claim_policy({'value': sup_metadata[claim]}, _sub_policy[claim]) - return res + _mp = {} + for claim in set(sup_set).intersection(chi_set): + _mp[claim] = combine_claim_policy(_sup_policy[claim], _sub_policy[claim]) + + for claim in sup_set.difference(chi_set): + _mp[claim] = _sup_policy[claim] + + for claim in chi_set.difference(sup_set): + _mp[claim] = _sub_policy[claim] + superior['metadata_policy'] = _mp + return superior def gather_policies(chain, entity_type): """ @@ -235,11 +245,10 @@ def gather_policies(chain, entity_type): except KeyError: pass else: - combined_policy = combine_policy(combined_policy, child) + combined_policy = combine(combined_policy, child) return combined_policy - def union(val1, val2): if isinstance(val1, list): base = set(val1) @@ -253,91 +262,159 @@ def union(val1, val2): return base.union(ext) -def apply_policy(metadata, policy): - """ - Apply a metadata policy to a metadata statement. - The order is value, add, default and then the checks subset_of/superset_of and one_of - :param metadata: A metadata statement - :param policy: A metadata policy - :return: A metadata statement that adheres to a metadata policy - """ - metadata_set = set(metadata.keys()) - policy_set = set(policy.keys()) - - # Metadata claims that there exists a policy for - for claim in metadata_set.intersection(policy_set): - if "value" in policy[claim]: # value overrides everything - metadata[claim] = policy[claim]["value"] - else: - if "one_of" in policy[claim]: - # The is for claims that can have only one value - if isinstance(metadata[claim], list): # Should not be but ... - _claim = [ - c for c in metadata[claim] if c in policy[claim]["one_of"] - ] - if _claim: - metadata[claim] = _claim[0] - else: - raise PolicyError( - "{}: None of {} among {}".format( - claim, metadata[claim], policy[claim]["one_of"] - ) - ) - else: - if metadata[claim] in policy[claim]["one_of"]: - pass - else: - raise PolicyError( - "{} not among {}".format( - metadata[claim], policy[claim]["one_of"] - ) - ) - else: - # The following is for claims that can have lists of values - if "add" in policy[claim]: - metadata[claim] = list( - union(metadata[claim], policy[claim]["add"])) - - if "subset_of" in policy[claim]: - _val = set(policy[claim]["subset_of"]).intersection( - set(metadata[claim]) - ) - if _val: - metadata[claim] = list(_val) - else: - raise PolicyError( - "{} not subset of {}".format( - metadata[claim], policy[claim]["subset_of"] - ) - ) - if "superset_of" in policy[claim]: - if set(policy[claim]["superset_of"]).difference( - set(metadata[claim]) - ): - raise PolicyError( - "{} not superset of {}".format( - metadata[claim], policy[claim]["superset_of"] - ) - ) - else: - pass +class TrustChainPolicy(object): + def gather_policies(self, chain, entity_type): + """ + Gather and combine all the metadata policies that are defined in the trust chain + :param chain: A list of Entity Statements + :return: The combined metadata policy + """ + + _rule = {'metadata_policy': {}, 'metadata': {}} + for _item in ['metadata_policy', 'metadata']: + try: + _rule[_item] = chain[0][_item][entity_type] + except KeyError: + pass + + for es in chain[1:]: + _sub_policy = {'metadata_policy': {}, 'metadata': {}} + for _item in ['metadata_policy', 'metadata']: + try: + _sub_policy[_item] = es[_item][entity_type] + except KeyError: + pass + + if _sub_policy == {'metadata_policy': {}, 'metadata': {}}: + continue + + _overlap = set(_sub_policy['metadata_policy']).intersection( + set(_sub_policy['metadata'])) + if _overlap: # Not allowed + raise PolicyError( + 'Claim appearing both in metadata and metadata_policy not allowed') + _rule = combine(_rule, _sub_policy) - # In policy but not in metadata - for claim in policy_set.difference(metadata_set): - if "value" in policy[claim]: - metadata[claim] = policy[claim]["value"] - elif "add" in policy[claim]: - metadata[claim] = policy[claim]["add"] - elif "default" in policy[claim]: - metadata[claim] = policy[claim]["default"] + return _rule - if claim not in metadata: - if "essential" in policy[claim] and policy[claim]["essential"]: - raise PolicyError("Essential claim '{}' missing".format(claim)) + def _apply_metadata_policy(self, metadata, metadata_policy): + """ + Apply a metadata policy to a metadata statement. + The order is value, add, default and then check subset_of/superset_of and one_of + """ - # All that are in metadata but not in policy should just remain + policy_set = set(metadata_policy.keys()) + metadata_set = set(metadata.keys()) - return metadata + # Metadata claims that there exists a policy for + for claim in metadata_set.intersection(policy_set): + if "value" in metadata_policy[claim]: # value overrides everything + metadata[claim] = metadata_policy[claim]["value"] + else: + if "one_of" in metadata_policy[claim]: + # The is for claims that can have only one value + if isinstance(metadata[claim], list): # Should not be but ... + _claim = [c for c in metadata[claim] if + c in metadata_policy[claim]['one_of']] + if _claim: + metadata[claim] = _claim[0] + else: + raise PolicyError( + "{}: None of {} among {}".format(claim, metadata[claim], + metadata_policy[claim]['one_of'])) + else: + if metadata[claim] in metadata_policy[claim]['one_of']: + pass + else: + raise PolicyError( + f"{metadata[claim]} not among {metadata_policy[claim]['one_of']}") + else: + # The following is for claims that can have lists of values + if "add" in metadata_policy[claim]: + metadata[claim] = list( + union(metadata[claim], metadata_policy[claim]['add'])) + + if "subset_of" in metadata_policy[claim]: + _val = set(metadata_policy[claim]['subset_of']).intersection( + set(metadata[claim])) + if _val: + metadata[claim] = list(_val) + else: + raise PolicyError("{} not subset of {}".format(metadata[claim], + metadata_policy[claim][ + 'subset_of'])) + if "superset_of" in metadata_policy[claim]: + if set(metadata_policy[claim]['superset_of']).difference( + set(metadata[claim])): + raise PolicyError("{} not superset of {}".format(metadata[claim], + metadata_policy[claim][ + 'superset_of'])) + else: + pass + + # In policy but not in metadata + for claim in policy_set.difference(metadata_set): + if "value" in metadata_policy[claim]: + metadata[claim] = metadata_policy[claim]['value'] + elif "add" in metadata_policy[claim]: + metadata[claim] = metadata_policy[claim]['add'] + elif "default" in metadata_policy[claim]: + metadata[claim] = metadata_policy[claim]['default'] + + if claim not in metadata: + if "essential" in metadata_policy[claim] and metadata_policy[claim]["essential"]: + raise PolicyError(f"Essential claim '{claim}' missing") + + return metadata + + def apply_policy(self, metadata: dict, policy: dict) -> dict: + """ + Apply a metadata policy on metadata. + + :param metadata: Metadata statements + :param policy: A dictionary with metadata and metadata_policy as keys + :return: A metadata statement that adheres to a metadata policy + """ + + if policy['metadata_policy']: + metadata = self._apply_metadata_policy(metadata, policy['metadata_policy']) + + # All that are in metadata but not in policy should just remain + metadata.update(policy['metadata']) + + return metadata + + def _policy(self, trust_chain, entity_type: str): + + + combined_policy = self.gather_policies(trust_chain[:-1], entity_type) + logger.debug("Combined policy: %s", combined_policy) + try: + # This should be the entity configuration + metadata = trust_chain.verified_chain[-1]['metadata'][entity_type] + except KeyError: + return None + else: + # apply the combined metadata policies on the metadata + trust_chain.set_combined_policy(entity_type, combined_policy) + _metadata = self.apply_policy(metadata, combined_policy) + logger.debug(f"After applied policy: {_metadata}") + return _metadata + + def __call__(self, trust_chain, entity_type: Optional[str] = ''): + """ + :param trust_chain: TrustChain instance + :param entity_type: Which Entity Type the entity are + """ + if len(trust_chain.verified_chain) > 1: + if entity_type: + trust_chain.metadata[entity_type] = self._policy(trust_chain, entity_type) + else: + for _type in trust_chain.verified_chain[-1]['metadata'].keys(): + trust_chain.metadata[_type] = self._policy(trust_chain, _type) + else: + trust_chain.metadata = trust_chain.verified_chain[0]["metadata"][entity_type] + trust_chain.combined_policy[entity_type] = {} def diff2policy(new, old): @@ -346,12 +423,12 @@ def diff2policy(new, old): if new[claim] == old[claim]: continue else: - res[claim] = {"value": new[claim]} + res[claim] = {'value': new[claim]} for claim in set(new).difference(set(old)): - if claim in ["contacts"]: - res[claim] = {"add": new[claim]} + if claim in ['contacts']: + res[claim] = {'add': new[claim]} else: - res[claim] = {"value": new[claim]} + res[claim] = {'value': new[claim]} return res diff --git a/pyeudiw/federation/trust_chain_builder.py b/pyeudiw/federation/trust_chain_builder.py index 5c338d9d..ad9b4319 100644 --- a/pyeudiw/federation/trust_chain_builder.py +++ b/pyeudiw/federation/trust_chain_builder.py @@ -5,7 +5,7 @@ from collections import OrderedDict from typing import Union -from pyeudiw.federation.policy import apply_policy +from .policy import TrustChainPolicy from .exceptions import ( InvalidEntityStatement, @@ -150,7 +150,7 @@ def apply_metadata_policy(self) -> dict: for md_type, md in _pol.items(): if not self.final_metadata.get(md_type): continue - self.final_metadata[md_type] = apply_policy( + self.final_metadata[md_type] = TrustChainPolicy().apply_policy( self.final_metadata[md_type], _pol[md_type] ) diff --git a/pyeudiw/metadata/__init__.py b/pyeudiw/metadata/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pyeudiw/metadata/exceptions.py b/pyeudiw/metadata/exceptions.py deleted file mode 100644 index 239c5808..00000000 --- a/pyeudiw/metadata/exceptions.py +++ /dev/null @@ -1,6 +0,0 @@ -__author__ = "Roland Hedberg" -__license__ = "Apache 2.0" -__version__ = "" - -class PolicyError(Exception): - pass \ No newline at end of file diff --git a/pyeudiw/metadata/policy.py b/pyeudiw/metadata/policy.py deleted file mode 100755 index 6eb4c9ab..00000000 --- a/pyeudiw/metadata/policy.py +++ /dev/null @@ -1,413 +0,0 @@ -import logging -from typing import Optional -from pyeudiw.trust.trust_chain import TrustChain -from pyeudiw.metadata.exceptions import PolicyError - -__author__ = "Roland Hedberg" -__license__ = "Apache 2.0" -__version__ = "" - -logger = logging.getLogger(__name__) - - -def combine_subset_of(s1, s2): - return list(set(s1).intersection(set(s2))) - - -def combine_superset_of(s1, s2): - return list(set(s1).intersection(set(s2))) - - -def combine_one_of(s1, s2): - return list(set(s1).intersection(set(s2))) - - -def combine_add(s1, s2): - if isinstance(s1, list): - set1 = set(s1) - else: - set1 = {s1} - if isinstance(s2, list): - set2 = set(s2) - else: - set2 = {s2} - return list(set1.union(set2)) - - -POLICY_FUNCTIONS = {"subset_of", "superset_of", "one_of", "add", "value", "default", "essential"} - -OP2FUNC = { - "subset_of": combine_subset_of, - "superset_of": combine_superset_of, - "one_of": combine_one_of, - "add": combine_add, -} - - -def do_sub_one_super_add(superior, child, policy): - if policy in superior and policy in child: - comb = OP2FUNC[policy](superior[policy], child[policy]) - if comb: - return comb - else: - raise PolicyError("Value sets doesn't overlap") - elif policy in superior: - return superior[policy] - elif policy in child: - return child[policy] - - -def do_value(superior, child, policy): - if policy in superior and policy in child: - if superior[policy] == child[policy]: - return superior[policy] - else: - raise PolicyError("Not allowed to combine values") - elif policy in superior: - return superior[policy] - elif policy in child: - return child[policy] - - -def do_default(superior, child, policy): - # A child's default can not override a superiors - if policy in superior and policy in child: - if superior['default'] == child['default']: - return superior['default'] - else: - raise PolicyError("Not allowed to change default") - elif policy in superior: - return superior[policy] - elif policy in child: - return child[policy] - - -def do_essential(superior, child, policy): - # essential: a child can make it True if a superior has states False - # but not the other way around - - if policy in superior and policy in child: - if not superior[policy] and child['essential']: - return True - else: - return superior[policy] - elif policy in superior: - return superior[policy] - elif policy in child: # Not in superior is the same as essential=True - return True - - -DO_POLICY = { - "superset_of": do_sub_one_super_add, - "subset_of": do_sub_one_super_add, - "one_of": do_sub_one_super_add, - "add": do_sub_one_super_add, - "value": do_value, - "default": do_default, - "essential": do_essential -} - - -def combine_claim_policy(superior, child): - """ - Combine policy rules. - Applying the child policy can only make the combined policy more restrictive. - - :param superior: Superior policy - :param child: Intermediates policy - """ - - # weed out everything I don't recognize - superior_set = set(superior).intersection(POLICY_FUNCTIONS) - child_set = set(child).intersection(POLICY_FUNCTIONS) - - if "value" in superior_set: # An exact value can not be restricted. - if child_set: - if "essential" in child_set: - if len(child_set) == 1: - return {"value": superior["value"], "essential": child["essential"]} - else: - raise PolicyError( - f"value can only be combined with essential, not {child_set}") - elif "value" in child_set: - if child["value"] != superior["value"]: # Not OK - raise PolicyError("Child can not set another value then superior") - else: - return superior - else: - raise PolicyError( - f"Not allowed combination of policies: {superior} + {child}") - return superior - else: - if "essential" in superior_set and "essential" in child_set: - # can only go from False to True - if superior["essential"] != child["essential"] and child["essential"] is False: - raise PolicyError("Essential can not go from True to False") - - comb_policy = superior_set.union(child_set) - if "one_of" in comb_policy: - if "subset_of" in comb_policy or "superset_of" in comb_policy: - raise PolicyError("one_of can not be combined with subset_of/superset_of") - - rule = {} - for policy in comb_policy: - rule[policy] = DO_POLICY[policy](superior, child, policy) - - if comb_policy == {'superset_of', 'subset_of'}: - # make sure the subset_of is a superset of superset_of. - if set(rule['superset_of']).difference(set(rule['subset_of'])): - raise PolicyError('superset_of not a super set of subset_of') - elif comb_policy == {'superset_of', 'subset_of', 'default'}: - # make sure the subset_of is a superset of superset_of. - if set(rule['superset_of']).difference(set(rule['subset_of'])): - raise PolicyError('superset_of not a super set of subset_of') - if set(rule['default']).difference(set(rule['subset_of'])): - raise PolicyError('default not a sub set of subset_of') - if set(rule['superset_of']).difference(set(rule['default'])): - raise PolicyError('default not a super set of subset_of') - elif comb_policy == {'subset_of', 'default'}: - if set(rule['default']).difference(set(rule['subset_of'])): - raise PolicyError('default not a sub set of subset_of') - elif comb_policy == {'superset_of', 'default'}: - if set(rule['superset_of']).difference(set(rule['default'])): - raise PolicyError('default not a super set of subset_of') - elif comb_policy == {'one_of', 'default'}: - if isinstance(rule['default'], list): - if set(rule['default']).difference(set(rule['one_of'])): - raise PolicyError('default not a super set of one_of') - else: - if {rule['default']}.difference(set(rule['one_of'])): - raise PolicyError('default not a super set of one_of') - return rule - - -def combine(superior: dict, sub: dict) -> dict: - """ - - :param rule: Dictionary with two keys metadata_policy and metadata - :param sub: Dictionary with two keys metadata_policy and metadata - :return: - """ - sup_metadata = superior.get('metadata', {}) - sub_metadata = sub.get('metadata', {}) - sup_m_set = set(sup_metadata.keys()) - if sub_metadata: - chi_m_set = set(sub_metadata.keys()) - _overlap = chi_m_set.intersection(sup_m_set) - if _overlap: - for key in _overlap: - if sup_metadata[key] != sub_metadata[key]: - raise PolicyError( - 'A subordinate is not allowed to set a value different then the superiors') - - _metadata = sup_metadata.copy() - _metadata.update(sub_metadata) - superior['metadata'] = _metadata - - # Now for metadata_policies - _sup_policy = superior.get('metadata_policy', {}) - _sub_policy = sub.get('metadata_policy', {}) - if _sub_policy: - sup_set = set(_sup_policy.keys()) - chi_set = set(sub['metadata_policy'].keys()) - - # A metadata_policy claim can not change a metadata claim - for claim in chi_set.intersection(sup_m_set): - combine_claim_policy({'value': sup_metadata[claim]}, _sub_policy[claim]) - - _mp = {} - for claim in set(sup_set).intersection(chi_set): - _mp[claim] = combine_claim_policy(_sup_policy[claim], _sub_policy[claim]) - - for claim in sup_set.difference(chi_set): - _mp[claim] = _sup_policy[claim] - - for claim in chi_set.difference(sup_set): - _mp[claim] = _sub_policy[claim] - - superior['metadata_policy'] = _mp - return superior - -def union(val1, val2): - if isinstance(val1, list): - base = set(val1) - else: - base = {val1} - - if isinstance(val2, list): - ext = set(val2) - else: - ext = {val2} - return base.union(ext) - - -class TrustChainPolicy(object): - def gather_policies(self, chain, entity_type): - """ - Gather and combine all the metadata policies that are defined in the trust chain - :param chain: A list of Entity Statements - :return: The combined metadata policy - """ - - _rule = {'metadata_policy': {}, 'metadata': {}} - for _item in ['metadata_policy', 'metadata']: - try: - _rule[_item] = chain[0][_item][entity_type] - except KeyError: - pass - - for es in chain[1:]: - _sub_policy = {'metadata_policy': {}, 'metadata': {}} - for _item in ['metadata_policy', 'metadata']: - try: - _sub_policy[_item] = es[_item][entity_type] - except KeyError: - pass - - if _sub_policy == {'metadata_policy': {}, 'metadata': {}}: - continue - - _overlap = set(_sub_policy['metadata_policy']).intersection( - set(_sub_policy['metadata'])) - if _overlap: # Not allowed - raise PolicyError( - 'Claim appearing both in metadata and metadata_policy not allowed') - _rule = combine(_rule, _sub_policy) - - return _rule - - def _apply_metadata_policy(self, metadata, metadata_policy): - """ - Apply a metadata policy to a metadata statement. - The order is value, add, default and then check subset_of/superset_of and one_of - """ - - policy_set = set(metadata_policy.keys()) - metadata_set = set(metadata.keys()) - - # Metadata claims that there exists a policy for - for claim in metadata_set.intersection(policy_set): - if "value" in metadata_policy[claim]: # value overrides everything - metadata[claim] = metadata_policy[claim]["value"] - else: - if "one_of" in metadata_policy[claim]: - # The is for claims that can have only one value - if isinstance(metadata[claim], list): # Should not be but ... - _claim = [c for c in metadata[claim] if - c in metadata_policy[claim]['one_of']] - if _claim: - metadata[claim] = _claim[0] - else: - raise PolicyError( - "{}: None of {} among {}".format(claim, metadata[claim], - metadata_policy[claim]['one_of'])) - else: - if metadata[claim] in metadata_policy[claim]['one_of']: - pass - else: - raise PolicyError( - f"{metadata[claim]} not among {metadata_policy[claim]['one_of']}") - else: - # The following is for claims that can have lists of values - if "add" in metadata_policy[claim]: - metadata[claim] = list( - union(metadata[claim], metadata_policy[claim]['add'])) - - if "subset_of" in metadata_policy[claim]: - _val = set(metadata_policy[claim]['subset_of']).intersection( - set(metadata[claim])) - if _val: - metadata[claim] = list(_val) - else: - raise PolicyError("{} not subset of {}".format(metadata[claim], - metadata_policy[claim][ - 'subset_of'])) - if "superset_of" in metadata_policy[claim]: - if set(metadata_policy[claim]['superset_of']).difference( - set(metadata[claim])): - raise PolicyError("{} not superset of {}".format(metadata[claim], - metadata_policy[claim][ - 'superset_of'])) - else: - pass - - # In policy but not in metadata - for claim in policy_set.difference(metadata_set): - if "value" in metadata_policy[claim]: - metadata[claim] = metadata_policy[claim]['value'] - elif "add" in metadata_policy[claim]: - metadata[claim] = metadata_policy[claim]['add'] - elif "default" in metadata_policy[claim]: - metadata[claim] = metadata_policy[claim]['default'] - - if claim not in metadata: - if "essential" in metadata_policy[claim] and metadata_policy[claim]["essential"]: - raise PolicyError(f"Essential claim '{claim}' missing") - - return metadata - - def apply_policy(self, metadata: dict, policy: dict) -> dict: - """ - Apply a metadata policy on metadata. - - :param metadata: Metadata statements - :param policy: A dictionary with metadata and metadata_policy as keys - :return: A metadata statement that adheres to a metadata policy - """ - - if policy['metadata_policy']: - metadata = self._apply_metadata_policy(metadata, policy['metadata_policy']) - - # All that are in metadata but not in policy should just remain - metadata.update(policy['metadata']) - - return metadata - - def _policy(self, trust_chain: TrustChain, entity_type: str): - - - combined_policy = self.gather_policies(trust_chain[:-1], entity_type) - logger.debug("Combined policy: %s", combined_policy) - try: - # This should be the entity configuration - metadata = trust_chain.verified_chain[-1]['metadata'][entity_type] - except KeyError: - return None - else: - # apply the combined metadata policies on the metadata - trust_chain.set_combined_policy(entity_type, combined_policy) - _metadata = self.apply_policy(metadata, combined_policy) - logger.debug(f"After applied policy: {_metadata}") - return _metadata - - def __call__(self, trust_chain: TrustChain, entity_type: Optional[str] = ''): - """ - :param trust_chain: TrustChain instance - :param entity_type: Which Entity Type the entity are - """ - if len(trust_chain.verified_chain) > 1: - if entity_type: - trust_chain.metadata[entity_type] = self._policy(trust_chain, entity_type) - else: - for _type in trust_chain.verified_chain[-1]['metadata'].keys(): - trust_chain.metadata[_type] = self._policy(trust_chain, _type) - else: - trust_chain.metadata = trust_chain.verified_chain[0]["metadata"][entity_type] - trust_chain.combined_policy[entity_type] = {} - - -def diff2policy(new, old): - res = {} - for claim in set(new).intersection(set(old)): - if new[claim] == old[claim]: - continue - else: - res[claim] = {'value': new[claim]} - - for claim in set(new).difference(set(old)): - if claim in ['contacts']: - res[claim] = {'add': new[claim]} - else: - res[claim] = {'value': new[claim]} - - return res \ No newline at end of file From f94b0636b35a51642b16e65340ec79600a8828da Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 29 Nov 2023 15:37:37 +0100 Subject: [PATCH 13/88] test: adapted tests --- pyeudiw/tests/{metadata => federation}/test_metadata.py | 8 ++++---- pyeudiw/tests/federation/test_policy.py | 4 +++- 2 files changed, 7 insertions(+), 5 deletions(-) rename pyeudiw/tests/{metadata => federation}/test_metadata.py (98%) diff --git a/pyeudiw/tests/metadata/test_metadata.py b/pyeudiw/tests/federation/test_metadata.py similarity index 98% rename from pyeudiw/tests/metadata/test_metadata.py rename to pyeudiw/tests/federation/test_metadata.py index 0c499ca2..9a4b3180 100644 --- a/pyeudiw/tests/metadata/test_metadata.py +++ b/pyeudiw/tests/federation/test_metadata.py @@ -1,8 +1,8 @@ import pytest -from pyeudiw.metadata.policy import combine -from pyeudiw.metadata.policy import combine_claim_policy -from pyeudiw.metadata.policy import TrustChainPolicy -from pyeudiw.metadata.exceptions import PolicyError +from pyeudiw.federation.policy import combine +from pyeudiw.federation.policy import combine_claim_policy +from pyeudiw.federation.policy import TrustChainPolicy +from pyeudiw.federation.exceptions import PolicyError __author__ = "Roland Hedberg" __license__ = "Apache 2.0" diff --git a/pyeudiw/tests/federation/test_policy.py b/pyeudiw/tests/federation/test_policy.py index 6a235081..fc7c6f1f 100644 --- a/pyeudiw/tests/federation/test_policy.py +++ b/pyeudiw/tests/federation/test_policy.py @@ -1,8 +1,10 @@ from pyeudiw.federation.policy import ( - do_sub_one_super_add, PolicyError, do_value + do_sub_one_super_add, do_value ) +from pyeudiw.federation.exceptions import PolicyError + def test_do_sub_one_super_add_subset_of(): SUPERIOR = { From 67a4d12fb4d4d38f76baced855c65c4218d3257a Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 29 Nov 2023 15:38:47 +0100 Subject: [PATCH 14/88] feat: added final_metadata property --- pyeudiw/federation/trust_chain_validator.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pyeudiw/federation/trust_chain_validator.py b/pyeudiw/federation/trust_chain_validator.py index 69597371..eef834fc 100644 --- a/pyeudiw/federation/trust_chain_validator.py +++ b/pyeudiw/federation/trust_chain_validator.py @@ -218,5 +218,15 @@ def entity_id(self) -> str: chain = self.trust_chain payload = unpad_jwt_payload(chain[0]) return payload["iss"] + + @property + def final_metadata(self) -> dict: + anchor = self.trust_anchor_jwks[-1] + es_anchor_payload = unpad_jwt_payload(anchor) + + policy = es_anchor_payload.get("metadata_policy", {}) + + leaf = self.trust_anchor_jwks[0] + es_leaf_payload = unpad_jwt_payload(leaf) - # TODO - apply metadata policy and get the final metadata + #return TrustChainPolicy().apply_policy(es_leaf_payload["metadata"], policy) From 059e94bf00575397fd679c0cd0fa01885edf8a6d Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 29 Nov 2023 15:39:55 +0100 Subject: [PATCH 15/88] feat: added chain discovery plus refactoring --- pyeudiw/trust/__init__.py | 70 +++++++++++++++++++++++++-------------- 1 file changed, 46 insertions(+), 24 deletions(-) diff --git a/pyeudiw/trust/__init__.py b/pyeudiw/trust/__init__.py index 479a6117..ef8cd089 100644 --- a/pyeudiw/trust/__init__.py +++ b/pyeudiw/trust/__init__.py @@ -1,3 +1,4 @@ +import logging from datetime import datetime from pyeudiw.federation.trust_chain_builder import TrustChainBuilder @@ -17,7 +18,11 @@ InvalidAnchor ) -import pyeudiw.metadata.policy as pcl +from pyeudiw.federation.statements import EntityStatement +from pyeudiw.federation.exceptions import TimeValidationError +from pyeudiw.federation.policy import TrustChainPolicy, combine + +logger = logging.getLogger(__name__) class TrustEvaluationHelper: def __init__(self, storage: DBEngine, httpc_params, trust_anchor: str = None, **kwargs): @@ -57,6 +62,16 @@ def _get_evaluation_method(self): def evaluation_method(self) -> bool: ev_method = self._get_evaluation_method() return ev_method() + + def _update_chain(self, entity_id: str | None = None, exp: datetime | None = None, trust_chain: list | None = None): + if entity_id != None: + self.entity_id = entity_id + + if exp != None: + self.exp = exp + + if trust_chain != None: + self.trust_chain = trust_chain def _handle_federation_chain(self): _first_statement = unpad_jwt_payload(self.trust_chain[-1]) @@ -90,15 +105,22 @@ def _handle_federation_chain(self): tc = StaticTrustChainValidator( self.trust_chain, jwks, self.httpc_params ) - self.entity_id = tc.entity_id - self.exp = tc.exp + self._update_chain( + entity_id=tc.entity_id, + exp=tc.exp + ) + _is_valid = False + try: _is_valid = tc.validate() - except Exception: - # raise / log here that's expired - pass # nosec - B110 + except TimeValidationError: + logger.warn(f"Trust Chain {tc.entity_id} is expired") + except Exception as e: + logger.warn(f"Cannot validate Trust Chain {tc.entity_id} for the following reason: {e}") + db_chain = None + if not _is_valid: try: db_chain = self.storage.get_trust_attestation( @@ -110,9 +132,13 @@ def _handle_federation_chain(self): except (EntryNotFound, Exception): pass + _is_valid = tc.update() - self.exp = tc.exp - self.trust_chain = tc.trust_chain + + self._update_chain( + trust_chain=tc.trust_chain, + exp=tc.exp + ) # the good trust chain is then stored self.storage.add_or_update_trust_attestation( @@ -164,21 +190,14 @@ def _handle_x509_pem(self): return _is_valid def federation(self) -> bool: + if len(self.trust_chain) == 0: + self.discovery(self.entity_id) + if self.trust_chain: self.is_valid = self._handle_federation_chain() return self.is_valid - # TODO - at least a TA entity id is required for a discovery process - # _tc = TrustChainBuilder( - # subject= self.entity_id, - # trust_anchor=trust_anchor_ec, - # trust_anchor_configuration=trust_anchor_ec - # ) - # if _tc.is_valid: - # self.trust_chain = _tc.serialize() - # return self.trust_chain - - return [] + return False def x509(self) -> bool: self.is_valid = self._handle_x509_pem() @@ -188,7 +207,7 @@ def get_final_metadata(self, metadata_type: str, policies: list[dict]) -> dict: policy_acc = {"metadata": {}, "metadata_policy": {}} for policy in policies: - policy_acc = pcl.combine(policy, policy_acc) + policy_acc = combine(policy, policy_acc) self.final_metadata = unpad_jwt_payload(self.trust_chain[0]) @@ -199,7 +218,7 @@ def get_final_metadata(self, metadata_type: str, policies: list[dict]) -> dict: "metadata_policy": {} } - self.final_metadata = pcl.TrustChainPolicy().apply_policy( + self.final_metadata = TrustChainPolicy().apply_policy( selected_metadata, policy_acc ) @@ -217,7 +236,7 @@ def get_trusted_jwks(self, metadata_type: str, policies: list[dict] = []) -> lis policies=policies ).get('jwks', {}).get('keys', []) - def discovery(self, entity_id, entity_configuration): + def discovery(self, entity_id: str, entity_configuration: EntityStatement | None = None): """ Updates fields ``trust_chain`` and ``exp`` based on the discovery process. @@ -234,8 +253,11 @@ def discovery(self, entity_id, entity_configuration): subject_configuration=entity_configuration, httpc_params=self.httpc_params ) - self.trust_chain = tcbuilder.get_trust_chain() - self.exp = tcbuilder.exp + + self._update_chain( + trust_chain=tcbuilder.get_trust_chain(), + exp=tcbuilder.exp + ) is_good = tcbuilder.is_valid if not is_good: raise DiscoveryFailedError( From b807998a6c69e6e35b48ec7a1c64263af2cd136f Mon Sep 17 00:00:00 2001 From: PascalDR Date: Mon, 4 Dec 2023 17:42:11 +0100 Subject: [PATCH 16/88] docs: documented file class and functions --- pyeudiw/federation/trust_chain_validator.py | 149 ++++++++++++++++---- 1 file changed, 125 insertions(+), 24 deletions(-) diff --git a/pyeudiw/federation/trust_chain_validator.py b/pyeudiw/federation/trust_chain_validator.py index eef834fc..9c00c53a 100644 --- a/pyeudiw/federation/trust_chain_validator.py +++ b/pyeudiw/federation/trust_chain_validator.py @@ -3,6 +3,7 @@ from pyeudiw.jwt import JWSHelper from pyeudiw.jwt.utils import unpad_jwt_payload, unpad_jwt_header from pyeudiw.federation import is_es +from pyeudiw.federation.policy import TrustChainPolicy from pyeudiw.federation.statements import ( get_entity_configurations, get_entity_statements @@ -17,7 +18,18 @@ logger = logging.getLogger(__name__) -def find_jwk(kid: str, jwks: list) -> dict: +def find_jwk(kid: str, jwks: list[dict]) -> dict: + """ + Find the JWK with the indicated kid in the jwks list. + + :param kid: the identifier of the jwk + :type kid: str + :param jwks: the list of jwks + :type jwks: list[dict] + + :returns: the jwk with the indicated kid or an empty dict if no jwk is found + :rtype: dict + """ if not kid: return {} for jwk in jwks: @@ -27,13 +39,25 @@ def find_jwk(kid: str, jwks: list) -> dict: class StaticTrustChainValidator: + """Helper class for Static Trust Chain validation""" def __init__( self, - static_trust_chain: list, - trust_anchor_jwks: list, + static_trust_chain: list[str], + trust_anchor_jwks: list[dict], httpc_params: dict, **kwargs, ) -> None: + + """ + Generates a new StaticTrustChainValidator istance + + :param static_trust_chain: the list of JWTs, containing the EC, componing the static tust chain + :type static_trust_chain: list[str] + :param trust_anchor_jwks: the list of trust anchor jwks + :type trust_anchor_jwks: list[dict] + :param httpc_params: parameters to perform http requests + :type httpc_params: dict + """ self.static_trust_chain = static_trust_chain self.updated_trust_chain = [] @@ -51,9 +75,45 @@ def __init__( setattr(self, k, v) def _check_expired(self, exp: int) -> bool: + """ + Checks if exp value is expired. + + :param exp: an integer that represent the timestemp to check + :type exp: int + :returns: True if exp is expired and False otherwise + :rtype: bool + """ + return exp < iat_now() + + def _validate_exp(self, exp: int) -> None: + """ + Checks if exp value is expired. + + :param exp: an integer that represent the timestemp to check + :type exp: int + + :raises TimeValidationError: if exp value is expired + """ + + if not self._check_expired(exp): + raise TimeValidationError( + "Expired validation error" + ) + + def _validate_keys(self, fed_jwks: list[dict], st_header: dict) -> None: + """ + Checks that the kid in st_header match with one JWK present + in the federation JWKs list. + + :param fed_jwks: the list of federation's JWKs + :type fed_jwks: list[dict] + :param st_header: the statement header + :type st_header: dict + + :raises KeyValidationError: if no JWK with the kid specified in feild st_header is found + """ - def _validate_keys(self, fed_jwks: list[str], st_header: dict) -> None: current_kid = st_header["kid"] validation_kid = None @@ -65,17 +125,14 @@ def _validate_keys(self, fed_jwks: list[str], st_header: dict) -> None: if not validation_kid: raise KeyValidationError(f"Kid {current_kid} not found") - def _validate_single(self, fed_jwks: list[str], header: dict, payload: dict) -> bool: - try: - self._validate_keys(fed_jwks, header) - self._validate_exp(payload["exp"]) - except Exception as e: - logger.warning(f"Warning: {e}") - return False - - return True - def validate(self) -> bool: + """ + Validates the static chain checking the validity in all jwt inside the field trust_chain. + + :returns: True if static chain is valid and False otherwise + :rtype: bool + """ + # start from the last entity statement rev_tc = [ i for i in reversed(self.trust_chain) @@ -104,9 +161,7 @@ def validate(self) -> bool: self.exp = es_payload["exp"] if self._check_expired(self.exp): - raise TimeValidationError( - "Expired validation error" - ) + return False fed_jwks = es_payload["jwks"]["keys"] @@ -133,11 +188,16 @@ def validate(self) -> bool: return True - @property - def is_valid(self) -> bool: - return self.validate() - def _retrieve_ec(self, iss: str) -> str: + """ + Retrieves the Entity configuration from an on-line source. + + :param iss: The issuer url where retrieve the entity configuration. + :type iss: str + + :returns: the entity configuration in form of JWT. + :rtype: str + """ jwt = get_entity_configurations(iss, self.httpc_params) if not jwt: raise HttpError( @@ -147,6 +207,17 @@ def _retrieve_ec(self, iss: str) -> str: return jwt[0] def _retrieve_es(self, download_url: str, iss: str) -> str: + """ + Retrieves the Entity Statement from an on-line source. + + :param download_url: The path where retrieve the entity configuration. + :type download_url: str + :param iss: The issuer url. + :type iss: str + + :returns: the entity statement in form of JWT. + :rtype: str + """ jwt = get_entity_statements(download_url, self.httpc_params) if not jwt: logger.warning( @@ -157,6 +228,15 @@ def _retrieve_es(self, download_url: str, iss: str) -> str: return jwt def _update_st(self, st: str) -> str: + """ + Updates the statement retrieving the new one using the source end_point and the sub fields of st payload. + + :param st: The statement in form of a JWT. + :type st: str + + :returns: the entity statement in form of JWT. + :rtype: str + """ payload = unpad_jwt_payload(st) iss = payload['iss'] if not is_es(payload): @@ -190,10 +270,22 @@ def _update_st(self, st: str) -> str: return jwt def set_exp(self, exp: int) -> None: + """ + Updates the self.exp field if the exp parameter is more recent than the previous one. + + :param exp: an integer that represent the timestemp to check + :type exp: int + """ if not self.exp or self.exp > exp: self.exp = exp def update(self) -> bool: + """ + Updates the statement retrieving and the exp filed and determines the validity of it. + + :returns: True if the updated chain is valid, False otherwise. + :rtype: bool + """ self.exp = 0 for st in self.static_trust_chain: jwt = self._update_st(st) @@ -204,29 +296,38 @@ def update(self) -> bool: self.updated_trust_chain.append(jwt) return self.is_valid + + @property + def is_valid(self) -> bool: + """Get the validity of chain.""" + return self.validate() @property def trust_chain(self) -> list[str]: + """Get the list of the jwt that compones the trust chain.""" return self.updated_trust_chain or self.static_trust_chain @property def is_expired(self) -> int: + """Get the status of chain expiration.""" return self._check_expired(self.exp) @property def entity_id(self) -> str: + """Get the chain's entity_id.""" chain = self.trust_chain payload = unpad_jwt_payload(chain[0]) return payload["iss"] @property def final_metadata(self) -> dict: - anchor = self.trust_anchor_jwks[-1] + """Apply the metadata and returns the final metadata.""" + anchor = self.static_trust_chain[-1] es_anchor_payload = unpad_jwt_payload(anchor) policy = es_anchor_payload.get("metadata_policy", {}) - leaf = self.trust_anchor_jwks[0] + leaf = self.static_trust_chain[0] es_leaf_payload = unpad_jwt_payload(leaf) - #return TrustChainPolicy().apply_policy(es_leaf_payload["metadata"], policy) + return TrustChainPolicy().apply_policy(es_leaf_payload["metadata"], policy) From 62edae68758f4af9645471d56acb6fdaa9c33d48 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Mon, 4 Dec 2023 18:24:36 +0100 Subject: [PATCH 17/88] fix: fixed trust_anchor_entity_conf handling --- pyeudiw/federation/statements.py | 7 ++++++- pyeudiw/federation/trust_chain_builder.py | 4 +++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/pyeudiw/federation/statements.py b/pyeudiw/federation/statements.py index c167475e..fa8d75d5 100644 --- a/pyeudiw/federation/statements.py +++ b/pyeudiw/federation/statements.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from copy import deepcopy from pyeudiw.federation.exceptions import ( UnknownKid, @@ -155,7 +157,7 @@ def __init__( jwt: str, httpc_params: dict, filter_by_allowed_trust_marks: list = [], - trust_anchor_entity_conf=None, + trust_anchor_entity_conf: 'EntityStatement' | None = None, trust_mark_issuers_entity_confs: dict = [], ): self.jwt = jwt @@ -196,6 +198,9 @@ def __init__( self.verified_trust_marks = [] self.is_valid = False + def update_trust_anchor_conf(self, trust_anchor_entity_conf: 'EntityStatement') -> None: + self.trust_anchor_entity_conf = trust_anchor_entity_conf + def validate_by_itself(self) -> bool: """ validates the entity configuration by it self diff --git a/pyeudiw/federation/trust_chain_builder.py b/pyeudiw/federation/trust_chain_builder.py index ad9b4319..e6959f7e 100644 --- a/pyeudiw/federation/trust_chain_builder.py +++ b/pyeudiw/federation/trust_chain_builder.py @@ -64,7 +64,9 @@ def __init__( trust_anchor_configuration = EntityStatement( jwts[0], httpc_params=self.httpc_params ) - trust_anchor_configuration.subject_configuration.validate_by_itself() + + subject_configuration.update_trust_anchor_conf(trust_anchor_configuration) + subject_configuration.validate_by_itself() except Exception as e: _msg = f"Entity Configuration for {self.trust_anchor} failed: {e}" logger.error(_msg) From 9243e61a5e033947b2ea817c57537157847dafb7 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Mon, 4 Dec 2023 18:29:26 +0100 Subject: [PATCH 18/88] docs: documented trust_chain_builder.py --- pyeudiw/federation/trust_chain_builder.py | 107 +++++++++++++++++----- 1 file changed, 83 insertions(+), 24 deletions(-) diff --git a/pyeudiw/federation/trust_chain_builder.py b/pyeudiw/federation/trust_chain_builder.py index e6959f7e..0371d3b3 100644 --- a/pyeudiw/federation/trust_chain_builder.py +++ b/pyeudiw/federation/trust_chain_builder.py @@ -26,13 +26,6 @@ class TrustChainBuilder: """ A trust walker that fetches statements and evaluate the evaluables - - max_intermediaries means how many hops are allowed to the trust anchor - max_authority_hints means how much authority_hints to follow on each hop - - required_trust_marks means all the trsut marks needed to start a metadata discovery - at least one of the required trust marks is needed to start a metadata discovery - if this param if absent the filter won't be considered. """ def __init__( @@ -42,14 +35,37 @@ def __init__( httpc_params: dict, trust_anchor_configuration: Union[EntityStatement, str, None] = None, max_authority_hints: int = 10, - subject_configuration: EntityStatement = None, - required_trust_marks: list = [], + subject_configuration: EntityStatement | None = None, + required_trust_marks: list[dict] = [], # TODO - prefetch cache? # pre_fetched_entity_configurations = {}, # pre_fetched_statements = {}, # **kwargs, ) -> None: + """ + Initialized a TrustChainBuilder istance + + :parameter subject: represents the subject url (leaf) of the Trust Chain + :type subject: str + :parameter trust_anchor: represents the issuer url (leaf) of the Trust Chain + :type trust_anchor: str + :param httpc_params: parameters needed to perform http requests + :type httpc_params: dict + :param trust_anchor_configuration: is the entity statment configuration of Trust Anchor. + The assigned value can be an EntityStatement, a str or None. + If the value is a string it will be converted in an EntityStatement istance. + If the value is None it will be retrieved from an http request on the trust_anchor field. + :parameter max_authority_hints: the number of how many authority_hints to follow on each hop + :type max_authority_hints: int + :parameter subject_configuration: the configuration of subject + :type subject_configuration: EntityStatement + :parameter required_trust_marks: means all the trust marks needed to start a metadata discovery + at least one of the required trust marks is needed to start a metadata discovery + if this param if absent the filter won't be considered. + :type required_trust_marks: list[dict] + + """ self.subject = subject self.subject_configuration = subject_configuration @@ -97,8 +113,10 @@ def __init__( def apply_metadata_policy(self) -> dict: """ filters the trust path from subject to trust anchor - apply the metadata policies along the path and - returns the final metadata + apply the metadata policies along the path. + + :returns: the final metadata with policy applied + :rtype: dict """ # find the path of trust if not self.trust_path: @@ -157,23 +175,26 @@ def apply_metadata_policy(self) -> dict: ) # set exp - self.set_exp() + self._set_exp() return self.final_metadata - @property - def exp_datetime(self) -> datetime.datetime: - if self.exp: # pragma: no cover - return datetime_from_timestamp(self.exp) - - def set_exp(self) -> int: + def _set_exp(self) -> None: + """ + updates the internal exp field with the nearest + expiraton date found in the trust_path field + """ exps = [i.payload["exp"] for i in self.trust_path] if exps: self.exp = min(exps) def discovery(self) -> bool: """ - return a chain of verified statements - from the lower up to the trust anchor + discovers the chain of verified statements + from the lower up to the trust anchor and updates + the internal representation of chain. + + :returns: the validity status of the updated chain + :rtype: bool """ logger.info( f"Starting a Walk into Metadata Discovery for {self.subject}") @@ -227,6 +248,10 @@ def discovery(self) -> bool: return self.is_valid def get_trust_anchor_configuration(self) -> None: + """ + Download and updates the internal field trust_anchor_configuration + with the entity statement of trust anchor. + """ if not isinstance(self.trust_anchor, EntityStatement): logger.info( f"Get Trust Anchor Entity Configuration for {self.subject}") @@ -247,8 +272,11 @@ def get_trust_anchor_configuration(self) -> None: self._set_max_path_len() - def _set_max_path_len(self): - + def _set_max_path_len(self) -> None: + """ + Sets the internal field max_path_len with the costraint + found in trust anchor payload + """ if self.trust_anchor_configuration.payload.get("constraints", {}).get( "max_path_length" ): @@ -259,6 +287,12 @@ def _set_max_path_len(self): ) def get_subject_configuration(self) -> None: + """ + Download and updates the internal field subject_configuration + with the entity statement of leaf. + + :rtype: None + """ if not self.subject_configuration: try: jwts = get_entity_configurations( @@ -291,10 +325,22 @@ def get_subject_configuration(self) -> None: else: self.verified_trust_marks.extend(sc.verified_trust_marks) - def serialize(self): + def serialize(self) -> str: + """ + Serializes the chain in JSON format. + + :returns: the serialized chain in JSON format + :rtype: str + """ return json.dumps(self.get_trust_chain()) - def get_trust_chain(self): + def get_trust_chain(self) -> list[str]: + """ + Retrieves the leaf and the Trust Anchor entity configurations. + + :returns: the list containing the ECs + :rtype: list[str] + """ res = [] # we keep just the leaf's and TA's EC, all the intermediates EC will be dropped ta_ec: str = "" @@ -314,6 +360,13 @@ def get_trust_chain(self): return res def start(self): + """ + Retrieves the subject (leaf) configuration and starts + chain discovery. + + :returns: the list containing the ECs + :rtype: list[str] + """ try: # self.get_trust_anchor_configuration() self.get_subject_configuration() @@ -322,3 +375,9 @@ def start(self): self.is_valid = False logger.error(f"{e}") raise e + + @property + def exp_datetime(self) -> datetime.datetime: + """The exp filed converted in datetime format""" + if self.exp: # pragma: no cover + return datetime_from_timestamp(self.exp) \ No newline at end of file From 24a878213f88e0db93cfb38d9455f89627ad4572 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Tue, 5 Dec 2023 15:52:48 +0100 Subject: [PATCH 19/88] fix: moved implementation of get_http_url in utils.py --- pyeudiw/tools/utils.py | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/pyeudiw/tools/utils.py b/pyeudiw/tools/utils.py index fcd9d1a2..bf605620 100644 --- a/pyeudiw/tools/utils.py +++ b/pyeudiw/tools/utils.py @@ -1,7 +1,11 @@ import datetime import json import logging +import asyncio +import requests + from secrets import token_hex +from pyeudiw.federation.http_client import http_get logger = logging.getLogger(__name__) @@ -26,10 +30,32 @@ def datetime_from_timestamp(value) -> datetime.datetime: return make_timezone_aware(datetime.datetime.fromtimestamp(value)) -def get_http_url(url: str): - raise NotImplementedError( - f"{__name__} get_http_url is not implemented, please see federation.statements" - ) +def get_http_url(urls: list[str] | str, httpc_params: dict, http_async: bool = True) -> list[dict]: + """ + Perform an HTTP Request returning the payload of the call. + + :param urls: The url or a list of url where perform the GET HTTP calls + :type urls: list[str] | str + :param httpc_params: parameters to perform http requests. + :type httpc_params: dict + :param http_async: if is set to True the operation will be performed in async (deafault True) + :type http_async: bool + + :returns: A list of responses. + :rtype: list[dict] + """ + + urls = urls if isinstance(urls, list) else [urls] + + if http_async: + responses = asyncio.run( + http_get(urls, httpc_params)) # pragma: no cover + else: + responses = [] + for i in urls: + res = requests.get(i, **httpc_params) # nosec - B113 + responses.append(res.content.decode()) + return responses def get_jwks(httpc_params: dict, metadata: dict, federation_jwks: list = []) -> dict: From 29a876e1741a26be2690abccba58b48662ecc766 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Tue, 5 Dec 2023 18:33:26 +0100 Subject: [PATCH 20/88] fix: fixed response handling --- pyeudiw/tools/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyeudiw/tools/utils.py b/pyeudiw/tools/utils.py index bf605620..895dab69 100644 --- a/pyeudiw/tools/utils.py +++ b/pyeudiw/tools/utils.py @@ -44,7 +44,6 @@ def get_http_url(urls: list[str] | str, httpc_params: dict, http_async: bool = T :returns: A list of responses. :rtype: list[dict] """ - urls = urls if isinstance(urls, list) else [urls] if http_async: @@ -54,7 +53,7 @@ def get_http_url(urls: list[str] | str, httpc_params: dict, http_async: bool = T responses = [] for i in urls: res = requests.get(i, **httpc_params) # nosec - B113 - responses.append(res.content.decode()) + responses.append(res.content) return responses From 1de5916fa888bdd38d01e9fb76399147c89b162c Mon Sep 17 00:00:00 2001 From: PascalDR Date: Tue, 5 Dec 2023 18:39:54 +0100 Subject: [PATCH 21/88] docs: documented file class and function plus refactoring --- pyeudiw/federation/statements.py | 206 +++++++++++++++++++++++-------- 1 file changed, 157 insertions(+), 49 deletions(-) diff --git a/pyeudiw/federation/statements.py b/pyeudiw/federation/statements.py index fa8d75d5..9fd20b44 100644 --- a/pyeudiw/federation/statements.py +++ b/pyeudiw/federation/statements.py @@ -9,20 +9,18 @@ InvalidEntityHeader, InvalidEntityStatementPayload ) -from pyeudiw.federation.http_client import http_get from pyeudiw.federation.schemas.entity_configuration import ( EntityConfigurationHeader, EntityStatementPayload ) from pyeudiw.jwt.utils import unpad_jwt_payload, unpad_jwt_header from pyeudiw.jwt import JWSHelper +from pyeudiw.tools.utils import get_http_url from pydantic import ValidationError -import asyncio + import json import logging -import requests - try: pass @@ -34,42 +32,84 @@ logger = logging.getLogger(__name__) -def jwks_from_jwks_uri(jwks_uri: str, httpc_params: dict) -> list: - return [json.loads(asyncio.run(http_get([jwks_uri], httpc_params)))] # pragma: no cover +def jwks_from_jwks_uri(jwks_uri: str, httpc_params: dict, http_async: bool = True) -> list[dict]: + """ + Retrieves jwks from an entity uri. + + :param jwks_uri: the uri where the jwks are located. + :type jwks_uri: str + :param httpc_params: parameters to perform http requests. + :type httpc_params: dict + :param http_async: if is set to True the operation will be performed in async (deafault True) + :type http_async: bool + + :returns: A list of entity jwks. + :rtype: list[dict] + """ + + response = get_http_url(jwks_uri, httpc_params, http_async) + jwks = json.loads(response) + + return [jwks] + +def get_federation_jwks(jwt_payload: dict) -> list[dict]: + """ + Returns the list of JWKS inside a JWT payload. + + :param jwt_payload: the jwt payload from where extract the JWKs. + :type jwt_payload: dict -def get_federation_jwks(jwt_payload: dict, httpc_params: dict): - return ( - jwt_payload.get("jwks", {}).get("keys", []) - ) + :returns: A list of entity jwk's keys. + :rtype: list[dict] + """ + jwks = jwt_payload.get("jwks", {}) + keys = jwks.get("keys", []) -def get_http_url(urls: list, httpc_params: dict, http_async: bool = True) -> list: - if http_async: - responses = asyncio.run( - http_get(urls, httpc_params)) # pragma: no cover - else: - responses = [] - for i in urls: - res = requests.get(i, **httpc_params) # nosec - B113 - responses.append(res.content.decode()) - return responses + return keys -def get_entity_statements(urls: list, httpc_params: dict) -> list: +def get_entity_statements(urls: list[str] | str, httpc_params: dict, http_async: bool = True) -> list[dict]: """ - Fetches an entity statement/configuration + Fetches an entity statement from the specified urls. + + :param urls: The url or a list of url where perform the GET HTTP calls + :type urls: list[str] | str + :param httpc_params: parameters to perform http requests. + :type httpc_params: dict + :param http_async: if is set to True the operation will be performed in async (deafault True) + :type http_async: bool + + :returns: A list of entity statements. + :rtype: list[dict] """ - if isinstance(urls, str): - urls = [urls] # pragma: no cover + + urls = urls if isinstance(urls, list) else [urls] + for url in urls: logger.debug(f"Starting Entity Statement Request to {url}") - return get_http_url(urls, httpc_params) + return get_http_url(urls, httpc_params, http_async) + + +def get_entity_configurations(subjects: list[str] | str, httpc_params: dict, http_async: bool = True): + """ + Fetches an entity configuration from the specified subjects. + + :param subjects: The url or a list of url where perform the GET HTTP calls + :type subjects: list[str] | str + :param httpc_params: parameters to perform http requests. + :type httpc_params: dict + :param http_async: if is set to True the operation will be performed in async (deafault True) + :type http_async: bool + + :returns: A list of entity statements. + :rtype: list[dict] + """ + + subjects = subjects if isinstance(subjects, list) else [subjects] -def get_entity_configurations(subjects: list, httpc_params: dict): - if isinstance(subjects, str): - subjects = [subjects] urls = [] for subject in subjects: if subject[-1] != "/": @@ -77,11 +117,23 @@ def get_entity_configurations(subjects: list, httpc_params: dict): url = f"{subject}{OIDCFED_FEDERATION_WELLKNOWN_URL}" urls.append(url) logger.info(f"Starting Entity Configuration Request for {url}") - return get_http_url(urls, httpc_params) + + return get_http_url(urls, httpc_params, http_async) class TrustMark: + """The class representing a Trust Mark""" + def __init__(self, jwt: str, httpc_params: dict): + """ + Create an instance of Trust Mark + + :param jwt: the JWT containing the trust marks + :type jwt: str + :param httpc_params: parameters to perform http requests. + :type httpc_params: dict + """ + self.jwt = jwt self.header = unpad_jwt_header(jwt) self.payload = unpad_jwt_payload(jwt) @@ -95,7 +147,16 @@ def __init__(self, jwt: str, httpc_params: dict): self.issuer_entity_configuration = None self.httpc_params = httpc_params - def validate_by(self, ec) -> bool: + def validate_by(self, ec: dict) -> bool: + """ + Validates Trust Marks by an Entity Configuration + + :param ec: the entity configuration to validate by + :type ec: dict + + :returns: True if is valid otherwise False + :rtype: bool + """ try: EntityConfigurationHeader(**self.header) except ValidationError as e: @@ -118,9 +179,15 @@ def validate_by(self, ec) -> bool: return payload def validate_by_its_issuer(self) -> bool: + """ + Validates Trust Marks by it's issuer + + :returns: True if is valid otherwise False + :rtype: bool + """ if not self.issuer_entity_configuration: self.issuer_entity_configuration = get_entity_configurations( - self.iss, self.httpc_params + self.iss, self.httpc_params, False ) try: ec = EntityStatement(self.issuer_entity_configuration[0]) @@ -156,17 +223,30 @@ def __init__( self, jwt: str, httpc_params: dict, - filter_by_allowed_trust_marks: list = [], + filter_by_allowed_trust_marks: list[str] = [], trust_anchor_entity_conf: 'EntityStatement' | None = None, - trust_mark_issuers_entity_confs: dict = [], + trust_mark_issuers_entity_confs: list[EntityStatement] = [], ): + """ + Creates EntityStatement istance + + :param jwt: the JWT containing the trust marks. + :type jwt: str + :param httpc_params: parameters to perform http requests. + :type httpc_params: dict + :param filter_by_allowed_trust_marks: allowed trust marks list. + :type filter_by_allowed_trust_marks: list[str] + :param trust_anchor_entity_conf: the trust anchor entity conf or None + :type trust_anchor_entity_conf: EntityStatement | None + :param trust_mark_issuers_entity_confs: the list containig the trust mark's entiity confs + """ self.jwt = jwt self.header = unpad_jwt_header(jwt) self.payload = unpad_jwt_payload(jwt) self.sub = self.payload["sub"] self.iss = self.payload["iss"] self.exp = self.payload["exp"] - self.jwks = get_federation_jwks(self.payload, httpc_params) + self.jwks = get_federation_jwks(self.payload) if not self.jwks or not self.jwks[0]: _msg = f"Missing jwks in the statement for {self.sub}" logger.error(_msg) @@ -199,6 +279,12 @@ def __init__( self.is_valid = False def update_trust_anchor_conf(self, trust_anchor_entity_conf: 'EntityStatement') -> None: + """ + Updates the internal Trust Anchor conf. + + :param trust_anchor_entity_conf: the trust anchor entity conf + :type trust_anchor_entity_conf: EntityStatement + """ self.trust_anchor_entity_conf = trust_anchor_entity_conf def validate_by_itself(self) -> bool: @@ -324,12 +410,22 @@ def validate_by_allowed_trust_marks(self) -> bool: def get_superiors( self, - authority_hints: list = [], + authority_hints: list[str] = [], max_authority_hints: int = 0, - superiors_hints: list = [], + superiors_hints: list[dict] = [], ) -> dict: """ get superiors entity configurations + + :param authority_hints: the authority hint list + :type authority_hints: list[str] + :param max_authority_hints: the number of max authority hint + :type max_authority_hints: int + :param superiors_hints: the list of superior hints + :type superiors_hints: list[dict] + + :returns: a dict with the superior's entity configurations + :rtype: dict """ # apply limits if defined authority_hints = authority_hints or deepcopy( @@ -366,7 +462,7 @@ def get_superiors( if not jwts: jwts = get_entity_configurations( - authority_hints, self.httpc_params) + authority_hints, self.httpc_params, False) for jwt in jwts: try: @@ -398,6 +494,12 @@ def get_superiors( def validate_descendant_statement(self, jwt: str) -> bool: """ jwt is a descendant entity statement issued by self + + :param jwt: the JWT to validate by + :type jwt: str + + :returns: True if is valid or False otherwise + :rtype: bool """ header = unpad_jwt_header(jwt) payload = unpad_jwt_payload(jwt) @@ -430,13 +532,16 @@ def validate_descendant_statement(self, jwt: str) -> bool: self.verified_descendant_statements_as_jwt[payload["sub"]] = jwt return self.verified_descendant_statements - def validate_by_superior_statement(self, jwt: str, ec): + def validate_by_superior_statement(self, jwt: str, ec: 'EntityStatement') -> str: """ - jwt is a statement issued by a superior - ec is a superior entity configuration - - this method validates self with the jwks contained in statement - of the superior + validates self with the jwks contained in statement of the superior + :param jwt: the statement issued by a superior in form of JWT + :type jwt: str + :param ec: is a superior entity configuration + :type ec: EntityStatement + + :returns: the entity configuration subject if is valid + :rtype: str """ is_valid = None payload = {} @@ -444,7 +549,7 @@ def validate_by_superior_statement(self, jwt: str, ec): payload = unpad_jwt_payload(jwt) ec.validate_by_itself() ec.validate_descendant_statement(jwt) - _jwks = get_federation_jwks(payload, self.httpc_params) + _jwks = get_federation_jwks(payload) _kids = [i.get("kid") for i in _jwks] jwsh = JWSHelper(_jwks[_kids.index(self.header["kid"])]) @@ -476,11 +581,14 @@ def validate_by_superiors( superiors_entity_configurations: dict = {}, ) -> dict: """ - validates the entity configuration with the entity statements - issued by its superiors + validates the entity configuration with the entity statements issued by its superiors + this methods create self.verified_superiors and failed ones and self.verified_by_superiors and failed ones - this methods create self.verified_superiors and failed ones - and self.verified_by_superiors and failed ones + :param superiors_entity_configurations: an object containing the entity configurations of superiors + :type superiors_entity_configurations: dict + + :returns: an object containing the superior validations + :rtype: dict """ for ec in superiors_entity_configurations: if ec.sub in ec.verified_by_superiors: @@ -503,7 +611,7 @@ def validate_by_superiors( else: _url = f"{fetch_api_url}?sub={self.sub}" logger.info(f"Getting entity statements from {_url}") - jwts = get_entity_statements([_url], self.httpc_params) + jwts = get_entity_statements([_url], self.httpc_params, False) if not jwts: logger.error( f"Empty response for {_url}" From 7d5a27389223af492ab02d1bdd78d9bde266932d Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 6 Dec 2023 16:31:57 +0100 Subject: [PATCH 22/88] docs: documented file __init__.py --- pyeudiw/federation/__init__.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/pyeudiw/federation/__init__.py b/pyeudiw/federation/__init__.py index e8044ceb..afdc418d 100644 --- a/pyeudiw/federation/__init__.py +++ b/pyeudiw/federation/__init__.py @@ -1,6 +1,16 @@ from pyeudiw.federation.schemas.entity_configuration import EntityStatementPayload, EntityConfigurationPayload def is_es(payload: dict) -> bool: + """ + Determines if payload dict is an Entity Statement + + :param payload: the object to determine if is an Entity Statement + :type payload: dict + + :returns: True if is an Entity Statement and False otherwise + :rtype: bool + """ + try: EntityStatementPayload(**payload) if payload["iss"] != payload["sub"]: @@ -10,6 +20,16 @@ def is_es(payload: dict) -> bool: def is_ec(payload: dict) -> bool: + """ + Determines if payload dict is an Entity Configuration + + :param payload: the object to determine if is an Entity Configuration + :type payload: dict + + :returns: True if is an Entity Configuration and False otherwise + :rtype: bool + """ + try: EntityConfigurationPayload(**payload) return True From baaede8e5a37ef76ed856707b126c762f3962a1c Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 6 Dec 2023 18:47:20 +0100 Subject: [PATCH 23/88] docs: added docs for http_client.py --- pyeudiw/federation/http_client.py | 44 +++++++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/pyeudiw/federation/http_client.py b/pyeudiw/federation/http_client.py index 177ab666..a7cb3b04 100644 --- a/pyeudiw/federation/http_client.py +++ b/pyeudiw/federation/http_client.py @@ -3,7 +3,21 @@ import requests -async def fetch(session, url, httpc_params: dict): +async def fetch(session: dict, url: str, httpc_params: dict) -> str: + """ + Fetches the content of a URL. + + :param session: a dict representing the current session + :type session: dict + :param url: the url where fetch the content + :type url: str + :param httpc_params: parameters to perform http requests. + :type httpc_params: dict + + :returns: the response in string format + :rtype: str + """ + async with session.get(url, **httpc_params.get("connection", {})) as response: if response.status != 200: # pragma: no cover # response.raise_for_status() @@ -11,7 +25,21 @@ async def fetch(session, url, httpc_params: dict): return await response.text() -async def fetch_all(session, urls, httpc_params: dict): +async def fetch_all(session: dict, urls: list[str], httpc_params: dict) -> list[str]: + """ + Fetches the content of a list of URL. + + :param session: a dict representing the current session + :type session: dict + :param urls: the url list where fetch the content + :type urls: list[str] + :param httpc_params: parameters to perform http requests. + :type httpc_params: dict + + :returns: the list of responses in string format + :rtype: list[str] + """ + tasks = [] for url in urls: task = asyncio.create_task(fetch(session, url, httpc_params)) @@ -21,7 +49,19 @@ async def fetch_all(session, urls, httpc_params: dict): async def http_get(urls, httpc_params: dict, sync=True): + """ + Perform a GET http call. + + :param session: a dict representing the current session + :type session: dict + :param urls: the url list where fetch the content + :type urls: list[str] + :param httpc_params: parameters to perform http requests. + :type httpc_params: dict + :returns: the list of responses in string format + :rtype: list[str] + """ if sync: _conf = { 'verify': httpc_params['connection']['ssl'], From 2d04f1cb2a4acb0297a355d3f01940d66f15740e Mon Sep 17 00:00:00 2001 From: PascalDR Date: Mon, 11 Dec 2023 10:31:58 +0100 Subject: [PATCH 24/88] docs: documented the content of __init__.py --- pyeudiw/jwk/__init__.py | 48 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/pyeudiw/jwk/__init__.py b/pyeudiw/jwk/__init__.py index 1834a15f..e3515dbd 100644 --- a/pyeudiw/jwk/__init__.py +++ b/pyeudiw/jwk/__init__.py @@ -13,6 +13,10 @@ class JWK(): + """ + The class representing a JWK istance + """ + def __init__( self, key: Union[dict, None] = None, @@ -20,7 +24,19 @@ def __init__( hash_func: str = 'SHA-256', ec_crv: str = "P-256" ) -> None: + """ + Creates an instance of JWK. + :param key: An optional key in dict form. + If no key is provided a randomic key will be generated. + :type key: Union[dict, None] + :param key_type: a string that represents the key type. Can be EC or RSA. + :type key_type: str + :param hash_func: a string that represents the hash function to use with the instance. + :type hash_func: str + :param ec_crv: a string that represents the curve to use with the instance. + :type ec_crv: str + """ kwargs = {} self.kid = "" @@ -46,10 +62,22 @@ def __init__( self.public_key = self.key.serialize() self.public_key['kid'] = self.jwk["kid"] - def as_json(self): + def as_json(self) -> str: + """ + Returns the JWK in format of json string. + + :returns: A json string that represents the key. + :rtype: str + """ return json.dumps(self.jwk) - def export_private_pem(self): + def export_private_pem(self) -> str: + """ + Returns the JWK in format of a private pem certificte. + + :returns: A private pem certificate that represents the key. + :rtype: str + """ _k = key_from_jwk_dict(self.jwk) pk = _k.private_key() pem = pk.private_bytes( @@ -59,7 +87,13 @@ def export_private_pem(self): ) return pem.decode() - def export_public_pem(self): + def export_public_pem(self) -> str: + """ + Returns the JWK in format of a public pem certificte. + + :returns: A public pem certificate that represents the key. + :rtype: str + """ _k = key_from_jwk_dict(self.jwk) pk = _k.public_key() cert = pk.public_bytes( @@ -68,7 +102,13 @@ def export_public_pem(self): ) return cert.decode() - def as_dict(self): + def as_dict(self) -> dict: + """ + Returns the JWK in format of dict. + + :returns: The key in form of dict. + :rtype: dict + """ return self.jwk def __repr__(self): From 0f612358ecd97ab0ee20c9ce91fb464737f3ea7b Mon Sep 17 00:00:00 2001 From: PascalDR Date: Mon, 11 Dec 2023 11:03:46 +0100 Subject: [PATCH 25/88] docs: documented contento of __init__.py --- pyeudiw/jwt/__init__.py | 70 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 66 insertions(+), 4 deletions(-) diff --git a/pyeudiw/jwt/__init__.py b/pyeudiw/jwt/__init__.py index ed9661e3..4d4ba2b8 100644 --- a/pyeudiw/jwt/__init__.py +++ b/pyeudiw/jwt/__init__.py @@ -1,6 +1,6 @@ import binascii import json -from typing import Union +from typing import Union, Any import cryptojwt from cryptojwt.exception import VerificationError @@ -38,10 +38,32 @@ class JWEHelper(): - def __init__(self, jwk: JWK): + """ + The helper class for work with JWEs. + """ + def __init__(self, jwk: Union[JWK, dict]): + """ + Creates an instance of JWEHelper. + + :param jwk: The JWK used to crypt and encrypt the content of JWE. + :type jwk: JWK + """ self.jwk = jwk + if isinstance(jwk, dict): + self.jwk = JWK(jwk) + self.alg = DEFAULT_SIG_KTY_MAP[self.jwk.key.kty] def encrypt(self, plain_dict: Union[dict, str, int, None], **kwargs) -> str: + """ + Generate a encrypted JWE string. + + :param plain_dict: The payload of JWE. + :type plain_dict: Union[dict, str, int, None] + :param kwargs: Other optional fields to generate the JWE. + + :returns: A string that represents the JWE. + :rtype: str + """ _key = key_from_jwk_dict(self.jwk.as_dict()) if isinstance(_key, cryptojwt.jwk.rsa.RSAKey): @@ -75,6 +97,15 @@ def encrypt(self, plain_dict: Union[dict, str, int, None], **kwargs) -> str: return _keyobj.encrypt(key=_key.public_key()) def decrypt(self, jwe: str) -> dict: + """ + Generate a dict containing the content of decrypted JWE string. + + :param jwe: A string representing the jwe. + :type jwe: str + + :returns: A dict that represents the payload of decrypted JWE. + :rtype: dict + """ try: jwe_header = unpad_jwt_header(jwe) except (binascii.Error, Exception) as e: @@ -97,7 +128,16 @@ def decrypt(self, jwe: str) -> dict: class JWSHelper: + """ + The helper class for work with JWEs. + """ def __init__(self, jwk: Union[JWK, dict]): + """ + Creates an instance of JWSHelper. + + :param jwk: The JWK used to sign and verify the content of JWS. + :type jwk: Union[JWK, dict] + """ self.jwk = jwk if isinstance(jwk, dict): self.jwk = JWK(jwk) @@ -109,7 +149,19 @@ def sign( protected: dict = {}, **kwargs ) -> str: - + """ + Generate a encrypted JWS string. + + :param plain_dict: The payload of JWS. + :type plain_dict: Union[dict, str, int, None] + :param protected: a dict containing all the values + to include in the protected header. + :type protected: dict + :param kwargs: Other optional fields to generate the JWE. + + :returns: A string that represents the JWS. + :rtype: str + """ _key = key_from_jwk_dict(self.jwk.as_dict()) _payload: str | int | bytes = "" @@ -126,7 +178,17 @@ def sign( _signer = JWSec(_payload, alg=self.alg, **kwargs) return _signer.sign_compact([_key], protected=protected, **kwargs) - def verify(self, jws: str, **kwargs): + def verify(self, jws: str, **kwargs) -> (str | Any | bytes): + """ + Verify a JWS string. + + :param jws: A string representing the jwe. + :type jws: str + :param kwargs: Other optional fields to generate the JWE. + + :returns: A string that represents the payload of JWS. + :rtype: str + """ _key = key_from_jwk_dict(self.jwk.as_dict()) _jwk_dict = self.jwk.as_dict() _head = unpad_jwt_header(jws) From 1fef46182edd540aa0ea70f953a5179ada53077d Mon Sep 17 00:00:00 2001 From: PascalDR Date: Mon, 11 Dec 2023 11:23:07 +0100 Subject: [PATCH 26/88] fix: method name refactoring --- pyeudiw/federation/statements.py | 16 +++++++-------- pyeudiw/federation/trust_chain_validator.py | 22 ++++++++++----------- pyeudiw/jwt/__init__.py | 6 +++--- pyeudiw/jwt/utils.py | 17 ++++++++++------ pyeudiw/oauth2/dpop/__init__.py | 8 ++++---- pyeudiw/openid4vp/direct_post_response.py | 4 ++-- pyeudiw/openid4vp/vp.py | 10 +++++----- pyeudiw/satosa/dpop.py | 6 +++--- pyeudiw/satosa/trust.py | 5 ++--- pyeudiw/sd_jwt/__init__.py | 4 ++-- pyeudiw/tests/oauth2/test_dpop.py | 8 ++++---- pyeudiw/tests/satosa/test_backend.py | 8 ++++---- pyeudiw/tests/test_jwt.py | 6 +++--- pyeudiw/trust/__init__.py | 8 ++++---- pyeudiw/trust/trust_chain.py | 2 +- 15 files changed, 67 insertions(+), 63 deletions(-) diff --git a/pyeudiw/federation/statements.py b/pyeudiw/federation/statements.py index 9fd20b44..596cb2b1 100644 --- a/pyeudiw/federation/statements.py +++ b/pyeudiw/federation/statements.py @@ -13,7 +13,7 @@ EntityConfigurationHeader, EntityStatementPayload ) -from pyeudiw.jwt.utils import unpad_jwt_payload, unpad_jwt_header +from pyeudiw.jwt.utils import decode_jwt_payload, decode_jwt_header from pyeudiw.jwt import JWSHelper from pyeudiw.tools.utils import get_http_url from pydantic import ValidationError @@ -135,8 +135,8 @@ def __init__(self, jwt: str, httpc_params: dict): """ self.jwt = jwt - self.header = unpad_jwt_header(jwt) - self.payload = unpad_jwt_payload(jwt) + self.header = decode_jwt_header(jwt) + self.payload = decode_jwt_payload(jwt) self.id = self.payload["id"] self.sub = self.payload["sub"] @@ -241,8 +241,8 @@ def __init__( :param trust_mark_issuers_entity_confs: the list containig the trust mark's entiity confs """ self.jwt = jwt - self.header = unpad_jwt_header(jwt) - self.payload = unpad_jwt_payload(jwt) + self.header = decode_jwt_header(jwt) + self.payload = decode_jwt_payload(jwt) self.sub = self.payload["sub"] self.iss = self.payload["iss"] self.exp = self.payload["exp"] @@ -501,8 +501,8 @@ def validate_descendant_statement(self, jwt: str) -> bool: :returns: True if is valid or False otherwise :rtype: bool """ - header = unpad_jwt_header(jwt) - payload = unpad_jwt_payload(jwt) + header = decode_jwt_header(jwt) + payload = decode_jwt_payload(jwt) try: EntityConfigurationHeader(**header) @@ -546,7 +546,7 @@ def validate_by_superior_statement(self, jwt: str, ec: 'EntityStatement') -> str is_valid = None payload = {} try: - payload = unpad_jwt_payload(jwt) + payload = decode_jwt_payload(jwt) ec.validate_by_itself() ec.validate_descendant_statement(jwt) _jwks = get_federation_jwks(payload) diff --git a/pyeudiw/federation/trust_chain_validator.py b/pyeudiw/federation/trust_chain_validator.py index 9c00c53a..23540fe4 100644 --- a/pyeudiw/federation/trust_chain_validator.py +++ b/pyeudiw/federation/trust_chain_validator.py @@ -1,7 +1,7 @@ import logging from pyeudiw.tools.utils import iat_now from pyeudiw.jwt import JWSHelper -from pyeudiw.jwt.utils import unpad_jwt_payload, unpad_jwt_header +from pyeudiw.jwt.utils import decode_jwt_payload, decode_jwt_header from pyeudiw.federation import is_es from pyeudiw.federation.policy import TrustChainPolicy from pyeudiw.federation.statements import ( @@ -141,8 +141,8 @@ def validate(self) -> bool: # inspect the entity statement kid header to know which # TA's public key to use for the validation last_element = rev_tc[0] - es_header = unpad_jwt_header(last_element) - es_payload = unpad_jwt_payload(last_element) + es_header = decode_jwt_header(last_element) + es_payload = decode_jwt_payload(last_element) ta_jwk = find_jwk( es_header.get("kid", None), self.trust_anchor_jwks @@ -169,8 +169,8 @@ def validate(self) -> bool: # validate the entire chain taking in cascade using fed_jwks # if valid -> update fed_jwks with $st for st in rev_tc[1:]: - st_header = unpad_jwt_header(st) - st_payload = unpad_jwt_payload(st) + st_header = decode_jwt_header(st) + st_payload = decode_jwt_payload(st) jwk = find_jwk( st_header.get("kid", None), fed_jwks ) @@ -237,7 +237,7 @@ def _update_st(self, st: str) -> str: :returns: the entity statement in form of JWT. :rtype: str """ - payload = unpad_jwt_payload(st) + payload = decode_jwt_payload(st) iss = payload['iss'] if not is_es(payload): # It's an entity configuration @@ -251,7 +251,7 @@ def _update_st(self, st: str) -> str: ) else: ec = self._retrieve_ec(iss) - ec_data = unpad_jwt_payload(ec) + ec_data = decode_jwt_payload(ec) fetch_api_url = None try: @@ -290,7 +290,7 @@ def update(self) -> bool: for st in self.static_trust_chain: jwt = self._update_st(st) - exp = unpad_jwt_payload(jwt)["exp"] + exp = decode_jwt_payload(jwt)["exp"] self.set_exp(exp) self.updated_trust_chain.append(jwt) @@ -316,18 +316,18 @@ def is_expired(self) -> int: def entity_id(self) -> str: """Get the chain's entity_id.""" chain = self.trust_chain - payload = unpad_jwt_payload(chain[0]) + payload = decode_jwt_payload(chain[0]) return payload["iss"] @property def final_metadata(self) -> dict: """Apply the metadata and returns the final metadata.""" anchor = self.static_trust_chain[-1] - es_anchor_payload = unpad_jwt_payload(anchor) + es_anchor_payload = decode_jwt_payload(anchor) policy = es_anchor_payload.get("metadata_policy", {}) leaf = self.static_trust_chain[0] - es_leaf_payload = unpad_jwt_payload(leaf) + es_leaf_payload = decode_jwt_payload(leaf) return TrustChainPolicy().apply_policy(es_leaf_payload["metadata"], policy) diff --git a/pyeudiw/jwt/__init__.py b/pyeudiw/jwt/__init__.py index 4d4ba2b8..ed7d4a13 100644 --- a/pyeudiw/jwt/__init__.py +++ b/pyeudiw/jwt/__init__.py @@ -12,7 +12,7 @@ from pyeudiw.jwk import JWK from pyeudiw.jwk.exceptions import KidError -from pyeudiw.jwt.utils import unpad_jwt_header +from pyeudiw.jwt.utils import decode_jwt_header DEFAULT_HASH_FUNC = "SHA-256" @@ -107,7 +107,7 @@ def decrypt(self, jwe: str) -> dict: :rtype: dict """ try: - jwe_header = unpad_jwt_header(jwe) + jwe_header = decode_jwt_header(jwe) except (binascii.Error, Exception) as e: raise VerificationError("The JWT is not valid") @@ -191,7 +191,7 @@ def verify(self, jws: str, **kwargs) -> (str | Any | bytes): """ _key = key_from_jwk_dict(self.jwk.as_dict()) _jwk_dict = self.jwk.as_dict() - _head = unpad_jwt_header(jws) + _head = decode_jwt_header(jws) if _head.get("kid"): if _head["kid"] != _jwk_dict["kid"]: # pragma: no cover diff --git a/pyeudiw/jwt/utils.py b/pyeudiw/jwt/utils.py index ca89ca50..d9a95785 100644 --- a/pyeudiw/jwt/utils.py +++ b/pyeudiw/jwt/utils.py @@ -2,11 +2,16 @@ import json import re +from pyeudiw.jwt.exceptions import JWTInvalidElementPosition + # JWT_REGEXP = r"^(([-A-Za-z0-9\=_])*\.([-A-Za-z0-9\=_])*\.([-A-Za-z0-9\=_])*)$" JWT_REGEXP = r'^[\w\-]+\.[\w\-]+\.[\w\-]+' -def unpad_jwt_element(jwt: str, position: int) -> dict: +def decode_jwt_element(jwt: str, position: int) -> dict: + if position > 1: + raise JWTInvalidElementPosition(f"JWT has no jwt element in position {position}") + if isinstance(jwt, bytes): jwt = jwt.decode() b = jwt.split(".")[position] @@ -15,19 +20,19 @@ def unpad_jwt_element(jwt: str, position: int) -> dict: return data -def unpad_jwt_header(jwt: str) -> dict: - return unpad_jwt_element(jwt, position=0) +def decode_jwt_header(jwt: str) -> dict: + return decode_jwt_element(jwt, position=0) -def unpad_jwt_payload(jwt: str) -> dict: - return unpad_jwt_element(jwt, position=1) +def decode_jwt_payload(jwt: str) -> dict: + return decode_jwt_element(jwt, position=1) def get_jwk_from_jwt(jwt: str, provider_jwks: dict) -> dict: """ docs here """ - head = unpad_jwt_header(jwt) + head = decode_jwt_header(jwt) kid = head["kid"] if isinstance(provider_jwks, dict) and provider_jwks.get('keys'): provider_jwks = provider_jwks['keys'] diff --git a/pyeudiw/oauth2/dpop/__init__.py b/pyeudiw/oauth2/dpop/__init__.py index 2672df96..d478a9da 100644 --- a/pyeudiw/oauth2/dpop/__init__.py +++ b/pyeudiw/oauth2/dpop/__init__.py @@ -11,7 +11,7 @@ ) from pyeudiw.jwk.exceptions import KidError from pyeudiw.jwt import JWSHelper -from pyeudiw.jwt.utils import unpad_jwt_header, unpad_jwt_payload +from pyeudiw.jwt.utils import decode_jwt_header, decode_jwt_payload from pyeudiw.oauth2.dpop.schema import ( DPoPTokenHeaderSchema, DPoPTokenPayloadSchema @@ -72,7 +72,7 @@ def __init__( ) # If the jwt is invalid, this will raise an exception try: - unpad_jwt_header(http_header_dpop) + decode_jwt_header(http_header_dpop) except UnicodeDecodeError as e: logger.error( "DPoP proof validation error, " @@ -108,7 +108,7 @@ def validate(self) -> bool: f"{e.__class__.__name__}: {e}" ) - header = unpad_jwt_header(self.proof) + header = decode_jwt_header(self.proof) DPoPTokenHeaderSchema(**header) if header['jwk'] != self.public_jwk: @@ -118,7 +118,7 @@ def validate(self) -> bool: f"{header['jwk']} != {self.public_jwk}" )) - payload = unpad_jwt_payload(self.proof) + payload = decode_jwt_payload(self.proof) DPoPTokenPayloadSchema(**payload) _ath = hashlib.sha256(self.dpop_token.encode()) diff --git a/pyeudiw/openid4vp/direct_post_response.py b/pyeudiw/openid4vp/direct_post_response.py index 38eae6a7..f5592086 100644 --- a/pyeudiw/openid4vp/direct_post_response.py +++ b/pyeudiw/openid4vp/direct_post_response.py @@ -3,7 +3,7 @@ from pyeudiw.jwt import JWEHelper from pyeudiw.jwt.exceptions import JWEDecryptionError from pyeudiw.jwk.exceptions import KidNotFoundError -from pyeudiw.jwt.utils import unpad_jwt_header +from pyeudiw.jwt.utils import decode_jwt_header from pyeudiw.openid4vp.exceptions import ( VPNotFound, VPInvalidNonce, @@ -16,7 +16,7 @@ class DirectPostResponse: def __init__(self, jwt: str, jwks_by_kids: dict, nonce: str = ""): - self.headers = unpad_jwt_header(jwt) + self.headers = decode_jwt_header(jwt) self.jwks_by_kids = jwks_by_kids self.jwt = jwt self.nonce = nonce diff --git a/pyeudiw/openid4vp/vp.py b/pyeudiw/openid4vp/vp.py index 69ab61ad..fcfb385b 100644 --- a/pyeudiw/openid4vp/vp.py +++ b/pyeudiw/openid4vp/vp.py @@ -1,5 +1,5 @@ -from pyeudiw.jwt.utils import unpad_jwt_payload, unpad_jwt_header +from pyeudiw.jwt.utils import decode_jwt_payload, decode_jwt_header from pyeudiw.openid4vp.vp_sd_jwt import VpSdJwt @@ -7,9 +7,9 @@ class Vp(VpSdJwt): def __init__(self, jwt: str): # TODO: what if the credential is not a JWT? - self.headers = unpad_jwt_header(jwt) + self.headers = decode_jwt_header(jwt) self.jwt = jwt - self.payload = unpad_jwt_payload(jwt) + self.payload = decode_jwt_payload(jwt) self.credential_headers: dict = {} self.credential_payload: dict = {} @@ -35,8 +35,8 @@ def credential_issuer(self): def parse_digital_credential(self): _typ = self._detect_vp_type() if _typ == 'jwt': - self.credential_headers = unpad_jwt_header(self.payload['vp']) - self.credential_payload = unpad_jwt_payload(self.payload['vp']) + self.credential_headers = decode_jwt_header(self.payload['vp']) + self.credential_payload = decode_jwt_payload(self.payload['vp']) else: raise NotImplementedError( f"VP Digital credentials type not implemented yet: {_typ}" diff --git a/pyeudiw/satosa/dpop.py b/pyeudiw/satosa/dpop.py index c0775cb9..366bb45b 100644 --- a/pyeudiw/satosa/dpop.py +++ b/pyeudiw/satosa/dpop.py @@ -3,7 +3,7 @@ from typing import Union -from pyeudiw.jwt.utils import unpad_jwt_header, unpad_jwt_payload +from pyeudiw.jwt.utils import decode_jwt_header, decode_jwt_payload from pyeudiw.oauth2.dpop import DPoPVerifier from pyeudiw.openid4vp.schemas.wallet_instance_attestation import WalletInstanceAttestationPayload, \ WalletInstanceAttestationHeader @@ -25,8 +25,8 @@ def _request_endpoint_dpop(self, context, *args) -> Union[JsonResponse, None]: # take WIA dpop_jws = context.http_headers['HTTP_AUTHORIZATION'].split()[-1] - _head = unpad_jwt_header(dpop_jws) - wia = unpad_jwt_payload(dpop_jws) + _head = decode_jwt_header(dpop_jws) + wia = decode_jwt_payload(dpop_jws) self._log( context, diff --git a/pyeudiw/satosa/trust.py b/pyeudiw/satosa/trust.py index 73a066f8..5842c78a 100644 --- a/pyeudiw/satosa/trust.py +++ b/pyeudiw/satosa/trust.py @@ -8,8 +8,7 @@ from pyeudiw.jwk import JWK from pyeudiw.jwt import JWSHelper -from pyeudiw.jwt.utils import unpad_jwt_header -from pyeudiw.federation.trust_chain_builder import TrustChainBuilder +from pyeudiw.jwt.utils import decode_jwt_header from pyeudiw.satosa.exceptions import ( NotTrustedFederationError, DiscoveryFailedError ) @@ -166,7 +165,7 @@ def _validate_trust(self, context: Context, jws: str) -> TrustEvaluationHelper: ) ) - headers = unpad_jwt_header(jws) + headers = decode_jwt_header(jws) trust_eval = TrustEvaluationHelper( self.db_engine, httpc_params=self.config['network']['httpc_params'], diff --git a/pyeudiw/sd_jwt/__init__.py b/pyeudiw/sd_jwt/__init__.py index 14cd1d89..0fbe0e53 100644 --- a/pyeudiw/sd_jwt/__init__.py +++ b/pyeudiw/sd_jwt/__init__.py @@ -13,7 +13,7 @@ from pyeudiw.jwk import JWK from pyeudiw.jwt import DEFAULT_SIG_KTY_MAP -from pyeudiw.jwt.utils import unpad_jwt_payload +from pyeudiw.jwt.utils import decode_jwt_payload from pyeudiw.tools.utils import exp_from_now, iat_now from jwcrypto.jws import JWS @@ -167,7 +167,7 @@ def verify_sd_jwt( settings.update( { - "issuer": unpad_jwt_payload(sd_jwt_presentation)["iss"] + "issuer": decode_jwt_payload(sd_jwt_presentation)["iss"] } ) adapted_keys = { diff --git a/pyeudiw/tests/oauth2/test_dpop.py b/pyeudiw/tests/oauth2/test_dpop.py index f65a8a10..b5a82e2b 100644 --- a/pyeudiw/tests/oauth2/test_dpop.py +++ b/pyeudiw/tests/oauth2/test_dpop.py @@ -4,7 +4,7 @@ from pyeudiw.jwk import JWK from pyeudiw.jwt import JWSHelper -from pyeudiw.jwt.utils import unpad_jwt_header, unpad_jwt_payload +from pyeudiw.jwt.utils import decode_jwt_header, decode_jwt_payload from pyeudiw.oauth2.dpop import DPoPIssuer, DPoPVerifier from pyeudiw.oauth2.dpop.exceptions import InvalidDPoPKid from pyeudiw.tools.utils import iat_now @@ -67,7 +67,7 @@ def wia_jws(jwshelper): def test_create_validate_dpop_http_headers(wia_jws, private_jwk=PRIVATE_JWK): # create - header = unpad_jwt_header(wia_jws) + header = decode_jwt_header(wia_jws) assert header assert isinstance(header["trust_chain"], list) assert isinstance(header["x5c"], list) @@ -82,13 +82,13 @@ def test_create_validate_dpop_http_headers(wia_jws, private_jwk=PRIVATE_JWK): proof = new_dpop.proof assert proof - header = unpad_jwt_header(proof) + header = decode_jwt_header(proof) assert header["typ"] == "dpop+jwt" assert header["alg"] assert "mac" not in str(header["alg"]).lower() assert "d" not in header["jwk"] - payload = unpad_jwt_payload(proof) + payload = decode_jwt_payload(proof) assert payload["ath"] == base64.urlsafe_b64encode( hashlib.sha256(wia_jws.encode() ).digest()).rstrip(b'=').decode() diff --git a/pyeudiw/tests/satosa/test_backend.py b/pyeudiw/tests/satosa/test_backend.py index 15b9540e..cd507927 100644 --- a/pyeudiw/tests/satosa/test_backend.py +++ b/pyeudiw/tests/satosa/test_backend.py @@ -13,9 +13,9 @@ from sd_jwt.holder import SDJWTHolder from pyeudiw.jwk import JWK -from pyeudiw.jwt import JWEHelper, JWSHelper, unpad_jwt_header, DEFAULT_SIG_KTY_MAP +from pyeudiw.jwt import JWEHelper, JWSHelper, decode_jwt_header, DEFAULT_SIG_KTY_MAP from cryptojwt.jws.jws import JWS -from pyeudiw.jwt.utils import unpad_jwt_payload +from pyeudiw.jwt.utils import decode_jwt_payload from pyeudiw.oauth2.dpop import DPoPIssuer from pyeudiw.satosa.backend import OpenID4VPBackend from pyeudiw.sd_jwt import ( @@ -522,8 +522,8 @@ def test_request_endpoint(self, context): msg = json.loads(request_endpoint.message) assert msg["response"] - header = unpad_jwt_header(msg["response"]) - payload = unpad_jwt_payload(msg["response"]) + header = decode_jwt_header(msg["response"]) + payload = decode_jwt_payload(msg["response"]) assert header["alg"] assert header["kid"] assert payload["scope"] == " ".join(CONFIG["authorization"]["scopes"]) diff --git a/pyeudiw/tests/test_jwt.py b/pyeudiw/tests/test_jwt.py index 0e8771a9..d0982098 100644 --- a/pyeudiw/tests/test_jwt.py +++ b/pyeudiw/tests/test_jwt.py @@ -3,7 +3,7 @@ from pyeudiw.jwk import JWK from pyeudiw.jwt import (DEFAULT_ENC_ALG_MAP, DEFAULT_ENC_ENC_MAP, JWEHelper, JWSHelper) -from pyeudiw.jwt.utils import unpad_jwt_header +from pyeudiw.jwt.utils import decode_jwt_header JWKs_EC = [ (JWK(key_type="EC"), {"key": "value"}), @@ -24,11 +24,11 @@ @pytest.mark.parametrize("jwk, payload", JWKs_RSA) -def test_unpad_jwt_header(jwk, payload): +def test_decode_jwt_header(jwk, payload): jwe_helper = JWEHelper(jwk) jwe = jwe_helper.encrypt(payload) assert jwe - header = unpad_jwt_header(jwe) + header = decode_jwt_header(jwe) assert header assert header["alg"] == DEFAULT_ENC_ALG_MAP[jwk.jwk["kty"]] assert header["enc"] == DEFAULT_ENC_ENC_MAP[jwk.jwk["kty"]] diff --git a/pyeudiw/trust/__init__.py b/pyeudiw/trust/__init__.py index ef8cd089..82c1bad4 100644 --- a/pyeudiw/trust/__init__.py +++ b/pyeudiw/trust/__init__.py @@ -6,7 +6,7 @@ from pyeudiw.federation.exceptions import ProtocolMetadataNotFound from pyeudiw.satosa.exceptions import DiscoveryFailedError from pyeudiw.storage.db_engine import DBEngine -from pyeudiw.jwt.utils import unpad_jwt_payload, is_jwt_format +from pyeudiw.jwt.utils import decode_jwt_payload, is_jwt_format from pyeudiw.x509.verify import verify_x509_anchor, get_issuer_from_x5c, is_der_format from pyeudiw.storage.exceptions import EntryNotFound @@ -74,7 +74,7 @@ def _update_chain(self, entity_id: str | None = None, exp: datetime | None = Non self.trust_chain = trust_chain def _handle_federation_chain(self): - _first_statement = unpad_jwt_payload(self.trust_chain[-1]) + _first_statement = decode_jwt_payload(self.trust_chain[-1]) trust_anchor_eid = self.trust_anchor or _first_statement.get( 'iss', None) @@ -92,7 +92,7 @@ def _handle_federation_chain(self): "a recognizable Trust Anchor." ) - decoded_ec = unpad_jwt_payload( + decoded_ec = decode_jwt_payload( trust_anchor['federation']['entity_configuration'] ) jwks = decoded_ec.get('jwks', {}).get('keys', []) @@ -209,7 +209,7 @@ def get_final_metadata(self, metadata_type: str, policies: list[dict]) -> dict: for policy in policies: policy_acc = combine(policy, policy_acc) - self.final_metadata = unpad_jwt_payload(self.trust_chain[0]) + self.final_metadata = decode_jwt_payload(self.trust_chain[0]) try: # TODO: there are some cases where the jwks are taken from a uri ... diff --git a/pyeudiw/trust/trust_chain.py b/pyeudiw/trust/trust_chain.py index fc074a64..bc9c0820 100644 --- a/pyeudiw/trust/trust_chain.py +++ b/pyeudiw/trust/trust_chain.py @@ -4,7 +4,7 @@ from cryptojwt import KeyJar from cryptojwt.jwt import utc_time_sans_frac -from pyeudiw.jwt.utils import unpad_jwt_payload +from pyeudiw.jwt.utils import decode_jwt_payload __author__ = "Roland Hedberg" __license__ = "Apache 2.0" From f3fae08fecc7baffc3f4552d86d17136c816bdca Mon Sep 17 00:00:00 2001 From: PascalDR Date: Mon, 11 Dec 2023 11:23:54 +0100 Subject: [PATCH 27/88] fix: added exception --- pyeudiw/jwt/exceptions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyeudiw/jwt/exceptions.py b/pyeudiw/jwt/exceptions.py index 2e059616..cec4a78c 100644 --- a/pyeudiw/jwt/exceptions.py +++ b/pyeudiw/jwt/exceptions.py @@ -1,2 +1,5 @@ class JWEDecryptionError(Exception): pass + +class JWTInvalidElementPosition(Exception): + pass \ No newline at end of file From d9a8f3e8f8a20d84fb183856ae809ea5ff9fbb46 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Mon, 11 Dec 2023 16:40:23 +0100 Subject: [PATCH 28/88] fix: refactored method find_jwk --- pyeudiw/federation/statements.py | 41 ++++++++++++++------- pyeudiw/federation/trust_chain_validator.py | 32 ++++------------ pyeudiw/jwk/__init__.py | 26 ++++++++++++- pyeudiw/jwk/exceptions.py | 2 + 4 files changed, 63 insertions(+), 38 deletions(-) diff --git a/pyeudiw/federation/statements.py b/pyeudiw/federation/statements.py index 596cb2b1..b74012fa 100644 --- a/pyeudiw/federation/statements.py +++ b/pyeudiw/federation/statements.py @@ -18,6 +18,7 @@ from pyeudiw.tools.utils import get_http_url from pydantic import ValidationError +from pyeudiw.jwk import find_jwk import json import logging @@ -165,15 +166,19 @@ def validate_by(self, ec: dict) -> bool: f"Trust Mark validation failed: " f"{e}" ) + + _kid = self.header["kid"] - - if self.header.get("kid") not in ec.kids: + if _kid not in ec.kids: raise UnknownKid( # pragma: no cover f"Trust Mark validation failed: " f"{self.header.get('kid')} not found in {ec.jwks}" ) + + _jwk = find_jwk(_kid, ec.jwks) + # verify signature - jwsh = JWSHelper(ec.jwks[ec.kids.index(self.header["kid"])]) + jwsh = JWSHelper(_jwk) payload = jwsh.verify(self.jwt) self.is_valid = True return payload @@ -189,13 +194,15 @@ def validate_by_its_issuer(self) -> bool: self.issuer_entity_configuration = get_entity_configurations( self.iss, self.httpc_params, False ) + + _kid = self.header.get('kid') try: ec = EntityStatement(self.issuer_entity_configuration[0]) ec.validate_by_itself() except UnknownKid: logger.warning( f"Trust Mark validation failed by its Issuer: " - f"{self.header.get('kid')} not found in " + f"{_kid} not found in " f"{self.issuer_entity_configuration.jwks}") return False except Exception: @@ -205,7 +212,8 @@ def validate_by_its_issuer(self) -> bool: return False # verify signature - jwsh = JWSHelper(ec.jwks[ec.kids.index(self.header["kid"])]) + _jwk = find_jwk(_kid, ec.jwks) + jwsh = JWSHelper(_jwk) payload = jwsh.verify(self.jwt) self.is_valid = True return payload @@ -300,11 +308,15 @@ def validate_by_itself(self) -> bool: f"{e}" ) - if self.header.get("kid") not in self.kids: + _kid = self.header.get("kid") + + if _kid not in self.kids: raise UnknownKid( - f"{self.header.get('kid')} not found in {self.jwks}") # pragma: no cover + f"{_kid} not found in {self.jwks}") # pragma: no cover + # verify signature - jwsh = JWSHelper(self.jwks[self.kids.index(self.header["kid"])]) + _jwk = find_jwk(_kid, self.jwks) + jwsh = JWSHelper(_jwk) jwsh.verify(self.jwt) self.is_valid = True return True @@ -520,12 +532,15 @@ def validate_descendant_statement(self, jwt: str) -> bool: f"{e}" ) - if header.get("kid") not in self.kids: + _kid = header.get("kid") + + if _kid not in self.kids: raise UnknownKid( - f"{self.header.get('kid')} not found in {self.jwks}") + f"{_kid} not found in {self.jwks}") # verify signature - jwsh = JWSHelper(self.jwks[self.kids.index(header["kid"])]) + _jwk = find_jwk(_kid, self.jwks) + jwsh = JWSHelper(_jwk) payload = jwsh.verify(jwt) self.verified_descendant_statements[payload["sub"]] = payload @@ -550,9 +565,9 @@ def validate_by_superior_statement(self, jwt: str, ec: 'EntityStatement') -> str ec.validate_by_itself() ec.validate_descendant_statement(jwt) _jwks = get_federation_jwks(payload) - _kids = [i.get("kid") for i in _jwks] + _jwk = find_jwk(self.header["kid"], _jwks) - jwsh = JWSHelper(_jwks[_kids.index(self.header["kid"])]) + jwsh = JWSHelper(_jwk) payload = jwsh.verify(self.jwt) is_valid = True diff --git a/pyeudiw/federation/trust_chain_validator.py b/pyeudiw/federation/trust_chain_validator.py index 23540fe4..6426d9b8 100644 --- a/pyeudiw/federation/trust_chain_validator.py +++ b/pyeudiw/federation/trust_chain_validator.py @@ -15,27 +15,10 @@ KeyValidationError ) -logger = logging.getLogger(__name__) - - -def find_jwk(kid: str, jwks: list[dict]) -> dict: - """ - Find the JWK with the indicated kid in the jwks list. - - :param kid: the identifier of the jwk - :type kid: str - :param jwks: the list of jwks - :type jwks: list[dict] +from pyeudiw.jwk import find_jwk +from pyeudiw.jwk.exceptions import KidNotFoundError, InvalidKid - :returns: the jwk with the indicated kid or an empty dict if no jwk is found - :rtype: dict - """ - if not kid: - return {} - for jwk in jwks: - valid_jwk = jwk.get("kid", None) - if valid_jwk and kid == valid_jwk: - return jwk +logger = logging.getLogger(__name__) class StaticTrustChainValidator: @@ -171,11 +154,12 @@ def validate(self) -> bool: for st in rev_tc[1:]: st_header = decode_jwt_header(st) st_payload = decode_jwt_payload(st) - jwk = find_jwk( - st_header.get("kid", None), fed_jwks - ) - if not jwk: + try: + jwk = find_jwk( + st_header.get("kid", None), fed_jwks + ) + except (KidNotFoundError, InvalidKid): return False jwsh = JWSHelper(jwk) diff --git a/pyeudiw/jwk/__init__.py b/pyeudiw/jwk/__init__.py index e3515dbd..5c4b6a2a 100644 --- a/pyeudiw/jwk/__init__.py +++ b/pyeudiw/jwk/__init__.py @@ -6,12 +6,13 @@ from cryptojwt.jwk.jwk import key_from_jwk_dict from cryptojwt.jwk.rsa import new_rsa_key +from .exceptions import InvalidKid, KidNotFoundError + KEY_TYPES_FUNC = dict( EC=new_ec_key, RSA=new_rsa_key ) - class JWK(): """ The class representing a JWK istance @@ -114,3 +115,26 @@ def as_dict(self) -> dict: def __repr__(self): # private part! return self.as_json() + +def find_jwk(kid: str, jwks: list[dict], as_dict: bool=True) -> dict | JWK: + """ + Find the JWK with the indicated kid in the jwks list. + + :param kid: the identifier of the jwk + :type kid: str + :param jwks: the list of jwks + :type jwks: list[dict] + :param as_dict: if True the return type will be a dict, JWK otherwise. + :type as_dict: bool + + :returns: the jwk with the indicated kid or an empty dict if no jwk is found + :rtype: dict | JWK + """ + if not kid: + raise InvalidKid("Kid cannot be empty") + for jwk in jwks: + valid_jwk = jwk.get("kid", None) + if valid_jwk and kid == valid_jwk: + return jwk if as_dict else JWK(jwk) + + raise KidNotFoundError(f"Key with Kid {kid} not found") \ No newline at end of file diff --git a/pyeudiw/jwk/exceptions.py b/pyeudiw/jwk/exceptions.py index 7f05a493..b3a84613 100644 --- a/pyeudiw/jwk/exceptions.py +++ b/pyeudiw/jwk/exceptions.py @@ -5,6 +5,8 @@ class KidError(Exception): class KidNotFoundError(Exception): pass +class InvalidKid(Exception): + pass class JwkError(Exception): pass From eb0dfda00d0b64edcf4646573bbdb7f90e147a15 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Mon, 11 Dec 2023 16:52:54 +0100 Subject: [PATCH 29/88] docs: fixed documentation --- pyeudiw/jwk/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyeudiw/jwk/__init__.py b/pyeudiw/jwk/__init__.py index 5c4b6a2a..0106d2e4 100644 --- a/pyeudiw/jwk/__init__.py +++ b/pyeudiw/jwk/__init__.py @@ -127,6 +127,9 @@ def find_jwk(kid: str, jwks: list[dict], as_dict: bool=True) -> dict | JWK: :param as_dict: if True the return type will be a dict, JWK otherwise. :type as_dict: bool + :raises InvalidKid: if kid is None. + :raises KidNotFoundError: if kid is not in jwks list. + :returns: the jwk with the indicated kid or an empty dict if no jwk is found :rtype: dict | JWK """ From 46ac8bbb86c41a2394e7cd2fd47db8f34d08286c Mon Sep 17 00:00:00 2001 From: PascalDR Date: Mon, 11 Dec 2023 16:54:35 +0100 Subject: [PATCH 30/88] fix: refactoring --- pyeudiw/jwt/utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/pyeudiw/jwt/utils.py b/pyeudiw/jwt/utils.py index d9a95785..3dff1efb 100644 --- a/pyeudiw/jwt/utils.py +++ b/pyeudiw/jwt/utils.py @@ -2,18 +2,21 @@ import json import re +from typing import Dict from pyeudiw.jwt.exceptions import JWTInvalidElementPosition +from pyeudiw.jwk import find_jwk # JWT_REGEXP = r"^(([-A-Za-z0-9\=_])*\.([-A-Za-z0-9\=_])*\.([-A-Za-z0-9\=_])*)$" JWT_REGEXP = r'^[\w\-]+\.[\w\-]+\.[\w\-]+' def decode_jwt_element(jwt: str, position: int) -> dict: - if position > 1: - raise JWTInvalidElementPosition(f"JWT has no jwt element in position {position}") + if position > 1 or position < 0: + raise JWTInvalidElementPosition(f"JWT has no element in position {position}") if isinstance(jwt, bytes): jwt = jwt.decode() + b = jwt.split(".")[position] padded = f"{b}{'=' * divmod(len(b), 4)[1]}" data = json.loads(base64.urlsafe_b64decode(padded)) @@ -36,10 +39,8 @@ def get_jwk_from_jwt(jwt: str, provider_jwks: dict) -> dict: kid = head["kid"] if isinstance(provider_jwks, dict) and provider_jwks.get('keys'): provider_jwks = provider_jwks['keys'] - for jwk in provider_jwks: - if jwk["kid"] == kid: - return jwk - return {} + + return find_jwk(kid, provider_jwks) def is_jwt_format(jwt: str) -> bool: From 3559531e668b2fcfa6a5c7d8be6bb350a40855fa Mon Sep 17 00:00:00 2001 From: PascalDR Date: Mon, 11 Dec 2023 16:54:55 +0100 Subject: [PATCH 31/88] docs: documented content of utils.py --- pyeudiw/jwt/utils.py | 57 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 55 insertions(+), 2 deletions(-) diff --git a/pyeudiw/jwt/utils.py b/pyeudiw/jwt/utils.py index 3dff1efb..b44e8a6e 100644 --- a/pyeudiw/jwt/utils.py +++ b/pyeudiw/jwt/utils.py @@ -11,6 +11,19 @@ def decode_jwt_element(jwt: str, position: int) -> dict: + """ + Decodes the element in a determinated position. + + :param jwt: a string that represents the jwt. + :type jwt: str + :param position: the position of segment to unpad. + :type position: int + + :raises JWTInvalidElementPosition: If the JWT element position is greather then one or less of 0 + + :returns: a dict with the content of the decoded section. + :rtype: dict + """ if position > 1 or position < 0: raise JWTInvalidElementPosition(f"JWT has no element in position {position}") @@ -24,16 +37,46 @@ def decode_jwt_element(jwt: str, position: int) -> dict: def decode_jwt_header(jwt: str) -> dict: + """ + Decodes the jwt header. + + :param jwt: a string that represents the jwt. + :type jwt: str + + :returns: a dict with the content of the decoded header. + :rtype: dict + """ return decode_jwt_element(jwt, position=0) def decode_jwt_payload(jwt: str) -> dict: + """ + Decodes the jwt payload. + + :param jwt: a string that represents the jwt. + :type jwt: str + + :returns: a dict with the content of the decoded payload. + :rtype: dict + """ return decode_jwt_element(jwt, position=1) -def get_jwk_from_jwt(jwt: str, provider_jwks: dict) -> dict: +def get_jwk_from_jwt(jwt: str, provider_jwks: Dict[str, dict]) -> dict: """ - docs here + Find the JWK inside the provider JWKs with the kid + specified in jwt header. + + :param jwt: a string that represents the jwt. + :type jwt: str + :param provider_jwks: a dictionary that contains one or more JWKs with the KID as the key. + :type provider_jwks: Dict[str, dict] + + :raises InvalidKid: if kid is None. + :raises KidNotFoundError: if kid is not in jwks list. + + :returns: the jwk as dict. + :rtype: dict """ head = decode_jwt_header(jwt) kid = head["kid"] @@ -44,5 +87,15 @@ def get_jwk_from_jwt(jwt: str, provider_jwks: dict) -> dict: def is_jwt_format(jwt: str) -> bool: + """ + Check if a string is in JWT format. + + :param jwt: a string that represents the jwt. + :type jwt: str + + :returns: True if the string is a JWT, False otherwise. + :rtype: bool + """ + res = re.match(JWT_REGEXP, jwt) return bool(res) From f7a85cce78c7966894d1fb0f5a1f5a3440ce067e Mon Sep 17 00:00:00 2001 From: PascalDR Date: Mon, 11 Dec 2023 17:21:09 +0100 Subject: [PATCH 32/88] docs: documented __init__.py content --- pyeudiw/oauth2/dpop/__init__.py | 42 +++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/pyeudiw/oauth2/dpop/__init__.py b/pyeudiw/oauth2/dpop/__init__.py index d478a9da..986efd44 100644 --- a/pyeudiw/oauth2/dpop/__init__.py +++ b/pyeudiw/oauth2/dpop/__init__.py @@ -22,7 +22,20 @@ class DPoPIssuer: + """ + Helper class for generate DPoP proofs. + """ def __init__(self, htu: str, token: str, private_jwk: dict): + """ + Generates an instance of DPoPIssuer. + + :param htu: a string representing the htu value. + :type htu: str + :param token: a string representing the token value. + :type token: str + :param private_jwk: a dict representing the private JWK of DPoP. + :type private_jwk: dict + """ self.token = token self.private_jwk = private_jwk self.signer = JWSHelper(private_jwk) @@ -30,6 +43,7 @@ def __init__(self, htu: str, token: str, private_jwk: dict): @property def proof(self): + """Returns the proof.""" data = { "jti": str(uuid.uuid4()), "htm": "GET", @@ -48,6 +62,10 @@ def proof(self): class DPoPVerifier: + """ + Helper class for validate DPoP proofs. + """ + dpop_header_prefix = 'DPoP ' def __init__( @@ -56,6 +74,19 @@ def __init__( http_header_authz: str, http_header_dpop: str, ): + """ + Generate an instance of DPoPVerifier. + + :param public_jwk: a dict representing the public JWK of DPoP. + :type public_jwk: dict + :param http_header_authz: a string representing the authz value. + :type http_header_authz: str + :param http_header_dpop: a string representing the DPoP value. + :type http_header_dpop: str + + :raises ValueError: if DPoP proof is not a valid JWT + + """ self.public_jwk = public_jwk self.dpop_token = ( http_header_authz.replace(self.dpop_header_prefix, '') @@ -89,9 +120,20 @@ def __init__( @property def is_valid(self) -> bool: + """Returns True if DPoP is valid.""" return self.validate() def validate(self) -> bool: + """ + Validates the content of DPoP. + + :raises InvalidDPoPKid: if the kid of DPoP is invalid. + :raises InvalidDPoPAth: if the header's JWK is different from public_jwk's one. + + :returns: True if the validation is correctly executed, False otherwise + :rtype: bool + """ + jws_verifier = JWSHelper(self.public_jwk) try: dpop_valid = jws_verifier.verify(self.proof) From f221ac16241b7f9964c49eeeccd425ef7e58066c Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 13 Dec 2023 12:19:48 +0100 Subject: [PATCH 33/88] fix: Resolved todo (what if the credential is not a JWT?) --- pyeudiw/openid4vp/vp.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pyeudiw/openid4vp/vp.py b/pyeudiw/openid4vp/vp.py index fcfb385b..6159959c 100644 --- a/pyeudiw/openid4vp/vp.py +++ b/pyeudiw/openid4vp/vp.py @@ -1,12 +1,14 @@ - -from pyeudiw.jwt.utils import decode_jwt_payload, decode_jwt_header +from .exceptions import InvalidVPToken +from pyeudiw.jwt.utils import decode_jwt_payload, decode_jwt_header, is_jwt_format from pyeudiw.openid4vp.vp_sd_jwt import VpSdJwt class Vp(VpSdJwt): def __init__(self, jwt: str): - # TODO: what if the credential is not a JWT? + if not is_jwt_format(jwt): + raise InvalidVPToken(f"VP is not in JWT format.") + self.headers = decode_jwt_header(jwt) self.jwt = jwt self.payload = decode_jwt_payload(jwt) From 052262ae03c7b866f1193e683cff83ac566d4b4a Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 13 Dec 2023 13:22:27 +0100 Subject: [PATCH 34/88] feat: implemented is_jwe_format and is_jws_format --- pyeudiw/jwt/utils.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/pyeudiw/jwt/utils.py b/pyeudiw/jwt/utils.py index b44e8a6e..cf0285ea 100644 --- a/pyeudiw/jwt/utils.py +++ b/pyeudiw/jwt/utils.py @@ -99,3 +99,40 @@ def is_jwt_format(jwt: str) -> bool: res = re.match(JWT_REGEXP, jwt) return bool(res) + +def is_jwe_format(jwt: str): + """ + Check if a string is in JWE format. + + :param jwt: a string that represents the jwt. + :type jwt: str + + :returns: True if the string is a JWE, False otherwise. + :rtype: bool + """ + + if not is_jwt_format(jwt): + return False + + header = decode_jwt_header(jwt) + + if header.get("enc", None) == None: + return False + + return True + +def is_jws_format(jwt: str): + """ + Check if a string is in JWS format. + + :param jwt: a string that represents the jwt. + :type jwt: str + + :returns: True if the string is a JWS, False otherwise. + :rtype: bool + """ + breakpoint() + if not is_jwt_format(jwt): + return False + + return not is_jwe_format(jwt) \ No newline at end of file From 2711ba824030c0f7c73f0f628710d4a0cba02d4c Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 13 Dec 2023 13:22:37 +0100 Subject: [PATCH 35/88] test: amplied test --- pyeudiw/tests/test_jwt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyeudiw/tests/test_jwt.py b/pyeudiw/tests/test_jwt.py index d0982098..2130d43d 100644 --- a/pyeudiw/tests/test_jwt.py +++ b/pyeudiw/tests/test_jwt.py @@ -3,7 +3,7 @@ from pyeudiw.jwk import JWK from pyeudiw.jwt import (DEFAULT_ENC_ALG_MAP, DEFAULT_ENC_ENC_MAP, JWEHelper, JWSHelper) -from pyeudiw.jwt.utils import decode_jwt_header +from pyeudiw.jwt.utils import decode_jwt_header, is_jwe_format, is_jws_format JWKs_EC = [ (JWK(key_type="EC"), {"key": "value"}), @@ -47,6 +47,7 @@ def test_jwe_helper_encrypt(jwk, payload): helper = JWEHelper(jwk) jwe = helper.encrypt(payload) assert jwe + assert is_jwe_format(jwe) @pytest.mark.parametrize("jwk, payload", JWKs_RSA) @@ -83,7 +84,6 @@ def test_jws_helper_sign(jwk, payload): jws = helper.sign(payload) assert jws - @pytest.mark.parametrize("jwk, payload", JWKs_RSA) def test_jws_helper_verify(jwk, payload): helper = JWSHelper(jwk) From 8a99ab114afe2619013be65ac85c365084c41f73 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 13 Dec 2023 13:38:53 +0100 Subject: [PATCH 36/88] fix: refactored code --- pyeudiw/jwt/__init__.py | 15 ++++++++++++--- pyeudiw/jwt/exceptions.py | 3 +++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/pyeudiw/jwt/__init__.py b/pyeudiw/jwt/__init__.py index ed7d4a13..cb9f0f7c 100644 --- a/pyeudiw/jwt/__init__.py +++ b/pyeudiw/jwt/__init__.py @@ -3,7 +3,6 @@ from typing import Union, Any import cryptojwt -from cryptojwt.exception import VerificationError from cryptojwt.jwe.jwe import factory from cryptojwt.jwe.jwe_ec import JWE_EC from cryptojwt.jwe.jwe_rsa import JWE_RSA @@ -14,6 +13,8 @@ from pyeudiw.jwk.exceptions import KidError from pyeudiw.jwt.utils import decode_jwt_header +from .exceptions import JWEDecryptionError, JWSVerificationError + DEFAULT_HASH_FUNC = "SHA-256" DEFAULT_SIG_KTY_MAP = { @@ -103,13 +104,15 @@ def decrypt(self, jwe: str) -> dict: :param jwe: A string representing the jwe. :type jwe: str + :raises JWEDecryptionError: if jwe field is not in a JWE Format + :returns: A dict that represents the payload of decrypted JWE. :rtype: dict """ try: jwe_header = decode_jwt_header(jwe) except (binascii.Error, Exception) as e: - raise VerificationError("The JWT is not valid") + raise JWEDecryptionError("Not a valid JWE format") _alg = jwe_header.get("alg") _enc = jwe_header.get("enc") @@ -186,12 +189,18 @@ def verify(self, jws: str, **kwargs) -> (str | Any | bytes): :type jws: str :param kwargs: Other optional fields to generate the JWE. + :raises JWSVerificationError: if jws field is not in a JWS Format + :returns: A string that represents the payload of JWS. :rtype: str """ _key = key_from_jwk_dict(self.jwk.as_dict()) _jwk_dict = self.jwk.as_dict() - _head = decode_jwt_header(jws) + + try: + _head = decode_jwt_header(jws) + except (binascii.Error, Exception) as e: + raise JWSVerificationError("Not a valid JWS format") if _head.get("kid"): if _head["kid"] != _jwk_dict["kid"]: # pragma: no cover diff --git a/pyeudiw/jwt/exceptions.py b/pyeudiw/jwt/exceptions.py index cec4a78c..f9428711 100644 --- a/pyeudiw/jwt/exceptions.py +++ b/pyeudiw/jwt/exceptions.py @@ -2,4 +2,7 @@ class JWEDecryptionError(Exception): pass class JWTInvalidElementPosition(Exception): + pass + +class JWSVerificationError(Exception): pass \ No newline at end of file From 9b54e9371205423e1b9dab028b2ec4d5bf68f0c6 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 13 Dec 2023 13:47:40 +0100 Subject: [PATCH 37/88] feat: resolved todo (detect if it is encrypted otherwise) --- pyeudiw/openid4vp/direct_post_response.py | 35 ++++++++++++++--------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/pyeudiw/openid4vp/direct_post_response.py b/pyeudiw/openid4vp/direct_post_response.py index 8a76c9c1..fa760c9d 100644 --- a/pyeudiw/openid4vp/direct_post_response.py +++ b/pyeudiw/openid4vp/direct_post_response.py @@ -1,9 +1,9 @@ from pyeudiw.jwk import JWK -from pyeudiw.jwt import JWEHelper +from pyeudiw.jwt import JWEHelper, JWSHelper from pyeudiw.jwt.exceptions import JWEDecryptionError from pyeudiw.jwk.exceptions import KidNotFoundError -from pyeudiw.jwt.utils import decode_jwt_header +from pyeudiw.jwt.utils import decode_jwt_header, is_jwe_format from pyeudiw.openid4vp.exceptions import ( VPNotFound, VPInvalidNonce, @@ -26,13 +26,7 @@ def __init__(self, jwt: str, jwks_by_kids: dict, nonce: str = ""): self.credentials_by_issuer: dict = {} self._claims_by_issuer: dict = {} - @property - def payload(self) -> dict: - # TODO: detect if it is encrypted otherwise ... - # here we support only the encrypted jwt - if not self._payload: - self.decrypt() - return self._payload + def _decode_payload(self) -> None: def decrypt(self) -> None: _kid = self.headers.get('kid', None) @@ -41,12 +35,13 @@ def decrypt(self) -> None: f"The JWT headers {self.headers} doesnt have any KID value" ) self.jwk = JWK(self.jwks_by_kids[_kid]) - jweHelper = JWEHelper(self.jwk) - try: + + if is_jwe_format(self.jwt): + jweHelper = JWEHelper(self.jwk) self._payload = jweHelper.decrypt(self.jwt) - except Exception as e: - _msg = f"Response decryption error: {e}" - raise JWEDecryptionError(_msg) + else: + jwsHelper = JWSHelper(self.jwk) + self._payload = jwsHelper.verify(self.jwt) def load_nonce(self, nonce: str): self.nonce = nonce @@ -95,3 +90,15 @@ def get_presentation_vps(self): self.credentials_by_issuer[cred_iss].append(_vp.payload['vp']) return self._vps + + @property + def vps(self): + if not self._vps: + self.get_presentation_vps() + return self._vps + + @property + def payload(self) -> dict: + if not self._payload: + self._decode_payload() + return self._payload \ No newline at end of file From 2c3dc7d0fa34d8e0710b371de7007f58e2122028 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 13 Dec 2023 13:50:15 +0100 Subject: [PATCH 38/88] fix: code refactoring --- pyeudiw/openid4vp/direct_post_response.py | 27 ++++++++++++++++------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/pyeudiw/openid4vp/direct_post_response.py b/pyeudiw/openid4vp/direct_post_response.py index fa760c9d..ec9ad883 100644 --- a/pyeudiw/openid4vp/direct_post_response.py +++ b/pyeudiw/openid4vp/direct_post_response.py @@ -1,4 +1,4 @@ - +from typing import Dict from pyeudiw.jwk import JWK from pyeudiw.jwt import JWEHelper, JWSHelper from pyeudiw.jwt.exceptions import JWEDecryptionError @@ -11,10 +11,10 @@ ) from pyeudiw.openid4vp.schemas.vp_token import VPTokenPayload, VPTokenHeader from pyeudiw.openid4vp.vp import Vp - +from pydantic import ValidationError class DirectPostResponse: - def __init__(self, jwt: str, jwks_by_kids: dict, nonce: str = ""): + def __init__(self, jwt: str, jwks_by_kids: Dict[str, dict], nonce: str = ""): self.headers = decode_jwt_header(jwt) self.jwks_by_kids = jwks_by_kids @@ -43,13 +43,13 @@ def decrypt(self) -> None: jwsHelper = JWSHelper(self.jwk) self._payload = jwsHelper.verify(self.jwt) - def load_nonce(self, nonce: str): + def load_nonce(self, nonce: str) -> None: self.nonce = nonce - def validate(self) -> bool: + def _validate_vp(self, vp: dict) -> bool: - # check nonces - for vp in self.get_presentation_vps(): + try: + # check nonce if self.nonce: if not vp.payload.get('nonce', None): raise NoNonceInVPToken() @@ -61,7 +61,17 @@ def validate(self) -> bool: ) VPTokenPayload(**vp.payload) VPTokenHeader(**vp.headers) + except ValidationError: + return False + return True + + def validate(self) -> bool: + + for vp in self.get_presentation_vps(): + if not self._validate_vp(vp): + return False + return True @property @@ -76,8 +86,9 @@ def get_presentation_vps(self): _vps = self.payload.get('vp_token', []) vps = [_vps] if isinstance(_vps, str) else _vps + if not vps: - raise VPNotFound("vp is null") + raise VPNotFound(f"Vps for response with nonce \"{self.nonce}\" are empty") for vp in vps: _vp = Vp(vp) From 61707f171cfe2a734c1eac090157d554310a0c24 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 13 Dec 2023 13:55:33 +0100 Subject: [PATCH 39/88] docs: documented content of direct_post_response.py --- pyeudiw/openid4vp/direct_post_response.py | 57 +++++++++++++++++++---- 1 file changed, 47 insertions(+), 10 deletions(-) diff --git a/pyeudiw/openid4vp/direct_post_response.py b/pyeudiw/openid4vp/direct_post_response.py index ec9ad883..c5406ed5 100644 --- a/pyeudiw/openid4vp/direct_post_response.py +++ b/pyeudiw/openid4vp/direct_post_response.py @@ -1,7 +1,6 @@ from typing import Dict from pyeudiw.jwk import JWK from pyeudiw.jwt import JWEHelper, JWSHelper -from pyeudiw.jwt.exceptions import JWEDecryptionError from pyeudiw.jwk.exceptions import KidNotFoundError from pyeudiw.jwt.utils import decode_jwt_header, is_jwe_format from pyeudiw.openid4vp.exceptions import ( @@ -14,8 +13,20 @@ from pydantic import ValidationError class DirectPostResponse: + """ + Helper class for generate Direct Post Response. + """ def __init__(self, jwt: str, jwks_by_kids: Dict[str, dict], nonce: str = ""): - + """ + Generate an instance of DirectPostResponse. + + :param jwt: a string that represents the jwt. + :type jwt: str + :param jwks_by_kids: a dictionary that contains one or more JWKs with the KID as the key. + :type jwks_by_kids: Dict[str, dict] + :param nonce: a string that represents the nonce. + :type nonce: str + """ self.headers = decode_jwt_header(jwt) self.jwks_by_kids = jwks_by_kids self.jwt = jwt @@ -27,8 +38,12 @@ def __init__(self, jwt: str, jwks_by_kids: Dict[str, dict], nonce: str = ""): self._claims_by_issuer: dict = {} def _decode_payload(self) -> None: + """ + Internally decrypts the content of the JWT. - def decrypt(self) -> None: + :raises JWSVerificationError: if jws field is not in a JWS Format + :raises JWEDecryptionError: if jwe field is not in a JWE Format + """ _kid = self.headers.get('kid', None) if not _kid: raise KidNotFoundError( @@ -44,10 +59,24 @@ def decrypt(self) -> None: self._payload = jwsHelper.verify(self.jwt) def load_nonce(self, nonce: str) -> None: + """ + Load a nonce string inside the body of response. + + :param nonce: a string that represents the nonce. + :type nonce: str + """ self.nonce = nonce def _validate_vp(self, vp: dict) -> bool: + """ + Validate a single Verifiable Presentation. + :param vp: the verifiable presentation to validate. + :type vp: str + + :returns: True if is valid, False otherwhise. + :rtype: bool + """ try: # check nonce if self.nonce: @@ -67,6 +96,12 @@ def _validate_vp(self, vp: dict) -> bool: def validate(self) -> bool: + """ + Validates all VPs inside JWT's body. + + :returns: True if all VP are valid, False otherwhise. + :rtype: bool + """ for vp in self.get_presentation_vps(): if not self._validate_vp(vp): @@ -74,13 +109,13 @@ def validate(self) -> bool: return True - @property - def vps(self): - if not self._vps: - self.get_presentation_vps() - return self._vps + def get_presentation_vps(self) -> list[dict]: + """ + Returns the presentation's verifiable presentations - def get_presentation_vps(self): + :returns: the list of vps. + :rtype: list[dict] + """ if self._vps: return self._vps @@ -103,13 +138,15 @@ def get_presentation_vps(self): return self._vps @property - def vps(self): + def vps(self) -> list[dict]: + """Returns the presentation's verifiable presentations""" if not self._vps: self.get_presentation_vps() return self._vps @property def payload(self) -> dict: + """Returns the decoded payload of presentation""" if not self._payload: self._decode_payload() return self._payload \ No newline at end of file From 084d1f2536ddb2747e6f51749a422770c54d8583 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 13 Dec 2023 13:55:50 +0100 Subject: [PATCH 40/88] fix: amplied error messages --- pyeudiw/jwt/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyeudiw/jwt/__init__.py b/pyeudiw/jwt/__init__.py index cb9f0f7c..f5b14f44 100644 --- a/pyeudiw/jwt/__init__.py +++ b/pyeudiw/jwt/__init__.py @@ -112,7 +112,7 @@ def decrypt(self, jwe: str) -> dict: try: jwe_header = decode_jwt_header(jwe) except (binascii.Error, Exception) as e: - raise JWEDecryptionError("Not a valid JWE format") + raise JWEDecryptionError(f"Not a valid JWE format for the following reason: {e}") _alg = jwe_header.get("alg") _enc = jwe_header.get("enc") @@ -200,7 +200,7 @@ def verify(self, jws: str, **kwargs) -> (str | Any | bytes): try: _head = decode_jwt_header(jws) except (binascii.Error, Exception) as e: - raise JWSVerificationError("Not a valid JWS format") + raise JWSVerificationError(f"Not a valid JWS format for the following reason: {e}") if _head.get("kid"): if _head["kid"] != _jwk_dict["kid"]: # pragma: no cover From 2ab400213f7d25d2734db1f39a1ff58c0ea1df61 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 13 Dec 2023 14:03:26 +0100 Subject: [PATCH 41/88] feat: resolved todo (automatic detection of the credential) --- pyeudiw/openid4vp/vp.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyeudiw/openid4vp/vp.py b/pyeudiw/openid4vp/vp.py index 6159959c..91622fe6 100644 --- a/pyeudiw/openid4vp/vp.py +++ b/pyeudiw/openid4vp/vp.py @@ -20,8 +20,7 @@ def __init__(self, jwt: str): self.disclosed_user_attributes: dict = {} def _detect_vp_type(self): - # TODO - automatic detection of the credential - return 'jwt' + return self.headers["typ"].lower() def get_credential_jwks(self): if not self.credential_jwks: From 2c2c80ea5500d63006b9c65e47b5df064fce9c3a Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 13 Dec 2023 15:54:10 +0100 Subject: [PATCH 42/88] docs: amplied the documentation --- pyeudiw/jwk/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyeudiw/jwk/__init__.py b/pyeudiw/jwk/__init__.py index 0106d2e4..e7fb4914 100644 --- a/pyeudiw/jwk/__init__.py +++ b/pyeudiw/jwk/__init__.py @@ -37,6 +37,8 @@ def __init__( :type hash_func: str :param ec_crv: a string that represents the curve to use with the instance. :type ec_crv: str + + :raises NotImplementedError: the key_type is not implemented """ kwargs = {} self.kid = "" From ebbf8a459d901d98d85ac151b66e7e712f892767 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 13 Dec 2023 15:55:38 +0100 Subject: [PATCH 43/88] fix: refactored code --- pyeudiw/openid4vp/vp.py | 10 +--------- pyeudiw/openid4vp/vp_sd_jwt.py | 16 ++++++++++++++-- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/pyeudiw/openid4vp/vp.py b/pyeudiw/openid4vp/vp.py index 91622fe6..da17e5cd 100644 --- a/pyeudiw/openid4vp/vp.py +++ b/pyeudiw/openid4vp/vp.py @@ -6,15 +6,7 @@ class Vp(VpSdJwt): def __init__(self, jwt: str): - if not is_jwt_format(jwt): - raise InvalidVPToken(f"VP is not in JWT format.") - - self.headers = decode_jwt_header(jwt) - self.jwt = jwt - self.payload = decode_jwt_payload(jwt) - - self.credential_headers: dict = {} - self.credential_payload: dict = {} + super().__init__(jwt) self.parse_digital_credential() self.disclosed_user_attributes: dict = {} diff --git a/pyeudiw/openid4vp/vp_sd_jwt.py b/pyeudiw/openid4vp/vp_sd_jwt.py index 7ba57a06..16c41c9a 100644 --- a/pyeudiw/openid4vp/vp_sd_jwt.py +++ b/pyeudiw/openid4vp/vp_sd_jwt.py @@ -1,16 +1,28 @@ +from typing import Dict from pyeudiw.jwk import JWK from pyeudiw.jwt import JWSHelper +from pyeudiw.jwt.utils import is_jwt_format, decode_jwt_header, decode_jwt_payload from pyeudiw.sd_jwt import verify_sd_jwt from pyeudiw.jwk.exceptions import KidNotFoundError class VpSdJwt: + def __init__(self, jwt: str): + if not is_jwt_format(jwt): + raise InvalidVPToken(f"VP is not in JWT format.") + + self.headers = decode_jwt_header(jwt) + self.jwt = jwt + self.payload = decode_jwt_payload(jwt) + + self.credential_headers: dict = {} + self.credential_payload: dict = {} def verify_sdjwt( self, - issuer_jwks_by_kid: dict = {} - ) -> dict: + issuer_jwks_by_kid: Dict[str, dict] = {} + ) -> bool: if not issuer_jwks_by_kid.get(self.credential_headers["kid"], None): raise KidNotFoundError( From 75437528ea20bf2e918f6e0f42d0284bcd80050e Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 13 Dec 2023 15:56:05 +0100 Subject: [PATCH 44/88] fix: added dependency --- pyeudiw/openid4vp/vp_sd_jwt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyeudiw/openid4vp/vp_sd_jwt.py b/pyeudiw/openid4vp/vp_sd_jwt.py index 16c41c9a..35192a93 100644 --- a/pyeudiw/openid4vp/vp_sd_jwt.py +++ b/pyeudiw/openid4vp/vp_sd_jwt.py @@ -6,6 +6,7 @@ from pyeudiw.jwk.exceptions import KidNotFoundError +from .exceptions import InvalidVPToken class VpSdJwt: def __init__(self, jwt: str): From b8a9b27d1ab3445b97231c2e0317926ea425d8b9 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 13 Dec 2023 15:57:05 +0100 Subject: [PATCH 45/88] docs: documented content of vp_sd_jwt.py --- pyeudiw/openid4vp/vp_sd_jwt.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/pyeudiw/openid4vp/vp_sd_jwt.py b/pyeudiw/openid4vp/vp_sd_jwt.py index 35192a93..a1115820 100644 --- a/pyeudiw/openid4vp/vp_sd_jwt.py +++ b/pyeudiw/openid4vp/vp_sd_jwt.py @@ -9,7 +9,18 @@ from .exceptions import InvalidVPToken class VpSdJwt: + """Class for SD-JWT Format""" + def __init__(self, jwt: str): + """ + Generates a VpSdJwt istance + + :param jwt: a string that represents the jwt. + :type jwt: str + + :raises InvalidVPToken: if the jwt field's value is not a JWT. + """ + if not is_jwt_format(jwt): raise InvalidVPToken(f"VP is not in JWT format.") @@ -24,7 +35,18 @@ def verify_sdjwt( self, issuer_jwks_by_kid: Dict[str, dict] = {} ) -> bool: + """ + Verifies a SDJWT. + + :param jwks_by_kids: a dictionary that contains one or more JWKs with the KID as the key. + :type jwks_by_kids: Dict[str, dict] + + :raises KidNotFoundError: if the needed kid is not inside the issuer_jwks_by_kid. + :raises NotImplementedError: the key_type of one or more JWK is not implemented. + :raises JWSVerificationError: if self.jwt field is not in a JWS Format. + :returns: True if is valid, False otherwise. + """ if not issuer_jwks_by_kid.get(self.credential_headers["kid"], None): raise KidNotFoundError( f"issuer jwks {issuer_jwks_by_kid} doesn't contain " From 91c5952e7557b663e9714928909d8514c571e83f Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 13 Dec 2023 16:01:37 +0100 Subject: [PATCH 46/88] fix: refactored code --- pyeudiw/openid4vp/vp.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/pyeudiw/openid4vp/vp.py b/pyeudiw/openid4vp/vp.py index da17e5cd..925df3ea 100644 --- a/pyeudiw/openid4vp/vp.py +++ b/pyeudiw/openid4vp/vp.py @@ -1,36 +1,36 @@ -from .exceptions import InvalidVPToken -from pyeudiw.jwt.utils import decode_jwt_payload, decode_jwt_header, is_jwt_format +from pyeudiw.jwt.utils import decode_jwt_payload, decode_jwt_header from pyeudiw.openid4vp.vp_sd_jwt import VpSdJwt class Vp(VpSdJwt): - def __init__(self, jwt: str): + def __init__(self, jwt: str) -> None: super().__init__(jwt) self.parse_digital_credential() self.disclosed_user_attributes: dict = {} - def _detect_vp_type(self): + def _detect_vp_type(self) -> str: return self.headers["typ"].lower() - def get_credential_jwks(self): + def get_credential_jwks(self) -> list[dict]: if not self.credential_jwks: return {} return self.credential_jwks - @property - def credential_issuer(self): - if not self.credential_payload.get('iss', None): - self.parse_digital_credential() - return self.credential_payload.get('iss', None) - - def parse_digital_credential(self): + def parse_digital_credential(self) -> None: _typ = self._detect_vp_type() - if _typ == 'jwt': - self.credential_headers = decode_jwt_header(self.payload['vp']) - self.credential_payload = decode_jwt_payload(self.payload['vp']) - else: + + if _typ != 'jwt': raise NotImplementedError( f"VP Digital credentials type not implemented yet: {_typ}" ) + + self.credential_headers = decode_jwt_header(self.payload['vp']) + self.credential_payload = decode_jwt_payload(self.payload['vp']) + + @property + def credential_issuer(self) -> str: + if not self.credential_payload.get('iss', None): + self.parse_digital_credential() + return self.credential_payload.get('iss', None) \ No newline at end of file From b89312a108fe4325acfe7ff745c110a417e23144 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 13 Dec 2023 16:08:28 +0100 Subject: [PATCH 47/88] docs: documented content of vp.py --- pyeudiw/openid4vp/vp.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/pyeudiw/openid4vp/vp.py b/pyeudiw/openid4vp/vp.py index 925df3ea..74aaf261 100644 --- a/pyeudiw/openid4vp/vp.py +++ b/pyeudiw/openid4vp/vp.py @@ -3,22 +3,47 @@ class Vp(VpSdJwt): - + "Class for SD-JWT Format" def __init__(self, jwt: str) -> None: + """ + Generates a VP istance. + + :param jwt: a string that represents the jwt. + :type jwt: str + + :raises InvalidVPToken: if the jwt field's value is not a JWT. + """ super().__init__(jwt) self.parse_digital_credential() self.disclosed_user_attributes: dict = {} def _detect_vp_type(self) -> str: + """ + Detects and return the type of verifiable presentation. + + :returns: the type of VP. + :rtype: str + """ return self.headers["typ"].lower() def get_credential_jwks(self) -> list[dict]: + """ + Returns the credential JWKs. + + :returns: the list containing credential's JWKs. + :rtype: list[dict] + """ if not self.credential_jwks: return {} return self.credential_jwks def parse_digital_credential(self) -> None: + """ + Parse the digital credential of VP. + + :raises NotImplementedError: if VP Digital credentials type not implemented. + """ _typ = self._detect_vp_type() if _typ != 'jwt': @@ -31,6 +56,7 @@ def parse_digital_credential(self) -> None: @property def credential_issuer(self) -> str: + """Returns the credential issuer""" if not self.credential_payload.get('iss', None): self.parse_digital_credential() return self.credential_payload.get('iss', None) \ No newline at end of file From 86ebd730686441f7ff1ee5fabc42b3ce89ae9534 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Thu, 14 Dec 2023 18:03:22 +0100 Subject: [PATCH 48/88] fix: refactoring for better redability --- pyeudiw/satosa/backend.py | 694 +++++++++++++---------------------- pyeudiw/satosa/exceptions.py | 6 + pyeudiw/satosa/http_error.py | 41 +++ 3 files changed, 306 insertions(+), 435 deletions(-) create mode 100644 pyeudiw/satosa/http_error.py diff --git a/pyeudiw/satosa/backend.py b/pyeudiw/satosa/backend.py index c8131646..5a33176f 100644 --- a/pyeudiw/satosa/backend.py +++ b/pyeudiw/satosa/backend.py @@ -34,9 +34,14 @@ from pyeudiw.storage.db_engine import DBEngine from pyeudiw.storage.exceptions import StorageWriteError from pyeudiw.federation.schemas.wallet_relying_party import WalletRelyingParty +from pyeudiw.openid4vp.vp import Vp +from typing import Callable from pydantic import ValidationError +from .http_error import HTTPErrorHandler +from .exceptions import HTTPError, EmptyHTTPError + logger = logging.getLogger(__name__) @@ -45,22 +50,28 @@ class OpenID4VPBackend(BackendModule, BackendTrust, BackendDPoP): A backend module (acting as a OpenID4VP SP). """ - def __init__(self, auth_callback_func, internal_attributes, config, base_url, name): + def __init__( + self, + auth_callback_func: Callable[[Context, InternalData], Response], + internal_attributes: dict[str, dict[str, str | list[str]]], + config: dict[str, dict[str, str] | list[str]], + base_url: str, + name: str + ) -> None: """ OpenID4VP backend module. :param auth_callback_func: Callback should be called by the module after the authorization in the backend is done. + :type auth_callback_func: Callable[[Context, InternalData], Response] :param internal_attributes: Mapping dictionary between SATOSA internal attribute names and the names returned by underlying IdP's/OP's as well as what attributes the calling SP's and RP's expects namevice. + :type internal_attributes: dict[str, dict[str, str | list[str]]] :param config: Configuration parameters for the module. - :param base_url: base url of the service - :param name: name of the plugin - :type auth_callback_func: - (satosa.context.Context, satosa.internal.InternalData) -> satosa.response.Response - :type internal_attributes: dict[string, dict[str, str | list[str]]] :type config: dict[str, dict[str, str] | list[str]] + :param base_url: base url of the service :type base_url: str + :param name: name of the plugin :type name: str """ super().__init__(auth_callback_func, internal_attributes, base_url, name) @@ -69,10 +80,8 @@ def __init__(self, auth_callback_func, internal_attributes, config, base_url, na WalletRelyingParty(**config['metadata']) except ValidationError as e: logger.warning( - """ - The backend configuration presents the following validation issues: - {} - """.format(logger.warning(e))) + f"""The backend configuration presents the following validation issues: + {logger.warning(e)}""") self.config = config self.client_id = self.config['metadata']['client_id'] @@ -99,47 +108,8 @@ def __init__(self, auth_callback_func, internal_attributes, config, base_url, na self._render_metadata_conf_elements() self.init_trust_resources() - logger.debug( - lu.LOG_FMT.format( - id="OpenID4VP init", - message=f"Loaded configuration: {json.dumps(config)}" - ) - ) - - @property - def db_engine(self) -> DBEngine: - - try: - self._db_engine.is_connected - except Exception as e: - if getattr(self, '_db_engine', None): - logger.debug( - lu.LOG_FMT.format( - id="OpenID4VP db storage handling", - message=f"connection check silently fails and get restored: {e}" - ) - ) - self._db_engine = DBEngine(self.config["storage"]) - - return self._db_engine - - def _render_metadata_conf_elements(self) -> None: - for k, v in self.config['metadata'].items(): - if isinstance(v, (int, float, dict, list)): - continue - if not v or len(v) == 0: - continue - if all(( - v[0] == '<', - v[-1] == '>', - '.' in v - )): - conf_section, conf_k = v[1:-1].split('.') - self.config['metadata'][k] = self.config[conf_section][conf_k] - - @property - def default_metadata_private_jwk(self) -> tuple: - return tuple(self.metadata_jwks_by_kids.values())[0] + self.http_error_handler = HTTPErrorHandler("templates", "error.html", self._log) + self._log_debug("OpenID4VP init", f"Loaded configuration: {json.dumps(config)}") def register_endpoints(self) -> list: """ @@ -171,7 +141,7 @@ def register_endpoints(self) -> list: self.absolute_status_url = _endpoint return url_map - def start_auth(self, context, internal_request): + def start_auth(self, context: Context, internal_request): """ This is the start up function of the backend authorization. @@ -185,25 +155,9 @@ def start_auth(self, context, internal_request): """ return self.pre_request_endpoint(context, internal_request) - def _log(self, context: Context, level: str, message: str) -> None: - log_level = getattr(logger, level) - log_level( - lu.LOG_FMT.format( - id=lu.get_session_id(context.state), - message=message - ) - ) - - def pre_request_endpoint(self, context, internal_request, **kwargs): + def pre_request_endpoint(self, context: Context, internal_request, **kwargs): - self._log( - context, - level='debug', - message=( - "[INCOMING REQUEST] pre_request_endpoint with Context: " - f"{context.__dict__} and internal_request: {internal_request}" - ) - ) + self._log_function_debug("pre_request_endpoint", context, "internal_request", internal_request) session_id = context.state["SESSION_ID"] state = str(uuid.uuid4()) @@ -218,28 +172,14 @@ def pre_request_endpoint(self, context, internal_request, **kwargs): session_id=session_id ) except (StorageWriteError) as e: - _msg = ( - f"Error while initializing session with state {state} and {session_id}." - ) - logger.error(f"{_msg} for the following reason {e}") - return self.handle_error( - context, - message="server_error", - troubleshoot=f"{_msg}", - err=f"{_msg}. {e.__class__.__name__}: {e}", - err_code="500" - ) + _msg = f"Error while initializing session with state {state} and {session_id}." + self._log_error(context, f"{_msg} for the following reason {e}") + return self._handle_500(context, _msg, e) + except (Exception) as e: - _msg = ( - f"Error while initializing session with state {state} and {session_id}. " - ) - return self.handle_error( - context, - message="server_error", - troubleshoot=f"{_msg}", - err=f"{_msg}. {e.__class__.__name__}: {e}", - err_code="500" - ) + _msg = f"Error while initializing session with state {state} and {session_id}." + self._log_error(context, _msg) + return self._handle_500(context, _msg, e) # PAR payload = { @@ -269,220 +209,76 @@ def pre_request_endpoint(self, context, internal_request, **kwargs): ) return Response(result, content="text/html; charset=utf8", status="200") - def _translate_response(self, response: dict, issuer: str, context: Context): - """ - Translates wallet response to SATOSA internal response. - :type response: dict[str, str] - :type issuer: str - :type subject_type: str - :rtype: InternalData - :param response: Dictioary with attribute name as key. - :param issuer: The oidc op that gave the repsonse. - :param subject_type: public or pairwise according to oidc standard. - :return: A SATOSA internal response. - """ - # it may depends by credential type and attested security context evaluated - # if WIA was previously submitted by the Wallet + def redirect_endpoint(self, context: Context, *args): - timestamp_epoch = ( - response.get("auth_time") - or response.get("iat") - or iat_now() - ) - timestamp_dt = datetime.datetime.fromtimestamp( - timestamp_epoch, - datetime.timezone.utc - ) - timestamp_iso = timestamp_dt.isoformat().replace("+00:00", "Z") - - auth_class_ref = ( - response.get("acr") or - response.get("amr") or - self.config["authorization"]["default_acr_value"] - ) - auth_info = AuthenticationInformation( - auth_class_ref, timestamp_iso, issuer) - - # TODO - ACR values - internal_resp = InternalData(auth_info=auth_info) - - sub = "" - pepper = self.config.get("user_attributes", {})['subject_id_random_value'] - for i in self.config.get("user_attributes", {}).get("unique_identifiers", []): - if response.get(i): - _sub = response[i] - sub = hashlib.sha256( - f"{_sub}~{pepper}".encode( - ) - ).hexdigest() - break - - if not sub: - self._log( - context, - level='warning', - message=( - "[USER ATTRIBUTES] Missing subject id from OpenID4VP presentation " - "setting a random one for interop for internal frontends" - ) - ) - sub = hashlib.sha256( - f"{json.dumps(response).encode()}~{pepper}".encode() - ).hexdigest() - - response["sub"] = [sub] - internal_resp.attributes = self.converter.to_internal( - "openid4vp", response - ) - internal_resp.subject_id = sub - return internal_resp - - @property - def server_url(self): - return ( - self.base_url[:-1] - if self.base_url[-1] == '/' - else self.base_url - ) + self._log_function_debug("redirect_endpoint", context, "args", args) - def redirect_endpoint(self, context, *args): - self._log( - context, - level='debug', - message=( - "[INCOMING REQUEST] redirect_endpoint with Context: " - f"{context.__dict__} and args: {args}" - ) - ) if context.request_method.lower() != 'post': # raise BadRequestError("HTTP Method not supported") - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot="HTTP Method not supported", - err_code="400" - ) + return self._handle_400(context, "HTTP Method not supported") _endpoint = f'{self.server_url}{context.request_uri}' if self.config["metadata"].get('redirect_uris', None): if _endpoint not in self.config["metadata"]['redirect_uris']: - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot="request_uri not valid", - err_code="400" - ) + return self._handle_400(context, "request_uri not valid") # take the encrypted jwt, decrypt with my public key (one of the metadata) -> if not -> exception jwt = context.request.get("response", None) if not jwt: - _msg = f"Response error, missing JWT" - self._log(context, level='error', message=_msg) - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot=_msg, - err_code="400" - ) + self._log_error(context, f"Response error, missing JWT") + return self._handle_400(context, _msg) try: vpt = DirectPostResponse(jwt, self.metadata_jwks_by_kids) - self._log( - context, - level='debug', - message=( - f"Redirect uri endpoint Response using direct post contains: {vpt.payload}" - ) - ) + + debug_message = f"Redirect uri endpoint Response using direct post contains: {vpt.payload}" + self._log_debug(context, debug_message) + ResponseSchema(**vpt.payload) except Exception as e: _msg = f"DirectPostResponse parse and validation error: {e}" - self._log(context, level='error', message=_msg) - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot=_msg, - err_code="400", - err=f"Error:{e}, with JWT: {jwt}" - ) + self._log_error(context, _msg) + return self._handle_400(context, _msg, HTTPError(f"Error:{e}, with JWT: {jwt}")) # state MUST be present in the response since it was in the request # then do lookup on the db -> if not -> exception state = vpt.payload.get("state", None) if not state: - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot="state not found in the response", - err_code="400", - err=f"{_msg} with: {vpt.payload}" - ) + return self._handle_400(context, _msg, HTTPError(f"{_msg} with: {vpt.payload}")) try: stored_session = self.db_engine.get_by_state(state=state) except Exception as e: _msg = f"Session lookup by state value failed" - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot=_msg, - err=f"{e.__class__.__name__}: {e}", - err_code="400" - ) + return self._handle_400(context, _msg, e) if stored_session["finalized"]: _msg = f"Session already finalized" - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot=_msg, - err=_msg, - err_code="400" - ) + return self._handle_400(context, _msg, HTTPError(_msg)) try: vpt.load_nonce(stored_session['nonce']) - vps: list = vpt.get_presentation_vps() + vps: list[Vp] = vpt.get_presentation_vps() vpt.validate() + except VPNotFound as e: _msg = "Error while retrieving VP. Payload 'vp_token' is empty or has an unexpected value." - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot=_msg, - err=f"{e.__class__.__name__}: {e}", - err_code="400" - ) + return self._handle_400(context, _msg, e) + except NoNonceInVPToken as e: _msg = "Error while validating VP: vp has no nonce." - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot=_msg, - err=f"{e.__class__.__name__}: {e}", - err_code="400" - ) + return self._handle_400(context, _msg, e) + except VPInvalidNonce as e: _msg = "Error while validating VP: unexpected value." - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot=_msg, - err=f"{e.__class__.__name__}: {e}", - err_code="400" - ) + return self._handle_400(context, _msg, e) + except Exception as e: _msg = ( "DirectPostResponse content parse and validation error. " - f"Single VPs are faulty." - ) - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot=_msg, - err=f"{e.__class__.__name__}: {e}", - err_code="400" + "Single VPs are faulty." ) + return self._handle_400(context, _msg, e) # evaluate the trust to each credential issuer found in the vps # look for trust chain or x509 or do discovery! @@ -500,51 +296,40 @@ def redirect_endpoint(self, context, *args): tchelper = self._validate_trust(context, vp.payload['vp']) if not tchelper.is_trusted: - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot=f"Trust Evaluation failed for {tchelper.entity_id}", - err_code="400" - ) + return self._handle_400(context, f"Trust Evaluation failed for {tchelper.entity_id}") # TODO: generalyze also for x509 - vp.credential_jwks = tchelper.get_trusted_jwks( + credential_jwks = tchelper.get_trusted_jwks( metadata_type='openid_credential_issuer' ) + vp.set_credential_jwks(credential_jwks) except InvalidVPToken: - return self.handle_error(context=context, message="invalid_request", troubleshoot=f"Cannot validate VP: {vp.jwt}", err_code="400") + return self._handle_400(context, f"Cannot validate VP: {vp.jwt}") except ValidationError as e: - return self.handle_error(context=context, message="invalid_request", troubleshoot=f"Error validating schemas: {e}", err_code="400") + return self._handle_400(context, f"Error validating schemas: {e}") except KIDNotFound as e: - return self.handle_error(context=context, message="invalid_request", troubleshoot=f"Kid error: {e}", err_code="400") + return self._handle_400(context, f"Kid error: {e}") except NotTrustedFederationError as e: - return self.handle_error(context=context, message="invalid_request", troubleshoot=f"Not trusted federation error: {e}", err_code="400") + return self._handle_400(context, f"Not trusted federation error: {e}") except Exception as e: - return self.handle_error(context=context, message="invalid_request", troubleshoot=f"VP parsing error: {e}", err_code="400") + return self._handle_400(context, f"VP parsing error: {e}") # the trust is established to the credential issuer, then we can get the disclosed user attributes # TODO - what if the credential is different from sd-jwt? -> generalyze within Vp class try: vp.verify_sdjwt( - issuer_jwks_by_kid={ + issuer_jwks_by_kid = { i['kid']: i for i in vp.credential_jwks} ) except Exception as e: - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot=f"VP SD-JWT validation error: {e}", - err_code="400" - ) + return self._handle_400(context, f"VP SD-JWT validation error: {e}") # vp.result attributes_by_issuers[vp.credential_issuer] = vp.disclosed_user_attributes - self._log( - context, - level='debug', - message=f"Disclosed user attributes from {vp.credential_issuer}: {vp.disclosed_user_attributes}" - ) + + debug_message = f"Disclosed user attributes from {vp.credential_issuer}: {vp.disclosed_user_attributes}" + self._log_debug(context, debug_message) # TODO: check the revocation of the credential # ... @@ -555,10 +340,7 @@ def redirect_endpoint(self, context, *args): for i in attributes_by_issuers.values(): all_user_attributes.update(**i) - self._log( - context, level='debug', - message=f"Wallet disclosure: {all_user_attributes}" - ) + self._log_debug(context, f"Wallet disclosure: {all_user_attributes}") # TODO: not sure that we want these issuers in the following form ... please recheck. _info = {"issuer": ';'.join(cred_issuers)} @@ -574,25 +356,12 @@ def redirect_endpoint(self, context, *args): self.db_engine.set_finalized(stored_session['document_id']) if logger.getEffectiveLevel() == logging.DEBUG: stored_session = self.db_engine.get_by_state(state=state) - self._log( - context, - level="debug", - message=f"Session update on storage: {stored_session}" - ) + self._log_debug(context, f"Session update on storage: {stored_session}") + except StorageWriteError as e: # TODO - do we have to block in the case the update cannot be done? - self._log( - context, - level="error", - message=f"Session update on storage failed: {e}" - ) - return self.handle_error( - context=context, - message="server_error", - troubleshoot=f"Cannot update response object.", - err=f"{e.__class__.__name__}: {e}", - err_code="500" - ) + self._log_error(context, f"Session update on storage failed: {e}") + return self._handle_500(context, f"Cannot update response object.", e) if stored_session['session_id'] == context.state["SESSION_ID"]: # Same device flow @@ -608,16 +377,10 @@ def redirect_endpoint(self, context, *args): status="200" ) - def request_endpoint(self, context, *args): + def request_endpoint(self, context: Context, *args): + + self._log_function_debug("request_endpoint", context, "args", args) - self._log( - context, - level='debug', - message=( - "[INCOMING REQUEST] request_endpoint with Context: " - f"{context.__dict__} and args: {args}" - ) - ) # check DPOP for WIA if any try: dpop_validation_error: JsonResponse = self._request_endpoint_dpop( @@ -626,16 +389,8 @@ def request_endpoint(self, context, *args): if dpop_validation_error: return dpop_validation_error except Exception as e: - _msg = ( - f"[DPoP VALIDATION ERROR] WIA evalution error: {e}." - ) - return self.handle_error( - context=context, - message="invalid_client", - troubleshoot=_msg, - err=f"{e} with {context.__dict__}", - err_code="401" - ) + _msg = f"[DPoP VALIDATION ERROR] WIA evalution error: {e}." + return self._handle_401(context, _msg, e) try: state = context.qs_params["id"] @@ -644,13 +399,7 @@ def request_endpoint(self, context, *args): "Error while retrieving id from qs_params: " f"{e.__class__.__name__}: {e}" ) - return self.handle_error( - context, - message="invalid_request", - troubleshoot=_msg, - err=f"{e} with {context.__dict__}", - err_code="400" - ) + return self._handle_400(context, _msg, HTTPError(f"{e} with {context.__dict__}")) data = { "scope": ' '.join(self.config['authorization']['scopes']), @@ -672,12 +421,7 @@ def request_endpoint(self, context, *args): attestation = context.http_headers['HTTP_AUTHORIZATION'] except KeyError as e: _msg = f"Error while accessing http headers: {e}" - return self.handle_error( - context, - message="invalid_request", - err=f"{e} with {context.__dict__}", - err_code="400" - ) + return self._handle_400(context, _msg, HTTPError(f"{e} with {context.__dict__}")) # take the session created in the pre-request authz endpoint try: @@ -687,25 +431,14 @@ def request_endpoint(self, context, *args): document_id, dpop_proof, attestation ) self.db_engine.update_request_object(document_id, data) + except ValueError as e: - _msg = ( - "Error while retrieving request object from database." - ) - return self.handle_error( - context, - message="server_error", - troubleshoot=_msg, - err=f"{e} with {context.__dict__}", - err_code="500" - ) + _msg = "Error while retrieving request object from database." + return self._handle_500(context, _msg, HTTPError(f"{e} with {context.__dict__}")) + except (Exception, BaseException) as e: _msg = f"Error while updating request object: {e}" - return self.handle_error( - context, - message="server_error", - err=_msg, - err_code="500" - ) + return self._handle_500(context, _msg, e) helper = JWSHelper(self.default_metadata_private_jwk) @@ -719,44 +452,9 @@ def request_endpoint(self, context, *args): status="200" ) - def handle_error( - self, - context: dict, - message: str, - troubleshoot: str = "", - err="", - err_code="500", - template_path="templates", - error_template="error.html", - level="error" - ): - - _msg = f"{message}:" - if err: - _msg += f" {err}." - self._log( - context, level=level, - message=f"{_msg} {troubleshoot}" - ) + def get_response_endpoint(self, context: Context): - return JsonResponse( - { - "error": message, - "error_description": troubleshoot - }, - status=err_code - ) - - def get_response_endpoint(self, context): - - self._log( - context, - level='debug', - message=( - "[INCOMING REQUEST] get_response_endpoint with Context: " - f"{context.__dict__}" - ) - ) + self._log_function_debug("get_response_endpoint", context) state = context.qs_params.get("id", None) session_id = context.state["SESSION_ID"] @@ -775,30 +473,15 @@ def get_response_endpoint(self, context): ) except Exception as e: _msg = f"Error while retrieving session by state {state} and session_id {session_id}: {e}" - return self.handle_error( - context=context, - message="invalid_client", - troubleshoot=_msg, - err_code="401" - ) + return self._handle_401(context, _msg, e) if not finalized_session: - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot="session not found or invalid", - err_code="400" - ) + return self._handle_400(context, "session not found or invalid") _now = iat_now() _exp = finalized_session['request_object']['exp'] if _exp < _now: - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot=f"session expired, request object exp is {_exp} while now is {_now}", - err_code="400" - ) + return self._handle_400(context, f"session expired, request object exp is {_exp} while now is {_now}") internal_response = InternalData() resp = internal_response.from_dict( @@ -810,16 +493,9 @@ def get_response_endpoint(self, context): resp ) - def status_endpoint(self, context): + def status_endpoint(self, context: Context): - self._log( - context, - level='debug', - message=( - "[INCOMING REQUEST] state_endpoint with Context: " - f"{context.__dict__}" - ) - ) + self._log_function_debug("status_endpoint", context) session_id = context.state["SESSION_ID"] _err_msg = "" @@ -833,12 +509,7 @@ def status_endpoint(self, context): _err_msg = f"No id found in qs_params: {e}" if _err_msg: - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot=_err_msg, - err_code="400" - ) + return self._handle_400(context, _err_msg) try: session = self.db_engine.get_by_state_and_session_id( @@ -846,22 +517,12 @@ def status_endpoint(self, context): ) except Exception as e: _msg = f"Error while retrieving session by state {state} and session_id {session_id}: {e}" - return self.handle_error( - context=context, - message="invalid_client", - troubleshoot=_msg, - err_code="401" - ) + return self._handle_401(context, _msg) request_object = session.get("request_object", None) if request_object: if iat_now() > request_object["exp"]: - return self.handle_error( - context=context, - message="expired", - troubleshoot=f"Request object expired", - err_code="403" - ) + return self._handle_403("expired", f"Request object expired") if session["finalized"]: # return Redirect( @@ -880,3 +541,166 @@ def status_endpoint(self, context): }, status="201" ) + + def _render_metadata_conf_elements(self) -> None: + """Renders the elements of config's metadata""" + for k, v in self.config['metadata'].items(): + if isinstance(v, (int, float, dict, list)): + continue + if not v or len(v) == 0: + continue + if all(( + v[0] == '<', + v[-1] == '>', + '.' in v + )): + conf_section, conf_k = v[1:-1].split('.') + self.config['metadata'][k] = self.config[conf_section][conf_k] + + def _log(self, context: str | Context, level: str, message: str) -> None: + context = context if isinstance(context, str) else context.state + + log_level = getattr(logger, level) + log_level( + lu.LOG_FMT.format( + id=lu.get_session_id(context), + message=message + ) + ) + + def _log_debug(self, context: str | Context, message: str) -> None: + self._log(context, "debug", message) + + def _log_function_debug(self, fn_name: str, context: Context, args_name: str | None = None, args = None) -> None: + args_str = f" and {args_name}: {args}" if not args_name else "" + + debug_message = ( + f"[INCOMING REQUEST] {fn_name} with Context: " + f"{context.__dict__}{args_str}" + ) + self._log_debug(context, debug_message) + + def _log_error(self, context: str | Context, message: str) -> None: + self._log(context, "error", message) + + def _translate_response(self, response: dict, issuer: str, context: Context): + """ + Translates wallet response to SATOSA internal response. + :type response: dict[str, str] + :type issuer: str + :type subject_type: str + :rtype: InternalData + :param response: Dictioary with attribute name as key. + :param issuer: The oidc op that gave the repsonse. + :param subject_type: public or pairwise according to oidc standard. + :return: A SATOSA internal response. + """ + # it may depends by credential type and attested security context evaluated + # if WIA was previously submitted by the Wallet + + timestamp_epoch = ( + response.get("auth_time") + or response.get("iat") + or iat_now() + ) + timestamp_dt = datetime.datetime.fromtimestamp( + timestamp_epoch, + datetime.timezone.utc + ) + timestamp_iso = timestamp_dt.isoformat().replace("+00:00", "Z") + + auth_class_ref = ( + response.get("acr") or + response.get("amr") or + self.config["authorization"]["default_acr_value"] + ) + auth_info = AuthenticationInformation( + auth_class_ref, timestamp_iso, issuer) + + # TODO - ACR values + internal_resp = InternalData(auth_info=auth_info) + + sub = "" + pepper = self.config.get("user_attributes", {})['subject_id_random_value'] + for i in self.config.get("user_attributes", {}).get("unique_identifiers", []): + if response.get(i): + _sub = response[i] + sub = hashlib.sha256( + f"{_sub}~{pepper}".encode( + ) + ).hexdigest() + break + + if not sub: + self._log( + context, + level='warning', + message=( + "[USER ATTRIBUTES] Missing subject id from OpenID4VP presentation " + "setting a random one for interop for internal frontends" + ) + ) + sub = hashlib.sha256( + f"{json.dumps(response).encode()}~{pepper}".encode() + ).hexdigest() + + response["sub"] = [sub] + internal_resp.attributes = self.converter.to_internal( + "openid4vp", response + ) + internal_resp.subject_id = sub + return internal_resp + + def _handle_500(self, context, msg: str, err: Exception): + return self.http_error_handler.handle500( + context=context, + troubleshoot=f"{msg}", + err=f"{msg}. {err.__class__.__name__}: {err}", + ) + + def _handle_40X(self, code_number: str, message: str, context, troubleshoot: str, err: Exception): + return self.http_error_handler.handle40X( + code_number, + message, + context, + troubleshoot=f"{troubleshoot}", + err=f"{err.__class__.__name__}: {err}", + ) + + def _handle_400(self, context, troubleshoot: str, err: Exception = EmptyHTTPError("")): + return self._handle_40X("0", "invalid_request", context, troubleshoot, err) + + def _handle_401(self, context, troubleshoot: str, err: EmptyHTTPError = EmptyHTTPError("")): + return self._handle_40X("1", "invalid_client", context, troubleshoot, err) + + def _handle_403(self, context, troubleshoot: str, err: EmptyHTTPError = EmptyHTTPError("")): + return self._handle_40X("3", "expired", context, troubleshoot, err) + + @property + def db_engine(self) -> DBEngine: + """Returns the DBEngine instance used by the class""" + try: + self._db_engine.is_connected + except Exception as e: + if getattr(self, '_db_engine', None): + logger.debug( + lu.LOG_FMT.format( + id="OpenID4VP db storage handling", + message=f"connection check silently fails and get restored: {e}" + ) + ) + self._db_engine = DBEngine(self.config["storage"]) + + return self._db_engine + + @property + def default_metadata_private_jwk(self) -> tuple: + return tuple(self.metadata_jwks_by_kids.values())[0] + + @property + def server_url(self): + return ( + self.base_url[:-1] + if self.base_url[-1] == '/' + else self.base_url + ) \ No newline at end of file diff --git a/pyeudiw/satosa/exceptions.py b/pyeudiw/satosa/exceptions.py index 805e1a42..ea012311 100644 --- a/pyeudiw/satosa/exceptions.py +++ b/pyeudiw/satosa/exceptions.py @@ -21,3 +21,9 @@ class DiscoveryFailedError(Exception): Raised when the discovery fails """ pass + +class HTTPError(Exception): + pass + +class EmptyHTTPError(HTTPError): + pass \ No newline at end of file diff --git a/pyeudiw/satosa/http_error.py b/pyeudiw/satosa/http_error.py new file mode 100644 index 00000000..3ef9a84e --- /dev/null +++ b/pyeudiw/satosa/http_error.py @@ -0,0 +1,41 @@ +import logging +from pyeudiw.satosa.response import JsonResponse + +logger = logging.getLogger(__name__) + +class HTTPErrorHandler: + def __init__(self, template_path: str, error_template: str, log): + self.template_path = template_path + self.error_template = error_template + self._log = log + + def _serialize_error( + self, + context, + message: str, + troubleshoot: str, + err: str, + err_code: str, + level: str + ): + _msg = f"{message}:" + if err: + _msg += f" {err}." + self._log( + context, level=level, + message=f"{_msg} {troubleshoot}" + ) + + return JsonResponse( + { + "error": message, + "error_description": troubleshoot + }, + status=err_code + ) + + def handle500(self, context, troubleshoot: str = "", err: str = ""): + return self._serialize_error(context, "server_error", troubleshoot, err, "500", "error") + + def handle40X(self, code_number: str, message: str, context, troubleshoot: str = "", err: str = ""): + return self._serialize_error(context, message, troubleshoot, err, f"40{code_number}", "error") \ No newline at end of file From 289bb3ffdbef3162822ee0b15af3ed8dbdd4c8d2 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Thu, 14 Dec 2023 18:16:16 +0100 Subject: [PATCH 49/88] fix: redability fix --- pyeudiw/openid4vp/direct_post_response.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pyeudiw/openid4vp/direct_post_response.py b/pyeudiw/openid4vp/direct_post_response.py index c5406ed5..a51ad414 100644 --- a/pyeudiw/openid4vp/direct_post_response.py +++ b/pyeudiw/openid4vp/direct_post_response.py @@ -34,7 +34,7 @@ def __init__(self, jwt: str, jwks_by_kids: Dict[str, dict], nonce: str = ""): self._payload: dict = {} self._vps: list = [] - self.credentials_by_issuer: dict = {} + self.credentials_by_issuer: Dict[str, list[dict]] = {} self._claims_by_issuer: dict = {} def _decode_payload(self) -> None: @@ -109,9 +109,11 @@ def validate(self) -> bool: return True - def get_presentation_vps(self) -> list[dict]: + def get_presentation_vps(self) -> list[Vp]: """ - Returns the presentation's verifiable presentations + Returns the presentation's verifiable presentations. + + :raises VPNotFound: if no VPs are found. :returns: the list of vps. :rtype: list[dict] @@ -123,7 +125,7 @@ def get_presentation_vps(self) -> list[dict]: vps = [_vps] if isinstance(_vps, str) else _vps if not vps: - raise VPNotFound(f"Vps for response with nonce \"{self.nonce}\" are empty") + raise VPNotFound(f"Vps are empty for response with nonce \"{self.nonce}\"") for vp in vps: _vp = Vp(vp) From b5b65ef09408bde62830153e4a463096f9247f89 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Thu, 14 Dec 2023 18:20:30 +0100 Subject: [PATCH 50/88] feat: added methods for handling credential's JWKs --- pyeudiw/openid4vp/vp.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/pyeudiw/openid4vp/vp.py b/pyeudiw/openid4vp/vp.py index 74aaf261..213f9ec0 100644 --- a/pyeudiw/openid4vp/vp.py +++ b/pyeudiw/openid4vp/vp.py @@ -17,6 +17,7 @@ def __init__(self, jwt: str) -> None: self.parse_digital_credential() self.disclosed_user_attributes: dict = {} + self._credential_jwks: list[dict] = [] def _detect_vp_type(self) -> str: """ @@ -54,6 +55,20 @@ def parse_digital_credential(self) -> None: self.credential_headers = decode_jwt_header(self.payload['vp']) self.credential_payload = decode_jwt_payload(self.payload['vp']) + def set_credential_jwks(self, credential_jwks: list[dict]) -> None: + """ + Set the credential JWKs for the current istance. + + :param credential_jwks: a list containing the credential's JWKs. + :type credential_jwks: list[dict] + """ + self._credential_jwks = credential_jwks + + @property + def credential_jwks(self) -> list[dict]: + """Returns the credential JWKs""" + return self._credential_jwks + @property def credential_issuer(self) -> str: """Returns the credential issuer""" From 94f28b9371b1395e800172d6615e9083ab8f2915 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Thu, 14 Dec 2023 18:21:32 +0100 Subject: [PATCH 51/88] fix: fixed signatures --- pyeudiw/sd_jwt/__init__.py | 4 +++- pyeudiw/trust/__init__.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pyeudiw/sd_jwt/__init__.py b/pyeudiw/sd_jwt/__init__.py index 0fbe0e53..14da89ff 100644 --- a/pyeudiw/sd_jwt/__init__.py +++ b/pyeudiw/sd_jwt/__init__.py @@ -21,6 +21,8 @@ import jwcrypto +from typing import Any + class TrustChainSDJWTIssuer(SDJWTIssuer): def __init__(self, user_claims: Dict, issuer_key, holder_key=None, sign_alg=None, add_decoy_claims: bool = True, serialization_format: str = "compact", additional_headers: dict = {}): @@ -163,7 +165,7 @@ def verify_sd_jwt( issuer_key: JWK, holder_key: JWK, settings: dict = {'key_binding': True} -) -> dict: +) -> (list | dict | Any): settings.update( { diff --git a/pyeudiw/trust/__init__.py b/pyeudiw/trust/__init__.py index 82c1bad4..836e229a 100644 --- a/pyeudiw/trust/__init__.py +++ b/pyeudiw/trust/__init__.py @@ -230,7 +230,7 @@ def get_final_metadata(self, metadata_type: str, policies: list[dict]) -> dict: f" {self.final_metadata['metadata']}" ) - def get_trusted_jwks(self, metadata_type: str, policies: list[dict] = []) -> list: + def get_trusted_jwks(self, metadata_type: str, policies: list[dict] = []) -> list[dict]: return self.get_final_metadata( metadata_type=metadata_type, policies=policies From 7592b97a1d3d7ee43cdd53b524b723962ef376ee Mon Sep 17 00:00:00 2001 From: PascalDR Date: Thu, 14 Dec 2023 18:21:48 +0100 Subject: [PATCH 52/88] test: fixed test --- pyeudiw/tests/satosa/test_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyeudiw/tests/satosa/test_backend.py b/pyeudiw/tests/satosa/test_backend.py index cd507927..bd8eb714 100644 --- a/pyeudiw/tests/satosa/test_backend.py +++ b/pyeudiw/tests/satosa/test_backend.py @@ -548,8 +548,8 @@ def test_request_endpoint(self, context): # assert msg["response"] == "Authentication successful" def test_handle_error(self, context): - error_message = "Error message!" - error_resp = self.backend.handle_error(context, error_message) + error_message = "server_error" + error_resp = self.backend._handle_500(context, error_message, Exception()) assert error_resp.status == "500" assert error_resp.message err = json.loads(error_resp.message) From 2bc8494411a16602ea7f0a3ec98d330db4d29090 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Fri, 15 Dec 2023 16:19:04 +0100 Subject: [PATCH 53/88] docs: documented the content of backend.py --- pyeudiw/satosa/backend.py | 184 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 175 insertions(+), 9 deletions(-) diff --git a/pyeudiw/satosa/backend.py b/pyeudiw/satosa/backend.py index 5a33176f..784775df 100644 --- a/pyeudiw/satosa/backend.py +++ b/pyeudiw/satosa/backend.py @@ -141,7 +141,7 @@ def register_endpoints(self) -> list: self.absolute_status_url = _endpoint return url_map - def start_auth(self, context: Context, internal_request): + def start_auth(self, context: Context, internal_request) -> Response: """ This is the start up function of the backend authorization. @@ -155,7 +155,19 @@ def start_auth(self, context: Context, internal_request): """ return self.pre_request_endpoint(context, internal_request) - def pre_request_endpoint(self, context: Context, internal_request, **kwargs): + def pre_request_endpoint(self, context: Context, internal_request, **kwargs) -> Response: + """ + This endpoint is called by the frontend before calling the request endpoint. + It initializes the session and returns the request_uri to be used by the frontend. + + :type context: the context of current request + :param context: the request context + :type internal_request: satosa.internal.InternalData + :param internal_request: Information about the authorization request + + :return: a response containing the request_uri + :rtype: satosa.response.Response + """ self._log_function_debug("pre_request_endpoint", context, "internal_request", internal_request) @@ -209,7 +221,16 @@ def pre_request_endpoint(self, context: Context, internal_request, **kwargs): ) return Response(result, content="text/html; charset=utf8", status="200") - def redirect_endpoint(self, context: Context, *args): + def redirect_endpoint(self, context: Context, *args: tuple) -> Redirect | JsonResponse: + """ + This endpoint is called by the frontend after the user has been authenticated. + + :type context: the context of current request + :param context: the request context + + :return: a redirect to the frontend, if is in same device flow, or a json response if is in cross device flow. + :rtype: Redirect | JsonResponse + """ self._log_function_debug("redirect_endpoint", context, "args", args) @@ -377,7 +398,18 @@ def redirect_endpoint(self, context: Context, *args): status="200" ) - def request_endpoint(self, context: Context, *args): + def request_endpoint(self, context: Context, *args) -> JsonResponse: + """ + This endpoint is called by the frontend to retrieve the signed signed Request Object. + + :type context: the context of current request + :param context: the request context + :param args: the request arguments + :type args: tuple + + :return: a json response containing the request object + :rtype: JsonResponse + """ self._log_function_debug("request_endpoint", context, "args", args) @@ -452,7 +484,16 @@ def request_endpoint(self, context: Context, *args): status="200" ) - def get_response_endpoint(self, context: Context): + def get_response_endpoint(self, context: Context) -> Response: + """ + This endpoint is called by the frontend to retrieve the response of the authentication. + + :param context: the request context + :type context: satosa.context.Context + + :return: a response containing the response object with the authenctication status + :rtype: Response + """ self._log_function_debug("get_response_endpoint", context) @@ -493,7 +534,16 @@ def get_response_endpoint(self, context: Context): resp ) - def status_endpoint(self, context: Context): + def status_endpoint(self, context: Context) -> JsonResponse: + """ + This endpoint is called by the frontend the url to the response endpoint to finalize the process. + + :param context: the request context + :type context: satosa.context.Context + + :return: a json response containing the status and the url to get the response + :rtype: JsonResponse + """ self._log_function_debug("status_endpoint", context) @@ -558,6 +608,17 @@ def _render_metadata_conf_elements(self) -> None: self.config['metadata'][k] = self.config[conf_section][conf_k] def _log(self, context: str | Context, level: str, message: str) -> None: + """ + Log a message with the given level. + + :param context: the request context or the scope of the class + :type context: satosa.context.Context | str + :param level: the log level + :type level: str + :param message: the message to log + :type message: str + """ + context = context if isinstance(context, str) else context.state log_level = getattr(logger, level) @@ -569,9 +630,30 @@ def _log(self, context: str | Context, level: str, message: str) -> None: ) def _log_debug(self, context: str | Context, message: str) -> None: + """ + Log a message with the DEBUG level. + + :param context: the request context or the scope of the class + :type context: satosa.context.Context | str + :param message: the message to log + :type message: str + """ + self._log(context, "debug", message) def _log_function_debug(self, fn_name: str, context: Context, args_name: str | None = None, args = None) -> None: + """ + Logs a message at the start of a backend function. + + :param fn_name: the name of the function + :type fn_name: str + :param context: the request context + :param args_name: the name of the arguments field + :type args_name: str | None + :param args: the arguments provided to the function + :type args: Any + """ + args_str = f" and {args_name}: {args}" if not args_name else "" debug_message = ( @@ -581,6 +663,15 @@ def _log_function_debug(self, fn_name: str, context: Context, args_name: str | N self._log_debug(context, debug_message) def _log_error(self, context: str | Context, message: str) -> None: + """ + Log a message with the ERROR level. + + :param context: the request context or the scope of the class + :type context: satosa.context.Context | str + :param message: the message to log + :type message: str + """ + self._log(context, "error", message) def _translate_response(self, response: dict, issuer: str, context: Context): @@ -651,14 +742,46 @@ def _translate_response(self, response: dict, issuer: str, context: Context): internal_resp.subject_id = sub return internal_resp - def _handle_500(self, context, msg: str, err: Exception): + def _handle_500(self, context: Context, msg: str, err: Exception) -> JsonResponse: + """ + Handles a 500 error. + + :param context: the request context + :type context: satosa.context.Context + :param msg: the error message + :type msg: str + :param err: the exception raised + :type err: Exception + + :return: a json response containing the error + :rtype: JsonResponse + """ + return self.http_error_handler.handle500( context=context, troubleshoot=f"{msg}", err=f"{msg}. {err.__class__.__name__}: {err}", ) - def _handle_40X(self, code_number: str, message: str, context, troubleshoot: str, err: Exception): + def _handle_40X(self, code_number: str, message: str, context, troubleshoot: str, err: Exception) -> JsonResponse: + """ + Handles a 40X error. + + :param code_number: the code number + :type code_number: str + :param message: the error message + :type message: str + :param context: the request context + :type context: satosa.context.Context + :param troubleshoot: the troubleshoot message + :type troubleshoot: str + :param err: the exception raised + :type err: Exception + + :return: a json response containing the error + :rtype: JsonResponse + """ + return self.http_error_handler.handle40X( code_number, message, @@ -667,13 +790,54 @@ def _handle_40X(self, code_number: str, message: str, context, troubleshoot: str err=f"{err.__class__.__name__}: {err}", ) - def _handle_400(self, context, troubleshoot: str, err: Exception = EmptyHTTPError("")): + def _handle_400(self, context: Context, troubleshoot: str, err: Exception = EmptyHTTPError("")) -> JsonResponse: + """ + Handles a 400 error. + + :param context: the request context + :type context: satosa.context.Context + :param troubleshoot: the troubleshoot message + :type troubleshoot: str + :param err: the exception raised + :type err: Exception + + :return: a json response containing the error + :rtype: JsonResponse + """ return self._handle_40X("0", "invalid_request", context, troubleshoot, err) def _handle_401(self, context, troubleshoot: str, err: EmptyHTTPError = EmptyHTTPError("")): + """ + Handles a 401 error. + + :param context: the request context + :type context: satosa.context.Context + :param troubleshoot: the troubleshoot message + :type troubleshoot: str + :param err: the exception raised + :type err: Exception + + :return: a json response containing the error + :rtype: JsonResponse + """ + return self._handle_40X("1", "invalid_client", context, troubleshoot, err) def _handle_403(self, context, troubleshoot: str, err: EmptyHTTPError = EmptyHTTPError("")): + """ + Handles a 403 error. + + :param context: the request context + :type context: satosa.context.Context + :param troubleshoot: the troubleshoot message + :type troubleshoot: str + :param err: the exception raised + :type err: Exception + + :return: a json response containing the error + :rtype: JsonResponse + """ + return self._handle_40X("3", "expired", context, troubleshoot, err) @property @@ -695,10 +859,12 @@ def db_engine(self) -> DBEngine: @property def default_metadata_private_jwk(self) -> tuple: + """Returns the default metadata private JWK""" return tuple(self.metadata_jwks_by_kids.values())[0] @property def server_url(self): + """Returns the server url""" return ( self.base_url[:-1] if self.base_url[-1] == '/' From fde3b68eff507049781707d7f4fcaeaaf46f65ef Mon Sep 17 00:00:00 2001 From: PascalDR Date: Fri, 15 Dec 2023 17:38:39 +0100 Subject: [PATCH 54/88] docs: documented code of dpop.py --- pyeudiw/satosa/dpop.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/pyeudiw/satosa/dpop.py b/pyeudiw/satosa/dpop.py index 366bb45b..3bada939 100644 --- a/pyeudiw/satosa/dpop.py +++ b/pyeudiw/satosa/dpop.py @@ -14,11 +14,17 @@ logger = logging.getLogger(__name__) - -class BackendDPoP: - - def _request_endpoint_dpop(self, context, *args) -> Union[JsonResponse, None]: - """ This validates, if any, the DPoP http request header """ + """ + Validates, if any, the DPoP http request header + + :param context: The current context + :type context: Context + :param args: The current request arguments + :type args: tuple + + :return: + :rtype: Union[JsonResponse, None] + """ if context.http_headers and 'HTTP_AUTHORIZATION' in context.http_headers: # The wallet instance uses the endpoint authentication to give its WIA From 0109d7dfa7c028bdf0aa926f493a5f50675de2b3 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Fri, 15 Dec 2023 17:39:27 +0100 Subject: [PATCH 55/88] feat: created class BaseHTTPErrorHandler --- pyeudiw/satosa/base_http_error_handler.py | 132 ++++++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 pyeudiw/satosa/base_http_error_handler.py diff --git a/pyeudiw/satosa/base_http_error_handler.py b/pyeudiw/satosa/base_http_error_handler.py new file mode 100644 index 00000000..f26ddbc8 --- /dev/null +++ b/pyeudiw/satosa/base_http_error_handler.py @@ -0,0 +1,132 @@ +from satosa.context import Context +from .base_logger import BaseLogger +from .exceptions import EmptyHTTPError +from pyeudiw.satosa.response import JsonResponse + + +class BaseHTTPErrorHandler(BaseLogger): + def _serialize_error( + self, + context: Context, + message: str, + troubleshoot: str, + err: str, + err_code: str, + level: str + ): + _msg = f"{message}:" + if err: + _msg += f" {err}." + self._log( + context, level=level, + message=f"{_msg} {troubleshoot}" + ) + + return JsonResponse({ + "error": message, + "error_description": troubleshoot + }, + status=err_code + ) + + def _handle_500(self, context: Context, msg: str, err: Exception) -> JsonResponse: + """ + Handles a 500 error. + + :param context: the request context + :type context: satosa.context.Context + :param msg: the error message + :type msg: str + :param err: the exception raised + :type err: Exception + + :return: a json response containing the error + :rtype: JsonResponse + """ + + return self._serialize_error( + context, + "server_error", + f"{msg}", + f"{msg}. {err.__class__.__name__}: {err}", + "500", + "error" + ) + + def _handle_40X(self, code_number: str, message: str, context: Context, troubleshoot: str, err: Exception) -> JsonResponse: + """ + Handles a 40X error. + + :param code_number: the code number + :type code_number: str + :param message: the error message + :type message: str + :param context: the request context + :type context: satosa.context.Context + :param troubleshoot: the troubleshoot message + :type troubleshoot: str + :param err: the exception raised + :type err: Exception + + :return: a json response containing the error + :rtype: JsonResponse + """ + + return self._serialize_error( + context, + message, + troubleshoot, + f"{err.__class__.__name__}: {err}", + f"40{code_number}", + "error" + ) + + def _handle_400(self, context: Context, troubleshoot: str, err: Exception = EmptyHTTPError("")) -> JsonResponse: + """ + Handles a 400 error. + + :param context: the request context + :type context: satosa.context.Context + :param troubleshoot: the troubleshoot message + :type troubleshoot: str + :param err: the exception raised + :type err: Exception + + :return: a json response containing the error + :rtype: JsonResponse + """ + return self._handle_40X("0", "invalid_request", context, troubleshoot, err) + + def _handle_401(self, context, troubleshoot: str, err: EmptyHTTPError = EmptyHTTPError("")): + """ + Handles a 401 error. + + :param context: the request context + :type context: satosa.context.Context + :param troubleshoot: the troubleshoot message + :type troubleshoot: str + :param err: the exception raised + :type err: Exception + + :return: a json response containing the error + :rtype: JsonResponse + """ + + return self._handle_40X("1", "invalid_client", context, troubleshoot, err) + + def _handle_403(self, context, troubleshoot: str, err: EmptyHTTPError = EmptyHTTPError("")): + """ + Handles a 403 error. + + :param context: the request context + :type context: satosa.context.Context + :param troubleshoot: the troubleshoot message + :type troubleshoot: str + :param err: the exception raised + :type err: Exception + + :return: a json response containing the error + :rtype: JsonResponse + """ + + return self._handle_40X("3", "expired", context, troubleshoot, err) \ No newline at end of file From edbf46d548b00a91296b61061379290a66748086 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Fri, 15 Dec 2023 17:39:50 +0100 Subject: [PATCH 56/88] feat: created class BaseLogger --- pyeudiw/satosa/base_logger.py | 120 ++++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 pyeudiw/satosa/base_logger.py diff --git a/pyeudiw/satosa/base_logger.py b/pyeudiw/satosa/base_logger.py new file mode 100644 index 00000000..6dbda148 --- /dev/null +++ b/pyeudiw/satosa/base_logger.py @@ -0,0 +1,120 @@ +import logging +import satosa.logging_util as lu +from satosa.context import Context + +logger = logging.getLogger(__name__) + +class BaseLogger: + def _log(self, context: str | Context, level: str, message: str) -> None: + """ + Log a message with the given level. + + :param context: the request context or the scope of the class + :type context: satosa.context.Context | str + :param level: the log level + :type level: str + :param message: the message to log + :type message: str + """ + + context = context if isinstance(context, str) else context.state + + log_level = getattr(logger, level) + log_level( + lu.LOG_FMT.format( + id=lu.get_session_id(context), + message=message + ) + ) + + def _log_debug(self, context: str | Context, message: str) -> None: + """ + Log a message with the DEBUG level. + + :param context: the request context or the scope of the class + :type context: satosa.context.Context | str + :param message: the message to log + :type message: str + """ + + self._log(context, "debug", message) + + def _log_function_debug(self, fn_name: str, context: Context, args_name: str | None = None, args = None) -> None: + """ + Logs a message at the start of a backend function. + + :param fn_name: the name of the function + :type fn_name: str + :param context: the request context + :param args_name: the name of the arguments field + :type args_name: str | None + :param args: the arguments provided to the function + :type args: Any + """ + + args_str = f" and {args_name}: {args}" if not args_name else "" + + debug_message = ( + f"[INCOMING REQUEST] {fn_name} with Context: " + f"{context.__dict__}{args_str}" + ) + self._log_debug(context, debug_message) + + def _log_error(self, context: str | Context, message: str) -> None: + """ + Log a message with the ERROR level. + + :param context: the request context or the scope of the class + :type context: satosa.context.Context | str + :param message: the message to log + :type message: str + """ + + self._log(context, "error", message) + + def _log_warning(self, context: str | Context, message: str) -> None: + """ + Log a message with the WARNING level. + + :param context: the request context or the scope of the class + :type context: satosa.context.Context | str + :param message: the message to log + :type message: str + """ + + self._log(context, "warning", message) + + def _log_info(self, context: str | Context, message: str) -> None: + """ + Log a message with the INFO level. + + :param context: the request context or the scope of the class + :type context: satosa.context.Context | str + :param message: the message to log + :type message: str + """ + + self._log(context, "info", message) + + def _log_critical(self, context: str | Context, message: str) -> None: + """ + Log a message with the CRITICAL level. + + :param context: the request context or the scope of the class + :type context: satosa.context.Context | str + :param message: the message to log + :type message: str + """ + + self._log(context, "critical", message) + + @property + def effective_log_level(self) -> int: + """ + Returns the effective log level. + + :return: the effective log level + :rtype: int + """ + + return logger.getEffectiveLevel() \ No newline at end of file From b64e3a18e282614f090805667c5c10255fb125c3 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Fri, 15 Dec 2023 17:40:31 +0100 Subject: [PATCH 57/88] chore: removed unused implementation --- pyeudiw/satosa/http_error.py | 41 ------------------------------------ 1 file changed, 41 deletions(-) delete mode 100644 pyeudiw/satosa/http_error.py diff --git a/pyeudiw/satosa/http_error.py b/pyeudiw/satosa/http_error.py deleted file mode 100644 index 3ef9a84e..00000000 --- a/pyeudiw/satosa/http_error.py +++ /dev/null @@ -1,41 +0,0 @@ -import logging -from pyeudiw.satosa.response import JsonResponse - -logger = logging.getLogger(__name__) - -class HTTPErrorHandler: - def __init__(self, template_path: str, error_template: str, log): - self.template_path = template_path - self.error_template = error_template - self._log = log - - def _serialize_error( - self, - context, - message: str, - troubleshoot: str, - err: str, - err_code: str, - level: str - ): - _msg = f"{message}:" - if err: - _msg += f" {err}." - self._log( - context, level=level, - message=f"{_msg} {troubleshoot}" - ) - - return JsonResponse( - { - "error": message, - "error_description": troubleshoot - }, - status=err_code - ) - - def handle500(self, context, troubleshoot: str = "", err: str = ""): - return self._serialize_error(context, "server_error", troubleshoot, err, "500", "error") - - def handle40X(self, code_number: str, message: str, context, troubleshoot: str = "", err: str = ""): - return self._serialize_error(context, message, troubleshoot, err, f"40{code_number}", "error") \ No newline at end of file From c6e469572a34cd4040b4d7fad76c4dac15789e1f Mon Sep 17 00:00:00 2001 From: PascalDR Date: Fri, 15 Dec 2023 17:41:08 +0100 Subject: [PATCH 58/88] fix: code refactoring --- pyeudiw/satosa/backend.py | 197 +++----------------------------------- pyeudiw/satosa/dpop.py | 78 +++++---------- pyeudiw/satosa/trust.py | 130 ++++++++++--------------- 3 files changed, 86 insertions(+), 319 deletions(-) diff --git a/pyeudiw/satosa/backend.py b/pyeudiw/satosa/backend.py index 784775df..1fb3c223 100644 --- a/pyeudiw/satosa/backend.py +++ b/pyeudiw/satosa/backend.py @@ -39,13 +39,11 @@ from typing import Callable from pydantic import ValidationError -from .http_error import HTTPErrorHandler -from .exceptions import HTTPError, EmptyHTTPError +from .exceptions import HTTPError +from .base_http_error_handler import BaseHTTPErrorHandler +from .base_logger import BaseLogger -logger = logging.getLogger(__name__) - - -class OpenID4VPBackend(BackendModule, BackendTrust, BackendDPoP): +class OpenID4VPBackend(BackendModule, BackendTrust, BackendDPoP, BaseHTTPErrorHandler, BaseLogger): """ A backend module (acting as a OpenID4VP SP). """ @@ -79,9 +77,8 @@ def __init__( try: WalletRelyingParty(**config['metadata']) except ValidationError as e: - logger.warning( - f"""The backend configuration presents the following validation issues: - {logger.warning(e)}""") + debug_message = f"""The backend configuration presents the following validation issues: {e}""" + self._log_warning("OpenID4VPBackend", debug_message) self.config = config self.client_id = self.config['metadata']['client_id'] @@ -107,11 +104,9 @@ def __init__( # resolve metadata pointers/placeholders self._render_metadata_conf_elements() self.init_trust_resources() - - self.http_error_handler = HTTPErrorHandler("templates", "error.html", self._log) self._log_debug("OpenID4VP init", f"Loaded configuration: {json.dumps(config)}") - def register_endpoints(self) -> list: + def register_endpoints(self) -> list[tuple[str, Callable[[Context], Response]]]: """ Creates a list of all the endpoints this backend module needs to listen to. In this case it's the authentication response from the underlying OP that is redirected from the OP to @@ -128,7 +123,8 @@ def register_endpoints(self) -> list: ) ) _endpoint = f"{self.client_id}{v}" - logger.debug( + self._log_debug( + "OpenID4VPBackend", f"Exposing backend entity endpoint = {_endpoint}" ) if k == 'get_response': @@ -375,7 +371,7 @@ def redirect_endpoint(self, context: Context, *args: tuple) -> Redirect | JsonRe ) # authentication finalized! self.db_engine.set_finalized(stored_session['document_id']) - if logger.getEffectiveLevel() == logging.DEBUG: + if self.effective_log_level == logging.DEBUG: stored_session = self.db_engine.get_by_state(state=state) self._log_debug(context, f"Session update on storage: {stored_session}") @@ -607,73 +603,6 @@ def _render_metadata_conf_elements(self) -> None: conf_section, conf_k = v[1:-1].split('.') self.config['metadata'][k] = self.config[conf_section][conf_k] - def _log(self, context: str | Context, level: str, message: str) -> None: - """ - Log a message with the given level. - - :param context: the request context or the scope of the class - :type context: satosa.context.Context | str - :param level: the log level - :type level: str - :param message: the message to log - :type message: str - """ - - context = context if isinstance(context, str) else context.state - - log_level = getattr(logger, level) - log_level( - lu.LOG_FMT.format( - id=lu.get_session_id(context), - message=message - ) - ) - - def _log_debug(self, context: str | Context, message: str) -> None: - """ - Log a message with the DEBUG level. - - :param context: the request context or the scope of the class - :type context: satosa.context.Context | str - :param message: the message to log - :type message: str - """ - - self._log(context, "debug", message) - - def _log_function_debug(self, fn_name: str, context: Context, args_name: str | None = None, args = None) -> None: - """ - Logs a message at the start of a backend function. - - :param fn_name: the name of the function - :type fn_name: str - :param context: the request context - :param args_name: the name of the arguments field - :type args_name: str | None - :param args: the arguments provided to the function - :type args: Any - """ - - args_str = f" and {args_name}: {args}" if not args_name else "" - - debug_message = ( - f"[INCOMING REQUEST] {fn_name} with Context: " - f"{context.__dict__}{args_str}" - ) - self._log_debug(context, debug_message) - - def _log_error(self, context: str | Context, message: str) -> None: - """ - Log a message with the ERROR level. - - :param context: the request context or the scope of the class - :type context: satosa.context.Context | str - :param message: the message to log - :type message: str - """ - - self._log(context, "error", message) - def _translate_response(self, response: dict, issuer: str, context: Context): """ Translates wallet response to SATOSA internal response. @@ -742,104 +671,6 @@ def _translate_response(self, response: dict, issuer: str, context: Context): internal_resp.subject_id = sub return internal_resp - def _handle_500(self, context: Context, msg: str, err: Exception) -> JsonResponse: - """ - Handles a 500 error. - - :param context: the request context - :type context: satosa.context.Context - :param msg: the error message - :type msg: str - :param err: the exception raised - :type err: Exception - - :return: a json response containing the error - :rtype: JsonResponse - """ - - return self.http_error_handler.handle500( - context=context, - troubleshoot=f"{msg}", - err=f"{msg}. {err.__class__.__name__}: {err}", - ) - - def _handle_40X(self, code_number: str, message: str, context, troubleshoot: str, err: Exception) -> JsonResponse: - """ - Handles a 40X error. - - :param code_number: the code number - :type code_number: str - :param message: the error message - :type message: str - :param context: the request context - :type context: satosa.context.Context - :param troubleshoot: the troubleshoot message - :type troubleshoot: str - :param err: the exception raised - :type err: Exception - - :return: a json response containing the error - :rtype: JsonResponse - """ - - return self.http_error_handler.handle40X( - code_number, - message, - context, - troubleshoot=f"{troubleshoot}", - err=f"{err.__class__.__name__}: {err}", - ) - - def _handle_400(self, context: Context, troubleshoot: str, err: Exception = EmptyHTTPError("")) -> JsonResponse: - """ - Handles a 400 error. - - :param context: the request context - :type context: satosa.context.Context - :param troubleshoot: the troubleshoot message - :type troubleshoot: str - :param err: the exception raised - :type err: Exception - - :return: a json response containing the error - :rtype: JsonResponse - """ - return self._handle_40X("0", "invalid_request", context, troubleshoot, err) - - def _handle_401(self, context, troubleshoot: str, err: EmptyHTTPError = EmptyHTTPError("")): - """ - Handles a 401 error. - - :param context: the request context - :type context: satosa.context.Context - :param troubleshoot: the troubleshoot message - :type troubleshoot: str - :param err: the exception raised - :type err: Exception - - :return: a json response containing the error - :rtype: JsonResponse - """ - - return self._handle_40X("1", "invalid_client", context, troubleshoot, err) - - def _handle_403(self, context, troubleshoot: str, err: EmptyHTTPError = EmptyHTTPError("")): - """ - Handles a 403 error. - - :param context: the request context - :type context: satosa.context.Context - :param troubleshoot: the troubleshoot message - :type troubleshoot: str - :param err: the exception raised - :type err: Exception - - :return: a json response containing the error - :rtype: JsonResponse - """ - - return self._handle_40X("3", "expired", context, troubleshoot, err) - @property def db_engine(self) -> DBEngine: """Returns the DBEngine instance used by the class""" @@ -847,11 +678,9 @@ def db_engine(self) -> DBEngine: self._db_engine.is_connected except Exception as e: if getattr(self, '_db_engine', None): - logger.debug( - lu.LOG_FMT.format( - id="OpenID4VP db storage handling", - message=f"connection check silently fails and get restored: {e}" - ) + self._log_debug( + "OpenID4VP db storage handling", + f"connection check silently fails and get restored: {e}" ) self._db_engine = DBEngine(self.config["storage"]) diff --git a/pyeudiw/satosa/dpop.py b/pyeudiw/satosa/dpop.py index 3bada939..c07c0b39 100644 --- a/pyeudiw/satosa/dpop.py +++ b/pyeudiw/satosa/dpop.py @@ -1,19 +1,19 @@ -import logging - from typing import Union - - from pyeudiw.jwt.utils import decode_jwt_header, decode_jwt_payload from pyeudiw.oauth2.dpop import DPoPVerifier -from pyeudiw.openid4vp.schemas.wallet_instance_attestation import WalletInstanceAttestationPayload, \ +from pyeudiw.openid4vp.schemas.wallet_instance_attestation import ( + WalletInstanceAttestationPayload, WalletInstanceAttestationHeader +) from pyeudiw.satosa.response import JsonResponse - -import satosa.logging_util as lu from satosa.context import Context +from pydantic import ValidationError -logger = logging.getLogger(__name__) +from .base_logger import BaseLogger +from .base_http_error_handler import BaseHTTPErrorHandler +class BackendDPoP(BaseHTTPErrorHandler, BaseLogger): + def _request_endpoint_dpop(self, context: Context, *args) -> Union[JsonResponse, None]: """ Validates, if any, the DPoP http request header @@ -34,43 +34,27 @@ _head = decode_jwt_header(dpop_jws) wia = decode_jwt_payload(dpop_jws) - self._log( - context, - level='debug', - message=( - f"[FOUND WIA] Headers: {_head} and Payload: {wia}" - ) - ) + self._log_debug(context, message=f"[FOUND WIA] Headers: {_head} and Payload: {wia}") try: WalletInstanceAttestationHeader(**_head) + except ValidationError as e: + self._log_warning(context, message=f"[FOUND WIA] Invalid Headers: {_head}! \nValidation error: {e}") except Exception as e: - self._log( - context, - level='warning', - message=f"[FOUND WIA] Invalid Headers: {_head}! \nValidation error: {e}" - ) + self._log_warning(context, message=f"[FOUND WIA] Invalid Headers: {_head}! \nUnexpected error: {e}") try: WalletInstanceAttestationPayload(**wia) + except ValidationError as e: + self._log_warning(context, message=f"[FOUND WIA] Invalid WIA: {wia}! \nValidation error: {e}") except Exception as e: - self._log( - context, - level='warning', - message=f"[FOUND WIA] Invalid WIA: {wia}! \nValidation error: {e}" - ) + self._log_warning(context, message=f"[FOUND WIA] Invalid WIA: {wia}! \nUnexpected error: {e}") try: self._validate_trust(context, dpop_jws) except Exception as e: _msg = f"Trust Chain validation failed for dpop JWS {dpop_jws}" - return self.handle_error( - context=context, - message="invalid_client", - troubleshoot=_msg, - err_code="401", - err=f"{e}" - ) + return self._handle_401(context, _msg, e) try: dpop = DPoPVerifier( @@ -78,27 +62,18 @@ http_header_authz=context.http_headers['HTTP_AUTHORIZATION'], http_header_dpop=context.http_headers['HTTP_DPOP'] ) + except ValidationError as e: + _msg = f"DPoP validation error: {e}" + return self._handle_401(context, _msg, e) except Exception as e: _msg = f"DPoP verification error: {e}" - return self.handle_error( - context=context, - message="invalid_client", - troubleshoot=_msg, - err_code="401", - err=f"{e}" - ) + return self._handle_401(context, _msg, e) try: dpop.validate() except Exception as e: _msg = "DPoP validation exception" - return self.handle_error( - context=context, - message="invalid_client", - troubleshoot=_msg, - err=f"{e}", - err_code="401" - ) + return self._handle_401(context, _msg, e) # TODO: assert and configure the wallet capabilities # TODO: assert and configure the wallet Attested Security Context @@ -108,13 +83,4 @@ "The Wallet Instance doesn't provide a valid Wallet Instance Attestation " "a default set of capabilities and a low security level are applied." ) - self._log(context, level='warning', message=_msg) - - def _log(self, context: Context, level: str, message: str) -> None: - log_level = getattr(logger, level) - log_level( - lu.LOG_FMT.format( - id=lu.get_session_id(context.state), - message=message - ) - ) + self._log_warning(context, message=_msg) \ No newline at end of file diff --git a/pyeudiw/satosa/trust.py b/pyeudiw/satosa/trust.py index 5842c78a..f2d781eb 100644 --- a/pyeudiw/satosa/trust.py +++ b/pyeudiw/satosa/trust.py @@ -1,7 +1,4 @@ import json -import logging - - import satosa.logging_util as lu from satosa.context import Context from satosa.response import Response @@ -17,11 +14,9 @@ from pyeudiw.trust import TrustEvaluationHelper from pyeudiw.trust.trust_anchors import update_trust_anchors_ecs +from .base_logger import BaseLogger -logger = logging.getLogger(__name__) - - -class BackendTrust: +class BackendTrust(BaseLogger): def init_trust_resources(self) -> None: # private keys by kid @@ -39,14 +34,12 @@ def init_trust_resources(self) -> None: try: self.get_backend_trust_chain() except Exception as e: - logger.critical( - f"Cannot fetch the trust anchor configuration: {e}" - ) + self._log_critical("Backend Trust", f"Cannot fetch the trust anchor configuration: {e}") self.db_engine.close() self._db_engine = None - def entity_configuration_endpoint(self, context): + def entity_configuration_endpoint(self, context: Context): data = self.entity_configuration_as_dict if context.qs_params.get('format', '') == 'json': @@ -64,12 +57,8 @@ def entity_configuration_endpoint(self, context): def update_trust_anchors(self): tas = self.config['federation']['trust_anchors'] - logger.info( - lu.LOG_FMT.format( - id="Trust Anchors updates", - message=f"Trying to update: {tas}" - ) - ) + self._log_info("Trust Anchors updates", f"Trying to update: {tas}") + for ta in tas: try: update_trust_anchors_ecs( @@ -78,22 +67,9 @@ def update_trust_anchors(self): httpc_params=self.config['network']['httpc_params'] ) except Exception as e: - logger.warning( - lu.LOG_FMT.format( - id=f"Trust Anchor updates", - message=f"{ta} update failed: {e}" - ) - ) - logger.info( - lu.LOG_FMT.format( - id="Trust Anchor update", - message=f"Trust Anchor updated: {ta}" - ) - ) + self._log_warning("Trust Anchor updates", f"{ta} update failed: {e}") - @property - def default_federation_private_jwk(self): - return tuple(self.federations_jwks_by_kids.values())[0] + self._log_info("Trust Anchor updates", f"{ta} updated") def get_backend_trust_chain(self) -> list: """ @@ -118,12 +94,53 @@ def get_backend_trust_chain(self) -> list: return trust_evaluation_helper.trust_chain except (DiscoveryFailedError, EntryNotFound, Exception) as e: - logger.warning( + message = ( f"Error while building trust chain for client with id: {self.client_id}\n" f"{e.__class__.__name__}: {e}" ) + self._log_warning("Trust Chain", message) return [] + + def _validate_trust(self, context: Context, jws: str) -> TrustEvaluationHelper: + self._log_debug(context, "[TRUST EVALUATION] evaluating trust.") + + headers = decode_jwt_header(jws) + trust_eval = TrustEvaluationHelper( + self.db_engine, + httpc_params=self.config['network']['httpc_params'], + **headers + ) + + try: + trust_eval.evaluation_method() + except EntryNotFound: + message = ( + "[TRUST EVALUATION] not found for " + f"{trust_eval.entity_id}" + ) + self._log_error(context, message) + + raise NotTrustedFederationError( + f"{trust_eval.entity_id} not found for Trust evaluation." + ) + except Exception as e: + message = ( + "[TRUST EVALUATION] failed for " + f"{trust_eval.entity_id}: {e}" + ) + self._log_error(context, message) + + raise NotTrustedFederationError( + f"{trust_eval.entity_id} is not trusted." + ) + + return trust_eval + + + @property + def default_federation_private_jwk(self): + return tuple(self.federations_jwks_by_kids.values())[0] @property def entity_configuration_as_dict(self) -> dict: @@ -154,49 +171,4 @@ def entity_configuration(self) -> dict: "typ": "entity-statement+jwt" }, plain_dict=data - ) - - def _validate_trust(self, context: Context, jws: str) -> TrustEvaluationHelper: - self._log( - context, - level='debug', - message=( - "[TRUST EVALUATION] evaluating trust." - ) - ) - - headers = decode_jwt_header(jws) - trust_eval = TrustEvaluationHelper( - self.db_engine, - httpc_params=self.config['network']['httpc_params'], - **headers - ) - - try: - trust_eval.evaluation_method() - except Exception as e: - self._log( - context, - level='error', - message=( - "[TRUST EVALUATION] failed for " - f"{trust_eval.entity_id}: {e}" - ) - ) - raise NotTrustedFederationError( - f"{trust_eval.entity_id} is not trusted." - ) - except EntryNotFound: - self._log( - context, - level='error', - message=( - "[TRUST EVALUATION] not found for " - f"{trust_eval.entity_id}" - ) - ) - raise NotTrustedFederationError( - f"{trust_eval.entity_id} not found for Trust evaluation." - ) - - return trust_eval + ) \ No newline at end of file From a3b6bd0cc9d788d8d6eee1576ccd800c2511ac36 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Fri, 15 Dec 2023 17:43:30 +0100 Subject: [PATCH 59/88] docs: added doc for _serialize_error --- pyeudiw/satosa/base_http_error_handler.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/pyeudiw/satosa/base_http_error_handler.py b/pyeudiw/satosa/base_http_error_handler.py index f26ddbc8..785042c5 100644 --- a/pyeudiw/satosa/base_http_error_handler.py +++ b/pyeudiw/satosa/base_http_error_handler.py @@ -13,7 +13,27 @@ def _serialize_error( err: str, err_code: str, level: str - ): + ) -> JsonResponse: + """ + Serializes an error. + + :param context: the request context + :type context: satosa.context.Context + :param message: the error message + :type message: str + :param troubleshoot: the troubleshoot message + :type troubleshoot: str + :param err: more info about the error + :type err: str + :param err_code: the error code + :type err_code: str + :param level: the log level + :type level: str + + :return: a json response containing the error + :rtype: JsonResponse + """ + _msg = f"{message}:" if err: _msg += f" {err}." From 0b8a93364e68b0f56fdac183334fb555dd0a9148 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Fri, 15 Dec 2023 17:45:38 +0100 Subject: [PATCH 60/88] docs: documented HTTPError and EmptyHTTPError --- pyeudiw/satosa/exceptions.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pyeudiw/satosa/exceptions.py b/pyeudiw/satosa/exceptions.py index ea012311..1940f1fa 100644 --- a/pyeudiw/satosa/exceptions.py +++ b/pyeudiw/satosa/exceptions.py @@ -23,7 +23,13 @@ class DiscoveryFailedError(Exception): pass class HTTPError(Exception): + """ + Raised when an error occurs during an HTTP request + """ pass class EmptyHTTPError(HTTPError): + """ + Default HTTP empty error + """ pass \ No newline at end of file From 14487a2e8bd33427232f7bc71739cb7414cb5505 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Fri, 15 Dec 2023 17:46:26 +0100 Subject: [PATCH 61/88] docs: fixed doc --- pyeudiw/satosa/dpop.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyeudiw/satosa/dpop.py b/pyeudiw/satosa/dpop.py index c07c0b39..0abc0ef6 100644 --- a/pyeudiw/satosa/dpop.py +++ b/pyeudiw/satosa/dpop.py @@ -13,6 +13,10 @@ from .base_http_error_handler import BaseHTTPErrorHandler class BackendDPoP(BaseHTTPErrorHandler, BaseLogger): + """ + Backend DPoP class. + """ + def _request_endpoint_dpop(self, context: Context, *args) -> Union[JsonResponse, None]: """ Validates, if any, the DPoP http request header From d5d4a5b966f62c64bb5c656861d06e99c5545e9d Mon Sep 17 00:00:00 2001 From: PascalDR Date: Fri, 15 Dec 2023 18:00:31 +0100 Subject: [PATCH 62/88] docs: documented content of html_template.py --- pyeudiw/satosa/html_template.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/pyeudiw/satosa/html_template.py b/pyeudiw/satosa/html_template.py index 3788818b..4320254d 100644 --- a/pyeudiw/satosa/html_template.py +++ b/pyeudiw/satosa/html_template.py @@ -1,9 +1,17 @@ +from typing import Any, Dict from jinja2 import Environment, FileSystemLoader, select_autoescape - class Jinja2TemplateHandler: + """ + Jinja2 template handler + """ + def __init__(self, config: Dict[str, Any]): + """ + Create an istance of Jinja2TemplateHandler - def __init__(self, config): + :param config: a dictionary that contains the configuration for initalize the template handler. + :type config: Dict[str, Any] + """ # error pages handler self.loader = Environment( From 5905212df5cf29ee520769e46ac9cd9fbd2e9cf0 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Fri, 15 Dec 2023 18:01:07 +0100 Subject: [PATCH 63/88] docs: documented content of response.py --- pyeudiw/satosa/response.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/pyeudiw/satosa/response.py b/pyeudiw/satosa/response.py index 542c91b6..4d73d811 100644 --- a/pyeudiw/satosa/response.py +++ b/pyeudiw/satosa/response.py @@ -1,12 +1,23 @@ import json - from satosa.response import Response class JsonResponse(Response): + """ + A JSON response istance class. + """ + _content_type = "application/json" def __init__(self, *args, **kwargs): + """ + Creates an instance of JsonResponse. + + :param args: a list of arguments + :type args: Any + :param kwargs: a dictionary of arguments + :type kwargs: Any + """ super().__init__(*args, **kwargs) if isinstance(self.message, list): From 380df4f0015c90cc461c677d83fd52c1ed92d82d Mon Sep 17 00:00:00 2001 From: PascalDR Date: Fri, 15 Dec 2023 18:01:42 +0100 Subject: [PATCH 64/88] docs: documented content of trust.py --- pyeudiw/satosa/trust.py | 43 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 3 deletions(-) diff --git a/pyeudiw/satosa/trust.py b/pyeudiw/satosa/trust.py index f2d781eb..ffaea9ca 100644 --- a/pyeudiw/satosa/trust.py +++ b/pyeudiw/satosa/trust.py @@ -17,8 +17,15 @@ from .base_logger import BaseLogger class BackendTrust(BaseLogger): + """ + Backend Trust class. + """ def init_trust_resources(self) -> None: + """ + Initializes the trust resources. + """ + # private keys by kid self.federations_jwks_by_kids = { i['kid']: i for i in self.config['federation']['federation_jwks'] @@ -39,7 +46,16 @@ def init_trust_resources(self) -> None: self.db_engine.close() self._db_engine = None - def entity_configuration_endpoint(self, context: Context): + def entity_configuration_endpoint(self, context: Context) -> Response: + """ + Entity Configuration endpoint. + + :param context: The current context + :type context: Context + + :return: The entity configuration + :rtype: Response + """ data = self.entity_configuration_as_dict if context.qs_params.get('format', '') == 'json': @@ -56,6 +72,10 @@ def entity_configuration_endpoint(self, context: Context): ) def update_trust_anchors(self): + """ + Updates the trust anchors of current instance. + """ + tas = self.config['federation']['trust_anchors'] self._log_info("Trust Anchors updates", f"Trying to update: {tas}") @@ -71,7 +91,7 @@ def update_trust_anchors(self): self._log_info("Trust Anchor updates", f"{ta} updated") - def get_backend_trust_chain(self) -> list: + def get_backend_trust_chain(self) -> list[str]: """ Get the backend trust chain. In case something raises an Exception (e.g. faulty storage), logs a warning message and returns an empty list. @@ -103,6 +123,20 @@ def get_backend_trust_chain(self) -> list: return [] def _validate_trust(self, context: Context, jws: str) -> TrustEvaluationHelper: + """ + Validates the trust of the given jws. + + :param context: the request context + :type context: satosa.context.Context + :param jws: the jws to validate + :type jws: str + + :raises: NotTrustedFederationError: raises an error if the trust evaluation fails. + + :return: the trust evaluation helper + :rtype: TrustEvaluationHelper + """ + self._log_debug(context, "[TRUST EVALUATION] evaluating trust.") headers = decode_jwt_header(jws) @@ -139,11 +173,13 @@ def _validate_trust(self, context: Context, jws: str) -> TrustEvaluationHelper: @property - def default_federation_private_jwk(self): + def default_federation_private_jwk(self) -> dict: + """Returns the default federation private jwk.""" return tuple(self.federations_jwks_by_kids.values())[0] @property def entity_configuration_as_dict(self) -> dict: + """Returns the entity configuration as a dictionary.""" ec_payload = { "exp": exp_from_now(minutes=self.default_exp), "iat": iat_now(), @@ -162,6 +198,7 @@ def entity_configuration_as_dict(self) -> dict: @property def entity_configuration(self) -> dict: + """Returns the entity configuration as a JWT.""" data = self.entity_configuration_as_dict jwshelper = JWSHelper(self.default_federation_private_jwk) return jwshelper.sign( From 900edc37fae3062e047037547c9abb01ff7b5a82 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Fri, 15 Dec 2023 18:02:02 +0100 Subject: [PATCH 65/88] fix: fixed signature --- pyeudiw/trust/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyeudiw/trust/__init__.py b/pyeudiw/trust/__init__.py index 836e229a..7952fae1 100644 --- a/pyeudiw/trust/__init__.py +++ b/pyeudiw/trust/__init__.py @@ -27,7 +27,7 @@ class TrustEvaluationHelper: def __init__(self, storage: DBEngine, httpc_params, trust_anchor: str = None, **kwargs): self.exp: int = 0 - self.trust_chain: list = [] + self.trust_chain: list[str] = [] self.trust_anchor = trust_anchor self.storage = storage self.entity_id: str = "" From b8e929aac0f6633ff379c73f29bf98b9d6929ebb Mon Sep 17 00:00:00 2001 From: PascalDR Date: Fri, 15 Dec 2023 18:15:31 +0100 Subject: [PATCH 66/88] fix: fixed message passing --- pyeudiw/satosa/backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyeudiw/satosa/backend.py b/pyeudiw/satosa/backend.py index 1fb3c223..2cd3f245 100644 --- a/pyeudiw/satosa/backend.py +++ b/pyeudiw/satosa/backend.py @@ -242,7 +242,8 @@ def redirect_endpoint(self, context: Context, *args: tuple) -> Redirect | JsonRe # take the encrypted jwt, decrypt with my public key (one of the metadata) -> if not -> exception jwt = context.request.get("response", None) if not jwt: - self._log_error(context, f"Response error, missing JWT") + _msg = f"Response error, missing JWT" + self._log_error(context, _msg) return self._handle_400(context, _msg) try: From 97e0244edbc4b309bb7ea83d98ac3e8c433a8cee Mon Sep 17 00:00:00 2001 From: PascalDR Date: Tue, 19 Dec 2023 20:25:55 +0100 Subject: [PATCH 67/88] docs: documented content of __init__.py --- pyeudiw/sd_jwt/__init__.py | 172 ++++++++++++++++++++++++++++++++++--- 1 file changed, 160 insertions(+), 12 deletions(-) diff --git a/pyeudiw/sd_jwt/__init__.py b/pyeudiw/sd_jwt/__init__.py index 14da89ff..65a4366d 100644 --- a/pyeudiw/sd_jwt/__init__.py +++ b/pyeudiw/sd_jwt/__init__.py @@ -1,4 +1,3 @@ -import cryptojwt import json from jwcrypto.common import base64url_encode @@ -22,12 +21,45 @@ import jwcrypto from typing import Any - +from cryptojwt.jwk.rsa import RSAKey +from cryptojwt.jwk.ec import ECKey +from cryptography.hazmat.backends.openssl.rsa import _RSAPrivateKey class TrustChainSDJWTIssuer(SDJWTIssuer): - def __init__(self, user_claims: Dict, issuer_key, holder_key=None, sign_alg=None, add_decoy_claims: bool = True, serialization_format: str = "compact", additional_headers: dict = {}): + """ + Class for issue SD-JWT of TrustChain. + """ + def __init__( + self, + user_claims: Dict[str, Any], + issuer_key: dict, + holder_key: dict | None = None, + sign_alg: str | None = None, + add_decoy_claims: bool = True, + serialization_format: str = "compact", + additional_headers: dict = {} + ) -> None: + """ + Crate an istance of TrustChainSDJWTIssuer. + + :param user_claims: the claims of the SD-JWT. + :type user_claims: dict + :param issuer_key: the issuer key. + :type issuer_key: dict + :param holder_key: the holder key. + :type holder_key: dict | None + :param sign_alg: the signing algorithm. + :type sign_alg: str | None + :param add_decoy_claims: if True add decoy claims. + :type add_decoy_claims: bool + :param serialization_format: the serialization format. + :type serialization_format: str + :param additional_headers: additional headers. + :type additional_headers: dict + """ + self.additional_headers = additional_headers - sign_alg = DEFAULT_SIG_KTY_MAP[issuer_key.kty] + sign_alg = sign_alg if sign_alg else DEFAULT_SIG_KTY_MAP[issuer_key.kty] super().__init__( user_claims, @@ -39,6 +71,9 @@ def __init__(self, user_claims: Dict, issuer_key, holder_key=None, sign_alg=None ) def _create_signed_jws(self): + """ + Creates the signed JWS. + """ self.sd_jwt = JWS(payload=dumps(self.sd_jwt_payload)) _protected_headers = {"alg": self._sign_alg} @@ -66,9 +101,19 @@ def _create_signed_jws(self): self.serialized_sd_jwt = dumps(jws_content) -def _serialize_key(key, **kwargs): +def _serialize_key( + key: RSAKey | ECKey | JWK | dict, + **kwargs + ) -> dict: + """ + Serialize a key into dict. + + :param key: the key to serialize. + :type key: RSAKey | ECKey | JWK | dict - if isinstance(key, cryptojwt.jwk.rsa.RSAKey): + :returns: the serialized key into a dict. + """ + if isinstance(key, RSAKey) or isinstance(key, ECKey): key = key.serialize() elif isinstance(key, JWK): key = key.as_dict() @@ -79,7 +124,19 @@ def _serialize_key(key, **kwargs): return key -def pk_encode_int(i, bit_size=None): +def pk_encode_int(i: str, bit_size: int = None) -> str: + """ + Encode an integer as a base64url string with padding. + + :param i: the integer to encode. + :type i: str + :param bit_size: the bit size of the integer. + :type bit_size: int + + :returns: the encoded integer. + :rtype: str + """ + extend = 0 if bit_size is not None: extend = ((bit_size + 7) // 8) * 2 @@ -92,7 +149,22 @@ def pk_encode_int(i, bit_size=None): return base64url_encode(unhexlify(extend * '0' + hexi)) -def import_pyca_pri_rsa(key, **params): +def import_pyca_pri_rsa(key: _RSAPrivateKey, **params) -> jwcrypto.jwk.JWK: + """ + Import a private RSA key from a PyCA object. + + :param key: the key to import. + :type key: RSAKey | ECKey + + :raises ValueError: if the key is not a PyCA RSAKey object. + + :returns: the imported key. + :rtype: RSAKey + """ + + if not isinstance(key, _RSAPrivateKey): + raise ValueError("key must be a ssl RSAPrivateKey object") + pn = key.private_numbers() params.update( kty='RSA', @@ -108,7 +180,19 @@ def import_pyca_pri_rsa(key, **params): return jwcrypto.jwk.JWK(**params) -def _adapt_keys(issuer_key: JWK, holder_key: JWK): +def _adapt_keys(issuer_key: JWK, holder_key: JWK) -> dict: + """ + Adapt the keys to the SD-JWT library. + + :param issuer_key: the issuer key. + :type issuer_key: JWK + :param holder_key: the holder key. + :type holder_key: JWK + + :returns: the adapted keys as a dict. + :rtype: dict + """ + # _iss_key = issuer_key.key.serialize(private=True) # _iss_key['key_ops'] = 'sign' _issuer_key = import_pyca_pri_rsa( @@ -124,11 +208,45 @@ def _adapt_keys(issuer_key: JWK, holder_key: JWK): ) -def load_specification_from_yaml_string(yaml_specification: str): +def load_specification_from_yaml_string(yaml_specification: str) -> dict: + """ + Load a specification from a yaml string. + + :param yaml_specification: the yaml string. + :type yaml_specification: str + + :returns: the specification as a dict. + :rtype: dict + """ + return _yaml_load_specification(StringIO(yaml_specification)) -def issue_sd_jwt(specification: dict, settings: dict, issuer_key: JWK, holder_key: JWK, trust_chain: list[str] | None = None) -> str: +def issue_sd_jwt( + specification: Dict[str, Any], + settings: dict, + issuer_key: JWK, + holder_key: JWK, + trust_chain: list[str] | None = None + ) -> str: + """ + Issue a SD-JWT. + + :param specification: the specification of the SD-JWT. + :type specification: Dict[str, Any] + :param settings: the settings of the SD-JWT. + :type settings: dict + :param issuer_key: the issuer key. + :type issuer_key: JWK + :param holder_key: the holder key. + :type holder_key: JWK + :param trust_chain: the trust chain. + :type trust_chain: list[str] | None + + :returns: the issued SD-JWT. + :rtype: str + """ + claims = { "iss": settings["issuer"], "iat": iat_now(), @@ -153,7 +271,23 @@ def issue_sd_jwt(specification: dict, settings: dict, issuer_key: JWK, holder_ke return {"jws": sdjwt_at_issuer.serialized_sd_jwt, "issuance": sdjwt_at_issuer.sd_jwt_issuance} -def _cb_get_issuer_key(issuer: str, settings: dict, adapted_keys: dict, *args, **kwargs): +def _cb_get_issuer_key(issuer: str, settings: dict, adapted_keys: dict, *args, **kwargs) -> JWK: + """ + Helper function for get the issuer key. + + :param issuer: the issuer. + :type issuer: str + :param settings: the settings of SD-JWT. + :type settings: dict + :param adapted_keys: the adapted keys. + :type adapted_keys: dict + + :raises Exception: if the issuer is unknown. + + :returns: the issuer key. + :rtype: JWK + """ + if issuer == settings["issuer"]: return adapted_keys["issuer_public_key"] else: @@ -166,6 +300,20 @@ def verify_sd_jwt( holder_key: JWK, settings: dict = {'key_binding': True} ) -> (list | dict | Any): + """ + Verify a SD-JWT. + + :param sd_jwt_presentation: the SD-JWT to verify. + :type sd_jwt_presentation: str + :param issuer_key: the issuer key. + :type issuer_key: JWK + :param holder_key: the holder key. + :type holder_key: JWK + :param settings: the settings of SD-JWT. + + :returns: the verified payload. + :rtype: list | dict | Any + """ settings.update( { From 0247bc62c4cec2bab5f6945c54ceaf4d8765f475 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 20 Dec 2023 12:16:36 +0100 Subject: [PATCH 68/88] feat: added specialized classes for JWK --- pyeudiw/jwk/__init__.py | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/pyeudiw/jwk/__init__.py b/pyeudiw/jwk/__init__.py index e7fb4914..b5f494cf 100644 --- a/pyeudiw/jwk/__init__.py +++ b/pyeudiw/jwk/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from typing import Union @@ -6,14 +8,14 @@ from cryptojwt.jwk.jwk import key_from_jwk_dict from cryptojwt.jwk.rsa import new_rsa_key -from .exceptions import InvalidKid, KidNotFoundError +from .exceptions import InvalidKid, KidNotFoundError, InvalidJwk KEY_TYPES_FUNC = dict( EC=new_ec_key, RSA=new_rsa_key ) -class JWK(): +class JWK: """ The class representing a JWK istance """ @@ -117,6 +119,34 @@ def as_dict(self) -> dict: def __repr__(self): # private part! return self.as_json() + +class RSAJWK(JWK): + def __init__(self, key: dict | None = None, hash_func: str = "SHA-256") -> None: + super().__init__(key, "RSA", hash_func, None) + +class ECJWK(JWK): + def __init__(self, key: dict | None = None, hash_func: str = "SHA-256", ec_crv: str = "P-256") -> None: + super().__init__(key, "EC", hash_func, ec_crv) + +def jwk_form_dict(key: dict, hash_func: str = "SHA-256") -> RSAJWK | ECJWK: + """ + Returns a JWK instance from a dict. + + :param key: a dict that represents the key. + :type key: dict + + :returns: a JWK instance. + :rtype: JWK + """ + _kty = key.get('kty', None) + + if _kty == None or _kty not in ['EC', 'RSA']: + raise InvalidJwk("Invalid JWK") + elif _kty == "RSA": + return RSAJWK(key, hash_func) + else: + ec_crv = key.get('crv', "P-256") + return ECJWK(key, hash_func, ec_crv) def find_jwk(kid: str, jwks: list[dict], as_dict: bool=True) -> dict | JWK: """ From 93489edb73b554d2e2cce859bba2f7f0e61cbc20 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 20 Dec 2023 12:16:53 +0100 Subject: [PATCH 69/88] feat: added error type --- pyeudiw/jwk/exceptions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyeudiw/jwk/exceptions.py b/pyeudiw/jwk/exceptions.py index b3a84613..056744f7 100644 --- a/pyeudiw/jwk/exceptions.py +++ b/pyeudiw/jwk/exceptions.py @@ -10,3 +10,6 @@ class InvalidKid(Exception): class JwkError(Exception): pass + +class InvalidJwk(Exception): + pass \ No newline at end of file From 6b6ea9f2d572ee9cf36d56059fb30c7591f8780a Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 20 Dec 2023 12:17:46 +0100 Subject: [PATCH 70/88] chore: moved file base_logger.py --- pyeudiw/satosa/backend.py | 2 +- pyeudiw/satosa/dpop.py | 2 +- pyeudiw/{satosa => tools}/base_logger.py | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename pyeudiw/{satosa => tools}/base_logger.py (100%) diff --git a/pyeudiw/satosa/backend.py b/pyeudiw/satosa/backend.py index 2cd3f245..5fea7b9b 100644 --- a/pyeudiw/satosa/backend.py +++ b/pyeudiw/satosa/backend.py @@ -41,7 +41,7 @@ from .exceptions import HTTPError from .base_http_error_handler import BaseHTTPErrorHandler -from .base_logger import BaseLogger +from pyeudiw.tools.base_logger import BaseLogger class OpenID4VPBackend(BackendModule, BackendTrust, BackendDPoP, BaseHTTPErrorHandler, BaseLogger): """ diff --git a/pyeudiw/satosa/dpop.py b/pyeudiw/satosa/dpop.py index 0abc0ef6..778b0696 100644 --- a/pyeudiw/satosa/dpop.py +++ b/pyeudiw/satosa/dpop.py @@ -9,7 +9,7 @@ from satosa.context import Context from pydantic import ValidationError -from .base_logger import BaseLogger +from pyeudiw.tools.base_logger import BaseLogger from .base_http_error_handler import BaseHTTPErrorHandler class BackendDPoP(BaseHTTPErrorHandler, BaseLogger): diff --git a/pyeudiw/satosa/base_logger.py b/pyeudiw/tools/base_logger.py similarity index 100% rename from pyeudiw/satosa/base_logger.py rename to pyeudiw/tools/base_logger.py From 0853b82778818ebd23ea33553b3d6ac75d364b41 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 20 Dec 2023 13:22:54 +0100 Subject: [PATCH 71/88] chore: fixed import --- pyeudiw/satosa/base_http_error_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyeudiw/satosa/base_http_error_handler.py b/pyeudiw/satosa/base_http_error_handler.py index 785042c5..cf7cb83f 100644 --- a/pyeudiw/satosa/base_http_error_handler.py +++ b/pyeudiw/satosa/base_http_error_handler.py @@ -1,5 +1,5 @@ from satosa.context import Context -from .base_logger import BaseLogger +from pyeudiw.tools.base_logger import BaseLogger from .exceptions import EmptyHTTPError from pyeudiw.satosa.response import JsonResponse From 7112f4ed80ef74c260566a559ccd62b1cbf94601 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 20 Dec 2023 13:23:14 +0100 Subject: [PATCH 72/88] chore: fixed import --- pyeudiw/satosa/trust.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyeudiw/satosa/trust.py b/pyeudiw/satosa/trust.py index ffaea9ca..86540b31 100644 --- a/pyeudiw/satosa/trust.py +++ b/pyeudiw/satosa/trust.py @@ -14,7 +14,7 @@ from pyeudiw.trust import TrustEvaluationHelper from pyeudiw.trust.trust_anchors import update_trust_anchors_ecs -from .base_logger import BaseLogger +from pyeudiw.tools.base_logger import BaseLogger class BackendTrust(BaseLogger): """ From 516cc9547a5ef7ec1eaafcf235e488ce2fe6674d Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 20 Dec 2023 13:23:58 +0100 Subject: [PATCH 73/88] feat: added BaseDB class and it's documentation --- pyeudiw/storage/base_db.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 pyeudiw/storage/base_db.py diff --git a/pyeudiw/storage/base_db.py b/pyeudiw/storage/base_db.py new file mode 100644 index 00000000..1d3abd13 --- /dev/null +++ b/pyeudiw/storage/base_db.py @@ -0,0 +1,22 @@ +class BaseDB: + """ + Interface class for database storage. + """ + + def _connect(self) -> None: + """ + Connect to the database server. + + :raises ConnectionFailure: if the connection fails. + + :returns: None + """ + raise NotImplementedError() + + def close(self) -> None: + """ + Close the connection to the storage. + + :returns: None + """ + raise NotImplementedError() \ No newline at end of file From 989fb7b6d7e0ba578b8bf6b379e046e5a24ef685 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 20 Dec 2023 13:25:29 +0100 Subject: [PATCH 74/88] docs: documented content of base_cache.py --- pyeudiw/storage/base_cache.py | 36 ++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/pyeudiw/storage/base_cache.py b/pyeudiw/storage/base_cache.py index a29afeed..79f2daf4 100644 --- a/pyeudiw/storage/base_cache.py +++ b/pyeudiw/storage/base_cache.py @@ -7,12 +7,46 @@ class RetrieveStatus(Enum): ADDED = 1 -class BaseCache(): + """ + Interface class for cache storage. + """ + def try_retrieve(self, object_name: str, on_not_found: Callable[[], str]) -> tuple[dict, RetrieveStatus]: + """ + Try to retrieve an object from the cache. If the object is not found, call the on_not_found function. + + :param object_name: the name of the object to retrieve. + :type object_name: str + :param on_not_found: the function to call if the object is not found. + :type on_not_found: Callable[[], str] + + :returns: a tuple with the retrieved object and a status. + :rtype: tuple[dict, RetrieveStatus] + """ raise NotImplementedError() def overwrite(self, object_name: str, value_gen_fn: Callable[[], str]) -> dict: + """ + Overwrite an object in the cache. + + :param object_name: the name of the object to overwrite. + :type object_name: str + :param value_gen_fn: the function to call to generate the new value. + :type value_gen_fn: Callable[[], str] + + :returns: the overwritten object. + :rtype: dict + """ raise NotImplementedError() def set(self, data: dict) -> dict: + """ + Set an object in the cache. + + :param data: the data to set. + :type data: dict + + :returns: the setted object. + :rtype: dict + """ raise NotImplementedError() From df983ec2294f46ea94101abcdb53c4347f83e179 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 20 Dec 2023 13:26:29 +0100 Subject: [PATCH 75/88] feat: added inheritance with BaseDB --- pyeudiw/storage/base_cache.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyeudiw/storage/base_cache.py b/pyeudiw/storage/base_cache.py index 79f2daf4..6292054c 100644 --- a/pyeudiw/storage/base_cache.py +++ b/pyeudiw/storage/base_cache.py @@ -1,5 +1,6 @@ from enum import Enum from typing import Callable +from .base_db import BaseDB class RetrieveStatus(Enum): @@ -7,6 +8,7 @@ class RetrieveStatus(Enum): ADDED = 1 +class BaseCache(BaseDB): """ Interface class for cache storage. """ From 8dfd086f9edb52db1d8e3bd680a3edd974ac8302 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 20 Dec 2023 13:27:37 +0100 Subject: [PATCH 76/88] feat: added inheritance of BaseDB --- pyeudiw/storage/base_storage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyeudiw/storage/base_storage.py b/pyeudiw/storage/base_storage.py index f97eaf96..b1085425 100644 --- a/pyeudiw/storage/base_storage.py +++ b/pyeudiw/storage/base_storage.py @@ -1,6 +1,7 @@ import datetime from enum import Enum from typing import Union +from .base_db import BaseDB class TrustType(Enum): X509 = 0 @@ -21,8 +22,7 @@ class TrustType(Enum): TrustType.FEDERATION: "entity_configuration" } -class BaseStorage(object): - def init_session(self, document_id: str, dpop_proof: dict, attestation: dict): +class BaseStorage(BaseDB): raise NotImplementedError() def is_connected(self) -> bool: From 872ac119ac6facad6aff026b274ffa3d87204135 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 20 Dec 2023 13:28:10 +0100 Subject: [PATCH 77/88] docs: documented content of base_storage.py --- pyeudiw/storage/base_storage.py | 267 ++++++++++++++++++++++++++++++-- 1 file changed, 253 insertions(+), 14 deletions(-) diff --git a/pyeudiw/storage/base_storage.py b/pyeudiw/storage/base_storage.py index b1085425..348f4ef5 100644 --- a/pyeudiw/storage/base_storage.py +++ b/pyeudiw/storage/base_storage.py @@ -1,6 +1,8 @@ import datetime from enum import Enum from typing import Union +from pymongo.results import UpdateResult + from .base_db import BaseDB class TrustType(Enum): @@ -23,68 +25,305 @@ class TrustType(Enum): } class BaseStorage(BaseDB): - raise NotImplementedError() + """ + Interface class for storage. + """ - def is_connected(self) -> bool: - raise NotImplementedError() + def init_session(self, document_id: str, dpop_proof: dict, attestation: dict) -> str: + """ + Initialize a session. - def close(self) -> None: + :param document_id: the document id. + :type document_id: str + :param dpop_proof: the dpop proof. + :type dpop_proof: dict + :param attestation: the attestation. + """ raise NotImplementedError() - def add_dpop_proof_and_attestation(self, document_id, dpop_proof: dict, attestation: dict): + def add_dpop_proof_and_attestation(self, document_id, dpop_proof: dict, attestation: dict) -> UpdateResult: + """ + Add a dpop proof and an attestation to the session. + + :param document_id: the document id. + :type document_id: str + :param dpop_proof: the dpop proof. + :type dpop_proof: dict + :param attestation: the attestation. + :type attestation: dict + + :returns: the result of the update operation. + :rtype: UpdateResult + """ raise NotImplementedError() - def set_finalized(self, document_id: str): + def set_finalized(self, document_id: str) -> UpdateResult: + """ + Set the session as finalized. + + :param document_id: the document id. + :type document_id: str + + :returns: the result of the update operation. + :rtype: UpdateResult + """ + raise NotImplementedError() - def update_request_object(self, document_id: str, request_object: dict) -> int: + def update_request_object(self, document_id: str, request_object: dict) -> UpdateResult: + """ + Update the request object of the session. + + :param document_id: the document id. + :type document_id: str + :param request_object: the request object. + :type request_object: dict + + :returns: the result of the update operation. + :rtype: UpdateResult + """ raise NotImplementedError() - def update_response_object(self, nonce: str, state: str, response_object: dict) -> int: + def update_response_object(self, nonce: str, state: str, response_object: dict) -> UpdateResult: + """ + Update the response object of the session. + + :param nonce: the nonce. + :type nonce: str + :param state: the state. + :type state: str + :param response_object: the response object. + :type response_object: dict + + :returns: the result of the update operation. + :rtype: UpdateResult + """ raise NotImplementedError() def get_trust_attestation(self, entity_id: str) -> Union[dict, None]: + """ + Get a trust attestation. + + :param entity_id: the entity id. + :type entity_id: str + + :returns: the trust attestation. + :rtype: Union[dict, None] + """ raise NotImplementedError() def get_trust_anchor(self, entity_id: str) -> Union[dict, None]: + """ + Get a trust anchor. + + :param entity_id: the entity id. + :type entity_id: str + + :returns: the trust anchor. + :rtype: Union[dict, None] + """ raise NotImplementedError() - def has_trust_attestation(self, entity_id: str): + def has_trust_attestation(self, entity_id: str) -> bool: + """ + Check if a trust attestation exists. + + :param entity_id: the entity id. + :type entity_id: str + + + :returns: True if the trust attestation exists, False otherwise. + :rtype: bool + """ raise NotImplementedError() - def has_trust_anchor(self, entity_id: str): + def has_trust_anchor(self, entity_id: str) -> bool: + """ + Check if a trust anchor exists. + + :param entity_id: the entity id. + :type entity_id: str + + :returns: True if the trust anchor exists, False otherwise. + :rtype: bool + """ raise NotImplementedError() def add_trust_attestation(self, entity_id: str, attestation: list[str], exp: datetime, trust_type: TrustType) -> str: + """ + Add a trust attestation. + + :param entity_id: the entity id. + :type entity_id: str + :param attestation: the attestation. + :type attestation: list[str] + :param exp: the expiration date. + :type exp: datetime + :param trust_type: the trust type. + :type trust_type: TrustType + + :returns: the document id. + :rtype: str + """ raise NotImplementedError() def add_trust_attestation_metadata(self, entity_id: str, metadata_type: str, metadata: dict) -> str: + """ + Add a trust attestation metadata. + + :param entity_id: the entity id. + :type entity_id: str + :param metadata_type: the metadata type. + :type metadata_type: str + + :returns: the document id. + :rtype: str + """ raise NotImplementedError() def add_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datetime, trust_type: TrustType): + """ + Add a trust anchor. + + :param entity_id: the entity id. + :type entity_id: str + :param entity_configuration: the entity configuration. + :type entity_configuration: str + :param exp: the expiration date. + :type exp: datetime + :param trust_type: the trust type. + :type trust_type: TrustType + + :returns: the document id. + :rtype: str + """ raise NotImplementedError() def update_trust_attestation(self, entity_id: str, attestation: list[str], exp: datetime, trust_type: TrustType) -> str: + """ + Update a trust attestation. + + :param entity_id: the entity id. + :type entity_id: str + :param attestation: the attestation. + :type attestation: list[str] + :param exp: the expiration date. + :type exp: datetime + :param trust_type: the trust type. + :type trust_type: TrustType + + :returns: the document id. + :rtype: str + """ raise NotImplementedError() def update_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datetime, trust_type: TrustType) -> str: + """ + Update a trust anchor. + + :param entity_id: the entity id. + :type entity_id: str + :param entity_configuration: the entity configuration. + :type entity_configuration: str + :param exp: the expiration date. + :type exp: datetime + :param trust_type: the trust type. + :type trust_type: TrustType + + :returns: the document id. + :rtype: str + """ raise NotImplementedError() def exists_by_state_and_session_id(self, state: str, session_id: str = "") -> bool: + """ + Check if a session exists by state and session id. + + :param state: the state. + :type state: str + :param session_id: the session id. + :type session_id: str + + :returns: True if the session exists, False otherwise. + :rtype: bool + """ raise NotImplementedError() - def get_by_state(self, state: str): + def get_by_state(self, state: str) -> Union[dict, None]: + """ + Get a session by state. + + :param state: the state. + :type state: str + + :returns: the session. + :rtype: Union[dict, None] + """ raise NotImplementedError() - def get_by_nonce_state(self, state: str, nonce: str): + def get_by_nonce_state(self, state: str, nonce: str) -> Union[dict, None]: + """ + Get a session by nonce and state. + + :param state: the state. + :type state: str + :param nonce: the nonce. + :type nonce: str + + :returns: the session. + :rtype: Union[dict, None] + """ raise NotImplementedError() - def get_by_state_and_session_id(self, state: str, session_id: str = ""): + def get_by_state_and_session_id(self, state: str, session_id: str = "") -> Union[dict, None]: + """ + Get a session by state and session id. + + :param state: the state. + :type state: str + :param session_id: the session id. + :type session_id: str + + :returns: the session. + :rtype: Union[dict, None] + """ raise NotImplementedError() - def get_by_session_id(self, session_id: str): + def get_by_session_id(self, session_id: str) -> Union[dict, None]: + """ + Get a session by session id. + + :param session_id: the session id. + :type session_id: str + + :returns: the session. + :rtype: Union[dict, None] + """ raise NotImplementedError() # TODO: create add_or_update for all the write methods def add_or_update_trust_attestation(self, entity_id: str, attestation: list[str], exp: datetime) -> str: + """ + Add or update a trust attestation. + + :param entity_id: the entity id. + :type entity_id: str + :param attestation: the attestation. + :type attestation: list[str] + :param exp: the expiration date. + :type exp: datetime + + :returns: the document id. + :rtype: str + """ + raise NotImplementedError() + + @property + def is_connected(self) -> bool: + """ + Check if the storage is connected. + + :returns: True if the storage is connected, False otherwise. + :rtype: bool + """ raise NotImplementedError() From 7a6c1ea4c994275c4fa21e20b88cfc5728b467a2 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 20 Dec 2023 13:28:49 +0100 Subject: [PATCH 78/88] feat: documentation and refactoring --- pyeudiw/storage/db_engine.py | 268 +++++++++++++++++++++++------------ 1 file changed, 177 insertions(+), 91 deletions(-) diff --git a/pyeudiw/storage/db_engine.py b/pyeudiw/storage/db_engine.py index b1dbc9cb..c2d07614 100644 --- a/pyeudiw/storage/db_engine.py +++ b/pyeudiw/storage/db_engine.py @@ -1,8 +1,7 @@ import uuid -import logging import importlib from datetime import datetime -from typing import Callable, Union +from typing import Callable, Union, Tuple, Dict from pyeudiw.storage.base_cache import BaseCache, RetrieveStatus from pyeudiw.storage.base_storage import BaseStorage, TrustType from pyeudiw.storage.exceptions import ( @@ -10,14 +9,24 @@ StorageWriteError, EntryNotFound ) +from pyeudiw.tools.base_logger import BaseLogger -logger = logging.getLogger(__name__) +from .base_db import BaseDB - -class DBEngine(): +class DBEngine(BaseStorage, BaseCache, BaseLogger): + """ + DB Engine class. + """ def __init__(self, config: dict): - self.caches = [] - self.storages = [] + """ + Create a DB Engine instance. + + :param config: the configuration of all the DBs. + :type config: dict + """ + self.caches: list[Tuple[str, BaseCache]] = [] + self.storages: list[Tuple[str, BaseStorage]] = [] + for db_name, db_conf in config.items(): storage_instance, cache_instance = self._handle_instance(db_conf) @@ -27,25 +36,6 @@ def __init__(self, config: dict): if cache_instance: self.caches.append((db_name, cache_instance)) - def _handle_instance(self, instance: dict) -> dict[BaseStorage | None, BaseCache | None]: - cache_conf = instance.get("cache", None) - storage_conf = instance.get("storage", None) - - storage_instance = None - if storage_conf: - module = importlib.import_module(storage_conf["module"]) - instance_class = getattr(module, storage_conf["class"]) - storage_instance = instance_class( - **storage_conf.get("init_params", {})) - - cache_instance = None - if cache_conf: - module = importlib.import_module(cache_conf["module"]) - instance_class = getattr(module, cache_conf["class"]) - cache_instance = instance_class(**cache_conf["init_params"]) - - return storage_instance, cache_instance - def init_session(self, session_id: str, state: str) -> str: document_id = str(uuid.uuid4()) for db_name, storage in self.storages: @@ -54,49 +44,38 @@ def init_session(self, session_id: str, state: str) -> str: document_id, session_id=session_id, state=state ) except StorageWriteError as e: - logger.critical( - f"Error while initializing session with document_id {document_id}. " - f"Cannot write document with id {document_id} on {db_name}: " - f"{e.__class__.__name__}: {e}" + self._log_critical( + e.__class__.__name__, + ( + f"Error while initializing session with document_id {document_id}. " + f"Cannot write document with id {document_id} on {db_name}: {e}" + ) ) raise e return document_id - @property - def is_connected(self): - _connected = False - _cons = {} - for db_name, storage in self.storages: - try: - _connected = storage.is_connected - _cons[db_name] = _connected - except Exception as e: - logger.debug( - f"Error while checking db engine connection on {db_name}. " - f"{e.__class__.__name__}: {e}" - ) + def close(self): + self._close_list(self.storages) + self._close_list(self.caches) - if True in _cons.values() and not all(_cons.values()): - logger.warning( - f"Not all the storage are found available, storages misalignment: " - f"{_cons}" - ) + def write(self, method: str, *args, **kwargs): + """ + Perform a write operation on the storages. - return _connected + :param method: the method to call. + :type method: str + :param args: the arguments to pass to the method. + :type args: Any + :param kwargs: the keyword arguments to pass to the method. + :type kwargs: Any - def close(self): - for db_name, storage in self.storages: - try: - storage.close() - except Exception as e: - logger.critical( - f"Error while closing db engine {db_name}. " - f"{e.__class__.__name__}: {e}" - ) - raise e + :raises StorageWriteError: if the write operation fails on all the storages. + + :returns: the number of replicas where the write operation is successful. + :rtype: int + """ - def write(self, method: str, *args, **kwargs): replica_count = 0 _err_msg = f"Cannot apply write method '{method}' with {args} {kwargs}" for db_name, storage in self.storages: @@ -104,8 +83,10 @@ def write(self, method: str, *args, **kwargs): getattr(storage, method)(*args, **kwargs) replica_count += 1 except Exception as e: - logger.critical( - f"Error {_err_msg} on {db_name} {storage}: {str(e)}") + self._log_critical( + e.__class__.__name__, + f"Error {_err_msg} on {db_name} {storage}: {str(e)}" + ) if not replica_count: raise StorageWriteError(_err_msg) @@ -129,7 +110,23 @@ def update_request_object(self, document_id: str, request_object: dict) -> int: def update_response_object(self, nonce: str, state: str, response_object: dict) -> int: return self.write("update_response_object", nonce, state, response_object) - def get(self, method: str, *args, **kwargs): + def get(self, method: str, *args, **kwargs) -> Union[dict, None]: + """ + Perform a get operation on the storages. + + :param method: the method to call. + :type method: str + :param args: the arguments to pass to the method. + :type args: Any + :param kwargs: the keyword arguments to pass to the method. + :type kwargs: Any + + :raises EntryNotFound: if the entry is not found on any storage. + + :returns: the result of the first elment found on DBs. + :rtype: Union[dict, None] + """ + for db_name, storage in self.storages: try: res = getattr(storage, method)(*args, **kwargs) @@ -137,7 +134,8 @@ def get(self, method: str, *args, **kwargs): return res except EntryNotFound as e: - logger.critical( + self._log_debug( + e.__class__.__name__, f"Cannot find result by method {method} on {db_name} with {args} {kwargs}: {str(e)}" ) @@ -149,11 +147,11 @@ def get_trust_attestation(self, entity_id: str) -> Union[dict, None]: def get_trust_anchor(self, entity_id: str) -> Union[dict, None]: return self.get("get_trust_anchor", entity_id) - def has_trust_attestation(self, entity_id: str): - return self.get_trust_attestation(entity_id) + def has_trust_attestation(self, entity_id: str) -> bool: + return self.get_trust_attestation(entity_id) != None - def has_trust_anchor(self, entity_id: str): - return self.get_anchor(entity_id) + def has_trust_anchor(self, entity_id: str) -> bool: + return self.get_trust_anchor(entity_id) != None def add_trust_attestation(self, entity_id: str, attestation: list[str], exp: datetime, trust_type: TrustType = TrustType.FEDERATION) -> str: return self.write("add_trust_attestation", entity_id, attestation, exp, trust_type) @@ -161,7 +159,7 @@ def add_trust_attestation(self, entity_id: str, attestation: list[str], exp: dat def add_trust_attestation_metadata(self, entity_id: str, metadat_type: str, metadata: dict) -> str: return self.write("add_trust_attestation_metadata", entity_id, metadat_type, metadata) - def add_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datetime, trust_type: TrustType = TrustType.FEDERATION): + def add_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datetime, trust_type: TrustType = TrustType.FEDERATION) -> str: return self.write("add_trust_anchor", entity_id, entity_configuration, exp, trust_type) def update_trust_attestation(self, entity_id: str, attestation: list[str], exp: datetime, trust_type: TrustType = TrustType.FEDERATION) -> str: @@ -177,20 +175,6 @@ def add_or_update_trust_attestation(self, entity_id: str, attestation: list[str] def update_trust_anchor(self, entity_id: str, entity_configuration: dict, exp: datetime, trust_type: TrustType = TrustType.FEDERATION) -> str: return self.write("update_trust_anchor", entity_id, entity_configuration, exp, trust_type) - def _cache_try_retrieve(self, object_name: str, on_not_found: Callable[[], str]) -> tuple[dict, RetrieveStatus, int]: - for i, cache in enumerate(self.caches): - try: - cache_object, status = cache.try_retrieve( - object_name, on_not_found) - return cache_object, status, i - except Exception: - logger.critical( - f"Cannot retrieve or write cache object with identifier {object_name} on cache database {i}" - ) - raise ConnectionRefusedError( - "Cannot write cache object on any instance" - ) - def try_retrieve(self, object_name: str, on_not_found: Callable[[], str]) -> dict: # if no cache instance exist return the object if len(self.caches): @@ -210,8 +194,9 @@ def try_retrieve(self, object_name: str, on_not_found: Callable[[], str]) -> dic for cache_name, cache in replica_instances: try: cache.set(cache_object) - except Exception: - logger.critical( + except Exception as e: + self._log_critical( + e.__class__.__name__, f"Cannot replicate cache object with identifier {object_name} on cache {cache_name}" ) @@ -222,8 +207,9 @@ def overwrite(self, object_name: str, value_gen_fn: Callable[[], str]) -> dict: cache_object = None try: cache_object = cache.overwrite(object_name, value_gen_fn) - except Exception: - logger.critical( + except Exception as e: + self._log_critical( + e.__class__.__name__, f"Cannot overwrite cache object with identifier {object_name} on cache {cache_name}" ) return cache_object @@ -236,14 +222,114 @@ def exists_by_state_and_session_id(self, state: str, session_id: str = "") -> bo return True return False - def get_by_state(self, state: str): + def get_by_state(self, state: str) -> Union[dict, None]: return self.get_by_state_and_session_id(state=state) - def get_by_nonce_state(self, state: str, nonce: str): + def get_by_nonce_state(self, state: str, nonce: str) -> Union[dict, None]: return self.get('get_by_nonce_state', state=state, nonce=nonce) - def get_by_state_and_session_id(self, state: str, session_id: str = ""): + def get_by_state_and_session_id(self, state: str, session_id: str = "") -> Union[dict, None]: return self.get("get_by_state_and_session_id", state, session_id) - def get_by_session_id(self, session_id: str): + def get_by_session_id(self, session_id: str) -> Union[dict, None]: return self.get("get_by_session_id", session_id) + + @property + def is_connected(self): + _connected = False + _cons = {} + for db_name, storage in self.storages: + try: + _connected = storage.is_connected() + _cons[db_name] = _connected + except Exception as e: + self._log_debug( + e.__class__.__name__, + f"Error while checking db engine connection on {db_name}: {e} " + ) + + if True in _cons.values() and not all(_cons.values()): + self._log_warning( + "DB Engine", + f"Not all the storage are found available, storages misalignment: " + f"{_cons}" + ) + + return _connected + + def _cache_try_retrieve(self, object_name: str, on_not_found: Callable[[], str]) -> tuple[dict, RetrieveStatus, int]: + """ + Try to retrieve an object from the cache. If the object is not found, call the on_not_found function. + + :param object_name: the name of the object to retrieve. + :type object_name: str + :param on_not_found: the function to call if the object is not found. + :type on_not_found: Callable[[], str] + + :raises ConnectionRefusedError: if the object cannot be retrieved on any instance. + + :returns: a tuple with the retrieved object, a status and the index of the cache instance. + :rtype: tuple[dict, RetrieveStatus, int] + """ + + for i, (cache_name, cache_istance) in enumerate(self.caches): + try: + cache_object, status = cache_istance.try_retrieve( + object_name, on_not_found) + return cache_object, status, i + except Exception as e: + self._log_critical( + e.__class__.__name__, + f"Cannot retrieve cache object with identifier {object_name} on cache database {cache_name}" + ) + raise ConnectionRefusedError( + "Cannot write cache object on any instance" + ) + + def _close_list(self, db_list: list[Tuple[str,BaseDB]]) -> None: + """ + Close a list of db. + + :param db_list: the list of db to close. + :type db_list: list[Tuple[str,BaseDB]] + + :raises Exception: if an error occurs while closing a db. + """ + + for db_name, db in db_list: + try: + db.close() + except Exception as e: + self._log_critical( + e.__class__.__name__, + f"Error while closing db engine {db_name}: {e}" + ) + raise e + + def _handle_instance(self, instance: dict) -> dict[BaseStorage | None, BaseCache | None]: + """ + Handle the initialization of a storage/cache instance. + + :param instance: the instance configuration. + :type instance: dict + + :returns: a tuple with the storage and cache instance. + :rtype: tuple[BaseStorage | None, BaseCache | None] + """ + cache_conf = instance.get("cache", None) + storage_conf = instance.get("storage", None) + + storage_instance = None + if storage_conf: + module = importlib.import_module(storage_conf["module"]) + instance_class = getattr(module, storage_conf["class"]) + storage_instance = instance_class( + **storage_conf.get("init_params", {})) + + cache_instance = None + if cache_conf: + module = importlib.import_module(cache_conf["module"]) + instance_class = getattr(module, cache_conf["class"]) + cache_instance = instance_class(**cache_conf["init_params"]) + + return storage_instance, cache_instance From 3c084a583e6b9c3b279718c90738a39bdbc1d7d6 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 20 Dec 2023 13:29:28 +0100 Subject: [PATCH 79/88] feat: documentation and refactoring --- pyeudiw/storage/mongo_cache.py | 61 +++++++++++++++++++++++++--------- 1 file changed, 46 insertions(+), 15 deletions(-) diff --git a/pyeudiw/storage/mongo_cache.py b/pyeudiw/storage/mongo_cache.py index f44dc859..12a877f8 100644 --- a/pyeudiw/storage/mongo_cache.py +++ b/pyeudiw/storage/mongo_cache.py @@ -4,32 +4,40 @@ import pymongo from pyeudiw.storage.base_cache import BaseCache, RetrieveStatus +from pymongo.collection import Collection +from pymongo.mongo_client import MongoClient +from pymongo.database import Database class MongoCache(BaseCache): + """ + MongoDB cache implementation. + """ + def __init__(self, conf: dict, url: str, connection_params: dict = None) -> None: + """ + Create a MongoCache istance. + + :param conf: the configuration of the cache. + :type conf: dict + :param url: the url of the MongoDB server. + :type url: str + :param connection_params: the connection parameters. + :type connection_params: dict, optional + """ super().__init__() self.storage_conf = conf self.url = url self.connection_params = connection_params - self.client = None - self.db = None - - def _connect(self): - if not self.client or not self.client.server_info(): - self.client = pymongo.MongoClient( - self.url, **self.connection_params) - self.db = getattr(self.client, self.storage_conf["db_name"]) - self.collection = getattr(self.db, "cache_storage") + self.client: MongoClient = None + self.db: Database = None + self.collection: Collection = None - def _gen_cache_object(self, object_name: str, data: str): - return { - "object_name": object_name, - "data": data, - "creation_date": datetime.now().isoformat() - } + def close(self) -> None: + self._connect() + self.client.close() def try_retrieve(self, object_name: str, on_not_found: Callable[[], str]) -> tuple[dict, RetrieveStatus]: self._connect() @@ -72,3 +80,26 @@ def set(self, data: dict) -> dict: self._connect() return self.collection.insert_one(data) + + def _connect(self) -> None: + if not self.client or not self.client.server_info(): + self.client = pymongo.MongoClient( + self.url, **self.connection_params) + self.db = getattr(self.client, self.storage_conf["db_name"]) + self.collection = getattr(self.db, "cache_storage") + + def _gen_cache_object(self, object_name: str, data: str) -> dict: + """ + Helper function to generate a cache object. + + :param object_name: the name of the object. + :type object_name: str + :param data: the data to store. + :type data: str + """ + + return { + "object_name": object_name, + "data": data, + "creation_date": datetime.now().isoformat() + } From 3a741d984614f107c0689903a5add20cc7b6f026 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Wed, 20 Dec 2023 13:30:05 +0100 Subject: [PATCH 80/88] fix: varius minor fixs --- pyeudiw/storage/mongo_storage.py | 34 ++++++++++++++++---------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/pyeudiw/storage/mongo_storage.py b/pyeudiw/storage/mongo_storage.py index 541c5471..eda332cc 100644 --- a/pyeudiw/storage/mongo_storage.py +++ b/pyeudiw/storage/mongo_storage.py @@ -27,7 +27,7 @@ def __init__(self, conf: dict, url: str, connection_params: dict = {}) -> None: self.db = None @property - def is_connected(self): + def is_connected(self) -> bool: if not self.client: return False try: @@ -80,7 +80,7 @@ def get_by_nonce_state(self, nonce: str, state: str | None) -> dict: return document - def get_by_session_id(self, session_id: str): + def get_by_session_id(self, session_id: str) -> Union[dict, None]: self._connect() query = {"session_id": session_id} document = self.sessions.find_one(query) @@ -92,7 +92,7 @@ def get_by_session_id(self, session_id: str): return document - def get_by_state_and_session_id(self, state: str, session_id: str = ""): + def get_by_state_and_session_id(self, state: str, session_id: str = "") -> Union[dict, None]: self._connect() query = {"state": state} if session_id: @@ -125,7 +125,7 @@ def init_session(self, document_id: str, session_id: str, state: str) -> str: return document_id - def add_dpop_proof_and_attestation(self, document_id: str, dpop_proof: dict, attestation: dict): + def add_dpop_proof_and_attestation(self, document_id: str, dpop_proof: dict, attestation: dict) -> UpdateResult: self._connect() update_result: UpdateResult = self.sessions.update_one( {"document_id": document_id}, @@ -142,7 +142,7 @@ def add_dpop_proof_and_attestation(self, document_id: str, dpop_proof: dict, att return update_result - def update_request_object(self, document_id: str, request_object: dict): + def update_request_object(self, document_id: str, request_object: dict) -> UpdateResult: self.get_by_id(document_id) documentStatus = self.sessions.update_one( {"document_id": document_id}, @@ -177,7 +177,7 @@ def set_finalized(self, document_id: str): ) return update_result - def update_response_object(self, nonce: str, state: str, internal_response: dict): + def update_response_object(self, nonce: str, state: str, internal_response: dict) -> UpdateResult: document = self.get_by_nonce_state(nonce, state) document_id = document["_id"] document_status = self.sessions.update_one( @@ -190,24 +190,24 @@ def update_response_object(self, nonce: str, state: str, internal_response: dict return document_status - def _get_trust_attestation(self, collection: str, entity_id: str) -> dict: + def _get_trust_attestation(self, collection: str, entity_id: str) -> dict | None: self._connect() db_collection = getattr(self, collection) return db_collection.find_one({"entity_id": entity_id}) - def get_trust_attestation(self, entity_id: str): + def get_trust_attestation(self, entity_id: str) -> dict | None: return self._get_trust_attestation("trust_attestations", entity_id) - def get_trust_anchor(self, entity_id: str): + def get_trust_anchor(self, entity_id: str) -> dict | None: return self._get_trust_attestation("trust_anchors", entity_id) - def _has_trust_attestation(self, collection: str, entity_id: str): - return self._get_trust_attestation(collection, entity_id) + def _has_trust_attestation(self, collection: str, entity_id: str) -> bool: + return self._get_trust_attestation(collection, entity_id) != None - def has_trust_attestation(self, entity_id: str): + def has_trust_attestation(self, entity_id: str) -> bool: return self._has_trust_attestation("trust_attestations", entity_id) - def has_trust_anchor(self, entity_id: str): + def has_trust_anchor(self, entity_id: str) -> bool: return self._has_trust_attestation("trust_anchors", entity_id) def _add_entry( @@ -257,7 +257,7 @@ def _update_anchor_metadata(self, entity: dict, attestation: list[str], exp: dat return entity - def add_trust_attestation(self, entity_id: str, attestation: list[str], exp: datetime, trust_type: TrustType): + def add_trust_attestation(self, entity_id: str, attestation: list[str], exp: datetime, trust_type: TrustType) -> str: entity = { "entity_id": entity_id, "federation": {}, @@ -267,7 +267,7 @@ def add_trust_attestation(self, entity_id: str, attestation: list[str], exp: dat updated_entity = self._update_attestation_metadata(entity, attestation, exp, trust_type) - self._add_entry( + return self._add_entry( "trust_attestations", entity_id, updated_entity, exp ) @@ -285,7 +285,7 @@ def add_trust_attestation_metadata(self, entity_id: str, metadata_type: str, met def add_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datetime, trust_type: TrustType): if self.has_trust_anchor(entity_id): - self.update_trust_anchor(entity_id, entity_configuration, exp, trust_type) + return self.update_trust_anchor(entity_id, entity_configuration, exp, trust_type) else: entity = { "entity_id": entity_id, @@ -294,7 +294,7 @@ def add_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datet } updated_entity = self._update_anchor_metadata(entity, entity_configuration, exp, trust_type) - self._add_entry("trust_anchors", entity_id, updated_entity, exp) + return self._add_entry("trust_anchors", entity_id, updated_entity, exp) def _update_trust_attestation(self, collection: str, entity_id: str, entity: dict) -> str: if not self._has_trust_attestation(collection, entity_id): From f766608f6541892aa9014963bcd184aeb0797c6d Mon Sep 17 00:00:00 2001 From: PascalDR Date: Thu, 21 Dec 2023 15:08:07 +0100 Subject: [PATCH 81/88] docs: documented content of mobile.py --- pyeudiw/tools/mobile.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/pyeudiw/tools/mobile.py b/pyeudiw/tools/mobile.py index 0c7372f6..b18a6f7a 100644 --- a/pyeudiw/tools/mobile.py +++ b/pyeudiw/tools/mobile.py @@ -1,7 +1,16 @@ from device_detector import DeviceDetector -def is_smartphone(useragent: str): +def is_smartphone(useragent: str) -> bool: + """Check if the useragent is a smartphone + + :param useragent: The useragent to check + :type useragent: str + :return: True if the useragent is a smartphone else False + :rtype: bool + """ + device = DeviceDetector(useragent).parse() if device.device_type() == 'smartphone': return True + return False From 693acfce4b709eadcd4d74190fdd6d6bbb6346fa Mon Sep 17 00:00:00 2001 From: PascalDR Date: Thu, 21 Dec 2023 15:08:40 +0100 Subject: [PATCH 82/88] docs: documented content of schema_utils.py --- pyeudiw/tools/schema_utils.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/pyeudiw/tools/schema_utils.py b/pyeudiw/tools/schema_utils.py index 9240de53..a8ba8757 100644 --- a/pyeudiw/tools/schema_utils.py +++ b/pyeudiw/tools/schema_utils.py @@ -13,7 +13,18 @@ ] -def check_algorithm(alg: str, info: FieldValidationInfo): +def check_algorithm(alg: str, info: FieldValidationInfo) -> None: + """ + Check if the algorithm is supported by the relaying party. + + :param alg: The algorithm to check + :type alg: str + :param info: The field validation info + :type info: FieldValidationInfo + + :raises ValueError: If the algorithm is not supported + """ + if not info.context: supported_algorithms = _default_supported_algorithms else: From 77e0bcd5f37922d49a2b8a99af1993a9918158a3 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Thu, 21 Dec 2023 15:09:07 +0100 Subject: [PATCH 83/88] docs: documented content of utils.py --- pyeudiw/tools/utils.py | 67 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 61 insertions(+), 6 deletions(-) diff --git a/pyeudiw/tools/utils.py b/pyeudiw/tools/utils.py index 895dab69..cec68985 100644 --- a/pyeudiw/tools/utils.py +++ b/pyeudiw/tools/utils.py @@ -10,7 +10,18 @@ logger = logging.getLogger(__name__) -def make_timezone_aware(dt: datetime.datetime, tz: datetime.timezone | datetime.tzinfo = datetime.timezone.utc): +def make_timezone_aware(dt: datetime.datetime, tz: datetime.timezone | datetime.tzinfo = datetime.timezone.utc) -> datetime.datetime: + """ + Make a datetime timezone aware. + + :param dt: The datetime to make timezone aware + :type dt: datetime.datetime + :param tz: The timezone to use + :type tz: datetime.timezone | datetime.tzinfo + + :returns: The timezone aware datetime + :rtype: datetime.datetime + """ if dt.tzinfo is None: return dt.replace(tzinfo=tz) else: @@ -18,16 +29,41 @@ def make_timezone_aware(dt: datetime.datetime, tz: datetime.timezone | datetime. def iat_now() -> int: + """ + Get the current timestamp in seconds. + + :returns: The current timestamp in seconds + :rtype: int + """ return int(datetime.datetime.now(datetime.timezone.utc).timestamp()) def exp_from_now(minutes: int = 33) -> int: + """ + Get the expiration timestamp in seconds for the given minutes from now. + + :param minutes: The minutes from now + :type minutes: int + + :returns: The timestamp in seconds for the given minutes from now + :rtype: int + """ now = datetime.datetime.now(datetime.timezone.utc) return int((now + datetime.timedelta(minutes=minutes)).timestamp()) -def datetime_from_timestamp(value) -> datetime.datetime: - return make_timezone_aware(datetime.datetime.fromtimestamp(value)) +def datetime_from_timestamp(timestamp: int | float) -> datetime.datetime: + """ + Get a datetime from a timestamp. + + :param value: The timestamp + :type value: int | float + + :returns: The datetime + :rtype: datetime.datetime + """ + + return make_timezone_aware(datetime.datetime.fromtimestamp(timestamp)) def get_http_url(urls: list[str] | str, httpc_params: dict, http_async: bool = True) -> list[dict]: @@ -57,9 +93,19 @@ def get_http_url(urls: list[str] | str, httpc_params: dict, http_async: bool = T return responses -def get_jwks(httpc_params: dict, metadata: dict, federation_jwks: list = []) -> dict: +def get_jwks(httpc_params: dict, metadata: dict, federation_jwks: list[dict] = []) -> dict: """ - get jwks or jwks_uri or signed_jwks_uri + Get jwks or jwks_uri or signed_jwks_uri + + :param httpc_params: parameters to perform http requests. + :type httpc_params: dict + :param metadata: metadata of the entity + :type metadata: dict + :param federation_jwks: jwks of the federation + :type federation_jwks: list + + :returns: A list of responses. + :rtype: list[dict] """ jwks_list = [] if metadata.get('jwks'): @@ -85,5 +131,14 @@ def get_jwks(httpc_params: dict, metadata: dict, federation_jwks: list = []) -> return jwks_list -def random_token(n=254): +def random_token(n=254) -> str: + """ + Generate a random token. + + :param n: The length of the token + :type n: int + + :returns: The random token + :rtype: str + """ return token_hex(n) From faf57deb73bc7de28fc5971b1466f95e78830111 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Thu, 21 Dec 2023 15:09:53 +0100 Subject: [PATCH 84/88] docs: documented content of trust_anchors.py --- pyeudiw/trust/trust_anchors.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/pyeudiw/trust/trust_anchors.py b/pyeudiw/trust/trust_anchors.py index 95e3f890..7460b757 100644 --- a/pyeudiw/trust/trust_anchors.py +++ b/pyeudiw/trust/trust_anchors.py @@ -7,7 +7,18 @@ logger = logging.getLogger(__name__) -def update_trust_anchors_ecs(trust_anchors: list, db: DBEngine, httpc_params: dict): +def update_trust_anchors_ecs(trust_anchors: list, db: DBEngine, httpc_params: dict) -> None: + """ + Update the trust anchors entity configurations. + + :param trust_anchors: The trust anchors + :type trust_anchors: list + :param db: The database engine + :type db: DBEngine + :param httpc_params: The HTTP client parameters + :type httpc_params: dict + """ + ta_ecs = get_entity_configurations( trust_anchors, httpc_params=httpc_params ) From f7faa9e4d7db45b884f930e177c4018a4907c7fd Mon Sep 17 00:00:00 2001 From: PascalDR Date: Thu, 21 Dec 2023 15:10:12 +0100 Subject: [PATCH 85/88] docs: documented content of trust_chain.py --- pyeudiw/trust/trust_chain.py | 59 +++++++++++++++++++++++++++--------- 1 file changed, 45 insertions(+), 14 deletions(-) diff --git a/pyeudiw/trust/trust_chain.py b/pyeudiw/trust/trust_chain.py index bc9c0820..7e7e5b0f 100644 --- a/pyeudiw/trust/trust_chain.py +++ b/pyeudiw/trust/trust_chain.py @@ -1,28 +1,28 @@ import logging -from typing import List from typing import Optional -from cryptojwt import KeyJar from cryptojwt.jwt import utc_time_sans_frac -from pyeudiw.jwt.utils import decode_jwt_payload +from pyeudiw.tools.base_logger import BaseLogger __author__ = "Roland Hedberg" __license__ = "Apache 2.0" __version__ = "" -logger = logging.getLogger(__name__) - -class TrustChain(object): +class TrustChain(BaseLogger): """ Class in which to store the parsed result from applying metadata policies on a metadata statement. """ - def __init__(self, exp: int = 0, - verified_chain: Optional[list] = None): + verified_chain: Optional[list] = None) -> None: """ + Create a TrustChain instance. + :param exp: Expiration time + :type exp: int + :param verified_chain: The verified chain + :type verified_chain: list """ self.anchor = "" self.iss_path = [] @@ -32,10 +32,22 @@ def __init__(self, self.verified_chain = verified_chain self.combined_policy = {} - def keys(self): + def keys(self) -> list[str]: + """ + Returns the metadata fields keys + + :return: The metadata fields keys + :rtype: list + """ return self.metadata.keys() - def items(self): + def items(self) -> list[tuple[str, dict]]: + """ + Returns the metadata fields items + + :return: The metadata fields items + :rtype: list[tuple[str, dict]] + """ return self.metadata.items() def __getitem__(self, item): @@ -50,23 +62,42 @@ def claims(self): """ return self.metadata - def is_expired(self): + def is_expired(self) -> bool: + """ + Check if the trust chain is expired. + + :return: True if the trust chain is expired else False + :rtype: bool + """ now = utc_time_sans_frac() if self.exp < now: - logger.debug(f'is_expired: {self.exp} < {now}') + self._log_debug( + "Trust chain", + f'is_expired: {self.exp} < {now}' + ) return True else: return False - def export_chain(self): + def export_chain(self) -> list: """ Exports the verified chain in such a way that it can be used as value on the trust_chain claim in an authorization or explicit registration request. - :return: + + :return: The exported chain in reverse order + :rtype: list """ _chain = self.verified_chain _chain.reverse() return _chain def set_combined_policy(self, entity_type: str, combined_policy: dict): + """ + Set the combined policy for the given entity type. + + :param entity_type: The entity type + :type entity_type: str + :param combined_policy: The combined policy + :type combined_policy: dict + """ self.combined_policy[entity_type] = combined_policy From 2427dce969a11e8ba62871f61ff4eb39d46d0275 Mon Sep 17 00:00:00 2001 From: PascalDR Date: Thu, 21 Dec 2023 15:10:35 +0100 Subject: [PATCH 86/88] docs: documented content of verify.py --- pyeudiw/x509/verify.py | 69 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/pyeudiw/x509/verify.py b/pyeudiw/x509/verify.py index 8f647e7f..a5835318 100644 --- a/pyeudiw/x509/verify.py +++ b/pyeudiw/x509/verify.py @@ -10,6 +10,15 @@ logger = logging.getLogger(__name__) def _verify_x509_certificate_chain(pems: list[str]): + """ + Verify the x509 certificate chain. + + :param pems: The x509 certificate chain + :type pems: list[str] + + :returns: True if the x509 certificate chain is valid else False + :rtype: bool + """ try: store = crypto.X509Store() @@ -32,6 +41,16 @@ def _verify_x509_certificate_chain(pems: list[str]): return False def _check_chain_len(pems: list) -> bool: + """ + Check the x509 certificate chain lenght. + + :param pems: The x509 certificate chain + :type pems: list + + :returns: True if the x509 certificate chain lenght is valid else False + :rtype: bool + """ + chain_len = len(pems) if chain_len < 2: @@ -42,6 +61,15 @@ def _check_chain_len(pems: list) -> bool: return True def _check_datetime(exp: datetime | None): + """ + Check the x509 certificate chain expiration date. + + :param exp: The x509 certificate chain expiration date + :type exp: datetime.datetime | None + + :returns: True if the x509 certificate chain expiration date is valid else False + :rtype: bool + """ if exp == None: return True @@ -53,6 +81,18 @@ def _check_datetime(exp: datetime | None): return True def verify_x509_attestation_chain(x5c: list[bytes], exp: datetime | None = None) -> bool: + """ + Verify the x509 attestation certificate chain. + + :param x5c: The x509 attestation certificate chain + :type x5c: list[bytes] + :param exp: The x509 attestation certificate chain expiration date + :type exp: datetime.datetime | None + + :returns: True if the x509 attestation certificate chain is valid else False + :rtype: bool + """ + if not _check_chain_len(x5c) or not _check_datetime(exp): return False @@ -61,6 +101,17 @@ def verify_x509_attestation_chain(x5c: list[bytes], exp: datetime | None = None) return _verify_x509_certificate_chain(pems) def verify_x509_anchor(pem_str: str, exp: datetime | None = None) -> bool: + """ + Verify the x509 anchor certificate. + + :param pem_str: The x509 anchor certificate + :type pem_str: str + :param exp: The x509 anchor certificate expiration date + :type exp: datetime.datetime | None + + :returns: True if the x509 anchor certificate is valid else False + :rtype: bool + """ if not _check_datetime(exp): return False @@ -72,10 +123,28 @@ def verify_x509_anchor(pem_str: str, exp: datetime | None = None) -> bool: return _verify_x509_certificate_chain(pems) def get_issuer_from_x5c(x5c: list[bytes]) -> str: + """ + Get the issuer from the x509 certificate chain. + + :param x5c: The x509 certificate chain + :type x5c: list[bytes] + + :returns: The issuer + :rtype: str + """ cert = load_der_x509_certificate(x5c[-1]) return cert.subject.rfc4514_string().split("=")[1] def is_der_format(cert: bytes) -> str: + """ + Check if the certificate is in DER format. + + :param cert: The certificate + :type cert: bytes + + :returns: True if the certificate is in DER format else False + :rtype: bool + """ try: pem = DER_cert_to_PEM_cert(cert) crypto.load_certificate(crypto.FILETYPE_PEM, str(pem)) From 3b435293405231dce88d84ca2d0c4088055a62ac Mon Sep 17 00:00:00 2001 From: PascalDR Date: Thu, 21 Dec 2023 15:12:49 +0100 Subject: [PATCH 87/88] docs: fixed docs --- pyeudiw/satosa/backend.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pyeudiw/satosa/backend.py b/pyeudiw/satosa/backend.py index 5fea7b9b..a92f44c9 100644 --- a/pyeudiw/satosa/backend.py +++ b/pyeudiw/satosa/backend.py @@ -153,8 +153,8 @@ def start_auth(self, context: Context, internal_request) -> Response: def pre_request_endpoint(self, context: Context, internal_request, **kwargs) -> Response: """ - This endpoint is called by the frontend before calling the request endpoint. - It initializes the session and returns the request_uri to be used by the frontend. + This endpoint is called by the User-Agent/Wallet Instance before calling the request endpoint. + It initializes the session and returns the request_uri to be used by the User-Agent/Wallet Instance. :type context: the context of current request :param context: the request context @@ -219,12 +219,12 @@ def pre_request_endpoint(self, context: Context, internal_request, **kwargs) -> def redirect_endpoint(self, context: Context, *args: tuple) -> Redirect | JsonResponse: """ - This endpoint is called by the frontend after the user has been authenticated. + This endpoint is called by the User-Agent/Wallet Instance after the user has been authenticated. :type context: the context of current request :param context: the request context - :return: a redirect to the frontend, if is in same device flow, or a json response if is in cross device flow. + :return: a redirect to the User-Agent/Wallet Instance, if is in same device flow, or a json response if is in cross device flow. :rtype: Redirect | JsonResponse """ @@ -397,7 +397,7 @@ def redirect_endpoint(self, context: Context, *args: tuple) -> Redirect | JsonRe def request_endpoint(self, context: Context, *args) -> JsonResponse: """ - This endpoint is called by the frontend to retrieve the signed signed Request Object. + This endpoint is called by the User-Agent/Wallet Instance to retrieve the signed signed Request Object. :type context: the context of current request :param context: the request context @@ -483,7 +483,7 @@ def request_endpoint(self, context: Context, *args) -> JsonResponse: def get_response_endpoint(self, context: Context) -> Response: """ - This endpoint is called by the frontend to retrieve the response of the authentication. + This endpoint is called by the User-Agent/Wallet Instance to retrieve the response of the authentication. :param context: the request context :type context: satosa.context.Context @@ -533,7 +533,7 @@ def get_response_endpoint(self, context: Context) -> Response: def status_endpoint(self, context: Context) -> JsonResponse: """ - This endpoint is called by the frontend the url to the response endpoint to finalize the process. + This endpoint is called by the User-Agent/Wallet Instance the url to the response endpoint to finalize the process. :param context: the request context :type context: satosa.context.Context From fc4f7d96333a17558d428c1d245ae86a31c251ef Mon Sep 17 00:00:00 2001 From: PascalDR Date: Thu, 21 Dec 2023 17:12:39 +0100 Subject: [PATCH 88/88] fix: fixed functions name --- pyeudiw/satosa/backend.py | 12 ++++---- pyeudiw/tests/satosa/test_backend.py | 44 ++++++++++++++-------------- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/pyeudiw/satosa/backend.py b/pyeudiw/satosa/backend.py index cdfa7074..30a081ec 100644 --- a/pyeudiw/satosa/backend.py +++ b/pyeudiw/satosa/backend.py @@ -213,9 +213,9 @@ def pre_request_endpoint(self, context: Context, internal_request, **kwargs) -> ) return Response(result, content="text/html; charset=utf8", status="200") - def redirect_endpoint(self, context: Context, *args: tuple) -> Redirect | JsonResponse: + def request_endpoint(self, context: Context, *args: tuple) -> Redirect | JsonResponse: """ - This endpoint is called by the User-Agent/Wallet Instance after the user has been authenticated. + This endpoint is called by the User-Agent/Wallet Instance to retrieve the signed signed Request Object. :type context: the context of current request :param context: the request context @@ -224,7 +224,7 @@ def redirect_endpoint(self, context: Context, *args: tuple) -> Redirect | JsonRe :rtype: Redirect | JsonResponse """ - self._log_function_debug("redirect_endpoint", context, "args", args) + self._log_function_debug("request_endpoint", context, "args", args) if context.request_method.lower() != 'post': # raise BadRequestError("HTTP Method not supported") @@ -391,9 +391,9 @@ def redirect_endpoint(self, context: Context, *args: tuple) -> Redirect | JsonRe status="200" ) - def request_endpoint(self, context: Context, *args) -> JsonResponse: + def redirect_endpoint(self, context: Context, *args) -> JsonResponse: """ - This endpoint is called by the User-Agent/Wallet Instance to retrieve the signed signed Request Object. + This endpoint is called by the User-Agent/Wallet Instance after the user has been authenticated. :type context: the context of current request :param context: the request context @@ -404,7 +404,7 @@ def request_endpoint(self, context: Context, *args) -> JsonResponse: :rtype: JsonResponse """ - self._log_function_debug("request_endpoint", context, "args", args) + self._log_function_debug("redirect_endpoint", context, "args", args) # check DPOP for WIA if any try: diff --git a/pyeudiw/tests/satosa/test_backend.py b/pyeudiw/tests/satosa/test_backend.py index 03a4be00..fcf65f69 100644 --- a/pyeudiw/tests/satosa/test_backend.py +++ b/pyeudiw/tests/satosa/test_backend.py @@ -241,9 +241,9 @@ def test_vp_validation_in_redirect_endpoint(self, context): context.request = { "response": encrypted_response } - redirect_endpoint = self.backend.redirect_endpoint(context) - assert redirect_endpoint.status == "400" - msg = json.loads(redirect_endpoint.message) + request_endpoint = self.backend.request_endpoint(context) + assert request_endpoint.status == "400" + msg = json.loads(request_endpoint.message) assert msg["error"] == "invalid_request" assert msg["error_description"] == "Error while validating VP: unexpected value." @@ -283,9 +283,9 @@ def test_vp_validation_in_redirect_endpoint(self, context): context.request = { "response": encrypted_response } - redirect_endpoint = self.backend.redirect_endpoint(context) - assert redirect_endpoint.status == "400" - msg = json.loads(redirect_endpoint.message) + request_endpoint = self.backend.request_endpoint(context) + assert request_endpoint.status == "400" + msg = json.loads(request_endpoint.message) assert msg["error"] == "invalid_request" assert msg["error_description"] == "Error while validating VP: vp has no nonce." @@ -296,9 +296,9 @@ def test_vp_validation_in_redirect_endpoint(self, context): context.request = { "response": encrypted_response } - redirect_endpoint = self.backend.redirect_endpoint(context) - assert redirect_endpoint.status == "400" - msg = json.loads(redirect_endpoint.message) + request_endpoint = self.backend.request_endpoint(context) + assert request_endpoint.status == "400" + msg = json.loads(request_endpoint.message) assert msg["error"] == "invalid_request" assert msg["error_description"] == "DirectPostResponse content parse and validation error. Single VPs are faulty." @@ -382,8 +382,8 @@ def test_redirect_endpoint(self, context): } # no nonce - redirect_endpoint = self.backend.redirect_endpoint(context) - msg = json.loads(redirect_endpoint.message) + request_endpoint = self.backend.request_endpoint(context) + msg = json.loads(request_endpoint.message) assert msg["error"] == "invalid_request" assert "nonce" in msg["error_description"] assert "missing" in msg["error_description"] @@ -395,8 +395,8 @@ def test_redirect_endpoint(self, context): context.request = { "response": encrypted_response } - redirect_endpoint = self.backend.redirect_endpoint(context) - msg = json.loads(redirect_endpoint.message) + request_endpoint = self.backend.request_endpoint(context) + msg = json.loads(request_endpoint.message) assert msg["error"] == "invalid_request" assert msg["error_description"] == "Session lookup by state value failed" @@ -407,8 +407,8 @@ def test_redirect_endpoint(self, context): context.request = { "response": encrypted_response } - redirect_endpoint = self.backend.redirect_endpoint(context) - msg = json.loads(redirect_endpoint.message) + request_endpoint = self.backend.request_endpoint(context) + msg = json.loads(request_endpoint.message) assert msg["error"] == "invalid_request" assert msg["error_description"] == "Session lookup by state value failed" @@ -422,8 +422,8 @@ def test_redirect_endpoint(self, context): self.backend.db_engine.update_request_object( document_id=doc_id, request_object={"nonce": nonce, "state": state}) - redirect_endpoint = self.backend.redirect_endpoint(context) - assert redirect_endpoint.status == "302 Found" + request_endpoint = self.backend.request_endpoint(context) + assert request_endpoint.status == "302 Found" def test_request_endpoint(self, context): @@ -509,12 +509,12 @@ def test_request_endpoint(self, context): request_uri = CONFIG['metadata']['request_uris'][0] context.request_uri = request_uri - request_endpoint = self.backend.request_endpoint(context) + redirect_endpoint = self.backend.redirect_endpoint(context) - assert request_endpoint - assert request_endpoint.status == "200" - assert request_endpoint.message - msg = json.loads(request_endpoint.message) + assert redirect_endpoint + assert redirect_endpoint.status == "200" + assert redirect_endpoint.message + msg = json.loads(redirect_endpoint.message) assert msg["response"] header = decode_jwt_header(msg["response"])