"""Burn-in lifecycle tests covering the DB helpers in burnin._common, plus the public surface of start_job + cancel_job that doesn't require spinning up _run_job (which would need a mocked TrueNASClient + SSH). These are the safety net Codex flagged was missing — the orchestration paths were entirely untested before this. Run inside the container image so app deps (aiosqlite, pydantic-settings, bcrypt) are present. """ from __future__ import annotations import os import tempfile import unittest import aiosqlite async def _setup_temp_db() -> str: """Same pattern as test_unlock_flow.py — temp DB + init_db, returning the path. Caller must unlink in tearDown.""" fd, path = tempfile.mkstemp(suffix=".db") os.close(fd) from app.config import settings settings.db_path = path from app.database import init_db await init_db() # Seed two drives so start_job has something to attach to. async with aiosqlite.connect(path) as db: await db.execute(""" INSERT INTO drives (truenas_disk_id, devname, serial, model, size_bytes, temperature_c, smart_health, last_seen_at, last_polled_at) VALUES ('id-1', 'sda', 'SER1', 'TestModel', 1000, 30, 'PASSED', '2026-05-03T00:00:00+00:00', '2026-05-03T00:00:00+00:00') """) await db.execute(""" INSERT INTO drives (truenas_disk_id, devname, serial, model, size_bytes, temperature_c, smart_health, last_seen_at, last_polled_at) VALUES ('id-2', 'sdb', 'SER2', 'TestModel', 1000, 30, 'PASSED', '2026-05-03T00:00:00+00:00', '2026-05-03T00:00:00+00:00') """) await db.commit() return path class TestCommonHelpers(unittest.IsolatedAsyncioTestCase): """The per-stage DB mutators in app.burnin._common — pure SQLite writes, no asyncssh, no orchestration. Trivially regression-testable.""" async def asyncSetUp(self): self.db_path = await _setup_temp_db() # Insert a queued job + 2 stages we can mutate. async with aiosqlite.connect(self.db_path) as db: cur = await db.execute( """INSERT INTO burnin_jobs (drive_id, profile, state, percent, operator, created_at) VALUES (?,?,?,?,?,?) RETURNING id""", (1, "full", "running", 0, "test", "2026-05-03T00:00:00+00:00"), ) self.job_id = (await cur.fetchone())[0] for stage_name in ("precheck", "surface_validate", "final_check"): await db.execute( "INSERT INTO burnin_stages (burnin_job_id, stage_name, state) VALUES (?,?,?)", (self.job_id, stage_name, "pending"), ) await db.commit() async def asyncTearDown(self): try: os.unlink(self.db_path) except OSError: pass async def test_start_stage_marks_running(self): from app.burnin import _common await _common._start_stage(self.job_id, "precheck") async with aiosqlite.connect(self.db_path) as db: db.row_factory = aiosqlite.Row cur = await db.execute( "SELECT state, started_at FROM burnin_stages " "WHERE burnin_job_id=? AND stage_name='precheck'", (self.job_id,), ) row = await cur.fetchone() self.assertEqual(row["state"], "running") self.assertIsNotNone(row["started_at"]) async def test_finish_stage_success_records_duration(self): from app.burnin import _common await _common._start_stage(self.job_id, "precheck") await _common._finish_stage(self.job_id, "precheck", success=True) async with aiosqlite.connect(self.db_path) as db: db.row_factory = aiosqlite.Row cur = await db.execute( "SELECT state, percent, duration_seconds FROM burnin_stages " "WHERE burnin_job_id=? AND stage_name='precheck'", (self.job_id,), ) row = await cur.fetchone() self.assertEqual(row["state"], "passed") self.assertEqual(row["percent"], 100) # Duration is float seconds since started_at — should be tiny but >0. self.assertIsNotNone(row["duration_seconds"]) self.assertGreaterEqual(row["duration_seconds"], 0) async def test_finish_stage_failure_carries_error_text(self): from app.burnin import _common await _common._start_stage(self.job_id, "surface_validate") await _common._finish_stage( self.job_id, "surface_validate", success=False, error_text="mock failure", ) async with aiosqlite.connect(self.db_path) as db: db.row_factory = aiosqlite.Row cur = await db.execute( "SELECT state, percent, error_text FROM burnin_stages " "WHERE burnin_job_id=? AND stage_name='surface_validate'", (self.job_id,), ) row = await cur.fetchone() self.assertEqual(row["state"], "failed") self.assertIsNone(row["percent"]) self.assertEqual(row["error_text"], "mock failure") async def test_finish_stage_preserves_existing_error(self): """When called with error_text=None, the existing column value from _set_stage_error must be preserved (not overwritten with NULL). This is the bug that 1.0.0-12-ish fixed.""" from app.burnin import _common await _common._start_stage(self.job_id, "surface_validate") await _common._set_stage_error( self.job_id, "surface_validate", "set by stage", ) await _common._finish_stage( self.job_id, "surface_validate", success=False, error_text=None, ) async with aiosqlite.connect(self.db_path) as db: cur = await db.execute( "SELECT error_text FROM burnin_stages " "WHERE burnin_job_id=? AND stage_name='surface_validate'", (self.job_id,), ) row = await cur.fetchone() self.assertEqual(row[0], "set by stage") async def test_recalculate_progress_weights_correctly(self): from app.burnin import _common # Mark precheck passed, surface_validate at 50% running. await _common._start_stage(self.job_id, "precheck") await _common._finish_stage(self.job_id, "precheck", success=True) await _common._start_stage(self.job_id, "surface_validate") await _common._update_stage_percent(self.job_id, "surface_validate", 50) await _common._recalculate_progress(self.job_id) async with aiosqlite.connect(self.db_path) as db: db.row_factory = aiosqlite.Row cur = await db.execute( "SELECT percent, stage_name FROM burnin_jobs WHERE id=?", (self.job_id,), ) row = await cur.fetchone() # Weights: precheck=5, surface=65, final=5. Total = 75 across these # 3 stages. Completed = 5 (precheck) + 32.5 (half of 65) = 37.5. # 37.5 / 75 = 50%. self.assertEqual(row["percent"], 50) self.assertEqual(row["stage_name"], "surface_validate") async def test_is_cancelled_reads_job_state(self): from app.burnin import _common self.assertFalse(await _common._is_cancelled(self.job_id)) async with aiosqlite.connect(self.db_path) as db: await db.execute( "UPDATE burnin_jobs SET state='cancelled' WHERE id=?", (self.job_id,), ) await db.commit() self.assertTrue(await _common._is_cancelled(self.job_id)) async def test_append_stage_log_concatenates(self): from app.burnin import _common await _common._append_stage_log(self.job_id, "precheck", "alpha\n") await _common._append_stage_log(self.job_id, "precheck", "beta\n") async with aiosqlite.connect(self.db_path) as db: cur = await db.execute( "SELECT log_text FROM burnin_stages " "WHERE burnin_job_id=? AND stage_name='precheck'", (self.job_id,), ) row = await cur.fetchone() self.assertEqual(row[0], "alpha\nbeta\n") class TestStartCancelJob(unittest.IsolatedAsyncioTestCase): """start_job + cancel_job touch the burnin orchestrator state. We spawn _run_job tasks that try to acquire the semaphore — we cancel immediately after to avoid running real burn-in stages. The real test value here is "did start_job create the right DB rows" and "does cancel_job mark them correctly.""" async def asyncSetUp(self): self.db_path = await _setup_temp_db() # Initialise burnin without a real TrueNASClient — pass None. # _run_job will hit the assert at top, but the test cancels # before _run_job's first await actually runs. from app import burnin burnin._unlock_grants.clear() burnin._active_tasks.clear() import asyncio burnin._semaphore = asyncio.Semaphore(2) burnin._client = None # unused by start_job itself async def asyncTearDown(self): # Cancel any outstanding tasks so they don't bleed into later tests. from app import burnin for t in list(burnin._active_tasks.values()): t.cancel() try: os.unlink(self.db_path) except OSError: pass async def test_start_job_inserts_queued_row_and_stages(self): from app import burnin job_id = await burnin.start_job(1, "surface", "test") async with aiosqlite.connect(self.db_path) as db: db.row_factory = aiosqlite.Row cur = await db.execute( "SELECT state, profile, operator FROM burnin_jobs WHERE id=?", (job_id,), ) row = await cur.fetchone() cur = await db.execute( "SELECT stage_name FROM burnin_stages " "WHERE burnin_job_id=? ORDER BY id", (job_id,), ) stages = [r[0] for r in await cur.fetchall()] # State should be queued OR running (the spawned _run_job may # have raced into the semaphore by now). self.assertIn(row["state"], ("queued", "running")) self.assertEqual(row["profile"], "surface") self.assertEqual(row["operator"], "test") # surface profile = precheck + surface_validate + final_check. self.assertEqual(stages, ["precheck", "surface_validate", "final_check"]) async def test_start_job_rejects_duplicate_active(self): from app import burnin await burnin.start_job(1, "surface", "test") # Second start on the same drive should be refused at the # ValueError level (caught by the inline duplicate check or by # the partial unique index). with self.assertRaises(ValueError): await burnin.start_job(1, "surface", "test") async def test_cancel_job_marks_state(self): from app import burnin job_id = await burnin.start_job(1, "surface", "test") ok = await burnin.cancel_job(job_id, "test") self.assertTrue(ok) async with aiosqlite.connect(self.db_path) as db: cur = await db.execute( "SELECT state FROM burnin_jobs WHERE id=?", (job_id,) ) row = await cur.fetchone() self.assertEqual(row[0], "cancelled") async def test_cancel_job_returns_false_for_terminal_state(self): from app import burnin # Create a passed job manually async with aiosqlite.connect(self.db_path) as db: cur = await db.execute( """INSERT INTO burnin_jobs (drive_id, profile, state, operator, created_at) VALUES (?,?,?,?,?) RETURNING id""", (2, "surface", "passed", "x", "2026-05-03T00:00:00+00:00"), ) job_id = (await cur.fetchone())[0] await db.commit() ok = await burnin.cancel_job(job_id, "test") self.assertFalse(ok) class TestRateLimiter(unittest.TestCase): """The generic rate-limit class added in 1.0.0-33 for the unlock + password-change endpoints.""" def test_register_allows_under_threshold(self): from app.auth import _RateLimiter rl = _RateLimiter("test", threshold=3, window_s=60, lockout_s=60) self.assertEqual(rl.register(("k", "alice")), "ok") self.assertEqual(rl.register(("k", "alice")), "ok") def test_register_trips_at_threshold(self): from app.auth import _RateLimiter rl = _RateLimiter("test", threshold=3, window_s=60, lockout_s=60) self.assertEqual(rl.register(("k", "alice")), "ok") self.assertEqual(rl.register(("k", "alice")), "ok") # 3rd attempt brings us to threshold — counts as the trip. self.assertEqual(rl.register(("k", "alice")), "now_locked_out") # 4th sees the lockout from the prior call. self.assertEqual(rl.register(("k", "alice")), "locked_out") def test_clear_removes_counter_and_lockout(self): from app.auth import _RateLimiter rl = _RateLimiter("test", threshold=2, window_s=60, lockout_s=60) rl.register(("k", "alice")) rl.register(("k", "alice")) # trips self.assertIsNotNone(rl.locked_until(("k", "alice"))) rl.clear(("k", "alice")) self.assertIsNone(rl.locked_until(("k", "alice"))) # Subsequent register should start fresh. self.assertEqual(rl.register(("k", "alice")), "ok") def test_separate_keys_dont_interfere(self): from app.auth import _RateLimiter rl = _RateLimiter("test", threshold=2, window_s=60, lockout_s=60) rl.register(("k", "alice")) rl.register(("k", "alice")) # trips alice # Bob's attempt should be allowed and untouched by alice's lockout. self.assertEqual(rl.register(("k", "bob")), "ok") self.assertIsNone(rl.locked_until(("k", "bob"))) if __name__ == "__main__": unittest.main()