Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/mcp/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ async def serve_loop(
# next request (spec: SHOULD NOT, not MUST NOT) sees the initialized
# state instead of failing the init-gate.
inline_methods=frozenset({"initialize"}),
drain_inbound_on_read_eof=getattr(read_stream, "drain_inbound_on_read_eof", False),
)
connection = Connection.for_loop(dispatcher, session_id=session_id)
await serve_connection(
Expand Down
4 changes: 4 additions & 0 deletions src/mcp/server/stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio.
stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8"))

read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0)
# Redirected stdin reaches EOF immediately after the final JSON-RPC line.
# Stdio must keep stdout alive long enough for already-accepted request
# responses to flush; other transports keep immediate EOF cancellation.
read_stream.drain_inbound_on_read_eof = True
write_stream, write_stream_reader = create_context_streams[SessionMessage](0)

async def stdin_reader():
Expand Down
3 changes: 2 additions & 1 deletion src/mcp/shared/_context_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,12 @@ async def __aexit__(
class ContextReceiveStream(Generic[T]):
"""Receive-side wrapper that yields ``T`` and stores the sender's context in ``last_context``."""

__slots__ = ("_inner", "last_context")
__slots__ = ("_inner", "last_context", "drain_inbound_on_read_eof")

def __init__(self, inner: MemoryObjectReceiveStream[_Envelope[T]]) -> None:
self._inner = inner
self.last_context: contextvars.Context | None = None
self.drain_inbound_on_read_eof = False

async def receive(self) -> T:
ctx, item = await self._inner.receive()
Expand Down
54 changes: 48 additions & 6 deletions src/mcp/shared/jsonrpc_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@
_SHUTDOWN_WRITE_TIMEOUT: float = 1
"""Tighter bound for the shutdown-arm error write so a wedged transport can't hold session close."""

_DRAIN_INBOUND_ON_EOF_TIMEOUT: float = 5
"""Bound for letting already-accepted inbound requests write responses after read EOF."""

_DRAIN_INBOUND_ON_EOF_POLL_INTERVAL: float = 0.01
"""Polling interval while waiting for accepted inbound requests to finish."""

TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext)

PeerCancelMode = Literal["interrupt", "signal"]
Expand Down Expand Up @@ -251,6 +257,7 @@ def __init__(
raise_handler_exceptions: bool = False,
inline_methods: frozenset[str] = frozenset(),
on_stream_exception: Callable[[Exception], Awaitable[None]] | None = None,
drain_inbound_on_read_eof: bool = False,
) -> None:
"""Wire a dispatcher over a transport's `SessionMessage` stream pair.

Expand All @@ -265,6 +272,10 @@ def __init__(
on_stream_exception: Observer for `Exception` items on the read
stream; without it they are debug-logged and dropped. Awaited
inline in the read loop, so a slow observer stalls dispatch.
drain_inbound_on_read_eof: Let already-accepted inbound request
response writes finish after read EOF before cancelling the run
task group. Intended for stdio EOF after redirected input;
default transport-close semantics remain immediate cancellation.
"""
self._read_stream = read_stream
self._write_stream = write_stream
Expand All @@ -281,10 +292,12 @@ def __init__(
"""Observer for ``Exception`` items on the read stream. Mutable so a session can
bind it after the dispatcher is built (e.g. ``ClientSession`` routing into
``message_handler``); only consulted inside ``run()`` so pre-enter assignment is safe."""
self._drain_inbound_on_read_eof = drain_inbound_on_read_eof

self._next_id = 0
self._pending: dict[RequestId, _Pending] = {}
self._in_flight: dict[RequestId, _InFlight[TransportT]] = {}
self._active_inbound_requests = 0
self._tg: anyio.abc.TaskGroup | None = None
self._running = False
self._closed = False
Expand Down Expand Up @@ -471,6 +484,8 @@ async def run(
self._running = False
self._closed = True
self._fan_out_closed()
if self._drain_inbound_on_read_eof:
await self._drain_active_inbound_requests()
finally:
# Cancel in-flight handlers; otherwise the task-group join
# waits on handlers whose callers are already gone.
Expand Down Expand Up @@ -525,17 +540,24 @@ async def _dispatch_request(
sender_ctx: contextvars.Context | None,
) -> None:
progress_token = progress_token_from_params(req.params)
self._active_inbound_requests += 1
try:
transport_ctx = self._transport_builder(metadata)
except Exception:
# A raising builder must cost only this message, not the connection.
# Track its spawned error response so stdio EOF drain waits for this
# already-accepted request outcome too.
logger.exception("transport_builder raised; rejecting request %r", req.id)
self._spawn(
self._write_error,
req.id,
ErrorData(code=INTERNAL_ERROR, message="transport context unavailable"),
sender_ctx=sender_ctx,
)

async def _reject_builder_failure() -> None:
try:
await self._write_error(
req.id, ErrorData(code=INTERNAL_ERROR, message="transport context unavailable")
)
finally:
self._active_inbound_requests = max(0, self._active_inbound_requests - 1)

self._spawn(_reject_builder_failure, sender_ctx=sender_ctx)
return
dctx = _JSONRPCDispatchContext(
transport=transport_ctx,
Expand Down Expand Up @@ -659,6 +681,24 @@ def _fan_out_closed(self) -> None:
pass
self._pending.clear()

async def _drain_active_inbound_requests(self) -> None:
"""Let accepted inbound requests finish response writes after read EOF.

A redirected-stdin stdio transport can reach EOF immediately after the last
request is accepted. Treating EOF as immediate shutdown cancels handlers
before their JSON-RPC responses reach stdout. Keep the write side open
briefly so already-accepted requests can produce responses, then let the
caller cancel any stragglers.
"""
with anyio.move_on_after(_DRAIN_INBOUND_ON_EOF_TIMEOUT) as scope:
while self._active_inbound_requests:
await anyio.sleep(_DRAIN_INBOUND_ON_EOF_POLL_INTERVAL)
if scope.cancelled_caught:
logger.warning(
"timed out waiting for %d inbound request(s) to finish after read EOF",
self._active_inbound_requests,
)

async def _handle_request(
self,
req: JSONRPCRequest,
Expand Down Expand Up @@ -722,6 +762,8 @@ async def _handle_request(
await self._write_error(req.id, ErrorData(code=0, message=str(e)))
if self._raise_handler_exceptions:
raise
finally:
self._active_inbound_requests = max(0, self._active_inbound_requests - 1)
# No `_in_flight` pop here: the inner finally covers every path, and a late pop could evict a reused id.

def _allocate_id(self) -> int:
Expand Down
81 changes: 81 additions & 0 deletions tests/shared/test_jsonrpc_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,87 @@ async def caller() -> None:
s.close()


@pytest.mark.anyio
async def test_opt_in_read_eof_drains_accepted_inbound_request_response():
"""Read EOF must not cancel a request that was already accepted.

This covers redirected-stdin stdio servers: EOF can arrive immediately after
the final JSON-RPC request is read, while the tool handler still has an
await point before its response write.
"""
c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32)
s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage](32)
server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(
c2s_recv, s2c_send, drain_inbound_on_read_eof=True
)
handler_started = anyio.Event()

async def on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]:
handler_started.set()
await anyio.sleep(0.05)
return {"ok": True}

async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None:
pass

try:
async with anyio.create_task_group() as tg:
await tg.start(server.run, on_request, on_notify)
await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="slow")))
await handler_started.wait()

# Simulate stdin EOF after the request has been accepted but before
# the handler has finished and written its response.
c2s_send.close()

with anyio.fail_after(5):
response = await s2c_recv.receive()
assert response.message == JSONRPCResponse(jsonrpc="2.0", id=1, result={"ok": True})
tg.cancel_scope.cancel()
finally:
for s in (c2s_send, c2s_recv, s2c_send, s2c_recv):
s.close()


@pytest.mark.anyio
async def test_opt_in_read_eof_drains_transport_builder_rejection_response():
"""The stdio EOF drain also covers spawned rejection writes before a handler exists."""
c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32)
s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage](32)

def reject(_metadata: MessageMetadata) -> TransportContext:
raise RuntimeError("no context")

server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(
c2s_recv,
s2c_send,
transport_builder=reject,
drain_inbound_on_read_eof=True,
)

async def on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]:
raise NotImplementedError

async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None:
pass

try:
async with anyio.create_task_group() as tg:
await tg.start(server.run, on_request, on_notify)
await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="slow")))
c2s_send.close()

with anyio.fail_after(5):
response = await s2c_recv.receive()
assert isinstance(response.message, JSONRPCError)
assert response.message.id == 1
assert response.message.error.code == INTERNAL_ERROR
tg.cancel_scope.cancel()
finally:
for stream in (c2s_send, c2s_recv, s2c_send, s2c_recv):
stream.close()


@pytest.mark.anyio
async def test_run_returns_cleanly_when_read_stream_receive_end_is_closed():
"""Iterating a closed receive end is EOF, not a crash (stateless SHTTP closes it during teardown)."""
Expand Down
Loading