diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index bd57340ca6..cce0e64b5a 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -13,8 +13,8 @@ from sys import version_info import pydantic +from pydantic import Field, computed_field from packaging import version -from pydantic import Field from pydantic_core import from_json from sqlglot import exp from sqlglot.errors import ParseError @@ -110,7 +110,14 @@ class ConnectionConfig(abc.ABC, BaseConfig): catalog_type_overrides: t.Optional[t.Dict[str, str]] = None # Whether to share a single connection across threads or create a new connection per thread. - shared_connection: t.ClassVar[bool] = False + # + # MyPy throws a "Decorators on top of @property are not supported" error despite this being a + # valid decoration, and Pydantic recommend disabling the MyPy hint for this reason - see: + # https://pydantic.dev/docs/validation/2.0/usage/computed_fields/ + @computed_field # type: ignore[prop-decorator] + @property + def shared_connection(self) -> bool: + return False @property @abc.abstractmethod @@ -311,7 +318,10 @@ class BaseDuckDBConnectionConfig(ConnectionConfig): token: t.Optional[str] = None - shared_connection: t.ClassVar[bool] = True + @computed_field # type: ignore[prop-decorator] + @property + def shared_connection(self) -> bool: + return True _data_file_to_adapter: t.ClassVar[t.Dict[str, EngineAdapter]] = {} @@ -820,11 +830,15 @@ class DatabricksConnectionConfig(ConnectionConfig): DISPLAY_NAME: t.ClassVar[t.Literal["Databricks"]] = "Databricks" DISPLAY_ORDER: t.ClassVar[t.Literal[3]] = 3 - shared_connection: t.ClassVar[bool] = True - _concurrent_tasks_validator = concurrent_tasks_validator _http_headers_validator = http_headers_validator + @computed_field # type: ignore[prop-decorator] + @property + def shared_connection(self) -> bool: + """The connection should only be shared if U2M OAuth is being used""" + return self.auth_type is not None and self.oauth_client_id is None + @model_validator(mode="before") def _databricks_connect_validator(cls, data: t.Any) -> t.Any: # SQLQueryContextLogger will output any error SQL queries even if they are in a try/except block. diff --git a/sqlmesh/core/config/loader.py b/sqlmesh/core/config/loader.py index e92c62960a..a3b6b213ab 100644 --- a/sqlmesh/core/config/loader.py +++ b/sqlmesh/core/config/loader.py @@ -272,4 +272,4 @@ def convert_config_type( config_obj: Config, config_type: t.Type[C], ) -> C: - return config_type.parse_obj(config_obj.dict()) + return config_type.parse_obj(config_obj.dict(exclude_computed_fields=True)) diff --git a/tests/core/test_config.py b/tests/core/test_config.py index 7af556d6a3..0350ad5d3f 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -532,6 +532,7 @@ def test_connection_config_serialization(): "extensions": [], "pre_ping": False, "pretty_sql": False, + "shared_connection": True, "connector_config": {}, "secrets": [], "filesystems": [], @@ -544,6 +545,7 @@ def test_connection_config_serialization(): "extensions": [], "pre_ping": False, "pretty_sql": False, + "shared_connection": True, "connector_config": {}, "secrets": [], "filesystems": [], diff --git a/tests/core/test_connection_config.py b/tests/core/test_connection_config.py index e1813df9b9..9c080eb560 100644 --- a/tests/core/test_connection_config.py +++ b/tests/core/test_connection_config.py @@ -1426,18 +1426,19 @@ def test_databricks(make_config): ) -def test_databricks_shared_connection(make_config): - """Databricks should use a shared connection pool to prevent OAuth CSRF races. +def test_databricks__u2m_oauth__shared_connection_pool(make_config): + """Databricks should use a shared connection pool when using OAuth to prevent CSRF races. When concurrent_tasks > 1, ThreadLocalConnectionPool creates one connection per thread. For U2M OAuth, each thread triggers its own browser-based OAuth flow; these race on the CSRF state parameter and cause MismatchingStateError. - Setting shared_connection = True causes ThreadLocalSharedConnectionPool to be - used instead: a single connection is created (behind a lock) and each thread - gets its own cursor, so only one OAuth flow is ever initiated. + For non-U2M OAuth authentication types (e.g. access_token and M2M OAuth) then + ThreadLocalConnectionPool should still be used. - See: https://github.com/tobymao/sqlmesh/issues/5646 + See: + https://github.com/tobymao/sqlmesh/issues/5646 + https://github.com/SQLMesh/sqlmesh/issues/5858 """ from sqlmesh.utils.connection_pool import ThreadLocalSharedConnectionPool @@ -1445,7 +1446,7 @@ def test_databricks_shared_connection(make_config): type="databricks", server_hostname="dbc-test.cloud.databricks.com", http_path="sql/test/foo", - access_token="test-token", + auth_type="databricks-oauth", concurrent_tasks=4, ) assert isinstance(config, DatabricksConnectionConfig) @@ -1455,6 +1456,41 @@ def test_databricks_shared_connection(make_config): assert isinstance(adapter._connection_pool, ThreadLocalSharedConnectionPool) +def test_databricks__m2m_oauth__connection_pool(make_config): + from sqlmesh.utils.connection_pool import ThreadLocalConnectionPool + + config = make_config( + type="databricks", + server_hostname="dbc-test.cloud.databricks.com", + http_path="sql/test/foo", + auth_type="databricks-oauth", + oauth_client_id="oauth_client_id", + concurrent_tasks=4, + ) + assert isinstance(config, DatabricksConnectionConfig) + assert config.shared_connection is False + + adapter = config.create_engine_adapter() + assert isinstance(adapter._connection_pool, ThreadLocalConnectionPool) + + +def test_databricks__access_token__connection_pool(make_config): + from sqlmesh.utils.connection_pool import ThreadLocalConnectionPool + + config = make_config( + type="databricks", + server_hostname="dbc-test.cloud.databricks.com", + http_path="sql/test/foo", + access_token="any-token", + concurrent_tasks=4, + ) + assert isinstance(config, DatabricksConnectionConfig) + assert config.shared_connection is False + + adapter = config.create_engine_adapter() + assert isinstance(adapter._connection_pool, ThreadLocalConnectionPool) + + def test_engine_import_validator(): with pytest.raises( ConfigError,