From 3ec27f2e03ab19d12cb8c212c79faf8996930285 Mon Sep 17 00:00:00 2001 From: mdevolde Date: Sat, 27 Jun 2026 12:39:17 +0200 Subject: [PATCH 1/7] test: split tests in categs, add bench tests --- pyproject.toml | 25 +++- pytest.ini | 5 + tests/benchmarks/__init__.py | 1 + tests/benchmarks/conftest.py | 17 +++ tests/benchmarks/test_bench_check.py | 81 ++++++++++++ tests/integration/__init__.py | 1 + tests/integration/conftest.py | 17 +++ tests/{ => integration}/test_api_public.py | 8 +- tests/{ => integration}/test_cli.py | 127 +++++-------------- tests/{ => integration}/test_config.py | 35 +---- tests/integration/test_download.py | 94 ++++++++++++++ tests/{ => integration}/test_match.py | 10 +- tests/{ => integration}/test_server_local.py | 2 +- tests/property/__init__.py | 1 + tests/property/conftest.py | 17 +++ tests/property/test_prop_config.py | 85 +++++++++++++ tests/property/test_prop_safe_zip.py | 74 +++++++++++ tests/property/test_prop_utils.py | 20 +++ tests/unit/__init__.py | 1 + tests/unit/conftest.py | 17 +++ tests/unit/test_cli_args.py | 60 +++++++++ tests/unit/test_config_validation.py | 37 ++++++ tests/{ => unit}/test_download.py | 101 +-------------- tests/{ => unit}/test_safe_zip.py | 2 +- uv.lock | 48 +++++++ 25 files changed, 650 insertions(+), 236 deletions(-) create mode 100644 tests/benchmarks/__init__.py create mode 100644 tests/benchmarks/conftest.py create mode 100644 tests/benchmarks/test_bench_check.py create mode 100644 tests/integration/__init__.py create mode 100644 tests/integration/conftest.py rename tests/{ => integration}/test_api_public.py (93%) rename tests/{ => integration}/test_cli.py (68%) rename tests/{ => integration}/test_config.py (81%) create mode 100644 tests/integration/test_download.py rename tests/{ => integration}/test_match.py (95%) rename tests/{ => integration}/test_server_local.py (98%) create mode 100644 tests/property/__init__.py create mode 100644 tests/property/conftest.py create mode 100644 tests/property/test_prop_config.py create mode 100644 tests/property/test_prop_safe_zip.py create mode 100644 tests/property/test_prop_utils.py create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/conftest.py create mode 100644 tests/unit/test_cli_args.py create mode 100644 tests/unit/test_config_validation.py rename tests/{ => unit}/test_download.py (82%) rename tests/{ => unit}/test_safe_zip.py (99%) diff --git a/pyproject.toml b/pyproject.toml index 47957ec..182c0cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,9 @@ changelog = "https://github.com/jxmorris12/language_tool_python/blob/master/CHAN [dependency-groups] tests = [ "pytest", + "pytest-benchmark", "pytest-cov", + "hypothesis", ] docs = [ @@ -141,9 +143,10 @@ ignore = [ ] [tool.ruff.lint.per-file-ignores] -"tests/*.py" = [ - "S101", # Need to use assert statements in tests - "SLF001" # Need to use private members of the library for testing +"tests/**/*.py" = [ + "S101", # Need to use assert statements in tests + "SLF001", # Need to use private members of the library for testing + "RUF001", # LanguageTool output contains typographic quotes (‘’“”) ] "src/language_tool_python/__main__.py" = ["T201"] # Allow usage of print in the CLI entry point @@ -170,3 +173,19 @@ warn_return_any = true warn_unreachable = true warn_unused_configs = true warn_unused_ignores = true + +[[tool.mypy.overrides]] +module = ["tests.benchmarks.*"] +# pytest-benchmark is untyped; relax Any restrictions for benchmark files only +disallow_any_unimported = false +disallow_any_expr = false +disallow_any_explicit = false +disallow_any_decorated = false + +[[tool.mypy.overrides]] +module = ["tests.property.*"] +# hypothesis is untyped; relax Any restrictions for property test files only +disallow_any_unimported = false +disallow_any_expr = false +disallow_any_explicit = false +disallow_any_decorated = false diff --git a/pytest.ini b/pytest.ini index 8170a2f..2da9216 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,6 +1,11 @@ [pytest] addopts = -vra --cov=src --cov-report=html --cov-report=xml testpaths = tests +markers = + unit: fast, isolated tests with no external dependencies + integration: tests that require a live LanguageTool server or network + property: property-based tests using Hypothesis + perf: performance benchmark tests using pytest-benchmark [coverage:run] source = src diff --git a/tests/benchmarks/__init__.py b/tests/benchmarks/__init__.py new file mode 100644 index 0000000..1dff63c --- /dev/null +++ b/tests/benchmarks/__init__.py @@ -0,0 +1 @@ +"""Benchmark tests for the language_tool_python library.""" diff --git a/tests/benchmarks/conftest.py b/tests/benchmarks/conftest.py new file mode 100644 index 0000000..6cebb40 --- /dev/null +++ b/tests/benchmarks/conftest.py @@ -0,0 +1,17 @@ +"""Configuration for the benchmark test suite.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + + +def pytest_collection_modifyitems( + items: list[pytest.Item], +) -> None: + """Apply the 'perf' marker to all tests collected from this directory.""" + benchmarks_dir = Path(__file__).parent + for item in items: + if item.path.is_relative_to(benchmarks_dir): + item.add_marker(pytest.mark.perf) diff --git a/tests/benchmarks/test_bench_check.py b/tests/benchmarks/test_bench_check.py new file mode 100644 index 0000000..2cca935 --- /dev/null +++ b/tests/benchmarks/test_bench_check.py @@ -0,0 +1,81 @@ +"""Benchmark tests for LanguageTool grammar checking performance. + +Run with: pytest tests/benchmarks/ -v +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import language_tool_python + +if TYPE_CHECKING: + from collections.abc import Generator + + from pytest_benchmark.fixture import BenchmarkFixture + +_SHORT_TEXT = "This is a sentence with some erors in it. " +_MEDIUM_TEXT = (_SHORT_TEXT * 20).strip() +_LONG_TEXT = (_SHORT_TEXT * 100).strip() + + +@pytest.fixture(scope="module") +def tool() -> Generator[language_tool_python.LanguageTool, None, None]: + """Provide a LanguageTool instance shared across benchmarks in this module.""" + with language_tool_python.LanguageTool("en-US") as t: + yield t + + +@pytest.fixture(scope="module") +def cached_tool() -> Generator[language_tool_python.LanguageTool, None, None]: + """Provide a pipeline-caching LanguageTool instance for cache benchmarks.""" + with language_tool_python.LanguageTool( + "en-US", + config={"cacheSize": 1000, "pipelineCaching": True}, + ) as t: + yield t + + +def test_bench_check_short_text( + benchmark: BenchmarkFixture, + tool: language_tool_python.LanguageTool, +) -> None: + """Benchmark grammar checking on a short sentence (~38 characters).""" + benchmark(tool.check, _SHORT_TEXT) + + +def test_bench_check_medium_text( + benchmark: BenchmarkFixture, + tool: language_tool_python.LanguageTool, +) -> None: + """Benchmark grammar checking on medium-length text (~840 characters).""" + benchmark(tool.check, _MEDIUM_TEXT) + + +def test_bench_check_long_text( + benchmark: BenchmarkFixture, + tool: language_tool_python.LanguageTool, +) -> None: + """Benchmark grammar checking on long text (~4200 characters).""" + benchmark(tool.check, _LONG_TEXT) + + +def test_bench_correct_short_text( + benchmark: BenchmarkFixture, + tool: language_tool_python.LanguageTool, +) -> None: + """Benchmark automatic text correction on a short sentence.""" + benchmark(tool.correct, _SHORT_TEXT) + + +def test_bench_check_with_pipeline_cache( + benchmark: BenchmarkFixture, + cached_tool: language_tool_python.LanguageTool, +) -> None: + """Benchmark grammar checking with pipeline caching enabled. + + Compare with test_bench_check_short_text to measure cache speedup. + """ + benchmark(cached_tool.check, _SHORT_TEXT) diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..7ac8aa9 --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1 @@ +"""Integration tests for the language_tool_python library.""" diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 0000000..3628ef0 --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,17 @@ +"""Configuration for the integration test suite.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + + +def pytest_collection_modifyitems( + items: list[pytest.Item], +) -> None: + """Apply the 'integration' marker to all tests collected from this directory.""" + integration_dir = Path(__file__).parent + for item in items: + if item.path.is_relative_to(integration_dir): + item.add_marker(pytest.mark.integration) diff --git a/tests/test_api_public.py b/tests/integration/test_api_public.py similarity index 93% rename from tests/test_api_public.py rename to tests/integration/test_api_public.py index 1788741..307c5d6 100644 --- a/tests/test_api_public.py +++ b/tests/integration/test_api_public.py @@ -1,4 +1,4 @@ -"""Tests for the public API functionality.""" +"""Integration tests for the public API functionality.""" import pytest @@ -34,7 +34,7 @@ def test_remote_es() -> None: 'INCORRECT_EXPRESSIONS', 'rule_issue_type': 'grammar', 'sentence': 'LanguageTool le ayudará a afrentar algunas dificultades propias de la escritura.'}), Match({'rule_id': 'PRON_HABER_PARTICIPIO', - 'message': 'El v. \u2018haber\u2019 se escribe con hache.', + 'message': 'El v. ‘haber’ se escribe con hache.', 'replacements': ['ha'], 'offset_in_context': 43, 'context': '...ificultades propias de la escritura. Se a hecho un esfuerzo para detectar errores...', 'offset': 107, 'error_length': 1, 'category': @@ -50,8 +50,8 @@ def test_remote_es() -> None: 'misspelling', 'sentence': 'Se a hecho un esfuerzo para detectar errores tipográficos, ortograficos y incluso gramaticales.'}), Match({'rule_id': 'Y_E_O_U', 'message': 'Cuando precede a palabras - que comienzan por \u2018i\u2019, la conjunción \u2018y\u2019 se - transforma en \u2018e\u2019.', 'replacements': ['e'], + que comienzan por ‘i’, la conjunción ‘y’ se + transforma en ‘e’.', 'replacements': ['e'], 'offset_in_context': 43, 'context': '...ctar errores tipográficos, ortograficos y incluso gramaticales. También algunos e...', 'offset': 176, 'error_length': 1, 'category': 'GRAMMAR', 'rule_issue_type': diff --git a/tests/test_cli.py b/tests/integration/test_cli.py similarity index 68% rename from tests/test_cli.py rename to tests/integration/test_cli.py index 18da908..d101f66 100644 --- a/tests/test_cli.py +++ b/tests/integration/test_cli.py @@ -1,4 +1,4 @@ -"""Tests for the command-line interface (CLI) functionality.""" +"""Integration tests for the CLI using real LanguageTool server instances.""" import io import sys @@ -7,7 +7,39 @@ import pytest import language_tool_python -from language_tool_python.__main__ import main, parse_args +from language_tool_python.__main__ import main + + +def main_with_stdin(argv: list[str], stdin: str) -> int: + """Execute the main CLI with simulated stdin input. + + :param argv: Command-line arguments to pass to the main function. + :param stdin: Input text to simulate as stdin. + :return: Exit code returned by the main function. + :rtype: int + """ + old_stdin = sys.stdin + sys.stdin = io.StringIO(stdin) + try: + return main(argv) + finally: + sys.stdin = old_stdin + + +@pytest.fixture(scope="module") +def remote_server() -> Generator[tuple[str, int], None, None]: + """Fixture that provides a remote LanguageTool server for testing. + + This fixture initializes a LanguageTool instance and yields its host and port, + ensuring proper cleanup after all tests in the module complete. + + :return: A tuple containing the server host and port (host, port). + :rtype: Generator[Tuple[str, int], None, None] + """ + with language_tool_python.LanguageTool("en-US") as tool: + host = tool._host + port = tool._port + yield host, port @pytest.mark.parametrize( @@ -89,22 +121,6 @@ def test_cli_exit_codes( assert code != 0 -@pytest.fixture(scope="module") -def remote_server() -> Generator[tuple[str, int], None, None]: - """Fixture that provides a remote LanguageTool server for testing. - - This fixture initializes a LanguageTool instance and yields its host and port, - ensuring proper cleanup after all tests in the module complete. - - :return: A tuple containing the server host and port (host, port). - :rtype: Generator[Tuple[str, int], None, None] - """ - with language_tool_python.LanguageTool("en-US") as tool: - host = tool._host - port = tool._port - yield host, port - - def test_cli_remote_ok(remote_server: tuple[str, int]) -> None: """Test the CLI with a remote server using valid input text. @@ -155,78 +171,3 @@ def test_cli_remote_error(remote_server: tuple[str, int]) -> None: "This is noot okay.\n", ) assert code != 0 - - -def test_parse_args_enabled_only_with_enable_categories() -> None: - """Test that --enabled-only is accepted when only --enable-categories is provided. - - :raises AssertionError: If parse_args raises an error for this valid combination. - """ - args = parse_args(["-l", "en-US", "--enabled-only", "-E", "TYPOS", "file.txt"]) - assert args.enabled_only is True - assert args.enable_categories == {"TYPOS"} - - -def test_parse_args_enabled_only_rejects_disable_categories() -> None: - """Test that --enabled-only cannot be combined with --disable-categories. - - :raises SystemExit: Expected, as argparse calls sys.exit on error. - """ - with pytest.raises(SystemExit): - parse_args( - ["-l", "en-US", "--enabled-only", "-e", "RULE", "-D", "TYPOS", "file.txt"] - ) - - -def test_parse_args_enabled_only_requires_enable_or_enable_categories() -> None: - """Test that --enabled-only requires at least --enable or --enable-categories. - - :raises SystemExit: Expected, as argparse calls sys.exit on error. - """ - with pytest.raises(SystemExit): - parse_args(["-l", "en-US", "--enabled-only", "file.txt"]) - - -def test_parse_args_categories() -> None: - """Test that --disable-categories and --enable-categories are parsed correctly. - - :raises AssertionError: If the parsed category sets do not match the expected - values. - """ - args = parse_args( - ["-l", "en-US", "-D", "TYPOS,GRAMMAR", "-E", "PUNCTUATION", "file.txt"] - ) - assert args.disable_categories == {"TYPOS", "GRAMMAR"} - assert args.enable_categories == {"PUNCTUATION"} - - -def test_parse_args_categories_multiple_flags() -> None: - """Test that repeated -D/-E flags accumulate into the same set. - - :raises AssertionError: If the category sets do not accumulate correctly. - """ - args = parse_args( - ["-l", "en-US", "-D", "TYPOS", "-D", "GRAMMAR", "-E", "PUNCTUATION", "file.txt"] - ) - assert args.disable_categories == {"TYPOS", "GRAMMAR"} - assert args.enable_categories == {"PUNCTUATION"} - - -def main_with_stdin(argv: list[str], stdin: str) -> int: - """Execute the main CLI with simulated stdin input. - - This utility function temporarily replaces sys.stdin with a StringIO object - containing the provided input, executes the main CLI function, and then restores the - original stdin. - - :param argv: Command-line arguments to pass to the main function. - :param stdin: Input text to simulate as stdin. - :return: Exit code returned by the main function. - :rtype: int - """ - old_stdin = sys.stdin - sys.stdin = io.StringIO(stdin) - try: - return main(argv) - finally: - sys.stdin = old_stdin diff --git a/tests/test_config.py b/tests/integration/test_config.py similarity index 81% rename from tests/test_config.py rename to tests/integration/test_config.py index 087d5d3..3a2ba42 100644 --- a/tests/test_config.py +++ b/tests/integration/test_config.py @@ -1,4 +1,4 @@ -"""Tests for the configuration options of LanguageTool.""" +"""Integration tests for LanguageTool configuration options (require a local server).""" import re import time @@ -6,7 +6,6 @@ import pytest import language_tool_python -from language_tool_python.config_file import ConfigValue, LanguageToolConfig from language_tool_python.exceptions import LanguageToolError @@ -175,35 +174,3 @@ def test_disabled_rule_in_config() -> None: text = "He realised that the organization was in jeopardy." matches = tool.check(text) assert len(matches) == 0 - - -@pytest.mark.parametrize( - "config", - [ - {"blockedReferrers": "example.com\ntrustXForwardForHeader=true"}, - {"disabledRuleIds": ["MORFOLOGIK_RULE_EN_US", "SAFE\rrequestLimit=0"]}, - {"lang-en\ntrustXForwardForHeader": "true"}, - {"lang-en": "custom-word\nrequestLimit=0"}, - ], -) -def test_config_rejects_line_break_injection(config: dict[str, ConfigValue]) -> None: - """Test that config serialization cannot be escaped with CR/LF characters.""" - with pytest.raises(ValueError, match="cannot contain line breaks"): - LanguageToolConfig(config) - - -@pytest.mark.parametrize( - "config", - [ - {"blockedReferrers": "example.com\\"}, - {"disabledRuleIds": ["MORFOLOGIK_RULE_EN_US", "SAFE\\"]}, - {"lang-en\\": "true"}, - {"lang-en": "custom-word\\"}, - ], -) -def test_config_rejects_odd_trailing_backslashes( - config: dict[str, ConfigValue], -) -> None: - """Test that config serialization cannot escape the line ending with a backslash.""" - with pytest.raises(ValueError, match="odd number of backslashes"): - LanguageToolConfig(config) diff --git a/tests/integration/test_download.py b/tests/integration/test_download.py new file mode 100644 index 0000000..fd74f50 --- /dev/null +++ b/tests/integration/test_download.py @@ -0,0 +1,94 @@ +"""Integration tests for LanguageTool download and version management (real network).""" + +from datetime import datetime, timedelta, timezone + +import pytest + +import language_tool_python +from language_tool_python.exceptions import LanguageToolError, PathError + + +def test_install_inexistent_version() -> None: + """Test errors when downloading a non-existent LanguageTool version. + + This test verifies that the tool correctly handles invalid version numbers by + raising a LanguageToolError when trying to initialize with a version that does not + exist. + + :raises AssertionError: If LanguageToolError is not raised for an invalid version. + """ + with pytest.raises(LanguageToolError): + language_tool_python.LanguageTool(language_tool_download_version="0.0") + + +def test_install_too_old_version() -> None: + """Test that attempting to download a too-old LanguageTool version raises an error. + + This test verifies that the tool correctly handles versions that are no longer + supported by raising a PathError when trying to initialize with an outdated version. + + :raises AssertionError: If PathError is not raised for a too-old version. + """ + with pytest.raises(PathError): + language_tool_python.LanguageTool(language_tool_download_version="3.9") + + +def test_inexistent_language() -> None: + """Test that creating a LanguageTag with an invalid language code raises an error. + + This test verifies that the LanguageTag constructor correctly validates language + codes and raises a ValueError when given a language code that is not supported. + + :raises AssertionError: If ValueError is not raised for an invalid language code. + """ + with ( + language_tool_python.LanguageTool("en-US") as tool, + pytest.raises(ValueError, match="unsupported language"), + ): + language_tool_python.LanguageTag("xx-XX", tool._get_languages()) + + +def test_install_oldest_supported_version() -> None: + """Test that downloading the oldest supported LanguageTool version works correctly. + + This test verifies that the tool can successfully download and initialize with the + oldest version that is still supported. + + :raises AssertionError: If the tool fails to initialize with the oldest supported + version. + """ + try: + with language_tool_python.LanguageTool( + "en-US", + language_tool_download_version="4.0", + ) as tool: + assert tool.language_tool_download_version == "4.0" + except LanguageToolError: + pytest.fail("Failed to download or initialize the oldest supported version.") + + +def test_install_snapshot_version() -> None: + """Test that downloading the snapshot version of LanguageTool works correctly. + + This test verifies that the tool can successfully download and initialize with the + snapshot of yesterday. + + :raises AssertionError: If the tool fails to initialize with the snapshot version. + """ + try: + with language_tool_python.LanguageTool( + "en-US", + language_tool_download_version=( + (datetime.now(timezone.utc) - timedelta(days=3)).strftime("%Y%m%d") + ), + ) as tool: + assert tool.language_tool_download_version == ( + datetime.now(timezone.utc) - timedelta(days=3) + ).strftime("%Y%m%d") + except LanguageToolError: + pytest.skip( + ( + "Failed to download or initialize the snapshot version. This may be " + "due to a missing snapshot for the expected date." + ), + ) diff --git a/tests/test_match.py b/tests/integration/test_match.py similarity index 95% rename from tests/test_match.py rename to tests/integration/test_match.py index a81713d..074f5f8 100644 --- a/tests/test_match.py +++ b/tests/integration/test_match.py @@ -1,4 +1,4 @@ -"""Tests for the Match functionality of LanguageTool.""" +"""Integration tests for the Match functionality of LanguageTool.""" from typing import TypedDict @@ -122,15 +122,15 @@ def test_match() -> None: expected format. """ with language_tool_python.LanguageTool("en-US") as tool: - text = "A sentence with a error in the Hitchhiker\u2019s Guide tot he Galaxy" + text = "A sentence with a error in the Hitchhiker’s Guide tot he Galaxy" matches = tool.check(text) assert len(matches) == EXPECTED_MATCH_COUNT assert str(matches[0]) == ( "Offset 16, length 1, Rule ID: EN_A_VS_AN\n" - "Message: Use “an” instead of \u2018a\u2019 if the following word starts " - "with a vowel sound, e.g. \u2018an article\u2019, \u2018an hour\u2019.\n" + "Message: Use “an” instead of ‘a’ if the following word starts " + "with a vowel sound, e.g. ‘an article’, ‘an hour’.\n" "Suggestion: an\n" - "A sentence with a error in the Hitchhiker\u2019s Guide tot he ..." + "A sentence with a error in the Hitchhiker’s Guide tot he ..." "\n ^" ) diff --git a/tests/test_server_local.py b/tests/integration/test_server_local.py similarity index 98% rename from tests/test_server_local.py rename to tests/integration/test_server_local.py index 3731952..31b4b47 100644 --- a/tests/test_server_local.py +++ b/tests/integration/test_server_local.py @@ -1,4 +1,4 @@ -"""Tests for the local server functionality of LanguageTool.""" +"""Integration tests for the local server functionality of LanguageTool.""" from __future__ import annotations diff --git a/tests/property/__init__.py b/tests/property/__init__.py new file mode 100644 index 0000000..2e5bdfe --- /dev/null +++ b/tests/property/__init__.py @@ -0,0 +1 @@ +"""Property-based tests for the language_tool_python library.""" diff --git a/tests/property/conftest.py b/tests/property/conftest.py new file mode 100644 index 0000000..6297f26 --- /dev/null +++ b/tests/property/conftest.py @@ -0,0 +1,17 @@ +"""Configuration for the property-based test suite.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + + +def pytest_collection_modifyitems( + items: list[pytest.Item], +) -> None: + """Apply the 'property' marker to all tests collected from this directory.""" + property_dir = Path(__file__).parent + for item in items: + if item.path.is_relative_to(property_dir): + item.add_marker(pytest.mark.property) diff --git a/tests/property/test_prop_config.py b/tests/property/test_prop_config.py new file mode 100644 index 0000000..41b95a9 --- /dev/null +++ b/tests/property/test_prop_config.py @@ -0,0 +1,85 @@ +"""Property-based tests for LanguageToolConfig input validation. + +These tests use Hypothesis to verify that injection-protection invariants +hold for any input, not just the handwritten examples in unit tests. +""" + +import pytest +from hypothesis import given, settings +from hypothesis import strategies as st + +from language_tool_python.config_file import LanguageToolConfig + +_LINEBREAK_CHARS = ["\n", "\r", "\r\n"] + + +@given( + before=st.text(), + linebreak=st.sampled_from(_LINEBREAK_CHARS), + after=st.text(), +) +@settings(max_examples=200) +def test_prop_config_value_with_linebreak_always_raises( + before: str, + linebreak: str, + after: str, +) -> None: + """Any config value containing CR or LF must raise ValueError. + + The string is constructed as ``before + linebreak + after`` to guarantee + the presence of a line-break character without relying on filter(). + + :param before: Arbitrary text before the line-break. + :param linebreak: A CR, LF, or CRLF sequence. + :param after: Arbitrary text after the line-break. + :raises AssertionError: If ValueError is not raised. + """ + value = before + linebreak + after + with pytest.raises(ValueError, match="line breaks"): + LanguageToolConfig({"blockedReferrers": value}) + + +@given( + prefix=st.text(alphabet=st.characters(blacklist_characters="\r\n\\")), + count=st.integers(min_value=1, max_value=5), +) +@settings(max_examples=200) +def test_prop_config_odd_trailing_backslashes_always_raise( + prefix: str, + count: int, +) -> None: + r"""Any config value ending with an odd number of backslashes must raise ValueError. + + The value is constructed as ``prefix + '\\\\' * (2*count - 1)`` to guarantee + the trailing backslash count is always odd (1, 3, 5, 7, or 9). + + :param prefix: A string with no backslashes or line-break characters. + :param count: Determines the odd backslash count: ``2*count - 1``. + :raises AssertionError: If ValueError is not raised. + """ + value = prefix + "\\" * (2 * count - 1) + with pytest.raises(ValueError, match="backslash"): + LanguageToolConfig({"blockedReferrers": value}) + + +@given( + key_before=st.text(alphabet=st.characters(blacklist_characters="\r\n")), + linebreak=st.sampled_from(_LINEBREAK_CHARS), + key_after=st.text(alphabet=st.characters(blacklist_characters="\r\n")), +) +@settings(max_examples=200) +def test_prop_config_key_with_linebreak_always_raises( + key_before: str, + linebreak: str, + key_after: str, +) -> None: + """Any config key containing CR or LF must raise ValueError. + + :param key_before: Text before the line-break in the key. + :param linebreak: A CR, LF, or CRLF sequence. + :param key_after: Text after the line-break in the key. + :raises AssertionError: If ValueError is not raised. + """ + key = key_before + linebreak + key_after + with pytest.raises(ValueError, match="line breaks"): + LanguageToolConfig({key: "valid_value"}) diff --git a/tests/property/test_prop_safe_zip.py b/tests/property/test_prop_safe_zip.py new file mode 100644 index 0000000..ae3f992 --- /dev/null +++ b/tests/property/test_prop_safe_zip.py @@ -0,0 +1,74 @@ +"""Property-based tests for the safe ZIP extractor path-traversal protection.""" + +import contextlib +import io +import shutil +import uuid +import zipfile +from collections.abc import Iterator +from contextlib import contextmanager +from pathlib import Path + +import pytest +from hypothesis import given, settings +from hypothesis import strategies as st + +from language_tool_python._internals.safe_zip import SafeZipExtractor +from language_tool_python.exceptions import PathError + +_TRAVERSAL_PREFIXES = ["../", "..\\", "/", "C:/", "..\\..\\"] + + +def _make_zip_payload(files: dict[str, bytes]) -> bytes: + """Create an in-memory ZIP payload for testing.""" + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w") as zf: + for name, data in files.items(): + zf.writestr(name, data) + return buf.getvalue() + + +@contextmanager +def _temp_dir() -> Iterator[Path]: + """Create a temporary dir inside the project workspace to avoid perm issues.""" + root = Path.cwd() / ".test_prop_safe_zip_tmp" + path = root / uuid.uuid4().hex + path.mkdir(parents=True) + try: + yield path + finally: + shutil.rmtree(path, ignore_errors=True) + with contextlib.suppress(OSError): + root.rmdir() + + +@given( + traversal=st.sampled_from(_TRAVERSAL_PREFIXES), + suffix=st.text( + alphabet=st.characters(whitelist_categories=("Ll", "Lu", "Nd")), + min_size=1, + ), +) +@settings(max_examples=300) +def test_prop_safe_zip_path_traversal_always_rejected( + traversal: str, + suffix: str, +) -> None: + """Any ZIP member whose name begins with a path-traversal prefix must be rejected. + + Checks that ``SafeZipExtractor`` raises ``PathError`` for filenames like + ``../evil``, ``/etc/passwd``, or ``C:/Windows/file`` regardless of the suffix. + + :param traversal: A path-traversal prefix (e.g. ``../``, ``/``). + :param suffix: Alphanumeric suffix appended after the traversal prefix. + :raises AssertionError: If ``PathError`` is not raised for the unsafe member name. + """ + filename = traversal + suffix + payload = _make_zip_payload({filename: b"payload"}) + + with ( + _temp_dir() as dest, + zipfile.ZipFile(io.BytesIO(payload)) as zf, + pytest.raises(PathError, match="Unsafe ZIP member"), + ): + SafeZipExtractor().extractall(zf, dest) diff --git a/tests/property/test_prop_utils.py b/tests/property/test_prop_utils.py new file mode 100644 index 0000000..4bc7f20 --- /dev/null +++ b/tests/property/test_prop_utils.py @@ -0,0 +1,20 @@ +"""Property-based tests for the LanguageTool utility functions.""" + +from hypothesis import given, settings +from hypothesis import strategies as st + +import language_tool_python + + +@given(text=st.text()) +@settings(max_examples=500) +def test_prop_correct_with_empty_matches_is_identity(text: str) -> None: + """correct(text, []) must always return the original text unchanged. + + This verifies the fundamental contract that applying zero corrections + is a no-op, regardless of the text content (empty, unicode, emojis...). + + :param text: Arbitrary string generated by Hypothesis. + :raises AssertionError: If the corrected text differs from the input. + """ + assert language_tool_python.utils.correct(text, []) == text diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..1733800 --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1 @@ +"""Unit tests for the language_tool_python library.""" diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 0000000..65fcecd --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,17 @@ +"""Configuration for the unit test suite.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + + +def pytest_collection_modifyitems( + items: list[pytest.Item], +) -> None: + """Apply the 'unit' marker to all tests collected from this directory.""" + unit_dir = Path(__file__).parent + for item in items: + if item.path.is_relative_to(unit_dir): + item.add_marker(pytest.mark.unit) diff --git a/tests/unit/test_cli_args.py b/tests/unit/test_cli_args.py new file mode 100644 index 0000000..6f79066 --- /dev/null +++ b/tests/unit/test_cli_args.py @@ -0,0 +1,60 @@ +"""Unit tests for the CLI argument parser.""" + +import pytest + +from language_tool_python.__main__ import parse_args + + +def test_parse_args_enabled_only_with_enable_categories() -> None: + """Test that --enabled-only is accepted when only --enable-categories is provided. + + :raises AssertionError: If parse_args raises an error for this valid combination. + """ + args = parse_args(["-l", "en-US", "--enabled-only", "-E", "TYPOS", "file.txt"]) + assert args.enabled_only is True + assert args.enable_categories == {"TYPOS"} + + +def test_parse_args_enabled_only_rejects_disable_categories() -> None: + """Test that --enabled-only cannot be combined with --disable-categories. + + :raises SystemExit: Expected, as argparse calls sys.exit on error. + """ + with pytest.raises(SystemExit): + parse_args( + ["-l", "en-US", "--enabled-only", "-e", "RULE", "-D", "TYPOS", "file.txt"] + ) + + +def test_parse_args_enabled_only_requires_enable_or_enable_categories() -> None: + """Test that --enabled-only requires at least --enable or --enable-categories. + + :raises SystemExit: Expected, as argparse calls sys.exit on error. + """ + with pytest.raises(SystemExit): + parse_args(["-l", "en-US", "--enabled-only", "file.txt"]) + + +def test_parse_args_categories() -> None: + """Test that --disable-categories and --enable-categories are parsed correctly. + + :raises AssertionError: If the parsed category sets do not match the expected + values. + """ + args = parse_args( + ["-l", "en-US", "-D", "TYPOS,GRAMMAR", "-E", "PUNCTUATION", "file.txt"] + ) + assert args.disable_categories == {"TYPOS", "GRAMMAR"} + assert args.enable_categories == {"PUNCTUATION"} + + +def test_parse_args_categories_multiple_flags() -> None: + """Test that repeated -D/-E flags accumulate into the same set. + + :raises AssertionError: If the category sets do not accumulate correctly. + """ + args = parse_args( + ["-l", "en-US", "-D", "TYPOS", "-D", "GRAMMAR", "-E", "PUNCTUATION", "file.txt"] + ) + assert args.disable_categories == {"TYPOS", "GRAMMAR"} + assert args.enable_categories == {"PUNCTUATION"} diff --git a/tests/unit/test_config_validation.py b/tests/unit/test_config_validation.py new file mode 100644 index 0000000..aa46834 --- /dev/null +++ b/tests/unit/test_config_validation.py @@ -0,0 +1,37 @@ +"""Unit tests for LanguageToolConfig input validation and injection protection.""" + +import pytest + +from language_tool_python.config_file import ConfigValue, LanguageToolConfig + + +@pytest.mark.parametrize( + "config", + [ + {"blockedReferrers": "example.com\ntrustXForwardForHeader=true"}, + {"disabledRuleIds": ["MORFOLOGIK_RULE_EN_US", "SAFE\rrequestLimit=0"]}, + {"lang-en\ntrustXForwardForHeader": "true"}, + {"lang-en": "custom-word\nrequestLimit=0"}, + ], +) +def test_config_rejects_line_break_injection(config: dict[str, ConfigValue]) -> None: + """Test that config serialization cannot be escaped with CR/LF characters.""" + with pytest.raises(ValueError, match="cannot contain line breaks"): + LanguageToolConfig(config) + + +@pytest.mark.parametrize( + "config", + [ + {"blockedReferrers": "example.com\\"}, + {"disabledRuleIds": ["MORFOLOGIK_RULE_EN_US", "SAFE\\"]}, + {"lang-en\\": "true"}, + {"lang-en": "custom-word\\"}, + ], +) +def test_config_rejects_odd_trailing_backslashes( + config: dict[str, ConfigValue], +) -> None: + """Test that config serialization cannot escape the line ending with a backslash.""" + with pytest.raises(ValueError, match="odd number of backslashes"): + LanguageToolConfig(config) diff --git a/tests/test_download.py b/tests/unit/test_download.py similarity index 82% rename from tests/test_download.py rename to tests/unit/test_download.py index d6de0f0..ad8ec2d 100644 --- a/tests/test_download.py +++ b/tests/unit/test_download.py @@ -1,4 +1,7 @@ -"""Tests for the download/language functionality of LanguageTool.""" +"""Unit tests for download logic, URL construction, HTTP handling, and integrity checks. + +These tests use mocks and monkeypatching to avoid real network requests. +""" import contextlib import hashlib @@ -10,7 +13,7 @@ import zipfile from collections.abc import Iterator from contextlib import contextmanager -from datetime import datetime, timedelta, timezone +from datetime import datetime, timezone from pathlib import Path from unittest.mock import patch @@ -23,7 +26,7 @@ _LTP_MAX_DOWNLOAD_BYTES_ENV_VAR, LocalLanguageTool, ) -from language_tool_python.exceptions import LanguageToolError, PathError +from language_tool_python.exceptions import PathError EXPECTED_DOWNLOAD_BYTES_OVERRIDE = 123 @@ -87,52 +90,9 @@ def workspace_temp_dir() -> Iterator[Path]: root.rmdir() -def test_install_inexistent_version() -> None: - """Test errors when downloading a non-existent LanguageTool version. - - This test verifies that the tool correctly handles invalid version numbers by - raising a LanguageToolError when trying to initialize with a version that does not - exist. - - :raises AssertionError: If LanguageToolError is not raised for an invalid version. - """ - with pytest.raises(LanguageToolError): - language_tool_python.LanguageTool(language_tool_download_version="0.0") - - -def test_install_too_old_version() -> None: - """Test that attempting to download a too-old LanguageTool version raises an error. - - This test verifies that the tool correctly handles versions that are no longer - supported by raising a PathError when trying to initialize with an outdated version. - - :raises AssertionError: If PathError is not raised for a too-old version. - """ - with pytest.raises(PathError): - language_tool_python.LanguageTool(language_tool_download_version="3.9") - - -def test_inexistent_language() -> None: - """Test that creating a LanguageTag with an invalid language code raises an error. - - This test verifies that the LanguageTag constructor correctly validates language - codes and raises a ValueError when given a language code that is not supported. - - :raises AssertionError: If ValueError is not raised for an invalid language code. - """ - with ( - language_tool_python.LanguageTool("en-US") as tool, - pytest.raises(ValueError, match="unsupported language"), - ): - language_tool_python.LanguageTag("xx-XX", tool._get_languages()) - - def test_http_get_403_forbidden() -> None: """Test that http_get raises PathError when receiving a 403 Forbidden status code. - This test verifies that the function correctly handles forbidden access errors when - attempting to download files. - :raises AssertionError: If PathError is not raised for a 403 status code. """ mock_response = MockDownloadResponse(b"", status_code=403) @@ -153,9 +113,6 @@ def test_http_get_403_forbidden() -> None: def test_http_get_other_error_codes() -> None: """Test PathError handling for unexpected HTTP status codes. - This test verifies that the function correctly handles different HTTP error codes - like 500 (Internal Server Error), 503 (Service Unavailable), etc. - :raises AssertionError: If PathError is not raised for error status codes. """ error_codes = [500, 502, 503, 504] @@ -562,49 +519,3 @@ def test_latest_snapshot_download_renames_archive_root_to_current_date( local_language_tool.download() get_mock.assert_not_called() - - -def test_install_oldest_supported_version() -> None: - """Test that downloading the oldest supported LanguageTool version works correctly. - - This test verifies that the tool can successfully download and initialize with the - oldest version that is still supported. - - :raises AssertionError: If the tool fails to initialize with the oldest supported - version. - """ - try: - with language_tool_python.LanguageTool( - "en-US", - language_tool_download_version="4.0", - ) as tool: - assert tool.language_tool_download_version == "4.0" - except LanguageToolError: - pytest.fail("Failed to download or initialize the oldest supported version.") - - -def test_install_snapshot_version() -> None: - """Test that downloading the snapshot version of LanguageTool works correctly. - - This test verifies that the tool can successfully download and initialize with the - snapshot of yesterday. - - :raises AssertionError: If the tool fails to initialize with the snapshot version. - """ - try: - with language_tool_python.LanguageTool( - "en-US", - language_tool_download_version=( - (datetime.now(timezone.utc) - timedelta(days=3)).strftime("%Y%m%d") - ), - ) as tool: - assert tool.language_tool_download_version == ( - datetime.now(timezone.utc) - timedelta(days=3) - ).strftime("%Y%m%d") - except LanguageToolError: - pytest.skip( - ( - "Failed to download or initialize the snapshot version. This may be " - "due to a missing snapshot for the expected date." - ), - ) diff --git a/tests/test_safe_zip.py b/tests/unit/test_safe_zip.py similarity index 99% rename from tests/test_safe_zip.py rename to tests/unit/test_safe_zip.py index 9741845..328c7d4 100644 --- a/tests/test_safe_zip.py +++ b/tests/unit/test_safe_zip.py @@ -1,4 +1,4 @@ -"""Tests for safe ZIP extraction.""" +"""Unit tests for safe ZIP extraction.""" import contextlib import hashlib diff --git a/uv.lock b/uv.lock index dcdd4ab..07ecf5e 100644 --- a/uv.lock +++ b/uv.lock @@ -392,6 +392,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f4/b2/50e9b292b5cac13e9e81272c7171301abc753a60460d21505b606e15cf21/furo-2025.12.19-py3-none-any.whl", hash = "sha256:bb0ead5309f9500130665a26bee87693c41ce4dbdff864dbfb6b0dae4673d24f", size = 339262, upload-time = "2025-12-19T17:34:38.905Z" }, ] +[[package]] +name = "hypothesis" +version = "6.155.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "sortedcontainers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/10/aa/9a91a4addf285702a98713da44b3581799539426436617bfb8914478c166/hypothesis-6.155.6.tar.gz", hash = "sha256:7569e1897690336c85d49d8391b49ec6ab83d951009515bfc29faebbac286cf5", size = 478038, upload-time = "2026-06-19T13:21:23.379Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/a9/4c17e962c2e9cbc314bb579ed2e2b2da45d7b6b942aab6948d14d85abfea/hypothesis-6.155.6-py3-none-any.whl", hash = "sha256:a96d9a29f6bbc8ccac39dd84e140892da76765464929f401a4181b90c20c9ad1", size = 544521, upload-time = "2026-06-19T13:21:20.934Z" }, +] + [[package]] name = "idna" version = "3.18" @@ -457,7 +470,9 @@ quality = [ { name = "ruff" }, ] tests = [ + { name = "hypothesis" }, { name = "pytest" }, + { name = "pytest-benchmark" }, { name = "pytest-cov" }, ] types = [ @@ -486,7 +501,9 @@ quality = [ { name = "ruff", specifier = "==0.15.16" }, ] tests = [ + { name = "hypothesis" }, { name = "pytest" }, + { name = "pytest-benchmark" }, { name = "pytest-cov" }, ] types = [ @@ -788,6 +805,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8c/c7/7bb2e321574b10df20cbde462a94e2b71d05f9bbda251ef27d104668306a/psutil-7.2.2-cp37-abi3-win_arm64.whl", hash = "sha256:8c233660f575a5a89e6d4cb65d9f938126312bca76d8fe087b947b3a1aaac9ee", size = 134617, upload-time = "2026-01-28T18:15:36.514Z" }, ] +[[package]] +name = "py-cpuinfo" +version = "9.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/37/a8/d832f7293ebb21690860d2e01d8115e5ff6f2ae8bbdc953f0eb0fa4bd2c7/py-cpuinfo-9.0.0.tar.gz", hash = "sha256:3cdbbf3fac90dc6f118bfd64384f309edeadd902d7c8fb17f02ffa1fc3f49690", size = 104716, upload-time = "2022-10-25T20:38:06.303Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/a9/023730ba63db1e494a271cb018dcd361bd2c917ba7004c3e49d5daf795a2/py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5", size = 22335, upload-time = "2022-10-25T20:38:27.636Z" }, +] + [[package]] name = "pygments" version = "2.20.0" @@ -815,6 +841,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/24/25/1de2678b631f5a49215c6c96fff41ba892b0a34df68d6d80292b1b48aa7f/pytest-9.1.1-py3-none-any.whl", hash = "sha256:37a86b45efb9a47a61a36449063e8e18d0cab3161329fc099eb21783169c4f0c", size = 386536, upload-time = "2026-06-19T10:58:31.347Z" }, ] +[[package]] +name = "pytest-benchmark" +version = "5.2.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "py-cpuinfo" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/24/34/9f732b76456d64faffbef6232f1f9dbec7a7c4999ff46282fa418bd1af66/pytest_benchmark-5.2.3.tar.gz", hash = "sha256:deb7317998a23c650fd4ff76e1230066a76cb45dcece0aca5607143c619e7779", size = 341340, upload-time = "2025-11-09T18:48:43.215Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/33/29/e756e715a48959f1c0045342088d7ca9762a2f509b945f362a316e9412b7/pytest_benchmark-5.2.3-py3-none-any.whl", hash = "sha256:bc839726ad20e99aaa0d11a127445457b4219bdb9e80a1afc4b51da7f96b0803", size = 45255, upload-time = "2025-11-09T18:48:39.765Z" }, +] + [[package]] name = "pytest-cov" version = "7.1.0" @@ -887,6 +926,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4c/07/2ebca9b11fb9be7340a818d8d6f63feaebb146be2c4afbd6061701d6df6e/snowballstemmer-3.1.1-py3-none-any.whl", hash = "sha256:7e207fa178741da09cdee59d3ecec3827ad5f92b1fc5c9ff3755b639f71f5752", size = 104164, upload-time = "2026-06-03T00:56:38.614Z" }, ] +[[package]] +name = "sortedcontainers" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e8/c4/ba2f8066cceb6f23394729afe52f3bf7adec04bf9ed2c820b39e19299111/sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88", size = 30594, upload-time = "2021-05-16T22:03:42.897Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575, upload-time = "2021-05-16T22:03:41.177Z" }, +] + [[package]] name = "soupsieve" version = "2.8.4" From 456eb875259bca59ff41b3b9fcc26fd1f7b2c637 Mon Sep 17 00:00:00 2001 From: mdevolde Date: Sat, 27 Jun 2026 19:19:17 +0200 Subject: [PATCH 2/7] test: add unit tests to increase coverage --- .github/workflows/test.yml | 2 +- .gitignore | 3 + pyproject.toml | 14 +- pytest.ini | 2 +- src/language_tool_python/__main__.py | 4 +- src/language_tool_python/_internals/compat.py | 4 +- src/language_tool_python/config_file.py | 2 +- src/language_tool_python/download_lt.py | 2 +- tests/unit/test_api_types.py | 64 +++ tests/unit/test_cli_unit.py | 195 ++++++++ tests/unit/test_config_unit.py | 235 ++++++++++ tests/unit/test_download_unit.py | 305 +++++++++++++ tests/unit/test_internals_utils.py | 224 ++++++++++ tests/unit/test_language_tag.py | 168 +++++++ tests/unit/test_match.py | 421 ++++++++++++++++++ tests/unit/test_utils.py | 218 +++++++++ 16 files changed, 1843 insertions(+), 20 deletions(-) create mode 100644 tests/unit/test_api_types.py create mode 100644 tests/unit/test_cli_unit.py create mode 100644 tests/unit/test_config_unit.py create mode 100644 tests/unit/test_download_unit.py create mode 100644 tests/unit/test_internals_utils.py create mode 100644 tests/unit/test_language_tag.py create mode 100644 tests/unit/test_match.py create mode 100644 tests/unit/test_utils.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ac8a0a3..9b4c7f7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -35,7 +35,7 @@ jobs: - os: ubuntu-26.04 python-version: "3.14" - os: ubuntu-26.04 - python-version: "3.15.0-beta.2" + python-version: "3.15.0-beta.3" - os: macos-26 python-version: "3.14" - os: windows-2025 diff --git a/.gitignore b/.gitignore index e15106e..3f03577 100644 --- a/.gitignore +++ b/.gitignore @@ -204,6 +204,9 @@ cython_debug/ # Ruff stuff: .ruff_cache/ +# Pytest tmp_path base directory (project-relative to avoid Windows temp permission issues) +.pytest_tmp/ + # PyPI configuration file .pypirc diff --git a/pyproject.toml b/pyproject.toml index 182c0cd..e540e9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -174,18 +174,8 @@ warn_unreachable = true warn_unused_configs = true warn_unused_ignores = true -[[tool.mypy.overrides]] -module = ["tests.benchmarks.*"] -# pytest-benchmark is untyped; relax Any restrictions for benchmark files only -disallow_any_unimported = false -disallow_any_expr = false -disallow_any_explicit = false -disallow_any_decorated = false - [[tool.mypy.overrides]] module = ["tests.property.*"] -# hypothesis is untyped; relax Any restrictions for property test files only -disallow_any_unimported = false -disallow_any_expr = false -disallow_any_explicit = false +# hypothesis decorators contain Any expressions, so we need to disable the following checks for tests using hypothesis disallow_any_decorated = false +disallow_any_expr = false diff --git a/pytest.ini b/pytest.ini index 2da9216..cbd524a 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,5 +1,5 @@ [pytest] -addopts = -vra --cov=src --cov-report=html --cov-report=xml +addopts = -vra --cov=src --cov-report=html --cov-report=xml --basetemp=.pytest_tmp testpaths = tests markers = unit: fast, isolated tests with no external dependencies diff --git a/src/language_tool_python/__main__.py b/src/language_tool_python/__main__.py index 8ee561a..8300df2 100644 --- a/src/language_tool_python/__main__.py +++ b/src/language_tool_python/__main__.py @@ -58,7 +58,7 @@ def _read_project_version(pyproject: Path) -> str: __version__ = version("language_tool_python") # If the package is not installed in the environment, # read the version from pyproject.toml -except PackageNotFoundError: +except PackageNotFoundError: # pragma: no cover project_root = Path(__file__).resolve().parent.parent pyproject = project_root / "pyproject.toml" __version__ = _read_project_version(pyproject) @@ -258,7 +258,7 @@ def __call__( cli_args.disable_categories.update(rule_values) elif self.dest == "enable_categories": cli_args.enable_categories.update(rule_values) - else: + else: # pragma: no cover err = f"unexpected rules destination: {self.dest}" raise ValueError(err) diff --git a/src/language_tool_python/_internals/compat.py b/src/language_tool_python/_internals/compat.py index be623b1..9dc0631 100644 --- a/src/language_tool_python/_internals/compat.py +++ b/src/language_tool_python/_internals/compat.py @@ -13,11 +13,11 @@ if sys.version_info >= (3, 11): from tomllib import loads as toml_loads else: - from tomli import loads as toml_loads + from tomli import loads as toml_loads # pragma: no cover if sys.version_info >= (3, 13): from warnings import deprecated else: - from typing_extensions import deprecated + from typing_extensions import deprecated # pragma: no cover __all__ = ["deprecated", "toml_loads"] diff --git a/src/language_tool_python/config_file.py b/src/language_tool_python/config_file.py index 8b73dbe..c7e703a 100644 --- a/src/language_tool_python/config_file.py +++ b/src/language_tool_python/config_file.py @@ -158,7 +158,7 @@ def _path_validator(v: PathLike[str] | str) -> None: if not p.exists(): err = f"path does not exist: {p}" raise PathError(err) - if not p.is_file() and not p.is_dir(): + if not p.is_file() and not p.is_dir(): # pragma: no cover err = f"path is not a file/directory: {p}" raise PathError(err) diff --git a/src/language_tool_python/download_lt.py b/src/language_tool_python/download_lt.py index 65defa0..1ec3330 100644 --- a/src/language_tool_python/download_lt.py +++ b/src/language_tool_python/download_lt.py @@ -385,7 +385,7 @@ def download(self) -> None: :raises NotImplementedError: Always, unless implemented by a subclass. """ - raise NotImplementedError + raise NotImplementedError # pragma: no cover def _get_remote_zip( self, diff --git a/tests/unit/test_api_types.py b/tests/unit/test_api_types.py new file mode 100644 index 0000000..af2b7ce --- /dev/null +++ b/tests/unit/test_api_types.py @@ -0,0 +1,64 @@ +"""Unit tests for _internals/api_types.py TypeGuard helpers.""" + +from language_tool_python._internals.api_types import ( + is_check_response, + is_language_info, +) + + +def test_is_language_info_valid() -> None: + """Accepts a well-formed LanguageInfo dict.""" + assert is_language_info({"code": "en", "longCode": "en-US", "name": "English"}) + + +def test_is_language_info_not_dict() -> None: + """Rejects non-dict values.""" + assert not is_language_info("not a dict") + assert not is_language_info(42) + assert not is_language_info(None) + assert not is_language_info(["code", "longCode", "name"]) + + +def test_is_language_info_missing_field() -> None: + """Rejects dicts with missing required fields.""" + assert not is_language_info({"code": "en", "longCode": "en-US"}) + assert not is_language_info({"code": "en", "name": "English"}) + assert not is_language_info({}) + + +def test_is_language_info_wrong_type() -> None: + """Rejects dicts with non-string field values.""" + assert not is_language_info({"code": 1, "longCode": "en-US", "name": "English"}) + assert not is_language_info({"code": "en", "longCode": None, "name": "English"}) + + +def test_is_check_response_valid() -> None: + """Accepts a well-formed CheckResponse dict.""" + assert is_check_response( + { + "matches": [], + "language": {"code": "en"}, + "warnings": {"incompleteResults": False}, + } + ) + + +def test_is_check_response_not_dict() -> None: + """Rejects non-dict values.""" + assert not is_check_response("not a dict") + assert not is_check_response(None) + assert not is_check_response(123) + + +def test_is_check_response_missing_field() -> None: + """Rejects dicts with missing required fields.""" + assert not is_check_response({"matches": [], "language": {}}) + assert not is_check_response({"matches": [], "warnings": {}}) + assert not is_check_response({}) + + +def test_is_check_response_wrong_type() -> None: + """Rejects dicts with wrong field types.""" + assert not is_check_response({"matches": "[]", "language": {}, "warnings": {}}) + assert not is_check_response({"matches": [], "language": "en", "warnings": {}}) + assert not is_check_response({"matches": [], "language": {}, "warnings": "none"}) diff --git a/tests/unit/test_cli_unit.py b/tests/unit/test_cli_unit.py new file mode 100644 index 0000000..33af934 --- /dev/null +++ b/tests/unit/test_cli_unit.py @@ -0,0 +1,195 @@ +"""Unit tests for the CLI helper functions in __main__.py.""" + +from __future__ import annotations + +import io +from pathlib import Path + +import pytest + +from language_tool_python.__main__ import ( + CliArgs, + _read_project_version, + get_input_text, + get_remote_server, + get_rules, + get_text, + parse_args, + print_exception, +) + + +class TestGetRules: + """Tests for the get_rules() rule-string parser.""" + + def test_comma_separated(self) -> None: + """Comma-separated rule IDs are returned as a set.""" + assert get_rules("RULE_A,RULE_B") == {"RULE_A", "RULE_B"} + + def test_uppercases(self) -> None: + """Rule IDs are uppercased.""" + assert get_rules("rule_a") == {"RULE_A"} + + def test_hyphen_allowed(self) -> None: + """Hyphens inside rule IDs are preserved.""" + assert get_rules("MORFOLOGIK-RULE") == {"MORFOLOGIK-RULE"} + + def test_whitespace_separated(self) -> None: + """Whitespace-separated rule IDs are each returned.""" + assert get_rules("RULE_A RULE_B") == {"RULE_A", "RULE_B"} + + def test_empty_string(self) -> None: + """Empty input returns an empty set.""" + assert get_rules("") == set() + + +class TestParseArgsEnabledOnly: + """Tests for the --enabled-only CLI argument validation.""" + + def test_enabled_only_with_disable_raises(self) -> None: + """--enabled-only combined with --disable causes SystemExit.""" + with pytest.raises(SystemExit): + parse_args( + [ + "-l", + "en-US", + "--enabled-only", + "-e", + "RULE", + "-d", + "OTHER", + "file.txt", + ] + ) + + def test_enabled_only_with_enable_passes(self) -> None: + """--enabled-only with --enable is accepted.""" + args = parse_args(["-l", "en-US", "--enabled-only", "-e", "RULE", "file.txt"]) + assert args.enabled_only is True + assert "RULE" in args.enable + + +class TestGetRemoteServer: + """Tests for the get_remote_server() URL builder.""" + + def _args(self, host: str | None = None, port: str | None = None) -> CliArgs: + """Build a minimal CliArgs with only remote_host/remote_port set.""" + args = CliArgs() + args.remote_host = host + args.remote_port = port + return args + + def test_no_host_returns_none(self) -> None: + """Returns None when no remote host is set.""" + assert get_remote_server(self._args()) is None + + def test_host_without_port(self) -> None: + """Returns the host name alone when no port is given.""" + assert get_remote_server(self._args(host="localhost")) == "localhost" + + def test_host_with_port(self) -> None: + """Returns host:port when both are provided.""" + result = get_remote_server(self._args(host="localhost", port="8081")) + assert result == "localhost:8081" + + +class TestPrintException: + """Tests for the print_exception() stderr printer.""" + + def test_without_debug_prints_to_stderr( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + """Without debug=True, only the message is printed to stderr.""" + print_exception(ValueError("test error"), debug=False) + assert "test error" in capsys.readouterr().err + + def test_with_debug_prints_traceback( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + """With debug=True, the full traceback is printed to stderr.""" + try: + msg = "original error" + raise ValueError(msg) + except ValueError: + print_exception(ValueError("current error"), debug=True) + captured = capsys.readouterr() + assert "ValueError" in captured.err + + +class TestGetText: + """Tests for the get_text() file reader.""" + + def test_reads_file(self, tmp_path: Path) -> None: + """File content is returned as-is when no ignore pattern is given.""" + f = tmp_path / "test.txt" + f.write_text("hello world\n", encoding="utf-8") + result = get_text(str(f), encoding="utf-8", ignore=None) + assert result == "hello world\n" + + def test_ignore_replaces_matching_lines(self, tmp_path: Path) -> None: + """Lines matching the ignore regex are replaced with a newline.""" + f = tmp_path / "test.txt" + f.write_text("keep this\n# skip this\nkeep too\n", encoding="utf-8") + result = get_text(str(f), encoding="utf-8", ignore=r"#.*") + assert "# skip this" not in result + assert "keep this" in result + assert "keep too" in result + + def test_no_ignore_keeps_all(self, tmp_path: Path) -> None: + """All lines are kept when no ignore pattern is set.""" + f = tmp_path / "test.txt" + f.write_text("line1\nline2\n", encoding="utf-8") + result = get_text(str(f), encoding=None, ignore=None) + assert result == "line1\nline2\n" + + +class TestGetInputText: + """Tests for the get_input_text() stdin/file dispatcher.""" + + def _args( + self, ignore_lines: str | None = None, encoding: str | None = None + ) -> CliArgs: + """Build a minimal CliArgs with only ignore_lines/encoding set.""" + args = CliArgs() + args.ignore_lines = ignore_lines + args.encoding = encoding + return args + + def test_reads_from_file(self, tmp_path: Path) -> None: + """Regular filename is read from disk.""" + f = tmp_path / "input.txt" + f.write_text("test content", encoding="utf-8") + result = get_input_text(str(f), self._args()) + assert result == "test content" + + def test_reads_from_stdin(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Filename '-' reads from stdin.""" + monkeypatch.setattr("sys.stdin", io.StringIO("stdin content")) + result = get_input_text("-", self._args()) + assert result == "stdin content" + + def test_stdin_with_ignore_lines(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Matching lines from stdin are suppressed when ignore_lines is set.""" + monkeypatch.setattr("sys.stdin", io.StringIO("keep\n# skip\nkeep2\n")) + result = get_input_text("-", self._args(ignore_lines=r"#.*")) + assert "# skip" not in result + assert "keep" in result + + def test_uses_encoding(self, tmp_path: Path) -> None: + """Non-UTF-8 files are decoded with the specified encoding.""" + f = tmp_path / "latin.txt" + content = "caf\xe9" + f.write_bytes(content.encode("latin-1")) + result = get_input_text(str(f), self._args(encoding="latin-1")) + assert "caf" in result + + +class TestReadProjectVersion: + """Tests for _read_project_version().""" + + def test_reads_version_from_pyproject(self) -> None: + """Version string is read from the project's pyproject.toml.""" + pyproject = Path(__file__).parent.parent.parent / "pyproject.toml" + version = _read_project_version(pyproject) + assert isinstance(version, str) + assert version.count(".") >= 1 diff --git a/tests/unit/test_config_unit.py b/tests/unit/test_config_unit.py new file mode 100644 index 0000000..2005a0d --- /dev/null +++ b/tests/unit/test_config_unit.py @@ -0,0 +1,235 @@ +"""Unit tests for config_file.py encoders, validators, and LanguageToolConfig.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from language_tool_python.config_file import ( + LanguageToolConfig, + _bool_encoder, + _comma_list_encoder, + _encode_config, + _int_encoder, + _is_lang_key, + _number_encoder, + _path_encoder, + _path_validator, +) +from language_tool_python.exceptions import PathError + + +class TestBoolEncoder: + """Tests for the _bool_encoder() function.""" + + def test_true(self) -> None: + """True is encoded as the string 'true'.""" + assert _bool_encoder(v=True) == "true" + + def test_false(self) -> None: + """False is encoded as the string 'false'.""" + assert _bool_encoder(v=False) == "false" + + def test_truthy_int(self) -> None: + """A truthy integer is encoded as 'true'.""" + assert _bool_encoder(1) == "true" + + def test_falsy_int(self) -> None: + """A falsy integer is encoded as 'false'.""" + assert _bool_encoder(0) == "false" + + +class TestIntEncoder: + """Tests for the _int_encoder() function.""" + + def test_positive(self) -> None: + """A positive integer is converted to its decimal string.""" + assert _int_encoder(42) == "42" + + def test_zero(self) -> None: + """Zero is converted to '0'.""" + assert _int_encoder(0) == "0" + + +class TestNumberEncoder: + """Tests for the _number_encoder() function.""" + + def test_integer(self) -> None: + """An integer value is rendered as a float string.""" + assert _number_encoder(5) == "5.0" + + def test_float(self) -> None: + """A float value is rendered with its decimal part.""" + assert _number_encoder(3.14) == "3.14" + + +class TestCommaListEncoder: + """Tests for the _comma_list_encoder() function.""" + + def test_string_passthrough(self) -> None: + """A plain string is returned unchanged.""" + assert _comma_list_encoder("a,b,c") == "a,b,c" + + def test_list_joined(self) -> None: + """A list of strings is joined with commas.""" + assert _comma_list_encoder(["a", "b", "c"]) == "a,b,c" + + def test_tuple_joined(self) -> None: + """A tuple of strings is joined with commas.""" + assert _comma_list_encoder(("x", "y")) == "x,y" + + def test_single_item(self) -> None: + """A single-element list returns the element without a comma.""" + assert _comma_list_encoder(["only"]) == "only" + + +class TestPathEncoder: + """Tests for the _path_encoder() function.""" + + def test_path_object(self, tmp_path: Path) -> None: + """A Path object is encoded to a string containing the path components.""" + result = _path_encoder(tmp_path / "model") + assert "model" in result + + def test_backslash_escaped(self) -> None: + """Windows backslashes in path strings are escaped or converted.""" + p = Path("C:\\Users\\test\\model") + result = _path_encoder(p) + assert "\\\\" in result or "/" in result + + +class TestPathValidator: + """Tests for the _path_validator() function.""" + + def test_existing_file(self, tmp_path: Path) -> None: + """An existing file path passes validation without error.""" + f = tmp_path / "file.txt" + f.write_text("content") + _path_validator(f) + + def test_existing_directory(self, tmp_path: Path) -> None: + """An existing directory path passes validation without error.""" + _path_validator(tmp_path) + + def test_nonexistent_raises(self, tmp_path: Path) -> None: + """A path that does not exist raises PathError.""" + with pytest.raises(PathError, match="does not exist"): + _path_validator(tmp_path / "nonexistent.txt") + + +class TestIsLangKey: + """Tests for the _is_lang_key() predicate.""" + + def test_lang_code_format(self) -> None: + """A key of the form 'lang-XX' is recognized as a language key.""" + assert _is_lang_key("lang-en") is True + + def test_lang_code_dict_path_format(self) -> None: + """A key of the form 'lang-XX-dictPath' is recognized as a language key.""" + assert _is_lang_key("lang-en-dictPath") is True + + def test_not_lang_prefix(self) -> None: + """A key without the 'lang-' prefix is not a language key.""" + assert _is_lang_key("cacheSize") is False + + def test_lang_only_no_code(self) -> None: + """'lang-' with no language code is not a valid language key.""" + assert _is_lang_key("lang-") is False + + def test_lang_too_many_parts(self) -> None: + """A key with more than three parts is not a valid language key.""" + assert _is_lang_key("lang-en-dictPath-extra") is False + + +class TestEncodeConfig: + """Tests for the _encode_config() dict encoder.""" + + def test_int_option(self) -> None: + """An integer option value is encoded as its decimal string.""" + result = _encode_config({"cacheSize": 1000}) + assert result == {"cacheSize": "1000"} + + def test_bool_option(self) -> None: + """A boolean option value is encoded as 'true' or 'false'.""" + result = _encode_config({"pipelineCaching": True}) + assert result == {"pipelineCaching": "true"} + + def test_number_option(self) -> None: + """A float option value is encoded as its float string.""" + result = _encode_config({"maxErrorsPerWordRate": 0.5}) + assert result == {"maxErrorsPerWordRate": "0.5"} + + def test_list_option(self) -> None: + """A list option value is encoded as a comma-separated string.""" + result = _encode_config({"blockedReferrers": ["a.com", "b.com"]}) + assert result == {"blockedReferrers": "a.com,b.com"} + + def test_lang_code_option(self) -> None: + """A language-code option is passed through without modification.""" + result = _encode_config({"lang-en": "custom-word"}) + assert result == {"lang-en": "custom-word"} + + def test_lang_dict_path_option(self, tmp_path: Path) -> None: + """A language dict-path option is accepted when the path exists.""" + result = _encode_config({"lang-en-dictPath": str(tmp_path)}) + assert "lang-en-dictPath" in result + + def test_unknown_key_raises(self) -> None: + """An unrecognized config key raises ValueError.""" + with pytest.raises(ValueError, match="unexpected key"): + _encode_config({"unknownKey": "value"}) + + def test_wrong_type_raises(self) -> None: + """A value of the wrong type for a known key raises TypeError.""" + with pytest.raises(TypeError, match="invalid type"): + _encode_config({"cacheSize": "not_an_int"}) + + def test_path_validator_called(self, tmp_path: Path) -> None: + """A path-type config option with a nonexistent path raises PathError.""" + nonexistent = tmp_path / "no_such_model" + with pytest.raises(PathError, match="does not exist"): + _encode_config({"languageModel": str(nonexistent)}) + + +class TestLanguageToolConfig: + """Tests for the LanguageToolConfig class.""" + + def test_empty_config_raises(self) -> None: + """Constructing with an empty dict raises ValueError.""" + with pytest.raises(ValueError, match="cannot be empty"): + LanguageToolConfig({}) + + def test_valid_config_creates_file(self) -> None: + """A valid config creates a temporary .properties file on disk.""" + cfg = LanguageToolConfig({"cacheSize": 500}) + assert cfg.path + assert Path(cfg.path).exists() + + def test_config_file_content(self) -> None: + """The .properties file contains the expected key=value pair.""" + cfg = LanguageToolConfig({"cacheSize": 500}) + content = Path(cfg.path).read_text(encoding="utf-8") + assert "cacheSize=500" in content + + def test_multiple_options(self) -> None: + """Multiple config options all appear in the .properties file.""" + cfg = LanguageToolConfig({"cacheSize": 100, "pipelineCaching": True}) + content = Path(cfg.path).read_text(encoding="utf-8") + assert "cacheSize=100" in content + assert "pipelineCaching=true" in content + + def test_config_dict_stored(self) -> None: + """The encoded config is stored on the .config attribute.""" + cfg = LanguageToolConfig({"cacheSize": 200}) + assert cfg.config == {"cacheSize": "200"} + + def test_boolean_config(self) -> None: + """A boolean config value is encoded as 'true' or 'false'.""" + cfg = LanguageToolConfig({"premiumOnly": False}) + assert cfg.config == {"premiumOnly": "false"} + + def test_list_config(self) -> None: + """A list config value is encoded as a comma-separated string.""" + cfg = LanguageToolConfig({"disabledRuleIds": ["RULE_A", "RULE_B"]}) + assert cfg.config["disabledRuleIds"] == "RULE_A,RULE_B" diff --git a/tests/unit/test_download_unit.py b/tests/unit/test_download_unit.py new file mode 100644 index 0000000..9111496 --- /dev/null +++ b/tests/unit/test_download_unit.py @@ -0,0 +1,305 @@ +"""Unit tests for download_lt.py helpers (no network, no Java required). + +Note: test_download.py calls importlib.reload(download_lt) which invalidates +static class imports. We access classes via the module object (updated in-place +by reload) to ensure isinstance checks work regardless of test ordering. +""" + +from __future__ import annotations + +import re +from typing import TYPE_CHECKING + +import pytest + +import language_tool_python.download_lt as _dl +from language_tool_python.exceptions import PathError + +if TYPE_CHECKING: + from pathlib import Path + +_JAVA_8_MINOR = 8 +_JAVA_17_MAJOR = 17 +_JAVA_21_MAJOR = 21 +_SHA256_HEX_LENGTH = 64 +_KIBIBYTE = 1024 + + +def return_42(_: object) -> int: + """Return 42, used for monkeypatching.""" + return 42 + + +class TestLoadsManifest: + """Tests for the _loads_manifest() TOML parser.""" + + def test_valid_toml_returns_dict(self) -> None: + """Valid TOML input returns a dict.""" + result = _dl._loads_manifest('[hashes]\n"6.8" = "abc"\n') + assert isinstance(result, dict) + + def test_empty_toml(self) -> None: + """Empty TOML input returns an empty dict.""" + result = _dl._loads_manifest("") + assert result == {} + + +class TestLoadExpectedDownloadSha256: + """Tests for _load_expected_download_sha256().""" + + def test_valid_manifest(self) -> None: + """A well-formed hash entry is parsed to version → hash mapping.""" + sha = "a" * _SHA256_HEX_LENGTH + result = _dl._load_expected_download_sha256(f'"6.8" = "{sha}"\n') + assert result["6.8"] == sha + + def test_non_dict_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: + """A manifest that does not parse to a dict raises PathError.""" + monkeypatch.setattr( + "language_tool_python.download_lt._loads_manifest", + return_42, + ) + with pytest.raises(PathError, match="expected a TOML table"): + _dl._load_expected_download_sha256("anything") + + def test_non_string_value_raises(self) -> None: + """A non-string hash value in the manifest raises PathError.""" + with pytest.raises(PathError, match="expected string keys and values"): + _dl._load_expected_download_sha256('"6.8" = 42\n') + + +class TestValidateDownloadSize: + """Tests for the _validate_download_size() Content-Length checker.""" + + def test_none_returns_none(self) -> None: + """None input (missing header) returns None.""" + assert _dl._validate_download_size(None) is None + + def test_valid_size(self) -> None: + """A numeric size string is converted to an int.""" + assert _dl._validate_download_size("1024") == _KIBIBYTE + + def test_zero_is_valid(self) -> None: + """Zero is a valid content-length.""" + assert _dl._validate_download_size("0") == 0 + + def test_invalid_string_raises(self) -> None: + """A non-numeric string raises PathError.""" + with pytest.raises(PathError, match="Invalid Content-Length"): + _dl._validate_download_size("notanumber") + + def test_negative_raises(self) -> None: + """A negative value raises PathError.""" + with pytest.raises(PathError, match="Invalid Content-Length"): + _dl._validate_download_size("-1") + + def test_too_large_raises(self) -> None: + """A size exceeding the maximum raises PathError.""" + with pytest.raises(PathError, match="Refusing to download"): + _dl._validate_download_size(str(512 * 1024 * 1024 + 1)) + + +class TestParseJavaVersion: + """Tests for _parse_java_version() version string parsing.""" + + def test_old_format_quoted(self) -> None: + """The old 'java version "1.8.0_N"' format is parsed to (1, 8).""" + text = 'java version "1.8.0_292"' + major, minor = _dl._parse_java_version(text) + assert major == 1 + assert minor == _JAVA_8_MINOR + + def test_new_format_17(self) -> None: + """The new 'openjdk N.M.P' format is parsed to (17, 0).""" + text = "openjdk 17.0.1 2021-10-19" + major, minor = _dl._parse_java_version(text) + assert major == _JAVA_17_MAJOR + assert minor == 0 + + def test_new_format_21(self) -> None: + """The new quoted 'openjdk version "21.0.2"' format is parsed to (21, ...).""" + text = 'openjdk version "21.0.2" 2024-01-16' + major, _ = _dl._parse_java_version(text) + assert major == _JAVA_21_MAJOR + + def test_unparseable_raises(self) -> None: + """A string that matches no known pattern causes SystemExit.""" + with pytest.raises(SystemExit, match="Could not parse"): + _dl._parse_java_version("not a java version string") + + def test_multiline_output(self) -> None: + """Multiline java -version output is parsed from the first line.""" + text = ( + 'openjdk version "21.0.2" 2024-01-16\n' + "OpenJDK Runtime Environment (build 21.0.2+13)\n" + "OpenJDK 64-Bit Server VM (build 21.0.2+13, mixed mode, sharing)\n" + ) + major, _ = _dl._parse_java_version(text) + assert major == _JAVA_21_MAJOR + + +class TestLocalLanguageToolFromVersionName: + """Tests for LocalLanguageTool.from_version_name() factory method.""" + + def test_release_version(self) -> None: + """An 'X.Y' string returns a ReleaseLocalLanguageTool instance.""" + lt = _dl.LocalLanguageTool.from_version_name("6.8") + assert isinstance(lt, _dl.ReleaseLocalLanguageTool) + + def test_snapshot_date_version(self) -> None: + """A 'YYYYMMDD' string returns a SnapshotLocalLanguageTool instance.""" + lt = _dl.LocalLanguageTool.from_version_name("20240101") + assert isinstance(lt, _dl.SnapshotLocalLanguageTool) + + def test_snapshot_latest(self) -> None: + """'latest' returns a SnapshotLocalLanguageTool instance.""" + lt = _dl.LocalLanguageTool.from_version_name("latest") + assert isinstance(lt, _dl.SnapshotLocalLanguageTool) + + def test_unknown_format_raises(self) -> None: + """An unrecognized version string raises ValueError.""" + with pytest.raises(ValueError, match="Unknown LanguageTool version"): + _dl.LocalLanguageTool.from_version_name("unknown-format") + + def test_default_version(self) -> None: + """Calling without arguments returns the default release version.""" + lt = _dl.LocalLanguageTool.from_version_name() + assert isinstance(lt, _dl.ReleaseLocalLanguageTool) + + +class TestLocalLanguageToolFromPath: + """Tests for LocalLanguageTool.from_path() directory-name parser.""" + + def test_valid_release_path(self, tmp_path: Path) -> None: + """A 'LanguageTool-X.Y' directory name returns a ReleaseLocalLanguageTool.""" + d = tmp_path / "LanguageTool-6.8" + lt = _dl.LocalLanguageTool.from_path(d) + assert isinstance(lt, _dl.ReleaseLocalLanguageTool) + + def test_valid_snapshot_path(self, tmp_path: Path) -> None: + """A 'LanguageTool-YYYYMMDD' directory returns a SnapshotLocalLanguageTool.""" + d = tmp_path / "LanguageTool-20240101" + lt = _dl.LocalLanguageTool.from_path(d) + assert isinstance(lt, _dl.SnapshotLocalLanguageTool) + + def test_invalid_path_raises(self, tmp_path: Path) -> None: + """A directory name without the expected pattern raises ValueError.""" + d = tmp_path / "not-a-lt-dir" + with pytest.raises(ValueError, match="Could not determine"): + _dl.LocalLanguageTool.from_path(d) + + +class TestReleaseLocalLanguageTool: + """Tests for ReleaseLocalLanguageTool attributes and ordering.""" + + def test_version_name(self) -> None: + """The version_name attribute reflects the version given at construction.""" + lt = _dl.ReleaseLocalLanguageTool("6.8") + assert lt.version_name == "6.8" + + def test_eq(self) -> None: + """Two instances with the same version are equal.""" + a = _dl.ReleaseLocalLanguageTool("6.8") + b = _dl.ReleaseLocalLanguageTool("6.8") + assert a == b + + def test_neq(self) -> None: + """Instances with different versions are not equal.""" + a = _dl.ReleaseLocalLanguageTool("6.8") + b = _dl.ReleaseLocalLanguageTool("6.7") + assert a != b + + def test_lt(self) -> None: + """An older version is less than a newer version.""" + old = _dl.ReleaseLocalLanguageTool("6.7") + new = _dl.ReleaseLocalLanguageTool("6.8") + assert old < new + + def test_hash(self) -> None: + """Equal instances produce the same hash.""" + a = _dl.ReleaseLocalLanguageTool("6.8") + b = _dl.ReleaseLocalLanguageTool("6.8") + assert hash(a) == hash(b) + + def test_in_set(self) -> None: + """Duplicate instances collapse to one element in a set.""" + s = {_dl.ReleaseLocalLanguageTool("6.8"), _dl.ReleaseLocalLanguageTool("6.8")} + assert len(s) == 1 + + def test_download_url_new_version(self) -> None: + """The download URL for a recent version contains the version string.""" + lt = _dl.ReleaseLocalLanguageTool("6.8") + assert "6.8" in lt.download_url + + def test_download_url_old_version_uses_archive(self) -> None: + """The download URL for an old version also contains the version string.""" + lt = _dl.ReleaseLocalLanguageTool("4.0") + assert "4.0" in lt.download_url + + +class TestSnapshotLocalLanguageTool: + """Tests for SnapshotLocalLanguageTool attributes and equality.""" + + def test_version_name_date(self) -> None: + """A date-format version name is stored as-is.""" + lt = _dl.SnapshotLocalLanguageTool("20240101") + assert lt.version_name == "20240101" + + def test_version_name_latest_expands_to_date(self) -> None: + """'latest' expands to an 8-digit date string.""" + lt = _dl.SnapshotLocalLanguageTool("latest") + assert re.match(r"^\d{8}$", lt.version_name) + + def test_eq(self) -> None: + """Two instances with the same date are equal.""" + a = _dl.SnapshotLocalLanguageTool("20240101") + b = _dl.SnapshotLocalLanguageTool("20240101") + assert a == b + + def test_neq(self) -> None: + """Instances with different dates are not equal.""" + a = _dl.SnapshotLocalLanguageTool("20240101") + b = _dl.SnapshotLocalLanguageTool("20240201") + assert a != b + + def test_hash(self) -> None: + """Equal instances produce the same hash.""" + a = _dl.SnapshotLocalLanguageTool("20240101") + b = _dl.SnapshotLocalLanguageTool("20240101") + assert hash(a) == hash(b) + + +class TestGetZipHash: + """Tests for _get_zip_hash() SHA-256 lookup.""" + + def test_bypass_env_returns_none_with_warning( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """LTP_BYPASS_VERIFIED_DOWNLOADS=true skips verification with a warning.""" + monkeypatch.setenv("LTP_BYPASS_VERIFIED_DOWNLOADS", "true") + with pytest.warns(RuntimeWarning, match="bypassed"): + result = _dl._get_zip_hash("6.8") + assert result is None + + def test_known_version_returns_hash(self) -> None: + """A version present in the integrity manifest returns a 64-char hex hash.""" + if not _dl._EXPECTED_DOWNLOAD_SHA256: + pytest.skip("No known hashes in manifest") + version_name = next(iter(_dl._EXPECTED_DOWNLOAD_SHA256)) + result = _dl._get_zip_hash(version_name) + assert result is not None + assert len(result) == _SHA256_HEX_LENGTH + + def test_unknown_version_returns_none( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """A version absent from the manifest returns None.""" + monkeypatch.delenv("LTP_BYPASS_VERIFIED_DOWNLOADS", raising=False) + result = _dl._get_zip_hash("0.0") + assert result is None + + def test_invalid_hash_in_env_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: + """An invalid SHA-256 value in LTP_DOWNLOAD_SHA256 raises PathError.""" + monkeypatch.setenv("LTP_DOWNLOAD_SHA256", "not-a-valid-sha256") + with pytest.raises(PathError, match="Invalid SHA-256"): + _dl._get_zip_hash("6.8") diff --git a/tests/unit/test_internals_utils.py b/tests/unit/test_internals_utils.py new file mode 100644 index 0000000..7a41cb3 --- /dev/null +++ b/tests/unit/test_internals_utils.py @@ -0,0 +1,224 @@ +"""Unit tests for language_tool_python._internals.utils.""" + +from __future__ import annotations + +import subprocess +import sys +import time +from typing import TYPE_CHECKING + +import psutil +import pytest + +from language_tool_python._internals.utils import ( + get_env_float, + get_env_int, + get_language_tool_download_path, + get_locale_language, + kill_process_force, + parse_url, + version_tuple, +) +from language_tool_python.exceptions import PathError + +if TYPE_CHECKING: + from pathlib import Path + +_DEFAULT_INT = 42 +_ENV_INT_VALUE = 100 +_DEFAULT_FLOAT = 1.5 + + +class TestParseUrl: + """Tests for parse_url() scheme normalisation.""" + + def test_full_url_unchanged(self) -> None: + """A complete http URL is returned as-is.""" + assert parse_url("http://localhost:8081") == "http://localhost:8081" + + def test_https_url_unchanged(self) -> None: + """A complete https URL is returned as-is.""" + assert parse_url("https://example.com") == "https://example.com" + + def test_adds_http_scheme(self) -> None: + """A host:port string without a scheme gets http:// prepended.""" + result = parse_url("localhost:8081") + assert result.startswith("http://") + assert "localhost" in result + + def test_canonical_form(self) -> None: + """An already-complete URL with trailing slash is returned unchanged.""" + assert parse_url("http://localhost:8081/") == "http://localhost:8081/" + + +class TestGetEnvInt: + """Tests for get_env_int() environment variable reader.""" + + def test_returns_default_when_absent(self, monkeypatch: pytest.MonkeyPatch) -> None: + """The default is returned when the variable is not set.""" + monkeypatch.delenv("TEST_INT_VAR", raising=False) + assert get_env_int("TEST_INT_VAR", _DEFAULT_INT) == _DEFAULT_INT + + def test_reads_valid_value(self, monkeypatch: pytest.MonkeyPatch) -> None: + """A valid integer string in the environment is returned as an int.""" + monkeypatch.setenv("TEST_INT_VAR", str(_ENV_INT_VALUE)) + assert get_env_int("TEST_INT_VAR", 0) == _ENV_INT_VALUE + + def test_raises_on_non_integer(self, monkeypatch: pytest.MonkeyPatch) -> None: + """A non-numeric string raises PathError.""" + monkeypatch.setenv("TEST_INT_VAR", "notanint") + with pytest.raises(PathError, match="Invalid integer"): + get_env_int("TEST_INT_VAR", 0) + + def test_raises_on_zero(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Zero is not a valid positive integer and raises PathError.""" + monkeypatch.setenv("TEST_INT_VAR", "0") + with pytest.raises(PathError, match="Invalid integer"): + get_env_int("TEST_INT_VAR", 0) + + def test_raises_on_negative(self, monkeypatch: pytest.MonkeyPatch) -> None: + """A negative integer string raises PathError.""" + monkeypatch.setenv("TEST_INT_VAR", "-5") + with pytest.raises(PathError, match="Invalid integer"): + get_env_int("TEST_INT_VAR", 0) + + +class TestGetEnvFloat: + """Tests for get_env_float() environment variable reader.""" + + def test_returns_default_when_absent(self, monkeypatch: pytest.MonkeyPatch) -> None: + """The default is returned when the variable is not set.""" + monkeypatch.delenv("TEST_FLOAT_VAR", raising=False) + assert get_env_float("TEST_FLOAT_VAR", _DEFAULT_FLOAT) == _DEFAULT_FLOAT + + def test_reads_valid_value(self, monkeypatch: pytest.MonkeyPatch) -> None: + """A valid float string is returned as a float.""" + monkeypatch.setenv("TEST_FLOAT_VAR", "3.14") + assert get_env_float("TEST_FLOAT_VAR", 0.0) == pytest.approx(3.14) + + def test_raises_on_non_float(self, monkeypatch: pytest.MonkeyPatch) -> None: + """A non-numeric string raises PathError.""" + monkeypatch.setenv("TEST_FLOAT_VAR", "notafloat") + with pytest.raises(PathError, match="Invalid float"): + get_env_float("TEST_FLOAT_VAR", 0.0) + + def test_raises_on_zero(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Zero is not a valid positive float and raises PathError.""" + monkeypatch.setenv("TEST_FLOAT_VAR", "0.0") + with pytest.raises(PathError, match="Invalid float"): + get_env_float("TEST_FLOAT_VAR", 1.0) + + def test_raises_on_negative(self, monkeypatch: pytest.MonkeyPatch) -> None: + """A negative float string raises PathError.""" + monkeypatch.setenv("TEST_FLOAT_VAR", "-1.0") + with pytest.raises(PathError, match="Invalid float"): + get_env_float("TEST_FLOAT_VAR", 1.0) + + def test_raises_on_inf(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Infinity is not a valid positive float and raises PathError.""" + monkeypatch.setenv("TEST_FLOAT_VAR", "inf") + with pytest.raises(PathError, match="Invalid float"): + get_env_float("TEST_FLOAT_VAR", 1.0) + + +class TestGetLanguageToolDownloadPath: + """Tests for get_language_tool_download_path() path resolver.""" + + def test_returns_path( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """The returned path exists and is a directory.""" + monkeypatch.setenv("LTP_PATH", str(tmp_path)) + path = get_language_tool_download_path() + assert path.exists() + assert path.is_dir() + + def test_creates_directory( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """A non-existent directory under LTP_PATH is created on first use.""" + new_dir = tmp_path / "new_subdir" + monkeypatch.setenv("LTP_PATH", str(new_dir)) + path = get_language_tool_download_path() + assert path.exists() + + def test_default_path_in_home(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Without LTP_PATH, the default path contains 'language_tool_python'.""" + monkeypatch.delenv("LTP_PATH", raising=False) + path = get_language_tool_download_path() + assert "language_tool_python" in str(path) + + +class TestGetLocaleLanguage: + """Tests for get_locale_language() system locale lookup.""" + + def test_returns_string(self) -> None: + """The function returns a non-empty string.""" + result = get_locale_language() + assert isinstance(result, str) + assert len(result) > 0 + + +class TestKillProcessForce: + """Tests for kill_process_force() process terminator.""" + + def test_raises_when_no_args(self) -> None: + """Calling with neither pid nor proc raises ValueError.""" + with pytest.raises(ValueError, match="Must pass either pid or proc"): + kill_process_force() + + def test_kills_by_pid(self) -> None: + """A process is terminated when its pid is given.""" + proc = subprocess.Popen( + [sys.executable, "-c", "import time; time.sleep(60)"], + ) + kill_process_force(pid=proc.pid) + proc.wait(timeout=5) + + def test_kills_by_proc(self) -> None: + """A process is terminated when a psutil.Process object is given.""" + proc = subprocess.Popen( + [sys.executable, "-c", "import time; time.sleep(60)"], + ) + ps_proc = psutil.Process(proc.pid) + kill_process_force(proc=ps_proc) + proc.wait(timeout=5) + + def test_kills_process_with_children(self) -> None: + """A process and its children are all terminated.""" + parent = subprocess.Popen( + [ + sys.executable, + "-c", + ( + "import subprocess, sys, time; " + "subprocess.Popen([sys.executable, '-c', " + "'import time; time.sleep(60)']); " + "time.sleep(60)" + ), + ], + ) + time.sleep(0.3) + kill_process_force(pid=parent.pid) + parent.wait(timeout=10) + + def test_nonexistent_pid_is_silent(self) -> None: + """A nonexistent pid is silently ignored.""" + kill_process_force(pid=999999999) + + +class TestVersionTuple: + """Tests for version_tuple() version string parser.""" + + def test_parses_version(self) -> None: + """A 'X.Y' version string is parsed to a (X, Y) int tuple.""" + assert version_tuple("6.8") == (6, 8) + + def test_parses_version_with_zeros(self) -> None: + """A 'X.0' version string is parsed correctly.""" + assert version_tuple("4.0") == (4, 0) + + def test_raises_on_invalid_format(self) -> None: + """A version string without a dot raises ValueError.""" + with pytest.raises(ValueError, match="not enough values"): + version_tuple("invalid") diff --git a/tests/unit/test_language_tag.py b/tests/unit/test_language_tag.py new file mode 100644 index 0000000..7ac4d3e --- /dev/null +++ b/tests/unit/test_language_tag.py @@ -0,0 +1,168 @@ +"""Unit tests for LanguageTag normalization and comparison.""" + +import pytest + +from language_tool_python.language_tag import LanguageTag + +_LANGS = ["en-US", "en-GB", "en", "de-DE", "fr-FR", "pt-BR"] + +_SET_SIZE_TWO = 2 + + +def _tag(tag: str, languages: list[str] = _LANGS) -> LanguageTag: + """Construct a LanguageTag against _LANGS by default.""" + return LanguageTag(tag, languages) + + +class TestInit: + """Tests for basic LanguageTag initialization and normalization.""" + + def test_exact_match(self) -> None: + """An exact match in the language list is returned unchanged.""" + lt = _tag("en-US") + assert lt.normalized_tag == "en-US" + + def test_underscore_normalized_to_dash(self) -> None: + """Underscore locale separators are converted to dashes.""" + lt = _tag("en_US") + assert lt.normalized_tag == "en-US" + + def test_case_insensitive(self) -> None: + """Tag lookup is case-insensitive.""" + lt = _tag("EN-us") + assert lt.normalized_tag == "en-US" + + def test_tag_stored(self) -> None: + """The original (pre-normalization) tag is preserved.""" + lt = _tag("en-US") + assert lt.tag == "en-US" + + def test_languages_stored(self) -> None: + """The language list is accessible on the tag object.""" + lt = _tag("en-US") + assert "en-US" in lt.languages + + +class TestNormalizePosix: + """Tests for POSIX/C locale fallback behaviour.""" + + def test_c_locale_falls_back_to_en_us(self) -> None: + """'C' locale resolves to en-US when available.""" + lt = _tag("C") + assert lt.normalized_tag == "en-US" + + def test_posix_locale_falls_back_to_en_us(self) -> None: + """'POSIX' locale resolves to en-US when available.""" + lt = _tag("POSIX") + assert lt.normalized_tag == "en-US" + + def test_c_dot_variant(self) -> None: + """'C.UTF-8' resolves to en-US when available.""" + lt = _tag("C.UTF-8") + assert lt.normalized_tag == "en-US" + + def test_posix_prefers_en_gb_when_no_en_us(self) -> None: + """'C' locale falls back to en-GB when en-US is absent.""" + lt = LanguageTag("C", ["en-GB", "fr-FR"]) + assert lt.normalized_tag == "en-GB" + + def test_posix_falls_to_en_when_no_en_us_or_gb(self) -> None: + """'C' locale falls back to bare 'en' when no regional variant exists.""" + lt = LanguageTag("C", ["en", "fr-FR"]) + assert lt.normalized_tag == "en" + + def test_posix_raises_when_no_english(self) -> None: + """'C' locale raises ValueError when no English variant is available.""" + with pytest.raises(ValueError, match="unsupported language"): + LanguageTag("C", ["de-DE", "fr-FR"]) + + +class TestNormalizeFallback: + """Tests for regex-based region-stripping fallback.""" + + def test_language_only_matches_base(self) -> None: + """A bare language code matches the base language entry.""" + lt = _tag("en") + assert lt.normalized_tag == "en" + + def test_regex_fallback_to_base_language(self) -> None: + """An exact-match tag is returned as-is.""" + lt = _tag("pt-BR") + assert lt.normalized_tag == "pt-BR" + + def test_regex_fallback_strips_region(self) -> None: + """A tag with an unavailable region falls back to the base language.""" + lt = LanguageTag("en-AU", ["en", "de-DE"]) + assert lt.normalized_tag == "en" + + def test_empty_tag_raises(self) -> None: + """An empty tag string raises ValueError.""" + with pytest.raises(ValueError, match="empty language tag"): + _tag("") + + def test_unsupported_tag_raises(self) -> None: + """A tag with no match raises ValueError.""" + with pytest.raises(ValueError, match="unsupported language"): + _tag("zz-ZZ") + + def test_unmatched_pattern_raises(self) -> None: + """A non-language-like string raises ValueError.""" + with pytest.raises(ValueError, match="unsupported language"): + _tag("123invalid") + + +class TestComparisons: + """Tests for LanguageTag equality, ordering, and hashing.""" + + def test_eq_same_tag(self) -> None: + """Two tags with the same value are equal.""" + assert _tag("en-US") == _tag("en-US") + + def test_eq_with_string(self) -> None: + """A LanguageTag equals its normalized string.""" + assert _tag("en-US") == "en-US" + + def test_eq_not_equal(self) -> None: + """Tags with different values are not equal.""" + assert _tag("en-US") != _tag("de-DE") + + def test_eq_not_implemented_for_non_str(self) -> None: + """Comparing with a non-string returns NotImplemented.""" + assert _tag("en-US").__eq__(42) is NotImplemented + + def test_lt_ordering(self) -> None: + """Tags are ordered lexicographically by their normalized value.""" + assert _tag("de-DE") < _tag("en-US") + + def test_lt_not_implemented_for_non_str(self) -> None: + """Less-than comparison with a non-string returns NotImplemented.""" + assert _tag("en-US").__lt__(42) is NotImplemented + + def test_hash_equal_tags(self) -> None: + """Equal tags produce the same hash.""" + assert hash(_tag("en-US")) == hash(_tag("en-US")) + + def test_hash_different_tags(self) -> None: + """Different tags produce different hashes (high probability).""" + assert hash(_tag("en-US")) != hash(_tag("de-DE")) + + def test_in_set(self) -> None: + """Two distinct tags result in a two-element set.""" + s = {_tag("en-US"), _tag("de-DE")} + assert len(s) == _SET_SIZE_TWO + + +class TestStrRepr: + """Tests for LanguageTag string representations.""" + + def test_str_returns_normalized(self) -> None: + """str() returns the normalized tag.""" + assert str(_tag("en-US")) == "en-US" + + def test_repr_format(self) -> None: + """repr() uses the canonical angle-bracket format.""" + assert repr(_tag("en-US")) == '' + + def test_total_ordering_gt(self) -> None: + """Greater-than comparison works via total_ordering.""" + assert _tag("en-US") > _tag("de-DE") diff --git a/tests/unit/test_match.py b/tests/unit/test_match.py new file mode 100644 index 0000000..fdccbc8 --- /dev/null +++ b/tests/unit/test_match.py @@ -0,0 +1,421 @@ +"""Unit tests for the Match class and related helpers.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from language_tool_python.match import ( + Match, + _four_byte_char_positions, + _get_match_ordered_dict, + is_check_match, +) + +if TYPE_CHECKING: + from language_tool_python._internals.api_types import CheckMatch + +_DEFAULT_OFFSET = 8 +_DEFAULT_LENGTH = 4 +_DEFAULT_CONTEXT_OFFSET = 8 +_NUM_MATCH_FIELDS = 10 + + +def _make_attrib( # noqa: PLR0913 + *, + message: str = "Possible spelling mistake.", + short_message: str = "Spelling mistake", + replacements: list[str] | None = None, + offset: int = 8, + length: int = 4, + context_text: str = "This is noot okay.", + context_offset: int = 8, + sentence: str = "This is noot okay.", + rule_id: str = "MORFOLOGIK_RULE_EN_US", + rule_desc: str = "Possible spelling mistake", + issue_type: str = "misspelling", + category_id: str = "TYPOS", + category_name: str = "Possible Typo", +) -> CheckMatch: + repl_list: list[str] = replacements if replacements is not None else ["not", "noon"] + return { + "message": message, + "shortMessage": short_message, + "replacements": [{"value": r} for r in repl_list], + "offset": offset, + "length": length, + "context": {"text": context_text, "offset": context_offset, "length": length}, + "sentence": sentence, + "type": {"typeName": "Other"}, + "rule": { + "id": rule_id, + "description": rule_desc, + "issueType": issue_type, + "category": {"id": category_id, "name": category_name}, + }, + "ignoreForIncompleteSentence": False, + "contextForSureMatch": 0, + } + + +def _make_match(text: str = "This is noot okay.", **kwargs: object) -> Match: + return Match(_make_attrib(**kwargs), text) # type: ignore[arg-type] + + +class TestMatchInit: + """Tests for Match.__init__() attribute mapping.""" + + def test_basic_attributes(self) -> None: + """Default attributes are populated correctly from the attrib dict.""" + m = _make_match() + assert m.rule_id == "MORFOLOGIK_RULE_EN_US" + assert m.message == "Possible spelling mistake." + assert m.replacements == ["not", "noon"] + assert m.offset == _DEFAULT_OFFSET + assert m.error_length == _DEFAULT_LENGTH + assert m.category == "TYPOS" + assert m.rule_issue_type == "misspelling" + assert m.sentence == "This is noot okay." + + def test_context_attributes(self) -> None: + """Context text and offset are set from the nested context dict.""" + m = _make_match() + assert m.context == "This is noot okay." + assert m.offset_in_context == _DEFAULT_CONTEXT_OFFSET + + def test_unicode_normalization(self) -> None: + """Message text is NFKC-normalized on construction.""" + # "fi" (U+FB01 LATIN SMALL LIGATURE FI) → "fi" + m = _make_match(message="find the error") + assert m.message == "find the error" + + def test_four_byte_char_adjustment(self) -> None: + """A 4-byte emoji before the match shifts the Python offset by 1.""" + # "🌅" at position 0 is 1 Python char but 2 Java chars + # Java offset 3 → Python offset 2 ("🌅 he" → 'h' is at index 2) + text = "🌅 hello world" + attrib = _make_attrib( + offset=3, + length=5, + context_text="🌅 hello world", + context_offset=3, + sentence="🌅 hello world", + ) + m = Match(attrib, text) + adjusted_offset = 2 + assert m.offset == adjusted_offset + + def test_no_adjustment_without_four_byte_chars(self) -> None: + """Offsets are unchanged when no 4-byte characters precede the match.""" + text = "Hello world today" + expected_offset = 6 + m = Match( + _make_attrib( + offset=expected_offset, + length=5, + context_text=text, + context_offset=expected_offset, + sentence=text, + ), + text, + ) + assert m.offset == expected_offset + + def test_same_text_reuses_cache(self) -> None: + """Two matches on the same text share the cached position list.""" + text = "Same text here." + explicit_offset = 5 + m1 = Match(_make_attrib(context_text=text, sentence=text), text) + m2 = Match( + _make_attrib( + context_text=text, + sentence=text, + offset=explicit_offset, + length=_DEFAULT_LENGTH, + context_offset=explicit_offset, + ), + text, + ) + assert text == Match.PREVIOUS_MATCHES_TEXT + assert m1.offset == _DEFAULT_OFFSET + assert m2.offset == explicit_offset + + +class TestFourByteCharPositions: + """Tests for _four_byte_char_positions() helper.""" + + def test_empty_string(self) -> None: + """An empty string has no 4-byte char positions.""" + assert _four_byte_char_positions("") == [] + + def test_ascii_only(self) -> None: + """A pure-ASCII string has no 4-byte char positions.""" + assert _four_byte_char_positions("hello") == [] + + def test_emoji_at_start(self) -> None: + """An emoji at position 0 is reported at index 0.""" + assert _four_byte_char_positions("🌅abc") == [0] + + def test_multiple_emojis(self) -> None: + """Two consecutive emojis are reported at their Python indices.""" + positions = _four_byte_char_positions("🌅🎉abc") + assert positions == [0, 2] + + def test_emoji_in_middle(self) -> None: + """An emoji in the middle of ASCII text is reported at the correct index.""" + positions = _four_byte_char_positions("ab🌅cd") + assert positions == [2] + + +class TestMatchOrderedDict: + """Tests for _get_match_ordered_dict() field-type registry.""" + + def test_returns_all_keys(self) -> None: + """All expected field names are returned in order.""" + d = _get_match_ordered_dict() + expected_keys = [ + "rule_id", + "message", + "replacements", + "offset_in_context", + "context", + "offset", + "error_length", + "category", + "rule_issue_type", + "sentence", + ] + assert list(d.keys()) == expected_keys + + def test_value_types(self) -> None: + """Field types are the expected Python built-ins.""" + d = _get_match_ordered_dict() + assert d["offset"] is int + assert d["rule_id"] is str + assert d["replacements"] is list + + +class TestIsCheckMatch: + """Tests for the is_check_match() type-guard.""" + + def test_valid_check_match(self) -> None: + """A fully populated attrib dict is recognised as a CheckMatch.""" + assert is_check_match(_make_attrib()) + + def test_not_dict(self) -> None: + """Non-dict values are rejected.""" + assert not is_check_match("not a dict") + assert not is_check_match(None) + assert not is_check_match(42) + + def test_missing_field(self) -> None: + """A dict missing a required field is rejected.""" + attrib = dict(_make_attrib()) + del attrib["message"] + assert not is_check_match(attrib) + + def test_wrong_type(self) -> None: + """A dict with a field of the wrong type is rejected.""" + attrib = dict(_make_attrib()) + attrib["offset"] = "not_an_int" + assert not is_check_match(attrib) + + +class TestMatchStr: + """Tests for Match.__str__() human-readable formatter.""" + + def test_str_contains_rule_id(self) -> None: + """The rule ID is present in the string representation.""" + m = _make_match() + s = str(m) + assert "MORFOLOGIK_RULE_EN_US" in s + + def test_str_contains_message(self) -> None: + """The error message is present in the string representation.""" + m = _make_match() + assert "Possible spelling mistake" in str(m) + + def test_str_contains_suggestions(self) -> None: + """Replacement suggestions are present in the string representation.""" + m = _make_match() + assert "not" in str(m) + + def test_str_no_message_skips_message_line(self) -> None: + """A match with no message omits the Message line.""" + m = _make_match(message="") + assert "Message" not in str(m) + + def test_str_no_replacements_skips_suggestion(self) -> None: + """A match with no replacements omits the Suggestion line.""" + m = _make_match(replacements=[]) + assert "Suggestion" not in str(m) + + +class TestMatchRepr: + """Tests for Match.__repr__() machine-readable formatter.""" + + def test_repr_contains_class_name(self) -> None: + """The class name 'Match(' appears in the repr.""" + m = _make_match() + assert "Match(" in repr(m) + + def test_repr_contains_rule_id(self) -> None: + """The rule ID appears in the repr.""" + m = _make_match() + assert "MORFOLOGIK_RULE_EN_US" in repr(m) + + +class TestMatchedText: + """Tests for the matched_text property.""" + + def test_matched_text_extracts_correctly(self) -> None: + """matched_text returns the exact text slice at offset/length.""" + m = _make_match() + assert m.matched_text == "noot" + + +class TestGetLineAndColumn: + """Tests for Match.get_line_and_column().""" + + def test_single_line(self) -> None: + """A single-line text returns line 1 and a positive column.""" + text = "This is noot okay." + m = _make_match(text=text) + line, col = m.get_line_and_column(text) + assert line == 1 + assert col > 0 + + def test_context_not_in_text_raises(self) -> None: + """Passing unrelated text raises ValueError.""" + m = _make_match() + with pytest.raises(ValueError, match="does not match the context"): + m.get_line_and_column("completely different text here blah blah") + + +class TestSelectReplacement: + """Tests for Match.select_replacement() replacement narrower.""" + + def test_select_valid_index(self) -> None: + """Selecting index 1 keeps only the second replacement.""" + m = _make_match() + m.select_replacement(1) + assert m.replacements == ["noon"] + + def test_select_first(self) -> None: + """Selecting index 0 keeps only the first replacement.""" + m = _make_match() + m.select_replacement(0) + assert m.replacements == ["not"] + + def test_negative_index_raises(self) -> None: + """A negative index raises ValueError.""" + m = _make_match() + with pytest.raises(ValueError, match="numbered from 0"): + m.select_replacement(-1) + + def test_out_of_bounds_raises(self) -> None: + """An out-of-range index raises ValueError.""" + m = _make_match() + with pytest.raises(ValueError, match="numbered from 0"): + m.select_replacement(99) + + def test_no_replacements_raises(self) -> None: + """Selecting when there are no replacements raises ValueError.""" + m = _make_match(replacements=[]) + with pytest.raises(ValueError, match="no suggestions"): + m.select_replacement(0) + + +class TestMatchComparisons: + """Tests for Match equality, ordering, and NotImplemented handling.""" + + def test_eq_equal_matches(self) -> None: + """Two matches built from the same attrib dict are equal.""" + m1 = _make_match() + m2 = _make_match() + assert m1 == m2 + + def test_eq_different_offset(self) -> None: + """Matches with different offsets are not equal.""" + m1 = _make_match() + m2 = _make_match(offset=0, context_offset=0) + assert m1 != m2 + + def test_eq_not_implemented_for_non_match(self) -> None: + """Comparing a Match with a non-Match returns NotImplemented.""" + m = _make_match() + assert m.__eq__("not a match") is NotImplemented + + def test_lt(self) -> None: + """A match at an earlier offset is less than one at a later offset.""" + text = "This is noot okay, and also baaad." + m_early = Match( + _make_attrib( + offset=0, + length=_DEFAULT_LENGTH, + context_text=text, + context_offset=0, + sentence=text, + ), + text, + ) + m_later = Match( + _make_attrib( + offset=_DEFAULT_OFFSET, + length=_DEFAULT_LENGTH, + context_text=text, + context_offset=_DEFAULT_OFFSET, + sentence=text, + ), + text, + ) + assert m_early < m_later + + def test_lt_not_implemented_for_non_match(self) -> None: + """Less-than comparison with a non-Match returns NotImplemented.""" + m = _make_match() + assert m.__lt__("not a match") is NotImplemented + + +class TestMatchIter: + """Tests for Match.__iter__() field-value iterator.""" + + def test_iter_yields_all_values(self) -> None: + """Iterating a match yields exactly _NUM_MATCH_FIELDS values.""" + m = _make_match() + values = list(m) + assert len(values) == _NUM_MATCH_FIELDS + + def test_iter_first_is_rule_id(self) -> None: + """The first value yielded by the iterator is the rule_id.""" + m = _make_match() + assert next(iter(m)) == "MORFOLOGIK_RULE_EN_US" + + +class TestMatchSetAttr: + """Tests for Match.__setattr__() type-coercing setter.""" + + def test_setattr_known_key_coerces_type(self) -> None: + """Setting a known field with a string coerces it to the declared type.""" + m = _make_match() + new_offset = 5 + m.offset = "5" # type: ignore[assignment] + assert m.offset == new_offset + assert isinstance(m.offset, int) + + def test_setattr_unknown_key_is_ignored(self) -> None: + """Setting an unknown field is silently ignored.""" + m = _make_match() + m.__setattr__("nonexistent_key", "value") + assert not hasattr(m, "nonexistent_key") + + +class TestMatchGetAttr: + """Tests for Match.__getattr__() unknown-attribute guard.""" + + def test_getattr_unknown_key_raises(self) -> None: + """Accessing an unknown attribute raises AttributeError.""" + m = _make_match() + with pytest.raises(AttributeError, match="no attribute"): + _ = m.completely_unknown diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py new file mode 100644 index 0000000..aa44343 --- /dev/null +++ b/tests/unit/test_utils.py @@ -0,0 +1,218 @@ +"""Unit tests for language_tool_python.utils (classify_matches, correct).""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from language_tool_python.match import Match +from language_tool_python.utils import TextStatus, classify_matches, correct + +if TYPE_CHECKING: + from language_tool_python._internals.api_types import CheckMatch + + +def _make_match( + rule_id: str = "RULE", + offset: int = 0, + length: int = 4, + replacements: list[str] | None = None, +) -> Match: + attrib: CheckMatch = { + "message": "Error", + "shortMessage": "", + "replacements": [{"value": r} for r in (replacements or [])], + "offset": offset, + "length": length, + "context": {"text": "text here.", "offset": offset, "length": length}, + "sentence": "text here.", + "type": {"typeName": "Other"}, + "rule": { + "id": rule_id, + "description": "desc", + "issueType": "misspelling", + "category": {"id": "TYPOS", "name": "Typos"}, + }, + "ignoreForIncompleteSentence": False, + "contextForSureMatch": 0, + } + return Match(attrib, "text here.") + + +class TestClassifyMatches: + """Tests for classify_matches() match-set status classifier.""" + + def test_no_matches_returns_correct(self) -> None: + """An empty match list is classified as CORRECT.""" + assert classify_matches([]) == TextStatus.CORRECT + + def test_matches_with_replacements_returns_faulty(self) -> None: + """A match that has a replacement is classified as FAULTY.""" + m = _make_match(replacements=["fix"]) + assert classify_matches([m]) == TextStatus.FAULTY + + def test_matches_without_replacements_returns_garbage(self) -> None: + """A match without any replacement is classified as GARBAGE.""" + m = _make_match(replacements=[]) + assert classify_matches([m]) == TextStatus.GARBAGE + + def test_mixed_filters_to_faulty(self) -> None: + """A mix of matches with and without replacements is classified as FAULTY.""" + m_with = _make_match(replacements=["fix"]) + m_without = _make_match(replacements=[]) + assert classify_matches([m_with, m_without]) == TextStatus.FAULTY + + def test_all_without_replacements_is_garbage(self) -> None: + """Multiple matches all lacking replacements are classified as GARBAGE.""" + matches = [_make_match(replacements=[]) for _ in range(3)] + assert classify_matches(matches) == TextStatus.GARBAGE + + +class TestCorrect: + """Tests for correct() auto-correction function.""" + + def test_no_matches_returns_unchanged(self) -> None: + """Text with no matches is returned unchanged.""" + assert correct("hello world", []) == "hello world" + + def test_single_correction(self) -> None: + """A single match with a replacement is applied to the text.""" + m = _make_match(offset=0, length=4, replacements=["text"]) + result = correct("text here.", [m]) + assert result == "text here." + + def test_correction_replaces_error(self) -> None: + """A misspelled word is replaced by the first suggested correction.""" + text = "Helo world" + attrib: CheckMatch = { + "message": "Misspelling", + "shortMessage": "", + "replacements": [{"value": "Hello"}], + "offset": 0, + "length": 4, + "context": {"text": text, "offset": 0, "length": 4}, + "sentence": text, + "type": {"typeName": "Other"}, + "rule": { + "id": "SPELL", + "description": "Spelling", + "issueType": "misspelling", + "category": {"id": "TYPOS", "name": "Typos"}, + }, + "ignoreForIncompleteSentence": False, + "contextForSureMatch": 0, + } + m = Match(attrib, text) + result = correct(text, [m]) + assert result == "Hello world" + + def test_match_without_replacement_is_skipped(self) -> None: + """A match with no replacement leaves the text unchanged.""" + m = _make_match(offset=0, length=4, replacements=[]) + assert correct("text here.", [m]) == "text here." + + def test_overlapping_match_skips_mismatched_error(self) -> None: + """The second of two overlapping matches is skipped when offset drifts.""" + # First match replaces "aa" (offset 0, len 2) with "xxxxxx" (longer). + # Second match overlaps at offset 1, len 2 ("ab"). After the first + # replacement expands the text, the second match's expected text no + # longer sits at the right position → continue branch is hit. + text = "aabbc" + attrib1: CheckMatch = { + "message": "e", + "shortMessage": "", + "replacements": [{"value": "xxxxxx"}], + "offset": 0, + "length": 2, + "context": {"text": text, "offset": 0, "length": 2}, + "sentence": text, + "type": {"typeName": "Other"}, + "rule": { + "id": "R", + "description": "d", + "issueType": "misspelling", + "category": {"id": "C", "name": "C"}, + }, + "ignoreForIncompleteSentence": False, + "contextForSureMatch": 0, + } + attrib2: CheckMatch = { + "message": "e", + "shortMessage": "", + "replacements": [{"value": "y"}], + "offset": 1, + "length": 2, + "context": {"text": text, "offset": 1, "length": 2}, + "sentence": text, + "type": {"typeName": "Other"}, + "rule": { + "id": "R", + "description": "d", + "issueType": "misspelling", + "category": {"id": "C", "name": "C"}, + }, + "ignoreForIncompleteSentence": False, + "contextForSureMatch": 0, + } + m1 = Match(attrib1, text) + m2 = Match(attrib2, text) + result = correct(text, [m1, m2]) + assert result == "xxxxxxbbc" + + def test_correct_adjusts_offset_for_length_change(self) -> None: + """A length-changing replacement shifts the offset for subsequent matches.""" + text = "A b c" + attrib1: CheckMatch = { + "message": "err", + "shortMessage": "", + "replacements": [{"value": "AAA"}], + "offset": 0, + "length": 1, + "context": {"text": text, "offset": 0, "length": 1}, + "sentence": text, + "type": {"typeName": "Other"}, + "rule": { + "id": "R", + "description": "d", + "issueType": "misspelling", + "category": {"id": "C", "name": "C"}, + }, + "ignoreForIncompleteSentence": False, + "contextForSureMatch": 0, + } + attrib2: CheckMatch = { + "message": "err", + "shortMessage": "", + "replacements": [{"value": "BBB"}], + "offset": 2, + "length": 1, + "context": {"text": text, "offset": 2, "length": 1}, + "sentence": text, + "type": {"typeName": "Other"}, + "rule": { + "id": "R", + "description": "d", + "issueType": "misspelling", + "category": {"id": "C", "name": "C"}, + }, + "ignoreForIncompleteSentence": False, + "contextForSureMatch": 0, + } + m1 = Match(attrib1, text) + m2 = Match(attrib2, text) + result = correct(text, [m1, m2]) + assert result == "AAA BBB c" + + +@pytest.mark.parametrize( + ("status", "value"), + [ + (TextStatus.CORRECT, "correct"), + (TextStatus.FAULTY, "faulty"), + (TextStatus.GARBAGE, "garbage"), + ], +) +def test_text_status_values(status: TextStatus, value: str) -> None: + """TextStatus enum values match expected strings.""" + assert status.value == value From 2c5fe230528e93c9b0a9800c8801ea9dc216b314 Mon Sep 17 00:00:00 2001 From: mdevolde Date: Wed, 1 Jul 2026 14:52:27 +0300 Subject: [PATCH 3/7] test: increase coverage by adding unit tests --- Makefile | 9 +- make.bat | 9 +- src/language_tool_python/__main__.py | 6 +- src/language_tool_python/_internals/compat.py | 2 + .../_internals/safe_zip.py | 39 +- src/language_tool_python/config_file.py | 2 + src/language_tool_python/download_lt.py | 13 +- src/language_tool_python/language_tag.py | 4 + src/language_tool_python/server.py | 2 +- tests/unit/test_cli_unit.py | 243 +++- tests/unit/test_download.py | 63 + tests/unit/test_download_unit.py | 359 +++++- tests/unit/test_internals_utils.py | 6 - tests/unit/test_match.py | 37 +- tests/unit/test_safe_zip.py | 90 ++ tests/unit/test_server_unit.py | 1089 +++++++++++++++++ tests/unit/test_utils.py | 16 +- 17 files changed, 1922 insertions(+), 67 deletions(-) create mode 100644 tests/unit/test_server_unit.py diff --git a/Makefile b/Makefile index 537d29c..c321966 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: default install format fix ruff-check mypy-check check test doc publish +.PHONY: default install format fix ruff-check mypy-check check test doc publish clean UV := $(shell command -v uv 2>/dev/null || true) ifeq ($(UV),) @@ -6,7 +6,7 @@ $(warning uv not found. Install uv (curl -LsSf https://astral.sh/uv/install.sh | endif default: - @echo "Usage: make [install|format|fix|ruff-check|mypy-check|check|test|doc|publish]" + @echo "Usage: make [install|format|fix|ruff-check|mypy-check|check|test|doc|publish|clean]" @exit 1 install: @@ -36,7 +36,10 @@ doc: uv run --group docs --locked sphinx-build -M html docs/source docs/build publish: - rm -rf dist/ + make clean uv build uvx twine check dist/* uv publish + +clean: + git clean -xfd --exclude .venv diff --git a/make.bat b/make.bat index 6f0e194..7ad5bde 100644 --- a/make.bat +++ b/make.bat @@ -16,8 +16,9 @@ if "%1"=="check" goto check if "%1"=="test" goto test if "%1"=="doc" goto doc if "%1"=="publish" goto publish +if "%1"=="clean" goto clean -echo Usage: make.bat [install^|format^|fix^|ruff-check^|mypy-check^|check^|test^|doc^|publish] +echo Usage: make.bat [install^|format^|fix^|ruff-check^|mypy-check^|check^|test^|doc^|publish^|clean] exit /b 1 :install @@ -58,7 +59,7 @@ call uv run --group docs --locked sphinx-build -M html docs/source docs/build exit /b %errorlevel% :publish -if exist dist\ rmdir /s /q dist\ +call .\make.bat clean uv build if errorlevel 1 exit /b %errorlevel% @@ -68,3 +69,7 @@ if errorlevel 1 exit /b %errorlevel% uv publish exit /b %errorlevel% + +:clean +git clean -xfd --exclude .venv +exit /b %errorlevel% diff --git a/src/language_tool_python/__main__.py b/src/language_tool_python/__main__.py index 8300df2..8b5cc27 100644 --- a/src/language_tool_python/__main__.py +++ b/src/language_tool_python/__main__.py @@ -58,7 +58,7 @@ def _read_project_version(pyproject: Path) -> str: __version__ = version("language_tool_python") # If the package is not installed in the environment, # read the version from pyproject.toml -except PackageNotFoundError: # pragma: no cover +except PackageNotFoundError: # pragma: no cover # package installed in test env project_root = Path(__file__).resolve().parent.parent pyproject = project_root / "pyproject.toml" __version__ = _read_project_version(pyproject) @@ -258,7 +258,7 @@ def __call__( cli_args.disable_categories.update(rule_values) elif self.dest == "enable_categories": cli_args.enable_categories.update(rule_values) - else: # pragma: no cover + else: # pragma: no cover # defensive: all known dest values are handled above err = f"unexpected rules destination: {self.dest}" raise ValueError(err) @@ -449,5 +449,5 @@ def main(argv: Sequence[str] | None = None) -> int: return status -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover raise SystemExit(main()) diff --git a/src/language_tool_python/_internals/compat.py b/src/language_tool_python/_internals/compat.py index 9dc0631..57109c6 100644 --- a/src/language_tool_python/_internals/compat.py +++ b/src/language_tool_python/_internals/compat.py @@ -13,11 +13,13 @@ if sys.version_info >= (3, 11): from tomllib import loads as toml_loads else: + # Python < 3.11 fallback, cov CI runs on 3.11+, so this branch is never executed. from tomli import loads as toml_loads # pragma: no cover if sys.version_info >= (3, 13): from warnings import deprecated else: + # Python < 3.13 fallback, cov CI runs on 3.13+, so this branch is never executed. from typing_extensions import deprecated # pragma: no cover __all__ = ["deprecated", "toml_loads"] diff --git a/src/language_tool_python/_internals/safe_zip.py b/src/language_tool_python/_internals/safe_zip.py index aa2c757..dac6023 100644 --- a/src/language_tool_python/_internals/safe_zip.py +++ b/src/language_tool_python/_internals/safe_zip.py @@ -179,7 +179,9 @@ def _normalize_member_path(self, filename: str) -> PurePosixPath: member_path = PurePosixPath(*parts) - if member_path.is_absolute() or any(part == ".." for part in member_path.parts): + if ( # pragma: no cover # parts validated; PurePosixPath always relative + member_path.is_absolute() or any(part == ".." for part in member_path.parts) + ): err = f"Unsafe ZIP member path: {filename!r}." raise PathError(err) @@ -256,7 +258,7 @@ def _zip_target(self, destination: Path, member_path: PurePosixPath) -> Path: if destination_resolved != target_resolved and ( destination_resolved not in target_resolved.parents - ): + ): # pragma: no cover # TOCTOU: escape needs concurrent modification err = f"Unsafe ZIP member path: {str(member_path)!r}." raise PathError(err) @@ -457,13 +459,13 @@ def _ensure_safe_parent(self, destination: Path, target: Path) -> None: if destination_resolved != parent_resolved and ( destination_resolved not in parent_resolved.parents - ): + ): # pragma: no cover # TOCTOU: escape needs concurrent modification err = f"Unsafe ZIP extraction parent path: {target.parent}." raise PathError(err) try: relative_parent = target.parent.relative_to(destination) - except ValueError as e: + except ValueError as e: # pragma: no cover # caught by check above err = f"Unsafe ZIP extraction parent path: {target.parent}." raise PathError(err) from e @@ -472,11 +474,11 @@ def _ensure_safe_parent(self, destination: Path, target: Path) -> None: for part in relative_parent.parts: current = current / part - if current.is_symlink(): + if current.is_symlink(): # pragma: no cover # TOCTOU: mkdir'd above err = f"Refusing to extract through symlinked directory: {current}." raise PathError(err) - if not current.is_dir(): + if not current.is_dir(): # pragma: no cover # TOCTOU: mkdir'd above err = f"Refusing to extract through non-directory path: {current}." raise PathError(err) @@ -484,7 +486,7 @@ def _ensure_safe_parent(self, destination: Path, target: Path) -> None: if destination_resolved != current_resolved and ( destination_resolved not in current_resolved.parents - ): + ): # pragma: no cover # TOCTOU: escape needs concurrent modification err = f"Unsafe ZIP extraction directory path: {current}." raise PathError(err) @@ -504,11 +506,11 @@ def _copy_member( :type target: Path :raises PathError: If the target is unsafe or size checks fail. """ - if target.exists() or target.is_symlink(): + if target.exists() or target.is_symlink(): # pragma: no cover # TOCTOU err = f"Refusing to overwrite existing path while extracting ZIP: {target}." raise PathError(err) - if target.parent.is_symlink(): + if target.parent.is_symlink(): # pragma: no cover # TOCTOU: parent safe above err = ( f"Refusing to extract into symlinked parent directory: {target.parent}." ) @@ -578,13 +580,13 @@ def _extract_to_private_directory( destination.mkdir(parents=True, exist_ok=True) - if destination.is_symlink(): + if destination.is_symlink(): # pragma: no cover # TOCTOU: mkdtemp dir err = f"Refusing to extract into symlinked destination: {destination}." raise PathError(err) destination_resolved = destination.resolve(strict=True) - if not destination_resolved.is_dir(): + if not destination_resolved.is_dir(): # pragma: no cover # TOCTOU: mkdir'd err = f"ZIP extraction destination is not a directory: {destination}." raise PathError(err) @@ -592,7 +594,9 @@ def _extract_to_private_directory( target = self._zip_target(destination, member_path) if member.is_dir(): - if target.exists() and not target.is_dir(): + if ( # pragma: no cover # dup-path check catches file-vs-dir above + target.exists() and not target.is_dir() + ): err = ( f"Refusing to overwrite existing path while extracting ZIP: " f"{target}." @@ -602,7 +606,7 @@ def _extract_to_private_directory( target.mkdir(parents=True, exist_ok=True) self._ensure_safe_parent(destination, target) - if target.is_symlink(): + if target.is_symlink(): # pragma: no cover # TOCTOU: mkdir'd above err = ( f"Refusing to create or use symlinked ZIP directory: {target}." ) @@ -610,8 +614,9 @@ def _extract_to_private_directory( target_resolved = target.resolve(strict=True) - if destination_resolved != target_resolved and ( - destination_resolved not in target_resolved.parents + if ( # pragma: no cover # TOCTOU: mkdir'd dir escaped + destination_resolved != target_resolved + and destination_resolved not in target_resolved.parents ): err = f"Unsafe ZIP directory path after creation: {target}." raise PathError(err) @@ -687,7 +692,7 @@ def _extractall_to_directory( final_directory_resolved = final_directory.resolve(strict=True) - if not final_directory_resolved.is_dir(): + if not final_directory_resolved.is_dir(): # pragma: no cover # TOCTOU err = ( f"ZIP extraction destination is not a directory: {final_directory}." ) @@ -695,7 +700,7 @@ def _extractall_to_directory( destinations: list[tuple[Path, Path]] = [] for child in extract_dir.iterdir(): - if child.is_symlink(): + if child.is_symlink(): # pragma: no cover # symlinks rejected above err = f"Refusing to move symlinked extracted path: {child}." raise PathError(err) diff --git a/src/language_tool_python/config_file.py b/src/language_tool_python/config_file.py index c7e703a..6d066d7 100644 --- a/src/language_tool_python/config_file.py +++ b/src/language_tool_python/config_file.py @@ -158,6 +158,8 @@ def _path_validator(v: PathLike[str] | str) -> None: if not p.exists(): err = f"path does not exist: {p}" raise PathError(err) + # Defensive: a path that exists but is neither file nor directory (e.g. socket, + # device node, FIFO) cannot be created portably in unit tests. if not p.is_file() and not p.is_dir(): # pragma: no cover err = f"path is not a file/directory: {p}" raise PathError(err) diff --git a/src/language_tool_python/download_lt.py b/src/language_tool_python/download_lt.py index 1ec3330..51c5b5a 100644 --- a/src/language_tool_python/download_lt.py +++ b/src/language_tool_python/download_lt.py @@ -385,6 +385,7 @@ def download(self) -> None: :raises NotImplementedError: Always, unless implemented by a subclass. """ + # Unreachable: ABC prevents direct instantiation of this abstract method. raise NotImplementedError # pragma: no cover def _get_remote_zip( @@ -618,7 +619,7 @@ def version_name(self) -> str: :rtype: str :raises NotImplementedError: Always, unless implemented by a subclass. """ - raise NotImplementedError + raise NotImplementedError # pragma: no cover # abstract body @property @abstractmethod @@ -633,7 +634,7 @@ def version_into(self) -> tuple[int, int] | datetime: :rtype: tuple[int, int] | datetime.datetime :raises NotImplementedError: Always, unless implemented by a subclass. """ - raise NotImplementedError + raise NotImplementedError # pragma: no cover # abstract body @property @abstractmethod @@ -647,7 +648,7 @@ def download_url(self) -> str: :rtype: str :raises NotImplementedError: Always, unless implemented by a subclass. """ - raise NotImplementedError + raise NotImplementedError # pragma: no cover # abstract body def __eq__(self, other: object) -> bool: """Check equality between two LocalLanguageTool instances. @@ -751,7 +752,7 @@ def download(self) -> None: return if self not in self.get_installed_versions(): - with ( + with ( # pragma: no cover # integration: HTTP download + extraction tempfile.TemporaryDirectory(dir=download_folder) as temp_dir, tempfile.NamedTemporaryFile( suffix=".zip", @@ -905,7 +906,9 @@ def download(self) -> None: raise PathError(err) expected_dir = download_folder / f"LanguageTool-{self.version_name}" - if expected_dir.exists() or expected_dir.is_symlink(): + if ( # pragma: no cover # TOCTOU: dir appears between check and rename + expected_dir.exists() or expected_dir.is_symlink() + ): err = ( "Refusing to overwrite existing LanguageTool snapshot " f"directory: {expected_dir}." diff --git a/src/language_tool_python/language_tag.py b/src/language_tool_python/language_tag.py index c035ec2..acd1198 100644 --- a/src/language_tool_python/language_tag.py +++ b/src/language_tool_python/language_tag.py @@ -36,6 +36,10 @@ class LanguageTag: def __init__(self, tag: str, languages: Iterable[str]) -> None: """Initialize a LanguageTag instance. + :param tag: The language tag to normalize. + :type tag: str + :param languages: An iterable of supported language tags. + :type languages: collections.abc.Iterable[str] :raises ValueError: If the tag is empty or unsupported. """ self.tag = tag diff --git a/src/language_tool_python/server.py b/src/language_tool_python/server.py index 023b06d..1a5caf3 100644 --- a/src/language_tool_python/server.py +++ b/src/language_tool_python/server.py @@ -94,7 +94,7 @@ def _decode_response_content(response: requests.Response) -> str: content: object = response.content if isinstance(content, bytes): return content.decode() - return str(content) + return str(content) # pragma: no cover # requests always returns bytes class LanguageTool: diff --git a/tests/unit/test_cli_unit.py b/tests/unit/test_cli_unit.py index 33af934..4252258 100644 --- a/tests/unit/test_cli_unit.py +++ b/tests/unit/test_cli_unit.py @@ -3,6 +3,7 @@ from __future__ import annotations import io +import logging from pathlib import Path import pytest @@ -14,9 +15,14 @@ get_remote_server, get_rules, get_text, + main, parse_args, print_exception, + process_file, ) +from language_tool_python.exceptions import LanguageToolError + +NUMBER_OF_DOTS_IN_VERSION = 2 # e.g. "3.4.0" has two dots class TestGetRules: @@ -101,7 +107,9 @@ def test_without_debug_prints_to_stderr( ) -> None: """Without debug=True, only the message is printed to stderr.""" print_exception(ValueError("test error"), debug=False) - assert "test error" in capsys.readouterr().err + result = capsys.readouterr() + assert "test error" in result.err + assert "ValueError" not in result.err def test_with_debug_prints_traceback( self, capsys: pytest.CaptureFixture[str] @@ -112,8 +120,9 @@ def test_with_debug_prints_traceback( raise ValueError(msg) except ValueError: print_exception(ValueError("current error"), debug=True) - captured = capsys.readouterr() - assert "ValueError" in captured.err + result = capsys.readouterr() + assert "original error" in result.err + assert "ValueError" in result.err class TestGetText: @@ -174,6 +183,7 @@ def test_stdin_with_ignore_lines(self, monkeypatch: pytest.MonkeyPatch) -> None: result = get_input_text("-", self._args(ignore_lines=r"#.*")) assert "# skip" not in result assert "keep" in result + assert "keep2" in result def test_uses_encoding(self, tmp_path: Path) -> None: """Non-UTF-8 files are decoded with the specified encoding.""" @@ -192,4 +202,229 @@ def test_reads_version_from_pyproject(self) -> None: pyproject = Path(__file__).parent.parent.parent / "pyproject.toml" version = _read_project_version(pyproject) assert isinstance(version, str) - assert version.count(".") >= 1 + assert version.count(".") == NUMBER_OF_DOTS_IN_VERSION + + +class _MockMatch: + """Minimal match object for process_file unit tests.""" + + def __init__( + self, + rule_id: str = "RULE", + message: str = "A suggestion.", + replacements: list[str] | None = None, + ) -> None: + self.rule_id = rule_id + self.message = message + self.replacements: list[str] = replacements or [] + + def get_line_and_column(self, _text: str) -> tuple[int, int]: + return (1, 0) + + +class _MockLangTool: + """Minimal LanguageTool mock for process_file unit tests.""" + + _last_instance: _MockLangTool | None = None + + def __init__(self, **_kw: object) -> None: + _MockLangTool._last_instance = self + self.disabled_rules: set[str] = set() + self.enabled_rules: set[str] = set() + self.disabled_categories: set[str] = set() + self.enabled_categories: set[str] = set() + self.enabled_rules_only: bool = False + self.picky: bool = False + self._spellcheck_disabled: bool = False + self._matches: list[_MockMatch] = [] + + def __enter__(self) -> _MockLangTool: + return self + + def __exit__(self, *_: object) -> None: + pass + + def disable_spellchecking(self) -> None: + self._spellcheck_disabled = True + + def check(self, _text: str) -> list[_MockMatch]: + return self._matches + + def correct(self, text: str) -> str: + return text + " (corrected)" + + +class _RaisingLangTool: + """LanguageTool mock that raises LanguageToolError on context entry.""" + + def __init__(self, **_kw: object) -> None: + pass + + def __enter__(self) -> _RaisingLangTool: + err = "server failed" + raise LanguageToolError(err) + + def __exit__(self, *_: object) -> None: + pass + + +def _parse_file_args(filename: str, **overrides: object) -> CliArgs: + """Build CliArgs from parse_args defaults with optional field overrides.""" + args = parse_args([filename]) + for k, v in overrides.items(): + setattr(args, k, v) + return args + + +class TestProcessFile: + """Tests for process_file() with LanguageTool mocked.""" + + def test_prints_filename_to_stderr_for_multiple_files( + self, + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + capsys: pytest.CaptureFixture[str], + ) -> None: + """Filename is printed to stderr when processing multiple files.""" + f = tmp_path / "a.txt" + f.write_text("hello", encoding="utf-8") + monkeypatch.setattr( + "language_tool_python.__main__.LanguageTool", + _MockLangTool, + ) + args = _parse_file_args(str(f), files=[str(f), "other.txt"]) + process_file(str(f), args, None) + assert str(f) in capsys.readouterr().err + + def test_returns_zero_on_file_not_found( + self, + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + ) -> None: + """Returns 0 when get_input_text raises FileNotFoundError.""" + monkeypatch.setattr( + "language_tool_python.__main__.LanguageTool", + _MockLangTool, + ) + missing = str(tmp_path / "does_not_exist.txt") + result = process_file(missing, _parse_file_args(missing), None) + assert result == 0 + + def test_disables_spellcheck_when_flag_off( + self, + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + ) -> None: + """disable_spellchecking() is called when spell_check=False.""" + f = tmp_path / "text.txt" + f.write_text("hello", encoding="utf-8") + monkeypatch.setattr(_MockLangTool, "_last_instance", None) + monkeypatch.setattr( + "language_tool_python.__main__.LanguageTool", + _MockLangTool, + ) + process_file(str(f), _parse_file_args(str(f), spell_check=False), None) + assert _MockLangTool._last_instance is not None + assert _MockLangTool._last_instance._spellcheck_disabled + + def test_sets_picky_when_flag_on( + self, + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + ) -> None: + """Picky is set to True on the tool when args.picky=True.""" + f = tmp_path / "text.txt" + f.write_text("hello", encoding="utf-8") + monkeypatch.setattr(_MockLangTool, "_last_instance", None) + monkeypatch.setattr( + "language_tool_python.__main__.LanguageTool", + _MockLangTool, + ) + process_file(str(f), _parse_file_args(str(f), picky=True), None) + assert _MockLangTool._last_instance is not None + assert _MockLangTool._last_instance.picky is True + + def test_apply_prints_corrected_text_and_returns_zero( + self, + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + capsys: pytest.CaptureFixture[str], + ) -> None: + """--apply prints corrected text to stdout and returns 0.""" + f = tmp_path / "text.txt" + f.write_text("hello", encoding="utf-8") + monkeypatch.setattr( + "language_tool_python.__main__.LanguageTool", + _MockLangTool, + ) + result = process_file(str(f), _parse_file_args(str(f), apply=True), None) + assert result == 0 + assert "corrected" in capsys.readouterr().out + + def test_returns_zero_on_language_tool_error( + self, + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + ) -> None: + """Returns 0 when LanguageTool raises LanguageToolError on entry.""" + f = tmp_path / "text.txt" + f.write_text("hello", encoding="utf-8") + monkeypatch.setattr( + "language_tool_python.__main__.LanguageTool", + _RaisingLangTool, + ) + result = process_file(str(f), _parse_file_args(str(f)), None) + assert result == 0 + + def test_prints_match_and_returns_two_when_issues_found( + self, + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + capsys: pytest.CaptureFixture[str], + ) -> None: + """Match details are printed and status 2 is returned when issues are found.""" + + class _MatchingLangTool(_MockLangTool): + def check(self, _text: str) -> list[_MockMatch]: + return [ + _MockMatch( + rule_id="SOME_RULE", + message="Fix this.", + replacements=["fix"], + ), + ] + + f = tmp_path / "text.txt" + f.write_text("hello", encoding="utf-8") + monkeypatch.setattr( + "language_tool_python.__main__.LanguageTool", + _MatchingLangTool, + ) + status_issues = 2 + result = process_file(str(f), _parse_file_args(str(f)), None) + assert result == status_issues + out = capsys.readouterr().out + assert "SOME_RULE" in out + assert "fix" in out + + +class TestMain: + """Tests for main() with LanguageTool mocked.""" + + def test_verbose_flag_sets_debug_logging( + self, + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + ) -> None: + """--verbose sets the root logger level to DEBUG.""" + f = tmp_path / "text.txt" + f.write_text("hello", encoding="utf-8") + monkeypatch.setattr( + "language_tool_python.__main__.LanguageTool", + _MockLangTool, + ) + root = logging.getLogger() + monkeypatch.setattr(root, "level", root.level) + result = main(["--verbose", str(f)]) + assert result == 0 + assert root.level == logging.DEBUG diff --git a/tests/unit/test_download.py b/tests/unit/test_download.py index ad8ec2d..ba92a96 100644 --- a/tests/unit/test_download.py +++ b/tests/unit/test_download.py @@ -18,6 +18,7 @@ from unittest.mock import patch import pytest +import requests import language_tool_python from language_tool_python.download_lt import ( @@ -479,6 +480,68 @@ def test_snapshot_download_renames_archive_root_to_requested_date( get_mock.assert_not_called() +def test_http_get_404_raises_path_error() -> None: + """_do_download raises PathError when the server returns 404 Not Found.""" + mock_response = MockDownloadResponse(b"", status_code=404) + mock_response.headers = {} + out_file = io.BytesIO() + local_language_tool = LocalLanguageTool.from_version_name() + with ( + patch( + "language_tool_python.download_lt.requests.get", + return_value=mock_response, + ), + pytest.raises(PathError, match="Could not find at URL"), + ): + local_language_tool._get_remote_zip(out_file) + + +def test_http_get_timeout_raises_timeout_error() -> None: + """_do_download raises TimeoutError when the HTTP request times out.""" + out_file = io.BytesIO() + local_language_tool = LocalLanguageTool.from_version_name() + with ( + patch( + "language_tool_python.download_lt.requests.get", + side_effect=requests.exceptions.Timeout("timed out"), + ), + pytest.raises(TimeoutError, match="timed out"), + ): + local_language_tool._get_remote_zip(out_file) + + +def test_snapshot_download_raises_when_archive_has_multiple_root_dirs( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """download() raises PathError when the snapshot archive has multiple root dirs.""" + payload = make_zip_payload( + { + "Dir1/file.txt": b"content1", + "Dir2/file.txt": b"content2", + } + ) + local_language_tool = LocalLanguageTool.from_version_name("20240102") + monkeypatch.setattr( + language_tool_python.download_lt, + "_confirm_java_compatibility", + skip_java_compatibility_check, + ) + with ( + workspace_temp_dir() as temp_dir, + patch( + "language_tool_python.download_lt.requests.get", + return_value=MockDownloadResponse(payload), + ), + ): + monkeypatch.setattr( + language_tool_python.download_lt, + "get_language_tool_download_path", + lambda: temp_dir, + ) + with pytest.raises(PathError, match="Expected snapshot archive"): + local_language_tool.download() + + def test_latest_snapshot_download_renames_archive_root_to_current_date( monkeypatch: pytest.MonkeyPatch, ) -> None: diff --git a/tests/unit/test_download_unit.py b/tests/unit/test_download_unit.py index 9111496..dc17699 100644 --- a/tests/unit/test_download_unit.py +++ b/tests/unit/test_download_unit.py @@ -8,12 +8,15 @@ from __future__ import annotations import re +import subprocess +from datetime import datetime, timezone from typing import TYPE_CHECKING import pytest import language_tool_python.download_lt as _dl -from language_tool_python.exceptions import PathError +from language_tool_python.config_file import LanguageToolConfig +from language_tool_python.exceptions import JavaError, PathError if TYPE_CHECKING: from pathlib import Path @@ -23,6 +26,10 @@ _JAVA_21_MAJOR = 21 _SHA256_HEX_LENGTH = 64 _KIBIBYTE = 1024 +_SNAPSHOT_TEST_VERSION = "20240315" +_SNAPSHOT_TEST_YEAR = 2024 +_SNAPSHOT_TEST_MONTH = 3 +_SNAPSHOT_TEST_DAY = 15 def return_42(_: object) -> int: @@ -30,6 +37,30 @@ def return_42(_: object) -> int: return 42 +def _which_none(_name: str) -> str | None: + """Stub for shutil.which that always returns None (Java not found).""" + return None + + +def _which_java(_name: str) -> str: + """Stub for shutil.which that always returns a fake Java path.""" + return "/usr/bin/java" + + +def _check_output_java8(*_args: object, **_kw: object) -> str: + """Stub for subprocess.check_output returning a Java 1.8 version string.""" + return 'java version "1.8.0_292"' + + +def _check_output_java17(*_args: object, **_kw: object) -> str: + """Stub for subprocess.check_output returning a Java 17 version string.""" + return "openjdk 17.0.1 2021-10-19" + + +def _noop(_v: object) -> None: + """No-operation stub for monkeypatching void functions.""" + + class TestLoadsManifest: """Tests for the _loads_manifest() TOML parser.""" @@ -303,3 +334,329 @@ def test_invalid_hash_in_env_raises(self, monkeypatch: pytest.MonkeyPatch) -> No monkeypatch.setenv("LTP_DOWNLOAD_SHA256", "not-a-valid-sha256") with pytest.raises(PathError, match="Invalid SHA-256"): _dl._get_zip_hash("6.8") + + +class TestConfirmJavaCompatibility: + """Tests for _confirm_java_compatibility() Java version checker.""" + + def test_no_java_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Raises ModuleNotFoundError when Java cannot be found on PATH.""" + monkeypatch.setattr("language_tool_python.download_lt.which", _which_none) + with pytest.raises(ModuleNotFoundError, match="No java install"): + _dl._confirm_java_compatibility("6.8") + + def test_java_8_fails_for_current_lt(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Java 1.8 raises SystemError when current LT requires Java >= 17.""" + monkeypatch.setattr("language_tool_python.download_lt.which", _which_java) + monkeypatch.setattr(subprocess, "check_output", _check_output_java8) + with pytest.raises(SystemError, match="requires"): + _dl._confirm_java_compatibility("6.8") + + def test_java_8_fails_for_old_lt(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Java 1.8 raises SystemError for an old LT version requiring Java >= 9.""" + monkeypatch.setattr("language_tool_python.download_lt.which", _which_java) + monkeypatch.setattr(subprocess, "check_output", _check_output_java8) + with pytest.raises(SystemError, match="requires"): + _dl._confirm_java_compatibility("5.0") + + def test_java_17_passes(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Java 17 satisfies the current LT requirement without error.""" + monkeypatch.setattr("language_tool_python.download_lt.which", _which_java) + monkeypatch.setattr(subprocess, "check_output", _check_output_java17) + _dl._confirm_java_compatibility("6.8") + + +class TestGetInstalledVersions: + """Tests for LocalLanguageTool.get_installed_versions().""" + + def test_returns_empty_when_no_lt_dirs( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """Returns an empty list when no LanguageTool directories are present.""" + monkeypatch.setenv("LTP_PATH", str(tmp_path)) + versions = _dl.LocalLanguageTool.get_installed_versions() + assert versions == [] + + def test_returns_installed_versions( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """Returns a sorted list of installed LocalLanguageTool instances.""" + (tmp_path / "LanguageTool-6.8").mkdir() + (tmp_path / "LanguageTool-6.7").mkdir() + monkeypatch.setenv("LTP_PATH", str(tmp_path)) + versions = _dl.LocalLanguageTool.get_installed_versions() + version_names = [v.version_name for v in versions] + assert "6.8" in version_names + assert "6.7" in version_names + + +class TestGetLatestInstalledVersion: + """Tests for LocalLanguageTool.get_latest_installed_version().""" + + def test_returns_none_when_no_versions( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """Returns None when no LanguageTool versions are installed.""" + monkeypatch.setenv("LTP_PATH", str(tmp_path)) + result = _dl.LocalLanguageTool.get_latest_installed_version() + assert result is None + + def test_returns_latest_version( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """Returns the highest-ordered installed version.""" + (tmp_path / "LanguageTool-6.7").mkdir() + (tmp_path / "LanguageTool-6.8").mkdir() + monkeypatch.setenv("LTP_PATH", str(tmp_path)) + result = _dl.LocalLanguageTool.get_latest_installed_version() + assert result is not None + assert result.version_name == "6.8" + + +class TestGetDirectoryPath: + """Tests for LocalLanguageTool.get_directory_path().""" + + def test_no_lt_dirs_raises( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """Raises FileNotFoundError when no LanguageTool directories exist.""" + monkeypatch.setenv("LTP_PATH", str(tmp_path)) + lt = _dl.ReleaseLocalLanguageTool("6.8") + with pytest.raises(FileNotFoundError, match="LanguageTool not found"): + lt.get_directory_path() + + def test_version_not_found_raises( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """Raises FileNotFoundError when the requested version is not installed.""" + (tmp_path / "LanguageTool-6.7").mkdir() + monkeypatch.setenv("LTP_PATH", str(tmp_path)) + lt = _dl.ReleaseLocalLanguageTool("6.8") + with pytest.raises(FileNotFoundError, match=r"6\.8"): + lt.get_directory_path() + + def test_returns_matching_directory( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """Returns the directory path when the version directory exists.""" + lt_dir = tmp_path / "LanguageTool-6.8" + lt_dir.mkdir() + monkeypatch.setenv("LTP_PATH", str(tmp_path)) + lt = _dl.ReleaseLocalLanguageTool("6.8") + assert lt.get_directory_path() == lt_dir + + +class TestGetJarPath: + """Tests for LocalLanguageTool.get_jar_path().""" + + def test_no_jar_raises( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """Raises FileNotFoundError when no JAR file exists in the LT directory.""" + (tmp_path / "LanguageTool-6.8").mkdir() + monkeypatch.setenv("LTP_PATH", str(tmp_path)) + lt = _dl.ReleaseLocalLanguageTool("6.8") + with pytest.raises(FileNotFoundError, match="JAR not found"): + lt.get_jar_path() + + def test_finds_server_jar( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """Returns the path to languagetool-server.jar when it exists.""" + lt_dir = tmp_path / "LanguageTool-6.8" + lt_dir.mkdir() + jar = lt_dir / "languagetool-server.jar" + jar.touch() + monkeypatch.setenv("LTP_PATH", str(tmp_path)) + lt = _dl.ReleaseLocalLanguageTool("6.8") + assert lt.get_jar_path() == jar + + +class TestGetServerCmd: + """Tests for LocalLanguageTool.get_server_cmd().""" + + def test_no_java_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Raises JavaError when Java is not found on PATH.""" + monkeypatch.setattr("language_tool_python.download_lt.which", _which_none) + lt = _dl.ReleaseLocalLanguageTool("6.8") + with pytest.raises(JavaError, match="can't find Java"): + lt.get_server_cmd() + + def test_without_port_or_config( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """The command contains java and jar path when no extras are given.""" + lt_dir = tmp_path / "LanguageTool-6.8" + lt_dir.mkdir() + (lt_dir / "languagetool-server.jar").touch() + monkeypatch.setenv("LTP_PATH", str(tmp_path)) + monkeypatch.setattr("language_tool_python.download_lt.which", _which_java) + lt = _dl.ReleaseLocalLanguageTool("6.8") + cmd = lt.get_server_cmd() + assert "java" in cmd[0] + assert "languagetool-server.jar" in cmd[2] + + def test_with_port(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + """The generated command includes -p when a port is specified.""" + lt_dir = tmp_path / "LanguageTool-6.8" + lt_dir.mkdir() + (lt_dir / "languagetool-server.jar").touch() + monkeypatch.setenv("LTP_PATH", str(tmp_path)) + monkeypatch.setattr("language_tool_python.download_lt.which", _which_java) + lt = _dl.ReleaseLocalLanguageTool("6.8") + cmd = lt.get_server_cmd(port=8081) + assert "-p" in cmd + assert "8081" in cmd + + def test_with_config(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + """The generated command includes --config when a config object is given.""" + lt_dir = tmp_path / "LanguageTool-6.8" + lt_dir.mkdir() + (lt_dir / "languagetool-server.jar").touch() + monkeypatch.setenv("LTP_PATH", str(tmp_path)) + monkeypatch.setattr("language_tool_python.download_lt.which", _which_java) + lt = _dl.ReleaseLocalLanguageTool("6.8") + config = LanguageToolConfig({"cacheSize": 500}) + cmd = lt.get_server_cmd(config=config) + assert "--config" in cmd + + +class TestLocalLanguageToolComparisons: + """Tests for LocalLanguageTool.__eq__ and __lt__ cross-type behaviours.""" + + def test_eq_with_non_lt_returns_not_implemented(self) -> None: + """__eq__ returns NotImplemented for non-LocalLanguageTool objects.""" + lt = _dl.ReleaseLocalLanguageTool("6.8") + result = lt.__eq__("not-a-lt") + assert result is NotImplemented + + def test_snapshot_lt_release_is_false(self) -> None: + """A Snapshot is never less than a Release (Snapshots are always newer).""" + snap = _dl.SnapshotLocalLanguageTool("20240101") + rel = _dl.ReleaseLocalLanguageTool("6.8") + assert not (snap < rel) + + def test_release_lt_snapshot_is_true(self) -> None: + """A Release is always less than a Snapshot (Snapshots are always newer).""" + rel = _dl.ReleaseLocalLanguageTool("6.8") + snap = _dl.SnapshotLocalLanguageTool("20240101") + assert rel < snap + + def test_lt_with_third_subclass_returns_not_implemented(self) -> None: + """__lt__ returns NotImplemented for a third LocalLanguageTool subclass.""" + + class _ThirdLT(_dl.LocalLanguageTool): + @property + def version_name(self) -> str: + return "third" + + @property + def version_into(self) -> tuple[int, int] | datetime: + return (0, 0) + + @property + def download_url(self) -> str: + return "" + + def download(self) -> None: + pass + + rel = _dl.ReleaseLocalLanguageTool("6.8") + assert rel.__lt__(_ThirdLT()) is NotImplemented + + def test_lt_same_type_mismatched_version_into_returns_not_implemented(self) -> None: + """__lt__ returns NotImplemented when version_into types differ.""" + + class _MixedLT(_dl.LocalLanguageTool): + def __init__(self, *, use_dt: bool) -> None: + self._use_dt = use_dt + + @property + def version_name(self) -> str: + return "mixed" + + @property + def version_into(self) -> tuple[int, int] | datetime: + if self._use_dt: + return datetime(2024, 1, 1, tzinfo=timezone.utc) + return (0, 0) + + @property + def download_url(self) -> str: + return "" + + def download(self) -> None: + pass + + a = _MixedLT(use_dt=False) + b = _MixedLT(use_dt=True) + assert a.__lt__(b) is NotImplemented + + +class TestSnapshotVersionInto: + """Tests for SnapshotLocalLanguageTool.version_into property.""" + + def test_returns_datetime_for_date_string(self) -> None: + """A date version string is converted to the correct datetime.""" + lt = _dl.SnapshotLocalLanguageTool(_SNAPSHOT_TEST_VERSION) + vi = lt.version_into + assert isinstance(vi, datetime) + assert vi.year == _SNAPSHOT_TEST_YEAR + assert vi.month == _SNAPSHOT_TEST_MONTH + assert vi.day == _SNAPSHOT_TEST_DAY + + +class TestReleaseDownloadUrlEdgeCases: + """Tests for ReleaseLocalLanguageTool.download_url edge cases.""" + + def test_version_below_4_0_raises(self) -> None: + """Versions below 4.0 are not supported and raise PathError.""" + lt = _dl.ReleaseLocalLanguageTool("3.9") + with pytest.raises(PathError, match="no longer supported"): + _ = lt.download_url + + def test_version_6_0_uses_release_url(self) -> None: + """Versions 6.0-6.6 use the main release download URL.""" + lt = _dl.ReleaseLocalLanguageTool("6.0") + assert "6.0" in lt.download_url + + +class TestDownloadEarlyReturn: + """Tests for LocalLanguageTool.download() early-return via LTP_JAR_DIR_PATH.""" + + def test_release_download_skips_when_jar_dir_set( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """download() exits without network access when LTP_JAR_DIR_PATH is set.""" + monkeypatch.setattr( + "language_tool_python.download_lt._confirm_java_compatibility", + _noop, + ) + monkeypatch.setenv("LTP_PATH", str(tmp_path)) + monkeypatch.setenv("LTP_JAR_DIR_PATH", str(tmp_path)) + _dl.ReleaseLocalLanguageTool("6.8").download() + + def test_snapshot_download_skips_when_jar_dir_set( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """Snapshot download() skips the network when LTP_JAR_DIR_PATH is set.""" + monkeypatch.setattr( + "language_tool_python.download_lt._confirm_java_compatibility", + _noop, + ) + monkeypatch.setenv("LTP_PATH", str(tmp_path)) + monkeypatch.setenv("LTP_JAR_DIR_PATH", str(tmp_path)) + _dl.SnapshotLocalLanguageTool("20240101").download() + + def test_release_download_skips_when_already_installed( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """download() skips the network when the release is already installed.""" + monkeypatch.setattr( + "language_tool_python.download_lt._confirm_java_compatibility", + _noop, + ) + monkeypatch.delenv("LTP_JAR_DIR_PATH", raising=False) + monkeypatch.setenv("LTP_PATH", str(tmp_path)) + (tmp_path / "LanguageTool-6.8").mkdir() + _dl.ReleaseLocalLanguageTool("6.8").download() diff --git a/tests/unit/test_internals_utils.py b/tests/unit/test_internals_utils.py index 7a41cb3..4f23034 100644 --- a/tests/unit/test_internals_utils.py +++ b/tests/unit/test_internals_utils.py @@ -114,12 +114,6 @@ def test_raises_on_negative(self, monkeypatch: pytest.MonkeyPatch) -> None: with pytest.raises(PathError, match="Invalid float"): get_env_float("TEST_FLOAT_VAR", 1.0) - def test_raises_on_inf(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Infinity is not a valid positive float and raises PathError.""" - monkeypatch.setenv("TEST_FLOAT_VAR", "inf") - with pytest.raises(PathError, match="Invalid float"): - get_env_float("TEST_FLOAT_VAR", 1.0) - class TestGetLanguageToolDownloadPath: """Tests for get_language_tool_download_path() path resolver.""" diff --git a/tests/unit/test_match.py b/tests/unit/test_match.py index fdccbc8..56a384d 100644 --- a/tests/unit/test_match.py +++ b/tests/unit/test_match.py @@ -8,8 +8,8 @@ from language_tool_python.match import ( Match, - _four_byte_char_positions, _get_match_ordered_dict, + four_byte_char_positions, is_check_match, ) @@ -60,7 +60,7 @@ def _make_attrib( # noqa: PLR0913 def _make_match(text: str = "This is noot okay.", **kwargs: object) -> Match: - return Match(_make_attrib(**kwargs), text) # type: ignore[arg-type] + return Match(_make_attrib(**kwargs), four_byte_char_positions(text)) # type: ignore[arg-type] class TestMatchInit: @@ -102,7 +102,7 @@ def test_four_byte_char_adjustment(self) -> None: context_offset=3, sentence="🌅 hello world", ) - m = Match(attrib, text) + m = Match(attrib, four_byte_char_positions(text)) adjusted_offset = 2 assert m.offset == adjusted_offset @@ -118,15 +118,16 @@ def test_no_adjustment_without_four_byte_chars(self) -> None: context_offset=expected_offset, sentence=text, ), - text, + four_byte_char_positions(text), ) assert m.offset == expected_offset - def test_same_text_reuses_cache(self) -> None: - """Two matches on the same text share the cached position list.""" + def test_same_text_reuses_positions(self) -> None: + """Two matches sharing a precomputed positions list both get correct offsets.""" text = "Same text here." explicit_offset = 5 - m1 = Match(_make_attrib(context_text=text, sentence=text), text) + four_byte_positions = four_byte_char_positions(text) + m1 = Match(_make_attrib(context_text=text, sentence=text), four_byte_positions) m2 = Match( _make_attrib( context_text=text, @@ -135,36 +136,35 @@ def test_same_text_reuses_cache(self) -> None: length=_DEFAULT_LENGTH, context_offset=explicit_offset, ), - text, + four_byte_positions, ) - assert text == Match.PREVIOUS_MATCHES_TEXT assert m1.offset == _DEFAULT_OFFSET assert m2.offset == explicit_offset class TestFourByteCharPositions: - """Tests for _four_byte_char_positions() helper.""" + """Tests for four_byte_char_positions() helper.""" def test_empty_string(self) -> None: """An empty string has no 4-byte char positions.""" - assert _four_byte_char_positions("") == [] + assert four_byte_char_positions("") == [] def test_ascii_only(self) -> None: """A pure-ASCII string has no 4-byte char positions.""" - assert _four_byte_char_positions("hello") == [] + assert four_byte_char_positions("hello") == [] def test_emoji_at_start(self) -> None: """An emoji at position 0 is reported at index 0.""" - assert _four_byte_char_positions("🌅abc") == [0] + assert four_byte_char_positions("🌅abc") == [0] def test_multiple_emojis(self) -> None: - """Two consecutive emojis are reported at their Python indices.""" - positions = _four_byte_char_positions("🌅🎉abc") + """Two consecutive emojis are reported with adjusted indices.""" + positions = four_byte_char_positions("🌅🎉abc") assert positions == [0, 2] def test_emoji_in_middle(self) -> None: """An emoji in the middle of ASCII text is reported at the correct index.""" - positions = _four_byte_char_positions("ab🌅cd") + positions = four_byte_char_positions("ab🌅cd") assert positions == [2] @@ -350,6 +350,7 @@ def test_eq_not_implemented_for_non_match(self) -> None: def test_lt(self) -> None: """A match at an earlier offset is less than one at a later offset.""" text = "This is noot okay, and also baaad." + positions = four_byte_char_positions(text) m_early = Match( _make_attrib( offset=0, @@ -358,7 +359,7 @@ def test_lt(self) -> None: context_offset=0, sentence=text, ), - text, + positions, ) m_later = Match( _make_attrib( @@ -368,7 +369,7 @@ def test_lt(self) -> None: context_offset=_DEFAULT_OFFSET, sentence=text, ), - text, + positions, ) assert m_early < m_later diff --git a/tests/unit/test_safe_zip.py b/tests/unit/test_safe_zip.py index 328c7d4..c0ff944 100644 --- a/tests/unit/test_safe_zip.py +++ b/tests/unit/test_safe_zip.py @@ -6,6 +6,7 @@ import io import shutil import stat +import unittest.mock import uuid import zipfile from collections.abc import Iterator @@ -532,3 +533,92 @@ def test_safe_extract_checks_total_compression_ratio_after_all_members() -> None assert ( temp_dir / "LanguageTool" / "already-compressed.bin" ).read_bytes() == already_compressed + + +def test_normalize_member_path_empty_name_raises() -> None: + """_normalize_member_path rejects an empty filename.""" + with pytest.raises(PathError, match="Unsafe ZIP member name"): + SafeZipExtractor()._normalize_member_path("") + + +def test_normalize_member_path_control_char_raises() -> None: + """_normalize_member_path rejects a filename containing a control character.""" + with pytest.raises(PathError, match="Unsafe ZIP member name"): + SafeZipExtractor()._normalize_member_path("foo\x01bar") + + +def test_validate_member_type_explicit_regular_file_passes() -> None: + """_validate_member_type accepts a ZipInfo with an explicit S_IFREG mode.""" + member = zipfile.ZipInfo("LanguageTool/file.txt") + member.external_attr = stat.S_IFREG << 16 + SafeZipExtractor()._validate_member_type(member) + + +def test_validate_member_compression_ratio_zero_compress_size_raises() -> None: + """_validate_member_compression_ratio rejects a member with zero compressed size.""" + member = zipfile.ZipInfo("LanguageTool/file.txt") + member.compress_size = 0 + member.file_size = 100 + with pytest.raises(PathError, match="zero compressed size"): + SafeZipExtractor()._validate_member_compression_ratio(member) + + +def test_validate_total_compression_ratio_zero_compressed_skips() -> None: + """_validate_total_compression_ratio returns early when total_compressed is zero.""" + SafeZipExtractor()._validate_total_compression_ratio(0, 0) + + +def test_validate_member_sizes_negative_compress_size_raises() -> None: + """_validate_member_sizes rejects a member with a negative compressed size.""" + member = zipfile.ZipInfo("LanguageTool/file.txt") + member.compress_size = -1 + member.file_size = 100 + with pytest.raises(PathError, match="Invalid ZIP member size"): + SafeZipExtractor()._validate_member_sizes(member) + + +def test_validate_member_sizes_negative_file_size_raises() -> None: + """_validate_member_sizes rejects a member with a negative uncompressed size.""" + member = zipfile.ZipInfo("LanguageTool/file.txt") + member.compress_size = 100 + member.file_size = -1 + with pytest.raises(PathError, match="Invalid ZIP member size"): + SafeZipExtractor()._validate_member_sizes(member) + + +def _open_returning_large(_m: object, _mode: str = "r") -> io.BytesIO: + """Fake ZipFile.open that yields more bytes than any small declared file_size.""" + return io.BytesIO(b"hello world - content longer than 3 bytes") + + +def _open_returning_small(_m: object, _mode: str = "r") -> io.BytesIO: + """Fake ZipFile.open that yields only 2 bytes regardless of declared size.""" + return io.BytesIO(b"hi") + + +def test_copy_member_raises_when_content_exceeds_declared_size() -> None: + """_copy_member raises when decompressed bytes exceed the declared file_size.""" + payload = make_zip_payload({"LanguageTool/file.txt": b"hello world"}) + with zipfile.ZipFile(io.BytesIO(payload)) as zf: + member = zf.infolist()[0] + member.file_size = 3 + with ( + unittest.mock.patch.object(zf, "open", new=_open_returning_large), + workspace_temp_dir() as temp_dir, + pytest.raises(PathError, match="expanded beyond declared size"), + ): + SafeZipExtractor()._copy_member(zf, member, temp_dir / "file.txt") + + +def test_copy_member_raises_when_content_is_less_than_declared_size() -> None: + """_copy_member raises when fewer bytes are read than the declared file_size.""" + payload = make_zip_payload({"LanguageTool/file.txt": b"hello world"}) + with zipfile.ZipFile(io.BytesIO(payload)) as zf: + member = zf.infolist()[0] + member.file_size = 1000 + with ( + unittest.mock.patch.object(zf, "open", new=_open_returning_small), + workspace_temp_dir() as temp_dir, + pytest.raises(PathError, match="extracted size mismatch"), + ): + SafeZipExtractor()._copy_member(zf, member, temp_dir / "file.txt") diff --git a/tests/unit/test_server_unit.py b/tests/unit/test_server_unit.py new file mode 100644 index 0000000..2374bcc --- /dev/null +++ b/tests/unit/test_server_unit.py @@ -0,0 +1,1089 @@ +"""Unit tests for server.py — no Java, no network required.""" + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING +from unittest.mock import patch + +import pytest +import requests + +from language_tool_python.exceptions import ( + LanguageToolError, + PathError, + RateLimitError, + ServerError, +) +from language_tool_python.language_tag import LanguageTag +from language_tool_python.server import ( + LanguageTool, + LanguageToolPublicAPI, + _kill_processes, + _terminate_server_at_exit, +) + +if TYPE_CHECKING: + from collections.abc import Callable + from pathlib import Path + + +_LANGUAGES: set[str] = {"en", "en-US", "fr", "auto"} +_DEFAULT_URL = "http://localhost:8081/v2/" +_DEFAULT_PORT = 8081 +_NEXT_PORT = 8082 + +# Typed response stubs used in patch.object calls (avoids dict[Any, Any] errors) +_INVALID_CHECK_SHAPE: dict[str, str] = {"bad": "shape"} +_BAD_LANG_ITEM: dict[str, str] = {"wrong": "keys"} +_NON_LIST_RESPONSE: dict[str, str] = {"not": "a list"} +_VALID_CHECK_EMPTY: dict[str, object] = {"matches": [], "language": {}, "warnings": {}} +_VALID_LANG_LIST: list[dict[str, str]] = [ + {"code": "en", "longCode": "en-US", "name": "English"}, +] +_HTTP_RATE_LIMIT_STATUS = 426 + + +class _MockSession(requests.Session): + """requests.Session subclass with configurable exception and response injection.""" + + def __init__( + self, + get_exc: Exception | None = None, + post_exc: Exception | None = None, + get_response: requests.Response | None = None, + post_response: requests.Response | None = None, + ) -> None: + """Initialise session with optional exception and response stubs.""" + super().__init__() + self._get_exc = get_exc + self._post_exc = post_exc + self._get_response = get_response + self._post_response = post_response + + def get( # type: ignore[override] + self, + _url: str | bytes, + **_kw: object, + ) -> requests.Response: + """Raise configured exception or return configured/empty response.""" + if self._get_exc is not None: + raise self._get_exc + return ( + self._get_response + if self._get_response is not None + else requests.Response() + ) + + def post( # type: ignore[override] + self, + _url: str | bytes, + _data: object = None, + _json: object = None, + **_kw: object, + ) -> requests.Response: + """Raise configured exception or return configured/empty response.""" + if self._post_exc is not None: + raise self._post_exc + return ( + self._post_response + if self._post_response is not None + else requests.Response() + ) + + +class _MockProcess: + """Minimal stand-in for subprocess.Popen[str] used in server tests.""" + + def __init__(self, poll_return: int | None = None) -> None: + """Initialise with a fixed poll return value.""" + self.pid: int = 12345 + self.stdin: None = None + self.returncode: int | None = poll_return + self._poll_return = poll_return + + def poll(self) -> int | None: + """Return the configured poll value.""" + return self._poll_return + + def wait(self, timeout: float | None = None) -> int: # noqa: ARG002 + """Pretend to wait and return 0.""" + return 0 + + def terminate(self) -> None: + """No-op terminate.""" + + +class _MockLocalLT: + """Mock LocalLanguageTool with configurable directory and PathError on cmd.""" + + def __init__(self, directory: Path | None = None) -> None: + """Initialise with an optional directory to return from get_directory_path.""" + self._directory = directory + + def download(self) -> None: + """No-op download.""" + + def get_directory_path(self) -> Path: + """Return the configured directory or raise RuntimeError if unset.""" + if self._directory is None: + err = "no directory configured on mock" + raise RuntimeError(err) + return self._directory + + def get_server_cmd( + self, + _port: int | None, + _config: object, + ) -> list[str]: + """Raise PathError when called.""" + err = "jar not found" + raise PathError(err) + + +class _MockLocalLTWithCmd: + """Mock LocalLanguageTool that returns a working server command.""" + + def download(self) -> None: + """No-op download.""" + + def get_directory_path(self) -> None: + """No-op — not needed for start_local_server tests.""" + + def get_server_cmd(self, _port: int | None, _config: object) -> list[str]: + """Return a fake Java command.""" + return ["java", "-jar", "lt.jar"] + + +class _MockMatchWithOffset: + """Minimal stand-in for Match carrying only an offset attribute.""" + + def __init__(self, offset: int = 0) -> None: + """Initialise with an offset.""" + self.offset = offset + + +def _make_json_response(body: bytes, status_code: int = 200) -> requests.Response: + """Build a requests.Response with the given body bytes and status code.""" + r = requests.Response() + r._content = body + # prevent close() from accessing r.raw (which is None) + r._content_consumed = True # type: ignore[attr-defined] + r.status_code = status_code + r.encoding = "utf-8" + return r + + +def _bare_lt(**attrs: object) -> LanguageTool: + """Return a LanguageTool instance created without calling __init__.""" + lt: LanguageTool = object.__new__(LanguageTool) + lt._server = None + lt._session = _MockSession() + lt._remote = False + lt._new_spellings_persist = True + lt._new_spellings = None + lt._url = _DEFAULT_URL + lt._port = _DEFAULT_PORT + lt._host = "127.0.0.1" + lt._config = None + lt._local_language_tool = None + lt._proxies = None + lt._mother_tongue = None + lt._language = LanguageTag("en", _LANGUAGES) + lt._disabled_rules = set[str]() + lt._enabled_rules = set[str]() + lt._disabled_categories = set[str]() + lt._enabled_categories = set[str]() + lt._enabled_rules_only = False + lt._preferred_variants = set[str]() + lt._picky = False + lt._premium_username = None + lt._premium_key = None + lt._language_tool_download_version = "6.8" + lt._available_ports = list[int]() + for k, v in attrs.items(): + setattr(lt, k, v) + return lt + + +def _start_fails_once_then_succeeds() -> Callable[[], None]: + """Return a callable that raises ServerError on the first call only.""" + calls: list[int] = [0] + + def _start() -> None: + calls[0] += 1 + if calls[0] == 1: + err = "port busy" + raise ServerError(err) + + return _start + + +class TestLanguageToolConstructorValidation: + """Tests for __init__ parameter validation.""" + + def test_raises_when_remote_server_and_config_combined(self) -> None: + """Combining remote_server and config raises ValueError immediately.""" + with pytest.raises(ValueError, match="Cannot use both"): + LanguageTool(remote_server="http://fake/", config={"k": "v"}) + + def test_raises_when_proxies_used_without_remote_server(self) -> None: + """Proxies without remote_server raises ValueError immediately.""" + with pytest.raises(ValueError, match="Proxies can only be used"): + LanguageTool(proxies={"http": "http://proxy/"}) + + def test_session_proxies_updated_when_proxies_and_remote_server_given( + self, + ) -> None: + """Session proxies are updated when proxies and remote_server are provided.""" + with ( + patch( + "language_tool_python.server.get_locale_language", + return_value="en", + ), + patch.object(LanguageTool, "_get_languages", return_value=_LANGUAGES), + LanguageTool( + remote_server="http://fake/", + proxies={"http": "http://proxy/"}, + ) as lt, + ): + assert lt._proxies == {"http": "http://proxy/"} + + def test_failsafe_language_used_when_locale_detection_fails(self) -> None: + """FAILSAFE_LANGUAGE is used when get_locale_language raises ValueError.""" + with ( + patch( + "language_tool_python.server.get_locale_language", + side_effect=ValueError("no locale"), + ), + patch.object(LanguageTool, "_get_languages", return_value=_LANGUAGES), + LanguageTool(remote_server="http://fake/") as lt, + ): + assert str(lt.language) == "en" + + +class TestLanguageToolDel: + """Tests for __del__ resource warning.""" + + def test_del_warns_and_calls_close_when_server_still_alive(self) -> None: + """ResourceWarning is emitted and close() called when server is alive.""" + lt = _bare_lt() + lt._server = _MockProcess(poll_return=None) # type: ignore[assignment] + with ( + warnings.catch_warnings(record=True) as caught, + ): + warnings.simplefilter("always") + with patch.object(lt, "close") as mock_close: + lt.__del__() + assert any(issubclass(w.category, ResourceWarning) for w in caught) + mock_close.assert_called_once() + lt._server = None # prevent GC __del__ re-triggering + + +class TestLanguageToolRepr: + """Tests for __repr__.""" + + def test_repr_contains_language_and_mother_tongue(self) -> None: + """__repr__ includes the class name, language, and mother tongue.""" + lt = _bare_lt(_mother_tongue="fr") + r = repr(lt) + assert "LanguageTool" in r + assert "en" in r + assert "fr" in r + + +class TestLanguageToolProperties: + """Tests for all property getters and setters.""" + + def test_language_getter_returns_language_tag(self) -> None: + """Language getter returns the stored LanguageTag.""" + lt = _bare_lt() + assert lt.language == LanguageTag("en", _LANGUAGES) + + def test_language_setter_updates_tag_and_clears_rules(self) -> None: + """Language setter updates _language and clears rule sets.""" + lt = _bare_lt() + lt._disabled_rules = {"OLD"} + lt._enabled_rules = {"ALSO_OLD"} + with patch.object(lt, "_get_languages", return_value=_LANGUAGES): + lt.language = "fr" + assert str(lt.language) == "fr" + assert lt._disabled_rules == set() + assert lt._enabled_rules == set() + + def test_mother_tongue_getter_returns_language_tag_when_set(self) -> None: + """Mother_tongue getter wraps the stored string in a LanguageTag.""" + lt = _bare_lt(_mother_tongue="fr") + with patch.object(lt, "_get_languages", return_value=_LANGUAGES): + mt = lt.mother_tongue + assert mt is not None + assert str(mt) == "fr" + + def test_mother_tongue_setter_stores_value(self) -> None: + """Mother_tongue setter stores the provided value.""" + lt = _bare_lt() + lt.mother_tongue = "fr" + assert lt._mother_tongue == "fr" + + def test_proxies_getter_returns_stored_value(self) -> None: + """Proxies getter returns _proxies.""" + lt = _bare_lt(_proxies={"http": "http://p/"}) + assert lt.proxies == {"http": "http://p/"} + + def test_proxies_setter_raises_when_local_server(self) -> None: + """Proxies setter raises ValueError when server is local.""" + lt = _bare_lt(_remote=False) + with pytest.raises(ValueError, match="Proxies can only be used"): + lt.proxies = {"http": "http://proxy/"} + + def test_proxies_setter_updates_session_when_remote(self) -> None: + """Proxies setter updates session proxies on a remote server.""" + lt = _bare_lt(_remote=True) + lt.proxies = {"http": "http://proxy/"} + assert lt._proxies == {"http": "http://proxy/"} + + def test_proxies_setter_clears_proxies_when_set_to_none(self) -> None: + """Proxies setter clears session proxies when set to None.""" + lt = _bare_lt(_remote=True) + lt.proxies = None + assert lt._proxies is None + + def test_disabled_rules_setter(self) -> None: + """Disabled_rules setter replaces the set.""" + lt = _bare_lt() + lt.disabled_rules = {"RULE_X"} + assert lt._disabled_rules == {"RULE_X"} + + def test_disabled_categories_setter(self) -> None: + """Disabled_categories setter replaces the set.""" + lt = _bare_lt() + lt.disabled_categories = {"CAT_A"} + assert lt._disabled_categories == {"CAT_A"} + + def test_enabled_categories_setter(self) -> None: + """Enabled_categories setter replaces the set.""" + lt = _bare_lt() + lt.enabled_categories = {"CAT_B"} + assert lt._enabled_categories == {"CAT_B"} + + def test_enabled_rules_only_getter(self) -> None: + """Enabled_rules_only getter returns the stored bool.""" + lt = _bare_lt(_enabled_rules_only=True) + assert lt.enabled_rules_only is True + + def test_preferred_variants_getter(self) -> None: + """Preferred_variants getter returns the stored set.""" + lt = _bare_lt(_preferred_variants={"en-US"}) + assert lt.preferred_variants == {"en-US"} + + def test_preferred_variants_setter(self) -> None: + """Preferred_variants setter replaces the set.""" + lt = _bare_lt() + lt.preferred_variants = {"en-GB"} + assert lt._preferred_variants == {"en-GB"} + + def test_picky_getter(self) -> None: + """Picky getter returns the stored bool.""" + lt = _bare_lt(_picky=True) + assert lt.picky is True + + def test_picky_setter(self) -> None: + """Picky setter stores the provided bool.""" + lt = _bare_lt() + lt.picky = True + assert lt._picky is True + + def test_premium_username_getter(self) -> None: + """Premium_username getter returns the stored value.""" + lt = _bare_lt(_premium_username="alice") + assert lt.premium_username == "alice" + + def test_premium_username_setter(self) -> None: + """Premium_username setter stores the provided value.""" + lt = _bare_lt() + lt.premium_username = "alice" + assert lt._premium_username == "alice" + + def test_premium_key_getter(self) -> None: + """Premium_key getter returns the stored value.""" + lt = _bare_lt(_premium_key="secret") + assert lt.premium_key == "secret" + + def test_premium_key_setter(self) -> None: + """Premium_key setter stores the provided value.""" + lt = _bare_lt() + lt.premium_key = "secret" + assert lt._premium_key == "secret" + + def test_config_getter(self) -> None: + """Config getter returns _config.""" + lt = _bare_lt() + assert lt.config is None + + def test_url_getter(self) -> None: + """Url getter returns _url.""" + lt = _bare_lt() + assert lt.url == _DEFAULT_URL + + def test_is_remote_getter(self) -> None: + """Is_remote getter returns _remote.""" + lt = _bare_lt(_remote=True) + assert lt.is_remote is True + + def test_host_getter(self) -> None: + """Host getter returns _host.""" + lt = _bare_lt() + assert lt.host == "127.0.0.1" + + def test_port_getter(self) -> None: + """Port getter returns _port.""" + lt = _bare_lt() + assert lt.port == _DEFAULT_PORT + + def test_disabled_rules_getter(self) -> None: + """Disabled_rules getter returns the stored set.""" + lt = _bare_lt(_disabled_rules={"R1"}) + assert lt.disabled_rules == {"R1"} + + def test_enabled_rules_getter(self) -> None: + """Enabled_rules getter returns the stored set.""" + lt = _bare_lt(_enabled_rules={"R2"}) + assert lt.enabled_rules == {"R2"} + + def test_enabled_rules_setter(self) -> None: + """Enabled_rules setter replaces the set.""" + lt = _bare_lt() + lt.enabled_rules = {"R3"} + assert lt._enabled_rules == {"R3"} + + def test_disabled_categories_getter(self) -> None: + """Disabled_categories getter returns the stored set.""" + lt = _bare_lt(_disabled_categories={"C1"}) + assert lt.disabled_categories == {"C1"} + + def test_enabled_categories_getter(self) -> None: + """Enabled_categories getter returns the stored set.""" + lt = _bare_lt(_enabled_categories={"C2"}) + assert lt.enabled_categories == {"C2"} + + def test_enabled_rules_only_setter(self) -> None: + """Enabled_rules_only setter stores the provided bool.""" + lt = _bare_lt() + lt.enabled_rules_only = True + assert lt._enabled_rules_only is True + + def test_language_tool_download_version_getter(self) -> None: + """Language_tool_download_version getter returns the stored version string.""" + lt = _bare_lt() + assert lt.language_tool_download_version == "6.8" + + +class TestLanguageToolClose: + """Tests for close() terminate and unregister paths.""" + + def test_close_terminates_server_when_alive_and_local(self) -> None: + """_terminate_server() is called when the server is alive and not remote.""" + mock_proc = _MockProcess(poll_return=None) + lt = _bare_lt(_remote=False) + lt._server = mock_proc # type: ignore[assignment] + with patch.object(lt, "_terminate_server") as mock_term: + lt.close() + mock_term.assert_called_once() + lt._server = None + + def test_close_unregisters_spellings_when_not_persisted(self) -> None: + """_unregister_spellings() is called when new_spellings_persist is False.""" + lt = _bare_lt(_new_spellings_persist=False, _new_spellings=["hello"]) + with patch.object(lt, "_unregister_spellings") as mock_unreg: + lt.close() + mock_unreg.assert_called_once() + + +class TestLanguageToolCheck: + """Tests for check() error branches.""" + + def test_raises_server_error_when_query_returns_none(self) -> None: + """Check() raises ServerError when _query_server returns None.""" + lt = _bare_lt() + with ( + patch.object(lt, "_query_server", return_value=None), + pytest.raises(ServerError, match="No response received"), + ): + lt.check("hello") + + def test_raises_server_error_when_response_has_invalid_shape(self) -> None: + """Check() raises ServerError when the response fails is_check_response.""" + lt = _bare_lt() + with ( + patch.object(lt, "_query_server", return_value=_INVALID_CHECK_SHAPE), + pytest.raises(ServerError, match="Invalid response"), + ): + lt.check("hello") + + def test_returns_empty_match_list_on_valid_response(self) -> None: + """Check() returns an empty list when the server returns zero matches.""" + lt = _bare_lt() + with patch.object(lt, "_query_server", return_value=_VALID_CHECK_EMPTY): + result = lt.check("hello") + assert result == [] + + +class TestCheckMatchingRegions: + """Tests for check_matching_regions().""" + + def test_returns_empty_list_when_pattern_matches_nothing(self) -> None: + """Returns [] immediately when the pattern produces no regions.""" + lt = _bare_lt() + result = lt.check_matching_regions("hello world", r"\d+") + assert result == [] + + def test_returns_adjusted_matches_when_pattern_matches(self) -> None: + """Matches are offset-adjusted and sorted when the pattern finds regions.""" + mock_match = _MockMatchWithOffset(offset=2) + match_list: list[_MockMatchWithOffset] = [mock_match] + lt = _bare_lt() + with patch.object(lt, "check", return_value=match_list): + results = lt.check_matching_regions("hello world", r"hello") + expected_offset = 2 + assert len(results) == 1 + assert results[0].offset == expected_offset + + +class TestCreateParams: + """Tests for _create_params() optional parameter branches.""" + + def test_optional_params_included_when_attributes_set(self) -> None: + """All optional _create_params branches fire when attributes are set.""" + lt = _bare_lt( + _mother_tongue="fr", + _disabled_rules={"RULE1"}, + _preferred_variants={"en-US"}, + _picky=True, + _premium_username="user@test", + _premium_key="key123", + ) + with patch.object(lt, "_get_languages", return_value=_LANGUAGES): + params = lt._create_params("hello") + + assert params.get("motherTongue") == "fr" + assert "RULE1" in (params.get("disabledRules") or "") + assert "en-US" in (params.get("preferredVariants") or "") + assert params.get("level") == "picky" + assert params.get("username") == "user@test" + assert params.get("apiKey") == "key123" + + def test_enabled_rules_categories_and_enabled_only_included(self) -> None: + """Enabled rules, enabled-only, and category params are included when set.""" + lt = _bare_lt( + _enabled_rules={"RULE_EN"}, + _enabled_rules_only=True, + _disabled_categories={"CAT_DIS"}, + _enabled_categories={"CAT_EN"}, + ) + with patch.object(lt, "_get_languages", return_value=_LANGUAGES): + params = lt._create_params("hello") + + assert "RULE_EN" in (params.get("enabledRules") or "") + assert params.get("enabledOnly") == "true" + assert "CAT_DIS" in (params.get("disabledCategories") or "") + assert "CAT_EN" in (params.get("enabledCategories") or "") + + +class TestSpellchecking: + """Tests for enable_spellchecking().""" + + def test_enable_spellchecking_removes_typos_category(self) -> None: + """Enable_spellchecking() removes TYPOS from disabled_categories.""" + lt = _bare_lt(_disabled_categories={"TYPOS", "OTHER"}) + lt.enable_spellchecking() + assert "TYPOS" not in lt._disabled_categories + assert "OTHER" in lt._disabled_categories + + def test_disable_spellchecking_adds_typos_category(self) -> None: + """Disable_spellchecking() adds TYPOS to disabled_categories.""" + lt = _bare_lt(_disabled_categories=set[str]()) + lt.disable_spellchecking() + assert "TYPOS" in lt._disabled_categories + + +class TestCorrect: + """Tests for correct().""" + + def test_correct_applies_check_results_to_text(self) -> None: + """Correct() delegates to check() and returns the corrected text.""" + no_matches: list[_MockMatchWithOffset] = [] + lt = _bare_lt() + with patch.object(lt, "check", return_value=no_matches): + result = lt.correct("hello") + assert result == "hello" + + +class TestGetValidSpellingFilePath: + """Tests for _get_valid_spelling_file_path() error branches.""" + + def test_raises_when_local_language_tool_not_initialized(self) -> None: + """Raises PathError when _local_language_tool is None.""" + lt = _bare_lt(_local_language_tool=None) + with pytest.raises( + PathError, match="LocalLanguageTool instance is not initialized" + ): + lt._get_valid_spelling_file_path() + + def test_raises_file_not_found_when_spelling_file_missing( + self, tmp_path: Path + ) -> None: + """Raises FileNotFoundError when the spelling file does not exist.""" + mock_llt = _MockLocalLT(directory=tmp_path) + lt = _bare_lt() + lt._local_language_tool = mock_llt # type: ignore[assignment] + with pytest.raises(FileNotFoundError, match="Failed to find"): + lt._get_valid_spelling_file_path() + + def test_auto_language_defaults_to_en_and_raises_file_not_found( + self, tmp_path: Path + ) -> None: + """Auto language logs debug, defaults to en, then raises if file missing.""" + mock_llt = _MockLocalLT(directory=tmp_path) + lt = _bare_lt() + lt._language = LanguageTag("auto", {"auto", "en"}) + lt._local_language_tool = mock_llt # type: ignore[assignment] + with pytest.raises(FileNotFoundError, match="Failed to find"): + lt._get_valid_spelling_file_path() + + def test_returns_path_when_spelling_file_exists(self, tmp_path: Path) -> None: + """Returns the spelling file path when the file exists.""" + spelling_path = ( + tmp_path + / "org" + / "languagetool" + / "resource" + / "en" + / "hunspell" + / "spelling.txt" + ) + spelling_path.parent.mkdir(parents=True) + spelling_path.write_text("existing\n", encoding="utf-8") + mock_llt = _MockLocalLT(directory=tmp_path) + lt = _bare_lt() + lt._local_language_tool = mock_llt # type: ignore[assignment] + result = lt._get_valid_spelling_file_path() + assert result == spelling_path + + +class TestRegisterSpellingsBody: + """Tests for _register_spellings() body when new spellings are present.""" + + def test_writes_new_spellings_to_file(self, tmp_path: Path) -> None: + """_register_spellings() appends new words not already in the file.""" + spelling_path = ( + tmp_path + / "org" + / "languagetool" + / "resource" + / "en" + / "hunspell" + / "spelling.txt" + ) + spelling_path.parent.mkdir(parents=True) + spelling_path.write_text("existing\n", encoding="utf-8") + lt = _bare_lt(_new_spellings=["newword"]) + lt._local_language_tool = _MockLocalLT(directory=tmp_path) # type: ignore[assignment] + lt._register_spellings() + content = spelling_path.read_text(encoding="utf-8") + assert "newword" in content + + +class TestUnregisterSpellingsBody: + """Tests for _unregister_spellings() body when new spellings are present.""" + + def test_removes_spellings_from_file(self, tmp_path: Path) -> None: + """_unregister_spellings() removes the registered words from the file.""" + spelling_path = ( + tmp_path + / "org" + / "languagetool" + / "resource" + / "en" + / "hunspell" + / "spelling.txt" + ) + spelling_path.parent.mkdir(parents=True) + spelling_path.write_text("existing\nnewword\n", encoding="utf-8") + lt = _bare_lt(_new_spellings=["newword"]) + lt._local_language_tool = _MockLocalLT(directory=tmp_path) # type: ignore[assignment] + lt._unregister_spellings() + content = spelling_path.read_text(encoding="utf-8") + assert "newword" not in content + assert "existing" in content + + +class TestRegisterUnregisterSpellings: + """Tests for early-return paths in spelling registration methods.""" + + def test_register_spellings_returns_early_when_new_spellings_is_none( + self, + ) -> None: + """_register_spellings() returns immediately when _new_spellings is None.""" + lt = _bare_lt(_new_spellings=None) + lt._register_spellings() # must not raise + + def test_unregister_spellings_returns_early_when_new_spellings_is_none( + self, + ) -> None: + """_unregister_spellings() returns immediately when _new_spellings is None.""" + lt = _bare_lt(_new_spellings=None) + lt._unregister_spellings() # must not raise + + +class TestGetLanguages: + """Tests for _get_languages() error branches.""" + + def test_raises_when_query_returns_none(self) -> None: + """Raises ServerError when _query_server returns None.""" + lt = _bare_lt() + with ( + patch.object(lt, "_start_server_if_needed"), + patch.object(lt, "_query_server", return_value=None), + pytest.raises(ServerError, match="No response received"), + ): + lt._get_languages() + + def test_raises_when_list_item_fails_is_language_info(self) -> None: + """Raises ServerError when a list item does not pass is_language_info.""" + lt = _bare_lt() + bad_items: list[dict[str, str]] = [_BAD_LANG_ITEM] + with ( + patch.object(lt, "_start_server_if_needed"), + patch.object(lt, "_query_server", return_value=bad_items), + pytest.raises(ServerError, match="Unexpected response format"), + ): + lt._get_languages() + + def test_raises_when_response_is_not_a_list(self) -> None: + """Raises ServerError when response is not a list.""" + lt = _bare_lt() + with ( + patch.object(lt, "_start_server_if_needed"), + patch.object(lt, "_query_server", return_value=_NON_LIST_RESPONSE), + pytest.raises(ServerError, match="Unexpected response format"), + ): + lt._get_languages() + + def test_returns_language_set_from_valid_list_response(self) -> None: + """Returns language codes when the server returns a valid language list.""" + lt = _bare_lt() + with ( + patch.object(lt, "_start_server_if_needed"), + patch.object(lt, "_query_server", return_value=_VALID_LANG_LIST), + ): + langs = lt._get_languages() + assert "en" in langs + assert "en-US" in langs + assert "auto" in langs + + +class TestStartServerIfNeeded: + """Tests for _start_server_if_needed().""" + + def test_calls_start_on_free_port_when_server_not_alive_and_not_remote( + self, + ) -> None: + """_start_server_on_free_port() is called when server is dead and local.""" + lt = _bare_lt(_server=None, _remote=False) + with patch.object(lt, "_start_server_on_free_port") as mock_start: + lt._start_server_if_needed() + mock_start.assert_called_once() + + +class TestQueryServer: + """Tests for _query_server() network-error handling.""" + + def test_raises_language_tool_error_on_oserror_when_remote(self) -> None: + """Raises LanguageToolError when session.get raises OSError (remote).""" + lt = _bare_lt(_remote=True) + lt._session = _MockSession(get_exc=OSError("connection refused")) + with pytest.raises(LanguageToolError, match="connection refused"): + lt._query_server("http://fake/", num_tries=1) + + def test_restarts_local_server_before_raising_on_oserror(self) -> None: + """Terminate and start are called when local server gets OSError.""" + lt = _bare_lt(_remote=False) + lt._session = _MockSession(get_exc=OSError("connection refused")) + with ( + patch.object(lt, "_terminate_server") as mock_term, + patch.object(lt, "_start_local_server") as mock_start, + pytest.raises(LanguageToolError, match="connection refused"), + ): + lt._query_server("http://fake/", num_tries=1) + mock_term.assert_called_once() + mock_start.assert_called_once() + + +class TestStartServerOnFreePort: + """Tests for _start_server_on_free_port() port-retry logic.""" + + def test_retries_with_next_port_when_first_port_busy(self) -> None: + """Port is incremented and server restarted when first attempt fails.""" + lt = _bare_lt() + lt._port = _DEFAULT_PORT + lt._available_ports = [_NEXT_PORT] + with patch.object( + lt, + "_start_local_server", + side_effect=_start_fails_once_then_succeeds(), + ): + lt._start_server_on_free_port() + assert lt._port == _NEXT_PORT + + def test_raises_when_no_ports_remain(self) -> None: + """ServerError is re-raised when no more ports are available.""" + lt = _bare_lt() + lt._available_ports = list[int]() + err = "all ports exhausted" + with ( + patch.object(lt, "_start_local_server", side_effect=ServerError(err)), + pytest.raises(ServerError, match=err), + ): + lt._start_server_on_free_port() + + +class TestQueryServerResponseHandling: + """Tests for _query_server() JSON-parsing and POST-method paths.""" + + def test_returns_parsed_json_on_successful_get(self) -> None: + """Returns parsed JSON dict when the GET response contains valid JSON.""" + r = _make_json_response(b'{"ok": true}') + lt = _bare_lt() + lt._session = _MockSession(get_response=r) + result = lt._query_server("http://fake/", num_tries=1) + assert result == {"ok": True} + + def test_post_method_returns_parsed_json(self) -> None: + """POST method routes through session.post and returns parsed JSON.""" + r = _make_json_response(b'{"key": "value"}') + lt = _bare_lt() + lt._session = _MockSession(post_response=r) + result = lt._query_server("http://fake/", method="post") + assert result == {"key": "value"} + + def test_raises_language_tool_error_on_invalid_json(self) -> None: + """LanguageToolError is raised when the response body is not valid JSON.""" + r = _make_json_response(b"not json", status_code=200) + lt = _bare_lt(_remote=True) + lt._session = _MockSession(get_response=r) + with pytest.raises(LanguageToolError): + lt._query_server("http://fake/", num_tries=1) + + def test_raises_rate_limit_error_on_http_426(self) -> None: + """RateLimitError is raised on HTTP 426 with non-JSON body.""" + r = _make_json_response(b"rate limited", status_code=_HTTP_RATE_LIMIT_STATUS) + lt = _bare_lt(_remote=True) + lt._session = _MockSession(get_response=r) + with pytest.raises(RateLimitError): + lt._query_server("http://fake/", num_tries=1) + + def test_returns_none_with_zero_tries(self) -> None: + """Returns None immediately when num_tries=0 (loop never executes).""" + lt = _bare_lt() + result = lt._query_server("http://fake/", num_tries=0) + assert result is None + + +class TestKillProcesses: + """Tests for _kill_processes() iteration and wait logic.""" + + def test_suppresses_no_such_process_and_calls_wait(self) -> None: + """_kill_processes() iterates, suppresses NoSuchProcess, and calls p.wait().""" + mock_proc = _MockProcess() + _kill_processes([mock_proc]) # type: ignore[list-item] + + +class TestTerminateServer: + """Tests for _terminate_server() body.""" + + def test_kills_process_removes_from_list_and_clears_server(self) -> None: + """_terminate_server() kills the process, removes it, and sets _server=None.""" + mock_proc = _MockProcess(poll_return=None) + running: list[object] = [mock_proc] + lt = _bare_lt() + lt._server = mock_proc # type: ignore[assignment] + with ( + patch("language_tool_python.server._RUNNING_SERVER_PROCESSES", running), + patch("language_tool_python.server._kill_processes") as mock_kill, + ): + lt._terminate_server() + assert lt._server is None + mock_kill.assert_called_once() + assert mock_proc not in running + + def test_closes_stdin_when_present(self) -> None: + """_terminate_server() calls stdin.close() when stdin is not None.""" + + class _MockStdin: + def __init__(self) -> None: + self.closed = False + + def close(self) -> None: + self.closed = True + + mock_stdin = _MockStdin() + mock_proc = _MockProcess(poll_return=None) + mock_proc.stdin = mock_stdin # type: ignore[assignment] + running: list[object] = [mock_proc] + lt = _bare_lt() + lt._server = mock_proc # type: ignore[assignment] + with ( + patch("language_tool_python.server._RUNNING_SERVER_PROCESSES", running), + patch("language_tool_python.server._kill_processes"), + ): + lt._terminate_server() + assert mock_stdin.closed + + +class TestStartLocalServer: + """Tests for _start_local_server() PathError branch.""" + + def test_raises_path_error_when_get_server_cmd_fails(self) -> None: + """Wraps get_server_cmd PathError with 'Failed to find LanguageTool'.""" + lt = _bare_lt() + with ( + patch( + "language_tool_python.server.LocalLanguageTool.from_version_name", + return_value=_MockLocalLT(), + ), + pytest.raises(PathError, match="Failed to find LanguageTool"), + ): + lt._start_local_server() + + +class TestStartLocalServerSuccess: + """Tests for _start_local_server() success path.""" + + def test_spawns_process_and_schedules_wait_for_ready(self) -> None: + """_start_local_server() spawns Popen and calls _wait_for_server_ready.""" + mock_proc = _MockProcess(poll_return=None) + running: list[object] = [] + lt = _bare_lt() + with ( + patch( + "language_tool_python.server.LocalLanguageTool.from_version_name", + return_value=_MockLocalLTWithCmd(), + ), + patch( + "language_tool_python.server.subprocess.Popen", return_value=mock_proc + ), + patch("language_tool_python.server._RUNNING_SERVER_PROCESSES", running), + patch.object(lt, "_wait_for_server_ready"), + ): + lt._start_local_server() + assert mock_proc in running + lt._server = None + + +class TestWaitForServerReady: + """Tests for _wait_for_server_ready() error branches.""" + + def test_raises_when_server_is_none(self) -> None: + """Raises ServerError when _server is None.""" + lt = _bare_lt(_server=None) + with pytest.raises(ServerError, match="Server process is not initialized"): + lt._wait_for_server_ready() + + def test_raises_when_server_process_exits_early(self) -> None: + """Raises ServerError when server poll() returns a non-None exit code.""" + mock_proc = _MockProcess(poll_return=1) + lt = _bare_lt() + lt._server = mock_proc # type: ignore[assignment] + with pytest.raises(ServerError, match="exited early"): + lt._wait_for_server_ready() + + def test_raises_when_timeout_expires_before_server_responds(self) -> None: + """Raises ServerError when the server does not respond within timeout.""" + mock_proc = _MockProcess(poll_return=None) + lt = _bare_lt() + lt._server = mock_proc # type: ignore[assignment] + with pytest.raises(ServerError, match="did not become ready"): + lt._wait_for_server_ready(timeout=0) + lt._server = None # prevent GC __del__ re-triggering + + def test_returns_when_server_responds_ok(self) -> None: + """Returns without error when the healthcheck endpoint responds HTTP 200.""" + mock_proc = _MockProcess(poll_return=None) + r = _make_json_response(b"OK", status_code=200) + lt = _bare_lt() + lt._server = mock_proc # type: ignore[assignment] + lt._session = _MockSession(get_response=r) + lt._wait_for_server_ready(timeout=10) + lt._server = None + + def test_sleeps_when_server_not_ready_yet(self) -> None: + """time.sleep() is called when server responds but r.ok is False.""" + mock_proc = _MockProcess(poll_return=None) + r = _make_json_response(b"not ready", status_code=503) + lt = _bare_lt() + lt._server = mock_proc # type: ignore[assignment] + lt._session = _MockSession(get_response=r) + times: list[float] = [0.0, 0.0, 2.0] + with ( + patch("language_tool_python.server.time.time", side_effect=times), + patch("language_tool_python.server.time.sleep") as mock_sleep, + pytest.raises(ServerError, match="did not become ready"), + ): + lt._wait_for_server_ready(timeout=1) + mock_sleep.assert_called_once_with(0.2) + lt._server = None + + +class TestTerminateServerAtExit: + """Tests for the atexit handler.""" + + def test_logs_and_kills_when_processes_are_running(self) -> None: + """Logger and _kill_processes are called when processes exist.""" + mock_proc = _MockProcess() + with ( + patch( + "language_tool_python.server._RUNNING_SERVER_PROCESSES", + [mock_proc], + ), + patch("language_tool_python.server._kill_processes") as mock_kill, + ): + _terminate_server_at_exit() + mock_kill.assert_called_once() + + +class TestLanguageToolConstructorLocalServer: + """Tests for __init__ branches that start a local server or register spellings.""" + + def test_constructor_starts_local_server_when_no_remote(self) -> None: + """Local server startup is triggered when no remote_server is given.""" + with ( + patch("language_tool_python.server.get_locale_language", return_value="en"), + patch.object(LanguageTool, "_get_languages", return_value=_LANGUAGES), + patch.object(LanguageTool, "_start_server_on_free_port") as mock_start, + LanguageTool() as _lt, + ): + mock_start.assert_called_once() + + def test_constructor_registers_new_spellings_when_provided(self) -> None: + """New spellings are registered when the new_spellings arg is non-empty.""" + with ( + patch("language_tool_python.server.get_locale_language", return_value="en"), + patch.object(LanguageTool, "_get_languages", return_value=_LANGUAGES), + patch.object(LanguageTool, "_start_server_on_free_port"), + patch.object(LanguageTool, "_register_spellings") as mock_reg, + LanguageTool(new_spellings=["hello"]) as _lt, + ): + mock_reg.assert_called_once() + + +class TestLanguageToolPublicAPI: + """Tests for LanguageToolPublicAPI constructor.""" + + def test_initializes_with_public_api_remote_server(self) -> None: + """LanguageToolPublicAPI sets is_remote=True via the public API URL.""" + with ( + patch("language_tool_python.server.get_locale_language", return_value="en"), + patch.object(LanguageTool, "_get_languages", return_value=_LANGUAGES), + LanguageToolPublicAPI() as lt, + ): + assert lt.is_remote is True diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index aa44343..b449da9 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -6,7 +6,7 @@ import pytest -from language_tool_python.match import Match +from language_tool_python.match import Match, four_byte_char_positions from language_tool_python.utils import TextStatus, classify_matches, correct if TYPE_CHECKING: @@ -37,7 +37,7 @@ def _make_match( "ignoreForIncompleteSentence": False, "contextForSureMatch": 0, } - return Match(attrib, "text here.") + return Match(attrib, four_byte_char_positions("text here.")) class TestClassifyMatches: @@ -103,7 +103,7 @@ def test_correction_replaces_error(self) -> None: "ignoreForIncompleteSentence": False, "contextForSureMatch": 0, } - m = Match(attrib, text) + m = Match(attrib, four_byte_char_positions(text)) result = correct(text, [m]) assert result == "Hello world" @@ -155,8 +155,9 @@ def test_overlapping_match_skips_mismatched_error(self) -> None: "ignoreForIncompleteSentence": False, "contextForSureMatch": 0, } - m1 = Match(attrib1, text) - m2 = Match(attrib2, text) + positions = four_byte_char_positions(text) + m1 = Match(attrib1, positions) + m2 = Match(attrib2, positions) result = correct(text, [m1, m2]) assert result == "xxxxxxbbc" @@ -199,8 +200,9 @@ def test_correct_adjusts_offset_for_length_change(self) -> None: "ignoreForIncompleteSentence": False, "contextForSureMatch": 0, } - m1 = Match(attrib1, text) - m2 = Match(attrib2, text) + positions = four_byte_char_positions(text) + m1 = Match(attrib1, positions) + m2 = Match(attrib2, positions) result = correct(text, [m1, m2]) assert result == "AAA BBB c" From 3cc19d6b9cea0d2d5f7fe3705571db84d965ae1f Mon Sep 17 00:00:00 2001 From: mdevolde Date: Wed, 1 Jul 2026 20:38:43 +0300 Subject: [PATCH 4/7] test: add test to cover new changes from #210 #211 #212 #213 --- tests/unit/test_match.py | 15 +++++++++++++++ tests/unit/test_server_unit.py | 22 ++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/tests/unit/test_match.py b/tests/unit/test_match.py index 56a384d..db36637 100644 --- a/tests/unit/test_match.py +++ b/tests/unit/test_match.py @@ -106,6 +106,21 @@ def test_four_byte_char_adjustment(self) -> None: adjusted_offset = 2 assert m.offset == adjusted_offset + def test_four_byte_char_after_match_is_not_counted(self) -> None: + """A 4-byte emoji after the match does not shift the offset.""" + # "🌅" at position 6 is after the match at offset 0, so the adjustment + # loop must break on its first iteration instead of counting it. + text = "hello 🌅" + attrib = _make_attrib( + offset=0, + length=5, + context_text=text, + context_offset=0, + sentence=text, + ) + m = Match(attrib, four_byte_char_positions(text)) + assert m.offset == 0 + def test_no_adjustment_without_four_byte_chars(self) -> None: """Offsets are unchanged when no 4-byte characters precede the match.""" text = "Hello world today" diff --git a/tests/unit/test_server_unit.py b/tests/unit/test_server_unit.py index 2374bcc..d635c69 100644 --- a/tests/unit/test_server_unit.py +++ b/tests/unit/test_server_unit.py @@ -882,6 +882,16 @@ def test_raises_rate_limit_error_on_http_426(self) -> None: with pytest.raises(RateLimitError): lt._query_server("http://fake/", num_tries=1) + def test_raises_rate_limit_error_on_http_426_with_valid_json(self) -> None: + """RateLimitError is raised on HTTP 426 even when the body is valid JSON.""" + r = _make_json_response( + b'{"error": "rate limited"}', status_code=_HTTP_RATE_LIMIT_STATUS + ) + lt = _bare_lt(_remote=True) + lt._session = _MockSession(get_response=r) + with pytest.raises(RateLimitError): + lt._query_server("http://fake/", num_tries=1) + def test_returns_none_with_zero_tries(self) -> None: """Returns None immediately when num_tries=0 (loop never executes).""" lt = _bare_lt() @@ -939,6 +949,18 @@ def close(self) -> None: lt._terminate_server() assert mock_stdin.closed + def test_does_not_raise_when_process_not_in_running_list(self) -> None: + """_terminate_server() does not raise if the process is absent from the list.""" + mock_proc = _MockProcess(poll_return=None) + lt = _bare_lt() + lt._server = mock_proc # type: ignore[assignment] + with ( + patch("language_tool_python.server._RUNNING_SERVER_PROCESSES", []), + patch("language_tool_python.server._kill_processes"), + ): + lt._terminate_server() + assert lt._server is None + class TestStartLocalServer: """Tests for _start_local_server() PathError branch.""" From ead694b75dd3638583014dc8e606d63f8987809a Mon Sep 17 00:00:00 2001 From: mdevolde Date: Thu, 2 Jul 2026 00:00:22 +0300 Subject: [PATCH 5/7] test: implement py testing good practices --- tests/benchmarks/test_bench_check.py | 32 ++- tests/integration/test_api_public.py | 104 +++++---- tests/integration/test_cli.py | 4 +- tests/integration/test_config.py | 61 ++++-- tests/integration/test_download.py | 15 -- tests/integration/test_server_local.py | 38 ++-- tests/property/test_prop_config.py | 91 +++++++- tests/property/test_prop_language_tag.py | 114 ++++++++++ tests/property/test_prop_safe_zip.py | 151 +++++++++---- tests/property/test_prop_utils.py | 40 ++++ tests/unit/test_cli_args.py | 16 +- tests/unit/test_cli_unit.py | 79 ++++++- tests/unit/test_config_unit.py | 221 +++++++++++++------ tests/unit/test_download.py | 149 +++++-------- tests/unit/test_download_unit.py | 25 +++ tests/unit/test_internals_utils.py | 97 ++++++--- tests/unit/test_language_tag.py | 68 +++--- tests/unit/test_match.py | 105 ++++++--- tests/unit/test_safe_zip.py | 260 +++++++++-------------- tests/unit/test_server_unit.py | 93 +++++++- 20 files changed, 1178 insertions(+), 585 deletions(-) create mode 100644 tests/property/test_prop_language_tag.py diff --git a/tests/benchmarks/test_bench_check.py b/tests/benchmarks/test_bench_check.py index 2cca935..9b34d67 100644 --- a/tests/benchmarks/test_bench_check.py +++ b/tests/benchmarks/test_bench_check.py @@ -1,6 +1,10 @@ """Benchmark tests for LanguageTool grammar checking performance. Run with: pytest tests/benchmarks/ -v + +A JVM is required to run these benchmarks. If the local LanguageTool JAR cache is +empty, the first ``LanguageTool(...)`` call in this module triggers a real download +of the LanguageTool archive over the network. """ from __future__ import annotations @@ -23,18 +27,36 @@ @pytest.fixture(scope="module") def tool() -> Generator[language_tool_python.LanguageTool, None, None]: - """Provide a LanguageTool instance shared across benchmarks in this module.""" + """Provide a LanguageTool instance shared across benchmarks in this module. + + Performs one warm-up ``check()`` call before yielding, since pytest.ini does + not configure ``--benchmark-warmup``: without it, whichever benchmark happens + to run first in this module would absorb the server's cold-start JIT cost. + """ with language_tool_python.LanguageTool("en-US") as t: + t.check("warm-up") yield t @pytest.fixture(scope="module") def cached_tool() -> Generator[language_tool_python.LanguageTool, None, None]: - """Provide a pipeline-caching LanguageTool instance for cache benchmarks.""" + """Provide a pipeline-caching LanguageTool instance for cache benchmarks. + + ``cacheSize=1000`` sets the maximum number of previously checked sentences the + server keeps in memory, ``pipelineCaching=True`` additionally caches the + internal per-language analysis pipeline (tokenizer, tagger, etc.) so it is not + rebuilt on every request. Together they let repeated checks of the same + sentence skip most of the analysis work. + + Performs one warm-up ``check()`` call (on text distinct from the benchmarked + sentence, so it does not itself pre-populate the cache for that sentence) + before yielding, for the same cold-start reason as the ``tool`` fixture above. + """ with language_tool_python.LanguageTool( "en-US", config={"cacheSize": 1000, "pipelineCaching": True}, ) as t: + t.check("warm-up, unrelated to any benchmarked sentence") yield t @@ -76,6 +98,10 @@ def test_bench_check_with_pipeline_cache( ) -> None: """Benchmark grammar checking with pipeline caching enabled. - Compare with test_bench_check_short_text to measure cache speedup. + Every round checks the same ``_SHORT_TEXT`` on the same server instance, so + (after the fixture's warm-up call) all rounds but the very first are cache + hits. Compare the resulting numbers with ``test_bench_check_short_text`` + (same text, no caching configured) to estimate the cache's speedup, this + test alone does not exercise a cache miss/hit contrast within itself. """ benchmark(cached_tool.check, _SHORT_TEXT) diff --git a/tests/integration/test_api_public.py b/tests/integration/test_api_public.py index 307c5d6..1d16672 100644 --- a/tests/integration/test_api_public.py +++ b/tests/integration/test_api_public.py @@ -1,69 +1,65 @@ """Integration tests for the public API functionality.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + import pytest import language_tool_python from language_tool_python.exceptions import RateLimitError +if TYPE_CHECKING: + from language_tool_python.match import Match + +_ES_TEXT = ( + "Escriba un texto aquí. LanguageTool le ayudará a afrentar " + "algunas dificultades propias de la escritura. Se a hecho un esfuerzo " + "para detectar errores tipográficos, ortograficos y incluso " + "gramaticales. También algunos errores de estilo, a grosso modo." +) -def test_remote_es() -> None: - """Test the public API with Spanish language text. - This test verifies that the LanguageToolPublicAPI correctly identifies various - errors in a Spanish text sample. +@pytest.fixture(scope="module") +def es_matches() -> list[Match]: + """Check ``_ES_TEXT`` once against the public API and share the result. - :raises AssertionError: If the detected matches do not match the expected output. + The check is only performed once (module-scoped) to limit the number of + requests sent to the public API. If the request is rate-limited, every test + depending on this fixture is skipped silently rather than failing. """ try: with language_tool_python.LanguageToolPublicAPI("es") as tool: - es_text = ( - "Escriba un texto aquí. LanguageTool le ayudará a afrentar " - "algunas dificultades propias de la escritura. Se a hecho un esfuerzo " - "para detectar errores tipográficos, ortograficos y incluso " - "gramaticales. También algunos errores de estilo, a grosso modo." - ) - matches = tool.check(es_text) - assert ( - str(matches) - == """[Match({'rule_id': 'AFRENTAR_DIFICULTADES', 'message': 'Confusión - entre «afrontar» y «afrentar».', 'replacements': ['afrontar'], - 'offset_in_context': 43, 'context': '...n texto aquí. - - LanguageTool le ayudará a afrentar algunas dificultades propias de la - escr...', 'offset': 49, 'error_length': 8, 'category': - 'INCORRECT_EXPRESSIONS', 'rule_issue_type': 'grammar', 'sentence': - 'LanguageTool le ayudará a afrentar algunas dificultades propias de - la escritura.'}), Match({'rule_id': 'PRON_HABER_PARTICIPIO', - 'message': 'El v. ‘haber’ se escribe con hache.', - 'replacements': ['ha'], 'offset_in_context': 43, 'context': - '...ificultades propias de la escritura. Se a hecho un esfuerzo para - detectar errores...', 'offset': 107, 'error_length': 1, 'category': - 'MISSPELLING', 'rule_issue_type': 'misspelling', 'sentence': 'Se a - hecho un esfuerzo para detectar errores tipográficos, ortograficos y - incluso gramaticales.'}), Match({'rule_id': 'MORFOLOGIK_RULE_ES', - 'message': 'Se ha encontrado un posible error ortográfico.', - 'replacements': ['ortográficos', 'ortográficas', 'ortográfico', - 'orográficos', 'ortografiaos', 'ortografíeos'], 'offset_in_context': - 43, 'context': '...rzo para detectar errores tipográficos, - ortograficos y incluso gramaticales. También algunos...', 'offset': - 163, 'error_length': 12, 'category': 'TYPOS', 'rule_issue_type': - 'misspelling', 'sentence': 'Se a hecho un esfuerzo para detectar - errores tipográficos, ortograficos y incluso gramaticales.'}), - Match({'rule_id': 'Y_E_O_U', 'message': 'Cuando precede a palabras - que comienzan por ‘i’, la conjunción ‘y’ se - transforma en ‘e’.', 'replacements': ['e'], - 'offset_in_context': 43, 'context': '...ctar errores tipográficos, - ortograficos y incluso gramaticales. También algunos e...', 'offset': - 176, 'error_length': 1, 'category': 'GRAMMAR', 'rule_issue_type': - 'grammar', 'sentence': 'Se a hecho un esfuerzo para detectar errores - tipográficos, ortograficos y incluso gramaticales.'}), - Match({'rule_id': 'GROSSO_MODO', 'message': 'Esta expresión latina se - usa sin preposición.', 'replacements': ['grosso modo'], - 'offset_in_context': 43, 'context': '...les. También algunos errores - de estilo, a grosso modo.', 'offset': 235, 'error_length': 13, - 'category': 'GRAMMAR', 'rule_issue_type': 'grammar', 'sentence': - 'También algunos errores de estilo, a grosso modo.'})] - """ - ) + return tool.check(_ES_TEXT) except RateLimitError: pytest.skip("Rate limit exceeded for public API.") + + +@pytest.mark.parametrize( + ("rule_id", "category", "offset"), + [ + ("AFRENTAR_DIFICULTADES", "INCORRECT_EXPRESSIONS", 49), + ("PRON_HABER_PARTICIPIO", "MISSPELLING", 107), + ("MORFOLOGIK_RULE_ES", "TYPOS", 163), + ("Y_E_O_U", "GRAMMAR", 176), + ("GROSSO_MODO", "GRAMMAR", 235), + ], +) +def test_remote_es( + es_matches: list[Match], + rule_id: str, + category: str, + offset: int, +) -> None: + """Test that the public API detects a specific known error in Spanish text. + + LanguageTool rules can change over time, so this asserts on individual match + fields (rule_id, category, offset) rather than requiring the entire response + to match a frozen snapshot exactly. + + :raises AssertionError: If no match with the expected fields is found. + """ + assert any( + m.rule_id == rule_id and m.category == category and m.offset == offset + for m in es_matches + ) diff --git a/tests/integration/test_cli.py b/tests/integration/test_cli.py index d101f66..4d257d9 100644 --- a/tests/integration/test_cli.py +++ b/tests/integration/test_cli.py @@ -37,9 +37,7 @@ def remote_server() -> Generator[tuple[str, int], None, None]: :rtype: Generator[Tuple[str, int], None, None] """ with language_tool_python.LanguageTool("en-US") as tool: - host = tool._host - port = tool._port - yield host, port + yield tool.host, tool.port @pytest.mark.parametrize( diff --git a/tests/integration/test_config.py b/tests/integration/test_config.py index 3a2ba42..2175f6c 100644 --- a/tests/integration/test_config.py +++ b/tests/integration/test_config.py @@ -135,29 +135,60 @@ def test_config_caching() -> None: This test verifies that LanguageTool's caching mechanism (cacheSize and pipelineCaching) significantly improves performance when checking the same text - multiple times. The test measures the time difference between the first and second - checks to ensure caching provides a substantial speedup. + multiple times. The test measures the time difference between an uncached and a + cached check to ensure caching provides a substantial speedup. + + This is inherently a timing-sensitive test and could still be flaky under heavy + machine load, so it: (1) performs a warm-up check on unrelated text before + timing, to exclude one-off JIT/connection-setup costs from the measurement + without pre-populating the cache for the text under test, and (2) repeats the + timed comparison up to ``_ATTEMPTS`` times, succeeding as soon as one attempt + shows the expected speedup, instead of requiring every attempt to pass. :raises AssertionError: If caching does not provide the expected performance - improvement. + improvement in any attempt. """ + speedup_factor = 5.0 + attempts = 3 + with language_tool_python.LanguageTool( "en-US", config={"cacheSize": 1000, "pipelineCaching": True}, ) as tool: + tool.check("warm-up text unrelated to the cached sentence below") + s = "hello darkness my old frend" - t1 = time.time() - tool.check(s) - t2 = time.time() - tool.check(s) - t3 = time.time() - - # This is a silly test that says: caching should speed up a grammar-checking - # by a factor of speed_factor when checking the same sentence twice. It - # theoretically could be very flaky. - # But in practice I've observed speedup of around 250x (6.76s to 0.028s). - speedup_factor = 10.0 - assert (t2 - t1) / speedup_factor > (t3 - t2) + for _ in range(attempts): + t1 = time.time() + tool.check(s) + t2 = time.time() + tool.check(s) + t3 = time.time() + + # In practice, speedups of around 250x (6.76s to 0.028s) have been observed. + if (t2 - t1) / speedup_factor > (t3 - t2): + return + + pytest.fail( + f"Caching did not provide the expected speedup in {attempts} attempts." + ) + + +def test_inexistent_language() -> None: + """Test that creating a LanguageTag with an invalid language code raises an error. + + This test verifies that the LanguageTag constructor correctly validates language + codes and raises a ValueError when given a language code that is not supported. + A real server is required here to obtain the list of supported languages via + ``tool._get_languages()``. + + :raises AssertionError: If ValueError is not raised for an invalid language code. + """ + with ( + language_tool_python.LanguageTool("en-US") as tool, + pytest.raises(ValueError, match="unsupported language"), + ): + language_tool_python.LanguageTag("xx-XX", tool._get_languages()) def test_disabled_rule_in_config() -> None: diff --git a/tests/integration/test_download.py b/tests/integration/test_download.py index fd74f50..7f406fc 100644 --- a/tests/integration/test_download.py +++ b/tests/integration/test_download.py @@ -33,21 +33,6 @@ def test_install_too_old_version() -> None: language_tool_python.LanguageTool(language_tool_download_version="3.9") -def test_inexistent_language() -> None: - """Test that creating a LanguageTag with an invalid language code raises an error. - - This test verifies that the LanguageTag constructor correctly validates language - codes and raises a ValueError when given a language code that is not supported. - - :raises AssertionError: If ValueError is not raised for an invalid language code. - """ - with ( - language_tool_python.LanguageTool("en-US") as tool, - pytest.raises(ValueError, match="unsupported language"), - ): - language_tool_python.LanguageTag("xx-XX", tool._get_languages()) - - def test_install_oldest_supported_version() -> None: """Test that downloading the oldest supported LanguageTool version works correctly. diff --git a/tests/integration/test_server_local.py b/tests/integration/test_server_local.py index 31b4b47..fb21a95 100644 --- a/tests/integration/test_server_local.py +++ b/tests/integration/test_server_local.py @@ -58,7 +58,6 @@ def test_process_starts_and_stops_on_close() -> None: # Make sure process stopped after close() was called. time.sleep(0.5) # Give some time for process to stop after close() call. assert proc.poll() is not None, "tool._server should stop running after deletion" - # remember --> if poll is None: # p.subprocess is alive def test_local_client_server_connection() -> None: @@ -72,7 +71,7 @@ def test_local_client_server_connection() -> None: server. """ with language_tool_python.LanguageTool("en-US", host="127.0.0.1") as tool1: - url = f"http://{tool1._host}:{tool1._port}/" + url = f"http://{tool1.host}:{tool1.port}/" with language_tool_python.LanguageTool("en-US", remote_server=url) as tool2: assert len(tool2.check("helo darknes my old frend")) @@ -104,28 +103,31 @@ def test_session_only_new_spellings() -> None: initial_checksum = hashlib.sha256(initial_spelling_file_contents.encode()) new_spellings = ["word1", "word2", "word3"] - with language_tool_python.LanguageTool( - "en-US", - new_spellings=new_spellings, - new_spellings_persist=False, - ) as tool: - tool.enabled_rules_only = True - tool.enabled_rules = {"MORFOLOGIK_RULE_EN_US"} - matches = tool.check(" ".join(new_spellings)) + try: + with language_tool_python.LanguageTool( + "en-US", + new_spellings=new_spellings, + new_spellings_persist=False, + ) as tool: + tool.enabled_rules_only = True + tool.enabled_rules = {"MORFOLOGIK_RULE_EN_US"} + matches = tool.check(" ".join(new_spellings)) + + with spelling_file_path.open("r", encoding="utf-8") as spelling_file: + subsequent_spelling_file_contents = spelling_file.read() + subsequent_checksum = hashlib.sha256(subsequent_spelling_file_contents.encode()) - with spelling_file_path.open("r", encoding="utf-8") as spelling_file: - subsequent_spelling_file_contents = spelling_file.read() - subsequent_checksum = hashlib.sha256(subsequent_spelling_file_contents.encode()) - - if initial_checksum != subsequent_checksum: + assert not matches + assert initial_checksum.hexdigest() == subsequent_checksum.hexdigest() + finally: + # Restore unconditionally (not just when a diff was detected) so that a + # failed assertion above can never leave the shared spelling file corrupted + # for the rest of the test session. with spelling_file_path.open( "w", encoding="utf-8", newline="\n" ) as spelling_file: spelling_file.write(initial_spelling_file_contents) - assert not matches - assert initial_checksum.hexdigest() == subsequent_checksum.hexdigest() - def test_new_spellins_in_es() -> None: """Test that new spellings are recognized in Spanish language. diff --git a/tests/property/test_prop_config.py b/tests/property/test_prop_config.py index 41b95a9..3cbe857 100644 --- a/tests/property/test_prop_config.py +++ b/tests/property/test_prop_config.py @@ -4,14 +4,30 @@ hold for any input, not just the handwritten examples in unit tests. """ +import string + import pytest from hypothesis import given, settings from hypothesis import strategies as st -from language_tool_python.config_file import LanguageToolConfig +from language_tool_python.config_file import ( + _CONFIG_SCHEMA, + LanguageToolConfig, + _bool_encoder, + _encode_config, + _int_encoder, + _is_lang_key, +) _LINEBREAK_CHARS = ["\n", "\r", "\r\n"] +_INT_KEYS = [ + key for key, spec in _CONFIG_SCHEMA.items() if spec.encoder is _int_encoder +] +_BOOL_KEYS = [ + key for key, spec in _CONFIG_SCHEMA.items() if spec.encoder is _bool_encoder +] + @given( before=st.text(), @@ -83,3 +99,76 @@ def test_prop_config_key_with_linebreak_always_raises( key = key_before + linebreak + key_after with pytest.raises(ValueError, match="line breaks"): LanguageToolConfig({key: "valid_value"}) + + +@given(code=st.text(alphabet=string.ascii_lowercase, min_size=1, max_size=10)) +@settings(max_examples=200) +def test_prop_is_lang_key_accepts_any_lang_prefixed_code(code: str) -> None: + """Any 'lang-' key is recognized as a language key. + + :param code: Non-empty lowercase code with no '-' inside it. + :raises AssertionError: If _is_lang_key does not accept the generated key. + """ + assert _is_lang_key(f"lang-{code}") is True + + +@given(key=st.text().filter(lambda s: not s.startswith("lang-"))) +@settings(max_examples=200) +def test_prop_is_lang_key_rejects_non_lang_prefixed_keys(key: str) -> None: + """Any key not starting with 'lang-' is never recognized as a language key. + + :param key: Arbitrary text not starting with the 'lang-' prefix. + :raises AssertionError: If _is_lang_key incorrectly accepts the generated key. + """ + assert _is_lang_key(key) is False + + +@given( + key=st.sampled_from(_INT_KEYS), + value=st.integers(min_value=-1_000_000, max_value=1_000_000), +) +@settings(max_examples=200) +def test_prop_int_schema_key_round_trips_through_encode_config( + key: str, + value: int, +) -> None: + """Any int-typed schema key round-trips through _encode_config as str(value). + + :param key: A schema key whose encoder is _int_encoder. + :param value: An arbitrary integer value. + :raises AssertionError: If the encoded value does not equal str(value). + """ + result = _encode_config({key: value}) + assert result[key] == str(value) + + +@given(key=st.sampled_from(_BOOL_KEYS), value=st.booleans()) +@settings(max_examples=200) +def test_prop_bool_schema_key_round_trips_through_encode_config( + key: str, + value: bool, +) -> None: + """Any bool-typed schema key round-trips through _encode_config as 'true'/'false'. + + :param key: A schema key whose encoder is _bool_encoder. + :param value: An arbitrary boolean value. + :raises AssertionError: If the encoded value does not match the expected string. + """ + result = _encode_config({key: value}) + assert result[key] == ("true" if value else "false") + + +@given( + key=st.text(alphabet=string.ascii_letters, min_size=1, max_size=20).filter( + lambda s: not s.startswith("lang-") and s not in _CONFIG_SCHEMA, + ), +) +@settings(max_examples=200) +def test_prop_unknown_key_always_raises(key: str) -> None: + """Any key that is neither lang-* nor in the schema always raises ValueError. + + :param key: A generated key guaranteed to be outside lang-* and the schema. + :raises AssertionError: If ValueError is not raised for the unknown key. + """ + with pytest.raises(ValueError, match="unexpected key"): + _encode_config({key: "value"}) diff --git a/tests/property/test_prop_language_tag.py b/tests/property/test_prop_language_tag.py new file mode 100644 index 0000000..b5357bd --- /dev/null +++ b/tests/property/test_prop_language_tag.py @@ -0,0 +1,114 @@ +"""Property-based tests for LanguageTag normalization and ordering.""" + +from __future__ import annotations + +import string + +import pytest +from hypothesis import given, settings +from hypothesis import strategies as st + +from language_tool_python.language_tag import LanguageTag + +_LANGS = ["en-US", "en-GB", "en", "de-DE", "fr-FR", "pt-BR"] + +_valid_tags = st.sampled_from(_LANGS) + + +@given(tag=_valid_tags) +@settings(max_examples=200) +def test_prop_normalize_is_idempotent(tag: str) -> None: + """Normalizing an already-normalized tag returns the same normalized tag. + + :param tag: A tag drawn from the supported language list. + :raises AssertionError: If re-normalizing the normalized tag changes it. + """ + first = LanguageTag(tag, _LANGS) + second = LanguageTag(first.normalized_tag, _LANGS) + assert second.normalized_tag == first.normalized_tag + + +@given( + tag=_valid_tags, + upper=st.booleans(), + use_underscore=st.booleans(), +) +@settings(max_examples=200) +def test_prop_normalize_invariant_to_case_and_separator( + tag: str, + upper: bool, + use_underscore: bool, +) -> None: + """Case and -/_ separator variants of a supported tag normalize identically. + + :param tag: A tag drawn from the supported language list. + :param upper: Whether to uppercase the variant before normalizing. + :param use_underscore: Whether to replace '-' with '_' in the variant. + :raises AssertionError: If the variant normalizes differently from the original. + """ + variant = tag.upper() if upper else tag.lower() + if use_underscore: + variant = variant.replace("-", "_") + baseline = LanguageTag(tag, _LANGS) + variant_tag = LanguageTag(variant, _LANGS) + assert variant_tag.normalized_tag == baseline.normalized_tag + + +@given(a=_valid_tags, b=_valid_tags, c=_valid_tags) +@settings(max_examples=200) +def test_prop_total_ordering_is_consistent(a: str, b: str, c: str) -> None: + """__lt__/__eq__/__hash__ satisfy total-ordering invariants for any triplet. + + Checks irreflexivity of '<', consistency between '==' and '<', hash equality + for equal tags, and transitivity of '<'. + + :param a: First tag drawn from the supported language list. + :param b: Second tag drawn from the supported language list. + :param c: Third tag drawn from the supported language list. + :raises AssertionError: If any total-ordering invariant is violated. + """ + tag_a = LanguageTag(a, _LANGS) + tag_b = LanguageTag(b, _LANGS) + tag_c = LanguageTag(c, _LANGS) + + assert not (tag_a < tag_a) # noqa: PLR0124 # irreflexivity check, not a typo + + if tag_a == tag_b: + assert not (tag_a < tag_b) + assert not (tag_b < tag_a) + assert hash(tag_a) == hash(tag_b) + else: + assert (tag_a < tag_b) != (tag_b < tag_a) # trichotomy for distinct tags + + if tag_a < tag_b < tag_c: + assert tag_a < tag_c # transitivity + + +@st.composite +def _unsupported_tag(draw: st.DrawFn) -> str: + """Generate a string guaranteed to be rejected by LanguageTag(..., _LANGS). + + A 4-6 letter lowercase word with no '-'/'_' separator can never satisfy + ``LanguageTag._LANGUAGE_RE`` (which requires the whole tag to be either a bare + 2-3 letter code, or a 2-3 letter code plus a '-'/'_' plus a 2-letter region), + and is also excluded from the POSIX/C-locale special case. + """ + word = draw( + st.text(alphabet=string.ascii_lowercase, min_size=4, max_size=6), + ) + if word == "posix": + word = "posixx" + return word + + +@given(tag=_unsupported_tag()) +@settings(max_examples=200) +def test_prop_unsupported_tag_always_rejected(tag: str) -> None: + """A tag outside the supported set (and outside the LANGUAGE_RE shape) always + raises ValueError. + + :param tag: An adversarially generated, guaranteed-unsupported tag. + :raises AssertionError: If ValueError is not raised for the unsupported tag. + """ # noqa: D205 + with pytest.raises(ValueError, match="unsupported language"): + LanguageTag(tag, _LANGS) diff --git a/tests/property/test_prop_safe_zip.py b/tests/property/test_prop_safe_zip.py index ae3f992..371a0a9 100644 --- a/tests/property/test_prop_safe_zip.py +++ b/tests/property/test_prop_safe_zip.py @@ -1,13 +1,12 @@ """Property-based tests for the safe ZIP extractor path-traversal protection.""" -import contextlib +from __future__ import annotations + import io -import shutil -import uuid +import tempfile import zipfile -from collections.abc import Iterator -from contextlib import contextmanager -from pathlib import Path +from pathlib import Path, PurePosixPath +from typing import TYPE_CHECKING import pytest from hypothesis import given, settings @@ -16,7 +15,11 @@ from language_tool_python._internals.safe_zip import SafeZipExtractor from language_tool_python.exceptions import PathError -_TRAVERSAL_PREFIXES = ["../", "..\\", "/", "C:/", "..\\..\\"] +if TYPE_CHECKING: + from hypothesis.strategies import DrawFn + +_TRAVERSAL_SEGMENTS = st.sampled_from([".."] * 3 + ["."]) +_SEP = st.sampled_from(["/", "\\"]) def _make_zip_payload(files: dict[str, bytes]) -> bytes: @@ -28,47 +31,117 @@ def _make_zip_payload(files: dict[str, bytes]) -> bytes: return buf.getvalue() -@contextmanager -def _temp_dir() -> Iterator[Path]: - """Create a temporary dir inside the project workspace to avoid perm issues.""" - root = Path.cwd() / ".test_prop_safe_zip_tmp" - path = root / uuid.uuid4().hex - path.mkdir(parents=True) - try: - yield path - finally: - shutil.rmtree(path, ignore_errors=True) - with contextlib.suppress(OSError): - root.rmdir() +@st.composite +def adversarial_member_names(draw: DrawFn) -> str: + """Generate adversarial ZIP member names built from unsafe path segments. + Combines repeated ``..`` traversal segments, mixed separators, absolute + paths, Windows drive letters, and UNC paths, so the strategy is not limited + to a small fixed set of prefixes. + """ + depth = draw(st.integers(min_value=1, max_value=4)) + segs = [draw(_TRAVERSAL_SEGMENTS) for _ in range(depth)] + sep = draw(_SEP) + leaf = draw( + st.text( + alphabet=st.characters( + exclude_categories=["Cs"], + exclude_characters="\x00", + ), + min_size=1, + max_size=20, + ), + ) + style = draw(st.sampled_from(["prefix", "embedded", "absolute", "drive", "unc"])) + if style == "prefix": + return sep.join([*segs, leaf]) + if style == "embedded": + return sep.join(["safe", *segs, leaf]) + if style == "absolute": + return sep + leaf + if style == "drive": + return draw(st.sampled_from("CDZ")) + ":" + sep + leaf + return "\\\\server\\share\\" + leaf -@given( - traversal=st.sampled_from(_TRAVERSAL_PREFIXES), - suffix=st.text( - alphabet=st.characters(whitelist_categories=("Ll", "Lu", "Nd")), - min_size=1, - ), -) -@settings(max_examples=300) -def test_prop_safe_zip_path_traversal_always_rejected( - traversal: str, - suffix: str, -) -> None: - """Any ZIP member whose name begins with a path-traversal prefix must be rejected. - Checks that ``SafeZipExtractor`` raises ``PathError`` for filenames like - ``../evil``, ``/etc/passwd``, or ``C:/Windows/file`` regardless of the suffix. +@given(filename=adversarial_member_names()) +@settings(max_examples=300, deadline=None) +def test_prop_safe_zip_path_traversal_always_rejected(filename: str) -> None: + """Any adversarial ZIP member name must be rejected by SafeZipExtractor. + + Checks that ``SafeZipExtractor`` raises ``PathError`` for a wide range of + unsafe filenames (traversal, absolute paths, drive letters, UNC paths) + rather than a small fixed set of hand-picked prefixes. - :param traversal: A path-traversal prefix (e.g. ``../``, ``/``). - :param suffix: Alphanumeric suffix appended after the traversal prefix. + A fresh temporary directory is created per example instead of using a + pytest fixture, since function-scoped fixtures are not reset between + Hypothesis-generated examples within the same test call. + + :param filename: An adversarially generated ZIP member name. :raises AssertionError: If ``PathError`` is not raised for the unsafe member name. """ - filename = traversal + suffix payload = _make_zip_payload({filename: b"payload"}) with ( - _temp_dir() as dest, + tempfile.TemporaryDirectory() as temp_dir, zipfile.ZipFile(io.BytesIO(payload)) as zf, pytest.raises(PathError, match="Unsafe ZIP member"), ): - SafeZipExtractor().extractall(zf, dest) + SafeZipExtractor().extractall(zf, Path(temp_dir) / "destination") + + +@given( + member_path=st.lists( + st.text( + alphabet=st.characters(categories=["Ll", "Lu", "Nd"]), + min_size=1, + max_size=10, + ), + min_size=1, + max_size=5, + ).map(lambda parts: PurePosixPath(*parts)), +) +@settings(max_examples=200) +def test_prop_zip_target_always_inside_destination(member_path: PurePosixPath) -> None: + """``_zip_target`` must always resolve inside the given destination. + + Exercises ``_zip_target`` directly (no ZIP I/O) so a large number of + examples can be run quickly. + + :param member_path: An already-normalized, safe relative POSIX path. + :raises AssertionError: If the resolved target escapes the destination. + """ + with tempfile.TemporaryDirectory() as temp_dir: + destination = Path(temp_dir) / "destination" + destination.mkdir() + + target = SafeZipExtractor()._zip_target(destination, member_path) + + resolved_destination = destination.resolve(strict=True) + resolved_target = target.resolve(strict=False) + assert ( + resolved_target == resolved_destination + or resolved_destination in resolved_target.parents + ) + + +@given(filename=adversarial_member_names()) +@settings(max_examples=300) +def test_prop_normalize_member_path_always_rejects_or_stays_relative( + filename: str, +) -> None: + """``_normalize_member_path`` must either reject or return a safe relative path. + + Exercises ``_normalize_member_path`` directly (no ZIP or filesystem I/O), + so a large number of adversarial examples can be checked quickly. + + :param filename: An adversarially generated ZIP member name. + :raises AssertionError: If a returned path is absolute or escapes upward. + """ + extractor = SafeZipExtractor() + try: + normalized = extractor._normalize_member_path(filename) + except PathError: + return + assert not normalized.is_absolute() + assert ".." not in normalized.parts diff --git a/tests/property/test_prop_utils.py b/tests/property/test_prop_utils.py index 4bc7f20..1a6134a 100644 --- a/tests/property/test_prop_utils.py +++ b/tests/property/test_prop_utils.py @@ -1,9 +1,12 @@ """Property-based tests for the LanguageTool utility functions.""" +from __future__ import annotations + from hypothesis import given, settings from hypothesis import strategies as st import language_tool_python +from language_tool_python.utils import TextStatus, classify_matches @given(text=st.text()) @@ -18,3 +21,40 @@ def test_prop_correct_with_empty_matches_is_identity(text: str) -> None: :raises AssertionError: If the corrected text differs from the input. """ assert language_tool_python.utils.correct(text, []) == text + + +class _StubMatch: + """Minimal stand-in for Match, classify_matches only reads .replacements.""" + + def __init__(self, replacements: list[str]) -> None: + self.replacements = replacements + + +@given( + replacements_per_match=st.lists( + st.lists(st.text(min_size=1), max_size=3), + max_size=5, + ), +) +@settings(max_examples=300) +def test_prop_classify_matches_invariants( + replacements_per_match: list[list[str]], +) -> None: + """classify_matches() follows its documented CORRECT/GARBAGE/FAULTY contract. + + An empty match list is always CORRECT; when every match has no replacements + it is always GARBAGE; when at least one match has a non-empty replacements + list it is always FAULTY. + + :param replacements_per_match: One replacements list per generated match. + :raises AssertionError: If the classification does not follow the contract. + """ + matches = [_StubMatch(replacements) for replacements in replacements_per_match] + status = classify_matches(matches) # type: ignore[arg-type] + + if not matches: + assert status == TextStatus.CORRECT + elif any(match.replacements for match in matches): + assert status == TextStatus.FAULTY + else: + assert status == TextStatus.GARBAGE diff --git a/tests/unit/test_cli_args.py b/tests/unit/test_cli_args.py index 6f79066..bb87026 100644 --- a/tests/unit/test_cli_args.py +++ b/tests/unit/test_cli_args.py @@ -6,10 +6,7 @@ def test_parse_args_enabled_only_with_enable_categories() -> None: - """Test that --enabled-only is accepted when only --enable-categories is provided. - - :raises AssertionError: If parse_args raises an error for this valid combination. - """ + """Test that --enabled-only is accepted with only --enable-categories provided.""" args = parse_args(["-l", "en-US", "--enabled-only", "-E", "TYPOS", "file.txt"]) assert args.enabled_only is True assert args.enable_categories == {"TYPOS"} @@ -36,11 +33,7 @@ def test_parse_args_enabled_only_requires_enable_or_enable_categories() -> None: def test_parse_args_categories() -> None: - """Test that --disable-categories and --enable-categories are parsed correctly. - - :raises AssertionError: If the parsed category sets do not match the expected - values. - """ + """Test that --disable-categories and --enable-categories are parsed correctly.""" args = parse_args( ["-l", "en-US", "-D", "TYPOS,GRAMMAR", "-E", "PUNCTUATION", "file.txt"] ) @@ -49,10 +42,7 @@ def test_parse_args_categories() -> None: def test_parse_args_categories_multiple_flags() -> None: - """Test that repeated -D/-E flags accumulate into the same set. - - :raises AssertionError: If the category sets do not accumulate correctly. - """ + """Test that repeated -D/-E flags accumulate into the same set.""" args = parse_args( ["-l", "en-US", "-D", "TYPOS", "-D", "GRAMMAR", "-E", "PUNCTUATION", "file.txt"] ) diff --git a/tests/unit/test_cli_unit.py b/tests/unit/test_cli_unit.py index 4252258..7c66063 100644 --- a/tests/unit/test_cli_unit.py +++ b/tests/unit/test_cli_unit.py @@ -5,9 +5,13 @@ import io import logging from pathlib import Path +from typing import TYPE_CHECKING import pytest +if TYPE_CHECKING: + from collections.abc import Iterator + from language_tool_python.__main__ import ( CliArgs, _read_project_version, @@ -279,6 +283,17 @@ def _parse_file_args(filename: str, **overrides: object) -> CliArgs: class TestProcessFile: """Tests for process_file() with LanguageTool mocked.""" + @pytest.fixture(autouse=True) + def _reset_last_instance(self) -> Iterator[None]: + """Reset _MockLangTool._last_instance before and after each test. + + Without this, only tests that explicitly reset the class attribute could + reliably assert on the instance created by the test they belong to. + """ + _MockLangTool._last_instance = None + yield + _MockLangTool._last_instance = None + def test_prints_filename_to_stderr_for_multiple_files( self, monkeypatch: pytest.MonkeyPatch, @@ -318,7 +333,6 @@ def test_disables_spellcheck_when_flag_off( """disable_spellchecking() is called when spell_check=False.""" f = tmp_path / "text.txt" f.write_text("hello", encoding="utf-8") - monkeypatch.setattr(_MockLangTool, "_last_instance", None) monkeypatch.setattr( "language_tool_python.__main__.LanguageTool", _MockLangTool, @@ -335,7 +349,6 @@ def test_sets_picky_when_flag_on( """Picky is set to True on the tool when args.picky=True.""" f = tmp_path / "text.txt" f.write_text("hello", encoding="utf-8") - monkeypatch.setattr(_MockLangTool, "_last_instance", None) monkeypatch.setattr( "language_tool_python.__main__.LanguageTool", _MockLangTool, @@ -424,7 +437,61 @@ def test_verbose_flag_sets_debug_logging( _MockLangTool, ) root = logging.getLogger() - monkeypatch.setattr(root, "level", root.level) - result = main(["--verbose", str(f)]) - assert result == 0 - assert root.level == logging.DEBUG + original_level = root.level + try: + result = main(["--verbose", str(f)]) + assert result == 0 + assert root.level == logging.DEBUG + finally: + root.setLevel(original_level) + + def test_status_is_max_across_multiple_files( + self, + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + ) -> None: + """main() aggregates status via max() across all processed files.""" + + class _ConditionalLangTool(_MockLangTool): + def check(self, text: str) -> list[_MockMatch]: + if text == "bad text": + return [_MockMatch(rule_id="SOME_RULE")] + return [] + + clean_file = tmp_path / "clean.txt" + clean_file.write_text("good text", encoding="utf-8") + bad_file = tmp_path / "bad.txt" + bad_file.write_text("bad text", encoding="utf-8") + + monkeypatch.setattr( + "language_tool_python.__main__.LanguageTool", + _ConditionalLangTool, + ) + status_issues = 2 + result = main([str(clean_file), str(bad_file)]) + assert result == status_issues + + def test_remote_server_propagates_to_language_tool_constructor( + self, + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + ) -> None: + """--remote-host/--remote-port flow through to the LanguageTool constructor.""" + f = tmp_path / "text.txt" + f.write_text("hello", encoding="utf-8") + captured_kwargs: list[dict[str, object]] = [] + + class _CapturingLangTool(_MockLangTool): + def __init__(self, **kw: object) -> None: + super().__init__(**kw) + captured_kwargs.append(kw) + + monkeypatch.setattr( + "language_tool_python.__main__.LanguageTool", + _CapturingLangTool, + ) + + main(["--remote-host", "example.test", "--remote-port", "8081", str(f)]) + + assert len(captured_kwargs) == 1 + assert captured_kwargs[0]["remote_server"] == "example.test:8081" diff --git a/tests/unit/test_config_unit.py b/tests/unit/test_config_unit.py index 2005a0d..22b2b8d 100644 --- a/tests/unit/test_config_unit.py +++ b/tests/unit/test_config_unit.py @@ -3,6 +3,7 @@ from __future__ import annotations from pathlib import Path +from typing import TYPE_CHECKING import pytest @@ -16,40 +17,45 @@ _number_encoder, _path_encoder, _path_validator, + _reject_line_breaks, ) from language_tool_python.exceptions import PathError +if TYPE_CHECKING: + from collections.abc import Callable, Iterator, Mapping -class TestBoolEncoder: - """Tests for the _bool_encoder() function.""" - - def test_true(self) -> None: - """True is encoded as the string 'true'.""" - assert _bool_encoder(v=True) == "true" + from language_tool_python.config_file import ConfigValue - def test_false(self) -> None: - """False is encoded as the string 'false'.""" - assert _bool_encoder(v=False) == "false" - def test_truthy_int(self) -> None: - """A truthy integer is encoded as 'true'.""" - assert _bool_encoder(1) == "true" +class TestBoolEncoder: + """Tests for the _bool_encoder() function.""" - def test_falsy_int(self) -> None: - """A falsy integer is encoded as 'false'.""" - assert _bool_encoder(0) == "false" + @pytest.mark.parametrize( + ("value", "expected"), + [ + (True, "true"), + (False, "false"), + (1, "true"), + (0, "false"), + ], + ids=["true", "false", "truthy_int", "falsy_int"], + ) + def test_encodes_bool_value(self, value: bool, expected: str) -> None: + """Truthy/falsy values are encoded as lowercase 'true'/'false'.""" + assert _bool_encoder(value) == expected class TestIntEncoder: """Tests for the _int_encoder() function.""" - def test_positive(self) -> None: - """A positive integer is converted to its decimal string.""" - assert _int_encoder(42) == "42" - - def test_zero(self) -> None: - """Zero is converted to '0'.""" - assert _int_encoder(0) == "0" + @pytest.mark.parametrize( + ("value", "expected"), + [(42, "42"), (0, "0")], + ids=["positive", "zero"], + ) + def test_encodes_int_value(self, value: int, expected: str) -> None: + """Integers are converted to their decimal string representation.""" + assert _int_encoder(value) == expected class TestNumberEncoder: @@ -67,21 +73,21 @@ def test_float(self) -> None: class TestCommaListEncoder: """Tests for the _comma_list_encoder() function.""" - def test_string_passthrough(self) -> None: - """A plain string is returned unchanged.""" - assert _comma_list_encoder("a,b,c") == "a,b,c" - - def test_list_joined(self) -> None: - """A list of strings is joined with commas.""" - assert _comma_list_encoder(["a", "b", "c"]) == "a,b,c" - - def test_tuple_joined(self) -> None: - """A tuple of strings is joined with commas.""" - assert _comma_list_encoder(("x", "y")) == "x,y" - - def test_single_item(self) -> None: - """A single-element list returns the element without a comma.""" - assert _comma_list_encoder(["only"]) == "only" + @pytest.mark.parametrize( + ("value", "expected"), + [ + ("a,b,c", "a,b,c"), + (["a", "b", "c"], "a,b,c"), + (("x", "y"), "x,y"), + (["only"], "only"), + ], + ids=["string_passthrough", "list_joined", "tuple_joined", "single_item"], + ) + def test_encodes_comma_list_value( + self, value: str | list[str] | tuple[str, ...], expected: str + ) -> None: + """Strings pass through unchanged; iterables are comma-joined.""" + assert _comma_list_encoder(value) == expected class TestPathEncoder: @@ -118,28 +124,62 @@ def test_nonexistent_raises(self, tmp_path: Path) -> None: _path_validator(tmp_path / "nonexistent.txt") -class TestIsLangKey: - """Tests for the _is_lang_key() predicate.""" - - def test_lang_code_format(self) -> None: - """A key of the form 'lang-XX' is recognized as a language key.""" - assert _is_lang_key("lang-en") is True - - def test_lang_code_dict_path_format(self) -> None: - """A key of the form 'lang-XX-dictPath' is recognized as a language key.""" - assert _is_lang_key("lang-en-dictPath") is True +class TestRejectLineBreaks: + """Tests for the _reject_line_breaks() config-value validator.""" + + @pytest.mark.parametrize( + "value", + ["line\nbreak", "line\rbreak", "line\r\nbreak"], + ids=["lf", "cr", "crlf"], + ) + def test_raises_on_line_break(self, value: str) -> None: + """Any value containing a line-break character raises ValueError.""" + with pytest.raises(ValueError, match="line breaks"): + _reject_line_breaks("field", value) + + @pytest.mark.parametrize( + "value", + ["one\\", "three\\\\\\"], + ids=["one", "three"], + ) + def test_raises_on_odd_trailing_backslashes(self, value: str) -> None: + """A value ending with an odd number of backslashes raises ValueError.""" + with pytest.raises(ValueError, match="odd number of backslashes"): + _reject_line_breaks("field", value) + + @pytest.mark.parametrize( + "value", + ["no backslash at all", "two\\\\", "four\\\\\\\\"], + ids=["none", "two", "four"], + ) + def test_accepts_even_trailing_backslashes(self, value: str) -> None: + """A value ending with an even number of backslashes does not raise.""" + _reject_line_breaks("field", value) # must not raise - def test_not_lang_prefix(self) -> None: - """A key without the 'lang-' prefix is not a language key.""" - assert _is_lang_key("cacheSize") is False - def test_lang_only_no_code(self) -> None: - """'lang-' with no language code is not a valid language key.""" - assert _is_lang_key("lang-") is False +class TestIsLangKey: + """Tests for the _is_lang_key() predicate.""" - def test_lang_too_many_parts(self) -> None: - """A key with more than three parts is not a valid language key.""" - assert _is_lang_key("lang-en-dictPath-extra") is False + @pytest.mark.parametrize( + ("key", "expected"), + [ + ("lang-en", True), + ("lang-en-dictPath", True), + ("cacheSize", False), + ("lang-", False), + ("lang-en-dictPath-extra", False), + ], + ids=[ + "lang_code_format", + "lang_code_dict_path_format", + "not_lang_prefix", + "lang_only_no_code", + "lang_too_many_parts", + ], + ) + def test_is_lang_key(self, key: str, expected: bool) -> None: + """_is_lang_key() correctly classifies each key shape.""" + assert _is_lang_key(key) is expected class TestEncodeConfig: @@ -195,41 +235,86 @@ def test_path_validator_called(self, tmp_path: Path) -> None: class TestLanguageToolConfig: """Tests for the LanguageToolConfig class.""" + @pytest.fixture() # noqa: PT001 # bare @pytest.fixture resolves to Any under mypy strict here + def make_config( + self, + ) -> Iterator[Callable[[Mapping[str, ConfigValue]], LanguageToolConfig]]: + """Build a LanguageToolConfig factory that deletes its temp files afterwards. + + LanguageToolConfig() creates a real temporary file (normally cleaned up via + an atexit hook that only runs at interpreter shutdown), so without this + fixture, every test in this class would leak a file for the rest of the + test session. + """ + created: list[LanguageToolConfig] = [] + + def _make(config: Mapping[str, ConfigValue]) -> LanguageToolConfig: + cfg = LanguageToolConfig(config) + created.append(cfg) + return cfg + + yield _make + + for cfg in created: + Path(cfg.path).unlink(missing_ok=True) + def test_empty_config_raises(self) -> None: """Constructing with an empty dict raises ValueError.""" with pytest.raises(ValueError, match="cannot be empty"): LanguageToolConfig({}) - def test_valid_config_creates_file(self) -> None: + def test_valid_config_creates_file( + self, make_config: Callable[[Mapping[str, ConfigValue]], LanguageToolConfig] + ) -> None: """A valid config creates a temporary .properties file on disk.""" - cfg = LanguageToolConfig({"cacheSize": 500}) + cfg = make_config({"cacheSize": 500}) assert cfg.path assert Path(cfg.path).exists() - def test_config_file_content(self) -> None: + def test_config_file_content( + self, make_config: Callable[[Mapping[str, ConfigValue]], LanguageToolConfig] + ) -> None: """The .properties file contains the expected key=value pair.""" - cfg = LanguageToolConfig({"cacheSize": 500}) + cfg = make_config({"cacheSize": 500}) content = Path(cfg.path).read_text(encoding="utf-8") assert "cacheSize=500" in content - def test_multiple_options(self) -> None: + def test_multiple_options( + self, make_config: Callable[[Mapping[str, ConfigValue]], LanguageToolConfig] + ) -> None: """Multiple config options all appear in the .properties file.""" - cfg = LanguageToolConfig({"cacheSize": 100, "pipelineCaching": True}) + cfg = make_config({"cacheSize": 100, "pipelineCaching": True}) content = Path(cfg.path).read_text(encoding="utf-8") assert "cacheSize=100" in content assert "pipelineCaching=true" in content - def test_config_dict_stored(self) -> None: + def test_config_dict_stored( + self, make_config: Callable[[Mapping[str, ConfigValue]], LanguageToolConfig] + ) -> None: """The encoded config is stored on the .config attribute.""" - cfg = LanguageToolConfig({"cacheSize": 200}) + cfg = make_config({"cacheSize": 200}) assert cfg.config == {"cacheSize": "200"} - def test_boolean_config(self) -> None: + def test_boolean_config( + self, make_config: Callable[[Mapping[str, ConfigValue]], LanguageToolConfig] + ) -> None: """A boolean config value is encoded as 'true' or 'false'.""" - cfg = LanguageToolConfig({"premiumOnly": False}) + cfg = make_config({"premiumOnly": False}) assert cfg.config == {"premiumOnly": "false"} - def test_list_config(self) -> None: + def test_list_config( + self, make_config: Callable[[Mapping[str, ConfigValue]], LanguageToolConfig] + ) -> None: """A list config value is encoded as a comma-separated string.""" - cfg = LanguageToolConfig({"disabledRuleIds": ["RULE_A", "RULE_B"]}) + cfg = make_config({"disabledRuleIds": ["RULE_A", "RULE_B"]}) assert cfg.config["disabledRuleIds"] == "RULE_A,RULE_B" + + def test_lang_dict_path_end_to_end( + self, + make_config: Callable[[Mapping[str, ConfigValue]], LanguageToolConfig], + tmp_path: Path, + ) -> None: + """A full LanguageToolConfig with a lang-xx-dictPath key writes it to disk.""" + cfg = make_config({"lang-en-dictPath": str(tmp_path)}) + content = Path(cfg.path).read_text(encoding="utf-8") + assert "lang-en-dictPath=" in content diff --git a/tests/unit/test_download.py b/tests/unit/test_download.py index ba92a96..abf4a01 100644 --- a/tests/unit/test_download.py +++ b/tests/unit/test_download.py @@ -3,16 +3,11 @@ These tests use mocks and monkeypatching to avoid real network requests. """ -import contextlib import hashlib -import importlib import io import re -import shutil -import uuid import zipfile from collections.abc import Iterator -from contextlib import contextmanager from datetime import datetime, timezone from pathlib import Path from unittest.mock import patch @@ -21,6 +16,7 @@ import requests import language_tool_python +from language_tool_python._internals.utils import get_env_int from language_tool_python.download_lt import ( _LTP_BYPASS_VERIFIED_DOWNLOADS_ENV_VAR, _LTP_DOWNLOAD_SHA256_ENV_VAR, @@ -77,25 +73,8 @@ def skip_java_compatibility_check(_language_tool_version: str) -> None: """Skip Java compatibility checks in download-only tests.""" -@contextmanager -def workspace_temp_dir() -> Iterator[Path]: - """Create a temporary directory inside the repository workspace.""" - root = Path.cwd() / ".test_download_tmp" - path = root / uuid.uuid4().hex - path.mkdir(parents=True) - try: - yield path - finally: - shutil.rmtree(path, ignore_errors=True) - with contextlib.suppress(OSError): - root.rmdir() - - def test_http_get_403_forbidden() -> None: - """Test that http_get raises PathError when receiving a 403 Forbidden status code. - - :raises AssertionError: If PathError is not raised for a 403 status code. - """ + """Test that http_get raises PathError on a 403 Forbidden status code.""" mock_response = MockDownloadResponse(b"", status_code=403) mock_response.headers = {} @@ -111,27 +90,22 @@ def test_http_get_403_forbidden() -> None: local_language_tool._get_remote_zip(out_file) -def test_http_get_other_error_codes() -> None: - """Test PathError handling for unexpected HTTP status codes. - - :raises AssertionError: If PathError is not raised for error status codes. - """ - error_codes = [500, 502, 503, 504] - - for error_code in error_codes: - mock_response = MockDownloadResponse(b"", status_code=error_code) - mock_response.headers = {} +@pytest.mark.parametrize("error_code", [500, 502, 503, 504]) +def test_http_get_other_error_codes(error_code: int) -> None: + """Test PathError handling for unexpected HTTP status codes.""" + mock_response = MockDownloadResponse(b"", status_code=error_code) + mock_response.headers = {} - out_file = io.BytesIO() - local_language_tool = LocalLanguageTool.from_version_name() - with ( - patch( - "language_tool_python.download_lt.requests.get", - return_value=mock_response, - ), - pytest.raises(PathError, match=f"Failed to download.*{error_code}"), - ): - local_language_tool._get_remote_zip(out_file) + out_file = io.BytesIO() + local_language_tool = LocalLanguageTool.from_version_name() + with ( + patch( + "language_tool_python.download_lt.requests.get", + return_value=mock_response, + ), + pytest.raises(PathError, match=f"Failed to download.*{error_code}"), + ): + local_language_tool._get_remote_zip(out_file) def test_http_get_rejects_oversized_content_length( @@ -159,19 +133,12 @@ def test_max_download_bytes_uses_env_override( monkeypatch: pytest.MonkeyPatch, ) -> None: """Test that the download size limit can be configured from the environment.""" - try: - with monkeypatch.context() as env: - env.setenv( - _LTP_MAX_DOWNLOAD_BYTES_ENV_VAR, str(EXPECTED_DOWNLOAD_BYTES_OVERRIDE) - ) - importlib.reload(language_tool_python.download_lt) + monkeypatch.setenv( + _LTP_MAX_DOWNLOAD_BYTES_ENV_VAR, str(EXPECTED_DOWNLOAD_BYTES_OVERRIDE) + ) - assert ( - language_tool_python.download_lt._MAX_DOWNLOAD_BYTES - == EXPECTED_DOWNLOAD_BYTES_OVERRIDE - ) - finally: - importlib.reload(language_tool_python.download_lt) + result = get_env_int(_LTP_MAX_DOWNLOAD_BYTES_ENV_VAR, 1) + assert result == EXPECTED_DOWNLOAD_BYTES_OVERRIDE def test_http_get_rejects_oversized_stream( @@ -248,7 +215,9 @@ def test_latest_snapshot_uses_latest_download_url_and_current_date( "https://example.test/snapshots/", ) - FixedDatetime.current_datetime = datetime(2024, 5, 14, tzinfo=timezone.utc) + monkeypatch.setattr( + FixedDatetime, "current_datetime", datetime(2024, 5, 14, tzinfo=timezone.utc) + ) monkeypatch.setattr(language_tool_python.download_lt, "datetime", FixedDatetime) local_language_tool = LocalLanguageTool.from_version_name("latest") @@ -442,6 +411,7 @@ def test_http_get_bypass_skips_sha256_verification( def test_snapshot_download_renames_archive_root_to_requested_date( monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, ) -> None: """Test that date-pinned snapshots are installed under the requested date name.""" requested_snapshot = "20240101" @@ -454,24 +424,21 @@ def test_snapshot_download_renames_archive_root_to_requested_date( "_confirm_java_compatibility", skip_java_compatibility_check, ) + monkeypatch.setattr( + language_tool_python.download_lt, + "get_language_tool_download_path", + lambda: tmp_path, + ) - with ( - workspace_temp_dir() as temp_dir, - patch( - "language_tool_python.download_lt.requests.get", - return_value=MockDownloadResponse(payload), - ), + with patch( + "language_tool_python.download_lt.requests.get", + return_value=MockDownloadResponse(payload), ): - monkeypatch.setattr( - language_tool_python.download_lt, - "get_language_tool_download_path", - lambda: temp_dir, - ) local_language_tool.download() - expected_dir = temp_dir / f"LanguageTool-{requested_snapshot}" + expected_dir = tmp_path / f"LanguageTool-{requested_snapshot}" assert (expected_dir / "languagetool-server.jar").read_bytes() == b"jar" - assert not (temp_dir / "LanguageTool-6.9-SNAPSHOT").exists() + assert not (tmp_path / "LanguageTool-6.9-SNAPSHOT").exists() assert local_language_tool.get_directory_path() == expected_dir with patch("language_tool_python.download_lt.requests.get") as get_mock: @@ -512,6 +479,7 @@ def test_http_get_timeout_raises_timeout_error() -> None: def test_snapshot_download_raises_when_archive_has_multiple_root_dirs( monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, ) -> None: """download() raises PathError when the snapshot archive has multiple root dirs.""" payload = make_zip_payload( @@ -526,31 +494,33 @@ def test_snapshot_download_raises_when_archive_has_multiple_root_dirs( "_confirm_java_compatibility", skip_java_compatibility_check, ) + monkeypatch.setattr( + language_tool_python.download_lt, + "get_language_tool_download_path", + lambda: tmp_path, + ) with ( - workspace_temp_dir() as temp_dir, patch( "language_tool_python.download_lt.requests.get", return_value=MockDownloadResponse(payload), ), + pytest.raises(PathError, match="Expected snapshot archive"), ): - monkeypatch.setattr( - language_tool_python.download_lt, - "get_language_tool_download_path", - lambda: temp_dir, - ) - with pytest.raises(PathError, match="Expected snapshot archive"): - local_language_tool.download() + local_language_tool.download() def test_latest_snapshot_download_renames_archive_root_to_current_date( monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, ) -> None: """Test that latest snapshots are installed under the current date name.""" current_snapshot_date = "20240514" payload = make_zip_payload( {"LanguageTool-6.9-SNAPSHOT/languagetool-server.jar": b"jar"}, ) - FixedDatetime.current_datetime = datetime(2024, 5, 14, tzinfo=timezone.utc) + monkeypatch.setattr( + FixedDatetime, "current_datetime", datetime(2024, 5, 14, tzinfo=timezone.utc) + ) monkeypatch.setattr(language_tool_python.download_lt, "datetime", FixedDatetime) local_language_tool = LocalLanguageTool.from_version_name("latest") monkeypatch.setattr( @@ -558,24 +528,21 @@ def test_latest_snapshot_download_renames_archive_root_to_current_date( "_confirm_java_compatibility", skip_java_compatibility_check, ) + monkeypatch.setattr( + language_tool_python.download_lt, + "get_language_tool_download_path", + lambda: tmp_path, + ) - with ( - workspace_temp_dir() as temp_dir, - patch( - "language_tool_python.download_lt.requests.get", - return_value=MockDownloadResponse(payload), - ), + with patch( + "language_tool_python.download_lt.requests.get", + return_value=MockDownloadResponse(payload), ): - monkeypatch.setattr( - language_tool_python.download_lt, - "get_language_tool_download_path", - lambda: temp_dir, - ) local_language_tool.download() - expected_dir = temp_dir / f"LanguageTool-{current_snapshot_date}" + expected_dir = tmp_path / f"LanguageTool-{current_snapshot_date}" assert (expected_dir / "languagetool-server.jar").read_bytes() == b"jar" - assert not (temp_dir / "LanguageTool-6.9-SNAPSHOT").exists() + assert not (tmp_path / "LanguageTool-6.9-SNAPSHOT").exists() assert local_language_tool.get_directory_path() == expected_dir with patch("language_tool_python.download_lt.requests.get") as get_mock: diff --git a/tests/unit/test_download_unit.py b/tests/unit/test_download_unit.py index dc17699..b518a53 100644 --- a/tests/unit/test_download_unit.py +++ b/tests/unit/test_download_unit.py @@ -57,6 +57,11 @@ def _check_output_java17(*_args: object, **_kw: object) -> str: return "openjdk 17.0.1 2021-10-19" +def _check_output_java9(*_args: object, **_kw: object) -> str: + """Stub for subprocess.check_output returning a Java 9 version string.""" + return 'java version "9.0.4"' + + def _noop(_v: object) -> None: """No-operation stub for monkeypatching void functions.""" @@ -365,6 +370,12 @@ def test_java_17_passes(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(subprocess, "check_output", _check_output_java17) _dl._confirm_java_compatibility("6.8") + def test_java_9_passes_for_old_lt(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Java 9 satisfies an old LT version's lower requirement without error.""" + monkeypatch.setattr("language_tool_python.download_lt.which", _which_java) + monkeypatch.setattr(subprocess, "check_output", _check_output_java9) + _dl._confirm_java_compatibility("5.0") + class TestGetInstalledVersions: """Tests for LocalLanguageTool.get_installed_versions().""" @@ -389,6 +400,20 @@ def test_returns_installed_versions( assert "6.8" in version_names assert "6.7" in version_names + def test_skips_malformed_version_directory( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """A directory matching 'LanguageTool-*' but with an unparseable version name + is silently skipped (the ValueError from from_path() is suppressed) rather + than propagating. + """ # noqa: D205 + (tmp_path / "LanguageTool-not-a-real-version").mkdir() + (tmp_path / "LanguageTool-6.8").mkdir() + monkeypatch.setenv("LTP_PATH", str(tmp_path)) + versions = _dl.LocalLanguageTool.get_installed_versions() + version_names = [v.version_name for v in versions] + assert version_names == ["6.8"] + class TestGetLatestInstalledVersion: """Tests for LocalLanguageTool.get_latest_installed_version().""" diff --git a/tests/unit/test_internals_utils.py b/tests/unit/test_internals_utils.py index 4f23034..de8cef1 100644 --- a/tests/unit/test_internals_utils.py +++ b/tests/unit/test_internals_utils.py @@ -2,9 +2,6 @@ from __future__ import annotations -import subprocess -import sys -import time from typing import TYPE_CHECKING import psutil @@ -153,6 +150,33 @@ def test_returns_string(self) -> None: assert len(result) > 0 +class _MockPsutilProcess: + """Minimal stand-in for psutil.Process used in kill_process_force tests.""" + + def __init__( + self, + pid: int, + children: list[_MockPsutilProcess] | None = None, + *, + raise_on_kill: bool = False, + ) -> None: + """Initialise with a pid, optional children, and an optional kill() failure.""" + self.pid = pid + self.killed = False + self._children = children or [] + self._raise_on_kill = raise_on_kill + + def children(self, *, recursive: bool = False) -> list[_MockPsutilProcess]: # noqa: ARG002 + """Return the configured child processes.""" + return self._children + + def kill(self) -> None: + """Mark the process as killed, or raise NoSuchProcess if so configured.""" + if self._raise_on_kill: + raise psutil.NoSuchProcess(self.pid) + self.killed = True + + class TestKillProcessForce: """Tests for kill_process_force() process terminator.""" @@ -161,44 +185,51 @@ def test_raises_when_no_args(self) -> None: with pytest.raises(ValueError, match="Must pass either pid or proc"): kill_process_force() - def test_kills_by_pid(self) -> None: + def test_kills_by_pid(self, monkeypatch: pytest.MonkeyPatch) -> None: """A process is terminated when its pid is given.""" - proc = subprocess.Popen( - [sys.executable, "-c", "import time; time.sleep(60)"], - ) - kill_process_force(pid=proc.pid) - proc.wait(timeout=5) + mock_proc = _MockPsutilProcess(pid=123) + + def _fake_process(_pid: int) -> _MockPsutilProcess: + return mock_proc + + monkeypatch.setattr(psutil, "Process", _fake_process) + kill_process_force(pid=123) + assert mock_proc.killed def test_kills_by_proc(self) -> None: """A process is terminated when a psutil.Process object is given.""" - proc = subprocess.Popen( - [sys.executable, "-c", "import time; time.sleep(60)"], - ) - ps_proc = psutil.Process(proc.pid) - kill_process_force(proc=ps_proc) - proc.wait(timeout=5) + mock_proc = _MockPsutilProcess(pid=123) + kill_process_force(proc=mock_proc) # type: ignore[arg-type] + assert mock_proc.killed def test_kills_process_with_children(self) -> None: """A process and its children are all terminated.""" - parent = subprocess.Popen( - [ - sys.executable, - "-c", - ( - "import subprocess, sys, time; " - "subprocess.Popen([sys.executable, '-c', " - "'import time; time.sleep(60)']); " - "time.sleep(60)" - ), - ], - ) - time.sleep(0.3) - kill_process_force(pid=parent.pid) - parent.wait(timeout=10) - - def test_nonexistent_pid_is_silent(self) -> None: + child = _MockPsutilProcess(pid=456) + parent = _MockPsutilProcess(pid=123, children=[child]) + kill_process_force(proc=parent) # type: ignore[arg-type] + assert child.killed + assert parent.killed + + def test_nonexistent_pid_is_silent(self, monkeypatch: pytest.MonkeyPatch) -> None: """A nonexistent pid is silently ignored.""" - kill_process_force(pid=999999999) + + def _raise_no_such_process(pid: int) -> _MockPsutilProcess: + raise psutil.NoSuchProcess(pid) + + monkeypatch.setattr(psutil, "Process", _raise_no_such_process) + kill_process_force(pid=999999999) # must not raise + + def test_suppresses_no_such_process_on_child_race(self) -> None: + """A child that vanishes before kill() (NoSuchProcess) is silently skipped.""" + child = _MockPsutilProcess(pid=456, raise_on_kill=True) + parent = _MockPsutilProcess(pid=123, children=[child]) + kill_process_force(proc=parent) # type: ignore[arg-type] # must not raise + assert parent.killed + + def test_suppresses_no_such_process_on_parent_race(self) -> None: + """A parent that vanishes before kill() (NoSuchProcess) is silently skipped.""" + parent = _MockPsutilProcess(pid=123, raise_on_kill=True) + kill_process_force(proc=parent) # type: ignore[arg-type] # must not raise class TestVersionTuple: diff --git a/tests/unit/test_language_tag.py b/tests/unit/test_language_tag.py index 7ac4d3e..01ba5a3 100644 --- a/tests/unit/test_language_tag.py +++ b/tests/unit/test_language_tag.py @@ -46,19 +46,14 @@ def test_languages_stored(self) -> None: class TestNormalizePosix: """Tests for POSIX/C locale fallback behaviour.""" - def test_c_locale_falls_back_to_en_us(self) -> None: - """'C' locale resolves to en-US when available.""" - lt = _tag("C") - assert lt.normalized_tag == "en-US" - - def test_posix_locale_falls_back_to_en_us(self) -> None: - """'POSIX' locale resolves to en-US when available.""" - lt = _tag("POSIX") - assert lt.normalized_tag == "en-US" - - def test_c_dot_variant(self) -> None: - """'C.UTF-8' resolves to en-US when available.""" - lt = _tag("C.UTF-8") + @pytest.mark.parametrize( + "tag", + ["C", "POSIX", "C.UTF-8"], + ids=["c_locale", "posix_locale", "c_dot_variant"], + ) + def test_posix_like_tag_falls_back_to_en_us(self, tag: str) -> None: + """POSIX-like locale tags resolve to en-US when available.""" + lt = _tag(tag) assert lt.normalized_tag == "en-US" def test_posix_prefers_en_gb_when_no_en_us(self) -> None: @@ -80,35 +75,36 @@ def test_posix_raises_when_no_english(self) -> None: class TestNormalizeFallback: """Tests for regex-based region-stripping fallback.""" - def test_language_only_matches_base(self) -> None: - """A bare language code matches the base language entry.""" - lt = _tag("en") - assert lt.normalized_tag == "en" - - def test_regex_fallback_to_base_language(self) -> None: - """An exact-match tag is returned as-is.""" - lt = _tag("pt-BR") - assert lt.normalized_tag == "pt-BR" + @pytest.mark.parametrize( + ("tag", "expected"), + [("en", "en"), ("pt-BR", "pt-BR")], + ids=["language_only_matches_base", "regex_fallback_to_base_language"], + ) + def test_normalizes_against_default_languages( + self, tag: str, expected: str + ) -> None: + """A bare or exact-match tag normalizes as expected against _LANGS.""" + lt = _tag(tag) + assert lt.normalized_tag == expected def test_regex_fallback_strips_region(self) -> None: """A tag with an unavailable region falls back to the base language.""" lt = LanguageTag("en-AU", ["en", "de-DE"]) assert lt.normalized_tag == "en" - def test_empty_tag_raises(self) -> None: - """An empty tag string raises ValueError.""" - with pytest.raises(ValueError, match="empty language tag"): - _tag("") - - def test_unsupported_tag_raises(self) -> None: - """A tag with no match raises ValueError.""" - with pytest.raises(ValueError, match="unsupported language"): - _tag("zz-ZZ") - - def test_unmatched_pattern_raises(self) -> None: - """A non-language-like string raises ValueError.""" - with pytest.raises(ValueError, match="unsupported language"): - _tag("123invalid") + @pytest.mark.parametrize( + ("tag", "match"), + [ + ("", "empty language tag"), + ("zz-ZZ", "unsupported language"), + ("123invalid", "unsupported language"), + ], + ids=["empty_tag", "unsupported_tag", "unmatched_pattern"], + ) + def test_invalid_tag_raises(self, tag: str, match: str) -> None: + """Empty, unsupported, or unmatched tags all raise ValueError.""" + with pytest.raises(ValueError, match=match): + _tag(tag) class TestComparisons: diff --git a/tests/unit/test_match.py b/tests/unit/test_match.py index db36637..d58b58c 100644 --- a/tests/unit/test_match.py +++ b/tests/unit/test_match.py @@ -160,27 +160,26 @@ def test_same_text_reuses_positions(self) -> None: class TestFourByteCharPositions: """Tests for four_byte_char_positions() helper.""" - def test_empty_string(self) -> None: - """An empty string has no 4-byte char positions.""" - assert four_byte_char_positions("") == [] - - def test_ascii_only(self) -> None: - """A pure-ASCII string has no 4-byte char positions.""" - assert four_byte_char_positions("hello") == [] - - def test_emoji_at_start(self) -> None: - """An emoji at position 0 is reported at index 0.""" - assert four_byte_char_positions("🌅abc") == [0] - - def test_multiple_emojis(self) -> None: - """Two consecutive emojis are reported with adjusted indices.""" - positions = four_byte_char_positions("🌅🎉abc") - assert positions == [0, 2] - - def test_emoji_in_middle(self) -> None: - """An emoji in the middle of ASCII text is reported at the correct index.""" - positions = four_byte_char_positions("ab🌅cd") - assert positions == [2] + @pytest.mark.parametrize( + ("text", "expected"), + [ + ("", []), + ("hello", []), + ("🌅abc", [0]), + ("🌅🎉abc", [0, 2]), + ("ab🌅cd", [2]), + ], + ids=[ + "empty_string", + "ascii_only", + "emoji_at_start", + "multiple_emojis", + "emoji_in_middle", + ], + ) + def test_four_byte_char_positions(self, text: str, expected: list[int]) -> None: + """4-byte encoded character positions are reported correctly.""" + assert four_byte_char_positions(text) == expected class TestMatchOrderedDict: @@ -218,11 +217,14 @@ def test_valid_check_match(self) -> None: """A fully populated attrib dict is recognised as a CheckMatch.""" assert is_check_match(_make_attrib()) - def test_not_dict(self) -> None: + @pytest.mark.parametrize( + "value", + ["not a dict", None, 42], + ids=["string", "none", "int"], + ) + def test_not_dict(self, value: object) -> None: """Non-dict values are rejected.""" - assert not is_check_match("not a dict") - assert not is_check_match(None) - assert not is_check_match(42) + assert not is_check_match(value) def test_missing_field(self) -> None: """A dict missing a required field is rejected.""" @@ -266,6 +268,21 @@ def test_str_no_replacements_skips_suggestion(self) -> None: m = _make_match(replacements=[]) assert "Suggestion" not in str(m) + def test_str_indicator_line_position_and_length(self) -> None: + """The '^^^^' indicator line is padded and sized to offset_in_context/length.""" + context_text = "This is noot okay." + context_offset = 8 + length = 4 + m = _make_match( + text=context_text, + context_text=context_text, + context_offset=context_offset, + length=length, + ) + lines = str(m).splitlines() + indicator_line = lines[-1] + assert indicator_line == " " * context_offset + "^" * length + class TestMatchRepr: """Tests for Match.__repr__() machine-readable formatter.""" @@ -280,6 +297,44 @@ def test_repr_contains_rule_id(self) -> None: m = _make_match() assert "MORFOLOGIK_RULE_EN_US" in repr(m) + def test_repr_full_structure_and_field_order(self) -> None: + """The repr contains every field, in the documented order, exactly once.""" + m = _make_match() + r = repr(m) + assert r.startswith("Match({") + assert r.endswith("})") + + expected_field_order = [ + "rule_id", + "message", + "replacements", + "offset_in_context", + "context", + "offset", + "error_length", + "category", + "rule_issue_type", + "sentence", + ] + field_positions = [r.index(f"'{field}':") for field in expected_field_order] + assert field_positions == sorted(field_positions) + + expected_repr = ( + "Match({" + "'rule_id': 'MORFOLOGIK_RULE_EN_US', " + "'message': 'Possible spelling mistake.', " + "'replacements': ['not', 'noon'], " + "'offset_in_context': 8, " + "'context': 'This is noot okay.', " + "'offset': 8, " + "'error_length': 4, " + "'category': 'TYPOS', " + "'rule_issue_type': 'misspelling', " + "'sentence': 'This is noot okay.'" + "})" + ) + assert r == expected_repr + class TestMatchedText: """Tests for the matched_text property.""" diff --git a/tests/unit/test_safe_zip.py b/tests/unit/test_safe_zip.py index c0ff944..b6f35bd 100644 --- a/tests/unit/test_safe_zip.py +++ b/tests/unit/test_safe_zip.py @@ -1,16 +1,10 @@ """Unit tests for safe ZIP extraction.""" -import contextlib import hashlib -import importlib import io -import shutil import stat import unittest.mock -import uuid import zipfile -from collections.abc import Iterator -from contextlib import contextmanager from pathlib import Path import pytest @@ -66,21 +60,7 @@ def make_symlink_or_skip( pytest.skip(f"Cannot create symlink for this test: {error}") -@contextmanager -def workspace_temp_dir() -> Iterator[Path]: - """Create a temporary directory inside the repository workspace.""" - root = Path.cwd() / ".test_safe_zip_tmp" - path = root / uuid.uuid4().hex - path.mkdir(parents=True) - try: - yield path - finally: - shutil.rmtree(path, ignore_errors=True) - with contextlib.suppress(OSError): - root.rmdir() - - -def test_safe_extract_allows_regular_zip() -> None: +def test_safe_extract_allows_regular_zip(tmp_path: Path) -> None: """Test that a regular ZIP is extracted by the safe extractor.""" payload = make_zip_payload( { @@ -89,65 +69,40 @@ def test_safe_extract_allows_regular_zip() -> None: }, ) - with ( - workspace_temp_dir() as temp_dir, - zipfile.ZipFile(io.BytesIO(payload)) as zip_file, - ): - SafeZipExtractor().extractall(zip_file, temp_dir) + with zipfile.ZipFile(io.BytesIO(payload)) as zip_file: + SafeZipExtractor().extractall(zip_file, tmp_path) assert ( - temp_dir / "LanguageTool-6.9-SNAPSHOT" / "languagetool-server.jar" + tmp_path / "LanguageTool-6.9-SNAPSHOT" / "languagetool-server.jar" ).read_bytes() == b"jar" -def test_safe_zip_limits_use_env_overrides( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Test that safe ZIP limits can be configured from the environment.""" - try: - with monkeypatch.context() as env: - env.setenv( - safe_zip.LTP_SAFE_ZIP_MAX_ARCHIVE_BYTES_ENV_VAR, - str(EXPECTED_MAX_ARCHIVE_BYTES), - ) - env.setenv( - safe_zip.LTP_SAFE_ZIP_MAX_EXTRACTED_BYTES_ENV_VAR, - str(EXPECTED_MAX_EXTRACTED_BYTES), - ) - env.setenv( - safe_zip.LTP_SAFE_ZIP_MAX_MEMBERS_ENV_VAR, str(EXPECTED_MAX_MEMBERS) - ) - env.setenv( - safe_zip.LTP_SAFE_ZIP_MAX_MEMBER_EXTRACTED_BYTES_ENV_VAR, - str(EXPECTED_MAX_MEMBER_EXTRACTED_BYTES), - ) - env.setenv( - safe_zip.LTP_SAFE_ZIP_MAX_MEMBER_COMPRESSION_RATIO_ENV_VAR, - str(EXPECTED_MAX_MEMBER_COMPRESSION_RATIO), - ) - env.setenv( - safe_zip.LTP_SAFE_ZIP_MAX_TOTAL_COMPRESSION_RATIO_ENV_VAR, - str(EXPECTED_MAX_TOTAL_COMPRESSION_RATIO), - ) - - importlib.reload(safe_zip) - limits = safe_zip.SafeZipLimits() - - assert limits.max_archive_bytes == EXPECTED_MAX_ARCHIVE_BYTES - assert limits.max_extracted_bytes == EXPECTED_MAX_EXTRACTED_BYTES - assert limits.max_members == EXPECTED_MAX_MEMBERS - assert ( - limits.max_member_extracted_bytes == EXPECTED_MAX_MEMBER_EXTRACTED_BYTES - ) - assert ( - limits.max_member_compression_ratio - == EXPECTED_MAX_MEMBER_COMPRESSION_RATIO - ) - assert ( - limits.max_total_compression_ratio - == EXPECTED_MAX_TOTAL_COMPRESSION_RATIO - ) - finally: - importlib.reload(safe_zip) +def test_safe_zip_limits_defaults_wired_to_module_constants() -> None: + """Test that SafeZipLimits() defaults are wired to the module DEFAULT_* constants. + + The module-level ``DEFAULT_*`` constants are computed once at import time via + :func:`get_env_int`/:func:`get_env_float` (whose environment-variable-override + branch is covered directly by ``TestGetEnvInt``/``TestGetEnvFloat`` in + ``test_internals_utils.py``, and by + ``test_safe_zip_float_env_rejects_non_finite_values`` below). This test instead + verifies the downstream wiring: that each ``SafeZipLimits`` field default is the + corresponding module constant, without needing to reload the module (which would + leak state across tests and require careful cleanup). + """ + limits = SafeZipLimits() + assert limits.max_archive_bytes == safe_zip.DEFAULT_MAX_ARCHIVE_BYTES + assert limits.max_extracted_bytes == safe_zip.DEFAULT_MAX_EXTRACTED_BYTES + assert limits.max_members == safe_zip.DEFAULT_MAX_MEMBERS + assert ( + limits.max_member_extracted_bytes == safe_zip.DEFAULT_MAX_MEMBER_EXTRACTED_BYTES + ) + assert ( + limits.max_member_compression_ratio + == safe_zip.DEFAULT_MAX_MEMBER_COMPRESSION_RATIO + ) + assert ( + limits.max_total_compression_ratio + == safe_zip.DEFAULT_MAX_TOTAL_COMPRESSION_RATIO + ) @pytest.mark.parametrize("configured", ["nan", "inf"]) @@ -190,19 +145,19 @@ def test_safe_zip_float_env_rejects_non_finite_values( ) def test_safe_extract_rejects_unsafe_member_names( filename: str, + tmp_path: Path, ) -> None: """Test that unsafe ZIP member names are rejected.""" payload = make_zip_payload({filename: b"nope"}) with ( - workspace_temp_dir() as temp_dir, zipfile.ZipFile(io.BytesIO(payload)) as zip_file, pytest.raises(PathError, match="Unsafe ZIP member"), ): - SafeZipExtractor().extractall(zip_file, temp_dir) + SafeZipExtractor().extractall(zip_file, tmp_path) -def test_safe_extract_rejects_duplicate_member_paths() -> None: +def test_safe_extract_rejects_duplicate_member_paths(tmp_path: Path) -> None: """Test that duplicate ZIP member paths are rejected before extraction.""" buffer = io.BytesIO() with zipfile.ZipFile(buffer, "w") as zip_file: @@ -211,14 +166,13 @@ def test_safe_extract_rejects_duplicate_member_paths() -> None: buffer.seek(0) with ( - workspace_temp_dir() as temp_dir, zipfile.ZipFile(buffer) as zip_file, pytest.raises(PathError, match="duplicate ZIP member path"), ): - SafeZipExtractor().extractall(zip_file, temp_dir) + SafeZipExtractor().extractall(zip_file, tmp_path) -def test_safe_extract_rejects_file_directory_conflict() -> None: +def test_safe_extract_rejects_file_directory_conflict(tmp_path: Path) -> None: """Test that archives reject file-and-child path conflicts.""" payload = make_zip_payload( { @@ -228,14 +182,15 @@ def test_safe_extract_rejects_file_directory_conflict() -> None: ) with ( - workspace_temp_dir() as temp_dir, zipfile.ZipFile(io.BytesIO(payload)) as zip_file, pytest.raises(PathError, match=r"below file path|file over directory path"), ): - SafeZipExtractor().extractall(zip_file, temp_dir) + SafeZipExtractor().extractall(zip_file, tmp_path) -def test_safe_extract_rejects_file_directory_conflict_in_reverse_order() -> None: +def test_safe_extract_rejects_file_directory_conflict_in_reverse_order( + tmp_path: Path, +) -> None: """Test that archives cannot replace a directory path with a file path.""" payload = make_zip_payload( { @@ -245,14 +200,13 @@ def test_safe_extract_rejects_file_directory_conflict_in_reverse_order() -> None ) with ( - workspace_temp_dir() as temp_dir, zipfile.ZipFile(io.BytesIO(payload)) as zip_file, pytest.raises(PathError, match=r"below file path|file over directory path"), ): - SafeZipExtractor().extractall(zip_file, temp_dir) + SafeZipExtractor().extractall(zip_file, tmp_path) -def test_safe_extract_rejects_zip_symlink() -> None: +def test_safe_extract_rejects_zip_symlink(tmp_path: Path) -> None: """Test that ZIP symlink entries are rejected.""" member = zipfile.ZipInfo("LanguageTool/link") member.create_system = 3 @@ -260,24 +214,20 @@ def test_safe_extract_rejects_zip_symlink() -> None: payload = make_zip_payload_from_info(member, b"target") with ( - workspace_temp_dir() as temp_dir, zipfile.ZipFile(io.BytesIO(payload)) as zip_file, pytest.raises(PathError, match="symlink"), ): - SafeZipExtractor().extractall(zip_file, temp_dir) + SafeZipExtractor().extractall(zip_file, tmp_path) -def test_safe_extract_rejects_symlinked_destination() -> None: +def test_safe_extract_rejects_symlinked_destination(tmp_path: Path) -> None: """Test that the final destination itself cannot be a symlink.""" payload = make_zip_payload({"LanguageTool/file.txt": b"jar"}) - with ( - workspace_temp_dir() as temp_dir, - zipfile.ZipFile(io.BytesIO(payload)) as zip_file, - ): - real_destination = temp_dir / "real-destination" + with zipfile.ZipFile(io.BytesIO(payload)) as zip_file: + real_destination = tmp_path / "real-destination" real_destination.mkdir() - destination_link = temp_dir / "destination-link" + destination_link = tmp_path / "destination-link" make_symlink_or_skip( real_destination, destination_link, @@ -288,23 +238,20 @@ def test_safe_extract_rejects_symlinked_destination() -> None: SafeZipExtractor().extractall( zip_file, destination_link, - work_dir=temp_dir / "work", + work_dir=tmp_path / "work", ) assert not (real_destination / "LanguageTool").exists() -def test_safe_extract_rejects_existing_symlink_in_destination() -> None: +def test_safe_extract_rejects_existing_symlink_in_destination(tmp_path: Path) -> None: """Test that an existing destination symlink cannot redirect extracted content.""" payload = make_zip_payload({"LanguageTool/file.txt": b"jar"}) - with ( - workspace_temp_dir() as temp_dir, - zipfile.ZipFile(io.BytesIO(payload)) as zip_file, - ): - destination = temp_dir / "destination" + with zipfile.ZipFile(io.BytesIO(payload)) as zip_file: + destination = tmp_path / "destination" destination.mkdir() - outside = temp_dir / "outside" + outside = tmp_path / "outside" outside.mkdir() make_symlink_or_skip( outside, @@ -319,34 +266,31 @@ def test_safe_extract_rejects_existing_symlink_in_destination() -> None: SafeZipExtractor().extractall( zip_file, destination, - work_dir=temp_dir / "work", + work_dir=tmp_path / "work", ) assert not (outside / "file.txt").exists() -def test_safe_extract_rejects_symlinked_work_dir() -> None: +def test_safe_extract_rejects_symlinked_work_dir(tmp_path: Path) -> None: """Test that the private extraction work directory cannot be a symlink.""" payload = make_zip_payload({"LanguageTool/file.txt": b"jar"}) - with ( - workspace_temp_dir() as temp_dir, - zipfile.ZipFile(io.BytesIO(payload)) as zip_file, - ): - work_target = temp_dir / "work-target" + with zipfile.ZipFile(io.BytesIO(payload)) as zip_file: + work_target = tmp_path / "work-target" work_target.mkdir() - work_link = temp_dir / "work-link" + work_link = tmp_path / "work-link" make_symlink_or_skip(work_target, work_link, target_is_directory=True) with pytest.raises(PathError, match="private extraction directory"): SafeZipExtractor().extractall( zip_file, - temp_dir / "destination", + tmp_path / "destination", work_dir=work_link, ) -def test_safe_extract_rejects_special_zip_member_type() -> None: +def test_safe_extract_rejects_special_zip_member_type(tmp_path: Path) -> None: """Test that non-file, non-directory ZIP entries are rejected.""" member = zipfile.ZipInfo("LanguageTool/fifo") member.create_system = 3 @@ -354,14 +298,13 @@ def test_safe_extract_rejects_special_zip_member_type() -> None: payload = make_zip_payload_from_info(member, b"") with ( - workspace_temp_dir() as temp_dir, zipfile.ZipFile(io.BytesIO(payload)) as zip_file, pytest.raises(PathError, match="unsupported ZIP member type"), ): - SafeZipExtractor().extractall(zip_file, temp_dir) + SafeZipExtractor().extractall(zip_file, tmp_path) -def test_safe_extract_allows_multiple_safe_roots() -> None: +def test_safe_extract_allows_multiple_safe_roots(tmp_path: Path) -> None: """Test that safe extraction does not require a LanguageTool-specific root.""" payload = make_zip_payload( { @@ -370,27 +313,21 @@ def test_safe_extract_allows_multiple_safe_roots() -> None: }, ) - with ( - workspace_temp_dir() as temp_dir, - zipfile.ZipFile(io.BytesIO(payload)) as zip_file, - ): - destination = temp_dir / "destination" - work_dir = temp_dir / "work" + with zipfile.ZipFile(io.BytesIO(payload)) as zip_file: + destination = tmp_path / "destination" + work_dir = tmp_path / "work" SafeZipExtractor().extractall(zip_file, destination, work_dir=work_dir) assert (destination / "first" / "file.txt").read_bytes() == b"one" assert (destination / "second" / "file.txt").read_bytes() == b"two" -def test_safe_extract_rejects_existing_destination_path() -> None: +def test_safe_extract_rejects_existing_destination_path(tmp_path: Path) -> None: """Test that extraction never overwrites an existing final destination path.""" payload = make_zip_payload({"file.txt": b"new"}) - with ( - workspace_temp_dir() as temp_dir, - zipfile.ZipFile(io.BytesIO(payload)) as zip_file, - ): - destination = temp_dir / "destination" + with zipfile.ZipFile(io.BytesIO(payload)) as zip_file: + destination = tmp_path / "destination" destination.mkdir() existing_file = destination / "file.txt" existing_file.write_bytes(b"old") @@ -399,13 +336,13 @@ def test_safe_extract_rejects_existing_destination_path() -> None: SafeZipExtractor().extractall( zip_file, destination, - work_dir=temp_dir / "work", + work_dir=tmp_path / "work", ) assert existing_file.read_bytes() == b"old" -def test_safe_extract_rejects_too_many_members() -> None: +def test_safe_extract_rejects_too_many_members(tmp_path: Path) -> None: """Test that ZIP archives with too many entries are rejected.""" payload = make_zip_payload( { @@ -416,27 +353,25 @@ def test_safe_extract_rejects_too_many_members() -> None: extractor = SafeZipExtractor(SafeZipLimits(max_members=1)) with ( - workspace_temp_dir() as temp_dir, zipfile.ZipFile(io.BytesIO(payload)) as zip_file, pytest.raises(PathError, match="Maximum allowed member count"), ): - extractor.extractall(zip_file, temp_dir) + extractor.extractall(zip_file, tmp_path) -def test_safe_extract_rejects_too_much_uncompressed_data() -> None: +def test_safe_extract_rejects_too_much_uncompressed_data(tmp_path: Path) -> None: """Test that ZIP archives with too much uncompressed data are rejected.""" payload = make_zip_payload({"LanguageTool/file.txt": b"four"}) extractor = SafeZipExtractor(SafeZipLimits(max_extracted_bytes=3)) with ( - workspace_temp_dir() as temp_dir, zipfile.ZipFile(io.BytesIO(payload)) as zip_file, pytest.raises(PathError, match="Maximum allowed extracted size"), ): - extractor.extractall(zip_file, temp_dir) + extractor.extractall(zip_file, tmp_path) -def test_safe_extract_rejects_oversized_member_during_copy() -> None: +def test_safe_extract_rejects_oversized_member_during_copy(tmp_path: Path) -> None: """Test that per-member extracted size limits are enforced while copying.""" payload = make_zip_payload({"LanguageTool/file.txt": b"four"}) extractor = SafeZipExtractor( @@ -447,27 +382,27 @@ def test_safe_extract_rejects_oversized_member_during_copy() -> None: ) with ( - workspace_temp_dir() as temp_dir, zipfile.ZipFile(io.BytesIO(payload)) as zip_file, pytest.raises(PathError, match="ZIP member larger"), ): - extractor.extractall(zip_file, temp_dir) + extractor.extractall(zip_file, tmp_path) -def test_safe_extract_rejects_too_much_compressed_data() -> None: +def test_safe_extract_rejects_too_much_compressed_data(tmp_path: Path) -> None: """Test that local ZIP extraction also applies the compressed-size limit.""" payload = make_zip_payload({"LanguageTool/file.txt": b"data"}) extractor = SafeZipExtractor(SafeZipLimits(max_archive_bytes=1)) with ( - workspace_temp_dir() as temp_dir, zipfile.ZipFile(io.BytesIO(payload)) as zip_file, pytest.raises(PathError, match="compressed member bytes"), ): - extractor.extractall(zip_file, temp_dir) + extractor.extractall(zip_file, tmp_path) -def test_safe_extract_rejects_suspicious_member_compression_ratio() -> None: +def test_safe_extract_rejects_suspicious_member_compression_ratio( + tmp_path: Path, +) -> None: """Test that a single member with an abusive compression ratio is rejected.""" payload = make_deflated_zip_payload({"LanguageTool/file.txt": b"A" * 4096}) extractor = SafeZipExtractor( @@ -478,14 +413,15 @@ def test_safe_extract_rejects_suspicious_member_compression_ratio() -> None: ) with ( - workspace_temp_dir() as temp_dir, zipfile.ZipFile(io.BytesIO(payload)) as zip_file, pytest.raises(PathError, match="suspicious compression ratio"), ): - extractor.extractall(zip_file, temp_dir) + extractor.extractall(zip_file, tmp_path) -def test_safe_extract_rejects_suspicious_total_compression_ratio() -> None: +def test_safe_extract_rejects_suspicious_total_compression_ratio( + tmp_path: Path, +) -> None: """Test that an archive with an abusive total compression ratio is rejected.""" payload = make_deflated_zip_payload({"LanguageTool/file.txt": b"A" * 4096}) extractor = SafeZipExtractor( @@ -496,14 +432,15 @@ def test_safe_extract_rejects_suspicious_total_compression_ratio() -> None: ) with ( - workspace_temp_dir() as temp_dir, zipfile.ZipFile(io.BytesIO(payload)) as zip_file, pytest.raises(PathError, match="suspicious total compression ratio"), ): - extractor.extractall(zip_file, temp_dir) + extractor.extractall(zip_file, tmp_path) -def test_safe_extract_checks_total_compression_ratio_after_all_members() -> None: +def test_safe_extract_checks_total_compression_ratio_after_all_members( + tmp_path: Path, +) -> None: """Test that total ratio checks are based on the final archive ratio.""" already_compressed = b"".join( hashlib.sha256(index.to_bytes(4, "big")).digest() for index in range(2048) @@ -521,17 +458,14 @@ def test_safe_extract_checks_total_compression_ratio_after_all_members() -> None ), ) - with ( - workspace_temp_dir() as temp_dir, - zipfile.ZipFile(io.BytesIO(payload)) as zip_file, - ): - extractor.extractall(zip_file, temp_dir) + with zipfile.ZipFile(io.BytesIO(payload)) as zip_file: + extractor.extractall(zip_file, tmp_path) assert ( - temp_dir / "LanguageTool" / "compressible.txt" + tmp_path / "LanguageTool" / "compressible.txt" ).read_bytes() == b"A" * 4096 assert ( - temp_dir / "LanguageTool" / "already-compressed.bin" + tmp_path / "LanguageTool" / "already-compressed.bin" ).read_bytes() == already_compressed @@ -596,7 +530,7 @@ def _open_returning_small(_m: object, _mode: str = "r") -> io.BytesIO: return io.BytesIO(b"hi") -def test_copy_member_raises_when_content_exceeds_declared_size() -> None: +def test_copy_member_raises_when_content_exceeds_declared_size(tmp_path: Path) -> None: """_copy_member raises when decompressed bytes exceed the declared file_size.""" payload = make_zip_payload({"LanguageTool/file.txt": b"hello world"}) with zipfile.ZipFile(io.BytesIO(payload)) as zf: @@ -604,13 +538,14 @@ def test_copy_member_raises_when_content_exceeds_declared_size() -> None: member.file_size = 3 with ( unittest.mock.patch.object(zf, "open", new=_open_returning_large), - workspace_temp_dir() as temp_dir, pytest.raises(PathError, match="expanded beyond declared size"), ): - SafeZipExtractor()._copy_member(zf, member, temp_dir / "file.txt") + SafeZipExtractor()._copy_member(zf, member, tmp_path / "file.txt") -def test_copy_member_raises_when_content_is_less_than_declared_size() -> None: +def test_copy_member_raises_when_content_is_less_than_declared_size( + tmp_path: Path, +) -> None: """_copy_member raises when fewer bytes are read than the declared file_size.""" payload = make_zip_payload({"LanguageTool/file.txt": b"hello world"}) with zipfile.ZipFile(io.BytesIO(payload)) as zf: @@ -618,7 +553,6 @@ def test_copy_member_raises_when_content_is_less_than_declared_size() -> None: member.file_size = 1000 with ( unittest.mock.patch.object(zf, "open", new=_open_returning_small), - workspace_temp_dir() as temp_dir, pytest.raises(PathError, match="extracted size mismatch"), ): - SafeZipExtractor()._copy_member(zf, member, temp_dir / "file.txt") + SafeZipExtractor()._copy_member(zf, member, tmp_path / "file.txt") diff --git a/tests/unit/test_server_unit.py b/tests/unit/test_server_unit.py index d635c69..fc6cb2a 100644 --- a/tests/unit/test_server_unit.py +++ b/tests/unit/test_server_unit.py @@ -1,4 +1,4 @@ -"""Unit tests for server.py — no Java, no network required.""" +"""Unit tests for server.py, no Java, no network required.""" from __future__ import annotations @@ -148,7 +148,7 @@ def download(self) -> None: """No-op download.""" def get_directory_path(self) -> None: - """No-op — not needed for start_local_server tests.""" + """No-op, not needed for start_local_server tests.""" def get_server_cmd(self, _port: int | None, _config: object) -> list[str]: """Return a fake Java command.""" @@ -261,6 +261,22 @@ def test_failsafe_language_used_when_locale_detection_fails(self) -> None: ): assert str(lt.language) == "en" + def test_explicit_host_is_used_instead_of_localhost(self) -> None: + """An explicitly passed host is stored instead of the localhost default.""" + explicit_host = "192.0.2.1" + with ( + patch( + "language_tool_python.server.get_locale_language", + return_value="en", + ), + patch.object(LanguageTool, "_get_languages", return_value=_LANGUAGES), + LanguageTool( + remote_server="http://fake/", + host=explicit_host, + ) as lt, + ): + assert lt.host == explicit_host + class TestLanguageToolDel: """Tests for __del__ resource warning.""" @@ -325,6 +341,24 @@ def test_mother_tongue_setter_stores_value(self) -> None: lt.mother_tongue = "fr" assert lt._mother_tongue == "fr" + def test_language_setter_raises_for_unsupported_tag(self) -> None: + """Language setter raises ValueError immediately for an unsupported tag.""" + lt = _bare_lt() + with ( + patch.object(lt, "_get_languages", return_value=_LANGUAGES), + pytest.raises(ValueError, match="unsupported language"), + ): + lt.language = "zz-ZZ" + + def test_mother_tongue_getter_raises_for_unsupported_tag(self) -> None: + """Mother_tongue getter raises ValueError when the stored tag is unsupported.""" + lt = _bare_lt(_mother_tongue="zz-ZZ") + with ( + patch.object(lt, "_get_languages", return_value=_LANGUAGES), + pytest.raises(ValueError, match="unsupported language"), + ): + _ = lt.mother_tongue + def test_proxies_getter_returns_stored_value(self) -> None: """Proxies getter returns _proxies.""" lt = _bare_lt(_proxies={"http": "http://p/"}) @@ -548,6 +582,32 @@ def test_returns_adjusted_matches_when_pattern_matches(self) -> None: assert len(results) == 1 assert results[0].offset == expected_offset + def test_sorts_matches_from_multiple_regions_by_final_offset(self) -> None: + """Matches from a later region can sort before matches from an earlier one. + + The first region ("AAA") is mocked to return a match with a large local + offset, while the second region ("BBB") is mocked to return a match with a + small local offset. After offset adjustment, the second region's match ends + up earlier in the text than the first region's, so the insertion order + (region order) genuinely differs from the final sorted order, this + exercises the ``sorted(..., key=_match_offset)`` call, not just a no-op sort + on an already-ordered single-region result. + """ + lt = _bare_lt() + + def _check_side_effect(region_text: str) -> list[_MockMatchWithOffset]: + if region_text == "AAA": + return [_MockMatchWithOffset(offset=10)] + return [_MockMatchWithOffset(offset=1)] + + with patch.object(lt, "check", side_effect=_check_side_effect): + results = lt.check_matching_regions("AAA BBB", r"\w+") + + # region "AAA" (start_offset=0) -> final offset 0+10=10 + # region "BBB" (start_offset=4) -> final offset 4+1=5 + # insertion order is [10, 5]; sorted order must be [5, 10] + assert [match.offset for match in results] == [5, 10] + class TestCreateParams: """Tests for _create_params() optional parameter branches.""" @@ -818,6 +878,35 @@ def test_restarts_local_server_before_raising_on_oserror(self) -> None: mock_term.assert_called_once() mock_start.assert_called_once() + def test_retries_and_succeeds_on_second_attempt(self) -> None: + """A transient OSError on the first attempt is retried and can succeed.""" + + class _FlakyThenOkSession(requests.Session): + def __init__(self, response: requests.Response) -> None: + super().__init__() + self._response = response + self.call_count = 0 + + def get( # type: ignore[override] + self, + _url: str | bytes, + **_kw: object, + ) -> requests.Response: + self.call_count += 1 + if self.call_count == 1: + err = "transient failure" + raise OSError(err) + return self._response + + expected_call_count = 2 + response = _make_json_response(b'{"ok": true}') + lt = _bare_lt(_remote=True) + session = _FlakyThenOkSession(response) + lt._session = session + result = lt._query_server("http://fake/", num_tries=2) + assert result == {"ok": True} + assert session.call_count == expected_call_count + class TestStartServerOnFreePort: """Tests for _start_server_on_free_port() port-retry logic.""" From fa081182c743def3676f3400ebbd6d9f2002c3b0 Mon Sep 17 00:00:00 2001 From: mdevolde Date: Thu, 2 Jul 2026 14:15:52 +0300 Subject: [PATCH 6/7] test: clean old docstring, add and precise some tests --- tests/property/test_prop_safe_zip.py | 69 ++++++++++++++++++++++------ tests/unit/test_download_unit.py | 7 +-- tests/unit/test_internals_utils.py | 15 ++++++ tests/unit/test_safe_zip.py | 57 ++++++++++++++++++++--- 4 files changed, 122 insertions(+), 26 deletions(-) diff --git a/tests/property/test_prop_safe_zip.py b/tests/property/test_prop_safe_zip.py index 371a0a9..2f1dffe 100644 --- a/tests/property/test_prop_safe_zip.py +++ b/tests/property/test_prop_safe_zip.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING import pytest -from hypothesis import given, settings +from hypothesis import example, given, settings from hypothesis import strategies as st from language_tool_python._internals.safe_zip import SafeZipExtractor @@ -20,6 +20,19 @@ _TRAVERSAL_SEGMENTS = st.sampled_from([".."] * 3 + ["."]) _SEP = st.sampled_from(["/", "\\"]) +_RESERVED_LEAVES = st.sampled_from( + [ + "CON", + "NUL", + "PRN", + "AUX", + "COM1", + "LPT1", + "trailing-space ", + "trailing-dot.", + "file.txt:stream", + ], +) def _make_zip_payload(files: dict[str, bytes]) -> bytes: @@ -36,8 +49,10 @@ def adversarial_member_names(draw: DrawFn) -> str: """Generate adversarial ZIP member names built from unsafe path segments. Combines repeated ``..`` traversal segments, mixed separators, absolute - paths, Windows drive letters, and UNC paths, so the strategy is not limited - to a small fixed set of prefixes. + paths, Windows drive letters, UNC paths, traversal sandwiched between + safe-looking components, and Windows-reserved/ADS-style leaf names, so the + strategy is not limited to a small fixed set of prefixes or to traversal + sitting only at the start of the path. """ depth = draw(st.integers(min_value=1, max_value=4)) segs = [draw(_TRAVERSAL_SEGMENTS) for _ in range(depth)] @@ -52,26 +67,48 @@ def adversarial_member_names(draw: DrawFn) -> str: max_size=20, ), ) - style = draw(st.sampled_from(["prefix", "embedded", "absolute", "drive", "unc"])) + style = draw( + st.sampled_from( + ["prefix", "embedded", "nested", "absolute", "drive", "unc", "reserved"], + ), + ) if style == "prefix": - return sep.join([*segs, leaf]) - if style == "embedded": - return sep.join(["safe", *segs, leaf]) - if style == "absolute": - return sep + leaf - if style == "drive": - return draw(st.sampled_from("CDZ")) + ":" + sep + leaf - return "\\\\server\\share\\" + leaf + name = sep.join([*segs, leaf]) + elif style == "embedded": + name = sep.join(["safe", *segs, leaf]) + elif style == "nested": + # Traversal sandwiched between two otherwise-safe-looking components, + # e.g. "safe/../../also-safe/leaf" rather than only leading traversal. + name = sep.join(["safe", *segs, "also-safe", leaf]) + elif style == "absolute": + name = sep + leaf + elif style == "drive": + name = draw(st.sampled_from("CDZ")) + ":" + sep + leaf + elif style == "unc": + name = sep * 2 + "server" + sep + "share" + sep + leaf + else: + name = sep.join(["safe", draw(_RESERVED_LEAVES)]) + return name @given(filename=adversarial_member_names()) @settings(max_examples=300, deadline=None) +@example(filename="../../../etc/passwd") +@example(filename="..\\..\\..\\Windows\\System32\\evil.dll") +@example(filename="safe/../../../etc/passwd") +@example(filename="\\\\server\\share\\..\\..\\evil") +@example(filename="safe/CON") +@example(filename="safe/file.txt:stream") +@example(filename="safe/trailing-dot.") def test_prop_safe_zip_path_traversal_always_rejected(filename: str) -> None: """Any adversarial ZIP member name must be rejected by SafeZipExtractor. Checks that ``SafeZipExtractor`` raises ``PathError`` for a wide range of - unsafe filenames (traversal, absolute paths, drive letters, UNC paths) - rather than a small fixed set of hand-picked prefixes. + unsafe filenames (traversal anywhere in the path, absolute paths, drive + letters, UNC paths, Windows-reserved/ADS-style names) rather than a small + fixed set of hand-picked prefixes. A handful of canonical zip-slip payloads + are pinned via ``@example`` so they are always checked regardless of the + Hypothesis random seed. A fresh temporary directory is created per example instead of using a pytest fixture, since function-scoped fixtures are not reset between @@ -127,6 +164,10 @@ def test_prop_zip_target_always_inside_destination(member_path: PurePosixPath) - @given(filename=adversarial_member_names()) @settings(max_examples=300) +@example(filename="../../../etc/passwd") +@example(filename="safe/../../../etc/passwd") +@example(filename="safe/CON") +@example(filename="safe/file.txt:stream") def test_prop_normalize_member_path_always_rejects_or_stays_relative( filename: str, ) -> None: diff --git a/tests/unit/test_download_unit.py b/tests/unit/test_download_unit.py index b518a53..2950085 100644 --- a/tests/unit/test_download_unit.py +++ b/tests/unit/test_download_unit.py @@ -1,9 +1,4 @@ -"""Unit tests for download_lt.py helpers (no network, no Java required). - -Note: test_download.py calls importlib.reload(download_lt) which invalidates -static class imports. We access classes via the module object (updated in-place -by reload) to ensure isinstance checks work regardless of test ordering. -""" +"""Unit tests for download_lt.py helpers (no network, no Java required).""" from __future__ import annotations diff --git a/tests/unit/test_internals_utils.py b/tests/unit/test_internals_utils.py index de8cef1..28eae3d 100644 --- a/tests/unit/test_internals_utils.py +++ b/tests/unit/test_internals_utils.py @@ -2,12 +2,14 @@ from __future__ import annotations +import locale from typing import TYPE_CHECKING import psutil import pytest from language_tool_python._internals.utils import ( + FAILSAFE_LANGUAGE, get_env_float, get_env_int, get_language_tool_download_path, @@ -149,6 +151,19 @@ def test_returns_string(self) -> None: assert isinstance(result, str) assert len(result) > 0 + def test_falls_back_to_failsafe_when_no_locale_is_configured( + self, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """FAILSAFE_LANGUAGE is returned when both locale lookups yield no language.""" + + def _no_locale(*_args: object) -> tuple[None, None]: + return (None, None) + + monkeypatch.setattr(locale, "getlocale", _no_locale) + monkeypatch.setattr(locale, "getdefaultlocale", _no_locale) + assert get_locale_language() == FAILSAFE_LANGUAGE + class _MockPsutilProcess: """Minimal stand-in for psutil.Process used in kill_process_force tests.""" diff --git a/tests/unit/test_safe_zip.py b/tests/unit/test_safe_zip.py index b6f35bd..e1cd99b 100644 --- a/tests/unit/test_safe_zip.py +++ b/tests/unit/test_safe_zip.py @@ -13,12 +13,8 @@ from language_tool_python._internals.safe_zip import SafeZipExtractor, SafeZipLimits from language_tool_python.exceptions import PathError -EXPECTED_MAX_ARCHIVE_BYTES = 11 -EXPECTED_MAX_EXTRACTED_BYTES = 22 -EXPECTED_MAX_MEMBERS = 33 -EXPECTED_MAX_MEMBER_EXTRACTED_BYTES = 44 -EXPECTED_MAX_MEMBER_COMPRESSION_RATIO = 55.5 -EXPECTED_MAX_TOTAL_COMPRESSION_RATIO = 66.5 +_EXPECTED_INT_ENV_OVERRIDE = 12345 +_EXPECTED_FLOAT_ENV_OVERRIDE = 12.5 def make_zip_payload(files: dict[str, bytes]) -> bytes: @@ -105,6 +101,55 @@ def test_safe_zip_limits_defaults_wired_to_module_constants() -> None: ) +@pytest.mark.parametrize( + "env_var", + [ + safe_zip.LTP_SAFE_ZIP_MAX_ARCHIVE_BYTES_ENV_VAR, + safe_zip.LTP_SAFE_ZIP_MAX_EXTRACTED_BYTES_ENV_VAR, + safe_zip.LTP_SAFE_ZIP_MAX_MEMBERS_ENV_VAR, + safe_zip.LTP_SAFE_ZIP_MAX_MEMBER_EXTRACTED_BYTES_ENV_VAR, + ], +) +def test_safe_zip_int_limit_env_var_overrides_default( + monkeypatch: pytest.MonkeyPatch, + env_var: str, +) -> None: + """Test that each safe-zip integer limit's dedicated env var overrides it. + + Each ``DEFAULT_MAX_*`` constant in ``safe_zip`` is computed once at import + time as ``get_env_int(, ...)``. Reloading the module to prove the + override is wired up would leak global state across tests (see + ``test_max_download_bytes_uses_env_override`` in ``test_download.py`` for the + same reasoning applied to the download size limit), so this instead calls + ``get_env_int`` directly with each module's real environment-variable name, + which is exactly what happens at import time. + """ + monkeypatch.setenv(env_var, str(_EXPECTED_INT_ENV_OVERRIDE)) + assert utils.get_env_int(env_var, 1) == _EXPECTED_INT_ENV_OVERRIDE + + +@pytest.mark.parametrize( + "env_var", + [ + safe_zip.LTP_SAFE_ZIP_MAX_MEMBER_COMPRESSION_RATIO_ENV_VAR, + safe_zip.LTP_SAFE_ZIP_MAX_TOTAL_COMPRESSION_RATIO_ENV_VAR, + ], +) +def test_safe_zip_float_limit_env_var_overrides_default( + monkeypatch: pytest.MonkeyPatch, + env_var: str, +) -> None: + """Test that each safe-zip float ratio limit's dedicated env var overrides it. + + Same reasoning as ``test_safe_zip_int_limit_env_var_overrides_default``, but + for the two ``float``-typed compression-ratio limits. + """ + monkeypatch.setenv(env_var, str(_EXPECTED_FLOAT_ENV_OVERRIDE)) + assert utils.get_env_float(env_var, 1.0) == pytest.approx( + _EXPECTED_FLOAT_ENV_OVERRIDE, + ) + + @pytest.mark.parametrize("configured", ["nan", "inf"]) def test_safe_zip_float_env_rejects_non_finite_values( monkeypatch: pytest.MonkeyPatch, From 644bfc84218de12c1d665abe23516e0daaec9165 Mon Sep 17 00:00:00 2001 From: mdevolde Date: Thu, 2 Jul 2026 14:52:39 +0300 Subject: [PATCH 7/7] test: make bench optional, add targets in makefiles for test types --- Makefile | 16 ++++++++++++++-- make.bat | 22 +++++++++++++++++++++- tests/benchmarks/conftest.py | 14 +++++++++++++- 3 files changed, 48 insertions(+), 4 deletions(-) diff --git a/Makefile b/Makefile index c321966..724e2a6 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: default install format fix ruff-check mypy-check check test doc publish clean +.PHONY: default install format fix ruff-check mypy-check check benchmark-test integration-test property-test unit-test test doc publish clean UV := $(shell command -v uv 2>/dev/null || true) ifeq ($(UV),) @@ -6,7 +6,7 @@ $(warning uv not found. Install uv (curl -LsSf https://astral.sh/uv/install.sh | endif default: - @echo "Usage: make [install|format|fix|ruff-check|mypy-check|check|test|doc|publish|clean]" + @echo "Usage: make [install|format|fix|ruff-check|mypy-check|check|benchmark-test|integration-test|property-test|unit-test|test|doc|publish|clean]" @exit 1 install: @@ -29,6 +29,18 @@ check: make ruff-check make mypy-check +benchmark-test: + uv run --group tests --locked pytest -m perf + +integration-test: + uv run --group tests --locked pytest -m integration + +property-test: + uv run --group tests --locked pytest -m property + +unit-test: + uv run --group tests --locked pytest -m unit + test: uv run --group tests --locked pytest diff --git a/make.bat b/make.bat index 7ad5bde..fc49e02 100644 --- a/make.bat +++ b/make.bat @@ -13,12 +13,16 @@ if "%1"=="fix" goto fix if "%1"=="ruff-check" goto ruff-check if "%1"=="mypy-check" goto mypy-check if "%1"=="check" goto check +if "%1"=="benchmark-test" goto benchmark-test +if "%1"=="integration-test" goto integration-test +if "%1"=="property-test" goto property-test +if "%1"=="unit-test" goto unit-test if "%1"=="test" goto test if "%1"=="doc" goto doc if "%1"=="publish" goto publish if "%1"=="clean" goto clean -echo Usage: make.bat [install^|format^|fix^|ruff-check^|mypy-check^|check^|test^|doc^|publish^|clean] +echo Usage: make.bat [install^|format^|fix^|ruff-check^|mypy-check^|check^|benchmark-test^|integration-test^|property-test^|unit-test^|test^|doc^|publish^|clean] exit /b 1 :install @@ -50,6 +54,22 @@ if errorlevel 1 exit /b %errorlevel% call :mypy-check exit /b %errorlevel% +:benchmark-test +uv run --group tests --locked pytest -m perf +exit /b %errorlevel% + +:integration-test +uv run --group tests --locked pytest -m integration +exit /b %errorlevel% + +:property-test +uv run --group tests --locked pytest -m property +exit /b %errorlevel% + +:unit-test +uv run --group tests --locked pytest -m unit +exit /b %errorlevel% + :test uv run --group tests --locked pytest exit /b %errorlevel% diff --git a/tests/benchmarks/conftest.py b/tests/benchmarks/conftest.py index 6cebb40..cff22dd 100644 --- a/tests/benchmarks/conftest.py +++ b/tests/benchmarks/conftest.py @@ -8,10 +8,22 @@ def pytest_collection_modifyitems( + config: pytest.Config, items: list[pytest.Item], ) -> None: - """Apply the 'perf' marker to all tests collected from this directory.""" + """Apply the 'perf' marker to all tests collected from this directory. + + Benchmarks require a running JVM and are slow, so they're opt-in: skipped + unless explicitly selected with ``-m perf`` (or a markexpr referencing it). + """ benchmarks_dir = Path(__file__).parent + markexpr: str = config.getoption("markexpr") + run_perf = "perf" in markexpr + skip_perf = pytest.mark.skip( + reason="Benchmarks are opt-in, run with -m perf to include them." + ) for item in items: if item.path.is_relative_to(benchmarks_dir): item.add_marker(pytest.mark.perf) + if not run_perf: + item.add_marker(skip_perf)