Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ibis): Added BE Support for MySQL SSL Connection #1024

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ibis-server/app/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class MySqlConnectionInfo(BaseModel):
database: SecretStr
user: SecretStr
password: SecretStr
ssl_mode: SecretStr | None = Field(alias="sslMode", default=None)
ssl_ca: SecretStr | None = Field(alias="sslCA", default=None)
kwargs: dict[str, str] | None = Field(
description="Additional keyword arguments to pass to PyMySQL", default=None
)
Expand Down
42 changes: 39 additions & 3 deletions ibis-server/app/model/data_source.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import base64
import ssl
from enum import Enum, StrEnum, auto
from json import loads
from typing import Optional

import ibis
from google.oauth2 import service_account
Expand Down Expand Up @@ -30,6 +32,12 @@
)


class SSLMode(str, Enum):
DISABLE = "Disable"
REQUIRE = "Require"
VERIFY_CA = "Verify CA"
Comment on lines +36 to +38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
DISABLE = "Disable"
REQUIRE = "Require"
VERIFY_CA = "Verify CA"
DISABLED = "disabled"
ENABLED = "enabled"
VERIFY_CA = "verify_ca"

I prefer to rename require -> enabled.
We can use the snake case for the enum value. The request body could be

{
      "sslMode": "verify_ca",
      "sslCA": "xxx",
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the review. I’ll work on this shortly 😊.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @goldmedal, I have tried adding a test for ssl_mode=enabled as shown below:

async def test_connection_ssl_enabled(client, mysql: MySqlContainer):
    connection_info = _to_connection_info(mysql)
    connection_info["ssl_mode"] = "enabled"
    response = await client.post(
        url=f"{base_url}/metadata/version",
        json={"connectionInfo": connection_info},
    )
    assert response.status_code == 200
    assert response.text == '"8.0.40"'

However, the test passes, and I'm unsure if this is expected behavior due to MySQL server's default SSL configuration [ref: link].

By default, MySQL server always installs and enables SSL configuration. However, it is not enforced that clients connect using SSL. Clients can choose to connect with or without SSL as the server allows both types of connections.

Your help is very much appreciated!



class DataSource(StrEnum):
bigquery = auto()
canner = auto()
Expand Down Expand Up @@ -123,15 +131,19 @@ def get_mssql_connection(cls, info: MSSqlConnectionInfo) -> BaseBackend:
**info.kwargs if info.kwargs else dict(),
)

@staticmethod
def get_mysql_connection(info: MySqlConnectionInfo) -> BaseBackend:
@classmethod
def get_mysql_connection(cls, info: MySqlConnectionInfo) -> BaseBackend:
ssl_context = cls._create_ssl_context(info)
kwargs = {"ssl": ssl_context} if ssl_context else {}
if info.kwargs:
kwargs.update(info.kwargs)
return ibis.mysql.connect(
host=info.host.get_secret_value(),
port=int(info.port.get_secret_value()),
database=info.database.get_secret_value(),
user=info.user.get_secret_value(),
password=info.password.get_secret_value(),
**info.kwargs if info.kwargs else dict(),
**kwargs,
)

@staticmethod
Expand Down Expand Up @@ -168,3 +180,27 @@ def get_trino_connection(info: TrinoConnectionInfo) -> BaseBackend:
@staticmethod
def _escape_special_characters_for_odbc(value: str) -> str:
return "{" + value.replace("}", "}}") + "}"

@staticmethod
def _create_ssl_context(info: ConnectionInfo) -> Optional[ssl.SSLContext]:
ssl_mode = (
info.ssl_mode.get_secret_value() if hasattr(info, "ssl_mode") else None
)

if not ssl_mode or ssl_mode == SSLMode.DISABLE:
return None

ctx = ssl.create_default_context()
ctx.check_hostname = False

if ssl_mode == SSLMode.REQUIRE:
ctx.verify_mode = ssl.CERT_NONE
elif ssl_mode == SSLMode.VERIFY_CA:
ctx.verify_mode = ssl.CERT_REQUIRED
ctx.load_verify_locations(
cadata=base64.b64decode(info.ssl_ca.get_secret_value()).decode("utf-8")
if info.ssl_ca
else None
)

return ctx
Loading