diff --git a/.github/.release-please-manifest-v1.json b/.github/.release-please-manifest-v1.json new file mode 100644 index 00000000000..d82abb2aa56 --- /dev/null +++ b/.github/.release-please-manifest-v1.json @@ -0,0 +1,3 @@ +{ + ".": "1.36.0" +} diff --git a/.github/.release-please-manifest-v2.json b/.github/.release-please-manifest-v2.json deleted file mode 100644 index 0739396e62f..00000000000 --- a/.github/.release-please-manifest-v2.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - ".": "2.0.0-alpha.0" -} diff --git a/.github/release-please-config-v2.json b/.github/release-please-config-v2.json deleted file mode 100644 index 6947d9e15a4..00000000000 --- a/.github/release-please-config-v2.json +++ /dev/null @@ -1,62 +0,0 @@ -{ - "$schema": "https://raw.githubusercontent.com/googleapis/release-please/main/schemas/config.json", - "packages": { - ".": { - "release-type": "python", - "versioning": "prerelease", - "prerelease": true, - "prerelease-type": "alpha", - "package-name": "google-adk", - "include-component-in-tag": false, - "skip-github-release": true, - "changelog-path": "CHANGELOG-v2.md", - "changelog-sections": [ - { - "type": "feat", - "section": "Features" - }, - { - "type": "fix", - "section": "Bug Fixes" - }, - { - "type": "perf", - "section": "Performance Improvements" - }, - { - "type": "refactor", - "section": "Code Refactoring" - }, - { - "type": "docs", - "section": "Documentation" - }, - { - "type": "test", - "section": "Tests", - "hidden": true - }, - { - "type": "build", - "section": "Build System", - "hidden": true - }, - { - "type": "ci", - "section": "CI/CD", - "hidden": true - }, - { - "type": "style", - "section": "Styles", - "hidden": true - }, - { - "type": "chore", - "section": "Miscellaneous Chores", - "hidden": true - } - ] - } - } -} diff --git a/.github/workflows/check-file-contents.yml b/.github/workflows/check-file-contents.yml deleted file mode 100644 index f703422f76a..00000000000 --- a/.github/workflows/check-file-contents.yml +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: "Check file contents" - -on: - pull_request: - paths: - - '**.py' - -permissions: - contents: read - -jobs: - check-file-contents: - runs-on: ubuntu-latest - steps: - - name: Checkout Code - uses: actions/checkout@v6 - with: - fetch-depth: 2 - - - name: Check for logger pattern in all changed Python files - run: | - git fetch origin ${GITHUB_BASE_REF} - CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' || true) - if [ -n "$CHANGED_FILES" ]; then - echo "Changed Python files to check:" - echo "$CHANGED_FILES" - echo "" - - # Check for 'logger = logging.getLogger(__name__)' in changed .py files. - # The grep command will exit with a non-zero status code if the pattern is not found. - # We invert the exit code with ! so the step succeeds if the pattern is NOT found. - set +e - FILES_WITH_FORBIDDEN_LOGGER=$(grep -lE 'logger = logging\.getLogger\(__name__\)' $CHANGED_FILES) - GREP_EXIT_CODE=$? - set -e - - # grep exits with 0 if matches are found, 1 if no matches are found. - # A non-zero exit code other than 1 indicates an error. - if [ $GREP_EXIT_CODE -eq 0 ]; then - echo "❌ Found forbidden use of 'logger = logging.getLogger(__name__)'. Please use 'logger = logging.getLogger('google_adk.' + __name__)' instead." - echo "The following files contain the forbidden pattern:" - echo "$FILES_WITH_FORBIDDEN_LOGGER" - exit 1 - elif [ $GREP_EXIT_CODE -eq 1 ]; then - echo "✅ No instances of 'logger = logging.getLogger(__name__)' found in changed Python files." - fi - else - echo "✅ No relevant Python files found." - fi - - - name: Check for import pattern in certain changed Python files - run: | - git fetch origin ${GITHUB_BASE_REF} - CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' | grep -v -E '__init__.py$|version.py$|tests/.*|contributing/samples/' || true) - if [ -n "$CHANGED_FILES" ]; then - echo "Changed Python files to check:" - echo "$CHANGED_FILES" - echo "" - - # Use grep -L to find files that DO NOT contain the pattern. - # This command will output a list of non-compliant files. - FILES_MISSING_IMPORT=$(grep -L 'from __future__ import annotations' $CHANGED_FILES || true) - - # Check if the list of non-compliant files is empty - if [ -z "$FILES_MISSING_IMPORT" ]; then - echo "✅ All modified Python files include 'from __future__ import annotations'." - exit 0 - else - echo "❌ The following files are missing 'from __future__ import annotations':" - echo "$FILES_MISSING_IMPORT" - echo "This import is required to allow forward references in type annotations without quotes." - exit 1 - fi - else - echo "✅ No relevant Python files found." - fi - - - name: Check for import from cli package in certain changed Python files - run: | - git fetch origin ${GITHUB_BASE_REF} - CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' | grep -v -E 'cli/.*|src/google/adk/tools/apihub_tool/apihub_toolset.py|tests/.*|contributing/samples/' || true) - if [ -n "$CHANGED_FILES" ]; then - echo "Changed Python files to check:" - echo "$CHANGED_FILES" - echo "" - - set +e - FILES_WITH_FORBIDDEN_IMPORT=$(grep -lE '^from.*\bcli\b.*import.*$' $CHANGED_FILES) - GREP_EXIT_CODE=$? - set -e - - if [[ $GREP_EXIT_CODE -eq 0 ]]; then - echo "❌ Do not import from the cli package outside of the cli package. If you need to reuse the code elsewhere, please move the code outside of the cli package." - echo "The following files contain the forbidden pattern:" - echo "$FILES_WITH_FORBIDDEN_IMPORT" - exit 1 - else - echo "✅ No instances of importing from the cli package found in relevant changed Python files." - fi - else - echo "✅ No relevant Python files found." - fi diff --git a/.github/workflows/continuous-integration.yml b/.github/workflows/continuous-integration.yml new file mode 100644 index 00000000000..1ae99891fe2 --- /dev/null +++ b/.github/workflows/continuous-integration.yml @@ -0,0 +1,276 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Continuous Integration + +on: + push: + branches: [main, v1] + paths: + - '**.py' + - '.pre-commit-config.yaml' + - 'pyproject.toml' + - 'tests/**' + pull_request: + branches: [main, v1] + paths: + - '**.py' + - '.pre-commit-config.yaml' + - 'pyproject.toml' + - 'tests/**' + +permissions: + contents: read + +jobs: + # 1. Code format and linting (Linter) + lint: + name: Pre-commit Linter + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@v6 + + - name: Run pre-commit checks + uses: pre-commit/action@v3.0.1 + + # 2. Static type analysis (Mypy Check with Matrix) + # Compares new changes against the target base branch dynamically to support v1. + type-check: + name: Mypy Check (Python ${{ matrix.python-version }}) + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ['3.10', '3.11', '3.12', '3.13'] + steps: + - name: Checkout code + uses: actions/checkout@v6 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + + - name: Install uv + uses: astral-sh/setup-uv@v7 + + - name: Generate Baseline + env: + TARGET_BRANCH: ${{ github.base_ref || github.ref_name }} + run: | + # Switch to target base branch to generate baseline + git checkout origin/$TARGET_BRANCH + + git checkout ${{ github.sha }} -- pyproject.toml + + # Install dependencies for target branch + uv venv .venv + source .venv/bin/activate + uv sync --all-extras + + # Run mypy, filter for errors only, remove line numbers, and sort + # We ignore exit code (|| true) because we expect errors on baseline + uv run mypy . | grep "error:" | sed 's/:\([0-9]\+\):/::/g' | sort > baseline_errors.txt || true + echo "Found $(wc -l < baseline_errors.txt) errors on $TARGET_BRANCH." + + - name: Check PR Branch + run: | + # Switch back to the PR commit + git checkout ${{ github.sha }} + + # Re-sync dependencies in case the PR changed them + source .venv/bin/activate + uv sync --all-extras + + # Run mypy on PR code, apply same processing + uv run mypy . | grep "error:" | sed 's/:\([0-9]\+\):/::/g' | sort > pr_errors.txt || true + echo "Found $(wc -l < pr_errors.txt) errors on PR branch." + + - name: Compare and Fail on New Errors + run: | + # 'comm -13' suppresses unique lines in file1 (baseline) and common lines, + # leaving only lines unique to file2 (PR) -> The new errors. + comm -13 baseline_errors.txt pr_errors.txt > new_errors.txt + + if [ -s new_errors.txt ]; then + echo "::error::The following NEW mypy errors were introduced:" + cat new_errors.txt + exit 1 + else + echo "Great job! No new mypy errors introduced." + fi + + # 3. Unit testing (Unit Tests with Matrix) + unit-test: + name: Unit Tests (Python ${{ matrix.python-version }}) + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] + timeout-minutes: 10 + steps: + - name: Checkout code + uses: actions/checkout@v6 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + + - name: Install the latest version of uv + uses: astral-sh/setup-uv@v7 + + - name: Install dependencies + run: | + uv venv .venv + source .venv/bin/activate + uv sync --extra test + + - name: Run unit tests with pytest + run: | + source .venv/bin/activate + pytest tests/unittests \ + -n auto \ + --ignore=tests/unittests/artifacts/test_artifact_service.py \ + --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py + + # 4. Custom file content compliance checks (PR only) + compliance-check: + name: File Content Compliance + runs-on: ubuntu-latest + if: github.event_name == 'pull_request' + steps: + - name: Checkout Code + uses: actions/checkout@v6 + with: + # Fetch full history (depth: 0) instead of shallow clone (depth: 2) to ensure + # git diff origin/${base_ref}...HEAD can reliably find the merge base, + # preventing fatal git errors on deep PRs or when the target branch has progressed. + fetch-depth: 0 + + - name: Check for logger pattern in all changed Python files + run: | + git fetch origin ${GITHUB_BASE_REF} + CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' || true) + if [ -n "$CHANGED_FILES" ]; then + echo "Changed Python files to check:" + echo "$CHANGED_FILES" + echo "" + + # Check for 'logger = logging.getLogger(__name__)' in changed .py files. + set +e + FILES_WITH_FORBIDDEN_LOGGER=$(grep -lE 'logger = logging\.getLogger\(__name__\)' $CHANGED_FILES) + GREP_EXIT_CODE=$? + set -e + + if [ $GREP_EXIT_CODE -eq 0 ]; then + echo "❌ Found forbidden use of 'logger = logging.getLogger(__name__)'. Please use 'logger = logging.getLogger('google_adk.' + __name__)' instead." + echo "The following files contain the forbidden pattern:" + echo "$FILES_WITH_FORBIDDEN_LOGGER" + exit 1 + elif [ $GREP_EXIT_CODE -eq 1 ]; then + echo "✅ No instances of 'logger = logging.getLogger(__name__)' found in changed Python files." + fi + else + echo "✅ No relevant Python files found." + fi + + - name: Check for import pattern in certain changed Python files + run: | + git fetch origin ${GITHUB_BASE_REF} + CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' | grep -v -E '__init__.py$|version.py$|tests/.*|contributing/samples/' || true) + if [ -n "$CHANGED_FILES" ]; then + echo "Changed Python files to check:" + echo "$CHANGED_FILES" + echo "" + + # Use grep -L to find files that DO NOT contain the pattern. + FILES_MISSING_IMPORT=$(grep -L 'from __future__ import annotations' $CHANGED_FILES || true) + + if [ -z "$FILES_MISSING_IMPORT" ]; then + echo "✅ All modified Python files include 'from __future__ import annotations'." + exit 0 + else + echo "❌ The following files are missing 'from __future__ import annotations':" + echo "$FILES_MISSING_IMPORT" + echo "This import is required to allow forward references in type annotations without quotes." + exit 1 + fi + else + echo "✅ No relevant Python files found." + fi + + - name: Check for import from cli package in certain changed Python files + run: | + git fetch origin ${GITHUB_BASE_REF} + CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' | grep -v -E 'cli/.*|src/google/adk/tools/apihub_tool/apihub_toolset.py|tests/.*|contributing/samples/' || true) + if [ -n "$CHANGED_FILES" ]; then + echo "Changed Python files to check:" + echo "$CHANGED_FILES" + echo "" + + set +e + FILES_WITH_FORBIDDEN_IMPORT=$(grep -lE '^from.*\bcli\b.*import.*$' $CHANGED_FILES) + GREP_EXIT_CODE=$? + set -e + + if [[ $GREP_EXIT_CODE -eq 0 ]]; then + echo "❌ Do not import from the cli package outside of the cli package. If you need to reuse the code elsewhere, please move the code outside of the cli package." + echo "The following files contain the forbidden pattern:" + echo "$FILES_WITH_FORBIDDEN_IMPORT" + exit 1 + else + echo "✅ No instances of importing from the cli package found in relevant changed Python files." + fi + else + echo "✅ No relevant Python files found." + fi + + - name: Check for hardcoded googleapis.com endpoints + run: | + git fetch origin ${GITHUB_BASE_REF} + CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' || true) + if [ -n "$CHANGED_FILES" ]; then + echo "Checking for hardcoded endpoints in: $CHANGED_FILES" + + # 1. Identify files containing any googleapis.com URL. + set +e + FILES_WITH_ENDPOINTS=$(grep -lE 'https?://[a-zA-Z0-9.-]+\.googleapis\.com' $CHANGED_FILES) + + # 2. From those, identify files that are MISSING the required mTLS version. + if [ -n "$FILES_WITH_ENDPOINTS" ]; then + FILES_MISSING_MTLS=$(grep -L '.mtls.googleapis.com' $FILES_WITH_ENDPOINTS) + fi + set -e + + if [ -n "$FILES_MISSING_MTLS" ]; then + echo "❌ Found hardcoded googleapis.com endpoints without mTLS support." + echo "The following files must define both standard and mTLS (.mtls.googleapis.com) endpoints" + echo "to support dynamic endpoint selection as required by security policy:" + echo "$FILES_MISSING_MTLS" + echo "" + echo "To fix this, please follow these steps:" + echo "1. Initialize an AuthorizedSession with your credentials." + echo "2. Use 'mtls.has_default_client_cert_source() from google-auth' to check for available client certificates." + echo "3. If certificates are present, use 'session.configure_mtls_channel()'." + echo "4. Dynamically select the '.mtls.' variant of the endpoint when mTLS is active." + exit 1 + else + echo "✅ All hardcoded endpoints have corresponding mTLS definitions or no endpoints found." + fi + fi diff --git a/.github/workflows/gemini-dispatch.yml b/.github/workflows/gemini-dispatch.yml deleted file mode 100644 index 9c2bf8ec9dc..00000000000 --- a/.github/workflows/gemini-dispatch.yml +++ /dev/null @@ -1,189 +0,0 @@ -name: '🔀 Gemini Dispatch' - -on: - pull_request_review_comment: - types: - - 'created' - pull_request_review: - types: - - 'submitted' - issue_comment: - types: - - 'created' - -defaults: - run: - shell: 'bash' - -jobs: - debugger: - if: |- - ${{ fromJSON(vars.GEMINI_DEBUG || vars.ACTIONS_STEP_DEBUG || false) }} - runs-on: 'ubuntu-latest' - permissions: - contents: 'read' - steps: - - name: 'Print context for debugging' - env: - DEBUG_event_name: '${{ github.event_name }}' - DEBUG_event__action: '${{ github.event.action }}' - DEBUG_event__comment__author_association: '${{ github.event.comment.author_association }}' - DEBUG_event__issue__author_association: '${{ github.event.issue.author_association }}' - DEBUG_event__pull_request__author_association: '${{ github.event.pull_request.author_association }}' - DEBUG_event__review__author_association: '${{ github.event.review.author_association }}' - DEBUG_event: '${{ toJSON(github.event) }}' - run: |- - env | grep '^DEBUG_' - - dispatch: - # Only trigger if user types @gemini-cli and author association is OWNER, MEMBER, or COLLABORATOR - if: |- - github.event.sender.type == 'User' && - startsWith(github.event.comment.body || github.event.review.body, '@gemini-cli') && - contains(fromJSON('["OWNER", "MEMBER", "COLLABORATOR"]'), github.event.comment.author_association || github.event.review.author_association) - runs-on: 'ubuntu-latest' - permissions: - contents: 'read' - issues: 'write' - pull-requests: 'write' - outputs: - command: '${{ steps.extract_command.outputs.command }}' - request: '${{ steps.extract_command.outputs.request }}' - additional_context: '${{ steps.extract_command.outputs.additional_context }}' - issue_number: '${{ github.event.pull_request.number || github.event.issue.number }}' - steps: - - name: 'Mint identity token' - id: 'mint_identity_token' - if: |- - ${{ vars.APP_ID }} - uses: 'actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf' # ratchet:actions/create-github-app-token@v2 - with: - app-id: '${{ vars.APP_ID }}' - private-key: '${{ secrets.APP_PRIVATE_KEY }}' - permission-contents: 'read' - permission-issues: 'write' - permission-pull-requests: 'write' - - - name: 'Extract command' - id: 'extract_command' - uses: 'actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd' # ratchet:actions/github-script@v8.0.0 - env: - REQUEST: '${{ github.event.comment.body || github.event.review.body }}' - IS_PR: '${{ !!(github.event.pull_request || github.event.issue.pull_request) }}' - with: - script: | - const request = process.env.REQUEST; - const isPr = process.env.IS_PR === 'true'; - core.setOutput('request', request); - - // Ensure request is on a PR targeting the main branch - let baseRef = ''; - if (context.eventName === 'pull_request_review' || context.eventName === 'pull_request_review_comment') { - baseRef = context.payload.pull_request.base.ref; - } else if (context.eventName === 'issue_comment' && context.payload.issue.pull_request) { - const pr = await github.rest.pulls.get({ - owner: context.repo.owner, - repo: context.repo.repo, - pull_number: context.payload.issue.number - }); - baseRef = pr.data.base.ref; - } - - if (isPr && baseRef !== 'main') { - console.log(`Skipping: PR targets '${baseRef}', but only 'main' is allowed.`); - core.setOutput('command', 'fallthrough'); - return; - } - - if (request.startsWith("@gemini-cli /review")) { - if (isPr) { - core.setOutput('command', 'review'); - const additionalContext = request.replace(/^@gemini-cli \/review/, '').trim(); - core.setOutput('additional_context', additionalContext); - } else { - core.setOutput('command', 'fallthrough'); - } - } else if (request.startsWith("@gemini-cli")) { - const additionalContext = request.replace(/^@gemini-cli/, '').trim(); - core.setOutput('command', 'invoke'); - core.setOutput('additional_context', additionalContext); - } else { - core.setOutput('command', 'fallthrough'); - } - - - name: 'Acknowledge request' - env: - GITHUB_TOKEN: '${{ steps.mint_identity_token.outputs.token || secrets.GITHUB_TOKEN || github.token }}' - ISSUE_NUMBER: '${{ github.event.pull_request.number || github.event.issue.number }}' - MESSAGE: |- - 🤖 Hi @${{ github.actor }}, I've received your request, and I'm working on it now! You can track my progress [in the logs](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}) for more details. - REPOSITORY: '${{ github.repository }}' - run: |- - gh issue comment "${ISSUE_NUMBER}" \ - --body "${MESSAGE}" \ - --repo "${REPOSITORY}" - - review: - needs: 'dispatch' - if: |- - ${{ needs.dispatch.outputs.command == 'review' }} - uses: './.github/workflows/gemini-review.yml' - permissions: - contents: 'read' - id-token: 'write' - issues: 'write' - pull-requests: 'write' - with: - additional_context: '${{ needs.dispatch.outputs.additional_context }}' - secrets: 'inherit' - - invoke: - needs: 'dispatch' - if: |- - ${{ needs.dispatch.outputs.command == 'invoke' }} - uses: './.github/workflows/gemini-invoke.yml' - permissions: - contents: 'read' - id-token: 'write' - issues: 'write' - pull-requests: 'write' - with: - additional_context: '${{ needs.dispatch.outputs.additional_context }}' - secrets: 'inherit' - - fallthrough: - needs: - - 'dispatch' - - 'review' - - 'invoke' - if: |- - ${{ always() && !cancelled() && (failure() || needs.dispatch.outputs.command == 'fallthrough') }} - runs-on: 'ubuntu-latest' - permissions: - contents: 'read' - issues: 'write' - pull-requests: 'write' - steps: - - name: 'Mint identity token' - id: 'mint_identity_token' - if: |- - ${{ vars.APP_ID }} - uses: 'actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf' # ratchet:actions/create-github-app-token@v2 - with: - app-id: '${{ vars.APP_ID }}' - private-key: '${{ secrets.APP_PRIVATE_KEY }}' - permission-contents: 'read' - permission-issues: 'write' - permission-pull-requests: 'write' - - - name: 'Send failure comment' - env: - GITHUB_TOKEN: '${{ steps.mint_identity_token.outputs.token || secrets.GITHUB_TOKEN || github.token }}' - ISSUE_NUMBER: '${{ github.event.pull_request.number || github.event.issue.number }}' - MESSAGE: |- - 🤖 I'm sorry @${{ github.actor }}, but I was unable to process your request. Please [see the logs](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}) for more details. - REPOSITORY: '${{ github.repository }}' - run: |- - gh issue comment "${ISSUE_NUMBER}" \ - --body "${MESSAGE}" \ - --repo "${REPOSITORY}" diff --git a/.github/workflows/gemini-invoke.yml b/.github/workflows/gemini-invoke.yml deleted file mode 100644 index 5138d6f7294..00000000000 --- a/.github/workflows/gemini-invoke.yml +++ /dev/null @@ -1,104 +0,0 @@ -name: '▶️ Gemini Invoke' - -on: - workflow_call: - inputs: - additional_context: - type: 'string' - description: 'Any additional context from the request' - required: false - -concurrency: - group: '${{ github.workflow }}-invoke-${{ github.event_name }}-${{ github.event.pull_request.number || github.event.issue.number }}' - cancel-in-progress: false - -defaults: - run: - shell: 'bash' - -jobs: - invoke: - runs-on: 'ubuntu-latest' - permissions: - contents: 'read' - id-token: 'write' - issues: 'write' - pull-requests: 'write' - steps: - - name: 'Mint identity token' - id: 'mint_identity_token' - if: |- - ${{ vars.APP_ID }} - uses: 'actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf' # ratchet:actions/create-github-app-token@v2 - with: - app-id: '${{ vars.APP_ID }}' - private-key: '${{ secrets.APP_PRIVATE_KEY }}' - permission-contents: 'read' - permission-issues: 'write' - permission-pull-requests: 'write' - - - name: 'Checkout Code' - uses: 'actions/checkout@v4' # ratchet:exclude - - - name: 'Run Gemini CLI' - id: 'run_gemini' - uses: 'google-github-actions/run-gemini-cli@v0' # ratchet:exclude - env: - TITLE: '${{ github.event.pull_request.title || github.event.issue.title }}' - DESCRIPTION: '${{ github.event.pull_request.body || github.event.issue.body }}' - EVENT_NAME: '${{ github.event_name }}' - GITHUB_TOKEN: '${{ steps.mint_identity_token.outputs.token || secrets.GITHUB_TOKEN || github.token }}' - IS_PULL_REQUEST: '${{ !!github.event.pull_request }}' - ISSUE_NUMBER: '${{ github.event.pull_request.number || github.event.issue.number }}' - REPOSITORY: '${{ github.repository }}' - ADDITIONAL_CONTEXT: '${{ inputs.additional_context }}' - # Required to allow the Gemini CLI to process files in the ephemeral GitHub Actions runner - GEMINI_CLI_TRUST_WORKSPACE: 'true' - with: - gcp_location: '${{ vars.GOOGLE_CLOUD_LOCATION }}' - gcp_project_id: '${{ vars.GOOGLE_CLOUD_PROJECT }}' - gcp_service_account: '${{ vars.SERVICE_ACCOUNT_EMAIL }}' - gcp_workload_identity_provider: '${{ vars.GCP_WIF_PROVIDER }}' - gemini_api_key: '${{ secrets.GOOGLE_API_KEY }}' - gemini_cli_version: '${{ vars.GEMINI_CLI_VERSION }}' - gemini_debug: '${{ fromJSON(vars.GEMINI_DEBUG || vars.ACTIONS_STEP_DEBUG || false) }}' - gemini_model: '${{ vars.GEMINI_MODEL }}' - google_api_key: '${{ secrets.GOOGLE_API_KEY }}' - use_gemini_code_assist: '${{ vars.GOOGLE_GENAI_USE_GCA }}' - use_vertex_ai: '${{ vars.GOOGLE_GENAI_USE_VERTEXAI }}' - upload_artifacts: '${{ vars.UPLOAD_ARTIFACTS }}' - workflow_name: 'gemini-invoke' - # Assistant workflows can be triggered by comments on either Issues or PRs. - # We explicitly map both fields so the CLI can correctly categorize the interaction. - github_pr_number: '${{ github.event.pull_request.number }}' - github_issue_number: '${{ github.event.issue.number }}' - settings: |- - { - "model": { - "maxSessionTurns": 25 - }, - "telemetry": { - "enabled": true, - "target": "local", - "outfile": ".gemini/telemetry.log" - }, - "mcpServers": { - "github": { - "command": "docker", - "args": [ - "run", - "-i", - "--rm", - "-e", - "GITHUB_PERSONAL_ACCESS_TOKEN", - "ghcr.io/github/github-mcp-server:v0.27.0" - ], - "env": { - "GITHUB_PERSONAL_ACCESS_TOKEN": "${{ steps.mint_identity_token.outputs.token || secrets.GITHUB_TOKEN || github.token }}" - } - } - } - } - prompt: |- - /gemini-invoke - [IMPORTANT] Do not generate execution plans and do not ask for approval (such as suggesting `@gemini-cli /approve`). Perform the requested task or answer the question directly and immediately. diff --git a/.github/workflows/gemini-review.yml b/.github/workflows/gemini-review.yml deleted file mode 100644 index 9c1b1bf4424..00000000000 --- a/.github/workflows/gemini-review.yml +++ /dev/null @@ -1,100 +0,0 @@ -name: '🔎 Gemini Review' - -on: - workflow_call: - inputs: - additional_context: - type: 'string' - description: 'Any additional context from the request' - required: false - -concurrency: - group: '${{ github.workflow }}-review-${{ github.event_name }}-${{ github.event.pull_request.number || github.event.issue.number }}' - cancel-in-progress: true - -defaults: - run: - shell: 'bash' - -jobs: - review: - runs-on: 'ubuntu-latest' - timeout-minutes: 7 - permissions: - contents: 'read' - id-token: 'write' - issues: 'write' - pull-requests: 'write' - steps: - - name: 'Mint identity token' - id: 'mint_identity_token' - if: |- - ${{ vars.APP_ID }} - uses: 'actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf' # ratchet:actions/create-github-app-token@v2 - with: - app-id: '${{ vars.APP_ID }}' - private-key: '${{ secrets.APP_PRIVATE_KEY }}' - permission-contents: 'read' - permission-issues: 'write' - permission-pull-requests: 'write' - - - name: 'Checkout repository' - uses: 'actions/checkout@v4' # ratchet:exclude - - - name: 'Run Gemini pull request review' - uses: 'google-github-actions/run-gemini-cli@v0' # ratchet:exclude - id: 'gemini_pr_review' - env: - GITHUB_TOKEN: '${{ steps.mint_identity_token.outputs.token || secrets.GITHUB_TOKEN || github.token }}' - ISSUE_TITLE: '${{ github.event.pull_request.title || github.event.issue.title }}' - ISSUE_BODY: '${{ github.event.pull_request.body || github.event.issue.body }}' - PULL_REQUEST_NUMBER: '${{ github.event.pull_request.number || github.event.issue.number }}' - REPOSITORY: '${{ github.repository }}' - ADDITIONAL_CONTEXT: '${{ inputs.additional_context }}' - GEMINI_API_KEY: '${{ secrets.GOOGLE_API_KEY }}' - # Required to allow the Gemini CLI to process files in the ephemeral GitHub Actions runner - GEMINI_CLI_TRUST_WORKSPACE: 'true' - with: - gcp_location: '${{ vars.GOOGLE_CLOUD_LOCATION }}' - gcp_project_id: '${{ vars.GOOGLE_CLOUD_PROJECT }}' - gcp_service_account: '${{ vars.SERVICE_ACCOUNT_EMAIL }}' - gcp_workload_identity_provider: '${{ vars.GCP_WIF_PROVIDER }}' - gemini_api_key: '${{ secrets.GOOGLE_API_KEY }}' - gemini_cli_version: '${{ vars.GEMINI_CLI_VERSION }}' - gemini_debug: '${{ fromJSON(vars.GEMINI_DEBUG || vars.ACTIONS_STEP_DEBUG || false) }}' - gemini_model: '${{ vars.GEMINI_MODEL }}' - google_api_key: '${{ secrets.GOOGLE_API_KEY }}' - use_gemini_code_assist: '${{ vars.GOOGLE_GENAI_USE_GCA }}' - use_vertex_ai: '${{ vars.GOOGLE_GENAI_USE_VERTEXAI }}' - upload_artifacts: '${{ vars.UPLOAD_ARTIFACTS }}' - workflow_name: 'gemini-review' - # Explicitly set the PR number to handle `issue_comment` triggers (which GitHub treats as issues, not PRs) - github_pr_number: '${{ github.event.pull_request.number || github.event.issue.number }}' - settings: |- - { - "model": { - "maxSessionTurns": 25 - }, - "telemetry": { - "enabled": true, - "target": "local", - "outfile": ".gemini/telemetry.log" - }, - "mcpServers": { - "github": { - "command": "docker", - "args": [ - "run", - "-i", - "--rm", - "-e", - "GITHUB_PERSONAL_ACCESS_TOKEN", - "ghcr.io/github/github-mcp-server:v0.27.0" - ], - "env": { - "GITHUB_PERSONAL_ACCESS_TOKEN": "${{ steps.mint_identity_token.outputs.token || secrets.GITHUB_TOKEN || github.token }}" - } - } - } - } - prompt: 'Please use the pull_request_read tool to read pull request #${{ github.event.pull_request.number || github.event.issue.number }}. Analyze the code for bugs, security issues, and best practices. Then, use the add_comment_to_pending_review and pull_request_review_write tools to post your review directly on pull request #${{ github.event.pull_request.number || github.event.issue.number }}.' diff --git a/.github/workflows/mypy-new-errors.yml b/.github/workflows/mypy-new-errors.yml deleted file mode 100644 index 2d3c8aebc28..00000000000 --- a/.github/workflows/mypy-new-errors.yml +++ /dev/null @@ -1,77 +0,0 @@ -name: Mypy New Error Check - -on: - push: - branches: [ main ] - pull_request: - branches: [ main ] - - -permissions: - contents: read - -jobs: - mypy-diff: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ['3.10', '3.11', '3.12', '3.13',] - steps: - - name: Checkout code - uses: actions/checkout@v6 - with: - fetch-depth: 0 - - - name: Set up Python - uses: actions/setup-python@v6 - with: - python-version: ${{ matrix.python-version }} - - - name: Install uv - uses: astral-sh/setup-uv@v7 - - - name: Generate Baseline (Main) - run: | - # Switch to main branch to generate baseline - git checkout origin/main - - git checkout ${{ github.sha }} -- pyproject.toml - - # Install dependencies for main - uv venv .venv - source .venv/bin/activate - uv sync --all-extras - - # Run mypy, filter for errors only, remove line numbers (file:123: -> file::), and sort - # We ignore exit code (|| true) because we expect errors on main - uv run mypy . | grep "error:" | sed 's/:\([0-9]\+\):/::/g' | sort > main_errors.txt || true - - echo "Found $(wc -l < main_errors.txt) errors on main." - - - name: Check PR Branch - run: | - # Switch back to the PR commit - git checkout ${{ github.sha }} - - # Re-sync dependencies in case the PR changed them - source .venv/bin/activate - uv sync --all-extras - - # Run mypy on PR code, apply same processing - uv run mypy . | grep "error:" | sed 's/:\([0-9]\+\):/::/g' | sort > pr_errors.txt || true - - echo "Found $(wc -l < pr_errors.txt) errors on PR branch." - - - name: Compare and Fail on New Errors - run: | - # 'comm -13' suppresses unique lines in file1 (main) and common lines, - # leaving only lines unique to file2 (PR) -> The new errors. - comm -13 main_errors.txt pr_errors.txt > new_errors.txt - - if [ -s new_errors.txt ]; then - echo "::error::The following NEW mypy errors were introduced:" - cat new_errors.txt - exit 1 - else - echo "Great job! No new mypy errors introduced." - fi diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml deleted file mode 100644 index 23a032f87a6..00000000000 --- a/.github/workflows/pre-commit.yml +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: Pre-commit Checks - -on: - push: - branches: [main, v2] - paths: - - '**.py' - - '.pre-commit-config.yaml' - - 'pyproject.toml' - pull_request: - branches: [main, v2] - paths: - - '**.py' - - '.pre-commit-config.yaml' - - 'pyproject.toml' - -permissions: - contents: read - -jobs: - pre-commit: - runs-on: ubuntu-latest - steps: - - name: Checkout Code - uses: actions/checkout@v6 - - - name: Run pre-commit checks - uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/python-unit-tests.yml b/.github/workflows/python-unit-tests.yml deleted file mode 100644 index 6e204a8e675..00000000000 --- a/.github/workflows/python-unit-tests.yml +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: Python Unit Tests - -on: - push: - branches: [ main ] - pull_request: - branches: [ main ] - -permissions: - contents: read - -jobs: - test: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] - - steps: - - name: Checkout code - uses: actions/checkout@v6 - - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v6 - with: - python-version: ${{ matrix.python-version }} - - - name: Install the latest version of uv - uses: astral-sh/setup-uv@v7 - - - name: Install dependencies - run: | - uv venv .venv - source .venv/bin/activate - uv sync --extra test - - - name: Run unit tests with pytest - run: | - source .venv/bin/activate - pytest tests/unittests \ - --ignore=tests/unittests/artifacts/test_artifact_service.py \ - --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py diff --git a/.github/workflows/release-v2-cherry-pick.yml b/.github/workflows/release-v1-cherry-pick.yml similarity index 75% rename from .github/workflows/release-v2-cherry-pick.yml rename to .github/workflows/release-v1-cherry-pick.yml index f5641a55b3d..91e858590bc 100644 --- a/.github/workflows/release-v2-cherry-pick.yml +++ b/.github/workflows/release-v1-cherry-pick.yml @@ -1,7 +1,7 @@ -# Step 3 (v2, optional): Cherry-picks a commit from v2 to the release/v2-candidate branch. +# Step 3 (v1, optional): Cherry-picks a commit from v1 to the release/v1-candidate branch. # Use between step 1 and step 4 to include bug fixes in an in-progress release. # Note: Does NOT auto-trigger release-please to preserve manual changelog edits. -name: "Release v2: Cherry-pick" +name: "Release v1: Cherry-pick" on: workflow_dispatch: @@ -20,7 +20,7 @@ jobs: steps: - uses: actions/checkout@v6 with: - ref: release/v2-candidate + ref: release/v1-candidate fetch-depth: 0 - name: Configure git @@ -30,14 +30,14 @@ jobs: - name: Cherry-pick commit run: | - echo "Cherry-picking ${INPUTS_COMMIT_SHA} to release/v2-candidate" + echo "Cherry-picking ${INPUTS_COMMIT_SHA} to release/v1-candidate" git cherry-pick ${INPUTS_COMMIT_SHA} env: INPUTS_COMMIT_SHA: ${{ inputs.commit_sha }} - name: Push changes run: | - git push origin release/v2-candidate - echo "Successfully cherry-picked commit to release/v2-candidate" + git push origin release/v1-candidate + echo "Successfully cherry-picked commit to release/v1-candidate" echo "Note: Release Please is NOT auto-triggered to preserve manual changelog edits." - echo "Run release-v2-please.yml manually if you want to regenerate the changelog." + echo "Run release-v1-please.yml manually if you want to regenerate the changelog." diff --git a/.github/workflows/release-v1-cut.yml b/.github/workflows/release-v1-cut.yml new file mode 100644 index 00000000000..7a35b6a1205 --- /dev/null +++ b/.github/workflows/release-v1-cut.yml @@ -0,0 +1,46 @@ +# Step 1 (v1): Starts the v1 release process by creating a release/v1-candidate branch. +# Generates a changelog PR for review (step 2). +name: "Release v1: Cut" + +on: + workflow_dispatch: + inputs: + commit_sha: + description: 'Commit SHA to cut from (leave empty for latest v1)' + required: false + type: string + +permissions: + contents: write + actions: write + +jobs: + cut-release: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + with: + ref: ${{ inputs.commit_sha || 'v1' }} + + - name: Check for existing release/v1-candidate branch + env: + GH_TOKEN: ${{ github.token }} + run: | + if git ls-remote --exit-code --heads origin release/v1-candidate &>/dev/null; then + echo "Error: release/v1-candidate branch already exists" + echo "Please finalize or delete the existing release candidate before starting a new one" + exit 1 + fi + + - name: Create and push release/v1-candidate branch + run: | + git checkout -b release/v1-candidate + git push origin release/v1-candidate + echo "Created branch: release/v1-candidate" + + - name: Trigger Release Please + env: + GH_TOKEN: ${{ github.token }} + run: | + gh workflow run release-v1-please.yml --repo ${{ github.repository }} --ref release/v1-candidate + echo "Triggered Release Please workflow for v1" diff --git a/.github/workflows/release-v2-finalize.yml b/.github/workflows/release-v1-finalize.yml similarity index 71% rename from .github/workflows/release-v2-finalize.yml rename to .github/workflows/release-v1-finalize.yml index c8b10209489..58ce1705630 100644 --- a/.github/workflows/release-v2-finalize.yml +++ b/.github/workflows/release-v1-finalize.yml @@ -1,12 +1,12 @@ -# Step 4 (v2): Triggers when the changelog PR is merged to release/v2-candidate. -# Records last-release-sha and renames release/v2-candidate to release/v{version}. -name: "Release v2: Finalize" +# Step 4 (v1): Triggers when the changelog PR is merged to release/v1-candidate. +# Records last-release-sha and renames release/v1-candidate to release/v{version}. +name: "Release v1: Finalize" on: pull_request: types: [closed] branches: - - release/v2-candidate + - release/v1-candidate permissions: contents: write @@ -32,7 +32,7 @@ jobs: - uses: actions/checkout@v6 if: steps.check.outputs.is_release_pr == 'true' with: - ref: release/v2-candidate + ref: release/v1-candidate token: ${{ secrets.RELEASE_PAT }} fetch-depth: 0 @@ -40,7 +40,7 @@ jobs: if: steps.check.outputs.is_release_pr == 'true' id: version run: | - VERSION=$(jq -r '.["."]' .github/.release-please-manifest-v2.json) + VERSION=$(jq -r '.["."]' .github/.release-please-manifest-v1.json) echo "version=$VERSION" >> $GITHUB_OUTPUT echo "Extracted version: $VERSION" @@ -56,21 +56,21 @@ jobs: - name: Record last-release-sha for release-please if: steps.check.outputs.is_release_pr == 'true' run: | - git fetch origin v2 - CUT_SHA=$(git merge-base origin/v2 HEAD) - echo "Release was cut from v2 at: $CUT_SHA" + git fetch origin v1 + CUT_SHA=$(git merge-base origin/v1 HEAD) + echo "Release was cut from v1 at: $CUT_SHA" jq --arg sha "$CUT_SHA" '. + {"last-release-sha": $sha}' \ - .github/release-please-config-v2.json > tmp.json && mv tmp.json .github/release-please-config-v2.json - git add .github/release-please-config-v2.json - git commit -m "chore: update last-release-sha for next v2 release" - git push origin release/v2-candidate + .github/release-please-config-v1.json > tmp.json && mv tmp.json .github/release-please-config-v1.json + git add .github/release-please-config-v1.json + git commit -m "chore: update last-release-sha for next v1 release" + git push origin release/v1-candidate - - name: Rename release/v2-candidate to release/v{version} + - name: Rename release/v1-candidate to release/v{version} if: steps.check.outputs.is_release_pr == 'true' run: | VERSION="v${STEPS_VERSION_OUTPUTS_VERSION}" - git push origin "release/v2-candidate:refs/heads/release/$VERSION" ":release/v2-candidate" - echo "Renamed release/v2-candidate to release/$VERSION" + git push origin "release/v1-candidate:refs/heads/release/$VERSION" ":release/v1-candidate" + echo "Renamed release/v1-candidate to release/$VERSION" env: STEPS_VERSION_OUTPUTS_VERSION: ${{ steps.version.outputs.version }} @@ -79,6 +79,7 @@ jobs: env: GH_TOKEN: ${{ github.token }} run: | + gh label create "autorelease: tagged" --color "EDEDED" --description "Tagged release" || true gh pr edit ${{ github.event.pull_request.number }} \ --remove-label "autorelease: pending" \ --add-label "autorelease: tagged" diff --git a/.github/workflows/release-v2-please.yml b/.github/workflows/release-v1-please.yml similarity index 60% rename from .github/workflows/release-v2-please.yml rename to .github/workflows/release-v1-please.yml index d659ca1c4e3..9f1344dd31f 100644 --- a/.github/workflows/release-v2-please.yml +++ b/.github/workflows/release-v1-please.yml @@ -1,10 +1,10 @@ -# Runs release-please to create/update a PR with version bump and changelog for v2. -# Triggered only by workflow_dispatch (from release-v2-cut.yml). +# Runs release-please to create/update a PR with version bump and changelog for v1. +# Triggered only by workflow_dispatch (from release-v1-cut.yml). # Does NOT auto-run on push to preserve manual changelog edits after cherry-picks. -name: "Release v2: Please" +name: "Release v1: Please" on: - # Only run via workflow_dispatch (triggered by release-v2-cut.yml) + # Only run via workflow_dispatch (triggered by release-v1-cut.yml) workflow_dispatch: permissions: @@ -15,27 +15,27 @@ jobs: release-please: runs-on: ubuntu-latest steps: - - name: Check if release/v2-candidate still exists + - name: Check if release/v1-candidate still exists id: check env: GH_TOKEN: ${{ github.token }} run: | - if gh api repos/${{ github.repository }}/branches/release/v2-candidate --silent 2>/dev/null; then + if gh api repos/${{ github.repository }}/branches/release/v1-candidate --silent 2>/dev/null; then echo "exists=true" >> $GITHUB_OUTPUT else - echo "release/v2-candidate branch no longer exists, skipping" + echo "release/v1-candidate branch no longer exists, skipping" echo "exists=false" >> $GITHUB_OUTPUT fi - uses: actions/checkout@v6 if: steps.check.outputs.exists == 'true' with: - ref: release/v2-candidate + ref: release/v1-candidate - uses: googleapis/release-please-action@v4 if: steps.check.outputs.exists == 'true' with: token: ${{ secrets.RELEASE_PAT }} - config-file: .github/release-please-config-v2.json - manifest-file: .github/.release-please-manifest-v2.json - target-branch: release/v2-candidate + config-file: .github/release-please-config-v1.json + manifest-file: .github/.release-please-manifest-v1.json + target-branch: release/v1-candidate diff --git a/.github/workflows/release-v2-publish.yml b/.github/workflows/release-v1-publish.yml similarity index 81% rename from .github/workflows/release-v2-publish.yml rename to .github/workflows/release-v1-publish.yml index 41edc78d9e4..a4f3b1419fa 100644 --- a/.github/workflows/release-v2-publish.yml +++ b/.github/workflows/release-v1-publish.yml @@ -1,8 +1,8 @@ -# Step 6 (v2): Builds and publishes the v2 package to PyPI from a release/v{version} branch. -# Reads version from .release-please-manifest-v2.json, converts to PEP 440, +# Step 6 (v1): Builds and publishes the v1 package to PyPI from a release/v{version} branch. +# Reads version from .release-please-manifest-v1.json, converts to PEP 440, # updates version.py, then builds and publishes. -# Creates a merge-back PR (step 7) to sync release changes to v2. -name: "Release v2: Publish to PyPi" +# Creates a merge-back PR (step 7) to sync release changes to v1. +name: "Release v1: Publish to PyPi" on: workflow_dispatch: @@ -18,7 +18,7 @@ jobs: - name: Validate branch run: | if [[ ! "${GITHUB_REF_NAME}" =~ ^release/v[0-9]+\.[0-9]+\.[0-9]+ ]]; then - echo "Error: Must run from a release/v* branch (e.g., release/v2.0.0-alpha.2)" + echo "Error: Must run from a release/v* branch (e.g., release/v1.34.1)" exit 1 fi @@ -27,15 +27,15 @@ jobs: - name: Extract version from manifest and convert to PEP 440 id: version run: | - VERSION=$(jq -r '.["."]' .github/.release-please-manifest-v2.json) + VERSION=$(jq -r '.["."]' .github/.release-please-manifest-v1.json) echo "semver=$VERSION" >> $GITHUB_OUTPUT echo "Semver version: $VERSION" # Convert semver pre-release to PEP 440: - # 2.0.0-alpha.1 -> 2.0.0a1 - # 2.0.0-beta.1 -> 2.0.0b1 - # 2.0.0-rc.1 -> 2.0.0rc1 - # 2.0.0 -> 2.0.0 (no change for stable) + # 1.35.0-alpha.1 -> 1.35.0a1 + # 1.35.0-beta.1 -> 1.35.0b1 + # 1.35.0-rc.1 -> 1.35.0rc1 + # 1.35.0 -> 1.35.0 (no change for stable) PEP440=$(echo "$VERSION" | sed -E 's/-alpha\./a/; s/-beta\./b/; s/-rc\./rc/') echo "pep440=$PEP440" >> $GITHUB_OUTPUT echo "PEP 440 version: $PEP440" @@ -73,7 +73,7 @@ jobs: PEP440_VERSION: ${{ steps.version.outputs.pep440 }} run: | gh pr create \ - --base v2 \ + --base v1 \ --head "${GITHUB_REF_NAME}" \ - --title "chore: merge release v${PEP440_VERSION} to v2" \ - --body "Syncs version bump and CHANGELOG from release v${SEMVER_VERSION} to v2." + --title "chore: merge release v${PEP440_VERSION} to v1" \ + --body "Syncs version bump and CHANGELOG from release v${SEMVER_VERSION} to v1." diff --git a/.github/workflows/release-v2-cut.yml b/.github/workflows/release-v2-cut.yml deleted file mode 100644 index 52af5bf038a..00000000000 --- a/.github/workflows/release-v2-cut.yml +++ /dev/null @@ -1,46 +0,0 @@ -# Step 1 (v2): Starts the v2 release process by creating a release/v2-candidate branch. -# Generates a changelog PR for review (step 2). -name: "Release v2: Cut" - -on: - workflow_dispatch: - inputs: - commit_sha: - description: 'Commit SHA to cut from (leave empty for latest v2)' - required: false - type: string - -permissions: - contents: write - actions: write - -jobs: - cut-release: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v6 - with: - ref: ${{ inputs.commit_sha || 'v2' }} - - - name: Check for existing release/v2-candidate branch - env: - GH_TOKEN: ${{ github.token }} - run: | - if git ls-remote --exit-code --heads origin release/v2-candidate &>/dev/null; then - echo "Error: release/v2-candidate branch already exists" - echo "Please finalize or delete the existing release candidate before starting a new one" - exit 1 - fi - - - name: Create and push release/v2-candidate branch - run: | - git checkout -b release/v2-candidate - git push origin release/v2-candidate - echo "Created branch: release/v2-candidate" - - - name: Trigger Release Please - env: - GH_TOKEN: ${{ github.token }} - run: | - gh workflow run release-v2-please.yml --repo ${{ github.repository }} --ref release/v2-candidate - echo "Triggered Release Please workflow for v2" diff --git a/.github/workflows/v2-sync.yml b/.github/workflows/v2-sync.yml deleted file mode 100644 index c627f40d46b..00000000000 --- a/.github/workflows/v2-sync.yml +++ /dev/null @@ -1,59 +0,0 @@ -# Automatically creates a PR to merge main (v1) into v2 to keep v2 up to date. -# The oncall is responsible for reviewing and merging the sync PR. -name: "Sync: main -> v2" - -on: - schedule: - - cron: '0 6 * * *' # Daily at 6am UTC - workflow_dispatch: - -permissions: - contents: write - pull-requests: write - -jobs: - sync: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v6 - with: - ref: v2 - fetch-depth: 0 - token: ${{ secrets.RELEASE_PAT }} - - - name: Check for new commits on main - id: check - run: | - git fetch origin main - BEHIND=$(git rev-list --count HEAD..origin/main) - echo "behind=$BEHIND" >> $GITHUB_OUTPUT - if [ "$BEHIND" -eq 0 ]; then - echo "v2 is up to date with main, nothing to sync" - else - echo "v2 is $BEHIND commit(s) behind main" - fi - - - name: Check for existing sync PR - if: steps.check.outputs.behind != '0' - id: existing - env: - GH_TOKEN: ${{ github.token }} - run: | - PR=$(gh pr list --base v2 --head main --state open --json number --jq '.[0].number // empty') - if [ -n "$PR" ]; then - echo "Sync PR #$PR already exists, skipping" - echo "exists=true" >> $GITHUB_OUTPUT - else - echo "exists=false" >> $GITHUB_OUTPUT - fi - - - name: Create sync PR - if: steps.check.outputs.behind != '0' && steps.existing.outputs.exists == 'false' - env: - GH_TOKEN: ${{ secrets.RELEASE_PAT }} - run: | - gh pr create \ - --base v2 \ - --head main \ - --title "chore: sync main -> v2" \ - --body "Automated sync of v1 changes from main into v2. The oncall is responsible for reviewing and merging this PR. Resolve conflicts in favor of the v2 implementation." diff --git a/CHANGELOG.md b/CHANGELOG.md index 08799f7a133..69714eb0317 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,54 @@ # Changelog +## [1.36.0](https://github.com/google/adk-python/compare/v1.35.2...v1.36.0) (2026-06-22) + + +### Features + +* **interactions:** support for GenAI SDK upgraded to 2.9+ ([#6191](https://github.com/google/adk-python/issues/6191)) ([7a9152a](https://github.com/google/adk-python/commit/7a9152a382b05a2a733adbc9bde25dacb02893a2)) + +## [1.35.2](https://github.com/google/adk-python/compare/v1.35.1...v1.35.2) (2026-06-17) + + +### Bug Fixes + +* remove live event buffering in runner ([#6151](https://github.com/google/adk-python/issues/6151)) ([afe4083](https://github.com/google/adk-python/commit/afe408376a9c12fc3b206df234a1655f565c826c)) + +## [1.35.0](https://github.com/google/adk-python/compare/v1.34.2...v1.35.0) (2026-06-15) + + +### Features + +* **live:** Handle input transcription differently for Gemini Live 3.1 models ([#6045](https://github.com/google/adk-python/issues/6045)) ([ecfdaf5](https://github.com/google/adk-python/commit/ecfdaf5f5e7accfbb4294cb8cc56c910dad2b1a8)) + + +### Bug Fixes + +* add missing Gemini imports in base_llm_flow ([#5943](https://github.com/google/adk-python/issues/5943)) ([6d027b4](https://github.com/google/adk-python/commit/6d027b4ce8bc1c5d288b02e1e3819917117038ec)) +* **flows:** Reset reconnect attempts on connection success ([#6042](https://github.com/google/adk-python/issues/6042)) ([87abf23](https://github.com/google/adk-python/commit/87abf230dbc21b49fa5606e18627c7f62df0d37b)) +* **models:** Default grounding metadata for Gemini 3.1 live ([#6018](https://github.com/google/adk-python/issues/6018)) ([fafafb3](https://github.com/google/adk-python/commit/fafafb38e1027a5cfe185357f6b8a107bbd3779e)) +* Only send grounding_metadata for 3.1 live at the end of each turn ([#6129](https://github.com/google/adk-python/issues/6129)) ([847a259](https://github.com/google/adk-python/commit/847a259cd89a7b720582ae46f6856d6a8c8000b7)) +* **streaming:** Ensure final partial=False frame is always yielded ([#6096](https://github.com/google/adk-python/issues/6096)) ([6e59c61](https://github.com/google/adk-python/commit/6e59c61104718d12b1265958c9c2992eee65abbf)) +* Support generalized history config injection for Gemini 3.1 Live on Vertex AI ([#5999](https://github.com/google/adk-python/issues/5999)) ([aafd97f](https://github.com/google/adk-python/commit/aafd97f6f0ae114b0ca772b4f5176602e3677e79)) + +## [1.34.2](https://github.com/google/adk-python/compare/v1.34.1...v1.34.2) (2026-06-01) + + +### Bug Fixes + +* Fix bug where grounding metadata in Gemini 3.1 live was being silently discarded ([9b6b9e9](https://github.com/google/adk-python/commit/9b6b9e976300ab223c77075554d6cd66ce1179ff)) +* fix input and output transcription finished events for Gemini v3.1 ([13763d7](https://github.com/google/adk-python/commit/13763d71f883b215dae08feb3f042869b9cd5d18)) +* **tools:** Prevent session drop on MCP tool error ([1fd406b](https://github.com/google/adk-python/commit/1fd406b90ae00c59d84093c33bc04530825bc760)) + +## [1.34.1](https://github.com/google/adk-python/compare/v1.34.0...v1.34.1) (2026-05-22) + + +### Bug Fixes + +* Fix bug where grounding metadata in Gemini 3.1 live was being silently discarded ([9b6b9e9](https://github.com/google/adk-python/commit/9b6b9e976300ab223c77075554d6cd66ce1179ff)) +* fix input and output transcription finished events for Gemini v3.1 ([13763d7](https://github.com/google/adk-python/commit/13763d71f883b215dae08feb3f042869b9cd5d18)) +* **tools:** Prevent session drop on MCP tool error ([1fd406b](https://github.com/google/adk-python/commit/1fd406b90ae00c59d84093c33bc04530825bc760)) + ## [1.34.0](https://github.com/google/adk-python/compare/v1.33.0...v1.34.0) (2026-05-18) diff --git a/contributing/samples/interactions_api/agent.py b/contributing/samples/interactions_api/agent.py index 908a8539482..5561bfbd898 100644 --- a/contributing/samples/interactions_api/agent.py +++ b/contributing/samples/interactions_api/agent.py @@ -12,19 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Agent definition for testing the Interactions API integration. - -NOTE: The Interactions API does NOT support mixing custom function calling tools -with built-in tools in the same agent. To work around this limitation, we use -bypass_multi_tools_limit=True on GoogleSearchTool, which converts the built-in -google_search to a function calling tool (via GoogleSearchAgentTool). - -The bypass is only triggered when len(agent.tools) > 1, so we include multiple -tools in the agent (GoogleSearchTool + get_current_weather). - -With bypass_multi_tools_limit=True and multiple tools, all tools become function -calling tools, which allows mixing google_search with custom function tools. -""" +"""Agent definition for testing the Interactions API integration.""" from google.adk.agents.llm_agent import Agent from google.adk.models.google_llm import Gemini @@ -74,10 +62,7 @@ def get_current_weather(city: str) -> dict: } -# Main agent with google_search (via bypass) and custom function tools -# Using bypass_multi_tools_limit=True converts google_search to a function calling tool. -# We need len(tools) > 1 to trigger the bypass, so we include get_current_weather directly. -# This allows mixing google_search with custom function tools via the Interactions API. +# Main agent with google_search built-in tool and custom function tools # # NOTE: code_executor is not compatible with function calling mode because the model # tries to call a function (e.g., run_code) instead of outputting code in markdown. @@ -99,7 +84,7 @@ def get_current_weather(city: str) -> dict: Be concise and helpful in your responses. Always confirm what you did. """, tools=[ - GoogleSearchTool(bypass_multi_tools_limit=True), + GoogleSearchTool(), get_current_weather, ], ) diff --git a/contributing/samples/interactions_api/main.py b/contributing/samples/interactions_api/main.py index a776f31ea94..8b40c3c12ea 100644 --- a/contributing/samples/interactions_api/main.py +++ b/contributing/samples/interactions_api/main.py @@ -16,17 +16,11 @@ This script tests the following features: 1. Basic text generation -2. Google Search tool (via bypass_multi_tools_limit) +2. Google Search tool 3. Multi-turn conversations with stateful interactions 4. Google Search tool (additional coverage) 5. Custom function tool (get_current_weather) -NOTE: The Interactions API does NOT support mixing custom function calling tools -with built-in tools. To work around this, we use bypass_multi_tools_limit=True -on GoogleSearchTool, which converts it to a function calling tool (via -GoogleSearchAgentTool). The bypass only triggers when len(agent.tools) > 1, -so we include both GoogleSearchTool and get_current_weather in the agent. - NOTE: Code execution via UnsafeLocalCodeExecutor is not compatible with function calling mode because the model tries to call a function instead of outputting code in markdown. @@ -41,7 +35,6 @@ import logging from pathlib import Path import time -from typing import Optional from dotenv import load_dotenv from google.adk.agents.run_config import RunConfig @@ -49,6 +42,7 @@ from google.adk.runners import InMemoryRunner from google.adk.runners import Runner from google.genai import types +import httpx from .agent import root_agent @@ -67,7 +61,8 @@ async def call_agent_async( prompt: str, agent_name: str = "", show_interaction_id: bool = True, -) -> tuple[str, Optional[str]]: + additional_parts: list[types.Part] | None = None, +) -> tuple[str, str | None]: """Call the agent asynchronously with the user's prompt. Args: @@ -77,13 +72,16 @@ async def call_agent_async( prompt: The prompt to send agent_name: The expected agent name for filtering responses show_interaction_id: Whether to show interaction IDs in output + additional_parts: Optional list of additional content parts (e.g. files) Returns: A tuple of (response_text, interaction_id) """ - content = types.Content( - role="user", parts=[types.Part.from_text(text=prompt)] - ) + parts = [types.Part.from_text(text=prompt)] + if additional_parts: + parts.extend(additional_parts) + + content = types.Content(role="user", parts=parts) final_response_text = "" last_interaction_id = None @@ -264,6 +262,39 @@ async def test_custom_function_tool(runner: Runner, session_id: str): return interaction_id +async def test_pdf_summarization(runner: Runner, session_id: str) -> str | None: + """Test PDF summarization using the Interactions API.""" + print("\n" + "=" * 60) + print("TEST 6: PDF Summarization") + print("=" * 60) + + url = "https://storage.googleapis.com/cloud-samples-data/generative-ai/pdf/2403.05530.pdf" + print(f"Downloading {url}...") + async with httpx.AsyncClient() as client: + response = await client.get( + url, headers={"User-Agent": "Mozilla/5.0"}, follow_redirects=True + ) + response.raise_for_status() + pdf_bytes = response.content + + pdf_part = types.Part.from_bytes(data=pdf_bytes, mime_type="application/pdf") + response, interaction_id = await call_agent_async( + runner, + USER_ID, + session_id, + "Please summarize the attached PDF document.", + additional_parts=[pdf_part], + ) + + assert response, "Expected a non-empty response" + assert len(response) > 0, f"Expected summary in response: {response}" + assert ( + "gemini" in response.lower() or "multimodal" in response.lower() + ), f"Expected summary of PDF in response: {response}" + print("PASSED: PDF Summarization works") + return interaction_id + + def check_interactions_api_available() -> bool: """Check if the interactions API is available in the SDK.""" try: @@ -311,6 +342,7 @@ async def run_all_tests(): await test_multi_turn_conversation(runner, session.id) await test_google_search_tool(runner, session.id) await test_custom_function_tool(runner, session.id) + await test_pdf_summarization(runner, session.id) print("\n" + "=" * 60) print("ALL TESTS PASSED (Interactions API)") diff --git a/pyproject.toml b/pyproject.toml index 0e7c449b9f8..8a7048d10a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ dependencies = [ "google-cloud-spanner>=3.56,<4", # For Spanner database "google-cloud-speech>=2.30,<3", # For Audio Transcription "google-cloud-storage>=2.18,<4", # For GCS Artifact service - "google-genai>=1.72,<2", # Google GenAI SDK + "google-genai>=2.9,<3", # Google GenAI SDK "graphviz>=0.20.2,<1", # Graphviz for graph rendering "httpx>=0.27,<1", # HTTP client library "jsonschema>=4.23,<5", # Agent Builder config validation diff --git a/src/google/adk/agents/run_config.py b/src/google/adk/agents/run_config.py index e059cd957db..8126ac5bf3f 100644 --- a/src/google/adk/agents/run_config.py +++ b/src/google/adk/agents/run_config.py @@ -247,6 +247,9 @@ class RunConfig(BaseModel): session_resumption: Optional[types.SessionResumptionConfig] = None """Configures session resumption mechanism. Only support transparent session resumption mode now.""" + history_config: Optional[types.HistoryConfig] = None + """Configures the exchange of history between the client and the server.""" + context_window_compression: Optional[types.ContextWindowCompressionConfig] = ( None ) diff --git a/src/google/adk/features/_feature_registry.py b/src/google/adk/features/_feature_registry.py index b5f51f28256..5f67e166073 100644 --- a/src/google/adk/features/_feature_registry.py +++ b/src/google/adk/features/_feature_registry.py @@ -138,7 +138,7 @@ class FeatureConfig: FeatureStage.WIP, default_on=False ), FeatureName._MCP_GRACEFUL_ERROR_HANDLING: FeatureConfig( - FeatureStage.EXPERIMENTAL, default_on=False + FeatureStage.EXPERIMENTAL, default_on=True ), FeatureName.PROGRESSIVE_SSE_STREAMING: FeatureConfig( FeatureStage.EXPERIMENTAL, default_on=True diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index db897637c31..9c0731e6a2b 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -39,6 +39,8 @@ from ...auth.auth_tool import AuthConfig from ...events.event import Event from ...models.base_llm_connection import BaseLlmConnection +from ...models.google_llm import Gemini +from ...models.google_llm import GoogleLLMVariant from ...models.llm_request import LlmRequest from ...models.llm_response import LlmResponse from ...telemetry import tracing @@ -47,6 +49,7 @@ from ...telemetry.tracing import tracer from ...tools.base_toolset import BaseToolset from ...tools.tool_context import ToolContext +from ...utils import model_name_utils from ...utils.context_utils import Aclosing from .audio_cache_manager import AudioCacheManager from .functions import build_auth_request_event @@ -54,6 +57,11 @@ # Prefix used by toolset auth credential IDs TOOLSET_AUTH_CREDENTIAL_ID_PREFIX = '_adk_toolset_auth_' + +class _ReconnectSentinel(Event): + """Internal sentinel event to signal a silent reconnection request.""" + + if TYPE_CHECKING: from ...agents.llm_agent import LlmAgent from ...models.base_llm import BaseLlm @@ -514,13 +522,47 @@ async def run_live( llm_request.live_connect_config.session_resumption.handle = ( invocation_context.live_session_resumption_handle ) - llm_request.live_connect_config.session_resumption.transparent = True + # Only set transparent=True for Vertex AI backend, as the Gemini API + # backend explicitly rejects it. + if ( + isinstance(llm, Gemini) + and llm._api_backend == GoogleLLMVariant.VERTEX_AI # pylint: disable=protected-access + ): + session_resumption = ( + llm_request.live_connect_config.session_resumption + ) + if session_resumption.transparent is None: + session_resumption.transparent = True + + # When seeding a fresh connection with prior conversation history, set + # initial_history_in_client_content to True. This tells the Live server + # that the provided history already includes the model's past responses, + # preventing the server from generating duplicate responses for those replayed turns. + if ( + llm_request.contents + and not invocation_context.live_session_resumption_handle + ): + if not llm_request.live_connect_config: + llm_request.live_connect_config = types.LiveConnectConfig() + if not llm_request.live_connect_config.history_config: + llm_request.live_connect_config.history_config = ( + types.HistoryConfig() + ) + if ( + llm_request.live_connect_config.history_config.initial_history_in_client_content + is None + ): + llm_request.live_connect_config.history_config.initial_history_in_client_content = ( + True + ) logger.info( 'Establishing live connection for agent: %s', invocation_context.agent.name, ) async with llm.connect(llm_request) as llm_connection: + # Reset attempt counter on successful connection. + attempt = 1 # Skip sending history if we are resuming a session. The server # already has the state associated with the resumption handle. if ( @@ -540,6 +582,7 @@ async def run_live( self._send_to_model(llm_connection, invocation_context) ) + should_reconnect = False try: async with Aclosing( self._receive_from_model( @@ -550,8 +593,9 @@ async def run_live( ) ) as agen: async for event in agen: - # Reset attempt counter on successful communication. - attempt = 1 + if isinstance(event, _ReconnectSentinel): + should_reconnect = True + break # Empty event means the queue is closed. if not event: break @@ -632,6 +676,9 @@ async def run_live( await send_task except asyncio.CancelledError: pass + if should_reconnect: + continue + break except (ConnectionClosed, ConnectionClosedOK) as e: # If we have a session resumption handle, we attempt to reconnect. # This handle is updated dynamically during the session. @@ -770,9 +817,9 @@ def get_author_for_event(llm_response: LlmResponse) -> str: if llm_response.go_away: logger.info(f'Received go away signal: {llm_response.go_away}') # The server signals that it will close the connection soon. - # We proactively raise ConnectionClosed to trigger the reconnection - # logic in run_live, which will use the latest session handle. - raise ConnectionClosed(None, None) + # We yield a sentinel event to request reconnection internally. + yield _ReconnectSentinel(author='system') + return model_response_event = Event( id=Event.new_id(), @@ -976,6 +1023,7 @@ async def _postprocess_async( not llm_response.content and not llm_response.error_code and not llm_response.interrupted + and not llm_response.grounding_metadata ): return @@ -1040,6 +1088,7 @@ async def _postprocess_live( and not llm_response.output_transcription and not llm_response.usage_metadata and not llm_response.live_session_resumption_update + and not llm_response.grounding_metadata ): return diff --git a/src/google/adk/flows/llm_flows/basic.py b/src/google/adk/flows/llm_flows/basic.py index 8e9bfa514ca..50f03d0bf17 100644 --- a/src/google/adk/flows/llm_flows/basic.py +++ b/src/google/adk/flows/llm_flows/basic.py @@ -25,6 +25,7 @@ from ...agents.invocation_context import InvocationContext from ...events.event import Event from ...models.llm_request import LlmRequest +from ...utils import model_name_utils from ...utils.output_schema_utils import can_use_output_schema_with_tools from ._base_llm_processor import BaseLlmRequestProcessor @@ -78,15 +79,25 @@ def _build_basic_request( llm_request.live_connect_config.realtime_input_config = ( invocation_context.run_config.realtime_input_config ) + active_model_name = ( + getattr(getattr(agent, 'canonical_live_model', None), 'model', None) + or llm_request.model + ) + is_gemini_31 = model_name_utils.is_gemini_3_1_flash_live(active_model_name) llm_request.live_connect_config.enable_affective_dialog = ( - invocation_context.run_config.enable_affective_dialog + None + if is_gemini_31 + else invocation_context.run_config.enable_affective_dialog ) llm_request.live_connect_config.proactivity = ( - invocation_context.run_config.proactivity + None if is_gemini_31 else invocation_context.run_config.proactivity ) llm_request.live_connect_config.session_resumption = ( invocation_context.run_config.session_resumption ) + llm_request.live_connect_config.history_config = ( + invocation_context.run_config.history_config + ) llm_request.live_connect_config.context_window_compression = ( invocation_context.run_config.context_window_compression ) diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 11ed8386e11..215635968b7 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -50,6 +50,9 @@ def __init__( self._output_transcription_text: str = '' self._api_backend = api_backend self._model_version = model_version + self._is_gemini_3_1_flash_live = model_name_utils.is_gemini_3_1_flash_live( + model_version + ) async def send_history(self, history: list[types.Content]): """Sends the conversation history to the gemini model. @@ -80,10 +83,30 @@ async def send_history(self, history: list[types.Content]): ] if contents: + # Gemini Enterprise Agent Platform does not support history_config in the SDK. + # To initialize a live session with prior history without hitting a 1007 + # protocol error (invalid role mid-session), we consolidate previous multi-turn + # interactions into a unified contextual preamble on a single user role turn. + if ( + self._is_gemini_3_1_flash_live + and self._api_backend != GoogleLLMVariant.GEMINI_API + ): + collapsed_text = 'Previous conversation history:\n' + for c in contents: + text_parts = ''.join(p.text for p in c.parts if p.text) + collapsed_text += f'[{c.role}]: {text_parts}\n' + contents = [ + types.Content( + role='user', parts=[types.Part.from_text(text=collapsed_text)] + ) + ] + logger.debug('Sending history to live connection: %s', contents) await self._gemini_session.send_client_content( turns=contents, - turn_complete=contents[-1].role == 'user', + turn_complete=True + if self._is_gemini_3_1_flash_live + else (contents[-1].role == 'user'), ) else: logger.info('no content is sent') @@ -108,10 +131,11 @@ async def send_content(self, content: types.Content): ) else: logger.debug('Sending LLM new content %s', content) - is_gemini_31 = model_name_utils.is_gemini_3_1_flash_live( - self._model_version - ) - if is_gemini_31 and len(content.parts) == 1 and content.parts[0].text: + if ( + self._is_gemini_3_1_flash_live + and len(content.parts) == 1 + and content.parts[0].text + ): logger.debug('Using send_realtime_input for Gemini 3.1 text input') await self._gemini_session.send_realtime_input( text=content.parts[0].text @@ -133,10 +157,7 @@ async def send_realtime(self, input: RealtimeInput): if isinstance(input, types.Blob): # The blob is binary and is very large. So let's not log it. logger.debug('Sending LLM Blob.') - is_gemini_31 = model_name_utils.is_gemini_3_1_flash_live( - self._model_version - ) - if is_gemini_31: + if self._is_gemini_3_1_flash_live: if input.mime_type and input.mime_type.startswith('audio/'): await self._gemini_session.send_realtime_input(audio=input) elif input.mime_type and input.mime_type.startswith('image/'): @@ -159,7 +180,12 @@ async def send_realtime(self, input: RealtimeInput): else: raise ValueError('Unsupported input type: %s' % type(input)) - def __build_full_text_response(self, text: str): + def __build_full_text_response( + self, + text: str, + is_thought: bool = False, + grounding_metadata: types.GroundingMetadata | None = None, + ): """Builds a full text response. The text should not be partial and the returned LlmResponse is not @@ -167,15 +193,23 @@ def __build_full_text_response(self, text: str): Args: text: The text to be included in the response. + is_thought: Whether the text is a thought. + grounding_metadata: The grounding metadata to include. Returns: An LlmResponse containing the full text. """ + part = types.Part.from_text(text=text) + if is_thought: + part.thought = True + return LlmResponse( content=types.Content( role='model', - parts=[types.Part.from_text(text=text)], + parts=[part], ), + grounding_metadata=grounding_metadata, + partial=False, live_session_id=self._gemini_session.session_id, ) @@ -187,7 +221,9 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: """ text = '' + is_thought = False tool_call_parts = [] + pending_grounding_metadata = None async with Aclosing(self._gemini_session.receive()) as agen: # TODO(b/440101573): Reuse StreamingResponseAggregator to accumulate # partial content and emit responses as needed. @@ -203,6 +239,10 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: ) if message.server_content: content = message.server_content.model_turn + if message.server_content.grounding_metadata: + pending_grounding_metadata = ( + message.server_content.grounding_metadata + ) # Standalone grounding_metadata event (when content is empty) if ( @@ -215,6 +255,9 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: interrupted=message.server_content.interrupted, model_version=self._model_version, live_session_id=live_session_id, + turn_complete_reason=getattr( + message.server_content, 'turn_complete_reason', None + ), ) if content and content.parts: @@ -223,51 +266,81 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: interrupted=message.server_content.interrupted, model_version=self._model_version, live_session_id=live_session_id, + turn_complete_reason=getattr( + message.server_content, 'turn_complete_reason', None + ), ) # grounding_metadata is yielded again at turn_complete, # so avoid duplicating it here if turn_complete is true. if not message.server_content.turn_complete: - llm_response.grounding_metadata = ( - message.server_content.grounding_metadata - ) - if content.parts[0].text: - text += content.parts[0].text - llm_response.partial = True - # don't yield the merged text event when receiving audio data - elif text and not content.parts[0].inline_data: - yield self.__build_full_text_response(text) + if message.server_content.grounding_metadata is not None: + llm_response.grounding_metadata = ( + message.server_content.grounding_metadata + ) + has_inline_data = any(p.inline_data for p in content.parts) + for part in content.parts: + if part.text: + current_is_thought = getattr(part, 'thought', False) + if text and current_is_thought != is_thought: + yield self.__build_full_text_response(text, is_thought) + text = '' + is_thought = False + + text += part.text + is_thought = current_is_thought + llm_response.partial = True + if ( + text + and not any(p.text for p in content.parts) + and not has_inline_data + ): + yield self.__build_full_text_response(text, is_thought) text = '' yield llm_response # Note: in some cases, tool_call may arrive before # generation_complete, causing transcription to appear after # tool_call in the session log. if message.server_content.input_transcription: - if message.server_content.input_transcription.text: - self._input_transcription_text += ( - message.server_content.input_transcription.text - ) - yield LlmResponse( - input_transcription=types.Transcription( - text=message.server_content.input_transcription.text, - finished=False, - ), - partial=True, - model_version=self._model_version, - live_session_id=live_session_id, - ) - # finished=True and partial transcription may happen in the same - # message. - if message.server_content.input_transcription.finished: - yield LlmResponse( - input_transcription=types.Transcription( - text=self._input_transcription_text, - finished=True, - ), - partial=False, - model_version=self._model_version, - live_session_id=live_session_id, - ) - self._input_transcription_text = '' + # Gemini 3.1 Flash Live only sends a single final input + # transcription + if self._is_gemini_3_1_flash_live: + if message.server_content.input_transcription.text: + yield LlmResponse( + input_transcription=types.Transcription( + text=message.server_content.input_transcription.text, + finished=True, + ), + partial=False, + model_version=self._model_version, + live_session_id=live_session_id, + ) + else: + if message.server_content.input_transcription.text: + self._input_transcription_text += ( + message.server_content.input_transcription.text + ) + yield LlmResponse( + input_transcription=types.Transcription( + text=message.server_content.input_transcription.text, + finished=False, + ), + partial=True, + model_version=self._model_version, + live_session_id=live_session_id, + ) + # finished=True and partial transcription may happen in the same + # message. + if message.server_content.input_transcription.finished: + yield LlmResponse( + input_transcription=types.Transcription( + text=self._input_transcription_text, + finished=True, + ), + partial=False, + model_version=self._model_version, + live_session_id=live_session_id, + ) + self._input_transcription_text = '' if message.server_content.output_transcription: if message.server_content.output_transcription.text: self._output_transcription_text += ( @@ -293,10 +366,10 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: live_session_id=live_session_id, ) self._output_transcription_text = '' - # The Gemini API might not send a transcription finished signal. + # The Gemini API or Vertex AI might not send a transcription finished signal. # Instead, we rely on generation_complete, turn_complete or # interrupted signals to flush any pending transcriptions. - if self._api_backend == GoogleLLMVariant.GEMINI_API and ( + if ( message.server_content.interrupted or message.server_content.turn_complete or message.server_content.generation_complete @@ -324,9 +397,14 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: ) self._output_transcription_text = '' if message.server_content.turn_complete: + g_metadata_to_yield = pending_grounding_metadata if text: - yield self.__build_full_text_response(text) + yield self.__build_full_text_response( + text, is_thought, g_metadata_to_yield + ) text = '' + is_thought = False + g_metadata_to_yield = None if tool_call_parts: logger.debug('Returning aggregated tool_call_parts') yield LlmResponse( @@ -338,9 +416,18 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: yield LlmResponse( turn_complete=True, interrupted=message.server_content.interrupted, - grounding_metadata=message.server_content.grounding_metadata, + grounding_metadata=message.server_content.grounding_metadata + or g_metadata_to_yield + or ( + types.GroundingMetadata() + if self._is_gemini_3_1_flash_live + else None + ), model_version=self._model_version, live_session_id=live_session_id, + turn_complete_reason=getattr( + message.server_content, 'turn_complete_reason', None + ), ) break # in case of empty content or parts, we still surface it @@ -371,10 +458,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: # deadlocking the conversation. Other models (e.g. 2.5-pro, # native-audio) send turn_complete after tool calls, so buffer # and merge them into a single response at turn_complete. - if ( - model_name_utils.is_gemini_3_1_flash_live(self._model_version) - and tool_call_parts - ): + if self._is_gemini_3_1_flash_live and tool_call_parts: logger.debug( 'Yielding tool_call_parts immediately for Gemini 3.1 live tool' ' call' diff --git a/src/google/adk/models/interactions_utils.py b/src/google/adk/models/interactions_utils.py index 89ffe6be71b..044081bb759 100644 --- a/src/google/adk/models/interactions_utils.py +++ b/src/google/adk/models/interactions_utils.py @@ -35,21 +35,51 @@ import logging from typing import Any from typing import AsyncGenerator -from typing import Optional from typing import TYPE_CHECKING from google.genai import types +from google.genai.interactions import AudioContentParam +from google.genai.interactions import CodeExecutionCallStep +from google.genai.interactions import CodeExecutionCallStepParam +from google.genai.interactions import CodeExecutionResultStep +from google.genai.interactions import CodeExecutionResultStepParam +from google.genai.interactions import ContentParam +from google.genai.interactions import DocumentContentParam +from google.genai.interactions import ErrorEvent +from google.genai.interactions import FunctionCallStep +from google.genai.interactions import FunctionCallStepParam +from google.genai.interactions import FunctionParam +from google.genai.interactions import FunctionResultStep +from google.genai.interactions import FunctionResultStepParam +from google.genai.interactions import GenerationConfigParam +from google.genai.interactions import GoogleSearchResultStep +from google.genai.interactions import ImageContentParam +from google.genai.interactions import Interaction +from google.genai.interactions import InteractionCompletedEvent +from google.genai.interactions import InteractionCreatedEvent +from google.genai.interactions import InteractionSSEEvent +from google.genai.interactions import InteractionStatusUpdate +from google.genai.interactions import ModelOutputStep +from google.genai.interactions import ModelOutputStepParam +from google.genai.interactions import Step +from google.genai.interactions import StepDelta +from google.genai.interactions import StepParam +from google.genai.interactions import StepStart +from google.genai.interactions import StepStop +from google.genai.interactions import TextContentParam +from google.genai.interactions import ThoughtStep +from google.genai.interactions import ThoughtStepParam +from google.genai.interactions import ToolParam +from google.genai.interactions import UserInputStepParam +from google.genai.interactions import VideoContentParam +from pydantic import BaseModel +from typing_extensions import deprecated if TYPE_CHECKING: from google.genai import Client - from google.genai._interactions.types.interaction import Output - from google.genai._interactions.types.tool_param import ToolParam - from google.genai._interactions.types.turn_param import TurnParam - from google.genai.interactions_types import Interaction - from google.genai.interactions_types import InteractionSSEEvent - from .llm_request import LlmRequest - from .llm_response import LlmResponse +from .llm_request import LlmRequest +from .llm_response import LlmResponse logger = logging.getLogger('google_adk.' + __name__) @@ -57,8 +87,8 @@ def _extract_stream_interaction_id( - event: 'InteractionSSEEvent', -) -> Optional[str]: + event: InteractionSSEEvent, +) -> str | None: """Extract the interaction ID from an Interactions SSE event. Different SSE lifecycle events expose the interaction ID on different @@ -67,26 +97,37 @@ def _extract_stream_interaction_id( google-genai builds may also yield a legacy ``interaction`` event with a top-level ``id``. """ - from google.genai._interactions.types.interaction_complete_event import InteractionCompleteEvent - from google.genai._interactions.types.interaction_start_event import InteractionStartEvent - from google.genai._interactions.types.interaction_status_update import InteractionStatusUpdate - if isinstance(event, InteractionStatusUpdate): return event.interaction_id - if isinstance(event, (InteractionStartEvent, InteractionCompleteEvent)): + if isinstance(event, (InteractionCreatedEvent, InteractionCompletedEvent)): return event.interaction.id - try: - if event.event_type == 'interaction': - return event.id - except AttributeError: - pass + if isinstance(event, Interaction): + return event.id return None -def convert_part_to_interaction_content(part: types.Part) -> Optional[dict]: +def _encode_base64_string(data: bytes) -> str: + """Encode bytes to a base64 string.""" + return base64.b64encode(data).decode('utf-8') + + +def _wrap_content_param_in_step( + content_param: ContentParam, role: str +) -> StepParam: + """Wraps a ContentParam into a UserInputStepParam or ModelOutputStepParam.""" + if role == 'model': + return ModelOutputStepParam(type='model_output', content=[content_param]) + return UserInputStepParam(type='user_input', content=[content_param]) + + +@deprecated( + 'convert_part_to_interaction_content is deprecated and will be removed in' + ' future versions' +) +def convert_part_to_interaction_content(part: types.Part) -> dict | None: """Convert a types.Part to an interaction content dict. Args: @@ -213,45 +254,180 @@ def convert_part_to_interaction_content(part: types.Part) -> Optional[dict]: return None -def convert_content_to_turn(content: types.Content) -> TurnParam: - """Convert a types.Content to a TurnParam dict for interactions API. +def _convert_part_to_interaction_content( + part: types.Part, + role: str = 'user', +) -> StepParam | None: + """Convert a types.Part to an interaction content dict. + + Args: + part: The Part object to convert. + role: The role to wrap the content in ('user' or 'model'). + + Returns: + A StepParam dict representing the interaction content, or None if + the part type is not supported. + """ + if part.text is not None: + return _wrap_content_param_in_step( + TextContentParam(type='text', text=part.text), role + ) + elif part.function_call is not None: + return FunctionCallStepParam( + type='function_call', + id=part.function_call.id or '', + name=part.function_call.name or '', + arguments=part.function_call.args or {}, + ) + elif part.function_response is not None: + + # genai.types.FunctionResponse specifies that + # an error response should be inside an error key + func_resp = part.function_response.response + is_error = False + if isinstance(func_resp, dict) and 'error' in func_resp: + is_error = True + + # Pass the function response through to the interactions API. + # Dict and list values are passed directly — the Interactions API handles + # JSON serialization internally. Pre-serializing with json.dumps() would + # cause double-escaping. + if not isinstance(func_resp, (dict, str, list)): + func_resp = str(func_resp) + logger.debug( + 'Converting function_response: name=%s, call_id=%s', + part.function_response.name, + part.function_response.id, + ) + return FunctionResultStepParam( + type='function_result', + name=part.function_response.name or '', + call_id=part.function_response.id or '', + result=func_resp, + is_error=is_error, + ) + elif part.inline_data is not None: + mime_type = part.inline_data.mime_type or '' + # The interactions API requires inline data to be a base64 encoded string + # when serialized to JSON, otherwise openapi_dumps will raise a TypeError. + data = part.inline_data.data + if isinstance(data, bytes): + data = _encode_base64_string(data) + + if mime_type.startswith('image/'): + return _wrap_content_param_in_step( + ImageContentParam(type='image', data=data, mime_type=mime_type), role + ) + elif mime_type.startswith('audio/'): + return _wrap_content_param_in_step( + AudioContentParam(type='audio', data=data, mime_type=mime_type), role + ) + elif mime_type.startswith('video/'): + return _wrap_content_param_in_step( + VideoContentParam(type='video', data=data, mime_type=mime_type), role + ) + else: + return _wrap_content_param_in_step( + DocumentContentParam(type='document', data=data, mime_type=mime_type), + role, + ) + elif part.file_data is not None: + mime_type = part.file_data.mime_type or '' + if mime_type.startswith('image/'): + return _wrap_content_param_in_step( + ImageContentParam( + type='image', uri=part.file_data.file_uri, mime_type=mime_type + ), + role, + ) + elif mime_type.startswith('audio/'): + return _wrap_content_param_in_step( + AudioContentParam( + type='audio', uri=part.file_data.file_uri, mime_type=mime_type + ), + role, + ) + elif mime_type.startswith('video/'): + return _wrap_content_param_in_step( + VideoContentParam( + type='video', uri=part.file_data.file_uri, mime_type=mime_type + ), + role, + ) + else: + return _wrap_content_param_in_step( + DocumentContentParam( + type='document', uri=part.file_data.file_uri, mime_type=mime_type + ), + role, + ) + elif part.thought: + # part.thought is a boolean indicating this is a thought part + # ThoughtContentParam expects 'signature' (base64 encoded bytes) + thought_result = ThoughtStepParam(type='thought') + if part.thought_signature is not None: + thought_result['signature'] = _encode_base64_string( + part.thought_signature + ) + return thought_result + elif part.code_execution_result is not None: + is_error = part.code_execution_result.outcome in ( + types.Outcome.OUTCOME_FAILED, + types.Outcome.OUTCOME_DEADLINE_EXCEEDED, + ) + return CodeExecutionResultStepParam( + type='code_execution_result', + call_id='', + result=part.code_execution_result.output or '', + is_error=is_error, + ) + elif part.executable_code is not None: + return CodeExecutionCallStepParam( + type='code_execution_call', + id='', + arguments={ + 'code': part.executable_code.code, + 'language': part.executable_code.language, + }, + ) + return None + + +def _convert_content_to_step(content: types.Content) -> list[StepParam]: + """Convert a types.Content to a list of StepParam dicts for interactions API. Args: content: The Content object to convert. Returns: - A TurnParam dictionary for the interactions API. + A list of StepParam dictionaries for the interactions API. """ - contents = [] + steps: list[StepParam] = [] + + role = content.role or 'user' if content.parts: for part in content.parts: - interaction_content = convert_part_to_interaction_content(part) + interaction_content = _convert_part_to_interaction_content(part, role) if interaction_content: - contents.append(interaction_content) + steps.append(interaction_content) - return { - 'role': content.role or 'user', - 'content': contents, - } + return steps -def convert_contents_to_turns( +def _convert_contents_to_steps( contents: list[types.Content], -) -> list[TurnParam]: +) -> list[StepParam]: """Convert a list of Content objects to interactions API input format. Args: contents: The list of Content objects to convert. Returns: - A list of TurnParam dictionaries for the interactions API. + A list of StepParam dictionaries for the interactions API. """ - turns = [] - for content in contents: - turn = convert_content_to_turn(content) - if turn['content']: # Only add turns with content - turns.append(turn) - return turns + return [ + step for content in contents for step in _convert_content_to_step(content) + ] def convert_tools_config_to_interactions_format( @@ -276,7 +452,7 @@ def convert_tools_config_to_interactions_format( # Handle function declarations if tool.function_declarations: for func_decl in tool.function_declarations: - func_tool: dict[str, Any] = { + func_tool: FunctionParam = { 'type': 'function', 'name': func_decl.name, } @@ -288,14 +464,14 @@ def convert_tools_config_to_interactions_format( props = {} for k, v in func_decl.parameters.properties.items(): props[k] = v.model_dump(exclude_none=True) - func_tool['parameters'] = { + + params_dict: dict[str, object] = { 'type': 'object', 'properties': props, } if func_decl.parameters.required: - func_tool['parameters']['required'] = list( - func_decl.parameters.required - ) + params_dict['required'] = list(func_decl.parameters.required) + func_tool['parameters'] = params_dict elif func_decl.parameters_json_schema: func_tool['parameters'] = func_decl.parameters_json_schema interaction_tools.append(func_tool) @@ -319,115 +495,127 @@ def convert_tools_config_to_interactions_format( return interaction_tools -def convert_interaction_output_to_part(output: Output) -> Optional[types.Part]: - """Convert an interaction output content to a types.Part. +def _function_result_to_response( + result: BaseModel | dict[str, Any] | list[Any] | str, +) -> dict[str, Any]: + """Convert a FunctionResultStep result into a FunctionResponse dict. + + The Interactions API types the result as a model, a list of content blocks, + or a plain string, but types.FunctionResponse.response requires a dict. A + dict is returned as-is; other non-dict shapes are wrapped under a 'result' + key. + """ + if isinstance(result, dict): + return result + if isinstance(result, BaseModel): + return result.model_dump() + if isinstance(result, list): + items: list[Any] = [] + for item in result: + if isinstance(item, BaseModel): + items.append(item.model_dump()) + else: + items.append(item) + return {'result': items} + return {'result': result} + + +def _convert_interaction_step_to_parts(step: Step) -> list[types.Part]: + """Convert an interaction output content to a list of types.Part. Args: output: The interaction output object to convert. Returns: - A types.Part object, or None if the output type is not supported. + A list of types.Part objects. """ - if not hasattr(output, 'type'): - return None - - output_type = output.type - - if output_type == 'text': - return types.Part.from_text(text=output.text or '') - elif output_type == 'function_call': + if isinstance(step, ModelOutputStep): + if not step.content: + return [] + + parts = [] + for content in step.content: + if content.type == 'text': + parts.append(types.Part.from_text(text=content.text)) + elif content.type in ['image', 'audio', 'document', 'video']: + if content.data: + parts.append( + types.Part( + inline_data=types.Blob( + data=content.data, + mime_type=content.mime_type, + ) + ) + ) + elif content.uri: + parts.append( + types.Part( + file_data=types.FileData( + file_uri=content.uri, + mime_type=content.mime_type, + ) + ) + ) + return parts + elif isinstance(step, FunctionCallStep): logger.debug( 'Converting function_call output: name=%s, id=%s', - output.name, - output.id, + step.name, + step.id, ) - thought_signature = None - thought_sig_value = getattr(output, 'thought_signature', None) - if thought_sig_value and isinstance(thought_sig_value, str): - # Decode base64 string back to bytes - thought_signature = base64.b64decode(thought_sig_value) - return types.Part( - function_call=types.FunctionCall( - id=output.id, - name=output.name, - args=output.arguments or {}, - ), - thought_signature=thought_signature, - ) - elif output_type == 'function_result': - result = output.result - # Handle different result formats - if isinstance(result, str): - result_value = result - elif hasattr(result, 'items'): - result_value = result.items - else: - result_value = result - return types.Part( - function_response=types.FunctionResponse( - id=output.call_id, - response=result_value, + return [ + types.Part( + function_call=types.FunctionCall( + id=step.id, + name=step.name, + args=step.arguments or {}, + ), ) - ) - elif output_type == 'image': - if output.data: - return types.Part( - inline_data=types.Blob( - data=output.data, - mime_type=output.mime_type, - ) - ) - elif output.uri: - return types.Part( - file_data=types.FileData( - file_uri=output.uri, - mime_type=output.mime_type, - ) - ) - elif output_type == 'audio': - if output.data: - return types.Part( - inline_data=types.Blob( - data=output.data, - mime_type=output.mime_type, - ) - ) - elif output.uri: - return types.Part( - file_data=types.FileData( - file_uri=output.uri, - mime_type=output.mime_type, - ) - ) - elif output_type == 'thought': + ] + elif isinstance(step, FunctionResultStep): + return [ + types.Part( + function_response=types.FunctionResponse( + id=step.call_id or '', + response=_function_result_to_response(step.result), + ) + ) + ] + elif isinstance(step, ThoughtStep): # ThoughtContent has a 'signature' attribute, not 'thought' # These are internal model reasoning and typically not exposed as Parts # Skip thought outputs for now - return None - elif output_type == 'code_execution_result': - return types.Part( - code_execution_result=types.CodeExecutionResult( - output=output.result or '', - outcome=types.Outcome.OUTCOME_FAILED - if output.is_error - else types.Outcome.OUTCOME_OK, + return [] + elif isinstance(step, CodeExecutionResultStep): + return [ + types.Part( + code_execution_result=types.CodeExecutionResult( + output=step.result or '', + outcome=types.Outcome.OUTCOME_FAILED + if step.is_error + else types.Outcome.OUTCOME_OK, + ) ) - ) - elif output_type == 'code_execution_call': - args = output.arguments or {} - return types.Part( - executable_code=types.ExecutableCode( - code=args.get('code', ''), - language=args.get('language', 'PYTHON'), + ] + elif isinstance(step, CodeExecutionCallStep): + args = step.arguments + return [ + types.Part( + executable_code=types.ExecutableCode( + code=args.code, + language=types.Language.PYTHON + if args.language and args.language.lower() == 'python' + else types.Language.LANGUAGE_UNSPECIFIED, + ) ) - ) - elif output_type == 'google_search_result': + ] + elif isinstance(step, GoogleSearchResultStep): # For google search results, we create a text part with the results - if output.result: - results_text = '\n'.join(str(r) for r in output.result if r) - return types.Part.from_text(text=results_text) + if step.result: + results_text = '\n'.join(str(r) for r in step.result if r) + return [types.Part.from_text(text=results_text)] - return None + return [] def convert_interaction_to_llm_response( @@ -443,13 +631,15 @@ def convert_interaction_to_llm_response( """ from .llm_response import LlmResponse - # Check for errors + # Check for errors. Lifecycle SSE events carry a partial interaction + # (InteractionSseEventInteraction) that has no 'error' attribute. if interaction.status == 'failed': error_msg = 'Unknown error' error_code = 'UNKNOWN_ERROR' - if interaction.error: - error_msg = interaction.error.message or error_msg - error_code = interaction.error.code or error_code + error = getattr(interaction, 'error', None) + if error: + error_msg = error.message or error_msg + error_code = error.code or error_code return LlmResponse( error_code=error_code, error_message=error_msg, @@ -458,11 +648,11 @@ def convert_interaction_to_llm_response( # Convert outputs to Content parts parts = [] - if interaction.outputs: - for output in interaction.outputs: - part = convert_interaction_output_to_part(output) - if part: - parts.append(part) + if interaction.steps: + for step in interaction.steps: + step_parts = _convert_interaction_step_to_parts(step) + if step_parts: + parts.extend(step_parts) content = None if parts: @@ -502,8 +692,8 @@ def convert_interaction_to_llm_response( def convert_interaction_event_to_llm_response( event: InteractionSSEEvent, aggregated_parts: list[types.Part], - interaction_id: Optional[str] = None, -) -> Optional[LlmResponse]: + interaction_id: str | None = None, +) -> LlmResponse | None: """Convert an InteractionSSEEvent to an LlmResponse for streaming. Args: @@ -514,19 +704,34 @@ def convert_interaction_event_to_llm_response( Returns: LlmResponse if this event produces one, None otherwise. """ - from .llm_response import LlmResponse - event_type = getattr(event, 'event_type', None) + if isinstance(event, StepStart): + + # Streaming function calls follow a sequence of events (https://ai.google.dev/gemini-api/docs/interactions-breaking-changes-may-2026#streaming): + # 1. StepStart: Delivers the function id and name. + # 2. StepDelta (multiple): Streams arguments as raw JSON strings via arguments. + # 3. StepStop: Signals the end of the step, where arguments are finalized and parsed. + if isinstance(event.step, FunctionCallStep): + fc = types.FunctionCall( + id=event.step.id, + name=event.step.name, + partial_args=[], + ) + part = types.Part(function_call=fc) + aggregated_parts.append(part) - if event_type == 'content.delta': - delta = event.delta - if delta is None: - return None + return LlmResponse( + content=types.Content(role='model', parts=[part]), + partial=True, + turn_complete=False, + interaction_id=interaction_id, + ) - delta_type = getattr(delta, 'type', None) + elif isinstance(event, StepDelta): + delta = event.delta - if delta_type == 'text': - text = delta.text or '' + if delta.type == 'text': + text = delta.text if text: part = types.Part.from_text(text=text) aggregated_parts.append(part) @@ -537,93 +742,121 @@ def convert_interaction_event_to_llm_response( interaction_id=interaction_id, ) - elif delta_type == 'function_call': - # Function calls are typically sent as complete units - # DON'T yield immediately - add to aggregated_parts only. - # The function_call will be yielded in the final response which has - # the correct interaction_id. If we yield here, interaction_id may be - # None because SSE streams the id later in the 'interaction' event. - if delta.name: - thought_signature = None - thought_sig_value = getattr(delta, 'thought_signature', None) - if thought_sig_value and isinstance(thought_sig_value, str): - # Decode base64 string back to bytes - thought_signature = base64.b64decode(thought_sig_value) - part = types.Part( - function_call=types.FunctionCall( - id=delta.id or '', - name=delta.name, - args=delta.arguments or {}, - ), - thought_signature=thought_signature, - ) - aggregated_parts.append(part) - # Return None - function_call will be in the final aggregated response - return None - - elif delta_type == 'image': - if delta.data or delta.uri: - if delta.data: + elif delta.type == 'image': + data = delta.data + uri = delta.uri + mime_type = delta.mime_type + if data or uri: + if data: part = types.Part( inline_data=types.Blob( - data=delta.data, - mime_type=delta.mime_type, + data=data, + mime_type=mime_type, ) ) else: part = types.Part( file_data=types.FileData( - file_uri=delta.uri, - mime_type=delta.mime_type, + file_uri=uri, + mime_type=mime_type, ) ) aggregated_parts.append(part) return LlmResponse( content=types.Content(role='model', parts=[part]), - partial=False, + partial=True, turn_complete=False, interaction_id=interaction_id, ) - elif event_type == 'content.stop': - # Content streaming finished, return aggregated content - if aggregated_parts: - return LlmResponse( - content=types.Content(role='model', parts=list(aggregated_parts)), - partial=False, - turn_complete=False, - interaction_id=interaction_id, - ) + elif delta.type == 'arguments_delta': + if aggregated_parts: + last_part = aggregated_parts[-1] + if last_part.function_call: + delta_args = delta.arguments + if ( + delta_args is not None + and last_part.function_call.partial_args is not None + ): + last_part.function_call.partial_args.append( + types.PartialArg(string_value=delta_args) + ) + + chunk_part = types.Part( + function_call=types.FunctionCall( + name=last_part.function_call.name, + partial_args=[types.PartialArg(string_value=delta_args)], + ) + ) + return LlmResponse( + content=types.Content(role='model', parts=[chunk_part]), + partial=True, + turn_complete=False, + interaction_id=interaction_id, + ) + + elif isinstance(event, StepStop): + if aggregated_parts and aggregated_parts[-1].function_call: + fc = aggregated_parts[-1].function_call + if fc.partial_args is not None: + arg_str = ''.join(pa.string_value or '' for pa in fc.partial_args) + + args = {} + if arg_str: + try: + args = json.loads(arg_str) + except json.JSONDecodeError as e: + logger.error( + 'Failed to parse function call args: %s. arg_str: %s', + e, + arg_str, + ) + fc.args = args + fc.partial_args = None + return LlmResponse( + error_code='JSON_PARSE_ERROR', + error_message='Failed to parse function call arguments', + turn_complete=True, + finish_reason=types.FinishReason.STOP, + interaction_id=interaction_id, + ) + + fc.args = args + fc.partial_args = None - elif event_type == 'interaction': - # Final interaction event with complete data - return convert_interaction_to_llm_response(event) + return None - elif event_type == 'interaction.status_update': - status = getattr(event, 'status', None) - if status in ('completed', 'requires_action'): + elif isinstance(event, InteractionCompletedEvent): + # Final aggregated response + if aggregated_parts: return LlmResponse( - content=types.Content(role='model', parts=list(aggregated_parts)) - if aggregated_parts - else None, + content=types.Content(role='model', parts=aggregated_parts), partial=False, turn_complete=True, finish_reason=types.FinishReason.STOP, interaction_id=interaction_id, ) - elif status == 'failed': - error = getattr(event, 'error', None) + # If no streaming parts were collected, convert the final interaction directly + return convert_interaction_to_llm_response(event.interaction) + + elif isinstance(event, Interaction): + # Fallback for legacy interaction events without lifecycle + return convert_interaction_to_llm_response(event) + + elif isinstance(event, InteractionStatusUpdate): + if event.status == 'failed': return LlmResponse( - error_code=error.code if error else 'UNKNOWN_ERROR', - error_message=error.message if error else 'Unknown error', + error_code='UNKNOWN_ERROR', + error_message='Unknown error', turn_complete=True, interaction_id=interaction_id, ) - elif event_type == 'error': + elif isinstance(event, ErrorEvent): + error = event.error return LlmResponse( - error_code=getattr(event, 'code', 'UNKNOWN_ERROR'), - error_message=getattr(event, 'message', 'Unknown error'), + error_code=error.code if error else 'UNKNOWN_ERROR', + error_message=error.message if error else 'Unknown error', turn_complete=True, interaction_id=interaction_id, ) @@ -633,7 +866,7 @@ def convert_interaction_event_to_llm_response( def build_generation_config( config: types.GenerateContentConfig, -) -> dict[str, Any]: +) -> GenerationConfigParam: """Build generation config dict for interactions API. Args: @@ -642,7 +875,7 @@ def build_generation_config( Returns: A dictionary containing generation configuration parameters. """ - generation_config: dict[str, Any] = {} + generation_config: GenerationConfigParam = {} if config.temperature is not None: generation_config['temperature'] = config.temperature if config.top_p is not None: @@ -662,7 +895,7 @@ def build_generation_config( def extract_system_instruction( config: types.GenerateContentConfig, -) -> Optional[str]: +) -> str | None: """Extract system instruction as a string from config. Args: @@ -679,9 +912,10 @@ def extract_system_instruction( elif isinstance(config.system_instruction, types.Content): # Extract text from Content texts = [] - for part in config.system_instruction.parts: - if part.text: - texts.append(part.text) + if config.system_instruction.parts: + for part in config.system_instruction.parts: + if part.text: + texts.append(part.text) return '\n'.join(texts) if texts else None return None @@ -707,18 +941,18 @@ def _build_tool_log(tool: ToolParam) -> str: def build_interactions_request_log( model: str, - input_turns: list[TurnParam], - system_instruction: Optional[str], - tools: Optional[list[ToolParam]], - generation_config: Optional[dict[str, Any]], - previous_interaction_id: Optional[str], + input_steps: list[StepParam], + system_instruction: str | None, + tools: list[ToolParam] | None, + generation_config: dict[str, object] | None, + previous_interaction_id: str | None, stream: bool, ) -> str: """Build a log string for an interactions API request. Args: model: The model name. - input_turns: The input turns to send. + input_steps: The input steps to send. system_instruction: The system instruction. tools: The tools configuration. generation_config: The generation config. @@ -728,11 +962,11 @@ def build_interactions_request_log( Returns: A formatted log string describing the request. """ - # Format input turns for logging - turns_logs = [] - for turn in input_turns: - role = turn.get('role', 'unknown') - contents = turn.get('content', []) + # Format input steps for logging + steps_logs = [] + for step in input_steps: + role = step.get('role', 'unknown') + contents = step.get('content', []) content_strs = [] for content in contents: content_type = content.get('type', 'unknown') @@ -755,7 +989,7 @@ def build_interactions_request_log( content_strs.append(f'function_result[{call_id}]: {result}') else: content_strs.append(f'{content_type}: ...') - turns_logs.append(f' [{role}]: {", ".join(content_strs)}') + steps_logs.append(f' [{role}]: {", ".join(content_strs)}') # Format tools for logging tools_logs = [] @@ -781,8 +1015,8 @@ def build_interactions_request_log( Generation Config: {config_str} ----------------------------------------------------------- -Input Turns: -{_NEW_LINE.join(turns_logs) if turns_logs else '(none)'} +Input Steps: +{_NEW_LINE.join(steps_logs) if steps_logs else '(none)'} ----------------------------------------------------------- Tools: {_NEW_LINE.join(tools_logs) if tools_logs else '(none)'} @@ -805,17 +1039,17 @@ def build_interactions_response_log(interaction: Interaction) -> str: # Extract outputs outputs_logs = [] - if hasattr(interaction, 'outputs') and interaction.outputs: - for output in interaction.outputs: - output_type = getattr(output, 'type', 'unknown') + if hasattr(interaction, 'steps') and interaction.steps: + for step in interaction.steps: + output_type = getattr(step, 'type', 'unknown') if output_type == 'text': - text = getattr(output, 'text', '') + text = getattr(step, 'text', '') if len(text) > 300: text = text[:300] + '...' outputs_logs.append(f' text: "{text}"') elif output_type == 'function_call': - name = getattr(output, 'name', '') - args = getattr(output, 'arguments', {}) + name = getattr(step, 'name', '') + args = getattr(step, 'arguments', {}) outputs_logs.append(f' function_call: {name}({json.dumps(args)})') else: outputs_logs.append(f' {output_type}: ...') @@ -868,7 +1102,7 @@ def build_interactions_event_log(event: InteractionSSEEvent) -> str: details = [] - if event_type == 'content.delta': + if event_type == 'step.delta': delta = getattr(event, 'delta', None) if delta: delta_type = getattr(delta, 'type', 'unknown') @@ -884,11 +1118,11 @@ def build_interactions_event_log(event: InteractionSSEEvent) -> str: else: details.append(f'{delta_type}: ...') - elif event_type == 'interaction.status_update': + elif event_type in ('interaction.completed', 'interaction.requires_action'): status = getattr(event, 'status', 'unknown') details.append(f'status: {status}') - elif event_type == 'error': + elif event_type == 'interaction.error': code = getattr(event, 'code', 'unknown') message = getattr(event, 'message', 'unknown') details.append(f'error: {code} - {message}') @@ -906,12 +1140,8 @@ def _get_latest_user_contents( For interactions API with previous_interaction_id, we only need to send the current turn's messages since prior history is maintained by - the interaction chain. - - Special handling for function_result: When the user content contains a - function_result (response to a model's function_call), we must also include - the preceding model content with the function_call. The Interactions API - needs both the function_call and function_result to properly match call_ids. + the interaction chain. The preceding model turn with the function_call + is already encapsulated in the previous_interaction_id state. Args: contents: The full list of content messages. @@ -923,41 +1153,16 @@ def _get_latest_user_contents( return [] # Find the latest continuous user messages from the end - latest_user_contents = [] - for content in reversed(contents): + latest_user_contents: list[types.Content] = [] + for i in range(len(contents) - 1, -1, -1): + content = contents[i] if content.role == 'user': - latest_user_contents.insert(0, content) + latest_user_contents.append(content) else: # Stop when we hit a non-user message break - # Check if the user contents contain a function_result - has_function_result = False - for content in latest_user_contents: - if content.parts: - for part in content.parts: - if part.function_response is not None: - has_function_result = True - break - if has_function_result: - break - - # If we have a function_result, we also need the preceding model content - # with the function_call so the API can match the call_id - if has_function_result and len(contents) > len(latest_user_contents): - # Get the index where user contents start - user_start_idx = len(contents) - len(latest_user_contents) - if user_start_idx > 0: - # Check if the content before user contents is a model turn with - # function_call - preceding_content = contents[user_start_idx - 1] - if preceding_content.role == 'model' and preceding_content.parts: - for part in preceding_content.parts: - if part.function_call is not None: - # Include the model's function_call turn before user's - # function_result - return [preceding_content] + latest_user_contents - + latest_user_contents.reverse() return latest_user_contents @@ -983,7 +1188,6 @@ async def generate_content_via_interactions( Yields: LlmResponse objects converted from interaction responses. """ - from .llm_response import LlmResponse # When previous_interaction_id is set, only send the latest continuous # user messages (the current turn) instead of full conversation history @@ -992,7 +1196,7 @@ async def generate_content_via_interactions( contents = _get_latest_user_contents(contents) # Convert contents to interactions API format - input_turns = convert_contents_to_turns(contents) + input_steps = _convert_contents_to_steps(contents) interaction_tools = convert_tools_config_to_interactions_format( llm_request.config ) @@ -1013,8 +1217,8 @@ async def generate_content_via_interactions( logger.debug( build_interactions_request_log( - model=llm_request.model, - input_turns=input_turns, + model=llm_request.model or '', + input_steps=input_steps, system_instruction=system_instruction, tools=interaction_tools if interaction_tools else None, generation_config=generation_config if generation_config else None, @@ -1024,13 +1228,13 @@ async def generate_content_via_interactions( ) # Track the current interaction ID from responses - current_interaction_id: Optional[str] = None + current_interaction_id: str | None = None if stream: # Streaming mode responses = await api_client.aio.interactions.create( model=llm_request.model, - input=input_turns, + input=input_steps, stream=True, system_instruction=system_instruction, tools=interaction_tools if interaction_tools else None, @@ -1052,21 +1256,11 @@ async def generate_content_via_interactions( if llm_response: yield llm_response - # Final aggregated response - if aggregated_parts: - yield LlmResponse( - content=types.Content(role='model', parts=aggregated_parts), - partial=False, - turn_complete=True, - finish_reason=types.FinishReason.STOP, - interaction_id=current_interaction_id, - ) - else: # Non-streaming mode interaction = await api_client.aio.interactions.create( model=llm_request.model, - input=input_turns, + input=input_steps, stream=False, system_instruction=system_instruction, tools=interaction_tools if interaction_tools else None, diff --git a/src/google/adk/models/llm_response.py b/src/google/adk/models/llm_response.py index c921f197c33..333034565ff 100644 --- a/src/google/adk/models/llm_response.py +++ b/src/google/adk/models/llm_response.py @@ -81,6 +81,12 @@ class LlmResponse(BaseModel): Only used for streaming mode. """ + turn_complete_reason: Optional[types.TurnCompleteReason] = None + """The reason why the turn is complete. + + Only used for streaming mode. + """ + finish_reason: Optional[types.FinishReason] = None """The finish reason of the response.""" diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 850c26bbba8..397bb3aca4b 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -66,10 +66,6 @@ logger = logging.getLogger('google_adk.' + __name__) -def _is_tool_call_or_response(event: Event) -> bool: - return bool(event.get_function_calls() or event.get_function_responses()) - - def _get_function_responses_from_content( content: types.Content, ) -> list[types.FunctionResponse]: @@ -80,21 +76,6 @@ def _get_function_responses_from_content( ] -def _is_transcription(event: Event) -> bool: - return ( - event.input_transcription is not None - or event.output_transcription is not None - ) - - -def _has_non_empty_transcription_text( - transcription: types.Transcription, -) -> bool: - return bool( - transcription and transcription.text and transcription.text.strip() - ) - - def _apply_run_config_custom_metadata( event: Event, run_config: RunConfig | None ) -> None: @@ -873,22 +854,6 @@ async def _exec_with_plugin( yield early_exit_event else: # Step 2: Otherwise continue with normal execution - # Note for live/bidi: - # the transcription may arrive later than the action(function call - # event and thus function response event). In this case, the order of - # transcription and function call event will be wrong if we just - # append as it arrives. To address this, we should check if there is - # transcription going on. If there is transcription going on, we - # should hold on appending the function call event until the - # transcription is finished. The transcription in progress can be - # identified by checking if the transcription event is partial. When - # the next transcription event is not partial, it means the previous - # transcription is finished. Then if there is any buffered function - # call event, we should append them after this finished(non-partial) - # transcription event. - buffered_events: list[Event] = [] - is_transcribing: bool = False - async with Aclosing(execute_fn(invocation_context)) as agen: async for event in agen: _apply_run_config_custom_metadata( @@ -906,50 +871,14 @@ async def _exec_with_plugin( ) if is_live_call: - if event.partial and _is_transcription(event): - is_transcribing = True - if is_transcribing and _is_tool_call_or_response(event): - # only buffer function call and function response event which is - # non-partial - buffered_events.append(output_event) - continue - # Note for live/bidi: for audio response, it's considered as - # non-partial event(event.partial=None) - # event.partial=False and event.partial=None are considered as - # non-partial event; event.partial=True is considered as partial - # event. - if event.partial is not True: - if _is_transcription(event) and ( - _has_non_empty_transcription_text(event.input_transcription) - or _has_non_empty_transcription_text( - event.output_transcription - ) - ): - # transcription end signal, append buffered events - is_transcribing = False - logger.debug( - 'Appending transcription finished event: %s', event - ) - if self._should_append_event(event, is_live_call): - await self.session_service.append_event( - session=invocation_context.session, event=output_event - ) - - for buffered_event in buffered_events: - logger.debug('Appending buffered event: %s', buffered_event) - await self.session_service.append_event( - session=invocation_context.session, event=buffered_event - ) - yield buffered_event # yield buffered events to caller - buffered_events = [] - else: - # non-transcription event or empty transcription event, for - # example, event that stores blob reference, should be appended. - if self._should_append_event(event, is_live_call): - logger.debug('Appending non-buffered event: %s', event) - await self.session_service.append_event( - session=invocation_context.session, event=output_event - ) + # Skip partial transcriptions for Live + if event.partial is not True and self._should_append_event( + event, is_live_call + ): + logger.debug('Appending live event: %s', output_event) + await self.session_service.append_event( + session=invocation_context.session, event=output_event + ) else: if event.partial is not True: await self.session_service.append_event( diff --git a/src/google/adk/tools/mcp_tool/session_context.py b/src/google/adk/tools/mcp_tool/session_context.py index 0ad63044d45..5249423cd1e 100644 --- a/src/google/adk/tools/mcp_tool/session_context.py +++ b/src/google/adk/tools/mcp_tool/session_context.py @@ -130,6 +130,12 @@ async def start(self) -> ClientSession: if not self._task: self._task = asyncio.create_task(self._run()) + def _retrieve_exception(t: asyncio.Task): + if not t.cancelled(): + t.exception() + + self._task.add_done_callback(_retrieve_exception) + await self._ready_event.wait() if self._task.cancelled(): diff --git a/src/google/adk/utils/model_name_utils.py b/src/google/adk/utils/model_name_utils.py index b2e032e0d19..dbb3a08193c 100644 --- a/src/google/adk/utils/model_name_utils.py +++ b/src/google/adk/utils/model_name_utils.py @@ -172,4 +172,5 @@ def is_gemini_3_1_flash_live(model_string: Optional[str]) -> bool: """ if not model_string: return False - return model_string.startswith('gemini-3.1-flash-live') + model_name = extract_model_name(model_string) + return model_name.startswith('gemini-3.1-flash-live') diff --git a/src/google/adk/utils/streaming_utils.py b/src/google/adk/utils/streaming_utils.py index af6957b8fef..4ffd63b001d 100644 --- a/src/google/adk/utils/streaming_utils.py +++ b/src/google/adk/utils/streaming_utils.py @@ -349,61 +349,60 @@ def close(self) -> Optional[LlmResponse]: Returns: The aggregated LlmResponse. """ + if not self._response: + return None + + candidate = ( + self._response.candidates[0] if self._response.candidates else None + ) + + finish_reason = self._finish_reason + if not finish_reason and candidate: + finish_reason = candidate.finish_reason + + error_code = None + error_message = None + if finish_reason and finish_reason != types.FinishReason.STOP: + error_code = finish_reason + error_message = candidate.finish_message if candidate else None + elif not candidate and self._response.prompt_feedback: + error_code = self._response.prompt_feedback.block_reason + error_message = self._response.prompt_feedback.block_reason_message + # ========== Progressive SSE Streaming (new feature) ========== if is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING): - # Always generate final aggregated response in progressive mode - if self._response and self._response.candidates: - # Flush any remaining buffers to complete the sequence - self._flush_text_buffer_to_sequence() - self._flush_function_call_to_sequence() - - # Use the parts sequence which preserves original ordering - final_parts = self._parts_sequence - - if final_parts: - candidate = self._response.candidates[0] - finish_reason = self._finish_reason or candidate.finish_reason - - return LlmResponse( - content=types.ModelContent(parts=final_parts), - grounding_metadata=self._grounding_metadata, - citation_metadata=self._citation_metadata, - error_code=None - if finish_reason == types.FinishReason.STOP - else finish_reason, - error_message=None - if finish_reason == types.FinishReason.STOP - else candidate.finish_message, - usage_metadata=self._usage_metadata, - finish_reason=finish_reason, - partial=False, - ) - - return None + self._flush_text_buffer_to_sequence() + self._flush_function_call_to_sequence() + + final_parts = self._parts_sequence + content = types.ModelContent(parts=final_parts) if final_parts else None - # ========== Non-Progressive SSE Streaming (old behavior) ========== - if ( - (self._text or self._thought_text) - and self._response - and self._response.candidates - ): - parts = [] - if self._thought_text: - parts.append(types.Part(text=self._thought_text, thought=True)) - if self._text: - parts.append(types.Part.from_text(text=self._text)) - candidate = self._response.candidates[0] return LlmResponse( - content=types.ModelContent(parts=parts), + content=content, grounding_metadata=self._grounding_metadata, citation_metadata=self._citation_metadata, - error_code=None - if candidate.finish_reason == types.FinishReason.STOP - else candidate.finish_reason, - error_message=None - if candidate.finish_reason == types.FinishReason.STOP - else candidate.finish_message, + error_code=error_code, + error_message=error_message, usage_metadata=self._usage_metadata, + finish_reason=finish_reason, + partial=False, ) - return None + # ========== Non-Progressive SSE Streaming (old behavior) ========== + parts = [] + if self._thought_text: + parts.append(types.Part(text=self._thought_text, thought=True)) + if self._text: + parts.append(types.Part.from_text(text=self._text)) + content = types.ModelContent(parts=parts) if parts else None + + return LlmResponse( + content=content, + grounding_metadata=self._grounding_metadata, + citation_metadata=self._citation_metadata, + error_code=error_code, + error_message=error_message, + usage_metadata=self._usage_metadata, + finish_reason=finish_reason, + partial=False, + ) diff --git a/src/google/adk/version.py b/src/google/adk/version.py index cf2713c03f2..14080175ffa 100644 --- a/src/google/adk/version.py +++ b/src/google/adk/version.py @@ -13,4 +13,4 @@ # limitations under the License. # version: major.minor.patch -__version__ = "1.34.0" +__version__ = "1.36.0" diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow.py b/tests/unittests/flows/llm_flows/test_base_llm_flow.py index ce2e83b6f72..2b8eb92d3a6 100644 --- a/tests/unittests/flows/llm_flows/test_base_llm_flow.py +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow.py @@ -17,12 +17,15 @@ from unittest import mock from unittest.mock import AsyncMock +from google.adk.agents.live_request_queue import LiveRequestQueue from google.adk.agents.llm_agent import Agent from google.adk.agents.run_config import RunConfig from google.adk.events.event import Event from google.adk.flows.llm_flows.base_llm_flow import _handle_after_model_callback +from google.adk.flows.llm_flows.base_llm_flow import _ReconnectSentinel from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow from google.adk.models.google_llm import Gemini +from google.adk.models.google_llm import GoogleLLMVariant from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.adk.plugins.base_plugin import BasePlugin @@ -30,6 +33,7 @@ from google.adk.tools.google_search_tool import GoogleSearchTool from google.genai import types import pytest +from websockets.exceptions import ConnectionClosed from ... import testing_utils @@ -490,8 +494,6 @@ async def call(self, **kwargs): @pytest.mark.asyncio async def test_run_live_reconnects_on_connection_closed(): """Test that run_live reconnects when ConnectionClosed occurs.""" - from google.adk.agents.live_request_queue import LiveRequestQueue - from websockets.exceptions import ConnectionClosed real_model = Gemini() mock_connection = mock.AsyncMock() @@ -558,7 +560,6 @@ async def mock_receive_2(): @pytest.mark.asyncio async def test_run_live_reconnects_on_api_error(): """Test that run_live reconnects when APIError occurs.""" - from google.adk.agents.live_request_queue import LiveRequestQueue from google.genai.errors import APIError real_model = Gemini() @@ -626,7 +627,6 @@ async def mock_receive_2(): @pytest.mark.asyncio async def test_run_live_skips_send_history_on_resumption(): """Test that run_live skips send_history when resuming a session.""" - from google.adk.agents.live_request_queue import LiveRequestQueue real_model = Gemini() mock_connection = mock.AsyncMock() @@ -684,7 +684,6 @@ async def mock_receive(): @pytest.mark.asyncio async def test_live_session_resumption_go_away(): """Test that go_away triggers reconnection.""" - from google.adk.agents.live_request_queue import LiveRequestQueue real_model = Gemini() mock_connection = mock.AsyncMock() @@ -730,21 +729,27 @@ async def mock_receive_2(): ) as mock_connect: mock_connect.return_value.__aenter__ = mock_aenter + yielded_events = [] try: - async for _ in flow.run_live(invocation_context): - pass + async for event in flow.run_live(invocation_context): + yielded_events.append(event) except StopError: pass # Verify that we attempted to connect twice (initial + reconnect after go_away). assert mock_connect.call_count == 2 + # Verify that the internal _ReconnectSentinel is not leaked/yielded to the caller. + assert not any(isinstance(e, _ReconnectSentinel) for e in yielded_events) + + # Verify we yielded the expected response after reconnection. + assert len(yielded_events) == 1 + assert yielded_events[0].content.parts[0].text == 'hi' + @pytest.mark.asyncio async def test_run_live_no_reconnect_without_handle(): """Test that run_live does not reconnect when handle is missing.""" - from google.adk.agents.live_request_queue import LiveRequestQueue - from websockets.exceptions import ConnectionClosed real_model = Gemini() mock_connection = mock.AsyncMock() @@ -786,8 +791,6 @@ async def mock_receive(): @pytest.mark.asyncio async def test_run_live_reconnect_limit(): """Test that run_live stops reconnecting after 5 attempts.""" - from google.adk.agents.live_request_queue import LiveRequestQueue - from websockets.exceptions import ConnectionClosed real_model = Gemini() @@ -796,17 +799,17 @@ async def test_run_live_reconnect_limit(): async def mock_connect_impl(*args, **kwargs): nonlocal connection_cnt connection_cnt += 1 + if connection_cnt > 1: + raise ConnectionClosed(None, None) conn = mock.AsyncMock() async def mock_receive(): - if connection_cnt == 1: - # Yield handle only on the first connection. - yield LlmResponse( - live_session_resumption_update=types.LiveServerSessionResumptionUpdate( - new_handle='test_handle' - ), - turn_complete=True, - ) + yield LlmResponse( + live_session_resumption_update=types.LiveServerSessionResumptionUpdate( + new_handle='test_handle' + ), + turn_complete=True, + ) # All subsequent receives (and all receives on later connections) fail. raise ConnectionClosed(None, None) @@ -842,10 +845,8 @@ async def mock_receive(): @pytest.mark.asyncio async def test_run_live_reconnect_reset_attempt(): - """Test that attempt counter is reset on successful communication.""" - from google.adk.agents.live_request_queue import LiveRequestQueue + """Test that attempt counter is reset on successful connection establishment.""" from google.adk.flows.llm_flows.base_llm_flow import DEFAULT_MAX_RECONNECT_ATTEMPTS - from websockets.exceptions import ConnectionClosed real_model = Gemini() @@ -854,23 +855,29 @@ async def test_run_live_reconnect_reset_attempt(): async def mock_connect_impl(*args, **kwargs): nonlocal connection_cnt connection_cnt += 1 - conn = mock.AsyncMock() + # Establish connection successfully on attempts 1, 2, and 5 + if connection_cnt in (1, 2, 5): + conn = mock.AsyncMock() - async def mock_receive(): - if connection_cnt <= 2: - # Yield handle on the first two connections. - yield LlmResponse( - live_session_resumption_update=types.LiveServerSessionResumptionUpdate( - new_handle='test_handle' - ), - turn_complete=True, - ) - # All subsequent receives fail. + async def mock_receive(): + if connection_cnt == 1: + yield LlmResponse( + live_session_resumption_update=types.LiveServerSessionResumptionUpdate( + new_handle='test_handle' + ), + turn_complete=True, + ) + else: + if False: + yield + raise ConnectionClosed(None, None) + + conn.receive = mock.Mock(side_effect=mock_receive) + return conn + else: + # Failed connection establishments on other attempts raise ConnectionClosed(None, None) - conn.receive = mock.Mock(side_effect=mock_receive) - return conn - agent = Agent(name='test_agent', model=real_model) invocation_context = await testing_utils.create_invocation_context( agent=agent @@ -891,9 +898,13 @@ async def mock_receive(): async for _ in flow.run_live(invocation_context): pass - # We expect 2 successful attempts + DEFAULT_MAX_RECONNECT_ATTEMPTS failed attempts - # Total calls = 2 + 5 = 7 - assert mock_connect.call_count == DEFAULT_MAX_RECONNECT_ATTEMPTS + 2 + # Connection 1: succeeds (resets to 1), yields handle, receive raises ConnectionClosed. + # Connection 2: succeeds (resets to 1), receive raises ConnectionClosed. + # Connection 3: fails (attempt becomes 2) + # Connection 4: fails (attempt becomes 3) + # Connection 5: succeeds (resets to 1), receive raises ConnectionClosed. + # Connection 6-10: fail. Connection 10 has attempt = 6 > DEFAULT_MAX_RECONNECT_ATTEMPTS (5), so raises and terminates. + assert mock_connect.call_count == DEFAULT_MAX_RECONNECT_ATTEMPTS + 5 @pytest.mark.asyncio @@ -987,7 +998,6 @@ async def mock_receive(): @pytest.mark.asyncio async def test_run_live_clears_resumption_handle_on_transfer(): """Test that run_live clears session resumption handles when transferring to another agent.""" - from google.adk.agents.live_request_queue import LiveRequestQueue agent = Agent(name='test_agent') invocation_context = await testing_utils.create_invocation_context( @@ -1069,3 +1079,345 @@ async def mock_run_live_sub_agent(child_ctx, *args, **kwargs): assert ( invocation_context.run_config.session_resumption.handle == 'test_handle' ) + + +@pytest.mark.asyncio +async def test_postprocess_live_yields_grounding_metadata_only(): + """Test that _postprocess_live yields LlmResponse with only grounding_metadata.""" + agent = Agent(name='test_agent') + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + flow = BaseLlmFlowForTesting() + + llm_request = LlmRequest() + grounding_metadata = types.GroundingMetadata( + web_search_queries=['test query'], + ) + llm_response = LlmResponse(grounding_metadata=grounding_metadata) + model_response_event = Event( + id=Event.new_id(), + invocation_id=invocation_context.invocation_id, + author=agent.name, + ) + + events = [] + async for event in flow._postprocess_live( + invocation_context, llm_request, llm_response, model_response_event + ): + events.append(event) + + assert len(events) == 1 + assert events[0].grounding_metadata == grounding_metadata + + +@pytest.mark.asyncio +async def test_postprocess_async_yields_grounding_metadata_only(): + """Test that _postprocess_async yields LlmResponse with only grounding_metadata.""" + agent = Agent(name='test_agent') + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + flow = BaseLlmFlowForTesting() + + llm_request = LlmRequest() + grounding_metadata = types.GroundingMetadata( + web_search_queries=['test query'], + ) + llm_response = LlmResponse(grounding_metadata=grounding_metadata) + model_response_event = Event( + id=Event.new_id(), + invocation_id=invocation_context.invocation_id, + author=agent.name, + ) + + events = [] + async for event in flow._postprocess_async( + invocation_context, llm_request, llm_response, model_response_event + ): + events.append(event) + + assert len(events) == 1 + assert events[0].grounding_metadata == grounding_metadata + + +@pytest.mark.asyncio +async def test_run_live_reconnect_does_not_set_transparent(): + """Test that run_live reconnect does not set transparent=True.""" + + real_model = Gemini() + mock_connection = mock.AsyncMock() + + async def mock_receive(): + yield LlmResponse( + live_session_resumption_update=types.LiveServerSessionResumptionUpdate( + new_handle='test_handle' + ) + ) + raise ConnectionClosed(None, None) + + mock_connection.receive = mock.Mock(side_effect=mock_receive) + + agent = Agent(name='test_agent', model=real_model) + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + invocation_context.live_request_queue = LiveRequestQueue() + invocation_context.run_config = RunConfig() + + flow = BaseLlmFlowForTesting() + + with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock): + + async def mock_preprocess(ctx, req): + req.live_connect_config.session_resumption = ( + ctx.run_config.session_resumption + ) + yield Event(id=Event.new_id(), author='test') + + with mock.patch.object( + flow, '_preprocess_async', side_effect=mock_preprocess + ): + mock_connection_2 = mock.AsyncMock() + + class StopTestError(Exception): + pass + + async def mock_receive_2(): + yield LlmResponse( + content=types.Content(parts=[types.Part.from_text(text='hi')]) + ) + raise StopTestError('stop') + + mock_connection_2.receive = mock.Mock(side_effect=mock_receive_2) + + mock_aenter = mock.AsyncMock() + mock_aenter.side_effect = [mock_connection, mock_connection_2] + + with mock.patch.object( + Gemini, '_api_backend', new_callable=mock.PropertyMock + ) as mock_backend: + mock_backend.return_value = GoogleLLMVariant.GEMINI_API + with mock.patch( + 'google.adk.models.google_llm.Gemini.connect' + ) as mock_connect: + mock_connect.return_value.__aenter__ = mock_aenter + + try: + async for _ in flow.run_live(invocation_context): + pass + except StopTestError: + pass + + assert mock_connect.call_count == 2 + second_call_req = mock_connect.call_args_list[1][0][0] + session_resump = ( + second_call_req.live_connect_config.session_resumption + ) + assert session_resump.transparent is None + + +@pytest.mark.asyncio +async def test_run_live_reconnect_sets_transparent_for_vertex(): + """Test that run_live reconnect sets transparent=True for vertex backend.""" + + real_model = Gemini( + model='projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash-exp' + ) + mock_connection = mock.AsyncMock() + + async def mock_receive(): + yield LlmResponse( + live_session_resumption_update=types.LiveServerSessionResumptionUpdate( + new_handle='test_handle' + ) + ) + raise ConnectionClosed(None, None) + + mock_connection.receive = mock.Mock(side_effect=mock_receive) + + agent = Agent(name='test_agent', model=real_model) + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + invocation_context.live_request_queue = LiveRequestQueue() + invocation_context.run_config = RunConfig() + + flow = BaseLlmFlowForTesting() + + with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock): + + async def mock_preprocess(ctx, req): + req.live_connect_config.session_resumption = ( + ctx.run_config.session_resumption + ) + yield Event(id=Event.new_id(), author='test') + + with mock.patch.object( + flow, '_preprocess_async', side_effect=mock_preprocess + ): + mock_connection_2 = mock.AsyncMock() + + class StopTestError(Exception): + pass + + async def mock_receive_2(): + yield LlmResponse( + content=types.Content(parts=[types.Part.from_text(text='hi')]) + ) + raise StopTestError('stop') + + mock_connection_2.receive = mock.Mock(side_effect=mock_receive_2) + + mock_aenter = mock.AsyncMock() + mock_aenter.side_effect = [mock_connection, mock_connection_2] + + with mock.patch( + 'google.adk.models.google_llm.Gemini.connect' + ) as mock_connect: + mock_connect.return_value.__aenter__ = mock_aenter + + try: + async for _ in flow.run_live(invocation_context): + pass + except StopTestError: + pass + + assert mock_connect.call_count == 2 + second_call_req = mock_connect.call_args_list[1][0][0] + session_resump = second_call_req.live_connect_config.session_resumption + assert session_resump.transparent + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'api_backend', + [ + GoogleLLMVariant.GEMINI_API, + GoogleLLMVariant.VERTEX_AI, + ], +) +async def test_run_live_history_config_set_for_all_backends(api_backend): + """Test that run_live sets history_config for all backends.""" + + real_model = Gemini(model='gemini-3.1-flash-live-preview') + mock_connection = mock.AsyncMock() + + class StopTestError(Exception): + pass + + async def mock_receive(): + yield LlmResponse( + content=types.Content(parts=[types.Part.from_text(text='hi')]) + ) + raise StopTestError('stop') + + mock_connection.receive = mock.Mock(side_effect=mock_receive) + + agent = Agent(name='test_agent', model=real_model) + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + invocation_context.live_request_queue = LiveRequestQueue() + + flow = BaseLlmFlowForTesting() + + with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock): + + async def mock_preprocess(ctx, req): + req.model = 'gemini-3.1-flash-live-preview' + req.contents = [ + types.Content(parts=[types.Part.from_text(text='history')]) + ] + yield Event(id=Event.new_id(), author='test') + + with mock.patch.object( + flow, '_preprocess_async', side_effect=mock_preprocess + ): + with mock.patch.object( + Gemini, '_api_backend', new_callable=mock.PropertyMock + ) as mock_backend: + mock_backend.return_value = api_backend + with mock.patch( + 'google.adk.models.google_llm.Gemini.connect' + ) as mock_connect: + mock_connect.return_value.__aenter__.return_value = mock_connection + + try: + async for _ in flow.run_live(invocation_context): + pass + except StopTestError: + pass + + assert mock_connect.call_count == 1 + called_req = mock_connect.call_args[0][0] + assert called_req.live_connect_config is not None + assert called_req.live_connect_config.history_config is not None + assert ( + called_req.live_connect_config.history_config.initial_history_in_client_content + is True + ) + + +@pytest.mark.asyncio +async def test_run_live_respects_explicit_initial_history_in_client_content_false(): + """Test that run_live respects explicit initial_history_in_client_content=False in RunConfig.""" + + real_model = Gemini() + mock_connection = mock.AsyncMock() + + agent = Agent(name='test_agent', model=real_model) + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + invocation_context.live_request_queue = LiveRequestQueue() + run_config = RunConfig( + history_config=types.HistoryConfig( + initial_history_in_client_content=False + ) + ) + invocation_context.run_config = run_config + + flow = BaseLlmFlowForTesting() + + async def mock_preprocess(ctx, req): + req.contents = [types.Content(parts=[types.Part.from_text(text='history')])] + from google.adk.flows.llm_flows.basic import _build_basic_request + + _build_basic_request(ctx, req) + yield Event(id=Event.new_id(), author='test') + + with mock.patch.object( + flow, '_preprocess_async', side_effect=mock_preprocess + ): + with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock): + + class StopTestError(Exception): + pass + + async def mock_receive(): + yield LlmResponse( + content=types.Content(parts=[types.Part.from_text(text='hi')]) + ) + raise StopTestError('stop') + + mock_connection.receive = mock.Mock(side_effect=mock_receive) + + with mock.patch( + 'google.adk.models.google_llm.Gemini.connect' + ) as mock_connect: + mock_connect.return_value.__aenter__.return_value = mock_connection + + try: + async for _ in flow.run_live(invocation_context): + pass + except StopTestError: + pass + + assert mock_connect.call_count == 1 + call_req = mock_connect.call_args[0][0] + assert call_req.live_connect_config.history_config is not None + assert ( + call_req.live_connect_config.history_config.initial_history_in_client_content + is False + ) diff --git a/tests/unittests/integrations/crewai/test_crewai_tool.py b/tests/unittests/integrations/crewai/test_crewai_tool.py index f7f9bfe0bd6..eda884da600 100644 --- a/tests/unittests/integrations/crewai/test_crewai_tool.py +++ b/tests/unittests/integrations/crewai/test_crewai_tool.py @@ -16,9 +16,11 @@ import pytest -# Skip entire module if Python < 3.10 (must be before crewai_tool import) +# Skip the module when the optional crewai dependency is not installed. Guard on +# the third-party dep itself rather than the adk wrapper, so a real import bug in +# crewai_tool surfaces as a failure instead of being silently skipped. pytest.importorskip( - "google.adk.integrations.crewai.crewai_tool", reason="Requires Python 3.10+" + "crewai.tools", reason="Requires crewai (google-adk[extensions])" ) from google.adk.agents.context import Context diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index 133a4557388..54f8706a964 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -285,11 +285,17 @@ async def mock_receive_generator(): @pytest.mark.asyncio +@pytest.mark.parametrize( + 'conn_fixture', + ['gemini_api_connection', 'gemini_connection'], +) async def test_receive_transcript_finished_on_interrupt( - gemini_api_connection, + conn_fixture, mock_gemini_session, + request, ): """Test receive finishes transcription on interrupt signal.""" + connection = request.getfixturevalue(conn_fixture) message1 = mock.Mock() message1.usage_metadata = None @@ -345,7 +351,7 @@ async def mock_receive_generator(): receive_mock = mock.Mock(return_value=mock_receive_generator()) mock_gemini_session.receive = receive_mock - responses = [resp async for resp in gemini_api_connection.receive()] + responses = [resp async for resp in connection.receive()] assert len(responses) == 5 assert responses[4].interrupted is True @@ -365,11 +371,17 @@ async def mock_receive_generator(): @pytest.mark.asyncio +@pytest.mark.parametrize( + 'conn_fixture', + ['gemini_api_connection', 'gemini_connection'], +) async def test_receive_transcript_finished_on_generation_complete( - gemini_api_connection, + conn_fixture, mock_gemini_session, + request, ): """Test receive finishes transcription on generation_complete signal.""" + connection = request.getfixturevalue(conn_fixture) message1 = mock.Mock() message1.usage_metadata = None @@ -425,7 +437,7 @@ async def mock_receive_generator(): receive_mock = mock.Mock(return_value=mock_receive_generator()) mock_gemini_session.receive = receive_mock - responses = [resp async for resp in gemini_api_connection.receive()] + responses = [resp async for resp in connection.receive()] assert len(responses) == 4 @@ -444,11 +456,17 @@ async def mock_receive_generator(): @pytest.mark.asyncio +@pytest.mark.parametrize( + 'conn_fixture', + ['gemini_api_connection', 'gemini_connection'], +) async def test_receive_transcript_finished_on_turn_complete( - gemini_api_connection, + conn_fixture, mock_gemini_session, + request, ): """Test receive finishes transcription on interrupt or complete signals.""" + connection = request.getfixturevalue(conn_fixture) message1 = mock.Mock() message1.usage_metadata = None @@ -504,7 +522,7 @@ async def mock_receive_generator(): receive_mock = mock.Mock(return_value=mock_receive_generator()) mock_gemini_session.receive = receive_mock - responses = [resp async for resp in gemini_api_connection.receive()] + responses = [resp async for resp in connection.receive()] assert len(responses) == 5 assert responses[4].turn_complete is True @@ -867,6 +885,7 @@ async def test_receive_grounding_metadata_standalone( mock_server_content.interrupted = False mock_server_content.input_transcription = None mock_server_content.output_transcription = None + mock_server_content.generation_complete = False mock_message = mock.create_autospec(types.LiveServerMessage, instance=True) mock_message.usage_metadata = None @@ -911,6 +930,7 @@ async def test_receive_grounding_metadata_with_content( mock_server_content.interrupted = False mock_server_content.input_transcription = None mock_server_content.output_transcription = None + mock_server_content.generation_complete = False mock_message = mock.create_autospec(types.LiveServerMessage, instance=True) mock_message.usage_metadata = None @@ -981,6 +1001,7 @@ async def test_receive_tool_call_and_grounding_metadata_with_native_audio( mock_server_content.interrupted = False mock_server_content.input_transcription = None mock_server_content.output_transcription = None + mock_server_content.generation_complete = False mock_metadata_msg = mock.create_autospec( types.LiveServerMessage, instance=True @@ -1001,6 +1022,7 @@ async def test_receive_tool_call_and_grounding_metadata_with_native_audio( mock_turn_complete_content.interrupted = False mock_turn_complete_content.input_transcription = None mock_turn_complete_content.output_transcription = None + mock_turn_complete_content.generation_complete = False mock_turn_complete_msg = mock.create_autospec( types.LiveServerMessage, instance=True @@ -1240,3 +1262,482 @@ async def mock_receive_generator(): content_response = next((r for r in responses if r.content), None) assert content_response is not None assert content_response.content == mock_content + + +@pytest.mark.asyncio +async def test_receive_grounding_metadata_pending( + gemini_connection, mock_gemini_session +): + """Test that grounding metadata in partial chunks is pending and yielded on full text.""" + grounding_metadata = types.GroundingMetadata( + web_search_queries=['stock price of google'], + ) + + def make_msg(text=None, g_meta=None, tc=False): + msg = mock.Mock( + usage_metadata=None, + tool_call=None, + session_resumption_update=None, + go_away=None, + ) + msg.server_content = mock.Mock( + interrupted=False, + input_transcription=None, + output_transcription=None, + generation_complete=False, + turn_complete=tc, + grounding_metadata=g_meta, + model_turn=types.Content( + role='model', parts=[types.Part.from_text(text=text)] + ) + if text + else None, + ) + return msg + + msg1 = make_msg(text='hello', g_meta=grounding_metadata) + msg2 = make_msg(text=' world') + msg3 = make_msg(tc=True) + + async def gen(): + yield msg1 + yield msg2 + yield msg3 + + mock_gemini_session.receive = mock.Mock(return_value=gen()) + + responses = [resp async for resp in gemini_connection.receive()] + + # Expected responses: + # 1. Msg 1 partial (hello) with grounding_metadata + # 2. Msg 2 partial ( world) without grounding_metadata + # 3. Full text response (hello world) with PENDING grounding_metadata + # 4. Turn complete response without grounding_metadata (already cleared) + assert len(responses) == 4 + + assert responses[0].content.parts[0].text == 'hello' + assert responses[0].partial is True + assert responses[0].grounding_metadata == grounding_metadata + + assert responses[1].content.parts[0].text == ' world' + assert responses[1].partial is True + assert responses[1].grounding_metadata is None + + assert responses[2].content.parts[0].text == 'hello world' + assert responses[2].partial is False + assert responses[2].grounding_metadata == grounding_metadata + + assert responses[3].turn_complete is True + assert responses[3].grounding_metadata is None + + +@pytest.mark.asyncio +async def test_receive_populates_turn_complete_reason( + gemini_connection, mock_gemini_session +): + """Test that receive populates turn_complete_reason in LlmResponse.""" + mock_server_content = mock.create_autospec( + types.LiveServerContent, instance=True + ) + mock_server_content.model_turn = None + mock_server_content.grounding_metadata = None + mock_server_content.turn_complete = True + mock_server_content.interrupted = False + mock_server_content.input_transcription = None + mock_server_content.output_transcription = None + mock_server_content.generation_complete = False + mock_server_content.turn_complete_reason = ( + types.TurnCompleteReason.RESPONSE_REJECTED + ) + + mock_message = mock.create_autospec(types.LiveServerMessage, instance=True) + mock_message.usage_metadata = None + mock_message.server_content = mock_server_content + mock_message.tool_call = None + mock_message.session_resumption_update = None + mock_message.go_away = None + + async def mock_receive_generator(): + yield mock_message + + mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator()) + + responses = [resp async for resp in gemini_connection.receive()] + + assert len(responses) == 1 + assert responses[0].turn_complete is True + assert ( + responses[0].turn_complete_reason + == types.TurnCompleteReason.RESPONSE_REJECTED + ) + + +@pytest.mark.asyncio +async def test_receive_populates_turn_complete_reason_standalone_grounding( + gemini_connection, mock_gemini_session +): + """Test that receive populates turn_complete_reason in LlmResponse for standalone grounding metadata.""" + mock_server_content = mock.create_autospec( + types.LiveServerContent, instance=True + ) + mock_server_content.model_turn = None + mock_server_content.grounding_metadata = mock.create_autospec( + types.GroundingMetadata, instance=True + ) + mock_server_content.turn_complete = False + mock_server_content.interrupted = False + mock_server_content.input_transcription = None + mock_server_content.output_transcription = None + mock_server_content.generation_complete = False + mock_server_content.turn_complete_reason = ( + types.TurnCompleteReason.RESPONSE_REJECTED + ) + + mock_message = mock.create_autospec(types.LiveServerMessage, instance=True) + mock_message.usage_metadata = None + mock_message.server_content = mock_server_content + mock_message.tool_call = None + mock_message.session_resumption_update = None + mock_message.go_away = None + + async def mock_receive_generator(): + yield mock_message + + mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator()) + + responses = [resp async for resp in gemini_connection.receive()] + + assert len(responses) == 1 + assert responses[0].grounding_metadata is not None + assert responses[0].turn_complete is None + assert ( + responses[0].turn_complete_reason + == types.TurnCompleteReason.RESPONSE_REJECTED + ) + + +@pytest.mark.asyncio +async def test_receive_populates_turn_complete_reason_with_content( + gemini_connection, mock_gemini_session +): + """Test that receive populates turn_complete_reason in LlmResponse when model turn has content parts.""" + mock_content = types.Content( + role='model', + parts=[types.Part.from_text(text='hello')], + ) + mock_server_content = mock.create_autospec( + types.LiveServerContent, instance=True + ) + mock_server_content.model_turn = mock_content + mock_server_content.grounding_metadata = None + mock_server_content.turn_complete = False + mock_server_content.interrupted = False + mock_server_content.input_transcription = None + mock_server_content.output_transcription = None + mock_server_content.generation_complete = False + mock_server_content.turn_complete_reason = ( + types.TurnCompleteReason.RESPONSE_REJECTED + ) + + mock_message = mock.create_autospec(types.LiveServerMessage, instance=True) + mock_message.usage_metadata = None + mock_message.server_content = mock_server_content + mock_message.tool_call = None + mock_message.session_resumption_update = None + mock_message.go_away = None + + async def mock_receive_generator(): + yield mock_message + + mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator()) + + responses = [resp async for resp in gemini_connection.receive()] + + assert len(responses) == 1 + assert responses[0].content == mock_content + assert ( + responses[0].turn_complete_reason + == types.TurnCompleteReason.RESPONSE_REJECTED + ) + + +@pytest.mark.asyncio +async def test_receive_multiplexed_parts( + gemini_connection, mock_gemini_session +): + """Test receive with multiplexed inline data and text content.""" + mock_content = types.Content( + role='model', + parts=[ + types.Part( + inline_data=types.Blob(data=b'audio_data', mime_type='audio/pcm') + ), + types.Part.from_text(text='transcription text'), + ], + ) + mock_server_content = mock.Mock() + mock_server_content.model_turn = mock_content + mock_server_content.interrupted = False + mock_server_content.input_transcription = None + mock_server_content.output_transcription = None + mock_server_content.turn_complete = False + mock_server_content.grounding_metadata = None + + mock_message = mock.AsyncMock() + mock_message.usage_metadata = None + mock_message.server_content = mock_server_content + mock_message.tool_call = None + mock_message.session_resumption_update = None + mock_message.go_away = None + + async def mock_receive_generator(): + yield mock_message + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_connection.receive()] + + assert responses + content_response = next((r for r in responses if r.content), None) + assert content_response is not None + assert content_response.content == mock_content + assert content_response.partial is True + + +@pytest.mark.asyncio +async def test_send_history_gemini_31_turn_complete(mock_gemini_session): + """Verify Gemini 3.1 Live history seeding explicitly appends turn_complete=True.""" + from google.adk.models.google_llm import GoogleLLMVariant + + conn = GeminiLlmConnection( + mock_gemini_session, + api_backend=GoogleLLMVariant.GEMINI_API, + model_version='gemini-3.1-flash-live-preview', + ) + mock_gemini_session.send_client_content = mock.AsyncMock() + + mock_contents = [ + types.Content(role='user', parts=[types.Part.from_text(text='hi')]), + types.Content(role='model', parts=[types.Part.from_text(text='hello')]), + ] + await conn.send_history(mock_contents) + + mock_gemini_session.send_client_content.assert_called_once_with( + turns=mock_contents, + turn_complete=True, + ) + + +@pytest.mark.asyncio +async def test_send_history_collapse_vertex_ai(mock_gemini_session): + """Verify history prompt collapse when seeding Gemini 3.1 Live on Vertex AI backend.""" + from google.adk.models.google_llm import GoogleLLMVariant + + conn = GeminiLlmConnection( + mock_gemini_session, + api_backend=GoogleLLMVariant.VERTEX_AI, + model_version='gemini-3.1-flash-live-preview', + ) + mock_gemini_session.send_client_content = mock.AsyncMock() + + mock_contents = [ + types.Content(role='user', parts=[types.Part.from_text(text='hi')]), + types.Content(role='model', parts=[types.Part.from_text(text='hello')]), + ] + await conn.send_history(mock_contents) + + assert mock_gemini_session.send_client_content.call_count == 1 + called_turns = mock_gemini_session.send_client_content.call_args.kwargs[ + 'turns' + ] + assert len(called_turns) == 1 + assert called_turns[0].role == 'user' + assert 'Previous conversation history:' in called_turns[0].parts[0].text + assert '[user]: hi' in called_turns[0].parts[0].text + assert '[model]: hello' in called_turns[0].parts[0].text + assert ( + mock_gemini_session.send_client_content.call_args.kwargs['turn_complete'] + is True + ) + + +@pytest.mark.asyncio +async def test_receive_grounding_metadata_default_gemini_3_1( + mock_gemini_session, +): + """Verify grounding_metadata defaults to empty GroundingMetadata for Gemini 3.1.""" + conn = GeminiLlmConnection( + mock_gemini_session, + model_version='gemini-3.1-flash-live-preview', + ) + + def make_msg(text=None, tc=False, tool_call=None): + msg = mock.create_autospec(types.LiveServerMessage, instance=True) + msg.usage_metadata = None + msg.tool_call = tool_call + msg.session_resumption_update = None + msg.go_away = None + msg.server_content = mock.Mock() + msg.server_content.interrupted = False + msg.server_content.input_transcription = None + msg.server_content.output_transcription = None + msg.server_content.generation_complete = False + msg.server_content.turn_complete = tc + msg.server_content.grounding_metadata = None + msg.server_content.model_turn = ( + types.Content(role='model', parts=[types.Part.from_text(text=text)]) + if text + else None + ) + return msg + + # 1. Content event + msg1 = make_msg(text='hello') + # 2. Tool call event (yields immediately for Gemini 3.1) + function_call = types.FunctionCall(name='foo', args={}) + tool_call = mock.create_autospec(types.LiveServerToolCall, instance=True) + tool_call.function_calls = [function_call] + msg2 = make_msg(tool_call=tool_call) + # 3. Turn complete event + msg3 = make_msg(tc=True) + + async def mock_receive_generator(): + yield msg1 + yield msg2 + yield msg3 + + mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator()) + responses = [resp async for resp in conn.receive()] + # Expected: + # responses[0] -> partial content response for msg1 (has no grounding_metadata) + # responses[1] -> full text response for msg1 (has no grounding_metadata) + # responses[2] -> tool call response for msg2 (has no grounding_metadata) + # responses[3] -> turn_complete response for msg3 (has grounding_metadata) + assert len(responses) == 4 + assert responses[0].content.parts[0].text == 'hello' + assert responses[0].grounding_metadata is None + assert responses[0].partial is True + assert responses[1].content.parts[0].text == 'hello' + assert responses[1].grounding_metadata is None + assert responses[1].partial is False + assert responses[2].content.parts[0].function_call.name == 'foo' + assert responses[2].grounding_metadata is None + assert responses[3].turn_complete is True + assert isinstance(responses[3].grounding_metadata, types.GroundingMetadata) + + +@pytest.mark.asyncio +async def test_receive_grounding_metadata_default_non_gemini_3_1( + mock_gemini_session, +): + """Verify grounding_metadata stays None for non-Gemini 3.1 models.""" + conn = GeminiLlmConnection( + mock_gemini_session, + model_version='gemini-2.5-flash-live', + ) + + def make_msg(text=None, tc=False): + msg = mock.create_autospec(types.LiveServerMessage, instance=True) + msg.usage_metadata = None + msg.tool_call = None + msg.session_resumption_update = None + msg.go_away = None + msg.server_content = mock.Mock() + msg.server_content.interrupted = False + msg.server_content.input_transcription = None + msg.server_content.output_transcription = None + msg.server_content.generation_complete = False + msg.server_content.turn_complete = tc + msg.server_content.grounding_metadata = None + msg.server_content.model_turn = ( + types.Content(role='model', parts=[types.Part.from_text(text=text)]) + if text + else None + ) + return msg + + msg1 = make_msg(text='hello') + msg2 = make_msg(tc=True) + + async def mock_receive_generator(): + yield msg1 + yield msg2 + + mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator()) + responses = [resp async for resp in conn.receive()] + assert len(responses) == 3 + assert responses[0].content.parts[0].text == 'hello' + assert responses[0].grounding_metadata is None + assert responses[0].partial is True + assert responses[1].content.parts[0].text == 'hello' + assert responses[1].grounding_metadata is None + assert responses[1].partial is False + assert responses[2].turn_complete is True + assert responses[2].grounding_metadata is None + + +@pytest.mark.asyncio +async def test_receive_input_transcription_gemini_3_1( + mock_gemini_session, +): + """Verify input_transcription yields finished=True immediately for Gemini 3.1.""" + conn = GeminiLlmConnection( + mock_gemini_session, + model_version='gemini-3.1-flash-live-preview', + ) + + def make_msg( + input_text=None, output_text=None, output_finished=False, tc=False + ): + msg = mock.create_autospec(types.LiveServerMessage, instance=True) + msg.usage_metadata = None + msg.tool_call = None + msg.session_resumption_update = None + msg.go_away = None + msg.server_content = mock.Mock() + msg.server_content.interrupted = False + msg.server_content.input_transcription = ( + types.Transcription(text=input_text, finished=False) + if input_text + else None + ) + msg.server_content.output_transcription = ( + types.Transcription(text=output_text, finished=output_finished) + if output_text + else None + ) + msg.server_content.generation_complete = False + msg.server_content.turn_complete = tc + msg.server_content.grounding_metadata = None + msg.server_content.model_turn = None + return msg + + msg1 = make_msg(input_text='Hello') + msg2 = make_msg(output_text='Hi there!', output_finished=True) + msg3 = make_msg(tc=True) + + async def mock_receive_generator(): + yield msg1 + yield msg2 + yield msg3 + + mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator()) + + responses = [resp async for resp in conn.receive()] + + assert len(responses) == 4 + + assert responses[0].input_transcription.text == 'Hello' + assert responses[0].input_transcription.finished is True + assert responses[0].partial is False + + assert responses[1].output_transcription.text == 'Hi there!' + assert responses[1].output_transcription.finished is False + assert responses[1].partial is True + + assert responses[2].output_transcription.text == 'Hi there!' + assert responses[2].output_transcription.finished is True + assert responses[2].partial is False + + assert responses[3].turn_complete is True diff --git a/tests/unittests/models/test_interactions_utils.py b/tests/unittests/models/test_interactions_utils.py index 118a925ab6e..65019b1c9eb 100644 --- a/tests/unittests/models/test_interactions_utils.py +++ b/tests/unittests/models/test_interactions_utils.py @@ -20,16 +20,27 @@ from datetime import datetime from datetime import timezone import json -from types import SimpleNamespace from unittest.mock import MagicMock from google.adk.models import interactions_utils from google.adk.models.llm_request import LlmRequest +from google.genai import interactions from google.genai import types -from google.genai._interactions.types.interaction import Interaction -from google.genai._interactions.types.interaction_complete_event import InteractionCompleteEvent -from google.genai._interactions.types.interaction_start_event import InteractionStartEvent -from google.genai._interactions.types.interaction_status_update import InteractionStatusUpdate +from google.genai.interactions import CodeExecutionResultStep +from google.genai.interactions import FunctionCallStep +from google.genai.interactions import FunctionResultStep +from google.genai.interactions import ImageContent +from google.genai.interactions import Interaction +from google.genai.interactions import InteractionCompletedEvent +from google.genai.interactions import InteractionCreatedEvent +from google.genai.interactions import InteractionSseEventInteraction +from google.genai.interactions import ModelOutputStep +from google.genai.interactions import StepDelta +from google.genai.interactions import StepStart +from google.genai.interactions import StepStop +from google.genai.interactions import TextContent +from google.genai.interactions import ThoughtStep +from google.genai.interactions import Usage import pytest @@ -73,21 +84,6 @@ def __init__(self, events: list[object]): self.aio = _FakeAio(events) -def _build_function_call_delta_event( - *, function_id: str, name: str, arguments: dict[str, object] -) -> SimpleNamespace: - """Build a version-agnostic content.delta event for a function call.""" - return SimpleNamespace( - event_type='content.delta', - delta=SimpleNamespace( - type='function_call', - id=function_id, - name=name, - arguments=arguments, - ), - ) - - def _build_llm_request() -> LlmRequest: """Build a minimal request for interactions streaming tests.""" return LlmRequest( @@ -102,69 +98,75 @@ def _build_llm_request() -> LlmRequest: ) -def _build_lifecycle_streamed_events() -> list[object]: +@pytest.fixture +def fc_step() -> FunctionCallStep: + """Fixture providing a basic FunctionCallStep.""" + return FunctionCallStep( + type='function_call', + id='call_1', + name='get_weather', + arguments={'city': 'Tokyo'}, + ) + + +def _build_lifecycle_streamed_events(fc_step: FunctionCallStep) -> list[object]: """Build streamed events with lifecycle updates carrying the ID.""" - now = datetime.now(timezone.utc) + now = datetime.now(timezone.utc).isoformat() + + interaction = InteractionSseEventInteraction( + id='interaction_123', + created=now, + updated=now, + status='requires_action', + steps=[fc_step], + ) + return [ - InteractionStartEvent( - event_type='interaction.start', - interaction=Interaction( - id='interaction_123', - created=now, - updated=now, - status='in_progress', - ), + InteractionCreatedEvent( + event_type='interaction.created', + interaction=interaction, ), - _build_function_call_delta_event( - function_id='call_1', - name='get_weather', - arguments={'city': 'Tokyo'}, - ), - InteractionStatusUpdate( - event_type='interaction.status_update', - interaction_id='interaction_123', - status='requires_action', + InteractionCompletedEvent( + event_type='interaction.completed', + interaction=interaction, ), ] -def _build_complete_streamed_events() -> list[object]: +def _build_complete_streamed_events(fc_step: FunctionCallStep) -> list[object]: """Build streamed events with the ID on an interaction.complete event.""" - now = datetime.now(timezone.utc) + now = datetime.now(timezone.utc).isoformat() + + interaction = InteractionSseEventInteraction( + id='interaction_complete_123', + created=now, + updated=now, + status='requires_action', + steps=[fc_step], + ) + return [ - _build_function_call_delta_event( - function_id='call_1', - name='get_weather', - arguments={'city': 'Tokyo'}, - ), - InteractionCompleteEvent( - event_type='interaction.complete', - interaction=Interaction( - id='interaction_complete_123', - created=now, - updated=now, - status='requires_action', - ), + InteractionCompletedEvent( + event_type='interaction.completed', + interaction=interaction, ), ] -def _build_legacy_streamed_events() -> list[object]: +def _build_legacy_streamed_events(fc_step: FunctionCallStep) -> list[object]: """Build streamed events with the ID on the legacy interaction event.""" + now = datetime.now(timezone.utc).isoformat() + + interaction = Interaction( + id='interaction_legacy_123', + created=now, + updated=now, + status='requires_action', + steps=[fc_step], + ) + return [ - _build_function_call_delta_event( - function_id='call_1', - name='get_weather', - arguments={'city': 'Tokyo'}, - ), - SimpleNamespace( - event_type='interaction', - id='interaction_legacy_123', - status='requires_action', - error=None, - outputs=None, - usage=None, - ), + interaction, ] @@ -194,13 +196,27 @@ async def _collect_function_call_interaction_ids( class TestConvertPartToInteractionContent: - """Tests for convert_part_to_interaction_content.""" + """Tests for _convert_part_to_interaction_content.""" def test_text_part(self): """Test converting a text Part.""" part = types.Part(text='Hello, world!') - result = interactions_utils.convert_part_to_interaction_content(part) - assert result == {'type': 'text', 'text': 'Hello, world!'} + result = interactions_utils._convert_part_to_interaction_content(part) + assert result == { + 'type': 'user_input', + 'content': [{'type': 'text', 'text': 'Hello, world!'}], + } + + def test_text_part_model_role(self): + """Test converting a text Part for model role.""" + part = types.Part(text='Hello, user!') + result = interactions_utils._convert_part_to_interaction_content( + part, role='model' + ) + assert result == { + 'type': 'model_output', + 'content': [{'type': 'text', 'text': 'Hello, user!'}], + } def test_function_call_part(self): """Test converting a function call Part.""" @@ -211,7 +227,7 @@ def test_function_call_part(self): args={'city': 'London'}, ) ) - result = interactions_utils.convert_part_to_interaction_content(part) + result = interactions_utils._convert_part_to_interaction_content(part) assert result == { 'type': 'function_call', 'id': 'call_123', @@ -227,12 +243,12 @@ def test_function_call_part_no_id(self): args={'city': 'London'}, ) ) - result = interactions_utils.convert_part_to_interaction_content(part) + result = interactions_utils._convert_part_to_interaction_content(part) assert result['id'] == '' assert result['name'] == 'get_weather' - def test_function_call_part_with_thought_signature(self): - """Test converting a function call Part with thought_signature.""" + def test_function_call_part_thought_signature_dropped(self): + """Thought signatures are not sent on interactions function call steps.""" part = types.Part( function_call=types.FunctionCall( id='call_456', @@ -241,17 +257,14 @@ def test_function_call_part_with_thought_signature(self): ), thought_signature=b'test_signature_bytes', ) - result = interactions_utils.convert_part_to_interaction_content(part) - assert result['type'] == 'function_call' - assert result['id'] == 'call_456' - assert result['name'] == 'my_tool' - assert result['arguments'] == {'doc': 'content'} - # thought_signature should be base64 encoded - assert 'thought_signature' in result - - assert ( - base64.b64decode(result['thought_signature']) == b'test_signature_bytes' - ) + result = interactions_utils._convert_part_to_interaction_content(part) + assert result == { + 'type': 'function_call', + 'id': 'call_456', + 'name': 'my_tool', + 'arguments': {'doc': 'content'}, + } + assert 'signature' not in result def test_function_call_part_without_thought_signature(self): """Test converting a function call Part without thought_signature.""" @@ -262,10 +275,10 @@ def test_function_call_part_without_thought_signature(self): args={}, ) ) - result = interactions_utils.convert_part_to_interaction_content(part) + result = interactions_utils._convert_part_to_interaction_content(part) assert result['type'] == 'function_call' - # thought_signature should not be present - assert 'thought_signature' not in result + # signature should not be present + assert 'signature' not in result def test_function_response_dict(self): """Test converting a function response Part with dict response.""" @@ -276,13 +289,15 @@ def test_function_response_dict(self): response={'temperature': 20, 'condition': 'sunny'}, ) ) - result = interactions_utils.convert_part_to_interaction_content(part) + result = interactions_utils._convert_part_to_interaction_content(part) assert result['type'] == 'function_result' assert result['call_id'] == 'call_123' assert result['name'] == 'get_weather' - # Dict should be passed through directly (not JSON-serialized). - assert result['result'] == {'temperature': 20, 'condition': 'sunny'} - assert isinstance(result['result'], dict) + # Dict should be passed through directly (not JSON-serialized) + assert result['result'] == { + 'temperature': 20, + 'condition': 'sunny', + } def test_function_response_simple(self): """Test converting a function response Part with simple response.""" @@ -293,13 +308,30 @@ def test_function_response_simple(self): response={'message': 'Weather is sunny'}, ) ) - result = interactions_utils.convert_part_to_interaction_content(part) + result = interactions_utils._convert_part_to_interaction_content(part) assert result['type'] == 'function_result' assert result['call_id'] == 'call_123' assert result['name'] == 'check_weather' - # Dict should be passed through directly (not JSON-serialized). + # Dict should be JSON serialized assert result['result'] == {'message': 'Weather is sunny'} + def test_convert_part_to_interaction_content_function_response_error(self): + part = types.Part( + function_response=types.FunctionResponse( + name='my_function', + id='call_123', + response={'error': 'something went wrong'}, + ) + ) + result = interactions_utils._convert_part_to_interaction_content(part) + assert result == interactions.FunctionResultStepParam( + type='function_result', + name='my_function', + call_id='call_123', + result={'error': 'something went wrong'}, + is_error=True, + ) + def test_function_response_dict_not_double_serialized(self): """Regression test: avoid double-serializing bash tool outputs. @@ -320,7 +352,7 @@ def test_function_response_dict_not_double_serialized(self): response=bash_response, ) ) - result = interactions_utils.convert_part_to_interaction_content(part) + result = interactions_utils._convert_part_to_interaction_content(part) # The result value must be the dict itself, NOT a JSON string. assert isinstance(result['result'], dict) assert result['result'] == bash_response @@ -337,11 +369,16 @@ def test_inline_data_image(self): mime_type='image/png', ) ) - result = interactions_utils.convert_part_to_interaction_content(part) + result = interactions_utils._convert_part_to_interaction_content(part) assert result == { - 'type': 'image', - 'data': b'image_data', - 'mime_type': 'image/png', + 'type': 'user_input', + 'content': [{ + 'type': 'image', + 'data': ( + 'aW1hZ2VfZGF0YQ==' + ), # base64.b64encode(b'image_data').decode('utf-8') + 'mime_type': 'image/png', + }], } def test_inline_data_audio(self): @@ -352,11 +389,16 @@ def test_inline_data_audio(self): mime_type='audio/mp3', ) ) - result = interactions_utils.convert_part_to_interaction_content(part) + result = interactions_utils._convert_part_to_interaction_content(part) assert result == { - 'type': 'audio', - 'data': b'audio_data', - 'mime_type': 'audio/mp3', + 'type': 'user_input', + 'content': [{ + 'type': 'audio', + 'data': ( + 'YXVkaW9fZGF0YQ==' + ), # base64.b64encode(b'audio_data').decode('utf-8') + 'mime_type': 'audio/mp3', + }], } def test_inline_data_video(self): @@ -367,11 +409,16 @@ def test_inline_data_video(self): mime_type='video/mp4', ) ) - result = interactions_utils.convert_part_to_interaction_content(part) + result = interactions_utils._convert_part_to_interaction_content(part) assert result == { - 'type': 'video', - 'data': b'video_data', - 'mime_type': 'video/mp4', + 'type': 'user_input', + 'content': [{ + 'type': 'video', + 'data': ( + 'dmlkZW9fZGF0YQ==' + ), # base64.b64encode(b'video_data').decode('utf-8') + 'mime_type': 'video/mp4', + }], } def test_inline_data_document(self): @@ -382,11 +429,16 @@ def test_inline_data_document(self): mime_type='application/pdf', ) ) - result = interactions_utils.convert_part_to_interaction_content(part) + result = interactions_utils._convert_part_to_interaction_content(part) assert result == { - 'type': 'document', - 'data': b'doc_data', - 'mime_type': 'application/pdf', + 'type': 'user_input', + 'content': [{ + 'type': 'document', + 'data': ( + 'ZG9jX2RhdGE=' + ), # base64.b64encode(b'doc_data').decode('utf-8') + 'mime_type': 'application/pdf', + }], } def test_file_data_image(self): @@ -397,11 +449,14 @@ def test_file_data_image(self): mime_type='image/png', ) ) - result = interactions_utils.convert_part_to_interaction_content(part) + result = interactions_utils._convert_part_to_interaction_content(part) assert result == { - 'type': 'image', - 'uri': 'gs://bucket/image.png', - 'mime_type': 'image/png', + 'type': 'user_input', + 'content': [{ + 'type': 'image', + 'uri': 'gs://bucket/image.png', + 'mime_type': 'image/png', + }], } def test_text_with_thought_flag(self): @@ -410,22 +465,25 @@ def test_text_with_thought_flag(self): # When text is present, the convert function returns text type (not thought) # because text check comes before thought check in the implementation part = types.Part(text='Let me think about this...', thought=True) - result = interactions_utils.convert_part_to_interaction_content(part) + result = interactions_utils._convert_part_to_interaction_content(part) # Text content is returned as-is (thought flag not represented in output) - assert result == {'type': 'text', 'text': 'Let me think about this...'} + assert result == { + 'type': 'user_input', + 'content': [{'type': 'text', 'text': 'Let me think about this...'}], + } def test_thought_only_part(self): """Test converting a thought-only Part with signature.""" signature_bytes = b'test-thought-signature' part = types.Part(thought=True, thought_signature=signature_bytes) - result = interactions_utils.convert_part_to_interaction_content(part) + result = interactions_utils._convert_part_to_interaction_content(part) expected_signature = base64.b64encode(signature_bytes).decode('utf-8') assert result == {'type': 'thought', 'signature': expected_signature} def test_thought_only_part_without_signature(self): """Test converting a thought-only Part without signature.""" part = types.Part(thought=True) - result = interactions_utils.convert_part_to_interaction_content(part) + result = interactions_utils._convert_part_to_interaction_content(part) assert result == {'type': 'thought'} def test_code_execution_result(self): @@ -436,7 +494,7 @@ def test_code_execution_result(self): outcome=types.Outcome.OUTCOME_OK, ) ) - result = interactions_utils.convert_part_to_interaction_content(part) + result = interactions_utils._convert_part_to_interaction_content(part) assert result == { 'type': 'code_execution_result', 'call_id': '', @@ -452,7 +510,7 @@ def test_code_execution_result_with_error(self): outcome=types.Outcome.OUTCOME_FAILED, ) ) - result = interactions_utils.convert_part_to_interaction_content(part) + result = interactions_utils._convert_part_to_interaction_content(part) assert result == { 'type': 'code_execution_result', 'call_id': '', @@ -468,7 +526,7 @@ def test_code_execution_result_deadline_exceeded(self): outcome=types.Outcome.OUTCOME_DEADLINE_EXCEEDED, ) ) - result = interactions_utils.convert_part_to_interaction_content(part) + result = interactions_utils._convert_part_to_interaction_content(part) assert result == { 'type': 'code_execution_result', 'call_id': '', @@ -484,7 +542,7 @@ def test_executable_code(self): language='PYTHON', ) ) - result = interactions_utils.convert_part_to_interaction_content(part) + result = interactions_utils._convert_part_to_interaction_content(part) assert result == { 'type': 'code_execution_call', 'id': '', @@ -497,12 +555,12 @@ def test_executable_code(self): def test_empty_part(self): """Test converting an empty Part returns None.""" part = types.Part() - result = interactions_utils.convert_part_to_interaction_content(part) + result = interactions_utils._convert_part_to_interaction_content(part) assert result is None -class TestConvertContentToTurn: - """Tests for convert_content_to_turn.""" +class TestConvertContentToStep: + """Tests for _convert_content_to_step.""" def test_user_content(self): """Test converting user content.""" @@ -510,11 +568,11 @@ def test_user_content(self): role='user', parts=[types.Part(text='Hello!')], ) - result = interactions_utils.convert_content_to_turn(content) - assert result == { - 'role': 'user', + result = interactions_utils._convert_content_to_step(content) + assert result == [{ + 'type': 'user_input', 'content': [{'type': 'text', 'text': 'Hello!'}], - } + }] def test_model_content(self): """Test converting model content.""" @@ -522,11 +580,11 @@ def test_model_content(self): role='model', parts=[types.Part(text='Hi there!')], ) - result = interactions_utils.convert_content_to_turn(content) - assert result == { - 'role': 'model', + result = interactions_utils._convert_content_to_step(content) + assert result == [{ + 'type': 'model_output', 'content': [{'type': 'text', 'text': 'Hi there!'}], - } + }] def test_multiple_parts(self): """Test converting content with multiple parts.""" @@ -539,30 +597,60 @@ def test_multiple_parts(self): ), ], ) - result = interactions_utils.convert_content_to_turn(content) - assert result['role'] == 'user' - assert len(result['content']) == 2 - assert result['content'][0] == {'type': 'text', 'text': 'Look at this:'} - assert result['content'][1]['type'] == 'image' + result = interactions_utils._convert_content_to_step(content) + assert len(result) == 2 + assert result[0]['type'] == 'user_input' + assert result[0]['content'][0] == {'type': 'text', 'text': 'Look at this:'} + assert result[1]['type'] == 'user_input' + assert result[1]['content'][0]['type'] == 'image' + + def test_interleaved_parts(self): + """Test converting content with interleaved text and media parts.""" + content = types.Content( + role='user', + parts=[ + types.Part(text='First:'), + types.Part( + inline_data=types.Blob(data=b'img1', mime_type='image/png') + ), + types.Part(text='Second:'), + types.Part( + inline_data=types.Blob(data=b'img2', mime_type='image/jpeg') + ), + types.Part(text='End'), + ], + ) + result = interactions_utils._convert_content_to_step(content) + assert len(result) == 5 + assert result[0]['type'] == 'user_input' + assert result[0]['content'][0] == {'type': 'text', 'text': 'First:'} + assert result[1]['type'] == 'user_input' + assert result[1]['content'][0]['type'] == 'image' + assert result[2]['type'] == 'user_input' + assert result[2]['content'][0] == {'type': 'text', 'text': 'Second:'} + assert result[3]['type'] == 'user_input' + assert result[3]['content'][0]['type'] == 'image' + assert result[4]['type'] == 'user_input' + assert result[4]['content'][0] == {'type': 'text', 'text': 'End'} def test_default_role(self): """Test that default role is 'user' when not specified.""" content = types.Content(parts=[types.Part(text='Hi')]) - result = interactions_utils.convert_content_to_turn(content) - assert result['role'] == 'user' + result = interactions_utils._convert_content_to_step(content) + assert result[0]['type'] == 'user_input' -class TestConvertContentsToTurns: - """Tests for convert_contents_to_turns.""" +class TestConvertContentsToSteps: + """Tests for convert_contents_to_steps.""" def test_single_content(self): """Test converting a list with single content.""" contents = [ types.Content(role='user', parts=[types.Part(text='What is 2+2?')]), ] - result = interactions_utils.convert_contents_to_turns(contents) + result = interactions_utils._convert_contents_to_steps(contents) assert len(result) == 1 - assert result[0]['role'] == 'user' + assert result[0]['type'] == 'user_input' assert result[0]['content'][0]['text'] == 'What is 2+2?' def test_multi_turn_conversation(self): @@ -572,11 +660,11 @@ def test_multi_turn_conversation(self): types.Content(role='model', parts=[types.Part(text='Hello!')]), types.Content(role='user', parts=[types.Part(text='How are you?')]), ] - result = interactions_utils.convert_contents_to_turns(contents) + result = interactions_utils._convert_contents_to_steps(contents) assert len(result) == 3 - assert result[0]['role'] == 'user' - assert result[1]['role'] == 'model' - assert result[2]['role'] == 'user' + assert result[0]['type'] == 'user_input' + assert result[1]['type'] == 'model_output' + assert result[2]['type'] == 'user_input' def test_empty_content_skipped(self): """Test that empty contents are skipped.""" @@ -584,13 +672,13 @@ def test_empty_content_skipped(self): types.Content(role='user', parts=[types.Part(text='Hi')]), types.Content(role='model', parts=[]), # Empty parts ] - result = interactions_utils.convert_contents_to_turns(contents) + result = interactions_utils._convert_contents_to_steps(contents) # Only the first content should be included assert len(result) == 1 class TestConvertToolsConfig: - """Tests for convert_tools_config_to_interactions_format.""" + """Tests for _convert_tools_config_to_interactions_format.""" def test_function_declaration(self): """Test converting function declarations.""" @@ -651,133 +739,184 @@ def test_no_tools(self): assert result == [] -class TestConvertInteractionOutputToPart: - """Tests for convert_interaction_output_to_part.""" +class TestConvertInteractionOutputToParts: + """Tests for convert_interaction_output_to_parts.""" def test_text_output(self): """Test converting text output.""" - output = MagicMock() - output.type = 'text' - output.text = 'Hello!' - result = interactions_utils.convert_interaction_output_to_part(output) + output = ModelOutputStep( + type='model_output', content=[TextContent(type='text', text='Hello!')] + ) + result_list = interactions_utils._convert_interaction_step_to_parts(output) + result = result_list[0] if result_list else None assert result.text == 'Hello!' def test_function_call_output(self): """Test converting function call output.""" - output = MagicMock() - output.type = 'function_call' - output.id = 'call_123' - output.name = 'get_weather' - output.arguments = {'city': 'London'} - result = interactions_utils.convert_interaction_output_to_part(output) + output = FunctionCallStep( + type='function_call', + id='call_123', + name='get_weather', + arguments={'city': 'London'}, + ) + result_list = interactions_utils._convert_interaction_step_to_parts(output) + result = result_list[0] if result_list else None assert result.function_call.id == 'call_123' assert result.function_call.name == 'get_weather' assert result.function_call.args == {'city': 'London'} - def test_function_call_output_with_thought_signature(self): - """Test converting function call output with thought_signature.""" - output = MagicMock( - spec=['type', 'id', 'name', 'arguments', 'thought_signature'] - ) - output.type = 'function_call' - output.id = 'call_sig_123' - output.name = 'gemini3_tool' - output.arguments = {'content': 'hello'} - # thought_signature is base64 encoded in the output - output.thought_signature = base64.b64encode(b'gemini3_signature').decode( - 'utf-8' - ) - result = interactions_utils.convert_interaction_output_to_part(output) - assert result.function_call.id == 'call_sig_123' - assert result.function_call.name == 'gemini3_tool' - assert result.function_call.args == {'content': 'hello'} - # thought_signature should be decoded back to bytes - assert result.thought_signature == b'gemini3_signature' - def test_function_call_output_without_thought_signature(self): """Test converting function call output without thought_signature.""" - output = MagicMock(spec=['type', 'id', 'name', 'arguments']) - output.type = 'function_call' - output.id = 'call_no_sig' - output.name = 'regular_tool' - output.arguments = {} - result = interactions_utils.convert_interaction_output_to_part(output) + output = FunctionCallStep( + type='function_call', + id='call_no_sig', + name='regular_tool', + arguments={}, + ) + result_list = interactions_utils._convert_interaction_step_to_parts(output) + result = result_list[0] if result_list else None assert result.function_call.id == 'call_no_sig' assert result.function_call.name == 'regular_tool' # thought_signature should be None assert result.thought_signature is None - def test_function_result_output_with_items_list(self): - """Test converting function result output with items list. - - The implementation handles the case where result has an 'items' attribute - that returns a list-like structure. This test validates that path. - """ - output = MagicMock() - output.type = 'function_result' - output.call_id = 'call_123' - # Create a mock that has .items returning a dict (for FunctionResponse) - output.result = MagicMock() - output.result.items = {'weather': 'Sunny'} # items attribute returns dict - result = interactions_utils.convert_interaction_output_to_part(output) + def test_function_result_output(self): + """Test converting function result output.""" + output = FunctionResultStep( + type='function_result', + call_id='call_123', + result={'weather': 'Sunny'}, + ) + result_list = interactions_utils._convert_interaction_step_to_parts(output) + result = result_list[0] if result_list else None assert result.function_response.id == 'call_123' assert result.function_response.response == {'weather': 'Sunny'} + def test_function_result_output_preserves_none_values(self): + """None values in a dict result must not be dropped.""" + output = FunctionResultStep( + type='function_result', + call_id='call_none', + result={'data': None, 'ok': True}, + ) + result_list = interactions_utils._convert_interaction_step_to_parts(output) + result = result_list[0] if result_list else None + assert result.function_response.response == {'data': None, 'ok': True} + + def test_function_result_output_string(self): + """A plain string result is wrapped under a 'result' key.""" + output = FunctionResultStep( + type='function_result', + call_id='call_str', + result='plain text', + ) + result_list = interactions_utils._convert_interaction_step_to_parts(output) + result = result_list[0] if result_list else None + assert result.function_response.response == {'result': 'plain text'} + + def test_function_result_output_list(self): + """A list result of content blocks is wrapped under a 'result' key.""" + output = FunctionResultStep( + type='function_result', + call_id='call_list', + result=[{'type': 'text', 'text': 'hi'}], + ) + result_list = interactions_utils._convert_interaction_step_to_parts(output) + result = result_list[0] if result_list else None + wrapped = result.function_response.response['result'] + assert wrapped[0]['type'] == 'text' + assert wrapped[0]['text'] == 'hi' + def test_image_output_with_data(self): """Test converting image output with inline data.""" - output = MagicMock() - output.type = 'image' - output.data = b'image_bytes' - output.uri = None - output.mime_type = 'image/png' - result = interactions_utils.convert_interaction_output_to_part(output) + output = ModelOutputStep( + type='model_output', + content=[ + ImageContent( + type='image', + data=base64.b64encode(b'image_bytes').decode('utf-8'), + mime_type='image/png', + ) + ], + ) + result_list = interactions_utils._convert_interaction_step_to_parts(output) + result = result_list[0] if result_list else None assert result.inline_data.data == b'image_bytes' assert result.inline_data.mime_type == 'image/png' def test_image_output_with_uri(self): """Test converting image output with URI.""" - output = MagicMock() - output.type = 'image' - output.data = None - output.uri = 'gs://bucket/image.png' - output.mime_type = 'image/png' - result = interactions_utils.convert_interaction_output_to_part(output) + output = ModelOutputStep( + type='model_output', + content=[ + ImageContent( + type='image', + uri='gs://bucket/image.png', + mime_type='image/png', + ) + ], + ) + result_list = interactions_utils._convert_interaction_step_to_parts(output) + result = result_list[0] if result_list else None assert result.file_data.file_uri == 'gs://bucket/image.png' assert result.file_data.mime_type == 'image/png' def test_code_execution_result_output(self): """Test converting code execution result output.""" - output = MagicMock() - output.type = 'code_execution_result' - output.result = 'Output from code' - output.is_error = False # Indicate successful execution - result = interactions_utils.convert_interaction_output_to_part(output) + output = CodeExecutionResultStep( + type='code_execution_result', + call_id='', + result='Output from code', + is_error=False, + ) + result_list = interactions_utils._convert_interaction_step_to_parts(output) + result = result_list[0] if result_list else None assert result.code_execution_result.output == 'Output from code' assert result.code_execution_result.outcome == types.Outcome.OUTCOME_OK def test_code_execution_result_error_output(self): """Test converting code execution result output with error.""" - output = MagicMock() - output.type = 'code_execution_result' - output.result = 'Error: division by zero' - output.is_error = True # Indicate failed execution - result = interactions_utils.convert_interaction_output_to_part(output) + output = CodeExecutionResultStep( + type='code_execution_result', + call_id='', + result='Error: division by zero', + is_error=True, + ) + result_list = interactions_utils._convert_interaction_step_to_parts(output) + result = result_list[0] if result_list else None assert result.code_execution_result.output == 'Error: division by zero' assert result.code_execution_result.outcome == types.Outcome.OUTCOME_FAILED - def test_thought_output_returns_none(self): - """Test that thought output returns None (not exposed as Part).""" - output = MagicMock() - output.type = 'thought' - output.signature = 'thinking...' - result = interactions_utils.convert_interaction_output_to_part(output) - assert result is None + def test_thought_output_returns_empty(self): + """Test that thought output returns empty list (not exposed as Part).""" + output = ThoughtStep(type='thought', signature='thinking...') + result = interactions_utils._convert_interaction_step_to_parts(output) + assert result == [] def test_no_type_attribute(self): """Test handling output without type attribute.""" output = MagicMock(spec=[]) # No 'type' attribute - result = interactions_utils.convert_interaction_output_to_part(output) - assert result is None + result = interactions_utils._convert_interaction_step_to_parts(output) + assert result == [] + + def test_code_execution_call_output_uppercase_python(self): + """Test converting code execution call output with uppercase PYTHON.""" + from google.genai.interactions import CodeExecutionCallStep + + mock_args = MagicMock() + mock_args.code = 'print("hello")' + mock_args.language = 'PYTHON' + + output = CodeExecutionCallStep.model_construct( + type='code_execution_call', + id='', + arguments=mock_args, + ) + result_list = interactions_utils._convert_interaction_step_to_parts(output) + result = result_list[0] if result_list else None + assert result is not None + assert result.executable_code.code == 'print("hello")' + assert result.executable_code.language == types.Language.PYTHON class TestConvertInteractionToLlmResponse: @@ -785,18 +924,19 @@ class TestConvertInteractionToLlmResponse: def test_successful_text_response(self): """Test converting a successful text response.""" - interaction = MagicMock() - interaction.id = 'interaction_123' - interaction.status = 'completed' - text_output = MagicMock() - text_output.type = 'text' - text_output.text = 'The answer is 4.' - interaction.outputs = [text_output] - interaction.usage = MagicMock() - interaction.usage.total_input_tokens = 10 - interaction.usage.total_output_tokens = 5 - interaction.error = None - + interaction = Interaction( + id='interaction_123', + status='completed', + created=datetime.now(timezone.utc).isoformat(), + updated=datetime.now(timezone.utc).isoformat(), + steps=[ + ModelOutputStep( + type='model_output', + content=[TextContent(type='text', text='The answer is 4.')], + ) + ], + usage=Usage(total_input_tokens=10, total_output_tokens=5), + ) result = interactions_utils.convert_interaction_to_llm_response(interaction) assert result.interaction_id == 'interaction_123' @@ -808,13 +948,14 @@ def test_successful_text_response(self): def test_failed_response(self): """Test converting a failed response.""" - interaction = MagicMock() - interaction.id = 'interaction_123' - interaction.status = 'failed' - interaction.outputs = [] - interaction.error = MagicMock() - interaction.error.code = 'INVALID_REQUEST' - interaction.error.message = 'Bad request' + interaction = Interaction( + id='interaction_123', + status='failed', + created=datetime.now(timezone.utc).isoformat(), + updated=datetime.now(timezone.utc).isoformat(), + steps=[], + ) + interaction.error = MagicMock(code='INVALID_REQUEST', message='Bad request') result = interactions_utils.convert_interaction_to_llm_response(interaction) @@ -824,18 +965,20 @@ def test_failed_response(self): def test_requires_action_response(self): """Test converting a requires_action response (function call).""" - interaction = MagicMock() - interaction.id = 'interaction_123' - interaction.status = 'requires_action' - fc_output = MagicMock() - fc_output.type = 'function_call' - fc_output.id = 'call_1' - fc_output.name = 'get_weather' - fc_output.arguments = {'city': 'Paris'} - interaction.outputs = [fc_output] - interaction.usage = None - interaction.error = None - + interaction = Interaction( + id='interaction_123', + status='requires_action', + created=datetime.now(timezone.utc).isoformat(), + updated=datetime.now(timezone.utc).isoformat(), + steps=[ + FunctionCallStep( + type='function_call', + id='call_1', + name='get_weather', + arguments={'city': 'Paris'}, + ) + ], + ) result = interactions_utils.convert_interaction_to_llm_response(interaction) assert result.interaction_id == 'interaction_123' @@ -1030,12 +1173,11 @@ class TestConvertInteractionEventToLlmResponse: def test_text_delta_event(self): """Test converting a text delta event.""" - event = MagicMock() - event.event_type = 'content.delta' - event.delta = MagicMock() - event.delta.type = 'text' - event.delta.text = 'Hello world' - + event = StepDelta( + event_type='step.delta', + index=0, + delta={'type': 'text', 'text': 'Hello world'}, + ) aggregated_parts = [] result = interactions_utils.convert_interaction_event_to_llm_response( event, aggregated_parts, interaction_id='int_123' @@ -1047,111 +1189,172 @@ def test_text_delta_event(self): assert result.interaction_id == 'int_123' assert len(aggregated_parts) == 1 - def test_function_call_delta_with_thought_signature(self): - """Test converting a function call delta with thought_signature.""" - event = MagicMock() - event.event_type = 'content.delta' - event.delta = MagicMock( - spec=['type', 'id', 'name', 'arguments', 'thought_signature'] - ) - event.delta.type = 'function_call' - event.delta.id = 'fc_delta_123' - event.delta.name = 'streaming_tool' - event.delta.arguments = {'param': 'value'} - # thought_signature is base64 encoded in the delta - event.delta.thought_signature = base64.b64encode(b'delta_signature').decode( - 'utf-8' + def test_image_delta_with_data(self): + """Test converting an image delta with inline data.""" + event = StepDelta( + event_type='step.delta', + index=0, + delta={ + 'type': 'image', + 'data': base64.b64encode(b'image_bytes').decode('utf-8'), + 'mime_type': 'image/png', + }, ) - aggregated_parts = [] result = interactions_utils.convert_interaction_event_to_llm_response( - event, aggregated_parts, interaction_id='int_456' + event, aggregated_parts, interaction_id='int_img' ) - # Function calls return None (added to aggregated_parts only) - assert result is None + assert result is not None + assert result.partial + assert result.content.parts[0].inline_data.data == b'image_bytes' assert len(aggregated_parts) == 1 - fc_part = aggregated_parts[0] - assert fc_part.function_call.id == 'fc_delta_123' - assert fc_part.function_call.name == 'streaming_tool' - assert fc_part.function_call.args == {'param': 'value'} - # thought_signature should be decoded back to bytes - assert fc_part.thought_signature == b'delta_signature' - - def test_function_call_delta_without_thought_signature(self): - """Test converting a function call delta without thought_signature.""" + + def test_unknown_event_type_returns_none(self): + """Test that unknown event types return None.""" event = MagicMock() - event.event_type = 'content.delta' - event.delta = MagicMock(spec=['type', 'id', 'name', 'arguments']) - event.delta.type = 'function_call' - event.delta.id = 'fc_no_sig' - event.delta.name = 'regular_tool' - event.delta.arguments = {} + event.event_type = 'some_unknown_event' # Unknown event type aggregated_parts = [] result = interactions_utils.convert_interaction_event_to_llm_response( - event, aggregated_parts, interaction_id='int_789' + event, aggregated_parts, interaction_id='int_other' ) - # Function calls return None assert result is None - assert len(aggregated_parts) == 1 - fc_part = aggregated_parts[0] - assert fc_part.function_call.name == 'regular_tool' - # thought_signature should be None - assert fc_part.thought_signature is None - - def test_function_call_delta_without_name_skipped(self): - """Test that function call delta without name is skipped.""" - event = MagicMock() - event.event_type = 'content.delta' - event.delta = MagicMock(spec=['type', 'id', 'name', 'arguments']) - event.delta.type = 'function_call' - event.delta.id = 'fc_no_name' - event.delta.name = None # No name - event.delta.arguments = {} + assert not aggregated_parts - aggregated_parts = [] + def test_completed_event_failed_partial_interaction(self): + """A failed lifecycle event with a partial interaction does not crash.""" + event = InteractionCompletedEvent( + event_type='interaction.completed', + interaction=InteractionSseEventInteraction( + id='int_failed', + status='failed', + steps=[], + ), + ) result = interactions_utils.convert_interaction_event_to_llm_response( - event, aggregated_parts, interaction_id='int_000' + event, aggregated_parts=[], interaction_id='int_failed' + ) + assert result is not None + assert result.error_code == 'UNKNOWN_ERROR' + assert result.interaction_id == 'int_failed' + + def test_function_call_streaming_flow(self): + """Test the complete streaming flow for function calls (Start, Delta, Stop).""" + # 1. StepStart + start_event = StepStart( + event_type='step.start', + index=0, + step=FunctionCallStep( + type='function_call', + id='call_1', + name='get_weather', + arguments={}, + ), + ) + aggregated_parts: list[types.Part] = [] + result1 = interactions_utils.convert_interaction_event_to_llm_response( + start_event, aggregated_parts, interaction_id='int_123' ) - # Should be skipped (no name) - assert result is None - assert not aggregated_parts + assert result1 is not None + assert result1.partial is True + assert len(aggregated_parts) == 1 + fc = aggregated_parts[-1].function_call + assert fc + assert fc.name == 'get_weather' + assert fc.id == 'call_1' + assert fc.partial_args == [] + + # 2. StepDelta + delta_event1 = StepDelta( + event_type='step.delta', + index=0, + delta={'type': 'arguments_delta', 'arguments': '{"city": '}, + ) + result2 = interactions_utils.convert_interaction_event_to_llm_response( + delta_event1, aggregated_parts, interaction_id='int_123' + ) - def test_image_delta_with_data(self): - """Test converting an image delta with inline data.""" - event = MagicMock() - event.event_type = 'content.delta' - event.delta = MagicMock() - event.delta.type = 'image' - event.delta.data = b'image_bytes' - event.delta.uri = None - event.delta.mime_type = 'image/png' + assert result2 is not None + assert result2.partial is True + assert ( + result2.content.parts[0].function_call.partial_args[0].string_value + == '{"city": ' + ) - aggregated_parts = [] - result = interactions_utils.convert_interaction_event_to_llm_response( - event, aggregated_parts, interaction_id='int_img' + delta_event2 = StepDelta( + event_type='step.delta', + index=0, + delta={'type': 'arguments_delta', 'arguments': '"Paris"}'}, + ) + result3 = interactions_utils.convert_interaction_event_to_llm_response( + delta_event2, aggregated_parts, interaction_id='int_123' ) - assert result is not None - assert not result.partial - assert result.content.parts[0].inline_data.data == b'image_bytes' - assert len(aggregated_parts) == 1 + assert result3 is not None + assert len(aggregated_parts[0].function_call.partial_args) == 2 - def test_unknown_event_type_returns_none(self): - """Test that unknown event types return None.""" - event = MagicMock() - event.event_type = 'some_unknown_event' # Unknown event type + # 3. StepStop + stop_event = StepStop( + event_type='step.stop', + index=0, + ) + result4 = interactions_utils.convert_interaction_event_to_llm_response( + stop_event, aggregated_parts, interaction_id='int_123' + ) + assert result4 is None + assert aggregated_parts[0].function_call.args == {'city': 'Paris'} + assert aggregated_parts[0].function_call.partial_args is None + + def test_function_call_streaming_json_parse_error(self, caplog): + """Test function call streaming returns an error response on JSON parse error.""" + # 1. StepStart + start_event = StepStart( + event_type='step.start', + index=0, + step=FunctionCallStep( + type='function_call', + id='call_err', + name='bad_json_tool', + arguments={}, + ), + ) aggregated_parts = [] + interactions_utils.convert_interaction_event_to_llm_response( + start_event, aggregated_parts, interaction_id='int_err' + ) + + # 2. StepDelta (invalid JSON) + delta_event = StepDelta( + event_type='step.delta', + index=0, + delta={'type': 'arguments_delta', 'arguments': '{"broken": "json'}, + ) + interactions_utils.convert_interaction_event_to_llm_response( + delta_event, aggregated_parts, interaction_id='int_err' + ) + + # 3. StepStop + stop_event = StepStop( + event_type='step.stop', + index=0, + ) result = interactions_utils.convert_interaction_event_to_llm_response( - event, aggregated_parts, interaction_id='int_other' + stop_event, aggregated_parts, interaction_id='int_err' ) - assert result is None - assert not aggregated_parts + # Assert an error LlmResponse is returned + assert result is not None + assert result.error_code == 'JSON_PARSE_ERROR' + assert result.error_message == 'Failed to parse function call arguments' + assert result.turn_complete is True + assert result.interaction_id == 'int_err' + + # The logging check can remain to ensure the raw exception is still logged. + assert 'Failed to parse function call args' in caplog.text @pytest.mark.parametrize( @@ -1159,7 +1362,7 @@ def test_unknown_event_type_returns_none(self): [ pytest.param( _build_lifecycle_streamed_events, - ['interaction_123', 'interaction_123'], + ['interaction_123'], id='lifecycle-events', ), pytest.param( @@ -1175,11 +1378,12 @@ def test_unknown_event_type_returns_none(self): ], ) def test_generate_content_via_interactions_stream_extracts_interaction_id( - streamed_events_factory: Callable[[], list[object]], + streamed_events_factory: Callable[[FunctionCallStep], list[object]], expected_ids: list[str], + fc_step: FunctionCallStep, ): """Streamed interaction IDs should be preserved across event variants.""" - streamed_events = streamed_events_factory() + streamed_events = streamed_events_factory(fc_step) assert ( asyncio.run(_collect_function_call_interaction_ids(streamed_events)) diff --git a/tests/unittests/streaming/test_streaming.py b/tests/unittests/streaming/test_streaming.py index d77b13e5385..409243a09e7 100644 --- a/tests/unittests/streaming/test_streaming.py +++ b/tests/unittests/streaming/test_streaming.py @@ -34,36 +34,36 @@ def test_streaming(): mock_model = testing_utils.MockModel.create([response1]) root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[], ) runner = testing_utils.InMemoryRunner( - root_agent=root_agent, response_modalities=['AUDIO'] + root_agent=root_agent, response_modalities=["AUDIO"] ) live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( - blob=types.Blob(data=b'\x00\xFF', mime_type='audio/pcm') + blob=types.Blob(data=b"\x00\xFF", mime_type="audio/pcm") ) res_events = runner.run_live(live_request_queue) - assert res_events is not None, 'Expected a list of events, got None.' + assert res_events is not None, "Expected a list of events, got None." assert ( len(res_events) > 0 - ), 'Expected at least one response, but got an empty list.' + ), "Expected at least one response, but got an empty list." def test_live_streaming_function_call_single(): """Test live streaming with a single function call response.""" # Create a function call response function_call = types.Part.from_function_call( - name='get_weather', args={'location': 'San Francisco', 'unit': 'celsius'} + name="get_weather", args={"location": "San Francisco", "unit": "celsius"} ) # Create LLM responses: function call followed by turn completion response1 = LlmResponse( - content=types.Content(role='model', parts=[function_call]), + content=types.Content(role="model", parts=[function_call]), turn_complete=False, ) response2 = LlmResponse( @@ -73,16 +73,16 @@ def test_live_streaming_function_call_single(): mock_model = testing_utils.MockModel.create([response1, response2]) # Mock function that would be called - def get_weather(location: str, unit: str = 'celsius') -> dict: + def get_weather(location: str, unit: str = "celsius") -> dict: return { - 'temperature': 22, - 'condition': 'sunny', - 'location': location, - 'unit': unit, + "temperature": 22, + "condition": "sunny", + "location": location, + "unit": unit, } root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[get_weather], ) @@ -136,14 +136,14 @@ async def consume_responses(session: testing_utils.Session): live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( blob=types.Blob( - data=b'What is the weather in San Francisco?', mime_type='audio/pcm' + data=b"What is the weather in San Francisco?", mime_type="audio/pcm" ) ) res_events = runner.run_live(live_request_queue) - assert res_events is not None, 'Expected a list of events, got None.' - assert len(res_events) >= 1, 'Expected at least one event.' + assert res_events is not None, "Expected a list of events, got None." + assert len(res_events) >= 1, "Expected at least one event." # Check that we got a function call event function_call_found = False @@ -152,19 +152,19 @@ async def consume_responses(session: testing_utils.Session): for event in res_events: if event.content and event.content.parts: for part in event.content.parts: - if part.function_call and part.function_call.name == 'get_weather': + if part.function_call and part.function_call.name == "get_weather": function_call_found = True - assert part.function_call.args['location'] == 'San Francisco' - assert part.function_call.args['unit'] == 'celsius' + assert part.function_call.args["location"] == "San Francisco" + assert part.function_call.args["unit"] == "celsius" elif ( part.function_response - and part.function_response.name == 'get_weather' + and part.function_response.name == "get_weather" ): function_response_found = True - assert part.function_response.response['temperature'] == 22 - assert part.function_response.response['condition'] == 'sunny' + assert part.function_response.response["temperature"] == 22 + assert part.function_response.response["condition"] == "sunny" - assert function_call_found, 'Expected a function call event.' + assert function_call_found, "Expected a function call event." # Note: In live streaming, function responses might be handled differently, # so we check for the function call which is the primary indicator of function calling working @@ -173,19 +173,19 @@ def test_live_streaming_function_call_multiple(): """Test live streaming with multiple function calls in sequence.""" # Create multiple function call responses function_call1 = types.Part.from_function_call( - name='get_weather', args={'location': 'San Francisco'} + name="get_weather", args={"location": "San Francisco"} ) function_call2 = types.Part.from_function_call( - name='get_time', args={'timezone': 'PST'} + name="get_time", args={"timezone": "PST"} ) # Create LLM responses: two function calls followed by turn completion response1 = LlmResponse( - content=types.Content(role='model', parts=[function_call1]), + content=types.Content(role="model", parts=[function_call1]), turn_complete=False, ) response2 = LlmResponse( - content=types.Content(role='model', parts=[function_call2]), + content=types.Content(role="model", parts=[function_call2]), turn_complete=False, ) response3 = LlmResponse( @@ -196,13 +196,13 @@ def test_live_streaming_function_call_multiple(): # Mock functions def get_weather(location: str) -> dict: - return {'temperature': 22, 'condition': 'sunny', 'location': location} + return {"temperature": 22, "condition": "sunny", "location": location} def get_time(timezone: str) -> dict: - return {'time': '14:30', 'timezone': timezone} + return {"time": "14:30", "timezone": timezone} root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[get_weather, get_time], ) @@ -255,14 +255,14 @@ async def consume_responses(session: testing_utils.Session): live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( blob=types.Blob( - data=b'What is the weather and current time?', mime_type='audio/pcm' + data=b"What is the weather and current time?", mime_type="audio/pcm" ) ) res_events = runner.run_live(live_request_queue) - assert res_events is not None, 'Expected a list of events, got None.' - assert len(res_events) >= 1, 'Expected at least one event.' + assert res_events is not None, "Expected a list of events, got None." + assert len(res_events) >= 1, "Expected at least one event." # Check function calls weather_call_found = False @@ -272,33 +272,33 @@ async def consume_responses(session: testing_utils.Session): if event.content and event.content.parts: for part in event.content.parts: if part.function_call: - if part.function_call.name == 'get_weather': + if part.function_call.name == "get_weather": weather_call_found = True - assert part.function_call.args['location'] == 'San Francisco' - elif part.function_call.name == 'get_time': + assert part.function_call.args["location"] == "San Francisco" + elif part.function_call.name == "get_time": time_call_found = True - assert part.function_call.args['timezone'] == 'PST' + assert part.function_call.args["timezone"] == "PST" # In live streaming, we primarily check that function calls are generated correctly assert ( weather_call_found or time_call_found - ), 'Expected at least one function call.' + ), "Expected at least one function call." def test_live_streaming_function_call_parallel(): """Test live streaming with parallel function calls.""" # Create parallel function calls in the same response function_call1 = types.Part.from_function_call( - name='get_weather', args={'location': 'San Francisco'} + name="get_weather", args={"location": "San Francisco"} ) function_call2 = types.Part.from_function_call( - name='get_weather', args={'location': 'New York'} + name="get_weather", args={"location": "New York"} ) # Create LLM response with parallel function calls response1 = LlmResponse( content=types.Content( - role='model', parts=[function_call1, function_call2] + role="model", parts=[function_call1, function_call2] ), turn_complete=False, ) @@ -310,11 +310,11 @@ def test_live_streaming_function_call_parallel(): # Mock function def get_weather(location: str) -> dict: - temperatures = {'San Francisco': 22, 'New York': 15} - return {'temperature': temperatures.get(location, 20), 'location': location} + temperatures = {"San Francisco": 22, "New York": 15} + return {"temperature": temperatures.get(location, 20), "location": location} root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[get_weather], ) @@ -367,14 +367,14 @@ async def consume_responses(session: testing_utils.Session): live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( blob=types.Blob( - data=b'Compare weather in SF and NYC', mime_type='audio/pcm' + data=b"Compare weather in SF and NYC", mime_type="audio/pcm" ) ) res_events = runner.run_live(live_request_queue) - assert res_events is not None, 'Expected a list of events, got None.' - assert len(res_events) >= 1, 'Expected at least one event.' + assert res_events is not None, "Expected a list of events, got None." + assert len(res_events) >= 1, "Expected at least one event." # Check parallel function calls sf_call_found = False @@ -383,28 +383,28 @@ async def consume_responses(session: testing_utils.Session): for event in res_events: if event.content and event.content.parts: for part in event.content.parts: - if part.function_call and part.function_call.name == 'get_weather': - location = part.function_call.args['location'] - if location == 'San Francisco': + if part.function_call and part.function_call.name == "get_weather": + location = part.function_call.args["location"] + if location == "San Francisco": sf_call_found = True - elif location == 'New York': + elif location == "New York": nyc_call_found = True assert ( sf_call_found and nyc_call_found - ), 'Expected both location function calls.' + ), "Expected both location function calls." def test_live_streaming_function_call_with_error(): """Test live streaming with function call that returns an error.""" # Create a function call response function_call = types.Part.from_function_call( - name='get_weather', args={'location': 'Invalid Location'} + name="get_weather", args={"location": "Invalid Location"} ) # Create LLM responses response1 = LlmResponse( - content=types.Content(role='model', parts=[function_call]), + content=types.Content(role="model", parts=[function_call]), turn_complete=False, ) response2 = LlmResponse( @@ -415,12 +415,12 @@ def test_live_streaming_function_call_with_error(): # Mock function that returns an error for invalid locations def get_weather(location: str) -> dict: - if location == 'Invalid Location': - return {'error': 'Location not found'} - return {'temperature': 22, 'condition': 'sunny', 'location': location} + if location == "Invalid Location": + return {"error": "Location not found"} + return {"temperature": 22, "condition": "sunny", "location": location} root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[get_weather], ) @@ -473,37 +473,37 @@ async def consume_responses(session: testing_utils.Session): live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( blob=types.Blob( - data=b'What is weather in Invalid Location?', mime_type='audio/pcm' + data=b"What is weather in Invalid Location?", mime_type="audio/pcm" ) ) res_events = runner.run_live(live_request_queue) - assert res_events is not None, 'Expected a list of events, got None.' - assert len(res_events) >= 1, 'Expected at least one event.' + assert res_events is not None, "Expected a list of events, got None." + assert len(res_events) >= 1, "Expected at least one event." # Check that we got the function call (error handling happens at execution time) function_call_found = False for event in res_events: if event.content and event.content.parts: for part in event.content.parts: - if part.function_call and part.function_call.name == 'get_weather': + if part.function_call and part.function_call.name == "get_weather": function_call_found = True - assert part.function_call.args['location'] == 'Invalid Location' + assert part.function_call.args["location"] == "Invalid Location" - assert function_call_found, 'Expected function call event with error case.' + assert function_call_found, "Expected function call event with error case." def test_live_streaming_function_call_sync_tool(): """Test live streaming with synchronous function call.""" # Create a function call response function_call = types.Part.from_function_call( - name='calculate', args={'x': 5, 'y': 3} + name="calculate", args={"x": 5, "y": 3} ) # Create LLM responses response1 = LlmResponse( - content=types.Content(role='model', parts=[function_call]), + content=types.Content(role="model", parts=[function_call]), turn_complete=False, ) response2 = LlmResponse( @@ -514,10 +514,10 @@ def test_live_streaming_function_call_sync_tool(): # Mock sync function def calculate(x: int, y: int) -> dict: - return {'result': x + y, 'operation': 'addition'} + return {"result": x + y, "operation": "addition"} root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[calculate], ) @@ -569,37 +569,37 @@ async def consume_responses(session: testing_utils.Session): runner = CustomTestRunner(root_agent=root_agent) live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( - blob=types.Blob(data=b'Calculate 5 plus 3', mime_type='audio/pcm') + blob=types.Blob(data=b"Calculate 5 plus 3", mime_type="audio/pcm") ) res_events = runner.run_live(live_request_queue) - assert res_events is not None, 'Expected a list of events, got None.' - assert len(res_events) >= 1, 'Expected at least one event.' + assert res_events is not None, "Expected a list of events, got None." + assert len(res_events) >= 1, "Expected at least one event." # Check function call function_call_found = False for event in res_events: if event.content and event.content.parts: for part in event.content.parts: - if part.function_call and part.function_call.name == 'calculate': + if part.function_call and part.function_call.name == "calculate": function_call_found = True - assert part.function_call.args['x'] == 5 - assert part.function_call.args['y'] == 3 + assert part.function_call.args["x"] == 5 + assert part.function_call.args["y"] == 3 - assert function_call_found, 'Expected calculate function call event.' + assert function_call_found, "Expected calculate function call event." def test_live_streaming_simple_streaming_tool(): """Test live streaming with a simple streaming tool (non-video).""" # Create a function call response for the streaming tool function_call = types.Part.from_function_call( - name='monitor_stock_price', args={'stock_symbol': 'AAPL'} + name="monitor_stock_price", args={"stock_symbol": "AAPL"} ) # Create LLM responses response1 = LlmResponse( - content=types.Content(role='model', parts=[function_call]), + content=types.Content(role="model", parts=[function_call]), turn_complete=False, ) response2 = LlmResponse( @@ -612,18 +612,18 @@ def test_live_streaming_simple_streaming_tool(): async def monitor_stock_price(stock_symbol: str): """Mock streaming tool that monitors stock prices.""" # Simulate some streaming updates - yield f'Stock {stock_symbol} price: $150' + yield f"Stock {stock_symbol} price: $150" await asyncio.sleep(0.1) - yield f'Stock {stock_symbol} price: $155' + yield f"Stock {stock_symbol} price: $155" await asyncio.sleep(0.1) - yield f'Stock {stock_symbol} price: $160' + yield f"Stock {stock_symbol} price: $160" def stop_streaming(function_name: str): """Stop the streaming tool.""" pass root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[monitor_stock_price, stop_streaming], ) @@ -675,13 +675,13 @@ async def consume_responses(session: testing_utils.Session): runner = CustomTestRunner(root_agent=root_agent) live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( - blob=types.Blob(data=b'Monitor AAPL stock price', mime_type='audio/pcm') + blob=types.Blob(data=b"Monitor AAPL stock price", mime_type="audio/pcm") ) res_events = runner.run_live(live_request_queue) - assert res_events is not None, 'Expected a list of events, got None.' - assert len(res_events) >= 1, 'Expected at least one event.' + assert res_events is not None, "Expected a list of events, got None." + assert len(res_events) >= 1, "Expected at least one event." # Check that we got the streaming tool function call function_call_found = False @@ -690,26 +690,26 @@ async def consume_responses(session: testing_utils.Session): for part in event.content.parts: if ( part.function_call - and part.function_call.name == 'monitor_stock_price' + and part.function_call.name == "monitor_stock_price" ): function_call_found = True - assert part.function_call.args['stock_symbol'] == 'AAPL' + assert part.function_call.args["stock_symbol"] == "AAPL" assert ( function_call_found - ), 'Expected monitor_stock_price function call event.' + ), "Expected monitor_stock_price function call event." def test_live_streaming_video_streaming_tool(): """Test live streaming with a video streaming tool.""" # Create a function call response for the video streaming tool function_call = types.Part.from_function_call( - name='monitor_video_stream', args={} + name="monitor_video_stream", args={} ) # Create LLM responses response1 = LlmResponse( - content=types.Content(role='model', parts=[function_call]), + content=types.Content(role="model", parts=[function_call]), turn_complete=False, ) response2 = LlmResponse( @@ -727,13 +727,13 @@ async def monitor_video_stream(input_stream: LiveRequestQueue): try: # Try to get a frame from the queue with timeout live_req = await asyncio.wait_for(input_stream.get(), timeout=0.1) - if live_req.blob and live_req.blob.mime_type == 'image/jpeg': + if live_req.blob and live_req.blob.mime_type == "image/jpeg": frame_count += 1 - yield f'Processed frame {frame_count}: detected 2 people' + yield f"Processed frame {frame_count}: detected 2 people" except asyncio.TimeoutError: # No more frames, simulate detection anyway for testing frame_count += 1 - yield f'Simulated frame {frame_count}: detected 1 person' + yield f"Simulated frame {frame_count}: detected 1 person" await asyncio.sleep(0.1) def stop_streaming(function_name: str): @@ -741,7 +741,7 @@ def stop_streaming(function_name: str): pass root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[monitor_video_stream, stop_streaming], ) @@ -795,19 +795,19 @@ async def consume_responses(session: testing_utils.Session): # Send some mock video frames live_request_queue.send_realtime( - blob=types.Blob(data=b'fake_jpeg_data_1', mime_type='image/jpeg') + blob=types.Blob(data=b"fake_jpeg_data_1", mime_type="image/jpeg") ) live_request_queue.send_realtime( - blob=types.Blob(data=b'fake_jpeg_data_2', mime_type='image/jpeg') + blob=types.Blob(data=b"fake_jpeg_data_2", mime_type="image/jpeg") ) live_request_queue.send_realtime( - blob=types.Blob(data=b'Monitor video stream', mime_type='audio/pcm') + blob=types.Blob(data=b"Monitor video stream", mime_type="audio/pcm") ) res_events = runner.run_live(live_request_queue) - assert res_events is not None, 'Expected a list of events, got None.' - assert len(res_events) >= 1, 'Expected at least one event.' + assert res_events is not None, "Expected a list of events, got None." + assert len(res_events) >= 1, "Expected at least one event." # Check that we got the video streaming tool function call function_call_found = False @@ -816,32 +816,32 @@ async def consume_responses(session: testing_utils.Session): for part in event.content.parts: if ( part.function_call - and part.function_call.name == 'monitor_video_stream' + and part.function_call.name == "monitor_video_stream" ): function_call_found = True assert ( function_call_found - ), 'Expected monitor_video_stream function call event.' + ), "Expected monitor_video_stream function call event." def test_live_streaming_stop_streaming_tool(): """Test live streaming with stop_streaming functionality.""" # Create function calls for starting and stopping a streaming tool start_function_call = types.Part.from_function_call( - name='monitor_stock_price', args={'stock_symbol': 'TSLA'} + name="monitor_stock_price", args={"stock_symbol": "TSLA"} ) stop_function_call = types.Part.from_function_call( - name='stop_streaming', args={'function_name': 'monitor_stock_price'} + name="stop_streaming", args={"function_name": "monitor_stock_price"} ) # Create LLM responses: start streaming, then stop streaming response1 = LlmResponse( - content=types.Content(role='model', parts=[start_function_call]), + content=types.Content(role="model", parts=[start_function_call]), turn_complete=False, ) response2 = LlmResponse( - content=types.Content(role='model', parts=[stop_function_call]), + content=types.Content(role="model", parts=[stop_function_call]), turn_complete=False, ) response3 = LlmResponse( @@ -853,17 +853,17 @@ def test_live_streaming_stop_streaming_tool(): # Mock streaming tool and stop function async def monitor_stock_price(stock_symbol: str): """Mock streaming tool that monitors stock prices.""" - yield f'Started monitoring {stock_symbol}' + yield f"Started monitoring {stock_symbol}" while True: # Infinite stream (would be stopped by stop_streaming) - yield f'Stock {stock_symbol} price update' + yield f"Stock {stock_symbol} price update" await asyncio.sleep(0.1) def stop_streaming(function_name: str): """Stop the streaming tool.""" - return f'Stopped streaming for {function_name}' + return f"Stopped streaming for {function_name}" root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[monitor_stock_price, stop_streaming], ) @@ -915,13 +915,13 @@ async def consume_responses(session: testing_utils.Session): runner = CustomTestRunner(root_agent=root_agent) live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( - blob=types.Blob(data=b'Monitor TSLA and then stop', mime_type='audio/pcm') + blob=types.Blob(data=b"Monitor TSLA and then stop", mime_type="audio/pcm") ) res_events = runner.run_live(live_request_queue) - assert res_events is not None, 'Expected a list of events, got None.' - assert len(res_events) >= 1, 'Expected at least one event.' + assert res_events is not None, "Expected a list of events, got None." + assert len(res_events) >= 1, "Expected at least one event." # Check that we got both function calls monitor_call_found = False @@ -931,34 +931,34 @@ async def consume_responses(session: testing_utils.Session): if event.content and event.content.parts: for part in event.content.parts: if part.function_call: - if part.function_call.name == 'monitor_stock_price': + if part.function_call.name == "monitor_stock_price": monitor_call_found = True - assert part.function_call.args['stock_symbol'] == 'TSLA' - elif part.function_call.name == 'stop_streaming': + assert part.function_call.args["stock_symbol"] == "TSLA" + elif part.function_call.name == "stop_streaming": stop_call_found = True assert ( - part.function_call.args['function_name'] - == 'monitor_stock_price' + part.function_call.args["function_name"] + == "monitor_stock_price" ) - assert monitor_call_found, 'Expected monitor_stock_price function call event.' - assert stop_call_found, 'Expected stop_streaming function call event.' + assert monitor_call_found, "Expected monitor_stock_price function call event." + assert stop_call_found, "Expected stop_streaming function call event." def test_live_streaming_multiple_streaming_tools(): """Test live streaming with multiple streaming tools running simultaneously.""" # Create function calls for multiple streaming tools stock_function_call = types.Part.from_function_call( - name='monitor_stock_price', args={'stock_symbol': 'NVDA'} + name="monitor_stock_price", args={"stock_symbol": "NVDA"} ) video_function_call = types.Part.from_function_call( - name='monitor_video_stream', args={} + name="monitor_video_stream", args={} ) # Create LLM responses: start both streaming tools response1 = LlmResponse( content=types.Content( - role='model', parts=[stock_function_call, video_function_call] + role="model", parts=[stock_function_call, video_function_call] ), turn_complete=False, ) @@ -971,22 +971,22 @@ def test_live_streaming_multiple_streaming_tools(): # Mock streaming tools async def monitor_stock_price(stock_symbol: str): """Mock streaming tool that monitors stock prices.""" - yield f'Stock {stock_symbol} price: $800' + yield f"Stock {stock_symbol} price: $800" await asyncio.sleep(0.1) - yield f'Stock {stock_symbol} price: $805' + yield f"Stock {stock_symbol} price: $805" async def monitor_video_stream(input_stream: LiveRequestQueue): """Mock video streaming tool.""" - yield 'Video monitoring started' + yield "Video monitoring started" await asyncio.sleep(0.1) - yield 'Detected motion in video stream' + yield "Detected motion in video stream" def stop_streaming(function_name: str): """Stop the streaming tool.""" pass root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[monitor_stock_price, monitor_video_stream, stop_streaming], ) @@ -1039,14 +1039,14 @@ async def consume_responses(session: testing_utils.Session): live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( blob=types.Blob( - data=b'Monitor both stock and video', mime_type='audio/pcm' + data=b"Monitor both stock and video", mime_type="audio/pcm" ) ) res_events = runner.run_live(live_request_queue) - assert res_events is not None, 'Expected a list of events, got None.' - assert len(res_events) >= 1, 'Expected at least one event.' + assert res_events is not None, "Expected a list of events, got None." + assert len(res_events) >= 1, "Expected at least one event." # Check that we got both streaming tool function calls stock_call_found = False @@ -1056,39 +1056,38 @@ async def consume_responses(session: testing_utils.Session): if event.content and event.content.parts: for part in event.content.parts: if part.function_call: - if part.function_call.name == 'monitor_stock_price': + if part.function_call.name == "monitor_stock_price": stock_call_found = True - assert part.function_call.args['stock_symbol'] == 'NVDA' - elif part.function_call.name == 'monitor_video_stream': + assert part.function_call.args["stock_symbol"] == "NVDA" + elif part.function_call.name == "monitor_video_stream": video_call_found = True - assert stock_call_found, 'Expected monitor_stock_price function call event.' - assert video_call_found, 'Expected monitor_video_stream function call event.' + assert stock_call_found, "Expected monitor_stock_price function call event." + assert video_call_found, "Expected monitor_video_stream function call event." -def test_live_streaming_buffered_function_call_yielded_during_transcription(): - """Test that function calls buffered during transcription are yielded. +def test_live_streaming_function_call_yielded_before_finished_transcription(): + """Test that function calls arriving during live transcription are yielded immediately. - This tests the fix for the bug where function_call and function_response - events were buffered during active transcription but never yielded to the - caller. The fix ensures buffered events are yielded after transcription ends. + This verifies that tool call events are not buffered and are permitted to + arrive in the stream before the final completed transcription event. """ function_call = types.Part.from_function_call( - name='get_weather', args={'location': 'San Francisco'} + name="get_weather", args={"location": "San Francisco"} ) response1 = LlmResponse( - input_transcription=types.Transcription(text='Show'), + input_transcription=types.Transcription(text="Show"), partial=True, # ← Triggers is_transcribing = True ) response2 = LlmResponse( content=types.Content( - role='model', parts=[function_call] + role="model", parts=[function_call] ), # ← Gets buffered turn_complete=False, ) response3 = LlmResponse( - input_transcription=types.Transcription(text='Show me the weather'), + input_transcription=types.Transcription(text="Show me the weather"), partial=False, # ← Transcription ends, buffered events yielded ) response4 = LlmResponse( @@ -1100,10 +1099,10 @@ def test_live_streaming_buffered_function_call_yielded_during_transcription(): ) def get_weather(location: str) -> dict: - return {'temperature': 22, 'location': location} + return {"temperature": 22, "location": location} root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[get_weather], ) @@ -1154,41 +1153,50 @@ async def consume_responses(session: testing_utils.Session): runner = CustomTestRunner(root_agent=root_agent) live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( - blob=types.Blob(data=b'Show me the weather', mime_type='audio/pcm') + blob=types.Blob(data=b"Show me the weather", mime_type="audio/pcm") ) res_events = runner.run_live(live_request_queue) - assert res_events is not None, 'Expected a list of events, got None.' - assert len(res_events) >= 1, 'Expected at least one event.' + assert res_events is not None, "Expected a list of events, got None." + assert len(res_events) >= 1, "Expected at least one event." - function_call_found = False - function_response_found = False + function_call_index = -1 + finished_transcription_index = -1 - for event in res_events: + for idx, event in enumerate(res_events): if event.content and event.content.parts: for part in event.content.parts: - if part.function_call and part.function_call.name == 'get_weather': - function_call_found = True - assert part.function_call.args['location'] == 'San Francisco' + if part.function_call and part.function_call.name == "get_weather": + function_call_index = idx + assert part.function_call.args["location"] == "San Francisco" if ( part.function_response - and part.function_response.name == 'get_weather' + and part.function_response.name == "get_weather" ): - function_response_found = True - assert part.function_response.response['temperature'] == 22 + assert part.function_response.response["temperature"] == 22 + if ( + event.input_transcription + and event.input_transcription.text == "Show me the weather" + ): + finished_transcription_index = idx - assert function_call_found, 'Buffered function_call event was not yielded.' + assert function_call_index != -1, "Function call event was not yielded." assert ( - function_response_found - ), 'Buffered function_response event was not yielded.' + finished_transcription_index != -1 + ), "Finished transcription event was not yielded." + assert function_call_index < finished_transcription_index, ( + f"Expected function call (at index {function_call_index}) to arrive" + " before finished transcription (at index" + f" {finished_transcription_index})." + ) def test_live_streaming_text_content_persisted_in_session(): """Test that user text content sent via send_content is persisted in session.""" response1 = LlmResponse( content=types.Content( - role='model', parts=[types.Part(text='Hello! How can I help you?')] + role="model", parts=[types.Part(text="Hello! How can I help you?")] ), turn_complete=True, ) @@ -1196,7 +1204,7 @@ def test_live_streaming_text_content_persisted_in_session(): mock_model = testing_utils.MockModel.create([response1]) root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[], ) @@ -1253,19 +1261,19 @@ async def consume_responses(session: testing_utils.Session): live_request_queue = LiveRequestQueue() # Send text content (not audio blob) - user_text = 'Hello, this is a test message' + user_text = "Hello, this is a test message" live_request_queue.send_content( - types.Content(role='user', parts=[types.Part(text=user_text)]) + types.Content(role="user", parts=[types.Part(text=user_text)]) ) res_events, session = runner.run_live_and_get_session(live_request_queue) - assert res_events is not None, 'Expected a list of events, got None.' + assert res_events is not None, "Expected a list of events, got None." # Check that user text content was persisted in the session user_content_found = False for event in session.events: - if event.author == 'user' and event.content: + if event.author == "user" and event.content: for part in event.content.parts: if part.text and user_text in part.text: user_content_found = True @@ -1273,7 +1281,7 @@ async def consume_responses(session: testing_utils.Session): assert user_content_found, ( f'Expected user text content "{user_text}" to be persisted in session. ' - f'Session events: {[e.content for e in session.events]}' + f"Session events: {[e.content for e in session.events]}" ) @@ -1328,16 +1336,16 @@ def test_input_streaming_tool_registered_lazily_with_stream(): # tool is NOT registered before the model calls it. text_response = LlmResponse( content=types.Content( - role='model', - parts=[types.Part(text='Processing...')], + role="model", + parts=[types.Part(text="Processing...")], ), turn_complete=False, ) function_call = types.Part.from_function_call( - name='monitor_video_stream', args={} + name="monitor_video_stream", args={} ) call_response = LlmResponse( - content=types.Content(role='model', parts=[function_call]), + content=types.Content(role="model", parts=[function_call]), turn_complete=False, ) done_response = LlmResponse(turn_complete=True) @@ -1354,10 +1362,10 @@ async def monitor_video_stream( """Record whether input_stream was provided.""" nonlocal stream_state_during_call stream_state_during_call = input_stream is not None - yield 'monitoring started' + yield "monitoring started" root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[monitor_video_stream], ) @@ -1378,7 +1386,7 @@ def capturing_method(*args, **kwargs) -> Any: live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( - blob=types.Blob(data=b'test_data', mime_type='audio/pcm') + blob=types.Blob(data=b"test_data", mime_type="audio/pcm") ) # Collect events and check that the tool is NOT registered before @@ -1403,7 +1411,7 @@ async def consume(session: testing_utils.Session): and not response.get_function_calls() ): not_registered_before_call = ( - active is None or 'monitor_video_stream' not in active + active is None or "monitor_video_stream" not in active ) if len(collected) >= 4: return @@ -1413,28 +1421,28 @@ async def consume(session: testing_utils.Session): # Tool should not be registered before the model calls it. assert ( not_registered_before_call is True - ), 'Expected tool to NOT be registered before the model calls it' + ), "Expected tool to NOT be registered before the model calls it" # When the model calls the tool, input_stream should be provided. assert ( stream_state_during_call is True - ), 'Expected input_stream to be provided to the streaming tool when called' + ), "Expected input_stream to be provided to the streaming tool when called" def test_stop_streaming_resets_stream_to_none(): """Test that stop_streaming sets stream back to None.""" start_call = types.Part.from_function_call( - name='monitor_stock_price', args={'stock_symbol': 'GOOG'} + name="monitor_stock_price", args={"stock_symbol": "GOOG"} ) stop_call = types.Part.from_function_call( - name='stop_streaming', args={'function_name': 'monitor_stock_price'} + name="stop_streaming", args={"function_name": "monitor_stock_price"} ) response1 = LlmResponse( - content=types.Content(role='model', parts=[start_call]), + content=types.Content(role="model", parts=[start_call]), turn_complete=False, ) response2 = LlmResponse( - content=types.Content(role='model', parts=[stop_call]), + content=types.Content(role="model", parts=[stop_call]), turn_complete=False, ) response3 = LlmResponse(turn_complete=True) @@ -1445,17 +1453,17 @@ async def monitor_stock_price( stock_symbol: str, ) -> AsyncGenerator[str, None]: """Yield periodic price updates for the given stock symbol.""" - yield f'Monitoring {stock_symbol}' + yield f"Monitoring {stock_symbol}" while True: await asyncio.sleep(0.1) - yield f'{stock_symbol} price update' + yield f"{stock_symbol} price update" def stop_streaming(function_name: str) -> None: """Stop a running streaming tool by name.""" pass root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[monitor_stock_price, stop_streaming], ) @@ -1479,7 +1487,7 @@ def capturing_create(*args, **kwargs) -> Any: live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( - blob=types.Blob(data=b'Monitor GOOG then stop', mime_type='audio/pcm') + blob=types.Blob(data=b"Monitor GOOG then stop", mime_type="audio/pcm") ) res_events = runner.run_live(live_request_queue, max_responses=4) @@ -1487,32 +1495,32 @@ def capturing_create(*args, **kwargs) -> Any: # Verify both function calls were processed. call_names = _collect_function_call_names(res_events) assert ( - 'monitor_stock_price' in call_names - ), 'Expected monitor_stock_price function call.' + "monitor_stock_price" in call_names + ), "Expected monitor_stock_price function call." assert ( - 'stop_streaming' in call_names - ), 'Expected stop_streaming function call.' + "stop_streaming" in call_names + ), "Expected stop_streaming function call." # Verify that stop_streaming reset the stream to None. assert ( captured_child_context is not None - ), 'Expected child invocation context to be captured' + ), "Expected child invocation context to be captured" active_tools = captured_child_context.active_streaming_tools or {} assert ( - 'monitor_stock_price' in active_tools - ), 'Expected monitor_stock_price in active_streaming_tools' + "monitor_stock_price" in active_tools + ), "Expected monitor_stock_price in active_streaming_tools" assert ( - active_tools['monitor_stock_price'].stream is None - ), 'Expected stream to be reset to None after stop_streaming' + active_tools["monitor_stock_price"].stream is None + ), "Expected stream to be reset to None after stop_streaming" def test_output_streaming_tool_registered_lazily_without_stream(): """Test that output-streaming tools are registered lazily when called, with stream=None.""" function_call = types.Part.from_function_call( - name='monitor_stock_price', args={'stock_symbol': 'GOOG'} + name="monitor_stock_price", args={"stock_symbol": "GOOG"} ) response1 = LlmResponse( - content=types.Content(role='model', parts=[function_call]), + content=types.Content(role="model", parts=[function_call]), turn_complete=False, ) response2 = LlmResponse(turn_complete=True) @@ -1523,10 +1531,10 @@ async def monitor_stock_price( stock_symbol: str, ) -> AsyncGenerator[str, None]: """Yield periodic price updates.""" - yield f'price for {stock_symbol}' + yield f"price for {stock_symbol}" root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[monitor_stock_price], ) @@ -1548,7 +1556,7 @@ def capturing_create(*args, **kwargs) -> Any: live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( - blob=types.Blob(data=b'test', mime_type='audio/pcm') + blob=types.Blob(data=b"test", mime_type="audio/pcm") ) runner.run_live(live_request_queue, max_responses=3) @@ -1558,11 +1566,11 @@ def capturing_create(*args, **kwargs) -> Any: assert captured_child_context is not None active_tools = captured_child_context.active_streaming_tools or {} assert ( - 'monitor_stock_price' in active_tools - ), 'Expected output-streaming tool to be registered when called' + "monitor_stock_price" in active_tools + ), "Expected output-streaming tool to be registered when called" assert ( - active_tools['monitor_stock_price'].stream is None - ), 'Expected stream to be None for output-streaming tool' + active_tools["monitor_stock_price"].stream is None + ), "Expected stream to be None for output-streaming tool" def _run_single_tool_live( @@ -1581,7 +1589,7 @@ def _run_single_tool_live( name=func_name, args=func_args or {} ) response1 = LlmResponse( - content=types.Content(role='model', parts=[function_call]), + content=types.Content(role="model", parts=[function_call]), turn_complete=False, ) response2 = LlmResponse(turn_complete=True) @@ -1589,7 +1597,7 @@ def _run_single_tool_live( mock_model = testing_utils.MockModel.create([response1, response2]) root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[tool_func], ) @@ -1609,7 +1617,7 @@ def capturing_create(*args, **kwargs) -> Any: live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( - blob=types.Blob(data=b'test', mime_type='audio/pcm') + blob=types.Blob(data=b"test", mime_type="audio/pcm") ) runner.run_live(live_request_queue, max_responses=max_responses) @@ -1625,42 +1633,42 @@ async def monitor_video_stream( input_stream: LiveRequestQueue, ) -> AsyncGenerator[str, None]: """Simulate an input-streaming tool.""" - yield 'started' + yield "started" active_tools = _run_single_tool_live( - monitor_video_stream, 'monitor_video_stream' + monitor_video_stream, "monitor_video_stream" ) assert ( - 'monitor_video_stream' in active_tools - ), 'Expected input-streaming tool to be registered when called' + "monitor_video_stream" in active_tools + ), "Expected input-streaming tool to be registered when called" # Stream should be a LiveRequestQueue, not None. assert ( - active_tools['monitor_video_stream'].stream is not None - ), 'Expected .stream to be set for input-streaming tool' + active_tools["monitor_video_stream"].stream is not None + ), "Expected .stream to be set for input-streaming tool" assert isinstance( - active_tools['monitor_video_stream'].stream, LiveRequestQueue - ), 'Expected .stream to be a LiveRequestQueue instance' + active_tools["monitor_video_stream"].stream, LiveRequestQueue + ), "Expected .stream to be a LiveRequestQueue instance" def test_input_streaming_tool_stream_recreated_after_stop(): """Test that re-invoking an input-streaming tool after stop creates a new stream.""" - start_call = types.Part.from_function_call(name='monitor_video', args={}) + start_call = types.Part.from_function_call(name="monitor_video", args={}) stop_call = types.Part.from_function_call( - name='stop_streaming', args={'function_name': 'monitor_video'} + name="stop_streaming", args={"function_name": "monitor_video"} ) - restart_call = types.Part.from_function_call(name='monitor_video', args={}) + restart_call = types.Part.from_function_call(name="monitor_video", args={}) response1 = LlmResponse( - content=types.Content(role='model', parts=[start_call]), + content=types.Content(role="model", parts=[start_call]), turn_complete=False, ) response2 = LlmResponse( - content=types.Content(role='model', parts=[stop_call]), + content=types.Content(role="model", parts=[stop_call]), turn_complete=False, ) response3 = LlmResponse( - content=types.Content(role='model', parts=[restart_call]), + content=types.Content(role="model", parts=[restart_call]), turn_complete=False, ) response4 = LlmResponse(turn_complete=True) @@ -1677,17 +1685,17 @@ async def monitor_video( """Simulate an input-streaming tool that tracks invocation count.""" nonlocal call_count call_count += 1 - yield f'started (call {call_count})' + yield f"started (call {call_count})" while True: await asyncio.sleep(0.1) - yield 'frame' + yield "frame" def stop_streaming(function_name: str) -> None: """Stop a running streaming tool by name.""" pass root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[monitor_video, stop_streaming], ) @@ -1707,7 +1715,7 @@ def capturing_create(*args, **kwargs) -> Any: live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( - blob=types.Blob(data=b'test', mime_type='audio/pcm') + blob=types.Blob(data=b"test", mime_type="audio/pcm") ) res_events = runner.run_live(live_request_queue, max_responses=8) @@ -1719,16 +1727,16 @@ def capturing_create(*args, **kwargs) -> Any: fc.name for event in res_events for fc in event.get_function_calls() ] assert ( - call_names.count('monitor_video') >= 2 - ), f'Expected monitor_video called at least twice, got: {call_names}' + call_names.count("monitor_video") >= 2 + ), f"Expected monitor_video called at least twice, got: {call_names}" # After re-invocation, stream should be set again (not None). assert captured_child_context is not None active_tools = captured_child_context.active_streaming_tools or {} - assert 'monitor_video' in active_tools + assert "monitor_video" in active_tools assert ( - active_tools['monitor_video'].stream is not None - ), 'Expected .stream to be recreated after stop + re-invocation' + active_tools["monitor_video"].stream is not None + ), "Expected .stream to be recreated after stop + re-invocation" def test_async_gen_with_input_stream_wrong_annotation_gets_no_stream(): @@ -1739,22 +1747,22 @@ async def my_tool(input_stream: str) -> AsyncGenerator[str, None]: """Simulate an async generator whose input_stream is typed as str.""" nonlocal received_input_stream received_input_stream = input_stream - yield f'got: {input_stream}' + yield f"got: {input_stream}" active_tools = _run_single_tool_live( - my_tool, 'my_tool', func_args={'input_stream': 'some_value'} + my_tool, "my_tool", func_args={"input_stream": "some_value"} ) assert ( - 'my_tool' in active_tools - ), 'Expected async generator tool to be registered' + "my_tool" in active_tools + ), "Expected async generator tool to be registered" # Stream should be None because annotation is str, not LiveRequestQueue. - assert active_tools['my_tool'].stream is None, ( - 'Expected .stream to be None when input_stream annotation is not' - ' LiveRequestQueue' + assert active_tools["my_tool"].stream is None, ( + "Expected .stream to be None when input_stream annotation is not" + " LiveRequestQueue" ) # The tool should have received the model-provided arg value, not a # LiveRequestQueue. assert ( - received_input_stream == 'some_value' - ), 'Expected input_stream to be the model-provided string value' + received_input_stream == "some_value" + ), "Expected input_stream to be the model-provided string value" diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py index a94b2eb8852..f7e16014ffc 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py @@ -588,6 +588,9 @@ async def test_create_session_cleans_up_without_aclose_if_loop_is_different( self, ): """Verify that sessions from different loops are cleaned up without calling aclose().""" + from google.adk.features import FeatureName + from google.adk.features._feature_registry import temporary_feature_override + manager = MCPSessionManager(self.mock_stdio_connection_params) # 1. Simulate a session created in a "different" loop @@ -617,8 +620,11 @@ async def test_create_session_cleans_up_without_aclose_if_loop_is_different( mock_wait_for.return_value = new_session mock_session_context_class.return_value = AsyncMock() - # 3. Call create_session - session = await manager.create_session() + # 3. Call create_session with flag off to hit wait_for branch + with temporary_feature_override( + FeatureName._MCP_GRACEFUL_ERROR_HANDLING, False + ): + session = await manager.create_session() # 4. Verify results assert session == new_session @@ -969,8 +975,8 @@ class TestMCPGracefulErrorHandlingFlagContract: loudly so we don't silently break GE's rollout. """ - def test_default_state_is_off_so_cl_is_a_noop(self): - """The CL must be a no-op until GE explicitly enables it.""" + def test_default_state_is_on(self): + """The fix must be enabled by default.""" import os from google.adk.features import FeatureName @@ -981,34 +987,34 @@ def test_default_state_is_off_so_cl_is_a_noop(self): saved = {k: os.environ.pop(k) for k in (enable, disable) if k in os.environ} try: assert ( - is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING) is False + is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING) is True ) finally: os.environ.update(saved) - def test_env_var_enable_flips_flag_on_at_runtime(self): - """The env var GE will set must turn the fix on without a rebuild.""" + def test_env_var_disable_flips_flag_off_at_runtime(self): + """The env var must turn the fix off without a rebuild.""" import os from google.adk.features import FeatureName from google.adk.features import is_feature_enabled - enable = "ADK_ENABLE_MCP_GRACEFUL_ERROR_HANDLING" - saved = os.environ.pop(enable, None) + disable = "ADK_DISABLE_MCP_GRACEFUL_ERROR_HANDLING" + saved = os.environ.pop(disable, None) try: - os.environ[enable] = "1" + os.environ[disable] = "1" assert ( - is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING) is True + is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING) is False ) # And once it's removed, we revert. Confirms the value is read # live from os.environ on every call (no caching, no binary push). - del os.environ[enable] + del os.environ[disable] assert ( - is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING) is False + is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING) is True ) finally: if saved is not None: - os.environ[enable] = saved + os.environ[disable] = saved def test_env_var_disable_acts_as_kill_switch(self): """The disable env var lets consumers turn off without a rebuild.""" diff --git a/tests/unittests/utils/test_model_name_utils.py b/tests/unittests/utils/test_model_name_utils.py index bb2654c3db2..f962fb6f5c2 100644 --- a/tests/unittests/utils/test_model_name_utils.py +++ b/tests/unittests/utils/test_model_name_utils.py @@ -16,6 +16,7 @@ from google.adk.utils.model_name_utils import extract_model_name from google.adk.utils.model_name_utils import is_gemini_1_model +from google.adk.utils.model_name_utils import is_gemini_3_1_flash_live from google.adk.utils.model_name_utils import is_gemini_eap_or_2_or_above from google.adk.utils.model_name_utils import is_gemini_model from google.adk.utils.model_name_utils import is_gemini_model_id_check_disabled @@ -338,3 +339,28 @@ def test_default_is_disabled(self, monkeypatch): def test_true_enables_check_bypass(self, monkeypatch): monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') assert is_gemini_model_id_check_disabled() is True + + +class TestIsGemini31FlashLive: + """Test the is_gemini_3_1_flash_live function.""" + + def test_is_gemini_3_1_flash_live_simple_name(self): + """Test with simple model name format.""" + assert is_gemini_3_1_flash_live('gemini-3.1-flash-live') is True + assert is_gemini_3_1_flash_live('gemini-3.1-flash-live-preview') is True + assert is_gemini_3_1_flash_live('gemini-3.1-pro-live') is False + assert is_gemini_3_1_flash_live('gemini-2.5-flash-live') is False + + def test_is_gemini_3_1_flash_live_path_based_name(self): + """Test with path-based format (Vertex AI etc.).""" + vertex_path = 'projects/123/locations/us-central1/publishers/google/models/gemini-3.1-flash-live' + assert is_gemini_3_1_flash_live(vertex_path) is True + vertex_path_preview = 'projects/123/locations/us-central1/publishers/google/models/gemini-3.1-flash-live-preview' + assert is_gemini_3_1_flash_live(vertex_path_preview) is True + non_live_path = 'projects/123/locations/us-central1/publishers/google/models/gemini-3.1-flash' + assert is_gemini_3_1_flash_live(non_live_path) is False + + def test_is_gemini_3_1_flash_live_edge_cases(self): + """Test edge cases.""" + assert is_gemini_3_1_flash_live(None) is False + assert is_gemini_3_1_flash_live('') is False diff --git a/tests/unittests/utils/test_streaming_utils.py b/tests/unittests/utils/test_streaming_utils.py index 6b68789bf08..4cb81ed9ba4 100644 --- a/tests/unittests/utils/test_streaming_utils.py +++ b/tests/unittests/utils/test_streaming_utils.py @@ -184,25 +184,105 @@ async def test_close_with_error(self): assert closed_response.error_message == "Recitation error" @pytest.mark.asyncio - async def test_process_response_with_none_content(self): - """Test that StreamingResponseAggregator handles content=None.""" - aggregator = streaming_utils.StreamingResponseAggregator() - response = types.GenerateContentResponse( - candidates=[ - types.Candidate( - content=types.Content(parts=[]), - finish_reason=types.FinishReason.STOP, - ) - ] - ) - results = [] - async for r in aggregator.process_response(response): - results.append(r) - assert len(results) == 1 - assert results[0].content is not None + @pytest.mark.parametrize("use_progressive_sse", [True, False]) + async def test_empty_content_produces_empty_final_frame( + self, use_progressive_sse + ): + """A candidate with an empty parts list produces an empty final frame.""" + with temporary_feature_override( + FeatureName.PROGRESSIVE_SSE_STREAMING, use_progressive_sse + ): + aggregator = streaming_utils.StreamingResponseAggregator() + response = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content(parts=[]), + finish_reason=types.FinishReason.STOP, + ) + ] + ) + results = [] + async for r in aggregator.process_response(response): + results.append(r) + closed_response = aggregator.close() - closed_response = aggregator.close() - assert closed_response is None + assert len(results) == 1 + assert results[0].content is not None + assert closed_response is not None + assert closed_response.partial is False + assert closed_response.content is None + assert closed_response.finish_reason == types.FinishReason.STOP + + @pytest.mark.asyncio + @pytest.mark.parametrize("use_progressive_sse", [True, False]) + async def test_prompt_feedback_block_returns_error_frame( + self, use_progressive_sse + ): + """A prompt-level safety block produces a final frame with the error code.""" + with temporary_feature_override( + FeatureName.PROGRESSIVE_SSE_STREAMING, use_progressive_sse + ): + aggregator = streaming_utils.StreamingResponseAggregator() + response = types.GenerateContentResponse( + prompt_feedback=types.GenerateContentResponsePromptFeedback( + block_reason=types.BlockedReason.SAFETY, + block_reason_message="Blocked by safety", + ) + ) + results = [] + async for r in aggregator.process_response(response): + results.append(r) + closed_response = aggregator.close() + + assert len(results) == 1 + assert closed_response is not None + assert closed_response.partial is False + assert closed_response.error_code == types.BlockedReason.SAFETY + assert closed_response.error_message == "Blocked by safety" + assert closed_response.content is None + + @pytest.mark.asyncio + @pytest.mark.parametrize("use_progressive_sse", [True, False]) + async def test_pure_function_call_behavior_differs_by_mode( + self, use_progressive_sse + ): + """A pure function call yields the part in progressive mode and an empty frame otherwise.""" + with temporary_feature_override( + FeatureName.PROGRESSIVE_SSE_STREAMING, use_progressive_sse + ): + aggregator = streaming_utils.StreamingResponseAggregator() + response = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content( + parts=[ + types.Part( + function_call=types.FunctionCall( + name="my_tool", + args={"x": 1}, + ) + ) + ] + ), + finish_reason=types.FinishReason.STOP, + ) + ] + ) + + results = [] + async for r in aggregator.process_response(response): + results.append(r) + closed_response = aggregator.close() + + assert closed_response is not None + assert closed_response.partial is False + + if use_progressive_sse: + assert closed_response.content is not None + assert len(closed_response.content.parts) == 1 + assert closed_response.content.parts[0].function_call.name == "my_tool" + else: + assert closed_response.content is None @pytest.mark.asyncio @pytest.mark.parametrize(