Skip to content

Commit

Permalink
feat: documentation and refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
PascalDR committed Dec 20, 2023
1 parent 872ac11 commit 7a6c1ea
Showing 1 changed file with 177 additions and 91 deletions.
268 changes: 177 additions & 91 deletions pyeudiw/storage/db_engine.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,32 @@
import uuid
import logging
import importlib
from datetime import datetime
from typing import Callable, Union
from typing import Callable, Union, Tuple, Dict
from pyeudiw.storage.base_cache import BaseCache, RetrieveStatus
from pyeudiw.storage.base_storage import BaseStorage, TrustType
from pyeudiw.storage.exceptions import (
ChainNotExist,
StorageWriteError,
EntryNotFound
)
from pyeudiw.tools.base_logger import BaseLogger

logger = logging.getLogger(__name__)
from .base_db import BaseDB


class DBEngine():
class DBEngine(BaseStorage, BaseCache, BaseLogger):
"""
DB Engine class.
"""
def __init__(self, config: dict):
self.caches = []
self.storages = []
"""
Create a DB Engine instance.
:param config: the configuration of all the DBs.
:type config: dict
"""
self.caches: list[Tuple[str, BaseCache]] = []
self.storages: list[Tuple[str, BaseStorage]] = []

for db_name, db_conf in config.items():
storage_instance, cache_instance = self._handle_instance(db_conf)

Expand All @@ -27,25 +36,6 @@ def __init__(self, config: dict):
if cache_instance:
self.caches.append((db_name, cache_instance))

def _handle_instance(self, instance: dict) -> dict[BaseStorage | None, BaseCache | None]:
cache_conf = instance.get("cache", None)
storage_conf = instance.get("storage", None)

storage_instance = None
if storage_conf:
module = importlib.import_module(storage_conf["module"])
instance_class = getattr(module, storage_conf["class"])
storage_instance = instance_class(
**storage_conf.get("init_params", {}))

cache_instance = None
if cache_conf:
module = importlib.import_module(cache_conf["module"])
instance_class = getattr(module, cache_conf["class"])
cache_instance = instance_class(**cache_conf["init_params"])

return storage_instance, cache_instance

def init_session(self, session_id: str, state: str) -> str:
document_id = str(uuid.uuid4())
for db_name, storage in self.storages:
Expand All @@ -54,58 +44,49 @@ def init_session(self, session_id: str, state: str) -> str:
document_id, session_id=session_id, state=state
)
except StorageWriteError as e:
logger.critical(
f"Error while initializing session with document_id {document_id}. "
f"Cannot write document with id {document_id} on {db_name}: "
f"{e.__class__.__name__}: {e}"
self._log_critical(
e.__class__.__name__,
(
f"Error while initializing session with document_id {document_id}. "
f"Cannot write document with id {document_id} on {db_name}: {e}"
)
)
raise e

return document_id

@property
def is_connected(self):
_connected = False
_cons = {}
for db_name, storage in self.storages:
try:
_connected = storage.is_connected
_cons[db_name] = _connected
except Exception as e:
logger.debug(
f"Error while checking db engine connection on {db_name}. "
f"{e.__class__.__name__}: {e}"
)
def close(self):
self._close_list(self.storages)
self._close_list(self.caches)

if True in _cons.values() and not all(_cons.values()):
logger.warning(
f"Not all the storage are found available, storages misalignment: "
f"{_cons}"
)
def write(self, method: str, *args, **kwargs):
"""
Perform a write operation on the storages.
return _connected
:param method: the method to call.
:type method: str
:param args: the arguments to pass to the method.
:type args: Any
:param kwargs: the keyword arguments to pass to the method.
:type kwargs: Any
def close(self):
for db_name, storage in self.storages:
try:
storage.close()
except Exception as e:
logger.critical(
f"Error while closing db engine {db_name}. "
f"{e.__class__.__name__}: {e}"
)
raise e
:raises StorageWriteError: if the write operation fails on all the storages.
:returns: the number of replicas where the write operation is successful.
:rtype: int
"""

def write(self, method: str, *args, **kwargs):
replica_count = 0
_err_msg = f"Cannot apply write method '{method}' with {args} {kwargs}"
for db_name, storage in self.storages:
try:
getattr(storage, method)(*args, **kwargs)
replica_count += 1
except Exception as e:
logger.critical(
f"Error {_err_msg} on {db_name} {storage}: {str(e)}")
self._log_critical(
e.__class__.__name__,
f"Error {_err_msg} on {db_name} {storage}: {str(e)}"
)

if not replica_count:
raise StorageWriteError(_err_msg)
Expand All @@ -129,15 +110,32 @@ def update_request_object(self, document_id: str, request_object: dict) -> int:
def update_response_object(self, nonce: str, state: str, response_object: dict) -> int:
return self.write("update_response_object", nonce, state, response_object)

def get(self, method: str, *args, **kwargs):
def get(self, method: str, *args, **kwargs) -> Union[dict, None]:
"""
Perform a get operation on the storages.
:param method: the method to call.
:type method: str
:param args: the arguments to pass to the method.
:type args: Any
:param kwargs: the keyword arguments to pass to the method.
:type kwargs: Any
:raises EntryNotFound: if the entry is not found on any storage.
:returns: the result of the first elment found on DBs.
:rtype: Union[dict, None]
"""

for db_name, storage in self.storages:
try:
res = getattr(storage, method)(*args, **kwargs)
if res:
return res

except EntryNotFound as e:
logger.critical(
self._log_debug(
e.__class__.__name__,
f"Cannot find result by method {method} on {db_name} with {args} {kwargs}: {str(e)}"
)

Expand All @@ -149,19 +147,19 @@ def get_trust_attestation(self, entity_id: str) -> Union[dict, None]:
def get_trust_anchor(self, entity_id: str) -> Union[dict, None]:
return self.get("get_trust_anchor", entity_id)

def has_trust_attestation(self, entity_id: str):
return self.get_trust_attestation(entity_id)
def has_trust_attestation(self, entity_id: str) -> bool:
return self.get_trust_attestation(entity_id) != None

def has_trust_anchor(self, entity_id: str):
return self.get_anchor(entity_id)
def has_trust_anchor(self, entity_id: str) -> bool:
return self.get_trust_anchor(entity_id) != None

def add_trust_attestation(self, entity_id: str, attestation: list[str], exp: datetime, trust_type: TrustType = TrustType.FEDERATION) -> str:
return self.write("add_trust_attestation", entity_id, attestation, exp, trust_type)

def add_trust_attestation_metadata(self, entity_id: str, metadat_type: str, metadata: dict) -> str:
return self.write("add_trust_attestation_metadata", entity_id, metadat_type, metadata)

def add_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datetime, trust_type: TrustType = TrustType.FEDERATION):
def add_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datetime, trust_type: TrustType = TrustType.FEDERATION) -> str:
return self.write("add_trust_anchor", entity_id, entity_configuration, exp, trust_type)

def update_trust_attestation(self, entity_id: str, attestation: list[str], exp: datetime, trust_type: TrustType = TrustType.FEDERATION) -> str:
Expand All @@ -177,20 +175,6 @@ def add_or_update_trust_attestation(self, entity_id: str, attestation: list[str]
def update_trust_anchor(self, entity_id: str, entity_configuration: dict, exp: datetime, trust_type: TrustType = TrustType.FEDERATION) -> str:
return self.write("update_trust_anchor", entity_id, entity_configuration, exp, trust_type)

def _cache_try_retrieve(self, object_name: str, on_not_found: Callable[[], str]) -> tuple[dict, RetrieveStatus, int]:
for i, cache in enumerate(self.caches):
try:
cache_object, status = cache.try_retrieve(
object_name, on_not_found)
return cache_object, status, i
except Exception:
logger.critical(
f"Cannot retrieve or write cache object with identifier {object_name} on cache database {i}"
)
raise ConnectionRefusedError(
"Cannot write cache object on any instance"
)

def try_retrieve(self, object_name: str, on_not_found: Callable[[], str]) -> dict:
# if no cache instance exist return the object
if len(self.caches):
Expand All @@ -210,8 +194,9 @@ def try_retrieve(self, object_name: str, on_not_found: Callable[[], str]) -> dic
for cache_name, cache in replica_instances:
try:
cache.set(cache_object)
except Exception:
logger.critical(
except Exception as e:
self._log_critical(
e.__class__.__name__,
f"Cannot replicate cache object with identifier {object_name} on cache {cache_name}"
)

Expand All @@ -222,8 +207,9 @@ def overwrite(self, object_name: str, value_gen_fn: Callable[[], str]) -> dict:
cache_object = None
try:
cache_object = cache.overwrite(object_name, value_gen_fn)
except Exception:
logger.critical(
except Exception as e:
self._log_critical(
e.__class__.__name__,
f"Cannot overwrite cache object with identifier {object_name} on cache {cache_name}"
)
return cache_object
Expand All @@ -236,14 +222,114 @@ def exists_by_state_and_session_id(self, state: str, session_id: str = "") -> bo
return True
return False

def get_by_state(self, state: str):
def get_by_state(self, state: str) -> Union[dict, None]:
return self.get_by_state_and_session_id(state=state)

def get_by_nonce_state(self, state: str, nonce: str):
def get_by_nonce_state(self, state: str, nonce: str) -> Union[dict, None]:
return self.get('get_by_nonce_state', state=state, nonce=nonce)

def get_by_state_and_session_id(self, state: str, session_id: str = ""):
def get_by_state_and_session_id(self, state: str, session_id: str = "") -> Union[dict, None]:
return self.get("get_by_state_and_session_id", state, session_id)

def get_by_session_id(self, session_id: str):
def get_by_session_id(self, session_id: str) -> Union[dict, None]:
return self.get("get_by_session_id", session_id)

@property
def is_connected(self):
_connected = False
_cons = {}
for db_name, storage in self.storages:
try:
_connected = storage.is_connected()
_cons[db_name] = _connected
except Exception as e:
self._log_debug(
e.__class__.__name__,
f"Error while checking db engine connection on {db_name}: {e} "
)

if True in _cons.values() and not all(_cons.values()):
self._log_warning(
"DB Engine",
f"Not all the storage are found available, storages misalignment: "
f"{_cons}"
)

return _connected

def _cache_try_retrieve(self, object_name: str, on_not_found: Callable[[], str]) -> tuple[dict, RetrieveStatus, int]:
"""
Try to retrieve an object from the cache. If the object is not found, call the on_not_found function.
:param object_name: the name of the object to retrieve.
:type object_name: str
:param on_not_found: the function to call if the object is not found.
:type on_not_found: Callable[[], str]
:raises ConnectionRefusedError: if the object cannot be retrieved on any instance.
:returns: a tuple with the retrieved object, a status and the index of the cache instance.
:rtype: tuple[dict, RetrieveStatus, int]
"""

for i, (cache_name, cache_istance) in enumerate(self.caches):
try:
cache_object, status = cache_istance.try_retrieve(
object_name, on_not_found)
return cache_object, status, i
except Exception as e:
self._log_critical(
e.__class__.__name__,
f"Cannot retrieve cache object with identifier {object_name} on cache database {cache_name}"
)
raise ConnectionRefusedError(
"Cannot write cache object on any instance"
)

def _close_list(self, db_list: list[Tuple[str,BaseDB]]) -> None:
"""
Close a list of db.
:param db_list: the list of db to close.
:type db_list: list[Tuple[str,BaseDB]]
:raises Exception: if an error occurs while closing a db.
"""

for db_name, db in db_list:
try:
db.close()
except Exception as e:
self._log_critical(
e.__class__.__name__,
f"Error while closing db engine {db_name}: {e}"
)
raise e

def _handle_instance(self, instance: dict) -> dict[BaseStorage | None, BaseCache | None]:
"""
Handle the initialization of a storage/cache instance.
:param instance: the instance configuration.
:type instance: dict
:returns: a tuple with the storage and cache instance.
:rtype: tuple[BaseStorage | None, BaseCache | None]
"""
cache_conf = instance.get("cache", None)
storage_conf = instance.get("storage", None)

storage_instance = None
if storage_conf:
module = importlib.import_module(storage_conf["module"])
instance_class = getattr(module, storage_conf["class"])
storage_instance = instance_class(
**storage_conf.get("init_params", {}))

cache_instance = None
if cache_conf:
module = importlib.import_module(cache_conf["module"])
instance_class = getattr(module, cache_conf["class"])
cache_instance = instance_class(**cache_conf["init_params"])

return storage_instance, cache_instance

0 comments on commit 7a6c1ea

Please sign in to comment.