From c2cf4ed6605b196a75f423c0f626cc2e10ab628b Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 30 Jun 2026 15:19:44 +0000 Subject: [PATCH 01/15] Add client extension declaration vocabulary ClientExtension, ResultClaim, NotificationBinding, ClaimContext, and the advertise() factory in mcp/client/extension.py: a closed declarative surface mirroring the server Extension, with all shape rules enforced at construction. The extension identifier grammar moves to mcp/shared/extension.py (one source of truth for both sides); the server module re-exports it. --- src/mcp/client/__init__.py | 12 ++ src/mcp/client/extension.py | 197 +++++++++++++++++++++ src/mcp/server/extension.py | 25 +-- src/mcp/shared/extension.py | 33 ++++ tests/client/test_extension.py | 312 +++++++++++++++++++++++++++++++++ tests/shared/test_extension.py | 31 ++++ 6 files changed, 589 insertions(+), 21 deletions(-) create mode 100644 src/mcp/client/extension.py create mode 100644 src/mcp/shared/extension.py create mode 100644 tests/client/test_extension.py create mode 100644 tests/shared/test_extension.py diff --git a/src/mcp/client/__init__.py b/src/mcp/client/__init__.py index b7823f5ef..9a3c3ae0f 100644 --- a/src/mcp/client/__init__.py +++ b/src/mcp/client/__init__.py @@ -12,6 +12,13 @@ ) from mcp.client.client import Client from mcp.client.context import ClientRequestContext +from mcp.client.extension import ( + ClaimContext, + ClientExtension, + NotificationBinding, + ResultClaim, + advertise, +) from mcp.client.session import ClientSession __all__ = [ @@ -19,11 +26,16 @@ "CacheEntry", "CacheKey", "CacheMode", + "ClaimContext", "Client", + "ClientExtension", "ClientRequestContext", "ClientSession", "InMemoryResponseCacheStore", "InputRequiredRoundsExceededError", + "NotificationBinding", "ResponseCacheStore", + "ResultClaim", "Transport", + "advertise", ] diff --git a/src/mcp/client/extension.py b/src/mcp/client/extension.py new file mode 100644 index 000000000..15ca9a08c --- /dev/null +++ b/src/mcp/client/extension.py @@ -0,0 +1,197 @@ +"""Opt-in extension interface for MCP clients. + +To make an extension: subclass `ClientExtension`, set `identifier`, and +override whichever of `settings()` / `claims()` / `notifications()` apply. To +use one: pass instances to `Client(extensions=[...])` — the client folds the +declarations into its own machinery; the extension never receives the client. +To advertise an extension identifier with no client-side behaviour, use +`advertise()`. +""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable, Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, get_args + +from mcp_types import CallToolResult, InputRequiredResult, Result +from mcp_types.version import MODERN_PROTOCOL_VERSIONS +from pydantic import BaseModel + +from mcp.shared.extension import validate_extension_identifier + +if TYPE_CHECKING: + from mcp.client.session import ClientSession + +__all__ = [ + "ClaimContext", + "ClientExtension", + "NotificationBinding", + "ResultClaim", + "advertise", +] + +ClaimedT = TypeVar("ClaimedT", bound=Result) +NotifyParamsT = TypeVar("NotifyParamsT", bound=BaseModel) + + +@dataclass(frozen=True, kw_only=True) +class ClaimContext: + """Host-injected context for one `ResultClaim.resolve` call. + + `session` is the sanctioned public low-level handle — the same one users + already reach via `client.session`; the resolver gets no `Client` and no + new authority. + """ + + session: ClientSession + tool_name: str + read_timeout_seconds: float | None + + +@dataclass(frozen=True, kw_only=True) +class ResultClaim(Generic[ClaimedT]): + """One extra result shape on one spec verb, keyed by the wire `resultType`. + + A claim is active only while the declaring extension is constructed in AND + the negotiated version admits it; otherwise parsing stays byte-identical to + a claim-less client, so an undeclared shape still fails validation — the + supported `resultType` set is always core plus declared claims. + + `resolve` finishes a claimed result on the transparent path: it may send + follow-ups through `ctx.session` and must return the verb's ordinary + result. It is required — a claim nothing can finish would be useless. A + package that wants explicit-only handling ships a resolver that raises a + typed error naming `session.call_tool(allow_claimed=True)`, which is also + how callers reach the undriven shape per-call. + + `model` must declare `result_type` as a Literal of exactly the claimed tag, + and must not subclass a core result type — a core subclass would satisfy + the session's isinstance branches and bypass claim routing. `protocol_versions`, + when set, restricts the claim to a subset of the modern protocol revisions; + `None` (the default) means every modern version. The modern floor is + structural, not a restriction: claimed shapes cannot be delivered on a + legacy wire. All of this is enforced at construction. + """ + + result_type: str + model: type[ClaimedT] + resolve: Callable[[ClaimedT, ClaimContext], Awaitable[CallToolResult]] + method: Literal["tools/call"] = "tools/call" + protocol_versions: frozenset[str] | None = None + + def __post_init__(self) -> None: + if self.result_type in ("complete", "input_required"): + raise ValueError(f"resultType {self.result_type!r} is core protocol vocabulary") + if issubclass(self.model, CallToolResult | InputRequiredResult): + raise ValueError("claim models must not subclass core result types") + field = self.model.model_fields.get("result_type") + if field is None or get_args(field.annotation) != (self.result_type,): + raise ValueError(f"{self.model.__name__}.result_type must be Literal[{self.result_type!r}]") + if self.protocol_versions is not None and not self.protocol_versions: + raise ValueError("empty protocol_versions could never activate; use None for all") + if self.protocol_versions is not None and not self.protocol_versions.issubset(MODERN_PROTOCOL_VERSIONS): + unrecognized = sorted(self.protocol_versions.difference(MODERN_PROTOCOL_VERSIONS)) + raise ValueError( + f"protocol_versions {unrecognized} are not modern protocol revisions; claimed shapes " + "cannot be delivered on a legacy wire (None means every modern version)" + ) + + +@dataclass(frozen=True, kw_only=True) +class NotificationBinding(Generic[NotifyParamsT]): + """Deliver server notifications for `method` to `handler` (today: silently dropped). + + Observation-only: the handler receives validated params, returns None, and + cannot short-circuit anything. Delivery is per-binding serialized through a + bounded FIFO — one consumer task per binding, so a handler sees events in + arrival order and may do session I/O without deadlocking the in-process + dispatch path; on overflow the oldest event is dropped with a warning + (observation semantics make the drop acceptable). + + There is deliberately no spec-table check at construction: bindings are + consulted only for methods the negotiated version's core tables do NOT + know, so they are additive by construction. If a future core version + adopts the method, the binding goes quiet — detected and warned once at + activation, not per delivery — instead of import-erroring every package. + + `method` is the bare wire name (e.g. `notifications/tasks`); `params_type` + validates the notification params before `handler` runs. + """ + + method: str + params_type: type[NotifyParamsT] + handler: Callable[[NotifyParamsT], Awaitable[None]] + + +class ClientExtension: + """Base class for an opt-in client extension. Override only what you need. + + Mirror of `mcp.server.extension.Extension` in feel: a closed declarative + surface, fixed at construction, that never receives the client. The + contribution kinds are the ones a 2026 client actually has — there is + deliberately no served-request kind (servers do not initiate requests) and + no open interceptor (the only sanctioned augmentation is extension + `resultType` values, and a claim already names its owner, so composition + and ordering questions dissolve by construction). + """ + + #: Reverse-DNS extension identifier, advertised under `ClientCapabilities.extensions`. + identifier: str + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + # Validate a class-level `identifier` at definition time. A subclass may + # instead assign `identifier` in `__init__` (per-instance ids); that case + # is validated when the extension is consumed, since no class attribute + # exists to inspect here. + if (identifier := cls.__dict__.get("identifier")) is not None: + validate_extension_identifier(identifier, owner=cls.__name__) + + def settings(self) -> dict[str, Any]: + """Per-extension settings advertised at `ClientCapabilities.extensions[identifier]`. + + Read ONCE at `Client` construction — dynamic per-request settings are + out of scope. An empty dict (the default) advertises the extension with + no settings. + + A claim-bearing extension's identifier is advertised only at protocol + versions where at least one of its claims is active: the ad and the + claims dissolve together, so the client never advertises an extension + on a request whose claimed result shapes it would reject. Claim-less + extensions advertise at every version. + """ + return {} + + def claims(self) -> Sequence[ResultClaim[Any]]: + """Extra result shapes this extension claims, with their resolvers.""" + return () + + def notifications(self) -> Sequence[NotificationBinding[Any]]: + """Server notifications this extension observes.""" + return () + + +class _AdvertiseOnly(ClientExtension): + """Ad-only extension returned by `advertise()`: an identifier plus captured settings.""" + + def __init__(self, identifier: str, settings: dict[str, Any]) -> None: + self.identifier = identifier + self._settings = settings + + def settings(self) -> dict[str, Any]: + return self._settings + + +def advertise(identifier: str, settings: dict[str, Any] | None = None) -> ClientExtension: + """Advertise an extension identifier (with optional settings) and nothing else. + + Returns an extension that contributes only the capability ad: no claims, no + notification bindings. The identifier is validated eagerly, at this call. + + WARNING: advertising an extension you do not implement asserts wire support + you don't have — for behavioral extensions (e.g. tasks) construct the real + extension object instead. + """ + validate_extension_identifier(identifier, owner="advertise") + return _AdvertiseOnly(identifier, {} if settings is None else settings) diff --git a/src/mcp/server/extension.py b/src/mcp/server/extension.py index e045e6f29..11705943d 100644 --- a/src/mcp/server/extension.py +++ b/src/mcp/server/extension.py @@ -19,7 +19,6 @@ from __future__ import annotations -import re from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any @@ -30,31 +29,15 @@ from mcp.server.context import CallNext, HandlerResult, ServerMiddleware, ServerRequestContext +# The identifier grammar moved to `mcp.shared.extension` (the client extension +# surface shares it); re-exported here for existing importers. +from mcp.shared.extension import validate_extension_identifier as validate_extension_identifier + if TYPE_CHECKING: from mcp.server.mcpserver.resources import Resource RequestHandler = Callable[[ServerRequestContext[Any, Any], Any], Awaitable[HandlerResult]] -# Extension identifiers follow the `_meta` key grammar with a mandatory prefix -# (SEP-2133 / basic/index.mdx): dot-separated labels, each starting with a -# letter and ending with a letter or digit (hyphens interior), then `/`, then a -# name that starts and ends alphanumeric (`.`/`_`/`-` interior). -_LABEL = r"[A-Za-z](?:[A-Za-z0-9-]*[A-Za-z0-9])?" -_NAME = r"[A-Za-z0-9](?:[A-Za-z0-9._-]*[A-Za-z0-9])?" -_IDENTIFIER_RE = re.compile(rf"{_LABEL}(?:\.{_LABEL})*/{_NAME}") - - -def validate_extension_identifier(identifier: Any, *, owner: str) -> None: - """Raise `TypeError` unless `identifier` is a `vendor-prefix/name` string. - - SEP-2133 requires extension identifiers to carry a reverse-DNS prefix. - """ - if not isinstance(identifier, str) or not _IDENTIFIER_RE.fullmatch(identifier): - raise TypeError( - f"{owner}.identifier must be a `vendor-prefix/name` string " - f"(reverse-DNS prefix required), got {identifier!r}" - ) - @dataclass(frozen=True) class ToolBinding: diff --git a/src/mcp/shared/extension.py b/src/mcp/shared/extension.py new file mode 100644 index 000000000..f275e2436 --- /dev/null +++ b/src/mcp/shared/extension.py @@ -0,0 +1,33 @@ +"""Extension-identifier grammar shared by the server and client extension surfaces. + +Server extensions (`mcp.server.extension`) and client extensions +(`mcp.client.extension`) carry the same kind of identifier; this module is the +one source of truth for its validation. +""" + +from __future__ import annotations + +import re +from typing import Any + +__all__ = ["validate_extension_identifier"] + +# Extension identifiers follow the `_meta` key grammar with a mandatory prefix +# (SEP-2133 / basic/index.mdx): dot-separated labels, each starting with a +# letter and ending with a letter or digit (hyphens interior), then `/`, then a +# name that starts and ends alphanumeric (`.`/`_`/`-` interior). +_LABEL = r"[A-Za-z](?:[A-Za-z0-9-]*[A-Za-z0-9])?" +_NAME = r"[A-Za-z0-9](?:[A-Za-z0-9._-]*[A-Za-z0-9])?" +_IDENTIFIER_RE = re.compile(rf"{_LABEL}(?:\.{_LABEL})*/{_NAME}") + + +def validate_extension_identifier(identifier: Any, *, owner: str) -> None: + """Raise `TypeError` unless `identifier` is a `vendor-prefix/name` string. + + SEP-2133 requires extension identifiers to carry a reverse-DNS prefix. + """ + if not isinstance(identifier, str) or not _IDENTIFIER_RE.fullmatch(identifier): + raise TypeError( + f"{owner}.identifier must be a `vendor-prefix/name` string " + f"(reverse-DNS prefix required), got {identifier!r}" + ) diff --git a/tests/client/test_extension.py b/tests/client/test_extension.py new file mode 100644 index 000000000..a06260703 --- /dev/null +++ b/tests/client/test_extension.py @@ -0,0 +1,312 @@ +"""Tests for the client extension vocabulary (`mcp.client.extension`). + +Everything here is construction-time: claims, notification bindings, the +`ClientExtension` base class, and the `advertise()` factory. No session or +client is ever opened — the classes are pure declarations, and every +validation rule fires before an instance exists. +""" + +from dataclasses import FrozenInstanceError +from typing import Any, Literal + +import pytest +from inline_snapshot import snapshot +from mcp_types import CallToolResult, InputRequiredResult, Result +from mcp_types.version import MODERN_PROTOCOL_VERSIONS +from pydantic import BaseModel + +from mcp.client.extension import ( + ClaimContext, + ClientExtension, + NotificationBinding, + ResultClaim, + advertise, +) + + +class _TaskResult(Result): + """A well-formed claimed shape: `result_type` is a Literal of the claimed tag.""" + + result_type: Literal["task"] = "task" + task_id: str = "t-1" + + +class _UntaggedResult(Result): + """No `result_type` field at all.""" + + +class _PlainStringTagResult(Result): + """`result_type` declared as a plain `str`, not a Literal.""" + + result_type: str = "task" + + +class _OtherTagResult(Result): + """`result_type` is a Literal of a tag other than the one claimed.""" + + result_type: Literal["other"] = "other" + + +class _ClaimedCallToolResult(CallToolResult): + """A core-result subclass; rejected as a claim model regardless of its tag.""" + + +class _ClaimedInputRequiredResult(InputRequiredResult): + """A core-result subclass; rejected as a claim model regardless of its tag.""" + + +async def _resolve(result: Result, ctx: ClaimContext) -> CallToolResult: + raise NotImplementedError # construction-only tests never drive a claim + + +def _claim(model: type[Result] = _TaskResult, **kwargs: Any) -> ResultClaim[Result]: + return ResultClaim(result_type="task", model=model, resolve=_resolve, **kwargs) + + +# ── ResultClaim construction ──────────────────────────────────────────────── + + +def test_claim_with_literal_discriminated_model_constructs() -> None: + """SDK-defined: a claim whose model carries `result_type: Literal[]` + constructs, defaulting to the `tools/call` verb at every modern version.""" + claim = ResultClaim(result_type="task", model=_TaskResult, resolve=_resolve) + + assert claim.result_type == "task" + assert claim.model is _TaskResult + assert claim.resolve is _resolve + assert claim.method == "tools/call" + assert claim.protocol_versions is None + + +def test_claim_accepts_modern_protocol_versions() -> None: + """SDK-defined: a non-None `protocol_versions` is accepted when it is a subset of + the modern protocol revisions.""" + versions = frozenset(MODERN_PROTOCOL_VERSIONS) + + claim = _claim(protocol_versions=versions) + + assert claim.protocol_versions == versions + + +@pytest.mark.parametrize("result_type", ["complete", "input_required"]) +def test_claim_rejects_core_result_type_vocabulary(result_type: str) -> None: + """SDK-defined: "complete" and "input_required" are core protocol vocabulary — + a claim cannot re-key the shapes the session itself routes on.""" + with pytest.raises(ValueError, match="core protocol vocabulary"): + ResultClaim(result_type=result_type, model=_TaskResult, resolve=_resolve) + + +@pytest.mark.parametrize("model", [_ClaimedCallToolResult, _ClaimedInputRequiredResult]) +def test_claim_rejects_model_subclassing_core_result_types(model: type[Result]) -> None: + """SDK-defined: a claim model subclassing `CallToolResult` or `InputRequiredResult` + would satisfy the session's isinstance branches and bypass claim routing.""" + with pytest.raises(ValueError, match="must not subclass core result types"): + _claim(model=model) + + +def test_claim_rejects_model_without_result_type_field() -> None: + """SDK-defined: the claim model must declare the discriminating `result_type` + field; without it the claimed shape could never be routed.""" + with pytest.raises(ValueError) as exc_info: + _claim(model=_UntaggedResult) + + assert str(exc_info.value) == snapshot("_UntaggedResult.result_type must be Literal['task']") + + +def test_claim_rejects_plain_str_result_type_field() -> None: + """SDK-defined: a plain `str` tag would let one model validate any claimed shape; + the field must be a Literal of exactly the claimed tag.""" + with pytest.raises(ValueError) as exc_info: + _claim(model=_PlainStringTagResult) + + assert str(exc_info.value) == snapshot("_PlainStringTagResult.result_type must be Literal['task']") + + +def test_claim_rejects_mismatched_result_type_literal() -> None: + """SDK-defined: the model's Literal tag must equal the claim's `result_type` — + a mismatch would register the model under a tag it refuses to validate.""" + with pytest.raises(ValueError) as exc_info: + _claim(model=_OtherTagResult) + + assert str(exc_info.value) == snapshot("_OtherTagResult.result_type must be Literal['task']") + + +def test_claim_rejects_empty_protocol_versions() -> None: + """SDK-defined: an empty version set could never activate; `None` is the + spelling for "every modern version".""" + with pytest.raises(ValueError) as exc_info: + _claim(protocol_versions=frozenset()) + + assert str(exc_info.value) == snapshot("empty protocol_versions could never activate; use None for all") + + +@pytest.mark.parametrize( + "versions", + [ + frozenset({"2025-11-25"}), + frozenset({"2026-07-28", "2025-11-25"}), + frozenset({"never-a-version"}), + ], +) +def test_claim_rejects_non_modern_protocol_versions(versions: frozenset[str]) -> None: + """SDK-defined: claimed shapes cannot be delivered on a legacy wire, so a + non-None version set must be a subset of the modern protocol revisions.""" + with pytest.raises(ValueError, match="not modern protocol revisions"): + _claim(protocol_versions=versions) + + +def test_result_claim_is_frozen() -> None: + """SDK-defined: claims are immutable declarations — mutating one after + construction raises.""" + claim = _claim() + + with pytest.raises(FrozenInstanceError): + setattr(claim, "result_type", "other") # direct assignment is also a type error + + +# ── NotificationBinding construction ──────────────────────────────────────── + + +class _TaskNotificationParams(BaseModel): + task_id: str + + +async def _on_task(params: _TaskNotificationParams) -> None: + raise NotImplementedError # construction-only tests never deliver + + +def test_notification_binding_constructs() -> None: + """SDK-defined: a binding is a bare declaration — wire method name, params + model, async observer — with no construction-time validation.""" + binding = NotificationBinding(method="notifications/tasks", params_type=_TaskNotificationParams, handler=_on_task) + + assert binding.method == "notifications/tasks" + assert binding.params_type is _TaskNotificationParams + assert binding.handler is _on_task + + +def test_notification_binding_accepts_core_known_method() -> None: + """SDK-defined: deliberately NO spec-table check at construction — bindings are + consulted only for methods core does not know, so they are additive by + construction, and an import-time table check would break packages whenever a + core version adopts a method.""" + binding = NotificationBinding( + method="notifications/progress", params_type=_TaskNotificationParams, handler=_on_task + ) + + assert binding.method == "notifications/progress" + + +def test_notification_binding_is_frozen() -> None: + """SDK-defined: bindings are immutable declarations — mutating one after + construction raises.""" + binding = NotificationBinding(method="notifications/tasks", params_type=_TaskNotificationParams, handler=_on_task) + + with pytest.raises(FrozenInstanceError): + setattr(binding, "method", "notifications/other") # direct assignment is also a type error + + +# ── ClientExtension subclassing ───────────────────────────────────────────── + + +def test_extension_defaults_advertise_nothing() -> None: + """SDK-defined: a minimal subclass overrides nothing — empty settings, no + claims, no notification bindings.""" + + class _MinimalExt(ClientExtension): + identifier = "com.example/minimal" + + ext = _MinimalExt() + + assert ext.settings() == {} + assert ext.claims() == () + assert ext.notifications() == () + + +@pytest.mark.parametrize( + "identifier", + [ + "io.modelcontextprotocol/ui", + "com.example/my_ext", + "com.x-y.z2/n.a-b_c", + "example/x", + "a/b", + "com.example/9start", + ], +) +def test_grammar_conformant_identifiers_accepted_at_class_definition(identifier: str) -> None: + """Spec `_meta` key grammar: dot-separated labels (letter start, letter/digit end, + hyphens interior), a slash, then a name that starts and ends alphanumeric.""" + cls = type("_GoodExt", (ClientExtension,), {"identifier": identifier}) + + assert cls.identifier == identifier + + +@pytest.mark.parametrize( + "identifier", + [ + "noprefix", + "-foo/bar", + ".leading/x", + "a..b/x", + "foo-/x", + "9foo/x", + "foo/-bar", + "foo/bar-", + "foo/", + "/bar", + "foo/ba r", + "io.modelcontextprotocol/ui\n", + "", + 42, + ], +) +def test_malformed_identifier_rejected_at_class_definition(identifier: Any) -> None: + """SDK-defined: SEP-2133 requires a `vendor-prefix/name` identifier, enforced the + moment the subclass is defined — same grammar and helper as the server side.""" + with pytest.raises(TypeError): + type("_BadExt", (ClientExtension,), {"identifier": identifier}) + + +def test_subclass_without_identifier_allowed_at_definition() -> None: + """SDK-defined: a subclass that sets no class-level `identifier` (an abstract-ish + intermediate base, or one assigning per-instance ids in `__init__`) is allowed at + definition time; the identifier is validated when the extension is consumed.""" + + class _AbstractishExt(ClientExtension): + """Intermediate base; concrete subclasses supply the identifier.""" + + class _ConcreteExt(_AbstractishExt): + identifier = "com.example/concrete" + + assert _ConcreteExt.identifier == "com.example/concrete" + + +# ── advertise() factory ───────────────────────────────────────────────────── + + +def test_advertise_serves_captured_settings() -> None: + """SDK-defined: `advertise()` returns an ad-only extension whose `settings()` + override serves the captured settings.""" + ext = advertise("com.example/flags", {"enabled": True}) + + assert isinstance(ext, ClientExtension) + assert ext.identifier == "com.example/flags" + assert ext.settings() == {"enabled": True} + assert ext.claims() == () + assert ext.notifications() == () + + +def test_advertise_defaults_to_empty_settings() -> None: + """SDK-defined: omitting settings advertises the extension with an empty map.""" + ext = advertise("com.example/flags") + + assert ext.settings() == {} + + +@pytest.mark.parametrize("identifier", ["noprefix", "foo/", ""]) +def test_advertise_validates_identifier_eagerly(identifier: str) -> None: + """SDK-defined: `advertise()` validates the identifier at the call, not at some + later consumption point — a bad ad-only id fails where it is written.""" + with pytest.raises(TypeError): + advertise(identifier) diff --git a/tests/shared/test_extension.py b/tests/shared/test_extension.py new file mode 100644 index 000000000..4a94c75b4 --- /dev/null +++ b/tests/shared/test_extension.py @@ -0,0 +1,31 @@ +"""Tests for `mcp.shared.extension` — the extension-identifier grammar shared by +the server and client extension surfaces. + +The grammar matrix (accepted and rejected identifiers) lives with the original +server tests in `tests/server/mcpserver/test_extension.py`, which exercise the +same function via the server module's re-export. +""" + +import pytest + +import mcp.server.extension +import mcp.shared.extension + + +def test_validator_importable_from_shared_home() -> None: + """SDK-defined: the identifier grammar lives in `mcp.shared.extension` — one + source of truth for both the server and client extension surfaces.""" + mcp.shared.extension.validate_extension_identifier("com.example/thing", owner="T") + + +def test_validator_rejects_malformed_identifier_via_shared_path() -> None: + """SDK-defined: the shared-home function enforces the same `vendor-prefix/name` + grammar the server side always has.""" + with pytest.raises(TypeError): + mcp.shared.extension.validate_extension_identifier("noprefix", owner="T") + + +def test_server_extension_module_reexports_shared_validator() -> None: + """SDK-defined: `mcp.server.extension.validate_extension_identifier` remains + importable after the move and is the very same function object.""" + assert mcp.server.extension.validate_extension_identifier is mcp.shared.extension.validate_extension_identifier From c4cfebdccd481c1f56d2a4180fb53175ceb1b46e Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 30 Jun 2026 15:23:24 +0000 Subject: [PATCH 02/15] Support vendor request types in ClientSession.send_request Request gains a name_param ClassVar naming the wire-params key to mirror into the Mcp-Name header; send_request emits it whenever the stamp has not already set the header, so core NAME_BEARING_METHODS rows win by ordering and a missing value fails loud instead of silently omitting the header. send_request's typing widens to accept any Request[Any, Any] (runtime was always duck-typed), retiring the cast ceremony in the extension and custom-method stories. dispatch_input_request and validate_tool_result are promoted to public ClientSession surface. --- docs/advanced/extensions.md | 4 +- docs_src/extensions/tutorial004.py | 4 +- examples/stories/custom_methods/README.md | 7 +- examples/stories/custom_methods/client.py | 11 +- examples/stories/extensions/client.py | 7 +- src/mcp-types/mcp_types/_types.py | 8 +- src/mcp/client/client.py | 2 +- src/mcp/client/session.py | 41 +++- tests/client/test_send_request_mcp_name.py | 228 +++++++++++++++++++++ tests/client/test_session_promotions.py | 83 ++++++++ tests/docs_src/test_extensions.py | 6 +- tests/server/mcpserver/test_extension.py | 6 +- tests/types/test_request_name_param.py | 37 ++++ 13 files changed, 408 insertions(+), 36 deletions(-) create mode 100644 tests/client/test_send_request_mcp_name.py create mode 100644 tests/client/test_session_promotions.py create mode 100644 tests/types/test_request_name_param.py diff --git a/docs/advanced/extensions.md b/docs/advanced/extensions.md index 5a6d7d524..024c7af69 100644 --- a/docs/advanced/extensions.md +++ b/docs/advanced/extensions.md @@ -119,8 +119,8 @@ The same file's `main()` is the whole client story, both halves of it: which: `require_client_extension(ctx, ...)` and `ctx.session.check_client_capability(...)` read the right source on both paths. * Vendor methods drop one layer to `client.session.send_request(...)`; `Client` - only grows first-class methods for spec verbs. The `cast` is there because - `send_request` is typed against the spec's closed request union. + only grows first-class methods for spec verbs. `send_request` accepts any + `Request` subclass, so the vendor request passes as-is. ### Intercepting `tools/call` diff --git a/docs_src/extensions/tutorial004.py b/docs_src/extensions/tutorial004.py index 4a0a022af..b08840705 100644 --- a/docs_src/extensions/tutorial004.py +++ b/docs_src/extensions/tutorial004.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Any, Literal, cast +from typing import Any, Literal import mcp_types as types from pydantic import Field @@ -53,6 +53,6 @@ def methods(self) -> Sequence[MethodBinding]: async def main() -> None: async with Client(mcp, extensions={EXTENSION_ID: {}}) as client: request = SearchRequest(params=SearchParams(query="mcp", limit=3)) - result = await client.session.send_request(cast("types.ClientRequest", request), SearchResult) + result = await client.session.send_request(request, SearchResult) print(result.items) # ['mcp-0', 'mcp-1', 'mcp-2'] diff --git a/examples/stories/custom_methods/README.md b/examples/stories/custom_methods/README.md index 924ea0298..817437b3a 100644 --- a/examples/stories/custom_methods/README.md +++ b/examples/stories/custom_methods/README.md @@ -28,10 +28,9 @@ uv run python -m stories.custom_methods.client --http method string is the wire `method`; use a vendor prefix so it can never collide with a future spec method. - `client.py` `client.session.send_request(...)` — `Client` only exposes spec - verbs, so vendor methods go through the underlying `ClientSession`. The - `cast("types.ClientRequest", ...)` is needed because `send_request`'s - `request` parameter is currently typed as the closed spec union; widening it - (or adding `Client.send_request`) is tracked for beta. + verbs, so vendor methods go through the underlying `ClientSession`. + `send_request` accepts any `types.Request` subclass, so the vendor request + passes as-is, no cast. ## Caveats diff --git a/examples/stories/custom_methods/client.py b/examples/stories/custom_methods/client.py index 4003885fa..5282d584d 100644 --- a/examples/stories/custom_methods/client.py +++ b/examples/stories/custom_methods/client.py @@ -1,6 +1,6 @@ """Send a vendor-prefixed request via the `client.session` escape hatch.""" -from typing import Literal, cast +from typing import Literal import mcp_types as types @@ -26,12 +26,11 @@ async def main(target: Target, *, mode: str = "auto") -> None: async with Client(target, mode=mode) as client: # `Client` only exposes spec-defined verbs, so vendor methods have to drop one # layer to `client.session` today — there is no `Client`-level API for them - # yet, and whether `.session` stays public is undecided. `send_request` is - # typed against the closed `ClientRequest` union, hence the cast; at runtime - # the body only calls `.model_dump()` and the unknown method skips the - # per-spec result-validation registry. + # yet, and whether `.session` stays public is undecided. `send_request` + # accepts any `Request` subclass; the unknown method skips the per-spec + # result-validation registry. request = SearchRequest(params=SearchParams(query="mcp", limit=3)) - result = await client.session.send_request(cast("types.ClientRequest", request), SearchResult) + result = await client.session.send_request(request, SearchResult) assert result.items == ["mcp-0", "mcp-1", "mcp-2"], result diff --git a/examples/stories/extensions/client.py b/examples/stories/extensions/client.py index d3aacc140..849586f6f 100644 --- a/examples/stories/extensions/client.py +++ b/examples/stories/extensions/client.py @@ -1,6 +1,6 @@ """Discover an extension's capability entry, call its tool, then send its vendor method.""" -from typing import Literal, cast +from typing import Literal import mcp_types as types from mcp_types import TextContent @@ -43,10 +43,9 @@ async def main(target: Target, *, mode: str = "auto") -> None: assert isinstance(result.content[0], TextContent) assert result.content[0].text == "mcp-suggestion", result.content[0].text - # Vendor methods drop one layer to `client.session` (see custom_methods/); - # the cast is needed because `send_request` is typed against the spec union. + # Vendor methods drop one layer to `client.session` (see custom_methods/). request = SearchRequest(params=SearchParams(query="mcp", limit=3)) - found = await client.session.send_request(cast("types.ClientRequest", request), SearchResult) + found = await client.session.send_request(request, SearchResult) assert found.items == ["mcp-0", "mcp-1", "mcp-2"], found diff --git a/src/mcp-types/mcp_types/_types.py b/src/mcp-types/mcp_types/_types.py index 34dc10083..94156879a 100644 --- a/src/mcp-types/mcp_types/_types.py +++ b/src/mcp-types/mcp_types/_types.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Annotated, Any, Final, Generic, Literal, TypeAlias, TypeVar +from typing import Annotated, Any, ClassVar, Final, Generic, Literal, TypeAlias, TypeVar from pydantic import ( BaseModel, @@ -128,6 +128,12 @@ class Request(MCPModel, Generic[RequestParamsT, MethodT]): method: MethodT params: RequestParamsT + name_param: ClassVar[str | None] = None + """Wire-params key mirrored into the `Mcp-Name` header on sends (SEP-2243 + family; SEP-2663 requires it for tasks/*). The request type declares; the + host emits. Subclasses override by bare assignment (`name_param = "taskId"`) + — re-annotating as `ClassVar[str]` trips pyright's ClassVar invariance.""" + class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]): """Base class for paginated requests, matching the schema's PaginatedRequest interface.""" diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 638ea63a9..2b0cabe61 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -715,7 +715,7 @@ async def _drive_input_required( async def dispatch(key: str, req: InputRequest) -> InputResponse | ErrorData: ctx = ClientRequestContext(session=session, request_id=key, meta=req.params.meta if req.params else None) - return await session._dispatch_input_request(ctx, req) # pyright: ignore[reportPrivateUsage] + return await session.dispatch_input_request(ctx, req) return await run_input_required_driver( first, dispatch=dispatch, retry=retry, max_rounds=self.input_required_max_rounds diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 6a2298ad9..8a401b198 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -304,7 +304,7 @@ async def __aexit__( async def send_request( self, - request: types.ClientRequest, + request: types.ClientRequest | types.Request[Any, Any], result_type: type[ReceiveResultT] | TypeAdapter[ReceiveResultT], request_read_timeout_seconds: float | None = None, metadata: ClientMessageMetadata | None = None, @@ -318,11 +318,22 @@ async def send_request( Raises: MCPError: Error response, read timeout, or connection closed. RuntimeError: Called before entering the context manager. + ValueError: The request type declares `name_param` but the params + carry no string value under that key for the `Mcp-Name` header. """ data = request.model_dump(by_alias=True, mode="json", exclude_none=True) method: str = data["method"] opts: CallOptions = {} self._stamp(data, opts) + # Presence-keyed so the stamp's NAME_BEARING_METHODS rows win by ordering; + # a missing/non-string name fails loud rather than omitting a MUST header. + headers = opts.setdefault("headers", {}) + if (key := type(request).name_param) is not None and MCP_NAME_HEADER not in headers: + params_data: dict[str, Any] = data.get("params") or {} + name = params_data.get(key) + if not isinstance(name, str): + raise ValueError(f"{method} requires params[{key!r}] for Mcp-Name") + headers[MCP_NAME_HEADER] = encode_header_value(name) timeout = ( request_read_timeout_seconds if request_read_timeout_seconds is not None @@ -766,7 +777,7 @@ async def call_tool( ) if isinstance(result, types.CallToolResult) and not result.is_error: - await self._validate_tool_result(name, result) + await self.validate_tool_result(name, result) if isinstance(result, types.InputRequiredResult) and not allow_input_required: raise _input_required_unexpected("call_tool") @@ -779,8 +790,17 @@ def _resolve_param_headers(self, name: str, arguments: Mapping[str, Any]) -> dic return {} return mcp_param_headers(header_map, arguments) - async def _validate_tool_result(self, name: str, result: types.CallToolResult) -> None: - """Validate the structured content of a tool result against its output schema.""" + async def validate_tool_result(self, name: str, result: types.CallToolResult) -> None: + """Revalidate a `CallToolResult` against the tool's declared output schema. + + Fetches the tool listing first when `name` has no cached schema. Tools + without an output schema (or not listed by the server) pass without + validation. + + Raises: + RuntimeError: The result's structured content is missing or does + not conform to the tool's output schema. + """ if name not in self._tool_output_schemas: # refresh output schema cache await self.list_tools() @@ -970,7 +990,7 @@ async def _on_request( ctx = ClientRequestContext( session=self, request_id=dctx.request_id, meta=request.params.meta if request.params else None ) - response = await self._dispatch_input_request(ctx, request) + response = await self.dispatch_input_request(ctx, request) client_response = ClientResponse.validate_python(response) if isinstance(client_response, types.ErrorData): raise MCPError.from_error_data(client_response) @@ -982,16 +1002,19 @@ async def _on_request( raise MCPError(code=INTERNAL_ERROR, message="Client callback returned an invalid result") from None return dumped - async def _dispatch_input_request( - self, ctx: ClientRequestContext, req: types.InputRequest + async def dispatch_input_request( + self, ctx: ClientRequestContext, request: types.InputRequest ) -> types.InputResponse | types.ErrorData: - """Route a server-initiated input request to the matching constructor callback. + """Route an input request through the client's callback table. Shared by the legacy server→client RPC path (`_on_request`) and the 2026-07-28 multi-round-trip driver, which dispatches the embedded `InputRequiredResult.input_requests` through the same callbacks. + + Returns the callback's `InputResponse`, or `ErrorData` when the + callback declines — the refusal path; callers must handle that arm. """ - match req: + match request: case types.CreateMessageRequest(params=p): return await self._sampling_callback(ctx, p) case types.ElicitRequest(params=p): diff --git a/tests/client/test_send_request_mcp_name.py b/tests/client/test_send_request_mcp_name.py new file mode 100644 index 000000000..0cc318ff9 --- /dev/null +++ b/tests/client/test_send_request_mcp_name.py @@ -0,0 +1,228 @@ +"""`ClientSession.send_request` mirrors `Request.name_param` into the `Mcp-Name` header. + +The modern stamp emits `Mcp-Name` for the core `NAME_BEARING_METHODS` table; the +`name_param` delta covers every other send path (vendor request types, the +legacy handshake stamp), keyed on header presence so the stamp's table rows +always win by ordering. The vendor sends below also pin the widened +`send_request` typing: a `Request[...]` subclass passes without a cast. +""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, Literal + +import anyio +import anyio.abc +import mcp_types as types +import pytest +from mcp_types import ( + CallToolResult, + Implementation, + ListToolsResult, + Request, + ServerCapabilities, + TextContent, + Tool, +) +from mcp_types.version import LATEST_HANDSHAKE_VERSION, LATEST_MODERN_VERSION + +from mcp.client.session import ClientSession +from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest +from mcp.shared.inbound import MCP_NAME_HEADER, MCP_PROTOCOL_VERSION_HEADER, encode_header_value + + +class _RecordingDispatcher: + """Records `send_raw_request` opts and answers with canned per-method results.""" + + def __init__(self) -> None: + self.calls: list[tuple[str, CallOptions]] = [] + + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: + task_status.started() + await anyio.sleep_forever() + + async def send_raw_request( + self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None + ) -> dict[str, Any]: + self.calls.append((method, opts or {})) + if method == "tools/call": + return CallToolResult(content=[TextContent(type="text", text="ok")]).model_dump( + by_alias=True, mode="json", exclude_none=True + ) + if method == "tools/list": + return ListToolsResult(tools=[Tool(name="my-tool", input_schema={"type": "object"})]).model_dump( + by_alias=True, mode="json", exclude_none=True + ) + return {} + + async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: + raise NotImplementedError + + +class _GetWidgetParams(types.RequestParams): + widget_id: str + + +class _GetWidgetRequest(Request[_GetWidgetParams, Literal["vendor/widgets/get"]]): + method: Literal["vendor/widgets/get"] = "vendor/widgets/get" + name_param = "widgetId" + + +class _RawWidgetRequest(Request[dict[str, Any], Literal["vendor/widgets/get"]]): + """Same wire shape with untyped params, so tests can omit or mistype the name value.""" + + method: Literal["vendor/widgets/get"] = "vendor/widgets/get" + name_param = "widgetId" + + +class _ShadowCallToolRequest(Request[dict[str, Any], Literal["tools/call"]]): + """A vendor type declaring `name_param` for a method the core table already covers.""" + + method: Literal["tools/call"] = "tools/call" + name_param = "customKey" + + +class _PlainVendorRequest(Request[dict[str, Any], Literal["vendor/widgets/list"]]): + method: Literal["vendor/widgets/list"] = "vendor/widgets/list" + + +def _adopt_modern(session: ClientSession) -> None: + session.adopt( + types.DiscoverResult( + supported_versions=[LATEST_MODERN_VERSION], + capabilities=ServerCapabilities(), + server_info=Implementation(name="stub", version="0"), + ) + ) + + +def _adopt_handshake(session: ClientSession) -> None: + session.adopt( + types.InitializeResult( + protocol_version=LATEST_HANDSHAKE_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="stub", version="0"), + ) + ) + + +def _headers(opts: CallOptions) -> dict[str, str]: + return opts.get("headers") or {} + + +@pytest.mark.anyio +async def test_vendor_name_param_emits_mcp_name_on_the_modern_path() -> None: + """A vendor request type declaring `name_param` gets `Mcp-Name` on a modern + wire even though its method is not in `NAME_BEARING_METHODS`.""" + dispatcher = _RecordingDispatcher() + with anyio.fail_after(5): + async with ClientSession(dispatcher=dispatcher) as session: + _adopt_modern(session) + await session.send_request(_GetWidgetRequest(params=_GetWidgetParams(widget_id="w-1")), types.EmptyResult) + [(_, opts)] = dispatcher.calls + assert _headers(opts)[MCP_NAME_HEADER] == "w-1" + + +@pytest.mark.anyio +async def test_vendor_name_param_emits_mcp_name_on_the_handshake_path() -> None: + """The handshake stamp sets no `Mcp-Name` at all, so for a legacy wire the + delta is the responsible emitter — emission is era-unconditional.""" + dispatcher = _RecordingDispatcher() + with anyio.fail_after(5): + async with ClientSession(dispatcher=dispatcher) as session: + _adopt_handshake(session) + await session.send_request(_GetWidgetRequest(params=_GetWidgetParams(widget_id="w-1")), types.EmptyResult) + [(_, opts)] = dispatcher.calls + assert _headers(opts)[MCP_NAME_HEADER] == "w-1" + # The stamp's own headers survive the delta. + assert _headers(opts)[MCP_PROTOCOL_VERSION_HEADER] == LATEST_HANDSHAKE_VERSION + + +@pytest.mark.anyio +async def test_name_value_passes_through_encode_header_value() -> None: + """A name that cannot ride as a plain ASCII header value is base64-sentinel + encoded (spec MUST for `Mcp-Name`).""" + name = "wídget ✨" + dispatcher = _RecordingDispatcher() + with anyio.fail_after(5): + async with ClientSession(dispatcher=dispatcher) as session: + _adopt_handshake(session) + await session.send_request(_GetWidgetRequest(params=_GetWidgetParams(widget_id=name)), types.EmptyResult) + [(_, opts)] = dispatcher.calls + assert _headers(opts)[MCP_NAME_HEADER] == encode_header_value(name) + assert _headers(opts)[MCP_NAME_HEADER].startswith("=?base64?") + + +@pytest.mark.anyio +async def test_core_tools_call_header_comes_from_the_stamp_alone() -> None: + """Core `tools/call` is unchanged: the modern stamp emits `Mcp-Name` from the + table; `CallToolRequest` declares no `name_param`, and on a legacy wire core + methods stay headerless exactly as today.""" + dispatcher = _RecordingDispatcher() + with anyio.fail_after(5): + async with ClientSession(dispatcher=dispatcher) as session: + _adopt_modern(session) + await session.call_tool("my-tool", {}) + _adopt_handshake(session) + await session.call_tool("my-tool", {}) + (_, modern_opts), (_, legacy_opts) = (call for call in dispatcher.calls if call[0] == "tools/call") + assert _headers(modern_opts)[MCP_NAME_HEADER] == "my-tool" + assert MCP_NAME_HEADER not in _headers(legacy_opts) + + +@pytest.mark.anyio +async def test_stamp_table_rows_win_over_name_param_by_ordering() -> None: + """Header-presence keying: when the modern stamp already emitted `Mcp-Name` + from the core table, a `name_param` on the request type does not overwrite it.""" + dispatcher = _RecordingDispatcher() + request = _ShadowCallToolRequest(params={"name": "real-tool", "customKey": "other-value"}) + with anyio.fail_after(5): + async with ClientSession(dispatcher=dispatcher) as session: + _adopt_modern(session) + await session.send_request(request, types.CallToolResult) + [(_, opts)] = dispatcher.calls + assert _headers(opts)[MCP_NAME_HEADER] == "real-tool" + + +@pytest.mark.anyio +async def test_missing_name_value_fails_loud_naming_method_and_key() -> None: + dispatcher = _RecordingDispatcher() + with anyio.fail_after(5): + async with ClientSession(dispatcher=dispatcher) as session: + _adopt_handshake(session) + with pytest.raises(ValueError, match=r"vendor/widgets/get requires params\['widgetId'\] for Mcp-Name"): + await session.send_request(_RawWidgetRequest(params={}), types.EmptyResult) + assert dispatcher.calls == [] # raised before reaching the wire + + +@pytest.mark.anyio +async def test_non_string_name_value_fails_loud() -> None: + dispatcher = _RecordingDispatcher() + with anyio.fail_after(5): + async with ClientSession(dispatcher=dispatcher) as session: + _adopt_handshake(session) + with pytest.raises(ValueError, match=r"vendor/widgets/get requires params\['widgetId'\] for Mcp-Name"): + await session.send_request(_RawWidgetRequest(params={"widgetId": 7}), types.EmptyResult) + assert dispatcher.calls == [] + + +@pytest.mark.anyio +async def test_request_without_name_param_sends_no_mcp_name() -> None: + """No `name_param` and a method outside `NAME_BEARING_METHODS`: neither + emitter produces an `Mcp-Name` header, on either era's stamp.""" + dispatcher = _RecordingDispatcher() + with anyio.fail_after(5): + async with ClientSession(dispatcher=dispatcher) as session: + _adopt_modern(session) + await session.send_request(_PlainVendorRequest(params={}), types.EmptyResult) + _adopt_handshake(session) + await session.send_ping() + for _, opts in dispatcher.calls: + assert MCP_NAME_HEADER not in _headers(opts) diff --git a/tests/client/test_session_promotions.py b/tests/client/test_session_promotions.py new file mode 100644 index 000000000..6010834b3 --- /dev/null +++ b/tests/client/test_session_promotions.py @@ -0,0 +1,83 @@ +"""`dispatch_input_request` and `validate_tool_result` are public `ClientSession` API.""" + +import re +from pathlib import Path + +import mcp_types as types +import pytest +from mcp_types import ( + CallToolResult, + ErrorData, + ListRootsResult, + ListToolsResult, + PaginatedRequestParams, + Tool, +) + +from mcp.client.client import Client +from mcp.client.session import ClientRequestContext, ClientSession +from mcp.server import Server, ServerRequestContext +from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair + + +@pytest.mark.anyio +async def test_dispatch_input_request_routes_through_the_callback_table() -> None: + expected = ListRootsResult(roots=[]) + + async def list_roots(context: ClientRequestContext) -> ListRootsResult: + return expected + + client_side, _server_side = create_direct_dispatcher_pair() + session = ClientSession(dispatcher=client_side, list_roots_callback=list_roots) + ctx = ClientRequestContext(session=session, request_id="r-1") + response = await session.dispatch_input_request(ctx, types.ListRootsRequest()) + assert response is expected + + +@pytest.mark.anyio +async def test_dispatch_input_request_returns_error_data_on_refusal() -> None: + """The `ErrorData` arm is the refusal path: with no callback registered, the + default callback declines and the caller receives the error, not a raise.""" + client_side, _server_side = create_direct_dispatcher_pair() + session = ClientSession(dispatcher=client_side) + ctx = ClientRequestContext(session=session, request_id="r-1") + response = await session.dispatch_input_request(ctx, types.ListRootsRequest()) + assert isinstance(response, ErrorData) + assert response.code == types.INVALID_REQUEST + + +def _make_server(output_schema: dict[str, object]) -> Server: + async def on_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="t", input_schema={"type": "object"}, output_schema=output_schema)]) + + return Server("test-server", on_list_tools=on_list_tools) + + +@pytest.mark.anyio +async def test_validate_tool_result_passes_a_conforming_result() -> None: + server = _make_server({"type": "object", "properties": {"x": {"type": "integer"}}, "required": ["x"]}) + async with Client(server) as client: + # The session fetches the listing itself when the tool isn't cached yet. + await client.session.validate_tool_result("t", CallToolResult(content=[], structured_content={"x": 1})) + + +@pytest.mark.anyio +async def test_validate_tool_result_raises_on_schema_mismatch() -> None: + server = _make_server({"type": "object", "properties": {"x": {"type": "integer"}}, "required": ["x"]}) + async with Client(server) as client: + with pytest.raises(RuntimeError, match="Invalid structured content returned by tool t"): + await client.session.validate_tool_result("t", CallToolResult(content=[], structured_content={"x": "no"})) + + +def _spell_private(name: str) -> str: + return f"_{name}" + + +def test_no_private_spelling_references_remain() -> None: + """The promotions are renames, not aliases — the old private names are gone from `src/`.""" + pattern = re.compile(f"{_spell_private('dispatch_input_request')}|{_spell_private('validate_tool_result')}") + src = Path(__file__).resolve().parents[2] / "src" + offenders = [ + (path.name, match) for path in sorted(src.rglob("*.py")) for match in pattern.findall(path.read_text()) + ] + assert not offenders diff --git a/tests/docs_src/test_extensions.py b/tests/docs_src/test_extensions.py index ebe00e5a8..cd412b85a 100644 --- a/tests/docs_src/test_extensions.py +++ b/tests/docs_src/test_extensions.py @@ -1,9 +1,7 @@ """`docs/advanced/extensions.md`: every claim the page makes, proved against the real SDK.""" import logging -from typing import cast -import mcp_types as types import pytest from inline_snapshot import snapshot from mcp_types import METHOD_NOT_FOUND, MISSING_REQUIRED_CLIENT_CAPABILITY, TextContent @@ -70,7 +68,7 @@ async def test_vendor_method_rejects_a_non_declaring_client_with_32021() -> None async with Client(tutorial004.mcp) as client: request = tutorial004.SearchRequest(params=tutorial004.SearchParams(query="mcp")) with pytest.raises(MCPError) as exc_info: - await client.session.send_request(cast("types.ClientRequest", request), tutorial004.SearchResult) + await client.session.send_request(request, tutorial004.SearchResult) assert exc_info.value.code == MISSING_REQUIRED_CLIENT_CAPABILITY assert exc_info.value.error.data == {"requiredCapabilities": {"extensions": {"com.example/search": {}}}} @@ -81,7 +79,7 @@ async def test_version_pinned_method_is_not_found_on_a_legacy_connection() -> No async with Client(tutorial004.mcp, mode="legacy", extensions={tutorial004.EXTENSION_ID: {}}) as client: request = tutorial004.SearchRequest(params=tutorial004.SearchParams(query="mcp")) with pytest.raises(MCPError) as exc_info: - await client.session.send_request(cast("types.ClientRequest", request), tutorial004.SearchResult) + await client.session.send_request(request, tutorial004.SearchResult) assert exc_info.value.code == METHOD_NOT_FOUND diff --git a/tests/server/mcpserver/test_extension.py b/tests/server/mcpserver/test_extension.py index e2ec366b2..8fec58e99 100644 --- a/tests/server/mcpserver/test_extension.py +++ b/tests/server/mcpserver/test_extension.py @@ -193,7 +193,7 @@ async def test_extension_method_reachable_via_session_send_request() -> None: async with Client(server) as client: request = _PingRequest(params=_PingParams()) - result = await client.session.send_request(cast("types.ClientRequest", request), _PingResult) + result = await client.session.send_request(request, _PingResult) assert result == snapshot(_PingResult(pong=True)) @@ -343,7 +343,7 @@ async def test_version_pinned_method_is_served_at_an_allowed_version() -> None: async with Client(server, mode="2026-07-28") as client: request = _VersionPinnedRequest(params=_VersionPinnedParams()) - result = await client.session.send_request(cast("types.ClientRequest", request), _VersionPinnedResult) + result = await client.session.send_request(request, _VersionPinnedResult) assert result == snapshot(_VersionPinnedResult(ok=True)) @@ -356,7 +356,7 @@ async def test_version_pinned_method_is_method_not_found_at_a_disallowed_version async with Client(server, mode="legacy") as client: request = _VersionPinnedRequest(params=_VersionPinnedParams()) with pytest.raises(MCPError) as exc_info: - await client.session.send_request(cast("types.ClientRequest", request), _VersionPinnedResult) + await client.session.send_request(request, _VersionPinnedResult) assert exc_info.value.code == METHOD_NOT_FOUND assert exc_info.value.error.data == "com.example/pinned" diff --git a/tests/types/test_request_name_param.py b/tests/types/test_request_name_param.py new file mode 100644 index 000000000..8666f2fca --- /dev/null +++ b/tests/types/test_request_name_param.py @@ -0,0 +1,37 @@ +"""`Request.name_param` — the wire-params key a request type declares for `Mcp-Name` emission.""" + +from typing import Literal + +import mcp_types as types +from mcp_types import CallToolRequest, PingRequest, Request + + +class _VendorParams(types.RequestParams): + task_id: str + + +class _VendorRequest(Request[_VendorParams, Literal["vendor/tasks/get"]]): + method: Literal["vendor/tasks/get"] = "vendor/tasks/get" + name_param = "taskId" + + +def test_request_base_declares_no_name_param() -> None: + assert Request.name_param is None + + +def test_core_request_types_inherit_none() -> None: + assert CallToolRequest.name_param is None + assert PingRequest.name_param is None + + +def test_subclass_overrides_by_bare_assignment() -> None: + """Subclasses set `name_param` by bare assignment; the override is class-local.""" + assert _VendorRequest.name_param == "taskId" + assert Request.name_param is None + + +def test_name_param_is_not_a_pydantic_field() -> None: + request = _VendorRequest(params=_VendorParams(task_id="t-1")) + assert "name_param" not in _VendorRequest.model_fields + dumped = request.model_dump(by_alias=True, mode="json", exclude_none=True) + assert dumped == {"method": "vendor/tasks/get", "params": {"taskId": "t-1"}} From a9edb2d59042d835da2ad5d06b59d27b1667ed95 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 30 Jun 2026 15:26:44 +0000 Subject: [PATCH 03/15] Snapshot extension validation messages instead of matching substrings --- tests/client/test_extension.py | 51 ++++++++++++++++++++++++---------- 1 file changed, 36 insertions(+), 15 deletions(-) diff --git a/tests/client/test_extension.py b/tests/client/test_extension.py index a06260703..8201bd3c8 100644 --- a/tests/client/test_extension.py +++ b/tests/client/test_extension.py @@ -88,21 +88,32 @@ def test_claim_accepts_modern_protocol_versions() -> None: assert claim.protocol_versions == versions -@pytest.mark.parametrize("result_type", ["complete", "input_required"]) -def test_claim_rejects_core_result_type_vocabulary(result_type: str) -> None: +def test_claim_rejects_core_result_type_vocabulary() -> None: """SDK-defined: "complete" and "input_required" are core protocol vocabulary — a claim cannot re-key the shapes the session itself routes on.""" - with pytest.raises(ValueError, match="core protocol vocabulary"): - ResultClaim(result_type=result_type, model=_TaskResult, resolve=_resolve) + messages: dict[str, str] = {} + for result_type in ("complete", "input_required"): + with pytest.raises(ValueError) as exc_info: + ResultClaim(result_type=result_type, model=_TaskResult, resolve=_resolve) + messages[result_type] = str(exc_info.value) + + assert messages == snapshot( + { + "complete": "resultType 'complete' is core protocol vocabulary", + "input_required": "resultType 'input_required' is core protocol vocabulary", + } + ) @pytest.mark.parametrize("model", [_ClaimedCallToolResult, _ClaimedInputRequiredResult]) def test_claim_rejects_model_subclassing_core_result_types(model: type[Result]) -> None: """SDK-defined: a claim model subclassing `CallToolResult` or `InputRequiredResult` would satisfy the session's isinstance branches and bypass claim routing.""" - with pytest.raises(ValueError, match="must not subclass core result types"): + with pytest.raises(ValueError) as exc_info: _claim(model=model) + assert str(exc_info.value) == snapshot("claim models must not subclass core result types") + def test_claim_rejects_model_without_result_type_field() -> None: """SDK-defined: the claim model must declare the discriminating `result_type` @@ -140,19 +151,29 @@ def test_claim_rejects_empty_protocol_versions() -> None: assert str(exc_info.value) == snapshot("empty protocol_versions could never activate; use None for all") -@pytest.mark.parametrize( - "versions", - [ +def test_claim_rejects_non_modern_protocol_versions() -> None: + """SDK-defined: claimed shapes cannot be delivered on a legacy wire, so a + non-None version set must be a subset of the modern protocol revisions.""" + messages: list[str] = [] + for versions in ( frozenset({"2025-11-25"}), frozenset({"2026-07-28", "2025-11-25"}), frozenset({"never-a-version"}), - ], -) -def test_claim_rejects_non_modern_protocol_versions(versions: frozenset[str]) -> None: - """SDK-defined: claimed shapes cannot be delivered on a legacy wire, so a - non-None version set must be a subset of the modern protocol revisions.""" - with pytest.raises(ValueError, match="not modern protocol revisions"): - _claim(protocol_versions=versions) + ): + with pytest.raises(ValueError) as exc_info: + _claim(protocol_versions=versions) + messages.append(str(exc_info.value)) + + assert messages == snapshot( + [ + "protocol_versions ['2025-11-25'] are not modern protocol revisions; claimed shapes " + "cannot be delivered on a legacy wire (None means every modern version)", + "protocol_versions ['2025-11-25'] are not modern protocol revisions; claimed shapes " + "cannot be delivered on a legacy wire (None means every modern version)", + "protocol_versions ['never-a-version'] are not modern protocol revisions; claimed shapes " + "cannot be delivered on a legacy wire (None means every modern version)", + ] + ) def test_result_claim_is_frozen() -> None: From 9de9e44820b2beb825e92a6209d6424c1c71cc43 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 30 Jun 2026 15:31:58 +0000 Subject: [PATCH 04/15] Pin preconnect and absent-params arms of the Mcp-Name delta Emission must hold with no version adopted at all, and a dropped None params must still surface the documented ValueError. Also snapshot the SDK-authored failure messages instead of regex-matching them. --- tests/client/test_send_request_mcp_name.py | 41 ++++++++++++++++++++-- tests/client/test_session_promotions.py | 1 + 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/tests/client/test_send_request_mcp_name.py b/tests/client/test_send_request_mcp_name.py index 0cc318ff9..8cd045787 100644 --- a/tests/client/test_send_request_mcp_name.py +++ b/tests/client/test_send_request_mcp_name.py @@ -16,6 +16,7 @@ import anyio.abc import mcp_types as types import pytest +from inline_snapshot import snapshot from mcp_types import ( CallToolResult, Implementation, @@ -93,6 +94,14 @@ class _PlainVendorRequest(Request[dict[str, Any], Literal["vendor/widgets/list"] method: Literal["vendor/widgets/list"] = "vendor/widgets/list" +class _OptionalParamsWidgetRequest(Request[dict[str, Any] | None, Literal["vendor/widgets/get"]]): + """Optional params, so a send can carry no params key at all.""" + + method: Literal["vendor/widgets/get"] = "vendor/widgets/get" + params: dict[str, Any] | None = None + name_param = "widgetId" + + def _adopt_modern(session: ClientSession) -> None: session.adopt( types.DiscoverResult( @@ -191,15 +200,28 @@ async def test_stamp_table_rows_win_over_name_param_by_ordering() -> None: assert _headers(opts)[MCP_NAME_HEADER] == "real-tool" +@pytest.mark.anyio +async def test_vendor_name_param_emits_mcp_name_on_the_preconnect_path() -> None: + """Emission is era-unconditional: a lowlevel caller that never adopts any + version (the preconnect stamp) still gets `Mcp-Name` from `name_param`.""" + dispatcher = _RecordingDispatcher() + with anyio.fail_after(5): + async with ClientSession(dispatcher=dispatcher) as session: + await session.send_request(_GetWidgetRequest(params=_GetWidgetParams(widget_id="w-1")), types.EmptyResult) + [(_, opts)] = dispatcher.calls + assert _headers(opts) == {MCP_NAME_HEADER: "w-1"} # and no era headers: nothing adopted + + @pytest.mark.anyio async def test_missing_name_value_fails_loud_naming_method_and_key() -> None: dispatcher = _RecordingDispatcher() with anyio.fail_after(5): async with ClientSession(dispatcher=dispatcher) as session: _adopt_handshake(session) - with pytest.raises(ValueError, match=r"vendor/widgets/get requires params\['widgetId'\] for Mcp-Name"): + with pytest.raises(ValueError) as exc_info: await session.send_request(_RawWidgetRequest(params={}), types.EmptyResult) assert dispatcher.calls == [] # raised before reaching the wire + assert str(exc_info.value) == snapshot("vendor/widgets/get requires params['widgetId'] for Mcp-Name") @pytest.mark.anyio @@ -208,9 +230,24 @@ async def test_non_string_name_value_fails_loud() -> None: with anyio.fail_after(5): async with ClientSession(dispatcher=dispatcher) as session: _adopt_handshake(session) - with pytest.raises(ValueError, match=r"vendor/widgets/get requires params\['widgetId'\] for Mcp-Name"): + with pytest.raises(ValueError) as exc_info: await session.send_request(_RawWidgetRequest(params={"widgetId": 7}), types.EmptyResult) assert dispatcher.calls == [] + assert str(exc_info.value) == snapshot("vendor/widgets/get requires params['widgetId'] for Mcp-Name") + + +@pytest.mark.anyio +async def test_absent_params_fails_loud_not_attribute_error() -> None: + """`exclude_none` drops a None params entirely; the delta still answers with + the documented ValueError, not an AttributeError on the missing key.""" + dispatcher = _RecordingDispatcher() + with anyio.fail_after(5): + async with ClientSession(dispatcher=dispatcher) as session: + _adopt_handshake(session) + with pytest.raises(ValueError) as exc_info: + await session.send_request(_OptionalParamsWidgetRequest(), types.EmptyResult) + assert dispatcher.calls == [] + assert str(exc_info.value) == snapshot("vendor/widgets/get requires params['widgetId'] for Mcp-Name") @pytest.mark.anyio diff --git a/tests/client/test_session_promotions.py b/tests/client/test_session_promotions.py index 6010834b3..f944ef7e0 100644 --- a/tests/client/test_session_promotions.py +++ b/tests/client/test_session_promotions.py @@ -65,6 +65,7 @@ async def test_validate_tool_result_passes_a_conforming_result() -> None: async def test_validate_tool_result_raises_on_schema_mismatch() -> None: server = _make_server({"type": "object", "properties": {"x": {"type": "integer"}}, "required": ["x"]}) async with Client(server) as client: + # Stable SDK prefix only: the message tail is jsonschema text that shifts with the dependency. with pytest.raises(RuntimeError, match="Invalid structured content returned by tool t"): await client.session.validate_tool_result("t", CallToolResult(content=[], structured_content={"x": "no"})) From c73a27f2730877655c956aad60804d8c57e038d1 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 30 Jun 2026 15:45:31 +0000 Subject: [PATCH 05/15] Fold result claims and notification bindings into ClientSession MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Claims arrive keyed by their owning extension identifier; at adopt() the active set is computed for the negotiated version and the tools/call result adapter is rebuilt as a discriminated union (zero active claims keeps the module-level adapter, byte-identical parsing). The capability ad is built per version through the same association, so an extension whose claims are all inactive is not advertised — the ad and the claims dissolve together, on the initialize, discover-probe, and modern-stamp paths alike. call_tool gains allow_claimed: a claimed shape raises UnexpectedClaimedResult carrying the parsed value unless allowed. Notification bindings deliver through per-binding bounded FIFOs with one serialized consumer task each; enqueue never awaits, and overflow drops the oldest event with a warning. CORE_RESULT_TYPES derives from the ResultType literal in mcp-types, one source of truth for both the claim constructor and the session. --- src/mcp-types/mcp_types/__init__.py | 2 + src/mcp-types/mcp_types/_types.py | 9 +- src/mcp/client/__init__.py | 2 + src/mcp/client/extension.py | 28 +- src/mcp/client/session.py | 276 ++++++++++- tests/client/test_session_claims.py | 446 ++++++++++++++++++ .../test_session_notification_bindings.py | 310 ++++++++++++ 7 files changed, 1051 insertions(+), 22 deletions(-) create mode 100644 tests/client/test_session_claims.py create mode 100644 tests/client/test_session_notification_bindings.py diff --git a/src/mcp-types/mcp_types/__init__.py b/src/mcp-types/mcp_types/__init__.py index 2ed97cba3..87c0c5d59 100644 --- a/src/mcp-types/mcp_types/__init__.py +++ b/src/mcp-types/mcp_types/__init__.py @@ -8,6 +8,7 @@ from mcp_types._types import ( CLIENT_CAPABILITIES_META_KEY, CLIENT_INFO_META_KEY, + CORE_RESULT_TYPES, DEFAULT_NEGOTIATED_VERSION, LOG_LEVEL_META_KEY, PROTOCOL_VERSION_META_KEY, @@ -231,6 +232,7 @@ "CLIENT_CAPABILITIES_META_KEY", "LOG_LEVEL_META_KEY", # Type aliases and variables + "CORE_RESULT_TYPES", "ContentBlock", "ElicitRequestedSchema", "ElicitRequestParams", diff --git a/src/mcp-types/mcp_types/_types.py b/src/mcp-types/mcp_types/_types.py index 94156879a..2b5698770 100644 --- a/src/mcp-types/mcp_types/_types.py +++ b/src/mcp-types/mcp_types/_types.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Annotated, Any, ClassVar, Final, Generic, Literal, TypeAlias, TypeVar +from typing import Annotated, Any, ClassVar, Final, Generic, Literal, TypeAlias, TypeVar, get_args from pydantic import ( BaseModel, @@ -150,7 +150,9 @@ class Notification(MCPModel, Generic[NotificationParamsT, MethodT]): params: NotificationParamsT -ResultType = Literal["complete", "input_required"] | str +_CoreResultType = Literal["complete", "input_required"] + +ResultType = _CoreResultType | str """Tags a `Result` so the client knows how to parse it (2026-07-28). "complete" means the result is final; "input_required" means it is an @@ -158,6 +160,9 @@ class Notification(MCPModel, Generic[NotificationParamsT, MethodT]): Absent `resultType` is equivalent to "complete". """ +CORE_RESULT_TYPES: Final[frozenset[str]] = frozenset(get_args(_CoreResultType)) +"""The `resultType` tags owned by the core protocol vocabulary; extension claims may not re-key them.""" + class Result(MCPModel): """Base class for JSON-RPC results. diff --git a/src/mcp/client/__init__.py b/src/mcp/client/__init__.py index 9a3c3ae0f..21581749d 100644 --- a/src/mcp/client/__init__.py +++ b/src/mcp/client/__init__.py @@ -17,6 +17,7 @@ ClientExtension, NotificationBinding, ResultClaim, + UnexpectedClaimedResult, advertise, ) from mcp.client.session import ClientSession @@ -37,5 +38,6 @@ "ResponseCacheStore", "ResultClaim", "Transport", + "UnexpectedClaimedResult", "advertise", ] diff --git a/src/mcp/client/extension.py b/src/mcp/client/extension.py index 15ca9a08c..21fcd28ed 100644 --- a/src/mcp/client/extension.py +++ b/src/mcp/client/extension.py @@ -14,7 +14,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, get_args -from mcp_types import CallToolResult, InputRequiredResult, Result +from mcp_types import CORE_RESULT_TYPES, CallToolResult, InputRequiredResult, Result from mcp_types.version import MODERN_PROTOCOL_VERSIONS from pydantic import BaseModel @@ -28,6 +28,7 @@ "ClientExtension", "NotificationBinding", "ResultClaim", + "UnexpectedClaimedResult", "advertise", ] @@ -81,7 +82,7 @@ class ResultClaim(Generic[ClaimedT]): protocol_versions: frozenset[str] | None = None def __post_init__(self) -> None: - if self.result_type in ("complete", "input_required"): + if self.result_type in CORE_RESULT_TYPES: raise ValueError(f"resultType {self.result_type!r} is core protocol vocabulary") if issubclass(self.model, CallToolResult | InputRequiredResult): raise ValueError("claim models must not subclass core result types") @@ -98,9 +99,30 @@ def __post_init__(self) -> None: ) +class UnexpectedClaimedResult(RuntimeError): + """A claimed (extension) result shape arrived on a `call_tool` that did not opt in. + + Raised by `ClientSession.call_tool` when a claimed shape parses and + `allow_claimed` is False. By the time this raises the server may have + durably created state (e.g. a task) — the parsed value is carried as + `result` so the caller can reach its id to clean up, not just read a + message. To handle claimed shapes, pass the owning extension to + `Client(extensions=[...])` (the transparent path) or call with + `allow_claimed=True` and handle the shape yourself. + """ + + def __init__(self, result: Result) -> None: + super().__init__( + f"Server returned a claimed result ({type(result).__name__}); pass the owning extension to " + "Client(extensions=[...]) for transparent resolution, or call with allow_claimed=True " + "and handle the shape. The carried result may reference server-side state needing cleanup." + ) + self.result = result + + @dataclass(frozen=True, kw_only=True) class NotificationBinding(Generic[NotifyParamsT]): - """Deliver server notifications for `method` to `handler` (today: silently dropped). + """Deliver server notifications for `method` to `handler` (unbound methods stay silently dropped). Observation-only: the handler receives validated params, returns None, and cannot short-circuit anything. Delivery is per-binding serialized through a diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 8a401b198..b2e551ba0 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,15 +1,18 @@ from __future__ import annotations import logging -from collections.abc import Callable, Mapping +from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass +from functools import reduce +from operator import or_ from types import TracebackType -from typing import Any, Literal, Protocol, cast, overload +from typing import Annotated, Any, Final, Literal, Protocol, cast, overload import anyio import anyio.abc import anyio.lowlevel import mcp_types as types +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp_types import ( CLIENT_CAPABILITIES_META_KEY, CLIENT_INFO_META_KEY, @@ -27,10 +30,11 @@ LATEST_MODERN_VERSION, MODERN_PROTOCOL_VERSIONS, ) -from pydantic import BaseModel, TypeAdapter, ValidationError +from pydantic import BaseModel, Discriminator, Tag, TypeAdapter, ValidationError from typing_extensions import Self, TypeVar, deprecated from mcp.client._transport import ReadStream, WriteStream +from mcp.client.extension import NotificationBinding, ResultClaim, UnexpectedClaimedResult from mcp.shared._compat import resync_tracer from mcp.shared.dispatcher import CallOptions, DispatchContext, Dispatcher, ProgressFnT from mcp.shared.exceptions import MCPDeprecationWarning, MCPError @@ -51,6 +55,7 @@ DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") DISCOVER_TIMEOUT_SECONDS = 10.0 +_NOTIFICATION_QUEUE_SIZE: Final = 256 logger = logging.getLogger("client") @@ -189,7 +194,9 @@ async def _default_logging_callback( ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData) -_CallToolResultAdapter: TypeAdapter[types.CallToolResult | types.InputRequiredResult] = TypeAdapter( +# Declared against the wide tools/call parse union so the session's adopt-built +# claim adapters (which add `Result` arms) share one attribute type with it. +_CallToolResultAdapter: TypeAdapter[types.CallToolResult | types.InputRequiredResult | types.Result] = TypeAdapter( types.CallToolResult | types.InputRequiredResult ) _GetPromptResultAdapter: TypeAdapter[types.GetPromptResult | types.InputRequiredResult] = TypeAdapter( @@ -200,6 +207,97 @@ async def _default_logging_callback( ) +def _claim_active(claim: ResultClaim[Any], version: str) -> bool: + """A claim is active at modern versions only, narrowed by its optional version subset.""" + return version in MODERN_PROTOCOL_VERSIONS and ( + claim.protocol_versions is None or version in claim.protocol_versions + ) + + +def _active_claims_at( + claims_by_extension: Mapping[str, tuple[ResultClaim[Any], ...]], version: str +) -> dict[str, ResultClaim[Any]]: + """Claims active at `version`, keyed by wire tag (unique across extensions by construction). + + Empty at any legacy version: no claim is active there, so both `adopt()` arms + share this one rule. + """ + return { + claim.result_type: claim + for claims in claims_by_extension.values() + for claim in claims + if _claim_active(claim, version) + } + + +def _build_call_tool_adapter( + active: Mapping[str, ResultClaim[Any]], +) -> TypeAdapter[types.CallToolResult | types.InputRequiredResult | types.Result]: + """Discriminated tools/call adapter: a core arm plus one arm per active claim. + + Zero active claims returns the module-level `_CallToolResultAdapter` itself, keeping + the no-extensions parse path byte-identical. + """ + if not active: + return _CallToolResultAdapter + tags = frozenset(active) + + def _route(value: Any) -> str: + # pydantic hands the discriminator either the raw inbound dict or an + # already-built model (revalidation). A non-string or unknown tag stays on + # the core arm, so a malformed `resultType` fails core validation instead + # of blowing up the discriminator lookup. + if isinstance(value, dict): + tag = cast("dict[str, Any]", value).get("resultType") + else: + tag = getattr(value, "result_type", None) + return tag if isinstance(tag, str) and tag in tags else "core" + + arms: list[Any] = [Annotated[types.CallToolResult | types.InputRequiredResult, Tag("core")]] + arms += [Annotated[claim.model, Tag(tag)] for tag, claim in active.items()] + # reduce(or_, ...) builds the Union dynamically; PEP-646 star-unpack needs py3.11+. + return TypeAdapter(Annotated[reduce(or_, arms), Discriminator(_route)]) + + +def _index_claims( + result_claims: Mapping[str, Sequence[ResultClaim[Any]]] | None, + extensions: dict[str, dict[str, Any]] | None, +) -> dict[str, tuple[ResultClaim[Any], ...]]: + """Validate and copy the claims-by-extension mapping. + + The mapping keys ARE the claim/ad association: at adopt the capability ad and the + claims dissolve together per extension identifier, so a key must name an + advertised extension and no wire tag may be claimed twice. + """ + indexed: dict[str, tuple[ResultClaim[Any], ...]] = {} + seen: set[tuple[str, str]] = set() + for identifier, claims in (result_claims or {}).items(): + if extensions is None or identifier not in extensions: + raise ValueError( + f"result_claims key {identifier!r} has no extensions entry; a claim is only " + "advertised through its extension's capability ad" + ) + for claim in claims: + key = (claim.method, claim.result_type) + if key in seen: + raise ValueError(f"duplicate result claim for {claim.method!r} resultType {claim.result_type!r}") + seen.add(key) + indexed[identifier] = tuple(claims) + return indexed + + +def _index_bindings( + notification_bindings: Sequence[NotificationBinding[Any]] | None, +) -> dict[str, NotificationBinding[Any]]: + """Index bindings by wire method, rejecting duplicates.""" + indexed: dict[str, NotificationBinding[Any]] = {} + for binding in notification_bindings or (): + if binding.method in indexed: + raise ValueError(f"duplicate notification binding for method {binding.method!r}") + indexed[binding.method] = binding + return indexed + + def _input_required_unexpected(method: str) -> RuntimeError: return RuntimeError( "Server returned InputRequiredResult; pass allow_input_required=True to receive it " @@ -216,6 +314,12 @@ class ClientSession: correlation; this class owns the typed MCP layer and the constructor callbacks. Transport `Exception` items reach `message_handler` only when the session builds its own dispatcher from a stream pair. + + Extension contributions enter here too: `result_claims` (keyed by the + advertising identifier in `extensions`, so the capability ad and the claims + dissolve together) fold into tools/call parsing at `adopt()`, and + `notification_bindings` observe vendor notifications through per-binding + bounded FIFOs. """ def __init__( @@ -232,6 +336,8 @@ def __init__( *, sampling_capabilities: types.SamplingCapability | None = None, extensions: dict[str, dict[str, Any]] | None = None, + result_claims: Mapping[str, Sequence[ResultClaim[Any]]] | None = None, + notification_bindings: Sequence[NotificationBinding[Any]] | None = None, dispatcher: Dispatcher[Any] | None = None, ) -> None: self._session_read_timeout_seconds = read_timeout_seconds @@ -239,6 +345,13 @@ def __init__( self._sampling_callback = sampling_callback or _default_sampling_callback self._sampling_capabilities = sampling_capabilities self._extensions = extensions + self._result_claims = _index_claims(result_claims, extensions) + self._notification_bindings = _index_bindings(notification_bindings) + self._active_claims: dict[str, ResultClaim[Any]] = {} + self._call_tool_adapter = _CallToolResultAdapter + self._binding_queues: dict[ + str, tuple[MemoryObjectSendStream[BaseModel], MemoryObjectReceiveStream[BaseModel]] + ] = {} self._elicitation_callback = elicitation_callback or _default_elicitation_callback self._list_roots_callback = list_roots_callback or _default_list_roots_callback self._logging_callback = logging_callback or _default_logging_callback @@ -275,6 +388,10 @@ async def __aenter__(self) -> Self: await self._task_group.__aenter__() try: await self._task_group.start(self._dispatcher.run, self._on_request, self._on_notify) + for binding in self._notification_bindings.values(): + send, receive = anyio.create_memory_object_stream[BaseModel](_NOTIFICATION_QUEUE_SIZE) + self._binding_queues[binding.method] = (send, receive) + self._task_group.start_soon(self._deliver_bound_notifications, binding, receive) except BaseException: # Unwind the entered task group before propagating: a cancellation # landing here (e.g. `move_on_after` around connect) would abandon @@ -286,6 +403,7 @@ async def __aenter__(self) -> Self: # so a pending outer cancellation cannot re-fire inside __aexit__. task_group.cancel_scope.shield = True await task_group.__aexit__(None, None, None) + self._close_binding_queues() raise return self @@ -295,13 +413,39 @@ async def __aexit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> bool | None: - # Exit must not block: cancel the dispatcher and in-flight callbacks. + # Exit must not block: cancel the dispatcher, binding consumers, and in-flight callbacks. assert self._task_group is not None self._task_group.cancel_scope.cancel() result = await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + self._close_binding_queues() await resync_tracer() return result + def _close_binding_queues(self) -> None: + # Memory object streams warn at garbage collection unless closed; the consumers + # die by task-group cancellation, so both ends are closed here (close is idempotent). + for send, receive in self._binding_queues.values(): + send.close() + receive.close() + self._binding_queues.clear() + + async def _deliver_bound_notifications( + self, binding: NotificationBinding[Any], receive: MemoryObjectReceiveStream[BaseModel] + ) -> None: + """Serialized consumer for one binding's FIFO. + + Spawn-decoupled from the dispatcher so the handler may do session I/O without + deadlocking in-process delivery; dies with the session's task group. + """ + while True: + params = await receive.receive() + try: + await binding.handler(params) + except Exception: + # Same containment contract as the notification callbacks in `_on_notify`: + # a raising handler costs only that delivery. + logger.exception("notification binding handler for %r raised", binding.method) + async def send_request( self, request: types.ClientRequest | types.Request[Any, Any], @@ -371,7 +515,23 @@ async def send_notification(self, notification: types.ClientNotification) -> Non self._stamp(data, opts) await self._dispatcher.notify(data["method"], data.get("params"), opts) - def _build_capabilities(self) -> types.ClientCapabilities: + def _build_capabilities(self, version: str) -> types.ClientCapabilities: + """Build the capability ad for a wire speaking `version`. + + An identifier with result claims contributes to the ad only when at least one + of its claims is active at `version` — the ad and the claims dissolve + together, so the client never advertises an extension on a request whose + claimed result shapes it would reject. Claim-less identifiers always + contribute; when every identifier drops, the ad omits `extensions` entirely. + """ + extensions = self._extensions + if extensions is not None and self._result_claims: + extensions = { + identifier: settings + for identifier, settings in extensions.items() + if identifier not in self._result_claims + or any(_claim_active(claim, version) for claim in self._result_claims[identifier]) + } or None sampling = ( (self._sampling_capabilities or types.SamplingCapability()) if self._sampling_callback is not _default_sampling_callback @@ -391,7 +551,7 @@ def _build_capabilities(self) -> types.ClientCapabilities: else None ) return types.ClientCapabilities( - sampling=sampling, elicitation=elicitation, experimental=None, extensions=self._extensions, roots=roots + sampling=sampling, elicitation=elicitation, experimental=None, extensions=extensions, roots=roots ) async def initialize(self) -> types.InitializeResult: @@ -401,7 +561,9 @@ async def initialize(self) -> types.InitializeResult: types.InitializeRequest( params=types.InitializeRequestParams( protocol_version=LATEST_HANDSHAKE_VERSION, - capabilities=self._build_capabilities(), + # The handshake can only negotiate legacy versions, where no claim is + # active — every claim-bearing identifier drops from this ad. + capabilities=self._build_capabilities(LATEST_HANDSHAKE_VERSION), client_info=self._client_info, ), ), @@ -435,17 +597,32 @@ def adopt(self, result: types.InitializeResult | types.DiscoverResult) -> None: f"No mutually supported modern protocol version " f"(server: {result.supported_versions}, client: {list(MODERN_PROTOCOL_VERSIONS)})" ) + version = mutual[-1] client_info = self._client_info.model_dump(by_alias=True, mode="json", exclude_none=True) - capabilities = self._build_capabilities().model_dump(by_alias=True, mode="json", exclude_none=True) - self._stamp = _make_modern_stamp(mutual[-1], client_info, capabilities, self._resolve_param_headers) + capabilities = self._build_capabilities(version).model_dump(by_alias=True, mode="json", exclude_none=True) + self._stamp = _make_modern_stamp(version, client_info, capabilities, self._resolve_param_headers) self._discover_result = result self._initialize_result = None - self._negotiated_version = mutual[-1] else: - self._stamp = _make_handshake_stamp(result.protocol_version) + version = result.protocol_version + self._stamp = _make_handshake_stamp(version) self._initialize_result = result self._discover_result = None - self._negotiated_version = result.protocol_version + self._negotiated_version = version + # Assigned fresh in both arms (re-adopt safe): empty at any legacy version by the + # one activation rule. Claims tagged with core vocabulary are unconstructible + # (`ResultClaim.__post_init__`), so activation needs no core-tag exclusion. + self._active_claims = _active_claims_at(self._result_claims, version) + self._call_tool_adapter = _build_call_tool_adapter(self._active_claims) + for method in self._notification_bindings: + # Bindings are consulted only for methods core does not know (`_on_notify`'s + # KeyError branch), so a core-known binding can never fire — say so once here. + if (method, version) in _methods.SERVER_NOTIFICATIONS: + logger.warning( + "notification binding for %r will never fire at %s: the core protocol defines this method", + method, + version, + ) async def send_discover(self, version: str) -> dict[str, Any]: """Send a single ``server/discover`` at ``version`` and return the raw result dict. @@ -461,7 +638,7 @@ async def send_discover(self, version: str) -> dict[str, Any]: synthesized into a JSON-RPC error by the transport). """ client_info = self._client_info.model_dump(by_alias=True, mode="json", exclude_none=True) - capabilities = self._build_capabilities().model_dump(by_alias=True, mode="json", exclude_none=True) + capabilities = self._build_capabilities(version).model_dump(by_alias=True, mode="json", exclude_none=True) request = types.DiscoverRequest( params=types.RequestParams( _meta={ @@ -715,6 +892,7 @@ async def call_tool( request_state: str | None = None, meta: RequestParamsMeta | None = None, allow_input_required: Literal[False] = False, + allow_claimed: Literal[False] = False, ) -> types.CallToolResult: ... @overload @@ -729,8 +907,39 @@ async def call_tool( request_state: str | None = None, meta: RequestParamsMeta | None = None, allow_input_required: bool, + allow_claimed: Literal[False] = False, ) -> types.CallToolResult | types.InputRequiredResult: ... + @overload + async def call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: float | None = None, + progress_callback: ProgressFnT | None = None, + *, + input_responses: types.InputResponses | None = None, + request_state: str | None = None, + meta: RequestParamsMeta | None = None, + allow_input_required: Literal[False] = False, + allow_claimed: bool, + ) -> types.CallToolResult | types.Result: ... + + @overload + async def call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: float | None = None, + progress_callback: ProgressFnT | None = None, + *, + input_responses: types.InputResponses | None = None, + request_state: str | None = None, + meta: RequestParamsMeta | None = None, + allow_input_required: bool, + allow_claimed: bool, + ) -> types.CallToolResult | types.InputRequiredResult | types.Result: ... + async def call_tool( self, name: str, @@ -742,7 +951,8 @@ async def call_tool( request_state: str | None = None, meta: RequestParamsMeta | None = None, allow_input_required: bool = False, - ) -> types.CallToolResult | types.InputRequiredResult: + allow_claimed: bool = False, + ) -> types.CallToolResult | types.InputRequiredResult | types.Result: """Send a tools/call request with optional progress callback support. On a modern (2026-07-28) connection, arguments annotated with `x-mcp-header` @@ -756,10 +966,15 @@ async def call_tool( allow_input_required: When ``False`` (default), an `InputRequiredResult` from the server raises `RuntimeError`; when ``True``, it is returned so the caller can resolve the requests and retry. + allow_claimed: When ``False`` (default), a claimed (extension) result + shape raises `UnexpectedClaimedResult`; when ``True``, the parsed + claim model is returned for the caller to handle. Raises: RuntimeError: If the server returns an `InputRequiredResult` and ``allow_input_required`` is ``False``. + UnexpectedClaimedResult: If a claimed result shape parses and + ``allow_claimed`` is ``False``; carries the parsed value. """ result = await self.send_request( types.CallToolRequest( @@ -771,7 +986,7 @@ async def call_tool( _meta=meta, ), ), - _CallToolResultAdapter, + self._call_tool_adapter, request_read_timeout_seconds=read_timeout_seconds, progress_callback=progress_callback, ) @@ -779,8 +994,12 @@ async def call_tool( if isinstance(result, types.CallToolResult) and not result.is_error: await self.validate_tool_result(name, result) + # Driver-innermost ordering: the input_required arm stays first — a claimed + # shape exits the multi-round-trip driver as terminal. if isinstance(result, types.InputRequiredResult) and not allow_input_required: raise _input_required_unexpected("call_tool") + if not isinstance(result, types.CallToolResult | types.InputRequiredResult) and not allow_claimed: + raise UnexpectedClaimedResult(result) return result def _resolve_param_headers(self, name: str, arguments: Mapping[str, Any]) -> dict[str, str]: @@ -1031,7 +1250,30 @@ async def _on_notify( try: notification = cast(types.ServerNotification, _methods.parse_server_notification(method, version, params)) except KeyError: - logger.debug("dropped %r: not defined at %s", method, version) + # Methods the negotiated version's core tables do not know are offered to + # the notification bindings; core-known methods structurally never get here. + binding = self._notification_bindings.get(method) + if binding is None: + logger.debug("dropped %r: not defined at %s", method, version) + return + try: + bound_params = binding.params_type.model_validate(params or {}) + except ValidationError: + # Mirrors the core notification arm below: warn and drop. + logger.warning("Failed to validate notification: %s", method, exc_info=True) + return + send, receive = self._binding_queues[method] + try: + # Never awaits: DirectDispatcher awaits _on_notify inline in the peer's + # notify(), so blocking here would deadlock in-process servers. + send.send_nowait(bound_params) + except anyio.WouldBlock: + # Bounded FIFO: evict the oldest queued event (observation semantics + # tolerate the loss). No checkpoint since the failed send, so the + # buffer is still full and the eviction cannot itself block. + receive.receive_nowait() + logger.warning("notification queue for %r is full; dropped the oldest event", method) + send.send_nowait(bound_params) return except ValidationError: logger.warning("Failed to validate notification: %s", method, exc_info=True) diff --git a/tests/client/test_session_claims.py b/tests/client/test_session_claims.py new file mode 100644 index 000000000..133b826f5 --- /dev/null +++ b/tests/client/test_session_claims.py @@ -0,0 +1,446 @@ +"""`ClientSession` result claims: construction validation, adopt-time activation, +the discriminated tools/call adapter, the version-aware capability ad, and the +`allow_claimed` escape hatch. + +Claims activate at `adopt()` — the stamp-swap moment — and only at modern +protocol versions; everywhere else parsing stays byte-identical to a claim-less +session (the zero-claims adapter is the module-level constant, by identity). +""" + +from collections.abc import Mapping +from typing import Any, Literal + +import anyio +import anyio.abc +import mcp_types as types +import pytest +from inline_snapshot import snapshot +from mcp_types import ( + CLIENT_CAPABILITIES_META_KEY, + CallToolResult, + Implementation, + InputRequiredResult, + ListToolsResult, + Result, + ServerCapabilities, + TextContent, + Tool, +) +from mcp_types.methods import validate_server_result +from mcp_types.version import LATEST_HANDSHAKE_VERSION, LATEST_MODERN_VERSION +from pydantic import ValidationError +from typing_extensions import assert_type + +from mcp.client.extension import ClaimContext, ResultClaim, UnexpectedClaimedResult +from mcp.client.session import ClientSession, _CallToolResultAdapter +from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest + +_TASKS_EXT = "com.example/tasks" +_AD_ONLY_EXT = "com.example/flags" + + +class _TaskResult(Result): + """A claimed result shape, tagged `task`.""" + + result_type: Literal["task"] = "task" + task_id: str + + +async def _resolve_task(result: _TaskResult, ctx: ClaimContext) -> CallToolResult: + raise NotImplementedError # session-tier tests never drive a resolver; that is the Client's job + + +def _task_claim(**kwargs: Any) -> ResultClaim[_TaskResult]: + return ResultClaim(result_type="task", model=_TaskResult, resolve=_resolve_task, **kwargs) + + +_COMPLETE_TOOL_RESULT = CallToolResult(content=[TextContent(type="text", text="ok")]).model_dump( + by_alias=True, mode="json", exclude_none=True +) +_CLAIMED_TASK_RESULT = {"resultType": "task", "taskId": "t-1"} +_TOOL_LISTING = ListToolsResult(tools=[Tool(name="t", input_schema={"type": "object"})]).model_dump( + by_alias=True, mode="json", exclude_none=True +) +_INITIALIZE_RESULT = types.InitializeResult( + protocol_version=LATEST_HANDSHAKE_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="stub", version="0"), +).model_dump(by_alias=True, mode="json", exclude_none=True) + + +class _RecordingDispatcher: + """Records every send and answers each method with a canned result.""" + + def __init__(self, tool_result: dict[str, Any] | None = None) -> None: + self.calls: list[tuple[str, Mapping[str, Any] | None, CallOptions]] = [] + self.notifications: list[str] = [] + self._tool_result = tool_result if tool_result is not None else _COMPLETE_TOOL_RESULT + + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: + task_status.started() + await anyio.sleep_forever() + + async def send_raw_request( + self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None + ) -> dict[str, Any]: + self.calls.append((method, params, opts or {})) + if method == "tools/call": + return self._tool_result + if method == "tools/list": + return _TOOL_LISTING + if method == "initialize": + return _INITIALIZE_RESULT + return {} + + async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: + self.notifications.append(method) + + +def _claims_session(dispatcher: _RecordingDispatcher, *claims: ResultClaim[Any]) -> ClientSession: + return ClientSession(dispatcher=dispatcher, extensions={_TASKS_EXT: {}}, result_claims={_TASKS_EXT: list(claims)}) + + +def _adopt_modern(session: ClientSession) -> None: + session.adopt( + types.DiscoverResult( + supported_versions=[LATEST_MODERN_VERSION], + capabilities=ServerCapabilities(), + server_info=Implementation(name="stub", version="0"), + ) + ) + + +def _adopt_handshake(session: ClientSession) -> None: + session.adopt( + types.InitializeResult( + protocol_version=LATEST_HANDSHAKE_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="stub", version="0"), + ) + ) + + +# ── Construction-time validation ──────────────────────────────────────────── + + +def test_duplicate_claim_tag_across_extensions_rejected() -> None: + """SDK-defined: two claims on the same (method, resultType) — even from different + extensions — could not be routed apart, so construction fails.""" + with pytest.raises(ValueError) as exc_info: + ClientSession( + dispatcher=_RecordingDispatcher(), + extensions={_TASKS_EXT: {}, _AD_ONLY_EXT: {}}, + result_claims={_TASKS_EXT: [_task_claim()], _AD_ONLY_EXT: [_task_claim()]}, + ) + + assert str(exc_info.value) == snapshot("duplicate result claim for 'tools/call' resultType 'task'") + + +def test_claims_keyed_to_unadvertised_extension_rejected() -> None: + """SDK-defined: a claim rides its extension's capability ad — a `result_claims` key + with no `extensions` entry (including extensions=None) advertises nothing the + server could ever act on, so construction fails.""" + messages: list[str] = [] + for extensions in (None, {_AD_ONLY_EXT: {"flag": True}}): + with pytest.raises(ValueError) as exc_info: + ClientSession( + dispatcher=_RecordingDispatcher(), + extensions=extensions, + result_claims={_TASKS_EXT: [_task_claim()]}, + ) + messages.append(str(exc_info.value)) + + assert messages == snapshot( + [ + "result_claims key 'com.example/tasks' has no extensions entry; a claim is only " + "advertised through its extension's capability ad", + "result_claims key 'com.example/tasks' has no extensions entry; a claim is only " + "advertised through its extension's capability ad", + ] + ) + + +def test_empty_settings_count_as_an_advertised_extension() -> None: + """SDK-defined: an extension advertised with empty settings ({}) is still an ad — + claims keyed to it construct fine.""" + session = _claims_session(_RecordingDispatcher(), _task_claim()) + + assert isinstance(session, ClientSession) + + +# ── Activation at adopt() ─────────────────────────────────────────────────── + + +def test_without_claims_the_call_tool_adapter_is_the_module_constant() -> None: + """SDK-defined: the no-extensions parse path stays byte-identical — with zero + active claims the session holds the module-level adapter itself, not a rebuild, + before and after either adopt arm.""" + session = ClientSession(dispatcher=_RecordingDispatcher()) + + assert session._call_tool_adapter is _CallToolResultAdapter + _adopt_modern(session) + assert session._call_tool_adapter is _CallToolResultAdapter + _adopt_handshake(session) + assert session._call_tool_adapter is _CallToolResultAdapter + + +@pytest.mark.anyio +@pytest.mark.parametrize("protocol_versions", [None, frozenset({LATEST_MODERN_VERSION})]) +async def test_modern_adopt_activates_claims_and_routes_claimed_results( + protocol_versions: frozenset[str] | None, +) -> None: + """SDK-defined: at a modern adopt, a claim active at the negotiated version (None = + every modern version; an explicit subset containing it) routes the claimed raw to + the claim model.""" + dispatcher = _RecordingDispatcher(tool_result=_CLAIMED_TASK_RESULT) + session = _claims_session(dispatcher, _task_claim(protocol_versions=protocol_versions)) + with anyio.fail_after(5): + async with session: + _adopt_modern(session) + result = await session.call_tool("t", {}, allow_claimed=True) + + assert isinstance(result, _TaskResult) + assert result.task_id == "t-1" + + +@pytest.mark.anyio +async def test_legacy_adopt_clears_active_claims() -> None: + """SDK-defined: re-adopt safe — a session that adopts modern then legacy fully + clears its active claims, restoring the module-level adapter by identity.""" + dispatcher = _RecordingDispatcher(tool_result=_CLAIMED_TASK_RESULT) + session = _claims_session(dispatcher, _task_claim()) + with anyio.fail_after(5): + async with session: + _adopt_modern(session) + assert isinstance(await session.call_tool("t", {}, allow_claimed=True), _TaskResult) + + _adopt_handshake(session) + assert session._call_tool_adapter is _CallToolResultAdapter + with pytest.raises(ValidationError): + await session.call_tool("t", {}, allow_claimed=True) + # The rejection came from response parsing — the request did reach the wire. + assert dispatcher.calls[-1][0] == "tools/call" + + +# ── The version-aware capability ad ───────────────────────────────────────── + + +@pytest.mark.anyio +async def test_legacy_initialize_ad_drops_claim_bearing_identifiers() -> None: + """SDK-defined: the legacy handshake can never negotiate a modern version, so no + claim can activate — a claim-bearing identifier drops from the initialize ad while + ad-only identifiers ride along.""" + dispatcher = _RecordingDispatcher() + session = ClientSession( + dispatcher=dispatcher, + extensions={_TASKS_EXT: {}, _AD_ONLY_EXT: {"flag": True}}, + result_claims={_TASKS_EXT: [_task_claim()]}, + ) + with anyio.fail_after(5): + async with session: + await session.initialize() + + [(_, params, _)] = [call for call in dispatcher.calls if call[0] == "initialize"] + assert params is not None + assert params["capabilities"]["extensions"] == {_AD_ONLY_EXT: {"flag": True}} + + +@pytest.mark.anyio +async def test_legacy_ad_omits_extensions_entirely_when_every_identifier_drops() -> None: + """SDK-defined: when the filter drops every identifier, the ad omits the + `extensions` key — an empty extensions object advertises nothing.""" + dispatcher = _RecordingDispatcher() + session = _claims_session(dispatcher, _task_claim()) + with anyio.fail_after(5): + async with session: + await session.initialize() + + [(_, params, _)] = [call for call in dispatcher.calls if call[0] == "initialize"] + assert params is not None + assert "extensions" not in params["capabilities"] + + +@pytest.mark.anyio +async def test_modern_adopt_ad_includes_active_claim_identifiers() -> None: + """SDK-defined: the modern stamp's per-request `_meta` ad includes a claim-bearing + identifier when its claims are active at the adopted version.""" + dispatcher = _RecordingDispatcher() + session = ClientSession( + dispatcher=dispatcher, + extensions={_TASKS_EXT: {}, _AD_ONLY_EXT: {"flag": True}}, + result_claims={_TASKS_EXT: [_task_claim()]}, + ) + with anyio.fail_after(5): + async with session: + _adopt_modern(session) + await session.send_ping() + + [(_, params, _)] = dispatcher.calls + assert params is not None + capabilities = params["_meta"][CLIENT_CAPABILITIES_META_KEY] + assert capabilities["extensions"] == {_TASKS_EXT: {}, _AD_ONLY_EXT: {"flag": True}} + + +@pytest.mark.anyio +async def test_discover_probe_ad_includes_claim_identifiers_at_the_probe_version() -> None: + """SDK-defined: `send_discover` builds its `_meta` ad at the probe version — modern, + so claim-bearing identifiers contribute.""" + dispatcher = _RecordingDispatcher() + session = _claims_session(dispatcher, _task_claim()) + with anyio.fail_after(5): + async with session: + await session.send_discover(LATEST_MODERN_VERSION) + + [(_, params, _)] = dispatcher.calls + assert params is not None + capabilities = params["_meta"][CLIENT_CAPABILITIES_META_KEY] + assert capabilities["extensions"] == {_TASKS_EXT: {}} + + +# ── Routing through the adopt-built adapter ───────────────────────────────── + + +@pytest.mark.anyio +@pytest.mark.parametrize("with_claims", [True, False]) +async def test_unknown_result_type_fails_validation_with_and_without_claims(with_claims: bool) -> None: + """SDK-defined: a resultType outside the active claim set routes to the core arm + and fails core validation — exactly the claim-less session's behaviour.""" + raw = {"resultType": "weird", "taskId": "t-1"} + dispatcher = _RecordingDispatcher(tool_result=raw) + session = _claims_session(dispatcher, _task_claim()) if with_claims else ClientSession(dispatcher=dispatcher) + with anyio.fail_after(5): + async with session: + _adopt_modern(session) + with pytest.raises(ValidationError): + await session.call_tool("t", {}, allow_claimed=True) + # The rejection came from response parsing — the request did reach the wire. + assert dispatcher.calls[-1][0] == "tools/call" + + +@pytest.mark.anyio +async def test_non_string_result_type_fails_core_validation_not_discrimination() -> None: + """SDK-defined: a malformed (non-string) resultType stays on the core arm — the + discriminator never uses it as a lookup key, so the failure is today's + ValidationError, not a TypeError.""" + raw: dict[str, Any] = {"resultType": {"nested": True}} + dispatcher = _RecordingDispatcher(tool_result=raw) + session = _claims_session(dispatcher, _task_claim()) + with anyio.fail_after(5): + async with session: + _adopt_modern(session) + with pytest.raises(ValidationError): + await session.call_tool("t", {}, allow_claimed=True) + # The rejection came from response parsing — the request did reach the wire. + assert dispatcher.calls[-1][0] == "tools/call" + + +def test_adopt_built_adapter_revalidates_model_instances() -> None: + """SDK-defined: pydantic hands the callable discriminator either a raw dict or an + already-built model (revalidation); both route — a claim instance to its arm, a + core instance to the core arm.""" + session = _claims_session(_RecordingDispatcher(), _task_claim()) + _adopt_modern(session) + adapter = session._call_tool_adapter + + claimed = adapter.validate_python(_TaskResult(task_id="t-2")) + assert isinstance(claimed, _TaskResult) + core = adapter.validate_python(CallToolResult(content=[])) + assert isinstance(core, CallToolResult) + + +@pytest.mark.anyio +async def test_input_required_routes_to_core_arm_with_claims_active() -> None: + """Spec-mandated: `input_required` is core vocabulary — active claims leave the + multi-round-trip arm untouched.""" + raw = {"resultType": "input_required", "requestState": "s-1"} + session = _claims_session(_RecordingDispatcher(tool_result=raw), _task_claim()) + with anyio.fail_after(5): + async with session: + _adopt_modern(session) + result = await session.call_tool("t", {}, allow_input_required=True, allow_claimed=True) + + assert isinstance(result, InputRequiredResult) + assert result.request_state == "s-1" + + +# ── allow_claimed ──────────────────────────────────────────────────────────── + + +@pytest.mark.anyio +async def test_claimed_result_raises_unexpected_claimed_result_by_default() -> None: + """SDK-defined: without `allow_claimed=True` a claimed shape raises, carrying the + parsed value — the server may have durably created state (e.g. a task), and the + carried result is how the caller reaches its id to clean up.""" + dispatcher = _RecordingDispatcher(tool_result=_CLAIMED_TASK_RESULT) + session = _claims_session(dispatcher, _task_claim()) + with anyio.fail_after(5): + async with session: + _adopt_modern(session) + with pytest.raises(UnexpectedClaimedResult) as exc_info: + await session.call_tool("t", {}) + # The shape parsed and then raised — the request did reach the wire. + assert dispatcher.calls[-1][0] == "tools/call" + + assert isinstance(exc_info.value.result, _TaskResult) + assert exc_info.value.result.task_id == "t-1" + assert str(exc_info.value) == snapshot( + "Server returned a claimed result (_TaskResult); pass the owning extension to " + "Client(extensions=[...]) for transparent resolution, or call with allow_claimed=True " + "and handle the shape. The carried result may reference server-side state needing cleanup." + ) + + +@pytest.mark.anyio +async def test_call_tool_result_path_identical_under_both_allow_claimed_values() -> None: + """SDK-defined: `allow_claimed` only affects claimed shapes — an ordinary + CallToolResult comes back identical with the flag on or off.""" + dispatcher = _RecordingDispatcher() + session = _claims_session(dispatcher, _task_claim()) + with anyio.fail_after(5): + async with session: + _adopt_modern(session) + r_default = await session.call_tool("t", {}) + r_opted = await session.call_tool("t", {}, allow_claimed=True) + + assert isinstance(r_opted, CallToolResult) + assert r_opted == r_default + + +@pytest.mark.anyio +async def test_call_tool_overload_matrix_narrows_statically() -> None: + """SDK-defined: the allow_input_required x allow_claimed overload matrix — each + combination narrows to its documented return union (assert_type is checked by + pyright; the canned CallToolResult satisfies every combination at runtime).""" + dispatcher = _RecordingDispatcher() + session = _claims_session(dispatcher, _task_claim()) + with anyio.fail_after(5): + async with session: + _adopt_modern(session) + r1 = await session.call_tool("t", {}) + assert_type(r1, CallToolResult) + r2 = await session.call_tool("t", {}, allow_input_required=True) + assert_type(r2, CallToolResult | InputRequiredResult) + r3 = await session.call_tool("t", {}, allow_claimed=True) + assert_type(r3, CallToolResult | Result) + r4 = await session.call_tool("t", {}, allow_input_required=True, allow_claimed=True) + assert_type(r4, CallToolResult | InputRequiredResult | Result) + + assert [type(r) for r in (r1, r2, r3, r4)] == [CallToolResult] * 4 + + +# ── The pinned dependency ──────────────────────────────────────────────────── + + +def test_claimed_raw_passes_v2026_tools_call_surface_validation() -> None: + """Pins the claim path's load-bearing dependency: a tools/call raw with an unknown + resultType passes `validate_server_result` at 2026-07-28 because the v2026 surface + InputRequiredResult keeps resultType open with optional fields. If mcp-types ever + tightens that surface, claimed results would be rejected before the session's + claim adapter — this failure is the signal, not a silent break.""" + validate_server_result("tools/call", LATEST_MODERN_VERSION, {"resultType": "task", "taskId": "t-1"}) diff --git a/tests/client/test_session_notification_bindings.py b/tests/client/test_session_notification_bindings.py new file mode 100644 index 000000000..900817be4 --- /dev/null +++ b/tests/client/test_session_notification_bindings.py @@ -0,0 +1,310 @@ +"""`ClientSession` notification bindings: per-binding serialized delivery through a +bounded FIFO, spawn-decoupled from the dispatcher so handlers may do session I/O +without deadlocking the in-process (DirectDispatcher) path. + +Bindings are consulted only for methods the negotiated version's core tables do +NOT know; a binding for a core-known method goes quiet, warned once at adopt(). +""" + +import logging + +import anyio +import mcp_types as types +import pytest +from mcp_types import EmptyResult, Implementation, ServerCapabilities +from mcp_types.version import LATEST_MODERN_VERSION +from pydantic import BaseModel + +from mcp.client.extension import NotificationBinding +from mcp.client.session import _NOTIFICATION_QUEUE_SIZE, ClientSession +from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair +from mcp.shared.dispatcher import DispatchContext +from mcp.shared.transport_context import TransportContext + +_VENDOR_METHOD = "notifications/vendor/task_done" + + +class _EventParams(BaseModel): + seq: int + + +async def _server_on_request( + ctx: DispatchContext[TransportContext], method: str, params: dict[str, object] | None +) -> dict[str, object]: + assert method == "ping" + return {} + + +async def _server_on_notify( + ctx: DispatchContext[TransportContext], method: str, params: dict[str, object] | None +) -> None: + raise NotImplementedError + + +def _adopt_modern(session: ClientSession) -> None: + session.adopt( + types.DiscoverResult( + supported_versions=[LATEST_MODERN_VERSION], + capabilities=ServerCapabilities(), + server_info=Implementation(name="stub", version="0"), + ) + ) + + +async def _noop_handler(params: _EventParams) -> None: + raise NotImplementedError # construction-only tests never deliver + + +def test_duplicate_binding_method_rejected() -> None: + """SDK-defined: two bindings on one wire method could not be routed apart, so + construction fails.""" + client_side, _ = create_direct_dispatcher_pair() + binding = NotificationBinding(method=_VENDOR_METHOD, params_type=_EventParams, handler=_noop_handler) + + with pytest.raises(ValueError) as exc_info: + ClientSession(dispatcher=client_side, notification_bindings=[binding, binding]) + + assert str(exc_info.value) == "duplicate notification binding for method 'notifications/vendor/task_done'" + + +@pytest.mark.anyio +async def test_bound_vendor_notifications_are_delivered_in_order() -> None: + """SDK-defined: one consumer per binding serializes delivery — events arrive at the + handler in the order the server sent them.""" + delivered: list[int] = [] + done = anyio.Event() + + async def on_event(params: _EventParams) -> None: + delivered.append(params.seq) + if params.seq == 3: + done.set() + + client_side, server_side = create_direct_dispatcher_pair() + binding = NotificationBinding(method=_VENDOR_METHOD, params_type=_EventParams, handler=on_event) + session = ClientSession(dispatcher=client_side, notification_bindings=[binding]) + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + await tg.start(server_side.run, _server_on_request, _server_on_notify) + async with session: + _adopt_modern(session) + for seq in (1, 2, 3): + await server_side.notify(_VENDOR_METHOD, {"seq": seq}) + await done.wait() + server_side.close() + + assert delivered == [1, 2, 3] + + +@pytest.mark.anyio +async def test_binding_handler_may_do_session_io_without_deadlock() -> None: + """SDK-defined: delivery is spawn-decoupled from the dispatcher, so a handler that + awaits session I/O completes even on the in-process path, where the peer's + notify() awaits `_on_notify` inline.""" + pongs: list[EmptyResult] = [] + done = anyio.Event() + + client_side, server_side = create_direct_dispatcher_pair() + + async def on_event(params: _EventParams) -> None: + pongs.append(await session.send_ping()) + done.set() + + binding = NotificationBinding(method=_VENDOR_METHOD, params_type=_EventParams, handler=on_event) + session = ClientSession(dispatcher=client_side, notification_bindings=[binding]) + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + await tg.start(server_side.run, _server_on_request, _server_on_notify) + async with session: + _adopt_modern(session) + await server_side.notify(_VENDOR_METHOD, {"seq": 1}) + await done.wait() + server_side.close() + + assert pongs == [EmptyResult()] + + +@pytest.mark.anyio +async def test_overflow_drops_oldest_event_with_a_warning(caplog: pytest.LogCaptureFixture) -> None: + """SDK-defined: the per-binding FIFO is bounded; on overflow the OLDEST queued + event is dropped with a warning and the new event is enqueued (observation + semantics tolerate the loss; enqueueing never blocks the dispatcher). + + Steps: + 1. Deliver event 0 and block the consumer inside its handler. + 2. Fill the queue with events 1.._NOTIFICATION_QUEUE_SIZE. + 3. One more event overflows: event 1 is evicted, with a warning. + 4. Release the consumer; everything still queued is delivered in order. + """ + delivered: list[int] = [] + consumer_blocked = anyio.Event() + gate = anyio.Event() + done = anyio.Event() + last_seq = _NOTIFICATION_QUEUE_SIZE + 1 + + async def on_event(params: _EventParams) -> None: + delivered.append(params.seq) + if params.seq == 0: + consumer_blocked.set() + await gate.wait() + if params.seq == last_seq: + done.set() + + client_side, server_side = create_direct_dispatcher_pair() + binding = NotificationBinding(method=_VENDOR_METHOD, params_type=_EventParams, handler=on_event) + session = ClientSession(dispatcher=client_side, notification_bindings=[binding]) + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + await tg.start(server_side.run, _server_on_request, _server_on_notify) + async with session: + _adopt_modern(session) + await server_side.notify(_VENDOR_METHOD, {"seq": 0}) + await consumer_blocked.wait() + for seq in range(1, last_seq + 1): + await server_side.notify(_VENDOR_METHOD, {"seq": seq}) + gate.set() + await done.wait() + server_side.close() + + assert delivered == [0, *range(2, last_seq + 1)] + assert caplog.text.count(f"notification queue for {_VENDOR_METHOD!r} is full") == 1 + + +@pytest.mark.anyio +async def test_invalid_params_are_warned_and_dropped_without_reaching_handler( + caplog: pytest.LogCaptureFixture, +) -> None: + """SDK-defined: params failing the binding's model are warned and dropped — + mirroring the core notification ValidationError arm — and the handler never runs + for them; later valid events still deliver.""" + delivered: list[int] = [] + done = anyio.Event() + + async def on_event(params: _EventParams) -> None: + delivered.append(params.seq) + done.set() + + client_side, server_side = create_direct_dispatcher_pair() + binding = NotificationBinding(method=_VENDOR_METHOD, params_type=_EventParams, handler=on_event) + session = ClientSession(dispatcher=client_side, notification_bindings=[binding]) + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + await tg.start(server_side.run, _server_on_request, _server_on_notify) + async with session: + _adopt_modern(session) + await server_side.notify(_VENDOR_METHOD, {"bogus": "no seq"}) + await server_side.notify(_VENDOR_METHOD, {"seq": 1}) + await done.wait() + server_side.close() + + assert delivered == [1] + assert f"Failed to validate notification: {_VENDOR_METHOD}" in caplog.text + + +@pytest.mark.anyio +async def test_unbound_vendor_notification_keeps_the_debug_drop(caplog: pytest.LogCaptureFixture) -> None: + """SDK-defined: a vendor method with no binding keeps today's behaviour — a debug + log and a silent drop.""" + caplog.set_level(logging.DEBUG, logger="client") + + client_side, server_side = create_direct_dispatcher_pair() + binding = NotificationBinding(method=_VENDOR_METHOD, params_type=_EventParams, handler=_noop_handler) + session = ClientSession(dispatcher=client_side, notification_bindings=[binding]) + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + await tg.start(server_side.run, _server_on_request, _server_on_notify) + async with session: + _adopt_modern(session) + await server_side.notify("notifications/vendor/unbound", {"seq": 1}) + server_side.close() + + assert f"dropped 'notifications/vendor/unbound': not defined at {LATEST_MODERN_VERSION}" in caplog.text + + +@pytest.mark.anyio +async def test_core_known_method_never_reaches_binding_and_warns_once_at_adopt( + caplog: pytest.LogCaptureFixture, +) -> None: + """SDK-defined: bindings are consulted only for methods core does not know at the + negotiated version — a binding for `notifications/message` goes quiet (the typed + logging callback still runs), warned exactly once at adopt().""" + logged: list[types.LoggingMessageNotificationParams] = [] + + async def logging_callback(params: types.LoggingMessageNotificationParams) -> None: + logged.append(params) + + async def on_message(params: BaseModel) -> None: + raise NotImplementedError # structurally unreachable: core parses the method first + + client_side, server_side = create_direct_dispatcher_pair() + binding = NotificationBinding(method="notifications/message", params_type=BaseModel, handler=on_message) + session = ClientSession(dispatcher=client_side, logging_callback=logging_callback, notification_bindings=[binding]) + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + await tg.start(server_side.run, _server_on_request, _server_on_notify) + async with session: + _adopt_modern(session) + # The in-process peer awaits _on_notify inline, so the typed callback ran + # by the time notify() returns. + await server_side.notify("notifications/message", {"level": "info", "data": "hello"}) + server_side.close() + + assert [params.data for params in logged] == ["hello"] + # The bound handler never ran — a delivery would have logged its NotImplementedError. + assert "notification binding handler" not in caplog.text + expected = f"notification binding for 'notifications/message' will never fire at {LATEST_MODERN_VERSION}" + assert caplog.text.count(expected) == 1 + + +@pytest.mark.anyio +async def test_handler_exception_is_contained_and_later_events_deliver(caplog: pytest.LogCaptureFixture) -> None: + """SDK-defined: a raising handler costs only that delivery — the consumer logs the + exception and keeps serving subsequent events.""" + delivered: list[int] = [] + done = anyio.Event() + + async def on_event(params: _EventParams) -> None: + if params.seq == 1: + raise ValueError("handler boom") + delivered.append(params.seq) + done.set() + + client_side, server_side = create_direct_dispatcher_pair() + binding = NotificationBinding(method=_VENDOR_METHOD, params_type=_EventParams, handler=on_event) + session = ClientSession(dispatcher=client_side, notification_bindings=[binding]) + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + await tg.start(server_side.run, _server_on_request, _server_on_notify) + async with session: + _adopt_modern(session) + await server_side.notify(_VENDOR_METHOD, {"seq": 1}) + await server_side.notify(_VENDOR_METHOD, {"seq": 2}) + await done.wait() + server_side.close() + + assert delivered == [2] + assert f"notification binding handler for {_VENDOR_METHOD!r} raised" in caplog.text + + +@pytest.mark.anyio +async def test_binding_delivery_works_without_adopt() -> None: + """SDK-defined: bindings need no negotiated version — pre-handshake sessions fall + back to the default version tables, where a vendor method is just as unknown.""" + delivered: list[int] = [] + done = anyio.Event() + + async def on_event(params: _EventParams) -> None: + delivered.append(params.seq) + done.set() + + client_side, server_side = create_direct_dispatcher_pair() + binding = NotificationBinding(method=_VENDOR_METHOD, params_type=_EventParams, handler=on_event) + session = ClientSession(dispatcher=client_side, notification_bindings=[binding]) + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + await tg.start(server_side.run, _server_on_request, _server_on_notify) + async with session: + await server_side.notify(_VENDOR_METHOD, {"seq": 7}) + await done.wait() + server_side.close() + + assert delivered == [7] From 414caeff15dc63c81be84ee6bf9d06e92c1363c0 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 30 Jun 2026 16:04:40 +0000 Subject: [PATCH 06/15] Harden claim and binding construction invariants Reject an empty claim sequence at the session constructor: the ad-filter would otherwise treat the identifier as claim-bearing and silently drop it from the capability ad at every version, leaving the invariant to caller discipline. Validate ResultClaim.method against the closed verb set at construction so an unchecked runtime value cannot fold into tools/call parsing. Create notification binding queues before the dispatcher starts so the enqueue path indexes a complete dict by construction rather than by scheduling order. Copy the extensions ad dict at the constructor boundary. Pins added: modern re-adoption after legacy reactivates claims; a legacy-version discover probe drops claim-bearing identifiers from its ad. --- src/mcp/client/extension.py | 7 ++++- src/mcp/client/session.py | 13 ++++++-- tests/client/test_extension.py | 11 ++++++- tests/client/test_session_claims.py | 48 +++++++++++++++++++++++++++++ 4 files changed, 75 insertions(+), 4 deletions(-) diff --git a/src/mcp/client/extension.py b/src/mcp/client/extension.py index 21fcd28ed..d7cdb2522 100644 --- a/src/mcp/client/extension.py +++ b/src/mcp/client/extension.py @@ -12,7 +12,7 @@ from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, get_args +from typing import TYPE_CHECKING, Any, Final, Generic, Literal, TypeVar, get_args from mcp_types import CORE_RESULT_TYPES, CallToolResult, InputRequiredResult, Result from mcp_types.version import MODERN_PROTOCOL_VERSIONS @@ -32,6 +32,9 @@ "advertise", ] +_CLAIM_METHODS: Final[frozenset[str]] = frozenset({"tools/call"}) +"""The closed set of verbs a claim may attach to (widened with the `method` Literal).""" + ClaimedT = TypeVar("ClaimedT", bound=Result) NotifyParamsT = TypeVar("NotifyParamsT", bound=BaseModel) @@ -82,6 +85,8 @@ class ResultClaim(Generic[ClaimedT]): protocol_versions: frozenset[str] | None = None def __post_init__(self) -> None: + if self.method not in _CLAIM_METHODS: + raise ValueError(f"claims attach to {sorted(_CLAIM_METHODS)} only; got method {self.method!r}") if self.result_type in CORE_RESULT_TYPES: raise ValueError(f"resultType {self.result_type!r} is core protocol vocabulary") if issubclass(self.model, CallToolResult | InputRequiredResult): diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index b2e551ba0..948ac2e0c 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -277,6 +277,11 @@ def _index_claims( f"result_claims key {identifier!r} has no extensions entry; a claim is only " "advertised through its extension's capability ad" ) + if not claims: + raise ValueError( + f"result_claims[{identifier!r}] is empty; an empty claim set would drop the " + "extension from the capability ad at every version — omit the key instead" + ) for claim in claims: key = (claim.method, claim.result_type) if key in seen: @@ -344,7 +349,7 @@ def __init__( self._client_info = client_info or DEFAULT_CLIENT_INFO self._sampling_callback = sampling_callback or _default_sampling_callback self._sampling_capabilities = sampling_capabilities - self._extensions = extensions + self._extensions = dict(extensions) if extensions is not None else None self._result_claims = _index_claims(result_claims, extensions) self._notification_bindings = _index_bindings(notification_bindings) self._active_claims: dict[str, ResultClaim[Any]] = {} @@ -387,10 +392,14 @@ async def __aenter__(self) -> Self: self._task_group = anyio.create_task_group() await self._task_group.__aenter__() try: - await self._task_group.start(self._dispatcher.run, self._on_request, self._on_notify) + # Queues exist before the dispatcher can deliver: _on_notify may run as + # soon as the dispatcher starts, and its enqueue indexes this dict. for binding in self._notification_bindings.values(): send, receive = anyio.create_memory_object_stream[BaseModel](_NOTIFICATION_QUEUE_SIZE) self._binding_queues[binding.method] = (send, receive) + await self._task_group.start(self._dispatcher.run, self._on_request, self._on_notify) + for binding in self._notification_bindings.values(): + _, receive = self._binding_queues[binding.method] self._task_group.start_soon(self._deliver_bound_notifications, binding, receive) except BaseException: # Unwind the entered task group before propagating: a cancellation diff --git a/tests/client/test_extension.py b/tests/client/test_extension.py index 8201bd3c8..927399b43 100644 --- a/tests/client/test_extension.py +++ b/tests/client/test_extension.py @@ -7,7 +7,7 @@ """ from dataclasses import FrozenInstanceError -from typing import Any, Literal +from typing import Any, Literal, cast import pytest from inline_snapshot import snapshot @@ -142,6 +142,15 @@ def test_claim_rejects_mismatched_result_type_literal() -> None: assert str(exc_info.value) == snapshot("_OtherTagResult.result_type must be Literal['task']") +def test_claim_rejects_method_outside_the_closed_verb_set() -> None: + """SDK-defined: claims attach to `tools/call` only (the Literal is the static gate); + an unchecked runtime value must not fold into tools/call parsing silently.""" + with pytest.raises(ValueError) as exc_info: + _claim(method=cast("Literal['tools/call']", "prompts/get")) + + assert str(exc_info.value) == snapshot("claims attach to ['tools/call'] only; got method 'prompts/get'") + + def test_claim_rejects_empty_protocol_versions() -> None: """SDK-defined: an empty version set could never activate; `None` is the spelling for "every modern version".""" diff --git a/tests/client/test_session_claims.py b/tests/client/test_session_claims.py index 133b826f5..403453d11 100644 --- a/tests/client/test_session_claims.py +++ b/tests/client/test_session_claims.py @@ -166,6 +166,19 @@ def test_claims_keyed_to_unadvertised_extension_rejected() -> None: ) +def test_empty_claim_sequence_rejected() -> None: + """SDK-defined: an empty claim set would make the ad-filter treat the identifier as + claim-bearing and drop it from the capability ad at every version; "claim-less" and + "advertises everywhere" stay the same thing by rejecting the empty spelling.""" + with pytest.raises(ValueError) as exc_info: + ClientSession(dispatcher=_RecordingDispatcher(), extensions={_TASKS_EXT: {}}, result_claims={_TASKS_EXT: []}) + + assert str(exc_info.value) == snapshot( + "result_claims['com.example/tasks'] is empty; an empty claim set would drop the " + "extension from the capability ad at every version — omit the key instead" + ) + + def test_empty_settings_count_as_an_advertised_extension() -> None: """SDK-defined: an extension advertised with empty settings ({}) is still an ad — claims keyed to it construct fine.""" @@ -228,6 +241,25 @@ async def test_legacy_adopt_clears_active_claims() -> None: assert dispatcher.calls[-1][0] == "tools/call" +@pytest.mark.anyio +async def test_modern_readopt_after_legacy_reactivates_claims() -> None: + """SDK-defined: adoption is re-entrant in both directions — after modern→legacy→ + modern the claims are active again and the adopt-built adapter routes claimed raws.""" + dispatcher = _RecordingDispatcher(tool_result=_CLAIMED_TASK_RESULT) + session = _claims_session(dispatcher, _task_claim()) + with anyio.fail_after(5): + async with session: + _adopt_modern(session) + _adopt_handshake(session) + assert session._call_tool_adapter is _CallToolResultAdapter + + _adopt_modern(session) + result = await session.call_tool("t", {}, allow_claimed=True) + + assert isinstance(result, _TaskResult) + assert session._call_tool_adapter is not _CallToolResultAdapter + + # ── The version-aware capability ad ───────────────────────────────────────── @@ -303,6 +335,22 @@ async def test_discover_probe_ad_includes_claim_identifiers_at_the_probe_version assert capabilities["extensions"] == {_TASKS_EXT: {}} +@pytest.mark.anyio +async def test_discover_probe_ad_drops_claim_identifiers_at_a_legacy_probe_version() -> None: + """SDK-defined: a lowlevel `send_discover` at a non-modern version string builds an + ad where no claim can be active, so the claim-bearing identifier drops coherently.""" + dispatcher = _RecordingDispatcher() + session = _claims_session(dispatcher, _task_claim()) + with anyio.fail_after(5): + async with session: + await session.send_discover(LATEST_HANDSHAKE_VERSION) + + [(_, params, _)] = dispatcher.calls + assert params is not None + capabilities = params["_meta"][CLIENT_CAPABILITIES_META_KEY] + assert "extensions" not in capabilities + + # ── Routing through the adopt-built adapter ───────────────────────────────── From 0f85f31ea8cb89b6ef7d4b827ac357081e5b24e4 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 30 Jun 2026 16:13:59 +0000 Subject: [PATCH 07/15] Fold ClientExtension instances into Client Client.extensions becomes Sequence[ClientExtension] | None: instances are validated and read once at construction (identifier guard and grammar, duplicate identifiers, cross-extension claim and binding conflicts named by their owning extensions) and folded into the session as the capability ad, the claims-by-identifier mapping, and the binding list. call_tool resolves claimed shapes transparently: the retry closure allows claimed results through the multi-round-trip driver, the owning claim's resolver finishes the call, and non-error resolver products get the same output-schema revalidation as the direct path. The dict form of extensions= is replaced; advertise() covers ad-only uses, documented in the migration guide. No extensions means a session constructed byte-identically to before. --- docs/migration.md | 41 +- docs_src/apps/tutorial001.py | 3 +- docs_src/extensions/tutorial004.py | 3 +- examples/stories/apps/README.md | 2 +- examples/stories/apps/client.py | 6 +- examples/stories/extensions/README.md | 2 +- examples/stories/extensions/client.py | 4 +- src/mcp/client/client.py | 119 ++++- tests/client/test_client_extensions.py | 530 +++++++++++++++++++++ tests/docs_src/test_apps.py | 5 +- tests/docs_src/test_extensions.py | 3 +- tests/server/mcpserver/test_extension.py | 3 +- tests/server/test_apps.py | 20 +- tests/server/test_extensions_capability.py | 5 +- 14 files changed, 712 insertions(+), 34 deletions(-) create mode 100644 tests/client/test_client_extensions.py diff --git a/docs/migration.md b/docs/migration.md index 047626ee2..c377d6742 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -469,11 +469,42 @@ extension handler can call `mcp.server.mcpserver.require_client_extension(ctx, i to reject a request with the `-32021` (missing required client capability) error when the client did not declare the extension. -Clients advertise extension support with the new `Client(extensions=...)` / -`ClientSession(extensions=...)` argument, mirrored into `ClientCapabilities.extensions`. -The extensions capability map is negotiated over `server/discover` (modern path); -a legacy `initialize` handshake does not carry it. Extensions are off by default -and never alter behaviour unless registered. +On the client, `Client(extensions=...)` takes a sequence of +`mcp.client.ClientExtension` instances. A client extension contributes its +capability ad (mirrored into `ClientCapabilities.extensions`), its result +claims (extra `tools/call` result shapes that `Client.call_tool` resolves +transparently through the claim's resolver), and its notification bindings +(handlers for vendor server notifications). The capability map rides +`server/discover` and every modern request's `_meta` envelope; a legacy +`initialize` handshake carries only the claim-less identifiers, since claimed +result shapes cannot be delivered on a legacy wire. Extensions are off by +default and never alter behaviour unless registered. (The low-level +`ClientSession(extensions=...)` keeps the raw identifier-to-settings dict.) + +Changed in the v2 pre-releases: earlier alphas took +`Client(extensions={identifier: settings})`, an advertisement-only dict. +Extensions now contribute behaviour — claims and notification handlers — not +just an ad, and a sequence of declaration objects is the shape that can carry +that. An ad-only entry becomes an `advertise()` call: + +**Before (v2 alphas):** + +```python +client = Client(server, extensions={"com.example/ui": {"mimeTypes": [...]}}) +``` + +**After:** + +```python +from mcp.client import advertise + +client = Client(server, extensions=[advertise("com.example/ui", {"mimeTypes": [...]})]) +``` + +`advertise()` is only for identifiers with no client-side behaviour. +For a behavioural extension — e.g. tasks, once its extension ships — construct +that extension's object instead; advertising an identifier you do not +implement asserts wire support you don't have. ### `McpError` renamed to `MCPError` diff --git a/docs_src/apps/tutorial001.py b/docs_src/apps/tutorial001.py index 79721c597..27d9eded8 100644 --- a/docs_src/apps/tutorial001.py +++ b/docs_src/apps/tutorial001.py @@ -1,4 +1,5 @@ from mcp import Client +from mcp.client import advertise from mcp.server.apps import APP_MIME_TYPE, EXTENSION_ID, Apps, client_supports_apps from mcp.server.mcpserver import MCPServer from mcp.server.mcpserver.context import Context @@ -32,7 +33,7 @@ def get_time(ctx: Context) -> str: async def main() -> None: - async with Client(mcp, extensions={EXTENSION_ID: {"mimeTypes": [APP_MIME_TYPE]}}) as client: + async with Client(mcp, extensions=[advertise(EXTENSION_ID, {"mimeTypes": [APP_MIME_TYPE]})]) as client: result = await client.call_tool("get_time", {}) print(result.content) # [TextContent(text='2026-06-26T12:00:00Z')] diff --git a/docs_src/extensions/tutorial004.py b/docs_src/extensions/tutorial004.py index b08840705..7ad32052d 100644 --- a/docs_src/extensions/tutorial004.py +++ b/docs_src/extensions/tutorial004.py @@ -5,6 +5,7 @@ from pydantic import Field from mcp import Client +from mcp.client import advertise from mcp.server.context import ServerRequestContext from mcp.server.extension import Extension, MethodBinding from mcp.server.mcpserver import MCPServer, require_client_extension @@ -51,7 +52,7 @@ def methods(self) -> Sequence[MethodBinding]: async def main() -> None: - async with Client(mcp, extensions={EXTENSION_ID: {}}) as client: + async with Client(mcp, extensions=[advertise(EXTENSION_ID)]) as client: request = SearchRequest(params=SearchParams(query="mcp", limit=3)) result = await client.session.send_request(request, SearchResult) print(result.items) diff --git a/examples/stories/apps/README.md b/examples/stories/apps/README.md index dc180a0d3..40414737e 100644 --- a/examples/stories/apps/README.md +++ b/examples/stories/apps/README.md @@ -26,7 +26,7 @@ uv run python -m stories.apps.client --http `text/html;profile=mcp-app`. - `server.py` `client_supports_apps(ctx)` — SEP-2133 graceful degradation: a client that did not negotiate Apps gets a text-only result. -- `client.py` `Client(target, extensions={...})` — the client advertises Apps +- `client.py` `Client(target, extensions=[advertise(...)])` — the client advertises Apps support so the server returns the UI-enabled result, then reads the tool's `_meta.ui.resourceUri` and fetches that resource. diff --git a/examples/stories/apps/client.py b/examples/stories/apps/client.py index 8a238f469..dd79071b1 100644 --- a/examples/stories/apps/client.py +++ b/examples/stories/apps/client.py @@ -2,7 +2,7 @@ from mcp_types import TextContent, TextResourceContents -from mcp.client import Client +from mcp.client import Client, advertise from mcp.server.apps import APP_MIME_TYPE, EXTENSION_ID from stories._harness import Target, run_client @@ -10,7 +10,9 @@ async def main(target: Target, *, mode: str = "auto") -> None: # Advertise MCP Apps support so the server returns the UI-enabled result; a # client that omits this gets the text-only fallback (graceful degradation). - async with Client(target, mode=mode, extensions={EXTENSION_ID: {"mimeTypes": [APP_MIME_TYPE]}}) as client: + async with Client( + target, mode=mode, extensions=[advertise(EXTENSION_ID, {"mimeTypes": [APP_MIME_TYPE]})] + ) as client: # The extensions capability map rides `server/discover` (modern only). On a # legacy connection (today's stdio) it is absent, so assert it only when present. if client.server_capabilities.extensions is not None: diff --git a/examples/stories/extensions/README.md b/examples/stories/extensions/README.md index 6d3da72c9..4668f990d 100644 --- a/examples/stories/extensions/README.md +++ b/examples/stories/extensions/README.md @@ -24,7 +24,7 @@ uv run python -m stories.extensions.client --http rejects clients that did not declare the extension with `-32021` (missing required client capability) and a machine-readable `requiredCapabilities` payload. -- `client.py` `Client(target, extensions={EXTENSION_ID: {}})` — the client-side +- `client.py` `Client(target, extensions=[advertise(EXTENSION_ID)])` — the client-side half of the negotiation; on 2026-07-28 it travels in the per-request `_meta` envelope. - `client.py` `client.session.send_request(...)` — vendor methods have no diff --git a/examples/stories/extensions/client.py b/examples/stories/extensions/client.py index 849586f6f..0bb033d7a 100644 --- a/examples/stories/extensions/client.py +++ b/examples/stories/extensions/client.py @@ -5,7 +5,7 @@ import mcp_types as types from mcp_types import TextContent -from mcp.client import Client +from mcp.client import Client, advertise from stories._harness import Target, run_client EXTENSION_ID = "com.example/catalog" @@ -28,7 +28,7 @@ class SearchResult(types.Result): async def main(target: Target, *, mode: str = "auto") -> None: # Declare the extension client-side so the server's `require_client_extension` # gate on `com.example/search` passes. - async with Client(target, mode=mode, extensions={EXTENSION_ID: {}}) as client: + async with Client(target, mode=mode, extensions=[advertise(EXTENSION_ID)]) as client: # The extensions capability map rides `server/discover` (modern only). On a # legacy connection it is absent, so assert it only when present. if client.server_capabilities.extensions is not None: diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 2b0cabe61..7b02d3466 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -5,7 +5,7 @@ import hashlib import logging import uuid -from collections.abc import Awaitable, Callable, Mapping +from collections.abc import Awaitable, Callable, Mapping, Sequence from contextlib import AsyncExitStack from dataclasses import KW_ONLY, dataclass, field from typing import Any, Literal, TypeVar, cast @@ -36,6 +36,7 @@ ReadResourceResult, RequestParamsMeta, ResourceTemplateReference, + Result, ServerCapabilities, ) from mcp_types.version import HANDSHAKE_PROTOCOL_VERSIONS, MODERN_PROTOCOL_VERSIONS @@ -46,6 +47,7 @@ from mcp.client._probe import negotiate_auto from mcp.client._transport import Transport from mcp.client.caching import CacheConfig, CacheMode, ClientResponseCache, InMemoryResponseCacheStore +from mcp.client.extension import ClaimContext, ClientExtension, NotificationBinding, ResultClaim from mcp.client.session import ( ClientRequestContext, ClientSession, @@ -62,6 +64,7 @@ from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair from mcp.shared.dispatcher import Dispatcher, ProgressFnT from mcp.shared.exceptions import MCPDeprecationWarning, MCPError +from mcp.shared.extension import validate_extension_identifier from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher from mcp.shared.session import RequestResponder @@ -188,6 +191,70 @@ async def _no_inbound_client_notifications(_dctx: Any, _method: str, _params: Ma """ +@dataclass(frozen=True) +class _FoldedExtensions: + """`Client.extensions` instances folded into the shapes `ClientSession` consumes.""" + + ad: dict[str, dict[str, Any]] | None + claims: dict[str, tuple[ResultClaim[Any], ...]] | None + bindings: tuple[NotificationBinding[Any], ...] | None + by_model: Mapping[type[Result], ResultClaim[Any]] + + +def _fold_extensions(extensions: Sequence[ClientExtension] | None) -> _FoldedExtensions: + """Validate extension instances and fold their contributions, once, at `Client` construction. + + Mirrors the server's consumption-time posture (`MCPServer._apply_extension`): a + per-instance identifier is validated here because no class attribute existed to + validate at definition time. `settings()` is read exactly once per extension and + the returned dict is held by reference. Duplicate `(method, resultType)` claims + and duplicate notification methods are rejected here, where both owning extensions + can be named — the session's own duplicate checks know only methods and tags. + """ + if not extensions: + return _FoldedExtensions(ad=None, claims=None, bindings=None, by_model={}) + ad: dict[str, dict[str, Any]] = {} + claims: dict[str, tuple[ResultClaim[Any], ...]] = {} + bindings: list[NotificationBinding[Any]] = [] + by_model: dict[type[Result], ResultClaim[Any]] = {} + claim_owners: dict[tuple[str, str], str] = {} + binding_owners: dict[str, str] = {} + for extension in extensions: + identifier = getattr(extension, "identifier", None) + if identifier is None: + raise ValueError( + f"{type(extension).__name__} has no `identifier`; a ClientExtension must set the " + "`identifier` class attribute (or assign one in `__init__`) before it can be used" + ) + validate_extension_identifier(identifier, owner=type(extension).__name__) + if identifier in ad: + raise ValueError(f"extension identifier {identifier!r} is passed more than once") + ad[identifier] = extension.settings() + extension_claims = tuple(extension.claims()) + for claim in extension_claims: + key = (claim.method, claim.result_type) + if key in claim_owners: + raise ValueError( + f"extensions {claim_owners[key]!r} and {identifier!r} both claim {claim.method!r} " + f"resultType {claim.result_type!r}; a wire tag can have only one resolver" + ) + claim_owners[key] = identifier + # Collision-free by construction: a model's `result_type` Literal pins it to + # exactly one tag, and each (method, tag) pair has exactly one owner. + by_model[claim.model] = claim + if extension_claims: + claims[identifier] = extension_claims + for binding in extension.notifications(): + if binding.method in binding_owners: + raise ValueError( + f"extensions {binding_owners[binding.method]!r} and {identifier!r} both bind " + f"notification method {binding.method!r}; a method can have only one observer" + ) + binding_owners[binding.method] = identifier + bindings.append(binding) + return _FoldedExtensions(ad=ad, claims=claims or None, bindings=tuple(bindings) or None, by_model=by_model) + + @dataclass class Client: """A high-level MCP client for connecting to MCP servers. @@ -268,9 +335,16 @@ async def main(): `read_resource` give up. Use `client.session.(..., allow_input_required=True)` to drive the loop manually instead.""" - extensions: dict[str, dict[str, Any]] | None = None - """SEP-2133 extension support to advertise under `ClientCapabilities.extensions` - (identifier -> settings), e.g. `{"io.modelcontextprotocol/ui": {"mimeTypes": [...]}}`.""" + extensions: Sequence[ClientExtension] | None = None + """Opt-in client extensions (SEP-2133). + + Each instance contributes its capability ad (advertised under + `ClientCapabilities.extensions`), its result claims (extra `tools/call` result + shapes that `call_tool` resolves transparently through the claim's resolver), + and its notification bindings. For an ad-only entry — an identifier plus + settings, no client-side behaviour — use `mcp.client.advertise(identifier, + settings)`. Each extension's `settings()` is read once, at construction; the + returned dict is held by reference.""" cache: CacheConfig | Literal[False] | None = None """Client-side response caching for the SEP-2549 cacheable methods (2026-07-28). @@ -286,6 +360,7 @@ async def main(): _exit_stack: AsyncExitStack | None = field(init=False, default=None) _connect: _Connector = field(init=False, repr=False, compare=False) _response_cache: ClientResponseCache | None = field(init=False, default=None, repr=False, compare=False) + _folded_extensions: _FoldedExtensions = field(init=False, repr=False, compare=False) def __post_init__(self) -> None: if self.mode not in ("legacy", "auto") and self.mode not in MODERN_PROTOCOL_VERSIONS: @@ -298,6 +373,8 @@ def __post_init__(self) -> None: f"mode must be 'legacy', 'auto', or one of {list(MODERN_PROTOCOL_VERSIONS)}; got {self.mode!r}{hint}" ) + self._folded_extensions = _fold_extensions(self.extensions) + srv = self.server if isinstance(srv, MCPServer): srv = srv._lowlevel_server # pyright: ignore[reportPrivateUsage] @@ -348,7 +425,9 @@ async def _build_session(self, exit_stack: AsyncExitStack) -> ClientSession: message_handler=message_handler, client_info=self.client_info, elicitation_callback=self.elicitation_callback, - extensions=self.extensions, + extensions=self._folded_extensions.ad, + result_claims=self._folded_extensions.claims, + notification_bindings=self._folded_extensions.bindings, ) async def __aenter__(self) -> Client: @@ -611,6 +690,13 @@ async def call_tool( persist `request_state` across process restarts — use `client.session.call_tool(..., allow_input_required=True)`. + If the server returns a result shape claimed by one of this client's + `extensions`, the owning claim's resolver finishes the call and its + `CallToolResult` is returned — the claimed shape never surfaces here. + Resolver exceptions propagate as-is; the extension owns its error + vocabulary. To receive the claimed shape yourself, use + `client.session.call_tool(..., allow_claimed=True)`. + Args: name: The name of the tool to call. arguments: Arguments to pass to the tool. @@ -629,7 +715,7 @@ async def call_tool( MCPError: A callback returned `ErrorData` for an embedded input request. """ - async def retry(r: InputResponses | None, s: str | None) -> CallToolResult | InputRequiredResult: + async def retry(r: InputResponses | None, s: str | None) -> CallToolResult | InputRequiredResult | Result: return await self.session.call_tool( name, arguments, @@ -639,9 +725,28 @@ async def retry(r: InputResponses | None, s: str | None) -> CallToolResult | Inp request_state=s, meta=meta, allow_input_required=True, + # The driver's retry leg must also admit claimed shapes — the spec + # resolves multi-round-trip input before a claimed result, so a claim + # may terminate any round, not just the first. + allow_claimed=True, ) - return await self._drive_input_required(await retry(input_responses, request_state), retry) + result = await self._drive_input_required(await retry(input_responses, request_state), retry) + if isinstance(result, CallToolResult): + return result + # Only claimed shapes escape the parse (`_drive_input_required` never returns an + # `InputRequiredResult`), so the lookup is total; a KeyError here is an SDK bug. + claim = self._folded_extensions.by_model[type(result)] + final = await claim.resolve( + result, + ClaimContext(session=self.session, tool_name=name, read_timeout_seconds=read_timeout_seconds), + ) + if not final.is_error: + # The resolver's product gets the same output-schema revalidation as the + # direct path (`ClientSession.call_tool`'s own guard); isError results + # must not raise, also matching the direct path. + await self.session.validate_tool_result(name, final) + return final async def list_prompts( self, diff --git a/tests/client/test_client_extensions.py b/tests/client/test_client_extensions.py new file mode 100644 index 000000000..a67c2eec4 --- /dev/null +++ b/tests/client/test_client_extensions.py @@ -0,0 +1,530 @@ +"""`Client` + `ClientExtension` integration: folding extension declarations into the +session at construction, and `call_tool` driving claim resolvers transparently. + +Claimed-shape servers here are real `MCPServer`s whose SEP-2133 server extension +rewrites `tools/call` results via `intercept_tool_call` — the full public-API loop. +The in-process server can only deliver claimed fields the v2026 tools/call surface +keeps (`resultType`, `requestState`, `inputRequests`, `_meta`): the server-side +`serialize_server_result` drops anything else, so claimed payloads here ride +`requestState`. + +`tools/call` is never cached (`Client.call_tool` has no `_cached_fetch` weave and the +SEP-2549 cacheable verbs do not include it), so the claim path needs no cache tests. +""" + +import logging +from collections.abc import Awaitable, Callable, Sequence +from typing import Any, Literal + +import anyio +import mcp_types as types +import pytest +from inline_snapshot import snapshot +from mcp_types import CallToolResult, Result, TextContent +from mcp_types.version import LATEST_MODERN_VERSION +from pydantic import BaseModel +from typing_extensions import assert_type + +from mcp.client import ClaimContext, ClientExtension, NotificationBinding, ResultClaim, advertise +from mcp.client.client import Client +from mcp.client.session import ClientRequestContext, _CallToolResultAdapter +from mcp.server import Server, ServerRequestContext +from mcp.server.context import CallNext, HandlerResult +from mcp.server.extension import Extension +from mcp.server.mcpserver import Context, MCPServer + +pytestmark = pytest.mark.anyio + +_VOUCHER_EXT = "com.example/voucher" +_RIVAL_EXT = "com.example/rival" + +_NAME_SCHEMA = {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]} + + +def _name_elicitation() -> types.ElicitRequest: + return types.ElicitRequest( + params=types.ElicitRequestFormParams(message="What is your name?", requested_schema=_NAME_SCHEMA) + ) + + +class VoucherResult(Result): + """The claimed `tools/call` shape, tagged `voucher`; its payload rides `requestState` + (the only open payload-bearing field the in-process server's surface dump keeps).""" + + result_type: Literal["voucher"] = "voucher" + request_state: str | None = None + + +_Resolver = Callable[[VoucherResult, ClaimContext], Awaitable[CallToolResult]] + + +class _VoucherExtension(ClientExtension): + """Client half: claims the `voucher` tag with the supplied resolver.""" + + identifier = _VOUCHER_EXT + + def __init__(self, resolve: _Resolver) -> None: + self._resolve = resolve + + def claims(self) -> Sequence[ResultClaim[Any]]: + return [ResultClaim(result_type="voucher", model=VoucherResult, resolve=self._resolve)] + + +class _VoucherIssuer(Extension): + """Server half: rewrites every `tools/call` result into the vendor-claimed shape.""" + + identifier = _VOUCHER_EXT + + async def intercept_tool_call( + self, params: types.CallToolRequestParams, ctx: ServerRequestContext[Any, Any], call_next: CallNext + ) -> HandlerResult: + return {"resultType": "voucher", "requestState": "v-42"} + + +class _TwoRoundVoucherIssuer(Extension): + """Server half: demands input on the first round, then issues the claimed shape.""" + + identifier = _VOUCHER_EXT + + async def intercept_tool_call( + self, params: types.CallToolRequestParams, ctx: ServerRequestContext[Any, Any], call_next: CallNext + ) -> HandlerResult: + if params.input_responses is None: + return types.InputRequiredResult(input_requests={"user_name": _name_elicitation()}) + return {"resultType": "voucher", "requestState": "after-input"} + + +def _voucher_server(issuer: Extension | None = None) -> MCPServer: + """An `MCPServer` whose `issue` tool the server extension rewrites into the claimed shape.""" + server = MCPServer("vouchers", extensions=[issuer if issuer is not None else _VoucherIssuer()]) + + @server.tool() + def issue() -> CallToolResult: + """Issue a voucher.""" + raise NotImplementedError # the server extension short-circuits before the tool runs + + return server + + +def _structured_voucher_server() -> MCPServer: + """Like `_voucher_server`, but `issue` declares an output schema (`-> str`).""" + server = MCPServer("vouchers", extensions=[_VoucherIssuer()]) + + @server.tool() + def issue() -> str: + """Issue a voucher.""" + raise NotImplementedError # the server extension short-circuits before the tool runs + + return server + + +def _add_server() -> MCPServer: + """A plain claim-less server with one ordinary tool.""" + server = MCPServer("plain") + + @server.tool() + def add(a: int, b: int) -> int: + """Add two integers.""" + return a + b + + return server + + +# ── Construction-time validation ──────────────────────────────────────────── + + +def test_bare_extension_instance_is_rejected_with_the_fix_named() -> None: + """SDK-defined: an instance whose class never set `identifier` fails Client + construction with an error naming the type and the fix — not an AttributeError.""" + with pytest.raises(ValueError) as exc_info: + Client(_add_server(), extensions=[ClientExtension()]) + + assert str(exc_info.value) == snapshot( + "ClientExtension has no `identifier`; a ClientExtension must set the `identifier` " + "class attribute (or assign one in `__init__`) before it can be used" + ) + + +class _SelfAssignedBadId(ClientExtension): + """Assigns a malformed identifier in `__init__` — invisible at class definition.""" + + def __init__(self) -> None: + self.identifier = "not-prefixed" + + +def test_invalid_per_instance_identifier_raises_the_validators_error() -> None: + """SDK-defined: per-instance identifiers are validated when the Client consumes the + extension (no class attribute existed at definition time, mirroring the server's + posture); the shared validator's TypeError surfaces unwrapped.""" + with pytest.raises(TypeError) as exc_info: + Client(_add_server(), extensions=[_SelfAssignedBadId()]) + + assert str(exc_info.value) == snapshot( + "_SelfAssignedBadId.identifier must be a `vendor-prefix/name` string " + "(reverse-DNS prefix required), got 'not-prefixed'" + ) + + +def test_duplicate_extension_identifiers_are_rejected_naming_the_identifier() -> None: + """SDK-defined: one identifier cannot appear twice — there would be two settings + dicts for one capability-ad key.""" + with pytest.raises(ValueError) as exc_info: + Client(_add_server(), extensions=[advertise(_VOUCHER_EXT), advertise(_VOUCHER_EXT, {"a": 1})]) + + assert str(exc_info.value) == snapshot("extension identifier 'com.example/voucher' is passed more than once") + + +async def _unreachable_resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: + raise NotImplementedError # construction-only extensions never resolve + + +class _RivalVoucherExtension(ClientExtension): + """A second extension claiming the same `voucher` tag (construction-conflict tests).""" + + identifier = _RIVAL_EXT + + def claims(self) -> Sequence[ResultClaim[Any]]: + return [ResultClaim(result_type="voucher", model=VoucherResult, resolve=_unreachable_resolve)] + + +def test_conflicting_claims_across_extensions_name_both_owners() -> None: + """SDK-defined: two extensions claiming the same (method, resultType) fail at + Client construction with both owning extensions named — the session's own + duplicate check knows only the method and tag, which cannot tell a user which + two of their extensions collide.""" + with pytest.raises(ValueError) as exc_info: + Client(_add_server(), extensions=[_VoucherExtension(_unreachable_resolve), _RivalVoucherExtension()]) + + assert str(exc_info.value) == snapshot( + "extensions 'com.example/voucher' and 'com.example/rival' both claim 'tools/call' " + "resultType 'voucher'; a wire tag can have only one resolver" + ) + + +class _EventParams(BaseModel): + seq: int + + +async def _unreachable_handler(params: _EventParams) -> None: + raise NotImplementedError # construction-only extensions never deliver + + +class _ObserverA(ClientExtension): + identifier = "com.example/observer-a" + + def notifications(self) -> Sequence[NotificationBinding[Any]]: + return [ + NotificationBinding( + method="notifications/vendor/event", params_type=_EventParams, handler=_unreachable_handler + ) + ] + + +class _ObserverB(ClientExtension): + identifier = "com.example/observer-b" + + def notifications(self) -> Sequence[NotificationBinding[Any]]: + return [ + NotificationBinding( + method="notifications/vendor/event", params_type=_EventParams, handler=_unreachable_handler + ) + ] + + +def test_conflicting_notification_bindings_name_both_owners() -> None: + """SDK-defined: two extensions binding the same notification method fail at Client + construction with both owning extensions named, for the same reason as claims.""" + with pytest.raises(ValueError) as exc_info: + Client(_add_server(), extensions=[_ObserverA(), _ObserverB()]) + + assert str(exc_info.value) == snapshot( + "extensions 'com.example/observer-a' and 'com.example/observer-b' both bind " + "notification method 'notifications/vendor/event'; a method can have only one observer" + ) + + +# ── settings() consumption ─────────────────────────────────────────────────── + + +class _CountingSettings(ClientExtension): + """Counts `settings()` reads to pin the read-once contract.""" + + identifier = "com.example/counted" + + def __init__(self) -> None: + self.reads = 0 + + def settings(self) -> dict[str, Any]: + self.reads += 1 + return {"read": self.reads} + + +async def test_settings_is_read_exactly_once_at_construction() -> None: + """SDK-defined: `settings()` is read once, at Client construction — connecting and + calling tools (each modern request re-stamps the capability ad) never re-reads it.""" + extension = _CountingSettings() + client = Client(_add_server(), extensions=[extension]) + assert extension.reads == 1 + + with anyio.fail_after(5): + async with client: + await client.call_tool("add", {"a": 1, "b": 2}) + await client.call_tool("add", {"a": 3, "b": 4}) + + assert extension.reads == 1 + + +async def test_settings_dict_is_held_by_reference_not_copied() -> None: + """SDK-defined: the dict `settings()` returns is held by reference, not copied — + mutating it between construction and connect changes the advertised ad (the same + aliasing the dict-form `extensions=` argument had).""" + observed: list[dict[str, dict[str, Any]] | None] = [] + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "probe" + assert ctx.session.client_params is not None + observed.append(ctx.session.client_params.capabilities.extensions) + return CallToolResult(content=[]) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="probe", input_schema={"type": "object"})]) + + server = Server("probe", on_call_tool=call_tool, on_list_tools=list_tools) + settings = {"tier": "bronze"} + client = Client(server, extensions=[advertise("com.example/loyalty", settings)]) + settings["tier"] = "gold" + + with anyio.fail_after(5): + async with client: + await client.call_tool("probe", {}) + + assert observed == [{"com.example/loyalty": {"tier": "gold"}}] + + +# ── extensions=None stays byte-identical ───────────────────────────────────── + + +@pytest.mark.parametrize("extensions", [None, ()], ids=["none", "empty"]) +async def test_no_extensions_keeps_tools_call_parsing_byte_identical( + extensions: Sequence[ClientExtension] | None, +) -> None: + """SDK-defined: `extensions=None` (and an empty sequence) leave the session exactly + as a claim-less client's — the tools/call adapter is the module-level constant by + identity, and an ordinary call behaves as before.""" + with anyio.fail_after(5): + async with Client(_add_server(), extensions=extensions) as client: + assert client.session._call_tool_adapter is _CallToolResultAdapter + result = await client.call_tool("add", {"a": 1, "b": 2}) + + assert result.structured_content == {"result": 3} + + +# ── The transparent claim path ─────────────────────────────────────────────── + + +async def test_claimed_result_resolves_transparently_to_the_resolvers_result() -> None: + """A server-claimed `tools/call` shape never surfaces: the owning claim's resolver + receives the parsed claim model and `Client.call_tool` returns the resolver's + `CallToolResult` object — the signature stays `-> CallToolResult` (the assert_type + below is checked by pyright).""" + received: list[VoucherResult] = [] + produced: list[CallToolResult] = [] + + async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: + received.append(claimed) + product = CallToolResult(content=[TextContent(text=f"honored {claimed.request_state}")]) + produced.append(product) + return product + + with anyio.fail_after(5): + async with Client(_voucher_server(), extensions=[_VoucherExtension(resolve)]) as client: + result = await client.call_tool("issue", {}) + assert_type(result, CallToolResult) + + assert [claimed.request_state for claimed in received] == ["v-42"] + assert result is produced[0] + assert result.content == [TextContent(text="honored v-42")] + + +async def test_resolver_product_gets_the_direct_paths_output_schema_revalidation() -> None: + """The resolver's product passes through `validate_tool_result` exactly like a + directly-returned result: against the tool's output schema, missing structured + content raises the direct path's RuntimeError (the message below is the same + one `ClientSession.call_tool`'s own guard produces).""" + + async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: + return CallToolResult(content=[TextContent(text="unstructured")]) + + async with Client(_structured_voucher_server(), extensions=[_VoucherExtension(resolve)]) as client: + with anyio.fail_after(5), pytest.raises(RuntimeError) as exc_info: + await client.call_tool("issue", {}) + + assert str(exc_info.value) == snapshot("Tool issue has an output schema but did not return structured content") + + +async def test_resolver_error_result_is_returned_not_raised() -> None: + """An `isError` resolver product skips output-schema revalidation and comes back + as-is — the same strictness as the direct path, which only revalidates successes. + The tool here declares an output schema, so revalidating would have raised.""" + + async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: + return CallToolResult(content=[TextContent(text="voucher printer on fire")], is_error=True) + + with anyio.fail_after(5): + async with Client(_structured_voucher_server(), extensions=[_VoucherExtension(resolve)]) as client: + result = await client.call_tool("issue", {}) + + assert result.is_error + assert result.content == [TextContent(text="voucher printer on fire")] + + +async def test_resolver_receives_the_calls_claim_context() -> None: + """`ClaimContext` hands the resolver the client's own session object, the tool + name, and the per-call read timeout `call_tool` was given.""" + contexts: list[ClaimContext] = [] + + async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: + contexts.append(ctx) + return CallToolResult(content=[]) + + with anyio.fail_after(5): + async with Client(_voucher_server(), extensions=[_VoucherExtension(resolve)]) as client: + await client.call_tool("issue", {}, read_timeout_seconds=7.0) + [ctx] = contexts + assert ctx.session is client.session + + assert ctx.tool_name == "issue" + assert ctx.read_timeout_seconds == 7.0 + + +class _VoucherRefused(Exception): + """Extension-owned error vocabulary.""" + + +async def test_resolver_exception_propagates_untouched() -> None: + """A resolver exception reaches the `call_tool` caller as the very object the + resolver raised — no wrapping, the extension owns its error vocabulary.""" + refusal = _VoucherRefused("the voucher is refused") + + async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: + raise refusal + + async with Client(_voucher_server(), extensions=[_VoucherExtension(resolve)]) as client: + with anyio.fail_after(5), pytest.raises(_VoucherRefused) as exc_info: + await client.call_tool("issue", {}) + + assert exc_info.value is refusal + + +# ── Unclaimed results with extensions present ──────────────────────────────── + + +async def test_unclaimed_result_flows_through_unchanged_with_extensions_present() -> None: + """An ordinary `CallToolResult` is untouched by the claim machinery — the resolver + never runs and the result matches a claim-less client's.""" + + async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: + raise NotImplementedError # this server never produces a claimed shape + + with anyio.fail_after(5): + async with Client(_add_server(), extensions=[_VoucherExtension(resolve)]) as client: + result = await client.call_tool("add", {"a": 1, "b": 2}) + + assert result.structured_content == {"result": 3} + + +async def test_input_required_then_plain_result_keeps_the_auto_loop_working() -> None: + """With a claim-bearing extension present, the auto loop on an unclaimed tool is + unchanged: input_required resolves via the elicitation callback and the plain + terminal result comes back; the resolver never runs.""" + server = MCPServer("mrtr") + + @server.tool() + async def greet(ctx: Context) -> str | types.InputRequiredResult: + responses = ctx.input_responses + if responses and "user_name" in responses: + answer = responses["user_name"] + assert isinstance(answer, types.ElicitResult) + assert answer.content is not None + return f"Hello, {answer.content['name']}!" + return types.InputRequiredResult(input_requests={"user_name": _name_elicitation()}) + + async def elicitation_callback( + context: ClientRequestContext, params: types.ElicitRequestParams + ) -> types.ElicitResult | types.ErrorData: + return types.ElicitResult(action="accept", content={"name": "Ada"}) + + async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: + raise NotImplementedError # this server never produces a claimed shape + + with anyio.fail_after(5): + async with Client( + server, elicitation_callback=elicitation_callback, extensions=[_VoucherExtension(resolve)] + ) as client: + result = await client.call_tool("greet") + + assert result.content == [TextContent(text="Hello, Ada!")] + + +# ── The multi-round-trip + claimed interplay ───────────────────────────────── + + +async def test_input_required_then_claimed_result_on_retry_resolves_transparently() -> None: + """The retry-leg regression: a call that demands input first and returns a claimed + shape on the retry still resolves transparently. The driver's retry must admit + claimed shapes — multi-round-trip input resolves before a claimed result, so a + claim may terminate any round, not just the first.""" + prompted: list[str] = [] + received: list[VoucherResult] = [] + + async def elicitation_callback( + context: ClientRequestContext, params: types.ElicitRequestParams + ) -> types.ElicitResult | types.ErrorData: + assert isinstance(params, types.ElicitRequestFormParams) + prompted.append(params.message) + return types.ElicitResult(action="accept", content={"name": "Ada"}) + + async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: + received.append(claimed) + return CallToolResult(content=[TextContent(text=f"honored {claimed.request_state}")]) + + server = _voucher_server(issuer=_TwoRoundVoucherIssuer()) + with anyio.fail_after(5): + async with Client( + server, elicitation_callback=elicitation_callback, extensions=[_VoucherExtension(resolve)] + ) as client: + result = await client.call_tool("issue", {}) + + assert prompted == ["What is your name?"] + assert [claimed.request_state for claimed in received] == ["after-input"] + assert result.content == [TextContent(text="honored after-input")] + + +# ── Notification bindings fold into the session ────────────────────────────── + + +class _CoreMethodObserver(ClientExtension): + """Binds a method the modern core tables already define (construction-legal; the + session warns once at adopt that it can never fire).""" + + identifier = "com.example/observer" + + def notifications(self) -> Sequence[NotificationBinding[Any]]: + return [ + NotificationBinding(method="notifications/message", params_type=_EventParams, handler=_unreachable_handler) + ] + + +async def test_notification_bindings_fold_into_the_session(caplog: pytest.LogCaptureFixture) -> None: + """The Client threads extension notification bindings into its session: a binding + for a core-known method draws the session's one-time gone-quiet warning at adopt. + (Delivery mechanics are session-tier covered in + test_session_notification_bindings.py; this pins the Client fold seam.)""" + with caplog.at_level(logging.WARNING, logger="client"): + async with Client(_add_server(), extensions=[_CoreMethodObserver()]): + pass + + expected = f"notification binding for 'notifications/message' will never fire at {LATEST_MODERN_VERSION}" + assert caplog.text.count(expected) == 1 diff --git a/tests/docs_src/test_apps.py b/tests/docs_src/test_apps.py index 02375f97a..8b692fcc9 100644 --- a/tests/docs_src/test_apps.py +++ b/tests/docs_src/test_apps.py @@ -7,6 +7,7 @@ from docs_src.apps import tutorial001, tutorial002, tutorial003 from mcp import Client +from mcp.client import advertise from mcp.server.apps import APP_MIME_TYPE, EXTENSION_ID # See test_index.py for why this is a per-module mark and not a conftest hook. @@ -34,7 +35,9 @@ async def test_the_ui_resource_is_served_as_the_app_mime_type() -> None: async def test_one_tool_two_answers() -> None: """tutorial001: the canonical degradation pattern: raw data for a client that negotiated Apps, a human sentence for one that did not.""" - async with Client(tutorial001.mcp, extensions={EXTENSION_ID: {"mimeTypes": [APP_MIME_TYPE]}}) as ui_client: + async with Client( + tutorial001.mcp, extensions=[advertise(EXTENSION_ID, {"mimeTypes": [APP_MIME_TYPE]})] + ) as ui_client: rich = await ui_client.call_tool("get_time", {}) async with Client(tutorial001.mcp) as text_client: plain = await text_client.call_tool("get_time", {}) diff --git a/tests/docs_src/test_extensions.py b/tests/docs_src/test_extensions.py index cd412b85a..17defca9f 100644 --- a/tests/docs_src/test_extensions.py +++ b/tests/docs_src/test_extensions.py @@ -8,6 +8,7 @@ from docs_src.extensions import tutorial001, tutorial002, tutorial003, tutorial004, tutorial005 from mcp import Client, MCPError +from mcp.client import advertise from mcp.server.extension import Extension # See test_index.py for why this is a per-module mark and not a conftest hook. @@ -76,7 +77,7 @@ async def test_vendor_method_rejects_a_non_declaring_client_with_32021() -> None async def test_version_pinned_method_is_not_found_on_a_legacy_connection() -> None: """tutorial004: `protocol_versions={"2026-07-28"}` makes the method METHOD_NOT_FOUND at any other wire version; for a legacy client it doesn't exist.""" - async with Client(tutorial004.mcp, mode="legacy", extensions={tutorial004.EXTENSION_ID: {}}) as client: + async with Client(tutorial004.mcp, mode="legacy", extensions=[advertise(tutorial004.EXTENSION_ID)]) as client: request = tutorial004.SearchRequest(params=tutorial004.SearchParams(query="mcp")) with pytest.raises(MCPError) as exc_info: await client.session.send_request(request, tutorial004.SearchResult) diff --git a/tests/server/mcpserver/test_extension.py b/tests/server/mcpserver/test_extension.py index 8fec58e99..2ed093ac0 100644 --- a/tests/server/mcpserver/test_extension.py +++ b/tests/server/mcpserver/test_extension.py @@ -18,6 +18,7 @@ TextContent, ) +from mcp.client import advertise from mcp.client.client import Client from mcp.server.context import CallNext, HandlerResult, ServerRequestContext from mcp.server.extension import ( @@ -466,7 +467,7 @@ async def test_require_client_extension_passes_when_client_declared_it() -> None """SDK-defined: `require_client_extension` is a no-op when the client advertised the id.""" server = MCPServer("test", extensions=[_RequiresExt()]) - async with Client(server, extensions={_NEEDS_EXT: {}}) as client: + async with Client(server, extensions=[advertise(_NEEDS_EXT)]) as client: result = await client.call_tool("guarded", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="ok")], structured_content={"result": "ok"})) diff --git a/tests/server/test_apps.py b/tests/server/test_apps.py index 65908309a..262bdfe7a 100644 --- a/tests/server/test_apps.py +++ b/tests/server/test_apps.py @@ -14,6 +14,7 @@ from inline_snapshot import snapshot from mcp_types import CallToolResult, ReadResourceResult, TextContent, TextResourceContents +from mcp.client import advertise from mcp.client.client import Client from mcp.server import Server, ServerRequestContext from mcp.server.apps import ( @@ -95,7 +96,7 @@ async def test_apps_tool_returns_rich_output_when_client_negotiated_apps() -> No branching on `client_supports_apps(ctx)`, drives both halves.""" server = _clock_server() - async with Client(server, extensions={EXTENSION_ID: {"mimeTypes": [APP_MIME_TYPE]}}) as supports: + async with Client(server, extensions=[advertise(EXTENSION_ID, {"mimeTypes": [APP_MIME_TYPE]})]) as supports: rich = await supports.call_tool("get_time", {}) async with Client(server) as plain: fallback = await plain.call_tool("get_time", {}) @@ -104,7 +105,7 @@ async def test_apps_tool_returns_rich_output_when_client_negotiated_apps() -> No assert fallback.content == snapshot([TextContent(text="The time is 2026-06-26T00:00:00Z.")]) -async def _observed_client_supports_apps(extensions: dict[str, dict[str, Any]] | None) -> bool: +async def _observed_client_supports_apps(ui_settings: dict[str, Any] | None) -> bool: """Run one probe `tools/call` and report what `client_supports_apps` saw server-side. Exercises the lowlevel `ServerRequestContext` form, which reads the client's @@ -123,29 +124,30 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara return CallToolResult(content=[TextContent(text="ok")]) server = Server("probe", on_list_tools=list_tools, on_call_tool=call_tool) + extensions = None if ui_settings is None else [advertise(EXTENSION_ID, ui_settings)] async with Client(server, extensions=extensions) as client: await client.call_tool("probe", {}) return observed[0] @pytest.mark.parametrize( - ("extensions", "expected"), + ("ui_settings", "expected"), [ - pytest.param({EXTENSION_ID: {"mimeTypes": [APP_MIME_TYPE]}}, True, id="html-mime-listed"), - pytest.param({EXTENSION_ID: {"mimeTypes": (APP_MIME_TYPE,)}}, True, id="in-process-tuple-mime-types"), + pytest.param({"mimeTypes": [APP_MIME_TYPE]}, True, id="html-mime-listed"), + pytest.param({"mimeTypes": (APP_MIME_TYPE,)}, True, id="in-process-tuple-mime-types"), pytest.param(None, False, id="extension-not-declared"), - pytest.param({EXTENSION_ID: {"mimeTypes": ["application/x-other"]}}, False, id="html-mime-not-offered"), - pytest.param({EXTENSION_ID: {}}, False, id="mime-types-key-missing"), + pytest.param({"mimeTypes": ["application/x-other"]}, False, id="html-mime-not-offered"), + pytest.param({}, False, id="mime-types-key-missing"), ], ) async def test_client_supports_apps_from_lowlevel_request_context( - extensions: dict[str, dict[str, Any]] | None, expected: bool + ui_settings: dict[str, Any] | None, expected: bool ) -> None: """ext-apps: `client_supports_apps` is `True` only when the client declared the ui extension AND listed `text/html;profile=mcp-app` in its `mimeTypes` settings — a required field, so its absence means unsupported (the reference SDK's check is `uiCap?.mimeTypes?.includes(...)`).""" - assert await _observed_client_supports_apps(extensions) is expected + assert await _observed_client_supports_apps(ui_settings) is expected def test_apps_tool_rejects_non_ui_resource_uri() -> None: diff --git a/tests/server/test_extensions_capability.py b/tests/server/test_extensions_capability.py index 90f24be2b..49c81f58f 100644 --- a/tests/server/test_extensions_capability.py +++ b/tests/server/test_extensions_capability.py @@ -13,6 +13,7 @@ import pytest from inline_snapshot import snapshot +from mcp.client import advertise from mcp.client.client import Client from mcp.server import Server, ServerRequestContext from mcp.server.extension import Extension @@ -82,7 +83,7 @@ async def list_tools( return types.ListToolsResult(tools=[types.Tool(name="probe", input_schema={"type": "object"})]) server = Server("checker", on_call_tool=call_tool, on_list_tools=list_tools) - async with Client(server, extensions={_EXTENSION_ID: {"mimeTypes": ["text/html"]}}) as client: + async with Client(server, extensions=[advertise(_EXTENSION_ID, {"mimeTypes": ["text/html"]})]) as client: await client.call_tool("probe", {}) assert supported == [True] @@ -105,7 +106,7 @@ async def list_tools( return types.ListToolsResult(tools=[types.Tool(name="probe", input_schema={"type": "object"})]) server = Server("checker", on_call_tool=call_tool, on_list_tools=list_tools) - async with Client(server, extensions={_EXTENSION_ID: {"mimeTypes": ["text/html"]}}) as client: + async with Client(server, extensions=[advertise(_EXTENSION_ID, {"mimeTypes": ["text/html"]})]) as client: await client.call_tool("probe", {}) assert supported == [False] From 1084616dee4abc6c096db465f5cbc0c91b6bf5d9 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 30 Jun 2026 16:24:34 +0000 Subject: [PATCH 08/15] Refine extension fold errors and pin read-once for all declarations A Mapping passed as extensions= gets a migration error naming advertise() instead of an attribute error about str; a self-conflicting extension reads as one owner instead of "extensions 'a' and 'a'". Pins added: claims() and notifications() are read exactly once like settings(), and a claimed shape routes to its owning extension's resolver when two claim-bearing extensions are registered. --- src/mcp/client/client.py | 23 +++-- tests/client/test_client_extensions.py | 112 +++++++++++++++++++++++-- 2 files changed, 123 insertions(+), 12 deletions(-) diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 7b02d3466..00e9f4b9b 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -213,6 +213,11 @@ def _fold_extensions(extensions: Sequence[ClientExtension] | None) -> _FoldedExt """ if not extensions: return _FoldedExtensions(ad=None, claims=None, bindings=None, by_model={}) + if isinstance(extensions, Mapping): + raise TypeError( + "extensions= takes a sequence of ClientExtension instances; the mapping form was " + "replaced — use advertise(identifier, settings) for advertise-only entries" + ) ad: dict[str, dict[str, Any]] = {} claims: dict[str, tuple[ResultClaim[Any], ...]] = {} bindings: list[NotificationBinding[Any]] = [] @@ -234,9 +239,14 @@ def _fold_extensions(extensions: Sequence[ClientExtension] | None) -> _FoldedExt for claim in extension_claims: key = (claim.method, claim.result_type) if key in claim_owners: + owner = claim_owners[key] + both = ( + f"extension {identifier!r} claims" + if owner == identifier + else (f"extensions {owner!r} and {identifier!r} both claim") + ) raise ValueError( - f"extensions {claim_owners[key]!r} and {identifier!r} both claim {claim.method!r} " - f"resultType {claim.result_type!r}; a wire tag can have only one resolver" + f"{both} {claim.method!r} resultType {claim.result_type!r}; a wire tag can have only one resolver" ) claim_owners[key] = identifier # Collision-free by construction: a model's `result_type` Literal pins it to @@ -246,10 +256,13 @@ def _fold_extensions(extensions: Sequence[ClientExtension] | None) -> _FoldedExt claims[identifier] = extension_claims for binding in extension.notifications(): if binding.method in binding_owners: - raise ValueError( - f"extensions {binding_owners[binding.method]!r} and {identifier!r} both bind " - f"notification method {binding.method!r}; a method can have only one observer" + owner = binding_owners[binding.method] + both = ( + f"extension {identifier!r} binds" + if owner == identifier + else (f"extensions {owner!r} and {identifier!r} both bind") ) + raise ValueError(f"{both} notification method {binding.method!r}; a method can have only one observer") binding_owners[binding.method] = identifier bindings.append(binding) return _FoldedExtensions(ad=ad, claims=claims or None, bindings=tuple(bindings) or None, by_model=by_model) diff --git a/tests/client/test_client_extensions.py b/tests/client/test_client_extensions.py index a67c2eec4..4bd703cd0 100644 --- a/tests/client/test_client_extensions.py +++ b/tests/client/test_client_extensions.py @@ -14,7 +14,7 @@ import logging from collections.abc import Awaitable, Callable, Sequence -from typing import Any, Literal +from typing import Any, Literal, cast import anyio import mcp_types as types @@ -133,6 +133,64 @@ def add(a: int, b: int) -> int: # ── Construction-time validation ──────────────────────────────────────────── +class _CouponResult(Result): + """A second claimed shape with its own tag, for multi-claim routing.""" + + result_type: Literal["coupon"] = "coupon" + + +async def _unreachable_coupon_resolve(claimed: _CouponResult, ctx: ClaimContext) -> CallToolResult: + raise NotImplementedError # the wrong resolver for a voucher — must never run + + +class _CouponExtension(ClientExtension): + identifier = "com.example/coupons" + + def claims(self) -> Sequence[ResultClaim[Any]]: + return [ResultClaim(result_type="coupon", model=_CouponResult, resolve=_unreachable_coupon_resolve)] + + +class _SelfConflictingClaims(ClientExtension): + identifier = "com.example/twice" + + def claims(self) -> Sequence[ResultClaim[Any]]: + return [ + ResultClaim(result_type="twice", model=_TwiceResult, resolve=_unreachable_twice_resolve), + ResultClaim(result_type="twice", model=_TwiceResult, resolve=_unreachable_twice_resolve), + ] + + +class _TwiceResult(Result): + result_type: Literal["twice"] = "twice" + + +async def _unreachable_twice_resolve(claimed: _TwiceResult, ctx: ClaimContext) -> CallToolResult: + raise NotImplementedError + + +def test_mapping_extensions_get_the_migration_error() -> None: + """SDK-defined: the replaced dict form fails with a message naming the new shape, + not an attribute error about `str`.""" + with pytest.raises(TypeError) as exc_info: + Client(_add_server(), extensions=cast("Sequence[ClientExtension]", {"com.example/ui": {}})) + + assert str(exc_info.value) == snapshot( + "extensions= takes a sequence of ClientExtension instances; the mapping form was " + "replaced — use advertise(identifier, settings) for advertise-only entries" + ) + + +def test_one_extension_claiming_a_tag_twice_reads_as_one_owner() -> None: + """SDK-defined: a self-conflict names the one extension once instead of + "extensions 'a' and 'a'".""" + with pytest.raises(ValueError) as exc_info: + Client(_add_server(), extensions=[_SelfConflictingClaims()]) + + assert str(exc_info.value) == snapshot( + "extension 'com.example/twice' claims 'tools/call' resultType 'twice'; a wire tag can have only one resolver" + ) + + def test_bare_extension_instance_is_rejected_with_the_fix_named() -> None: """SDK-defined: an instance whose class never set `identifier` fails Client construction with an error naming the type and the fix — not an AttributeError.""" @@ -246,32 +304,54 @@ def test_conflicting_notification_bindings_name_both_owners() -> None: # ── settings() consumption ─────────────────────────────────────────────────── +class _CountedResult(Result): + result_type: Literal["counted"] = "counted" + + +async def _unreachable_counted_resolve(claimed: _CountedResult, ctx: ClaimContext) -> CallToolResult: + raise NotImplementedError # never driven; exists so claims() has something to return + + class _CountingSettings(ClientExtension): - """Counts `settings()` reads to pin the read-once contract.""" + """Counts every declaration read to pin the read-once contract for all three.""" identifier = "com.example/counted" def __init__(self) -> None: self.reads = 0 + self.claims_reads = 0 + self.notifications_reads = 0 def settings(self) -> dict[str, Any]: self.reads += 1 return {"read": self.reads} + def claims(self) -> Sequence[ResultClaim[Any]]: + self.claims_reads += 1 + return [ResultClaim(result_type="counted", model=_CountedResult, resolve=_unreachable_counted_resolve)] + + def notifications(self) -> Sequence[NotificationBinding[Any]]: + self.notifications_reads += 1 + return [ + NotificationBinding(method="notifications/counted", params_type=_EventParams, handler=_unreachable_handler) + ] + -async def test_settings_is_read_exactly_once_at_construction() -> None: - """SDK-defined: `settings()` is read once, at Client construction — connecting and - calling tools (each modern request re-stamps the capability ad) never re-reads it.""" +async def test_declarations_are_read_exactly_once_at_construction() -> None: + """SDK-defined: `settings()`, `claims()`, and `notifications()` are each read once, + at Client construction — connecting and calling tools (each modern request re-stamps + the capability ad) never re-reads any of them, so a stateful extension cannot desync + the ad from the claims.""" extension = _CountingSettings() client = Client(_add_server(), extensions=[extension]) - assert extension.reads == 1 + assert (extension.reads, extension.claims_reads, extension.notifications_reads) == (1, 1, 1) with anyio.fail_after(5): async with client: await client.call_tool("add", {"a": 1, "b": 2}) await client.call_tool("add", {"a": 3, "b": 4}) - assert extension.reads == 1 + assert (extension.reads, extension.claims_reads, extension.notifications_reads) == (1, 1, 1) async def test_settings_dict_is_held_by_reference_not_copied() -> None: @@ -348,6 +428,24 @@ async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: assert result.content == [TextContent(text="honored v-42")] +async def test_claimed_shape_routes_to_its_owning_extensions_resolver() -> None: + """With two claim-bearing extensions registered, the parsed shape runs ITS owner's + resolver — the coupon extension (registered first) must never see a voucher.""" + received: list[VoucherResult] = [] + + async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: + received.append(claimed) + return CallToolResult(content=[TextContent(text="routed")]) + + extensions = [_CouponExtension(), _VoucherExtension(resolve)] + with anyio.fail_after(5): + async with Client(_voucher_server(), extensions=extensions) as client: + result = await client.call_tool("issue", {}) + + assert [claimed.request_state for claimed in received] == ["v-42"] + assert result.content == [TextContent(text="routed")] + + async def test_resolver_product_gets_the_direct_paths_output_schema_revalidation() -> None: """The resolver's product passes through `validate_tool_result` exactly like a directly-returned result: against the tool's output schema, missing structured From be8e160eb39e9529704f082dec2e07138b64ee32 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 30 Jun 2026 16:34:23 +0000 Subject: [PATCH 09/15] Document client extensions and prove the loop in the interaction suite docs/advanced/extensions.md gains the client half: using an extension (construct, pass extensions=[...], call tools normally), writing one (identifier, claims with a resolver doing real follow-up sends, notifications, read-once settings), and extension verbs via Request subclasses with name_param. Two runnable tutorials back the page. Interaction tests prove the five-sentence story over the real harness: the both-ends claimed-result loop with a resolver follow-up, the off-switch (undeclared shape fails validation; legacy ad drops claim-bearing identifiers), per-request capability gating with -32021 refusal, and Mcp-Name from name_param observed on the modern HTTP wire. --- docs/advanced/extensions.md | 117 +++++++++-- docs_src/extensions/tutorial006.py | 69 +++++++ docs_src/extensions/tutorial007.py | 50 +++++ tests/docs_src/test_extensions.py | 44 +++- tests/interaction/_connect.py | 10 +- tests/interaction/_requirements.py | 58 ++++++ .../interaction/mcpserver/test_extensions.py | 190 ++++++++++++++++++ .../transports/test_hosting_http_modern.py | 65 +++++- 8 files changed, 585 insertions(+), 18 deletions(-) create mode 100644 docs_src/extensions/tutorial006.py create mode 100644 docs_src/extensions/tutorial007.py create mode 100644 tests/interaction/mcpserver/test_extensions.py diff --git a/docs/advanced/extensions.md b/docs/advanced/extensions.md index 024c7af69..5945cf64f 100644 --- a/docs/advanced/extensions.md +++ b/docs/advanced/extensions.md @@ -2,9 +2,10 @@ An **extension** is an opt-in bundle of MCP behaviour behind one identifier. -It can contribute tools, resources, and new request methods, and it can wrap `tools/call`. -The server advertises it under `capabilities.extensions`, the client opts in the same way, -and nothing changes for anyone who didn't ask for it. That is the contract ([SEP-2133](https://github.com/modelcontextprotocol/modelcontextprotocol/pull/2133)), and +On a server it can contribute tools, resources, and new request methods, and it can wrap +`tools/call`. On a client it can claim extra `tools/call` result shapes and observe vendor +notifications. Each side advertises under its own `capabilities.extensions`, and nothing +changes for anyone who didn't ask for it. That is the contract ([SEP-2133](https://github.com/modelcontextprotocol/modelcontextprotocol/pull/2133)), and it has one golden rule: **extensions are off by default**. ## Using an extension @@ -79,7 +80,7 @@ And `main()` is the proof, an in-memory client straight against `mcp`: An extension can register **new request methods**: its own verbs, served next to the spec's: -```python title="server.py" hl_lines="15-21 30 39-47" +```python title="server.py" hl_lines="16-22 31 40-48" --8<-- "docs_src/extensions/tutorial004.py" ``` @@ -108,15 +109,15 @@ runtime: The same file's `main()` is the whole client story, both halves of it: -```python title="server.py" hl_lines="53-57" +```python title="server.py" hl_lines="54-58" --8<-- "docs_src/extensions/tutorial004.py" ``` -* `Client(..., extensions={EXTENSION_ID: {}})` declares the extension. That map - becomes `ClientCapabilities.extensions`: on a 2026-07-28 connection it travels in - the per-request `_meta` envelope, so the server sees it on **every** request; on - a legacy connection it rides the `initialize` handshake. Server code doesn't care - which: `require_client_extension(ctx, ...)` and +* `Client(..., extensions=[advertise(EXTENSION_ID)])` declares the extension. The + declarations become `ClientCapabilities.extensions`: on a 2026-07-28 connection + the map travels in the per-request `_meta` envelope, so the server sees it on + **every** request; on a legacy connection it rides the `initialize` handshake. + Server code doesn't care which: `require_client_extension(ctx, ...)` and `ctx.session.check_client_capability(...)` read the right source on both paths. * Vendor methods drop one layer to `client.session.send_request(...)`; `Client` only grows first-class methods for spec verbs. `send_request` accepts any @@ -144,15 +145,103 @@ or veto a tool call: The hook wraps `tools/call` and nothing else. For every-message concerns, use [Middleware](middleware.md). That is what it is for. +## Using a client extension + +A **client extension** is the same contract from the consuming side: a bundle of +client-side behaviour behind one identifier. Pass instances to +`Client(extensions=[...])` and call tools normally: + +```python title="client.py" hl_lines="66-68" +--8<-- "docs_src/extensions/tutorial006.py" +``` + +`call_tool("buy", ...)` returns a plain `CallToolResult`, like every other call. What +the extension changed: the server may now answer `buy` with a `receipt` **result +shape** instead of a final result, and `Receipts` finishes it — here by redeeming the +receipt with a follow-up call — before `call_tool` returns. Nothing about the call +site moves. + +Drop the extension and none of this exists: a `receipt` shape arriving at a client +that didn't declare it fails validation, exactly as the spec requires for an +unrecognized `resultType`. Off by default, on both ends of the wire. + +To advertise an identifier with **no** client-side behaviour — the server gates on +the capability, the client does nothing, as in the search client above — use +`advertise()`: + +```python +from mcp.client import advertise + +client = Client(mcp, extensions=[advertise("com.example/search")]) +``` + +## Writing a client extension + +Subclass `ClientExtension` and override only what you need. Three contribution +kinds, each with a default: `settings()`, `claims()`, and `notifications()`. + +```python title="client.py" hl_lines="18-19 43-44 46-47" +--8<-- "docs_src/extensions/tutorial006.py" +``` + +* The identifier follows the same grammar as the server's, validated when the class + is defined. +* `claims()` returns `ResultClaim`s: a wire tag, the model that parses it, and the + resolver that finishes it. The model must pin the tag with + `result_type: Literal["receipt"]` and must not subclass a core result type — both + enforced when the claim is constructed. (The payload rides `requestState` here + because an `MCPServer` substituting a claimed shape serializes only the core + `tools/call` surface fields; a server on another SDK may send richer shapes.) +* The resolver receives the parsed model and a `ClaimContext`; `ctx.session` is the + same public handle as `client.session`, so follow-ups are ordinary session calls. + It returns the verb's normal `CallToolResult`. +* `settings()` is the value advertised at `ClientCapabilities.extensions[identifier]`, + read once at `Client` construction. + +`notifications()` declares vendor server notifications to observe: + +```python +def notifications(self) -> Sequence[NotificationBinding[Any]]: + return [NotificationBinding(method="notifications/receipts", params_type=ReceiptEvent, handler=self.on_receipt)] +``` + +The handler receives validated params, in arrival order. It observes; it cannot veto +or reply. + +Two quiet rules. Claims are active on 2026-07-28 connections only, and the capability +ad follows them: on a legacy connection the claims dissolve and the identifier drops +out of the ad in the same breath, so the client never advertises an extension whose +shapes it would reject. And when you want the claimed shape yourself instead of the +resolver, call `client.session.call_tool(..., allow_claimed=True)` — the escape hatch +`UnexpectedClaimedResult` names when a claimed shape reaches a session-tier caller +that didn't opt in. + +### Extension verbs + +An extension's own request methods need no client-side registration. A vendor request +type subclasses `mcp_types.Request` and goes through `client.session.send_request`, +as in [Serving your own methods](#serving-your-own-methods). One addition: when a +params key must ride the `Mcp-Name` header (extension specs such as tasks require +this for their verbs), the request type declares `name_param`: + +```python title="client.py" hl_lines="23-26 47-48" +--8<-- "docs_src/extensions/tutorial007.py" +``` + +The session mirrors `params["jobId"]` into `Mcp-Name` on every send path, and a +missing value fails loudly rather than silently omitting a required header. + ## What an extension cannot do -The contribution surface is **closed** on purpose: settings, tools, resources, -methods, one `tools/call` interceptor. An extension cannot: +The contribution surface is **closed** on purpose — on the server: settings, tools, +resources, methods, one `tools/call` interceptor; on the client: settings, result +claims, notification bindings. An extension cannot: -* **Reach into the server.** It declares data; it holds no server reference. +* **Reach into the host.** It declares data; it holds no server or client reference. * **Replace core behaviour.** Spec methods are rejected at construction, and `initialize` is reserved by the runner outright. -* **Register late.** After `MCPServer(...)` returns, the extension set is what it is. +* **Register late.** After `MCPServer(...)` or `Client(...)` returns, the extension + set is what it is. If you are fighting these walls, you are not writing an extension. You are writing a fork. The walls are the feature: a user reading `extensions=[Apps(), Stamps()]` diff --git a/docs_src/extensions/tutorial006.py b/docs_src/extensions/tutorial006.py new file mode 100644 index 000000000..ca7cb072f --- /dev/null +++ b/docs_src/extensions/tutorial006.py @@ -0,0 +1,69 @@ +from collections.abc import Sequence +from typing import Any, Literal + +import mcp_types as types + +from mcp import Client +from mcp.client import ClaimContext, ClientExtension, ResultClaim +from mcp.server.context import CallNext, HandlerResult, ServerRequestContext +from mcp.server.extension import Extension +from mcp.server.mcpserver import MCPServer + +EXTENSION_ID = "com.example/receipts" + + +class ReceiptResult(types.Result): + """The claimed result shape; `result_type` pins the wire tag.""" + + result_type: Literal["receipt"] = "receipt" + request_state: str + + +class ReceiptIssuer(Extension): + """Server half: answers `buy` with a receipt instead of a final result.""" + + identifier = EXTENSION_ID + + async def intercept_tool_call( + self, + params: types.CallToolRequestParams, + ctx: ServerRequestContext[Any, Any], + call_next: CallNext, + ) -> HandlerResult: + if params.name != "buy": + return await call_next(ctx) + return {"resultType": "receipt", "requestState": "r-117"} + + +class Receipts(ClientExtension): + """Client half: claims the `receipt` shape and supplies the code that finishes it.""" + + identifier = EXTENSION_ID + + def claims(self) -> Sequence[ResultClaim[Any]]: + return [ResultClaim(result_type="receipt", model=ReceiptResult, resolve=self._redeem)] + + async def _redeem(self, claimed: ReceiptResult, ctx: ClaimContext) -> types.CallToolResult: + return await ctx.session.call_tool("redeem", {"token": claimed.request_state}) + + +mcp = MCPServer("shop", extensions=[ReceiptIssuer()]) + + +@mcp.tool() +def buy(item: str) -> types.CallToolResult: + """Buy an item.""" + raise NotImplementedError # ReceiptIssuer answers `buy` before the tool runs + + +@mcp.tool() +def redeem(token: str) -> str: + """Exchange a receipt token for the goods.""" + return f"goods for {token}" + + +async def main() -> None: + async with Client(mcp, extensions=[Receipts()]) as client: + result = await client.call_tool("buy", {"item": "lamp"}) + print(result.content) + # [TextContent(text='goods for r-117')] diff --git a/docs_src/extensions/tutorial007.py b/docs_src/extensions/tutorial007.py new file mode 100644 index 000000000..37706ca21 --- /dev/null +++ b/docs_src/extensions/tutorial007.py @@ -0,0 +1,50 @@ +from collections.abc import Sequence +from typing import Any, Literal + +import mcp_types as types + +from mcp import Client +from mcp.client import advertise +from mcp.server.context import ServerRequestContext +from mcp.server.extension import Extension, MethodBinding +from mcp.server.mcpserver import MCPServer + +EXTENSION_ID = "com.example/jobs" + + +class JobParams(types.RequestParams): + job_id: str + + +class JobStatus(types.Result): + status: str + + +class JobStatusRequest(types.Request[JobParams, Literal["com.example/jobs.status"]]): + method: Literal["com.example/jobs.status"] = "com.example/jobs.status" + params: JobParams + name_param = "jobId" # params["jobId"] rides the Mcp-Name header + + +async def job_status(ctx: ServerRequestContext[Any, Any], params: JobParams) -> JobStatus: + return JobStatus(status=f"{params.job_id} is running") + + +class Jobs(Extension): + """An extension whose verb names its subject, so the header can route on it.""" + + identifier = EXTENSION_ID + + def methods(self) -> Sequence[MethodBinding]: + return [MethodBinding("com.example/jobs.status", JobParams, job_status)] + + +mcp = MCPServer("worker", extensions=[Jobs()]) + + +async def main() -> None: + async with Client(mcp, extensions=[advertise(EXTENSION_ID)]) as client: + request = JobStatusRequest(params=JobParams(job_id="job-7")) + result = await client.session.send_request(request, JobStatus) + print(result.status) + # job-7 is running diff --git a/tests/docs_src/test_extensions.py b/tests/docs_src/test_extensions.py index 17defca9f..346c2956e 100644 --- a/tests/docs_src/test_extensions.py +++ b/tests/docs_src/test_extensions.py @@ -5,8 +5,17 @@ import pytest from inline_snapshot import snapshot from mcp_types import METHOD_NOT_FOUND, MISSING_REQUIRED_CLIENT_CAPABILITY, TextContent - -from docs_src.extensions import tutorial001, tutorial002, tutorial003, tutorial004, tutorial005 +from pydantic import ValidationError + +from docs_src.extensions import ( + tutorial001, + tutorial002, + tutorial003, + tutorial004, + tutorial005, + tutorial006, + tutorial007, +) from mcp import Client, MCPError from mcp.client import advertise from mcp.server.extension import Extension @@ -94,3 +103,34 @@ async def test_interceptor_observes_the_call_and_passes_the_result_through( assert result.structured_content == {"result": 5} messages = [record.getMessage() for record in caplog.records if record.name == tutorial005.logger.name] assert messages == ["tool 'add' called"] + + +async def test_the_receipts_client_program_runs_as_shown(capsys: pytest.CaptureFixture[str]) -> None: + """tutorial006: `main()` is the literal client program on the page — the claimed + `receipt` shape never surfaces, and the printed content is the redeemed result.""" + await tutorial006.main() + assert "goods for r-117" in capsys.readouterr().out + + +async def test_a_claimed_shape_fails_validation_without_the_extension() -> None: + """The page's off-by-default claim: a client that does not construct `Receipts` + rejects the `receipt` shape as invalid (spec: an unrecognized resultType is invalid).""" + async with Client(tutorial006.mcp) as client: + with pytest.raises(ValidationError): + await client.call_tool("buy", {"item": "lamp"}) + + +async def test_session_tier_allow_claimed_returns_the_raw_shape() -> None: + """The page's escape hatch: `client.session.call_tool(..., allow_claimed=True)` hands + back the parsed claim model instead of running the resolver.""" + async with Client(tutorial006.mcp, extensions=[tutorial006.Receipts()]) as client: + result = await client.session.call_tool("buy", {"item": "lamp"}, allow_claimed=True) + assert isinstance(result, tutorial006.ReceiptResult) + assert result.request_state == "r-117" + + +async def test_the_jobs_client_program_runs_as_shown(capsys: pytest.CaptureFixture[str]) -> None: + """tutorial007: a vendor request type with `name_param` goes through + `client.session.send_request` with no registration and returns its typed result.""" + await tutorial007.main() + assert "job-7 is running" in capsys.readouterr().out diff --git a/tests/interaction/_connect.py b/tests/interaction/_connect.py index 05b2d2277..5da269ba4 100644 --- a/tests/interaction/_connect.py +++ b/tests/interaction/_connect.py @@ -7,7 +7,7 @@ (session ids, SSE encoding, session management) runs with no sockets, threads, or subprocesses. """ -from collections.abc import AsyncIterator, Awaitable, Callable, Iterable +from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence from contextlib import AbstractAsyncContextManager, asynccontextmanager from functools import partial from typing import Any, Protocol @@ -30,6 +30,7 @@ from starlette.routing import Mount, Route from mcp.client.client import Client +from mcp.client.extension import ClientExtension from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT from mcp.client.sse import sse_client from mcp.client.streamable_http import streamable_http_client @@ -70,6 +71,7 @@ def __call__( message_handler: MessageHandlerFnT | None = None, client_info: Implementation | None = None, elicitation_callback: ElicitationFnT | None = None, + extensions: Sequence[ClientExtension] | None = None, spec_version: str = LATEST_HANDSHAKE_VERSION, ) -> AbstractAsyncContextManager[Client]: ... @@ -85,6 +87,7 @@ async def connect_in_memory( message_handler: MessageHandlerFnT | None = None, client_info: Implementation | None = None, elicitation_callback: ElicitationFnT | None = None, + extensions: Sequence[ClientExtension] | None = None, spec_version: str = LATEST_HANDSHAKE_VERSION, ) -> AsyncIterator[Client]: """Yield a Client connected to the server over the in-memory transport. @@ -103,6 +106,7 @@ async def connect_in_memory( message_handler=message_handler, client_info=client_info, elicitation_callback=elicitation_callback, + extensions=extensions, ) as client: yield client @@ -122,6 +126,7 @@ async def connect_over_streamable_http( message_handler: MessageHandlerFnT | None = None, client_info: Implementation | None = None, elicitation_callback: ElicitationFnT | None = None, + extensions: Sequence[ClientExtension] | None = None, spec_version: str = LATEST_HANDSHAKE_VERSION, ) -> AsyncIterator[Client]: """Yield a Client connected to the server's streamable HTTP app, entirely in process. @@ -156,6 +161,7 @@ async def connect_over_streamable_http( message_handler=message_handler, client_info=client_info, elicitation_callback=elicitation_callback, + extensions=extensions, ) as client, ): yield client @@ -357,6 +363,7 @@ async def connect_over_sse( message_handler: MessageHandlerFnT | None = None, client_info: Implementation | None = None, elicitation_callback: ElicitationFnT | None = None, + extensions: Sequence[ClientExtension] | None = None, spec_version: str = LATEST_HANDSHAKE_VERSION, ) -> AsyncIterator[Client]: """Yield a Client connected to the server's legacy SSE transport, entirely in process.""" @@ -390,5 +397,6 @@ def httpx_client_factory( message_handler=message_handler, client_info=client_info, elicitation_callback=elicitation_callback, + extensions=extensions, ) as client: yield client diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index d376f0b9f..acbfd4bfb 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -2384,6 +2384,51 @@ def __post_init__(self) -> None: ), ), # ═══════════════════════════════════════════════════════════════════════════ + # Extensions (SEP-2133): client-side result claims and the capability ad + # ═══════════════════════════════════════════════════════════════════════════ + "extensions:client:claimed-result-resolved": Requirement( + source=f"{SPEC_2026_BASE_URL}/basic#resulttype", + behavior=( + "A tools/call answered with an extension-claimed resultType is finished by the owning " + "ClientExtension's claim resolver — which may send follow-up requests through the session it is " + "handed — and Client.call_tool returns the resolver's ordinary CallToolResult." + ), + added_in="2026-07-28", + ), + "extensions:client:claimed-result-undeclared-invalid": Requirement( + source=f"{SPEC_2026_BASE_URL}/basic#resulttype", + behavior=( + "A resultType unrecognized by the client is invalid: a claimed shape delivered to a client that " + "did not construct the owning extension fails result validation (the supported set is core plus " + "declared claims, never more)." + ), + added_in="2026-07-28", + ), + "extensions:client:capability-ad:gates-server-behaviour": Requirement( + source=f"{SPEC_2026_BASE_URL}/basic#resulttype", + behavior=( + "The per-request _meta capability ad carries each declared extension's identifier and settings, " + "and is what entitles the server to substitute that extension's claimed shapes: a server " + "extension gating on the ad sees the declared settings, and refuses a non-declaring client with " + "-32021 (missing required client capability)." + ), + added_in="2026-07-28", + ), + "extensions:client:capability-ad:legacy-omits-claimed": Requirement( + source=f"{SPEC_2026_BASE_URL}/basic#resulttype", + behavior=( + "On a legacy connection no claim can activate, and the initialize capability ad omits " + "claim-bearing identifiers in the same breath (claim-less identifiers still advertise), so the " + "client never advertises an extension whose claimed shapes it would reject." + ), + removed_in="2026-07-28", + note=( + "The legacy-era half of the ad/claims coupling: only a handshake connection can exhibit it, so " + "the version window ends where the modern era begins." + ), + arm_exclusions=(ArmExclusion(reason="requires-session", transport="streamable-http-stateless"),), + ), + # ═══════════════════════════════════════════════════════════════════════════ # Transports (in-suite coverage) # ═══════════════════════════════════════════════════════════════════════════ "transport:streamable-http:stateful": Requirement( @@ -3341,6 +3386,19 @@ def __post_init__(self) -> None: transports=("streamable-http",), note="Only observable over streamable HTTP: headers are derived from the cached tool schema at the seam.", ), + "client-transport:http:vendor-name-param-header": Requirement( + source="sdk", + behavior=( + "A vendor request type declaring name_param mirrors that wire-params key into the Mcp-Name " + "header of its outgoing HTTP request, with no client-side registration of the method." + ), + added_in="2026-07-28", + transports=("streamable-http",), + note=( + "SDK mechanism honouring the per-extension Mcp-Name requirements (e.g. SEP-2663 mandates the " + "header for tasks/*); only observable over streamable HTTP, where headers exist." + ), + ), "client-transport:http:stateless-ignores-session-id": Requirement( source=f"{SPEC_2026_BASE_URL}/basic/transports#stateless-request-headers", behavior=( diff --git a/tests/interaction/mcpserver/test_extensions.py b/tests/interaction/mcpserver/test_extensions.py new file mode 100644 index 000000000..a327cca9a --- /dev/null +++ b/tests/interaction/mcpserver/test_extensions.py @@ -0,0 +1,190 @@ +"""Client extensions (SEP-2133) over the full client-server loop. + +The servers here are MCPServers whose server extension substitutes a claimed `tools/call` +shape via `intercept_tool_call`; the client declares the owning `ClientExtension` and its +claim resolver finishes the call. The in-process server's 2026 result surface keeps only +`resultType` / `requestState` / `inputRequests` / `_meta` on a claimed result, so claimed +payloads here ride `requestState`. +""" + +import json +from collections.abc import Awaitable, Callable, Sequence +from typing import Any, Literal + +import mcp_types as types +import pytest +from inline_snapshot import snapshot +from mcp_types import MISSING_REQUIRED_CLIENT_CAPABILITY, CallToolResult, Result, TextContent +from pydantic import ValidationError + +from mcp import MCPError +from mcp.client import ClaimContext, ClientExtension, ResultClaim, advertise +from mcp.server.context import CallNext, HandlerResult, ServerRequestContext +from mcp.server.extension import Extension +from mcp.server.mcpserver import Context, MCPServer, require_client_extension +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + +_RECEIPTS = "com.example/receipts" +_FLAGS = "com.example/flags" + + +class ReceiptResult(Result): + """The claimed `tools/call` shape, tagged `receipt`; its payload rides `requestState`.""" + + result_type: Literal["receipt"] = "receipt" + request_state: str + + +_Resolver = Callable[[ReceiptResult, ClaimContext], Awaitable[CallToolResult]] + + +class Receipts(ClientExtension): + """Client half: claims the `receipt` tag with the test's resolver and settings.""" + + identifier = _RECEIPTS + + def __init__(self, resolve: _Resolver, settings: dict[str, Any] | None = None) -> None: + self._resolve = resolve + self._settings = {} if settings is None else settings + + def settings(self) -> dict[str, Any]: + return self._settings + + def claims(self) -> Sequence[ResultClaim[Any]]: + return [ResultClaim(result_type="receipt", model=ReceiptResult, resolve=self._resolve)] + + +class _ReceiptIssuer(Extension): + """Server half: answers `buy` with the claimed shape; every other tool passes through.""" + + identifier = _RECEIPTS + + async def intercept_tool_call( + self, params: types.CallToolRequestParams, ctx: ServerRequestContext[Any, Any], call_next: CallNext + ) -> HandlerResult: + if params.name != "buy": + return await call_next(ctx) + return {"resultType": "receipt", "requestState": "r-117"} + + +def _receipt_shop(issuer: Extension) -> MCPServer: + """An MCPServer whose `buy` tool the server extension rewrites into the claimed shape.""" + server = MCPServer("shop", extensions=[issuer]) + + @server.tool() + def buy(item: str) -> CallToolResult: + """Buy an item.""" + raise NotImplementedError # the server extension answers `buy` before the tool runs + + @server.tool() + def redeem(token: str) -> str: + """Exchange a receipt token for the goods.""" + return f"goods for {token}" + + return server + + +@requirement("extensions:client:claimed-result-resolved") +async def test_claimed_result_is_finished_by_the_owning_extensions_resolver(connect: Connect) -> None: + """The transparent claim path, both ends real: the server extension substitutes the + `receipt` shape, the client's claim resolver redeems it with a follow-up `tools/call` + through `ctx.session` — the same authority as `client.session` — and `call_tool` + returns the resolver's plain `CallToolResult`. The claimed shape never surfaces.""" + received: list[ReceiptResult] = [] + + async def redeem_receipt(claimed: ReceiptResult, ctx: ClaimContext) -> CallToolResult: + received.append(claimed) + return await ctx.session.call_tool("redeem", {"token": claimed.request_state}) + + async with connect(_receipt_shop(_ReceiptIssuer()), extensions=[Receipts(redeem_receipt)]) as client: + result = await client.call_tool("buy", {"item": "lamp"}) + + assert [claimed.request_state for claimed in received] == ["r-117"] + assert result == snapshot( + CallToolResult(content=[TextContent(text="goods for r-117")], structured_content={"result": "goods for r-117"}) + ) + + +@requirement("extensions:client:claimed-result-undeclared-invalid") +async def test_claimed_shape_fails_validation_for_a_client_without_the_extension(connect: Connect) -> None: + """Spec-mandated: an unrecognized `resultType` is invalid. A client that did not + construct the owning extension rejects the very shape the previous test resolves — + the request reaches the server, the substituted result fails client-side parsing.""" + async with connect(_receipt_shop(_ReceiptIssuer())) as client: + with pytest.raises(ValidationError): + await client.call_tool("buy", {"item": "lamp"}) + + +class _SettingsEchoIssuer(Extension): + """Server half for the ad tests: refuses non-declaring clients, then echoes the + declared settings back through the claimed payload.""" + + identifier = _RECEIPTS + + async def intercept_tool_call( + self, params: types.CallToolRequestParams, ctx: ServerRequestContext[Any, Any], call_next: CallNext + ) -> HandlerResult: + require_client_extension(ctx, _RECEIPTS) + client_params = ctx.session.client_params + assert client_params is not None # require_client_extension just read it + extensions = client_params.capabilities.extensions + assert extensions is not None + return {"resultType": "receipt", "requestState": json.dumps(extensions[_RECEIPTS], sort_keys=True)} + + +@requirement("extensions:client:capability-ad:gates-server-behaviour") +async def test_per_request_ad_carries_settings_and_gates_the_claimed_substitution(connect: Connect) -> None: + """The per-request `_meta` capability ad is the entitlement for claimed shapes: the + server extension's gate passes only for the declaring client, observes the declared + settings on the request, and the resolver receives them back through the payload. + A client declaring nothing is refused with -32021, not served the shape.""" + server = MCPServer("shop", extensions=[_SettingsEchoIssuer()]) + + @server.tool() + def buy(item: str) -> CallToolResult: + """Buy an item.""" + raise NotImplementedError # the server extension answers `buy` before the tool runs + + received: list[ReceiptResult] = [] + + async def keep(claimed: ReceiptResult, ctx: ClaimContext) -> CallToolResult: + received.append(claimed) + return CallToolResult(content=[TextContent(text="done")]) + + async with connect(server, extensions=[Receipts(keep, settings={"tier": "gold"})]) as client: + result = await client.call_tool("buy", {"item": "lamp"}) + assert result.content == [TextContent(text="done")] + assert [json.loads(claimed.request_state) for claimed in received] == [{"tier": "gold"}] + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("buy", {"item": "lamp"}) + assert exc_info.value.code == MISSING_REQUIRED_CLIENT_CAPABILITY + + +async def _unreachable_resolve(claimed: ReceiptResult, ctx: ClaimContext) -> CallToolResult: + raise NotImplementedError # no claimed shape can be delivered on a legacy wire + + +@requirement("extensions:client:capability-ad:legacy-omits-claimed") +async def test_legacy_ad_omits_claim_bearing_identifiers_but_keeps_claim_less_ones(connect: Connect) -> None: + """On a legacy connection the claims dissolve and the ad follows them: the + claim-bearing identifier is absent from the initialize capability ad the server + sees, while an ad-only identifier on the same client still advertises.""" + server = MCPServer("introspector") + + @server.tool() + def declared(ctx: Context) -> list[str]: + """Report the extension identifiers the client advertised.""" + capabilities = ctx.client_capabilities + assert capabilities is not None + return sorted(capabilities.extensions or {}) + + client_extensions = [Receipts(_unreachable_resolve), advertise(_FLAGS)] + async with connect(server, extensions=client_extensions) as client: + result = await client.call_tool("declared", {}) + + assert result.structured_content == {"result": [_FLAGS]} diff --git a/tests/interaction/transports/test_hosting_http_modern.py b/tests/interaction/transports/test_hosting_http_modern.py index 3feed4fed..31de9cd57 100644 --- a/tests/interaction/transports/test_hosting_http_modern.py +++ b/tests/interaction/transports/test_hosting_http_modern.py @@ -9,7 +9,7 @@ import json from collections.abc import Callable -from typing import Any +from typing import Any, Literal import anyio import httpx @@ -30,7 +30,9 @@ JSONRPCResponse, ListToolsResult, PaginatedRequestParams, + Request, RequestParams, + Result, ServerCapabilities, TextContent, Tool, @@ -551,3 +553,64 @@ async def on_request(request: httpx.Request) -> None: before, after = tool_calls assert before.headers.get("mcp-param-region") == "x" assert not any(k.startswith("mcp-param-") for k in after.headers) + + +class _JobParams(RequestParams): + job_id: str + + +class _JobStatusRequest(Request[_JobParams, Literal["com.example/jobs.status"]]): + """A vendor (extension) request type that names its subject for the Mcp-Name header.""" + + method: Literal["com.example/jobs.status"] = "com.example/jobs.status" + name_param = "jobId" + + +class _JobStatusResult(Result): + status: str + + +@requirement("client-transport:http:vendor-name-param-header") +async def test_vendor_request_with_name_param_carries_mcp_name_on_the_wire() -> None: + """A vendor request sent through `send_request` carries `Mcp-Name` from its `name_param` key. + + The request type is never registered with the client; `send_request` reads the declared + `name_param` ("jobId"), mirrors the params value into the `Mcp-Name` header, and the value + stays in the body unchanged. Asserted at the wire because the client never surfaces the + outgoing headers. The server serves the vendor method through `add_request_handler`, so the + round trip also proves the typed result comes back without any client-side method table. + """ + + async def job_status(ctx: ServerRequestContext, params: _JobParams) -> _JobStatusResult: + assert params.job_id == "job-7" + return _JobStatusResult(status="running") + + server = _server() + server.add_request_handler("com.example/jobs.status", _JobParams, job_status) + + requests: list[httpx.Request] = [] + + async def on_request(request: httpx.Request) -> None: + requests.append(request) + + discover = DiscoverResult( + supported_versions=[LATEST_MODERN_VERSION], + capabilities=ServerCapabilities(), + server_info=Implementation(name="srv", version="0"), + ) + with anyio.fail_after(5): + async with ( + mounted_app(server, on_request=on_request) as (http, _), + Client( + streamable_http_client(f"{BASE_URL}/mcp", http_client=http), + mode=LATEST_MODERN_VERSION, + prior_discover=discover, + ) as client, + ): + request = _JobStatusRequest(params=_JobParams(job_id="job-7")) + result = await client.session.send_request(request, _JobStatusResult) + + assert result.status == "running" + [wire_request] = requests + assert wire_request.headers["mcp-name"] == "job-7" + assert json.loads(wire_request.content)["params"]["jobId"] == "job-7" From 02b8519fd090128b55116aa0c45543bbf2168b61 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 30 Jun 2026 16:47:42 +0000 Subject: [PATCH 10/15] Claimed shapes carry vendor fields verbatim; say so honestly MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A short-circuiting tools/call interceptor's dict is passed through by the runner as a trusted well-formed result — nothing strips vendor top-level fields on the way to the client. The docs, tutorials, and test fixtures previously claimed the opposite and string-packed payloads into requestState; the claim models now carry real vendor fields end to end (the settings-echo test gains a dict-typed field instead of JSON string-packing). Also: the closed-surface bullet now says notification bindings shadowed by core vocabulary go quiet rather than being rejected, the core-subclass rule is scoped to the verb's result types, and the notification-binding contribution kind gets a deferred manifest entry naming its session-tier coverage. --- docs/advanced/extensions.md | 13 +++++----- docs_src/extensions/tutorial006.py | 6 ++--- tests/client/test_client_extensions.py | 26 +++++++++---------- tests/docs_src/test_extensions.py | 2 +- tests/interaction/_requirements.py | 13 ++++++++++ .../interaction/mcpserver/test_extensions.py | 22 ++++++++-------- 6 files changed, 47 insertions(+), 35 deletions(-) diff --git a/docs/advanced/extensions.md b/docs/advanced/extensions.md index 5945cf64f..832366aac 100644 --- a/docs/advanced/extensions.md +++ b/docs/advanced/extensions.md @@ -188,10 +188,10 @@ kinds, each with a default: `settings()`, `claims()`, and `notifications()`. is defined. * `claims()` returns `ResultClaim`s: a wire tag, the model that parses it, and the resolver that finishes it. The model must pin the tag with - `result_type: Literal["receipt"]` and must not subclass a core result type — both - enforced when the claim is constructed. (The payload rides `requestState` here - because an `MCPServer` substituting a claimed shape serializes only the core - `tools/call` surface fields; a server on another SDK may send richer shapes.) + `result_type: Literal["receipt"]` and must not subclass the verb's core result + types — both enforced when the claim is constructed. Vendor fields like + `receipt_token` ride the wire as-is: a substituted shape reaches the client + verbatim. * The resolver receives the parsed model and a `ClaimContext`; `ctx.session` is the same public handle as `client.session`, so follow-ups are ordinary session calls. It returns the verb's normal `CallToolResult`. @@ -238,8 +238,9 @@ resources, methods, one `tools/call` interceptor; on the client: settings, resul claims, notification bindings. An extension cannot: * **Reach into the host.** It declares data; it holds no server or client reference. -* **Replace core behaviour.** Spec methods are rejected at construction, and - `initialize` is reserved by the runner outright. +* **Replace core behaviour.** Spec methods and core result tags are rejected at + construction (`initialize` is reserved by the runner outright); a notification + binding shadowed by core vocabulary goes quiet with a warning instead. * **Register late.** After `MCPServer(...)` or `Client(...)` returns, the extension set is what it is. diff --git a/docs_src/extensions/tutorial006.py b/docs_src/extensions/tutorial006.py index ca7cb072f..55f3c50c1 100644 --- a/docs_src/extensions/tutorial006.py +++ b/docs_src/extensions/tutorial006.py @@ -16,7 +16,7 @@ class ReceiptResult(types.Result): """The claimed result shape; `result_type` pins the wire tag.""" result_type: Literal["receipt"] = "receipt" - request_state: str + receipt_token: str class ReceiptIssuer(Extension): @@ -32,7 +32,7 @@ async def intercept_tool_call( ) -> HandlerResult: if params.name != "buy": return await call_next(ctx) - return {"resultType": "receipt", "requestState": "r-117"} + return {"resultType": "receipt", "receiptToken": "r-117"} class Receipts(ClientExtension): @@ -44,7 +44,7 @@ def claims(self) -> Sequence[ResultClaim[Any]]: return [ResultClaim(result_type="receipt", model=ReceiptResult, resolve=self._redeem)] async def _redeem(self, claimed: ReceiptResult, ctx: ClaimContext) -> types.CallToolResult: - return await ctx.session.call_tool("redeem", {"token": claimed.request_state}) + return await ctx.session.call_tool("redeem", {"token": claimed.receipt_token}) mcp = MCPServer("shop", extensions=[ReceiptIssuer()]) diff --git a/tests/client/test_client_extensions.py b/tests/client/test_client_extensions.py index 4bd703cd0..e45ff87cc 100644 --- a/tests/client/test_client_extensions.py +++ b/tests/client/test_client_extensions.py @@ -3,10 +3,8 @@ Claimed-shape servers here are real `MCPServer`s whose SEP-2133 server extension rewrites `tools/call` results via `intercept_tool_call` — the full public-API loop. -The in-process server can only deliver claimed fields the v2026 tools/call surface -keeps (`resultType`, `requestState`, `inputRequests`, `_meta`): the server-side -`serialize_server_result` drops anything else, so claimed payloads here ride -`requestState`. +A short-circuiting interceptor's dict reaches the client verbatim (the runner trusts +it as a well-formed result), so the claimed models carry vendor top-level fields. `tools/call` is never cached (`Client.call_tool` has no `_cached_fetch` weave and the SEP-2549 cacheable verbs do not include it), so the claim path needs no cache tests. @@ -48,11 +46,11 @@ def _name_elicitation() -> types.ElicitRequest: class VoucherResult(Result): - """The claimed `tools/call` shape, tagged `voucher`; its payload rides `requestState` - (the only open payload-bearing field the in-process server's surface dump keeps).""" + """The claimed `tools/call` shape, tagged `voucher`, carrying a vendor top-level field + (a short-circuiting server interceptor's dict reaches the client verbatim).""" result_type: Literal["voucher"] = "voucher" - request_state: str | None = None + voucher_code: str | None = None _Resolver = Callable[[VoucherResult, ClaimContext], Awaitable[CallToolResult]] @@ -78,7 +76,7 @@ class _VoucherIssuer(Extension): async def intercept_tool_call( self, params: types.CallToolRequestParams, ctx: ServerRequestContext[Any, Any], call_next: CallNext ) -> HandlerResult: - return {"resultType": "voucher", "requestState": "v-42"} + return {"resultType": "voucher", "voucherCode": "v-42"} class _TwoRoundVoucherIssuer(Extension): @@ -91,7 +89,7 @@ async def intercept_tool_call( ) -> HandlerResult: if params.input_responses is None: return types.InputRequiredResult(input_requests={"user_name": _name_elicitation()}) - return {"resultType": "voucher", "requestState": "after-input"} + return {"resultType": "voucher", "voucherCode": "after-input"} def _voucher_server(issuer: Extension | None = None) -> MCPServer: @@ -414,7 +412,7 @@ async def test_claimed_result_resolves_transparently_to_the_resolvers_result() - async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: received.append(claimed) - product = CallToolResult(content=[TextContent(text=f"honored {claimed.request_state}")]) + product = CallToolResult(content=[TextContent(text=f"honored {claimed.voucher_code}")]) produced.append(product) return product @@ -423,7 +421,7 @@ async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: result = await client.call_tool("issue", {}) assert_type(result, CallToolResult) - assert [claimed.request_state for claimed in received] == ["v-42"] + assert [claimed.voucher_code for claimed in received] == ["v-42"] assert result is produced[0] assert result.content == [TextContent(text="honored v-42")] @@ -442,7 +440,7 @@ async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: async with Client(_voucher_server(), extensions=extensions) as client: result = await client.call_tool("issue", {}) - assert [claimed.request_state for claimed in received] == ["v-42"] + assert [claimed.voucher_code for claimed in received] == ["v-42"] assert result.content == [TextContent(text="routed")] @@ -586,7 +584,7 @@ async def elicitation_callback( async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: received.append(claimed) - return CallToolResult(content=[TextContent(text=f"honored {claimed.request_state}")]) + return CallToolResult(content=[TextContent(text=f"honored {claimed.voucher_code}")]) server = _voucher_server(issuer=_TwoRoundVoucherIssuer()) with anyio.fail_after(5): @@ -596,7 +594,7 @@ async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: result = await client.call_tool("issue", {}) assert prompted == ["What is your name?"] - assert [claimed.request_state for claimed in received] == ["after-input"] + assert [claimed.voucher_code for claimed in received] == ["after-input"] assert result.content == [TextContent(text="honored after-input")] diff --git a/tests/docs_src/test_extensions.py b/tests/docs_src/test_extensions.py index 346c2956e..195bbf14a 100644 --- a/tests/docs_src/test_extensions.py +++ b/tests/docs_src/test_extensions.py @@ -126,7 +126,7 @@ async def test_session_tier_allow_claimed_returns_the_raw_shape() -> None: async with Client(tutorial006.mcp, extensions=[tutorial006.Receipts()]) as client: result = await client.session.call_tool("buy", {"item": "lamp"}, allow_claimed=True) assert isinstance(result, tutorial006.ReceiptResult) - assert result.request_state == "r-117" + assert result.receipt_token == "r-117" async def test_the_jobs_client_program_runs_as_shown(capsys: pytest.CaptureFixture[str]) -> None: diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index acbfd4bfb..2fbe360a2 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -2428,6 +2428,19 @@ def __post_init__(self) -> None: ), arm_exclusions=(ArmExclusion(reason="requires-session", transport="streamable-http-stateless"),), ), + "extensions:client:notification-binding-delivery": Requirement( + source=f"{SPEC_2026_BASE_URL}/basic#resulttype", + behavior=( + "A vendor server notification bound by a ClientExtension's NotificationBinding is validated " + "against the binding's params type and delivered to its handler in arrival order." + ), + added_in="2026-07-28", + deferred=( + "Covered at session tier by tests/client/test_session_notification_bindings.py: no public " + "server-side surface emits vendor-method notifications (ServerNotification is a closed union), " + "and HTTP-modern arrival additionally needs the subscriptions/listen client runtime." + ), + ), # ═══════════════════════════════════════════════════════════════════════════ # Transports (in-suite coverage) # ═══════════════════════════════════════════════════════════════════════════ diff --git a/tests/interaction/mcpserver/test_extensions.py b/tests/interaction/mcpserver/test_extensions.py index a327cca9a..72d0e5777 100644 --- a/tests/interaction/mcpserver/test_extensions.py +++ b/tests/interaction/mcpserver/test_extensions.py @@ -2,12 +2,11 @@ The servers here are MCPServers whose server extension substitutes a claimed `tools/call` shape via `intercept_tool_call`; the client declares the owning `ClientExtension` and its -claim resolver finishes the call. The in-process server's 2026 result surface keeps only -`resultType` / `requestState` / `inputRequests` / `_meta` on a claimed result, so claimed -payloads here ride `requestState`. +claim resolver finishes the call. A short-circuiting interceptor's dict is passed through +verbatim (the runner trusts it as a well-formed result), so claimed shapes carry their +vendor fields end to end — the models below prove that with top-level vendor fields. """ -import json from collections.abc import Awaitable, Callable, Sequence from typing import Any, Literal @@ -32,10 +31,11 @@ class ReceiptResult(Result): - """The claimed `tools/call` shape, tagged `receipt`; its payload rides `requestState`.""" + """The claimed `tools/call` shape, tagged `receipt`, carrying vendor top-level fields.""" result_type: Literal["receipt"] = "receipt" - request_state: str + receipt_token: str + settings_echo: dict[str, Any] | None = None _Resolver = Callable[[ReceiptResult, ClaimContext], Awaitable[CallToolResult]] @@ -67,7 +67,7 @@ async def intercept_tool_call( ) -> HandlerResult: if params.name != "buy": return await call_next(ctx) - return {"resultType": "receipt", "requestState": "r-117"} + return {"resultType": "receipt", "receiptToken": "r-117"} def _receipt_shop(issuer: Extension) -> MCPServer: @@ -97,12 +97,12 @@ async def test_claimed_result_is_finished_by_the_owning_extensions_resolver(conn async def redeem_receipt(claimed: ReceiptResult, ctx: ClaimContext) -> CallToolResult: received.append(claimed) - return await ctx.session.call_tool("redeem", {"token": claimed.request_state}) + return await ctx.session.call_tool("redeem", {"token": claimed.receipt_token}) async with connect(_receipt_shop(_ReceiptIssuer()), extensions=[Receipts(redeem_receipt)]) as client: result = await client.call_tool("buy", {"item": "lamp"}) - assert [claimed.request_state for claimed in received] == ["r-117"] + assert [claimed.receipt_token for claimed in received] == ["r-117"] assert result == snapshot( CallToolResult(content=[TextContent(text="goods for r-117")], structured_content={"result": "goods for r-117"}) ) @@ -132,7 +132,7 @@ async def intercept_tool_call( assert client_params is not None # require_client_extension just read it extensions = client_params.capabilities.extensions assert extensions is not None - return {"resultType": "receipt", "requestState": json.dumps(extensions[_RECEIPTS], sort_keys=True)} + return {"resultType": "receipt", "receiptToken": "echo", "settingsEcho": extensions[_RECEIPTS]} @requirement("extensions:client:capability-ad:gates-server-behaviour") @@ -157,7 +157,7 @@ async def keep(claimed: ReceiptResult, ctx: ClaimContext) -> CallToolResult: async with connect(server, extensions=[Receipts(keep, settings={"tier": "gold"})]) as client: result = await client.call_tool("buy", {"item": "lamp"}) assert result.content == [TextContent(text="done")] - assert [json.loads(claimed.request_state) for claimed in received] == [{"tier": "gold"}] + assert [claimed.settings_echo for claimed in received] == [{"tier": "gold"}] async with connect(server) as client: with pytest.raises(MCPError) as exc_info: From 616e81bf1126c247c8e6e1400986fd2be7e6255f Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 30 Jun 2026 16:49:08 +0000 Subject: [PATCH 11/15] Move the identifier grammar matrix to the shared test home The accepted/rejected identifier parametrizations exercise the raw validator, which lives in mcp.shared.extension; the server suite keeps its class-level validation tests. --- tests/server/mcpserver/test_extension.py | 45 ---------------- tests/shared/test_extension.py | 65 +++++++++++++++++------- 2 files changed, 47 insertions(+), 63 deletions(-) diff --git a/tests/server/mcpserver/test_extension.py b/tests/server/mcpserver/test_extension.py index 2ed093ac0..b6ff0283d 100644 --- a/tests/server/mcpserver/test_extension.py +++ b/tests/server/mcpserver/test_extension.py @@ -27,7 +27,6 @@ ResourceBinding, ToolBinding, compose_tool_call_interceptor, - validate_extension_identifier, ) from mcp.server.mcpserver import Context, MCPServer, require_client_extension from mcp.server.mcpserver.resources import TextResource @@ -363,50 +362,6 @@ async def test_version_pinned_method_is_method_not_found_at_a_disallowed_version assert exc_info.value.error.data == "com.example/pinned" -@pytest.mark.parametrize( - "identifier", - [ - "io.modelcontextprotocol/ui", - "com.example/my_ext", - "com.x-y.z2/n.a-b_c", - "example/x", - "a/b", - "com.example/9start", - ], -) -def test_grammar_conformant_extension_identifiers_are_accepted(identifier: str) -> None: - """Spec `_meta` key grammar: dot-separated labels (letter start, letter/digit end, - hyphens interior), a slash, then a name that starts and ends alphanumeric.""" - validate_extension_identifier(identifier, owner="T") - - -@pytest.mark.parametrize( - "identifier", - [ - "noprefix", - "-foo/bar", - ".leading/x", - "a..b/x", - "foo-/x", - "9foo/x", - "foo/-bar", - "foo/bar-", - "foo/", - "/bar", - "foo/ba r", - "io.modelcontextprotocol/ui\n", - "", - None, - 42, - ], -) -def test_malformed_extension_identifiers_are_rejected(identifier: Any) -> None: - """Spec `_meta` key grammar: malformed prefixes (bad label start/end, empty labels) - and malformed names are rejected, as are non-strings.""" - with pytest.raises(TypeError): - validate_extension_identifier(identifier, owner="T") - - @pytest.mark.parametrize("method", ["tools/list", "completion/complete"]) def test_method_binding_rejects_spec_methods(method: str) -> None: """SDK-defined: extension methods are additive — binding a spec-defined request method diff --git a/tests/shared/test_extension.py b/tests/shared/test_extension.py index 4a94c75b4..87383d28a 100644 --- a/tests/shared/test_extension.py +++ b/tests/shared/test_extension.py @@ -1,31 +1,60 @@ """Tests for `mcp.shared.extension` — the extension-identifier grammar shared by -the server and client extension surfaces. +the server and client extension surfaces.""" -The grammar matrix (accepted and rejected identifiers) lives with the original -server tests in `tests/server/mcpserver/test_extension.py`, which exercise the -same function via the server module's re-export. -""" +from typing import Any import pytest import mcp.server.extension import mcp.shared.extension - - -def test_validator_importable_from_shared_home() -> None: - """SDK-defined: the identifier grammar lives in `mcp.shared.extension` — one - source of truth for both the server and client extension surfaces.""" - mcp.shared.extension.validate_extension_identifier("com.example/thing", owner="T") - - -def test_validator_rejects_malformed_identifier_via_shared_path() -> None: - """SDK-defined: the shared-home function enforces the same `vendor-prefix/name` - grammar the server side always has.""" - with pytest.raises(TypeError): - mcp.shared.extension.validate_extension_identifier("noprefix", owner="T") +from mcp.shared.extension import validate_extension_identifier def test_server_extension_module_reexports_shared_validator() -> None: """SDK-defined: `mcp.server.extension.validate_extension_identifier` remains importable after the move and is the very same function object.""" assert mcp.server.extension.validate_extension_identifier is mcp.shared.extension.validate_extension_identifier + + +@pytest.mark.parametrize( + "identifier", + [ + "io.modelcontextprotocol/ui", + "com.example/my_ext", + "com.x-y.z2/n.a-b_c", + "example/x", + "a/b", + "com.example/9start", + ], +) +def test_grammar_conformant_extension_identifiers_are_accepted(identifier: str) -> None: + """Spec `_meta` key grammar: dot-separated labels (letter start, letter/digit end, + hyphens interior), a slash, then a name that starts and ends alphanumeric.""" + validate_extension_identifier(identifier, owner="T") + + +@pytest.mark.parametrize( + "identifier", + [ + "noprefix", + "-foo/bar", + ".leading/x", + "a..b/x", + "foo-/x", + "9foo/x", + "foo/-bar", + "foo/bar-", + "foo/", + "/bar", + "foo/ba r", + "io.modelcontextprotocol/ui\n", + "", + None, + 42, + ], +) +def test_malformed_extension_identifiers_are_rejected(identifier: Any) -> None: + """Spec `_meta` key grammar: malformed prefixes (bad label start/end, empty labels) + and malformed names are rejected, as are non-strings.""" + with pytest.raises(TypeError): + validate_extension_identifier(identifier, owner="T") From 15ff3c205a974037a88ce69ed2de3f6792e231a6 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 30 Jun 2026 17:04:12 +0000 Subject: [PATCH 12/15] Fix core-arm sentinel collision and close binding queues on every exit A claim tagged 'core' could collide with the adapter's internal routing sentinel and hijack ordinary tools/call parsing; the sentinel is now derived to never equal a claimed tag. Binding queues close in finally blocks so a raising task-group exit cannot leak them. The mapping-form migration error now also covers an empty dict; claim dedup keys by resultType alone (the verb is single-valued at construction, and the activation map already keys by tag); the one-time private-spelling tree-grep test is gone; docstring backticks and a stray __future__ import brought in line with repo conventions. --- src/mcp/client/client.py | 26 ++++++++--------- src/mcp/client/session.py | 34 +++++++++++++--------- tests/client/test_client_extensions.py | 6 ++-- tests/client/test_send_request_mcp_name.py | 1 - tests/client/test_session_claims.py | 31 +++++++++++++++++++- tests/client/test_session_promotions.py | 16 ---------- 6 files changed, 65 insertions(+), 49 deletions(-) diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 00e9f4b9b..712934035 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -207,22 +207,22 @@ def _fold_extensions(extensions: Sequence[ClientExtension] | None) -> _FoldedExt Mirrors the server's consumption-time posture (`MCPServer._apply_extension`): a per-instance identifier is validated here because no class attribute existed to validate at definition time. `settings()` is read exactly once per extension and - the returned dict is held by reference. Duplicate `(method, resultType)` claims - and duplicate notification methods are rejected here, where both owning extensions + the returned dict is held by reference. Duplicate resultType claims and + duplicate notification methods are rejected here, where both owning extensions can be named — the session's own duplicate checks know only methods and tags. """ - if not extensions: - return _FoldedExtensions(ad=None, claims=None, bindings=None, by_model={}) if isinstance(extensions, Mapping): raise TypeError( "extensions= takes a sequence of ClientExtension instances; the mapping form was " "replaced — use advertise(identifier, settings) for advertise-only entries" ) + if not extensions: + return _FoldedExtensions(ad=None, claims=None, bindings=None, by_model={}) ad: dict[str, dict[str, Any]] = {} claims: dict[str, tuple[ResultClaim[Any], ...]] = {} bindings: list[NotificationBinding[Any]] = [] by_model: dict[type[Result], ResultClaim[Any]] = {} - claim_owners: dict[tuple[str, str], str] = {} + claim_owners: dict[str, str] = {} binding_owners: dict[str, str] = {} for extension in extensions: identifier = getattr(extension, "identifier", None) @@ -237,20 +237,18 @@ def _fold_extensions(extensions: Sequence[ClientExtension] | None) -> _FoldedExt ad[identifier] = extension.settings() extension_claims = tuple(extension.claims()) for claim in extension_claims: - key = (claim.method, claim.result_type) - if key in claim_owners: - owner = claim_owners[key] + tag = claim.result_type + if tag in claim_owners: + owner = claim_owners[tag] both = ( f"extension {identifier!r} claims" if owner == identifier else (f"extensions {owner!r} and {identifier!r} both claim") ) - raise ValueError( - f"{both} {claim.method!r} resultType {claim.result_type!r}; a wire tag can have only one resolver" - ) - claim_owners[key] = identifier - # Collision-free by construction: a model's `result_type` Literal pins it to - # exactly one tag, and each (method, tag) pair has exactly one owner. + raise ValueError(f"{both} resultType {tag!r}; a wire tag can have only one resolver") + claim_owners[tag] = identifier + # One model, one tag: the model's result_type Literal is pinned to exactly + # this tag at claim construction, so the type-keyed index cannot collide. by_model[claim.model] = claim if extension_claims: claims[identifier] = extension_claims diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 948ac2e0c..ed95165ac 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -241,6 +241,9 @@ def _build_call_tool_adapter( if not active: return _CallToolResultAdapter tags = frozenset(active) + core_arm = "core" + while core_arm in tags: # the routing sentinel must never collide with a claimed tag + core_arm += "-" def _route(value: Any) -> str: # pydantic hands the discriminator either the raw inbound dict or an @@ -251,9 +254,9 @@ def _route(value: Any) -> str: tag = cast("dict[str, Any]", value).get("resultType") else: tag = getattr(value, "result_type", None) - return tag if isinstance(tag, str) and tag in tags else "core" + return tag if isinstance(tag, str) and tag in tags else core_arm - arms: list[Any] = [Annotated[types.CallToolResult | types.InputRequiredResult, Tag("core")]] + arms: list[Any] = [Annotated[types.CallToolResult | types.InputRequiredResult, Tag(core_arm)]] arms += [Annotated[claim.model, Tag(tag)] for tag, claim in active.items()] # reduce(or_, ...) builds the Union dynamically; PEP-646 star-unpack needs py3.11+. return TypeAdapter(Annotated[reduce(or_, arms), Discriminator(_route)]) @@ -270,7 +273,7 @@ def _index_claims( advertised extension and no wire tag may be claimed twice. """ indexed: dict[str, tuple[ResultClaim[Any], ...]] = {} - seen: set[tuple[str, str]] = set() + seen: set[str] = set() for identifier, claims in (result_claims or {}).items(): if extensions is None or identifier not in extensions: raise ValueError( @@ -283,10 +286,9 @@ def _index_claims( "extension from the capability ad at every version — omit the key instead" ) for claim in claims: - key = (claim.method, claim.result_type) - if key in seen: - raise ValueError(f"duplicate result claim for {claim.method!r} resultType {claim.result_type!r}") - seen.add(key) + if claim.result_type in seen: + raise ValueError(f"duplicate result claim for resultType {claim.result_type!r}") + seen.add(claim.result_type) indexed[identifier] = tuple(claims) return indexed @@ -411,8 +413,10 @@ async def __aenter__(self) -> Self: # Shield the group's own scope (a new one would break LIFO exit) # so a pending outer cancellation cannot re-fire inside __aexit__. task_group.cancel_scope.shield = True - await task_group.__aexit__(None, None, None) - self._close_binding_queues() + try: + await task_group.__aexit__(None, None, None) + finally: + self._close_binding_queues() raise return self @@ -425,8 +429,10 @@ async def __aexit__( # Exit must not block: cancel the dispatcher, binding consumers, and in-flight callbacks. assert self._task_group is not None self._task_group.cancel_scope.cancel() - result = await self._task_group.__aexit__(exc_type, exc_val, exc_tb) - self._close_binding_queues() + try: + result = await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + finally: + self._close_binding_queues() await resync_tracer() return result @@ -975,15 +981,15 @@ async def call_tool( allow_input_required: When ``False`` (default), an `InputRequiredResult` from the server raises `RuntimeError`; when ``True``, it is returned so the caller can resolve the requests and retry. - allow_claimed: When ``False`` (default), a claimed (extension) result - shape raises `UnexpectedClaimedResult`; when ``True``, the parsed + allow_claimed: When `False` (default), a claimed (extension) result + shape raises `UnexpectedClaimedResult`; when `True`, the parsed claim model is returned for the caller to handle. Raises: RuntimeError: If the server returns an `InputRequiredResult` and ``allow_input_required`` is ``False``. UnexpectedClaimedResult: If a claimed result shape parses and - ``allow_claimed`` is ``False``; carries the parsed value. + `allow_claimed` is `False`; carries the parsed value. """ result = await self.send_request( types.CallToolRequest( diff --git a/tests/client/test_client_extensions.py b/tests/client/test_client_extensions.py index e45ff87cc..802d8a66f 100644 --- a/tests/client/test_client_extensions.py +++ b/tests/client/test_client_extensions.py @@ -185,7 +185,7 @@ def test_one_extension_claiming_a_tag_twice_reads_as_one_owner() -> None: Client(_add_server(), extensions=[_SelfConflictingClaims()]) assert str(exc_info.value) == snapshot( - "extension 'com.example/twice' claims 'tools/call' resultType 'twice'; a wire tag can have only one resolver" + "extension 'com.example/twice' claims resultType 'twice'; a wire tag can have only one resolver" ) @@ -252,8 +252,8 @@ def test_conflicting_claims_across_extensions_name_both_owners() -> None: Client(_add_server(), extensions=[_VoucherExtension(_unreachable_resolve), _RivalVoucherExtension()]) assert str(exc_info.value) == snapshot( - "extensions 'com.example/voucher' and 'com.example/rival' both claim 'tools/call' " - "resultType 'voucher'; a wire tag can have only one resolver" + "extensions 'com.example/voucher' and 'com.example/rival' both claim resultType " + "'voucher'; a wire tag can have only one resolver" ) diff --git a/tests/client/test_send_request_mcp_name.py b/tests/client/test_send_request_mcp_name.py index 8cd045787..3963445e7 100644 --- a/tests/client/test_send_request_mcp_name.py +++ b/tests/client/test_send_request_mcp_name.py @@ -7,7 +7,6 @@ `send_request` typing: a `Request[...]` subclass passes without a cast. """ -from __future__ import annotations from collections.abc import Mapping from typing import Any, Literal diff --git a/tests/client/test_session_claims.py b/tests/client/test_session_claims.py index 403453d11..8f275ef14 100644 --- a/tests/client/test_session_claims.py +++ b/tests/client/test_session_claims.py @@ -139,7 +139,7 @@ def test_duplicate_claim_tag_across_extensions_rejected() -> None: result_claims={_TASKS_EXT: [_task_claim()], _AD_ONLY_EXT: [_task_claim()]}, ) - assert str(exc_info.value) == snapshot("duplicate result claim for 'tools/call' resultType 'task'") + assert str(exc_info.value) == snapshot("duplicate result claim for resultType 'task'") def test_claims_keyed_to_unadvertised_extension_rejected() -> None: @@ -351,6 +351,35 @@ async def test_discover_probe_ad_drops_claim_identifiers_at_a_legacy_probe_versi assert "extensions" not in capabilities +class _CoreTaggedResult(Result): + """A claim whose wire tag collides with the adapter's internal routing sentinel.""" + + result_type: Literal["core"] = "core" + payload: str = "" + + +async def _resolve_core_tagged(result: _CoreTaggedResult, ctx: ClaimContext) -> CallToolResult: + raise NotImplementedError + + +@pytest.mark.anyio +async def test_claim_tagged_core_cannot_hijack_core_parsing() -> None: + """SDK-defined: "core" is not protocol vocabulary, so a claim may use it as a wire + tag — and the adapter's internal routing sentinel must not collide: ordinary tool + results still parse as core results, and a claimed `core` raw routes to the model.""" + claim = ResultClaim(result_type="core", model=_CoreTaggedResult, resolve=_resolve_core_tagged) + dispatcher = _RecordingDispatcher(tool_result={"resultType": "core", "payload": "p-1"}) + session = ClientSession(dispatcher=dispatcher, extensions={_TASKS_EXT: {}}, result_claims={_TASKS_EXT: [claim]}) + with anyio.fail_after(5): + async with session: + _adopt_modern(session) + ordinary = session._call_tool_adapter.validate_python(_COMPLETE_TOOL_RESULT) + claimed = await session.call_tool("t", {}, allow_claimed=True) + + assert isinstance(ordinary, CallToolResult) + assert isinstance(claimed, _CoreTaggedResult) + + # ── Routing through the adopt-built adapter ───────────────────────────────── diff --git a/tests/client/test_session_promotions.py b/tests/client/test_session_promotions.py index f944ef7e0..dfceca302 100644 --- a/tests/client/test_session_promotions.py +++ b/tests/client/test_session_promotions.py @@ -1,7 +1,5 @@ """`dispatch_input_request` and `validate_tool_result` are public `ClientSession` API.""" -import re -from pathlib import Path import mcp_types as types import pytest @@ -68,17 +66,3 @@ async def test_validate_tool_result_raises_on_schema_mismatch() -> None: # Stable SDK prefix only: the message tail is jsonschema text that shifts with the dependency. with pytest.raises(RuntimeError, match="Invalid structured content returned by tool t"): await client.session.validate_tool_result("t", CallToolResult(content=[], structured_content={"x": "no"})) - - -def _spell_private(name: str) -> str: - return f"_{name}" - - -def test_no_private_spelling_references_remain() -> None: - """The promotions are renames, not aliases — the old private names are gone from `src/`.""" - pattern = re.compile(f"{_spell_private('dispatch_input_request')}|{_spell_private('validate_tool_result')}") - src = Path(__file__).resolve().parents[2] / "src" - offenders = [ - (path.name, match) for path in sorted(src.rglob("*.py")) for match in pattern.findall(path.read_text()) - ] - assert not offenders From 815784660302f875bad336e52c1c269801be7495 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 30 Jun 2026 17:36:54 +0000 Subject: [PATCH 13/15] Tighten comments and docstrings across the extension surface Comments state only non-inferable constraints, one line where possible; docstrings follow Google format with a single-sentence summary and tightened Raises sections; development narration and restated-code comments are gone. Prose sticks to ASCII. --- docs/advanced/extensions.md | 23 ++-- docs/migration.md | 14 +- examples/stories/custom_methods/README.md | 3 +- examples/stories/custom_methods/client.py | 3 +- src/mcp-types/mcp_types/_types.py | 8 +- src/mcp/client/client.py | 46 ++----- src/mcp/client/extension.py | 115 ++++------------ src/mcp/client/session.py | 112 +++++---------- src/mcp/server/extension.py | 3 +- src/mcp/shared/extension.py | 7 +- tests/client/test_client_extensions.py | 129 ++++++------------ tests/client/test_extension.py | 93 ++++--------- tests/client/test_send_request_mcp_name.py | 39 ++---- tests/client/test_session_claims.py | 123 +++++------------ .../test_session_notification_bindings.py | 55 +++----- tests/client/test_session_promotions.py | 4 +- tests/docs_src/test_extensions.py | 12 +- tests/interaction/_requirements.py | 4 +- .../interaction/mcpserver/test_extensions.py | 42 ++---- .../transports/test_hosting_http_modern.py | 12 +- tests/shared/test_extension.py | 12 +- tests/types/test_request_name_param.py | 2 +- 22 files changed, 250 insertions(+), 611 deletions(-) diff --git a/docs/advanced/extensions.md b/docs/advanced/extensions.md index 832366aac..3646fdf0c 100644 --- a/docs/advanced/extensions.md +++ b/docs/advanced/extensions.md @@ -157,16 +157,16 @@ client-side behaviour behind one identifier. Pass instances to `call_tool("buy", ...)` returns a plain `CallToolResult`, like every other call. What the extension changed: the server may now answer `buy` with a `receipt` **result -shape** instead of a final result, and `Receipts` finishes it — here by redeeming the -receipt with a follow-up call — before `call_tool` returns. Nothing about the call +shape** instead of a final result, and `Receipts` finishes it (here by redeeming the +receipt with a follow-up call) before `call_tool` returns. Nothing about the call site moves. Drop the extension and none of this exists: a `receipt` shape arriving at a client that didn't declare it fails validation, exactly as the spec requires for an unrecognized `resultType`. Off by default, on both ends of the wire. -To advertise an identifier with **no** client-side behaviour — the server gates on -the capability, the client does nothing, as in the search client above — use +To advertise an identifier with **no** client-side behaviour (the server gates on +the capability, the client does nothing, as in the search client above), use `advertise()`: ```python @@ -189,7 +189,7 @@ kinds, each with a default: `settings()`, `claims()`, and `notifications()`. * `claims()` returns `ResultClaim`s: a wire tag, the model that parses it, and the resolver that finishes it. The model must pin the tag with `result_type: Literal["receipt"]` and must not subclass the verb's core result - types — both enforced when the claim is constructed. Vendor fields like + types; both are enforced when the claim is constructed. Vendor fields like `receipt_token` ride the wire as-is: a substituted shape reaches the client verbatim. * The resolver receives the parsed model and a `ClaimContext`; `ctx.session` is the @@ -210,11 +210,10 @@ or reply. Two quiet rules. Claims are active on 2026-07-28 connections only, and the capability ad follows them: on a legacy connection the claims dissolve and the identifier drops -out of the ad in the same breath, so the client never advertises an extension whose -shapes it would reject. And when you want the claimed shape yourself instead of the -resolver, call `client.session.call_tool(..., allow_claimed=True)` — the escape hatch -`UnexpectedClaimedResult` names when a claimed shape reaches a session-tier caller -that didn't opt in. +out of the ad with them, so the client never advertises an extension whose shapes it +would reject. And when you want the claimed shape yourself instead of the resolver, +call `client.session.call_tool(..., allow_claimed=True)`; without that flag, a +claimed shape reaching a session-tier caller raises `UnexpectedClaimedResult`. ### Extension verbs @@ -233,8 +232,8 @@ missing value fails loudly rather than silently omitting a required header. ## What an extension cannot do -The contribution surface is **closed** on purpose — on the server: settings, tools, -resources, methods, one `tools/call` interceptor; on the client: settings, result +The contribution surface is **closed** on purpose. On the server: settings, tools, +resources, methods, one `tools/call` interceptor. On the client: settings, result claims, notification bindings. An extension cannot: * **Reach into the host.** It declares data; it holds no server or client reference. diff --git a/docs/migration.md b/docs/migration.md index c377d6742..eb6f68a71 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -483,9 +483,9 @@ default and never alter behaviour unless registered. (The low-level Changed in the v2 pre-releases: earlier alphas took `Client(extensions={identifier: settings})`, an advertisement-only dict. -Extensions now contribute behaviour — claims and notification handlers — not -just an ad, and a sequence of declaration objects is the shape that can carry -that. An ad-only entry becomes an `advertise()` call: +Extensions now contribute behaviour (claims and notification handlers), not +just an ad, so the argument is a sequence of declaration objects. An ad-only +entry becomes an `advertise()` call: **Before (v2 alphas):** @@ -501,10 +501,10 @@ from mcp.client import advertise client = Client(server, extensions=[advertise("com.example/ui", {"mimeTypes": [...]})]) ``` -`advertise()` is only for identifiers with no client-side behaviour. -For a behavioural extension — e.g. tasks, once its extension ships — construct -that extension's object instead; advertising an identifier you do not -implement asserts wire support you don't have. +`advertise()` is only for identifiers with no client-side behaviour. For a +behavioural extension (e.g. tasks, once its extension ships), construct that +extension's object instead; advertising an identifier you do not implement +asserts wire support you don't have. ### `McpError` renamed to `MCPError` diff --git a/examples/stories/custom_methods/README.md b/examples/stories/custom_methods/README.md index 817437b3a..96d223fef 100644 --- a/examples/stories/custom_methods/README.md +++ b/examples/stories/custom_methods/README.md @@ -29,8 +29,7 @@ uv run python -m stories.custom_methods.client --http collide with a future spec method. - `client.py` `client.session.send_request(...)` — `Client` only exposes spec verbs, so vendor methods go through the underlying `ClientSession`. - `send_request` accepts any `types.Request` subclass, so the vendor request - passes as-is, no cast. + `send_request` accepts any `types.Request` subclass. ## Caveats diff --git a/examples/stories/custom_methods/client.py b/examples/stories/custom_methods/client.py index 5282d584d..7bf27dd76 100644 --- a/examples/stories/custom_methods/client.py +++ b/examples/stories/custom_methods/client.py @@ -27,8 +27,7 @@ async def main(target: Target, *, mode: str = "auto") -> None: # `Client` only exposes spec-defined verbs, so vendor methods have to drop one # layer to `client.session` today — there is no `Client`-level API for them # yet, and whether `.session` stays public is undecided. `send_request` - # accepts any `Request` subclass; the unknown method skips the per-spec - # result-validation registry. + # accepts any `Request` subclass. request = SearchRequest(params=SearchParams(query="mcp", limit=3)) result = await client.session.send_request(request, SearchResult) assert result.items == ["mcp-0", "mcp-1", "mcp-2"], result diff --git a/src/mcp-types/mcp_types/_types.py b/src/mcp-types/mcp_types/_types.py index 2b5698770..9c0516836 100644 --- a/src/mcp-types/mcp_types/_types.py +++ b/src/mcp-types/mcp_types/_types.py @@ -129,10 +129,10 @@ class Request(MCPModel, Generic[RequestParamsT, MethodT]): params: RequestParamsT name_param: ClassVar[str | None] = None - """Wire-params key mirrored into the `Mcp-Name` header on sends (SEP-2243 - family; SEP-2663 requires it for tasks/*). The request type declares; the - host emits. Subclasses override by bare assignment (`name_param = "taskId"`) - — re-annotating as `ClassVar[str]` trips pyright's ClassVar invariance.""" + """Wire-params key mirrored into the `Mcp-Name` header on sends; SEP-2663 requires it for tasks/*. + + Subclasses override by bare assignment: re-annotating as `ClassVar` trips pyright's invariance check. + """ class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]): diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 712934035..c2b34768c 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -202,19 +202,11 @@ class _FoldedExtensions: def _fold_extensions(extensions: Sequence[ClientExtension] | None) -> _FoldedExtensions: - """Validate extension instances and fold their contributions, once, at `Client` construction. - - Mirrors the server's consumption-time posture (`MCPServer._apply_extension`): a - per-instance identifier is validated here because no class attribute existed to - validate at definition time. `settings()` is read exactly once per extension and - the returned dict is held by reference. Duplicate resultType claims and - duplicate notification methods are rejected here, where both owning extensions - can be named — the session's own duplicate checks know only methods and tags. - """ + """Fold extension contributions at construction, naming both owners on duplicate tags or methods.""" if isinstance(extensions, Mapping): raise TypeError( - "extensions= takes a sequence of ClientExtension instances; the mapping form was " - "replaced — use advertise(identifier, settings) for advertise-only entries" + "extensions= takes a sequence of ClientExtension instances. The mapping form was " + "replaced: use advertise(identifier, settings) for advertise-only entries" ) if not extensions: return _FoldedExtensions(ad=None, claims=None, bindings=None, by_model={}) @@ -247,8 +239,7 @@ def _fold_extensions(extensions: Sequence[ClientExtension] | None) -> _FoldedExt ) raise ValueError(f"{both} resultType {tag!r}; a wire tag can have only one resolver") claim_owners[tag] = identifier - # One model, one tag: the model's result_type Literal is pinned to exactly - # this tag at claim construction, so the type-keyed index cannot collide. + # Each model pins its result_type Literal to one tag, so this index cannot collide. by_model[claim.model] = claim if extension_claims: claims[identifier] = extension_claims @@ -349,13 +340,9 @@ async def main(): extensions: Sequence[ClientExtension] | None = None """Opt-in client extensions (SEP-2133). - Each instance contributes its capability ad (advertised under - `ClientCapabilities.extensions`), its result claims (extra `tools/call` result - shapes that `call_tool` resolves transparently through the claim's resolver), - and its notification bindings. For an ad-only entry — an identifier plus - settings, no client-side behaviour — use `mcp.client.advertise(identifier, - settings)`. Each extension's `settings()` is read once, at construction; the - returned dict is held by reference.""" + Each instance contributes its capability ad, its result claims (resolved + transparently by `call_tool`), and its notification bindings. For an + ad-only entry use `mcp.client.advertise(identifier, settings)`.""" cache: CacheConfig | Literal[False] | None = None """Client-side response caching for the SEP-2549 cacheable methods (2026-07-28). @@ -701,11 +688,9 @@ async def call_tool( persist `request_state` across process restarts — use `client.session.call_tool(..., allow_input_required=True)`. - If the server returns a result shape claimed by one of this client's - `extensions`, the owning claim's resolver finishes the call and its - `CallToolResult` is returned — the claimed shape never surfaces here. - Resolver exceptions propagate as-is; the extension owns its error - vocabulary. To receive the claimed shape yourself, use + Result shapes claimed by this client's `extensions` are finished by the + owning claim's resolver, whose `CallToolResult` is returned; resolver + exceptions propagate as-is. To receive the claimed shape yourself, use `client.session.call_tool(..., allow_claimed=True)`. Args: @@ -736,26 +721,21 @@ async def retry(r: InputResponses | None, s: str | None) -> CallToolResult | Inp request_state=s, meta=meta, allow_input_required=True, - # The driver's retry leg must also admit claimed shapes — the spec - # resolves multi-round-trip input before a claimed result, so a claim - # may terminate any round, not just the first. + # Input rounds resolve before a claimed result, so a claim may end any round. allow_claimed=True, ) result = await self._drive_input_required(await retry(input_responses, request_state), retry) if isinstance(result, CallToolResult): return result - # Only claimed shapes escape the parse (`_drive_input_required` never returns an - # `InputRequiredResult`), so the lookup is total; a KeyError here is an SDK bug. + # Only claimed shapes reach this point, so the lookup is total. claim = self._folded_extensions.by_model[type(result)] final = await claim.resolve( result, ClaimContext(session=self.session, tool_name=name, read_timeout_seconds=read_timeout_seconds), ) if not final.is_error: - # The resolver's product gets the same output-schema revalidation as the - # direct path (`ClientSession.call_tool`'s own guard); isError results - # must not raise, also matching the direct path. + # Match the direct path: revalidate the output schema, but never for isError results. await self.session.validate_tool_result(name, final) return final diff --git a/src/mcp/client/extension.py b/src/mcp/client/extension.py index d7cdb2522..adc15c1b7 100644 --- a/src/mcp/client/extension.py +++ b/src/mcp/client/extension.py @@ -1,11 +1,8 @@ """Opt-in extension interface for MCP clients. -To make an extension: subclass `ClientExtension`, set `identifier`, and -override whichever of `settings()` / `claims()` / `notifications()` apply. To -use one: pass instances to `Client(extensions=[...])` — the client folds the -declarations into its own machinery; the extension never receives the client. -To advertise an extension identifier with no client-side behaviour, use -`advertise()`. +Subclass `ClientExtension`, set `identifier`, override the hooks you need, and +pass instances to `Client(extensions=[...])`. For an identifier-only +capability ad, use `advertise()`. """ from __future__ import annotations @@ -33,7 +30,7 @@ ] _CLAIM_METHODS: Final[frozenset[str]] = frozenset({"tools/call"}) -"""The closed set of verbs a claim may attach to (widened with the `method` Literal).""" +"""The closed set of verbs a claim may attach to; widen together with the `method` Literal.""" ClaimedT = TypeVar("ClaimedT", bound=Result) NotifyParamsT = TypeVar("NotifyParamsT", bound=BaseModel) @@ -41,12 +38,7 @@ @dataclass(frozen=True, kw_only=True) class ClaimContext: - """Host-injected context for one `ResultClaim.resolve` call. - - `session` is the sanctioned public low-level handle — the same one users - already reach via `client.session`; the resolver gets no `Client` and no - new authority. - """ + """Host-injected context for one `ResultClaim.resolve` call.""" session: ClientSession tool_name: str @@ -57,25 +49,10 @@ class ClaimContext: class ResultClaim(Generic[ClaimedT]): """One extra result shape on one spec verb, keyed by the wire `resultType`. - A claim is active only while the declaring extension is constructed in AND - the negotiated version admits it; otherwise parsing stays byte-identical to - a claim-less client, so an undeclared shape still fails validation — the - supported `resultType` set is always core plus declared claims. - - `resolve` finishes a claimed result on the transparent path: it may send - follow-ups through `ctx.session` and must return the verb's ordinary - result. It is required — a claim nothing can finish would be useless. A - package that wants explicit-only handling ships a resolver that raises a - typed error naming `session.call_tool(allow_claimed=True)`, which is also - how callers reach the undriven shape per-call. - - `model` must declare `result_type` as a Literal of exactly the claimed tag, - and must not subclass a core result type — a core subclass would satisfy - the session's isinstance branches and bypass claim routing. `protocol_versions`, - when set, restricts the claim to a subset of the modern protocol revisions; - `None` (the default) means every modern version. The modern floor is - structural, not a restriction: claimed shapes cannot be delivered on a - legacy wire. All of this is enforced at construction. + Active only while the declaring extension is constructed into the client and + the negotiated protocol version admits it. `resolve` finishes a claimed + result, may send follow-ups through `ctx.session`, and must return the + verb's ordinary result. All field constraints are enforced at construction. """ result_type: str @@ -105,15 +82,10 @@ def __post_init__(self) -> None: class UnexpectedClaimedResult(RuntimeError): - """A claimed (extension) result shape arrived on a `call_tool` that did not opt in. - - Raised by `ClientSession.call_tool` when a claimed shape parses and - `allow_claimed` is False. By the time this raises the server may have - durably created state (e.g. a task) — the parsed value is carried as - `result` so the caller can reach its id to clean up, not just read a - message. To handle claimed shapes, pass the owning extension to - `Client(extensions=[...])` (the transparent path) or call with - `allow_claimed=True` and handle the shape yourself. + """A claimed (extension) result arrived on a `call_tool` that did not opt in. + + The parsed value is carried as `result`; the server may already hold state it + references. Opt in via `Client(extensions=[...])` or `allow_claimed=True`. """ def __init__(self, result: Result) -> None: @@ -127,23 +99,11 @@ def __init__(self, result: Result) -> None: @dataclass(frozen=True, kw_only=True) class NotificationBinding(Generic[NotifyParamsT]): - """Deliver server notifications for `method` to `handler` (unbound methods stay silently dropped). - - Observation-only: the handler receives validated params, returns None, and - cannot short-circuit anything. Delivery is per-binding serialized through a - bounded FIFO — one consumer task per binding, so a handler sees events in - arrival order and may do session I/O without deadlocking the in-process - dispatch path; on overflow the oldest event is dropped with a warning - (observation semantics make the drop acceptable). - - There is deliberately no spec-table check at construction: bindings are - consulted only for methods the negotiated version's core tables do NOT - know, so they are additive by construction. If a future core version - adopts the method, the binding goes quiet — detected and warned once at - activation, not per delivery — instead of import-erroring every package. - - `method` is the bare wire name (e.g. `notifications/tasks`); `params_type` - validates the notification params before `handler` runs. + """Deliver server notifications for `method` (the bare wire name) to `handler`. + + Observation-only: validated params arrive in order through a bounded queue, + dropping the oldest with a warning on overflow. Methods the negotiated + version's core tables handle are never delivered to bindings. """ method: str @@ -152,15 +112,9 @@ class NotificationBinding(Generic[NotifyParamsT]): class ClientExtension: - """Base class for an opt-in client extension. Override only what you need. - - Mirror of `mcp.server.extension.Extension` in feel: a closed declarative - surface, fixed at construction, that never receives the client. The - contribution kinds are the ones a 2026 client actually has — there is - deliberately no served-request kind (servers do not initiate requests) and - no open interceptor (the only sanctioned augmentation is extension - `resultType` values, and a claim already names its owner, so composition - and ordering questions dissolve by construction). + """Base class for an opt-in client extension; override only what you need. + + The surface is declarative, fixed at construction, and never receives the client. """ #: Reverse-DNS extension identifier, advertised under `ClientCapabilities.extensions`. @@ -168,25 +122,16 @@ class ClientExtension: def __init_subclass__(cls, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) - # Validate a class-level `identifier` at definition time. A subclass may - # instead assign `identifier` in `__init__` (per-instance ids); that case - # is validated when the extension is consumed, since no class attribute - # exists to inspect here. + # Per-instance identifiers (assigned in __init__) are validated at consumption instead. if (identifier := cls.__dict__.get("identifier")) is not None: validate_extension_identifier(identifier, owner=cls.__name__) def settings(self) -> dict[str, Any]: """Per-extension settings advertised at `ClientCapabilities.extensions[identifier]`. - Read ONCE at `Client` construction — dynamic per-request settings are - out of scope. An empty dict (the default) advertises the extension with - no settings. - - A claim-bearing extension's identifier is advertised only at protocol - versions where at least one of its claims is active: the ad and the - claims dissolve together, so the client never advertises an extension - on a request whose claimed result shapes it would reject. Claim-less - extensions advertise at every version. + Read once at `Client` construction. A claim-bearing extension is + advertised only at protocol versions where at least one of its claims + is active. """ return {} @@ -200,7 +145,7 @@ def notifications(self) -> Sequence[NotificationBinding[Any]]: class _AdvertiseOnly(ClientExtension): - """Ad-only extension returned by `advertise()`: an identifier plus captured settings.""" + """Ad-only extension returned by `advertise()`.""" def __init__(self, identifier: str, settings: dict[str, Any]) -> None: self.identifier = identifier @@ -213,12 +158,8 @@ def settings(self) -> dict[str, Any]: def advertise(identifier: str, settings: dict[str, Any] | None = None) -> ClientExtension: """Advertise an extension identifier (with optional settings) and nothing else. - Returns an extension that contributes only the capability ad: no claims, no - notification bindings. The identifier is validated eagerly, at this call. - - WARNING: advertising an extension you do not implement asserts wire support - you don't have — for behavioral extensions (e.g. tasks) construct the real - extension object instead. + Advertising an extension you do not implement asserts wire support you do + not have; for behavioral extensions construct the real extension instead. """ validate_extension_identifier(identifier, owner="advertise") return _AdvertiseOnly(identifier, {} if settings is None else settings) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index ed95165ac..422c74f6a 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -194,8 +194,7 @@ async def _default_logging_callback( ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData) -# Declared against the wide tools/call parse union so the session's adopt-built -# claim adapters (which add `Result` arms) share one attribute type with it. +# Typed against the wide parse union so adopt-built claim adapters share this attribute type. _CallToolResultAdapter: TypeAdapter[types.CallToolResult | types.InputRequiredResult | types.Result] = TypeAdapter( types.CallToolResult | types.InputRequiredResult ) @@ -217,11 +216,7 @@ def _claim_active(claim: ResultClaim[Any], version: str) -> bool: def _active_claims_at( claims_by_extension: Mapping[str, tuple[ResultClaim[Any], ...]], version: str ) -> dict[str, ResultClaim[Any]]: - """Claims active at `version`, keyed by wire tag (unique across extensions by construction). - - Empty at any legacy version: no claim is active there, so both `adopt()` arms - share this one rule. - """ + """Claims active at `version`, keyed by wire tag; empty at any legacy version.""" return { claim.result_type: claim for claims in claims_by_extension.values() @@ -233,11 +228,7 @@ def _active_claims_at( def _build_call_tool_adapter( active: Mapping[str, ResultClaim[Any]], ) -> TypeAdapter[types.CallToolResult | types.InputRequiredResult | types.Result]: - """Discriminated tools/call adapter: a core arm plus one arm per active claim. - - Zero active claims returns the module-level `_CallToolResultAdapter` itself, keeping - the no-extensions parse path byte-identical. - """ + """Build a discriminated tools/call adapter: a core arm plus one arm per active claim.""" if not active: return _CallToolResultAdapter tags = frozenset(active) @@ -246,10 +237,8 @@ def _build_call_tool_adapter( core_arm += "-" def _route(value: Any) -> str: - # pydantic hands the discriminator either the raw inbound dict or an - # already-built model (revalidation). A non-string or unknown tag stays on - # the core arm, so a malformed `resultType` fails core validation instead - # of blowing up the discriminator lookup. + # pydantic hands the discriminator either the raw dict or an already-built model. + # Unknown or non-string tags route to the core arm and fail core validation there. if isinstance(value, dict): tag = cast("dict[str, Any]", value).get("resultType") else: @@ -258,7 +247,7 @@ def _route(value: Any) -> str: arms: list[Any] = [Annotated[types.CallToolResult | types.InputRequiredResult, Tag(core_arm)]] arms += [Annotated[claim.model, Tag(tag)] for tag, claim in active.items()] - # reduce(or_, ...) builds the Union dynamically; PEP-646 star-unpack needs py3.11+. + # reduce(or_) rather than Union star-unpack, which needs py3.11+. return TypeAdapter(Annotated[reduce(or_, arms), Discriminator(_route)]) @@ -266,12 +255,7 @@ def _index_claims( result_claims: Mapping[str, Sequence[ResultClaim[Any]]] | None, extensions: dict[str, dict[str, Any]] | None, ) -> dict[str, tuple[ResultClaim[Any], ...]]: - """Validate and copy the claims-by-extension mapping. - - The mapping keys ARE the claim/ad association: at adopt the capability ad and the - claims dissolve together per extension identifier, so a key must name an - advertised extension and no wire tag may be claimed twice. - """ + """Validate and copy the claims-by-extension mapping.""" indexed: dict[str, tuple[ResultClaim[Any], ...]] = {} seen: set[str] = set() for identifier, claims in (result_claims or {}).items(): @@ -282,8 +266,8 @@ def _index_claims( ) if not claims: raise ValueError( - f"result_claims[{identifier!r}] is empty; an empty claim set would drop the " - "extension from the capability ad at every version — omit the key instead" + f"result_claims[{identifier!r}] is empty and would drop the extension from " + "the capability ad at every version. Omit the key instead" ) for claim in claims: if claim.result_type in seen: @@ -322,11 +306,8 @@ class ClientSession: callbacks. Transport `Exception` items reach `message_handler` only when the session builds its own dispatcher from a stream pair. - Extension contributions enter here too: `result_claims` (keyed by the - advertising identifier in `extensions`, so the capability ad and the claims - dissolve together) fold into tools/call parsing at `adopt()`, and - `notification_bindings` observe vendor notifications through per-binding - bounded FIFOs. + Extension `result_claims` fold into tools/call parsing at `adopt()`; + `notification_bindings` observe vendor notifications via bounded FIFOs. """ def __init__( @@ -394,8 +375,7 @@ async def __aenter__(self) -> Self: self._task_group = anyio.create_task_group() await self._task_group.__aenter__() try: - # Queues exist before the dispatcher can deliver: _on_notify may run as - # soon as the dispatcher starts, and its enqueue indexes this dict. + # Queues must exist before the dispatcher starts: _on_notify enqueues into this dict. for binding in self._notification_bindings.values(): send, receive = anyio.create_memory_object_stream[BaseModel](_NOTIFICATION_QUEUE_SIZE) self._binding_queues[binding.method] = (send, receive) @@ -437,8 +417,7 @@ async def __aexit__( return result def _close_binding_queues(self) -> None: - # Memory object streams warn at garbage collection unless closed; the consumers - # die by task-group cancellation, so both ends are closed here (close is idempotent). + # Unclosed memory object streams warn at garbage collection; close is idempotent. for send, receive in self._binding_queues.values(): send.close() receive.close() @@ -447,18 +426,13 @@ def _close_binding_queues(self) -> None: async def _deliver_bound_notifications( self, binding: NotificationBinding[Any], receive: MemoryObjectReceiveStream[BaseModel] ) -> None: - """Serialized consumer for one binding's FIFO. - - Spawn-decoupled from the dispatcher so the handler may do session I/O without - deadlocking in-process delivery; dies with the session's task group. - """ + """Consume one binding's FIFO, decoupled from the dispatcher so handlers can do session I/O.""" while True: params = await receive.receive() try: await binding.handler(params) except Exception: - # Same containment contract as the notification callbacks in `_on_notify`: - # a raising handler costs only that delivery. + # A raising handler costs only that delivery, as in _on_notify. logger.exception("notification binding handler for %r raised", binding.method) async def send_request( @@ -477,15 +451,13 @@ async def send_request( Raises: MCPError: Error response, read timeout, or connection closed. RuntimeError: Called before entering the context manager. - ValueError: The request type declares `name_param` but the params - carry no string value under that key for the `Mcp-Name` header. + ValueError: The request declares `name_param` but its params carry no string name. """ data = request.model_dump(by_alias=True, mode="json", exclude_none=True) method: str = data["method"] opts: CallOptions = {} self._stamp(data, opts) - # Presence-keyed so the stamp's NAME_BEARING_METHODS rows win by ordering; - # a missing/non-string name fails loud rather than omitting a MUST header. + # The stamp runs first, so its NAME_BEARING_METHODS rows win; a missing name fails loud. headers = opts.setdefault("headers", {}) if (key := type(request).name_param) is not None and MCP_NAME_HEADER not in headers: params_data: dict[str, Any] = data.get("params") or {} @@ -533,11 +505,7 @@ async def send_notification(self, notification: types.ClientNotification) -> Non def _build_capabilities(self, version: str) -> types.ClientCapabilities: """Build the capability ad for a wire speaking `version`. - An identifier with result claims contributes to the ad only when at least one - of its claims is active at `version` — the ad and the claims dissolve - together, so the client never advertises an extension on a request whose - claimed result shapes it would reject. Claim-less identifiers always - contribute; when every identifier drops, the ad omits `extensions` entirely. + Identifiers with no active claim drop, so the client never advertises result shapes it would reject. """ extensions = self._extensions if extensions is not None and self._result_claims: @@ -576,8 +544,7 @@ async def initialize(self) -> types.InitializeResult: types.InitializeRequest( params=types.InitializeRequestParams( protocol_version=LATEST_HANDSHAKE_VERSION, - # The handshake can only negotiate legacy versions, where no claim is - # active — every claim-bearing identifier drops from this ad. + # The handshake negotiates only legacy versions, where no claim is active. capabilities=self._build_capabilities(LATEST_HANDSHAKE_VERSION), client_info=self._client_info, ), @@ -624,14 +591,12 @@ def adopt(self, result: types.InitializeResult | types.DiscoverResult) -> None: self._initialize_result = result self._discover_result = None self._negotiated_version = version - # Assigned fresh in both arms (re-adopt safe): empty at any legacy version by the - # one activation rule. Claims tagged with core vocabulary are unconstructible - # (`ResultClaim.__post_init__`), so activation needs no core-tag exclusion. + # Both arms reach here, so re-adoption resets cleanly; legacy versions activate no claims. + # Core-vocabulary tags are unconstructible (ResultClaim.__post_init__), so no exclusion needed. self._active_claims = _active_claims_at(self._result_claims, version) self._call_tool_adapter = _build_call_tool_adapter(self._active_claims) for method in self._notification_bindings: - # Bindings are consulted only for methods core does not know (`_on_notify`'s - # KeyError branch), so a core-known binding can never fire — say so once here. + # Bindings are consulted only for methods core does not know, so this one can never fire. if (method, version) in _methods.SERVER_NOTIFICATIONS: logger.warning( "notification binding for %r will never fire at %s: the core protocol defines this method", @@ -981,15 +946,13 @@ async def call_tool( allow_input_required: When ``False`` (default), an `InputRequiredResult` from the server raises `RuntimeError`; when ``True``, it is returned so the caller can resolve the requests and retry. - allow_claimed: When `False` (default), a claimed (extension) result - shape raises `UnexpectedClaimedResult`; when `True`, the parsed - claim model is returned for the caller to handle. + allow_claimed: When `False` (default), a claimed extension result raises + `UnexpectedClaimedResult`; when `True`, the parsed claim model is returned. Raises: RuntimeError: If the server returns an `InputRequiredResult` and ``allow_input_required`` is ``False``. - UnexpectedClaimedResult: If a claimed result shape parses and - `allow_claimed` is `False`; carries the parsed value. + UnexpectedClaimedResult: Claimed result with `allow_claimed` False; carries the parsed value. """ result = await self.send_request( types.CallToolRequest( @@ -1009,8 +972,7 @@ async def call_tool( if isinstance(result, types.CallToolResult) and not result.is_error: await self.validate_tool_result(name, result) - # Driver-innermost ordering: the input_required arm stays first — a claimed - # shape exits the multi-round-trip driver as terminal. + # The input_required arm stays first; a claimed shape is terminal for the multi-round-trip driver. if isinstance(result, types.InputRequiredResult) and not allow_input_required: raise _input_required_unexpected("call_tool") if not isinstance(result, types.CallToolResult | types.InputRequiredResult) and not allow_claimed: @@ -1027,13 +989,8 @@ def _resolve_param_headers(self, name: str, arguments: Mapping[str, Any]) -> dic async def validate_tool_result(self, name: str, result: types.CallToolResult) -> None: """Revalidate a `CallToolResult` against the tool's declared output schema. - Fetches the tool listing first when `name` has no cached schema. Tools - without an output schema (or not listed by the server) pass without - validation. - Raises: - RuntimeError: The result's structured content is missing or does - not conform to the tool's output schema. + RuntimeError: Structured content is missing or does not conform to the schema. """ if name not in self._tool_output_schemas: # refresh output schema cache @@ -1245,8 +1202,7 @@ async def dispatch_input_request( 2026-07-28 multi-round-trip driver, which dispatches the embedded `InputRequiredResult.input_requests` through the same callbacks. - Returns the callback's `InputResponse`, or `ErrorData` when the - callback declines — the refusal path; callers must handle that arm. + Returns the callback's `InputResponse`, or `ErrorData` when the callback declines. """ match request: case types.CreateMessageRequest(params=p): @@ -1265,8 +1221,7 @@ async def _on_notify( try: notification = cast(types.ServerNotification, _methods.parse_server_notification(method, version, params)) except KeyError: - # Methods the negotiated version's core tables do not know are offered to - # the notification bindings; core-known methods structurally never get here. + # Only methods unknown to the negotiated version's core tables reach the bindings. binding = self._notification_bindings.get(method) if binding is None: logger.debug("dropped %r: not defined at %s", method, version) @@ -1274,18 +1229,15 @@ async def _on_notify( try: bound_params = binding.params_type.model_validate(params or {}) except ValidationError: - # Mirrors the core notification arm below: warn and drop. logger.warning("Failed to validate notification: %s", method, exc_info=True) return send, receive = self._binding_queues[method] try: - # Never awaits: DirectDispatcher awaits _on_notify inline in the peer's - # notify(), so blocking here would deadlock in-process servers. + # Must not await: DirectDispatcher calls _on_notify inline; blocking deadlocks in-process servers. send.send_nowait(bound_params) except anyio.WouldBlock: - # Bounded FIFO: evict the oldest queued event (observation semantics - # tolerate the loss). No checkpoint since the failed send, so the - # buffer is still full and the eviction cannot itself block. + # Evict the oldest event; no checkpoint since the failed send, + # so the buffer is still full and the retry cannot block. receive.receive_nowait() logger.warning("notification queue for %r is full; dropped the oldest event", method) send.send_nowait(bound_params) diff --git a/src/mcp/server/extension.py b/src/mcp/server/extension.py index 11705943d..e9c62610e 100644 --- a/src/mcp/server/extension.py +++ b/src/mcp/server/extension.py @@ -29,8 +29,7 @@ from mcp.server.context import CallNext, HandlerResult, ServerMiddleware, ServerRequestContext -# The identifier grammar moved to `mcp.shared.extension` (the client extension -# surface shares it); re-exported here for existing importers. +# Re-exported from `mcp.shared.extension` (shared with the client surface) for existing importers. from mcp.shared.extension import validate_extension_identifier as validate_extension_identifier if TYPE_CHECKING: diff --git a/src/mcp/shared/extension.py b/src/mcp/shared/extension.py index f275e2436..283e9ba89 100644 --- a/src/mcp/shared/extension.py +++ b/src/mcp/shared/extension.py @@ -1,9 +1,4 @@ -"""Extension-identifier grammar shared by the server and client extension surfaces. - -Server extensions (`mcp.server.extension`) and client extensions -(`mcp.client.extension`) carry the same kind of identifier; this module is the -one source of truth for its validation. -""" +"""Extension-identifier grammar shared by the server and client extension surfaces.""" from __future__ import annotations diff --git a/tests/client/test_client_extensions.py b/tests/client/test_client_extensions.py index 802d8a66f..f80cfe884 100644 --- a/tests/client/test_client_extensions.py +++ b/tests/client/test_client_extensions.py @@ -1,13 +1,5 @@ -"""`Client` + `ClientExtension` integration: folding extension declarations into the -session at construction, and `call_tool` driving claim resolvers transparently. - -Claimed-shape servers here are real `MCPServer`s whose SEP-2133 server extension -rewrites `tools/call` results via `intercept_tool_call` — the full public-API loop. -A short-circuiting interceptor's dict reaches the client verbatim (the runner trusts -it as a well-formed result), so the claimed models carry vendor top-level fields. - -`tools/call` is never cached (`Client.call_tool` has no `_cached_fetch` weave and the -SEP-2549 cacheable verbs do not include it), so the claim path needs no cache tests. +"""`Client` + `ClientExtension` integration: extension declarations fold into the session at +construction, and `call_tool` drives claim resolvers transparently against real `MCPServer`s. """ import logging @@ -46,8 +38,7 @@ def _name_elicitation() -> types.ElicitRequest: class VoucherResult(Result): - """The claimed `tools/call` shape, tagged `voucher`, carrying a vendor top-level field - (a short-circuiting server interceptor's dict reaches the client verbatim).""" + """The claimed `tools/call` shape, tagged `voucher`, carrying a vendor top-level field.""" result_type: Literal["voucher"] = "voucher" voucher_code: str | None = None @@ -128,17 +119,15 @@ def add(a: int, b: int) -> int: return server -# ── Construction-time validation ──────────────────────────────────────────── +# Construction-time validation class _CouponResult(Result): - """A second claimed shape with its own tag, for multi-claim routing.""" - result_type: Literal["coupon"] = "coupon" async def _unreachable_coupon_resolve(claimed: _CouponResult, ctx: ClaimContext) -> CallToolResult: - raise NotImplementedError # the wrong resolver for a voucher — must never run + raise NotImplementedError # the wrong resolver for a voucher; must never run class _CouponExtension(ClientExtension): @@ -167,20 +156,18 @@ async def _unreachable_twice_resolve(claimed: _TwiceResult, ctx: ClaimContext) - def test_mapping_extensions_get_the_migration_error() -> None: - """SDK-defined: the replaced dict form fails with a message naming the new shape, - not an attribute error about `str`.""" + """SDK-defined: the replaced dict form fails with a message naming the new shape.""" with pytest.raises(TypeError) as exc_info: Client(_add_server(), extensions=cast("Sequence[ClientExtension]", {"com.example/ui": {}})) assert str(exc_info.value) == snapshot( - "extensions= takes a sequence of ClientExtension instances; the mapping form was " - "replaced — use advertise(identifier, settings) for advertise-only entries" + "extensions= takes a sequence of ClientExtension instances. The mapping form was " + "replaced: use advertise(identifier, settings) for advertise-only entries" ) def test_one_extension_claiming_a_tag_twice_reads_as_one_owner() -> None: - """SDK-defined: a self-conflict names the one extension once instead of - "extensions 'a' and 'a'".""" + """SDK-defined: a self-conflict names the one extension once, not as a pair.""" with pytest.raises(ValueError) as exc_info: Client(_add_server(), extensions=[_SelfConflictingClaims()]) @@ -190,8 +177,7 @@ def test_one_extension_claiming_a_tag_twice_reads_as_one_owner() -> None: def test_bare_extension_instance_is_rejected_with_the_fix_named() -> None: - """SDK-defined: an instance whose class never set `identifier` fails Client - construction with an error naming the type and the fix — not an AttributeError.""" + """SDK-defined: an instance whose class never set `identifier` fails construction naming the type and the fix.""" with pytest.raises(ValueError) as exc_info: Client(_add_server(), extensions=[ClientExtension()]) @@ -202,16 +188,14 @@ def test_bare_extension_instance_is_rejected_with_the_fix_named() -> None: class _SelfAssignedBadId(ClientExtension): - """Assigns a malformed identifier in `__init__` — invisible at class definition.""" + """Assigns a malformed identifier in `__init__`, invisible at class definition.""" def __init__(self) -> None: self.identifier = "not-prefixed" def test_invalid_per_instance_identifier_raises_the_validators_error() -> None: - """SDK-defined: per-instance identifiers are validated when the Client consumes the - extension (no class attribute existed at definition time, mirroring the server's - posture); the shared validator's TypeError surfaces unwrapped.""" + """SDK-defined: per-instance identifiers are validated when the Client consumes the extension.""" with pytest.raises(TypeError) as exc_info: Client(_add_server(), extensions=[_SelfAssignedBadId()]) @@ -222,8 +206,7 @@ def test_invalid_per_instance_identifier_raises_the_validators_error() -> None: def test_duplicate_extension_identifiers_are_rejected_naming_the_identifier() -> None: - """SDK-defined: one identifier cannot appear twice — there would be two settings - dicts for one capability-ad key.""" + """SDK-defined: one identifier cannot appear twice across the extensions sequence.""" with pytest.raises(ValueError) as exc_info: Client(_add_server(), extensions=[advertise(_VOUCHER_EXT), advertise(_VOUCHER_EXT, {"a": 1})]) @@ -231,12 +214,10 @@ def test_duplicate_extension_identifiers_are_rejected_naming_the_identifier() -> async def _unreachable_resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: - raise NotImplementedError # construction-only extensions never resolve + raise NotImplementedError class _RivalVoucherExtension(ClientExtension): - """A second extension claiming the same `voucher` tag (construction-conflict tests).""" - identifier = _RIVAL_EXT def claims(self) -> Sequence[ResultClaim[Any]]: @@ -244,10 +225,7 @@ def claims(self) -> Sequence[ResultClaim[Any]]: def test_conflicting_claims_across_extensions_name_both_owners() -> None: - """SDK-defined: two extensions claiming the same (method, resultType) fail at - Client construction with both owning extensions named — the session's own - duplicate check knows only the method and tag, which cannot tell a user which - two of their extensions collide.""" + """SDK-defined: two extensions claiming the same tag fail at construction with both owners named.""" with pytest.raises(ValueError) as exc_info: Client(_add_server(), extensions=[_VoucherExtension(_unreachable_resolve), _RivalVoucherExtension()]) @@ -262,7 +240,7 @@ class _EventParams(BaseModel): async def _unreachable_handler(params: _EventParams) -> None: - raise NotImplementedError # construction-only extensions never deliver + raise NotImplementedError class _ObserverA(ClientExtension): @@ -288,8 +266,7 @@ def notifications(self) -> Sequence[NotificationBinding[Any]]: def test_conflicting_notification_bindings_name_both_owners() -> None: - """SDK-defined: two extensions binding the same notification method fail at Client - construction with both owning extensions named, for the same reason as claims.""" + """SDK-defined: two extensions binding the same notification method fail with both owners named.""" with pytest.raises(ValueError) as exc_info: Client(_add_server(), extensions=[_ObserverA(), _ObserverB()]) @@ -299,7 +276,7 @@ def test_conflicting_notification_bindings_name_both_owners() -> None: ) -# ── settings() consumption ─────────────────────────────────────────────────── +# settings() consumption class _CountedResult(Result): @@ -307,12 +284,10 @@ class _CountedResult(Result): async def _unreachable_counted_resolve(claimed: _CountedResult, ctx: ClaimContext) -> CallToolResult: - raise NotImplementedError # never driven; exists so claims() has something to return + raise NotImplementedError class _CountingSettings(ClientExtension): - """Counts every declaration read to pin the read-once contract for all three.""" - identifier = "com.example/counted" def __init__(self) -> None: @@ -336,10 +311,7 @@ def notifications(self) -> Sequence[NotificationBinding[Any]]: async def test_declarations_are_read_exactly_once_at_construction() -> None: - """SDK-defined: `settings()`, `claims()`, and `notifications()` are each read once, - at Client construction — connecting and calling tools (each modern request re-stamps - the capability ad) never re-reads any of them, so a stateful extension cannot desync - the ad from the claims.""" + """SDK-defined: each declaration method is read exactly once, at Client construction, never again.""" extension = _CountingSettings() client = Client(_add_server(), extensions=[extension]) assert (extension.reads, extension.claims_reads, extension.notifications_reads) == (1, 1, 1) @@ -353,9 +325,7 @@ async def test_declarations_are_read_exactly_once_at_construction() -> None: async def test_settings_dict_is_held_by_reference_not_copied() -> None: - """SDK-defined: the dict `settings()` returns is held by reference, not copied — - mutating it between construction and connect changes the advertised ad (the same - aliasing the dict-form `extensions=` argument had).""" + """SDK-defined: the settings dict is held by reference, so mutating it before connect changes the ad.""" observed: list[dict[str, dict[str, Any]] | None] = [] async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: @@ -381,16 +351,14 @@ async def list_tools( assert observed == [{"com.example/loyalty": {"tier": "gold"}}] -# ── extensions=None stays byte-identical ───────────────────────────────────── +# extensions=None stays byte-identical @pytest.mark.parametrize("extensions", [None, ()], ids=["none", "empty"]) async def test_no_extensions_keeps_tools_call_parsing_byte_identical( extensions: Sequence[ClientExtension] | None, ) -> None: - """SDK-defined: `extensions=None` (and an empty sequence) leave the session exactly - as a claim-less client's — the tools/call adapter is the module-level constant by - identity, and an ordinary call behaves as before.""" + """SDK-defined: `extensions=None` and an empty sequence leave the session exactly as a claim-less client's.""" with anyio.fail_after(5): async with Client(_add_server(), extensions=extensions) as client: assert client.session._call_tool_adapter is _CallToolResultAdapter @@ -399,14 +367,11 @@ async def test_no_extensions_keeps_tools_call_parsing_byte_identical( assert result.structured_content == {"result": 3} -# ── The transparent claim path ─────────────────────────────────────────────── +# The transparent claim path async def test_claimed_result_resolves_transparently_to_the_resolvers_result() -> None: - """A server-claimed `tools/call` shape never surfaces: the owning claim's resolver - receives the parsed claim model and `Client.call_tool` returns the resolver's - `CallToolResult` object — the signature stays `-> CallToolResult` (the assert_type - below is checked by pyright).""" + """A claimed shape never surfaces: the resolver gets the parsed model and `call_tool` returns its product.""" received: list[VoucherResult] = [] produced: list[CallToolResult] = [] @@ -427,8 +392,7 @@ async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: async def test_claimed_shape_routes_to_its_owning_extensions_resolver() -> None: - """With two claim-bearing extensions registered, the parsed shape runs ITS owner's - resolver — the coupon extension (registered first) must never see a voucher.""" + """With two claim-bearing extensions registered, the parsed shape runs its owner's resolver only.""" received: list[VoucherResult] = [] async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: @@ -445,10 +409,7 @@ async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: async def test_resolver_product_gets_the_direct_paths_output_schema_revalidation() -> None: - """The resolver's product passes through `validate_tool_result` exactly like a - directly-returned result: against the tool's output schema, missing structured - content raises the direct path's RuntimeError (the message below is the same - one `ClientSession.call_tool`'s own guard produces).""" + """The resolver's product is revalidated against the tool's output schema exactly like a direct result.""" async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: return CallToolResult(content=[TextContent(text="unstructured")]) @@ -461,9 +422,7 @@ async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: async def test_resolver_error_result_is_returned_not_raised() -> None: - """An `isError` resolver product skips output-schema revalidation and comes back - as-is — the same strictness as the direct path, which only revalidates successes. - The tool here declares an output schema, so revalidating would have raised.""" + """An `isError` resolver product skips output-schema revalidation and comes back as-is.""" async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: return CallToolResult(content=[TextContent(text="voucher printer on fire")], is_error=True) @@ -477,8 +436,7 @@ async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: async def test_resolver_receives_the_calls_claim_context() -> None: - """`ClaimContext` hands the resolver the client's own session object, the tool - name, and the per-call read timeout `call_tool` was given.""" + """`ClaimContext` carries the client's own session object, the tool name, and the per-call read timeout.""" contexts: list[ClaimContext] = [] async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: @@ -500,8 +458,7 @@ class _VoucherRefused(Exception): async def test_resolver_exception_propagates_untouched() -> None: - """A resolver exception reaches the `call_tool` caller as the very object the - resolver raised — no wrapping, the extension owns its error vocabulary.""" + """A resolver exception reaches the `call_tool` caller as the very object raised, unwrapped.""" refusal = _VoucherRefused("the voucher is refused") async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: @@ -514,12 +471,11 @@ async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: assert exc_info.value is refusal -# ── Unclaimed results with extensions present ──────────────────────────────── +# Unclaimed results with extensions present async def test_unclaimed_result_flows_through_unchanged_with_extensions_present() -> None: - """An ordinary `CallToolResult` is untouched by the claim machinery — the resolver - never runs and the result matches a claim-less client's.""" + """An ordinary `CallToolResult` is untouched by the claim machinery; the resolver never runs.""" async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: raise NotImplementedError # this server never produces a claimed shape @@ -532,9 +488,7 @@ async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: async def test_input_required_then_plain_result_keeps_the_auto_loop_working() -> None: - """With a claim-bearing extension present, the auto loop on an unclaimed tool is - unchanged: input_required resolves via the elicitation callback and the plain - terminal result comes back; the resolver never runs.""" + """With a claim-bearing extension present, the input_required auto loop on an unclaimed tool is unchanged.""" server = MCPServer("mrtr") @server.tool() @@ -564,14 +518,11 @@ async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: assert result.content == [TextContent(text="Hello, Ada!")] -# ── The multi-round-trip + claimed interplay ───────────────────────────────── +# The multi-round-trip + claimed interplay async def test_input_required_then_claimed_result_on_retry_resolves_transparently() -> None: - """The retry-leg regression: a call that demands input first and returns a claimed - shape on the retry still resolves transparently. The driver's retry must admit - claimed shapes — multi-round-trip input resolves before a claimed result, so a - claim may terminate any round, not just the first.""" + """A call that demands input first and returns a claimed shape on the retry still resolves transparently.""" prompted: list[str] = [] received: list[VoucherResult] = [] @@ -598,12 +549,11 @@ async def resolve(claimed: VoucherResult, ctx: ClaimContext) -> CallToolResult: assert result.content == [TextContent(text="honored after-input")] -# ── Notification bindings fold into the session ────────────────────────────── +# Notification bindings fold into the session class _CoreMethodObserver(ClientExtension): - """Binds a method the modern core tables already define (construction-legal; the - session warns once at adopt that it can never fire).""" + """Binds a method the modern core tables already define.""" identifier = "com.example/observer" @@ -614,10 +564,7 @@ def notifications(self) -> Sequence[NotificationBinding[Any]]: async def test_notification_bindings_fold_into_the_session(caplog: pytest.LogCaptureFixture) -> None: - """The Client threads extension notification bindings into its session: a binding - for a core-known method draws the session's one-time gone-quiet warning at adopt. - (Delivery mechanics are session-tier covered in - test_session_notification_bindings.py; this pins the Client fold seam.)""" + """The Client threads extension bindings into its session; a core-known binding draws the one-time warning.""" with caplog.at_level(logging.WARNING, logger="client"): async with Client(_add_server(), extensions=[_CoreMethodObserver()]): pass diff --git a/tests/client/test_extension.py b/tests/client/test_extension.py index 927399b43..cd1894788 100644 --- a/tests/client/test_extension.py +++ b/tests/client/test_extension.py @@ -1,10 +1,4 @@ -"""Tests for the client extension vocabulary (`mcp.client.extension`). - -Everything here is construction-time: claims, notification bindings, the -`ClientExtension` base class, and the `advertise()` factory. No session or -client is ever opened — the classes are pure declarations, and every -validation rule fires before an instance exists. -""" +"""Construction-time tests for `mcp.client.extension`; no session is ever opened.""" from dataclasses import FrozenInstanceError from typing import Any, Literal, cast @@ -25,8 +19,6 @@ class _TaskResult(Result): - """A well-formed claimed shape: `result_type` is a Literal of the claimed tag.""" - result_type: Literal["task"] = "task" task_id: str = "t-1" @@ -36,14 +28,10 @@ class _UntaggedResult(Result): class _PlainStringTagResult(Result): - """`result_type` declared as a plain `str`, not a Literal.""" - result_type: str = "task" class _OtherTagResult(Result): - """`result_type` is a Literal of a tag other than the one claimed.""" - result_type: Literal["other"] = "other" @@ -56,19 +44,15 @@ class _ClaimedInputRequiredResult(InputRequiredResult): async def _resolve(result: Result, ctx: ClaimContext) -> CallToolResult: - raise NotImplementedError # construction-only tests never drive a claim + raise NotImplementedError def _claim(model: type[Result] = _TaskResult, **kwargs: Any) -> ResultClaim[Result]: return ResultClaim(result_type="task", model=model, resolve=_resolve, **kwargs) -# ── ResultClaim construction ──────────────────────────────────────────────── - - def test_claim_with_literal_discriminated_model_constructs() -> None: - """SDK-defined: a claim whose model carries `result_type: Literal[]` - constructs, defaulting to the `tools/call` verb at every modern version.""" + """SDK-defined: a model tagged with the claimed Literal constructs, defaulting to `tools/call` everywhere.""" claim = ResultClaim(result_type="task", model=_TaskResult, resolve=_resolve) assert claim.result_type == "task" @@ -79,8 +63,7 @@ def test_claim_with_literal_discriminated_model_constructs() -> None: def test_claim_accepts_modern_protocol_versions() -> None: - """SDK-defined: a non-None `protocol_versions` is accepted when it is a subset of - the modern protocol revisions.""" + """SDK-defined: a non-None `protocol_versions` subset of the modern revisions is accepted.""" versions = frozenset(MODERN_PROTOCOL_VERSIONS) claim = _claim(protocol_versions=versions) @@ -89,8 +72,7 @@ def test_claim_accepts_modern_protocol_versions() -> None: def test_claim_rejects_core_result_type_vocabulary() -> None: - """SDK-defined: "complete" and "input_required" are core protocol vocabulary — - a claim cannot re-key the shapes the session itself routes on.""" + """SDK-defined: a claim cannot re-key the core tags 'complete' and 'input_required'.""" messages: dict[str, str] = {} for result_type in ("complete", "input_required"): with pytest.raises(ValueError) as exc_info: @@ -107,8 +89,7 @@ def test_claim_rejects_core_result_type_vocabulary() -> None: @pytest.mark.parametrize("model", [_ClaimedCallToolResult, _ClaimedInputRequiredResult]) def test_claim_rejects_model_subclassing_core_result_types(model: type[Result]) -> None: - """SDK-defined: a claim model subclassing `CallToolResult` or `InputRequiredResult` - would satisfy the session's isinstance branches and bypass claim routing.""" + """SDK-defined: a claim model subclassing a core result type is rejected; it would bypass claim routing.""" with pytest.raises(ValueError) as exc_info: _claim(model=model) @@ -116,8 +97,7 @@ def test_claim_rejects_model_subclassing_core_result_types(model: type[Result]) def test_claim_rejects_model_without_result_type_field() -> None: - """SDK-defined: the claim model must declare the discriminating `result_type` - field; without it the claimed shape could never be routed.""" + """SDK-defined: the claim model must declare the discriminating `result_type` field.""" with pytest.raises(ValueError) as exc_info: _claim(model=_UntaggedResult) @@ -125,8 +105,7 @@ def test_claim_rejects_model_without_result_type_field() -> None: def test_claim_rejects_plain_str_result_type_field() -> None: - """SDK-defined: a plain `str` tag would let one model validate any claimed shape; - the field must be a Literal of exactly the claimed tag.""" + """SDK-defined: the model's `result_type` must be a Literal of the claimed tag, not a plain `str`.""" with pytest.raises(ValueError) as exc_info: _claim(model=_PlainStringTagResult) @@ -134,8 +113,7 @@ def test_claim_rejects_plain_str_result_type_field() -> None: def test_claim_rejects_mismatched_result_type_literal() -> None: - """SDK-defined: the model's Literal tag must equal the claim's `result_type` — - a mismatch would register the model under a tag it refuses to validate.""" + """SDK-defined: the model's Literal tag must equal the claim's `result_type`.""" with pytest.raises(ValueError) as exc_info: _claim(model=_OtherTagResult) @@ -143,8 +121,7 @@ def test_claim_rejects_mismatched_result_type_literal() -> None: def test_claim_rejects_method_outside_the_closed_verb_set() -> None: - """SDK-defined: claims attach to `tools/call` only (the Literal is the static gate); - an unchecked runtime value must not fold into tools/call parsing silently.""" + """SDK-defined: claims attach to `tools/call` only, even for values that dodge the static Literal gate.""" with pytest.raises(ValueError) as exc_info: _claim(method=cast("Literal['tools/call']", "prompts/get")) @@ -152,8 +129,7 @@ def test_claim_rejects_method_outside_the_closed_verb_set() -> None: def test_claim_rejects_empty_protocol_versions() -> None: - """SDK-defined: an empty version set could never activate; `None` is the - spelling for "every modern version".""" + """SDK-defined: an empty version set is rejected; `None` is the spelling for every modern version.""" with pytest.raises(ValueError) as exc_info: _claim(protocol_versions=frozenset()) @@ -161,8 +137,7 @@ def test_claim_rejects_empty_protocol_versions() -> None: def test_claim_rejects_non_modern_protocol_versions() -> None: - """SDK-defined: claimed shapes cannot be delivered on a legacy wire, so a - non-None version set must be a subset of the modern protocol revisions.""" + """SDK-defined: a non-None version set must be a subset of the modern protocol revisions.""" messages: list[str] = [] for versions in ( frozenset({"2025-11-25"}), @@ -186,28 +161,23 @@ def test_claim_rejects_non_modern_protocol_versions() -> None: def test_result_claim_is_frozen() -> None: - """SDK-defined: claims are immutable declarations — mutating one after - construction raises.""" + """SDK-defined: claims are immutable; mutating one after construction raises.""" claim = _claim() with pytest.raises(FrozenInstanceError): setattr(claim, "result_type", "other") # direct assignment is also a type error -# ── NotificationBinding construction ──────────────────────────────────────── - - class _TaskNotificationParams(BaseModel): task_id: str async def _on_task(params: _TaskNotificationParams) -> None: - raise NotImplementedError # construction-only tests never deliver + raise NotImplementedError def test_notification_binding_constructs() -> None: - """SDK-defined: a binding is a bare declaration — wire method name, params - model, async observer — with no construction-time validation.""" + """SDK-defined: a binding is a bare declaration with no construction-time validation.""" binding = NotificationBinding(method="notifications/tasks", params_type=_TaskNotificationParams, handler=_on_task) assert binding.method == "notifications/tasks" @@ -216,10 +186,7 @@ def test_notification_binding_constructs() -> None: def test_notification_binding_accepts_core_known_method() -> None: - """SDK-defined: deliberately NO spec-table check at construction — bindings are - consulted only for methods core does not know, so they are additive by - construction, and an import-time table check would break packages whenever a - core version adopts a method.""" + """SDK-defined: deliberately no spec-table check at construction, so packages survive core adopting a method.""" binding = NotificationBinding( method="notifications/progress", params_type=_TaskNotificationParams, handler=_on_task ) @@ -228,20 +195,15 @@ def test_notification_binding_accepts_core_known_method() -> None: def test_notification_binding_is_frozen() -> None: - """SDK-defined: bindings are immutable declarations — mutating one after - construction raises.""" + """SDK-defined: bindings are immutable; mutating one after construction raises.""" binding = NotificationBinding(method="notifications/tasks", params_type=_TaskNotificationParams, handler=_on_task) with pytest.raises(FrozenInstanceError): setattr(binding, "method", "notifications/other") # direct assignment is also a type error -# ── ClientExtension subclassing ───────────────────────────────────────────── - - def test_extension_defaults_advertise_nothing() -> None: - """SDK-defined: a minimal subclass overrides nothing — empty settings, no - claims, no notification bindings.""" + """SDK-defined: a minimal subclass advertises empty settings, no claims, and no bindings.""" class _MinimalExt(ClientExtension): identifier = "com.example/minimal" @@ -265,8 +227,7 @@ class _MinimalExt(ClientExtension): ], ) def test_grammar_conformant_identifiers_accepted_at_class_definition(identifier: str) -> None: - """Spec `_meta` key grammar: dot-separated labels (letter start, letter/digit end, - hyphens interior), a slash, then a name that starts and ends alphanumeric.""" + """Spec `_meta` key grammar: conformant `vendor-prefix/name` identifiers are accepted.""" cls = type("_GoodExt", (ClientExtension,), {"identifier": identifier}) assert cls.identifier == identifier @@ -292,16 +253,13 @@ def test_grammar_conformant_identifiers_accepted_at_class_definition(identifier: ], ) def test_malformed_identifier_rejected_at_class_definition(identifier: Any) -> None: - """SDK-defined: SEP-2133 requires a `vendor-prefix/name` identifier, enforced the - moment the subclass is defined — same grammar and helper as the server side.""" + """SDK-defined: the SEP-2133 `vendor-prefix/name` grammar is enforced the moment the subclass is defined.""" with pytest.raises(TypeError): type("_BadExt", (ClientExtension,), {"identifier": identifier}) def test_subclass_without_identifier_allowed_at_definition() -> None: - """SDK-defined: a subclass that sets no class-level `identifier` (an abstract-ish - intermediate base, or one assigning per-instance ids in `__init__`) is allowed at - definition time; the identifier is validated when the extension is consumed.""" + """SDK-defined: a subclass with no class-level `identifier` is allowed; validation waits for consumption.""" class _AbstractishExt(ClientExtension): """Intermediate base; concrete subclasses supply the identifier.""" @@ -312,12 +270,8 @@ class _ConcreteExt(_AbstractishExt): assert _ConcreteExt.identifier == "com.example/concrete" -# ── advertise() factory ───────────────────────────────────────────────────── - - def test_advertise_serves_captured_settings() -> None: - """SDK-defined: `advertise()` returns an ad-only extension whose `settings()` - override serves the captured settings.""" + """SDK-defined: `advertise()` returns an ad-only extension serving the captured settings.""" ext = advertise("com.example/flags", {"enabled": True}) assert isinstance(ext, ClientExtension) @@ -336,7 +290,6 @@ def test_advertise_defaults_to_empty_settings() -> None: @pytest.mark.parametrize("identifier", ["noprefix", "foo/", ""]) def test_advertise_validates_identifier_eagerly(identifier: str) -> None: - """SDK-defined: `advertise()` validates the identifier at the call, not at some - later consumption point — a bad ad-only id fails where it is written.""" + """SDK-defined: `advertise()` validates the identifier eagerly, at the call site.""" with pytest.raises(TypeError): advertise(identifier) diff --git a/tests/client/test_send_request_mcp_name.py b/tests/client/test_send_request_mcp_name.py index 3963445e7..408810814 100644 --- a/tests/client/test_send_request_mcp_name.py +++ b/tests/client/test_send_request_mcp_name.py @@ -1,12 +1,6 @@ -"""`ClientSession.send_request` mirrors `Request.name_param` into the `Mcp-Name` header. - -The modern stamp emits `Mcp-Name` for the core `NAME_BEARING_METHODS` table; the -`name_param` delta covers every other send path (vendor request types, the -legacy handshake stamp), keyed on header presence so the stamp's table rows -always win by ordering. The vendor sends below also pin the widened -`send_request` typing: a `Request[...]` subclass passes without a cast. -""" - +"""`ClientSession.send_request` mirrors `Request.name_param` into the `Mcp-Name` +header on send paths the core `NAME_BEARING_METHODS` table does not cover. The +vendor sends also pin the widened `send_request` typing (no cast needed).""" from collections.abc import Mapping from typing import Any, Literal @@ -127,8 +121,7 @@ def _headers(opts: CallOptions) -> dict[str, str]: @pytest.mark.anyio async def test_vendor_name_param_emits_mcp_name_on_the_modern_path() -> None: - """A vendor request type declaring `name_param` gets `Mcp-Name` on a modern - wire even though its method is not in `NAME_BEARING_METHODS`.""" + """A vendor `name_param` emits `Mcp-Name` on a modern wire even outside `NAME_BEARING_METHODS`.""" dispatcher = _RecordingDispatcher() with anyio.fail_after(5): async with ClientSession(dispatcher=dispatcher) as session: @@ -140,8 +133,7 @@ async def test_vendor_name_param_emits_mcp_name_on_the_modern_path() -> None: @pytest.mark.anyio async def test_vendor_name_param_emits_mcp_name_on_the_handshake_path() -> None: - """The handshake stamp sets no `Mcp-Name` at all, so for a legacy wire the - delta is the responsible emitter — emission is era-unconditional.""" + """The handshake stamp sets no `Mcp-Name`, so on a legacy wire the delta is the emitter.""" dispatcher = _RecordingDispatcher() with anyio.fail_after(5): async with ClientSession(dispatcher=dispatcher) as session: @@ -155,8 +147,7 @@ async def test_vendor_name_param_emits_mcp_name_on_the_handshake_path() -> None: @pytest.mark.anyio async def test_name_value_passes_through_encode_header_value() -> None: - """A name that cannot ride as a plain ASCII header value is base64-sentinel - encoded (spec MUST for `Mcp-Name`).""" + """A non-ASCII name is base64-sentinel encoded, a spec MUST for `Mcp-Name`.""" name = "wídget ✨" dispatcher = _RecordingDispatcher() with anyio.fail_after(5): @@ -170,9 +161,7 @@ async def test_name_value_passes_through_encode_header_value() -> None: @pytest.mark.anyio async def test_core_tools_call_header_comes_from_the_stamp_alone() -> None: - """Core `tools/call` is unchanged: the modern stamp emits `Mcp-Name` from the - table; `CallToolRequest` declares no `name_param`, and on a legacy wire core - methods stay headerless exactly as today.""" + """Core `tools/call` is unchanged: the modern stamp emits the header; legacy stays headerless.""" dispatcher = _RecordingDispatcher() with anyio.fail_after(5): async with ClientSession(dispatcher=dispatcher) as session: @@ -187,8 +176,7 @@ async def test_core_tools_call_header_comes_from_the_stamp_alone() -> None: @pytest.mark.anyio async def test_stamp_table_rows_win_over_name_param_by_ordering() -> None: - """Header-presence keying: when the modern stamp already emitted `Mcp-Name` - from the core table, a `name_param` on the request type does not overwrite it.""" + """A stamp-emitted `Mcp-Name` wins; `name_param` never overwrites an existing header.""" dispatcher = _RecordingDispatcher() request = _ShadowCallToolRequest(params={"name": "real-tool", "customKey": "other-value"}) with anyio.fail_after(5): @@ -201,8 +189,7 @@ async def test_stamp_table_rows_win_over_name_param_by_ordering() -> None: @pytest.mark.anyio async def test_vendor_name_param_emits_mcp_name_on_the_preconnect_path() -> None: - """Emission is era-unconditional: a lowlevel caller that never adopts any - version (the preconnect stamp) still gets `Mcp-Name` from `name_param`.""" + """Emission is era-unconditional: a session that never adopts still emits `Mcp-Name`.""" dispatcher = _RecordingDispatcher() with anyio.fail_after(5): async with ClientSession(dispatcher=dispatcher) as session: @@ -213,6 +200,7 @@ async def test_vendor_name_param_emits_mcp_name_on_the_preconnect_path() -> None @pytest.mark.anyio async def test_missing_name_value_fails_loud_naming_method_and_key() -> None: + """A missing name value raises ValueError naming the method and key, before the wire.""" dispatcher = _RecordingDispatcher() with anyio.fail_after(5): async with ClientSession(dispatcher=dispatcher) as session: @@ -225,6 +213,7 @@ async def test_missing_name_value_fails_loud_naming_method_and_key() -> None: @pytest.mark.anyio async def test_non_string_name_value_fails_loud() -> None: + """A non-string name value raises the same ValueError as a missing one.""" dispatcher = _RecordingDispatcher() with anyio.fail_after(5): async with ClientSession(dispatcher=dispatcher) as session: @@ -237,8 +226,7 @@ async def test_non_string_name_value_fails_loud() -> None: @pytest.mark.anyio async def test_absent_params_fails_loud_not_attribute_error() -> None: - """`exclude_none` drops a None params entirely; the delta still answers with - the documented ValueError, not an AttributeError on the missing key.""" + """Absent params still raise the documented ValueError, not an AttributeError.""" dispatcher = _RecordingDispatcher() with anyio.fail_after(5): async with ClientSession(dispatcher=dispatcher) as session: @@ -251,8 +239,7 @@ async def test_absent_params_fails_loud_not_attribute_error() -> None: @pytest.mark.anyio async def test_request_without_name_param_sends_no_mcp_name() -> None: - """No `name_param` and a method outside `NAME_BEARING_METHODS`: neither - emitter produces an `Mcp-Name` header, on either era's stamp.""" + """No `name_param` and a method outside the core table emits no `Mcp-Name` on either era.""" dispatcher = _RecordingDispatcher() with anyio.fail_after(5): async with ClientSession(dispatcher=dispatcher) as session: diff --git a/tests/client/test_session_claims.py b/tests/client/test_session_claims.py index 8f275ef14..21cf2fa69 100644 --- a/tests/client/test_session_claims.py +++ b/tests/client/test_session_claims.py @@ -1,11 +1,6 @@ -"""`ClientSession` result claims: construction validation, adopt-time activation, -the discriminated tools/call adapter, the version-aware capability ad, and the -`allow_claimed` escape hatch. - -Claims activate at `adopt()` — the stamp-swap moment — and only at modern -protocol versions; everywhere else parsing stays byte-identical to a claim-less -session (the zero-claims adapter is the module-level constant, by identity). -""" +"""`ClientSession` result claims: construction validation, activation at modern +adopts only, claimed-result routing, the version-aware capability ad, and the +`allow_claimed` escape hatch.""" from collections.abc import Mapping from typing import Any, Literal @@ -126,12 +121,8 @@ def _adopt_handshake(session: ClientSession) -> None: ) -# ── Construction-time validation ──────────────────────────────────────────── - - def test_duplicate_claim_tag_across_extensions_rejected() -> None: - """SDK-defined: two claims on the same (method, resultType) — even from different - extensions — could not be routed apart, so construction fails.""" + """SDK-defined: two claims on the same resultType cannot be routed apart, so construction fails.""" with pytest.raises(ValueError) as exc_info: ClientSession( dispatcher=_RecordingDispatcher(), @@ -143,9 +134,7 @@ def test_duplicate_claim_tag_across_extensions_rejected() -> None: def test_claims_keyed_to_unadvertised_extension_rejected() -> None: - """SDK-defined: a claim rides its extension's capability ad — a `result_claims` key - with no `extensions` entry (including extensions=None) advertises nothing the - server could ever act on, so construction fails.""" + """SDK-defined: a `result_claims` key with no `extensions` entry advertises nothing, so construction fails.""" messages: list[str] = [] for extensions in (None, {_AD_ONLY_EXT: {"flag": True}}): with pytest.raises(ValueError) as exc_info: @@ -167,33 +156,25 @@ def test_claims_keyed_to_unadvertised_extension_rejected() -> None: def test_empty_claim_sequence_rejected() -> None: - """SDK-defined: an empty claim set would make the ad-filter treat the identifier as - claim-bearing and drop it from the capability ad at every version; "claim-less" and - "advertises everywhere" stay the same thing by rejecting the empty spelling.""" + """SDK-defined: an empty claim list is rejected at construction; a claim-less extension omits the key.""" with pytest.raises(ValueError) as exc_info: ClientSession(dispatcher=_RecordingDispatcher(), extensions={_TASKS_EXT: {}}, result_claims={_TASKS_EXT: []}) assert str(exc_info.value) == snapshot( - "result_claims['com.example/tasks'] is empty; an empty claim set would drop the " - "extension from the capability ad at every version — omit the key instead" + "result_claims['com.example/tasks'] is empty and would drop the extension from " + "the capability ad at every version. Omit the key instead" ) def test_empty_settings_count_as_an_advertised_extension() -> None: - """SDK-defined: an extension advertised with empty settings ({}) is still an ad — - claims keyed to it construct fine.""" + """SDK-defined: empty settings ({}) still count as an ad, so claims keyed to the extension construct.""" session = _claims_session(_RecordingDispatcher(), _task_claim()) assert isinstance(session, ClientSession) -# ── Activation at adopt() ─────────────────────────────────────────────────── - - def test_without_claims_the_call_tool_adapter_is_the_module_constant() -> None: - """SDK-defined: the no-extensions parse path stays byte-identical — with zero - active claims the session holds the module-level adapter itself, not a rebuild, - before and after either adopt arm.""" + """SDK-defined: with zero active claims the session holds the module-level adapter by identity.""" session = ClientSession(dispatcher=_RecordingDispatcher()) assert session._call_tool_adapter is _CallToolResultAdapter @@ -208,9 +189,8 @@ def test_without_claims_the_call_tool_adapter_is_the_module_constant() -> None: async def test_modern_adopt_activates_claims_and_routes_claimed_results( protocol_versions: frozenset[str] | None, ) -> None: - """SDK-defined: at a modern adopt, a claim active at the negotiated version (None = - every modern version; an explicit subset containing it) routes the claimed raw to - the claim model.""" + """SDK-defined: at a modern adopt, a claim active at the negotiated version routes + the claimed raw to the claim model.""" dispatcher = _RecordingDispatcher(tool_result=_CLAIMED_TASK_RESULT) session = _claims_session(dispatcher, _task_claim(protocol_versions=protocol_versions)) with anyio.fail_after(5): @@ -224,8 +204,7 @@ async def test_modern_adopt_activates_claims_and_routes_claimed_results( @pytest.mark.anyio async def test_legacy_adopt_clears_active_claims() -> None: - """SDK-defined: re-adopt safe — a session that adopts modern then legacy fully - clears its active claims, restoring the module-level adapter by identity.""" + """SDK-defined: a legacy adopt clears active claims and restores the module-level adapter.""" dispatcher = _RecordingDispatcher(tool_result=_CLAIMED_TASK_RESULT) session = _claims_session(dispatcher, _task_claim()) with anyio.fail_after(5): @@ -237,14 +216,13 @@ async def test_legacy_adopt_clears_active_claims() -> None: assert session._call_tool_adapter is _CallToolResultAdapter with pytest.raises(ValidationError): await session.call_tool("t", {}, allow_claimed=True) - # The rejection came from response parsing — the request did reach the wire. + # Rejected at response parsing; the request did reach the wire. assert dispatcher.calls[-1][0] == "tools/call" @pytest.mark.anyio async def test_modern_readopt_after_legacy_reactivates_claims() -> None: - """SDK-defined: adoption is re-entrant in both directions — after modern→legacy→ - modern the claims are active again and the adopt-built adapter routes claimed raws.""" + """SDK-defined: a modern re-adopt after legacy reactivates the claims.""" dispatcher = _RecordingDispatcher(tool_result=_CLAIMED_TASK_RESULT) session = _claims_session(dispatcher, _task_claim()) with anyio.fail_after(5): @@ -260,14 +238,9 @@ async def test_modern_readopt_after_legacy_reactivates_claims() -> None: assert session._call_tool_adapter is not _CallToolResultAdapter -# ── The version-aware capability ad ───────────────────────────────────────── - - @pytest.mark.anyio async def test_legacy_initialize_ad_drops_claim_bearing_identifiers() -> None: - """SDK-defined: the legacy handshake can never negotiate a modern version, so no - claim can activate — a claim-bearing identifier drops from the initialize ad while - ad-only identifiers ride along.""" + """SDK-defined: the legacy initialize ad drops claim-bearing identifiers; ad-only ones ride along.""" dispatcher = _RecordingDispatcher() session = ClientSession( dispatcher=dispatcher, @@ -285,8 +258,7 @@ async def test_legacy_initialize_ad_drops_claim_bearing_identifiers() -> None: @pytest.mark.anyio async def test_legacy_ad_omits_extensions_entirely_when_every_identifier_drops() -> None: - """SDK-defined: when the filter drops every identifier, the ad omits the - `extensions` key — an empty extensions object advertises nothing.""" + """SDK-defined: when every identifier drops, the ad omits the `extensions` key entirely.""" dispatcher = _RecordingDispatcher() session = _claims_session(dispatcher, _task_claim()) with anyio.fail_after(5): @@ -300,8 +272,7 @@ async def test_legacy_ad_omits_extensions_entirely_when_every_identifier_drops() @pytest.mark.anyio async def test_modern_adopt_ad_includes_active_claim_identifiers() -> None: - """SDK-defined: the modern stamp's per-request `_meta` ad includes a claim-bearing - identifier when its claims are active at the adopted version.""" + """SDK-defined: the modern per-request `_meta` ad includes identifiers whose claims are active.""" dispatcher = _RecordingDispatcher() session = ClientSession( dispatcher=dispatcher, @@ -321,8 +292,7 @@ async def test_modern_adopt_ad_includes_active_claim_identifiers() -> None: @pytest.mark.anyio async def test_discover_probe_ad_includes_claim_identifiers_at_the_probe_version() -> None: - """SDK-defined: `send_discover` builds its `_meta` ad at the probe version — modern, - so claim-bearing identifiers contribute.""" + """SDK-defined: `send_discover` builds its `_meta` ad at the probe version, where claims are active.""" dispatcher = _RecordingDispatcher() session = _claims_session(dispatcher, _task_claim()) with anyio.fail_after(5): @@ -337,8 +307,7 @@ async def test_discover_probe_ad_includes_claim_identifiers_at_the_probe_version @pytest.mark.anyio async def test_discover_probe_ad_drops_claim_identifiers_at_a_legacy_probe_version() -> None: - """SDK-defined: a lowlevel `send_discover` at a non-modern version string builds an - ad where no claim can be active, so the claim-bearing identifier drops coherently.""" + """SDK-defined: at a legacy probe version no claim can be active, so the identifier drops.""" dispatcher = _RecordingDispatcher() session = _claims_session(dispatcher, _task_claim()) with anyio.fail_after(5): @@ -364,9 +333,7 @@ async def _resolve_core_tagged(result: _CoreTaggedResult, ctx: ClaimContext) -> @pytest.mark.anyio async def test_claim_tagged_core_cannot_hijack_core_parsing() -> None: - """SDK-defined: "core" is not protocol vocabulary, so a claim may use it as a wire - tag — and the adapter's internal routing sentinel must not collide: ordinary tool - results still parse as core results, and a claimed `core` raw routes to the model.""" + """SDK-defined: a claim may use "core" as its wire tag without colliding with core parsing.""" claim = ResultClaim(result_type="core", model=_CoreTaggedResult, resolve=_resolve_core_tagged) dispatcher = _RecordingDispatcher(tool_result={"resultType": "core", "payload": "p-1"}) session = ClientSession(dispatcher=dispatcher, extensions={_TASKS_EXT: {}}, result_claims={_TASKS_EXT: [claim]}) @@ -380,14 +347,10 @@ async def test_claim_tagged_core_cannot_hijack_core_parsing() -> None: assert isinstance(claimed, _CoreTaggedResult) -# ── Routing through the adopt-built adapter ───────────────────────────────── - - @pytest.mark.anyio @pytest.mark.parametrize("with_claims", [True, False]) async def test_unknown_result_type_fails_validation_with_and_without_claims(with_claims: bool) -> None: - """SDK-defined: a resultType outside the active claim set routes to the core arm - and fails core validation — exactly the claim-less session's behaviour.""" + """SDK-defined: a resultType outside the active claim set fails core validation, claims or not.""" raw = {"resultType": "weird", "taskId": "t-1"} dispatcher = _RecordingDispatcher(tool_result=raw) session = _claims_session(dispatcher, _task_claim()) if with_claims else ClientSession(dispatcher=dispatcher) @@ -396,15 +359,13 @@ async def test_unknown_result_type_fails_validation_with_and_without_claims(with _adopt_modern(session) with pytest.raises(ValidationError): await session.call_tool("t", {}, allow_claimed=True) - # The rejection came from response parsing — the request did reach the wire. + # Rejected at response parsing; the request did reach the wire. assert dispatcher.calls[-1][0] == "tools/call" @pytest.mark.anyio async def test_non_string_result_type_fails_core_validation_not_discrimination() -> None: - """SDK-defined: a malformed (non-string) resultType stays on the core arm — the - discriminator never uses it as a lookup key, so the failure is today's - ValidationError, not a TypeError.""" + """SDK-defined: a non-string resultType stays on the core arm and fails as ValidationError, not TypeError.""" raw: dict[str, Any] = {"resultType": {"nested": True}} dispatcher = _RecordingDispatcher(tool_result=raw) session = _claims_session(dispatcher, _task_claim()) @@ -413,14 +374,12 @@ async def test_non_string_result_type_fails_core_validation_not_discrimination() _adopt_modern(session) with pytest.raises(ValidationError): await session.call_tool("t", {}, allow_claimed=True) - # The rejection came from response parsing — the request did reach the wire. + # Rejected at response parsing; the request did reach the wire. assert dispatcher.calls[-1][0] == "tools/call" def test_adopt_built_adapter_revalidates_model_instances() -> None: - """SDK-defined: pydantic hands the callable discriminator either a raw dict or an - already-built model (revalidation); both route — a claim instance to its arm, a - core instance to the core arm.""" + """SDK-defined: the adopt-built adapter routes already-built model instances as well as raw dicts.""" session = _claims_session(_RecordingDispatcher(), _task_claim()) _adopt_modern(session) adapter = session._call_tool_adapter @@ -433,8 +392,7 @@ def test_adopt_built_adapter_revalidates_model_instances() -> None: @pytest.mark.anyio async def test_input_required_routes_to_core_arm_with_claims_active() -> None: - """Spec-mandated: `input_required` is core vocabulary — active claims leave the - multi-round-trip arm untouched.""" + """Spec-mandated: `input_required` is core vocabulary; active claims leave that arm untouched.""" raw = {"resultType": "input_required", "requestState": "s-1"} session = _claims_session(_RecordingDispatcher(tool_result=raw), _task_claim()) with anyio.fail_after(5): @@ -446,14 +404,10 @@ async def test_input_required_routes_to_core_arm_with_claims_active() -> None: assert result.request_state == "s-1" -# ── allow_claimed ──────────────────────────────────────────────────────────── - - @pytest.mark.anyio async def test_claimed_result_raises_unexpected_claimed_result_by_default() -> None: - """SDK-defined: without `allow_claimed=True` a claimed shape raises, carrying the - parsed value — the server may have durably created state (e.g. a task), and the - carried result is how the caller reaches its id to clean up.""" + """SDK-defined: without `allow_claimed` a claimed shape raises, carrying the parsed + result so the caller can clean up any server-side state it references.""" dispatcher = _RecordingDispatcher(tool_result=_CLAIMED_TASK_RESULT) session = _claims_session(dispatcher, _task_claim()) with anyio.fail_after(5): @@ -461,7 +415,7 @@ async def test_claimed_result_raises_unexpected_claimed_result_by_default() -> N _adopt_modern(session) with pytest.raises(UnexpectedClaimedResult) as exc_info: await session.call_tool("t", {}) - # The shape parsed and then raised — the request did reach the wire. + # The shape parsed and then raised; the request did reach the wire. assert dispatcher.calls[-1][0] == "tools/call" assert isinstance(exc_info.value.result, _TaskResult) @@ -475,8 +429,7 @@ async def test_claimed_result_raises_unexpected_claimed_result_by_default() -> N @pytest.mark.anyio async def test_call_tool_result_path_identical_under_both_allow_claimed_values() -> None: - """SDK-defined: `allow_claimed` only affects claimed shapes — an ordinary - CallToolResult comes back identical with the flag on or off.""" + """SDK-defined: `allow_claimed` only affects claimed shapes; ordinary results come back identical.""" dispatcher = _RecordingDispatcher() session = _claims_session(dispatcher, _task_claim()) with anyio.fail_after(5): @@ -491,9 +444,7 @@ async def test_call_tool_result_path_identical_under_both_allow_claimed_values() @pytest.mark.anyio async def test_call_tool_overload_matrix_narrows_statically() -> None: - """SDK-defined: the allow_input_required x allow_claimed overload matrix — each - combination narrows to its documented return union (assert_type is checked by - pyright; the canned CallToolResult satisfies every combination at runtime).""" + """SDK-defined: each flag combination narrows `call_tool` to its documented return union under pyright.""" dispatcher = _RecordingDispatcher() session = _claims_session(dispatcher, _task_claim()) with anyio.fail_after(5): @@ -511,13 +462,7 @@ async def test_call_tool_overload_matrix_narrows_statically() -> None: assert [type(r) for r in (r1, r2, r3, r4)] == [CallToolResult] * 4 -# ── The pinned dependency ──────────────────────────────────────────────────── - - def test_claimed_raw_passes_v2026_tools_call_surface_validation() -> None: - """Pins the claim path's load-bearing dependency: a tools/call raw with an unknown - resultType passes `validate_server_result` at 2026-07-28 because the v2026 surface - InputRequiredResult keeps resultType open with optional fields. If mcp-types ever - tightens that surface, claimed results would be rejected before the session's - claim adapter — this failure is the signal, not a silent break.""" + """Pins the claim path's dependency: an unknown resultType passes `validate_server_result` + at 2026-07-28; this failing is the signal that mcp-types tightened the surface.""" validate_server_result("tools/call", LATEST_MODERN_VERSION, {"resultType": "task", "taskId": "t-1"}) diff --git a/tests/client/test_session_notification_bindings.py b/tests/client/test_session_notification_bindings.py index 900817be4..2bed2bd64 100644 --- a/tests/client/test_session_notification_bindings.py +++ b/tests/client/test_session_notification_bindings.py @@ -1,10 +1,6 @@ -"""`ClientSession` notification bindings: per-binding serialized delivery through a -bounded FIFO, spawn-decoupled from the dispatcher so handlers may do session I/O -without deadlocking the in-process (DirectDispatcher) path. - -Bindings are consulted only for methods the negotiated version's core tables do -NOT know; a binding for a core-known method goes quiet, warned once at adopt(). -""" +"""`ClientSession` notification bindings: serialized per-binding delivery through a +bounded FIFO, consulted only for methods the negotiated version's core tables do +not know.""" import logging @@ -56,8 +52,7 @@ async def _noop_handler(params: _EventParams) -> None: def test_duplicate_binding_method_rejected() -> None: - """SDK-defined: two bindings on one wire method could not be routed apart, so - construction fails.""" + """SDK-defined: two bindings on one wire method cannot be routed apart, so construction fails.""" client_side, _ = create_direct_dispatcher_pair() binding = NotificationBinding(method=_VENDOR_METHOD, params_type=_EventParams, handler=_noop_handler) @@ -69,8 +64,7 @@ def test_duplicate_binding_method_rejected() -> None: @pytest.mark.anyio async def test_bound_vendor_notifications_are_delivered_in_order() -> None: - """SDK-defined: one consumer per binding serializes delivery — events arrive at the - handler in the order the server sent them.""" + """SDK-defined: one consumer per binding delivers events in the order the server sent them.""" delivered: list[int] = [] done = anyio.Event() @@ -97,9 +91,7 @@ async def on_event(params: _EventParams) -> None: @pytest.mark.anyio async def test_binding_handler_may_do_session_io_without_deadlock() -> None: - """SDK-defined: delivery is spawn-decoupled from the dispatcher, so a handler that - awaits session I/O completes even on the in-process path, where the peer's - notify() awaits `_on_notify` inline.""" + """SDK-defined: delivery is spawn-decoupled, so a handler may await session I/O without deadlock.""" pongs: list[EmptyResult] = [] done = anyio.Event() @@ -125,16 +117,8 @@ async def on_event(params: _EventParams) -> None: @pytest.mark.anyio async def test_overflow_drops_oldest_event_with_a_warning(caplog: pytest.LogCaptureFixture) -> None: - """SDK-defined: the per-binding FIFO is bounded; on overflow the OLDEST queued - event is dropped with a warning and the new event is enqueued (observation - semantics tolerate the loss; enqueueing never blocks the dispatcher). - - Steps: - 1. Deliver event 0 and block the consumer inside its handler. - 2. Fill the queue with events 1.._NOTIFICATION_QUEUE_SIZE. - 3. One more event overflows: event 1 is evicted, with a warning. - 4. Release the consumer; everything still queued is delivered in order. - """ + """SDK-defined: on overflow the bounded FIFO drops the oldest queued event with a + warning; everything still queued delivers in order.""" delivered: list[int] = [] consumer_blocked = anyio.Event() gate = anyio.Event() @@ -173,9 +157,7 @@ async def on_event(params: _EventParams) -> None: async def test_invalid_params_are_warned_and_dropped_without_reaching_handler( caplog: pytest.LogCaptureFixture, ) -> None: - """SDK-defined: params failing the binding's model are warned and dropped — - mirroring the core notification ValidationError arm — and the handler never runs - for them; later valid events still deliver.""" + """SDK-defined: params failing the binding's model are warned and dropped; later valid events deliver.""" delivered: list[int] = [] done = anyio.Event() @@ -202,8 +184,7 @@ async def on_event(params: _EventParams) -> None: @pytest.mark.anyio async def test_unbound_vendor_notification_keeps_the_debug_drop(caplog: pytest.LogCaptureFixture) -> None: - """SDK-defined: a vendor method with no binding keeps today's behaviour — a debug - log and a silent drop.""" + """SDK-defined: a vendor method with no binding keeps the debug-log-and-drop behaviour.""" caplog.set_level(logging.DEBUG, logger="client") client_side, server_side = create_direct_dispatcher_pair() @@ -224,9 +205,8 @@ async def test_unbound_vendor_notification_keeps_the_debug_drop(caplog: pytest.L async def test_core_known_method_never_reaches_binding_and_warns_once_at_adopt( caplog: pytest.LogCaptureFixture, ) -> None: - """SDK-defined: bindings are consulted only for methods core does not know at the - negotiated version — a binding for `notifications/message` goes quiet (the typed - logging callback still runs), warned exactly once at adopt().""" + """SDK-defined: a binding for a core-known method never fires and warns once at + adopt(); the typed callback still runs.""" logged: list[types.LoggingMessageNotificationParams] = [] async def logging_callback(params: types.LoggingMessageNotificationParams) -> None: @@ -243,13 +223,12 @@ async def on_message(params: BaseModel) -> None: await tg.start(server_side.run, _server_on_request, _server_on_notify) async with session: _adopt_modern(session) - # The in-process peer awaits _on_notify inline, so the typed callback ran - # by the time notify() returns. + # In-process notify() awaits _on_notify inline, so the typed callback has already run. await server_side.notify("notifications/message", {"level": "info", "data": "hello"}) server_side.close() assert [params.data for params in logged] == ["hello"] - # The bound handler never ran — a delivery would have logged its NotImplementedError. + # The bound handler never ran; a delivery would have logged its NotImplementedError. assert "notification binding handler" not in caplog.text expected = f"notification binding for 'notifications/message' will never fire at {LATEST_MODERN_VERSION}" assert caplog.text.count(expected) == 1 @@ -257,8 +236,7 @@ async def on_message(params: BaseModel) -> None: @pytest.mark.anyio async def test_handler_exception_is_contained_and_later_events_deliver(caplog: pytest.LogCaptureFixture) -> None: - """SDK-defined: a raising handler costs only that delivery — the consumer logs the - exception and keeps serving subsequent events.""" + """SDK-defined: a raising handler costs only that delivery; later events still deliver.""" delivered: list[int] = [] done = anyio.Event() @@ -287,8 +265,7 @@ async def on_event(params: _EventParams) -> None: @pytest.mark.anyio async def test_binding_delivery_works_without_adopt() -> None: - """SDK-defined: bindings need no negotiated version — pre-handshake sessions fall - back to the default version tables, where a vendor method is just as unknown.""" + """SDK-defined: bindings deliver pre-handshake, under the default version tables.""" delivered: list[int] = [] done = anyio.Event() diff --git a/tests/client/test_session_promotions.py b/tests/client/test_session_promotions.py index dfceca302..e4a62732b 100644 --- a/tests/client/test_session_promotions.py +++ b/tests/client/test_session_promotions.py @@ -1,6 +1,5 @@ """`dispatch_input_request` and `validate_tool_result` are public `ClientSession` API.""" - import mcp_types as types import pytest from mcp_types import ( @@ -34,8 +33,7 @@ async def list_roots(context: ClientRequestContext) -> ListRootsResult: @pytest.mark.anyio async def test_dispatch_input_request_returns_error_data_on_refusal() -> None: - """The `ErrorData` arm is the refusal path: with no callback registered, the - default callback declines and the caller receives the error, not a raise.""" + """With no callback registered, refusal comes back as `ErrorData`, not a raise.""" client_side, _server_side = create_direct_dispatcher_pair() session = ClientSession(dispatcher=client_side) ctx = ClientRequestContext(session=session, request_id="r-1") diff --git a/tests/docs_src/test_extensions.py b/tests/docs_src/test_extensions.py index 195bbf14a..a9dc3a1bd 100644 --- a/tests/docs_src/test_extensions.py +++ b/tests/docs_src/test_extensions.py @@ -106,23 +106,20 @@ async def test_interceptor_observes_the_call_and_passes_the_result_through( async def test_the_receipts_client_program_runs_as_shown(capsys: pytest.CaptureFixture[str]) -> None: - """tutorial006: `main()` is the literal client program on the page — the claimed - `receipt` shape never surfaces, and the printed content is the redeemed result.""" + """tutorial006: `main()` runs as printed and the output is the redeemed result, never the claimed shape.""" await tutorial006.main() assert "goods for r-117" in capsys.readouterr().out async def test_a_claimed_shape_fails_validation_without_the_extension() -> None: - """The page's off-by-default claim: a client that does not construct `Receipts` - rejects the `receipt` shape as invalid (spec: an unrecognized resultType is invalid).""" + """The page's off-by-default claim: a client without `Receipts` rejects the `receipt` shape as invalid.""" async with Client(tutorial006.mcp) as client: with pytest.raises(ValidationError): await client.call_tool("buy", {"item": "lamp"}) async def test_session_tier_allow_claimed_returns_the_raw_shape() -> None: - """The page's escape hatch: `client.session.call_tool(..., allow_claimed=True)` hands - back the parsed claim model instead of running the resolver.""" + """The page's escape hatch: `allow_claimed=True` returns the parsed claim model, not the resolved result.""" async with Client(tutorial006.mcp, extensions=[tutorial006.Receipts()]) as client: result = await client.session.call_tool("buy", {"item": "lamp"}, allow_claimed=True) assert isinstance(result, tutorial006.ReceiptResult) @@ -130,7 +127,6 @@ async def test_session_tier_allow_claimed_returns_the_raw_shape() -> None: async def test_the_jobs_client_program_runs_as_shown(capsys: pytest.CaptureFixture[str]) -> None: - """tutorial007: a vendor request type with `name_param` goes through - `client.session.send_request` with no registration and returns its typed result.""" + """tutorial007: a vendor request with `name_param` round-trips `send_request` with no registration.""" await tutorial007.main() assert "job-7 is running" in capsys.readouterr().out diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 2fbe360a2..dfa66e22b 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -2390,8 +2390,8 @@ def __post_init__(self) -> None: source=f"{SPEC_2026_BASE_URL}/basic#resulttype", behavior=( "A tools/call answered with an extension-claimed resultType is finished by the owning " - "ClientExtension's claim resolver — which may send follow-up requests through the session it is " - "handed — and Client.call_tool returns the resolver's ordinary CallToolResult." + "ClientExtension's claim resolver, and Client.call_tool returns the resolver's ordinary " + "CallToolResult. The resolver may send follow-up requests through the session it is handed." ), added_in="2026-07-28", ), diff --git a/tests/interaction/mcpserver/test_extensions.py b/tests/interaction/mcpserver/test_extensions.py index 72d0e5777..205a7fd6e 100644 --- a/tests/interaction/mcpserver/test_extensions.py +++ b/tests/interaction/mcpserver/test_extensions.py @@ -1,11 +1,5 @@ -"""Client extensions (SEP-2133) over the full client-server loop. - -The servers here are MCPServers whose server extension substitutes a claimed `tools/call` -shape via `intercept_tool_call`; the client declares the owning `ClientExtension` and its -claim resolver finishes the call. A short-circuiting interceptor's dict is passed through -verbatim (the runner trusts it as a well-formed result), so claimed shapes carry their -vendor fields end to end — the models below prove that with top-level vendor fields. -""" +"""Client extensions (SEP-2133) over the full client-server loop: a server extension +substitutes a claimed `tools/call` shape and the declaring client's `ClientExtension` resolves it.""" from collections.abc import Awaitable, Callable, Sequence from typing import Any, Literal @@ -31,8 +25,6 @@ class ReceiptResult(Result): - """The claimed `tools/call` shape, tagged `receipt`, carrying vendor top-level fields.""" - result_type: Literal["receipt"] = "receipt" receipt_token: str settings_echo: dict[str, Any] | None = None @@ -42,7 +34,7 @@ class ReceiptResult(Result): class Receipts(ClientExtension): - """Client half: claims the `receipt` tag with the test's resolver and settings.""" + """Client half: claims the `receipt` shape with the test's resolver and settings.""" identifier = _RECEIPTS @@ -71,7 +63,6 @@ async def intercept_tool_call( def _receipt_shop(issuer: Extension) -> MCPServer: - """An MCPServer whose `buy` tool the server extension rewrites into the claimed shape.""" server = MCPServer("shop", extensions=[issuer]) @server.tool() @@ -89,10 +80,8 @@ def redeem(token: str) -> str: @requirement("extensions:client:claimed-result-resolved") async def test_claimed_result_is_finished_by_the_owning_extensions_resolver(connect: Connect) -> None: - """The transparent claim path, both ends real: the server extension substitutes the - `receipt` shape, the client's claim resolver redeems it with a follow-up `tools/call` - through `ctx.session` — the same authority as `client.session` — and `call_tool` - returns the resolver's plain `CallToolResult`. The claimed shape never surfaces.""" + """The owning extension's claim resolver redeems the substituted `receipt` through + `ctx.session`, and `call_tool` returns the resolver's plain `CallToolResult`.""" received: list[ReceiptResult] = [] async def redeem_receipt(claimed: ReceiptResult, ctx: ClaimContext) -> CallToolResult: @@ -110,17 +99,15 @@ async def redeem_receipt(claimed: ReceiptResult, ctx: ClaimContext) -> CallToolR @requirement("extensions:client:claimed-result-undeclared-invalid") async def test_claimed_shape_fails_validation_for_a_client_without_the_extension(connect: Connect) -> None: - """Spec-mandated: an unrecognized `resultType` is invalid. A client that did not - construct the owning extension rejects the very shape the previous test resolves — - the request reaches the server, the substituted result fails client-side parsing.""" + """Spec-mandated: an unrecognized `resultType` is invalid, so a client without the + owning extension fails to parse the claimed shape.""" async with connect(_receipt_shop(_ReceiptIssuer())) as client: with pytest.raises(ValidationError): await client.call_tool("buy", {"item": "lamp"}) class _SettingsEchoIssuer(Extension): - """Server half for the ad tests: refuses non-declaring clients, then echoes the - declared settings back through the claimed payload.""" + """Server half: requires the declaring client, then echoes its declared settings.""" identifier = _RECEIPTS @@ -129,7 +116,7 @@ async def intercept_tool_call( ) -> HandlerResult: require_client_extension(ctx, _RECEIPTS) client_params = ctx.session.client_params - assert client_params is not None # require_client_extension just read it + assert client_params is not None extensions = client_params.capabilities.extensions assert extensions is not None return {"resultType": "receipt", "receiptToken": "echo", "settingsEcho": extensions[_RECEIPTS]} @@ -137,10 +124,8 @@ async def intercept_tool_call( @requirement("extensions:client:capability-ad:gates-server-behaviour") async def test_per_request_ad_carries_settings_and_gates_the_claimed_substitution(connect: Connect) -> None: - """The per-request `_meta` capability ad is the entitlement for claimed shapes: the - server extension's gate passes only for the declaring client, observes the declared - settings on the request, and the resolver receives them back through the payload. - A client declaring nothing is refused with -32021, not served the shape.""" + """The per-request `_meta` capability ad gates the claimed substitution: declared + settings reach the resolver and a non-declaring client is refused with -32021.""" server = MCPServer("shop", extensions=[_SettingsEchoIssuer()]) @server.tool() @@ -171,9 +156,8 @@ async def _unreachable_resolve(claimed: ReceiptResult, ctx: ClaimContext) -> Cal @requirement("extensions:client:capability-ad:legacy-omits-claimed") async def test_legacy_ad_omits_claim_bearing_identifiers_but_keeps_claim_less_ones(connect: Connect) -> None: - """On a legacy connection the claims dissolve and the ad follows them: the - claim-bearing identifier is absent from the initialize capability ad the server - sees, while an ad-only identifier on the same client still advertises.""" + """On a legacy connection the claim-bearing identifier drops out of the initialize + capability ad while an ad-only identifier still advertises.""" server = MCPServer("introspector") @server.tool() diff --git a/tests/interaction/transports/test_hosting_http_modern.py b/tests/interaction/transports/test_hosting_http_modern.py index 31de9cd57..f01502bcd 100644 --- a/tests/interaction/transports/test_hosting_http_modern.py +++ b/tests/interaction/transports/test_hosting_http_modern.py @@ -560,8 +560,6 @@ class _JobParams(RequestParams): class _JobStatusRequest(Request[_JobParams, Literal["com.example/jobs.status"]]): - """A vendor (extension) request type that names its subject for the Mcp-Name header.""" - method: Literal["com.example/jobs.status"] = "com.example/jobs.status" name_param = "jobId" @@ -572,14 +570,8 @@ class _JobStatusResult(Result): @requirement("client-transport:http:vendor-name-param-header") async def test_vendor_request_with_name_param_carries_mcp_name_on_the_wire() -> None: - """A vendor request sent through `send_request` carries `Mcp-Name` from its `name_param` key. - - The request type is never registered with the client; `send_request` reads the declared - `name_param` ("jobId"), mirrors the params value into the `Mcp-Name` header, and the value - stays in the body unchanged. Asserted at the wire because the client never surfaces the - outgoing headers. The server serves the vendor method through `add_request_handler`, so the - round trip also proves the typed result comes back without any client-side method table. - """ + """`send_request` mirrors an unregistered vendor request's `name_param` value into the + `Mcp-Name` header while the body keeps the params key unchanged.""" async def job_status(ctx: ServerRequestContext, params: _JobParams) -> _JobStatusResult: assert params.job_id == "job-7" diff --git a/tests/shared/test_extension.py b/tests/shared/test_extension.py index 87383d28a..fd9192554 100644 --- a/tests/shared/test_extension.py +++ b/tests/shared/test_extension.py @@ -1,5 +1,4 @@ -"""Tests for `mcp.shared.extension` — the extension-identifier grammar shared by -the server and client extension surfaces.""" +"""The extension-identifier grammar in `mcp.shared.extension`, shared by server and client.""" from typing import Any @@ -11,8 +10,7 @@ def test_server_extension_module_reexports_shared_validator() -> None: - """SDK-defined: `mcp.server.extension.validate_extension_identifier` remains - importable after the move and is the very same function object.""" + """SDK-defined: `mcp.server.extension` re-exports the shared validator as the same function object.""" assert mcp.server.extension.validate_extension_identifier is mcp.shared.extension.validate_extension_identifier @@ -28,8 +26,7 @@ def test_server_extension_module_reexports_shared_validator() -> None: ], ) def test_grammar_conformant_extension_identifiers_are_accepted(identifier: str) -> None: - """Spec `_meta` key grammar: dot-separated labels (letter start, letter/digit end, - hyphens interior), a slash, then a name that starts and ends alphanumeric.""" + """Spec `_meta` key grammar: conformant `vendor-prefix/name` identifiers are accepted.""" validate_extension_identifier(identifier, owner="T") @@ -54,7 +51,6 @@ def test_grammar_conformant_extension_identifiers_are_accepted(identifier: str) ], ) def test_malformed_extension_identifiers_are_rejected(identifier: Any) -> None: - """Spec `_meta` key grammar: malformed prefixes (bad label start/end, empty labels) - and malformed names are rejected, as are non-strings.""" + """Spec `_meta` key grammar: malformed prefixes, malformed names, and non-strings are rejected.""" with pytest.raises(TypeError): validate_extension_identifier(identifier, owner="T") diff --git a/tests/types/test_request_name_param.py b/tests/types/test_request_name_param.py index 8666f2fca..c8efe4fbb 100644 --- a/tests/types/test_request_name_param.py +++ b/tests/types/test_request_name_param.py @@ -1,4 +1,4 @@ -"""`Request.name_param` — the wire-params key a request type declares for `Mcp-Name` emission.""" +"""`Request.name_param`: the wire-params key a request type declares for `Mcp-Name` emission.""" from typing import Literal From 99fec3c6434481c97fbc9251a9f8a1341ef6b866 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 30 Jun 2026 18:36:03 +0000 Subject: [PATCH 14/15] Address review feedback on claim construction and notification ordering ResultClaim now rejects models that do not subclass Result at runtime (the typing bound only constrains checked callers) and models whose fields alias requestState or inputRequests, which the core result surface types and would reject before the claim adapter runs. The notification ordering contract is stated honestly: per-binding serial delivery in dispatch order, with near-simultaneous notifications on stream transports possibly dispatched out of wire order. The receipts tutorial server now gates substitution on the client declaring the extension, the stale custom-methods README caveat points at NotificationBinding, and the capability-ad docstring no longer overstates which identifiers drop. --- docs/advanced/extensions.md | 11 +++++---- docs_src/extensions/tutorial006.py | 3 ++- examples/stories/custom_methods/README.md | 8 +++---- src/mcp/client/extension.py | 20 +++++++++++++--- src/mcp/client/session.py | 4 +++- tests/client/test_extension.py | 28 +++++++++++++++++++++++ tests/docs_src/test_extensions.py | 8 +++---- tests/interaction/_requirements.py | 7 +++++- 8 files changed, 70 insertions(+), 19 deletions(-) diff --git a/docs/advanced/extensions.md b/docs/advanced/extensions.md index 3646fdf0c..0358ba5a8 100644 --- a/docs/advanced/extensions.md +++ b/docs/advanced/extensions.md @@ -151,7 +151,7 @@ A **client extension** is the same contract from the consuming side: a bundle of client-side behaviour behind one identifier. Pass instances to `Client(extensions=[...])` and call tools normally: -```python title="client.py" hl_lines="66-68" +```python title="client.py" hl_lines="67-69" --8<-- "docs_src/extensions/tutorial006.py" ``` @@ -161,8 +161,9 @@ shape** instead of a final result, and `Receipts` finishes it (here by redeeming receipt with a follow-up call) before `call_tool` returns. Nothing about the call site moves. -Drop the extension and none of this exists: a `receipt` shape arriving at a client -that didn't declare it fails validation, exactly as the spec requires for an +Drop the extension and none of this exists: the server's gate refuses a client +that did not declare it (error -32021), and a claimed shape from a server that +skips the gate fails validation, exactly as the spec requires for an unrecognized `resultType`. Off by default, on both ends of the wire. To advertise an identifier with **no** client-side behaviour (the server gates on @@ -180,7 +181,7 @@ client = Client(mcp, extensions=[advertise("com.example/search")]) Subclass `ClientExtension` and override only what you need. Three contribution kinds, each with a default: `settings()`, `claims()`, and `notifications()`. -```python title="client.py" hl_lines="18-19 43-44 46-47" +```python title="client.py" hl_lines="18-19 44-45 47-48" --8<-- "docs_src/extensions/tutorial006.py" ``` @@ -205,7 +206,7 @@ def notifications(self) -> Sequence[NotificationBinding[Any]]: return [NotificationBinding(method="notifications/receipts", params_type=ReceiptEvent, handler=self.on_receipt)] ``` -The handler receives validated params, in arrival order. It observes; it cannot veto +The handler receives validated params one at a time, in dispatch order. It observes; it cannot veto or reply. Two quiet rules. Claims are active on 2026-07-28 connections only, and the capability diff --git a/docs_src/extensions/tutorial006.py b/docs_src/extensions/tutorial006.py index 55f3c50c1..05ffbcb9d 100644 --- a/docs_src/extensions/tutorial006.py +++ b/docs_src/extensions/tutorial006.py @@ -7,7 +7,7 @@ from mcp.client import ClaimContext, ClientExtension, ResultClaim from mcp.server.context import CallNext, HandlerResult, ServerRequestContext from mcp.server.extension import Extension -from mcp.server.mcpserver import MCPServer +from mcp.server.mcpserver import MCPServer, require_client_extension EXTENSION_ID = "com.example/receipts" @@ -32,6 +32,7 @@ async def intercept_tool_call( ) -> HandlerResult: if params.name != "buy": return await call_next(ctx) + require_client_extension(ctx, EXTENSION_ID) return {"resultType": "receipt", "receiptToken": "r-117"} diff --git a/examples/stories/custom_methods/README.md b/examples/stories/custom_methods/README.md index 96d223fef..75f150202 100644 --- a/examples/stories/custom_methods/README.md +++ b/examples/stories/custom_methods/README.md @@ -34,10 +34,10 @@ uv run python -m stories.custom_methods.client --http ## Caveats - The TypeScript SDK's equivalent example also shows a custom server→client - **notification** (`acme/searchProgress`). The Python client currently drops - any notification whose method is not in the spec registry - (`ClientSession._on_notify` → `KeyError` → silent drop), and there is no - `set_notification_handler` analogue. That half is omitted here. + **notification** (`acme/searchProgress`). The Python client can observe + vendor notifications via `NotificationBinding` (see + `docs/advanced/extensions.md`). That half is omitted here because the + lowlevel server has no surface for emitting vendor notifications yet. ## Spec diff --git a/src/mcp/client/extension.py b/src/mcp/client/extension.py index adc15c1b7..78c283f91 100644 --- a/src/mcp/client/extension.py +++ b/src/mcp/client/extension.py @@ -32,6 +32,9 @@ _CLAIM_METHODS: Final[frozenset[str]] = frozenset({"tools/call"}) """The closed set of verbs a claim may attach to; widen together with the `method` Literal.""" +_RESERVED_WIRE_ALIASES: Final[frozenset[str]] = frozenset({"requestState", "inputRequests"}) +"""Typed optional fields of the core result surface that pre-validates every inbound result.""" + ClaimedT = TypeVar("ClaimedT", bound=Result) NotifyParamsT = TypeVar("NotifyParamsT", bound=BaseModel) @@ -66,8 +69,17 @@ def __post_init__(self) -> None: raise ValueError(f"claims attach to {sorted(_CLAIM_METHODS)} only; got method {self.method!r}") if self.result_type in CORE_RESULT_TYPES: raise ValueError(f"resultType {self.result_type!r} is core protocol vocabulary") + if Result not in self.model.__mro__: # runtime guard; the ClaimedT bound only constrains checked callers + raise ValueError(f"{self.model.__name__} must subclass mcp_types.Result") if issubclass(self.model, CallToolResult | InputRequiredResult): raise ValueError("claim models must not subclass core result types") + for name, model_field in self.model.model_fields.items(): + if (model_field.alias or name) in _RESERVED_WIRE_ALIASES: + raise ValueError( + f"{self.model.__name__}.{name} aliases {model_field.alias or name!r}, a typed field " + "of the core result surface; a colliding value would fail core validation before " + "the claim adapter runs" + ) field = self.model.model_fields.get("result_type") if field is None or get_args(field.annotation) != (self.result_type,): raise ValueError(f"{self.model.__name__}.result_type must be Literal[{self.result_type!r}]") @@ -101,9 +113,11 @@ def __init__(self, result: Result) -> None: class NotificationBinding(Generic[NotifyParamsT]): """Deliver server notifications for `method` (the bare wire name) to `handler`. - Observation-only: validated params arrive in order through a bounded queue, - dropping the oldest with a warning on overflow. Methods the negotiated - version's core tables handle are never delivered to bindings. + Observation-only: validated params arrive one at a time per binding, in + dispatch order, through a bounded queue that drops the oldest with a warning + on overflow. Stream transports dispatch each notification independently, so + near-simultaneous notifications may be dispatched out of wire order. Methods + the negotiated version's core tables handle are never delivered to bindings. """ method: str diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 422c74f6a..e6ae766d9 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -505,7 +505,9 @@ async def send_notification(self, notification: types.ClientNotification) -> Non def _build_capabilities(self, version: str) -> types.ClientCapabilities: """Build the capability ad for a wire speaking `version`. - Identifiers with no active claim drop, so the client never advertises result shapes it would reject. + Claim-bearing identifiers whose claims are all inactive at `version` drop, so + the client never advertises result shapes it would reject; claim-less + identifiers always advertise. """ extensions = self._extensions if extensions is not None and self._result_claims: diff --git a/tests/client/test_extension.py b/tests/client/test_extension.py index cd1894788..a53b4e240 100644 --- a/tests/client/test_extension.py +++ b/tests/client/test_extension.py @@ -120,6 +120,34 @@ def test_claim_rejects_mismatched_result_type_literal() -> None: assert str(exc_info.value) == snapshot("_OtherTagResult.result_type must be Literal['task']") +class _NotAResult(BaseModel): + result_type: Literal["plain"] = "plain" + + +class _ReservedAliasResult(Result): + result_type: Literal["clash"] = "clash" + request_state: dict[str, Any] = {} + + +def test_claim_rejects_model_not_subclassing_result() -> None: + """SDK-defined: a plain BaseModel cannot be a claim model; the session returns `Result` values.""" + with pytest.raises(ValueError) as exc_info: + ResultClaim(result_type="plain", model=cast("type[Result]", _NotAResult), resolve=_resolve) + + assert str(exc_info.value) == snapshot("_NotAResult must subclass mcp_types.Result") + + +def test_claim_rejects_model_aliasing_core_surface_fields() -> None: + """SDK-defined: a field aliasing requestState or inputRequests would fail core pre-validation.""" + with pytest.raises(ValueError) as exc_info: + ResultClaim(result_type="clash", model=_ReservedAliasResult, resolve=_resolve) + + assert str(exc_info.value) == snapshot( + "_ReservedAliasResult.request_state aliases 'requestState', a typed field of the core " + "result surface; a colliding value would fail core validation before the claim adapter runs" + ) + + def test_claim_rejects_method_outside_the_closed_verb_set() -> None: """SDK-defined: claims attach to `tools/call` only, even for values that dodge the static Literal gate.""" with pytest.raises(ValueError) as exc_info: diff --git a/tests/docs_src/test_extensions.py b/tests/docs_src/test_extensions.py index a9dc3a1bd..2a141337b 100644 --- a/tests/docs_src/test_extensions.py +++ b/tests/docs_src/test_extensions.py @@ -5,7 +5,6 @@ import pytest from inline_snapshot import snapshot from mcp_types import METHOD_NOT_FOUND, MISSING_REQUIRED_CLIENT_CAPABILITY, TextContent -from pydantic import ValidationError from docs_src.extensions import ( tutorial001, @@ -111,11 +110,12 @@ async def test_the_receipts_client_program_runs_as_shown(capsys: pytest.CaptureF assert "goods for r-117" in capsys.readouterr().out -async def test_a_claimed_shape_fails_validation_without_the_extension() -> None: - """The page's off-by-default claim: a client without `Receipts` rejects the `receipt` shape as invalid.""" +async def test_a_client_without_the_extension_is_refused_by_the_gate() -> None: + """The page's off-by-default claim: the server's capability gate refuses a non-declaring client.""" async with Client(tutorial006.mcp) as client: - with pytest.raises(ValidationError): + with pytest.raises(MCPError) as exc_info: await client.call_tool("buy", {"item": "lamp"}) + assert exc_info.value.code == MISSING_REQUIRED_CLIENT_CAPABILITY async def test_session_tier_allow_claimed_returns_the_raw_shape() -> None: diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index dfa66e22b..ada4b7fa0 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -2403,6 +2403,11 @@ def __post_init__(self) -> None: "declared claims, never more)." ), added_in="2026-07-28", + note=( + "Known leniency: the monolith result surface still accepts an unknown tag when the payload " + "also parses as a complete core result (open result_type, extras ignored). Rejecting tags " + "outside core plus active claims is a tracked follow-up ruling." + ), ), "extensions:client:capability-ad:gates-server-behaviour": Requirement( source=f"{SPEC_2026_BASE_URL}/basic#resulttype", @@ -2432,7 +2437,7 @@ def __post_init__(self) -> None: source=f"{SPEC_2026_BASE_URL}/basic#resulttype", behavior=( "A vendor server notification bound by a ClientExtension's NotificationBinding is validated " - "against the binding's params type and delivered to its handler in arrival order." + "against the binding's params type and delivered to its handler serially, in dispatch order." ), added_in="2026-07-28", deferred=( From 53a538a623f29144771118815e66b7dfa6f8b793 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 30 Jun 2026 20:00:29 +0000 Subject: [PATCH 15/15] Check every alias form when reserving core surface wire keys A claim model could route a reserved key through validation_alias, serialization_alias, or an AliasChoices entry and still hit the core pre-validation dead end the reservation exists to prevent. _wire_keys collects every top-level key a field can read from or write to, and the reservation checks the full set. --- src/mcp/client/extension.py | 27 +++++++++++++--- tests/client/test_extension.py | 58 +++++++++++++++++++++++++++++++++- 2 files changed, 79 insertions(+), 6 deletions(-) diff --git a/src/mcp/client/extension.py b/src/mcp/client/extension.py index 78c283f91..a813475e5 100644 --- a/src/mcp/client/extension.py +++ b/src/mcp/client/extension.py @@ -13,7 +13,8 @@ from mcp_types import CORE_RESULT_TYPES, CallToolResult, InputRequiredResult, Result from mcp_types.version import MODERN_PROTOCOL_VERSIONS -from pydantic import BaseModel +from pydantic import AliasChoices, AliasPath, BaseModel +from pydantic.fields import FieldInfo from mcp.shared.extension import validate_extension_identifier @@ -35,6 +36,22 @@ _RESERVED_WIRE_ALIASES: Final[frozenset[str]] = frozenset({"requestState", "inputRequests"}) """Typed optional fields of the core result surface that pre-validates every inbound result.""" + +def _wire_keys(name: str, field: FieldInfo) -> frozenset[str]: + """Every top-level wire key this field can read from or write to.""" + keys = {field.alias or name} + if field.serialization_alias: + keys.add(field.serialization_alias) + validation_alias = field.validation_alias + choices = validation_alias.choices if isinstance(validation_alias, AliasChoices) else [validation_alias] + for choice in choices: + if isinstance(choice, AliasPath): + choice = choice.path[0] + if isinstance(choice, str): + keys.add(choice) + return frozenset(keys) + + ClaimedT = TypeVar("ClaimedT", bound=Result) NotifyParamsT = TypeVar("NotifyParamsT", bound=BaseModel) @@ -74,11 +91,11 @@ def __post_init__(self) -> None: if issubclass(self.model, CallToolResult | InputRequiredResult): raise ValueError("claim models must not subclass core result types") for name, model_field in self.model.model_fields.items(): - if (model_field.alias or name) in _RESERVED_WIRE_ALIASES: + for clash in sorted(_wire_keys(name, model_field) & _RESERVED_WIRE_ALIASES): raise ValueError( - f"{self.model.__name__}.{name} aliases {model_field.alias or name!r}, a typed field " - "of the core result surface; a colliding value would fail core validation before " - "the claim adapter runs" + f"{self.model.__name__}.{name} aliases {clash!r}, a typed field of the core " + "result surface; a colliding value would fail core validation before the " + "claim adapter runs" ) field = self.model.model_fields.get("result_type") if field is None or get_args(field.annotation) != (self.result_type,): diff --git a/tests/client/test_extension.py b/tests/client/test_extension.py index a53b4e240..c30708137 100644 --- a/tests/client/test_extension.py +++ b/tests/client/test_extension.py @@ -7,13 +7,15 @@ from inline_snapshot import snapshot from mcp_types import CallToolResult, InputRequiredResult, Result from mcp_types.version import MODERN_PROTOCOL_VERSIONS -from pydantic import BaseModel +from pydantic import AliasChoices, AliasPath, BaseModel, Field +from pydantic.fields import FieldInfo from mcp.client.extension import ( ClaimContext, ClientExtension, NotificationBinding, ResultClaim, + _wire_keys, advertise, ) @@ -148,6 +150,60 @@ def test_claim_rejects_model_aliasing_core_surface_fields() -> None: ) +class _ValidationAliasResult(Result): + result_type: Literal["va"] = "va" + vendor_state: dict[str, Any] | None = Field(default=None, validation_alias="requestState") + + +class _SerializationAliasResult(Result): + result_type: Literal["sa"] = "sa" + vendor_state: dict[str, Any] | None = Field(default=None, serialization_alias="inputRequests") + + +class _AliasChoicesResult(Result): + result_type: Literal["ac"] = "ac" + vendor_state: dict[str, Any] | None = Field( + default=None, validation_alias=AliasChoices("vendorKey", "requestState") + ) + + +class _AliasPathResult(Result): + result_type: Literal["ap"] = "ap" + vendor_state: dict[str, Any] | None = Field( + default=None, validation_alias=AliasChoices(AliasPath("requestState", "nested")) + ) + + +def test_wire_keys_for_a_bare_field_is_just_its_name() -> None: + """SDK-defined: a field with no aliases reads and writes only its own name.""" + assert _wire_keys("plain", FieldInfo(annotation=str)) == frozenset({"plain"}) + + +def test_claim_rejects_reserved_aliases_in_every_alias_form() -> None: + """SDK-defined: validation_alias, serialization_alias, and AliasChoices routes to a reserved key are all caught.""" + messages: dict[str, str] = {} + for model in (_ValidationAliasResult, _SerializationAliasResult, _AliasChoicesResult, _AliasPathResult): + with pytest.raises(ValueError) as exc_info: + ResultClaim(result_type=model.model_fields["result_type"].default, model=model, resolve=_resolve) + messages[model.__name__] = str(exc_info.value) + + assert messages == snapshot( + { + "_ValidationAliasResult": "_ValidationAliasResult.vendor_state aliases " + "'requestState', a typed field of the core result surface; a colliding value would fail " + "core validation before the claim adapter runs", + "_SerializationAliasResult": "_SerializationAliasResult.vendor_state aliases " + "'inputRequests', a typed field of the core result surface; a colliding value would fail " + "core validation before the claim adapter runs", + "_AliasChoicesResult": "_AliasChoicesResult.vendor_state aliases 'requestState', a typed field of the core " + "result surface; a colliding value would fail core validation before the claim adapter runs", + "_AliasPathResult": "_AliasPathResult.vendor_state aliases " + "'requestState', a typed field of the core result surface; a colliding value would fail " + "core validation before the claim adapter runs", + } + ) + + def test_claim_rejects_method_outside_the_closed_verb_set() -> None: """SDK-defined: claims attach to `tools/call` only, even for values that dodge the static Literal gate.""" with pytest.raises(ValueError) as exc_info: