Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,20 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re

request_id = str(message.id)

# Reject duplicate in-flight request ids: `_request_streams` is keyed by
# request id, so a second concurrent request with the same id would
# silently overwrite the first one's routing slot and cross-wire their
# responses (one request receives the other's payload, the other hangs).
# The spec requires ids to be unique within a session; ids may still be
# reused once the earlier request has completed. See #3060.
if request_id in self._request_streams:
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
response = self._create_error_response(
f"Bad Request: Request id {request_id} is already in flight for this session",
HTTPStatus.BAD_REQUEST,
)
await response(scope, receive, send)
return

if self.is_json_response_enabled:
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](
REQUEST_STREAM_BUFFER_SIZE
Expand Down Expand Up @@ -597,19 +611,28 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
finally:
await self._clean_up_memory_streams(request_id)
else:
# Mint the priming event before any per-request state exists:
# `EventStore.store_event` is user code and may raise, in which
# case the outer handler returns a 500 with nothing to clean up.
# Still strictly precedes dispatch, so storage order == wire order.
priming_event = await self._mint_priming_event(request_id, protocol_version)

sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0)
self._sse_stream_writers[request_id] = sse_stream_writer
# Reserve the routing slot before any await so nothing separates
# the duplicate-id guard above from registration; a concurrent
# POST reusing this id while the event store persists the priming
# event could otherwise pass the guard and overwrite the slot.
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](
REQUEST_STREAM_BUFFER_SIZE
)
request_stream_reader = self._request_streams[request_id][1]

# Mint the priming event before dispatch so storage order matches
# wire order. `EventStore.store_event` is user code and may raise,
# in which case the outer handler returns a 500; release the slot
# reservation on the way out so the id does not stay in flight.
try:
priming_event = await self._mint_priming_event(request_id, protocol_version)
except BaseException:
await self._clean_up_memory_streams(request_id)
raise

sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0)
self._sse_stream_writers[request_id] = sse_stream_writer

headers = {
"Cache-Control": "no-cache, no-transform",
"Connection": "keep-alive",
Expand Down
229 changes: 227 additions & 2 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import Any
from typing import Any, cast
from unittest.mock import MagicMock
from urllib.parse import urlparse

Expand Down Expand Up @@ -49,7 +49,6 @@
from mcp.server import Server, ServerRequestContext
from mcp.server.streamable_http import (
GET_STREAM_KEY,
MCP_PROTOCOL_VERSION_HEADER,
MCP_SESSION_ID_HEADER,
SESSION_ID_PATTERN,
EventCallback,
Expand All @@ -63,6 +62,7 @@
from mcp.server.transport_security import TransportSecuritySettings
from mcp.shared._compat import resync_tracer
from mcp.shared._context_streams import create_context_streams
from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER
from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage
from mcp.shared.session import RequestResponder
from tests.interaction.transports import StreamingASGITransport
Expand Down Expand Up @@ -94,6 +94,14 @@ def first_sse_data(response: httpx.Response) -> dict[str, Any]:
raise ValueError("No data event in SSE response") # pragma: no cover


async def next_sse_data(lines: AsyncIterator[str]) -> dict[str, Any]:
"""Return the next SSE `data:` payload from a live line iterator, parsed as JSON."""
while True:
line = await anext(lines)
if line.startswith("data: "):
return json.loads(line.removeprefix("data: "))


def extract_protocol_version_from_sse(response: httpx.Response) -> str:
"""Extract the negotiated protocol version from an SSE initialization response."""
return first_sse_data(response)["result"]["protocolVersion"]
Expand Down Expand Up @@ -680,6 +688,223 @@ async def test_response(basic_app: Starlette) -> None:
assert tools_response.headers.get("Content-Type") == "text/event-stream"


@pytest.mark.anyio
async def test_duplicate_in_flight_request_id_rejected(basic_app: Starlette) -> None:
"""A request whose id is already in flight on the session is rejected with 400.

The per-request routing in the transport is keyed by request id, so a second
concurrent request with the same id would overwrite the in-flight request's
routing slot and cross-wire the two responses (see #3060). The duplicate is
rejected and the in-flight request completes unaffected.
"""
async with make_client(basic_app) as client:
response = await client.post(
"/mcp",
headers={
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
},
json=INIT_REQUEST,
)
assert response.status_code == 200
headers = {
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
MCP_SESSION_ID_HEADER: response.headers[MCP_SESSION_ID_HEADER],
MCP_PROTOCOL_VERSION_HEADER: extract_protocol_version_from_sse(response),
}

# Request A blocks server-side on the lock, keeping its id in flight.
async with client.stream(
"POST",
"/mcp",
headers=headers,
json={
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": {"name": "wait_for_lock_with_notification", "arguments": {}},
},
) as response_a:
assert response_a.status_code == 200
lines_a = response_a.aiter_lines()
# The tool's first notification confirms request A is in flight.
with anyio.fail_after(5):
notification = await next_sse_data(lines_a)
assert notification["params"]["data"] == "First notification before lock"

# A second request reusing id 1 while A is in flight is rejected.
response_b = await client.post(
"/mcp",
headers=headers,
json={
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": {"name": "test_tool", "arguments": {}},
},
)
assert response_b.status_code == 400
error = response_b.json()["error"]
assert error["code"] == INVALID_REQUEST
assert "already in flight" in error["message"]

# Request A is unaffected: release the lock and it completes normally.
release_response = await client.post(
"/mcp",
headers=headers,
json={
"jsonrpc": "2.0",
"id": 2,
"method": "tools/call",
"params": {"name": "release_lock", "arguments": {}},
},
)
assert release_response.status_code == 200

with anyio.fail_after(5):
notification = await next_sse_data(lines_a)
final = await next_sse_data(lines_a)
assert notification["params"]["data"] == "Second notification after lock"
assert final["id"] == 1
assert final["result"]["content"][0]["text"] == "Completed"


@pytest.mark.anyio
async def test_request_id_reuse_after_completion_allowed(basic_app: Starlette) -> None:
"""A request id can be reused once the earlier request with that id has completed.

Only concurrent requests with the same id are ambiguous to route; sequential
reuse (which some deployed clients rely on, sending every request with id 1)
keeps working (see #3060).
"""
async with make_client(basic_app) as client:
response = await client.post(
"/mcp",
headers={
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
},
json=INIT_REQUEST,
)
assert response.status_code == 200
headers = {
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
MCP_SESSION_ID_HEADER: response.headers[MCP_SESSION_ID_HEADER],
MCP_PROTOCOL_VERSION_HEADER: extract_protocol_version_from_sse(response),
}

for _ in range(2):
response = await client.post(
"/mcp",
headers=headers,
json={
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": {"name": "test_tool", "arguments": {}},
},
)
assert response.status_code == 200
body = first_sse_data(response)
assert body["id"] == 1
assert body["result"]["content"][0]["text"] == "Called test_tool"


@pytest.mark.anyio
async def test_duplicate_in_flight_request_id_rejected_during_priming() -> None:
"""The duplicate-id guard holds while the event store persists the priming event.

With an event store configured, the SSE branch awaits ``EventStore.store_event``
to mint the priming event. The routing slot is reserved before that await, so a
concurrent POST reusing the id during persistence is still rejected rather than
slipping past the guard and overwriting the slot (see #3060).
"""

class GatedEventStore(SimpleEventStore):
"""Blocks the priming write for the gated stream until released."""

def __init__(self) -> None:
super().__init__()
self.entered = anyio.Event()
self.release = anyio.Event()

async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage | None) -> EventId:
if stream_id == "gated-1" and message is None:
self.entered.set()
await self.release.wait()
return await super().store_event(stream_id, message)

def first_message_sse_data(response: httpx.Response) -> dict[str, Any]:
"""First non-empty SSE data payload; skips the empty-data priming event."""
for line in response.text.splitlines():
if line.startswith("data: ") and line.removeprefix("data: ").strip():
return json.loads(line.removeprefix("data: "))
raise ValueError("No message data event in SSE response") # pragma: no cover

store = GatedEventStore()
async with running_app(event_store=store) as app:
async with make_client(app) as client:
# Priming events are only minted for protocol >= 2025-11-25, so
# negotiate the latest version rather than INIT_REQUEST's pinned one.
init_params: dict[str, Any] = {
**cast(dict[str, Any], INIT_REQUEST["params"]),
"protocolVersion": types.LATEST_PROTOCOL_VERSION,
}
init_request: dict[str, Any] = {**INIT_REQUEST, "params": init_params}
response = await client.post(
"/mcp",
headers={
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
},
json=init_request,
)
assert response.status_code == 200
headers = {
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
MCP_SESSION_ID_HEADER: response.headers[MCP_SESSION_ID_HEADER],
MCP_PROTOCOL_VERSION_HEADER: first_message_sse_data(response)["result"]["protocolVersion"],
}
call: dict[str, Any] = {
"jsonrpc": "2.0",
"id": "gated-1",
"method": "tools/call",
"params": {"name": "test_tool", "arguments": {}},
}
results: dict[str, httpx.Response] = {}

async def post_first() -> None:
results["first"] = await client.post("/mcp", headers=headers, json=call)

async with anyio.create_task_group() as tg:
tg.start_soon(post_first)
# The first request is now suspended inside the event store's
# priming write: past the duplicate-id guard, response not started.
with anyio.fail_after(5):
await store.entered.wait()

# Bounded so a regression fails fast: without the early slot
# reservation, this POST would also suspend in the gated event
# store and the test would hang instead of failing.
with anyio.fail_after(5):
duplicate = await client.post("/mcp", headers=headers, json=call)
assert duplicate.status_code == 400
error = duplicate.json()["error"]
assert error["code"] == INVALID_REQUEST
assert "already in flight" in error["message"]

store.release.set()

# The first request is unaffected and completes with its own result.
assert results["first"].status_code == 200
body = first_message_sse_data(results["first"])
assert body["id"] == "gated-1"
assert body["result"]["content"][0]["text"] == "Called test_tool"


@pytest.mark.anyio
async def test_json_response(json_app: Starlette) -> None:
"""With JSON response mode enabled, requests are answered with application/json bodies."""
Expand Down
Loading