diff --git a/app/burnin/__init__.py b/app/burnin/__init__.py index 62e869c..681e5f0 100644 --- a/app/burnin/__init__.py +++ b/app/burnin/__init__.py @@ -29,36 +29,17 @@ from app.truenas import TrueNASClient log = logging.getLogger(__name__) -# --------------------------------------------------------------------------- -# Stage definitions -# --------------------------------------------------------------------------- +# Stage configuration + DB helpers extracted to _common.py in 1.0.0-31. +from ._common import ( # noqa: E402 + STAGE_ORDER, _STAGE_BASE_WEIGHTS, POLL_INTERVAL, + _now, _db, + _is_cancelled, + _start_stage, _finish_stage, _cancel_stage, _set_stage_error, + _update_stage_percent, _update_stage_bad_blocks, _append_stage_log, + _store_smart_attrs, _store_smart_raw_output, + _recalculate_progress, _push_update, +) -STAGE_ORDER: dict[str, list[str]] = { - # Legacy - "quick": ["precheck", "short_smart", "io_validate", "final_check"], - # Single-stage selectable profiles - "surface": ["precheck", "surface_validate", "final_check"], - "short": ["precheck", "short_smart", "final_check"], - "long": ["precheck", "long_smart", "final_check"], - # Two-stage combos - "surface_short": ["precheck", "surface_validate", "short_smart", "final_check"], - "surface_long": ["precheck", "surface_validate", "long_smart", "final_check"], - "short_long": ["precheck", "short_smart", "long_smart", "final_check"], - # All three - "full": ["precheck", "surface_validate", "short_smart", "long_smart", "final_check"], -} - -# Per-stage base weights used to compute overall job % progress dynamically -_STAGE_BASE_WEIGHTS: dict[str, int] = { - "precheck": 5, - "surface_validate": 65, - "short_smart": 12, - "long_smart": 13, - "io_validate": 10, - "final_check": 5, -} - -POLL_INTERVAL = 5.0 # seconds between progress checks during active stages # --------------------------------------------------------------------------- # Module-level state (initialized in init()) @@ -97,17 +78,7 @@ _is_unlocked = _unlock.is_unlocked # legacy private name _kill_remote_process = _kill.kill_remote_process -def _now() -> str: - return datetime.now(timezone.utc).isoformat() - - -@asynccontextmanager -async def _db(): - """Open a WAL-mode connection with busy_timeout so writers wait for the lock - instead of immediately raising 'database is locked' under contention.""" - async with aiosqlite.connect(settings.db_path) as db: - await db.execute("PRAGMA busy_timeout=10000") - yield db +# _now() and _db() are re-exported from _common above. # --------------------------------------------------------------------------- @@ -533,891 +504,31 @@ async def _execute_stages(job_id: int, stages: list[str], devname: str, drive_id return True -async def _dispatch_stage(job_id: int, stage_name: str, devname: str, drive_id: int) -> bool: - if stage_name == "precheck": - return await _stage_precheck(job_id, drive_id) - elif stage_name == "short_smart": - return await _stage_smart_test(job_id, devname, "SHORT", "short_smart", drive_id) - elif stage_name == "long_smart": - return await _stage_smart_test(job_id, devname, "LONG", "long_smart", drive_id) - elif stage_name == "surface_validate": - return await _stage_surface_validate(job_id, devname, drive_id) - elif stage_name == "io_validate": - return await _stage_timed_simulate(job_id, "io_validate", settings.io_validate_seconds) - elif stage_name == "final_check": - return await _stage_final_check(job_id, devname, drive_id) - return True +# Per-stage implementations and the dispatch router live in stages.py. +from .stages import ( # noqa: E402 + _dispatch_stage, + _badblocks_available, + _nvme_cli_available, + _stage_precheck, + _stage_smart_test, + _stage_smart_test_api, + _stage_smart_test_ssh, + _stage_surface_validate, + _stage_surface_validate_nvme, + _stage_surface_validate_ssh, + _stage_surface_validate_truenas, + _stage_timed_simulate, + _stage_final_check, +) -# --------------------------------------------------------------------------- -# Individual stage implementations -# --------------------------------------------------------------------------- - -async def _stage_precheck(job_id: int, drive_id: int) -> bool: - """Check SMART health and temperature before starting destructive work.""" - async with _db() as db: - cur = await db.execute( - "SELECT smart_health, temperature_c FROM drives WHERE id=?", (drive_id,) - ) - row = await cur.fetchone() - - if not row: - return False - - health, temp = row[0], row[1] - - if health == "FAILED": - await _set_stage_error(job_id, "precheck", "Drive SMART health is FAILED — refusing to burn in") - return False - - if temp and temp > settings.temp_crit_c: - await _set_stage_error(job_id, "precheck", f"Drive temperature {temp}°C exceeds {settings.temp_crit_c}°C limit") - return False - - await asyncio.sleep(1) # Simulate brief check - return True - - -async def _stage_smart_test(job_id: int, devname: str, test_type: str, stage_name: str, - drive_id: int | None = None) -> bool: - """Start a SMART test. Uses SSH if configured, TrueNAS REST API otherwise.""" - from app import ssh_client - if ssh_client.is_configured(): - return await _stage_smart_test_ssh(job_id, devname, test_type, stage_name, drive_id) - return await _stage_smart_test_api(job_id, devname, test_type, stage_name) - - -async def _stage_smart_test_api(job_id: int, devname: str, test_type: str, stage_name: str) -> bool: - """TrueNAS REST API path for SMART test (mock / dev mode).""" - tn_job_id = await _client.start_smart_test([devname], test_type) - - while True: - if await _is_cancelled(job_id): - try: - await _client.abort_job(tn_job_id) - except Exception: - pass - return False - - jobs = await _client.get_smart_jobs() - job = next((j for j in jobs if j["id"] == tn_job_id), None) - - if not job: - return False - - state = job["state"] - pct = job["progress"]["percent"] - - await _update_stage_percent(job_id, stage_name, pct) - await _recalculate_progress(job_id, None) - _push_update() - - if state == "SUCCESS": - return True - elif state in ("FAILED", "ABORTED"): - await _set_stage_error(job_id, stage_name, - job.get("error") or f"SMART {test_type} test failed") - return False - - await asyncio.sleep(POLL_INTERVAL) - - -async def _stage_smart_test_ssh(job_id: int, devname: str, test_type: str, stage_name: str, - drive_id: int | None) -> bool: - """SSH path for SMART test — runs smartctl directly on TrueNAS.""" - from app import ssh_client - - # Start the test - try: - startup = await ssh_client.start_smart_test(devname, test_type) - await _append_stage_log(job_id, stage_name, startup + "\n") - except Exception as exc: - await _set_stage_error(job_id, stage_name, f"Failed to start SMART test via SSH: {exc}") - return False - - # Brief pause to let the test register in smartctl output - await asyncio.sleep(3) - - # Throttle log_text appends — every poll on a multi-hour long_smart bloated - # log_text to 50+ MB and triggered SQLite "database is locked" because each - # COALESCE-then-append rewrites the whole column. Append every ~60s, on the - # first poll, and on any state change. - LOG_EVERY_N_POLLS = 12 - poll_count = 0 - last_state: str | None = None - - # Poll until complete - while True: - if await _is_cancelled(job_id): - try: - await ssh_client.abort_smart_test(devname) - except Exception: - pass - return False - - await asyncio.sleep(POLL_INTERVAL) - - try: - progress = await ssh_client.poll_smart_progress(devname) - except Exception as exc: - log.warning("SSH SMART poll failed: %s", exc, extra={"job_id": job_id}) - await _append_stage_log(job_id, stage_name, f"[poll error] {exc}\n") - continue - - poll_count += 1 - state_changed = progress["state"] != last_state - last_state = progress["state"] - if poll_count == 1 or poll_count % LOG_EVERY_N_POLLS == 0 or state_changed: - await _append_stage_log(job_id, stage_name, progress["output"] + "\n---\n") - - if progress["state"] == "running": - pct = max(0, 100 - progress["percent_remaining"]) - await _update_stage_percent(job_id, stage_name, pct) - await _recalculate_progress(job_id) - _push_update() - - elif progress["state"] == "passed": - await _update_stage_percent(job_id, stage_name, 100) - # Run attribute check - if drive_id is not None: - try: - attrs = await ssh_client.get_smart_attributes(devname) - await _store_smart_attrs(drive_id, attrs) - await _store_smart_raw_output(drive_id, test_type, attrs["raw_output"]) - if attrs["failures"]: - error = "SMART attribute failures: " + "; ".join(attrs["failures"]) - await _set_stage_error(job_id, stage_name, error) - return False - if attrs["warnings"]: - await _append_stage_log( - job_id, stage_name, - "[WARNING] " + "; ".join(attrs["warnings"]) + "\n" - ) - except Exception as exc: - log.warning("Failed to retrieve SMART attributes: %s", exc) - await _recalculate_progress(job_id) - _push_update() - return True - - elif progress["state"] == "failed": - await _set_stage_error(job_id, stage_name, f"SMART {test_type} test failed") - return False - # "unknown" → keep polling - - -async def _badblocks_available() -> bool: - """Check if badblocks is installed on the remote host (Linux/SCALE only).""" - from app import ssh_client - try: - async with await ssh_client._connect() as conn: - result = await conn.run("which badblocks", check=False) - return result.returncode == 0 - except Exception: - return False - - -async def _stage_surface_validate(job_id: int, devname: str, drive_id: int) -> bool: - """ - Surface validation stage — auto-routes to the right implementation: - - 1. NVMe device + SSH + nvme-cli available (TrueNAS SCALE): - → `nvme format -s 1 /dev/{devname}` (cryptographic erase). - Far faster than badblocks on NVMe (seconds vs hours) and - exercises the controller's secure-erase path, not just user-LBA - writes. - 2. SSH configured + badblocks available (TrueNAS SCALE / Linux): - → badblocks -wsv -b N -c N -p N /dev/{devname} directly over SSH. - 3. SSH configured + badblocks NOT available (TrueNAS CORE / FreeBSD): - → uses TrueNAS REST API disk.wipe FULL job + post-wipe SMART check. - 4. No SSH: - → simulated timed progress (dev/mock mode). - """ - from app import ssh_client - if ssh_client.is_configured(): - if devname.startswith("nvme") and await _nvme_cli_available(): - return await _stage_surface_validate_nvme(job_id, devname, drive_id) - if await _badblocks_available(): - return await _stage_surface_validate_ssh(job_id, devname, drive_id) - # TrueNAS CORE/FreeBSD: badblocks not available — use native wipe API - await _append_stage_log( - job_id, "surface_validate", - "[INFO] badblocks not found on host (TrueNAS CORE/FreeBSD) — " - "using TrueNAS disk.wipe API (FULL write pass).\n\n" - ) - return await _stage_surface_validate_truenas(job_id, devname, drive_id) - return await _stage_timed_simulate(job_id, "surface_validate", settings.surface_validate_seconds) - - -async def _nvme_cli_available() -> bool: - """Check if nvme-cli is installed on the remote host.""" - from app import ssh_client - try: - async with await ssh_client._connect() as conn: - r = await conn.run("which nvme", check=False) - return r.returncode == 0 - except Exception: - return False - - -async def _stage_surface_validate_nvme(job_id: int, devname: str, - drive_id: int) -> bool: - """NVMe destructive surface test via `nvme format -s 1` (crypto erase). - - Crypto-erase nukes the data encryption key on the drive's controller, - rendering all stored data unrecoverable in milliseconds; the actual - flash is then implicitly trim-able. This is the canonical destructive - burn-in for NVMe — badblocks would write the entire LBA space, which - is slower AND wears the flash unnecessarily. - - Post-format we re-read SMART attributes; the drive should report all - counters reset (life used + spare) and PASSED health. - """ - from app import ssh_client - - await _append_stage_log( - job_id, "surface_validate", - f"[START] nvme format -s 1 /dev/{devname}\n" - f"[NOTE] Cryptographic erase — destroys all data on /dev/{devname}.\n\n" - ) - - cmd = f"nvme format -s 1 --force /dev/{devname}" - try: - async with await ssh_client._connect() as conn: - r = await asyncio.wait_for( - conn.run(cmd, check=False), timeout=600 - ) - except Exception as exc: - await _append_stage_log( - job_id, "surface_validate", f"\n[SSH error] {exc}\n" - ) - await _set_stage_error( - job_id, "surface_validate", f"NVMe format SSH error: {exc}" - ) - return False - - output = (r.stdout or "") + (r.stderr or "") - await _append_stage_log(job_id, "surface_validate", output + "\n") - - if r.returncode != 0: - await _set_stage_error( - job_id, "surface_validate", - f"nvme format exited {r.returncode}: {output.strip()[:200]}" - ) - return False - - # Sanity-check post-format SMART health. Mirrors the surface_validate - # SSH path's check parity — fail on FAILED health, fail on real - # SMART attribute failures, log warnings but don't fail. A transport - # error here is treated as a soft pass (log + continue) so a single - # SSH blip after a successful format doesn't undo the work. - try: - attrs = await ssh_client.get_smart_attributes(devname) - ssh_only_failures = [ - f for f in (attrs.get("failures") or []) if f.startswith("SSH error:") - ] - real_failures = [ - f for f in (attrs.get("failures") or []) if not f.startswith("SSH error:") - ] - if attrs.get("health") == "FAILED": - await _set_stage_error( - job_id, "surface_validate", - "NVMe SMART health FAILED after format", - ) - return False - if real_failures: - await _set_stage_error( - job_id, "surface_validate", - "NVMe SMART attribute failures after format: " - + "; ".join(real_failures), - ) - return False - if ssh_only_failures: - await _append_stage_log( - job_id, "surface_validate", - "[WARN] post-format SMART check had SSH errors " - "(soft-passing): " + "; ".join(ssh_only_failures) + "\n", - ) - if attrs.get("warnings"): - await _append_stage_log( - job_id, "surface_validate", - "[WARN] " + "; ".join(attrs["warnings"]) + "\n", - ) - except Exception as exc: - log.warning("Post-format SMART check error on %s: %s", devname, exc) - await _append_stage_log( - job_id, "surface_validate", - f"[WARN] post-format SMART check raised: {exc}\n", - ) - - await _update_stage_percent(job_id, "surface_validate", 100) - await _recalculate_progress(job_id) - _push_update() - return True - - -async def _stage_surface_validate_ssh(job_id: int, devname: str, drive_id: int) -> bool: - """Run badblocks over SSH, streaming output to stage log.""" - from app import ssh_client - - await _append_stage_log( - job_id, "surface_validate", - f"[START] badblocks -wsv -b {settings.surface_validate_block_size} " - f"-c {settings.surface_validate_block_buffer} " - f"-p {settings.surface_validate_passes} /dev/{devname}\n" - f"[NOTE] This is a DESTRUCTIVE write test. " - f"All data on /dev/{devname} will be overwritten.\n\n" - ) - - def _is_cancelled_sync() -> bool: - # Synchronous version — we check the DB state flag set by cancel_job() - import asyncio - loop = asyncio.get_event_loop() - try: - return loop.run_until_complete(_is_cancelled(job_id)) - except Exception: - return False - - last_logged_pct = [-1] - - def on_progress(pct: int, bad_blocks: int, line: str) -> None: - nonlocal last_logged_pct - # Write to log (fire-and-forget via asyncio.create_task from sync context) - # The log append is done in the async flush below - pass - - accumulated_lines: list[str] = [] - - async def on_progress_async(pct: int, bad_blocks: int, line: str) -> None: - accumulated_lines.append(line) - # Flush to DB and update progress every ~25 lines to avoid excessive DB writes - if len(accumulated_lines) % 25 == 0: - await _append_stage_log(job_id, "surface_validate", "".join(accumulated_lines[-25:])) - await _update_stage_bad_blocks(job_id, "surface_validate", bad_blocks) - await _update_stage_percent(job_id, "surface_validate", pct) - await _recalculate_progress(job_id) - _push_update() - if await _is_cancelled(job_id): - raise asyncio.CancelledError - - # Run badblocks — we adapt the callback pattern to async by collecting then flushing - result = {"bad_blocks": 0, "output": "", "aborted": False} - try: - # The actual streaming; we handle progress via the accumulated_lines pattern - bad_blocks_total = 0 - output_lines: list[str] = [] - - async with await ssh_client._connect() as conn: - # Wrap in `sh -c 'echo PID:$$; exec ...'` so we get the remote - # PID on the first stdout line. asyncssh's proc.kill() sends an - # SSH signal request that OpenSSH's sshd ignores by default, so - # we need the PID to issue an out-of-band `kill -9` over a fresh - # session when we want to abort. - # - # Block geometry is operator-tunable (Settings → Burn-in): - # -b N block size in bytes (settings.surface_validate_block_size) - # -c N blocks held per IO (settings.surface_validate_block_buffer) - # -p N pass count (settings.surface_validate_passes) - # Defaults preserve original behavior (-b 4096 -c 64 -p 1). - bb_args = ( - f"-wsv " - f"-b {settings.surface_validate_block_size} " - f"-c {settings.surface_validate_block_buffer} " - f"-p {settings.surface_validate_passes}" - ) - cmd = ( - f"sh -c 'echo PID:$$; exec badblocks {bb_args} /dev/{devname}'" - ) - async with conn.create_process(cmd) as proc: - import re as _re - - pid_seen = False - - async def _drain(stream, is_stderr: bool): - nonlocal bad_blocks_total, pid_seen - async for raw in stream: - line = raw if isinstance(raw, str) else raw.decode("utf-8", errors="replace") - - # First stdout line is "PID:" from the wrapping shell. - # Capture it and don't append it to the user-visible log. - if not is_stderr and not pid_seen and line.startswith("PID:"): - pid_seen = True - try: - _remote_pids[job_id] = int(line[4:].strip()) - log.info( - "Captured remote PID %d for job %d (badblocks)", - _remote_pids[job_id], job_id, - ) - except ValueError: - pass - continue - - output_lines.append(line) - - if is_stderr: - m = _re.search(r"([\d.]+)%\s+done", line) - if m: - pct = min(99, int(float(m.group(1)))) - await _update_stage_percent(job_id, "surface_validate", pct) - await _update_stage_bad_blocks(job_id, "surface_validate", bad_blocks_total) - await _recalculate_progress(job_id) - _push_update() - else: - stripped = line.strip() - if stripped and stripped.isdigit(): - bad_blocks_total += 1 - - # Append to DB log in chunks - if len(output_lines) % 20 == 0: - chunk = "".join(output_lines[-20:]) - await _append_stage_log(job_id, "surface_validate", chunk) - - # Abort on bad block threshold - if bad_blocks_total > settings.bad_block_threshold: - await _kill_remote_process(job_id) - output_lines.append( - f"\n[ABORTED] {bad_blocks_total} bad block(s) exceeded " - f"threshold ({settings.bad_block_threshold})\n" - ) - return - - if await _is_cancelled(job_id): - await _kill_remote_process(job_id) - return - - await asyncio.gather( - _drain(proc.stdout, False), - _drain(proc.stderr, True), - return_exceptions=True, - ) - # Bound proc.wait so a remote process that ignored our kill - # signal (or that we never managed to kill) can't pin this - # task in the semaphore forever. Closing the connection on - # exit will deliver SIGPIPE to the remote on its next write. - try: - await asyncio.wait_for(proc.wait(), timeout=15) - except asyncio.TimeoutError: - log.warning( - "proc.wait() timed out for job %d — abandoning channel", - job_id, - ) - - # Flush only lines we haven't already written in 20-line chunks. - # Previously we appended the FULL accumulated output here too, - # doubling the stored log_text size for every surface_validate - # stage and pushing app.db into hundreds of MB. - flushed_count = (len(output_lines) // 20) * 20 - tail = "".join(output_lines[flushed_count:]) - if tail: - await _append_stage_log(job_id, "surface_validate", tail) - result["bad_blocks"] = bad_blocks_total - result["output"] = "".join(output_lines) # in-memory only, not re-stored - result["aborted"] = bad_blocks_total > settings.bad_block_threshold - - except asyncio.CancelledError: - return False - except Exception as exc: - await _append_stage_log(job_id, "surface_validate", f"\n[SSH error] {exc}\n") - await _set_stage_error(job_id, "surface_validate", f"SSH badblocks error: {exc}") - return False - - await _update_stage_bad_blocks(job_id, "surface_validate", result["bad_blocks"]) - - if result["aborted"] or result["bad_blocks"] > settings.bad_block_threshold: - await _set_stage_error( - job_id, "surface_validate", - f"Surface validate FAILED: {result['bad_blocks']} bad block(s) found " - f"(threshold: {settings.bad_block_threshold})" - ) - return False - - return True - - -async def _stage_surface_validate_truenas(job_id: int, devname: str, drive_id: int) -> bool: - """ - Surface validation via TrueNAS CORE disk.wipe REST API. - Used on FreeBSD (TrueNAS CORE) where badblocks is unavailable. - - Sends a FULL write-zero pass across the entire disk, polls progress, - then runs a post-wipe SMART attribute check to catch reallocated sectors. - """ - from app import ssh_client - - await _append_stage_log( - job_id, "surface_validate", - f"[START] TrueNAS disk.wipe FULL — {devname}\n" - f"[NOTE] DESTRUCTIVE: all data on {devname} will be overwritten.\n\n" - ) - - # Start the wipe job - try: - tn_job_id = await _client.wipe_disk(devname, "FULL") - except Exception as exc: - await _set_stage_error(job_id, "surface_validate", f"Failed to start disk.wipe: {exc}") - return False - - await _append_stage_log( - job_id, "surface_validate", - f"[JOB] TrueNAS wipe job started (job_id={tn_job_id})\n" - ) - - # Poll until complete - log_flush_counter = 0 - while True: - if await _is_cancelled(job_id): - try: - await _client.abort_job(tn_job_id) - except Exception: - pass - return False - - await asyncio.sleep(POLL_INTERVAL) - - try: - job = await _client.get_job(tn_job_id) - except Exception as exc: - log.warning("Wipe job poll failed: %s", exc, extra={"job_id": job_id}) - await _append_stage_log(job_id, "surface_validate", f"[poll error] {exc}\n") - continue - - if not job: - await _set_stage_error(job_id, "surface_validate", f"Wipe job {tn_job_id} not found") - return False - - state = job.get("state", "") - pct = int(job.get("progress", {}).get("percent", 0) or 0) - desc = job.get("progress", {}).get("description", "") - - await _update_stage_percent(job_id, "surface_validate", min(pct, 99)) - await _recalculate_progress(job_id) - _push_update() - - # Log progress description every ~5 polls to avoid DB spam - log_flush_counter += 1 - if desc and log_flush_counter % 5 == 0: - await _append_stage_log(job_id, "surface_validate", f"[{pct}%] {desc}\n") - - if state == "SUCCESS": - await _update_stage_percent(job_id, "surface_validate", 100) - await _append_stage_log( - job_id, "surface_validate", - f"\n[DONE] Wipe job {tn_job_id} completed successfully.\n" - ) - # Post-wipe SMART check — catch any sectors that failed under write stress - if ssh_client.is_configured() and drive_id is not None: - await _append_stage_log( - job_id, "surface_validate", - "[CHECK] Running post-wipe SMART attribute check...\n" - ) - try: - attrs = await ssh_client.get_smart_attributes(devname) - await _store_smart_attrs(drive_id, attrs) - if attrs["failures"]: - error = "Post-wipe SMART check: " + "; ".join(attrs["failures"]) - await _set_stage_error(job_id, "surface_validate", error) - return False - if attrs["warnings"]: - await _append_stage_log( - job_id, "surface_validate", - "[WARNING] " + "; ".join(attrs["warnings"]) + "\n" - ) - await _append_stage_log( - job_id, "surface_validate", - f"[CHECK] SMART health: {attrs['health']} — no critical attributes.\n" - ) - except Exception as exc: - log.warning("Post-wipe SMART check failed: %s", exc) - await _append_stage_log( - job_id, "surface_validate", - f"[WARN] Post-wipe SMART check failed (non-fatal): {exc}\n" - ) - return True - - elif state in ("FAILED", "ABORTED", "ERROR"): - error_msg = job.get("error") or f"Disk wipe failed (state={state})" - await _set_stage_error( - job_id, "surface_validate", - f"TrueNAS disk.wipe FAILED: {error_msg}" - ) - return False - # RUNNING or WAITING — keep polling - - -async def _stage_timed_simulate(job_id: int, stage_name: str, duration_seconds: int) -> bool: - """Simulate a timed stage with progress updates (mock / dev mode).""" - start = time.monotonic() - - while True: - if await _is_cancelled(job_id): - return False - - elapsed = time.monotonic() - start - pct = min(100, int(elapsed / duration_seconds * 100)) - - await _update_stage_percent(job_id, stage_name, pct) - await _recalculate_progress(job_id, None) - _push_update() - - if pct >= 100: - return True - - await asyncio.sleep(POLL_INTERVAL) - - -async def _stage_final_check(job_id: int, devname: str, drive_id: int | None = None) -> bool: - """ - Verify drive passed all tests. - SSH mode: run smartctl -a and check critical attributes. - Mock mode: check SMART health field in DB. - - A transient SSH connectivity failure here must NOT invalidate a prior - multi-day surface_validate. Retry SSH-only failures, then soft-pass. - """ - await asyncio.sleep(1) - from app import ssh_client - - def _ssh_only(failures: list[str]) -> bool: - return bool(failures) and all(f.startswith("SSH error:") for f in failures) - - if ssh_client.is_configured() and drive_id is not None: - try: - attrs = await ssh_client.get_smart_attributes(devname) - for attempt in range(2): - if not _ssh_only(attrs.get("failures") or []): - break - log.warning( - "final_check SSH unreachable (attempt %d/3); retrying in 30s", - attempt + 1, - extra={"job_id": job_id, "devname": devname}, - ) - await asyncio.sleep(30) - attrs = await ssh_client.get_smart_attributes(devname) - - failures = attrs.get("failures") or [] - if _ssh_only(failures): - log.warning( - "final_check soft-pass: SSH unreachable after retries; prior stages stand", - extra={"job_id": job_id, "devname": devname, "ssh_error": failures}, - ) - return True - - await _store_smart_attrs(drive_id, attrs) - if attrs["health"] == "FAILED" or failures: - msg = failures or [f"SMART health: {attrs['health']}"] - await _set_stage_error(job_id, "final_check", - "Final check failed: " + "; ".join(msg)) - return False - return True - except Exception as exc: - log.warning("SSH final_check raised, falling back to DB check: %s", exc) - - # DB check (mock mode fallback) - async with _db() as db: - cur = await db.execute( - "SELECT smart_health FROM drives WHERE devname=?", (devname,) - ) - row = await cur.fetchone() - - if not row or row[0] == "FAILED": - await _set_stage_error(job_id, "final_check", "Drive SMART health is FAILED after burn-in") - return False - - return True # --------------------------------------------------------------------------- # DB helpers # --------------------------------------------------------------------------- -async def _is_cancelled(job_id: int) -> bool: - async with _db() as db: - cur = await db.execute("SELECT state FROM burnin_jobs WHERE id=?", (job_id,)) - row = await cur.fetchone() - return bool(row and row[0] == "cancelled") - - -async def _start_stage(job_id: int, stage_name: str) -> None: - async with _db() as db: - await db.execute("PRAGMA journal_mode=WAL") - await db.execute( - "UPDATE burnin_stages SET state='running', started_at=? WHERE burnin_job_id=? AND stage_name=?", - (_now(), job_id, stage_name), - ) - await db.execute( - "UPDATE burnin_jobs SET stage_name=? WHERE id=?", - (stage_name, job_id), - ) - await db.commit() - - -async def _finish_stage(job_id: int, stage_name: str, success: bool, error_text: str | None = None) -> None: - now = _now() - state = "passed" if success else "failed" - async with _db() as db: - await db.execute("PRAGMA journal_mode=WAL") - cur = await db.execute( - "SELECT started_at FROM burnin_stages WHERE burnin_job_id=? AND stage_name=?", - (job_id, stage_name), - ) - row = await cur.fetchone() - duration = None - if row and row[0]: - try: - start = datetime.fromisoformat(row[0]) - if start.tzinfo is None: - start = start.replace(tzinfo=timezone.utc) - duration = (datetime.now(timezone.utc) - start).total_seconds() - except Exception: - pass - - # Only overwrite error_text if one is passed; otherwise preserve what the stage already wrote - if error_text is not None: - await db.execute( - """UPDATE burnin_stages - SET state=?, percent=?, finished_at=?, duration_seconds=?, error_text=? - WHERE burnin_job_id=? AND stage_name=?""", - (state, 100 if success else None, now, duration, error_text, job_id, stage_name), - ) - else: - await db.execute( - """UPDATE burnin_stages - SET state=?, percent=?, finished_at=?, duration_seconds=? - WHERE burnin_job_id=? AND stage_name=?""", - (state, 100 if success else None, now, duration, job_id, stage_name), - ) - await db.commit() - - -async def _update_stage_percent(job_id: int, stage_name: str, pct: int) -> None: - async with _db() as db: - await db.execute("PRAGMA journal_mode=WAL") - await db.execute( - "UPDATE burnin_stages SET percent=? WHERE burnin_job_id=? AND stage_name=?", - (pct, job_id, stage_name), - ) - await db.commit() - - -async def _cancel_stage(job_id: int, stage_name: str) -> None: - now = _now() - async with _db() as db: - await db.execute("PRAGMA journal_mode=WAL") - await db.execute( - "UPDATE burnin_stages SET state='cancelled', finished_at=? WHERE burnin_job_id=? AND stage_name=?", - (now, job_id, stage_name), - ) - await db.commit() - - -async def _append_stage_log(job_id: int, stage_name: str, text: str) -> None: - """Append text to the log_text column of a burnin_stages row.""" - async with _db() as db: - await db.execute("PRAGMA journal_mode=WAL") - await db.execute( - """UPDATE burnin_stages - SET log_text = COALESCE(log_text, '') || ? - WHERE burnin_job_id=? AND stage_name=?""", - (text, job_id, stage_name), - ) - await db.commit() - - -async def _update_stage_bad_blocks(job_id: int, stage_name: str, count: int) -> None: - async with _db() as db: - await db.execute("PRAGMA journal_mode=WAL") - await db.execute( - "UPDATE burnin_stages SET bad_blocks=? WHERE burnin_job_id=? AND stage_name=?", - (count, job_id, stage_name), - ) - await db.commit() - - -async def _store_smart_attrs(drive_id: int, attrs: dict) -> None: - """Persist latest SMART attribute dict to drives.smart_attrs (JSON).""" - import json - # Convert int keys to str for JSON serialisation - serialisable = {str(k): v for k, v in attrs.get("attributes", {}).items()} - blob = json.dumps({ - "health": attrs.get("health", "UNKNOWN"), - "attrs": serialisable, - "warnings": attrs.get("warnings", []), - "failures": attrs.get("failures", []), - }) - async with _db() as db: - await db.execute("PRAGMA journal_mode=WAL") - await db.execute("UPDATE drives SET smart_attrs=? WHERE id=?", (blob, drive_id)) - await db.commit() - - -async def _store_smart_raw_output(drive_id: int, test_type: str, raw: str) -> None: - """Store raw smartctl output in smart_tests.raw_output.""" - async with _db() as db: - await db.execute("PRAGMA journal_mode=WAL") - await db.execute( - "UPDATE smart_tests SET raw_output=? WHERE drive_id=? AND test_type=?", - (raw, drive_id, test_type.lower()), - ) - await db.commit() - - -async def _set_stage_error(job_id: int, stage_name: str, error_text: str) -> None: - async with _db() as db: - await db.execute("PRAGMA journal_mode=WAL") - await db.execute( - "UPDATE burnin_stages SET error_text=? WHERE burnin_job_id=? AND stage_name=?", - (error_text, job_id, stage_name), - ) - await db.commit() - - -async def _recalculate_progress(job_id: int, profile: str | None = None) -> None: - """Recompute overall job % from actual stage rows. profile param is unused (kept for compat).""" - async with _db() as db: - db.row_factory = aiosqlite.Row - await db.execute("PRAGMA journal_mode=WAL") - - cur = await db.execute( - "SELECT stage_name, state, percent FROM burnin_stages WHERE burnin_job_id=? ORDER BY id", - (job_id,), - ) - stages = await cur.fetchall() - if not stages: - return - - total_weight = sum(_STAGE_BASE_WEIGHTS.get(s["stage_name"], 5) for s in stages) - if total_weight == 0: - return - - completed = 0.0 - current = None - for s in stages: - w = _STAGE_BASE_WEIGHTS.get(s["stage_name"], 5) - st = s["state"] - if st == "passed": - completed += w - elif st == "running": - completed += w * (s["percent"] or 0) / 100 - current = s["stage_name"] - - pct = int(completed / total_weight * 100) - await db.execute( - "UPDATE burnin_jobs SET percent=?, stage_name=? WHERE id=?", - (pct, current, job_id), - ) - await db.commit() - - -# --------------------------------------------------------------------------- -# SSE push -# --------------------------------------------------------------------------- - -def _push_update(alert: dict | None = None) -> None: - """Notify SSE subscribers that data has changed, with optional browser notification payload.""" - try: - from app import poller - poller._notify_subscribers(alert=alert) - except Exception: - pass +# DB helpers / progress / SSE re-exported from _common above. # --------------------------------------------------------------------------- diff --git a/app/burnin/_common.py b/app/burnin/_common.py new file mode 100644 index 0000000..385ed9c --- /dev/null +++ b/app/burnin/_common.py @@ -0,0 +1,277 @@ +"""Shared helpers for the burnin package. + +Lives below stages.py / task.py / __init__.py — these all import from +here. _common itself imports nothing from sibling burnin modules so we +stay free of circular-import landmines. + +Owns: + * Stage configuration constants (STAGE_ORDER, _STAGE_BASE_WEIGHTS, + POLL_INTERVAL). + * The connection-helper context manager `_db()` and the `_now()` ISO + timestamp helper used everywhere. + * Per-stage DB mutators called by stage implementations and by the + job orchestrator (`_start_stage`, `_finish_stage`, `_cancel_stage`, + `_set_stage_error`, `_update_stage_percent`, + `_update_stage_bad_blocks`, `_append_stage_log`). + * Drive-row mutators for SMART caches + (`_store_smart_attrs`, `_store_smart_raw_output`). + * The job-state read (`_is_cancelled`) + progress aggregator + (`_recalculate_progress`). + * SSE notifier (`_push_update`). +""" + +from __future__ import annotations + +import json +import logging +from contextlib import asynccontextmanager +from datetime import datetime, timezone + +import aiosqlite + +from app.config import settings + +log = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Stage configuration +# --------------------------------------------------------------------------- + +STAGE_ORDER: dict[str, list[str]] = { + # Legacy + "quick": ["precheck", "short_smart", "io_validate", "final_check"], + # Single-stage selectable profiles + "surface": ["precheck", "surface_validate", "final_check"], + "short": ["precheck", "short_smart", "final_check"], + "long": ["precheck", "long_smart", "final_check"], + # Two-stage combos + "surface_short": ["precheck", "surface_validate", "short_smart", "final_check"], + "surface_long": ["precheck", "surface_validate", "long_smart", "final_check"], + "short_long": ["precheck", "short_smart", "long_smart", "final_check"], + # All three + "full": ["precheck", "surface_validate", "short_smart", "long_smart", "final_check"], +} + +# Per-stage base weights used to compute overall job % progress dynamically +_STAGE_BASE_WEIGHTS: dict[str, int] = { + "precheck": 5, + "surface_validate": 65, + "short_smart": 12, + "long_smart": 13, + "io_validate": 10, + "final_check": 5, +} + +POLL_INTERVAL = 5.0 # seconds between progress checks during active stages + + +# --------------------------------------------------------------------------- +# Connection helpers +# --------------------------------------------------------------------------- + +def _now() -> str: + return datetime.now(timezone.utc).isoformat() + + +@asynccontextmanager +async def _db(): + """Open a WAL-mode connection with busy_timeout so writers wait for the lock + instead of immediately raising 'database is locked' under contention.""" + async with aiosqlite.connect(settings.db_path) as db: + await db.execute("PRAGMA busy_timeout=10000") + yield db + + +# --------------------------------------------------------------------------- +# Job / stage DB mutators +# --------------------------------------------------------------------------- + +async def _is_cancelled(job_id: int) -> bool: + async with _db() as db: + cur = await db.execute("SELECT state FROM burnin_jobs WHERE id=?", (job_id,)) + row = await cur.fetchone() + return bool(row and row[0] == "cancelled") + + +async def _start_stage(job_id: int, stage_name: str) -> None: + async with _db() as db: + await db.execute("PRAGMA journal_mode=WAL") + await db.execute( + "UPDATE burnin_stages SET state='running', started_at=? WHERE burnin_job_id=? AND stage_name=?", + (_now(), job_id, stage_name), + ) + await db.execute( + "UPDATE burnin_jobs SET stage_name=? WHERE id=?", + (stage_name, job_id), + ) + await db.commit() + + +async def _finish_stage(job_id: int, stage_name: str, success: bool, error_text: str | None = None) -> None: + now = _now() + state = "passed" if success else "failed" + async with _db() as db: + await db.execute("PRAGMA journal_mode=WAL") + cur = await db.execute( + "SELECT started_at FROM burnin_stages WHERE burnin_job_id=? AND stage_name=?", + (job_id, stage_name), + ) + row = await cur.fetchone() + duration = None + if row and row[0]: + try: + start = datetime.fromisoformat(row[0]) + if start.tzinfo is None: + start = start.replace(tzinfo=timezone.utc) + duration = (datetime.now(timezone.utc) - start).total_seconds() + except Exception: + pass + + # Only overwrite error_text if one is passed; otherwise preserve what the stage already wrote + if error_text is not None: + await db.execute( + """UPDATE burnin_stages + SET state=?, percent=?, finished_at=?, duration_seconds=?, error_text=? + WHERE burnin_job_id=? AND stage_name=?""", + (state, 100 if success else None, now, duration, error_text, job_id, stage_name), + ) + else: + await db.execute( + """UPDATE burnin_stages + SET state=?, percent=?, finished_at=?, duration_seconds=? + WHERE burnin_job_id=? AND stage_name=?""", + (state, 100 if success else None, now, duration, job_id, stage_name), + ) + await db.commit() + + +async def _update_stage_percent(job_id: int, stage_name: str, pct: int) -> None: + async with _db() as db: + await db.execute("PRAGMA journal_mode=WAL") + await db.execute( + "UPDATE burnin_stages SET percent=? WHERE burnin_job_id=? AND stage_name=?", + (pct, job_id, stage_name), + ) + await db.commit() + + +async def _cancel_stage(job_id: int, stage_name: str) -> None: + now = _now() + async with _db() as db: + await db.execute("PRAGMA journal_mode=WAL") + await db.execute( + "UPDATE burnin_stages SET state='cancelled', finished_at=? WHERE burnin_job_id=? AND stage_name=?", + (now, job_id, stage_name), + ) + await db.commit() + + +async def _append_stage_log(job_id: int, stage_name: str, text: str) -> None: + """Append text to the log_text column of a burnin_stages row.""" + async with _db() as db: + await db.execute("PRAGMA journal_mode=WAL") + await db.execute( + """UPDATE burnin_stages + SET log_text = COALESCE(log_text, '') || ? + WHERE burnin_job_id=? AND stage_name=?""", + (text, job_id, stage_name), + ) + await db.commit() + + +async def _update_stage_bad_blocks(job_id: int, stage_name: str, count: int) -> None: + async with _db() as db: + await db.execute("PRAGMA journal_mode=WAL") + await db.execute( + "UPDATE burnin_stages SET bad_blocks=? WHERE burnin_job_id=? AND stage_name=?", + (count, job_id, stage_name), + ) + await db.commit() + + +async def _store_smart_attrs(drive_id: int, attrs: dict) -> None: + """Persist latest SMART attribute dict to drives.smart_attrs (JSON).""" + # Convert int keys to str for JSON serialisation + serialisable = {str(k): v for k, v in attrs.get("attributes", {}).items()} + blob = json.dumps({ + "health": attrs.get("health", "UNKNOWN"), + "attrs": serialisable, + "warnings": attrs.get("warnings", []), + "failures": attrs.get("failures", []), + }) + async with _db() as db: + await db.execute("PRAGMA journal_mode=WAL") + await db.execute("UPDATE drives SET smart_attrs=? WHERE id=?", (blob, drive_id)) + await db.commit() + + +async def _store_smart_raw_output(drive_id: int, test_type: str, raw: str) -> None: + """Store raw smartctl output in smart_tests.raw_output.""" + async with _db() as db: + await db.execute("PRAGMA journal_mode=WAL") + await db.execute( + "UPDATE smart_tests SET raw_output=? WHERE drive_id=? AND test_type=?", + (raw, drive_id, test_type.lower()), + ) + await db.commit() + + +async def _set_stage_error(job_id: int, stage_name: str, error_text: str) -> None: + async with _db() as db: + await db.execute("PRAGMA journal_mode=WAL") + await db.execute( + "UPDATE burnin_stages SET error_text=? WHERE burnin_job_id=? AND stage_name=?", + (error_text, job_id, stage_name), + ) + await db.commit() + + +async def _recalculate_progress(job_id: int, profile: str | None = None) -> None: + """Recompute overall job % from actual stage rows. profile param is unused (kept for compat).""" + async with _db() as db: + db.row_factory = aiosqlite.Row + await db.execute("PRAGMA journal_mode=WAL") + + cur = await db.execute( + "SELECT stage_name, state, percent FROM burnin_stages WHERE burnin_job_id=? ORDER BY id", + (job_id,), + ) + stages = await cur.fetchall() + if not stages: + return + + total_weight = sum(_STAGE_BASE_WEIGHTS.get(s["stage_name"], 5) for s in stages) + if total_weight == 0: + return + + completed = 0.0 + current = None + for s in stages: + w = _STAGE_BASE_WEIGHTS.get(s["stage_name"], 5) + st = s["state"] + if st == "passed": + completed += w + elif st == "running": + completed += w * (s["percent"] or 0) / 100 + current = s["stage_name"] + + pct = int(completed / total_weight * 100) + await db.execute( + "UPDATE burnin_jobs SET percent=?, stage_name=? WHERE id=?", + (pct, current, job_id), + ) + await db.commit() + + +# --------------------------------------------------------------------------- +# SSE notifier +# --------------------------------------------------------------------------- + +def _push_update(alert: dict | None = None) -> None: + """Notify SSE subscribers that data has changed, with optional browser notification payload.""" + try: + from app import poller + poller._notify_subscribers(alert=alert) + except Exception: + pass diff --git a/app/burnin/stages.py b/app/burnin/stages.py new file mode 100644 index 0000000..e135e01 --- /dev/null +++ b/app/burnin/stages.py @@ -0,0 +1,735 @@ +"""Per-stage burn-in implementations. + +Each ``_stage_*`` function runs to completion or returns False. They share +state (DB, helpers, configuration) via ``app.burnin._common`` and pull +the live ``TrueNASClient`` instance lazily from the package root so the +extraction stays free of circular imports at module load. + +``_dispatch_stage`` is the per-stage_name router used by the orchestrator +in ``app.burnin.__init__._execute_stages``. +""" + +from __future__ import annotations + +import asyncio +import logging +import time + +from app.config import settings + +from . import kill +from ._common import ( + POLL_INTERVAL, + _append_stage_log, + _db, + _is_cancelled, + _push_update, + _recalculate_progress, + _set_stage_error, + _store_smart_attrs, + _store_smart_raw_output, + _update_stage_bad_blocks, + _update_stage_percent, +) + +log = logging.getLogger(__name__) + + +def _get_client(): + """Lazy access to the TrueNASClient set by ``burnin.init()``. Lives on + the package root for backward compat with routes.py which reaches + for ``burnin._client`` directly.""" + from app import burnin + return burnin._client + + +async def _dispatch_stage(job_id: int, stage_name: str, devname: str, drive_id: int) -> bool: + if stage_name == "precheck": + return await _stage_precheck(job_id, drive_id) + elif stage_name == "short_smart": + return await _stage_smart_test(job_id, devname, "SHORT", "short_smart", drive_id) + elif stage_name == "long_smart": + return await _stage_smart_test(job_id, devname, "LONG", "long_smart", drive_id) + elif stage_name == "surface_validate": + return await _stage_surface_validate(job_id, devname, drive_id) + elif stage_name == "io_validate": + return await _stage_timed_simulate(job_id, "io_validate", settings.io_validate_seconds) + elif stage_name == "final_check": + return await _stage_final_check(job_id, devname, drive_id) + return True + + +# --------------------------------------------------------------------------- +# Individual stage implementations +# --------------------------------------------------------------------------- + +async def _stage_precheck(job_id: int, drive_id: int) -> bool: + """Check SMART health and temperature before starting destructive work.""" + async with _db() as db: + cur = await db.execute( + "SELECT smart_health, temperature_c FROM drives WHERE id=?", (drive_id,) + ) + row = await cur.fetchone() + + if not row: + return False + + health, temp = row[0], row[1] + + if health == "FAILED": + await _set_stage_error(job_id, "precheck", "Drive SMART health is FAILED — refusing to burn in") + return False + + if temp and temp > settings.temp_crit_c: + await _set_stage_error(job_id, "precheck", f"Drive temperature {temp}°C exceeds {settings.temp_crit_c}°C limit") + return False + + await asyncio.sleep(1) # Simulate brief check + return True + + +async def _stage_smart_test(job_id: int, devname: str, test_type: str, stage_name: str, + drive_id: int | None = None) -> bool: + """Start a SMART test. Uses SSH if configured, TrueNAS REST API otherwise.""" + from app import ssh_client + if ssh_client.is_configured(): + return await _stage_smart_test_ssh(job_id, devname, test_type, stage_name, drive_id) + return await _stage_smart_test_api(job_id, devname, test_type, stage_name) + + +async def _stage_smart_test_api(job_id: int, devname: str, test_type: str, stage_name: str) -> bool: + """TrueNAS REST API path for SMART test (mock / dev mode).""" + tn_job_id = await _get_client().start_smart_test([devname], test_type) + + while True: + if await _is_cancelled(job_id): + try: + await _get_client().abort_job(tn_job_id) + except Exception: + pass + return False + + jobs = await _get_client().get_smart_jobs() + job = next((j for j in jobs if j["id"] == tn_job_id), None) + + if not job: + return False + + state = job["state"] + pct = job["progress"]["percent"] + + await _update_stage_percent(job_id, stage_name, pct) + await _recalculate_progress(job_id, None) + _push_update() + + if state == "SUCCESS": + return True + elif state in ("FAILED", "ABORTED"): + await _set_stage_error(job_id, stage_name, + job.get("error") or f"SMART {test_type} test failed") + return False + + await asyncio.sleep(POLL_INTERVAL) + + +async def _stage_smart_test_ssh(job_id: int, devname: str, test_type: str, stage_name: str, + drive_id: int | None) -> bool: + """SSH path for SMART test — runs smartctl directly on TrueNAS.""" + from app import ssh_client + + # Start the test + try: + startup = await ssh_client.start_smart_test(devname, test_type) + await _append_stage_log(job_id, stage_name, startup + "\n") + except Exception as exc: + await _set_stage_error(job_id, stage_name, f"Failed to start SMART test via SSH: {exc}") + return False + + # Brief pause to let the test register in smartctl output + await asyncio.sleep(3) + + # Throttle log_text appends — every poll on a multi-hour long_smart bloated + # log_text to 50+ MB and triggered SQLite "database is locked" because each + # COALESCE-then-append rewrites the whole column. Append every ~60s, on the + # first poll, and on any state change. + LOG_EVERY_N_POLLS = 12 + poll_count = 0 + last_state: str | None = None + + # Poll until complete + while True: + if await _is_cancelled(job_id): + try: + await ssh_client.abort_smart_test(devname) + except Exception: + pass + return False + + await asyncio.sleep(POLL_INTERVAL) + + try: + progress = await ssh_client.poll_smart_progress(devname) + except Exception as exc: + log.warning("SSH SMART poll failed: %s", exc, extra={"job_id": job_id}) + await _append_stage_log(job_id, stage_name, f"[poll error] {exc}\n") + continue + + poll_count += 1 + state_changed = progress["state"] != last_state + last_state = progress["state"] + if poll_count == 1 or poll_count % LOG_EVERY_N_POLLS == 0 or state_changed: + await _append_stage_log(job_id, stage_name, progress["output"] + "\n---\n") + + if progress["state"] == "running": + pct = max(0, 100 - progress["percent_remaining"]) + await _update_stage_percent(job_id, stage_name, pct) + await _recalculate_progress(job_id) + _push_update() + + elif progress["state"] == "passed": + await _update_stage_percent(job_id, stage_name, 100) + # Run attribute check + if drive_id is not None: + try: + attrs = await ssh_client.get_smart_attributes(devname) + await _store_smart_attrs(drive_id, attrs) + await _store_smart_raw_output(drive_id, test_type, attrs["raw_output"]) + if attrs["failures"]: + error = "SMART attribute failures: " + "; ".join(attrs["failures"]) + await _set_stage_error(job_id, stage_name, error) + return False + if attrs["warnings"]: + await _append_stage_log( + job_id, stage_name, + "[WARNING] " + "; ".join(attrs["warnings"]) + "\n" + ) + except Exception as exc: + log.warning("Failed to retrieve SMART attributes: %s", exc) + await _recalculate_progress(job_id) + _push_update() + return True + + elif progress["state"] == "failed": + await _set_stage_error(job_id, stage_name, f"SMART {test_type} test failed") + return False + # "unknown" → keep polling + + +async def _badblocks_available() -> bool: + """Check if badblocks is installed on the remote host (Linux/SCALE only).""" + from app import ssh_client + try: + async with await ssh_client._connect() as conn: + result = await conn.run("which badblocks", check=False) + return result.returncode == 0 + except Exception: + return False + + +async def _stage_surface_validate(job_id: int, devname: str, drive_id: int) -> bool: + """ + Surface validation stage — auto-routes to the right implementation: + + 1. NVMe device + SSH + nvme-cli available (TrueNAS SCALE): + → `nvme format -s 1 /dev/{devname}` (cryptographic erase). + Far faster than badblocks on NVMe (seconds vs hours) and + exercises the controller's secure-erase path, not just user-LBA + writes. + 2. SSH configured + badblocks available (TrueNAS SCALE / Linux): + → badblocks -wsv -b N -c N -p N /dev/{devname} directly over SSH. + 3. SSH configured + badblocks NOT available (TrueNAS CORE / FreeBSD): + → uses TrueNAS REST API disk.wipe FULL job + post-wipe SMART check. + 4. No SSH: + → simulated timed progress (dev/mock mode). + """ + from app import ssh_client + if ssh_client.is_configured(): + if devname.startswith("nvme") and await _nvme_cli_available(): + return await _stage_surface_validate_nvme(job_id, devname, drive_id) + if await _badblocks_available(): + return await _stage_surface_validate_ssh(job_id, devname, drive_id) + # TrueNAS CORE/FreeBSD: badblocks not available — use native wipe API + await _append_stage_log( + job_id, "surface_validate", + "[INFO] badblocks not found on host (TrueNAS CORE/FreeBSD) — " + "using TrueNAS disk.wipe API (FULL write pass).\n\n" + ) + return await _stage_surface_validate_truenas(job_id, devname, drive_id) + return await _stage_timed_simulate(job_id, "surface_validate", settings.surface_validate_seconds) + + +async def _nvme_cli_available() -> bool: + """Check if nvme-cli is installed on the remote host.""" + from app import ssh_client + try: + async with await ssh_client._connect() as conn: + r = await conn.run("which nvme", check=False) + return r.returncode == 0 + except Exception: + return False + + +async def _stage_surface_validate_nvme(job_id: int, devname: str, + drive_id: int) -> bool: + """NVMe destructive surface test via `nvme format -s 1` (crypto erase). + + Crypto-erase nukes the data encryption key on the drive's controller, + rendering all stored data unrecoverable in milliseconds; the actual + flash is then implicitly trim-able. This is the canonical destructive + burn-in for NVMe — badblocks would write the entire LBA space, which + is slower AND wears the flash unnecessarily. + + Post-format we re-read SMART attributes; the drive should report all + counters reset (life used + spare) and PASSED health. + """ + from app import ssh_client + + await _append_stage_log( + job_id, "surface_validate", + f"[START] nvme format -s 1 /dev/{devname}\n" + f"[NOTE] Cryptographic erase — destroys all data on /dev/{devname}.\n\n" + ) + + cmd = f"nvme format -s 1 --force /dev/{devname}" + try: + async with await ssh_client._connect() as conn: + r = await asyncio.wait_for( + conn.run(cmd, check=False), timeout=600 + ) + except Exception as exc: + await _append_stage_log( + job_id, "surface_validate", f"\n[SSH error] {exc}\n" + ) + await _set_stage_error( + job_id, "surface_validate", f"NVMe format SSH error: {exc}" + ) + return False + + output = (r.stdout or "") + (r.stderr or "") + await _append_stage_log(job_id, "surface_validate", output + "\n") + + if r.returncode != 0: + await _set_stage_error( + job_id, "surface_validate", + f"nvme format exited {r.returncode}: {output.strip()[:200]}" + ) + return False + + # Sanity-check post-format SMART health. Mirrors the surface_validate + # SSH path's check parity — fail on FAILED health, fail on real + # SMART attribute failures, log warnings but don't fail. A transport + # error here is treated as a soft pass (log + continue) so a single + # SSH blip after a successful format doesn't undo the work. + try: + attrs = await ssh_client.get_smart_attributes(devname) + ssh_only_failures = [ + f for f in (attrs.get("failures") or []) if f.startswith("SSH error:") + ] + real_failures = [ + f for f in (attrs.get("failures") or []) if not f.startswith("SSH error:") + ] + if attrs.get("health") == "FAILED": + await _set_stage_error( + job_id, "surface_validate", + "NVMe SMART health FAILED after format", + ) + return False + if real_failures: + await _set_stage_error( + job_id, "surface_validate", + "NVMe SMART attribute failures after format: " + + "; ".join(real_failures), + ) + return False + if ssh_only_failures: + await _append_stage_log( + job_id, "surface_validate", + "[WARN] post-format SMART check had SSH errors " + "(soft-passing): " + "; ".join(ssh_only_failures) + "\n", + ) + if attrs.get("warnings"): + await _append_stage_log( + job_id, "surface_validate", + "[WARN] " + "; ".join(attrs["warnings"]) + "\n", + ) + except Exception as exc: + log.warning("Post-format SMART check error on %s: %s", devname, exc) + await _append_stage_log( + job_id, "surface_validate", + f"[WARN] post-format SMART check raised: {exc}\n", + ) + + await _update_stage_percent(job_id, "surface_validate", 100) + await _recalculate_progress(job_id) + _push_update() + return True + + +async def _stage_surface_validate_ssh(job_id: int, devname: str, drive_id: int) -> bool: + """Run badblocks over SSH, streaming output to stage log.""" + from app import ssh_client + + await _append_stage_log( + job_id, "surface_validate", + f"[START] badblocks -wsv -b {settings.surface_validate_block_size} " + f"-c {settings.surface_validate_block_buffer} " + f"-p {settings.surface_validate_passes} /dev/{devname}\n" + f"[NOTE] This is a DESTRUCTIVE write test. " + f"All data on /dev/{devname} will be overwritten.\n\n" + ) + + def _is_cancelled_sync() -> bool: + # Synchronous version — we check the DB state flag set by cancel_job() + import asyncio + loop = asyncio.get_event_loop() + try: + return loop.run_until_complete(_is_cancelled(job_id)) + except Exception: + return False + + last_logged_pct = [-1] + + def on_progress(pct: int, bad_blocks: int, line: str) -> None: + nonlocal last_logged_pct + # Write to log (fire-and-forget via asyncio.create_task from sync context) + # The log append is done in the async flush below + pass + + accumulated_lines: list[str] = [] + + async def on_progress_async(pct: int, bad_blocks: int, line: str) -> None: + accumulated_lines.append(line) + # Flush to DB and update progress every ~25 lines to avoid excessive DB writes + if len(accumulated_lines) % 25 == 0: + await _append_stage_log(job_id, "surface_validate", "".join(accumulated_lines[-25:])) + await _update_stage_bad_blocks(job_id, "surface_validate", bad_blocks) + await _update_stage_percent(job_id, "surface_validate", pct) + await _recalculate_progress(job_id) + _push_update() + if await _is_cancelled(job_id): + raise asyncio.CancelledError + + # Run badblocks — we adapt the callback pattern to async by collecting then flushing + result = {"bad_blocks": 0, "output": "", "aborted": False} + try: + # The actual streaming; we handle progress via the accumulated_lines pattern + bad_blocks_total = 0 + output_lines: list[str] = [] + + async with await ssh_client._connect() as conn: + # Wrap in `sh -c 'echo PID:$$; exec ...'` so we get the remote + # PID on the first stdout line. asyncssh's proc.kill() sends an + # SSH signal request that OpenSSH's sshd ignores by default, so + # we need the PID to issue an out-of-band `kill -9` over a fresh + # session when we want to abort. + # + # Block geometry is operator-tunable (Settings → Burn-in): + # -b N block size in bytes (settings.surface_validate_block_size) + # -c N blocks held per IO (settings.surface_validate_block_buffer) + # -p N pass count (settings.surface_validate_passes) + # Defaults preserve original behavior (-b 4096 -c 64 -p 1). + bb_args = ( + f"-wsv " + f"-b {settings.surface_validate_block_size} " + f"-c {settings.surface_validate_block_buffer} " + f"-p {settings.surface_validate_passes}" + ) + cmd = ( + f"sh -c 'echo PID:$$; exec badblocks {bb_args} /dev/{devname}'" + ) + async with conn.create_process(cmd) as proc: + import re as _re + + pid_seen = False + + async def _drain(stream, is_stderr: bool): + nonlocal bad_blocks_total, pid_seen + async for raw in stream: + line = raw if isinstance(raw, str) else raw.decode("utf-8", errors="replace") + + # First stdout line is "PID:" from the wrapping shell. + # Capture it and don't append it to the user-visible log. + if not is_stderr and not pid_seen and line.startswith("PID:"): + pid_seen = True + try: + kill.set_remote_pid(job_id, int(line[4:].strip())) + log.info( + "Captured remote PID %d for job %d (badblocks)", + kill.get_remote_pid(job_id), job_id, + ) + except ValueError: + pass + continue + + output_lines.append(line) + + if is_stderr: + m = _re.search(r"([\d.]+)%\s+done", line) + if m: + pct = min(99, int(float(m.group(1)))) + await _update_stage_percent(job_id, "surface_validate", pct) + await _update_stage_bad_blocks(job_id, "surface_validate", bad_blocks_total) + await _recalculate_progress(job_id) + _push_update() + else: + stripped = line.strip() + if stripped and stripped.isdigit(): + bad_blocks_total += 1 + + # Append to DB log in chunks + if len(output_lines) % 20 == 0: + chunk = "".join(output_lines[-20:]) + await _append_stage_log(job_id, "surface_validate", chunk) + + # Abort on bad block threshold + if bad_blocks_total > settings.bad_block_threshold: + await kill.kill_remote_process(job_id) + output_lines.append( + f"\n[ABORTED] {bad_blocks_total} bad block(s) exceeded " + f"threshold ({settings.bad_block_threshold})\n" + ) + return + + if await _is_cancelled(job_id): + await kill.kill_remote_process(job_id) + return + + await asyncio.gather( + _drain(proc.stdout, False), + _drain(proc.stderr, True), + return_exceptions=True, + ) + # Bound proc.wait so a remote process that ignored our kill + # signal (or that we never managed to kill) can't pin this + # task in the semaphore forever. Closing the connection on + # exit will deliver SIGPIPE to the remote on its next write. + try: + await asyncio.wait_for(proc.wait(), timeout=15) + except asyncio.TimeoutError: + log.warning( + "proc.wait() timed out for job %d — abandoning channel", + job_id, + ) + + # Flush only lines we haven't already written in 20-line chunks. + # Previously we appended the FULL accumulated output here too, + # doubling the stored log_text size for every surface_validate + # stage and pushing app.db into hundreds of MB. + flushed_count = (len(output_lines) // 20) * 20 + tail = "".join(output_lines[flushed_count:]) + if tail: + await _append_stage_log(job_id, "surface_validate", tail) + result["bad_blocks"] = bad_blocks_total + result["output"] = "".join(output_lines) # in-memory only, not re-stored + result["aborted"] = bad_blocks_total > settings.bad_block_threshold + + except asyncio.CancelledError: + return False + except Exception as exc: + await _append_stage_log(job_id, "surface_validate", f"\n[SSH error] {exc}\n") + await _set_stage_error(job_id, "surface_validate", f"SSH badblocks error: {exc}") + return False + + await _update_stage_bad_blocks(job_id, "surface_validate", result["bad_blocks"]) + + if result["aborted"] or result["bad_blocks"] > settings.bad_block_threshold: + await _set_stage_error( + job_id, "surface_validate", + f"Surface validate FAILED: {result['bad_blocks']} bad block(s) found " + f"(threshold: {settings.bad_block_threshold})" + ) + return False + + return True + + +async def _stage_surface_validate_truenas(job_id: int, devname: str, drive_id: int) -> bool: + """ + Surface validation via TrueNAS CORE disk.wipe REST API. + Used on FreeBSD (TrueNAS CORE) where badblocks is unavailable. + + Sends a FULL write-zero pass across the entire disk, polls progress, + then runs a post-wipe SMART attribute check to catch reallocated sectors. + """ + from app import ssh_client + + await _append_stage_log( + job_id, "surface_validate", + f"[START] TrueNAS disk.wipe FULL — {devname}\n" + f"[NOTE] DESTRUCTIVE: all data on {devname} will be overwritten.\n\n" + ) + + # Start the wipe job + try: + tn_job_id = await _get_client().wipe_disk(devname, "FULL") + except Exception as exc: + await _set_stage_error(job_id, "surface_validate", f"Failed to start disk.wipe: {exc}") + return False + + await _append_stage_log( + job_id, "surface_validate", + f"[JOB] TrueNAS wipe job started (job_id={tn_job_id})\n" + ) + + # Poll until complete + log_flush_counter = 0 + while True: + if await _is_cancelled(job_id): + try: + await _get_client().abort_job(tn_job_id) + except Exception: + pass + return False + + await asyncio.sleep(POLL_INTERVAL) + + try: + job = await _get_client().get_job(tn_job_id) + except Exception as exc: + log.warning("Wipe job poll failed: %s", exc, extra={"job_id": job_id}) + await _append_stage_log(job_id, "surface_validate", f"[poll error] {exc}\n") + continue + + if not job: + await _set_stage_error(job_id, "surface_validate", f"Wipe job {tn_job_id} not found") + return False + + state = job.get("state", "") + pct = int(job.get("progress", {}).get("percent", 0) or 0) + desc = job.get("progress", {}).get("description", "") + + await _update_stage_percent(job_id, "surface_validate", min(pct, 99)) + await _recalculate_progress(job_id) + _push_update() + + # Log progress description every ~5 polls to avoid DB spam + log_flush_counter += 1 + if desc and log_flush_counter % 5 == 0: + await _append_stage_log(job_id, "surface_validate", f"[{pct}%] {desc}\n") + + if state == "SUCCESS": + await _update_stage_percent(job_id, "surface_validate", 100) + await _append_stage_log( + job_id, "surface_validate", + f"\n[DONE] Wipe job {tn_job_id} completed successfully.\n" + ) + # Post-wipe SMART check — catch any sectors that failed under write stress + if ssh_client.is_configured() and drive_id is not None: + await _append_stage_log( + job_id, "surface_validate", + "[CHECK] Running post-wipe SMART attribute check...\n" + ) + try: + attrs = await ssh_client.get_smart_attributes(devname) + await _store_smart_attrs(drive_id, attrs) + if attrs["failures"]: + error = "Post-wipe SMART check: " + "; ".join(attrs["failures"]) + await _set_stage_error(job_id, "surface_validate", error) + return False + if attrs["warnings"]: + await _append_stage_log( + job_id, "surface_validate", + "[WARNING] " + "; ".join(attrs["warnings"]) + "\n" + ) + await _append_stage_log( + job_id, "surface_validate", + f"[CHECK] SMART health: {attrs['health']} — no critical attributes.\n" + ) + except Exception as exc: + log.warning("Post-wipe SMART check failed: %s", exc) + await _append_stage_log( + job_id, "surface_validate", + f"[WARN] Post-wipe SMART check failed (non-fatal): {exc}\n" + ) + return True + + elif state in ("FAILED", "ABORTED", "ERROR"): + error_msg = job.get("error") or f"Disk wipe failed (state={state})" + await _set_stage_error( + job_id, "surface_validate", + f"TrueNAS disk.wipe FAILED: {error_msg}" + ) + return False + # RUNNING or WAITING — keep polling + + +async def _stage_timed_simulate(job_id: int, stage_name: str, duration_seconds: int) -> bool: + """Simulate a timed stage with progress updates (mock / dev mode).""" + start = time.monotonic() + + while True: + if await _is_cancelled(job_id): + return False + + elapsed = time.monotonic() - start + pct = min(100, int(elapsed / duration_seconds * 100)) + + await _update_stage_percent(job_id, stage_name, pct) + await _recalculate_progress(job_id, None) + _push_update() + + if pct >= 100: + return True + + await asyncio.sleep(POLL_INTERVAL) + + +async def _stage_final_check(job_id: int, devname: str, drive_id: int | None = None) -> bool: + """ + Verify drive passed all tests. + SSH mode: run smartctl -a and check critical attributes. + Mock mode: check SMART health field in DB. + + A transient SSH connectivity failure here must NOT invalidate a prior + multi-day surface_validate. Retry SSH-only failures, then soft-pass. + """ + await asyncio.sleep(1) + from app import ssh_client + + def _ssh_only(failures: list[str]) -> bool: + return bool(failures) and all(f.startswith("SSH error:") for f in failures) + + if ssh_client.is_configured() and drive_id is not None: + try: + attrs = await ssh_client.get_smart_attributes(devname) + for attempt in range(2): + if not _ssh_only(attrs.get("failures") or []): + break + log.warning( + "final_check SSH unreachable (attempt %d/3); retrying in 30s", + attempt + 1, + extra={"job_id": job_id, "devname": devname}, + ) + await asyncio.sleep(30) + attrs = await ssh_client.get_smart_attributes(devname) + + failures = attrs.get("failures") or [] + if _ssh_only(failures): + log.warning( + "final_check soft-pass: SSH unreachable after retries; prior stages stand", + extra={"job_id": job_id, "devname": devname, "ssh_error": failures}, + ) + return True + + await _store_smart_attrs(drive_id, attrs) + if attrs["health"] == "FAILED" or failures: + msg = failures or [f"SMART health: {attrs['health']}"] + await _set_stage_error(job_id, "final_check", + "Final check failed: " + "; ".join(msg)) + return False + return True + except Exception as exc: + log.warning("SSH final_check raised, falling back to DB check: %s", exc) + + # DB check (mock mode fallback) + async with _db() as db: + cur = await db.execute( + "SELECT smart_health FROM drives WHERE devname=?", (devname,) + ) + row = await cur.fetchone() + + if not row or row[0] == "FAILED": + await _set_stage_error(job_id, "final_check", "Drive SMART health is FAILED after burn-in") + return False + + return True diff --git a/app/config.py b/app/config.py index 90bd3bd..3ee5f30 100644 --- a/app/config.py +++ b/app/config.py @@ -83,7 +83,7 @@ class Settings(BaseSettings): ssh_key: str = "" # PEM private key content (paste full key including headers) # Application version — used by the /api/v1/updates/check endpoint - app_version: str = "1.0.0-30" + app_version: str = "1.0.0-31" # ---- Authentication (1.0.0-22) ---- # session_secret: HMAC key for signing session cookies. Empty = generate