First slice of the planned tech-debt cleanup. burnin.py was 1667 lines and growing; staged extraction gives smaller diffs to review and a clear bisect target if anything regresses. Mechanical move only — no behaviour change. The two extracted modules: * app/burnin/unlock.py — _UnlockGrant, _unlock_grants, PoolMemberError, is_unlocked / unlock_expiry / grant_pool_unlock, plus the four *_TOKEN constants and UNLOCK_TTL_SECONDS. Owns its module-level state; opens its own DB connection in grant_pool_unlock so it doesn't depend on the parent package's _db() helper. * app/burnin/kill.py — _remote_pids dict and the kill_remote_process / set_remote_pid / clear_remote_pid / get_remote_pid helpers. Pulled out of __init__.py so the asyncssh-ignores-signals workaround lives next to the state it operates on. app/burnin/__init__.py re-exports every public symbol the rest of the app imports — `from app import burnin; burnin.start_job(...)`, `burnin.PoolMemberError`, `burnin.UNLOCK_TTL_SECONDS`, etc. all keep working unchanged. Internal aliases `_remote_pids` and `_unlock_grants` on the package root point at the SAME dict objects in the submodules, so existing in-package mutations (set in stages, cleared in cleanup callbacks) work without rewrite. Test fix: tests/test_unlock_flow.py:test_expired_grant_returns_false monkey-patches UNLOCK_TTL_SECONDS. The package-root alias is bound at import time and won't propagate back to the submodule's read site, so the test now patches `app.burnin.unlock.UNLOCK_TTL_SECONDS` directly. Verification: 44/44 unit tests pass in container; /health 200; container boots clean. routes.py, mailer.py, poller.py untouched — the public API is identical. Future: extract stages, task, _common in subsequent versions. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1485 lines
60 KiB
Python
1485 lines
60 KiB
Python
"""
|
||
Burn-in orchestrator.
|
||
|
||
Manages a FIFO queue of burn-in jobs capped at MAX_PARALLEL_BURNINS concurrent
|
||
executions. Each job runs stages sequentially; a failed stage aborts the job.
|
||
|
||
State is persisted to SQLite throughout — DB is source of truth.
|
||
|
||
On startup:
|
||
- Any 'running' jobs from a previous run are marked 'unknown' (interrupted).
|
||
- Any 'queued' jobs are re-enqueued automatically.
|
||
|
||
Cancellation:
|
||
- cancel_job() sets DB state to 'cancelled'.
|
||
- Running stage coroutines check _is_cancelled() at POLL_INTERVAL boundaries
|
||
and abort within a few seconds of the cancel request.
|
||
"""
|
||
|
||
import asyncio
|
||
import logging
|
||
import time
|
||
from contextlib import asynccontextmanager
|
||
from datetime import datetime, timezone
|
||
|
||
import aiosqlite
|
||
|
||
from app.config import settings
|
||
from app.truenas import TrueNASClient
|
||
|
||
log = logging.getLogger(__name__)
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Stage definitions
|
||
# ---------------------------------------------------------------------------
|
||
|
||
STAGE_ORDER: dict[str, list[str]] = {
|
||
# Legacy
|
||
"quick": ["precheck", "short_smart", "io_validate", "final_check"],
|
||
# Single-stage selectable profiles
|
||
"surface": ["precheck", "surface_validate", "final_check"],
|
||
"short": ["precheck", "short_smart", "final_check"],
|
||
"long": ["precheck", "long_smart", "final_check"],
|
||
# Two-stage combos
|
||
"surface_short": ["precheck", "surface_validate", "short_smart", "final_check"],
|
||
"surface_long": ["precheck", "surface_validate", "long_smart", "final_check"],
|
||
"short_long": ["precheck", "short_smart", "long_smart", "final_check"],
|
||
# All three
|
||
"full": ["precheck", "surface_validate", "short_smart", "long_smart", "final_check"],
|
||
}
|
||
|
||
# Per-stage base weights used to compute overall job % progress dynamically
|
||
_STAGE_BASE_WEIGHTS: dict[str, int] = {
|
||
"precheck": 5,
|
||
"surface_validate": 65,
|
||
"short_smart": 12,
|
||
"long_smart": 13,
|
||
"io_validate": 10,
|
||
"final_check": 5,
|
||
}
|
||
|
||
POLL_INTERVAL = 5.0 # seconds between progress checks during active stages
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Module-level state (initialized in init())
|
||
# ---------------------------------------------------------------------------
|
||
|
||
_semaphore: asyncio.Semaphore | None = None
|
||
_client: TrueNASClient | None = None
|
||
|
||
# Live job tracking — keeps a strong reference to every _run_job task so it
|
||
# isn't garbage-collected (asyncio.create_task only keeps a weak ref) and so
|
||
# cancel_job / check_stuck_jobs can actually unwedge a stuck task.
|
||
_active_tasks: dict[int, "asyncio.Task"] = {}
|
||
|
||
# Remote-PID kill machinery + pool-drive unlock state both live in their
|
||
# own submodules. We re-export the names the rest of the app reaches for
|
||
# (and keep the _kill_remote_process / _is_unlocked aliases for callers
|
||
# that grew up before the split).
|
||
from . import kill as _kill # noqa: E402
|
||
from . import unlock as _unlock # noqa: E402
|
||
|
||
_remote_pids = _kill._remote_pids
|
||
_unlock_grants = _unlock._unlock_grants
|
||
|
||
PoolMemberError = _unlock.PoolMemberError
|
||
UNLOCK_TTL_SECONDS = _unlock.UNLOCK_TTL_SECONDS
|
||
BOOT_POOL_NAME = _unlock.BOOT_POOL_NAME
|
||
BOOT_POOL_CONFIRM_TOKEN = _unlock.BOOT_POOL_CONFIRM_TOKEN
|
||
EXPORTED_POOL_ROLE = _unlock.EXPORTED_POOL_ROLE
|
||
EXPORTED_CONFIRM_TOKEN = _unlock.EXPORTED_CONFIRM_TOKEN
|
||
MOUNTED_ROLE = _unlock.MOUNTED_ROLE
|
||
MOUNTED_CONFIRM_TOKEN = _unlock.MOUNTED_CONFIRM_TOKEN
|
||
|
||
unlock_expiry = _unlock.unlock_expiry
|
||
grant_pool_unlock = _unlock.grant_pool_unlock
|
||
_is_unlocked = _unlock.is_unlocked # legacy private name
|
||
_kill_remote_process = _kill.kill_remote_process
|
||
|
||
|
||
def _now() -> str:
|
||
return datetime.now(timezone.utc).isoformat()
|
||
|
||
|
||
@asynccontextmanager
|
||
async def _db():
|
||
"""Open a WAL-mode connection with busy_timeout so writers wait for the lock
|
||
instead of immediately raising 'database is locked' under contention."""
|
||
async with aiosqlite.connect(settings.db_path) as db:
|
||
await db.execute("PRAGMA busy_timeout=10000")
|
||
yield db
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Init + startup reconciliation
|
||
# ---------------------------------------------------------------------------
|
||
|
||
async def init(client: TrueNASClient) -> None:
|
||
global _semaphore, _client
|
||
_semaphore = asyncio.Semaphore(settings.max_parallel_burnins)
|
||
_client = client
|
||
|
||
async with _db() as db:
|
||
db.row_factory = aiosqlite.Row
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
await db.execute("PRAGMA foreign_keys=ON")
|
||
|
||
# Mark interrupted running jobs as unknown
|
||
await db.execute(
|
||
"UPDATE burnin_jobs SET state='unknown', finished_at=? WHERE state='running'",
|
||
(_now(),),
|
||
)
|
||
|
||
# Re-enqueue previously queued jobs
|
||
cur = await db.execute(
|
||
"SELECT id FROM burnin_jobs WHERE state='queued' ORDER BY created_at"
|
||
)
|
||
queued = [r["id"] for r in await cur.fetchall()]
|
||
await db.commit()
|
||
|
||
for job_id in queued:
|
||
_spawn_run_job(job_id)
|
||
|
||
log.info("Burn-in orchestrator ready (max_concurrent=%d)", settings.max_parallel_burnins)
|
||
|
||
|
||
def _spawn_run_job(job_id: int) -> "asyncio.Task":
|
||
"""Schedule a _run_job task and keep a strong reference to it.
|
||
|
||
Plain asyncio.create_task() only leaves a weak reference behind, so the
|
||
task can be GC'd before it ever runs. Storing it in _active_tasks also
|
||
lets cancel_job / check_stuck_jobs cancel it directly.
|
||
"""
|
||
task = asyncio.create_task(_run_job(job_id))
|
||
_active_tasks[job_id] = task
|
||
|
||
def _cleanup(t: "asyncio.Task") -> None:
|
||
# Remove only if it's still us — avoid clobbering a re-enqueued task.
|
||
if _active_tasks.get(job_id) is t:
|
||
_active_tasks.pop(job_id, None)
|
||
_remote_pids.pop(job_id, None)
|
||
|
||
task.add_done_callback(_cleanup)
|
||
return task
|
||
|
||
|
||
# _kill_remote_process is re-exported above from .kill — the original
|
||
# definition was extracted to app/burnin/kill.py in 1.0.0-30.
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Public API
|
||
# ---------------------------------------------------------------------------
|
||
|
||
async def start_job(drive_id: int, profile: str, operator: str,
|
||
stage_order: list[str] | None = None) -> int:
|
||
"""Create and enqueue a burn-in job. Returns the new job ID.
|
||
|
||
If stage_order is provided (e.g. ["short_smart","long_smart","surface_validate"]),
|
||
the job runs those stages in that order (precheck and final_check are always prepended/appended).
|
||
Otherwise the preset STAGE_ORDER[profile] is used.
|
||
"""
|
||
now = _now()
|
||
|
||
# Build the actual stage list
|
||
if stage_order is not None:
|
||
stages = ["precheck"] + list(stage_order) + ["final_check"]
|
||
else:
|
||
stages = STAGE_ORDER[profile]
|
||
|
||
async with _db() as db:
|
||
db.row_factory = aiosqlite.Row
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
await db.execute("PRAGMA foreign_keys=ON")
|
||
|
||
# Reject duplicate active burn-in for same drive
|
||
cur = await db.execute(
|
||
"SELECT COUNT(*) FROM burnin_jobs WHERE drive_id=? AND state IN ('queued','running')",
|
||
(drive_id,),
|
||
)
|
||
if (await cur.fetchone())[0] > 0:
|
||
raise ValueError("Drive already has an active burn-in job")
|
||
|
||
# Pool-membership gate: locked unless the operator explicitly
|
||
# unlocked this drive via /api/v1/drives/{id}/unlock recently.
|
||
# _is_unlocked also checks that the grant's stored (pool_name,
|
||
# pool_role) still matches the live row — a grant issued for an
|
||
# exported drive doesn't carry over if the drive turns out to be
|
||
# in an active pool on the next poll.
|
||
cur = await db.execute(
|
||
"SELECT pool_name, pool_role, devname FROM drives WHERE id=?", (drive_id,)
|
||
)
|
||
drow = await cur.fetchone()
|
||
if drow and drow["pool_name"] and not _is_unlocked(
|
||
drive_id, drow["pool_name"], drow["pool_role"]
|
||
):
|
||
raise PoolMemberError(drive_id, drow["pool_name"], drow["pool_role"])
|
||
|
||
# Closes Codex finding #5: re-check pool state OVER SSH right now,
|
||
# not against cached row. Defends against the 12s poll window
|
||
# where a drive could have been imported into a pool, mounted, or
|
||
# had ZFS labels written between when the operator unlocked it
|
||
# and when they clicked Start. Adds ~200ms per start; cheap
|
||
# against the cost of destroying a freshly-imported pool.
|
||
if drow:
|
||
from app import ssh_client as _ssh
|
||
if _ssh.is_configured():
|
||
fresh = await _ssh.fresh_pool_check_for_drive(drow["devname"])
|
||
cached = (
|
||
{"pool": drow["pool_name"], "role": drow["pool_role"]}
|
||
if drow["pool_name"] else None
|
||
)
|
||
if fresh != cached:
|
||
# State changed since the last poll. Invalidate any
|
||
# unlock grant (it was bound to stale identity) and
|
||
# refuse with a descriptive error so the operator
|
||
# knows to wait for the next poll cycle.
|
||
_unlock_grants.pop(drive_id, None)
|
||
fresh_pool = fresh["pool"] if fresh else None
|
||
fresh_role = fresh["role"] if fresh else None
|
||
if fresh_pool:
|
||
raise PoolMemberError(drive_id, fresh_pool, fresh_role)
|
||
# If the FRESH check shows free but cached said
|
||
# locked, the drive was just removed from a pool —
|
||
# safe to start, but invalidate any stale grant so
|
||
# the operator doesn't reuse old confirmations.
|
||
log.warning(
|
||
"Live pool check for drive_id=%d (%s): cached=%s "
|
||
"fresh=None — drive came free since last poll, "
|
||
"allowing burn-in",
|
||
drive_id, drow["devname"], cached,
|
||
)
|
||
|
||
# Create job. The partial unique index uniq_active_burnin_per_drive
|
||
# (database.py) is the actual race-stopper here: if two concurrent
|
||
# /api/v1/burnin/start calls both pass the SELECT-COUNT check above,
|
||
# only one INSERT can win; the loser raises IntegrityError, which
|
||
# we surface with the same ValueError as the inline duplicate check.
|
||
try:
|
||
cur = await db.execute(
|
||
"""INSERT INTO burnin_jobs (drive_id, profile, state, percent, operator, created_at)
|
||
VALUES (?,?,?,?,?,?) RETURNING id""",
|
||
(drive_id, profile, "queued", 0, operator, now),
|
||
)
|
||
job_id = (await cur.fetchone())["id"]
|
||
except aiosqlite.IntegrityError:
|
||
raise ValueError("Drive already has an active burn-in job")
|
||
|
||
# Create stage rows in the desired execution order
|
||
for stage_name in stages:
|
||
await db.execute(
|
||
"INSERT INTO burnin_stages (burnin_job_id, stage_name, state) VALUES (?,?,?)",
|
||
(job_id, stage_name, "pending"),
|
||
)
|
||
|
||
await db.execute(
|
||
"""INSERT INTO audit_events (event_type, drive_id, burnin_job_id, operator, message)
|
||
VALUES (?,?,?,?,?)""",
|
||
("burnin_queued", drive_id, job_id, operator, f"Queued {profile} burn-in"),
|
||
)
|
||
await db.commit()
|
||
|
||
_spawn_run_job(job_id)
|
||
log.info("Burn-in job %d queued (drive_id=%d profile=%s operator=%s)",
|
||
job_id, drive_id, profile, operator)
|
||
return job_id
|
||
|
||
|
||
async def cancel_job(job_id: int, operator: str) -> bool:
|
||
"""Cancel a queued or running job. Returns True if state was changed."""
|
||
async with _db() as db:
|
||
db.row_factory = aiosqlite.Row
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
|
||
cur = await db.execute(
|
||
"SELECT state, drive_id FROM burnin_jobs WHERE id=?", (job_id,)
|
||
)
|
||
row = await cur.fetchone()
|
||
if not row or row["state"] not in ("queued", "running"):
|
||
return False
|
||
|
||
await db.execute(
|
||
"UPDATE burnin_jobs SET state='cancelled', finished_at=? WHERE id=?",
|
||
(_now(), job_id),
|
||
)
|
||
await db.execute(
|
||
"UPDATE burnin_stages SET state='cancelled' WHERE burnin_job_id=? AND state IN ('pending','running')",
|
||
(job_id,),
|
||
)
|
||
await db.execute(
|
||
"""INSERT INTO audit_events (event_type, drive_id, burnin_job_id, operator, message)
|
||
VALUES (?,?,?,?,?)""",
|
||
("burnin_cancelled", row["drive_id"], job_id, operator, "Cancelled by operator"),
|
||
)
|
||
await db.commit()
|
||
|
||
# Kill the remote child process FIRST (so proc.wait() in the running task
|
||
# can return), then cancel the task so any other awaits unblock.
|
||
await _kill_remote_process(job_id)
|
||
task = _active_tasks.get(job_id)
|
||
if task and not task.done():
|
||
task.cancel()
|
||
|
||
log.info("Burn-in job %d cancelled by %s", job_id, operator)
|
||
return True
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Job runner
|
||
# ---------------------------------------------------------------------------
|
||
|
||
async def _thermal_gate_ok() -> bool:
|
||
"""True if it's thermally safe to start a new burn-in.
|
||
Checks the peak temperature of drives currently under active burn-in.
|
||
"""
|
||
try:
|
||
async with _db() as db:
|
||
cur = await db.execute("""
|
||
SELECT MAX(d.temperature_c)
|
||
FROM drives d
|
||
JOIN burnin_jobs bj ON bj.drive_id = d.id
|
||
WHERE bj.state = 'running' AND d.temperature_c IS NOT NULL
|
||
""")
|
||
row = await cur.fetchone()
|
||
max_temp = row[0] if row and row[0] is not None else None
|
||
return max_temp is None or max_temp < settings.temp_warn_c
|
||
except Exception:
|
||
return True # Never block on error
|
||
|
||
|
||
async def _run_job(job_id: int) -> None:
|
||
"""Acquire semaphore slot, execute all stages, persist final state."""
|
||
assert _semaphore is not None, "burnin.init() not called"
|
||
|
||
# Adaptive thermal gate: wait before competing for a slot if running drives
|
||
# are already at or above the warning threshold. This prevents layering a
|
||
# new burn-in on top of a thermally-stressed system. Gives up after 3 min
|
||
# and proceeds anyway so jobs don't queue indefinitely.
|
||
for _attempt in range(18): # 18 × 10 s = 3 min max
|
||
if await _thermal_gate_ok():
|
||
break
|
||
if _attempt == 0:
|
||
log.info(
|
||
"Thermal gate: job %d waiting — running drive temps at or above %d°C",
|
||
job_id, settings.temp_warn_c,
|
||
)
|
||
await asyncio.sleep(10)
|
||
else:
|
||
log.warning("Thermal gate timed out for job %d — proceeding anyway", job_id)
|
||
|
||
async with _semaphore:
|
||
if await _is_cancelled(job_id):
|
||
return
|
||
|
||
# Transition queued → running
|
||
async with _db() as db:
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
row = await (await db.execute(
|
||
"SELECT drive_id, profile FROM burnin_jobs WHERE id=?", (job_id,)
|
||
)).fetchone()
|
||
if not row:
|
||
return
|
||
drive_id, profile = row[0], row[1]
|
||
|
||
cur = await db.execute("SELECT devname, serial, model FROM drives WHERE id=?", (drive_id,))
|
||
devname_row = await cur.fetchone()
|
||
if not devname_row:
|
||
return
|
||
devname = devname_row[0]
|
||
drive_serial = devname_row[1]
|
||
drive_model = devname_row[2]
|
||
|
||
await db.execute(
|
||
"UPDATE burnin_jobs SET state='running', started_at=? WHERE id=?",
|
||
(_now(), job_id),
|
||
)
|
||
await db.execute(
|
||
"""INSERT INTO audit_events (event_type, drive_id, burnin_job_id, operator, message)
|
||
VALUES (?,?,?,(SELECT operator FROM burnin_jobs WHERE id=?),?)""",
|
||
("burnin_started", drive_id, job_id, job_id, f"Started {profile} burn-in on {devname}"),
|
||
)
|
||
# Read stage order from DB (respects any custom order set at job creation)
|
||
stage_cur = await db.execute(
|
||
"SELECT stage_name FROM burnin_stages WHERE burnin_job_id=? ORDER BY id",
|
||
(job_id,),
|
||
)
|
||
job_stages = [r[0] for r in await stage_cur.fetchall()]
|
||
await db.commit()
|
||
|
||
_push_update()
|
||
log.info("Burn-in started", extra={"job_id": job_id, "devname": devname, "profile": profile})
|
||
|
||
success = False
|
||
error_text = None
|
||
was_cancelled = False
|
||
try:
|
||
success = await _execute_stages(job_id, job_stages, devname, drive_id)
|
||
except asyncio.CancelledError:
|
||
was_cancelled = True
|
||
except Exception as exc:
|
||
error_text = str(exc)
|
||
log.exception("Burn-in raised exception", extra={"job_id": job_id, "devname": devname})
|
||
|
||
# If the job has already moved to a terminal state — by cancel_job
|
||
# ('cancelled') or check_stuck_jobs ('unknown') — leave it alone. The
|
||
# task may have been cancelled mid-stage; finalizing as 'failed' would
|
||
# clobber that audit-meaningful terminal state.
|
||
async with _db() as db:
|
||
cur = await db.execute("SELECT state FROM burnin_jobs WHERE id=?", (job_id,))
|
||
cur_row = await cur.fetchone()
|
||
if cur_row and cur_row[0] != "running":
|
||
return
|
||
|
||
# Cancellation arriving here means the asyncio task was cancelled
|
||
# by something other than cancel_job/check_stuck_jobs (shutdown,
|
||
# uvicorn reload, future code paths). The DB still says 'running',
|
||
# so we have to write *some* terminal state, but classifying the
|
||
# interrupted job as 'failed' would lie — we don't actually know
|
||
# whether the underlying SMART/badblocks work passed or not.
|
||
if was_cancelled:
|
||
final_state = "unknown"
|
||
else:
|
||
final_state = "passed" if success else "failed"
|
||
async with _db() as db:
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
await db.execute(
|
||
"UPDATE burnin_jobs SET state=?, percent=?, finished_at=?, error_text=? WHERE id=?",
|
||
(final_state, 100 if success else None, _now(), error_text, job_id),
|
||
)
|
||
await db.execute(
|
||
"""INSERT INTO audit_events (event_type, drive_id, burnin_job_id, operator, message)
|
||
VALUES (?,?,?,(SELECT operator FROM burnin_jobs WHERE id=?),?)""",
|
||
(f"burnin_{final_state}", drive_id, job_id, job_id,
|
||
f"Burn-in {final_state} on {devname}"),
|
||
)
|
||
await db.commit()
|
||
|
||
# Build SSE alert for browser notifications
|
||
alert = {
|
||
"state": final_state,
|
||
"job_id": job_id,
|
||
"devname": devname,
|
||
"serial": drive_serial,
|
||
"model": drive_model,
|
||
"error_text": error_text,
|
||
}
|
||
_push_update(alert=alert)
|
||
log.info("Burn-in finished", extra={"job_id": job_id, "devname": devname, "state": final_state})
|
||
|
||
# Fire webhook + immediate email in background (non-blocking)
|
||
try:
|
||
from app import notifier
|
||
cur2 = None
|
||
async with _db() as db2:
|
||
db2.row_factory = aiosqlite.Row
|
||
cur2 = await db2.execute(
|
||
"SELECT profile, operator FROM burnin_jobs WHERE id=?", (job_id,)
|
||
)
|
||
job_row = await cur2.fetchone()
|
||
if job_row:
|
||
# Get bad_blocks count from surface_validate stage if present
|
||
bad_blocks = 0
|
||
async with _db() as db3:
|
||
cur3 = await db3.execute(
|
||
"SELECT bad_blocks FROM burnin_stages WHERE burnin_job_id=? AND stage_name='surface_validate'",
|
||
(job_id,)
|
||
)
|
||
bb_row = await cur3.fetchone()
|
||
if bb_row and bb_row[0]:
|
||
bad_blocks = bb_row[0]
|
||
asyncio.create_task(notifier.notify_job_complete(
|
||
job_id=job_id,
|
||
devname=devname,
|
||
serial=drive_serial,
|
||
model=drive_model,
|
||
state=final_state,
|
||
profile=job_row["profile"],
|
||
operator=job_row["operator"],
|
||
error_text=error_text,
|
||
bad_blocks=bad_blocks,
|
||
))
|
||
except Exception as exc:
|
||
log.error("Failed to schedule notifications: %s", exc)
|
||
|
||
|
||
async def _execute_stages(job_id: int, stages: list[str], devname: str, drive_id: int) -> bool:
|
||
for stage_name in stages:
|
||
if await _is_cancelled(job_id):
|
||
return False
|
||
|
||
await _start_stage(job_id, stage_name)
|
||
_push_update()
|
||
|
||
try:
|
||
ok = await _dispatch_stage(job_id, stage_name, devname, drive_id)
|
||
except Exception as exc:
|
||
log.error("Stage raised exception: %s", exc, extra={"job_id": job_id, "devname": devname, "stage": stage_name})
|
||
ok = False
|
||
await _finish_stage(job_id, stage_name, success=False, error_text=str(exc))
|
||
_push_update()
|
||
return False
|
||
|
||
if not ok and await _is_cancelled(job_id):
|
||
# Stage was aborted due to cancellation — mark it cancelled, not failed
|
||
await _cancel_stage(job_id, stage_name)
|
||
else:
|
||
await _finish_stage(job_id, stage_name, success=ok)
|
||
await _recalculate_progress(job_id)
|
||
_push_update()
|
||
|
||
if not ok:
|
||
return False
|
||
|
||
return True
|
||
|
||
|
||
async def _dispatch_stage(job_id: int, stage_name: str, devname: str, drive_id: int) -> bool:
|
||
if stage_name == "precheck":
|
||
return await _stage_precheck(job_id, drive_id)
|
||
elif stage_name == "short_smart":
|
||
return await _stage_smart_test(job_id, devname, "SHORT", "short_smart", drive_id)
|
||
elif stage_name == "long_smart":
|
||
return await _stage_smart_test(job_id, devname, "LONG", "long_smart", drive_id)
|
||
elif stage_name == "surface_validate":
|
||
return await _stage_surface_validate(job_id, devname, drive_id)
|
||
elif stage_name == "io_validate":
|
||
return await _stage_timed_simulate(job_id, "io_validate", settings.io_validate_seconds)
|
||
elif stage_name == "final_check":
|
||
return await _stage_final_check(job_id, devname, drive_id)
|
||
return True
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Individual stage implementations
|
||
# ---------------------------------------------------------------------------
|
||
|
||
async def _stage_precheck(job_id: int, drive_id: int) -> bool:
|
||
"""Check SMART health and temperature before starting destructive work."""
|
||
async with _db() as db:
|
||
cur = await db.execute(
|
||
"SELECT smart_health, temperature_c FROM drives WHERE id=?", (drive_id,)
|
||
)
|
||
row = await cur.fetchone()
|
||
|
||
if not row:
|
||
return False
|
||
|
||
health, temp = row[0], row[1]
|
||
|
||
if health == "FAILED":
|
||
await _set_stage_error(job_id, "precheck", "Drive SMART health is FAILED — refusing to burn in")
|
||
return False
|
||
|
||
if temp and temp > settings.temp_crit_c:
|
||
await _set_stage_error(job_id, "precheck", f"Drive temperature {temp}°C exceeds {settings.temp_crit_c}°C limit")
|
||
return False
|
||
|
||
await asyncio.sleep(1) # Simulate brief check
|
||
return True
|
||
|
||
|
||
async def _stage_smart_test(job_id: int, devname: str, test_type: str, stage_name: str,
|
||
drive_id: int | None = None) -> bool:
|
||
"""Start a SMART test. Uses SSH if configured, TrueNAS REST API otherwise."""
|
||
from app import ssh_client
|
||
if ssh_client.is_configured():
|
||
return await _stage_smart_test_ssh(job_id, devname, test_type, stage_name, drive_id)
|
||
return await _stage_smart_test_api(job_id, devname, test_type, stage_name)
|
||
|
||
|
||
async def _stage_smart_test_api(job_id: int, devname: str, test_type: str, stage_name: str) -> bool:
|
||
"""TrueNAS REST API path for SMART test (mock / dev mode)."""
|
||
tn_job_id = await _client.start_smart_test([devname], test_type)
|
||
|
||
while True:
|
||
if await _is_cancelled(job_id):
|
||
try:
|
||
await _client.abort_job(tn_job_id)
|
||
except Exception:
|
||
pass
|
||
return False
|
||
|
||
jobs = await _client.get_smart_jobs()
|
||
job = next((j for j in jobs if j["id"] == tn_job_id), None)
|
||
|
||
if not job:
|
||
return False
|
||
|
||
state = job["state"]
|
||
pct = job["progress"]["percent"]
|
||
|
||
await _update_stage_percent(job_id, stage_name, pct)
|
||
await _recalculate_progress(job_id, None)
|
||
_push_update()
|
||
|
||
if state == "SUCCESS":
|
||
return True
|
||
elif state in ("FAILED", "ABORTED"):
|
||
await _set_stage_error(job_id, stage_name,
|
||
job.get("error") or f"SMART {test_type} test failed")
|
||
return False
|
||
|
||
await asyncio.sleep(POLL_INTERVAL)
|
||
|
||
|
||
async def _stage_smart_test_ssh(job_id: int, devname: str, test_type: str, stage_name: str,
|
||
drive_id: int | None) -> bool:
|
||
"""SSH path for SMART test — runs smartctl directly on TrueNAS."""
|
||
from app import ssh_client
|
||
|
||
# Start the test
|
||
try:
|
||
startup = await ssh_client.start_smart_test(devname, test_type)
|
||
await _append_stage_log(job_id, stage_name, startup + "\n")
|
||
except Exception as exc:
|
||
await _set_stage_error(job_id, stage_name, f"Failed to start SMART test via SSH: {exc}")
|
||
return False
|
||
|
||
# Brief pause to let the test register in smartctl output
|
||
await asyncio.sleep(3)
|
||
|
||
# Throttle log_text appends — every poll on a multi-hour long_smart bloated
|
||
# log_text to 50+ MB and triggered SQLite "database is locked" because each
|
||
# COALESCE-then-append rewrites the whole column. Append every ~60s, on the
|
||
# first poll, and on any state change.
|
||
LOG_EVERY_N_POLLS = 12
|
||
poll_count = 0
|
||
last_state: str | None = None
|
||
|
||
# Poll until complete
|
||
while True:
|
||
if await _is_cancelled(job_id):
|
||
try:
|
||
await ssh_client.abort_smart_test(devname)
|
||
except Exception:
|
||
pass
|
||
return False
|
||
|
||
await asyncio.sleep(POLL_INTERVAL)
|
||
|
||
try:
|
||
progress = await ssh_client.poll_smart_progress(devname)
|
||
except Exception as exc:
|
||
log.warning("SSH SMART poll failed: %s", exc, extra={"job_id": job_id})
|
||
await _append_stage_log(job_id, stage_name, f"[poll error] {exc}\n")
|
||
continue
|
||
|
||
poll_count += 1
|
||
state_changed = progress["state"] != last_state
|
||
last_state = progress["state"]
|
||
if poll_count == 1 or poll_count % LOG_EVERY_N_POLLS == 0 or state_changed:
|
||
await _append_stage_log(job_id, stage_name, progress["output"] + "\n---\n")
|
||
|
||
if progress["state"] == "running":
|
||
pct = max(0, 100 - progress["percent_remaining"])
|
||
await _update_stage_percent(job_id, stage_name, pct)
|
||
await _recalculate_progress(job_id)
|
||
_push_update()
|
||
|
||
elif progress["state"] == "passed":
|
||
await _update_stage_percent(job_id, stage_name, 100)
|
||
# Run attribute check
|
||
if drive_id is not None:
|
||
try:
|
||
attrs = await ssh_client.get_smart_attributes(devname)
|
||
await _store_smart_attrs(drive_id, attrs)
|
||
await _store_smart_raw_output(drive_id, test_type, attrs["raw_output"])
|
||
if attrs["failures"]:
|
||
error = "SMART attribute failures: " + "; ".join(attrs["failures"])
|
||
await _set_stage_error(job_id, stage_name, error)
|
||
return False
|
||
if attrs["warnings"]:
|
||
await _append_stage_log(
|
||
job_id, stage_name,
|
||
"[WARNING] " + "; ".join(attrs["warnings"]) + "\n"
|
||
)
|
||
except Exception as exc:
|
||
log.warning("Failed to retrieve SMART attributes: %s", exc)
|
||
await _recalculate_progress(job_id)
|
||
_push_update()
|
||
return True
|
||
|
||
elif progress["state"] == "failed":
|
||
await _set_stage_error(job_id, stage_name, f"SMART {test_type} test failed")
|
||
return False
|
||
# "unknown" → keep polling
|
||
|
||
|
||
async def _badblocks_available() -> bool:
|
||
"""Check if badblocks is installed on the remote host (Linux/SCALE only)."""
|
||
from app import ssh_client
|
||
try:
|
||
async with await ssh_client._connect() as conn:
|
||
result = await conn.run("which badblocks", check=False)
|
||
return result.returncode == 0
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
async def _stage_surface_validate(job_id: int, devname: str, drive_id: int) -> bool:
|
||
"""
|
||
Surface validation stage — auto-routes to the right implementation:
|
||
|
||
1. NVMe device + SSH + nvme-cli available (TrueNAS SCALE):
|
||
→ `nvme format -s 1 /dev/{devname}` (cryptographic erase).
|
||
Far faster than badblocks on NVMe (seconds vs hours) and
|
||
exercises the controller's secure-erase path, not just user-LBA
|
||
writes.
|
||
2. SSH configured + badblocks available (TrueNAS SCALE / Linux):
|
||
→ badblocks -wsv -b N -c N -p N /dev/{devname} directly over SSH.
|
||
3. SSH configured + badblocks NOT available (TrueNAS CORE / FreeBSD):
|
||
→ uses TrueNAS REST API disk.wipe FULL job + post-wipe SMART check.
|
||
4. No SSH:
|
||
→ simulated timed progress (dev/mock mode).
|
||
"""
|
||
from app import ssh_client
|
||
if ssh_client.is_configured():
|
||
if devname.startswith("nvme") and await _nvme_cli_available():
|
||
return await _stage_surface_validate_nvme(job_id, devname, drive_id)
|
||
if await _badblocks_available():
|
||
return await _stage_surface_validate_ssh(job_id, devname, drive_id)
|
||
# TrueNAS CORE/FreeBSD: badblocks not available — use native wipe API
|
||
await _append_stage_log(
|
||
job_id, "surface_validate",
|
||
"[INFO] badblocks not found on host (TrueNAS CORE/FreeBSD) — "
|
||
"using TrueNAS disk.wipe API (FULL write pass).\n\n"
|
||
)
|
||
return await _stage_surface_validate_truenas(job_id, devname, drive_id)
|
||
return await _stage_timed_simulate(job_id, "surface_validate", settings.surface_validate_seconds)
|
||
|
||
|
||
async def _nvme_cli_available() -> bool:
|
||
"""Check if nvme-cli is installed on the remote host."""
|
||
from app import ssh_client
|
||
try:
|
||
async with await ssh_client._connect() as conn:
|
||
r = await conn.run("which nvme", check=False)
|
||
return r.returncode == 0
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
async def _stage_surface_validate_nvme(job_id: int, devname: str,
|
||
drive_id: int) -> bool:
|
||
"""NVMe destructive surface test via `nvme format -s 1` (crypto erase).
|
||
|
||
Crypto-erase nukes the data encryption key on the drive's controller,
|
||
rendering all stored data unrecoverable in milliseconds; the actual
|
||
flash is then implicitly trim-able. This is the canonical destructive
|
||
burn-in for NVMe — badblocks would write the entire LBA space, which
|
||
is slower AND wears the flash unnecessarily.
|
||
|
||
Post-format we re-read SMART attributes; the drive should report all
|
||
counters reset (life used + spare) and PASSED health.
|
||
"""
|
||
from app import ssh_client
|
||
|
||
await _append_stage_log(
|
||
job_id, "surface_validate",
|
||
f"[START] nvme format -s 1 /dev/{devname}\n"
|
||
f"[NOTE] Cryptographic erase — destroys all data on /dev/{devname}.\n\n"
|
||
)
|
||
|
||
cmd = f"nvme format -s 1 --force /dev/{devname}"
|
||
try:
|
||
async with await ssh_client._connect() as conn:
|
||
r = await asyncio.wait_for(
|
||
conn.run(cmd, check=False), timeout=600
|
||
)
|
||
except Exception as exc:
|
||
await _append_stage_log(
|
||
job_id, "surface_validate", f"\n[SSH error] {exc}\n"
|
||
)
|
||
await _set_stage_error(
|
||
job_id, "surface_validate", f"NVMe format SSH error: {exc}"
|
||
)
|
||
return False
|
||
|
||
output = (r.stdout or "") + (r.stderr or "")
|
||
await _append_stage_log(job_id, "surface_validate", output + "\n")
|
||
|
||
if r.returncode != 0:
|
||
await _set_stage_error(
|
||
job_id, "surface_validate",
|
||
f"nvme format exited {r.returncode}: {output.strip()[:200]}"
|
||
)
|
||
return False
|
||
|
||
# Sanity-check post-format SMART health. Mirrors the surface_validate
|
||
# SSH path's check parity — fail on FAILED health, fail on real
|
||
# SMART attribute failures, log warnings but don't fail. A transport
|
||
# error here is treated as a soft pass (log + continue) so a single
|
||
# SSH blip after a successful format doesn't undo the work.
|
||
try:
|
||
attrs = await ssh_client.get_smart_attributes(devname)
|
||
ssh_only_failures = [
|
||
f for f in (attrs.get("failures") or []) if f.startswith("SSH error:")
|
||
]
|
||
real_failures = [
|
||
f for f in (attrs.get("failures") or []) if not f.startswith("SSH error:")
|
||
]
|
||
if attrs.get("health") == "FAILED":
|
||
await _set_stage_error(
|
||
job_id, "surface_validate",
|
||
"NVMe SMART health FAILED after format",
|
||
)
|
||
return False
|
||
if real_failures:
|
||
await _set_stage_error(
|
||
job_id, "surface_validate",
|
||
"NVMe SMART attribute failures after format: "
|
||
+ "; ".join(real_failures),
|
||
)
|
||
return False
|
||
if ssh_only_failures:
|
||
await _append_stage_log(
|
||
job_id, "surface_validate",
|
||
"[WARN] post-format SMART check had SSH errors "
|
||
"(soft-passing): " + "; ".join(ssh_only_failures) + "\n",
|
||
)
|
||
if attrs.get("warnings"):
|
||
await _append_stage_log(
|
||
job_id, "surface_validate",
|
||
"[WARN] " + "; ".join(attrs["warnings"]) + "\n",
|
||
)
|
||
except Exception as exc:
|
||
log.warning("Post-format SMART check error on %s: %s", devname, exc)
|
||
await _append_stage_log(
|
||
job_id, "surface_validate",
|
||
f"[WARN] post-format SMART check raised: {exc}\n",
|
||
)
|
||
|
||
await _update_stage_percent(job_id, "surface_validate", 100)
|
||
await _recalculate_progress(job_id)
|
||
_push_update()
|
||
return True
|
||
|
||
|
||
async def _stage_surface_validate_ssh(job_id: int, devname: str, drive_id: int) -> bool:
|
||
"""Run badblocks over SSH, streaming output to stage log."""
|
||
from app import ssh_client
|
||
|
||
await _append_stage_log(
|
||
job_id, "surface_validate",
|
||
f"[START] badblocks -wsv -b {settings.surface_validate_block_size} "
|
||
f"-c {settings.surface_validate_block_buffer} "
|
||
f"-p {settings.surface_validate_passes} /dev/{devname}\n"
|
||
f"[NOTE] This is a DESTRUCTIVE write test. "
|
||
f"All data on /dev/{devname} will be overwritten.\n\n"
|
||
)
|
||
|
||
def _is_cancelled_sync() -> bool:
|
||
# Synchronous version — we check the DB state flag set by cancel_job()
|
||
import asyncio
|
||
loop = asyncio.get_event_loop()
|
||
try:
|
||
return loop.run_until_complete(_is_cancelled(job_id))
|
||
except Exception:
|
||
return False
|
||
|
||
last_logged_pct = [-1]
|
||
|
||
def on_progress(pct: int, bad_blocks: int, line: str) -> None:
|
||
nonlocal last_logged_pct
|
||
# Write to log (fire-and-forget via asyncio.create_task from sync context)
|
||
# The log append is done in the async flush below
|
||
pass
|
||
|
||
accumulated_lines: list[str] = []
|
||
|
||
async def on_progress_async(pct: int, bad_blocks: int, line: str) -> None:
|
||
accumulated_lines.append(line)
|
||
# Flush to DB and update progress every ~25 lines to avoid excessive DB writes
|
||
if len(accumulated_lines) % 25 == 0:
|
||
await _append_stage_log(job_id, "surface_validate", "".join(accumulated_lines[-25:]))
|
||
await _update_stage_bad_blocks(job_id, "surface_validate", bad_blocks)
|
||
await _update_stage_percent(job_id, "surface_validate", pct)
|
||
await _recalculate_progress(job_id)
|
||
_push_update()
|
||
if await _is_cancelled(job_id):
|
||
raise asyncio.CancelledError
|
||
|
||
# Run badblocks — we adapt the callback pattern to async by collecting then flushing
|
||
result = {"bad_blocks": 0, "output": "", "aborted": False}
|
||
try:
|
||
# The actual streaming; we handle progress via the accumulated_lines pattern
|
||
bad_blocks_total = 0
|
||
output_lines: list[str] = []
|
||
|
||
async with await ssh_client._connect() as conn:
|
||
# Wrap in `sh -c 'echo PID:$$; exec ...'` so we get the remote
|
||
# PID on the first stdout line. asyncssh's proc.kill() sends an
|
||
# SSH signal request that OpenSSH's sshd ignores by default, so
|
||
# we need the PID to issue an out-of-band `kill -9` over a fresh
|
||
# session when we want to abort.
|
||
#
|
||
# Block geometry is operator-tunable (Settings → Burn-in):
|
||
# -b N block size in bytes (settings.surface_validate_block_size)
|
||
# -c N blocks held per IO (settings.surface_validate_block_buffer)
|
||
# -p N pass count (settings.surface_validate_passes)
|
||
# Defaults preserve original behavior (-b 4096 -c 64 -p 1).
|
||
bb_args = (
|
||
f"-wsv "
|
||
f"-b {settings.surface_validate_block_size} "
|
||
f"-c {settings.surface_validate_block_buffer} "
|
||
f"-p {settings.surface_validate_passes}"
|
||
)
|
||
cmd = (
|
||
f"sh -c 'echo PID:$$; exec badblocks {bb_args} /dev/{devname}'"
|
||
)
|
||
async with conn.create_process(cmd) as proc:
|
||
import re as _re
|
||
|
||
pid_seen = False
|
||
|
||
async def _drain(stream, is_stderr: bool):
|
||
nonlocal bad_blocks_total, pid_seen
|
||
async for raw in stream:
|
||
line = raw if isinstance(raw, str) else raw.decode("utf-8", errors="replace")
|
||
|
||
# First stdout line is "PID:<n>" from the wrapping shell.
|
||
# Capture it and don't append it to the user-visible log.
|
||
if not is_stderr and not pid_seen and line.startswith("PID:"):
|
||
pid_seen = True
|
||
try:
|
||
_remote_pids[job_id] = int(line[4:].strip())
|
||
log.info(
|
||
"Captured remote PID %d for job %d (badblocks)",
|
||
_remote_pids[job_id], job_id,
|
||
)
|
||
except ValueError:
|
||
pass
|
||
continue
|
||
|
||
output_lines.append(line)
|
||
|
||
if is_stderr:
|
||
m = _re.search(r"([\d.]+)%\s+done", line)
|
||
if m:
|
||
pct = min(99, int(float(m.group(1))))
|
||
await _update_stage_percent(job_id, "surface_validate", pct)
|
||
await _update_stage_bad_blocks(job_id, "surface_validate", bad_blocks_total)
|
||
await _recalculate_progress(job_id)
|
||
_push_update()
|
||
else:
|
||
stripped = line.strip()
|
||
if stripped and stripped.isdigit():
|
||
bad_blocks_total += 1
|
||
|
||
# Append to DB log in chunks
|
||
if len(output_lines) % 20 == 0:
|
||
chunk = "".join(output_lines[-20:])
|
||
await _append_stage_log(job_id, "surface_validate", chunk)
|
||
|
||
# Abort on bad block threshold
|
||
if bad_blocks_total > settings.bad_block_threshold:
|
||
await _kill_remote_process(job_id)
|
||
output_lines.append(
|
||
f"\n[ABORTED] {bad_blocks_total} bad block(s) exceeded "
|
||
f"threshold ({settings.bad_block_threshold})\n"
|
||
)
|
||
return
|
||
|
||
if await _is_cancelled(job_id):
|
||
await _kill_remote_process(job_id)
|
||
return
|
||
|
||
await asyncio.gather(
|
||
_drain(proc.stdout, False),
|
||
_drain(proc.stderr, True),
|
||
return_exceptions=True,
|
||
)
|
||
# Bound proc.wait so a remote process that ignored our kill
|
||
# signal (or that we never managed to kill) can't pin this
|
||
# task in the semaphore forever. Closing the connection on
|
||
# exit will deliver SIGPIPE to the remote on its next write.
|
||
try:
|
||
await asyncio.wait_for(proc.wait(), timeout=15)
|
||
except asyncio.TimeoutError:
|
||
log.warning(
|
||
"proc.wait() timed out for job %d — abandoning channel",
|
||
job_id,
|
||
)
|
||
|
||
# Flush only lines we haven't already written in 20-line chunks.
|
||
# Previously we appended the FULL accumulated output here too,
|
||
# doubling the stored log_text size for every surface_validate
|
||
# stage and pushing app.db into hundreds of MB.
|
||
flushed_count = (len(output_lines) // 20) * 20
|
||
tail = "".join(output_lines[flushed_count:])
|
||
if tail:
|
||
await _append_stage_log(job_id, "surface_validate", tail)
|
||
result["bad_blocks"] = bad_blocks_total
|
||
result["output"] = "".join(output_lines) # in-memory only, not re-stored
|
||
result["aborted"] = bad_blocks_total > settings.bad_block_threshold
|
||
|
||
except asyncio.CancelledError:
|
||
return False
|
||
except Exception as exc:
|
||
await _append_stage_log(job_id, "surface_validate", f"\n[SSH error] {exc}\n")
|
||
await _set_stage_error(job_id, "surface_validate", f"SSH badblocks error: {exc}")
|
||
return False
|
||
|
||
await _update_stage_bad_blocks(job_id, "surface_validate", result["bad_blocks"])
|
||
|
||
if result["aborted"] or result["bad_blocks"] > settings.bad_block_threshold:
|
||
await _set_stage_error(
|
||
job_id, "surface_validate",
|
||
f"Surface validate FAILED: {result['bad_blocks']} bad block(s) found "
|
||
f"(threshold: {settings.bad_block_threshold})"
|
||
)
|
||
return False
|
||
|
||
return True
|
||
|
||
|
||
async def _stage_surface_validate_truenas(job_id: int, devname: str, drive_id: int) -> bool:
|
||
"""
|
||
Surface validation via TrueNAS CORE disk.wipe REST API.
|
||
Used on FreeBSD (TrueNAS CORE) where badblocks is unavailable.
|
||
|
||
Sends a FULL write-zero pass across the entire disk, polls progress,
|
||
then runs a post-wipe SMART attribute check to catch reallocated sectors.
|
||
"""
|
||
from app import ssh_client
|
||
|
||
await _append_stage_log(
|
||
job_id, "surface_validate",
|
||
f"[START] TrueNAS disk.wipe FULL — {devname}\n"
|
||
f"[NOTE] DESTRUCTIVE: all data on {devname} will be overwritten.\n\n"
|
||
)
|
||
|
||
# Start the wipe job
|
||
try:
|
||
tn_job_id = await _client.wipe_disk(devname, "FULL")
|
||
except Exception as exc:
|
||
await _set_stage_error(job_id, "surface_validate", f"Failed to start disk.wipe: {exc}")
|
||
return False
|
||
|
||
await _append_stage_log(
|
||
job_id, "surface_validate",
|
||
f"[JOB] TrueNAS wipe job started (job_id={tn_job_id})\n"
|
||
)
|
||
|
||
# Poll until complete
|
||
log_flush_counter = 0
|
||
while True:
|
||
if await _is_cancelled(job_id):
|
||
try:
|
||
await _client.abort_job(tn_job_id)
|
||
except Exception:
|
||
pass
|
||
return False
|
||
|
||
await asyncio.sleep(POLL_INTERVAL)
|
||
|
||
try:
|
||
job = await _client.get_job(tn_job_id)
|
||
except Exception as exc:
|
||
log.warning("Wipe job poll failed: %s", exc, extra={"job_id": job_id})
|
||
await _append_stage_log(job_id, "surface_validate", f"[poll error] {exc}\n")
|
||
continue
|
||
|
||
if not job:
|
||
await _set_stage_error(job_id, "surface_validate", f"Wipe job {tn_job_id} not found")
|
||
return False
|
||
|
||
state = job.get("state", "")
|
||
pct = int(job.get("progress", {}).get("percent", 0) or 0)
|
||
desc = job.get("progress", {}).get("description", "")
|
||
|
||
await _update_stage_percent(job_id, "surface_validate", min(pct, 99))
|
||
await _recalculate_progress(job_id)
|
||
_push_update()
|
||
|
||
# Log progress description every ~5 polls to avoid DB spam
|
||
log_flush_counter += 1
|
||
if desc and log_flush_counter % 5 == 0:
|
||
await _append_stage_log(job_id, "surface_validate", f"[{pct}%] {desc}\n")
|
||
|
||
if state == "SUCCESS":
|
||
await _update_stage_percent(job_id, "surface_validate", 100)
|
||
await _append_stage_log(
|
||
job_id, "surface_validate",
|
||
f"\n[DONE] Wipe job {tn_job_id} completed successfully.\n"
|
||
)
|
||
# Post-wipe SMART check — catch any sectors that failed under write stress
|
||
if ssh_client.is_configured() and drive_id is not None:
|
||
await _append_stage_log(
|
||
job_id, "surface_validate",
|
||
"[CHECK] Running post-wipe SMART attribute check...\n"
|
||
)
|
||
try:
|
||
attrs = await ssh_client.get_smart_attributes(devname)
|
||
await _store_smart_attrs(drive_id, attrs)
|
||
if attrs["failures"]:
|
||
error = "Post-wipe SMART check: " + "; ".join(attrs["failures"])
|
||
await _set_stage_error(job_id, "surface_validate", error)
|
||
return False
|
||
if attrs["warnings"]:
|
||
await _append_stage_log(
|
||
job_id, "surface_validate",
|
||
"[WARNING] " + "; ".join(attrs["warnings"]) + "\n"
|
||
)
|
||
await _append_stage_log(
|
||
job_id, "surface_validate",
|
||
f"[CHECK] SMART health: {attrs['health']} — no critical attributes.\n"
|
||
)
|
||
except Exception as exc:
|
||
log.warning("Post-wipe SMART check failed: %s", exc)
|
||
await _append_stage_log(
|
||
job_id, "surface_validate",
|
||
f"[WARN] Post-wipe SMART check failed (non-fatal): {exc}\n"
|
||
)
|
||
return True
|
||
|
||
elif state in ("FAILED", "ABORTED", "ERROR"):
|
||
error_msg = job.get("error") or f"Disk wipe failed (state={state})"
|
||
await _set_stage_error(
|
||
job_id, "surface_validate",
|
||
f"TrueNAS disk.wipe FAILED: {error_msg}"
|
||
)
|
||
return False
|
||
# RUNNING or WAITING — keep polling
|
||
|
||
|
||
async def _stage_timed_simulate(job_id: int, stage_name: str, duration_seconds: int) -> bool:
|
||
"""Simulate a timed stage with progress updates (mock / dev mode)."""
|
||
start = time.monotonic()
|
||
|
||
while True:
|
||
if await _is_cancelled(job_id):
|
||
return False
|
||
|
||
elapsed = time.monotonic() - start
|
||
pct = min(100, int(elapsed / duration_seconds * 100))
|
||
|
||
await _update_stage_percent(job_id, stage_name, pct)
|
||
await _recalculate_progress(job_id, None)
|
||
_push_update()
|
||
|
||
if pct >= 100:
|
||
return True
|
||
|
||
await asyncio.sleep(POLL_INTERVAL)
|
||
|
||
|
||
async def _stage_final_check(job_id: int, devname: str, drive_id: int | None = None) -> bool:
|
||
"""
|
||
Verify drive passed all tests.
|
||
SSH mode: run smartctl -a and check critical attributes.
|
||
Mock mode: check SMART health field in DB.
|
||
|
||
A transient SSH connectivity failure here must NOT invalidate a prior
|
||
multi-day surface_validate. Retry SSH-only failures, then soft-pass.
|
||
"""
|
||
await asyncio.sleep(1)
|
||
from app import ssh_client
|
||
|
||
def _ssh_only(failures: list[str]) -> bool:
|
||
return bool(failures) and all(f.startswith("SSH error:") for f in failures)
|
||
|
||
if ssh_client.is_configured() and drive_id is not None:
|
||
try:
|
||
attrs = await ssh_client.get_smart_attributes(devname)
|
||
for attempt in range(2):
|
||
if not _ssh_only(attrs.get("failures") or []):
|
||
break
|
||
log.warning(
|
||
"final_check SSH unreachable (attempt %d/3); retrying in 30s",
|
||
attempt + 1,
|
||
extra={"job_id": job_id, "devname": devname},
|
||
)
|
||
await asyncio.sleep(30)
|
||
attrs = await ssh_client.get_smart_attributes(devname)
|
||
|
||
failures = attrs.get("failures") or []
|
||
if _ssh_only(failures):
|
||
log.warning(
|
||
"final_check soft-pass: SSH unreachable after retries; prior stages stand",
|
||
extra={"job_id": job_id, "devname": devname, "ssh_error": failures},
|
||
)
|
||
return True
|
||
|
||
await _store_smart_attrs(drive_id, attrs)
|
||
if attrs["health"] == "FAILED" or failures:
|
||
msg = failures or [f"SMART health: {attrs['health']}"]
|
||
await _set_stage_error(job_id, "final_check",
|
||
"Final check failed: " + "; ".join(msg))
|
||
return False
|
||
return True
|
||
except Exception as exc:
|
||
log.warning("SSH final_check raised, falling back to DB check: %s", exc)
|
||
|
||
# DB check (mock mode fallback)
|
||
async with _db() as db:
|
||
cur = await db.execute(
|
||
"SELECT smart_health FROM drives WHERE devname=?", (devname,)
|
||
)
|
||
row = await cur.fetchone()
|
||
|
||
if not row or row[0] == "FAILED":
|
||
await _set_stage_error(job_id, "final_check", "Drive SMART health is FAILED after burn-in")
|
||
return False
|
||
|
||
return True
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# DB helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
async def _is_cancelled(job_id: int) -> bool:
|
||
async with _db() as db:
|
||
cur = await db.execute("SELECT state FROM burnin_jobs WHERE id=?", (job_id,))
|
||
row = await cur.fetchone()
|
||
return bool(row and row[0] == "cancelled")
|
||
|
||
|
||
async def _start_stage(job_id: int, stage_name: str) -> None:
|
||
async with _db() as db:
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
await db.execute(
|
||
"UPDATE burnin_stages SET state='running', started_at=? WHERE burnin_job_id=? AND stage_name=?",
|
||
(_now(), job_id, stage_name),
|
||
)
|
||
await db.execute(
|
||
"UPDATE burnin_jobs SET stage_name=? WHERE id=?",
|
||
(stage_name, job_id),
|
||
)
|
||
await db.commit()
|
||
|
||
|
||
async def _finish_stage(job_id: int, stage_name: str, success: bool, error_text: str | None = None) -> None:
|
||
now = _now()
|
||
state = "passed" if success else "failed"
|
||
async with _db() as db:
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
cur = await db.execute(
|
||
"SELECT started_at FROM burnin_stages WHERE burnin_job_id=? AND stage_name=?",
|
||
(job_id, stage_name),
|
||
)
|
||
row = await cur.fetchone()
|
||
duration = None
|
||
if row and row[0]:
|
||
try:
|
||
start = datetime.fromisoformat(row[0])
|
||
if start.tzinfo is None:
|
||
start = start.replace(tzinfo=timezone.utc)
|
||
duration = (datetime.now(timezone.utc) - start).total_seconds()
|
||
except Exception:
|
||
pass
|
||
|
||
# Only overwrite error_text if one is passed; otherwise preserve what the stage already wrote
|
||
if error_text is not None:
|
||
await db.execute(
|
||
"""UPDATE burnin_stages
|
||
SET state=?, percent=?, finished_at=?, duration_seconds=?, error_text=?
|
||
WHERE burnin_job_id=? AND stage_name=?""",
|
||
(state, 100 if success else None, now, duration, error_text, job_id, stage_name),
|
||
)
|
||
else:
|
||
await db.execute(
|
||
"""UPDATE burnin_stages
|
||
SET state=?, percent=?, finished_at=?, duration_seconds=?
|
||
WHERE burnin_job_id=? AND stage_name=?""",
|
||
(state, 100 if success else None, now, duration, job_id, stage_name),
|
||
)
|
||
await db.commit()
|
||
|
||
|
||
async def _update_stage_percent(job_id: int, stage_name: str, pct: int) -> None:
|
||
async with _db() as db:
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
await db.execute(
|
||
"UPDATE burnin_stages SET percent=? WHERE burnin_job_id=? AND stage_name=?",
|
||
(pct, job_id, stage_name),
|
||
)
|
||
await db.commit()
|
||
|
||
|
||
async def _cancel_stage(job_id: int, stage_name: str) -> None:
|
||
now = _now()
|
||
async with _db() as db:
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
await db.execute(
|
||
"UPDATE burnin_stages SET state='cancelled', finished_at=? WHERE burnin_job_id=? AND stage_name=?",
|
||
(now, job_id, stage_name),
|
||
)
|
||
await db.commit()
|
||
|
||
|
||
async def _append_stage_log(job_id: int, stage_name: str, text: str) -> None:
|
||
"""Append text to the log_text column of a burnin_stages row."""
|
||
async with _db() as db:
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
await db.execute(
|
||
"""UPDATE burnin_stages
|
||
SET log_text = COALESCE(log_text, '') || ?
|
||
WHERE burnin_job_id=? AND stage_name=?""",
|
||
(text, job_id, stage_name),
|
||
)
|
||
await db.commit()
|
||
|
||
|
||
async def _update_stage_bad_blocks(job_id: int, stage_name: str, count: int) -> None:
|
||
async with _db() as db:
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
await db.execute(
|
||
"UPDATE burnin_stages SET bad_blocks=? WHERE burnin_job_id=? AND stage_name=?",
|
||
(count, job_id, stage_name),
|
||
)
|
||
await db.commit()
|
||
|
||
|
||
async def _store_smart_attrs(drive_id: int, attrs: dict) -> None:
|
||
"""Persist latest SMART attribute dict to drives.smart_attrs (JSON)."""
|
||
import json
|
||
# Convert int keys to str for JSON serialisation
|
||
serialisable = {str(k): v for k, v in attrs.get("attributes", {}).items()}
|
||
blob = json.dumps({
|
||
"health": attrs.get("health", "UNKNOWN"),
|
||
"attrs": serialisable,
|
||
"warnings": attrs.get("warnings", []),
|
||
"failures": attrs.get("failures", []),
|
||
})
|
||
async with _db() as db:
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
await db.execute("UPDATE drives SET smart_attrs=? WHERE id=?", (blob, drive_id))
|
||
await db.commit()
|
||
|
||
|
||
async def _store_smart_raw_output(drive_id: int, test_type: str, raw: str) -> None:
|
||
"""Store raw smartctl output in smart_tests.raw_output."""
|
||
async with _db() as db:
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
await db.execute(
|
||
"UPDATE smart_tests SET raw_output=? WHERE drive_id=? AND test_type=?",
|
||
(raw, drive_id, test_type.lower()),
|
||
)
|
||
await db.commit()
|
||
|
||
|
||
async def _set_stage_error(job_id: int, stage_name: str, error_text: str) -> None:
|
||
async with _db() as db:
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
await db.execute(
|
||
"UPDATE burnin_stages SET error_text=? WHERE burnin_job_id=? AND stage_name=?",
|
||
(error_text, job_id, stage_name),
|
||
)
|
||
await db.commit()
|
||
|
||
|
||
async def _recalculate_progress(job_id: int, profile: str | None = None) -> None:
|
||
"""Recompute overall job % from actual stage rows. profile param is unused (kept for compat)."""
|
||
async with _db() as db:
|
||
db.row_factory = aiosqlite.Row
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
|
||
cur = await db.execute(
|
||
"SELECT stage_name, state, percent FROM burnin_stages WHERE burnin_job_id=? ORDER BY id",
|
||
(job_id,),
|
||
)
|
||
stages = await cur.fetchall()
|
||
if not stages:
|
||
return
|
||
|
||
total_weight = sum(_STAGE_BASE_WEIGHTS.get(s["stage_name"], 5) for s in stages)
|
||
if total_weight == 0:
|
||
return
|
||
|
||
completed = 0.0
|
||
current = None
|
||
for s in stages:
|
||
w = _STAGE_BASE_WEIGHTS.get(s["stage_name"], 5)
|
||
st = s["state"]
|
||
if st == "passed":
|
||
completed += w
|
||
elif st == "running":
|
||
completed += w * (s["percent"] or 0) / 100
|
||
current = s["stage_name"]
|
||
|
||
pct = int(completed / total_weight * 100)
|
||
await db.execute(
|
||
"UPDATE burnin_jobs SET percent=?, stage_name=? WHERE id=?",
|
||
(pct, current, job_id),
|
||
)
|
||
await db.commit()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# SSE push
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _push_update(alert: dict | None = None) -> None:
|
||
"""Notify SSE subscribers that data has changed, with optional browser notification payload."""
|
||
try:
|
||
from app import poller
|
||
poller._notify_subscribers(alert=alert)
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Stuck-job detection (called by poller every ~5 cycles)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
async def check_stuck_jobs() -> None:
|
||
"""Mark jobs that have been 'running' beyond stuck_job_hours as 'unknown'."""
|
||
threshold_seconds = settings.stuck_job_hours * 3600
|
||
|
||
async with _db() as db:
|
||
db.row_factory = aiosqlite.Row
|
||
await db.execute("PRAGMA journal_mode=WAL")
|
||
|
||
cur = await db.execute("""
|
||
SELECT bj.id, bj.drive_id, d.devname, bj.started_at
|
||
FROM burnin_jobs bj
|
||
JOIN drives d ON d.id = bj.drive_id
|
||
WHERE bj.state = 'running'
|
||
AND bj.started_at IS NOT NULL
|
||
AND (julianday('now') - julianday(bj.started_at)) * 86400 > ?
|
||
""", (threshold_seconds,))
|
||
stuck = await cur.fetchall()
|
||
|
||
if not stuck:
|
||
return
|
||
|
||
now = _now()
|
||
for row in stuck:
|
||
job_id, drive_id, devname, started_at = row[0], row[1], row[2], row[3]
|
||
log.critical(
|
||
"Stuck burn-in detected — marking unknown",
|
||
extra={"job_id": job_id, "devname": devname, "started_at": started_at},
|
||
)
|
||
await db.execute(
|
||
"UPDATE burnin_jobs SET state='unknown', finished_at=? WHERE id=?",
|
||
(now, job_id),
|
||
)
|
||
await db.execute(
|
||
"""UPDATE burnin_stages SET state='unknown', finished_at=?
|
||
WHERE burnin_job_id=? AND state='running'""",
|
||
(now, job_id),
|
||
)
|
||
await db.execute(
|
||
"""INSERT INTO audit_events (event_type, drive_id, burnin_job_id, operator, message)
|
||
VALUES (?,?,?,?,?)""",
|
||
("burnin_stuck", drive_id, job_id, "system",
|
||
f"Job stuck for >{settings.stuck_job_hours}h — automatically marked unknown"),
|
||
)
|
||
|
||
await db.commit()
|
||
|
||
# Actually unstick the running tasks so they release their semaphore slot.
|
||
# Without this the DB state becomes 'unknown' but the asyncio task keeps
|
||
# holding the slot forever — which is the bug that left subsequent jobs
|
||
# permanently 'queued' until container restart.
|
||
for row in stuck:
|
||
job_id = row[0]
|
||
await _kill_remote_process(job_id)
|
||
task = _active_tasks.get(job_id)
|
||
if task and not task.done():
|
||
task.cancel()
|
||
|
||
_push_update()
|
||
log.warning("Marked %d stuck job(s) as unknown", len(stuck))
|