nas-burnin/app/burnin.py
Brandon Walter 6c20e57fd8
Some checks are pending
Security scan / pip-audit (push) Waiting to run
Security scan / bandit (push) Waiting to run
Security scan / gitleaks (push) Waiting to run
fix: live pool re-check before start_job + drop dead run_badblocks (1.0.0-29)
Closes the last open Codex finding (#5) and removes one piece of dead
code Codex flagged in passing.

#5 — Live pool re-check before burn-in start:
  Before this change, _is_unlocked compared the operator's unlock grant
  against the cached drives.pool_* row. If a drive was imported into a
  pool, mounted, or had ZFS labels written between the operator's
  unlock click and the next ~12s poll, burn-in could still start
  against the stale identity and silently destroy the new pool.

  start_job now calls a fresh ssh_client.fresh_pool_check_for_drive()
  immediately after the cached gate. That helper re-runs the three
  detection probes (zpool list -vHP / lsblk zfs_member / findmnt) over
  a fresh SSH session and returns the live answer for one devname.
  If it differs from cached state we invalidate any existing unlock
  grant and raise PoolMemberError with the FRESH pool name so the UI
  reflects current reality. If fresh shows free but cached said locked
  the drive came back to free since last poll — log it and allow.

  Cost: ~200ms per burn-in start. For batch starts of 12 drives, that's
  2.4s extra latency — cheap against destroying a freshly-imported pool.

Dead code removal:
  ssh_client.run_badblocks() — no callers since 1.0.0-13 when the SSH
  badblocks logic was inlined into burnin._stage_surface_validate_ssh
  (with the asyncssh-signal-doesn't-actually-kill workaround). Removing
  the dead function also lets us drop the now-unused
  `from typing import Callable` import.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-02 21:29:11 -04:00

1667 lines
66 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Burn-in orchestrator.
Manages a FIFO queue of burn-in jobs capped at MAX_PARALLEL_BURNINS concurrent
executions. Each job runs stages sequentially; a failed stage aborts the job.
State is persisted to SQLite throughout — DB is source of truth.
On startup:
- Any 'running' jobs from a previous run are marked 'unknown' (interrupted).
- Any 'queued' jobs are re-enqueued automatically.
Cancellation:
- cancel_job() sets DB state to 'cancelled'.
- Running stage coroutines check _is_cancelled() at POLL_INTERVAL boundaries
and abort within a few seconds of the cancel request.
"""
import asyncio
import logging
import time
from contextlib import asynccontextmanager
from datetime import datetime, timezone
import aiosqlite
from app.config import settings
from app.truenas import TrueNASClient
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Stage definitions
# ---------------------------------------------------------------------------
STAGE_ORDER: dict[str, list[str]] = {
# Legacy
"quick": ["precheck", "short_smart", "io_validate", "final_check"],
# Single-stage selectable profiles
"surface": ["precheck", "surface_validate", "final_check"],
"short": ["precheck", "short_smart", "final_check"],
"long": ["precheck", "long_smart", "final_check"],
# Two-stage combos
"surface_short": ["precheck", "surface_validate", "short_smart", "final_check"],
"surface_long": ["precheck", "surface_validate", "long_smart", "final_check"],
"short_long": ["precheck", "short_smart", "long_smart", "final_check"],
# All three
"full": ["precheck", "surface_validate", "short_smart", "long_smart", "final_check"],
}
# Per-stage base weights used to compute overall job % progress dynamically
_STAGE_BASE_WEIGHTS: dict[str, int] = {
"precheck": 5,
"surface_validate": 65,
"short_smart": 12,
"long_smart": 13,
"io_validate": 10,
"final_check": 5,
}
POLL_INTERVAL = 5.0 # seconds between progress checks during active stages
# ---------------------------------------------------------------------------
# Module-level state (initialized in init())
# ---------------------------------------------------------------------------
_semaphore: asyncio.Semaphore | None = None
_client: TrueNASClient | None = None
# Live job tracking — keeps a strong reference to every _run_job task so it
# isn't garbage-collected (asyncio.create_task only keeps a weak ref) and so
# cancel_job / check_stuck_jobs can actually unwedge a stuck task.
_active_tasks: dict[int, "asyncio.Task"] = {}
# Remote PID of any long-running SSH child process (currently only badblocks)
# so we can kill it via a fresh SSH session — proc.kill() over asyncssh sends
# a "signal" channel request that OpenSSH sshd ignores by default, leaving
# the remote process running and proc.wait() hanging forever.
_remote_pids: dict[int, int] = {}
def _now() -> str:
return datetime.now(timezone.utc).isoformat()
@asynccontextmanager
async def _db():
"""Open a WAL-mode connection with busy_timeout so writers wait for the lock
instead of immediately raising 'database is locked' under contention."""
async with aiosqlite.connect(settings.db_path) as db:
await db.execute("PRAGMA busy_timeout=10000")
yield db
# ---------------------------------------------------------------------------
# Init + startup reconciliation
# ---------------------------------------------------------------------------
async def init(client: TrueNASClient) -> None:
global _semaphore, _client
_semaphore = asyncio.Semaphore(settings.max_parallel_burnins)
_client = client
async with _db() as db:
db.row_factory = aiosqlite.Row
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA foreign_keys=ON")
# Mark interrupted running jobs as unknown
await db.execute(
"UPDATE burnin_jobs SET state='unknown', finished_at=? WHERE state='running'",
(_now(),),
)
# Re-enqueue previously queued jobs
cur = await db.execute(
"SELECT id FROM burnin_jobs WHERE state='queued' ORDER BY created_at"
)
queued = [r["id"] for r in await cur.fetchall()]
await db.commit()
for job_id in queued:
_spawn_run_job(job_id)
log.info("Burn-in orchestrator ready (max_concurrent=%d)", settings.max_parallel_burnins)
def _spawn_run_job(job_id: int) -> "asyncio.Task":
"""Schedule a _run_job task and keep a strong reference to it.
Plain asyncio.create_task() only leaves a weak reference behind, so the
task can be GC'd before it ever runs. Storing it in _active_tasks also
lets cancel_job / check_stuck_jobs cancel it directly.
"""
task = asyncio.create_task(_run_job(job_id))
_active_tasks[job_id] = task
def _cleanup(t: "asyncio.Task") -> None:
# Remove only if it's still us — avoid clobbering a re-enqueued task.
if _active_tasks.get(job_id) is t:
_active_tasks.pop(job_id, None)
_remote_pids.pop(job_id, None)
task.add_done_callback(_cleanup)
return task
async def _kill_remote_process(job_id: int) -> None:
"""Send kill -9 to the remote PID associated with this job, if any.
asyncssh's proc.kill() sends an SSH 'signal' channel request which
OpenSSH's sshd does not honor by default. Opening a fresh session and
running /bin/kill is the reliable way to actually terminate the process.
"""
pid = _remote_pids.pop(job_id, None)
if not pid:
return
try:
from app import ssh_client
async with await ssh_client._connect() as conn:
await asyncio.wait_for(
conn.run(f"kill -9 {pid} 2>/dev/null || true", check=False),
timeout=10,
)
log.info("Remote-killed PID %d for job %d", pid, job_id)
except Exception as exc:
log.warning("Failed to remote-kill PID %d for job %d: %s", pid, job_id, exc)
# ---------------------------------------------------------------------------
# Pool-drive unlock state
# ---------------------------------------------------------------------------
#
# Drives that ZFS reports as belonging to an active zpool (including the
# boot pool) are locked from burn-in until the operator explicitly unlocks
# them via POST /api/v1/drives/{id}/unlock. Grants live in memory only —
# a container restart wipes them, which is the right default for "this is
# very dangerous." TTL is bounded so an unlock you forgot about can't sit
# armed indefinitely.
import time as _time
from dataclasses import dataclass
UNLOCK_TTL_SECONDS = 600 # 10 minutes
BOOT_POOL_NAME = "boot-pool"
BOOT_POOL_CONFIRM_TOKEN = "DESTROY BOOT POOL"
EXPORTED_POOL_ROLE = "exported"
EXPORTED_CONFIRM_TOKEN = "DESTROY EXPORTED POOL"
MOUNTED_ROLE = "mounted"
MOUNTED_CONFIRM_TOKEN = "DESTROY MOUNTED FILESYSTEM"
@dataclass
class _UnlockGrant:
"""An operator-issued, time-bounded permission to burn-in a pool drive.
The grant is BOUND to the (pool_name, pool_role) observed at unlock
time. If a subsequent poll reclassifies the drive — e.g. it was
"(exported)" when unlocked but is now in active pool "tank", or it
used to be a cache vdev and now shows as data — the grant is
invalidated. Otherwise the operator's "I confirm this exported drive
is decommissioned" judgement would silently authorise destruction
of a live pool.
"""
expiry: float
pool_name: str
pool_role: str | None
_unlock_grants: dict[int, _UnlockGrant] = {}
class PoolMemberError(Exception):
"""Raised by start_job when a drive is in a zpool and not unlocked."""
def __init__(self, drive_id: int, pool_name: str, pool_role: str | None):
self.drive_id = drive_id
self.pool_name = pool_name
self.pool_role = pool_role
is_boot = pool_name == BOOT_POOL_NAME
super().__init__(
f"Drive is part of {'BOOT POOL' if is_boot else 'pool'} "
f"'{pool_name}'{' (' + pool_role + ')' if pool_role else ''}. "
f"Unlock required before burn-in."
)
def _is_unlocked(drive_id: int, current_pool_name: str | None,
current_pool_role: str | None) -> bool:
"""True iff a non-expired grant exists AND the drive's pool identity
matches what was observed at unlock time."""
grant = _unlock_grants.get(drive_id)
if grant is None:
return False
if _time.time() >= grant.expiry:
_unlock_grants.pop(drive_id, None)
return False
if grant.pool_name != current_pool_name or grant.pool_role != current_pool_role:
# Pool identity changed since unlock — drive may now belong to a
# different (or live) pool. Invalidate the grant; operator must
# re-unlock with eyes-open against the current state.
_unlock_grants.pop(drive_id, None)
log.warning(
"Invalidating unlock grant for drive_id=%d: pool changed from "
"(%s, %s) to (%s, %s)",
drive_id, grant.pool_name, grant.pool_role,
current_pool_name, current_pool_role,
)
return False
return True
def unlock_expiry(drive_id: int, current_pool_name: str | None,
current_pool_role: str | None) -> float | None:
"""Return the absolute expiry of an active grant, or None.
Same identity-binding semantics as _is_unlocked: a grant whose stored
pool identity no longer matches the current row is treated as expired
and reaped. This is what the dashboard reads to decide whether to show
the unlocked-Burn-In affordance vs the locked-Unlock affordance.
"""
grant = _unlock_grants.get(drive_id)
if grant is None:
return None
if _time.time() >= grant.expiry:
_unlock_grants.pop(drive_id, None)
return None
if grant.pool_name != current_pool_name or grant.pool_role != current_pool_role:
_unlock_grants.pop(drive_id, None)
return None
return grant.expiry
async def grant_pool_unlock(drive_id: int, confirm_token: str,
operator: str, reason: str) -> float:
"""Validate confirmation token + reason and grant a time-limited unlock.
Raises ValueError on bad confirm_token, missing reason, or drive not
actually in a pool. Returns the unix expiry timestamp on success.
"""
if not reason or len(reason.strip()) < 5:
raise ValueError("A reason of at least 5 characters is required.")
if not operator or not operator.strip():
raise ValueError("Operator name is required.")
async with _db() as db:
db.row_factory = aiosqlite.Row
cur = await db.execute(
"SELECT pool_name, pool_role, devname FROM drives WHERE id=?",
(drive_id,),
)
row = await cur.fetchone()
if not row:
raise ValueError("Drive not found.")
pool_name = row["pool_name"]
pool_role = row["pool_role"]
if not pool_name:
raise ValueError(
"This drive is not part of any pool — no unlock needed."
)
# Boot-pool / exported / mounted-fs all get dedicated, harder-to-
# fat-finger tokens. Active data pools just need their pool name
# typed.
if pool_name == BOOT_POOL_NAME:
expected = BOOT_POOL_CONFIRM_TOKEN
elif pool_role == EXPORTED_POOL_ROLE:
expected = EXPORTED_CONFIRM_TOKEN
elif pool_role == MOUNTED_ROLE:
expected = MOUNTED_CONFIRM_TOKEN
else:
expected = pool_name
if (confirm_token or "").strip() != expected:
raise ValueError("Confirmation token does not match.")
if pool_name == BOOT_POOL_NAME:
evt = "boot_pool_drive_unlocked"
elif pool_role == EXPORTED_POOL_ROLE:
evt = "exported_pool_drive_unlocked"
elif pool_role == MOUNTED_ROLE:
evt = "mounted_drive_unlocked"
else:
evt = "pool_drive_unlocked"
await db.execute(
"""INSERT INTO audit_events
(event_type, drive_id, burnin_job_id, operator, message)
VALUES (?,?,?,?,?)""",
(evt, drive_id, None, operator.strip(),
f"Unlocked {pool_name} drive {row['devname']} for burn-in: {reason.strip()}"),
)
await db.commit()
# Arm the in-memory grant ONLY after the audit row is durable. If the
# commit above raises, we exit without writing _unlock_grants — no
# unaudited active unlocks. The grant is bound to the (pool_name,
# pool_role) we observed under the open transaction so a later poll
# that reclassifies the drive invalidates it (see _is_unlocked).
expiry = _time.time() + UNLOCK_TTL_SECONDS
_unlock_grants[drive_id] = _UnlockGrant(
expiry=expiry,
pool_name=pool_name,
pool_role=pool_role,
)
log.warning(
"Pool-drive unlock granted: drive_id=%d pool=%s role=%s "
"operator=%s reason=%r",
drive_id, pool_name, pool_role, operator, reason,
)
return expiry
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
async def start_job(drive_id: int, profile: str, operator: str,
stage_order: list[str] | None = None) -> int:
"""Create and enqueue a burn-in job. Returns the new job ID.
If stage_order is provided (e.g. ["short_smart","long_smart","surface_validate"]),
the job runs those stages in that order (precheck and final_check are always prepended/appended).
Otherwise the preset STAGE_ORDER[profile] is used.
"""
now = _now()
# Build the actual stage list
if stage_order is not None:
stages = ["precheck"] + list(stage_order) + ["final_check"]
else:
stages = STAGE_ORDER[profile]
async with _db() as db:
db.row_factory = aiosqlite.Row
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA foreign_keys=ON")
# Reject duplicate active burn-in for same drive
cur = await db.execute(
"SELECT COUNT(*) FROM burnin_jobs WHERE drive_id=? AND state IN ('queued','running')",
(drive_id,),
)
if (await cur.fetchone())[0] > 0:
raise ValueError("Drive already has an active burn-in job")
# Pool-membership gate: locked unless the operator explicitly
# unlocked this drive via /api/v1/drives/{id}/unlock recently.
# _is_unlocked also checks that the grant's stored (pool_name,
# pool_role) still matches the live row — a grant issued for an
# exported drive doesn't carry over if the drive turns out to be
# in an active pool on the next poll.
cur = await db.execute(
"SELECT pool_name, pool_role, devname FROM drives WHERE id=?", (drive_id,)
)
drow = await cur.fetchone()
if drow and drow["pool_name"] and not _is_unlocked(
drive_id, drow["pool_name"], drow["pool_role"]
):
raise PoolMemberError(drive_id, drow["pool_name"], drow["pool_role"])
# Closes Codex finding #5: re-check pool state OVER SSH right now,
# not against cached row. Defends against the 12s poll window
# where a drive could have been imported into a pool, mounted, or
# had ZFS labels written between when the operator unlocked it
# and when they clicked Start. Adds ~200ms per start; cheap
# against the cost of destroying a freshly-imported pool.
if drow:
from app import ssh_client as _ssh
if _ssh.is_configured():
fresh = await _ssh.fresh_pool_check_for_drive(drow["devname"])
cached = (
{"pool": drow["pool_name"], "role": drow["pool_role"]}
if drow["pool_name"] else None
)
if fresh != cached:
# State changed since the last poll. Invalidate any
# unlock grant (it was bound to stale identity) and
# refuse with a descriptive error so the operator
# knows to wait for the next poll cycle.
_unlock_grants.pop(drive_id, None)
fresh_pool = fresh["pool"] if fresh else None
fresh_role = fresh["role"] if fresh else None
if fresh_pool:
raise PoolMemberError(drive_id, fresh_pool, fresh_role)
# If the FRESH check shows free but cached said
# locked, the drive was just removed from a pool —
# safe to start, but invalidate any stale grant so
# the operator doesn't reuse old confirmations.
log.warning(
"Live pool check for drive_id=%d (%s): cached=%s "
"fresh=None — drive came free since last poll, "
"allowing burn-in",
drive_id, drow["devname"], cached,
)
# Create job. The partial unique index uniq_active_burnin_per_drive
# (database.py) is the actual race-stopper here: if two concurrent
# /api/v1/burnin/start calls both pass the SELECT-COUNT check above,
# only one INSERT can win; the loser raises IntegrityError, which
# we surface with the same ValueError as the inline duplicate check.
try:
cur = await db.execute(
"""INSERT INTO burnin_jobs (drive_id, profile, state, percent, operator, created_at)
VALUES (?,?,?,?,?,?) RETURNING id""",
(drive_id, profile, "queued", 0, operator, now),
)
job_id = (await cur.fetchone())["id"]
except aiosqlite.IntegrityError:
raise ValueError("Drive already has an active burn-in job")
# Create stage rows in the desired execution order
for stage_name in stages:
await db.execute(
"INSERT INTO burnin_stages (burnin_job_id, stage_name, state) VALUES (?,?,?)",
(job_id, stage_name, "pending"),
)
await db.execute(
"""INSERT INTO audit_events (event_type, drive_id, burnin_job_id, operator, message)
VALUES (?,?,?,?,?)""",
("burnin_queued", drive_id, job_id, operator, f"Queued {profile} burn-in"),
)
await db.commit()
_spawn_run_job(job_id)
log.info("Burn-in job %d queued (drive_id=%d profile=%s operator=%s)",
job_id, drive_id, profile, operator)
return job_id
async def cancel_job(job_id: int, operator: str) -> bool:
"""Cancel a queued or running job. Returns True if state was changed."""
async with _db() as db:
db.row_factory = aiosqlite.Row
await db.execute("PRAGMA journal_mode=WAL")
cur = await db.execute(
"SELECT state, drive_id FROM burnin_jobs WHERE id=?", (job_id,)
)
row = await cur.fetchone()
if not row or row["state"] not in ("queued", "running"):
return False
await db.execute(
"UPDATE burnin_jobs SET state='cancelled', finished_at=? WHERE id=?",
(_now(), job_id),
)
await db.execute(
"UPDATE burnin_stages SET state='cancelled' WHERE burnin_job_id=? AND state IN ('pending','running')",
(job_id,),
)
await db.execute(
"""INSERT INTO audit_events (event_type, drive_id, burnin_job_id, operator, message)
VALUES (?,?,?,?,?)""",
("burnin_cancelled", row["drive_id"], job_id, operator, "Cancelled by operator"),
)
await db.commit()
# Kill the remote child process FIRST (so proc.wait() in the running task
# can return), then cancel the task so any other awaits unblock.
await _kill_remote_process(job_id)
task = _active_tasks.get(job_id)
if task and not task.done():
task.cancel()
log.info("Burn-in job %d cancelled by %s", job_id, operator)
return True
# ---------------------------------------------------------------------------
# Job runner
# ---------------------------------------------------------------------------
async def _thermal_gate_ok() -> bool:
"""True if it's thermally safe to start a new burn-in.
Checks the peak temperature of drives currently under active burn-in.
"""
try:
async with _db() as db:
cur = await db.execute("""
SELECT MAX(d.temperature_c)
FROM drives d
JOIN burnin_jobs bj ON bj.drive_id = d.id
WHERE bj.state = 'running' AND d.temperature_c IS NOT NULL
""")
row = await cur.fetchone()
max_temp = row[0] if row and row[0] is not None else None
return max_temp is None or max_temp < settings.temp_warn_c
except Exception:
return True # Never block on error
async def _run_job(job_id: int) -> None:
"""Acquire semaphore slot, execute all stages, persist final state."""
assert _semaphore is not None, "burnin.init() not called"
# Adaptive thermal gate: wait before competing for a slot if running drives
# are already at or above the warning threshold. This prevents layering a
# new burn-in on top of a thermally-stressed system. Gives up after 3 min
# and proceeds anyway so jobs don't queue indefinitely.
for _attempt in range(18): # 18 × 10 s = 3 min max
if await _thermal_gate_ok():
break
if _attempt == 0:
log.info(
"Thermal gate: job %d waiting — running drive temps at or above %d°C",
job_id, settings.temp_warn_c,
)
await asyncio.sleep(10)
else:
log.warning("Thermal gate timed out for job %d — proceeding anyway", job_id)
async with _semaphore:
if await _is_cancelled(job_id):
return
# Transition queued → running
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
row = await (await db.execute(
"SELECT drive_id, profile FROM burnin_jobs WHERE id=?", (job_id,)
)).fetchone()
if not row:
return
drive_id, profile = row[0], row[1]
cur = await db.execute("SELECT devname, serial, model FROM drives WHERE id=?", (drive_id,))
devname_row = await cur.fetchone()
if not devname_row:
return
devname = devname_row[0]
drive_serial = devname_row[1]
drive_model = devname_row[2]
await db.execute(
"UPDATE burnin_jobs SET state='running', started_at=? WHERE id=?",
(_now(), job_id),
)
await db.execute(
"""INSERT INTO audit_events (event_type, drive_id, burnin_job_id, operator, message)
VALUES (?,?,?,(SELECT operator FROM burnin_jobs WHERE id=?),?)""",
("burnin_started", drive_id, job_id, job_id, f"Started {profile} burn-in on {devname}"),
)
# Read stage order from DB (respects any custom order set at job creation)
stage_cur = await db.execute(
"SELECT stage_name FROM burnin_stages WHERE burnin_job_id=? ORDER BY id",
(job_id,),
)
job_stages = [r[0] for r in await stage_cur.fetchall()]
await db.commit()
_push_update()
log.info("Burn-in started", extra={"job_id": job_id, "devname": devname, "profile": profile})
success = False
error_text = None
was_cancelled = False
try:
success = await _execute_stages(job_id, job_stages, devname, drive_id)
except asyncio.CancelledError:
was_cancelled = True
except Exception as exc:
error_text = str(exc)
log.exception("Burn-in raised exception", extra={"job_id": job_id, "devname": devname})
# If the job has already moved to a terminal state — by cancel_job
# ('cancelled') or check_stuck_jobs ('unknown') — leave it alone. The
# task may have been cancelled mid-stage; finalizing as 'failed' would
# clobber that audit-meaningful terminal state.
async with _db() as db:
cur = await db.execute("SELECT state FROM burnin_jobs WHERE id=?", (job_id,))
cur_row = await cur.fetchone()
if cur_row and cur_row[0] != "running":
return
# Cancellation arriving here means the asyncio task was cancelled
# by something other than cancel_job/check_stuck_jobs (shutdown,
# uvicorn reload, future code paths). The DB still says 'running',
# so we have to write *some* terminal state, but classifying the
# interrupted job as 'failed' would lie — we don't actually know
# whether the underlying SMART/badblocks work passed or not.
if was_cancelled:
final_state = "unknown"
else:
final_state = "passed" if success else "failed"
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(
"UPDATE burnin_jobs SET state=?, percent=?, finished_at=?, error_text=? WHERE id=?",
(final_state, 100 if success else None, _now(), error_text, job_id),
)
await db.execute(
"""INSERT INTO audit_events (event_type, drive_id, burnin_job_id, operator, message)
VALUES (?,?,?,(SELECT operator FROM burnin_jobs WHERE id=?),?)""",
(f"burnin_{final_state}", drive_id, job_id, job_id,
f"Burn-in {final_state} on {devname}"),
)
await db.commit()
# Build SSE alert for browser notifications
alert = {
"state": final_state,
"job_id": job_id,
"devname": devname,
"serial": drive_serial,
"model": drive_model,
"error_text": error_text,
}
_push_update(alert=alert)
log.info("Burn-in finished", extra={"job_id": job_id, "devname": devname, "state": final_state})
# Fire webhook + immediate email in background (non-blocking)
try:
from app import notifier
cur2 = None
async with _db() as db2:
db2.row_factory = aiosqlite.Row
cur2 = await db2.execute(
"SELECT profile, operator FROM burnin_jobs WHERE id=?", (job_id,)
)
job_row = await cur2.fetchone()
if job_row:
# Get bad_blocks count from surface_validate stage if present
bad_blocks = 0
async with _db() as db3:
cur3 = await db3.execute(
"SELECT bad_blocks FROM burnin_stages WHERE burnin_job_id=? AND stage_name='surface_validate'",
(job_id,)
)
bb_row = await cur3.fetchone()
if bb_row and bb_row[0]:
bad_blocks = bb_row[0]
asyncio.create_task(notifier.notify_job_complete(
job_id=job_id,
devname=devname,
serial=drive_serial,
model=drive_model,
state=final_state,
profile=job_row["profile"],
operator=job_row["operator"],
error_text=error_text,
bad_blocks=bad_blocks,
))
except Exception as exc:
log.error("Failed to schedule notifications: %s", exc)
async def _execute_stages(job_id: int, stages: list[str], devname: str, drive_id: int) -> bool:
for stage_name in stages:
if await _is_cancelled(job_id):
return False
await _start_stage(job_id, stage_name)
_push_update()
try:
ok = await _dispatch_stage(job_id, stage_name, devname, drive_id)
except Exception as exc:
log.error("Stage raised exception: %s", exc, extra={"job_id": job_id, "devname": devname, "stage": stage_name})
ok = False
await _finish_stage(job_id, stage_name, success=False, error_text=str(exc))
_push_update()
return False
if not ok and await _is_cancelled(job_id):
# Stage was aborted due to cancellation — mark it cancelled, not failed
await _cancel_stage(job_id, stage_name)
else:
await _finish_stage(job_id, stage_name, success=ok)
await _recalculate_progress(job_id)
_push_update()
if not ok:
return False
return True
async def _dispatch_stage(job_id: int, stage_name: str, devname: str, drive_id: int) -> bool:
if stage_name == "precheck":
return await _stage_precheck(job_id, drive_id)
elif stage_name == "short_smart":
return await _stage_smart_test(job_id, devname, "SHORT", "short_smart", drive_id)
elif stage_name == "long_smart":
return await _stage_smart_test(job_id, devname, "LONG", "long_smart", drive_id)
elif stage_name == "surface_validate":
return await _stage_surface_validate(job_id, devname, drive_id)
elif stage_name == "io_validate":
return await _stage_timed_simulate(job_id, "io_validate", settings.io_validate_seconds)
elif stage_name == "final_check":
return await _stage_final_check(job_id, devname, drive_id)
return True
# ---------------------------------------------------------------------------
# Individual stage implementations
# ---------------------------------------------------------------------------
async def _stage_precheck(job_id: int, drive_id: int) -> bool:
"""Check SMART health and temperature before starting destructive work."""
async with _db() as db:
cur = await db.execute(
"SELECT smart_health, temperature_c FROM drives WHERE id=?", (drive_id,)
)
row = await cur.fetchone()
if not row:
return False
health, temp = row[0], row[1]
if health == "FAILED":
await _set_stage_error(job_id, "precheck", "Drive SMART health is FAILED — refusing to burn in")
return False
if temp and temp > settings.temp_crit_c:
await _set_stage_error(job_id, "precheck", f"Drive temperature {temp}°C exceeds {settings.temp_crit_c}°C limit")
return False
await asyncio.sleep(1) # Simulate brief check
return True
async def _stage_smart_test(job_id: int, devname: str, test_type: str, stage_name: str,
drive_id: int | None = None) -> bool:
"""Start a SMART test. Uses SSH if configured, TrueNAS REST API otherwise."""
from app import ssh_client
if ssh_client.is_configured():
return await _stage_smart_test_ssh(job_id, devname, test_type, stage_name, drive_id)
return await _stage_smart_test_api(job_id, devname, test_type, stage_name)
async def _stage_smart_test_api(job_id: int, devname: str, test_type: str, stage_name: str) -> bool:
"""TrueNAS REST API path for SMART test (mock / dev mode)."""
tn_job_id = await _client.start_smart_test([devname], test_type)
while True:
if await _is_cancelled(job_id):
try:
await _client.abort_job(tn_job_id)
except Exception:
pass
return False
jobs = await _client.get_smart_jobs()
job = next((j for j in jobs if j["id"] == tn_job_id), None)
if not job:
return False
state = job["state"]
pct = job["progress"]["percent"]
await _update_stage_percent(job_id, stage_name, pct)
await _recalculate_progress(job_id, None)
_push_update()
if state == "SUCCESS":
return True
elif state in ("FAILED", "ABORTED"):
await _set_stage_error(job_id, stage_name,
job.get("error") or f"SMART {test_type} test failed")
return False
await asyncio.sleep(POLL_INTERVAL)
async def _stage_smart_test_ssh(job_id: int, devname: str, test_type: str, stage_name: str,
drive_id: int | None) -> bool:
"""SSH path for SMART test — runs smartctl directly on TrueNAS."""
from app import ssh_client
# Start the test
try:
startup = await ssh_client.start_smart_test(devname, test_type)
await _append_stage_log(job_id, stage_name, startup + "\n")
except Exception as exc:
await _set_stage_error(job_id, stage_name, f"Failed to start SMART test via SSH: {exc}")
return False
# Brief pause to let the test register in smartctl output
await asyncio.sleep(3)
# Throttle log_text appends — every poll on a multi-hour long_smart bloated
# log_text to 50+ MB and triggered SQLite "database is locked" because each
# COALESCE-then-append rewrites the whole column. Append every ~60s, on the
# first poll, and on any state change.
LOG_EVERY_N_POLLS = 12
poll_count = 0
last_state: str | None = None
# Poll until complete
while True:
if await _is_cancelled(job_id):
try:
await ssh_client.abort_smart_test(devname)
except Exception:
pass
return False
await asyncio.sleep(POLL_INTERVAL)
try:
progress = await ssh_client.poll_smart_progress(devname)
except Exception as exc:
log.warning("SSH SMART poll failed: %s", exc, extra={"job_id": job_id})
await _append_stage_log(job_id, stage_name, f"[poll error] {exc}\n")
continue
poll_count += 1
state_changed = progress["state"] != last_state
last_state = progress["state"]
if poll_count == 1 or poll_count % LOG_EVERY_N_POLLS == 0 or state_changed:
await _append_stage_log(job_id, stage_name, progress["output"] + "\n---\n")
if progress["state"] == "running":
pct = max(0, 100 - progress["percent_remaining"])
await _update_stage_percent(job_id, stage_name, pct)
await _recalculate_progress(job_id)
_push_update()
elif progress["state"] == "passed":
await _update_stage_percent(job_id, stage_name, 100)
# Run attribute check
if drive_id is not None:
try:
attrs = await ssh_client.get_smart_attributes(devname)
await _store_smart_attrs(drive_id, attrs)
await _store_smart_raw_output(drive_id, test_type, attrs["raw_output"])
if attrs["failures"]:
error = "SMART attribute failures: " + "; ".join(attrs["failures"])
await _set_stage_error(job_id, stage_name, error)
return False
if attrs["warnings"]:
await _append_stage_log(
job_id, stage_name,
"[WARNING] " + "; ".join(attrs["warnings"]) + "\n"
)
except Exception as exc:
log.warning("Failed to retrieve SMART attributes: %s", exc)
await _recalculate_progress(job_id)
_push_update()
return True
elif progress["state"] == "failed":
await _set_stage_error(job_id, stage_name, f"SMART {test_type} test failed")
return False
# "unknown" → keep polling
async def _badblocks_available() -> bool:
"""Check if badblocks is installed on the remote host (Linux/SCALE only)."""
from app import ssh_client
try:
async with await ssh_client._connect() as conn:
result = await conn.run("which badblocks", check=False)
return result.returncode == 0
except Exception:
return False
async def _stage_surface_validate(job_id: int, devname: str, drive_id: int) -> bool:
"""
Surface validation stage — auto-routes to the right implementation:
1. NVMe device + SSH + nvme-cli available (TrueNAS SCALE):
→ `nvme format -s 1 /dev/{devname}` (cryptographic erase).
Far faster than badblocks on NVMe (seconds vs hours) and
exercises the controller's secure-erase path, not just user-LBA
writes.
2. SSH configured + badblocks available (TrueNAS SCALE / Linux):
→ badblocks -wsv -b N -c N -p N /dev/{devname} directly over SSH.
3. SSH configured + badblocks NOT available (TrueNAS CORE / FreeBSD):
→ uses TrueNAS REST API disk.wipe FULL job + post-wipe SMART check.
4. No SSH:
→ simulated timed progress (dev/mock mode).
"""
from app import ssh_client
if ssh_client.is_configured():
if devname.startswith("nvme") and await _nvme_cli_available():
return await _stage_surface_validate_nvme(job_id, devname, drive_id)
if await _badblocks_available():
return await _stage_surface_validate_ssh(job_id, devname, drive_id)
# TrueNAS CORE/FreeBSD: badblocks not available — use native wipe API
await _append_stage_log(
job_id, "surface_validate",
"[INFO] badblocks not found on host (TrueNAS CORE/FreeBSD) — "
"using TrueNAS disk.wipe API (FULL write pass).\n\n"
)
return await _stage_surface_validate_truenas(job_id, devname, drive_id)
return await _stage_timed_simulate(job_id, "surface_validate", settings.surface_validate_seconds)
async def _nvme_cli_available() -> bool:
"""Check if nvme-cli is installed on the remote host."""
from app import ssh_client
try:
async with await ssh_client._connect() as conn:
r = await conn.run("which nvme", check=False)
return r.returncode == 0
except Exception:
return False
async def _stage_surface_validate_nvme(job_id: int, devname: str,
drive_id: int) -> bool:
"""NVMe destructive surface test via `nvme format -s 1` (crypto erase).
Crypto-erase nukes the data encryption key on the drive's controller,
rendering all stored data unrecoverable in milliseconds; the actual
flash is then implicitly trim-able. This is the canonical destructive
burn-in for NVMe — badblocks would write the entire LBA space, which
is slower AND wears the flash unnecessarily.
Post-format we re-read SMART attributes; the drive should report all
counters reset (life used + spare) and PASSED health.
"""
from app import ssh_client
await _append_stage_log(
job_id, "surface_validate",
f"[START] nvme format -s 1 /dev/{devname}\n"
f"[NOTE] Cryptographic erase — destroys all data on /dev/{devname}.\n\n"
)
cmd = f"nvme format -s 1 --force /dev/{devname}"
try:
async with await ssh_client._connect() as conn:
r = await asyncio.wait_for(
conn.run(cmd, check=False), timeout=600
)
except Exception as exc:
await _append_stage_log(
job_id, "surface_validate", f"\n[SSH error] {exc}\n"
)
await _set_stage_error(
job_id, "surface_validate", f"NVMe format SSH error: {exc}"
)
return False
output = (r.stdout or "") + (r.stderr or "")
await _append_stage_log(job_id, "surface_validate", output + "\n")
if r.returncode != 0:
await _set_stage_error(
job_id, "surface_validate",
f"nvme format exited {r.returncode}: {output.strip()[:200]}"
)
return False
# Sanity-check post-format SMART health. Mirrors the surface_validate
# SSH path's check parity — fail on FAILED health, fail on real
# SMART attribute failures, log warnings but don't fail. A transport
# error here is treated as a soft pass (log + continue) so a single
# SSH blip after a successful format doesn't undo the work.
try:
attrs = await ssh_client.get_smart_attributes(devname)
ssh_only_failures = [
f for f in (attrs.get("failures") or []) if f.startswith("SSH error:")
]
real_failures = [
f for f in (attrs.get("failures") or []) if not f.startswith("SSH error:")
]
if attrs.get("health") == "FAILED":
await _set_stage_error(
job_id, "surface_validate",
"NVMe SMART health FAILED after format",
)
return False
if real_failures:
await _set_stage_error(
job_id, "surface_validate",
"NVMe SMART attribute failures after format: "
+ "; ".join(real_failures),
)
return False
if ssh_only_failures:
await _append_stage_log(
job_id, "surface_validate",
"[WARN] post-format SMART check had SSH errors "
"(soft-passing): " + "; ".join(ssh_only_failures) + "\n",
)
if attrs.get("warnings"):
await _append_stage_log(
job_id, "surface_validate",
"[WARN] " + "; ".join(attrs["warnings"]) + "\n",
)
except Exception as exc:
log.warning("Post-format SMART check error on %s: %s", devname, exc)
await _append_stage_log(
job_id, "surface_validate",
f"[WARN] post-format SMART check raised: {exc}\n",
)
await _update_stage_percent(job_id, "surface_validate", 100)
await _recalculate_progress(job_id)
_push_update()
return True
async def _stage_surface_validate_ssh(job_id: int, devname: str, drive_id: int) -> bool:
"""Run badblocks over SSH, streaming output to stage log."""
from app import ssh_client
await _append_stage_log(
job_id, "surface_validate",
f"[START] badblocks -wsv -b {settings.surface_validate_block_size} "
f"-c {settings.surface_validate_block_buffer} "
f"-p {settings.surface_validate_passes} /dev/{devname}\n"
f"[NOTE] This is a DESTRUCTIVE write test. "
f"All data on /dev/{devname} will be overwritten.\n\n"
)
def _is_cancelled_sync() -> bool:
# Synchronous version — we check the DB state flag set by cancel_job()
import asyncio
loop = asyncio.get_event_loop()
try:
return loop.run_until_complete(_is_cancelled(job_id))
except Exception:
return False
last_logged_pct = [-1]
def on_progress(pct: int, bad_blocks: int, line: str) -> None:
nonlocal last_logged_pct
# Write to log (fire-and-forget via asyncio.create_task from sync context)
# The log append is done in the async flush below
pass
accumulated_lines: list[str] = []
async def on_progress_async(pct: int, bad_blocks: int, line: str) -> None:
accumulated_lines.append(line)
# Flush to DB and update progress every ~25 lines to avoid excessive DB writes
if len(accumulated_lines) % 25 == 0:
await _append_stage_log(job_id, "surface_validate", "".join(accumulated_lines[-25:]))
await _update_stage_bad_blocks(job_id, "surface_validate", bad_blocks)
await _update_stage_percent(job_id, "surface_validate", pct)
await _recalculate_progress(job_id)
_push_update()
if await _is_cancelled(job_id):
raise asyncio.CancelledError
# Run badblocks — we adapt the callback pattern to async by collecting then flushing
result = {"bad_blocks": 0, "output": "", "aborted": False}
try:
# The actual streaming; we handle progress via the accumulated_lines pattern
bad_blocks_total = 0
output_lines: list[str] = []
async with await ssh_client._connect() as conn:
# Wrap in `sh -c 'echo PID:$$; exec ...'` so we get the remote
# PID on the first stdout line. asyncssh's proc.kill() sends an
# SSH signal request that OpenSSH's sshd ignores by default, so
# we need the PID to issue an out-of-band `kill -9` over a fresh
# session when we want to abort.
#
# Block geometry is operator-tunable (Settings → Burn-in):
# -b N block size in bytes (settings.surface_validate_block_size)
# -c N blocks held per IO (settings.surface_validate_block_buffer)
# -p N pass count (settings.surface_validate_passes)
# Defaults preserve original behavior (-b 4096 -c 64 -p 1).
bb_args = (
f"-wsv "
f"-b {settings.surface_validate_block_size} "
f"-c {settings.surface_validate_block_buffer} "
f"-p {settings.surface_validate_passes}"
)
cmd = (
f"sh -c 'echo PID:$$; exec badblocks {bb_args} /dev/{devname}'"
)
async with conn.create_process(cmd) as proc:
import re as _re
pid_seen = False
async def _drain(stream, is_stderr: bool):
nonlocal bad_blocks_total, pid_seen
async for raw in stream:
line = raw if isinstance(raw, str) else raw.decode("utf-8", errors="replace")
# First stdout line is "PID:<n>" from the wrapping shell.
# Capture it and don't append it to the user-visible log.
if not is_stderr and not pid_seen and line.startswith("PID:"):
pid_seen = True
try:
_remote_pids[job_id] = int(line[4:].strip())
log.info(
"Captured remote PID %d for job %d (badblocks)",
_remote_pids[job_id], job_id,
)
except ValueError:
pass
continue
output_lines.append(line)
if is_stderr:
m = _re.search(r"([\d.]+)%\s+done", line)
if m:
pct = min(99, int(float(m.group(1))))
await _update_stage_percent(job_id, "surface_validate", pct)
await _update_stage_bad_blocks(job_id, "surface_validate", bad_blocks_total)
await _recalculate_progress(job_id)
_push_update()
else:
stripped = line.strip()
if stripped and stripped.isdigit():
bad_blocks_total += 1
# Append to DB log in chunks
if len(output_lines) % 20 == 0:
chunk = "".join(output_lines[-20:])
await _append_stage_log(job_id, "surface_validate", chunk)
# Abort on bad block threshold
if bad_blocks_total > settings.bad_block_threshold:
await _kill_remote_process(job_id)
output_lines.append(
f"\n[ABORTED] {bad_blocks_total} bad block(s) exceeded "
f"threshold ({settings.bad_block_threshold})\n"
)
return
if await _is_cancelled(job_id):
await _kill_remote_process(job_id)
return
await asyncio.gather(
_drain(proc.stdout, False),
_drain(proc.stderr, True),
return_exceptions=True,
)
# Bound proc.wait so a remote process that ignored our kill
# signal (or that we never managed to kill) can't pin this
# task in the semaphore forever. Closing the connection on
# exit will deliver SIGPIPE to the remote on its next write.
try:
await asyncio.wait_for(proc.wait(), timeout=15)
except asyncio.TimeoutError:
log.warning(
"proc.wait() timed out for job %d — abandoning channel",
job_id,
)
# Flush only lines we haven't already written in 20-line chunks.
# Previously we appended the FULL accumulated output here too,
# doubling the stored log_text size for every surface_validate
# stage and pushing app.db into hundreds of MB.
flushed_count = (len(output_lines) // 20) * 20
tail = "".join(output_lines[flushed_count:])
if tail:
await _append_stage_log(job_id, "surface_validate", tail)
result["bad_blocks"] = bad_blocks_total
result["output"] = "".join(output_lines) # in-memory only, not re-stored
result["aborted"] = bad_blocks_total > settings.bad_block_threshold
except asyncio.CancelledError:
return False
except Exception as exc:
await _append_stage_log(job_id, "surface_validate", f"\n[SSH error] {exc}\n")
await _set_stage_error(job_id, "surface_validate", f"SSH badblocks error: {exc}")
return False
await _update_stage_bad_blocks(job_id, "surface_validate", result["bad_blocks"])
if result["aborted"] or result["bad_blocks"] > settings.bad_block_threshold:
await _set_stage_error(
job_id, "surface_validate",
f"Surface validate FAILED: {result['bad_blocks']} bad block(s) found "
f"(threshold: {settings.bad_block_threshold})"
)
return False
return True
async def _stage_surface_validate_truenas(job_id: int, devname: str, drive_id: int) -> bool:
"""
Surface validation via TrueNAS CORE disk.wipe REST API.
Used on FreeBSD (TrueNAS CORE) where badblocks is unavailable.
Sends a FULL write-zero pass across the entire disk, polls progress,
then runs a post-wipe SMART attribute check to catch reallocated sectors.
"""
from app import ssh_client
await _append_stage_log(
job_id, "surface_validate",
f"[START] TrueNAS disk.wipe FULL — {devname}\n"
f"[NOTE] DESTRUCTIVE: all data on {devname} will be overwritten.\n\n"
)
# Start the wipe job
try:
tn_job_id = await _client.wipe_disk(devname, "FULL")
except Exception as exc:
await _set_stage_error(job_id, "surface_validate", f"Failed to start disk.wipe: {exc}")
return False
await _append_stage_log(
job_id, "surface_validate",
f"[JOB] TrueNAS wipe job started (job_id={tn_job_id})\n"
)
# Poll until complete
log_flush_counter = 0
while True:
if await _is_cancelled(job_id):
try:
await _client.abort_job(tn_job_id)
except Exception:
pass
return False
await asyncio.sleep(POLL_INTERVAL)
try:
job = await _client.get_job(tn_job_id)
except Exception as exc:
log.warning("Wipe job poll failed: %s", exc, extra={"job_id": job_id})
await _append_stage_log(job_id, "surface_validate", f"[poll error] {exc}\n")
continue
if not job:
await _set_stage_error(job_id, "surface_validate", f"Wipe job {tn_job_id} not found")
return False
state = job.get("state", "")
pct = int(job.get("progress", {}).get("percent", 0) or 0)
desc = job.get("progress", {}).get("description", "")
await _update_stage_percent(job_id, "surface_validate", min(pct, 99))
await _recalculate_progress(job_id)
_push_update()
# Log progress description every ~5 polls to avoid DB spam
log_flush_counter += 1
if desc and log_flush_counter % 5 == 0:
await _append_stage_log(job_id, "surface_validate", f"[{pct}%] {desc}\n")
if state == "SUCCESS":
await _update_stage_percent(job_id, "surface_validate", 100)
await _append_stage_log(
job_id, "surface_validate",
f"\n[DONE] Wipe job {tn_job_id} completed successfully.\n"
)
# Post-wipe SMART check — catch any sectors that failed under write stress
if ssh_client.is_configured() and drive_id is not None:
await _append_stage_log(
job_id, "surface_validate",
"[CHECK] Running post-wipe SMART attribute check...\n"
)
try:
attrs = await ssh_client.get_smart_attributes(devname)
await _store_smart_attrs(drive_id, attrs)
if attrs["failures"]:
error = "Post-wipe SMART check: " + "; ".join(attrs["failures"])
await _set_stage_error(job_id, "surface_validate", error)
return False
if attrs["warnings"]:
await _append_stage_log(
job_id, "surface_validate",
"[WARNING] " + "; ".join(attrs["warnings"]) + "\n"
)
await _append_stage_log(
job_id, "surface_validate",
f"[CHECK] SMART health: {attrs['health']} — no critical attributes.\n"
)
except Exception as exc:
log.warning("Post-wipe SMART check failed: %s", exc)
await _append_stage_log(
job_id, "surface_validate",
f"[WARN] Post-wipe SMART check failed (non-fatal): {exc}\n"
)
return True
elif state in ("FAILED", "ABORTED", "ERROR"):
error_msg = job.get("error") or f"Disk wipe failed (state={state})"
await _set_stage_error(
job_id, "surface_validate",
f"TrueNAS disk.wipe FAILED: {error_msg}"
)
return False
# RUNNING or WAITING — keep polling
async def _stage_timed_simulate(job_id: int, stage_name: str, duration_seconds: int) -> bool:
"""Simulate a timed stage with progress updates (mock / dev mode)."""
start = time.monotonic()
while True:
if await _is_cancelled(job_id):
return False
elapsed = time.monotonic() - start
pct = min(100, int(elapsed / duration_seconds * 100))
await _update_stage_percent(job_id, stage_name, pct)
await _recalculate_progress(job_id, None)
_push_update()
if pct >= 100:
return True
await asyncio.sleep(POLL_INTERVAL)
async def _stage_final_check(job_id: int, devname: str, drive_id: int | None = None) -> bool:
"""
Verify drive passed all tests.
SSH mode: run smartctl -a and check critical attributes.
Mock mode: check SMART health field in DB.
A transient SSH connectivity failure here must NOT invalidate a prior
multi-day surface_validate. Retry SSH-only failures, then soft-pass.
"""
await asyncio.sleep(1)
from app import ssh_client
def _ssh_only(failures: list[str]) -> bool:
return bool(failures) and all(f.startswith("SSH error:") for f in failures)
if ssh_client.is_configured() and drive_id is not None:
try:
attrs = await ssh_client.get_smart_attributes(devname)
for attempt in range(2):
if not _ssh_only(attrs.get("failures") or []):
break
log.warning(
"final_check SSH unreachable (attempt %d/3); retrying in 30s",
attempt + 1,
extra={"job_id": job_id, "devname": devname},
)
await asyncio.sleep(30)
attrs = await ssh_client.get_smart_attributes(devname)
failures = attrs.get("failures") or []
if _ssh_only(failures):
log.warning(
"final_check soft-pass: SSH unreachable after retries; prior stages stand",
extra={"job_id": job_id, "devname": devname, "ssh_error": failures},
)
return True
await _store_smart_attrs(drive_id, attrs)
if attrs["health"] == "FAILED" or failures:
msg = failures or [f"SMART health: {attrs['health']}"]
await _set_stage_error(job_id, "final_check",
"Final check failed: " + "; ".join(msg))
return False
return True
except Exception as exc:
log.warning("SSH final_check raised, falling back to DB check: %s", exc)
# DB check (mock mode fallback)
async with _db() as db:
cur = await db.execute(
"SELECT smart_health FROM drives WHERE devname=?", (devname,)
)
row = await cur.fetchone()
if not row or row[0] == "FAILED":
await _set_stage_error(job_id, "final_check", "Drive SMART health is FAILED after burn-in")
return False
return True
# ---------------------------------------------------------------------------
# DB helpers
# ---------------------------------------------------------------------------
async def _is_cancelled(job_id: int) -> bool:
async with _db() as db:
cur = await db.execute("SELECT state FROM burnin_jobs WHERE id=?", (job_id,))
row = await cur.fetchone()
return bool(row and row[0] == "cancelled")
async def _start_stage(job_id: int, stage_name: str) -> None:
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(
"UPDATE burnin_stages SET state='running', started_at=? WHERE burnin_job_id=? AND stage_name=?",
(_now(), job_id, stage_name),
)
await db.execute(
"UPDATE burnin_jobs SET stage_name=? WHERE id=?",
(stage_name, job_id),
)
await db.commit()
async def _finish_stage(job_id: int, stage_name: str, success: bool, error_text: str | None = None) -> None:
now = _now()
state = "passed" if success else "failed"
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
cur = await db.execute(
"SELECT started_at FROM burnin_stages WHERE burnin_job_id=? AND stage_name=?",
(job_id, stage_name),
)
row = await cur.fetchone()
duration = None
if row and row[0]:
try:
start = datetime.fromisoformat(row[0])
if start.tzinfo is None:
start = start.replace(tzinfo=timezone.utc)
duration = (datetime.now(timezone.utc) - start).total_seconds()
except Exception:
pass
# Only overwrite error_text if one is passed; otherwise preserve what the stage already wrote
if error_text is not None:
await db.execute(
"""UPDATE burnin_stages
SET state=?, percent=?, finished_at=?, duration_seconds=?, error_text=?
WHERE burnin_job_id=? AND stage_name=?""",
(state, 100 if success else None, now, duration, error_text, job_id, stage_name),
)
else:
await db.execute(
"""UPDATE burnin_stages
SET state=?, percent=?, finished_at=?, duration_seconds=?
WHERE burnin_job_id=? AND stage_name=?""",
(state, 100 if success else None, now, duration, job_id, stage_name),
)
await db.commit()
async def _update_stage_percent(job_id: int, stage_name: str, pct: int) -> None:
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(
"UPDATE burnin_stages SET percent=? WHERE burnin_job_id=? AND stage_name=?",
(pct, job_id, stage_name),
)
await db.commit()
async def _cancel_stage(job_id: int, stage_name: str) -> None:
now = _now()
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(
"UPDATE burnin_stages SET state='cancelled', finished_at=? WHERE burnin_job_id=? AND stage_name=?",
(now, job_id, stage_name),
)
await db.commit()
async def _append_stage_log(job_id: int, stage_name: str, text: str) -> None:
"""Append text to the log_text column of a burnin_stages row."""
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(
"""UPDATE burnin_stages
SET log_text = COALESCE(log_text, '') || ?
WHERE burnin_job_id=? AND stage_name=?""",
(text, job_id, stage_name),
)
await db.commit()
async def _update_stage_bad_blocks(job_id: int, stage_name: str, count: int) -> None:
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(
"UPDATE burnin_stages SET bad_blocks=? WHERE burnin_job_id=? AND stage_name=?",
(count, job_id, stage_name),
)
await db.commit()
async def _store_smart_attrs(drive_id: int, attrs: dict) -> None:
"""Persist latest SMART attribute dict to drives.smart_attrs (JSON)."""
import json
# Convert int keys to str for JSON serialisation
serialisable = {str(k): v for k, v in attrs.get("attributes", {}).items()}
blob = json.dumps({
"health": attrs.get("health", "UNKNOWN"),
"attrs": serialisable,
"warnings": attrs.get("warnings", []),
"failures": attrs.get("failures", []),
})
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("UPDATE drives SET smart_attrs=? WHERE id=?", (blob, drive_id))
await db.commit()
async def _store_smart_raw_output(drive_id: int, test_type: str, raw: str) -> None:
"""Store raw smartctl output in smart_tests.raw_output."""
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(
"UPDATE smart_tests SET raw_output=? WHERE drive_id=? AND test_type=?",
(raw, drive_id, test_type.lower()),
)
await db.commit()
async def _set_stage_error(job_id: int, stage_name: str, error_text: str) -> None:
async with _db() as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(
"UPDATE burnin_stages SET error_text=? WHERE burnin_job_id=? AND stage_name=?",
(error_text, job_id, stage_name),
)
await db.commit()
async def _recalculate_progress(job_id: int, profile: str | None = None) -> None:
"""Recompute overall job % from actual stage rows. profile param is unused (kept for compat)."""
async with _db() as db:
db.row_factory = aiosqlite.Row
await db.execute("PRAGMA journal_mode=WAL")
cur = await db.execute(
"SELECT stage_name, state, percent FROM burnin_stages WHERE burnin_job_id=? ORDER BY id",
(job_id,),
)
stages = await cur.fetchall()
if not stages:
return
total_weight = sum(_STAGE_BASE_WEIGHTS.get(s["stage_name"], 5) for s in stages)
if total_weight == 0:
return
completed = 0.0
current = None
for s in stages:
w = _STAGE_BASE_WEIGHTS.get(s["stage_name"], 5)
st = s["state"]
if st == "passed":
completed += w
elif st == "running":
completed += w * (s["percent"] or 0) / 100
current = s["stage_name"]
pct = int(completed / total_weight * 100)
await db.execute(
"UPDATE burnin_jobs SET percent=?, stage_name=? WHERE id=?",
(pct, current, job_id),
)
await db.commit()
# ---------------------------------------------------------------------------
# SSE push
# ---------------------------------------------------------------------------
def _push_update(alert: dict | None = None) -> None:
"""Notify SSE subscribers that data has changed, with optional browser notification payload."""
try:
from app import poller
poller._notify_subscribers(alert=alert)
except Exception:
pass
# ---------------------------------------------------------------------------
# Stuck-job detection (called by poller every ~5 cycles)
# ---------------------------------------------------------------------------
async def check_stuck_jobs() -> None:
"""Mark jobs that have been 'running' beyond stuck_job_hours as 'unknown'."""
threshold_seconds = settings.stuck_job_hours * 3600
async with _db() as db:
db.row_factory = aiosqlite.Row
await db.execute("PRAGMA journal_mode=WAL")
cur = await db.execute("""
SELECT bj.id, bj.drive_id, d.devname, bj.started_at
FROM burnin_jobs bj
JOIN drives d ON d.id = bj.drive_id
WHERE bj.state = 'running'
AND bj.started_at IS NOT NULL
AND (julianday('now') - julianday(bj.started_at)) * 86400 > ?
""", (threshold_seconds,))
stuck = await cur.fetchall()
if not stuck:
return
now = _now()
for row in stuck:
job_id, drive_id, devname, started_at = row[0], row[1], row[2], row[3]
log.critical(
"Stuck burn-in detected — marking unknown",
extra={"job_id": job_id, "devname": devname, "started_at": started_at},
)
await db.execute(
"UPDATE burnin_jobs SET state='unknown', finished_at=? WHERE id=?",
(now, job_id),
)
await db.execute(
"""UPDATE burnin_stages SET state='unknown', finished_at=?
WHERE burnin_job_id=? AND state='running'""",
(now, job_id),
)
await db.execute(
"""INSERT INTO audit_events (event_type, drive_id, burnin_job_id, operator, message)
VALUES (?,?,?,?,?)""",
("burnin_stuck", drive_id, job_id, "system",
f"Job stuck for >{settings.stuck_job_hours}h — automatically marked unknown"),
)
await db.commit()
# Actually unstick the running tasks so they release their semaphore slot.
# Without this the DB state becomes 'unknown' but the asyncio task keeps
# holding the slot forever — which is the bug that left subsequent jobs
# permanently 'queued' until container restart.
for row in stuck:
job_id = row[0]
await _kill_remote_process(job_id)
task = _active_tasks.get(job_id)
if task and not task.done():
task.cancel()
_push_update()
log.warning("Marked %d stuck job(s) as unknown", len(stuck))