diff --git a/pyeudiw/jwk/__init__.py b/pyeudiw/jwk/__init__.py index 0106d2e4..e7fb4914 100644 --- a/pyeudiw/jwk/__init__.py +++ b/pyeudiw/jwk/__init__.py @@ -37,6 +37,8 @@ def __init__( :type hash_func: str :param ec_crv: a string that represents the curve to use with the instance. :type ec_crv: str + + :raises NotImplementedError: the key_type is not implemented """ kwargs = {} self.kid = "" diff --git a/pyeudiw/jwt/__init__.py b/pyeudiw/jwt/__init__.py index ed7d4a13..f5b14f44 100644 --- a/pyeudiw/jwt/__init__.py +++ b/pyeudiw/jwt/__init__.py @@ -3,7 +3,6 @@ from typing import Union, Any import cryptojwt -from cryptojwt.exception import VerificationError from cryptojwt.jwe.jwe import factory from cryptojwt.jwe.jwe_ec import JWE_EC from cryptojwt.jwe.jwe_rsa import JWE_RSA @@ -14,6 +13,8 @@ from pyeudiw.jwk.exceptions import KidError from pyeudiw.jwt.utils import decode_jwt_header +from .exceptions import JWEDecryptionError, JWSVerificationError + DEFAULT_HASH_FUNC = "SHA-256" DEFAULT_SIG_KTY_MAP = { @@ -103,13 +104,15 @@ def decrypt(self, jwe: str) -> dict: :param jwe: A string representing the jwe. :type jwe: str + :raises JWEDecryptionError: if jwe field is not in a JWE Format + :returns: A dict that represents the payload of decrypted JWE. :rtype: dict """ try: jwe_header = decode_jwt_header(jwe) except (binascii.Error, Exception) as e: - raise VerificationError("The JWT is not valid") + raise JWEDecryptionError(f"Not a valid JWE format for the following reason: {e}") _alg = jwe_header.get("alg") _enc = jwe_header.get("enc") @@ -186,12 +189,18 @@ def verify(self, jws: str, **kwargs) -> (str | Any | bytes): :type jws: str :param kwargs: Other optional fields to generate the JWE. + :raises JWSVerificationError: if jws field is not in a JWS Format + :returns: A string that represents the payload of JWS. :rtype: str """ _key = key_from_jwk_dict(self.jwk.as_dict()) _jwk_dict = self.jwk.as_dict() - _head = decode_jwt_header(jws) + + try: + _head = decode_jwt_header(jws) + except (binascii.Error, Exception) as e: + raise JWSVerificationError(f"Not a valid JWS format for the following reason: {e}") if _head.get("kid"): if _head["kid"] != _jwk_dict["kid"]: # pragma: no cover diff --git a/pyeudiw/jwt/exceptions.py b/pyeudiw/jwt/exceptions.py index cec4a78c..f9428711 100644 --- a/pyeudiw/jwt/exceptions.py +++ b/pyeudiw/jwt/exceptions.py @@ -2,4 +2,7 @@ class JWEDecryptionError(Exception): pass class JWTInvalidElementPosition(Exception): + pass + +class JWSVerificationError(Exception): pass \ No newline at end of file diff --git a/pyeudiw/jwt/utils.py b/pyeudiw/jwt/utils.py index b44e8a6e..cf0285ea 100644 --- a/pyeudiw/jwt/utils.py +++ b/pyeudiw/jwt/utils.py @@ -99,3 +99,40 @@ def is_jwt_format(jwt: str) -> bool: res = re.match(JWT_REGEXP, jwt) return bool(res) + +def is_jwe_format(jwt: str): + """ + Check if a string is in JWE format. + + :param jwt: a string that represents the jwt. + :type jwt: str + + :returns: True if the string is a JWE, False otherwise. + :rtype: bool + """ + + if not is_jwt_format(jwt): + return False + + header = decode_jwt_header(jwt) + + if header.get("enc", None) == None: + return False + + return True + +def is_jws_format(jwt: str): + """ + Check if a string is in JWS format. + + :param jwt: a string that represents the jwt. + :type jwt: str + + :returns: True if the string is a JWS, False otherwise. + :rtype: bool + """ + breakpoint() + if not is_jwt_format(jwt): + return False + + return not is_jwe_format(jwt) \ No newline at end of file diff --git a/pyeudiw/openid4vp/direct_post_response.py b/pyeudiw/openid4vp/direct_post_response.py index 8a76c9c1..c5406ed5 100644 --- a/pyeudiw/openid4vp/direct_post_response.py +++ b/pyeudiw/openid4vp/direct_post_response.py @@ -1,9 +1,8 @@ - +from typing import Dict from pyeudiw.jwk import JWK -from pyeudiw.jwt import JWEHelper -from pyeudiw.jwt.exceptions import JWEDecryptionError +from pyeudiw.jwt import JWEHelper, JWSHelper from pyeudiw.jwk.exceptions import KidNotFoundError -from pyeudiw.jwt.utils import decode_jwt_header +from pyeudiw.jwt.utils import decode_jwt_header, is_jwe_format from pyeudiw.openid4vp.exceptions import ( VPNotFound, VPInvalidNonce, @@ -11,11 +10,23 @@ ) from pyeudiw.openid4vp.schemas.vp_token import VPTokenPayload, VPTokenHeader from pyeudiw.openid4vp.vp import Vp - +from pydantic import ValidationError class DirectPostResponse: - def __init__(self, jwt: str, jwks_by_kids: dict, nonce: str = ""): - + """ + Helper class for generate Direct Post Response. + """ + def __init__(self, jwt: str, jwks_by_kids: Dict[str, dict], nonce: str = ""): + """ + Generate an instance of DirectPostResponse. + + :param jwt: a string that represents the jwt. + :type jwt: str + :param jwks_by_kids: a dictionary that contains one or more JWKs with the KID as the key. + :type jwks_by_kids: Dict[str, dict] + :param nonce: a string that represents the nonce. + :type nonce: str + """ self.headers = decode_jwt_header(jwt) self.jwks_by_kids = jwks_by_kids self.jwt = jwt @@ -26,35 +37,48 @@ def __init__(self, jwt: str, jwks_by_kids: dict, nonce: str = ""): self.credentials_by_issuer: dict = {} self._claims_by_issuer: dict = {} - @property - def payload(self) -> dict: - # TODO: detect if it is encrypted otherwise ... - # here we support only the encrypted jwt - if not self._payload: - self.decrypt() - return self._payload + def _decode_payload(self) -> None: + """ + Internally decrypts the content of the JWT. - def decrypt(self) -> None: + :raises JWSVerificationError: if jws field is not in a JWS Format + :raises JWEDecryptionError: if jwe field is not in a JWE Format + """ _kid = self.headers.get('kid', None) if not _kid: raise KidNotFoundError( f"The JWT headers {self.headers} doesnt have any KID value" ) self.jwk = JWK(self.jwks_by_kids[_kid]) - jweHelper = JWEHelper(self.jwk) - try: + + if is_jwe_format(self.jwt): + jweHelper = JWEHelper(self.jwk) self._payload = jweHelper.decrypt(self.jwt) - except Exception as e: - _msg = f"Response decryption error: {e}" - raise JWEDecryptionError(_msg) + else: + jwsHelper = JWSHelper(self.jwk) + self._payload = jwsHelper.verify(self.jwt) - def load_nonce(self, nonce: str): + def load_nonce(self, nonce: str) -> None: + """ + Load a nonce string inside the body of response. + + :param nonce: a string that represents the nonce. + :type nonce: str + """ self.nonce = nonce - def validate(self) -> bool: + def _validate_vp(self, vp: dict) -> bool: + """ + Validate a single Verifiable Presentation. - # check nonces - for vp in self.get_presentation_vps(): + :param vp: the verifiable presentation to validate. + :type vp: str + + :returns: True if is valid, False otherwhise. + :rtype: bool + """ + try: + # check nonce if self.nonce: if not vp.payload.get('nonce', None): raise NoNonceInVPToken() @@ -66,23 +90,40 @@ def validate(self) -> bool: ) VPTokenPayload(**vp.payload) VPTokenHeader(**vp.headers) + except ValidationError: + return False + return True + + + def validate(self) -> bool: + """ + Validates all VPs inside JWT's body. + :returns: True if all VP are valid, False otherwhise. + :rtype: bool + """ + + for vp in self.get_presentation_vps(): + if not self._validate_vp(vp): + return False + return True - @property - def vps(self): - if not self._vps: - self.get_presentation_vps() - return self._vps + def get_presentation_vps(self) -> list[dict]: + """ + Returns the presentation's verifiable presentations - def get_presentation_vps(self): + :returns: the list of vps. + :rtype: list[dict] + """ if self._vps: return self._vps _vps = self.payload.get('vp_token', []) vps = [_vps] if isinstance(_vps, str) else _vps + if not vps: - raise VPNotFound("vp is null") + raise VPNotFound(f"Vps for response with nonce \"{self.nonce}\" are empty") for vp in vps: _vp = Vp(vp) @@ -95,3 +136,17 @@ def get_presentation_vps(self): self.credentials_by_issuer[cred_iss].append(_vp.payload['vp']) return self._vps + + @property + def vps(self) -> list[dict]: + """Returns the presentation's verifiable presentations""" + if not self._vps: + self.get_presentation_vps() + return self._vps + + @property + def payload(self) -> dict: + """Returns the decoded payload of presentation""" + if not self._payload: + self._decode_payload() + return self._payload \ No newline at end of file diff --git a/pyeudiw/openid4vp/vp.py b/pyeudiw/openid4vp/vp.py index fcfb385b..74aaf261 100644 --- a/pyeudiw/openid4vp/vp.py +++ b/pyeudiw/openid4vp/vp.py @@ -1,43 +1,62 @@ - from pyeudiw.jwt.utils import decode_jwt_payload, decode_jwt_header from pyeudiw.openid4vp.vp_sd_jwt import VpSdJwt class Vp(VpSdJwt): + "Class for SD-JWT Format" + def __init__(self, jwt: str) -> None: + """ + Generates a VP istance. - def __init__(self, jwt: str): - # TODO: what if the credential is not a JWT? - self.headers = decode_jwt_header(jwt) - self.jwt = jwt - self.payload = decode_jwt_payload(jwt) + :param jwt: a string that represents the jwt. + :type jwt: str - self.credential_headers: dict = {} - self.credential_payload: dict = {} + :raises InvalidVPToken: if the jwt field's value is not a JWT. + """ + super().__init__(jwt) self.parse_digital_credential() self.disclosed_user_attributes: dict = {} - def _detect_vp_type(self): - # TODO - automatic detection of the credential - return 'jwt' - - def get_credential_jwks(self): + def _detect_vp_type(self) -> str: + """ + Detects and return the type of verifiable presentation. + + :returns: the type of VP. + :rtype: str + """ + return self.headers["typ"].lower() + + def get_credential_jwks(self) -> list[dict]: + """ + Returns the credential JWKs. + + :returns: the list containing credential's JWKs. + :rtype: list[dict] + """ if not self.credential_jwks: return {} return self.credential_jwks - @property - def credential_issuer(self): - if not self.credential_payload.get('iss', None): - self.parse_digital_credential() - return self.credential_payload.get('iss', None) - - def parse_digital_credential(self): + def parse_digital_credential(self) -> None: + """ + Parse the digital credential of VP. + + :raises NotImplementedError: if VP Digital credentials type not implemented. + """ _typ = self._detect_vp_type() - if _typ == 'jwt': - self.credential_headers = decode_jwt_header(self.payload['vp']) - self.credential_payload = decode_jwt_payload(self.payload['vp']) - else: + + if _typ != 'jwt': raise NotImplementedError( f"VP Digital credentials type not implemented yet: {_typ}" ) + + self.credential_headers = decode_jwt_header(self.payload['vp']) + self.credential_payload = decode_jwt_payload(self.payload['vp']) + + @property + def credential_issuer(self) -> str: + """Returns the credential issuer""" + if not self.credential_payload.get('iss', None): + self.parse_digital_credential() + return self.credential_payload.get('iss', None) \ No newline at end of file diff --git a/pyeudiw/openid4vp/vp_sd_jwt.py b/pyeudiw/openid4vp/vp_sd_jwt.py index 7ba57a06..a1115820 100644 --- a/pyeudiw/openid4vp/vp_sd_jwt.py +++ b/pyeudiw/openid4vp/vp_sd_jwt.py @@ -1,17 +1,52 @@ +from typing import Dict from pyeudiw.jwk import JWK from pyeudiw.jwt import JWSHelper +from pyeudiw.jwt.utils import is_jwt_format, decode_jwt_header, decode_jwt_payload from pyeudiw.sd_jwt import verify_sd_jwt from pyeudiw.jwk.exceptions import KidNotFoundError +from .exceptions import InvalidVPToken class VpSdJwt: + """Class for SD-JWT Format""" + + def __init__(self, jwt: str): + """ + Generates a VpSdJwt istance + + :param jwt: a string that represents the jwt. + :type jwt: str + + :raises InvalidVPToken: if the jwt field's value is not a JWT. + """ + + if not is_jwt_format(jwt): + raise InvalidVPToken(f"VP is not in JWT format.") + + self.headers = decode_jwt_header(jwt) + self.jwt = jwt + self.payload = decode_jwt_payload(jwt) + + self.credential_headers: dict = {} + self.credential_payload: dict = {} def verify_sdjwt( self, - issuer_jwks_by_kid: dict = {} - ) -> dict: + issuer_jwks_by_kid: Dict[str, dict] = {} + ) -> bool: + """ + Verifies a SDJWT. + + :param jwks_by_kids: a dictionary that contains one or more JWKs with the KID as the key. + :type jwks_by_kids: Dict[str, dict] + + :raises KidNotFoundError: if the needed kid is not inside the issuer_jwks_by_kid. + :raises NotImplementedError: the key_type of one or more JWK is not implemented. + :raises JWSVerificationError: if self.jwt field is not in a JWS Format. + :returns: True if is valid, False otherwise. + """ if not issuer_jwks_by_kid.get(self.credential_headers["kid"], None): raise KidNotFoundError( f"issuer jwks {issuer_jwks_by_kid} doesn't contain " diff --git a/pyeudiw/tests/test_jwt.py b/pyeudiw/tests/test_jwt.py index d0982098..bd9373fc 100644 --- a/pyeudiw/tests/test_jwt.py +++ b/pyeudiw/tests/test_jwt.py @@ -3,7 +3,7 @@ from pyeudiw.jwk import JWK from pyeudiw.jwt import (DEFAULT_ENC_ALG_MAP, DEFAULT_ENC_ENC_MAP, JWEHelper, JWSHelper) -from pyeudiw.jwt.utils import decode_jwt_header +from pyeudiw.jwt.utils import decode_jwt_header, is_jwe_format JWKs_EC = [ (JWK(key_type="EC"), {"key": "value"}), @@ -47,6 +47,7 @@ def test_jwe_helper_encrypt(jwk, payload): helper = JWEHelper(jwk) jwe = helper.encrypt(payload) assert jwe + assert is_jwe_format(jwe) @pytest.mark.parametrize("jwk, payload", JWKs_RSA) @@ -83,7 +84,6 @@ def test_jws_helper_sign(jwk, payload): jws = helper.sign(payload) assert jws - @pytest.mark.parametrize("jwk, payload", JWKs_RSA) def test_jws_helper_verify(jwk, payload): helper = JWSHelper(jwk)