diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index d316345c7..43c0dc813 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -549,6 +549,20 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re request_id = str(message.id) + # Reject duplicate in-flight request ids: `_request_streams` is keyed by + # request id, so a second concurrent request with the same id would + # silently overwrite the first one's routing slot and cross-wire their + # responses (one request receives the other's payload, the other hangs). + # The spec requires ids to be unique within a session; ids may still be + # reused once the earlier request has completed. See #3060. + if request_id in self._request_streams: + response = self._create_error_response( + f"Bad Request: Request id {request_id} is already in flight for this session", + HTTPStatus.BAD_REQUEST, + ) + await response(scope, receive, send) + return + if self.is_json_response_enabled: self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage]( REQUEST_STREAM_BUFFER_SIZE @@ -597,19 +611,28 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re finally: await self._clean_up_memory_streams(request_id) else: - # Mint the priming event before any per-request state exists: - # `EventStore.store_event` is user code and may raise, in which - # case the outer handler returns a 500 with nothing to clean up. - # Still strictly precedes dispatch, so storage order == wire order. - priming_event = await self._mint_priming_event(request_id, protocol_version) - - sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0) - self._sse_stream_writers[request_id] = sse_stream_writer + # Reserve the routing slot before any await so nothing separates + # the duplicate-id guard above from registration; a concurrent + # POST reusing this id while the event store persists the priming + # event could otherwise pass the guard and overwrite the slot. self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage]( REQUEST_STREAM_BUFFER_SIZE ) request_stream_reader = self._request_streams[request_id][1] + # Mint the priming event before dispatch so storage order matches + # wire order. `EventStore.store_event` is user code and may raise, + # in which case the outer handler returns a 500; release the slot + # reservation on the way out so the id does not stay in flight. + try: + priming_event = await self._mint_priming_event(request_id, protocol_version) + except BaseException: + await self._clean_up_memory_streams(request_id) + raise + + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0) + self._sse_stream_writers[request_id] = sse_stream_writer + headers = { "Cache-Control": "no-cache, no-transform", "Connection": "keep-alive", diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index cbce222ec..7cb4443e4 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -11,7 +11,7 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass, field -from typing import Any +from typing import Any, cast from unittest.mock import MagicMock from urllib.parse import urlparse @@ -49,7 +49,6 @@ from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http import ( GET_STREAM_KEY, - MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, SESSION_ID_PATTERN, EventCallback, @@ -63,6 +62,7 @@ from mcp.server.transport_security import TransportSecuritySettings from mcp.shared._compat import resync_tracer from mcp.shared._context_streams import create_context_streams +from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder from tests.interaction.transports import StreamingASGITransport @@ -94,6 +94,14 @@ def first_sse_data(response: httpx.Response) -> dict[str, Any]: raise ValueError("No data event in SSE response") # pragma: no cover +async def next_sse_data(lines: AsyncIterator[str]) -> dict[str, Any]: + """Return the next SSE `data:` payload from a live line iterator, parsed as JSON.""" + while True: + line = await anext(lines) + if line.startswith("data: "): + return json.loads(line.removeprefix("data: ")) + + def extract_protocol_version_from_sse(response: httpx.Response) -> str: """Extract the negotiated protocol version from an SSE initialization response.""" return first_sse_data(response)["result"]["protocolVersion"] @@ -680,6 +688,223 @@ async def test_response(basic_app: Starlette) -> None: assert tools_response.headers.get("Content-Type") == "text/event-stream" +@pytest.mark.anyio +async def test_duplicate_in_flight_request_id_rejected(basic_app: Starlette) -> None: + """A request whose id is already in flight on the session is rejected with 400. + + The per-request routing in the transport is keyed by request id, so a second + concurrent request with the same id would overwrite the in-flight request's + routing slot and cross-wire the two responses (see #3060). The duplicate is + rejected and the in-flight request completes unaffected. + """ + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + headers = { + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: response.headers[MCP_SESSION_ID_HEADER], + MCP_PROTOCOL_VERSION_HEADER: extract_protocol_version_from_sse(response), + } + + # Request A blocks server-side on the lock, keeping its id in flight. + async with client.stream( + "POST", + "/mcp", + headers=headers, + json={ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": {"name": "wait_for_lock_with_notification", "arguments": {}}, + }, + ) as response_a: + assert response_a.status_code == 200 + lines_a = response_a.aiter_lines() + # The tool's first notification confirms request A is in flight. + with anyio.fail_after(5): + notification = await next_sse_data(lines_a) + assert notification["params"]["data"] == "First notification before lock" + + # A second request reusing id 1 while A is in flight is rejected. + response_b = await client.post( + "/mcp", + headers=headers, + json={ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": {"name": "test_tool", "arguments": {}}, + }, + ) + assert response_b.status_code == 400 + error = response_b.json()["error"] + assert error["code"] == INVALID_REQUEST + assert "already in flight" in error["message"] + + # Request A is unaffected: release the lock and it completes normally. + release_response = await client.post( + "/mcp", + headers=headers, + json={ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": {"name": "release_lock", "arguments": {}}, + }, + ) + assert release_response.status_code == 200 + + with anyio.fail_after(5): + notification = await next_sse_data(lines_a) + final = await next_sse_data(lines_a) + assert notification["params"]["data"] == "Second notification after lock" + assert final["id"] == 1 + assert final["result"]["content"][0]["text"] == "Completed" + + +@pytest.mark.anyio +async def test_request_id_reuse_after_completion_allowed(basic_app: Starlette) -> None: + """A request id can be reused once the earlier request with that id has completed. + + Only concurrent requests with the same id are ambiguous to route; sequential + reuse (which some deployed clients rely on, sending every request with id 1) + keeps working (see #3060). + """ + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + headers = { + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: response.headers[MCP_SESSION_ID_HEADER], + MCP_PROTOCOL_VERSION_HEADER: extract_protocol_version_from_sse(response), + } + + for _ in range(2): + response = await client.post( + "/mcp", + headers=headers, + json={ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": {"name": "test_tool", "arguments": {}}, + }, + ) + assert response.status_code == 200 + body = first_sse_data(response) + assert body["id"] == 1 + assert body["result"]["content"][0]["text"] == "Called test_tool" + + +@pytest.mark.anyio +async def test_duplicate_in_flight_request_id_rejected_during_priming() -> None: + """The duplicate-id guard holds while the event store persists the priming event. + + With an event store configured, the SSE branch awaits ``EventStore.store_event`` + to mint the priming event. The routing slot is reserved before that await, so a + concurrent POST reusing the id during persistence is still rejected rather than + slipping past the guard and overwriting the slot (see #3060). + """ + + class GatedEventStore(SimpleEventStore): + """Blocks the priming write for the gated stream until released.""" + + def __init__(self) -> None: + super().__init__() + self.entered = anyio.Event() + self.release = anyio.Event() + + async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage | None) -> EventId: + if stream_id == "gated-1" and message is None: + self.entered.set() + await self.release.wait() + return await super().store_event(stream_id, message) + + def first_message_sse_data(response: httpx.Response) -> dict[str, Any]: + """First non-empty SSE data payload; skips the empty-data priming event.""" + for line in response.text.splitlines(): + if line.startswith("data: ") and line.removeprefix("data: ").strip(): + return json.loads(line.removeprefix("data: ")) + raise ValueError("No message data event in SSE response") # pragma: no cover + + store = GatedEventStore() + async with running_app(event_store=store) as app: + async with make_client(app) as client: + # Priming events are only minted for protocol >= 2025-11-25, so + # negotiate the latest version rather than INIT_REQUEST's pinned one. + init_params: dict[str, Any] = { + **cast(dict[str, Any], INIT_REQUEST["params"]), + "protocolVersion": types.LATEST_PROTOCOL_VERSION, + } + init_request: dict[str, Any] = {**INIT_REQUEST, "params": init_params} + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=init_request, + ) + assert response.status_code == 200 + headers = { + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: response.headers[MCP_SESSION_ID_HEADER], + MCP_PROTOCOL_VERSION_HEADER: first_message_sse_data(response)["result"]["protocolVersion"], + } + call: dict[str, Any] = { + "jsonrpc": "2.0", + "id": "gated-1", + "method": "tools/call", + "params": {"name": "test_tool", "arguments": {}}, + } + results: dict[str, httpx.Response] = {} + + async def post_first() -> None: + results["first"] = await client.post("/mcp", headers=headers, json=call) + + async with anyio.create_task_group() as tg: + tg.start_soon(post_first) + # The first request is now suspended inside the event store's + # priming write: past the duplicate-id guard, response not started. + with anyio.fail_after(5): + await store.entered.wait() + + # Bounded so a regression fails fast: without the early slot + # reservation, this POST would also suspend in the gated event + # store and the test would hang instead of failing. + with anyio.fail_after(5): + duplicate = await client.post("/mcp", headers=headers, json=call) + assert duplicate.status_code == 400 + error = duplicate.json()["error"] + assert error["code"] == INVALID_REQUEST + assert "already in flight" in error["message"] + + store.release.set() + + # The first request is unaffected and completes with its own result. + assert results["first"].status_code == 200 + body = first_message_sse_data(results["first"]) + assert body["id"] == "gated-1" + assert body["result"]["content"][0]["text"] == "Called test_tool" + + @pytest.mark.anyio async def test_json_response(json_app: Starlette) -> None: """With JSON response mode enabled, requests are answered with application/json bodies."""