"""Shared helpers for the burnin package. Lives below stages.py / task.py / __init__.py — these all import from here. _common itself imports nothing from sibling burnin modules so we stay free of circular-import landmines. Owns: * Stage configuration constants (STAGE_ORDER, _STAGE_BASE_WEIGHTS, POLL_INTERVAL). * The connection-helper context manager `_db()` and the `_now()` ISO timestamp helper used everywhere. * Per-stage DB mutators called by stage implementations and by the job orchestrator (`_start_stage`, `_finish_stage`, `_cancel_stage`, `_set_stage_error`, `_update_stage_percent`, `_update_stage_bad_blocks`, `_append_stage_log`). * Drive-row mutators for SMART caches (`_store_smart_attrs`, `_store_smart_raw_output`). * The job-state read (`_is_cancelled`) + progress aggregator (`_recalculate_progress`). * SSE notifier (`_push_update`). """ from __future__ import annotations import json import logging from contextlib import asynccontextmanager from datetime import datetime, timezone import aiosqlite from app.config import settings log = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Stage configuration # --------------------------------------------------------------------------- STAGE_ORDER: dict[str, list[str]] = { # Legacy "quick": ["precheck", "short_smart", "io_validate", "final_check"], # Single-stage selectable profiles "surface": ["precheck", "surface_validate", "final_check"], "short": ["precheck", "short_smart", "final_check"], "long": ["precheck", "long_smart", "final_check"], # Two-stage combos "surface_short": ["precheck", "surface_validate", "short_smart", "final_check"], "surface_long": ["precheck", "surface_validate", "long_smart", "final_check"], "short_long": ["precheck", "short_smart", "long_smart", "final_check"], # All three "full": ["precheck", "surface_validate", "short_smart", "long_smart", "final_check"], } # Per-stage base weights used to compute overall job % progress dynamically _STAGE_BASE_WEIGHTS: dict[str, int] = { "precheck": 5, "surface_validate": 65, "short_smart": 12, "long_smart": 13, "io_validate": 10, "final_check": 5, } POLL_INTERVAL = 5.0 # seconds between progress checks during active stages # --------------------------------------------------------------------------- # Connection helpers # --------------------------------------------------------------------------- def _now() -> str: return datetime.now(timezone.utc).isoformat() @asynccontextmanager async def _db(): """Open a WAL-mode connection with busy_timeout so writers wait for the lock instead of immediately raising 'database is locked' under contention.""" async with aiosqlite.connect(settings.db_path) as db: await db.execute("PRAGMA busy_timeout=10000") yield db # --------------------------------------------------------------------------- # Job / stage DB mutators # --------------------------------------------------------------------------- async def _is_cancelled(job_id: int) -> bool: async with _db() as db: cur = await db.execute("SELECT state FROM burnin_jobs WHERE id=?", (job_id,)) row = await cur.fetchone() return bool(row and row[0] == "cancelled") async def _start_stage(job_id: int, stage_name: str) -> None: async with _db() as db: await db.execute("PRAGMA journal_mode=WAL") await db.execute( "UPDATE burnin_stages SET state='running', started_at=? WHERE burnin_job_id=? AND stage_name=?", (_now(), job_id, stage_name), ) await db.execute( "UPDATE burnin_jobs SET stage_name=? WHERE id=?", (stage_name, job_id), ) await db.commit() async def _finish_stage(job_id: int, stage_name: str, success: bool, error_text: str | None = None) -> None: now = _now() state = "passed" if success else "failed" async with _db() as db: await db.execute("PRAGMA journal_mode=WAL") cur = await db.execute( "SELECT started_at FROM burnin_stages WHERE burnin_job_id=? AND stage_name=?", (job_id, stage_name), ) row = await cur.fetchone() duration = None if row and row[0]: try: start = datetime.fromisoformat(row[0]) if start.tzinfo is None: start = start.replace(tzinfo=timezone.utc) duration = (datetime.now(timezone.utc) - start).total_seconds() except Exception: pass # Only overwrite error_text if one is passed; otherwise preserve what the stage already wrote if error_text is not None: await db.execute( """UPDATE burnin_stages SET state=?, percent=?, finished_at=?, duration_seconds=?, error_text=? WHERE burnin_job_id=? AND stage_name=?""", (state, 100 if success else None, now, duration, error_text, job_id, stage_name), ) else: await db.execute( """UPDATE burnin_stages SET state=?, percent=?, finished_at=?, duration_seconds=? WHERE burnin_job_id=? AND stage_name=?""", (state, 100 if success else None, now, duration, job_id, stage_name), ) await db.commit() async def _update_stage_percent(job_id: int, stage_name: str, pct: int) -> None: async with _db() as db: await db.execute("PRAGMA journal_mode=WAL") await db.execute( "UPDATE burnin_stages SET percent=? WHERE burnin_job_id=? AND stage_name=?", (pct, job_id, stage_name), ) await db.commit() async def _cancel_stage(job_id: int, stage_name: str) -> None: now = _now() async with _db() as db: await db.execute("PRAGMA journal_mode=WAL") await db.execute( "UPDATE burnin_stages SET state='cancelled', finished_at=? WHERE burnin_job_id=? AND stage_name=?", (now, job_id, stage_name), ) await db.commit() async def _append_stage_log(job_id: int, stage_name: str, text: str) -> None: """Append text to the log_text column of a burnin_stages row.""" async with _db() as db: await db.execute("PRAGMA journal_mode=WAL") await db.execute( """UPDATE burnin_stages SET log_text = COALESCE(log_text, '') || ? WHERE burnin_job_id=? AND stage_name=?""", (text, job_id, stage_name), ) await db.commit() async def _update_stage_bad_blocks(job_id: int, stage_name: str, count: int) -> None: async with _db() as db: await db.execute("PRAGMA journal_mode=WAL") await db.execute( "UPDATE burnin_stages SET bad_blocks=? WHERE burnin_job_id=? AND stage_name=?", (count, job_id, stage_name), ) await db.commit() async def _store_smart_attrs(drive_id: int, attrs: dict) -> None: """Persist latest SMART attribute dict to drives.smart_attrs (JSON).""" # Convert int keys to str for JSON serialisation serialisable = {str(k): v for k, v in attrs.get("attributes", {}).items()} blob = json.dumps({ "health": attrs.get("health", "UNKNOWN"), "attrs": serialisable, "warnings": attrs.get("warnings", []), "failures": attrs.get("failures", []), }) async with _db() as db: await db.execute("PRAGMA journal_mode=WAL") await db.execute("UPDATE drives SET smart_attrs=? WHERE id=?", (blob, drive_id)) await db.commit() async def _store_smart_raw_output(drive_id: int, test_type: str, raw: str) -> None: """Store raw smartctl output in smart_tests.raw_output.""" async with _db() as db: await db.execute("PRAGMA journal_mode=WAL") await db.execute( "UPDATE smart_tests SET raw_output=? WHERE drive_id=? AND test_type=?", (raw, drive_id, test_type.lower()), ) await db.commit() async def _set_stage_error(job_id: int, stage_name: str, error_text: str) -> None: async with _db() as db: await db.execute("PRAGMA journal_mode=WAL") await db.execute( "UPDATE burnin_stages SET error_text=? WHERE burnin_job_id=? AND stage_name=?", (error_text, job_id, stage_name), ) await db.commit() async def _recalculate_progress(job_id: int, profile: str | None = None) -> None: """Recompute overall job % from actual stage rows. profile param is unused (kept for compat).""" async with _db() as db: db.row_factory = aiosqlite.Row await db.execute("PRAGMA journal_mode=WAL") cur = await db.execute( "SELECT stage_name, state, percent FROM burnin_stages WHERE burnin_job_id=? ORDER BY id", (job_id,), ) stages = await cur.fetchall() if not stages: return total_weight = sum(_STAGE_BASE_WEIGHTS.get(s["stage_name"], 5) for s in stages) if total_weight == 0: return completed = 0.0 current = None for s in stages: w = _STAGE_BASE_WEIGHTS.get(s["stage_name"], 5) st = s["state"] if st == "passed": completed += w elif st == "running": completed += w * (s["percent"] or 0) / 100 current = s["stage_name"] pct = int(completed / total_weight * 100) await db.execute( "UPDATE burnin_jobs SET percent=?, stage_name=? WHERE id=?", (pct, current, job_id), ) await db.commit() # --------------------------------------------------------------------------- # SSE notifier # --------------------------------------------------------------------------- def _push_update(alert: dict | None = None) -> None: """Notify SSE subscribers that data has changed, with optional browser notification payload.""" try: from app import poller poller._notify_subscribers(alert=alert) except Exception: pass