diff --git a/pyeudiw/openid4vp/direct_post_response.py b/pyeudiw/openid4vp/direct_post_response.py index c5406ed5..a51ad414 100644 --- a/pyeudiw/openid4vp/direct_post_response.py +++ b/pyeudiw/openid4vp/direct_post_response.py @@ -34,7 +34,7 @@ def __init__(self, jwt: str, jwks_by_kids: Dict[str, dict], nonce: str = ""): self._payload: dict = {} self._vps: list = [] - self.credentials_by_issuer: dict = {} + self.credentials_by_issuer: Dict[str, list[dict]] = {} self._claims_by_issuer: dict = {} def _decode_payload(self) -> None: @@ -109,9 +109,11 @@ def validate(self) -> bool: return True - def get_presentation_vps(self) -> list[dict]: + def get_presentation_vps(self) -> list[Vp]: """ - Returns the presentation's verifiable presentations + Returns the presentation's verifiable presentations. + + :raises VPNotFound: if no VPs are found. :returns: the list of vps. :rtype: list[dict] @@ -123,7 +125,7 @@ def get_presentation_vps(self) -> list[dict]: vps = [_vps] if isinstance(_vps, str) else _vps if not vps: - raise VPNotFound(f"Vps for response with nonce \"{self.nonce}\" are empty") + raise VPNotFound(f"Vps are empty for response with nonce \"{self.nonce}\"") for vp in vps: _vp = Vp(vp) diff --git a/pyeudiw/openid4vp/vp.py b/pyeudiw/openid4vp/vp.py index 74aaf261..213f9ec0 100644 --- a/pyeudiw/openid4vp/vp.py +++ b/pyeudiw/openid4vp/vp.py @@ -17,6 +17,7 @@ def __init__(self, jwt: str) -> None: self.parse_digital_credential() self.disclosed_user_attributes: dict = {} + self._credential_jwks: list[dict] = [] def _detect_vp_type(self) -> str: """ @@ -54,6 +55,20 @@ def parse_digital_credential(self) -> None: self.credential_headers = decode_jwt_header(self.payload['vp']) self.credential_payload = decode_jwt_payload(self.payload['vp']) + def set_credential_jwks(self, credential_jwks: list[dict]) -> None: + """ + Set the credential JWKs for the current istance. + + :param credential_jwks: a list containing the credential's JWKs. + :type credential_jwks: list[dict] + """ + self._credential_jwks = credential_jwks + + @property + def credential_jwks(self) -> list[dict]: + """Returns the credential JWKs""" + return self._credential_jwks + @property def credential_issuer(self) -> str: """Returns the credential issuer""" diff --git a/pyeudiw/satosa/backend.py b/pyeudiw/satosa/backend.py index c8131646..2cd3f245 100644 --- a/pyeudiw/satosa/backend.py +++ b/pyeudiw/satosa/backend.py @@ -34,33 +34,42 @@ from pyeudiw.storage.db_engine import DBEngine from pyeudiw.storage.exceptions import StorageWriteError from pyeudiw.federation.schemas.wallet_relying_party import WalletRelyingParty +from pyeudiw.openid4vp.vp import Vp +from typing import Callable from pydantic import ValidationError -logger = logging.getLogger(__name__) +from .exceptions import HTTPError +from .base_http_error_handler import BaseHTTPErrorHandler +from .base_logger import BaseLogger - -class OpenID4VPBackend(BackendModule, BackendTrust, BackendDPoP): +class OpenID4VPBackend(BackendModule, BackendTrust, BackendDPoP, BaseHTTPErrorHandler, BaseLogger): """ A backend module (acting as a OpenID4VP SP). """ - def __init__(self, auth_callback_func, internal_attributes, config, base_url, name): + def __init__( + self, + auth_callback_func: Callable[[Context, InternalData], Response], + internal_attributes: dict[str, dict[str, str | list[str]]], + config: dict[str, dict[str, str] | list[str]], + base_url: str, + name: str + ) -> None: """ OpenID4VP backend module. :param auth_callback_func: Callback should be called by the module after the authorization in the backend is done. + :type auth_callback_func: Callable[[Context, InternalData], Response] :param internal_attributes: Mapping dictionary between SATOSA internal attribute names and the names returned by underlying IdP's/OP's as well as what attributes the calling SP's and RP's expects namevice. + :type internal_attributes: dict[str, dict[str, str | list[str]]] :param config: Configuration parameters for the module. - :param base_url: base url of the service - :param name: name of the plugin - :type auth_callback_func: - (satosa.context.Context, satosa.internal.InternalData) -> satosa.response.Response - :type internal_attributes: dict[string, dict[str, str | list[str]]] :type config: dict[str, dict[str, str] | list[str]] + :param base_url: base url of the service :type base_url: str + :param name: name of the plugin :type name: str """ super().__init__(auth_callback_func, internal_attributes, base_url, name) @@ -68,11 +77,8 @@ def __init__(self, auth_callback_func, internal_attributes, config, base_url, na try: WalletRelyingParty(**config['metadata']) except ValidationError as e: - logger.warning( - """ - The backend configuration presents the following validation issues: - {} - """.format(logger.warning(e))) + debug_message = f"""The backend configuration presents the following validation issues: {e}""" + self._log_warning("OpenID4VPBackend", debug_message) self.config = config self.client_id = self.config['metadata']['client_id'] @@ -98,50 +104,9 @@ def __init__(self, auth_callback_func, internal_attributes, config, base_url, na # resolve metadata pointers/placeholders self._render_metadata_conf_elements() self.init_trust_resources() + self._log_debug("OpenID4VP init", f"Loaded configuration: {json.dumps(config)}") - logger.debug( - lu.LOG_FMT.format( - id="OpenID4VP init", - message=f"Loaded configuration: {json.dumps(config)}" - ) - ) - - @property - def db_engine(self) -> DBEngine: - - try: - self._db_engine.is_connected - except Exception as e: - if getattr(self, '_db_engine', None): - logger.debug( - lu.LOG_FMT.format( - id="OpenID4VP db storage handling", - message=f"connection check silently fails and get restored: {e}" - ) - ) - self._db_engine = DBEngine(self.config["storage"]) - - return self._db_engine - - def _render_metadata_conf_elements(self) -> None: - for k, v in self.config['metadata'].items(): - if isinstance(v, (int, float, dict, list)): - continue - if not v or len(v) == 0: - continue - if all(( - v[0] == '<', - v[-1] == '>', - '.' in v - )): - conf_section, conf_k = v[1:-1].split('.') - self.config['metadata'][k] = self.config[conf_section][conf_k] - - @property - def default_metadata_private_jwk(self) -> tuple: - return tuple(self.metadata_jwks_by_kids.values())[0] - - def register_endpoints(self) -> list: + def register_endpoints(self) -> list[tuple[str, Callable[[Context], Response]]]: """ Creates a list of all the endpoints this backend module needs to listen to. In this case it's the authentication response from the underlying OP that is redirected from the OP to @@ -158,7 +123,8 @@ def register_endpoints(self) -> list: ) ) _endpoint = f"{self.client_id}{v}" - logger.debug( + self._log_debug( + "OpenID4VPBackend", f"Exposing backend entity endpoint = {_endpoint}" ) if k == 'get_response': @@ -171,7 +137,7 @@ def register_endpoints(self) -> list: self.absolute_status_url = _endpoint return url_map - def start_auth(self, context, internal_request): + def start_auth(self, context: Context, internal_request) -> Response: """ This is the start up function of the backend authorization. @@ -185,25 +151,21 @@ def start_auth(self, context, internal_request): """ return self.pre_request_endpoint(context, internal_request) - def _log(self, context: Context, level: str, message: str) -> None: - log_level = getattr(logger, level) - log_level( - lu.LOG_FMT.format( - id=lu.get_session_id(context.state), - message=message - ) - ) + def pre_request_endpoint(self, context: Context, internal_request, **kwargs) -> Response: + """ + This endpoint is called by the frontend before calling the request endpoint. + It initializes the session and returns the request_uri to be used by the frontend. - def pre_request_endpoint(self, context, internal_request, **kwargs): + :type context: the context of current request + :param context: the request context + :type internal_request: satosa.internal.InternalData + :param internal_request: Information about the authorization request + + :return: a response containing the request_uri + :rtype: satosa.response.Response + """ - self._log( - context, - level='debug', - message=( - "[INCOMING REQUEST] pre_request_endpoint with Context: " - f"{context.__dict__} and internal_request: {internal_request}" - ) - ) + self._log_function_debug("pre_request_endpoint", context, "internal_request", internal_request) session_id = context.state["SESSION_ID"] state = str(uuid.uuid4()) @@ -218,28 +180,14 @@ def pre_request_endpoint(self, context, internal_request, **kwargs): session_id=session_id ) except (StorageWriteError) as e: - _msg = ( - f"Error while initializing session with state {state} and {session_id}." - ) - logger.error(f"{_msg} for the following reason {e}") - return self.handle_error( - context, - message="server_error", - troubleshoot=f"{_msg}", - err=f"{_msg}. {e.__class__.__name__}: {e}", - err_code="500" - ) + _msg = f"Error while initializing session with state {state} and {session_id}." + self._log_error(context, f"{_msg} for the following reason {e}") + return self._handle_500(context, _msg, e) + except (Exception) as e: - _msg = ( - f"Error while initializing session with state {state} and {session_id}. " - ) - return self.handle_error( - context, - message="server_error", - troubleshoot=f"{_msg}", - err=f"{_msg}. {e.__class__.__name__}: {e}", - err_code="500" - ) + _msg = f"Error while initializing session with state {state} and {session_id}." + self._log_error(context, _msg) + return self._handle_500(context, _msg, e) # PAR payload = { @@ -269,220 +217,86 @@ def pre_request_endpoint(self, context, internal_request, **kwargs): ) return Response(result, content="text/html; charset=utf8", status="200") - def _translate_response(self, response: dict, issuer: str, context: Context): - """ - Translates wallet response to SATOSA internal response. - :type response: dict[str, str] - :type issuer: str - :type subject_type: str - :rtype: InternalData - :param response: Dictioary with attribute name as key. - :param issuer: The oidc op that gave the repsonse. - :param subject_type: public or pairwise according to oidc standard. - :return: A SATOSA internal response. + def redirect_endpoint(self, context: Context, *args: tuple) -> Redirect | JsonResponse: """ - # it may depends by credential type and attested security context evaluated - # if WIA was previously submitted by the Wallet + This endpoint is called by the frontend after the user has been authenticated. - timestamp_epoch = ( - response.get("auth_time") - or response.get("iat") - or iat_now() - ) - timestamp_dt = datetime.datetime.fromtimestamp( - timestamp_epoch, - datetime.timezone.utc - ) - timestamp_iso = timestamp_dt.isoformat().replace("+00:00", "Z") - - auth_class_ref = ( - response.get("acr") or - response.get("amr") or - self.config["authorization"]["default_acr_value"] - ) - auth_info = AuthenticationInformation( - auth_class_ref, timestamp_iso, issuer) - - # TODO - ACR values - internal_resp = InternalData(auth_info=auth_info) - - sub = "" - pepper = self.config.get("user_attributes", {})['subject_id_random_value'] - for i in self.config.get("user_attributes", {}).get("unique_identifiers", []): - if response.get(i): - _sub = response[i] - sub = hashlib.sha256( - f"{_sub}~{pepper}".encode( - ) - ).hexdigest() - break - - if not sub: - self._log( - context, - level='warning', - message=( - "[USER ATTRIBUTES] Missing subject id from OpenID4VP presentation " - "setting a random one for interop for internal frontends" - ) - ) - sub = hashlib.sha256( - f"{json.dumps(response).encode()}~{pepper}".encode() - ).hexdigest() + :type context: the context of current request + :param context: the request context - response["sub"] = [sub] - internal_resp.attributes = self.converter.to_internal( - "openid4vp", response - ) - internal_resp.subject_id = sub - return internal_resp + :return: a redirect to the frontend, if is in same device flow, or a json response if is in cross device flow. + :rtype: Redirect | JsonResponse + """ - @property - def server_url(self): - return ( - self.base_url[:-1] - if self.base_url[-1] == '/' - else self.base_url - ) + self._log_function_debug("redirect_endpoint", context, "args", args) - def redirect_endpoint(self, context, *args): - self._log( - context, - level='debug', - message=( - "[INCOMING REQUEST] redirect_endpoint with Context: " - f"{context.__dict__} and args: {args}" - ) - ) if context.request_method.lower() != 'post': # raise BadRequestError("HTTP Method not supported") - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot="HTTP Method not supported", - err_code="400" - ) + return self._handle_400(context, "HTTP Method not supported") _endpoint = f'{self.server_url}{context.request_uri}' if self.config["metadata"].get('redirect_uris', None): if _endpoint not in self.config["metadata"]['redirect_uris']: - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot="request_uri not valid", - err_code="400" - ) + return self._handle_400(context, "request_uri not valid") # take the encrypted jwt, decrypt with my public key (one of the metadata) -> if not -> exception jwt = context.request.get("response", None) if not jwt: _msg = f"Response error, missing JWT" - self._log(context, level='error', message=_msg) - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot=_msg, - err_code="400" - ) + self._log_error(context, _msg) + return self._handle_400(context, _msg) try: vpt = DirectPostResponse(jwt, self.metadata_jwks_by_kids) - self._log( - context, - level='debug', - message=( - f"Redirect uri endpoint Response using direct post contains: {vpt.payload}" - ) - ) + + debug_message = f"Redirect uri endpoint Response using direct post contains: {vpt.payload}" + self._log_debug(context, debug_message) + ResponseSchema(**vpt.payload) except Exception as e: _msg = f"DirectPostResponse parse and validation error: {e}" - self._log(context, level='error', message=_msg) - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot=_msg, - err_code="400", - err=f"Error:{e}, with JWT: {jwt}" - ) + self._log_error(context, _msg) + return self._handle_400(context, _msg, HTTPError(f"Error:{e}, with JWT: {jwt}")) # state MUST be present in the response since it was in the request # then do lookup on the db -> if not -> exception state = vpt.payload.get("state", None) if not state: - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot="state not found in the response", - err_code="400", - err=f"{_msg} with: {vpt.payload}" - ) + return self._handle_400(context, _msg, HTTPError(f"{_msg} with: {vpt.payload}")) try: stored_session = self.db_engine.get_by_state(state=state) except Exception as e: _msg = f"Session lookup by state value failed" - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot=_msg, - err=f"{e.__class__.__name__}: {e}", - err_code="400" - ) + return self._handle_400(context, _msg, e) if stored_session["finalized"]: _msg = f"Session already finalized" - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot=_msg, - err=_msg, - err_code="400" - ) + return self._handle_400(context, _msg, HTTPError(_msg)) try: vpt.load_nonce(stored_session['nonce']) - vps: list = vpt.get_presentation_vps() + vps: list[Vp] = vpt.get_presentation_vps() vpt.validate() + except VPNotFound as e: _msg = "Error while retrieving VP. Payload 'vp_token' is empty or has an unexpected value." - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot=_msg, - err=f"{e.__class__.__name__}: {e}", - err_code="400" - ) + return self._handle_400(context, _msg, e) + except NoNonceInVPToken as e: _msg = "Error while validating VP: vp has no nonce." - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot=_msg, - err=f"{e.__class__.__name__}: {e}", - err_code="400" - ) + return self._handle_400(context, _msg, e) + except VPInvalidNonce as e: _msg = "Error while validating VP: unexpected value." - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot=_msg, - err=f"{e.__class__.__name__}: {e}", - err_code="400" - ) + return self._handle_400(context, _msg, e) + except Exception as e: _msg = ( "DirectPostResponse content parse and validation error. " - f"Single VPs are faulty." - ) - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot=_msg, - err=f"{e.__class__.__name__}: {e}", - err_code="400" + "Single VPs are faulty." ) + return self._handle_400(context, _msg, e) # evaluate the trust to each credential issuer found in the vps # look for trust chain or x509 or do discovery! @@ -500,51 +314,40 @@ def redirect_endpoint(self, context, *args): tchelper = self._validate_trust(context, vp.payload['vp']) if not tchelper.is_trusted: - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot=f"Trust Evaluation failed for {tchelper.entity_id}", - err_code="400" - ) + return self._handle_400(context, f"Trust Evaluation failed for {tchelper.entity_id}") # TODO: generalyze also for x509 - vp.credential_jwks = tchelper.get_trusted_jwks( + credential_jwks = tchelper.get_trusted_jwks( metadata_type='openid_credential_issuer' ) + vp.set_credential_jwks(credential_jwks) except InvalidVPToken: - return self.handle_error(context=context, message="invalid_request", troubleshoot=f"Cannot validate VP: {vp.jwt}", err_code="400") + return self._handle_400(context, f"Cannot validate VP: {vp.jwt}") except ValidationError as e: - return self.handle_error(context=context, message="invalid_request", troubleshoot=f"Error validating schemas: {e}", err_code="400") + return self._handle_400(context, f"Error validating schemas: {e}") except KIDNotFound as e: - return self.handle_error(context=context, message="invalid_request", troubleshoot=f"Kid error: {e}", err_code="400") + return self._handle_400(context, f"Kid error: {e}") except NotTrustedFederationError as e: - return self.handle_error(context=context, message="invalid_request", troubleshoot=f"Not trusted federation error: {e}", err_code="400") + return self._handle_400(context, f"Not trusted federation error: {e}") except Exception as e: - return self.handle_error(context=context, message="invalid_request", troubleshoot=f"VP parsing error: {e}", err_code="400") + return self._handle_400(context, f"VP parsing error: {e}") # the trust is established to the credential issuer, then we can get the disclosed user attributes # TODO - what if the credential is different from sd-jwt? -> generalyze within Vp class try: vp.verify_sdjwt( - issuer_jwks_by_kid={ + issuer_jwks_by_kid = { i['kid']: i for i in vp.credential_jwks} ) except Exception as e: - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot=f"VP SD-JWT validation error: {e}", - err_code="400" - ) + return self._handle_400(context, f"VP SD-JWT validation error: {e}") # vp.result attributes_by_issuers[vp.credential_issuer] = vp.disclosed_user_attributes - self._log( - context, - level='debug', - message=f"Disclosed user attributes from {vp.credential_issuer}: {vp.disclosed_user_attributes}" - ) + + debug_message = f"Disclosed user attributes from {vp.credential_issuer}: {vp.disclosed_user_attributes}" + self._log_debug(context, debug_message) # TODO: check the revocation of the credential # ... @@ -555,10 +358,7 @@ def redirect_endpoint(self, context, *args): for i in attributes_by_issuers.values(): all_user_attributes.update(**i) - self._log( - context, level='debug', - message=f"Wallet disclosure: {all_user_attributes}" - ) + self._log_debug(context, f"Wallet disclosure: {all_user_attributes}") # TODO: not sure that we want these issuers in the following form ... please recheck. _info = {"issuer": ';'.join(cred_issuers)} @@ -572,27 +372,14 @@ def redirect_endpoint(self, context, *args): ) # authentication finalized! self.db_engine.set_finalized(stored_session['document_id']) - if logger.getEffectiveLevel() == logging.DEBUG: + if self.effective_log_level == logging.DEBUG: stored_session = self.db_engine.get_by_state(state=state) - self._log( - context, - level="debug", - message=f"Session update on storage: {stored_session}" - ) + self._log_debug(context, f"Session update on storage: {stored_session}") + except StorageWriteError as e: # TODO - do we have to block in the case the update cannot be done? - self._log( - context, - level="error", - message=f"Session update on storage failed: {e}" - ) - return self.handle_error( - context=context, - message="server_error", - troubleshoot=f"Cannot update response object.", - err=f"{e.__class__.__name__}: {e}", - err_code="500" - ) + self._log_error(context, f"Session update on storage failed: {e}") + return self._handle_500(context, f"Cannot update response object.", e) if stored_session['session_id'] == context.state["SESSION_ID"]: # Same device flow @@ -608,16 +395,21 @@ def redirect_endpoint(self, context, *args): status="200" ) - def request_endpoint(self, context, *args): + def request_endpoint(self, context: Context, *args) -> JsonResponse: + """ + This endpoint is called by the frontend to retrieve the signed signed Request Object. + + :type context: the context of current request + :param context: the request context + :param args: the request arguments + :type args: tuple + + :return: a json response containing the request object + :rtype: JsonResponse + """ + + self._log_function_debug("request_endpoint", context, "args", args) - self._log( - context, - level='debug', - message=( - "[INCOMING REQUEST] request_endpoint with Context: " - f"{context.__dict__} and args: {args}" - ) - ) # check DPOP for WIA if any try: dpop_validation_error: JsonResponse = self._request_endpoint_dpop( @@ -626,16 +418,8 @@ def request_endpoint(self, context, *args): if dpop_validation_error: return dpop_validation_error except Exception as e: - _msg = ( - f"[DPoP VALIDATION ERROR] WIA evalution error: {e}." - ) - return self.handle_error( - context=context, - message="invalid_client", - troubleshoot=_msg, - err=f"{e} with {context.__dict__}", - err_code="401" - ) + _msg = f"[DPoP VALIDATION ERROR] WIA evalution error: {e}." + return self._handle_401(context, _msg, e) try: state = context.qs_params["id"] @@ -644,13 +428,7 @@ def request_endpoint(self, context, *args): "Error while retrieving id from qs_params: " f"{e.__class__.__name__}: {e}" ) - return self.handle_error( - context, - message="invalid_request", - troubleshoot=_msg, - err=f"{e} with {context.__dict__}", - err_code="400" - ) + return self._handle_400(context, _msg, HTTPError(f"{e} with {context.__dict__}")) data = { "scope": ' '.join(self.config['authorization']['scopes']), @@ -672,12 +450,7 @@ def request_endpoint(self, context, *args): attestation = context.http_headers['HTTP_AUTHORIZATION'] except KeyError as e: _msg = f"Error while accessing http headers: {e}" - return self.handle_error( - context, - message="invalid_request", - err=f"{e} with {context.__dict__}", - err_code="400" - ) + return self._handle_400(context, _msg, HTTPError(f"{e} with {context.__dict__}")) # take the session created in the pre-request authz endpoint try: @@ -687,25 +460,14 @@ def request_endpoint(self, context, *args): document_id, dpop_proof, attestation ) self.db_engine.update_request_object(document_id, data) + except ValueError as e: - _msg = ( - "Error while retrieving request object from database." - ) - return self.handle_error( - context, - message="server_error", - troubleshoot=_msg, - err=f"{e} with {context.__dict__}", - err_code="500" - ) + _msg = "Error while retrieving request object from database." + return self._handle_500(context, _msg, HTTPError(f"{e} with {context.__dict__}")) + except (Exception, BaseException) as e: _msg = f"Error while updating request object: {e}" - return self.handle_error( - context, - message="server_error", - err=_msg, - err_code="500" - ) + return self._handle_500(context, _msg, e) helper = JWSHelper(self.default_metadata_private_jwk) @@ -719,44 +481,18 @@ def request_endpoint(self, context, *args): status="200" ) - def handle_error( - self, - context: dict, - message: str, - troubleshoot: str = "", - err="", - err_code="500", - template_path="templates", - error_template="error.html", - level="error" - ): - - _msg = f"{message}:" - if err: - _msg += f" {err}." - self._log( - context, level=level, - message=f"{_msg} {troubleshoot}" - ) + def get_response_endpoint(self, context: Context) -> Response: + """ + This endpoint is called by the frontend to retrieve the response of the authentication. + + :param context: the request context + :type context: satosa.context.Context - return JsonResponse( - { - "error": message, - "error_description": troubleshoot - }, - status=err_code - ) + :return: a response containing the response object with the authenctication status + :rtype: Response + """ - def get_response_endpoint(self, context): - - self._log( - context, - level='debug', - message=( - "[INCOMING REQUEST] get_response_endpoint with Context: " - f"{context.__dict__}" - ) - ) + self._log_function_debug("get_response_endpoint", context) state = context.qs_params.get("id", None) session_id = context.state["SESSION_ID"] @@ -775,30 +511,15 @@ def get_response_endpoint(self, context): ) except Exception as e: _msg = f"Error while retrieving session by state {state} and session_id {session_id}: {e}" - return self.handle_error( - context=context, - message="invalid_client", - troubleshoot=_msg, - err_code="401" - ) + return self._handle_401(context, _msg, e) if not finalized_session: - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot="session not found or invalid", - err_code="400" - ) + return self._handle_400(context, "session not found or invalid") _now = iat_now() _exp = finalized_session['request_object']['exp'] if _exp < _now: - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot=f"session expired, request object exp is {_exp} while now is {_now}", - err_code="400" - ) + return self._handle_400(context, f"session expired, request object exp is {_exp} while now is {_now}") internal_response = InternalData() resp = internal_response.from_dict( @@ -810,16 +531,18 @@ def get_response_endpoint(self, context): resp ) - def status_endpoint(self, context): + def status_endpoint(self, context: Context) -> JsonResponse: + """ + This endpoint is called by the frontend the url to the response endpoint to finalize the process. - self._log( - context, - level='debug', - message=( - "[INCOMING REQUEST] state_endpoint with Context: " - f"{context.__dict__}" - ) - ) + :param context: the request context + :type context: satosa.context.Context + + :return: a json response containing the status and the url to get the response + :rtype: JsonResponse + """ + + self._log_function_debug("status_endpoint", context) session_id = context.state["SESSION_ID"] _err_msg = "" @@ -833,12 +556,7 @@ def status_endpoint(self, context): _err_msg = f"No id found in qs_params: {e}" if _err_msg: - return self.handle_error( - context=context, - message="invalid_request", - troubleshoot=_err_msg, - err_code="400" - ) + return self._handle_400(context, _err_msg) try: session = self.db_engine.get_by_state_and_session_id( @@ -846,22 +564,12 @@ def status_endpoint(self, context): ) except Exception as e: _msg = f"Error while retrieving session by state {state} and session_id {session_id}: {e}" - return self.handle_error( - context=context, - message="invalid_client", - troubleshoot=_msg, - err_code="401" - ) + return self._handle_401(context, _msg) request_object = session.get("request_object", None) if request_object: if iat_now() > request_object["exp"]: - return self.handle_error( - context=context, - message="expired", - troubleshoot=f"Request object expired", - err_code="403" - ) + return self._handle_403("expired", f"Request object expired") if session["finalized"]: # return Redirect( @@ -880,3 +588,115 @@ def status_endpoint(self, context): }, status="201" ) + + def _render_metadata_conf_elements(self) -> None: + """Renders the elements of config's metadata""" + for k, v in self.config['metadata'].items(): + if isinstance(v, (int, float, dict, list)): + continue + if not v or len(v) == 0: + continue + if all(( + v[0] == '<', + v[-1] == '>', + '.' in v + )): + conf_section, conf_k = v[1:-1].split('.') + self.config['metadata'][k] = self.config[conf_section][conf_k] + + def _translate_response(self, response: dict, issuer: str, context: Context): + """ + Translates wallet response to SATOSA internal response. + :type response: dict[str, str] + :type issuer: str + :type subject_type: str + :rtype: InternalData + :param response: Dictioary with attribute name as key. + :param issuer: The oidc op that gave the repsonse. + :param subject_type: public or pairwise according to oidc standard. + :return: A SATOSA internal response. + """ + # it may depends by credential type and attested security context evaluated + # if WIA was previously submitted by the Wallet + + timestamp_epoch = ( + response.get("auth_time") + or response.get("iat") + or iat_now() + ) + timestamp_dt = datetime.datetime.fromtimestamp( + timestamp_epoch, + datetime.timezone.utc + ) + timestamp_iso = timestamp_dt.isoformat().replace("+00:00", "Z") + + auth_class_ref = ( + response.get("acr") or + response.get("amr") or + self.config["authorization"]["default_acr_value"] + ) + auth_info = AuthenticationInformation( + auth_class_ref, timestamp_iso, issuer) + + # TODO - ACR values + internal_resp = InternalData(auth_info=auth_info) + + sub = "" + pepper = self.config.get("user_attributes", {})['subject_id_random_value'] + for i in self.config.get("user_attributes", {}).get("unique_identifiers", []): + if response.get(i): + _sub = response[i] + sub = hashlib.sha256( + f"{_sub}~{pepper}".encode( + ) + ).hexdigest() + break + + if not sub: + self._log( + context, + level='warning', + message=( + "[USER ATTRIBUTES] Missing subject id from OpenID4VP presentation " + "setting a random one for interop for internal frontends" + ) + ) + sub = hashlib.sha256( + f"{json.dumps(response).encode()}~{pepper}".encode() + ).hexdigest() + + response["sub"] = [sub] + internal_resp.attributes = self.converter.to_internal( + "openid4vp", response + ) + internal_resp.subject_id = sub + return internal_resp + + @property + def db_engine(self) -> DBEngine: + """Returns the DBEngine instance used by the class""" + try: + self._db_engine.is_connected + except Exception as e: + if getattr(self, '_db_engine', None): + self._log_debug( + "OpenID4VP db storage handling", + f"connection check silently fails and get restored: {e}" + ) + self._db_engine = DBEngine(self.config["storage"]) + + return self._db_engine + + @property + def default_metadata_private_jwk(self) -> tuple: + """Returns the default metadata private JWK""" + return tuple(self.metadata_jwks_by_kids.values())[0] + + @property + def server_url(self): + """Returns the server url""" + return ( + self.base_url[:-1] + if self.base_url[-1] == '/' + else self.base_url + ) \ No newline at end of file diff --git a/pyeudiw/satosa/base_http_error_handler.py b/pyeudiw/satosa/base_http_error_handler.py new file mode 100644 index 00000000..785042c5 --- /dev/null +++ b/pyeudiw/satosa/base_http_error_handler.py @@ -0,0 +1,152 @@ +from satosa.context import Context +from .base_logger import BaseLogger +from .exceptions import EmptyHTTPError +from pyeudiw.satosa.response import JsonResponse + + +class BaseHTTPErrorHandler(BaseLogger): + def _serialize_error( + self, + context: Context, + message: str, + troubleshoot: str, + err: str, + err_code: str, + level: str + ) -> JsonResponse: + """ + Serializes an error. + + :param context: the request context + :type context: satosa.context.Context + :param message: the error message + :type message: str + :param troubleshoot: the troubleshoot message + :type troubleshoot: str + :param err: more info about the error + :type err: str + :param err_code: the error code + :type err_code: str + :param level: the log level + :type level: str + + :return: a json response containing the error + :rtype: JsonResponse + """ + + _msg = f"{message}:" + if err: + _msg += f" {err}." + self._log( + context, level=level, + message=f"{_msg} {troubleshoot}" + ) + + return JsonResponse({ + "error": message, + "error_description": troubleshoot + }, + status=err_code + ) + + def _handle_500(self, context: Context, msg: str, err: Exception) -> JsonResponse: + """ + Handles a 500 error. + + :param context: the request context + :type context: satosa.context.Context + :param msg: the error message + :type msg: str + :param err: the exception raised + :type err: Exception + + :return: a json response containing the error + :rtype: JsonResponse + """ + + return self._serialize_error( + context, + "server_error", + f"{msg}", + f"{msg}. {err.__class__.__name__}: {err}", + "500", + "error" + ) + + def _handle_40X(self, code_number: str, message: str, context: Context, troubleshoot: str, err: Exception) -> JsonResponse: + """ + Handles a 40X error. + + :param code_number: the code number + :type code_number: str + :param message: the error message + :type message: str + :param context: the request context + :type context: satosa.context.Context + :param troubleshoot: the troubleshoot message + :type troubleshoot: str + :param err: the exception raised + :type err: Exception + + :return: a json response containing the error + :rtype: JsonResponse + """ + + return self._serialize_error( + context, + message, + troubleshoot, + f"{err.__class__.__name__}: {err}", + f"40{code_number}", + "error" + ) + + def _handle_400(self, context: Context, troubleshoot: str, err: Exception = EmptyHTTPError("")) -> JsonResponse: + """ + Handles a 400 error. + + :param context: the request context + :type context: satosa.context.Context + :param troubleshoot: the troubleshoot message + :type troubleshoot: str + :param err: the exception raised + :type err: Exception + + :return: a json response containing the error + :rtype: JsonResponse + """ + return self._handle_40X("0", "invalid_request", context, troubleshoot, err) + + def _handle_401(self, context, troubleshoot: str, err: EmptyHTTPError = EmptyHTTPError("")): + """ + Handles a 401 error. + + :param context: the request context + :type context: satosa.context.Context + :param troubleshoot: the troubleshoot message + :type troubleshoot: str + :param err: the exception raised + :type err: Exception + + :return: a json response containing the error + :rtype: JsonResponse + """ + + return self._handle_40X("1", "invalid_client", context, troubleshoot, err) + + def _handle_403(self, context, troubleshoot: str, err: EmptyHTTPError = EmptyHTTPError("")): + """ + Handles a 403 error. + + :param context: the request context + :type context: satosa.context.Context + :param troubleshoot: the troubleshoot message + :type troubleshoot: str + :param err: the exception raised + :type err: Exception + + :return: a json response containing the error + :rtype: JsonResponse + """ + + return self._handle_40X("3", "expired", context, troubleshoot, err) \ No newline at end of file diff --git a/pyeudiw/satosa/base_logger.py b/pyeudiw/satosa/base_logger.py new file mode 100644 index 00000000..6dbda148 --- /dev/null +++ b/pyeudiw/satosa/base_logger.py @@ -0,0 +1,120 @@ +import logging +import satosa.logging_util as lu +from satosa.context import Context + +logger = logging.getLogger(__name__) + +class BaseLogger: + def _log(self, context: str | Context, level: str, message: str) -> None: + """ + Log a message with the given level. + + :param context: the request context or the scope of the class + :type context: satosa.context.Context | str + :param level: the log level + :type level: str + :param message: the message to log + :type message: str + """ + + context = context if isinstance(context, str) else context.state + + log_level = getattr(logger, level) + log_level( + lu.LOG_FMT.format( + id=lu.get_session_id(context), + message=message + ) + ) + + def _log_debug(self, context: str | Context, message: str) -> None: + """ + Log a message with the DEBUG level. + + :param context: the request context or the scope of the class + :type context: satosa.context.Context | str + :param message: the message to log + :type message: str + """ + + self._log(context, "debug", message) + + def _log_function_debug(self, fn_name: str, context: Context, args_name: str | None = None, args = None) -> None: + """ + Logs a message at the start of a backend function. + + :param fn_name: the name of the function + :type fn_name: str + :param context: the request context + :param args_name: the name of the arguments field + :type args_name: str | None + :param args: the arguments provided to the function + :type args: Any + """ + + args_str = f" and {args_name}: {args}" if not args_name else "" + + debug_message = ( + f"[INCOMING REQUEST] {fn_name} with Context: " + f"{context.__dict__}{args_str}" + ) + self._log_debug(context, debug_message) + + def _log_error(self, context: str | Context, message: str) -> None: + """ + Log a message with the ERROR level. + + :param context: the request context or the scope of the class + :type context: satosa.context.Context | str + :param message: the message to log + :type message: str + """ + + self._log(context, "error", message) + + def _log_warning(self, context: str | Context, message: str) -> None: + """ + Log a message with the WARNING level. + + :param context: the request context or the scope of the class + :type context: satosa.context.Context | str + :param message: the message to log + :type message: str + """ + + self._log(context, "warning", message) + + def _log_info(self, context: str | Context, message: str) -> None: + """ + Log a message with the INFO level. + + :param context: the request context or the scope of the class + :type context: satosa.context.Context | str + :param message: the message to log + :type message: str + """ + + self._log(context, "info", message) + + def _log_critical(self, context: str | Context, message: str) -> None: + """ + Log a message with the CRITICAL level. + + :param context: the request context or the scope of the class + :type context: satosa.context.Context | str + :param message: the message to log + :type message: str + """ + + self._log(context, "critical", message) + + @property + def effective_log_level(self) -> int: + """ + Returns the effective log level. + + :return: the effective log level + :rtype: int + """ + + return logger.getEffectiveLevel() \ No newline at end of file diff --git a/pyeudiw/satosa/dpop.py b/pyeudiw/satosa/dpop.py index 366bb45b..0abc0ef6 100644 --- a/pyeudiw/satosa/dpop.py +++ b/pyeudiw/satosa/dpop.py @@ -1,24 +1,34 @@ -import logging - from typing import Union - - from pyeudiw.jwt.utils import decode_jwt_header, decode_jwt_payload from pyeudiw.oauth2.dpop import DPoPVerifier -from pyeudiw.openid4vp.schemas.wallet_instance_attestation import WalletInstanceAttestationPayload, \ +from pyeudiw.openid4vp.schemas.wallet_instance_attestation import ( + WalletInstanceAttestationPayload, WalletInstanceAttestationHeader +) from pyeudiw.satosa.response import JsonResponse - -import satosa.logging_util as lu from satosa.context import Context - -logger = logging.getLogger(__name__) - - -class BackendDPoP: - - def _request_endpoint_dpop(self, context, *args) -> Union[JsonResponse, None]: - """ This validates, if any, the DPoP http request header """ +from pydantic import ValidationError + +from .base_logger import BaseLogger +from .base_http_error_handler import BaseHTTPErrorHandler + +class BackendDPoP(BaseHTTPErrorHandler, BaseLogger): + """ + Backend DPoP class. + """ + + def _request_endpoint_dpop(self, context: Context, *args) -> Union[JsonResponse, None]: + """ + Validates, if any, the DPoP http request header + + :param context: The current context + :type context: Context + :param args: The current request arguments + :type args: tuple + + :return: + :rtype: Union[JsonResponse, None] + """ if context.http_headers and 'HTTP_AUTHORIZATION' in context.http_headers: # The wallet instance uses the endpoint authentication to give its WIA @@ -28,43 +38,27 @@ def _request_endpoint_dpop(self, context, *args) -> Union[JsonResponse, None]: _head = decode_jwt_header(dpop_jws) wia = decode_jwt_payload(dpop_jws) - self._log( - context, - level='debug', - message=( - f"[FOUND WIA] Headers: {_head} and Payload: {wia}" - ) - ) + self._log_debug(context, message=f"[FOUND WIA] Headers: {_head} and Payload: {wia}") try: WalletInstanceAttestationHeader(**_head) + except ValidationError as e: + self._log_warning(context, message=f"[FOUND WIA] Invalid Headers: {_head}! \nValidation error: {e}") except Exception as e: - self._log( - context, - level='warning', - message=f"[FOUND WIA] Invalid Headers: {_head}! \nValidation error: {e}" - ) + self._log_warning(context, message=f"[FOUND WIA] Invalid Headers: {_head}! \nUnexpected error: {e}") try: WalletInstanceAttestationPayload(**wia) + except ValidationError as e: + self._log_warning(context, message=f"[FOUND WIA] Invalid WIA: {wia}! \nValidation error: {e}") except Exception as e: - self._log( - context, - level='warning', - message=f"[FOUND WIA] Invalid WIA: {wia}! \nValidation error: {e}" - ) + self._log_warning(context, message=f"[FOUND WIA] Invalid WIA: {wia}! \nUnexpected error: {e}") try: self._validate_trust(context, dpop_jws) except Exception as e: _msg = f"Trust Chain validation failed for dpop JWS {dpop_jws}" - return self.handle_error( - context=context, - message="invalid_client", - troubleshoot=_msg, - err_code="401", - err=f"{e}" - ) + return self._handle_401(context, _msg, e) try: dpop = DPoPVerifier( @@ -72,27 +66,18 @@ def _request_endpoint_dpop(self, context, *args) -> Union[JsonResponse, None]: http_header_authz=context.http_headers['HTTP_AUTHORIZATION'], http_header_dpop=context.http_headers['HTTP_DPOP'] ) + except ValidationError as e: + _msg = f"DPoP validation error: {e}" + return self._handle_401(context, _msg, e) except Exception as e: _msg = f"DPoP verification error: {e}" - return self.handle_error( - context=context, - message="invalid_client", - troubleshoot=_msg, - err_code="401", - err=f"{e}" - ) + return self._handle_401(context, _msg, e) try: dpop.validate() except Exception as e: _msg = "DPoP validation exception" - return self.handle_error( - context=context, - message="invalid_client", - troubleshoot=_msg, - err=f"{e}", - err_code="401" - ) + return self._handle_401(context, _msg, e) # TODO: assert and configure the wallet capabilities # TODO: assert and configure the wallet Attested Security Context @@ -102,13 +87,4 @@ def _request_endpoint_dpop(self, context, *args) -> Union[JsonResponse, None]: "The Wallet Instance doesn't provide a valid Wallet Instance Attestation " "a default set of capabilities and a low security level are applied." ) - self._log(context, level='warning', message=_msg) - - def _log(self, context: Context, level: str, message: str) -> None: - log_level = getattr(logger, level) - log_level( - lu.LOG_FMT.format( - id=lu.get_session_id(context.state), - message=message - ) - ) + self._log_warning(context, message=_msg) \ No newline at end of file diff --git a/pyeudiw/satosa/exceptions.py b/pyeudiw/satosa/exceptions.py index 805e1a42..1940f1fa 100644 --- a/pyeudiw/satosa/exceptions.py +++ b/pyeudiw/satosa/exceptions.py @@ -21,3 +21,15 @@ class DiscoveryFailedError(Exception): Raised when the discovery fails """ pass + +class HTTPError(Exception): + """ + Raised when an error occurs during an HTTP request + """ + pass + +class EmptyHTTPError(HTTPError): + """ + Default HTTP empty error + """ + pass \ No newline at end of file diff --git a/pyeudiw/satosa/html_template.py b/pyeudiw/satosa/html_template.py index 3788818b..4320254d 100644 --- a/pyeudiw/satosa/html_template.py +++ b/pyeudiw/satosa/html_template.py @@ -1,9 +1,17 @@ +from typing import Any, Dict from jinja2 import Environment, FileSystemLoader, select_autoescape - class Jinja2TemplateHandler: + """ + Jinja2 template handler + """ + def __init__(self, config: Dict[str, Any]): + """ + Create an istance of Jinja2TemplateHandler - def __init__(self, config): + :param config: a dictionary that contains the configuration for initalize the template handler. + :type config: Dict[str, Any] + """ # error pages handler self.loader = Environment( diff --git a/pyeudiw/satosa/response.py b/pyeudiw/satosa/response.py index 542c91b6..4d73d811 100644 --- a/pyeudiw/satosa/response.py +++ b/pyeudiw/satosa/response.py @@ -1,12 +1,23 @@ import json - from satosa.response import Response class JsonResponse(Response): + """ + A JSON response istance class. + """ + _content_type = "application/json" def __init__(self, *args, **kwargs): + """ + Creates an instance of JsonResponse. + + :param args: a list of arguments + :type args: Any + :param kwargs: a dictionary of arguments + :type kwargs: Any + """ super().__init__(*args, **kwargs) if isinstance(self.message, list): diff --git a/pyeudiw/satosa/trust.py b/pyeudiw/satosa/trust.py index 5842c78a..ffaea9ca 100644 --- a/pyeudiw/satosa/trust.py +++ b/pyeudiw/satosa/trust.py @@ -1,7 +1,4 @@ import json -import logging - - import satosa.logging_util as lu from satosa.context import Context from satosa.response import Response @@ -17,13 +14,18 @@ from pyeudiw.trust import TrustEvaluationHelper from pyeudiw.trust.trust_anchors import update_trust_anchors_ecs +from .base_logger import BaseLogger -logger = logging.getLogger(__name__) - - -class BackendTrust: +class BackendTrust(BaseLogger): + """ + Backend Trust class. + """ def init_trust_resources(self) -> None: + """ + Initializes the trust resources. + """ + # private keys by kid self.federations_jwks_by_kids = { i['kid']: i for i in self.config['federation']['federation_jwks'] @@ -39,14 +41,21 @@ def init_trust_resources(self) -> None: try: self.get_backend_trust_chain() except Exception as e: - logger.critical( - f"Cannot fetch the trust anchor configuration: {e}" - ) + self._log_critical("Backend Trust", f"Cannot fetch the trust anchor configuration: {e}") self.db_engine.close() self._db_engine = None - def entity_configuration_endpoint(self, context): + def entity_configuration_endpoint(self, context: Context) -> Response: + """ + Entity Configuration endpoint. + + :param context: The current context + :type context: Context + + :return: The entity configuration + :rtype: Response + """ data = self.entity_configuration_as_dict if context.qs_params.get('format', '') == 'json': @@ -63,13 +72,13 @@ def entity_configuration_endpoint(self, context): ) def update_trust_anchors(self): + """ + Updates the trust anchors of current instance. + """ + tas = self.config['federation']['trust_anchors'] - logger.info( - lu.LOG_FMT.format( - id="Trust Anchors updates", - message=f"Trying to update: {tas}" - ) - ) + self._log_info("Trust Anchors updates", f"Trying to update: {tas}") + for ta in tas: try: update_trust_anchors_ecs( @@ -78,24 +87,11 @@ def update_trust_anchors(self): httpc_params=self.config['network']['httpc_params'] ) except Exception as e: - logger.warning( - lu.LOG_FMT.format( - id=f"Trust Anchor updates", - message=f"{ta} update failed: {e}" - ) - ) - logger.info( - lu.LOG_FMT.format( - id="Trust Anchor update", - message=f"Trust Anchor updated: {ta}" - ) - ) + self._log_warning("Trust Anchor updates", f"{ta} update failed: {e}") - @property - def default_federation_private_jwk(self): - return tuple(self.federations_jwks_by_kids.values())[0] + self._log_info("Trust Anchor updates", f"{ta} updated") - def get_backend_trust_chain(self) -> list: + def get_backend_trust_chain(self) -> list[str]: """ Get the backend trust chain. In case something raises an Exception (e.g. faulty storage), logs a warning message and returns an empty list. @@ -118,15 +114,72 @@ def get_backend_trust_chain(self) -> list: return trust_evaluation_helper.trust_chain except (DiscoveryFailedError, EntryNotFound, Exception) as e: - logger.warning( + message = ( f"Error while building trust chain for client with id: {self.client_id}\n" f"{e.__class__.__name__}: {e}" ) + self._log_warning("Trust Chain", message) return [] + + def _validate_trust(self, context: Context, jws: str) -> TrustEvaluationHelper: + """ + Validates the trust of the given jws. + + :param context: the request context + :type context: satosa.context.Context + :param jws: the jws to validate + :type jws: str + + :raises: NotTrustedFederationError: raises an error if the trust evaluation fails. + + :return: the trust evaluation helper + :rtype: TrustEvaluationHelper + """ + + self._log_debug(context, "[TRUST EVALUATION] evaluating trust.") + + headers = decode_jwt_header(jws) + trust_eval = TrustEvaluationHelper( + self.db_engine, + httpc_params=self.config['network']['httpc_params'], + **headers + ) + + try: + trust_eval.evaluation_method() + except EntryNotFound: + message = ( + "[TRUST EVALUATION] not found for " + f"{trust_eval.entity_id}" + ) + self._log_error(context, message) + + raise NotTrustedFederationError( + f"{trust_eval.entity_id} not found for Trust evaluation." + ) + except Exception as e: + message = ( + "[TRUST EVALUATION] failed for " + f"{trust_eval.entity_id}: {e}" + ) + self._log_error(context, message) + + raise NotTrustedFederationError( + f"{trust_eval.entity_id} is not trusted." + ) + + return trust_eval + + + @property + def default_federation_private_jwk(self) -> dict: + """Returns the default federation private jwk.""" + return tuple(self.federations_jwks_by_kids.values())[0] @property def entity_configuration_as_dict(self) -> dict: + """Returns the entity configuration as a dictionary.""" ec_payload = { "exp": exp_from_now(minutes=self.default_exp), "iat": iat_now(), @@ -145,6 +198,7 @@ def entity_configuration_as_dict(self) -> dict: @property def entity_configuration(self) -> dict: + """Returns the entity configuration as a JWT.""" data = self.entity_configuration_as_dict jwshelper = JWSHelper(self.default_federation_private_jwk) return jwshelper.sign( @@ -154,49 +208,4 @@ def entity_configuration(self) -> dict: "typ": "entity-statement+jwt" }, plain_dict=data - ) - - def _validate_trust(self, context: Context, jws: str) -> TrustEvaluationHelper: - self._log( - context, - level='debug', - message=( - "[TRUST EVALUATION] evaluating trust." - ) - ) - - headers = decode_jwt_header(jws) - trust_eval = TrustEvaluationHelper( - self.db_engine, - httpc_params=self.config['network']['httpc_params'], - **headers - ) - - try: - trust_eval.evaluation_method() - except Exception as e: - self._log( - context, - level='error', - message=( - "[TRUST EVALUATION] failed for " - f"{trust_eval.entity_id}: {e}" - ) - ) - raise NotTrustedFederationError( - f"{trust_eval.entity_id} is not trusted." - ) - except EntryNotFound: - self._log( - context, - level='error', - message=( - "[TRUST EVALUATION] not found for " - f"{trust_eval.entity_id}" - ) - ) - raise NotTrustedFederationError( - f"{trust_eval.entity_id} not found for Trust evaluation." - ) - - return trust_eval + ) \ No newline at end of file diff --git a/pyeudiw/sd_jwt/__init__.py b/pyeudiw/sd_jwt/__init__.py index 0fbe0e53..14da89ff 100644 --- a/pyeudiw/sd_jwt/__init__.py +++ b/pyeudiw/sd_jwt/__init__.py @@ -21,6 +21,8 @@ import jwcrypto +from typing import Any + class TrustChainSDJWTIssuer(SDJWTIssuer): def __init__(self, user_claims: Dict, issuer_key, holder_key=None, sign_alg=None, add_decoy_claims: bool = True, serialization_format: str = "compact", additional_headers: dict = {}): @@ -163,7 +165,7 @@ def verify_sd_jwt( issuer_key: JWK, holder_key: JWK, settings: dict = {'key_binding': True} -) -> dict: +) -> (list | dict | Any): settings.update( { diff --git a/pyeudiw/tests/satosa/test_backend.py b/pyeudiw/tests/satosa/test_backend.py index cd507927..bd8eb714 100644 --- a/pyeudiw/tests/satosa/test_backend.py +++ b/pyeudiw/tests/satosa/test_backend.py @@ -548,8 +548,8 @@ def test_request_endpoint(self, context): # assert msg["response"] == "Authentication successful" def test_handle_error(self, context): - error_message = "Error message!" - error_resp = self.backend.handle_error(context, error_message) + error_message = "server_error" + error_resp = self.backend._handle_500(context, error_message, Exception()) assert error_resp.status == "500" assert error_resp.message err = json.loads(error_resp.message) diff --git a/pyeudiw/trust/__init__.py b/pyeudiw/trust/__init__.py index 82c1bad4..7952fae1 100644 --- a/pyeudiw/trust/__init__.py +++ b/pyeudiw/trust/__init__.py @@ -27,7 +27,7 @@ class TrustEvaluationHelper: def __init__(self, storage: DBEngine, httpc_params, trust_anchor: str = None, **kwargs): self.exp: int = 0 - self.trust_chain: list = [] + self.trust_chain: list[str] = [] self.trust_anchor = trust_anchor self.storage = storage self.entity_id: str = "" @@ -230,7 +230,7 @@ def get_final_metadata(self, metadata_type: str, policies: list[dict]) -> dict: f" {self.final_metadata['metadata']}" ) - def get_trusted_jwks(self, metadata_type: str, policies: list[dict] = []) -> list: + def get_trusted_jwks(self, metadata_type: str, policies: list[dict] = []) -> list[dict]: return self.get_final_metadata( metadata_type=metadata_type, policies=policies