diff --git a/packages/authentication/azure/kiota_authentication_azure/azure_identity_access_token_provider.py b/packages/authentication/azure/kiota_authentication_azure/azure_identity_access_token_provider.py index 986226d5..8803420d 100644 --- a/packages/authentication/azure/kiota_authentication_azure/azure_identity_access_token_provider.py +++ b/packages/authentication/azure/kiota_authentication_azure/azure_identity_access_token_provider.py @@ -96,21 +96,18 @@ async def get_authorization_token( decoded_bytes = base64.b64decode(additional_authentication_context[self.CLAIMS_KEY]) decoded_claim = decoded_bytes.decode("utf-8") - if not self._scopes: - self._scopes = [f"{parsed_url.scheme}://{parsed_url.netloc}/.default"] - span.set_attribute(self.SCOPES, ",".join(self._scopes)) + # Derive the scope per-call from the request hostname. + scopes = self._resolve_scopes(parsed_url, span) + span.set_attribute(self.SCOPES, ",".join(scopes)) span.set_attribute(self.ADDITIONAL_CLAIMS_PROVIDED, bool(self._options)) if self._options: result = self._credentials.get_token( - *self._scopes, - claims=decoded_claim, - enable_cae=self._is_cae_enabled, - **self._options + *scopes, claims=decoded_claim, enable_cae=self._is_cae_enabled, **self._options ) else: result = self._credentials.get_token( - *self._scopes, claims=decoded_claim, enable_cae=self._is_cae_enabled + *scopes, claims=decoded_claim, enable_cae=self._is_cae_enabled ) if inspect.isawaitable(result): @@ -127,3 +124,24 @@ def get_allowed_hosts_validator(self) -> AllowedHostsValidator: AllowedHostsValidator: The allowed hosts validator. """ return self._allowed_hosts_validator + + def _resolve_scopes(self, parsed_url, span) -> list[str]: + """Return the scopes to pass to `get_token` for this request. + + Caller-supplied scopes are returned verbatim. Otherwise a default + `.default` scope is derived from the request hostname only, so that + userinfo (`user:password@`) and ports (which Entra ID rejects for + `.default` scopes) are never copied into the scope or telemetry. + IPv6 literal brackets stripped by `urlparse` are re-added. + """ + if self._scopes: + return self._scopes + hostname = parsed_url.hostname + if not hostname: + span.set_attribute(self.IS_VALID_URL, False) + exc = HTTPError("Valid url scheme and host required") + span.record_exception(exc) + raise exc + if ":" in hostname: + hostname = f"[{hostname}]" + return [f"{parsed_url.scheme}://{hostname}/.default"] diff --git a/packages/authentication/azure/tests/test_azure_identity_access_token_provider.py b/packages/authentication/azure/tests/test_azure_identity_access_token_provider.py index a1d087d6..60a6a611 100644 --- a/packages/authentication/azure/tests/test_azure_identity_access_token_provider.py +++ b/packages/authentication/azure/tests/test_azure_identity_access_token_provider.py @@ -83,4 +83,96 @@ async def test_get_authorization_token_localhost(): token_provider = AzureIdentityAccessTokenProvider(DummySyncAzureTokenCredential(), None) token = await token_provider.get_authorization_token('HTTP://LOCALHOST:8080') assert token - + + +class RecordingSyncAzureTokenCredential(DummySyncAzureTokenCredential): + """Sync credential that records the scopes passed to get_token.""" + + def __init__(self): + self.received_scopes: list[tuple[str, ...]] = [] + + def get_token(self, *scopes, **kwargs): + self.received_scopes.append(scopes) + return super().get_token(*scopes, **kwargs) + + +@pytest.mark.asyncio +async def test_derived_scope_strips_userinfo_and_port(): + """The default `.default` scope passed to `get_token` must be derived + from the hostname only — never include userinfo or + a `:port` (which Entra ID rejects for `.default` scopes). + """ + credential = RecordingSyncAzureTokenCredential() + token_provider = AzureIdentityAccessTokenProvider(credential, None) + + await token_provider.get_authorization_token( + 'https://alice:secret@graph.microsoft.com:8443/v1.0/me' + ) + + assert credential.received_scopes == [('https://graph.microsoft.com/.default',)] + + +@pytest.mark.asyncio +async def test_derived_scope_is_not_cached_across_hosts(): + """The first URL's derived scope must not be reused for later URLs. + + Previously the scope was assigned to `self._scopes`, making it sticky for + the lifetime of the provider instance and causing tokens to be requested + for the wrong audience after the first call. + """ + credential = RecordingSyncAzureTokenCredential() + token_provider = AzureIdentityAccessTokenProvider(credential, None) + + await token_provider.get_authorization_token('https://graph.microsoft.com/v1.0/me') + await token_provider.get_authorization_token('https://graph.microsoft.us/v1.0/me') + + assert credential.received_scopes == [ + ('https://graph.microsoft.com/.default',), + ('https://graph.microsoft.us/.default',), + ] + # Provider must not have cached derived scopes into `_scopes`. + assert token_provider._scopes == [] + + +@pytest.mark.asyncio +async def test_explicit_scopes_are_respected(): + credential = RecordingSyncAzureTokenCredential() + token_provider = AzureIdentityAccessTokenProvider( + credential, None, scopes=['https://graph.microsoft.com/.default'] + ) + + await token_provider.get_authorization_token('https://graph.microsoft.com/v1.0/me') + await token_provider.get_authorization_token('https://graph.microsoft.us/v1.0/me') + + assert credential.received_scopes == [ + ('https://graph.microsoft.com/.default',), + ('https://graph.microsoft.com/.default',), + ] + + +@pytest.mark.asyncio +async def test_derived_scope_rejects_url_without_hostname(): + """A URI whose netloc has no hostname (e.g. `https://@/path`) must not + silently derive a scope like `https://None/.default`; it must raise. + """ + credential = RecordingSyncAzureTokenCredential() + token_provider = AzureIdentityAccessTokenProvider(credential, None) + + with pytest.raises(Exception): + await token_provider.get_authorization_token('https://@/path') + assert credential.received_scopes == [] + + +@pytest.mark.asyncio +async def test_derived_scope_brackets_ipv6_hostname(): + """`urlparse` strips brackets from IPv6 literals; the derived scope + must re-add them so the resulting URL is syntactically valid. + """ + credential = RecordingSyncAzureTokenCredential() + token_provider = AzureIdentityAccessTokenProvider(credential, None) + + await token_provider.get_authorization_token('https://[2001:db8::1]/v1.0/me') + + assert credential.received_scopes == [('https://[2001:db8::1]/.default',)] + +