diff --git a/src/google/adk/workflow/utils/_retry_utils.py b/src/google/adk/workflow/utils/_retry_utils.py index 5a220dd261..d9c0c0870b 100644 --- a/src/google/adk/workflow/utils/_retry_utils.py +++ b/src/google/adk/workflow/utils/_retry_utils.py @@ -32,7 +32,9 @@ def _should_retry_node( return False attempt_count = node_state.attempt_count - max_attempts = retry_config.max_attempts or 5 + max_attempts = ( + retry_config.max_attempts if retry_config.max_attempts is not None else 5 + ) # attempt_count starts at 1 for the original request. # So if attempt_count >= max_attempts, we have reached the limit. diff --git a/tests/unittests/workflow/utils/test_retry_utils.py b/tests/unittests/workflow/utils/test_retry_utils.py index 3685139c42..4544c10bd7 100644 --- a/tests/unittests/workflow/utils/test_retry_utils.py +++ b/tests/unittests/workflow/utils/test_retry_utils.py @@ -17,6 +17,7 @@ from google.adk.workflow._node_state import NodeState from google.adk.workflow._retry_config import RetryConfig from google.adk.workflow.utils._retry_utils import _get_retry_delay +from google.adk.workflow.utils._retry_utils import _should_retry_node import pytest @@ -68,3 +69,36 @@ def test_adds_jitter_when_enabled(self): assert all(5.0 <= d <= 15.0 for d in delays) assert len(set(delays)) > 1 + + +class TestShouldRetryNode: + + def test_no_config_never_retries(self): + """Without a retry config, a node is never retried.""" + assert _should_retry_node(RuntimeError(), None, NodeState(attempt_count=1)) is False + + @pytest.mark.parametrize("max_attempts", [0, 1]) + def test_max_attempts_zero_or_one_disables_retries(self, max_attempts): + """max_attempts of 0 or 1 means no retries (per RetryConfig docs). + + A falsy-coalescing default (``max_attempts or 5``) wrongly treated an + explicit ``0`` as unset and allowed 5 attempts. + """ + config = RetryConfig(max_attempts=max_attempts) + + assert _should_retry_node(RuntimeError(), config, NodeState(attempt_count=1)) is False + + def test_retries_until_max_attempts(self): + """A node is retried while attempt_count is below max_attempts.""" + config = RetryConfig(max_attempts=3) + + assert _should_retry_node(RuntimeError(), config, NodeState(attempt_count=1)) is True + assert _should_retry_node(RuntimeError(), config, NodeState(attempt_count=2)) is True + assert _should_retry_node(RuntimeError(), config, NodeState(attempt_count=3)) is False + + def test_unset_max_attempts_defaults_to_five(self): + """When max_attempts is unset (None), the default of 5 applies.""" + config = RetryConfig() + + assert _should_retry_node(RuntimeError(), config, NodeState(attempt_count=4)) is True + assert _should_retry_node(RuntimeError(), config, NodeState(attempt_count=5)) is False