diff --git a/pyeudiw/jwk/__init__.py b/pyeudiw/jwk/__init__.py index e7fb4914..b5f494cf 100644 --- a/pyeudiw/jwk/__init__.py +++ b/pyeudiw/jwk/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from typing import Union @@ -6,14 +8,14 @@ from cryptojwt.jwk.jwk import key_from_jwk_dict from cryptojwt.jwk.rsa import new_rsa_key -from .exceptions import InvalidKid, KidNotFoundError +from .exceptions import InvalidKid, KidNotFoundError, InvalidJwk KEY_TYPES_FUNC = dict( EC=new_ec_key, RSA=new_rsa_key ) -class JWK(): +class JWK: """ The class representing a JWK istance """ @@ -117,6 +119,34 @@ def as_dict(self) -> dict: def __repr__(self): # private part! return self.as_json() + +class RSAJWK(JWK): + def __init__(self, key: dict | None = None, hash_func: str = "SHA-256") -> None: + super().__init__(key, "RSA", hash_func, None) + +class ECJWK(JWK): + def __init__(self, key: dict | None = None, hash_func: str = "SHA-256", ec_crv: str = "P-256") -> None: + super().__init__(key, "EC", hash_func, ec_crv) + +def jwk_form_dict(key: dict, hash_func: str = "SHA-256") -> RSAJWK | ECJWK: + """ + Returns a JWK instance from a dict. + + :param key: a dict that represents the key. + :type key: dict + + :returns: a JWK instance. + :rtype: JWK + """ + _kty = key.get('kty', None) + + if _kty == None or _kty not in ['EC', 'RSA']: + raise InvalidJwk("Invalid JWK") + elif _kty == "RSA": + return RSAJWK(key, hash_func) + else: + ec_crv = key.get('crv', "P-256") + return ECJWK(key, hash_func, ec_crv) def find_jwk(kid: str, jwks: list[dict], as_dict: bool=True) -> dict | JWK: """ diff --git a/pyeudiw/jwk/exceptions.py b/pyeudiw/jwk/exceptions.py index b3a84613..056744f7 100644 --- a/pyeudiw/jwk/exceptions.py +++ b/pyeudiw/jwk/exceptions.py @@ -10,3 +10,6 @@ class InvalidKid(Exception): class JwkError(Exception): pass + +class InvalidJwk(Exception): + pass \ No newline at end of file diff --git a/pyeudiw/satosa/backend.py b/pyeudiw/satosa/backend.py index a00ec60c..30a081ec 100644 --- a/pyeudiw/satosa/backend.py +++ b/pyeudiw/satosa/backend.py @@ -41,7 +41,7 @@ from .exceptions import HTTPError from .base_http_error_handler import BaseHTTPErrorHandler -from .base_logger import BaseLogger +from pyeudiw.tools.base_logger import BaseLogger class OpenID4VPBackend(BackendModule, BackendTrust, BackendDPoP, BaseHTTPErrorHandler, BaseLogger): """ @@ -152,8 +152,8 @@ def start_auth(self, context: Context, internal_request) -> Response: def pre_request_endpoint(self, context: Context, internal_request, **kwargs) -> Response: """ - This endpoint is called by the frontend before calling the request endpoint. - It initializes the session and returns the request_uri to be used by the frontend. + This endpoint is called by the User-Agent/Wallet Instance before calling the request endpoint. + It initializes the session and returns the request_uri to be used by the User-Agent/Wallet Instance. :type context: the context of current request :param context: the request context @@ -213,18 +213,18 @@ def pre_request_endpoint(self, context: Context, internal_request, **kwargs) -> ) return Response(result, content="text/html; charset=utf8", status="200") - def redirect_endpoint(self, context: Context, *args: tuple) -> Redirect | JsonResponse: + def request_endpoint(self, context: Context, *args: tuple) -> Redirect | JsonResponse: """ - This endpoint is called by the frontend after the user has been authenticated. + This endpoint is called by the User-Agent/Wallet Instance to retrieve the signed signed Request Object. :type context: the context of current request :param context: the request context - :return: a redirect to the frontend, if is in same device flow, or a json response if is in cross device flow. + :return: a redirect to the User-Agent/Wallet Instance, if is in same device flow, or a json response if is in cross device flow. :rtype: Redirect | JsonResponse """ - self._log_function_debug("redirect_endpoint", context, "args", args) + self._log_function_debug("request_endpoint", context, "args", args) if context.request_method.lower() != 'post': # raise BadRequestError("HTTP Method not supported") @@ -391,9 +391,9 @@ def redirect_endpoint(self, context: Context, *args: tuple) -> Redirect | JsonRe status="200" ) - def request_endpoint(self, context: Context, *args) -> JsonResponse: + def redirect_endpoint(self, context: Context, *args) -> JsonResponse: """ - This endpoint is called by the frontend to retrieve the signed signed Request Object. + This endpoint is called by the User-Agent/Wallet Instance after the user has been authenticated. :type context: the context of current request :param context: the request context @@ -404,7 +404,7 @@ def request_endpoint(self, context: Context, *args) -> JsonResponse: :rtype: JsonResponse """ - self._log_function_debug("request_endpoint", context, "args", args) + self._log_function_debug("redirect_endpoint", context, "args", args) # check DPOP for WIA if any try: @@ -479,7 +479,7 @@ def request_endpoint(self, context: Context, *args) -> JsonResponse: def get_response_endpoint(self, context: Context) -> Response: """ - This endpoint is called by the frontend to retrieve the response of the authentication. + This endpoint is called by the User-Agent/Wallet Instance to retrieve the response of the authentication. :param context: the request context :type context: satosa.context.Context @@ -529,7 +529,7 @@ def get_response_endpoint(self, context: Context) -> Response: def status_endpoint(self, context: Context) -> JsonResponse: """ - This endpoint is called by the frontend the url to the response endpoint to finalize the process. + This endpoint is called by the User-Agent/Wallet Instance the url to the response endpoint to finalize the process. :param context: the request context :type context: satosa.context.Context diff --git a/pyeudiw/satosa/base_http_error_handler.py b/pyeudiw/satosa/base_http_error_handler.py index 785042c5..cf7cb83f 100644 --- a/pyeudiw/satosa/base_http_error_handler.py +++ b/pyeudiw/satosa/base_http_error_handler.py @@ -1,5 +1,5 @@ from satosa.context import Context -from .base_logger import BaseLogger +from pyeudiw.tools.base_logger import BaseLogger from .exceptions import EmptyHTTPError from pyeudiw.satosa.response import JsonResponse diff --git a/pyeudiw/satosa/dpop.py b/pyeudiw/satosa/dpop.py index 0abc0ef6..778b0696 100644 --- a/pyeudiw/satosa/dpop.py +++ b/pyeudiw/satosa/dpop.py @@ -9,7 +9,7 @@ from satosa.context import Context from pydantic import ValidationError -from .base_logger import BaseLogger +from pyeudiw.tools.base_logger import BaseLogger from .base_http_error_handler import BaseHTTPErrorHandler class BackendDPoP(BaseHTTPErrorHandler, BaseLogger): diff --git a/pyeudiw/satosa/trust.py b/pyeudiw/satosa/trust.py index ffaea9ca..86540b31 100644 --- a/pyeudiw/satosa/trust.py +++ b/pyeudiw/satosa/trust.py @@ -14,7 +14,7 @@ from pyeudiw.trust import TrustEvaluationHelper from pyeudiw.trust.trust_anchors import update_trust_anchors_ecs -from .base_logger import BaseLogger +from pyeudiw.tools.base_logger import BaseLogger class BackendTrust(BaseLogger): """ diff --git a/pyeudiw/sd_jwt/__init__.py b/pyeudiw/sd_jwt/__init__.py index 3855623b..c8cab393 100644 --- a/pyeudiw/sd_jwt/__init__.py +++ b/pyeudiw/sd_jwt/__init__.py @@ -1,4 +1,3 @@ -import cryptojwt import json from jwcrypto.common import base64url_encode @@ -23,12 +22,45 @@ import jwcrypto from typing import Any - +from cryptojwt.jwk.rsa import RSAKey +from cryptojwt.jwk.ec import ECKey +from cryptography.hazmat.backends.openssl.rsa import _RSAPrivateKey class TrustChainSDJWTIssuer(SDJWTIssuer): - def __init__(self, user_claims: Dict, issuer_key, holder_key=None, sign_alg=None, add_decoy_claims: bool = True, serialization_format: str = "compact", additional_headers: dict = {}): + """ + Class for issue SD-JWT of TrustChain. + """ + def __init__( + self, + user_claims: Dict[str, Any], + issuer_key: dict, + holder_key: dict | None = None, + sign_alg: str | None = None, + add_decoy_claims: bool = True, + serialization_format: str = "compact", + additional_headers: dict = {} + ) -> None: + """ + Crate an istance of TrustChainSDJWTIssuer. + + :param user_claims: the claims of the SD-JWT. + :type user_claims: dict + :param issuer_key: the issuer key. + :type issuer_key: dict + :param holder_key: the holder key. + :type holder_key: dict | None + :param sign_alg: the signing algorithm. + :type sign_alg: str | None + :param add_decoy_claims: if True add decoy claims. + :type add_decoy_claims: bool + :param serialization_format: the serialization format. + :type serialization_format: str + :param additional_headers: additional headers. + :type additional_headers: dict + """ + self.additional_headers = additional_headers - sign_alg = DEFAULT_SIG_KTY_MAP[issuer_key.kty] + sign_alg = sign_alg if sign_alg else DEFAULT_SIG_KTY_MAP[issuer_key.kty] super().__init__( user_claims, @@ -40,6 +72,9 @@ def __init__(self, user_claims: Dict, issuer_key, holder_key=None, sign_alg=None ) def _create_signed_jws(self): + """ + Creates the signed JWS. + """ self.sd_jwt = JWS(payload=dumps(self.sd_jwt_payload)) _protected_headers = {"alg": self._sign_alg} @@ -67,9 +102,19 @@ def _create_signed_jws(self): self.serialized_sd_jwt = dumps(jws_content) -def _serialize_key(key, **kwargs): +def _serialize_key( + key: RSAKey | ECKey | JWK | dict, + **kwargs + ) -> dict: + """ + Serialize a key into dict. + + :param key: the key to serialize. + :type key: RSAKey | ECKey | JWK | dict - if isinstance(key, cryptojwt.jwk.rsa.RSAKey): + :returns: the serialized key into a dict. + """ + if isinstance(key, RSAKey) or isinstance(key, ECKey): key = key.serialize() elif isinstance(key, JWK): key = key.as_dict() @@ -80,7 +125,19 @@ def _serialize_key(key, **kwargs): return key -def pk_encode_int(i, bit_size=None): +def pk_encode_int(i: str, bit_size: int = None) -> str: + """ + Encode an integer as a base64url string with padding. + + :param i: the integer to encode. + :type i: str + :param bit_size: the bit size of the integer. + :type bit_size: int + + :returns: the encoded integer. + :rtype: str + """ + extend = 0 if bit_size is not None: extend = ((bit_size + 7) // 8) * 2 @@ -93,7 +150,22 @@ def pk_encode_int(i, bit_size=None): return base64url_encode(unhexlify(extend * '0' + hexi)) -def import_pyca_pri_rsa(key, **params): +def import_pyca_pri_rsa(key: _RSAPrivateKey, **params) -> jwcrypto.jwk.JWK: + """ + Import a private RSA key from a PyCA object. + + :param key: the key to import. + :type key: RSAKey | ECKey + + :raises ValueError: if the key is not a PyCA RSAKey object. + + :returns: the imported key. + :rtype: RSAKey + """ + + if not isinstance(key, _RSAPrivateKey): + raise ValueError("key must be a ssl RSAPrivateKey object") + pn = key.private_numbers() params.update( kty='RSA', @@ -129,7 +201,19 @@ def import_ec(key, **params): ) return jwcrypto.jwk.JWK(**params) -def _adapt_keys(issuer_key: JWK, holder_key: JWK): +def _adapt_keys(issuer_key: JWK, holder_key: JWK) -> dict: + """ + Adapt the keys to the SD-JWT library. + + :param issuer_key: the issuer key. + :type issuer_key: JWK + :param holder_key: the holder key. + :type holder_key: JWK + + :returns: the adapted keys as a dict. + :rtype: dict + """ + # _iss_key = issuer_key.key.serialize(private=True) # _iss_key['key_ops'] = 'sign' @@ -152,11 +236,45 @@ def _adapt_keys(issuer_key: JWK, holder_key: JWK): ) -def load_specification_from_yaml_string(yaml_specification: str): +def load_specification_from_yaml_string(yaml_specification: str) -> dict: + """ + Load a specification from a yaml string. + + :param yaml_specification: the yaml string. + :type yaml_specification: str + + :returns: the specification as a dict. + :rtype: dict + """ + return _yaml_load_specification(StringIO(yaml_specification)) -def issue_sd_jwt(specification: dict, settings: dict, issuer_key: JWK, holder_key: JWK, trust_chain: list[str] | None = None) -> str: +def issue_sd_jwt( + specification: Dict[str, Any], + settings: dict, + issuer_key: JWK, + holder_key: JWK, + trust_chain: list[str] | None = None + ) -> str: + """ + Issue a SD-JWT. + + :param specification: the specification of the SD-JWT. + :type specification: Dict[str, Any] + :param settings: the settings of the SD-JWT. + :type settings: dict + :param issuer_key: the issuer key. + :type issuer_key: JWK + :param holder_key: the holder key. + :type holder_key: JWK + :param trust_chain: the trust chain. + :type trust_chain: list[str] | None + + :returns: the issued SD-JWT. + :rtype: str + """ + claims = { "iss": settings["issuer"], "iat": iat_now(), @@ -180,7 +298,23 @@ def issue_sd_jwt(specification: dict, settings: dict, issuer_key: JWK, holder_ke return {"jws": sdjwt_at_issuer.serialized_sd_jwt, "issuance": sdjwt_at_issuer.sd_jwt_issuance} -def _cb_get_issuer_key(issuer: str, settings: dict, adapted_keys: dict, *args, **kwargs): +def _cb_get_issuer_key(issuer: str, settings: dict, adapted_keys: dict, *args, **kwargs) -> JWK: + """ + Helper function for get the issuer key. + + :param issuer: the issuer. + :type issuer: str + :param settings: the settings of SD-JWT. + :type settings: dict + :param adapted_keys: the adapted keys. + :type adapted_keys: dict + + :raises Exception: if the issuer is unknown. + + :returns: the issuer key. + :rtype: JWK + """ + if issuer == settings["issuer"]: return adapted_keys["issuer_public_key"] else: @@ -193,6 +327,20 @@ def verify_sd_jwt( holder_key: JWK, settings: dict = {'key_binding': True} ) -> (list | dict | Any): + """ + Verify a SD-JWT. + + :param sd_jwt_presentation: the SD-JWT to verify. + :type sd_jwt_presentation: str + :param issuer_key: the issuer key. + :type issuer_key: JWK + :param holder_key: the holder key. + :type holder_key: JWK + :param settings: the settings of SD-JWT. + + :returns: the verified payload. + :rtype: list | dict | Any + """ settings.update( { diff --git a/pyeudiw/storage/base_cache.py b/pyeudiw/storage/base_cache.py index a29afeed..6292054c 100644 --- a/pyeudiw/storage/base_cache.py +++ b/pyeudiw/storage/base_cache.py @@ -1,5 +1,6 @@ from enum import Enum from typing import Callable +from .base_db import BaseDB class RetrieveStatus(Enum): @@ -7,12 +8,47 @@ class RetrieveStatus(Enum): ADDED = 1 -class BaseCache(): +class BaseCache(BaseDB): + """ + Interface class for cache storage. + """ + def try_retrieve(self, object_name: str, on_not_found: Callable[[], str]) -> tuple[dict, RetrieveStatus]: + """ + Try to retrieve an object from the cache. If the object is not found, call the on_not_found function. + + :param object_name: the name of the object to retrieve. + :type object_name: str + :param on_not_found: the function to call if the object is not found. + :type on_not_found: Callable[[], str] + + :returns: a tuple with the retrieved object and a status. + :rtype: tuple[dict, RetrieveStatus] + """ raise NotImplementedError() def overwrite(self, object_name: str, value_gen_fn: Callable[[], str]) -> dict: + """ + Overwrite an object in the cache. + + :param object_name: the name of the object to overwrite. + :type object_name: str + :param value_gen_fn: the function to call to generate the new value. + :type value_gen_fn: Callable[[], str] + + :returns: the overwritten object. + :rtype: dict + """ raise NotImplementedError() def set(self, data: dict) -> dict: + """ + Set an object in the cache. + + :param data: the data to set. + :type data: dict + + :returns: the setted object. + :rtype: dict + """ raise NotImplementedError() diff --git a/pyeudiw/storage/base_db.py b/pyeudiw/storage/base_db.py new file mode 100644 index 00000000..1d3abd13 --- /dev/null +++ b/pyeudiw/storage/base_db.py @@ -0,0 +1,22 @@ +class BaseDB: + """ + Interface class for database storage. + """ + + def _connect(self) -> None: + """ + Connect to the database server. + + :raises ConnectionFailure: if the connection fails. + + :returns: None + """ + raise NotImplementedError() + + def close(self) -> None: + """ + Close the connection to the storage. + + :returns: None + """ + raise NotImplementedError() \ No newline at end of file diff --git a/pyeudiw/storage/base_storage.py b/pyeudiw/storage/base_storage.py index f97eaf96..348f4ef5 100644 --- a/pyeudiw/storage/base_storage.py +++ b/pyeudiw/storage/base_storage.py @@ -1,6 +1,9 @@ import datetime from enum import Enum from typing import Union +from pymongo.results import UpdateResult + +from .base_db import BaseDB class TrustType(Enum): X509 = 0 @@ -21,70 +24,306 @@ class TrustType(Enum): TrustType.FEDERATION: "entity_configuration" } -class BaseStorage(object): - def init_session(self, document_id: str, dpop_proof: dict, attestation: dict): - raise NotImplementedError() +class BaseStorage(BaseDB): + """ + Interface class for storage. + """ - def is_connected(self) -> bool: - raise NotImplementedError() + def init_session(self, document_id: str, dpop_proof: dict, attestation: dict) -> str: + """ + Initialize a session. - def close(self) -> None: + :param document_id: the document id. + :type document_id: str + :param dpop_proof: the dpop proof. + :type dpop_proof: dict + :param attestation: the attestation. + """ raise NotImplementedError() - def add_dpop_proof_and_attestation(self, document_id, dpop_proof: dict, attestation: dict): + def add_dpop_proof_and_attestation(self, document_id, dpop_proof: dict, attestation: dict) -> UpdateResult: + """ + Add a dpop proof and an attestation to the session. + + :param document_id: the document id. + :type document_id: str + :param dpop_proof: the dpop proof. + :type dpop_proof: dict + :param attestation: the attestation. + :type attestation: dict + + :returns: the result of the update operation. + :rtype: UpdateResult + """ raise NotImplementedError() - def set_finalized(self, document_id: str): + def set_finalized(self, document_id: str) -> UpdateResult: + """ + Set the session as finalized. + + :param document_id: the document id. + :type document_id: str + + :returns: the result of the update operation. + :rtype: UpdateResult + """ + raise NotImplementedError() - def update_request_object(self, document_id: str, request_object: dict) -> int: + def update_request_object(self, document_id: str, request_object: dict) -> UpdateResult: + """ + Update the request object of the session. + + :param document_id: the document id. + :type document_id: str + :param request_object: the request object. + :type request_object: dict + + :returns: the result of the update operation. + :rtype: UpdateResult + """ raise NotImplementedError() - def update_response_object(self, nonce: str, state: str, response_object: dict) -> int: + def update_response_object(self, nonce: str, state: str, response_object: dict) -> UpdateResult: + """ + Update the response object of the session. + + :param nonce: the nonce. + :type nonce: str + :param state: the state. + :type state: str + :param response_object: the response object. + :type response_object: dict + + :returns: the result of the update operation. + :rtype: UpdateResult + """ raise NotImplementedError() def get_trust_attestation(self, entity_id: str) -> Union[dict, None]: + """ + Get a trust attestation. + + :param entity_id: the entity id. + :type entity_id: str + + :returns: the trust attestation. + :rtype: Union[dict, None] + """ raise NotImplementedError() def get_trust_anchor(self, entity_id: str) -> Union[dict, None]: + """ + Get a trust anchor. + + :param entity_id: the entity id. + :type entity_id: str + + :returns: the trust anchor. + :rtype: Union[dict, None] + """ raise NotImplementedError() - def has_trust_attestation(self, entity_id: str): + def has_trust_attestation(self, entity_id: str) -> bool: + """ + Check if a trust attestation exists. + + :param entity_id: the entity id. + :type entity_id: str + + + :returns: True if the trust attestation exists, False otherwise. + :rtype: bool + """ raise NotImplementedError() - def has_trust_anchor(self, entity_id: str): + def has_trust_anchor(self, entity_id: str) -> bool: + """ + Check if a trust anchor exists. + + :param entity_id: the entity id. + :type entity_id: str + + :returns: True if the trust anchor exists, False otherwise. + :rtype: bool + """ raise NotImplementedError() def add_trust_attestation(self, entity_id: str, attestation: list[str], exp: datetime, trust_type: TrustType) -> str: + """ + Add a trust attestation. + + :param entity_id: the entity id. + :type entity_id: str + :param attestation: the attestation. + :type attestation: list[str] + :param exp: the expiration date. + :type exp: datetime + :param trust_type: the trust type. + :type trust_type: TrustType + + :returns: the document id. + :rtype: str + """ raise NotImplementedError() def add_trust_attestation_metadata(self, entity_id: str, metadata_type: str, metadata: dict) -> str: + """ + Add a trust attestation metadata. + + :param entity_id: the entity id. + :type entity_id: str + :param metadata_type: the metadata type. + :type metadata_type: str + + :returns: the document id. + :rtype: str + """ raise NotImplementedError() def add_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datetime, trust_type: TrustType): + """ + Add a trust anchor. + + :param entity_id: the entity id. + :type entity_id: str + :param entity_configuration: the entity configuration. + :type entity_configuration: str + :param exp: the expiration date. + :type exp: datetime + :param trust_type: the trust type. + :type trust_type: TrustType + + :returns: the document id. + :rtype: str + """ raise NotImplementedError() def update_trust_attestation(self, entity_id: str, attestation: list[str], exp: datetime, trust_type: TrustType) -> str: + """ + Update a trust attestation. + + :param entity_id: the entity id. + :type entity_id: str + :param attestation: the attestation. + :type attestation: list[str] + :param exp: the expiration date. + :type exp: datetime + :param trust_type: the trust type. + :type trust_type: TrustType + + :returns: the document id. + :rtype: str + """ raise NotImplementedError() def update_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datetime, trust_type: TrustType) -> str: + """ + Update a trust anchor. + + :param entity_id: the entity id. + :type entity_id: str + :param entity_configuration: the entity configuration. + :type entity_configuration: str + :param exp: the expiration date. + :type exp: datetime + :param trust_type: the trust type. + :type trust_type: TrustType + + :returns: the document id. + :rtype: str + """ raise NotImplementedError() def exists_by_state_and_session_id(self, state: str, session_id: str = "") -> bool: + """ + Check if a session exists by state and session id. + + :param state: the state. + :type state: str + :param session_id: the session id. + :type session_id: str + + :returns: True if the session exists, False otherwise. + :rtype: bool + """ raise NotImplementedError() - def get_by_state(self, state: str): + def get_by_state(self, state: str) -> Union[dict, None]: + """ + Get a session by state. + + :param state: the state. + :type state: str + + :returns: the session. + :rtype: Union[dict, None] + """ raise NotImplementedError() - def get_by_nonce_state(self, state: str, nonce: str): + def get_by_nonce_state(self, state: str, nonce: str) -> Union[dict, None]: + """ + Get a session by nonce and state. + + :param state: the state. + :type state: str + :param nonce: the nonce. + :type nonce: str + + :returns: the session. + :rtype: Union[dict, None] + """ raise NotImplementedError() - def get_by_state_and_session_id(self, state: str, session_id: str = ""): + def get_by_state_and_session_id(self, state: str, session_id: str = "") -> Union[dict, None]: + """ + Get a session by state and session id. + + :param state: the state. + :type state: str + :param session_id: the session id. + :type session_id: str + + :returns: the session. + :rtype: Union[dict, None] + """ raise NotImplementedError() - def get_by_session_id(self, session_id: str): + def get_by_session_id(self, session_id: str) -> Union[dict, None]: + """ + Get a session by session id. + + :param session_id: the session id. + :type session_id: str + + :returns: the session. + :rtype: Union[dict, None] + """ raise NotImplementedError() # TODO: create add_or_update for all the write methods def add_or_update_trust_attestation(self, entity_id: str, attestation: list[str], exp: datetime) -> str: + """ + Add or update a trust attestation. + + :param entity_id: the entity id. + :type entity_id: str + :param attestation: the attestation. + :type attestation: list[str] + :param exp: the expiration date. + :type exp: datetime + + :returns: the document id. + :rtype: str + """ + raise NotImplementedError() + + @property + def is_connected(self) -> bool: + """ + Check if the storage is connected. + + :returns: True if the storage is connected, False otherwise. + :rtype: bool + """ raise NotImplementedError() diff --git a/pyeudiw/storage/db_engine.py b/pyeudiw/storage/db_engine.py index b1dbc9cb..c2d07614 100644 --- a/pyeudiw/storage/db_engine.py +++ b/pyeudiw/storage/db_engine.py @@ -1,8 +1,7 @@ import uuid -import logging import importlib from datetime import datetime -from typing import Callable, Union +from typing import Callable, Union, Tuple, Dict from pyeudiw.storage.base_cache import BaseCache, RetrieveStatus from pyeudiw.storage.base_storage import BaseStorage, TrustType from pyeudiw.storage.exceptions import ( @@ -10,14 +9,24 @@ StorageWriteError, EntryNotFound ) +from pyeudiw.tools.base_logger import BaseLogger -logger = logging.getLogger(__name__) +from .base_db import BaseDB - -class DBEngine(): +class DBEngine(BaseStorage, BaseCache, BaseLogger): + """ + DB Engine class. + """ def __init__(self, config: dict): - self.caches = [] - self.storages = [] + """ + Create a DB Engine instance. + + :param config: the configuration of all the DBs. + :type config: dict + """ + self.caches: list[Tuple[str, BaseCache]] = [] + self.storages: list[Tuple[str, BaseStorage]] = [] + for db_name, db_conf in config.items(): storage_instance, cache_instance = self._handle_instance(db_conf) @@ -27,25 +36,6 @@ def __init__(self, config: dict): if cache_instance: self.caches.append((db_name, cache_instance)) - def _handle_instance(self, instance: dict) -> dict[BaseStorage | None, BaseCache | None]: - cache_conf = instance.get("cache", None) - storage_conf = instance.get("storage", None) - - storage_instance = None - if storage_conf: - module = importlib.import_module(storage_conf["module"]) - instance_class = getattr(module, storage_conf["class"]) - storage_instance = instance_class( - **storage_conf.get("init_params", {})) - - cache_instance = None - if cache_conf: - module = importlib.import_module(cache_conf["module"]) - instance_class = getattr(module, cache_conf["class"]) - cache_instance = instance_class(**cache_conf["init_params"]) - - return storage_instance, cache_instance - def init_session(self, session_id: str, state: str) -> str: document_id = str(uuid.uuid4()) for db_name, storage in self.storages: @@ -54,49 +44,38 @@ def init_session(self, session_id: str, state: str) -> str: document_id, session_id=session_id, state=state ) except StorageWriteError as e: - logger.critical( - f"Error while initializing session with document_id {document_id}. " - f"Cannot write document with id {document_id} on {db_name}: " - f"{e.__class__.__name__}: {e}" + self._log_critical( + e.__class__.__name__, + ( + f"Error while initializing session with document_id {document_id}. " + f"Cannot write document with id {document_id} on {db_name}: {e}" + ) ) raise e return document_id - @property - def is_connected(self): - _connected = False - _cons = {} - for db_name, storage in self.storages: - try: - _connected = storage.is_connected - _cons[db_name] = _connected - except Exception as e: - logger.debug( - f"Error while checking db engine connection on {db_name}. " - f"{e.__class__.__name__}: {e}" - ) + def close(self): + self._close_list(self.storages) + self._close_list(self.caches) - if True in _cons.values() and not all(_cons.values()): - logger.warning( - f"Not all the storage are found available, storages misalignment: " - f"{_cons}" - ) + def write(self, method: str, *args, **kwargs): + """ + Perform a write operation on the storages. - return _connected + :param method: the method to call. + :type method: str + :param args: the arguments to pass to the method. + :type args: Any + :param kwargs: the keyword arguments to pass to the method. + :type kwargs: Any - def close(self): - for db_name, storage in self.storages: - try: - storage.close() - except Exception as e: - logger.critical( - f"Error while closing db engine {db_name}. " - f"{e.__class__.__name__}: {e}" - ) - raise e + :raises StorageWriteError: if the write operation fails on all the storages. + + :returns: the number of replicas where the write operation is successful. + :rtype: int + """ - def write(self, method: str, *args, **kwargs): replica_count = 0 _err_msg = f"Cannot apply write method '{method}' with {args} {kwargs}" for db_name, storage in self.storages: @@ -104,8 +83,10 @@ def write(self, method: str, *args, **kwargs): getattr(storage, method)(*args, **kwargs) replica_count += 1 except Exception as e: - logger.critical( - f"Error {_err_msg} on {db_name} {storage}: {str(e)}") + self._log_critical( + e.__class__.__name__, + f"Error {_err_msg} on {db_name} {storage}: {str(e)}" + ) if not replica_count: raise StorageWriteError(_err_msg) @@ -129,7 +110,23 @@ def update_request_object(self, document_id: str, request_object: dict) -> int: def update_response_object(self, nonce: str, state: str, response_object: dict) -> int: return self.write("update_response_object", nonce, state, response_object) - def get(self, method: str, *args, **kwargs): + def get(self, method: str, *args, **kwargs) -> Union[dict, None]: + """ + Perform a get operation on the storages. + + :param method: the method to call. + :type method: str + :param args: the arguments to pass to the method. + :type args: Any + :param kwargs: the keyword arguments to pass to the method. + :type kwargs: Any + + :raises EntryNotFound: if the entry is not found on any storage. + + :returns: the result of the first elment found on DBs. + :rtype: Union[dict, None] + """ + for db_name, storage in self.storages: try: res = getattr(storage, method)(*args, **kwargs) @@ -137,7 +134,8 @@ def get(self, method: str, *args, **kwargs): return res except EntryNotFound as e: - logger.critical( + self._log_debug( + e.__class__.__name__, f"Cannot find result by method {method} on {db_name} with {args} {kwargs}: {str(e)}" ) @@ -149,11 +147,11 @@ def get_trust_attestation(self, entity_id: str) -> Union[dict, None]: def get_trust_anchor(self, entity_id: str) -> Union[dict, None]: return self.get("get_trust_anchor", entity_id) - def has_trust_attestation(self, entity_id: str): - return self.get_trust_attestation(entity_id) + def has_trust_attestation(self, entity_id: str) -> bool: + return self.get_trust_attestation(entity_id) != None - def has_trust_anchor(self, entity_id: str): - return self.get_anchor(entity_id) + def has_trust_anchor(self, entity_id: str) -> bool: + return self.get_trust_anchor(entity_id) != None def add_trust_attestation(self, entity_id: str, attestation: list[str], exp: datetime, trust_type: TrustType = TrustType.FEDERATION) -> str: return self.write("add_trust_attestation", entity_id, attestation, exp, trust_type) @@ -161,7 +159,7 @@ def add_trust_attestation(self, entity_id: str, attestation: list[str], exp: dat def add_trust_attestation_metadata(self, entity_id: str, metadat_type: str, metadata: dict) -> str: return self.write("add_trust_attestation_metadata", entity_id, metadat_type, metadata) - def add_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datetime, trust_type: TrustType = TrustType.FEDERATION): + def add_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datetime, trust_type: TrustType = TrustType.FEDERATION) -> str: return self.write("add_trust_anchor", entity_id, entity_configuration, exp, trust_type) def update_trust_attestation(self, entity_id: str, attestation: list[str], exp: datetime, trust_type: TrustType = TrustType.FEDERATION) -> str: @@ -177,20 +175,6 @@ def add_or_update_trust_attestation(self, entity_id: str, attestation: list[str] def update_trust_anchor(self, entity_id: str, entity_configuration: dict, exp: datetime, trust_type: TrustType = TrustType.FEDERATION) -> str: return self.write("update_trust_anchor", entity_id, entity_configuration, exp, trust_type) - def _cache_try_retrieve(self, object_name: str, on_not_found: Callable[[], str]) -> tuple[dict, RetrieveStatus, int]: - for i, cache in enumerate(self.caches): - try: - cache_object, status = cache.try_retrieve( - object_name, on_not_found) - return cache_object, status, i - except Exception: - logger.critical( - f"Cannot retrieve or write cache object with identifier {object_name} on cache database {i}" - ) - raise ConnectionRefusedError( - "Cannot write cache object on any instance" - ) - def try_retrieve(self, object_name: str, on_not_found: Callable[[], str]) -> dict: # if no cache instance exist return the object if len(self.caches): @@ -210,8 +194,9 @@ def try_retrieve(self, object_name: str, on_not_found: Callable[[], str]) -> dic for cache_name, cache in replica_instances: try: cache.set(cache_object) - except Exception: - logger.critical( + except Exception as e: + self._log_critical( + e.__class__.__name__, f"Cannot replicate cache object with identifier {object_name} on cache {cache_name}" ) @@ -222,8 +207,9 @@ def overwrite(self, object_name: str, value_gen_fn: Callable[[], str]) -> dict: cache_object = None try: cache_object = cache.overwrite(object_name, value_gen_fn) - except Exception: - logger.critical( + except Exception as e: + self._log_critical( + e.__class__.__name__, f"Cannot overwrite cache object with identifier {object_name} on cache {cache_name}" ) return cache_object @@ -236,14 +222,114 @@ def exists_by_state_and_session_id(self, state: str, session_id: str = "") -> bo return True return False - def get_by_state(self, state: str): + def get_by_state(self, state: str) -> Union[dict, None]: return self.get_by_state_and_session_id(state=state) - def get_by_nonce_state(self, state: str, nonce: str): + def get_by_nonce_state(self, state: str, nonce: str) -> Union[dict, None]: return self.get('get_by_nonce_state', state=state, nonce=nonce) - def get_by_state_and_session_id(self, state: str, session_id: str = ""): + def get_by_state_and_session_id(self, state: str, session_id: str = "") -> Union[dict, None]: return self.get("get_by_state_and_session_id", state, session_id) - def get_by_session_id(self, session_id: str): + def get_by_session_id(self, session_id: str) -> Union[dict, None]: return self.get("get_by_session_id", session_id) + + @property + def is_connected(self): + _connected = False + _cons = {} + for db_name, storage in self.storages: + try: + _connected = storage.is_connected() + _cons[db_name] = _connected + except Exception as e: + self._log_debug( + e.__class__.__name__, + f"Error while checking db engine connection on {db_name}: {e} " + ) + + if True in _cons.values() and not all(_cons.values()): + self._log_warning( + "DB Engine", + f"Not all the storage are found available, storages misalignment: " + f"{_cons}" + ) + + return _connected + + def _cache_try_retrieve(self, object_name: str, on_not_found: Callable[[], str]) -> tuple[dict, RetrieveStatus, int]: + """ + Try to retrieve an object from the cache. If the object is not found, call the on_not_found function. + + :param object_name: the name of the object to retrieve. + :type object_name: str + :param on_not_found: the function to call if the object is not found. + :type on_not_found: Callable[[], str] + + :raises ConnectionRefusedError: if the object cannot be retrieved on any instance. + + :returns: a tuple with the retrieved object, a status and the index of the cache instance. + :rtype: tuple[dict, RetrieveStatus, int] + """ + + for i, (cache_name, cache_istance) in enumerate(self.caches): + try: + cache_object, status = cache_istance.try_retrieve( + object_name, on_not_found) + return cache_object, status, i + except Exception as e: + self._log_critical( + e.__class__.__name__, + f"Cannot retrieve cache object with identifier {object_name} on cache database {cache_name}" + ) + raise ConnectionRefusedError( + "Cannot write cache object on any instance" + ) + + def _close_list(self, db_list: list[Tuple[str,BaseDB]]) -> None: + """ + Close a list of db. + + :param db_list: the list of db to close. + :type db_list: list[Tuple[str,BaseDB]] + + :raises Exception: if an error occurs while closing a db. + """ + + for db_name, db in db_list: + try: + db.close() + except Exception as e: + self._log_critical( + e.__class__.__name__, + f"Error while closing db engine {db_name}: {e}" + ) + raise e + + def _handle_instance(self, instance: dict) -> dict[BaseStorage | None, BaseCache | None]: + """ + Handle the initialization of a storage/cache instance. + + :param instance: the instance configuration. + :type instance: dict + + :returns: a tuple with the storage and cache instance. + :rtype: tuple[BaseStorage | None, BaseCache | None] + """ + cache_conf = instance.get("cache", None) + storage_conf = instance.get("storage", None) + + storage_instance = None + if storage_conf: + module = importlib.import_module(storage_conf["module"]) + instance_class = getattr(module, storage_conf["class"]) + storage_instance = instance_class( + **storage_conf.get("init_params", {})) + + cache_instance = None + if cache_conf: + module = importlib.import_module(cache_conf["module"]) + instance_class = getattr(module, cache_conf["class"]) + cache_instance = instance_class(**cache_conf["init_params"]) + + return storage_instance, cache_instance diff --git a/pyeudiw/storage/mongo_cache.py b/pyeudiw/storage/mongo_cache.py index f44dc859..12a877f8 100644 --- a/pyeudiw/storage/mongo_cache.py +++ b/pyeudiw/storage/mongo_cache.py @@ -4,32 +4,40 @@ import pymongo from pyeudiw.storage.base_cache import BaseCache, RetrieveStatus +from pymongo.collection import Collection +from pymongo.mongo_client import MongoClient +from pymongo.database import Database class MongoCache(BaseCache): + """ + MongoDB cache implementation. + """ + def __init__(self, conf: dict, url: str, connection_params: dict = None) -> None: + """ + Create a MongoCache istance. + + :param conf: the configuration of the cache. + :type conf: dict + :param url: the url of the MongoDB server. + :type url: str + :param connection_params: the connection parameters. + :type connection_params: dict, optional + """ super().__init__() self.storage_conf = conf self.url = url self.connection_params = connection_params - self.client = None - self.db = None - - def _connect(self): - if not self.client or not self.client.server_info(): - self.client = pymongo.MongoClient( - self.url, **self.connection_params) - self.db = getattr(self.client, self.storage_conf["db_name"]) - self.collection = getattr(self.db, "cache_storage") + self.client: MongoClient = None + self.db: Database = None + self.collection: Collection = None - def _gen_cache_object(self, object_name: str, data: str): - return { - "object_name": object_name, - "data": data, - "creation_date": datetime.now().isoformat() - } + def close(self) -> None: + self._connect() + self.client.close() def try_retrieve(self, object_name: str, on_not_found: Callable[[], str]) -> tuple[dict, RetrieveStatus]: self._connect() @@ -72,3 +80,26 @@ def set(self, data: dict) -> dict: self._connect() return self.collection.insert_one(data) + + def _connect(self) -> None: + if not self.client or not self.client.server_info(): + self.client = pymongo.MongoClient( + self.url, **self.connection_params) + self.db = getattr(self.client, self.storage_conf["db_name"]) + self.collection = getattr(self.db, "cache_storage") + + def _gen_cache_object(self, object_name: str, data: str) -> dict: + """ + Helper function to generate a cache object. + + :param object_name: the name of the object. + :type object_name: str + :param data: the data to store. + :type data: str + """ + + return { + "object_name": object_name, + "data": data, + "creation_date": datetime.now().isoformat() + } diff --git a/pyeudiw/storage/mongo_storage.py b/pyeudiw/storage/mongo_storage.py index 541c5471..eda332cc 100644 --- a/pyeudiw/storage/mongo_storage.py +++ b/pyeudiw/storage/mongo_storage.py @@ -27,7 +27,7 @@ def __init__(self, conf: dict, url: str, connection_params: dict = {}) -> None: self.db = None @property - def is_connected(self): + def is_connected(self) -> bool: if not self.client: return False try: @@ -80,7 +80,7 @@ def get_by_nonce_state(self, nonce: str, state: str | None) -> dict: return document - def get_by_session_id(self, session_id: str): + def get_by_session_id(self, session_id: str) -> Union[dict, None]: self._connect() query = {"session_id": session_id} document = self.sessions.find_one(query) @@ -92,7 +92,7 @@ def get_by_session_id(self, session_id: str): return document - def get_by_state_and_session_id(self, state: str, session_id: str = ""): + def get_by_state_and_session_id(self, state: str, session_id: str = "") -> Union[dict, None]: self._connect() query = {"state": state} if session_id: @@ -125,7 +125,7 @@ def init_session(self, document_id: str, session_id: str, state: str) -> str: return document_id - def add_dpop_proof_and_attestation(self, document_id: str, dpop_proof: dict, attestation: dict): + def add_dpop_proof_and_attestation(self, document_id: str, dpop_proof: dict, attestation: dict) -> UpdateResult: self._connect() update_result: UpdateResult = self.sessions.update_one( {"document_id": document_id}, @@ -142,7 +142,7 @@ def add_dpop_proof_and_attestation(self, document_id: str, dpop_proof: dict, att return update_result - def update_request_object(self, document_id: str, request_object: dict): + def update_request_object(self, document_id: str, request_object: dict) -> UpdateResult: self.get_by_id(document_id) documentStatus = self.sessions.update_one( {"document_id": document_id}, @@ -177,7 +177,7 @@ def set_finalized(self, document_id: str): ) return update_result - def update_response_object(self, nonce: str, state: str, internal_response: dict): + def update_response_object(self, nonce: str, state: str, internal_response: dict) -> UpdateResult: document = self.get_by_nonce_state(nonce, state) document_id = document["_id"] document_status = self.sessions.update_one( @@ -190,24 +190,24 @@ def update_response_object(self, nonce: str, state: str, internal_response: dict return document_status - def _get_trust_attestation(self, collection: str, entity_id: str) -> dict: + def _get_trust_attestation(self, collection: str, entity_id: str) -> dict | None: self._connect() db_collection = getattr(self, collection) return db_collection.find_one({"entity_id": entity_id}) - def get_trust_attestation(self, entity_id: str): + def get_trust_attestation(self, entity_id: str) -> dict | None: return self._get_trust_attestation("trust_attestations", entity_id) - def get_trust_anchor(self, entity_id: str): + def get_trust_anchor(self, entity_id: str) -> dict | None: return self._get_trust_attestation("trust_anchors", entity_id) - def _has_trust_attestation(self, collection: str, entity_id: str): - return self._get_trust_attestation(collection, entity_id) + def _has_trust_attestation(self, collection: str, entity_id: str) -> bool: + return self._get_trust_attestation(collection, entity_id) != None - def has_trust_attestation(self, entity_id: str): + def has_trust_attestation(self, entity_id: str) -> bool: return self._has_trust_attestation("trust_attestations", entity_id) - def has_trust_anchor(self, entity_id: str): + def has_trust_anchor(self, entity_id: str) -> bool: return self._has_trust_attestation("trust_anchors", entity_id) def _add_entry( @@ -257,7 +257,7 @@ def _update_anchor_metadata(self, entity: dict, attestation: list[str], exp: dat return entity - def add_trust_attestation(self, entity_id: str, attestation: list[str], exp: datetime, trust_type: TrustType): + def add_trust_attestation(self, entity_id: str, attestation: list[str], exp: datetime, trust_type: TrustType) -> str: entity = { "entity_id": entity_id, "federation": {}, @@ -267,7 +267,7 @@ def add_trust_attestation(self, entity_id: str, attestation: list[str], exp: dat updated_entity = self._update_attestation_metadata(entity, attestation, exp, trust_type) - self._add_entry( + return self._add_entry( "trust_attestations", entity_id, updated_entity, exp ) @@ -285,7 +285,7 @@ def add_trust_attestation_metadata(self, entity_id: str, metadata_type: str, met def add_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datetime, trust_type: TrustType): if self.has_trust_anchor(entity_id): - self.update_trust_anchor(entity_id, entity_configuration, exp, trust_type) + return self.update_trust_anchor(entity_id, entity_configuration, exp, trust_type) else: entity = { "entity_id": entity_id, @@ -294,7 +294,7 @@ def add_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datet } updated_entity = self._update_anchor_metadata(entity, entity_configuration, exp, trust_type) - self._add_entry("trust_anchors", entity_id, updated_entity, exp) + return self._add_entry("trust_anchors", entity_id, updated_entity, exp) def _update_trust_attestation(self, collection: str, entity_id: str, entity: dict) -> str: if not self._has_trust_attestation(collection, entity_id): diff --git a/pyeudiw/tests/satosa/test_backend.py b/pyeudiw/tests/satosa/test_backend.py index 03a4be00..fcf65f69 100644 --- a/pyeudiw/tests/satosa/test_backend.py +++ b/pyeudiw/tests/satosa/test_backend.py @@ -241,9 +241,9 @@ def test_vp_validation_in_redirect_endpoint(self, context): context.request = { "response": encrypted_response } - redirect_endpoint = self.backend.redirect_endpoint(context) - assert redirect_endpoint.status == "400" - msg = json.loads(redirect_endpoint.message) + request_endpoint = self.backend.request_endpoint(context) + assert request_endpoint.status == "400" + msg = json.loads(request_endpoint.message) assert msg["error"] == "invalid_request" assert msg["error_description"] == "Error while validating VP: unexpected value." @@ -283,9 +283,9 @@ def test_vp_validation_in_redirect_endpoint(self, context): context.request = { "response": encrypted_response } - redirect_endpoint = self.backend.redirect_endpoint(context) - assert redirect_endpoint.status == "400" - msg = json.loads(redirect_endpoint.message) + request_endpoint = self.backend.request_endpoint(context) + assert request_endpoint.status == "400" + msg = json.loads(request_endpoint.message) assert msg["error"] == "invalid_request" assert msg["error_description"] == "Error while validating VP: vp has no nonce." @@ -296,9 +296,9 @@ def test_vp_validation_in_redirect_endpoint(self, context): context.request = { "response": encrypted_response } - redirect_endpoint = self.backend.redirect_endpoint(context) - assert redirect_endpoint.status == "400" - msg = json.loads(redirect_endpoint.message) + request_endpoint = self.backend.request_endpoint(context) + assert request_endpoint.status == "400" + msg = json.loads(request_endpoint.message) assert msg["error"] == "invalid_request" assert msg["error_description"] == "DirectPostResponse content parse and validation error. Single VPs are faulty." @@ -382,8 +382,8 @@ def test_redirect_endpoint(self, context): } # no nonce - redirect_endpoint = self.backend.redirect_endpoint(context) - msg = json.loads(redirect_endpoint.message) + request_endpoint = self.backend.request_endpoint(context) + msg = json.loads(request_endpoint.message) assert msg["error"] == "invalid_request" assert "nonce" in msg["error_description"] assert "missing" in msg["error_description"] @@ -395,8 +395,8 @@ def test_redirect_endpoint(self, context): context.request = { "response": encrypted_response } - redirect_endpoint = self.backend.redirect_endpoint(context) - msg = json.loads(redirect_endpoint.message) + request_endpoint = self.backend.request_endpoint(context) + msg = json.loads(request_endpoint.message) assert msg["error"] == "invalid_request" assert msg["error_description"] == "Session lookup by state value failed" @@ -407,8 +407,8 @@ def test_redirect_endpoint(self, context): context.request = { "response": encrypted_response } - redirect_endpoint = self.backend.redirect_endpoint(context) - msg = json.loads(redirect_endpoint.message) + request_endpoint = self.backend.request_endpoint(context) + msg = json.loads(request_endpoint.message) assert msg["error"] == "invalid_request" assert msg["error_description"] == "Session lookup by state value failed" @@ -422,8 +422,8 @@ def test_redirect_endpoint(self, context): self.backend.db_engine.update_request_object( document_id=doc_id, request_object={"nonce": nonce, "state": state}) - redirect_endpoint = self.backend.redirect_endpoint(context) - assert redirect_endpoint.status == "302 Found" + request_endpoint = self.backend.request_endpoint(context) + assert request_endpoint.status == "302 Found" def test_request_endpoint(self, context): @@ -509,12 +509,12 @@ def test_request_endpoint(self, context): request_uri = CONFIG['metadata']['request_uris'][0] context.request_uri = request_uri - request_endpoint = self.backend.request_endpoint(context) + redirect_endpoint = self.backend.redirect_endpoint(context) - assert request_endpoint - assert request_endpoint.status == "200" - assert request_endpoint.message - msg = json.loads(request_endpoint.message) + assert redirect_endpoint + assert redirect_endpoint.status == "200" + assert redirect_endpoint.message + msg = json.loads(redirect_endpoint.message) assert msg["response"] header = decode_jwt_header(msg["response"]) diff --git a/pyeudiw/satosa/base_logger.py b/pyeudiw/tools/base_logger.py similarity index 100% rename from pyeudiw/satosa/base_logger.py rename to pyeudiw/tools/base_logger.py diff --git a/pyeudiw/tools/mobile.py b/pyeudiw/tools/mobile.py index 0c7372f6..b18a6f7a 100644 --- a/pyeudiw/tools/mobile.py +++ b/pyeudiw/tools/mobile.py @@ -1,7 +1,16 @@ from device_detector import DeviceDetector -def is_smartphone(useragent: str): +def is_smartphone(useragent: str) -> bool: + """Check if the useragent is a smartphone + + :param useragent: The useragent to check + :type useragent: str + :return: True if the useragent is a smartphone else False + :rtype: bool + """ + device = DeviceDetector(useragent).parse() if device.device_type() == 'smartphone': return True + return False diff --git a/pyeudiw/tools/schema_utils.py b/pyeudiw/tools/schema_utils.py index 9240de53..a8ba8757 100644 --- a/pyeudiw/tools/schema_utils.py +++ b/pyeudiw/tools/schema_utils.py @@ -13,7 +13,18 @@ ] -def check_algorithm(alg: str, info: FieldValidationInfo): +def check_algorithm(alg: str, info: FieldValidationInfo) -> None: + """ + Check if the algorithm is supported by the relaying party. + + :param alg: The algorithm to check + :type alg: str + :param info: The field validation info + :type info: FieldValidationInfo + + :raises ValueError: If the algorithm is not supported + """ + if not info.context: supported_algorithms = _default_supported_algorithms else: diff --git a/pyeudiw/tools/utils.py b/pyeudiw/tools/utils.py index 895dab69..cec68985 100644 --- a/pyeudiw/tools/utils.py +++ b/pyeudiw/tools/utils.py @@ -10,7 +10,18 @@ logger = logging.getLogger(__name__) -def make_timezone_aware(dt: datetime.datetime, tz: datetime.timezone | datetime.tzinfo = datetime.timezone.utc): +def make_timezone_aware(dt: datetime.datetime, tz: datetime.timezone | datetime.tzinfo = datetime.timezone.utc) -> datetime.datetime: + """ + Make a datetime timezone aware. + + :param dt: The datetime to make timezone aware + :type dt: datetime.datetime + :param tz: The timezone to use + :type tz: datetime.timezone | datetime.tzinfo + + :returns: The timezone aware datetime + :rtype: datetime.datetime + """ if dt.tzinfo is None: return dt.replace(tzinfo=tz) else: @@ -18,16 +29,41 @@ def make_timezone_aware(dt: datetime.datetime, tz: datetime.timezone | datetime. def iat_now() -> int: + """ + Get the current timestamp in seconds. + + :returns: The current timestamp in seconds + :rtype: int + """ return int(datetime.datetime.now(datetime.timezone.utc).timestamp()) def exp_from_now(minutes: int = 33) -> int: + """ + Get the expiration timestamp in seconds for the given minutes from now. + + :param minutes: The minutes from now + :type minutes: int + + :returns: The timestamp in seconds for the given minutes from now + :rtype: int + """ now = datetime.datetime.now(datetime.timezone.utc) return int((now + datetime.timedelta(minutes=minutes)).timestamp()) -def datetime_from_timestamp(value) -> datetime.datetime: - return make_timezone_aware(datetime.datetime.fromtimestamp(value)) +def datetime_from_timestamp(timestamp: int | float) -> datetime.datetime: + """ + Get a datetime from a timestamp. + + :param value: The timestamp + :type value: int | float + + :returns: The datetime + :rtype: datetime.datetime + """ + + return make_timezone_aware(datetime.datetime.fromtimestamp(timestamp)) def get_http_url(urls: list[str] | str, httpc_params: dict, http_async: bool = True) -> list[dict]: @@ -57,9 +93,19 @@ def get_http_url(urls: list[str] | str, httpc_params: dict, http_async: bool = T return responses -def get_jwks(httpc_params: dict, metadata: dict, federation_jwks: list = []) -> dict: +def get_jwks(httpc_params: dict, metadata: dict, federation_jwks: list[dict] = []) -> dict: """ - get jwks or jwks_uri or signed_jwks_uri + Get jwks or jwks_uri or signed_jwks_uri + + :param httpc_params: parameters to perform http requests. + :type httpc_params: dict + :param metadata: metadata of the entity + :type metadata: dict + :param federation_jwks: jwks of the federation + :type federation_jwks: list + + :returns: A list of responses. + :rtype: list[dict] """ jwks_list = [] if metadata.get('jwks'): @@ -85,5 +131,14 @@ def get_jwks(httpc_params: dict, metadata: dict, federation_jwks: list = []) -> return jwks_list -def random_token(n=254): +def random_token(n=254) -> str: + """ + Generate a random token. + + :param n: The length of the token + :type n: int + + :returns: The random token + :rtype: str + """ return token_hex(n) diff --git a/pyeudiw/trust/trust_anchors.py b/pyeudiw/trust/trust_anchors.py index 95e3f890..7460b757 100644 --- a/pyeudiw/trust/trust_anchors.py +++ b/pyeudiw/trust/trust_anchors.py @@ -7,7 +7,18 @@ logger = logging.getLogger(__name__) -def update_trust_anchors_ecs(trust_anchors: list, db: DBEngine, httpc_params: dict): +def update_trust_anchors_ecs(trust_anchors: list, db: DBEngine, httpc_params: dict) -> None: + """ + Update the trust anchors entity configurations. + + :param trust_anchors: The trust anchors + :type trust_anchors: list + :param db: The database engine + :type db: DBEngine + :param httpc_params: The HTTP client parameters + :type httpc_params: dict + """ + ta_ecs = get_entity_configurations( trust_anchors, httpc_params=httpc_params ) diff --git a/pyeudiw/trust/trust_chain.py b/pyeudiw/trust/trust_chain.py index bc9c0820..7e7e5b0f 100644 --- a/pyeudiw/trust/trust_chain.py +++ b/pyeudiw/trust/trust_chain.py @@ -1,28 +1,28 @@ import logging -from typing import List from typing import Optional -from cryptojwt import KeyJar from cryptojwt.jwt import utc_time_sans_frac -from pyeudiw.jwt.utils import decode_jwt_payload +from pyeudiw.tools.base_logger import BaseLogger __author__ = "Roland Hedberg" __license__ = "Apache 2.0" __version__ = "" -logger = logging.getLogger(__name__) - -class TrustChain(object): +class TrustChain(BaseLogger): """ Class in which to store the parsed result from applying metadata policies on a metadata statement. """ - def __init__(self, exp: int = 0, - verified_chain: Optional[list] = None): + verified_chain: Optional[list] = None) -> None: """ + Create a TrustChain instance. + :param exp: Expiration time + :type exp: int + :param verified_chain: The verified chain + :type verified_chain: list """ self.anchor = "" self.iss_path = [] @@ -32,10 +32,22 @@ def __init__(self, self.verified_chain = verified_chain self.combined_policy = {} - def keys(self): + def keys(self) -> list[str]: + """ + Returns the metadata fields keys + + :return: The metadata fields keys + :rtype: list + """ return self.metadata.keys() - def items(self): + def items(self) -> list[tuple[str, dict]]: + """ + Returns the metadata fields items + + :return: The metadata fields items + :rtype: list[tuple[str, dict]] + """ return self.metadata.items() def __getitem__(self, item): @@ -50,23 +62,42 @@ def claims(self): """ return self.metadata - def is_expired(self): + def is_expired(self) -> bool: + """ + Check if the trust chain is expired. + + :return: True if the trust chain is expired else False + :rtype: bool + """ now = utc_time_sans_frac() if self.exp < now: - logger.debug(f'is_expired: {self.exp} < {now}') + self._log_debug( + "Trust chain", + f'is_expired: {self.exp} < {now}' + ) return True else: return False - def export_chain(self): + def export_chain(self) -> list: """ Exports the verified chain in such a way that it can be used as value on the trust_chain claim in an authorization or explicit registration request. - :return: + + :return: The exported chain in reverse order + :rtype: list """ _chain = self.verified_chain _chain.reverse() return _chain def set_combined_policy(self, entity_type: str, combined_policy: dict): + """ + Set the combined policy for the given entity type. + + :param entity_type: The entity type + :type entity_type: str + :param combined_policy: The combined policy + :type combined_policy: dict + """ self.combined_policy[entity_type] = combined_policy diff --git a/pyeudiw/x509/verify.py b/pyeudiw/x509/verify.py index 8f647e7f..a5835318 100644 --- a/pyeudiw/x509/verify.py +++ b/pyeudiw/x509/verify.py @@ -10,6 +10,15 @@ logger = logging.getLogger(__name__) def _verify_x509_certificate_chain(pems: list[str]): + """ + Verify the x509 certificate chain. + + :param pems: The x509 certificate chain + :type pems: list[str] + + :returns: True if the x509 certificate chain is valid else False + :rtype: bool + """ try: store = crypto.X509Store() @@ -32,6 +41,16 @@ def _verify_x509_certificate_chain(pems: list[str]): return False def _check_chain_len(pems: list) -> bool: + """ + Check the x509 certificate chain lenght. + + :param pems: The x509 certificate chain + :type pems: list + + :returns: True if the x509 certificate chain lenght is valid else False + :rtype: bool + """ + chain_len = len(pems) if chain_len < 2: @@ -42,6 +61,15 @@ def _check_chain_len(pems: list) -> bool: return True def _check_datetime(exp: datetime | None): + """ + Check the x509 certificate chain expiration date. + + :param exp: The x509 certificate chain expiration date + :type exp: datetime.datetime | None + + :returns: True if the x509 certificate chain expiration date is valid else False + :rtype: bool + """ if exp == None: return True @@ -53,6 +81,18 @@ def _check_datetime(exp: datetime | None): return True def verify_x509_attestation_chain(x5c: list[bytes], exp: datetime | None = None) -> bool: + """ + Verify the x509 attestation certificate chain. + + :param x5c: The x509 attestation certificate chain + :type x5c: list[bytes] + :param exp: The x509 attestation certificate chain expiration date + :type exp: datetime.datetime | None + + :returns: True if the x509 attestation certificate chain is valid else False + :rtype: bool + """ + if not _check_chain_len(x5c) or not _check_datetime(exp): return False @@ -61,6 +101,17 @@ def verify_x509_attestation_chain(x5c: list[bytes], exp: datetime | None = None) return _verify_x509_certificate_chain(pems) def verify_x509_anchor(pem_str: str, exp: datetime | None = None) -> bool: + """ + Verify the x509 anchor certificate. + + :param pem_str: The x509 anchor certificate + :type pem_str: str + :param exp: The x509 anchor certificate expiration date + :type exp: datetime.datetime | None + + :returns: True if the x509 anchor certificate is valid else False + :rtype: bool + """ if not _check_datetime(exp): return False @@ -72,10 +123,28 @@ def verify_x509_anchor(pem_str: str, exp: datetime | None = None) -> bool: return _verify_x509_certificate_chain(pems) def get_issuer_from_x5c(x5c: list[bytes]) -> str: + """ + Get the issuer from the x509 certificate chain. + + :param x5c: The x509 certificate chain + :type x5c: list[bytes] + + :returns: The issuer + :rtype: str + """ cert = load_der_x509_certificate(x5c[-1]) return cert.subject.rfc4514_string().split("=")[1] def is_der_format(cert: bytes) -> str: + """ + Check if the certificate is in DER format. + + :param cert: The certificate + :type cert: bytes + + :returns: True if the certificate is in DER format else False + :rtype: bool + """ try: pem = DER_cert_to_PEM_cert(cert) crypto.load_certificate(crypto.FILETYPE_PEM, str(pem))