diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py index d9d360d1bb..cffb76d035 100644 --- a/google/cloud/aiplatform/initializer.py +++ b/google/cloud/aiplatform/initializer.py @@ -466,12 +466,10 @@ def get_client_options( api_endpoint = self.api_endpoint - if ( - api_endpoint is None - and not self._project - and not self._location - and not location_override - ) or (self._location == "global"): + if api_endpoint is None and ( + (not self._project and not self._location and not location_override) + or self._location == "global" + ): # Default endpoint is location invariant if using API key or global # location. api_endpoint = "aiplatform.googleapis.com" diff --git a/tests/unit/aiplatform/test_initializer.py b/tests/unit/aiplatform/test_initializer.py index c11ae16799..8b7347efd5 100644 --- a/tests/unit/aiplatform/test_initializer.py +++ b/tests/unit/aiplatform/test_initializer.py @@ -310,6 +310,19 @@ def test_create_client_with_global_location(self): assert isinstance(client, utils.PredictionClientWithOverride) assert client._transport._host == f"https://{constants.API_BASE_PATH}" + def test_create_client_with_global_location_and_api_endpoint(self): + initializer.global_config.init( + project=_TEST_PROJECT, + location="global", + api_endpoint="test.aiplatform.googleapis.com", + ) + client = initializer.global_config.create_client( + client_class=utils.PredictionClientWithOverride + ) + assert initializer.global_config.location == "global" + assert isinstance(client, utils.PredictionClientWithOverride) + assert client._transport._host == "https://test.aiplatform.googleapis.com" + def test_create_client_with_global_location_and_grpc_transport(self): initializer.global_config.init( project=_TEST_PROJECT, location="global", api_transport="grpc" @@ -437,6 +450,12 @@ def test_not_set_api_endpoint(self): "test.aiplatform.googleapis.com", "test.aiplatform.googleapis.com", ), + ( + "global", + None, + "test.aiplatform.googleapis.com", + "test.aiplatform.googleapis.com", + ), ], ) def test_get_client_options(