Skip to content

Commit

Permalink
Use Pydantic to type check API responses
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaso committed Aug 13, 2023
1 parent 8736afc commit 9771109
Show file tree
Hide file tree
Showing 4 changed files with 315 additions and 24 deletions.
132 changes: 115 additions & 17 deletions hv4gha/gh.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
"""GitHub specific code"""

import json
from datetime import datetime, timezone
from typing import Final, TypedDict
from datetime import datetime
from typing import Final, Literal, TypedDict

import requests
from pydantic import BaseModel, Field, TypeAdapter, ValidationError

PermARW = None | Literal["admin", "read", "write"]
PermRW = None | Literal["read", "write"]
PermR = None | Literal["read"]
PermW = None | Literal["write"]


class TokenResponse(TypedDict, total=False):
Expand Down Expand Up @@ -32,6 +38,89 @@ class NotInstalledError(Exception):
"""The GitHub App isn't installed in the specified account"""


class GitHubErrors(BaseModel):
"""
https://docs.github.com/en/rest/overview/resources-in-the-rest-api?apiVersion=2022-11-28
"""

message: str


class AccountInfo(BaseModel):
"""Part of Installation"""

login: str = Field(
max_length=39, pattern=r"^[a-zA-Z0-9]([a-zA-Z0-9\-]*[a-zA-Z0-9])?$"
)


class Installation(BaseModel):
"""
https://docs.github.com/en/rest/apps/apps?apiVersion=2022-11-28#list-installations-for-the-authenticated-app
"""

id: int
account: AccountInfo


class TokenPermissions(BaseModel):
"""Part of AccessToken"""

# Repository permissions
actions: PermRW = None
administration: PermRW = None
checks: PermRW = None
contents: PermRW = None
deployments: PermRW = None
environments: PermRW = None
issues: PermRW = None
metadata: PermRW = None
packages: PermRW = None
pages: PermRW = None
pull_requests: PermRW = None
repository_hooks: PermRW = None
repository_projects: PermARW = None
secret_scanning_alerts: PermRW = None
secrets: PermRW = None
security_events: PermRW = None
single_file: PermRW = None
statuses: PermRW = None
vulnerability_alerts: PermRW = None
workflows: PermW = None
# Organizational permissions
members: PermRW = None
organization_administration: PermRW = None
organization_custom_roles: PermRW = None
organization_announcement_banners: PermRW = None
organization_hooks: PermRW = None
organization_personal_access_tokens: PermRW = None
organization_personal_access_token_requests: PermRW = None
organization_plan: PermR = None
organization_projects: PermARW = None
organization_packages: PermRW = None
organization_secrets: PermRW = None
organization_self_hosted_runners: PermRW = None
organization_user_blocking: PermRW = None
team_discussions: PermRW = None


class Repository(BaseModel):
"""Part of AccessToken"""

name: str = Field(max_length=100, pattern=r"^[a-zA-Z0-9_\-\.]+$")


class AccessToken(BaseModel):
"""
https://docs.github.com/en/rest/apps/apps?apiVersion=2022-11-28#create-an-installation-access-token-for-an-app
"""

token: str
expires_at: datetime
permissions: TokenPermissions
repositories: None | list[Repository] = None


class GitHubApp:
"""GitHub App Access Tokens, etc"""

Expand Down Expand Up @@ -68,16 +157,23 @@ def __find_installation(self) -> str:
)
response.raise_for_status()
except requests.exceptions.HTTPError as http_error:
error_message: str
try:
error_message = http_error.response.json()["message"]
errors_bm = GitHubErrors(**http_error.response.json())
error_message = errors_bm.message
except Exception: # pylint: disable=broad-exception-caught
error_message = "<Failed to parse GitHub API error response>"
raise InstallationLookupError(error_message) from http_error

for installation in response.json():
if installation["account"]["login"].lower() == self.account.lower():
return str(installation["id"])
try:
ita = TypeAdapter(list[Installation])
installations = ita.validate_python(response.json())
except ValidationError as validation_error:
error_message = "<Failed to parse Installations API response>"
raise InstallationLookupError(error_message) from validation_error

for installation in installations:
if installation.account.login.lower() == self.account.lower():
return str(installation.id)

if "next" in response.links.keys():
pagination_params["page"] += 1
Expand Down Expand Up @@ -125,26 +221,28 @@ def issue_token(
)
response.raise_for_status()
except requests.exceptions.HTTPError as http_error:
error_message: str
try:
error_message = http_error.response.json()["message"]
errors_bm = GitHubErrors(**http_error.response.json())
error_message = errors_bm.message
except Exception: # pylint: disable=broad-exception-caught
error_message = "<Failed to parse GitHub API error response>"
raise TokenIssueError(error_message) from http_error

expiry = datetime.strptime(
response.json()["expires_at"], "%Y-%m-%dT%H:%M:%SZ"
).replace(tzinfo=timezone.utc)
try:
access_token_bm = AccessToken(**response.json())
except ValidationError as validation_error:
error_message = "<Failed to parse Token Issue API response>"
raise TokenIssueError(error_message) from validation_error

access_token: TokenResponse = {
"access_token": response.json()["token"],
"expires_at": expiry,
"permissions": response.json()["permissions"],
"access_token": access_token_bm.token,
"expires_at": access_token_bm.expires_at,
"permissions": access_token_bm.permissions.model_dump(exclude_unset=True),
}

if "repositories" in response.json().keys():
if access_token_bm.repositories is not None:
access_token["repositories"] = sorted(
[repo["name"] for repo in response.json()["repositories"]]
[repo.name for repo in access_token_bm.repositories]
)

return access_token
61 changes: 55 additions & 6 deletions hv4gha/vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import requests
from cryptography.hazmat.primitives import hashes, keywrap, serialization
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from pydantic import BaseModel, ValidationError


class VaultAPIError(Exception):
Expand All @@ -31,6 +32,42 @@ class WrappingKeyDownloadError(VaultAPIError):
"""Failure to download the Vault Transit wrapping key"""


class VaultErrors(BaseModel):
"""
https://developer.hashicorp.com/vault/api-docs#error-response
"""

errors: list[str]


class JWTData(BaseModel):
"""Part of SignedJWT"""

signature: str


class SignedJWT(BaseModel):
"""
https://developer.hashicorp.com/vault/api-docs/secret/transit#sign-data
"""

data: JWTData


class KeyData(BaseModel):
"""Part of WrappingKey"""

public_key: str


class WrappingKey(BaseModel):
"""
https://developer.hashicorp.com/vault/api-docs/secret/transit#get-wrapping-key
"""

data: KeyData


class VaultTransit:
"""Interact with Vault's Transit Secrets Engine"""

Expand Down Expand Up @@ -74,14 +111,20 @@ def __download_wrapping_key(self) -> rsa.RSAPublicKey:
)
response.raise_for_status()
except requests.exceptions.HTTPError as http_error:
error_message: str
try:
error_message = "\n".join(http_error.response.json()["errors"])
errors_bm = VaultErrors(**http_error.response.json())
error_message = "\n".join(errors_bm.errors)
except Exception: # pylint: disable=broad-exception-caught
error_message = "<Failed to parse Vault API error response>"
raise WrappingKeyDownloadError(error_message) from http_error

wrapping_pem_key = response.json()["data"]["public_key"].encode()
try:
wrapping_key_bm = WrappingKey(**response.json())
except ValidationError as validation_error:
error_message = "<Failed to parse Wrapping Key API response>"
raise WrappingKeyDownloadError(error_message) from validation_error

wrapping_pem_key = wrapping_key_bm.data.public_key.encode()
wrapping_key = serialization.load_pem_public_key(wrapping_pem_key)

if not isinstance(wrapping_key, rsa.RSAPublicKey):
Expand Down Expand Up @@ -125,9 +168,9 @@ def __api_write(
)
response.raise_for_status()
except requests.exceptions.HTTPError as http_error:
error_message: str
try:
error_message = "\n".join(http_error.response.json()["errors"])
errors_bm = VaultErrors(**http_error.response.json())
error_message = "\n".join(errors_bm.errors)
except Exception: # pylint: disable=broad-exception-caught
error_message = "<Failed to parse Vault API error response>"
raise vault_exception(error_message) from http_error
Expand Down Expand Up @@ -201,7 +244,13 @@ def sign_jwt(self, key_name: str, app_id: str) -> str:
api_path, payload, JWTSigningError
)

signature: str = response.json()["data"]["signature"].removeprefix("vault:v1:")
try:
signature_bm = SignedJWT(**response.json())
except ValidationError as validation_error:
error_message = "<Failed to parse Sign JWT API response>"
raise JWTSigningError(error_message) from validation_error

signature = signature_bm.data.signature.removeprefix("vault:v1:")
signature = self.__b64str(base64.b64decode(signature), urlsafe=True)

jwt_token = header_and_claims + "." + signature
Expand Down
Loading

0 comments on commit 9771109

Please sign in to comment.