diff --git a/docs/reference/cli.md b/docs/reference/cli.md index b65f8256ac..ae0d6c675f 100644 --- a/docs/reference/cli.md +++ b/docs/reference/cli.md @@ -294,11 +294,15 @@ Usage: sqlmesh invalidate [OPTIONS] ENVIRONMENT of the janitor process. Options: - -s, --sync Wait for the environment to be deleted before returning. If not - specified, the environment will be deleted asynchronously by the - janitor process. This option requires a connection to the data - warehouse. - --help Show this message and exit. + -s, --sync Wait for the environment to be deleted before returning. + If not specified, the environment will be deleted + asynchronously by the janitor process. This option + requires a connection to the data warehouse. + --cleanup-snapshots + After invalidating, synchronously delete unreferenced + physical snapshot tables formerly referenced by this + environment. + --help Show this message and exit. ``` ## janitor @@ -313,14 +317,16 @@ Usage: sqlmesh janitor [OPTIONS] Options: --ignore-ttl Cleanup snapshots that are not referenced in any environment, regardless of when they're set to expire. Has - no effect when --environment is specified. + When --environment is specified, cleanup is scoped to + snapshots formerly referenced by that environment. --force-delete Delete expired environment and snapshot state records even when the physical table or view drops fail. Any objects that could not be dropped become orphaned and must be removed manually. -e, --environment TEXT - Scope cleanup to a single expired environment. Global - snapshot and interval compaction are skipped. + Scope cleanup to a single expired environment. With + --ignore-ttl, snapshot cleanup is scoped to snapshots + formerly referenced by this environment. --help Show this message and exit. ``` diff --git a/sqlmesh/cli/main.py b/sqlmesh/cli/main.py index b3c7a7027b..95969a0f23 100644 --- a/sqlmesh/cli/main.py +++ b/sqlmesh/cli/main.py @@ -620,6 +620,11 @@ def run(ctx: click.Context, environment: t.Optional[str] = None, **kwargs: t.Any is_flag=True, help="Wait for the environment to be deleted before returning. If not specified, the environment will be deleted asynchronously by the janitor process. This option requires a connection to the data warehouse.", ) +@click.option( + "--cleanup-snapshots", + is_flag=True, + help="After invalidating, synchronously delete unreferenced physical snapshot tables formerly referenced by this environment.", +) @click.pass_context @error_handler @cli_analytics @@ -633,7 +638,7 @@ def invalidate(ctx: click.Context, environment: str, **kwargs: t.Any) -> None: @click.option( "--ignore-ttl", is_flag=True, - help="Cleanup snapshots that are not referenced in any environment, regardless of when they're set to expire. Has no effect when --environment is specified.", + help="Cleanup snapshots that are not referenced in any environment, regardless of when they're set to expire. When --environment is specified, cleanup is scoped to snapshots formerly referenced by that environment.", ) @click.option( "--force-delete", @@ -645,7 +650,7 @@ def invalidate(ctx: click.Context, environment: str, **kwargs: t.Any) -> None: "--environment", "-e", default=None, - help="Scope cleanup to a single expired environment. Global snapshot and interval compaction are skipped.", + help="Scope cleanup to a single expired environment. With --ignore-ttl, snapshot cleanup is scoped to snapshots formerly referenced by this environment.", ) @click.pass_context @error_handler diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 5902977331..21f4281c78 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -99,6 +99,7 @@ Snapshot, SnapshotEvaluator, SnapshotFingerprint, + SnapshotId, missing_intervals, to_table_mapping, ) @@ -108,7 +109,10 @@ StateReader, StateSync, ) -from sqlmesh.core.janitor import cleanup_expired_views, delete_expired_snapshots +from sqlmesh.core.janitor import ( + cleanup_expired_views, + delete_expired_snapshots, +) from sqlmesh.core.table_diff import TableDiff from sqlmesh.core.test import ( ModelTextTestResult, @@ -1847,17 +1851,24 @@ def apply( ) @python_api_analytics - def invalidate_environment(self, name: str, sync: bool = False) -> None: + def invalidate_environment( + self, name: str, sync: bool = False, cleanup_snapshots: bool = False + ) -> None: """Invalidates the target environment by setting its expiration timestamp to now. Args: name: The name of the environment to invalidate. sync: If True, the call blocks until the environment is deleted. Otherwise, the environment will be deleted asynchronously by the janitor process. + cleanup_snapshots: If True, immediately deletes unreferenced physical snapshot tables that were + formerly referenced by this environment. Cleanup runs synchronously regardless of sync. """ name = Environment.sanitize_name(name) self.state_sync.invalidate_environment(name) - if sync: + if cleanup_snapshots: + self._run_janitor(ignore_ttl=True, environment=name) + self.console.log_success(f"Environment '{name}' deleted.") + elif sync: self._cleanup_environments(name=name) self.console.log_success(f"Environment '{name}' deleted.") else: @@ -2993,6 +3004,16 @@ def _run_janitor( current_ts = now_timestamp() failures: t.List[str] = [] + target_snapshot_ids: t.Set[SnapshotId] = set() + if environment is not None and ignore_ttl: + expired_environments = self.state_sync.get_expired_environments( + current_ts=current_ts, name=environment + ) + if expired_environments: + expired_env = self.state_reader.get_environment(expired_environments[0].name) + if expired_env: + target_snapshot_ids = {s.snapshot_id for s in expired_env.snapshots} + # Clean up expired environments by removing their views and schemas failures.extend( self._cleanup_environments( @@ -3013,6 +3034,27 @@ def _run_janitor( ) ) self.state_sync.compact_intervals() + elif ( + ignore_ttl + and target_snapshot_ids + and not self.state_reader.get_environment(environment) + ): + self.console.log_warning( + "Scoped snapshot cleanup will permanently delete unreferenced physical snapshot " + f"tables formerly referenced by environment '{environment}'." + ) + failures.extend( + delete_expired_snapshots( + self.state_sync, + self.snapshot_evaluator, + current_ts=current_ts, + ignore_ttl=ignore_ttl, + force_delete=force_delete, + console=self.console, + batch_size=self.config.janitor.expired_snapshots_batch_size, + target_snapshot_ids=target_snapshot_ids, + ) + ) if failures: failure_string = "\n - ".join(failures) diff --git a/sqlmesh/core/janitor.py b/sqlmesh/core/janitor.py index 92d889e276..11950ca13f 100644 --- a/sqlmesh/core/janitor.py +++ b/sqlmesh/core/janitor.py @@ -8,7 +8,7 @@ from sqlmesh.core.console import Console from sqlmesh.core.dialect import schema_ from sqlmesh.core.environment import Environment -from sqlmesh.core.snapshot import SnapshotEvaluator +from sqlmesh.core.snapshot import SnapshotEvaluator, SnapshotIdLike from sqlmesh.core.state_sync import StateSync from sqlmesh.core.state_sync.common import ( logger, @@ -127,6 +127,7 @@ def delete_expired_snapshots( ignore_ttl: bool = False, force_delete: bool = False, batch_size: t.Optional[int] = None, + target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None, console: t.Optional[Console] = None, ) -> t.List[str]: """Delete all expired snapshots in batches. @@ -153,6 +154,7 @@ def delete_expired_snapshots( current_ts=current_ts, ignore_ttl=ignore_ttl, batch_size=batch_size, + target_snapshot_ids=target_snapshot_ids, ): end_info = ( f"updated_ts={batch.batch_range.end.updated_ts}" @@ -184,6 +186,7 @@ def delete_expired_snapshots( end=batch.batch_range.end, ), ignore_ttl=ignore_ttl, + target_snapshot_ids=target_snapshot_ids, ) logger.info("Cleaned up expired snapshots batch") num_expired_snapshots += len(batch.expired_snapshot_ids) diff --git a/sqlmesh/core/state_sync/base.py b/sqlmesh/core/state_sync/base.py index 5c35be5ccb..6f5023304f 100644 --- a/sqlmesh/core/state_sync/base.py +++ b/sqlmesh/core/state_sync/base.py @@ -308,6 +308,7 @@ def get_expired_snapshots( batch_range: ExpiredBatchRange, current_ts: t.Optional[int] = None, ignore_ttl: bool = False, + target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None, ) -> t.Optional[ExpiredSnapshotBatch]: """Returns a single batch of expired snapshots ordered by (updated_ts, name, identifier). @@ -315,6 +316,8 @@ def get_expired_snapshots( current_ts: Timestamp used to evaluate expiration. ignore_ttl: If True, include snapshots regardless of TTL (only checks if unreferenced). batch_range: The range of the batch to fetch. + target_snapshot_ids: If provided, only consider snapshots with these IDs. Useful for + scoped cleanup after environment invalidation. Returns: A batch describing expired snapshots or None if no snapshots are pending cleanup. @@ -368,6 +371,7 @@ def delete_expired_snapshots( batch_range: ExpiredBatchRange, ignore_ttl: bool = False, current_ts: t.Optional[int] = None, + target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None, ) -> None: """Removes expired snapshots. @@ -379,6 +383,8 @@ def delete_expired_snapshots( ignore_ttl: Ignore the TTL on the snapshot when considering it expired. This has the effect of deleting all snapshots that are not referenced in any environment current_ts: Timestamp used to evaluate expiration. + target_snapshot_ids: If provided, only delete snapshots with these IDs. Useful for + scoped cleanup after environment invalidation. """ @abc.abstractmethod diff --git a/sqlmesh/core/state_sync/cache.py b/sqlmesh/core/state_sync/cache.py index 77f3fc6ba5..edf74e03f9 100644 --- a/sqlmesh/core/state_sync/cache.py +++ b/sqlmesh/core/state_sync/cache.py @@ -113,12 +113,14 @@ def delete_expired_snapshots( batch_range: ExpiredBatchRange, ignore_ttl: bool = False, current_ts: t.Optional[int] = None, + target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None, ) -> None: self.snapshot_cache.clear() self.state_sync.delete_expired_snapshots( batch_range=batch_range, ignore_ttl=ignore_ttl, current_ts=current_ts, + target_snapshot_ids=target_snapshot_ids, ) def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotIntervals]) -> None: diff --git a/sqlmesh/core/state_sync/common.py b/sqlmesh/core/state_sync/common.py index 6308c0c29d..199e4daf16 100644 --- a/sqlmesh/core/state_sync/common.py +++ b/sqlmesh/core/state_sync/common.py @@ -16,6 +16,7 @@ from sqlmesh.core.snapshot import ( Snapshot, SnapshotId, + SnapshotIdLike, SnapshotTableCleanupTask, SnapshotTableInfo, ) @@ -288,6 +289,7 @@ def iter_expired_snapshot_batches( current_ts: int, ignore_ttl: bool = False, batch_size: t.Optional[int] = None, + target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None, ) -> t.Iterator[ExpiredSnapshotBatch]: """Yields expired snapshot batches. @@ -306,6 +308,7 @@ def iter_expired_snapshot_batches( current_ts=current_ts, ignore_ttl=ignore_ttl, batch_range=batch_range, + target_snapshot_ids=target_snapshot_ids, ) if batch is None: diff --git a/sqlmesh/core/state_sync/db/facade.py b/sqlmesh/core/state_sync/db/facade.py index 572e54b7f1..8fb732e17c 100644 --- a/sqlmesh/core/state_sync/db/facade.py +++ b/sqlmesh/core/state_sync/db/facade.py @@ -267,6 +267,7 @@ def get_expired_snapshots( batch_range: ExpiredBatchRange, current_ts: t.Optional[int] = None, ignore_ttl: bool = False, + target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None, ) -> t.Optional[ExpiredSnapshotBatch]: current_ts = current_ts or now_timestamp() return self.snapshot_state.get_expired_snapshots( @@ -274,6 +275,7 @@ def get_expired_snapshots( current_ts=current_ts, ignore_ttl=ignore_ttl, batch_range=batch_range, + target_snapshot_ids=target_snapshot_ids, ) def get_expired_environments( @@ -287,11 +289,13 @@ def delete_expired_snapshots( batch_range: ExpiredBatchRange, ignore_ttl: bool = False, current_ts: t.Optional[int] = None, + target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None, ) -> None: batch = self.get_expired_snapshots( ignore_ttl=ignore_ttl, current_ts=current_ts, batch_range=batch_range, + target_snapshot_ids=target_snapshot_ids, ) if batch and batch.expired_snapshot_ids: self.snapshot_state.delete_snapshots(batch.expired_snapshot_ids) diff --git a/sqlmesh/core/state_sync/db/snapshot.py b/sqlmesh/core/state_sync/db/snapshot.py index 9b4337b504..287a69013b 100644 --- a/sqlmesh/core/state_sync/db/snapshot.py +++ b/sqlmesh/core/state_sync/db/snapshot.py @@ -170,6 +170,7 @@ def get_expired_snapshots( current_ts: int, ignore_ttl: bool, batch_range: ExpiredBatchRange, + target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None, ) -> t.Optional[ExpiredSnapshotBatch]: expired_query = exp.select("name", "identifier", "version", "updated_ts").from_( self.snapshots_table @@ -180,6 +181,16 @@ def get_expired_snapshots( (exp.column("updated_ts") + exp.column("ttl_ms")) <= current_ts ) + if target_snapshot_ids is not None: + target_conditions = list( + snapshot_id_filter( + self.engine_adapter, + target_snapshot_ids, + batch_size=self.SNAPSHOT_BATCH_SIZE, + ) + ) + expired_query = expired_query.where(exp.or_(*target_conditions)) + expired_query = expired_query.where(batch_range.where_filter) promoted_snapshot_ids = { diff --git a/tests/core/integration/test_aux_commands.py b/tests/core/integration/test_aux_commands.py index 7de585576d..1624b94af1 100644 --- a/tests/core/integration/test_aux_commands.py +++ b/tests/core/integration/test_aux_commands.py @@ -481,6 +481,124 @@ def test_invalidating_environment(sushi_context: Context): assert start_schemas - schemas_after_janitor == {"sushi__dev"} +def test_invalidate_environment_cleanup_snapshots_scoped(tmp_path: Path): + """Test that --cleanup-snapshots only deletes unreferenced snapshots from the invalidated env.""" + models_dir = tmp_path / "models" + models_dir.mkdir() + (models_dir / "model1.sql").write_text("MODEL(name test.model1, kind FULL); SELECT 1 AS col") + (models_dir / "model2.sql").write_text("MODEL(name test.model2, kind FULL); SELECT 2 AS col") + + ctx = Context( + paths=[tmp_path], + config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")), + ) + + # Apply both models to prod and dev. + ctx.plan("prod", no_prompts=True, auto_apply=True) + ctx.plan("dev", no_prompts=True, auto_apply=True, include_unmodified=True) + + prod_env = ctx.state_sync.get_environment("prod") + dev_env = ctx.state_sync.get_environment("dev") + assert prod_env is not None + assert dev_env is not None + + prod_snapshot_ids = {s.snapshot_id for s in prod_env.snapshots} + dev_snapshot_ids = {s.snapshot_id for s in dev_env.snapshots} + + # In a virtual environment, dev shares snapshots with prod. + # Shared snapshots must NOT be deleted when invalidating dev with --cleanup-snapshots. + shared_snapshot_ids = prod_snapshot_ids & dev_snapshot_ids + + ctx.invalidate_environment("dev", cleanup_snapshots=True) + + # The dev environment record should be gone. + assert ctx.state_sync.get_environment("dev") is None + + # Shared snapshots (also in prod) must still exist. + remaining_snapshots = ctx.state_sync.get_snapshots(list(shared_snapshot_ids)) + assert set(remaining_snapshots.keys()) == shared_snapshot_ids + + # Prod environment should be unaffected. + assert ctx.state_sync.get_environment("prod") is not None + + +def test_invalidate_environment_cleanup_snapshots_warns_and_drops_physical_tables( + tmp_path: Path, mocker: MockerFixture +): + models_dir = tmp_path / "models" + models_dir.mkdir() + (models_dir / "model1.sql").write_text("MODEL(name test.model1, kind FULL); SELECT 1 AS col") + + ctx = Context( + paths=[tmp_path], + config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")), + ) + + ctx.plan("dev", no_prompts=True, auto_apply=True) + snapshot = ctx.get_snapshot("test.model1") + assert snapshot is not None + snapshot_ids = [snapshot.snapshot_id] + physical_table_names = [ + snapshot.table_name(is_deployable=False), + snapshot.table_name(is_deployable=True), + ] + assert any(ctx.engine_adapter.table_exists(table_name) for table_name in physical_table_names) + + warning_mock = mocker.patch.object(ctx.console, "log_warning") + + ctx.invalidate_environment("dev", cleanup_snapshots=True) + + warning_mock.assert_any_call( + "Scoped snapshot cleanup will permanently delete unreferenced physical snapshot tables " + "formerly referenced by environment 'dev'." + ) + assert ctx.state_sync.get_environment("dev") is None + assert not ctx.state_sync.get_snapshots(snapshot_ids) + assert not any( + ctx.engine_adapter.table_exists(table_name) for table_name in physical_table_names + ) + + +def test_janitor_environment_ignore_ttl_cleans_only_scoped_snapshots( + tmp_path: Path, mocker: MockerFixture +): + models_dir = tmp_path / "models" + models_dir.mkdir() + model_path = models_dir / "model1.sql" + model_path.write_text("MODEL(name test.model1, kind FULL); SELECT 1 AS col") + + ctx = Context( + paths=[tmp_path], + config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")), + ) + + ctx.plan("dev_a", no_prompts=True, auto_apply=True) + dev_a_snapshot = ctx.get_snapshot("test.model1") + assert dev_a_snapshot is not None + + model_path.write_text("MODEL(name test.model1, kind FULL); SELECT 2 AS col") + ctx.load() + ctx.plan("dev_b", no_prompts=True, auto_apply=True) + dev_b_snapshot = ctx.get_snapshot("test.model1") + assert dev_b_snapshot is not None + assert dev_a_snapshot.snapshot_id != dev_b_snapshot.snapshot_id + + ctx.invalidate_environment("dev_a") + ctx.invalidate_environment("dev_b") + warning_mock = mocker.patch.object(ctx.console, "log_warning") + + ctx.run_janitor(ignore_ttl=True, environment="dev_a") + + warning_mock.assert_any_call( + "Scoped snapshot cleanup will permanently delete unreferenced physical snapshot tables " + "formerly referenced by environment 'dev_a'." + ) + assert ctx.state_sync.get_environment("dev_a") is None + assert ctx.state_sync.get_environment("dev_b") is not None + assert not ctx.state_sync.get_snapshots([dev_a_snapshot.snapshot_id]) + assert ctx.state_sync.get_snapshots([dev_b_snapshot.snapshot_id]) + + @time_machine.travel("2023-01-08 15:00:00 UTC") def test_evaluate_uncategorized_snapshot(init_and_plan_context: t.Callable): context, plan = init_and_plan_context("examples/sushi") diff --git a/tests/core/state_sync/test_state_sync.py b/tests/core/state_sync/test_state_sync.py index 348a883fd5..cc1f58bf66 100644 --- a/tests/core/state_sync/test_state_sync.py +++ b/tests/core/state_sync/test_state_sync.py @@ -59,10 +59,12 @@ def _get_cleanup_tasks( *, limit: int = 1000, ignore_ttl: bool = False, + target_snapshot_ids: t.Optional[t.Collection[SnapshotId]] = None, ) -> t.List[SnapshotTableCleanupTask]: batch = state_sync.get_expired_snapshots( ignore_ttl=ignore_ttl, batch_range=ExpiredBatchRange.init_batch_range(batch_size=limit), + target_snapshot_ids=target_snapshot_ids, ) return [] if batch is None else batch.cleanup_tasks @@ -1351,6 +1353,105 @@ def test_delete_expired_snapshots(state_sync: EngineAdapterStateSync, make_snaps assert not state_sync.get_snapshots(all_snapshots) +def test_delete_expired_snapshots_scoped_to_target_ids( + state_sync: EngineAdapterStateSync, + make_snapshot: t.Callable, + get_snapshot_intervals: t.Callable, +): + now_ts = now_timestamp() + + snapshot_a = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, ds"), + ), + ) + snapshot_a.ttl = "in 10 seconds" + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_a.updated_ts = now_ts - 15000 + + snapshot_b = make_snapshot( + SqlModel( + name="b", + query=parse_one("select b, ds"), + ), + ) + snapshot_b.ttl = "in 10 seconds" + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_b.updated_ts = now_ts - 15000 + + state_sync.push_snapshots([snapshot_a, snapshot_b]) + state_sync.add_interval(snapshot_a, "2023-01-01", "2023-01-03") + state_sync.add_interval(snapshot_b, "2023-01-01", "2023-01-03") + + assert _get_cleanup_tasks( + state_sync, + ignore_ttl=True, + target_snapshot_ids=[snapshot_a.snapshot_id], + ) == [SnapshotTableCleanupTask(snapshot=snapshot_a.table_info, dev_table_only=False)] + + state_sync.delete_expired_snapshots( + batch_range=ExpiredBatchRange.all_batch_range(), + ignore_ttl=True, + target_snapshot_ids=[snapshot_a.snapshot_id], + ) + + assert not state_sync.get_snapshots([snapshot_a]) + assert state_sync.get_snapshots([snapshot_b]) + assert not get_snapshot_intervals(snapshot_a) + assert get_snapshot_intervals(snapshot_b) + + +def test_get_expired_snapshots_scoped_excludes_referenced_snapshots( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +): + now_ts = now_timestamp() + + snapshot_a = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, ds"), + ), + ) + snapshot_a.ttl = "in 10 seconds" + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_a.updated_ts = now_ts - 15000 + + snapshot_b = make_snapshot( + SqlModel( + name="b", + query=parse_one("select b, ds"), + ), + ) + snapshot_b.ttl = "in 10 seconds" + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_b.updated_ts = now_ts - 15000 + + state_sync.push_snapshots([snapshot_a, snapshot_b]) + prod_env = Environment( + name="prod", + snapshots=[snapshot_b.table_info], + start_at="2022-01-01", + end_at="2022-01-01", + plan_id="test_plan_id", + previous_plan_id="test_plan_id", + ) + state_sync.promote(prod_env) + state_sync.finalize(prod_env) + + batch = state_sync.get_expired_snapshots( + ignore_ttl=True, + batch_range=ExpiredBatchRange.all_batch_range(), + target_snapshot_ids=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + ) + + assert batch is not None + assert batch.expired_snapshot_ids == {snapshot_a.snapshot_id} + assert batch.cleanup_tasks == [ + SnapshotTableCleanupTask(snapshot=snapshot_a.table_info, dev_table_only=False) + ] + + def test_get_expired_snapshot_batch(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): now_ts = now_timestamp() @@ -4220,3 +4321,144 @@ def test_state_version_is_too_old( match="The current state belongs to an old version of SQLMesh that is no longer supported. Please upgrade to 0.134.0 first before upgrading to.*", ): state_sync.migrate(skip_backup=True) + + +def test_get_expired_snapshots_scoped_to_target_ids( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +) -> None: + """Test that get_expired_snapshots with target_snapshot_ids only returns snapshots in the target set.""" + now_ts = now_timestamp() + + snapshot_a = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, ds"), + ), + ) + snapshot_a.ttl = "in 10 seconds" + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_a.updated_ts = now_ts - 15000 + + snapshot_b = make_snapshot( + SqlModel( + name="b", + query=parse_one("select b, ds"), + ), + ) + snapshot_b.ttl = "in 10 seconds" + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_b.updated_ts = now_ts - 15000 + + state_sync.push_snapshots([snapshot_a, snapshot_b]) + + # Both snapshots are expired (no active environments). + # When scoped to only snapshot_a, only snapshot_a should be returned. + batch = state_sync.get_expired_snapshots( + ignore_ttl=True, + batch_range=ExpiredBatchRange.all_batch_range(), + target_snapshot_ids=[snapshot_a.snapshot_id], + ) + assert batch is not None + assert batch.expired_snapshot_ids == {snapshot_a.snapshot_id} + assert [t.snapshot.name for t in batch.cleanup_tasks] == [snapshot_a.name] + + # snapshot_b should still exist because it was not in the target set. + batch_all = state_sync.get_expired_snapshots( + ignore_ttl=True, + batch_range=ExpiredBatchRange.all_batch_range(), + ) + assert batch_all is not None + assert snapshot_b.snapshot_id in batch_all.expired_snapshot_ids + + +def test_get_expired_snapshots_scoped_excludes_shared_snapshots( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +) -> None: + """Test that scoped cleanup respects protection: snapshots shared with other environments are not deleted.""" + now_ts = now_timestamp() + + snapshot_a = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, ds"), + ), + ) + snapshot_a.ttl = "in 10 seconds" + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_a.updated_ts = now_ts - 15000 + + snapshot_b = make_snapshot( + SqlModel( + name="b", + query=parse_one("select b, ds"), + ), + ) + snapshot_b.ttl = "in 10 seconds" + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_b.updated_ts = now_ts - 15000 + + state_sync.push_snapshots([snapshot_a, snapshot_b]) + + # Promote snapshot_b to another active environment (prod-like). + prod_env = Environment( + name="prod", + snapshots=[snapshot_b.table_info], + start_at="2022-01-01", + end_at="2022-01-01", + plan_id="test_plan_id", + previous_plan_id="test_plan_id", + ) + state_sync.promote(prod_env) + state_sync.finalize(prod_env) + + # Even though snapshot_b is in the target set, it should NOT be returned + # because it is still referenced by prod_env. + batch = state_sync.get_expired_snapshots( + ignore_ttl=True, + batch_range=ExpiredBatchRange.all_batch_range(), + target_snapshot_ids=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + ) + assert batch is not None + # Only snapshot_a is exclusively owned (not referenced by any active environment). + assert batch.expired_snapshot_ids == {snapshot_a.snapshot_id} + assert [t.snapshot.name for t in batch.cleanup_tasks] == [snapshot_a.name] + + +def test_delete_expired_snapshots_scoped( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +) -> None: + """Test that delete_expired_snapshots with target_snapshot_ids only deletes scoped snapshots.""" + now_ts = now_timestamp() + + snapshot_a = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, ds"), + ), + ) + snapshot_a.ttl = "in 10 seconds" + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_a.updated_ts = now_ts - 15000 + + snapshot_b = make_snapshot( + SqlModel( + name="b", + query=parse_one("select b, ds"), + ), + ) + snapshot_b.ttl = "in 10 seconds" + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_b.updated_ts = now_ts - 15000 + + state_sync.push_snapshots([snapshot_a, snapshot_b]) + + # Delete only snapshot_a via scoped cleanup. + state_sync.delete_expired_snapshots( + batch_range=ExpiredBatchRange.all_batch_range(), + ignore_ttl=True, + target_snapshot_ids=[snapshot_a.snapshot_id], + ) + + # snapshot_a should be deleted, snapshot_b should remain. + assert not state_sync.get_snapshots([snapshot_a]) + assert state_sync.get_snapshots([snapshot_b])