diff --git a/ibis-server/app/model/__init__.py b/ibis-server/app/model/__init__.py index 98e3a081f..cfa465702 100644 --- a/ibis-server/app/model/__init__.py +++ b/ibis-server/app/model/__init__.py @@ -95,6 +95,8 @@ class MySqlConnectionInfo(BaseModel): database: SecretStr user: SecretStr password: SecretStr + ssl_mode: SecretStr | None = Field(alias="sslMode", default=None) + ssl_ca: SecretStr | None = Field(alias="sslCA", default=None) kwargs: dict[str, str] | None = Field( description="Additional keyword arguments to pass to PyMySQL", default=None ) diff --git a/ibis-server/app/model/data_source.py b/ibis-server/app/model/data_source.py index ba282e973..c6947bc96 100644 --- a/ibis-server/app/model/data_source.py +++ b/ibis-server/app/model/data_source.py @@ -1,8 +1,10 @@ from __future__ import annotations import base64 +import ssl from enum import Enum, StrEnum, auto from json import loads +from typing import Optional import ibis from google.oauth2 import service_account @@ -30,6 +32,12 @@ ) +class SSLMode(str, Enum): + DISABLE = "Disable" + REQUIRE = "Require" + VERIFY_CA = "Verify CA" + + class DataSource(StrEnum): bigquery = auto() canner = auto() @@ -123,15 +131,19 @@ def get_mssql_connection(cls, info: MSSqlConnectionInfo) -> BaseBackend: **info.kwargs if info.kwargs else dict(), ) - @staticmethod - def get_mysql_connection(info: MySqlConnectionInfo) -> BaseBackend: + @classmethod + def get_mysql_connection(cls, info: MySqlConnectionInfo) -> BaseBackend: + ssl_context = cls._create_ssl_context(info) + kwargs = {"ssl": ssl_context} if ssl_context else {} + if info.kwargs: + kwargs.update(info.kwargs) return ibis.mysql.connect( host=info.host.get_secret_value(), port=int(info.port.get_secret_value()), database=info.database.get_secret_value(), user=info.user.get_secret_value(), password=info.password.get_secret_value(), - **info.kwargs if info.kwargs else dict(), + **kwargs, ) @staticmethod @@ -168,3 +180,27 @@ def get_trino_connection(info: TrinoConnectionInfo) -> BaseBackend: @staticmethod def _escape_special_characters_for_odbc(value: str) -> str: return "{" + value.replace("}", "}}") + "}" + + @staticmethod + def _create_ssl_context(info: ConnectionInfo) -> Optional[ssl.SSLContext]: + ssl_mode = ( + info.ssl_mode.get_secret_value() if hasattr(info, "ssl_mode") else None + ) + + if not ssl_mode or ssl_mode == SSLMode.DISABLE: + return None + + ctx = ssl.create_default_context() + ctx.check_hostname = False + + if ssl_mode == SSLMode.REQUIRE: + ctx.verify_mode = ssl.CERT_NONE + elif ssl_mode == SSLMode.VERIFY_CA: + ctx.verify_mode = ssl.CERT_REQUIRED + ctx.load_verify_locations( + cadata=base64.b64decode(info.ssl_ca.get_secret_value()).decode("utf-8") + if info.ssl_ca + else None + ) + + return ctx