Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/google/adk/workflow/utils/_retry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def _should_retry_node(
return False

attempt_count = node_state.attempt_count
max_attempts = retry_config.max_attempts or 5
max_attempts = (
retry_config.max_attempts if retry_config.max_attempts is not None else 5
)

# attempt_count starts at 1 for the original request.
# So if attempt_count >= max_attempts, we have reached the limit.
Expand Down
34 changes: 34 additions & 0 deletions tests/unittests/workflow/utils/test_retry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from google.adk.workflow._node_state import NodeState
from google.adk.workflow._retry_config import RetryConfig
from google.adk.workflow.utils._retry_utils import _get_retry_delay
from google.adk.workflow.utils._retry_utils import _should_retry_node
import pytest


Expand Down Expand Up @@ -68,3 +69,36 @@ def test_adds_jitter_when_enabled(self):

assert all(5.0 <= d <= 15.0 for d in delays)
assert len(set(delays)) > 1


class TestShouldRetryNode:

def test_no_config_never_retries(self):
"""Without a retry config, a node is never retried."""
assert _should_retry_node(RuntimeError(), None, NodeState(attempt_count=1)) is False

@pytest.mark.parametrize("max_attempts", [0, 1])
def test_max_attempts_zero_or_one_disables_retries(self, max_attempts):
"""max_attempts of 0 or 1 means no retries (per RetryConfig docs).

A falsy-coalescing default (``max_attempts or 5``) wrongly treated an
explicit ``0`` as unset and allowed 5 attempts.
"""
config = RetryConfig(max_attempts=max_attempts)

assert _should_retry_node(RuntimeError(), config, NodeState(attempt_count=1)) is False

def test_retries_until_max_attempts(self):
"""A node is retried while attempt_count is below max_attempts."""
config = RetryConfig(max_attempts=3)

assert _should_retry_node(RuntimeError(), config, NodeState(attempt_count=1)) is True
assert _should_retry_node(RuntimeError(), config, NodeState(attempt_count=2)) is True
assert _should_retry_node(RuntimeError(), config, NodeState(attempt_count=3)) is False

def test_unset_max_attempts_defaults_to_five(self):
"""When max_attempts is unset (None), the default of 5 applies."""
config = RetryConfig()

assert _should_retry_node(RuntimeError(), config, NodeState(attempt_count=4)) is True
assert _should_retry_node(RuntimeError(), config, NodeState(attempt_count=5)) is False