Closes the last open Codex finding (#5) and removes one piece of dead code Codex flagged in passing. #5 — Live pool re-check before burn-in start: Before this change, _is_unlocked compared the operator's unlock grant against the cached drives.pool_* row. If a drive was imported into a pool, mounted, or had ZFS labels written between the operator's unlock click and the next ~12s poll, burn-in could still start against the stale identity and silently destroy the new pool. start_job now calls a fresh ssh_client.fresh_pool_check_for_drive() immediately after the cached gate. That helper re-runs the three detection probes (zpool list -vHP / lsblk zfs_member / findmnt) over a fresh SSH session and returns the live answer for one devname. If it differs from cached state we invalidate any existing unlock grant and raise PoolMemberError with the FRESH pool name so the UI reflects current reality. If fresh shows free but cached said locked the drive came back to free since last poll — log it and allow. Cost: ~200ms per burn-in start. For batch starts of 12 drives, that's 2.4s extra latency — cheap against destroying a freshly-imported pool. Dead code removal: ssh_client.run_badblocks() — no callers since 1.0.0-13 when the SSH badblocks logic was inlined into burnin._stage_surface_validate_ssh (with the asyncssh-signal-doesn't-actually-kill workaround). Removing the dead function also lets us drop the now-unused `from typing import Callable` import. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1667 lines
66 KiB
Python
1667 lines
66 KiB
Python
"""
|
||
Burn-in orchestrator.
|
||
|
||
Manages a FIFO queue of burn-in jobs capped at MAX_PARALLEL_BURNINS concurrent
|
||
executions. Each job runs stages sequentially; a failed stage aborts the job.
|
||
|
||
State is persisted to SQLite throughout — DB is source of truth.
|
||
|
||
On startup:
|
||
- Any 'running' jobs from a previous run are marked 'unknown' (interrupted).
|
||
- Any 'queued' jobs are re-enqueued automatically.
|
||
|
||
Cancellation:
|
||
- cancel_job() sets DB state to 'cancelled'.
|
||
- Running stage coroutines check _is_cancelled() at POLL_INTERVAL boundaries
|
||
and abort within a few seconds of the cancel request.
|
||
"""
|
||
|
||
import asyncio
|
||
import logging
|
||
import time
|
||
from contextlib import asynccontextmanager
|
||
from datetime import datetime, timezone
|
||
|
||
import aiosqlite
|
||
|
||
from app.config import settings
|
||
from app.truenas import TrueNASClient
|
||
|
||
log = logging.getLogger(__name__)
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Stage definitions
|
||
# ---------------------------------------------------------------------------
|
||
|
||
STAGE_ORDER: dict[str, list[str]] = {
|
||
# Legacy
|
||
"quick": ["precheck", "short_smart", "io_validate", "final_check"],
|
||
# Single-stage selectable profiles
|
||
"surface": ["precheck", "surface_validate", "final_check"],
|
||
"short": ["precheck", "short_smart", "final_check"],
|
||
"long": ["precheck", "long_smart", "final_check"],
|
||
# Two-stage combos
|
||
"surface_short": ["precheck", "surface_validate", "short_smart", "final_check"],
|
||
"surface_long": ["precheck", "surface_validate", "long_smart", "final_check"],
|
||
"short_long": ["precheck", "short_smart", "long_smart", "final_check"],
|
||
# All three
|
||
"full": ["precheck", "surface_validate", "short_smart", "long_smart", "final_check"],
|
||
}
|
||
|
||
# Per-stage base weights used to compute overall job % progress dynamically
|
||
_STAGE_BASE_WEIGHTS: dict[str, int] = {
|
||
"precheck": 5,
|
||
"surface_validate": 65,
|
||
"short_smart": 12,
|
||
"long_smart": 13,
|
||
"io_validate": 10,
|
||
"final_check": 5,
|
||
}
|
||
|
||
POLL_INTERVAL = 5.0 # seconds between progress checks during active stages
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Module-level state (initialized in init())
|
||
# ---------------------------------------------------------------------------
|
||
|
||
_semaphore: asyncio.Semaphore | None = None
|
||
_client: TrueNASClient | None = None
|
||
|
||
# Live job tracking — keeps a strong reference to every _run_job task so it
|
||
# isn't garbage-collected (asyncio.create_task only keeps a weak ref) and so
|
||
# cancel_job / check_stuck_jobs can actually unwedge a stuck task.
|
||
_active_tasks: dict[int, "asyncio.Task"] = {}
|
||
|
||
# Remote PID of any long-running SSH child process (currently only badblocks)
|
||
# so we can kill it via a fresh SSH session — proc.kill() over asyncssh sends
|
||
# a "signal" channel request that OpenSSH sshd ignores by default, leaving
|
||
# the remote process running and proc.wait() hanging forever.
|
||
_remote_pids: dict[int, int] = {}
|
||
|
||
|
||
def _now() -> str:
|
||
return datetime.now(timezone.utc).isoformat()
|
||
|
||
|
||
@asynccontextmanager
|
||
async def _db():
|
||
"""Open a WAL-mode connection with busy_timeout so writers wait for the lock
|
||
instead of immediately raising 'database is locked' under contention."""
|
||
async with aiosqlite.connect(settings.db_path) as db:
|
||
await db.execute("PRAGMA busy_timeout=10000")
|
||
yield db
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Init + startup reconciliation
|
||
# ---------------------------------------------------------------------------
|
||
|
||
async def init(client: TrueNASClient) -> None:
|
||
global _semaphore, _client
|
||
_semaphore = asyncio.Semaphore(settings.max_parallel_burnins)
|
||
_client = client
|
||
|
||
async with _db() as db:
|
||
db.row_factory = aiosqlite.Row
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
await db.execute("PRAGMA foreign_keys=ON")
|
||
|
||
# Mark interrupted running jobs as unknown
|
||
await db.execute(
|
||
"UPDATE burnin_jobs SET state='unknown', finished_at=? WHERE state='running'",
|
||
(_now(),),
|
||
)
|
||
|
||
# Re-enqueue previously queued jobs
|
||
cur = await db.execute(
|
||
"SELECT id FROM burnin_jobs WHERE state='queued' ORDER BY created_at"
|
||
)
|
||
queued = [r["id"] for r in await cur.fetchall()]
|
||
await db.commit()
|
||
|
||
for job_id in queued:
|
||
_spawn_run_job(job_id)
|
||
|
||
log.info("Burn-in orchestrator ready (max_concurrent=%d)", settings.max_parallel_burnins)
|
||
|
||
|
||
def _spawn_run_job(job_id: int) -> "asyncio.Task":
|
||
"""Schedule a _run_job task and keep a strong reference to it.
|
||
|
||
Plain asyncio.create_task() only leaves a weak reference behind, so the
|
||
task can be GC'd before it ever runs. Storing it in _active_tasks also
|
||
lets cancel_job / check_stuck_jobs cancel it directly.
|
||
"""
|
||
task = asyncio.create_task(_run_job(job_id))
|
||
_active_tasks[job_id] = task
|
||
|
||
def _cleanup(t: "asyncio.Task") -> None:
|
||
# Remove only if it's still us — avoid clobbering a re-enqueued task.
|
||
if _active_tasks.get(job_id) is t:
|
||
_active_tasks.pop(job_id, None)
|
||
_remote_pids.pop(job_id, None)
|
||
|
||
task.add_done_callback(_cleanup)
|
||
return task
|
||
|
||
|
||
async def _kill_remote_process(job_id: int) -> None:
|
||
"""Send kill -9 to the remote PID associated with this job, if any.
|
||
|
||
asyncssh's proc.kill() sends an SSH 'signal' channel request which
|
||
OpenSSH's sshd does not honor by default. Opening a fresh session and
|
||
running /bin/kill is the reliable way to actually terminate the process.
|
||
"""
|
||
pid = _remote_pids.pop(job_id, None)
|
||
if not pid:
|
||
return
|
||
try:
|
||
from app import ssh_client
|
||
async with await ssh_client._connect() as conn:
|
||
await asyncio.wait_for(
|
||
conn.run(f"kill -9 {pid} 2>/dev/null || true", check=False),
|
||
timeout=10,
|
||
)
|
||
log.info("Remote-killed PID %d for job %d", pid, job_id)
|
||
except Exception as exc:
|
||
log.warning("Failed to remote-kill PID %d for job %d: %s", pid, job_id, exc)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Pool-drive unlock state
|
||
# ---------------------------------------------------------------------------
|
||
#
|
||
# Drives that ZFS reports as belonging to an active zpool (including the
|
||
# boot pool) are locked from burn-in until the operator explicitly unlocks
|
||
# them via POST /api/v1/drives/{id}/unlock. Grants live in memory only —
|
||
# a container restart wipes them, which is the right default for "this is
|
||
# very dangerous." TTL is bounded so an unlock you forgot about can't sit
|
||
# armed indefinitely.
|
||
|
||
import time as _time
|
||
from dataclasses import dataclass
|
||
|
||
UNLOCK_TTL_SECONDS = 600 # 10 minutes
|
||
BOOT_POOL_NAME = "boot-pool"
|
||
BOOT_POOL_CONFIRM_TOKEN = "DESTROY BOOT POOL"
|
||
EXPORTED_POOL_ROLE = "exported"
|
||
EXPORTED_CONFIRM_TOKEN = "DESTROY EXPORTED POOL"
|
||
MOUNTED_ROLE = "mounted"
|
||
MOUNTED_CONFIRM_TOKEN = "DESTROY MOUNTED FILESYSTEM"
|
||
|
||
|
||
@dataclass
|
||
class _UnlockGrant:
|
||
"""An operator-issued, time-bounded permission to burn-in a pool drive.
|
||
|
||
The grant is BOUND to the (pool_name, pool_role) observed at unlock
|
||
time. If a subsequent poll reclassifies the drive — e.g. it was
|
||
"(exported)" when unlocked but is now in active pool "tank", or it
|
||
used to be a cache vdev and now shows as data — the grant is
|
||
invalidated. Otherwise the operator's "I confirm this exported drive
|
||
is decommissioned" judgement would silently authorise destruction
|
||
of a live pool.
|
||
"""
|
||
expiry: float
|
||
pool_name: str
|
||
pool_role: str | None
|
||
|
||
|
||
_unlock_grants: dict[int, _UnlockGrant] = {}
|
||
|
||
|
||
class PoolMemberError(Exception):
|
||
"""Raised by start_job when a drive is in a zpool and not unlocked."""
|
||
def __init__(self, drive_id: int, pool_name: str, pool_role: str | None):
|
||
self.drive_id = drive_id
|
||
self.pool_name = pool_name
|
||
self.pool_role = pool_role
|
||
is_boot = pool_name == BOOT_POOL_NAME
|
||
super().__init__(
|
||
f"Drive is part of {'BOOT POOL' if is_boot else 'pool'} "
|
||
f"'{pool_name}'{' (' + pool_role + ')' if pool_role else ''}. "
|
||
f"Unlock required before burn-in."
|
||
)
|
||
|
||
|
||
def _is_unlocked(drive_id: int, current_pool_name: str | None,
|
||
current_pool_role: str | None) -> bool:
|
||
"""True iff a non-expired grant exists AND the drive's pool identity
|
||
matches what was observed at unlock time."""
|
||
grant = _unlock_grants.get(drive_id)
|
||
if grant is None:
|
||
return False
|
||
if _time.time() >= grant.expiry:
|
||
_unlock_grants.pop(drive_id, None)
|
||
return False
|
||
if grant.pool_name != current_pool_name or grant.pool_role != current_pool_role:
|
||
# Pool identity changed since unlock — drive may now belong to a
|
||
# different (or live) pool. Invalidate the grant; operator must
|
||
# re-unlock with eyes-open against the current state.
|
||
_unlock_grants.pop(drive_id, None)
|
||
log.warning(
|
||
"Invalidating unlock grant for drive_id=%d: pool changed from "
|
||
"(%s, %s) to (%s, %s)",
|
||
drive_id, grant.pool_name, grant.pool_role,
|
||
current_pool_name, current_pool_role,
|
||
)
|
||
return False
|
||
return True
|
||
|
||
|
||
def unlock_expiry(drive_id: int, current_pool_name: str | None,
|
||
current_pool_role: str | None) -> float | None:
|
||
"""Return the absolute expiry of an active grant, or None.
|
||
|
||
Same identity-binding semantics as _is_unlocked: a grant whose stored
|
||
pool identity no longer matches the current row is treated as expired
|
||
and reaped. This is what the dashboard reads to decide whether to show
|
||
the unlocked-Burn-In affordance vs the locked-Unlock affordance.
|
||
"""
|
||
grant = _unlock_grants.get(drive_id)
|
||
if grant is None:
|
||
return None
|
||
if _time.time() >= grant.expiry:
|
||
_unlock_grants.pop(drive_id, None)
|
||
return None
|
||
if grant.pool_name != current_pool_name or grant.pool_role != current_pool_role:
|
||
_unlock_grants.pop(drive_id, None)
|
||
return None
|
||
return grant.expiry
|
||
|
||
|
||
async def grant_pool_unlock(drive_id: int, confirm_token: str,
|
||
operator: str, reason: str) -> float:
|
||
"""Validate confirmation token + reason and grant a time-limited unlock.
|
||
|
||
Raises ValueError on bad confirm_token, missing reason, or drive not
|
||
actually in a pool. Returns the unix expiry timestamp on success.
|
||
"""
|
||
if not reason or len(reason.strip()) < 5:
|
||
raise ValueError("A reason of at least 5 characters is required.")
|
||
if not operator or not operator.strip():
|
||
raise ValueError("Operator name is required.")
|
||
|
||
async with _db() as db:
|
||
db.row_factory = aiosqlite.Row
|
||
cur = await db.execute(
|
||
"SELECT pool_name, pool_role, devname FROM drives WHERE id=?",
|
||
(drive_id,),
|
||
)
|
||
row = await cur.fetchone()
|
||
if not row:
|
||
raise ValueError("Drive not found.")
|
||
pool_name = row["pool_name"]
|
||
pool_role = row["pool_role"]
|
||
if not pool_name:
|
||
raise ValueError(
|
||
"This drive is not part of any pool — no unlock needed."
|
||
)
|
||
|
||
# Boot-pool / exported / mounted-fs all get dedicated, harder-to-
|
||
# fat-finger tokens. Active data pools just need their pool name
|
||
# typed.
|
||
if pool_name == BOOT_POOL_NAME:
|
||
expected = BOOT_POOL_CONFIRM_TOKEN
|
||
elif pool_role == EXPORTED_POOL_ROLE:
|
||
expected = EXPORTED_CONFIRM_TOKEN
|
||
elif pool_role == MOUNTED_ROLE:
|
||
expected = MOUNTED_CONFIRM_TOKEN
|
||
else:
|
||
expected = pool_name
|
||
if (confirm_token or "").strip() != expected:
|
||
raise ValueError("Confirmation token does not match.")
|
||
|
||
if pool_name == BOOT_POOL_NAME:
|
||
evt = "boot_pool_drive_unlocked"
|
||
elif pool_role == EXPORTED_POOL_ROLE:
|
||
evt = "exported_pool_drive_unlocked"
|
||
elif pool_role == MOUNTED_ROLE:
|
||
evt = "mounted_drive_unlocked"
|
||
else:
|
||
evt = "pool_drive_unlocked"
|
||
await db.execute(
|
||
"""INSERT INTO audit_events
|
||
(event_type, drive_id, burnin_job_id, operator, message)
|
||
VALUES (?,?,?,?,?)""",
|
||
(evt, drive_id, None, operator.strip(),
|
||
f"Unlocked {pool_name} drive {row['devname']} for burn-in: {reason.strip()}"),
|
||
)
|
||
await db.commit()
|
||
|
||
# Arm the in-memory grant ONLY after the audit row is durable. If the
|
||
# commit above raises, we exit without writing _unlock_grants — no
|
||
# unaudited active unlocks. The grant is bound to the (pool_name,
|
||
# pool_role) we observed under the open transaction so a later poll
|
||
# that reclassifies the drive invalidates it (see _is_unlocked).
|
||
expiry = _time.time() + UNLOCK_TTL_SECONDS
|
||
_unlock_grants[drive_id] = _UnlockGrant(
|
||
expiry=expiry,
|
||
pool_name=pool_name,
|
||
pool_role=pool_role,
|
||
)
|
||
|
||
log.warning(
|
||
"Pool-drive unlock granted: drive_id=%d pool=%s role=%s "
|
||
"operator=%s reason=%r",
|
||
drive_id, pool_name, pool_role, operator, reason,
|
||
)
|
||
return expiry
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Public API
|
||
# ---------------------------------------------------------------------------
|
||
|
||
async def start_job(drive_id: int, profile: str, operator: str,
|
||
stage_order: list[str] | None = None) -> int:
|
||
"""Create and enqueue a burn-in job. Returns the new job ID.
|
||
|
||
If stage_order is provided (e.g. ["short_smart","long_smart","surface_validate"]),
|
||
the job runs those stages in that order (precheck and final_check are always prepended/appended).
|
||
Otherwise the preset STAGE_ORDER[profile] is used.
|
||
"""
|
||
now = _now()
|
||
|
||
# Build the actual stage list
|
||
if stage_order is not None:
|
||
stages = ["precheck"] + list(stage_order) + ["final_check"]
|
||
else:
|
||
stages = STAGE_ORDER[profile]
|
||
|
||
async with _db() as db:
|
||
db.row_factory = aiosqlite.Row
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
await db.execute("PRAGMA foreign_keys=ON")
|
||
|
||
# Reject duplicate active burn-in for same drive
|
||
cur = await db.execute(
|
||
"SELECT COUNT(*) FROM burnin_jobs WHERE drive_id=? AND state IN ('queued','running')",
|
||
(drive_id,),
|
||
)
|
||
if (await cur.fetchone())[0] > 0:
|
||
raise ValueError("Drive already has an active burn-in job")
|
||
|
||
# Pool-membership gate: locked unless the operator explicitly
|
||
# unlocked this drive via /api/v1/drives/{id}/unlock recently.
|
||
# _is_unlocked also checks that the grant's stored (pool_name,
|
||
# pool_role) still matches the live row — a grant issued for an
|
||
# exported drive doesn't carry over if the drive turns out to be
|
||
# in an active pool on the next poll.
|
||
cur = await db.execute(
|
||
"SELECT pool_name, pool_role, devname FROM drives WHERE id=?", (drive_id,)
|
||
)
|
||
drow = await cur.fetchone()
|
||
if drow and drow["pool_name"] and not _is_unlocked(
|
||
drive_id, drow["pool_name"], drow["pool_role"]
|
||
):
|
||
raise PoolMemberError(drive_id, drow["pool_name"], drow["pool_role"])
|
||
|
||
# Closes Codex finding #5: re-check pool state OVER SSH right now,
|
||
# not against cached row. Defends against the 12s poll window
|
||
# where a drive could have been imported into a pool, mounted, or
|
||
# had ZFS labels written between when the operator unlocked it
|
||
# and when they clicked Start. Adds ~200ms per start; cheap
|
||
# against the cost of destroying a freshly-imported pool.
|
||
if drow:
|
||
from app import ssh_client as _ssh
|
||
if _ssh.is_configured():
|
||
fresh = await _ssh.fresh_pool_check_for_drive(drow["devname"])
|
||
cached = (
|
||
{"pool": drow["pool_name"], "role": drow["pool_role"]}
|
||
if drow["pool_name"] else None
|
||
)
|
||
if fresh != cached:
|
||
# State changed since the last poll. Invalidate any
|
||
# unlock grant (it was bound to stale identity) and
|
||
# refuse with a descriptive error so the operator
|
||
# knows to wait for the next poll cycle.
|
||
_unlock_grants.pop(drive_id, None)
|
||
fresh_pool = fresh["pool"] if fresh else None
|
||
fresh_role = fresh["role"] if fresh else None
|
||
if fresh_pool:
|
||
raise PoolMemberError(drive_id, fresh_pool, fresh_role)
|
||
# If the FRESH check shows free but cached said
|
||
# locked, the drive was just removed from a pool —
|
||
# safe to start, but invalidate any stale grant so
|
||
# the operator doesn't reuse old confirmations.
|
||
log.warning(
|
||
"Live pool check for drive_id=%d (%s): cached=%s "
|
||
"fresh=None — drive came free since last poll, "
|
||
"allowing burn-in",
|
||
drive_id, drow["devname"], cached,
|
||
)
|
||
|
||
# Create job. The partial unique index uniq_active_burnin_per_drive
|
||
# (database.py) is the actual race-stopper here: if two concurrent
|
||
# /api/v1/burnin/start calls both pass the SELECT-COUNT check above,
|
||
# only one INSERT can win; the loser raises IntegrityError, which
|
||
# we surface with the same ValueError as the inline duplicate check.
|
||
try:
|
||
cur = await db.execute(
|
||
"""INSERT INTO burnin_jobs (drive_id, profile, state, percent, operator, created_at)
|
||
VALUES (?,?,?,?,?,?) RETURNING id""",
|
||
(drive_id, profile, "queued", 0, operator, now),
|
||
)
|
||
job_id = (await cur.fetchone())["id"]
|
||
except aiosqlite.IntegrityError:
|
||
raise ValueError("Drive already has an active burn-in job")
|
||
|
||
# Create stage rows in the desired execution order
|
||
for stage_name in stages:
|
||
await db.execute(
|
||
"INSERT INTO burnin_stages (burnin_job_id, stage_name, state) VALUES (?,?,?)",
|
||
(job_id, stage_name, "pending"),
|
||
)
|
||
|
||
await db.execute(
|
||
"""INSERT INTO audit_events (event_type, drive_id, burnin_job_id, operator, message)
|
||
VALUES (?,?,?,?,?)""",
|
||
("burnin_queued", drive_id, job_id, operator, f"Queued {profile} burn-in"),
|
||
)
|
||
await db.commit()
|
||
|
||
_spawn_run_job(job_id)
|
||
log.info("Burn-in job %d queued (drive_id=%d profile=%s operator=%s)",
|
||
job_id, drive_id, profile, operator)
|
||
return job_id
|
||
|
||
|
||
async def cancel_job(job_id: int, operator: str) -> bool:
|
||
"""Cancel a queued or running job. Returns True if state was changed."""
|
||
async with _db() as db:
|
||
db.row_factory = aiosqlite.Row
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
|
||
cur = await db.execute(
|
||
"SELECT state, drive_id FROM burnin_jobs WHERE id=?", (job_id,)
|
||
)
|
||
row = await cur.fetchone()
|
||
if not row or row["state"] not in ("queued", "running"):
|
||
return False
|
||
|
||
await db.execute(
|
||
"UPDATE burnin_jobs SET state='cancelled', finished_at=? WHERE id=?",
|
||
(_now(), job_id),
|
||
)
|
||
await db.execute(
|
||
"UPDATE burnin_stages SET state='cancelled' WHERE burnin_job_id=? AND state IN ('pending','running')",
|
||
(job_id,),
|
||
)
|
||
await db.execute(
|
||
"""INSERT INTO audit_events (event_type, drive_id, burnin_job_id, operator, message)
|
||
VALUES (?,?,?,?,?)""",
|
||
("burnin_cancelled", row["drive_id"], job_id, operator, "Cancelled by operator"),
|
||
)
|
||
await db.commit()
|
||
|
||
# Kill the remote child process FIRST (so proc.wait() in the running task
|
||
# can return), then cancel the task so any other awaits unblock.
|
||
await _kill_remote_process(job_id)
|
||
task = _active_tasks.get(job_id)
|
||
if task and not task.done():
|
||
task.cancel()
|
||
|
||
log.info("Burn-in job %d cancelled by %s", job_id, operator)
|
||
return True
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Job runner
|
||
# ---------------------------------------------------------------------------
|
||
|
||
async def _thermal_gate_ok() -> bool:
|
||
"""True if it's thermally safe to start a new burn-in.
|
||
Checks the peak temperature of drives currently under active burn-in.
|
||
"""
|
||
try:
|
||
async with _db() as db:
|
||
cur = await db.execute("""
|
||
SELECT MAX(d.temperature_c)
|
||
FROM drives d
|
||
JOIN burnin_jobs bj ON bj.drive_id = d.id
|
||
WHERE bj.state = 'running' AND d.temperature_c IS NOT NULL
|
||
""")
|
||
row = await cur.fetchone()
|
||
max_temp = row[0] if row and row[0] is not None else None
|
||
return max_temp is None or max_temp < settings.temp_warn_c
|
||
except Exception:
|
||
return True # Never block on error
|
||
|
||
|
||
async def _run_job(job_id: int) -> None:
|
||
"""Acquire semaphore slot, execute all stages, persist final state."""
|
||
assert _semaphore is not None, "burnin.init() not called"
|
||
|
||
# Adaptive thermal gate: wait before competing for a slot if running drives
|
||
# are already at or above the warning threshold. This prevents layering a
|
||
# new burn-in on top of a thermally-stressed system. Gives up after 3 min
|
||
# and proceeds anyway so jobs don't queue indefinitely.
|
||
for _attempt in range(18): # 18 × 10 s = 3 min max
|
||
if await _thermal_gate_ok():
|
||
break
|
||
if _attempt == 0:
|
||
log.info(
|
||
"Thermal gate: job %d waiting — running drive temps at or above %d°C",
|
||
job_id, settings.temp_warn_c,
|
||
)
|
||
await asyncio.sleep(10)
|
||
else:
|
||
log.warning("Thermal gate timed out for job %d — proceeding anyway", job_id)
|
||
|
||
async with _semaphore:
|
||
if await _is_cancelled(job_id):
|
||
return
|
||
|
||
# Transition queued → running
|
||
async with _db() as db:
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
row = await (await db.execute(
|
||
"SELECT drive_id, profile FROM burnin_jobs WHERE id=?", (job_id,)
|
||
)).fetchone()
|
||
if not row:
|
||
return
|
||
drive_id, profile = row[0], row[1]
|
||
|
||
cur = await db.execute("SELECT devname, serial, model FROM drives WHERE id=?", (drive_id,))
|
||
devname_row = await cur.fetchone()
|
||
if not devname_row:
|
||
return
|
||
devname = devname_row[0]
|
||
drive_serial = devname_row[1]
|
||
drive_model = devname_row[2]
|
||
|
||
await db.execute(
|
||
"UPDATE burnin_jobs SET state='running', started_at=? WHERE id=?",
|
||
(_now(), job_id),
|
||
)
|
||
await db.execute(
|
||
"""INSERT INTO audit_events (event_type, drive_id, burnin_job_id, operator, message)
|
||
VALUES (?,?,?,(SELECT operator FROM burnin_jobs WHERE id=?),?)""",
|
||
("burnin_started", drive_id, job_id, job_id, f"Started {profile} burn-in on {devname}"),
|
||
)
|
||
# Read stage order from DB (respects any custom order set at job creation)
|
||
stage_cur = await db.execute(
|
||
"SELECT stage_name FROM burnin_stages WHERE burnin_job_id=? ORDER BY id",
|
||
(job_id,),
|
||
)
|
||
job_stages = [r[0] for r in await stage_cur.fetchall()]
|
||
await db.commit()
|
||
|
||
_push_update()
|
||
log.info("Burn-in started", extra={"job_id": job_id, "devname": devname, "profile": profile})
|
||
|
||
success = False
|
||
error_text = None
|
||
was_cancelled = False
|
||
try:
|
||
success = await _execute_stages(job_id, job_stages, devname, drive_id)
|
||
except asyncio.CancelledError:
|
||
was_cancelled = True
|
||
except Exception as exc:
|
||
error_text = str(exc)
|
||
log.exception("Burn-in raised exception", extra={"job_id": job_id, "devname": devname})
|
||
|
||
# If the job has already moved to a terminal state — by cancel_job
|
||
# ('cancelled') or check_stuck_jobs ('unknown') — leave it alone. The
|
||
# task may have been cancelled mid-stage; finalizing as 'failed' would
|
||
# clobber that audit-meaningful terminal state.
|
||
async with _db() as db:
|
||
cur = await db.execute("SELECT state FROM burnin_jobs WHERE id=?", (job_id,))
|
||
cur_row = await cur.fetchone()
|
||
if cur_row and cur_row[0] != "running":
|
||
return
|
||
|
||
# Cancellation arriving here means the asyncio task was cancelled
|
||
# by something other than cancel_job/check_stuck_jobs (shutdown,
|
||
# uvicorn reload, future code paths). The DB still says 'running',
|
||
# so we have to write *some* terminal state, but classifying the
|
||
# interrupted job as 'failed' would lie — we don't actually know
|
||
# whether the underlying SMART/badblocks work passed or not.
|
||
if was_cancelled:
|
||
final_state = "unknown"
|
||
else:
|
||
final_state = "passed" if success else "failed"
|
||
async with _db() as db:
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
await db.execute(
|
||
"UPDATE burnin_jobs SET state=?, percent=?, finished_at=?, error_text=? WHERE id=?",
|
||
(final_state, 100 if success else None, _now(), error_text, job_id),
|
||
)
|
||
await db.execute(
|
||
"""INSERT INTO audit_events (event_type, drive_id, burnin_job_id, operator, message)
|
||
VALUES (?,?,?,(SELECT operator FROM burnin_jobs WHERE id=?),?)""",
|
||
(f"burnin_{final_state}", drive_id, job_id, job_id,
|
||
f"Burn-in {final_state} on {devname}"),
|
||
)
|
||
await db.commit()
|
||
|
||
# Build SSE alert for browser notifications
|
||
alert = {
|
||
"state": final_state,
|
||
"job_id": job_id,
|
||
"devname": devname,
|
||
"serial": drive_serial,
|
||
"model": drive_model,
|
||
"error_text": error_text,
|
||
}
|
||
_push_update(alert=alert)
|
||
log.info("Burn-in finished", extra={"job_id": job_id, "devname": devname, "state": final_state})
|
||
|
||
# Fire webhook + immediate email in background (non-blocking)
|
||
try:
|
||
from app import notifier
|
||
cur2 = None
|
||
async with _db() as db2:
|
||
db2.row_factory = aiosqlite.Row
|
||
cur2 = await db2.execute(
|
||
"SELECT profile, operator FROM burnin_jobs WHERE id=?", (job_id,)
|
||
)
|
||
job_row = await cur2.fetchone()
|
||
if job_row:
|
||
# Get bad_blocks count from surface_validate stage if present
|
||
bad_blocks = 0
|
||
async with _db() as db3:
|
||
cur3 = await db3.execute(
|
||
"SELECT bad_blocks FROM burnin_stages WHERE burnin_job_id=? AND stage_name='surface_validate'",
|
||
(job_id,)
|
||
)
|
||
bb_row = await cur3.fetchone()
|
||
if bb_row and bb_row[0]:
|
||
bad_blocks = bb_row[0]
|
||
asyncio.create_task(notifier.notify_job_complete(
|
||
job_id=job_id,
|
||
devname=devname,
|
||
serial=drive_serial,
|
||
model=drive_model,
|
||
state=final_state,
|
||
profile=job_row["profile"],
|
||
operator=job_row["operator"],
|
||
error_text=error_text,
|
||
bad_blocks=bad_blocks,
|
||
))
|
||
except Exception as exc:
|
||
log.error("Failed to schedule notifications: %s", exc)
|
||
|
||
|
||
async def _execute_stages(job_id: int, stages: list[str], devname: str, drive_id: int) -> bool:
|
||
for stage_name in stages:
|
||
if await _is_cancelled(job_id):
|
||
return False
|
||
|
||
await _start_stage(job_id, stage_name)
|
||
_push_update()
|
||
|
||
try:
|
||
ok = await _dispatch_stage(job_id, stage_name, devname, drive_id)
|
||
except Exception as exc:
|
||
log.error("Stage raised exception: %s", exc, extra={"job_id": job_id, "devname": devname, "stage": stage_name})
|
||
ok = False
|
||
await _finish_stage(job_id, stage_name, success=False, error_text=str(exc))
|
||
_push_update()
|
||
return False
|
||
|
||
if not ok and await _is_cancelled(job_id):
|
||
# Stage was aborted due to cancellation — mark it cancelled, not failed
|
||
await _cancel_stage(job_id, stage_name)
|
||
else:
|
||
await _finish_stage(job_id, stage_name, success=ok)
|
||
await _recalculate_progress(job_id)
|
||
_push_update()
|
||
|
||
if not ok:
|
||
return False
|
||
|
||
return True
|
||
|
||
|
||
async def _dispatch_stage(job_id: int, stage_name: str, devname: str, drive_id: int) -> bool:
|
||
if stage_name == "precheck":
|
||
return await _stage_precheck(job_id, drive_id)
|
||
elif stage_name == "short_smart":
|
||
return await _stage_smart_test(job_id, devname, "SHORT", "short_smart", drive_id)
|
||
elif stage_name == "long_smart":
|
||
return await _stage_smart_test(job_id, devname, "LONG", "long_smart", drive_id)
|
||
elif stage_name == "surface_validate":
|
||
return await _stage_surface_validate(job_id, devname, drive_id)
|
||
elif stage_name == "io_validate":
|
||
return await _stage_timed_simulate(job_id, "io_validate", settings.io_validate_seconds)
|
||
elif stage_name == "final_check":
|
||
return await _stage_final_check(job_id, devname, drive_id)
|
||
return True
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Individual stage implementations
|
||
# ---------------------------------------------------------------------------
|
||
|
||
async def _stage_precheck(job_id: int, drive_id: int) -> bool:
|
||
"""Check SMART health and temperature before starting destructive work."""
|
||
async with _db() as db:
|
||
cur = await db.execute(
|
||
"SELECT smart_health, temperature_c FROM drives WHERE id=?", (drive_id,)
|
||
)
|
||
row = await cur.fetchone()
|
||
|
||
if not row:
|
||
return False
|
||
|
||
health, temp = row[0], row[1]
|
||
|
||
if health == "FAILED":
|
||
await _set_stage_error(job_id, "precheck", "Drive SMART health is FAILED — refusing to burn in")
|
||
return False
|
||
|
||
if temp and temp > settings.temp_crit_c:
|
||
await _set_stage_error(job_id, "precheck", f"Drive temperature {temp}°C exceeds {settings.temp_crit_c}°C limit")
|
||
return False
|
||
|
||
await asyncio.sleep(1) # Simulate brief check
|
||
return True
|
||
|
||
|
||
async def _stage_smart_test(job_id: int, devname: str, test_type: str, stage_name: str,
|
||
drive_id: int | None = None) -> bool:
|
||
"""Start a SMART test. Uses SSH if configured, TrueNAS REST API otherwise."""
|
||
from app import ssh_client
|
||
if ssh_client.is_configured():
|
||
return await _stage_smart_test_ssh(job_id, devname, test_type, stage_name, drive_id)
|
||
return await _stage_smart_test_api(job_id, devname, test_type, stage_name)
|
||
|
||
|
||
async def _stage_smart_test_api(job_id: int, devname: str, test_type: str, stage_name: str) -> bool:
|
||
"""TrueNAS REST API path for SMART test (mock / dev mode)."""
|
||
tn_job_id = await _client.start_smart_test([devname], test_type)
|
||
|
||
while True:
|
||
if await _is_cancelled(job_id):
|
||
try:
|
||
await _client.abort_job(tn_job_id)
|
||
except Exception:
|
||
pass
|
||
return False
|
||
|
||
jobs = await _client.get_smart_jobs()
|
||
job = next((j for j in jobs if j["id"] == tn_job_id), None)
|
||
|
||
if not job:
|
||
return False
|
||
|
||
state = job["state"]
|
||
pct = job["progress"]["percent"]
|
||
|
||
await _update_stage_percent(job_id, stage_name, pct)
|
||
await _recalculate_progress(job_id, None)
|
||
_push_update()
|
||
|
||
if state == "SUCCESS":
|
||
return True
|
||
elif state in ("FAILED", "ABORTED"):
|
||
await _set_stage_error(job_id, stage_name,
|
||
job.get("error") or f"SMART {test_type} test failed")
|
||
return False
|
||
|
||
await asyncio.sleep(POLL_INTERVAL)
|
||
|
||
|
||
async def _stage_smart_test_ssh(job_id: int, devname: str, test_type: str, stage_name: str,
|
||
drive_id: int | None) -> bool:
|
||
"""SSH path for SMART test — runs smartctl directly on TrueNAS."""
|
||
from app import ssh_client
|
||
|
||
# Start the test
|
||
try:
|
||
startup = await ssh_client.start_smart_test(devname, test_type)
|
||
await _append_stage_log(job_id, stage_name, startup + "\n")
|
||
except Exception as exc:
|
||
await _set_stage_error(job_id, stage_name, f"Failed to start SMART test via SSH: {exc}")
|
||
return False
|
||
|
||
# Brief pause to let the test register in smartctl output
|
||
await asyncio.sleep(3)
|
||
|
||
# Throttle log_text appends — every poll on a multi-hour long_smart bloated
|
||
# log_text to 50+ MB and triggered SQLite "database is locked" because each
|
||
# COALESCE-then-append rewrites the whole column. Append every ~60s, on the
|
||
# first poll, and on any state change.
|
||
LOG_EVERY_N_POLLS = 12
|
||
poll_count = 0
|
||
last_state: str | None = None
|
||
|
||
# Poll until complete
|
||
while True:
|
||
if await _is_cancelled(job_id):
|
||
try:
|
||
await ssh_client.abort_smart_test(devname)
|
||
except Exception:
|
||
pass
|
||
return False
|
||
|
||
await asyncio.sleep(POLL_INTERVAL)
|
||
|
||
try:
|
||
progress = await ssh_client.poll_smart_progress(devname)
|
||
except Exception as exc:
|
||
log.warning("SSH SMART poll failed: %s", exc, extra={"job_id": job_id})
|
||
await _append_stage_log(job_id, stage_name, f"[poll error] {exc}\n")
|
||
continue
|
||
|
||
poll_count += 1
|
||
state_changed = progress["state"] != last_state
|
||
last_state = progress["state"]
|
||
if poll_count == 1 or poll_count % LOG_EVERY_N_POLLS == 0 or state_changed:
|
||
await _append_stage_log(job_id, stage_name, progress["output"] + "\n---\n")
|
||
|
||
if progress["state"] == "running":
|
||
pct = max(0, 100 - progress["percent_remaining"])
|
||
await _update_stage_percent(job_id, stage_name, pct)
|
||
await _recalculate_progress(job_id)
|
||
_push_update()
|
||
|
||
elif progress["state"] == "passed":
|
||
await _update_stage_percent(job_id, stage_name, 100)
|
||
# Run attribute check
|
||
if drive_id is not None:
|
||
try:
|
||
attrs = await ssh_client.get_smart_attributes(devname)
|
||
await _store_smart_attrs(drive_id, attrs)
|
||
await _store_smart_raw_output(drive_id, test_type, attrs["raw_output"])
|
||
if attrs["failures"]:
|
||
error = "SMART attribute failures: " + "; ".join(attrs["failures"])
|
||
await _set_stage_error(job_id, stage_name, error)
|
||
return False
|
||
if attrs["warnings"]:
|
||
await _append_stage_log(
|
||
job_id, stage_name,
|
||
"[WARNING] " + "; ".join(attrs["warnings"]) + "\n"
|
||
)
|
||
except Exception as exc:
|
||
log.warning("Failed to retrieve SMART attributes: %s", exc)
|
||
await _recalculate_progress(job_id)
|
||
_push_update()
|
||
return True
|
||
|
||
elif progress["state"] == "failed":
|
||
await _set_stage_error(job_id, stage_name, f"SMART {test_type} test failed")
|
||
return False
|
||
# "unknown" → keep polling
|
||
|
||
|
||
async def _badblocks_available() -> bool:
|
||
"""Check if badblocks is installed on the remote host (Linux/SCALE only)."""
|
||
from app import ssh_client
|
||
try:
|
||
async with await ssh_client._connect() as conn:
|
||
result = await conn.run("which badblocks", check=False)
|
||
return result.returncode == 0
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
async def _stage_surface_validate(job_id: int, devname: str, drive_id: int) -> bool:
|
||
"""
|
||
Surface validation stage — auto-routes to the right implementation:
|
||
|
||
1. NVMe device + SSH + nvme-cli available (TrueNAS SCALE):
|
||
→ `nvme format -s 1 /dev/{devname}` (cryptographic erase).
|
||
Far faster than badblocks on NVMe (seconds vs hours) and
|
||
exercises the controller's secure-erase path, not just user-LBA
|
||
writes.
|
||
2. SSH configured + badblocks available (TrueNAS SCALE / Linux):
|
||
→ badblocks -wsv -b N -c N -p N /dev/{devname} directly over SSH.
|
||
3. SSH configured + badblocks NOT available (TrueNAS CORE / FreeBSD):
|
||
→ uses TrueNAS REST API disk.wipe FULL job + post-wipe SMART check.
|
||
4. No SSH:
|
||
→ simulated timed progress (dev/mock mode).
|
||
"""
|
||
from app import ssh_client
|
||
if ssh_client.is_configured():
|
||
if devname.startswith("nvme") and await _nvme_cli_available():
|
||
return await _stage_surface_validate_nvme(job_id, devname, drive_id)
|
||
if await _badblocks_available():
|
||
return await _stage_surface_validate_ssh(job_id, devname, drive_id)
|
||
# TrueNAS CORE/FreeBSD: badblocks not available — use native wipe API
|
||
await _append_stage_log(
|
||
job_id, "surface_validate",
|
||
"[INFO] badblocks not found on host (TrueNAS CORE/FreeBSD) — "
|
||
"using TrueNAS disk.wipe API (FULL write pass).\n\n"
|
||
)
|
||
return await _stage_surface_validate_truenas(job_id, devname, drive_id)
|
||
return await _stage_timed_simulate(job_id, "surface_validate", settings.surface_validate_seconds)
|
||
|
||
|
||
async def _nvme_cli_available() -> bool:
|
||
"""Check if nvme-cli is installed on the remote host."""
|
||
from app import ssh_client
|
||
try:
|
||
async with await ssh_client._connect() as conn:
|
||
r = await conn.run("which nvme", check=False)
|
||
return r.returncode == 0
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
async def _stage_surface_validate_nvme(job_id: int, devname: str,
|
||
drive_id: int) -> bool:
|
||
"""NVMe destructive surface test via `nvme format -s 1` (crypto erase).
|
||
|
||
Crypto-erase nukes the data encryption key on the drive's controller,
|
||
rendering all stored data unrecoverable in milliseconds; the actual
|
||
flash is then implicitly trim-able. This is the canonical destructive
|
||
burn-in for NVMe — badblocks would write the entire LBA space, which
|
||
is slower AND wears the flash unnecessarily.
|
||
|
||
Post-format we re-read SMART attributes; the drive should report all
|
||
counters reset (life used + spare) and PASSED health.
|
||
"""
|
||
from app import ssh_client
|
||
|
||
await _append_stage_log(
|
||
job_id, "surface_validate",
|
||
f"[START] nvme format -s 1 /dev/{devname}\n"
|
||
f"[NOTE] Cryptographic erase — destroys all data on /dev/{devname}.\n\n"
|
||
)
|
||
|
||
cmd = f"nvme format -s 1 --force /dev/{devname}"
|
||
try:
|
||
async with await ssh_client._connect() as conn:
|
||
r = await asyncio.wait_for(
|
||
conn.run(cmd, check=False), timeout=600
|
||
)
|
||
except Exception as exc:
|
||
await _append_stage_log(
|
||
job_id, "surface_validate", f"\n[SSH error] {exc}\n"
|
||
)
|
||
await _set_stage_error(
|
||
job_id, "surface_validate", f"NVMe format SSH error: {exc}"
|
||
)
|
||
return False
|
||
|
||
output = (r.stdout or "") + (r.stderr or "")
|
||
await _append_stage_log(job_id, "surface_validate", output + "\n")
|
||
|
||
if r.returncode != 0:
|
||
await _set_stage_error(
|
||
job_id, "surface_validate",
|
||
f"nvme format exited {r.returncode}: {output.strip()[:200]}"
|
||
)
|
||
return False
|
||
|
||
# Sanity-check post-format SMART health. Mirrors the surface_validate
|
||
# SSH path's check parity — fail on FAILED health, fail on real
|
||
# SMART attribute failures, log warnings but don't fail. A transport
|
||
# error here is treated as a soft pass (log + continue) so a single
|
||
# SSH blip after a successful format doesn't undo the work.
|
||
try:
|
||
attrs = await ssh_client.get_smart_attributes(devname)
|
||
ssh_only_failures = [
|
||
f for f in (attrs.get("failures") or []) if f.startswith("SSH error:")
|
||
]
|
||
real_failures = [
|
||
f for f in (attrs.get("failures") or []) if not f.startswith("SSH error:")
|
||
]
|
||
if attrs.get("health") == "FAILED":
|
||
await _set_stage_error(
|
||
job_id, "surface_validate",
|
||
"NVMe SMART health FAILED after format",
|
||
)
|
||
return False
|
||
if real_failures:
|
||
await _set_stage_error(
|
||
job_id, "surface_validate",
|
||
"NVMe SMART attribute failures after format: "
|
||
+ "; ".join(real_failures),
|
||
)
|
||
return False
|
||
if ssh_only_failures:
|
||
await _append_stage_log(
|
||
job_id, "surface_validate",
|
||
"[WARN] post-format SMART check had SSH errors "
|
||
"(soft-passing): " + "; ".join(ssh_only_failures) + "\n",
|
||
)
|
||
if attrs.get("warnings"):
|
||
await _append_stage_log(
|
||
job_id, "surface_validate",
|
||
"[WARN] " + "; ".join(attrs["warnings"]) + "\n",
|
||
)
|
||
except Exception as exc:
|
||
log.warning("Post-format SMART check error on %s: %s", devname, exc)
|
||
await _append_stage_log(
|
||
job_id, "surface_validate",
|
||
f"[WARN] post-format SMART check raised: {exc}\n",
|
||
)
|
||
|
||
await _update_stage_percent(job_id, "surface_validate", 100)
|
||
await _recalculate_progress(job_id)
|
||
_push_update()
|
||
return True
|
||
|
||
|
||
async def _stage_surface_validate_ssh(job_id: int, devname: str, drive_id: int) -> bool:
|
||
"""Run badblocks over SSH, streaming output to stage log."""
|
||
from app import ssh_client
|
||
|
||
await _append_stage_log(
|
||
job_id, "surface_validate",
|
||
f"[START] badblocks -wsv -b {settings.surface_validate_block_size} "
|
||
f"-c {settings.surface_validate_block_buffer} "
|
||
f"-p {settings.surface_validate_passes} /dev/{devname}\n"
|
||
f"[NOTE] This is a DESTRUCTIVE write test. "
|
||
f"All data on /dev/{devname} will be overwritten.\n\n"
|
||
)
|
||
|
||
def _is_cancelled_sync() -> bool:
|
||
# Synchronous version — we check the DB state flag set by cancel_job()
|
||
import asyncio
|
||
loop = asyncio.get_event_loop()
|
||
try:
|
||
return loop.run_until_complete(_is_cancelled(job_id))
|
||
except Exception:
|
||
return False
|
||
|
||
last_logged_pct = [-1]
|
||
|
||
def on_progress(pct: int, bad_blocks: int, line: str) -> None:
|
||
nonlocal last_logged_pct
|
||
# Write to log (fire-and-forget via asyncio.create_task from sync context)
|
||
# The log append is done in the async flush below
|
||
pass
|
||
|
||
accumulated_lines: list[str] = []
|
||
|
||
async def on_progress_async(pct: int, bad_blocks: int, line: str) -> None:
|
||
accumulated_lines.append(line)
|
||
# Flush to DB and update progress every ~25 lines to avoid excessive DB writes
|
||
if len(accumulated_lines) % 25 == 0:
|
||
await _append_stage_log(job_id, "surface_validate", "".join(accumulated_lines[-25:]))
|
||
await _update_stage_bad_blocks(job_id, "surface_validate", bad_blocks)
|
||
await _update_stage_percent(job_id, "surface_validate", pct)
|
||
await _recalculate_progress(job_id)
|
||
_push_update()
|
||
if await _is_cancelled(job_id):
|
||
raise asyncio.CancelledError
|
||
|
||
# Run badblocks — we adapt the callback pattern to async by collecting then flushing
|
||
result = {"bad_blocks": 0, "output": "", "aborted": False}
|
||
try:
|
||
# The actual streaming; we handle progress via the accumulated_lines pattern
|
||
bad_blocks_total = 0
|
||
output_lines: list[str] = []
|
||
|
||
async with await ssh_client._connect() as conn:
|
||
# Wrap in `sh -c 'echo PID:$$; exec ...'` so we get the remote
|
||
# PID on the first stdout line. asyncssh's proc.kill() sends an
|
||
# SSH signal request that OpenSSH's sshd ignores by default, so
|
||
# we need the PID to issue an out-of-band `kill -9` over a fresh
|
||
# session when we want to abort.
|
||
#
|
||
# Block geometry is operator-tunable (Settings → Burn-in):
|
||
# -b N block size in bytes (settings.surface_validate_block_size)
|
||
# -c N blocks held per IO (settings.surface_validate_block_buffer)
|
||
# -p N pass count (settings.surface_validate_passes)
|
||
# Defaults preserve original behavior (-b 4096 -c 64 -p 1).
|
||
bb_args = (
|
||
f"-wsv "
|
||
f"-b {settings.surface_validate_block_size} "
|
||
f"-c {settings.surface_validate_block_buffer} "
|
||
f"-p {settings.surface_validate_passes}"
|
||
)
|
||
cmd = (
|
||
f"sh -c 'echo PID:$$; exec badblocks {bb_args} /dev/{devname}'"
|
||
)
|
||
async with conn.create_process(cmd) as proc:
|
||
import re as _re
|
||
|
||
pid_seen = False
|
||
|
||
async def _drain(stream, is_stderr: bool):
|
||
nonlocal bad_blocks_total, pid_seen
|
||
async for raw in stream:
|
||
line = raw if isinstance(raw, str) else raw.decode("utf-8", errors="replace")
|
||
|
||
# First stdout line is "PID:<n>" from the wrapping shell.
|
||
# Capture it and don't append it to the user-visible log.
|
||
if not is_stderr and not pid_seen and line.startswith("PID:"):
|
||
pid_seen = True
|
||
try:
|
||
_remote_pids[job_id] = int(line[4:].strip())
|
||
log.info(
|
||
"Captured remote PID %d for job %d (badblocks)",
|
||
_remote_pids[job_id], job_id,
|
||
)
|
||
except ValueError:
|
||
pass
|
||
continue
|
||
|
||
output_lines.append(line)
|
||
|
||
if is_stderr:
|
||
m = _re.search(r"([\d.]+)%\s+done", line)
|
||
if m:
|
||
pct = min(99, int(float(m.group(1))))
|
||
await _update_stage_percent(job_id, "surface_validate", pct)
|
||
await _update_stage_bad_blocks(job_id, "surface_validate", bad_blocks_total)
|
||
await _recalculate_progress(job_id)
|
||
_push_update()
|
||
else:
|
||
stripped = line.strip()
|
||
if stripped and stripped.isdigit():
|
||
bad_blocks_total += 1
|
||
|
||
# Append to DB log in chunks
|
||
if len(output_lines) % 20 == 0:
|
||
chunk = "".join(output_lines[-20:])
|
||
await _append_stage_log(job_id, "surface_validate", chunk)
|
||
|
||
# Abort on bad block threshold
|
||
if bad_blocks_total > settings.bad_block_threshold:
|
||
await _kill_remote_process(job_id)
|
||
output_lines.append(
|
||
f"\n[ABORTED] {bad_blocks_total} bad block(s) exceeded "
|
||
f"threshold ({settings.bad_block_threshold})\n"
|
||
)
|
||
return
|
||
|
||
if await _is_cancelled(job_id):
|
||
await _kill_remote_process(job_id)
|
||
return
|
||
|
||
await asyncio.gather(
|
||
_drain(proc.stdout, False),
|
||
_drain(proc.stderr, True),
|
||
return_exceptions=True,
|
||
)
|
||
# Bound proc.wait so a remote process that ignored our kill
|
||
# signal (or that we never managed to kill) can't pin this
|
||
# task in the semaphore forever. Closing the connection on
|
||
# exit will deliver SIGPIPE to the remote on its next write.
|
||
try:
|
||
await asyncio.wait_for(proc.wait(), timeout=15)
|
||
except asyncio.TimeoutError:
|
||
log.warning(
|
||
"proc.wait() timed out for job %d — abandoning channel",
|
||
job_id,
|
||
)
|
||
|
||
# Flush only lines we haven't already written in 20-line chunks.
|
||
# Previously we appended the FULL accumulated output here too,
|
||
# doubling the stored log_text size for every surface_validate
|
||
# stage and pushing app.db into hundreds of MB.
|
||
flushed_count = (len(output_lines) // 20) * 20
|
||
tail = "".join(output_lines[flushed_count:])
|
||
if tail:
|
||
await _append_stage_log(job_id, "surface_validate", tail)
|
||
result["bad_blocks"] = bad_blocks_total
|
||
result["output"] = "".join(output_lines) # in-memory only, not re-stored
|
||
result["aborted"] = bad_blocks_total > settings.bad_block_threshold
|
||
|
||
except asyncio.CancelledError:
|
||
return False
|
||
except Exception as exc:
|
||
await _append_stage_log(job_id, "surface_validate", f"\n[SSH error] {exc}\n")
|
||
await _set_stage_error(job_id, "surface_validate", f"SSH badblocks error: {exc}")
|
||
return False
|
||
|
||
await _update_stage_bad_blocks(job_id, "surface_validate", result["bad_blocks"])
|
||
|
||
if result["aborted"] or result["bad_blocks"] > settings.bad_block_threshold:
|
||
await _set_stage_error(
|
||
job_id, "surface_validate",
|
||
f"Surface validate FAILED: {result['bad_blocks']} bad block(s) found "
|
||
f"(threshold: {settings.bad_block_threshold})"
|
||
)
|
||
return False
|
||
|
||
return True
|
||
|
||
|
||
async def _stage_surface_validate_truenas(job_id: int, devname: str, drive_id: int) -> bool:
|
||
"""
|
||
Surface validation via TrueNAS CORE disk.wipe REST API.
|
||
Used on FreeBSD (TrueNAS CORE) where badblocks is unavailable.
|
||
|
||
Sends a FULL write-zero pass across the entire disk, polls progress,
|
||
then runs a post-wipe SMART attribute check to catch reallocated sectors.
|
||
"""
|
||
from app import ssh_client
|
||
|
||
await _append_stage_log(
|
||
job_id, "surface_validate",
|
||
f"[START] TrueNAS disk.wipe FULL — {devname}\n"
|
||
f"[NOTE] DESTRUCTIVE: all data on {devname} will be overwritten.\n\n"
|
||
)
|
||
|
||
# Start the wipe job
|
||
try:
|
||
tn_job_id = await _client.wipe_disk(devname, "FULL")
|
||
except Exception as exc:
|
||
await _set_stage_error(job_id, "surface_validate", f"Failed to start disk.wipe: {exc}")
|
||
return False
|
||
|
||
await _append_stage_log(
|
||
job_id, "surface_validate",
|
||
f"[JOB] TrueNAS wipe job started (job_id={tn_job_id})\n"
|
||
)
|
||
|
||
# Poll until complete
|
||
log_flush_counter = 0
|
||
while True:
|
||
if await _is_cancelled(job_id):
|
||
try:
|
||
await _client.abort_job(tn_job_id)
|
||
except Exception:
|
||
pass
|
||
return False
|
||
|
||
await asyncio.sleep(POLL_INTERVAL)
|
||
|
||
try:
|
||
job = await _client.get_job(tn_job_id)
|
||
except Exception as exc:
|
||
log.warning("Wipe job poll failed: %s", exc, extra={"job_id": job_id})
|
||
await _append_stage_log(job_id, "surface_validate", f"[poll error] {exc}\n")
|
||
continue
|
||
|
||
if not job:
|
||
await _set_stage_error(job_id, "surface_validate", f"Wipe job {tn_job_id} not found")
|
||
return False
|
||
|
||
state = job.get("state", "")
|
||
pct = int(job.get("progress", {}).get("percent", 0) or 0)
|
||
desc = job.get("progress", {}).get("description", "")
|
||
|
||
await _update_stage_percent(job_id, "surface_validate", min(pct, 99))
|
||
await _recalculate_progress(job_id)
|
||
_push_update()
|
||
|
||
# Log progress description every ~5 polls to avoid DB spam
|
||
log_flush_counter += 1
|
||
if desc and log_flush_counter % 5 == 0:
|
||
await _append_stage_log(job_id, "surface_validate", f"[{pct}%] {desc}\n")
|
||
|
||
if state == "SUCCESS":
|
||
await _update_stage_percent(job_id, "surface_validate", 100)
|
||
await _append_stage_log(
|
||
job_id, "surface_validate",
|
||
f"\n[DONE] Wipe job {tn_job_id} completed successfully.\n"
|
||
)
|
||
# Post-wipe SMART check — catch any sectors that failed under write stress
|
||
if ssh_client.is_configured() and drive_id is not None:
|
||
await _append_stage_log(
|
||
job_id, "surface_validate",
|
||
"[CHECK] Running post-wipe SMART attribute check...\n"
|
||
)
|
||
try:
|
||
attrs = await ssh_client.get_smart_attributes(devname)
|
||
await _store_smart_attrs(drive_id, attrs)
|
||
if attrs["failures"]:
|
||
error = "Post-wipe SMART check: " + "; ".join(attrs["failures"])
|
||
await _set_stage_error(job_id, "surface_validate", error)
|
||
return False
|
||
if attrs["warnings"]:
|
||
await _append_stage_log(
|
||
job_id, "surface_validate",
|
||
"[WARNING] " + "; ".join(attrs["warnings"]) + "\n"
|
||
)
|
||
await _append_stage_log(
|
||
job_id, "surface_validate",
|
||
f"[CHECK] SMART health: {attrs['health']} — no critical attributes.\n"
|
||
)
|
||
except Exception as exc:
|
||
log.warning("Post-wipe SMART check failed: %s", exc)
|
||
await _append_stage_log(
|
||
job_id, "surface_validate",
|
||
f"[WARN] Post-wipe SMART check failed (non-fatal): {exc}\n"
|
||
)
|
||
return True
|
||
|
||
elif state in ("FAILED", "ABORTED", "ERROR"):
|
||
error_msg = job.get("error") or f"Disk wipe failed (state={state})"
|
||
await _set_stage_error(
|
||
job_id, "surface_validate",
|
||
f"TrueNAS disk.wipe FAILED: {error_msg}"
|
||
)
|
||
return False
|
||
# RUNNING or WAITING — keep polling
|
||
|
||
|
||
async def _stage_timed_simulate(job_id: int, stage_name: str, duration_seconds: int) -> bool:
|
||
"""Simulate a timed stage with progress updates (mock / dev mode)."""
|
||
start = time.monotonic()
|
||
|
||
while True:
|
||
if await _is_cancelled(job_id):
|
||
return False
|
||
|
||
elapsed = time.monotonic() - start
|
||
pct = min(100, int(elapsed / duration_seconds * 100))
|
||
|
||
await _update_stage_percent(job_id, stage_name, pct)
|
||
await _recalculate_progress(job_id, None)
|
||
_push_update()
|
||
|
||
if pct >= 100:
|
||
return True
|
||
|
||
await asyncio.sleep(POLL_INTERVAL)
|
||
|
||
|
||
async def _stage_final_check(job_id: int, devname: str, drive_id: int | None = None) -> bool:
|
||
"""
|
||
Verify drive passed all tests.
|
||
SSH mode: run smartctl -a and check critical attributes.
|
||
Mock mode: check SMART health field in DB.
|
||
|
||
A transient SSH connectivity failure here must NOT invalidate a prior
|
||
multi-day surface_validate. Retry SSH-only failures, then soft-pass.
|
||
"""
|
||
await asyncio.sleep(1)
|
||
from app import ssh_client
|
||
|
||
def _ssh_only(failures: list[str]) -> bool:
|
||
return bool(failures) and all(f.startswith("SSH error:") for f in failures)
|
||
|
||
if ssh_client.is_configured() and drive_id is not None:
|
||
try:
|
||
attrs = await ssh_client.get_smart_attributes(devname)
|
||
for attempt in range(2):
|
||
if not _ssh_only(attrs.get("failures") or []):
|
||
break
|
||
log.warning(
|
||
"final_check SSH unreachable (attempt %d/3); retrying in 30s",
|
||
attempt + 1,
|
||
extra={"job_id": job_id, "devname": devname},
|
||
)
|
||
await asyncio.sleep(30)
|
||
attrs = await ssh_client.get_smart_attributes(devname)
|
||
|
||
failures = attrs.get("failures") or []
|
||
if _ssh_only(failures):
|
||
log.warning(
|
||
"final_check soft-pass: SSH unreachable after retries; prior stages stand",
|
||
extra={"job_id": job_id, "devname": devname, "ssh_error": failures},
|
||
)
|
||
return True
|
||
|
||
await _store_smart_attrs(drive_id, attrs)
|
||
if attrs["health"] == "FAILED" or failures:
|
||
msg = failures or [f"SMART health: {attrs['health']}"]
|
||
await _set_stage_error(job_id, "final_check",
|
||
"Final check failed: " + "; ".join(msg))
|
||
return False
|
||
return True
|
||
except Exception as exc:
|
||
log.warning("SSH final_check raised, falling back to DB check: %s", exc)
|
||
|
||
# DB check (mock mode fallback)
|
||
async with _db() as db:
|
||
cur = await db.execute(
|
||
"SELECT smart_health FROM drives WHERE devname=?", (devname,)
|
||
)
|
||
row = await cur.fetchone()
|
||
|
||
if not row or row[0] == "FAILED":
|
||
await _set_stage_error(job_id, "final_check", "Drive SMART health is FAILED after burn-in")
|
||
return False
|
||
|
||
return True
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# DB helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
async def _is_cancelled(job_id: int) -> bool:
|
||
async with _db() as db:
|
||
cur = await db.execute("SELECT state FROM burnin_jobs WHERE id=?", (job_id,))
|
||
row = await cur.fetchone()
|
||
return bool(row and row[0] == "cancelled")
|
||
|
||
|
||
async def _start_stage(job_id: int, stage_name: str) -> None:
|
||
async with _db() as db:
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
await db.execute(
|
||
"UPDATE burnin_stages SET state='running', started_at=? WHERE burnin_job_id=? AND stage_name=?",
|
||
(_now(), job_id, stage_name),
|
||
)
|
||
await db.execute(
|
||
"UPDATE burnin_jobs SET stage_name=? WHERE id=?",
|
||
(stage_name, job_id),
|
||
)
|
||
await db.commit()
|
||
|
||
|
||
async def _finish_stage(job_id: int, stage_name: str, success: bool, error_text: str | None = None) -> None:
|
||
now = _now()
|
||
state = "passed" if success else "failed"
|
||
async with _db() as db:
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
cur = await db.execute(
|
||
"SELECT started_at FROM burnin_stages WHERE burnin_job_id=? AND stage_name=?",
|
||
(job_id, stage_name),
|
||
)
|
||
row = await cur.fetchone()
|
||
duration = None
|
||
if row and row[0]:
|
||
try:
|
||
start = datetime.fromisoformat(row[0])
|
||
if start.tzinfo is None:
|
||
start = start.replace(tzinfo=timezone.utc)
|
||
duration = (datetime.now(timezone.utc) - start).total_seconds()
|
||
except Exception:
|
||
pass
|
||
|
||
# Only overwrite error_text if one is passed; otherwise preserve what the stage already wrote
|
||
if error_text is not None:
|
||
await db.execute(
|
||
"""UPDATE burnin_stages
|
||
SET state=?, percent=?, finished_at=?, duration_seconds=?, error_text=?
|
||
WHERE burnin_job_id=? AND stage_name=?""",
|
||
(state, 100 if success else None, now, duration, error_text, job_id, stage_name),
|
||
)
|
||
else:
|
||
await db.execute(
|
||
"""UPDATE burnin_stages
|
||
SET state=?, percent=?, finished_at=?, duration_seconds=?
|
||
WHERE burnin_job_id=? AND stage_name=?""",
|
||
(state, 100 if success else None, now, duration, job_id, stage_name),
|
||
)
|
||
await db.commit()
|
||
|
||
|
||
async def _update_stage_percent(job_id: int, stage_name: str, pct: int) -> None:
|
||
async with _db() as db:
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
await db.execute(
|
||
"UPDATE burnin_stages SET percent=? WHERE burnin_job_id=? AND stage_name=?",
|
||
(pct, job_id, stage_name),
|
||
)
|
||
await db.commit()
|
||
|
||
|
||
async def _cancel_stage(job_id: int, stage_name: str) -> None:
|
||
now = _now()
|
||
async with _db() as db:
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
await db.execute(
|
||
"UPDATE burnin_stages SET state='cancelled', finished_at=? WHERE burnin_job_id=? AND stage_name=?",
|
||
(now, job_id, stage_name),
|
||
)
|
||
await db.commit()
|
||
|
||
|
||
async def _append_stage_log(job_id: int, stage_name: str, text: str) -> None:
|
||
"""Append text to the log_text column of a burnin_stages row."""
|
||
async with _db() as db:
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
await db.execute(
|
||
"""UPDATE burnin_stages
|
||
SET log_text = COALESCE(log_text, '') || ?
|
||
WHERE burnin_job_id=? AND stage_name=?""",
|
||
(text, job_id, stage_name),
|
||
)
|
||
await db.commit()
|
||
|
||
|
||
async def _update_stage_bad_blocks(job_id: int, stage_name: str, count: int) -> None:
|
||
async with _db() as db:
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
await db.execute(
|
||
"UPDATE burnin_stages SET bad_blocks=? WHERE burnin_job_id=? AND stage_name=?",
|
||
(count, job_id, stage_name),
|
||
)
|
||
await db.commit()
|
||
|
||
|
||
async def _store_smart_attrs(drive_id: int, attrs: dict) -> None:
|
||
"""Persist latest SMART attribute dict to drives.smart_attrs (JSON)."""
|
||
import json
|
||
# Convert int keys to str for JSON serialisation
|
||
serialisable = {str(k): v for k, v in attrs.get("attributes", {}).items()}
|
||
blob = json.dumps({
|
||
"health": attrs.get("health", "UNKNOWN"),
|
||
"attrs": serialisable,
|
||
"warnings": attrs.get("warnings", []),
|
||
"failures": attrs.get("failures", []),
|
||
})
|
||
async with _db() as db:
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
await db.execute("UPDATE drives SET smart_attrs=? WHERE id=?", (blob, drive_id))
|
||
await db.commit()
|
||
|
||
|
||
async def _store_smart_raw_output(drive_id: int, test_type: str, raw: str) -> None:
|
||
"""Store raw smartctl output in smart_tests.raw_output."""
|
||
async with _db() as db:
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
await db.execute(
|
||
"UPDATE smart_tests SET raw_output=? WHERE drive_id=? AND test_type=?",
|
||
(raw, drive_id, test_type.lower()),
|
||
)
|
||
await db.commit()
|
||
|
||
|
||
async def _set_stage_error(job_id: int, stage_name: str, error_text: str) -> None:
|
||
async with _db() as db:
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
await db.execute(
|
||
"UPDATE burnin_stages SET error_text=? WHERE burnin_job_id=? AND stage_name=?",
|
||
(error_text, job_id, stage_name),
|
||
)
|
||
await db.commit()
|
||
|
||
|
||
async def _recalculate_progress(job_id: int, profile: str | None = None) -> None:
|
||
"""Recompute overall job % from actual stage rows. profile param is unused (kept for compat)."""
|
||
async with _db() as db:
|
||
db.row_factory = aiosqlite.Row
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
|
||
cur = await db.execute(
|
||
"SELECT stage_name, state, percent FROM burnin_stages WHERE burnin_job_id=? ORDER BY id",
|
||
(job_id,),
|
||
)
|
||
stages = await cur.fetchall()
|
||
if not stages:
|
||
return
|
||
|
||
total_weight = sum(_STAGE_BASE_WEIGHTS.get(s["stage_name"], 5) for s in stages)
|
||
if total_weight == 0:
|
||
return
|
||
|
||
completed = 0.0
|
||
current = None
|
||
for s in stages:
|
||
w = _STAGE_BASE_WEIGHTS.get(s["stage_name"], 5)
|
||
st = s["state"]
|
||
if st == "passed":
|
||
completed += w
|
||
elif st == "running":
|
||
completed += w * (s["percent"] or 0) / 100
|
||
current = s["stage_name"]
|
||
|
||
pct = int(completed / total_weight * 100)
|
||
await db.execute(
|
||
"UPDATE burnin_jobs SET percent=?, stage_name=? WHERE id=?",
|
||
(pct, current, job_id),
|
||
)
|
||
await db.commit()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# SSE push
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _push_update(alert: dict | None = None) -> None:
|
||
"""Notify SSE subscribers that data has changed, with optional browser notification payload."""
|
||
try:
|
||
from app import poller
|
||
poller._notify_subscribers(alert=alert)
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Stuck-job detection (called by poller every ~5 cycles)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
async def check_stuck_jobs() -> None:
|
||
"""Mark jobs that have been 'running' beyond stuck_job_hours as 'unknown'."""
|
||
threshold_seconds = settings.stuck_job_hours * 3600
|
||
|
||
async with _db() as db:
|
||
db.row_factory = aiosqlite.Row
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
|
||
cur = await db.execute("""
|
||
SELECT bj.id, bj.drive_id, d.devname, bj.started_at
|
||
FROM burnin_jobs bj
|
||
JOIN drives d ON d.id = bj.drive_id
|
||
WHERE bj.state = 'running'
|
||
AND bj.started_at IS NOT NULL
|
||
AND (julianday('now') - julianday(bj.started_at)) * 86400 > ?
|
||
""", (threshold_seconds,))
|
||
stuck = await cur.fetchall()
|
||
|
||
if not stuck:
|
||
return
|
||
|
||
now = _now()
|
||
for row in stuck:
|
||
job_id, drive_id, devname, started_at = row[0], row[1], row[2], row[3]
|
||
log.critical(
|
||
"Stuck burn-in detected — marking unknown",
|
||
extra={"job_id": job_id, "devname": devname, "started_at": started_at},
|
||
)
|
||
await db.execute(
|
||
"UPDATE burnin_jobs SET state='unknown', finished_at=? WHERE id=?",
|
||
(now, job_id),
|
||
)
|
||
await db.execute(
|
||
"""UPDATE burnin_stages SET state='unknown', finished_at=?
|
||
WHERE burnin_job_id=? AND state='running'""",
|
||
(now, job_id),
|
||
)
|
||
await db.execute(
|
||
"""INSERT INTO audit_events (event_type, drive_id, burnin_job_id, operator, message)
|
||
VALUES (?,?,?,?,?)""",
|
||
("burnin_stuck", drive_id, job_id, "system",
|
||
f"Job stuck for >{settings.stuck_job_hours}h — automatically marked unknown"),
|
||
)
|
||
|
||
await db.commit()
|
||
|
||
# Actually unstick the running tasks so they release their semaphore slot.
|
||
# Without this the DB state becomes 'unknown' but the asyncio task keeps
|
||
# holding the slot forever — which is the bug that left subsequent jobs
|
||
# permanently 'queued' until container restart.
|
||
for row in stuck:
|
||
job_id = row[0]
|
||
await _kill_remote_process(job_id)
|
||
task = _active_tasks.get(job_id)
|
||
if task and not task.done():
|
||
task.cancel()
|
||
|
||
_push_update()
|
||
log.warning("Marked %d stuck job(s) as unknown", len(stuck))
|